[
  {
    "path": ".clang-format",
    "content": "Language: Cpp\nAccessModifierOffset: -4\nAlignAfterOpenBracket: Align\nAllowShortEnumsOnASingleLine: false\nAlignConsecutiveAssignments: true\nAlignConsecutiveDeclarations: true\nAlignEscapedNewlines: Right\nAlignOperands: true\nAlignTrailingComments: true\nAllowAllParametersOfDeclarationOnNextLine: true\nAllowAllArgumentsOnNextLine: true\nAllowShortBlocksOnASingleLine: Empty\nAllowShortCaseLabelsOnASingleLine: false\nAllowShortFunctionsOnASingleLine: Empty\nAllowShortIfStatementsOnASingleLine: Never\nAllowShortLoopsOnASingleLine: false\nAlwaysBreakAfterReturnType: None\nAlwaysBreakBeforeMultilineStrings: false\nAlwaysBreakTemplateDeclarations: true\nBinPackArguments: false\nBinPackParameters: false\nBreakBeforeBinaryOperators: NonAssignment\nBreakBeforeBraces: Stroustrup\nBreakBeforeTernaryOperators: false\nBreakConstructorInitializers: AfterColon\nBreakInheritanceList: AfterColon\nBreakStringLiterals: false\nColumnLimit: 120\nCompactNamespaces: false\nConstructorInitializerAllOnOneLineOrOnePerLine: true\nConstructorInitializerIndentWidth: 4\nContinuationIndentWidth: 4\nCpp11BracedListStyle: true\nDerivePointerAlignment: false\nFixNamespaceComments: true\nIndentCaseLabels: true\nIndentPPDirectives: None\nIndentWidth: 4\nIndentWrappedFunctionNames: false\nKeepEmptyLinesAtTheStartOfBlocks: true\nMaxEmptyLinesToKeep: 1\nNamespaceIndentation: None\nPointerAlignment: Left\nReflowComments: true\nSortIncludes: true\nSortUsingDeclarations: false\nSpaceAfterCStyleCast: false\nSpaceAfterTemplateKeyword: false\nSpaceBeforeAssignmentOperators: true\nSpaceBeforeCtorInitializerColon: false\nSpaceBeforeInheritanceColon: false\nSpaceBeforeParens: ControlStatements\nSpaceInEmptyParentheses: false\nSpacesBeforeTrailingComments: 2\nSpacesInAngles: false\nSpacesInCStyleCastParentheses: false\nSpacesInContainerLiterals: false\nSpacesInParentheses: false\nSpacesInSquareBrackets: false\nStandard: c++17\nTabWidth: 4\nUseTab: Never\n"
  },
  {
    "path": ".claude/skills/check-env/SKILL.md",
    "content": "---\nname: check-env\ndescription: Check if the LMDeploy dev environment is properly set up.\n---\n\n# Check LMDeploy Dev Environment\n\n## 1. Find and activate the conda env\n\n```bash\nconda env list                        # starred = currently active\nconda activate <env-name>             # pick the right env for this project\n```\n\n## 2. Verify editable install\n\n```bash\npython -c \"import lmdeploy; print(lmdeploy.__file__)\"\n# Must point into the repo dir, e.g. /path/to/lmdeploy_vl/lmdeploy/__init__.py\n```\n\nIf it doesn't:\n\n```bash\npip install -e .                      # run from repo root\n```\n\n## 3. Confirm python and CUDA\n\n```bash\nwhich python                          # must show conda env path, not /usr/bin/python\npython -c \"import torch; print(torch.__version__, torch.version.cuda, torch.cuda.device_count())\"\n```\n\n## Troubleshooting\n\n| Problem              | Fix                                             |\n| -------------------- | ----------------------------------------------- |\n| `conda: not found`   | `source ~/miniconda3/etc/profile.d/conda.sh`    |\n| Wrong Python         | `conda deactivate && conda activate <env-name>` |\n| `lmdeploy` not found | `pip install -e .` from repo root               |\n"
  },
  {
    "path": ".claude/skills/code-navigation/SKILL.md",
    "content": "---\nname: code-navigation\ndescription: LMDeploy codebase directory map for fast orientation.\n---\n\n# LMDeploy Project Structure\n\n```text\nlmdeploy/\n├── cli/                        # Command line interface implementations\n├── lib/                        # Shared libraries/binary assets\n├── lite/                       # Quantization Toolkit\n│   ├── apis/                   # Calibration, AWQ, and SmoothQuant entry points\n│   ├── modeling/               # GPTQ/quantized model specific logic\n│   ├── quantization/           # Scaling calculation (activations/weights)\n│   └── utils/                  # Quantization helper functions (cal_qparams.py)\n├── metrics/                    # Statistics and performance monitoring\n├── monitoring/                 # Monitoring configs (Docker/Grafana)\n├── pytorch/                    # PyTorch inference backend\n│   ├── adapter/                # LoRA and adapter logic\n│   ├── backends/               # Kernel/Operator Dispatchers (FP8, AWQ, CUDA)\n│   ├── check_env/              # Environment/GPU capability sanity checks\n│   ├── configurations/         # Per-model engine configurations (Llama, etc.)\n│   ├── devices/                # Device management (CUDA)\n│   ├── disagg/                 # Disaggregated prefill/decode logic\n│   ├── engine/                 # Main Scheduler and Execution Loop\n│   ├── kernels/                # Triton/CUDA Kernels (w8a8_triton_kernels.py)\n│   ├── models/                 # Model Patches: Replacing HF layers with kernels\n│   ├── multimodal/             # Multi-modal input types for Pytorch engine\n│   ├── nn/                     # Reusable PyTorch modules\n│   ├── paging/                 # PagedAttention: KV cache block management\n│   ├── spec_decode/            # Speculative decoding logic\n│   ├── strategies/             # Execution and dispatch strategies\n│   ├── third_party/            # External dependencies/repos\n│   ├── tools/                  # Internal engine debugging tools\n│   ├── transformers/           # HF Transformers integration depth\n│   └── weight_loader/          # Sharded/quantized weight loading engine\n├── serve/                      # Serving: OpenAI-compatible API and gRPC\n├── turbomind/                  # C++ TurboMind inference backend\n├── vl/                         # Vision-Language (VL) Support and Image Processing\n│   ├── media/                  # Image/Video/... loaders and base classes\n│   └── model/                  # VL Archs (InternVL, Qwen-VL, LLaVA, etc.) and preprocess\n├── api.py                      # High-level entry for model interaction\n├── archs.py                    # Registry: Maps architectures to runtime patches\n├── messages.py                 # Core Types: GenerationConfig, EngineConfig\n├── model.py                    # Chat Templates: CRITICAL for conversation logic\n├── pipeline.py                 # Main Orchestrator: Engine + Tokenizer\n└── tokenizer.py                # Wrapper for HF/SentencePiece tokenizers\n```\n"
  },
  {
    "path": ".claude/skills/resolve-review/SKILL.md",
    "content": "---\nname: resolve-review\ndescription: Fetch and resolve PR review comments, then push fixes.\n---\n\n# Resolve PR Review Comments\n\n## 1. Fetch comments\n\n```bash\ngh api repos/InternLM/lmdeploy/pulls/<PR>/comments \\\n  | python3 -c \"\nimport json, sys\nfor c in json.load(sys.stdin):\n    print(f'[{c[\\\"path\\\"]}:{c.get(\\\"line\\\",\\\"?\\\")}]')\n    print(c['body'])\n    print()\n\"\n```\n\n## 2. Fix each issue\n\nRead the flagged file, understand the comment, edit the file.\n\n## 3. Lint\n\n```bash\npre-commit run --all-files\n```\n\n## 4. Stage & commit\n\n```bash\ngit add <fixed files>\ngit commit -m \"fix: address PR review comments\"\n```\n\n## 5. Push\n\n```bash\ngit push\n```\n"
  },
  {
    "path": ".claude/skills/submit-pr/SKILL.md",
    "content": "---\nname: submit-pr\ndescription: Submit a GitHub pull request for LMDeploy.\n---\n\n# Submit a PR for LMDeploy\n\n## 1. Create branch (off main)\n\nSkip this step if already on a feature branch.\n\n```bash\ngit checkout main && git pull\ngit checkout -b <type>/<short-description>   # e.g. feat/qwen3-omni\n```\n\n## 2. Lint\n\n```bash\npre-commit run --all-files\n```\n\n## 3. Stage\n\n```bash\ngit add lmdeploy/path/to/changed_file.py     # specific files only, never git add .\ngit status                                   # verify staged set\n```\n\n## 4. Commit\n\n```bash\ngit commit -m \"feat: add Qwen3-Omni support\"\n# Conventional prefixes: feat | fix | refactor | docs | test | chore\n```\n\n## 5. Push\n\n```bash\ngit push -u origin <branch>\n```\n\n## 6. Create PR\n\n```bash\ngh pr create --title \"<type>: <short description>\" --body \"$(cat <<'EOF'\n## Summary\n- <bullet 1>\n- <bullet 2>\n\n## Test plan\n- [ ] `pre-commit run --all-files` passes\n- [ ] unit tests pass: `pytest tests/test_lmdeploy/`\n- [ ] manual smoke test with pipeline\n\n🤖 Generated with [Claude Code](https://claude.com/claude-code)\nEOF\n)\"\n```\n"
  },
  {
    "path": ".claude/skills/support-new-model/SKILL.md",
    "content": "---\nname: support-new-model\ndescription: Add a new LLM or VLM to LMDeploy's PyTorch backend.\n---\n\n# Tutorial: Adding a New Model to LMDeploy (PyTorch Backend)\n\nThis guide walks through adding a new LLM or VLM to LMDeploy's PyTorch backend.\n\n______________________________________________________________________\n\n## Before Writing Any Code\n\n**Study the reference implementations before touching any files.**\n\n1. Read the HF model's `config.json` to understand: `model_type`, `architectures`, layer counts, hidden dims, number of attention heads, MoE parameters (if applicable).\n2. Identify which category the model falls into:\n   - **LLM only** — pure text model\n   - **VLM** — text + vision (needs an additional preprocessor in `vl/model/`)\n3. Find the closest existing model in LMDeploy and read it thoroughly:\n\n| Reference model        | File(s)                                                                                                                                   |\n| ---------------------- | ----------------------------------------------------------------------------------------------------------------------------------------- |\n| LLM (dense)            | `lmdeploy/pytorch/models/qwen3.py`                                                                                                        |\n| LLM (MoE)              | `lmdeploy/pytorch/models/qwen3_moe.py`                                                                                                    |\n| VLM preprocessor       | `lmdeploy/vl/model/qwen3.py`                                                                                                              |\n| VLM (composite config) | `lmdeploy/pytorch/models/qwen3_omni_moe_thinker.py` + `lmdeploy/pytorch/configurations/qwen3_omni.py` + `lmdeploy/vl/model/qwen3_omni.py` |\n\n______________________________________________________________________\n\n## Key Files Quick Reference\n\n| File                                         | Purpose                                                         |\n| -------------------------------------------- | --------------------------------------------------------------- |\n| `lmdeploy/pytorch/models/<model>.py`         | Attention, MLP, DecoderLayer, Model, ForCausalLM                |\n| `lmdeploy/pytorch/models/module_map.py`      | HF class name → LMDeploy class path mapping                     |\n| `lmdeploy/pytorch/configurations/<model>.py` | Config builder — only needed for non-standard/nested HF configs |\n| `lmdeploy/vl/model/<model>.py`               | VLM: image/video preprocessing *(VLM only)*                     |\n| `lmdeploy/vl/model/base.py`                  | `VisionModel` base class + `VISION_MODELS` registry             |\n| `lmdeploy/archs.py`                          | VLM: arch name → task mapping *(VLM only)*                      |\n| `lmdeploy/lite/apis/calibrate.py`            | Quantization: layer/norm/head mappings *(optional)*             |\n| `lmdeploy/lite/quantization/awq.py`          | Quantization: AWQ scale mappings *(optional)*                   |\n\n______________________________________________________________________\n\n## Step-by-Step: LLM (PyTorch Backend)\n\n### Step 1 — Create the PyTorch model file\n\n**File:** `lmdeploy/pytorch/models/<model_name>.py`\n\nImplement the following class hierarchy (innermost → outermost):\n\n1. **`<Model>Attention`** — QKV projection, rotary embedding, attention forward\n2. **`<Model>MLP`** — gate-up linear, activation, down projection\n3. **`<Model>DecoderLayer`** — wraps Attention + MLP with layer norms and residual connections\n4. **`<Model>Model`** — embedding table, all decoder layers, final norm, rotary embedding\n5. **`<Model>ForCausalLM`** — top-level class; inherits `nn.Module`, `DeployModelMixinV1`, `CudaGraphMixin`\n\n**Required imports:**\n\n```python\nimport torch\nimport torch.nn as nn\nfrom lmdeploy.pytorch.model_inputs import StepContext, StepContextManager\nfrom lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, RMSNorm, SiluAndMul,\n                                  build_rotary_embedding_from_config)\nfrom lmdeploy.pytorch.nn.linear import (build_down_linear, build_gateup_linear,\n                                         build_o_proj, build_qkv_proj)\nfrom lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight\nfrom .patch import add_prefix\nfrom .utils.cudagraph import CudaGraphMixin\nfrom .utils.model import DeployModelMixinV1, build_embedding\n```\n\n**Attention skeleton:**\n\n```python\nclass MyModelAttention(nn.Module):\n    def __init__(self, config, dtype=None, device=None, prefix=''):\n        super().__init__()\n        self.qkv_proj = build_qkv_proj(\n            config.hidden_size,\n            num_q_heads=config.num_attention_heads,\n            num_kv_heads=config.num_key_value_heads,\n            head_size=config.hidden_size // config.num_attention_heads,\n            bias=False,\n            dtype=dtype, device=device, prefix=add_prefix('qkv_proj', prefix))\n        self.apply_rotary_pos_emb = ApplyRotaryEmb()\n        self.attn_fwd = Attention(\n            config.num_attention_heads,\n            config.hidden_size // config.num_attention_heads,\n            num_kv_heads=config.num_key_value_heads)\n        self.o_proj = build_o_proj(\n            config.num_attention_heads,\n            config.hidden_size // config.num_attention_heads,\n            config.hidden_size,\n            bias=False,\n            dtype=dtype, device=device, prefix=add_prefix('o_proj', prefix))\n\n    def forward(self, hidden_states, rotary_pos_emb, past_key_value, attn_metadata):\n        qkv_states = self.qkv_proj(hidden_states)\n        # split q, k, v; apply rotary; call attn_fwd; project output\n        ...\n```\n\n**MLP skeleton:**\n\n```python\nclass MyModelMLP(nn.Module):\n    def __init__(self, config, dtype=None, device=None, prefix=''):\n        super().__init__()\n        self.gate_up_proj = build_gateup_linear(\n            config.hidden_size, config.intermediate_size,\n            bias=False, dtype=dtype, device=device,\n            prefix=add_prefix('gate_up_proj', prefix))\n        self.down_proj = build_down_linear(\n            config.intermediate_size, config.hidden_size,\n            bias=False, dtype=dtype, device=device,\n            prefix=add_prefix('down_proj', prefix))\n        self.act_fn = SiluAndMul()\n\n    def forward(self, x):\n        return self.down_proj(self.act_fn(self.gate_up_proj(x)))\n```\n\n**ForCausalLM skeleton (critical fields):**\n\n```python\nclass MyModelForCausalLM(nn.Module, DeployModelMixinV1, CudaGraphMixin):\n    # Maps packed param name → list of original HF param suffixes\n    packed_modules_mapping = {\n        'qkv_proj': ['q_proj', 'k_proj', 'v_proj'],\n        'gate_up_proj': ['gate_proj', 'up_proj'],\n    }\n\n    def __init__(self, config, ctx_mgr=None, prefix='', **kwargs):\n        super().__init__()\n        self.model = MyModelModel(config, ...)\n        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n        self.ctx_mgr = ctx_mgr\n\n    def get_input_embeddings(self):\n        return self.model.embed_tokens\n\n    def forward(self, input_ids, inputs_embeds, past_key_values, attn_metadata, **kwargs):\n        hidden_states = self.model(input_ids, inputs_embeds, past_key_values, attn_metadata)\n        return hidden_states\n\n    def get_logits(self, hidden_states):\n        return self.lm_head(hidden_states)\n\n    # prepare_inputs_for_generation and load_weights: copy from qwen3.py,\n    # update stacked_params_mapping to match this model's HF weight names.\n```\n\n______________________________________________________________________\n\n### Step 2 — Register in `module_map.py`\n\n**File:** `lmdeploy/pytorch/models/module_map.py`\n\nAdd an entry to `MODULE_MAP`. The key is the exact HF architecture class name from `config.json`'s `architectures` field:\n\n```python\nMODULE_MAP.update({\n    'MyModelForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.my_model.MyModelForCausalLM',\n})\n```\n\n______________________________________________________________________\n\n### Step 3 — Add config builder (if needed)\n\n**File:** `lmdeploy/pytorch/configurations/<model_name>.py`\n\n**Skip this step** for models with a standard flat HF config — `DefaultModelConfigBuilder` handles them automatically.\n\nOnly create this file when the HF config is non-standard, e.g.:\n\n- Nested config (e.g., Qwen3-Omni has `hf_config.thinker_config.text_config`)\n- Unusual `model_type` that needs special field remapping\n\n```python\nfrom .builder import AutoModelConfigBuilder, DefaultModelConfigBuilder\n\nclass MyModelConfigBuilder(AutoModelConfigBuilder):\n    @classmethod\n    def condition(cls, hf_config):\n        # Must match model_type from config.json exactly\n        return hf_config.model_type == 'my_model'\n\n    @classmethod\n    def build(cls, hf_config, model_path=None, **kwargs):\n        # Extract the text config if nested; patch fields if needed\n        cfg = DefaultModelConfigBuilder.build(hf_config, model_path, **kwargs)\n        cfg.hf_config = hf_config  # keep full config for VLM layers\n        return cfg\n```\n\nAuto-discovery: subclasses of `AutoModelConfigBuilder` register themselves automatically via `__init_subclass__()` — no import needed elsewhere.\n\n______________________________________________________________________\n\n### Step 4 — Add quantization mappings (optional)\n\nOnly needed to support AWQ/SmoothQuant calibration for this model family.\n\n**`lmdeploy/lite/apis/calibrate.py`** — add layer name, norm name, and head name mappings for the new model type.\n\n**`lmdeploy/lite/quantization/awq.py`** — add entries to `NORM_FCS_MAP` (norm → downstream FC layers) and `FC_FCS_MAP` (FC → downstream FC layers) following the existing patterns.\n\n______________________________________________________________________\n\n## Step-by-Step: VLM (additional steps)\n\n### Step 5 — Create the VL preprocessor\n\n**File:** `lmdeploy/vl/model/<model_name>.py`\n\nThe preprocessor handles image/video decoding and feature extraction before the LLM backbone sees the input.\n\n```python\nfrom lmdeploy.vl.model.base import VISION_MODELS, VisionModel\n\n@VISION_MODELS.register_module()\nclass MyModelVLModel(VisionModel):\n    # Must match hf_config.architectures exactly (can be a list for variants)\n    _arch = ['MyModelForConditionalGeneration']\n\n    def build_preprocessor(self):\n        \"\"\"Load the vision processor from the model checkpoint.\"\"\"\n        from transformers import AutoProcessor\n        self.processor = AutoProcessor.from_pretrained(self.model_path)\n        # Set image_token_id to the token ID of the image placeholder\n        # (used by the engine to know where to inject image features)\n        tokenizer = self.processor.tokenizer\n        self.image_token = '<image>'  # model-specific placeholder token\n        self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)\n\n    # preprocess and to_pytorch: copy from vl/model/qwen3.py and adapt\n    # image token handling (image_token, image_token_id, image_tokens count).\n```\n\nKey points:\n\n- `collect_images()`, `proc_messages()`, `to_pytorch_aux()` are all provided by `VisionModel` — do not re-implement them.\n- `image_tokens` tells the engine how many token slots to reserve for each image.\n- Auto-registered via `@VISION_MODELS.register_module()` when the module is imported. **Add an explicit import** in `lmdeploy/vl/model/builder.py` alongside the existing imports so the decorator runs at startup:\n\n```python\nfrom .my_model import MyModelVLModel  # noqa F401\n```\n\n______________________________________________________________________\n\n### Step 6 — Register VLM arch in `archs.py`\n\n**File:** `lmdeploy/archs.py`\n\nAdd the architecture name to the `supported_archs` set inside `check_vl_llm()` so the engine routes the model through the VLM code path:\n\n```python\n# lmdeploy/archs.py — inside check_vl_llm()\nsupported_archs = set([\n    ...\n    'MyModelForConditionalGeneration',  # add this line\n])\n```\n\n______________________________________________________________________\n\n## Checklist\n\n**LLM (PyTorch backend):**\n\n- [ ] `pytorch/models/<model>.py` — all 5 classes implemented (`Attention`, `MLP`, `DecoderLayer`, `Model`, `ForCausalLM`)\n- [ ] `module_map.py` — HF architecture class name registered\n- [ ] `packed_modules_mapping` matches HF parameter naming scheme\n- [ ] `stacked_params_mapping` in `load_weights()` has correct shard indices\n- [ ] `pytorch/configurations/<model>.py` — added only if HF config is non-standard\n- [ ] Weights load cleanly from HF checkpoint (no missing/unexpected key errors)\n\n**VLM (additional):**\n\n- [ ] `vl/model/<model>.py` — `build_preprocessor`, `preprocess`, `to_pytorch` implemented\n- [ ] `_arch` matches `config.json` `architectures[0]` exactly\n- [ ] `image_token_id` correctly resolved from the tokenizer\n- [ ] `image_tokens` count is correct for the image resolution/encoding scheme\n- [ ] `vl/model/builder.py` — explicit import added for new model\n- [ ] `archs.py` entry added\n\n**Quantization (optional):**\n\n- [ ] `calibrate.py` — layer/norm/head name mappings added\n- [ ] `awq.py` — `NORM_FCS_MAP` / `FC_FCS_MAP` entries added\n\n______________________________________________________________________\n\n## Common Pitfalls\n\n1. **Weight name mismatches** — `packed_modules_mapping` keys must match HF param name suffixes exactly. Check actual HF weight names with `list(model.state_dict().keys())[:20]` before coding.\n2. **Wrong shard index order** — `stacked_params_mapping` for QKV must follow Q→0, K→1, V→2. Wrong order silently produces bad outputs.\n3. **Wrong `_arch`** — must match `hf_config.architectures[0]` literally (e.g., `'Qwen3VLForConditionalGeneration'`, not `'Qwen3VL'`).\n4. **`image_token_id` is None** — causes the engine to silently skip image feature injection. Always verify with `tokenizer.convert_tokens_to_ids(image_token)` returning a real token ID.\n5. **Missing `role='preprocess'` append** — `to_pytorch_aux()` searches messages for exactly `role='preprocess'`; if `preprocess()` does not append it, inference will fail with a confusing error.\n6. **Config builder `condition()` mismatch** — `model_type` in `condition()` must match the exact string in `config.json`, not a display name or alias.\n7. **MoE routing** — MoE models need `num_experts`, `num_experts_per_tok`, and a TopK gating mechanism in the MLP. Reference `qwen3_moe.py` for the pattern.\n8. **CUDA graph + dynamic control flow** — models with data-dependent branching (e.g., conditional expert dispatch) may break CUDA graph capture. Use `_no_cudagraph` guards in `CudaGraphMixin` if needed.\n\n______________________________________________________________________\n\n## Verification\n\n**LLM basic test:**\n\n```bash\npython -m lmdeploy.pytorch.chat <model_path> --backend pytorch\n```\n\n**VLM basic test:**\n\n```python\nfrom lmdeploy import pipeline\npipe = pipeline('<model_path>')\nresult = pipe(('Describe this image.', 'path/to/image.jpg'))\nprint(result.text)\n```\n\n**Unit tests:**\n\n```bash\npytest tests/test_lmdeploy/test_vl/     # VLM tests\npytest tests/test_lmdeploy/             # all unit tests\n```\n\n**Debug weight loading:**\n\n```bash\nLMDEPLOY_LOG_LEVEL=DEBUG python -m lmdeploy.pytorch.chat <model_path> --backend pytorch 2>&1 | grep -E \"load|weight|miss\"\n```\n"
  },
  {
    "path": ".github/CONTRIBUTING.md",
    "content": "## Contributing to LMDeploy\n\nWelcome to the LMDeploy community, all kinds of contributions are welcomed, including but not limited to\n\n**Fix bug**\n\nYou can directly post a Pull Request to fix typo in code or documents\n\nThe steps to fix the bug of code implementation are as follows.\n\n1. If the modification involve significant changes, you should create an issue first and describe the error information and how to trigger the bug. Other developers will discuss with you and propose an proper solution.\n\n2. Posting a pull request after fixing the bug and adding corresponding unit test.\n\n**New Feature or Enhancement**\n\n1. If the modification involve significant changes, you should create an issue to discuss with our developers to propose an proper design.\n2. Post a Pull Request after implementing the new feature or enhancement and add corresponding unit test.\n\n**Document**\n\nYou can directly post a pull request to fix documents. If you want to add a document, you should first create an issue to check if it is reasonable.\n\n### Pull Request Workflow\n\nIf you're not familiar with Pull Request, don't worry! The following guidance will tell you how to create a Pull Request step by step. If you want to dive into the develop mode of Pull Request, you can refer to the [official documents](https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests/about-pull-requests)\n\n#### 1. Fork and clone\n\nIf you are posting a pull request for the first time, you should fork the OpenMMLab repositories by clicking the **Fork** button in the top right corner of the GitHub page, and the forked repositories will appear under your GitHub profile.\n\n<img src=\"https://user-images.githubusercontent.com/57566630/167305749-43c7f4e9-449b-4e98-ade5-0c9276d5c9ce.png\" width=\"1200\">\n\nThen, you can clone the repositories to local:\n\n```shell\ngit clone git@github.com:{username}/lmdeploy.git\n```\n\nAfter that, you should add official repository as the upstream repository\n\n```bash\ngit remote add upstream git@github.com:InternLM/lmdeploy.git\n```\n\nCheck whether remote repository has been added successfully by `git remote -v`\n\n```bash\norigin\tgit@github.com:{username}/lmdeploy.git (fetch)\norigin\tgit@github.com:{username}/lmdeploy.git (push)\nupstream\tgit@github.com:InternLM/lmdeploy.git (fetch)\nupstream\tgit@github.com:InternLM/lmdeploy.git (push)\n```\n\n> Here's a brief introduction to origin and upstream. When we use \"git clone\", we create an \"origin\" remote by default, which points to the repository cloned from. As for \"upstream\", we add it ourselves to point to the target repository. Of course, if you don't like the name \"upstream\", you could name it as you wish. Usually, we'll push the code to \"origin\". If the pushed code conflicts with the latest code in official(\"upstream\"), we should pull the latest code from upstream to resolve the conflicts, and then push to \"origin\" again. The posted Pull Request will be updated automatically.\n\n#### 2. Configure pre-commit\n\nYou should configure [pre-commit](https://pre-commit.com/#intro) in the local development environment to make sure the code style matches that of LMDeploy. **Note**: The following code should be executed under the lmdeploy directory.\n\n```shell\npip install -U pre-commit\npre-commit install\n```\n\nCheck that pre-commit is configured successfully, and install the hooks defined in `.pre-commit-config.yaml`.\n\n```shell\npre-commit run --all-files\n```\n\n<img src=\"https://user-images.githubusercontent.com/57566630/173660750-3df20a63-cb66-4d33-a986-1f643f1d8aaf.png\" width=\"1200\">\n\n<img src=\"https://user-images.githubusercontent.com/57566630/202368856-0465a90d-8fce-4345-918e-67b8b9c82614.png\" width=\"1200\">\n\nIf the installation process is interrupted, you can repeatedly run `pre-commit run ... ` to continue the installation.\n\nIf the code does not conform to the code style specification, pre-commit will raise a warning and  fixes some of the errors automatically.\n\n<img src=\"https://user-images.githubusercontent.com/57566630/202369176-67642454-0025-4023-a095-263529107aa3.png\" width=\"1200\">\n\nIf we want to commit our code bypassing the pre-commit hook, we can use the `--no-verify` option(**only for temporarily commit**).\n\n```shell\ngit commit -m \"xxx\" --no-verify\n```\n\n#### 3. Create a development branch\n\nAfter configuring the pre-commit, we should create a branch based on the master branch to develop the new feature or fix the bug. The proposed branch name is `username/pr_name`\n\n```shell\ngit checkout -b yhc/refactor_contributing_doc\n```\n\nIn subsequent development, if the master branch of the local repository is behind the master branch of \"upstream\", we need to pull the upstream for synchronization, and then execute the above command:\n\n```shell\ngit pull upstream main\n```\n\n#### 4. Commit the code and pass the unit test\n\n- lmdeploy introduces mypy to do static type checking to increase the robustness of the code. Therefore, we need to add Type Hints to our code and pass the mypy check. If you are not familiar with Type Hints, you can refer to [this tutorial](https://docs.python.org/3/library/typing.html).\n\n- The committed code should pass through the unit test\n\n  ```shell\n  # Pass all unit tests\n  pytest tests\n\n  # Pass the unit test of runner\n  pytest tests/test_runner/test_runner.py\n  ```\n\n  If the unit test fails for lack of dependencies, you can install the dependencies referring to the [guidance](#unit-test)\n\n- If the documents are modified/added, we should check the rendering result referring to [guidance](#document-rendering)\n\n#### 5. Push the code to remote\n\nWe could push the local commits to remote after passing through the check of unit test and pre-commit. You can associate the local branch with remote branch by adding `-u` option.\n\n```shell\ngit push -u origin {branch_name}\n```\n\nThis will allow you to use the `git push` command to push code directly next time, without having to specify a branch or the remote repository.\n\n#### 6. Create a Pull Request\n\n(1) Create a pull request in GitHub's Pull request interface\n\n<img src=\"https://user-images.githubusercontent.com/57566630/201533288-516f7ac4-0b14-4dc8-afbd-912475c368b5.png\" width=\"1200\">\n\n(2) Modify the PR description according to the guidelines so that other developers can better understand your changes\n\n<img src=\"https://user-images.githubusercontent.com/57566630/202242953-c91a18ff-e388-4ff9-8591-5fae0ead6c1e.png\" width=\"1200\">\n\nFind more details about Pull Request description in [pull request guidelines](#pr-specs).\n\n**note**\n\n(a) The Pull Request description should contain the reason for the change, the content of the change, and the impact of the change, and be associated with the relevant Issue (see [documentation](https://docs.github.com/en/issues/tracking-your-work-with-issues/linking-a-pull-request-to-an-issue))\n\n(b) If it is your first contribution, please sign the CLA\n\n<img src=\"https://user-images.githubusercontent.com/57566630/167307569-a794b967-6e28-4eac-a942-00deb657815f.png\" width=\"1200\">\n\n(c) Check whether the Pull Request pass through the CI\n\n<img src=\"https://user-images.githubusercontent.com/57566630/167307490-f9ebf9fa-63c0-4d83-8ba1-081ea169eb3a.png\" width=\"1200\">\n\nLMDeploy will run unit test for the posted Pull Request on different platforms (Linux, Window, Mac), based on different versions of Python, PyTorch, CUDA to make sure the code is correct. We can see the specific test information by clicking `Details` in the above image so that we can modify the code.\n\n(3) If the Pull Request passes the CI, then you can wait for the review from other developers. You'll modify the code based on the reviewer's comments, and repeat the steps [4](#4-commit-the-code-and-pass-the-unit-test)-[5](#5-push-the-code-to-remote) until all reviewers approve it. Then, we will merge it ASAP.\n\n<img src=\"https://user-images.githubusercontent.com/57566630/202145400-cc2cd8c4-10b0-472f-ba37-07e6f50acc67.png\" width=\"1200\">\n\n#### 7. Resolve conflicts\n\nIf your local branch conflicts with the latest master branch of \"upstream\", you'll need to resolove them. There are two ways to do this:\n\n```shell\ngit fetch --all --prune\ngit rebase upstream/main\n```\n\nor\n\n```shell\ngit fetch --all --prune\ngit merge upstream/main\n```\n\nIf you are very good at handling conflicts, then you can use rebase to resolve conflicts, as this will keep your commit logs tidy. If you are not familiar with `rebase`, then you can use `merge` to resolve conflicts.\n\n### Guidance\n\n#### Document rendering\n\nIf the documents are modified/added, we should check the rendering result. We could install the dependencies and run the following command to render the documents and check the results:\n\n```shell\npip install -r requirements/docs.txt\ncd docs/zh_cn/\n# or docs/en\nmake html\n# check file in ./docs/zh_cn/_build/html/index.html\n```\n\n### Code style\n\n#### Python\n\nWe adopt [PEP8](https://www.python.org/dev/peps/pep-0008/) as the preferred code style.\n\nWe use the following tools for linting and formatting:\n\n- [flake8](https://github.com/PyCQA/flake8): A wrapper around some linter tools.\n- [isort](https://github.com/timothycrosley/isort): A Python utility to sort imports.\n- [yapf](https://github.com/google/yapf): A formatter for Python files.\n- [codespell](https://github.com/codespell-project/codespell): A Python utility to fix common misspellings in text files.\n- [mdformat](https://github.com/executablebooks/mdformat): Mdformat is an opinionated Markdown formatter that can be used to enforce a consistent style in Markdown files.\n- [docformatter](https://github.com/myint/docformatter): A formatter to format docstring.\n\nWe use [pre-commit hook](https://pre-commit.com/) that checks and formats for `flake8`, `yapf`, `isort`, `trailing whitespaces`, `markdown files`,\nfixes `end-of-files`, `double-quoted-strings`, `python-encoding-pragma`, `mixed-line-ending`, sorts `requirments.txt` automatically on every commit.\nThe config for a pre-commit hook is stored in [.pre-commit-config](../.pre-commit-config.yaml).\n\n#### C++ and CUDA\n\nThe clang-format config is stored in [.clang-format](../.clang-format). And it's recommended to use clang-format version **11**. Please do not use older or newer versions as they will result in differences after formatting, which can cause the [lint](https://github.com/InternLM/lmdeploy/blob/main/.github/workflows/lint.yml#L25) to fail.\n\n### PR Specs\n\n1. Use [pre-commit](https://pre-commit.com) hook to avoid issues of code style\n\n2. One short-time branch should be matched with only one PR\n\n3. Accomplish a detailed change in one PR. Avoid large PR\n\n   - Bad: Support Faster R-CNN\n   - Acceptable: Add a box head to Faster R-CNN\n   - Good: Add a parameter to box head to support custom conv-layer number\n\n4. Provide clear and significant commit message\n\n5. Provide clear and meaningful PR description\n\n   - Task name should be clarified in title. The general format is: \\[Prefix\\] Short description of the PR (Suffix)\n   - Prefix: add new feature \\[Feature\\], fix bug \\[Fix\\], related to documents \\[Docs\\], in developing \\[WIP\\] (which will not be reviewed temporarily)\n   - Introduce main changes, results and influences on other modules in short description\n   - Associate related issues and pull requests with a milestone\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/1-bug-report.yml",
    "content": "name: 🐞 Bug report\ndescription: Create a report to help us reproduce and fix the bug\ntitle: \"[Bug] \"\nlabels: ['Bug']\n\nbody:\n- type: checkboxes\n  attributes:\n    label: Checklist\n    options:\n    - label: 1. I have searched related issues but cannot get the expected help.\n    - label: 2. The bug has not been fixed in the latest version.\n    - label: 3. Please note that if the bug-related issue you submitted lacks corresponding environment info and a minimal reproducible demo, it will be challenging for us to reproduce and resolve the issue, reducing the likelihood of receiving feedback.\n- type: textarea\n  attributes:\n    label: Describe the bug\n    description: A clear and concise description of what the bug is.\n  validations:\n    required: true\n- type: textarea\n  attributes:\n    label: Reproduction\n    description: |\n      1. What command or script did you run?\n    placeholder: |\n      A placeholder for the command.\n  validations:\n    required: true\n- type: textarea\n  attributes:\n    label: Environment\n    description: |\n      1. Please run `lmdeploy check_env` to collect necessary environment information and paste it here.\n      2. You may add addition that may be helpful for locating the problem, such as\n         - Which **model** are you using?\n         - How you installed PyTorch \\[e.g., pip, conda, source\\]\n         - Other environment variables that may be related (such as `$PATH`, `$LD_LIBRARY_PATH`, `$PYTHONPATH`, etc.)\n    placeholder: Environment here.\n    render: Shell\n  validations:\n    required: true\n- type: textarea\n  attributes:\n    label: Error traceback\n    description: |\n      If applicable, paste the error trackback here.\n    placeholder: Logs and traceback here.\n    render: Shell\n- type: markdown\n  attributes:\n    value: >\n     If you have already identified the reason, you can provide the information here. If you are willing to create a PR to fix it, please also leave a comment here and that would be much appreciated!\n\n     Thanks for your bug report. We appreciate it a lot.\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/2-feature-request.yml",
    "content": "name: 🚀 Feature request\ndescription: Suggest an idea for this project\ntitle: \"[Feature] \"\n\nbody:\n- type: markdown\n  attributes:\n    value: |\n      We strongly appreciate you creating a PR to implement this feature [here](https://github.com/InternLM/lmdeploy/pulls)!\n      If you need our help, please fill in as much of the following form as you're able to.\n\n      **The less clear the description, the longer it will take to solve it.**\n- type: textarea\n  attributes:\n    label: Motivation\n    description: |\n      A clear and concise description of the motivation of the feature.\n      Ex1. It is inconvenient when \\[....\\].\n  validations:\n    required: true\n- type: textarea\n  attributes:\n    label: Related resources\n    description: |\n      If there is an official code release or third-party implementations, please also provide the information here, which would be very helpful.\n- type: textarea\n  attributes:\n    label: Additional context\n    description: |\n      Add any other context or screenshots about the feature request here.\n      If you would like to implement the feature and create a PR, please leave a comment here and that would be much appreciated.\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/3-documentation.yml",
    "content": "name: 📚 Documentation\ndescription: Report an issue related to the documentation.\nlabels: \"kind/doc,status/unconfirmed\"\ntitle: \"[Docs] \"\n\nbody:\n- type: textarea\n  attributes:\n    label: 📚 The doc issue\n    description: >\n      A clear and concise description the issue.\n  validations:\n    required: true\n\n- type: textarea\n  attributes:\n    label: Suggest a potential alternative/fix\n    description: >\n      Tell us how we could improve the documentation in this regard.\n- type: markdown\n  attributes:\n    value: >\n      Thanks for contributing 🎉!\n"
  },
  {
    "path": ".github/pull_request_template.md",
    "content": "Thanks for your contribution and we appreciate it a lot. The following instructions would make your pull request more healthy and more easily receiving feedbacks. If you do not understand some items, don't worry, just make the pull request and seek help from maintainers.\n\n## Motivation\n\nPlease describe the motivation of this PR and the goal you want to achieve through this PR.\n\n## Modification\n\nPlease briefly describe what modification is made in this PR.\n\n## BC-breaking (Optional)\n\nDoes the modification introduce changes that break the backward-compatibility of the downstream repositories?\nIf so, please describe how it breaks the compatibility and how the downstream projects should modify their code to keep compatibility with this PR.\n\n## Use cases (Optional)\n\nIf this PR introduces a new feature, it is better to list some use cases here, and update the documentation.\n\n## Checklist\n\n1. Pre-commit or other linting tools are used to fix the potential lint issues.\n2. The modification is covered by complete unit tests. If not, please add more unit tests to ensure the correctness.\n3. If the modification has a dependency on downstream projects of a newer version, this PR should be tested with all supported versions of downstream projects.\n4. The documentation has been modified accordingly, like docstring or example tutorials.\n"
  },
  {
    "path": ".github/release.yml",
    "content": "changelog:\n  categories:\n    - title: 🚀 Features\n      labels:\n        - feature\n        - enhancement\n    - title: 💥 Improvements\n      labels:\n        - improvement\n    - title: 🐞 Bug fixes\n      labels:\n        - bug\n        - Bug:P0\n        - Bug:P1\n        - Bug:P2\n        - Bug:P3\n    - title: 📚 Documentations\n      labels:\n        - documentation\n    - title: 🌐 Other\n      labels:\n        - '*'\n      exclude:\n        labels:\n          - feature\n          - enhancement\n          - improvement\n          - bug\n          - documentation\n          - Bug:P0\n          - Bug:P1\n          - Bug:P2\n          - Bug:P3\n"
  },
  {
    "path": ".github/scripts/action_tools.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport glob\nimport json\nimport logging\nimport os\nimport shutil\nimport subprocess\nimport time\nfrom collections import OrderedDict\nfrom typing import List\n\nimport fire\nimport pandas as pd\nfrom mmengine.config import Config\n\n\ndef run_cmd(cmd_lines: List[str], log_path: str, cwd: str = None):\n    \"\"\"\n    Args:\n        cmd_lines: (list[str]): A command in multiple line style.\n        log_path (str): Path to log file.\n        cwd (str): Path to the current working directory.\n\n    Returns:\n        int: error code.\n    \"\"\"\n    import platform\n\n    system = platform.system().lower()\n\n    if system == 'windows':\n        sep = r'`'\n    else:  # 'Linux', 'Darwin'\n        sep = '\\\\'\n    cmd_for_run = ' '.join(cmd_lines)\n    cmd_for_log = f' {sep}\\n'.join(cmd_lines) + '\\n'\n    with open(log_path, 'w', encoding='utf-8') as file_handler:\n        file_handler.write(f'Command: {cmd_for_log}\\n')\n        file_handler.flush()\n        process_res = subprocess.Popen(cmd_for_run, shell=True, cwd=cwd, stdout=file_handler, stderr=file_handler)\n        process_res.wait()\n        return_code = process_res.returncode\n\n    if return_code != 0:\n        logging.error(f'Got shell abnormal return code={return_code}')\n        with open(log_path, 'r') as f:\n            content = f.read()\n            logging.error(f'Log error message\\n{content}')\n    return return_code\n\n\ndef _append_summary(content):\n    summary_file = os.environ['GITHUB_STEP_SUMMARY']\n    with open(summary_file, 'a') as f:\n        f.write(content + '\\n')\n\n\ndef add_summary(csv_path: str):\n    \"\"\"Add csv file to github step summary.\n\n    Args:\n        csv_path (str): Input csv file.\n    \"\"\"\n    with open(csv_path, 'r') as fr:\n        lines = fr.readlines()\n        header = lines[0].strip().split(',')\n        n_col = len(header)\n        header = '|' + '|'.join(header) + '|'\n        aligner = '|' + '|'.join([':-:'] * n_col) + '|'\n        _append_summary(header)\n        _append_summary(aligner)\n        for line in lines[1:]:\n            line = '|' + line.strip().replace(',', '|') + '|'\n            _append_summary(line)\n        _append_summary('\\n')\n\n\ndef evaluate(models: List[str],\n             datasets: List[str],\n             workspace: str,\n             evaluate_type: str,\n             max_num_workers: int = 8,\n             is_smoke: bool = False):\n    \"\"\"Evaluate models from lmdeploy using opencompass.\n\n    Args:\n        models: Input models.\n        workspace: Working directory.\n    \"\"\"\n    os.makedirs(workspace, exist_ok=True)\n    output_csv = os.path.join(workspace, f'results_{evaluate_type}.csv')\n    if os.path.exists(output_csv):\n        os.remove(output_csv)\n    num_model = len(models)\n    for idx, ori_model in enumerate(models):\n        print()\n        print(50 * '==')\n        print(f'Start evaluating {idx+1}/{num_model} {ori_model} ...')\n        model = ori_model.lower()\n\n        lmdeploy_dir = os.path.abspath(os.environ['LMDEPLOY_DIR'])\n        config_path = os.path.join(lmdeploy_dir, f'.github/scripts/eval_{evaluate_type}_config.py')\n        config_path_new = os.path.join(lmdeploy_dir, 'eval_lmdeploy.py')\n        if os.path.exists(config_path_new):\n            os.remove(config_path_new)\n        shutil.copy(config_path, config_path_new)\n\n        cfg = Config.fromfile(config_path_new)\n        if not hasattr(cfg, model):\n            logging.error(f'Model {model} not in configuration file')\n            continue\n\n        model_cfg = cfg[model]\n        logging.info(f'Start evaluating {model} ...\\\\nn{model_cfg}\\n\\n')\n\n        with open(config_path_new, 'a') as f:\n            f.write(f'\\ndatasets = {datasets}\\n')\n            if is_smoke:\n                f.write('\\nfor d in datasets:\\n')\n                f.write(\"    if d['reader_cfg'] is not None:\\n\")\n                f.write(\"        d['reader_cfg']['test_range'] = '[0:50]'\\n\")\n            if model.startswith('hf'):\n                f.write(f'\\nmodels = [*{model}]\\n')\n            else:\n                f.write(f'\\nmodels = [{model}]\\n')\n\n        work_dir = os.path.join(workspace, model)\n        cmd_eval = [\n            f'opencompass {config_path_new} -w {work_dir} --reuse --max-num-workers {max_num_workers} --dump-res-length'  # noqa: E501\n        ]\n        eval_log = os.path.join(workspace, f'eval.{ori_model}.txt')\n        start_time = time.time()\n        ret = run_cmd(cmd_eval, log_path=eval_log, cwd=lmdeploy_dir)\n        end_time = time.time()\n        task_duration_seconds = round(end_time - start_time, 2)\n        logging.info(f'task_duration_seconds: {task_duration_seconds}\\n')\n        if ret != 0:\n            continue\n        csv_files = glob.glob(f'{work_dir}/*/summary/summary_*.csv')\n\n        if len(csv_files) < 1:\n            logging.error(f'Did not find summary csv file {csv_files}')\n            continue\n        else:\n            csv_file = max(csv_files, key=os.path.getctime)\n        # print csv_txt to screen\n        csv_txt = csv_file.replace('.csv', '.txt')\n        if os.path.exists(csv_txt):\n            with open(csv_txt, 'r') as f:\n                print(f.read())\n\n        # parse evaluation results from csv file\n        model_results = OrderedDict()\n        with open(csv_file, 'r') as f:\n            lines = f.readlines()\n            for line in lines[1:]:\n                row = line.strip().split(',')\n                row = [_.strip() for _ in row]\n                if row[-1] != '-':\n                    model_results[row[0]] = row[-1]\n        crows_pairs_json = glob.glob(os.path.join(work_dir, '*/results/*/crows_pairs.json'), recursive=True)\n        if len(crows_pairs_json) == 1:\n            with open(crows_pairs_json[0], 'r') as f:\n                acc = json.load(f)['accuracy']\n                acc = f'{float(acc):.2f}'  # noqa E231\n                model_results['crows_pairs'] = acc\n        logging.info(f'\\n{model}\\n{model_results}')\n        dataset_names = list(model_results.keys())\n\n        row = ','.join([model, str(task_duration_seconds)] + [model_results[_] for _ in dataset_names])\n\n        if not os.path.exists(output_csv):\n            with open(output_csv, 'w') as f:\n                header = ','.join(['Model', 'task_duration_secs'] + dataset_names)\n                f.write(header + '\\n')\n                f.write(row + '\\n')\n        else:\n            with open(output_csv, 'a') as f:\n                f.write(row + '\\n')\n\n    # write to github action summary\n    _append_summary('## Evaluation Results')\n    if os.path.exists(output_csv):\n        add_summary(output_csv)\n\n\ndef create_model_links(src_dir: str, dst_dir: str):\n    \"\"\"Create softlinks for models.\"\"\"\n    paths = glob.glob(os.path.join(src_dir, '*'))\n    model_paths = [os.path.abspath(p) for p in paths if os.path.isdir(p)]\n    os.makedirs(dst_dir, exist_ok=True)\n    for src in model_paths:\n        _, model_name = os.path.split(src)\n        dst = os.path.join(dst_dir, model_name)\n        if not os.path.exists(dst):\n            os.symlink(src, dst)\n        else:\n            logging.warning(f'Model_path exists: {dst}')\n\n\ndef generate_benchmark_report(report_path: str):\n    # write to github action summary\n    _append_summary('## Benchmark Results Start')\n    subfolders = [f.path for f in os.scandir(report_path) if f.is_dir()]\n    for dir_path in subfolders:\n        second_subfolders = [f.path for f in sorted(os.scandir(dir_path), key=lambda x: x.name) if f.is_dir()]\n        for sec_dir_path in second_subfolders:\n            model = sec_dir_path.replace(report_path + '/', '')\n            print('-' * 25, model, '-' * 25)\n            _append_summary('-' * 25 + model + '-' * 25 + '\\n')\n\n            benchmark_subfolders = [\n                f.path for f in sorted(os.scandir(sec_dir_path), key=lambda x: x.name) if f.is_dir()\n            ]\n            for backend_subfolder in benchmark_subfolders:\n                benchmark_type = backend_subfolder.replace(sec_dir_path + '/', '')\n                print('*' * 10, benchmark_type, '*' * 10)\n                _append_summary('-' * 10 + benchmark_type + '-' * 10 + '\\n')\n                merged_csv_path = os.path.join(backend_subfolder, 'summary.csv')\n                csv_files = glob.glob(os.path.join(backend_subfolder, '*.csv'))\n                average_csv_path = os.path.join(backend_subfolder, 'average.csv')\n                if merged_csv_path in csv_files:\n                    csv_files.remove(merged_csv_path)\n                if average_csv_path in csv_files:\n                    csv_files.remove(average_csv_path)\n                merged_df = pd.DataFrame()\n\n                if len(csv_files) > 0:\n                    for f in csv_files:\n                        df = pd.read_csv(f)\n                        merged_df = pd.concat([merged_df, df], ignore_index=True)\n                    if 'throughput' in backend_subfolder or 'longtext' in backend_subfolder:\n                        merged_df = merged_df.sort_values(by=merged_df.columns[1])\n\n                        grouped_df = merged_df.groupby(merged_df.columns[1])\n                    else:\n                        merged_df = merged_df.sort_values(by=merged_df.columns[0])\n\n                        grouped_df = merged_df.groupby(merged_df.columns[0])\n                    if 'generation' not in backend_subfolder:\n                        average_values = grouped_df.pipe((lambda group: {\n                            'mean': group.mean(numeric_only=True).round(decimals=3)\n                        }))['mean']\n                        average_values.to_csv(average_csv_path, index=True)\n                        avg_df = pd.read_csv(average_csv_path)\n                        merged_df = pd.concat([merged_df, avg_df], ignore_index=True)\n                        add_summary(average_csv_path)\n                    merged_df.to_csv(merged_csv_path, index=False)\n                    if 'generation' in backend_subfolder:\n                        add_summary(merged_csv_path)\n\n    _append_summary('## Benchmark Results End')\n\n\ndef generate_csv_from_profile_result(file_path: str, out_path: str):\n    with open(file_path, 'r') as f:\n        data = f.readlines()\n        data = [json.loads(line) for line in data]\n\n        data_csv = []\n        for item in data:\n            row = [\n                item.get('request_rate'),\n                item.get('completed'),\n                round(item.get('completed') / item.get('duration'), 3),\n                round(item.get('median_ttft_ms'), 3),\n                round(item.get('output_throughput'), 3)\n            ]\n            data_csv.append(row)\n        import csv\n        with open(out_path, 'w', newline='') as f:\n            writer = csv.writer(f)\n            writer.writerow(['request_rate', 'completed', 'RPM', 'median_ttft_ms', 'output_throughput'])\n            writer.writerows(data_csv)\n\n\ndef generate_output_for_evaluation(result_dir: str):\n    # find latest result\n    latest_csv_file = find_csv_files(result_dir)\n    df = pd.read_csv(latest_csv_file)\n    transposed_df = df.T\n    head_part = transposed_df.head(4)\n    tail_part = transposed_df[4:]\n    sorted_tail_part = tail_part.sort_index()\n    transposed_df = pd.concat([head_part, sorted_tail_part])\n    transposed_df.to_csv('transposed_output.csv', header=False, index=True)\n    # output to github action summary\n    add_summary('transposed_output.csv')\n\n\ndef find_csv_files(directory):\n    csv_files = []\n    for root, dirs, files in os.walk(directory):\n        for file in files:\n            if file.endswith('.csv') and file.startswith('summary'):\n                csv_files.append(os.path.join(root, file))\n\n    csv_files_with_time = {f: os.path.getctime(f) for f in csv_files}\n    sorted_csv_files = sorted(csv_files_with_time.items(), key=lambda x: x[1])\n    latest_csv_file = sorted_csv_files[-1][0]\n    return latest_csv_file\n\n\nif __name__ == '__main__':\n    fire.Fire()\n"
  },
  {
    "path": ".github/scripts/check_lmdeploy.py",
    "content": "# Copyright (c) MegFlow. All rights reserved.\nimport glob\nimport os\n\nimport fire\n\n\ndef check_module_init(root: str):\n    \"\"\"Check if a module has __init__.py file.\"\"\"\n    all_files = glob.glob(os.path.join(root, '**/*'), recursive=True)\n    not_exist = []\n    for d in all_files:\n        if not os.path.isdir(d):\n            continue\n        if '__pycache__' in d:\n            continue\n        elif d.startswith('lmdeploy/bin'):\n            continue\n        elif d.startswith('lmdeploy/lib'):\n            continue\n        elif d.startswith('lmdeploy/monitoring'):\n            continue\n        elif d.startswith('lmdeploy/serve/turbomind/triton_models'):\n            continue\n        elif d.startswith('lmdeploy/serve/turbomind/triton_python_backend'):\n            continue\n        init_file = os.path.join(d, '__init__.py')\n        if not os.path.exists(init_file):\n            not_exist.append(init_file)\n\n    assert len(not_exist) == 0, f'Missing files: {not_exist}'\n\n\nif __name__ == '__main__':\n    fire.Fire()\n"
  },
  {
    "path": ".github/scripts/doc_link_checker.py",
    "content": "# Copyright (c) MegFlow. All rights reserved.\n# /bin/python3\n\nimport argparse\nimport os\nimport re\n\n\ndef make_parser():\n    parser = argparse.ArgumentParser('Doc link checker')\n    parser.add_argument('--http', default=False, type=bool, help='check http or not ')\n    parser.add_argument('--target', default='./docs', type=str, help='the directory or file to check')\n    return parser\n\n\npattern = re.compile(r'\\[.*?\\]\\(.*?\\)')\n\n\ndef analyze_doc(home, path):\n    print('analyze {}'.format(path))\n    problem_list = []\n    code_block = 0\n    with open(path) as f:\n        lines = f.readlines()\n        for line in lines:\n            line = line.strip()\n            if line.startswith('```'):\n                code_block = 1 - code_block\n\n            if code_block > 0:\n                continue\n\n            if '[' in line and ']' in line and '(' in line and ')' in line:\n                all = pattern.findall(line)\n                for item in all:\n                    # skip  ![]()\n                    if item.find('[') == item.find(']') - 1:\n                        continue\n\n                    # process the case [text()]()\n                    offset = item.find('](')\n                    if offset == -1:\n                        continue\n                    item = item[offset:]\n                    start = item.find('(')\n                    end = item.find(')')\n                    ref = item[start + 1:end]\n\n                    if ref.startswith('http') or ref.startswith('#'):\n                        continue\n                    if '.md#' in ref:\n                        ref = ref[ref.find('#'):]\n                    fullpath = os.path.join(home, ref)\n                    if not os.path.exists(fullpath):\n                        problem_list.append(ref)\n            else:\n                continue\n    if len(problem_list) > 0:\n        print(f'{path}:')\n        for item in problem_list:\n            print(f'\\t {item}')\n        print('\\n')\n        raise Exception('found link error')\n\n\ndef traverse(target):\n    if os.path.isfile(target):\n        analyze_doc(os.path.dirname(target), target)\n        return\n    for home, dirs, files in os.walk(target):\n        for filename in files:\n            if filename.endswith('.md'):\n                path = os.path.join(home, filename)\n                if os.path.islink(path) is False:\n                    analyze_doc(home, path)\n\n\nif __name__ == '__main__':\n    args = make_parser().parse_args()\n    traverse(args.target)\n"
  },
  {
    "path": ".github/scripts/eval_base_config.py",
    "content": "from copy import deepcopy\n\nfrom mmengine.config import read_base\nfrom opencompass.models import TurboMindModel\n\nwith read_base():\n    # choose a list of datasets\n    from opencompass.configs.datasets.ARC_c.ARC_c_few_shot_ppl import ARC_c_datasets  # noqa: F401, E501\n    from opencompass.configs.datasets.bbh.bbh_gen_98fba6 import bbh_datasets  # noqa: F401, E501\n    from opencompass.configs.datasets.ceval.ceval_ppl import ceval_datasets  # noqa: F401, E501\n    from opencompass.configs.datasets.cmmlu.cmmlu_ppl_041cbf import cmmlu_datasets  # noqa: F401, E501\n    from opencompass.configs.datasets.crowspairs.crowspairs_ppl import crowspairs_datasets  # noqa: F401, E501\n    from opencompass.configs.datasets.drop.drop_gen_a2697c import drop_datasets  # noqa: F401, E501\n    # Corebench v1.7\n    from opencompass.configs.datasets.GaokaoBench.GaokaoBench_no_subjective_gen_d21e37 import \\\n        GaokaoBench_datasets  # noqa: F401, E501\n    from opencompass.configs.datasets.gpqa.gpqa_few_shot_ppl_4b5a83 import gpqa_datasets  # noqa: F401, E501\n    from opencompass.configs.datasets.gsm8k.gsm8k_gen_17d0dc import gsm8k_datasets  # noqa: F401, E501\n    from opencompass.configs.datasets.hellaswag.hellaswag_10shot_ppl_59c85e import \\\n        hellaswag_datasets  # noqa: F401, E501\n    from opencompass.configs.datasets.humaneval.internal_humaneval_gen_ce6b06 import \\\n        humaneval_datasets as humaneval_v2_datasets  # noqa: F401, E501\n    from opencompass.configs.datasets.humaneval.internal_humaneval_gen_d2537e import \\\n        humaneval_datasets  # noqa: F401, E501\n    from opencompass.configs.datasets.math.math_4shot_base_gen_43d5b6 import math_datasets  # noqa: F401, E501\n    from opencompass.configs.datasets.MathBench.mathbench_2024_few_shot_mixed_4a3fd4 import \\\n        mathbench_datasets  # noqa: F401, E501\n    from opencompass.configs.datasets.mbpp.sanitized_mbpp_gen_742f0c import sanitized_mbpp_datasets  # noqa: F401, E501\n    from opencompass.configs.datasets.mmlu.mmlu_ppl_ac766d import mmlu_datasets  # noqa: F401, E501\n    from opencompass.configs.datasets.mmlu_pro.mmlu_pro_few_shot_gen_bfaf90 import mmlu_pro_datasets  # noqa: F401, E501\n    from opencompass.configs.datasets.nq.nq_open_1shot_gen_20a989 import nq_datasets  # noqa: F401, E501\n    from opencompass.configs.datasets.race.race_few_shot_ppl import race_datasets  # noqa: F401, E501\n    from opencompass.configs.datasets.SuperGLUE_BoolQ.SuperGLUE_BoolQ_few_shot_ppl import \\\n        BoolQ_datasets  # noqa: F401, E501\n    from opencompass.configs.datasets.TheoremQA.TheoremQA_5shot_gen_6f0af8 import TheoremQA_datasets  # noqa: F401, E501\n    from opencompass.configs.datasets.triviaqa.triviaqa_wiki_1shot_gen_20a989 import \\\n        triviaqa_datasets  # noqa: F401, E501\n    from opencompass.configs.datasets.wikibench.wikibench_few_shot_ppl_c23d79 import \\\n        wikibench_datasets  # noqa: F401, E501\n    from opencompass.configs.datasets.winogrande.winogrande_5shot_ll_252f01 import \\\n        winogrande_datasets  # noqa: F401, E501\n    # Summary Groups\n    from opencompass.configs.summarizers.groups.cmmlu import cmmlu_summary_groups  # noqa: F401, E501\n    from opencompass.configs.summarizers.groups.GaokaoBench import GaokaoBench_summary_groups  # noqa: F401, E501\n    from opencompass.configs.summarizers.groups.mathbench_v1_2024 import \\\n        mathbench_2024_summary_groups  # noqa: F401, E501\n    from opencompass.configs.summarizers.groups.mmlu import mmlu_summary_groups  # noqa: F401, E501\n    from opencompass.configs.summarizers.groups.mmlu_pro import mmlu_pro_summary_groups  # noqa: F401, E501\n\n    # read models\nrace_datasets = [race_datasets[1]]\nmmlu_datasets = [\n    x for x in mmlu_datasets if x['abbr'].replace('lukaemon_mmlu_', '') in [\n        'business_ethics', 'clinical_knowledge', 'college_medicine', 'global_facts', 'human_aging', 'management',\n        'marketing', 'medical_genetics', 'miscellaneous', 'nutrition', 'professional_accounting',\n        'professional_medicine', 'virology'\n    ]\n]\n\nsummarizer = dict(\n    dataset_abbrs=[\n        ['race-high', 'accuracy'],\n        ['ARC-c', 'accuracy'],\n        ['BoolQ', 'accuracy'],\n        ['mmlu_pro', 'naive_average'],\n        ['GPQA_diamond', 'accuracy'],\n        ['cmmlu', 'naive_average'],\n        ['mmlu', 'naive_average'],\n        ['drop', 'accuracy'],\n        ['bbh', 'naive_average'],\n        ['math', 'accuracy'],\n        ['openai_humaneval', 'humaneval_pass@1'],\n        ['openai_humaneval_v2', 'humaneval_pass@1'],\n        ['sanitized_mbpp', 'score'],\n        ['wikibench-wiki-single_choice_cncircular', 'perf_4'],\n        ['gsm8k', 'accuracy'],\n        ['GaokaoBench', 'weighted_average'],\n        ['triviaqa_wiki_1shot', 'score'],\n        ['nq_open_1shot', 'score'],\n        ['winogrande', 'accuracy'],\n        ['hellaswag', 'accuracy'],\n        ['TheoremQA', 'score'],\n        '###### MathBench-A: Application Part ######',\n        'college',\n        'high',\n        'middle',\n        'primary',\n        'arithmetic',\n        'mathbench-a (average)',\n        '###### MathBench-T: Theory Part ######',\n        'college_knowledge',\n        'high_knowledge',\n        'middle_knowledge',\n        'primary_knowledge',\n        'mathbench-t (average)',\n        '###### Overall: Average between MathBench-A and MathBench-T ######',\n        'Overall',\n        '',\n        'mmlu',\n        'mmlu-stem',\n        'mmlu-social-science',\n        'mmlu-humanities',\n        'mmlu-other',\n        'cmmlu',\n        'cmmlu-stem',\n        'cmmlu-social-science',\n        'cmmlu-humanities',\n        'cmmlu-other',\n        'cmmlu-china-specific',\n        'mmlu_pro',\n        'mmlu_pro_biology',\n        'mmlu_pro_business',\n        'mmlu_pro_chemistry',\n        'mmlu_pro_computer_science',\n        'mmlu_pro_economics',\n        'mmlu_pro_engineering',\n        'mmlu_pro_health',\n        'mmlu_pro_history',\n        'mmlu_pro_law',\n        'mmlu_pro_math',\n        'mmlu_pro_philosophy',\n        'mmlu_pro_physics',\n        'mmlu_pro_psychology',\n        'mmlu_pro_other',\n    ],\n    summary_groups=sum([v for k, v in locals().items() if k.endswith('_summary_groups')], []),\n)\n\nbase_model = dict(\n    type=TurboMindModel,\n    engine_config=dict(session_len=7168, tp=1),\n    gen_config=dict(top_k=1, temperature=1e-6, top_p=0.9, max_new_tokens=1024),\n    max_seq_len=7168,\n    max_out_len=1024,\n    batch_size=32,\n    run_cfg=dict(num_gpus=1),\n)\n\nturbomind_qwen2_5_1_5b = deepcopy(base_model)\nturbomind_qwen2_5_1_5b['path'] = 'Qwen/Qwen2.5-1.5B'\nturbomind_qwen2_5_1_5b['abbr'] = 'turbomind_qwen2_5_1_5b'\nturbomind_qwen2_5_7b = deepcopy(base_model)\nturbomind_qwen2_5_7b['path'] = 'Qwen/Qwen2.5-7B'\nturbomind_qwen2_5_7b['abbr'] = 'turbomind_qwen2_5_7b'\nturbomind_qwen2_5_32b = deepcopy(base_model)\nturbomind_qwen2_5_32b['path'] = 'Qwen/Qwen2.5-32B'\nturbomind_qwen2_5_32b['abbr'] = 'turbomind_qwen2_5_32b'\nturbomind_qwen2_5_32b['run_cfg']['num_gpus'] = 2\nturbomind_qwen2_5_32b['engine_config']['tp'] = 2\nturbomind_internlm2_5_7b = deepcopy(base_model)\nturbomind_internlm2_5_7b['path'] = 'internlm/internlm2_5-7b-chat'\nturbomind_internlm2_5_7b['abbr'] = 'turbomind_internlm2_5_7b'\nturbomind_glm_4_9b = deepcopy(base_model)\nturbomind_glm_4_9b['path'] = 'THUDM/glm-4-9b'\nturbomind_glm_4_9b['abbr'] = 'turbomind_glm_4_9b'\nturbomind_llama_3_70b = deepcopy(base_model)\nturbomind_llama_3_70b['path'] = 'meta-llama/Meta-Llama-3-70B'\nturbomind_llama_3_70b['abbr'] = 'turbomind_llama_3_70b'\nturbomind_llama_3_70b['run_cfg']['num_gpus'] = 4\nturbomind_llama_3_70b['engine_config']['tp'] = 4\nturbomind_llama_3_1_8b = deepcopy(base_model)\nturbomind_llama_3_1_8b['path'] = 'meta-llama/Llama-3.1-8B'\nturbomind_llama_3_1_8b['abbr'] = 'turbomind_llama_3_1_8b'\nturbomind_qwen3_0_6b_base = deepcopy(base_model)\nturbomind_qwen3_0_6b_base['path'] = 'Qwen/Qwen3-0.6B-Base'\nturbomind_qwen3_0_6b_base['abbr'] = 'turbomind_qwen3_0_6b_base'\nturbomind_qwen3_8b_base = deepcopy(base_model)\nturbomind_qwen3_8b_base['path'] = 'Qwen/Qwen3-8B-Base'\nturbomind_qwen3_8b_base['abbr'] = 'turbomind_qwen3_8b_base'\nturbomind_qwen3_30b_A3B_base = deepcopy(base_model)\nturbomind_qwen3_30b_A3B_base['path'] = 'Qwen/Qwen3-30B-A3B-Base'\nturbomind_qwen3_30b_A3B_base['abbr'] = 'turbomind_qwen3_30b_A3B_base'\nturbomind_qwen3_30b_A3B_base['run_cfg']['num_gpus'] = 2\nturbomind_qwen3_30b_A3B_base['engine_config']['tp'] = 2\n\npytorch_qwen2_5_1_5b = deepcopy(base_model)\npytorch_qwen2_5_1_5b['path'] = 'Qwen/Qwen2.5-1.5B'\npytorch_qwen2_5_1_5b['abbr'] = 'pytorch_qwen2_5_1_5b'\npytorch_qwen2_5_7b = deepcopy(base_model)\npytorch_qwen2_5_7b['path'] = 'Qwen/Qwen2.5-7B'\npytorch_qwen2_5_7b['abbr'] = 'pytorch_qwen2_5_7b'\npytorch_qwen2_5_32b = deepcopy(base_model)\npytorch_qwen2_5_32b['path'] = 'Qwen/Qwen2.5-32B'\npytorch_qwen2_5_32b['abbr'] = 'pytorch_qwen2_5_32b'\npytorch_qwen2_5_32b['run_cfg']['num_gpus'] = 2\npytorch_qwen2_5_32b['engine_config']['tp'] = 2\npytorch_internlm2_5_7b = deepcopy(base_model)\npytorch_internlm2_5_7b['path'] = 'internlm/internlm2_5-7b-chat'\npytorch_internlm2_5_7b['abbr'] = 'pytorch_internlm2_5_7b'\npytorch_gemma_2_9b = deepcopy(base_model)\npytorch_gemma_2_9b['path'] = 'google/gemma-2-9b'\npytorch_gemma_2_9b['abbr'] = 'pytorch_gemma_2_9b'\npytorch_llama_3_70b = deepcopy(base_model)\npytorch_llama_3_70b['path'] = 'meta-llama/Meta-Llama-3-70B'\npytorch_llama_3_70b['abbr'] = 'pytorch_llama_3_70b'\npytorch_llama_3_70b['run_cfg']['num_gpus'] = 4\npytorch_llama_3_70b['engine_config']['tp'] = 4\npytorch_llama_3_1_8b = deepcopy(base_model)\npytorch_llama_3_1_8b['path'] = 'meta-llama/Llama-3.1-8B'\npytorch_llama_3_1_8b['abbr'] = 'pytorch_llama_3_1_8b'\npytorch_qwen3_0_6b_base = deepcopy(base_model)\npytorch_qwen3_0_6b_base['path'] = 'Qwen/Qwen3-0.6B-Base'\npytorch_qwen3_0_6b_base['abbr'] = 'pytorch_qwen3_0_6b_base'\npytorch_qwen3_8b_base = deepcopy(base_model)\npytorch_qwen3_8b_base['path'] = 'Qwen/Qwen3-8B-Base'\npytorch_qwen3_8b_base['abbr'] = 'pytorch_qwen3_8b_base'\npytorch_qwen3_30b_A3B_base = deepcopy(base_model)\npytorch_qwen3_30b_A3B_base['path'] = 'Qwen/Qwen3-30B-A3B-Base'\npytorch_qwen3_30b_A3B_base['abbr'] = 'pytorch_qwen3_30b_A3B_base'\npytorch_qwen3_30b_A3B_base['run_cfg']['num_gpus'] = 2\npytorch_qwen3_30b_A3B_base['engine_config']['tp'] = 2\n\nfor model in [v for k, v in locals().items() if k.startswith('pytorch_')]:\n    model['backend'] = 'pytorch'\n"
  },
  {
    "path": ".github/scripts/eval_chat_config.py",
    "content": "from copy import deepcopy\n\nfrom mmengine.config import read_base\nfrom opencompass.models import TurboMindModelwithChatTemplate\nfrom opencompass.utils.text_postprocessors import extract_non_reasoning_content\n\nwith read_base():\n    # choose a list of datasets\n    from opencompass.configs.datasets.bbh.bbh_gen_5b92b0 import bbh_datasets  # noqa: F401, E501\n    from opencompass.configs.datasets.ceval.ceval_gen_2daf24 import ceval_datasets  # noqa: F401, E501\n    from opencompass.configs.datasets.cmmlu.cmmlu_gen_c13365 import cmmlu_datasets  # noqa: F401, E501\n    from opencompass.configs.datasets.crowspairs.crowspairs_gen_381af0 import crowspairs_datasets  # noqa: F401, E501\n    from opencompass.configs.datasets.GaokaoBench.GaokaoBench_no_subjective_gen_4c31db import \\\n        GaokaoBench_datasets  # noqa: F401, E501\n    from opencompass.configs.datasets.gpqa.gpqa_gen_4baadb import gpqa_datasets  # noqa: F401, E501\n    from opencompass.configs.datasets.gsm8k.gsm8k_gen_1d7fe4 import gsm8k_datasets  # noqa: F401, E501\n    from opencompass.configs.datasets.hellaswag.hellaswag_10shot_gen_e42710 import \\\n        hellaswag_datasets  # noqa: F401, E501\n    from opencompass.configs.datasets.humaneval.humaneval_gen_8e312c import humaneval_datasets  # noqa: F401, E501\n    from opencompass.configs.datasets.IFEval.IFEval_gen_3321a3 import ifeval_datasets  # noqa: F401, E501\n    from opencompass.configs.datasets.math.math_0shot_gen_393424 import math_datasets  # noqa: F401, E501\n    from opencompass.configs.datasets.mbpp.sanitized_mbpp_gen_a0fc46 import sanitized_mbpp_datasets  # noqa: F401, E501\n    from opencompass.configs.datasets.mmlu.mmlu_gen_4d595a import mmlu_datasets  # noqa: F401, E501\n    from opencompass.configs.datasets.mmlu_pro.mmlu_pro_0shot_cot_gen_08c1de import \\\n        mmlu_pro_datasets  # noqa: F401, E501\n    from opencompass.configs.datasets.nq.nq_open_1shot_gen_01cf41 import nq_datasets  # noqa: F401, E501\n    from opencompass.configs.datasets.race.race_gen_69ee4f import race_datasets  # noqa: F401, E501\n    from opencompass.configs.datasets.TheoremQA.TheoremQA_5shot_gen_6f0af8 import TheoremQA_datasets  # noqa: F401, E501\n    from opencompass.configs.datasets.triviaqa.triviaqa_wiki_1shot_gen_eaf81e import \\\n        triviaqa_datasets  # noqa: F401, E501\n    from opencompass.configs.datasets.winogrande.winogrande_5shot_gen_b36770 import \\\n        winogrande_datasets  # noqa: F401, E501\n    # read models\n    from opencompass.configs.models.baichuan.hf_baichuan2_7b_chat import \\\n        models as hf_baichuan2_chat_7b  # noqa: F401, E501\n    from opencompass.configs.models.gemma.hf_gemma2_9b_it import models as hf_gemma2_9b_it  # noqa: F401, E501\n    from opencompass.configs.models.hf_internlm.hf_internlm2_5_7b_chat import \\\n        models as hf_internlm2_5_7b_chat  # noqa: F401, E501\n    from opencompass.configs.models.hf_internlm.hf_internlm2_5_20b_chat import \\\n        models as hf_internlm2_5_20b_chat  # noqa: F401, E501\n    from opencompass.configs.models.hf_internlm.hf_internlm2_chat_7b import \\\n        models as hf_internlm2_chat_7b  # noqa: F401, E501\n    from opencompass.configs.models.hf_internlm.hf_internlm2_chat_20b import \\\n        models as hf_internlm2_chat_20b  # noqa: F401, E501\n    from opencompass.configs.models.hf_internlm.lmdeploy_internlm2_5_7b_chat import \\\n        models as lmdeploy_internlm2_5_7b_chat  # noqa: F401, E501\n    from opencompass.configs.models.hf_internlm.lmdeploy_internlm2_5_20b_chat import \\\n        models as lmdeploy_internlm2_5_20b_chat  # noqa: F401, E501\n    from opencompass.configs.models.hf_internlm.lmdeploy_internlm2_chat_7b import \\\n        models as lmdeploy_internlm2_chat_7b  # noqa: F401, E501\n    from opencompass.configs.models.hf_internlm.lmdeploy_internlm2_chat_20b import \\\n        models as lmdeploy_internlm2_chat_20b  # noqa: F401, E501\n    from opencompass.configs.models.hf_internlm.lmdeploy_internlm3_8b_instruct import \\\n        models as lmdeploy_internlm3_8b_instruct  # noqa: F401, E501\n    from opencompass.configs.models.hf_internlm.lmdeploy_internlm_chat_7b import \\\n        models as lmdeploy_internlm_chat_7b  # noqa: F401, E501\n    from opencompass.configs.models.hf_llama.hf_llama2_7b_chat import models as hf_llama2_chat_7b  # noqa: F401, E501\n    from opencompass.configs.models.hf_llama.hf_llama3_1_8b_instruct import \\\n        models as hf_llama3_1_8b_instruct  # noqa: F401, E501\n    from opencompass.configs.models.hf_llama.hf_llama3_8b_instruct import \\\n        models as hf_llama_3_8b_instruct  # noqa: F401, E501\n    from opencompass.configs.models.hf_llama.lmdeploy_llama2_7b_chat import \\\n        models as lmdeploy_llama2_7b_chat  # noqa: F401, E501\n    from opencompass.configs.models.hf_llama.lmdeploy_llama3_1_8b_instruct import \\\n        models as lmdeploy_llama3_1_8b_instruct  # noqa: F401, E501\n    from opencompass.configs.models.hf_llama.lmdeploy_llama3_8b_instruct import \\\n        models as lmdeploy_llama3_8b_instruct  # noqa: F401, E501\n    from opencompass.configs.models.mistral.hf_mistral_7b_instruct_v0_1 import \\\n        models as hf_mistral_chat_7b  # noqa: F401, E501\n    from opencompass.configs.models.mistral.hf_mixtral_8x7b_instruct_v0_1 import \\\n        models as hf_mixtral_chat_8x7b  # noqa: F401, E501\n    from opencompass.configs.models.qwen2_5.lmdeploy_qwen2_5_7b_instruct import \\\n        models as lmdeploy_qwen2_5_7b_instruct  # noqa: F401, E501\n    from opencompass.configs.models.qwen2_5.lmdeploy_qwen2_5_32b_instruct import \\\n        models as lmdeploy_qwen2_5_32b_instruct  # noqa: F401, E501\n    from opencompass.configs.models.qwen.hf_qwen1_5_7b_chat import models as hf_qwen1_5_chat_7b  # noqa: F401, E501\n    from opencompass.configs.models.qwen.hf_qwen1_5_moe_a2_7b_chat import \\\n        models as hf_qwen1_5_moe_a2_7b_chat  # noqa: F401, E501\n    from opencompass.configs.models.qwen.hf_qwen2_7b_instruct import models as hf_qwen2_7b_instruct  # noqa: F401, E501\n    from opencompass.configs.models.qwen.hf_qwen_7b_chat import models as hf_qwen_chat_7b  # noqa: F401, E501\n    from opencompass.configs.models.qwen.lmdeploy_qwen1_5_7b_chat import \\\n        models as lmdeploy_qwen1_5_7b_chat  # noqa: F401, E501\n    from opencompass.configs.models.qwen.lmdeploy_qwen2_7b_instruct import \\\n        models as lmdeploy_qwen2_7b_instruct  # noqa: F401, E501\n    from opencompass.configs.models.qwen.lmdeploy_qwen_7b_chat import \\\n        models as lmdeploy_qwen_7b_chat  # noqa: F401, E501\n    # Summary Groups\n    from opencompass.configs.summarizers.groups.bbh import bbh_summary_groups  # noqa: F401, E501\n    from opencompass.configs.summarizers.groups.cmmlu import cmmlu_summary_groups  # noqa: F401, E501\n    from opencompass.configs.summarizers.groups.ds1000 import ds1000_summary_groups  # noqa: F401, E501\n    from opencompass.configs.summarizers.groups.GaokaoBench import GaokaoBench_summary_groups  # noqa: F401, E501\n    from opencompass.configs.summarizers.groups.humanevalx import humanevalx_summary_groups  # noqa: F401, E501\n    from opencompass.configs.summarizers.groups.mathbench_v1_2024 import \\\n        mathbench_2024_summary_groups  # noqa: F401, E501\n    from opencompass.configs.summarizers.groups.mmlu import mmlu_summary_groups  # noqa: F401, E501\n    from opencompass.configs.summarizers.groups.mmlu_pro import mmlu_pro_summary_groups  # noqa: F401, E501\n    from opencompass.configs.summarizers.groups.scicode import scicode_summary_groups  # noqa: F401, E501\n    from opencompass.configs.summarizers.groups.teval import teval_summary_groups  # noqa: F401, E501\n\nllama2_meta_template = dict(round=[\n    dict(role='HUMAN', begin='[INST] ', end=' [/INST]'),\n    dict(role='BOT', begin='', end='', generate=True),\n],\n                            eos_token_id=2)\n\nMAX_SESSION_LEN = 2048\nMAX_NEW_TOKENS = 1024\n\n# ===== Configs for internlm/internlm2-chat-7b =====\nturbomind_internlm2_chat_7b = deepcopy(*lmdeploy_internlm2_chat_7b)\nturbomind_internlm2_chat_7b_4bits = deepcopy(*lmdeploy_internlm2_chat_7b)\nturbomind_internlm2_chat_7b_kvint4 = deepcopy(*lmdeploy_internlm2_chat_7b)\nturbomind_internlm2_chat_7b_kvint8 = deepcopy(*lmdeploy_internlm2_chat_7b)\npytorch_internlm2_chat_7b = deepcopy(*lmdeploy_internlm2_chat_7b)\n\n# ===== Configs for internlm/internlm2_5_7b_chat =====\nturbomind_internlm2_5_7b_chat = deepcopy(*lmdeploy_internlm2_5_7b_chat)\nturbomind_internlm2_5_7b_chat_4bits = deepcopy(*lmdeploy_internlm2_5_7b_chat)\nturbomind_internlm2_5_7b_chat_kvint4 = deepcopy(*lmdeploy_internlm2_5_7b_chat)\nturbomind_internlm2_5_7b_chat_kvint8 = deepcopy(*lmdeploy_internlm2_5_7b_chat)\npytorch_internlm2_5_7b_chat = deepcopy(*lmdeploy_internlm2_5_7b_chat)\npytorch_internlm2_5_7b_chat_w8a8 = deepcopy(*lmdeploy_internlm2_5_7b_chat)\nturbomind_internlm2_5_7b_chat_batch1 = deepcopy(*lmdeploy_internlm2_5_7b_chat)\nturbomind_internlm2_5_7b_chat_batch1_4bits = deepcopy(*lmdeploy_internlm2_5_7b_chat)\n\nturbomind_internlm3_8b_instruct = deepcopy(*lmdeploy_internlm3_8b_instruct)\nturbomind_internlm3_8b_instruct_4bits = deepcopy(*lmdeploy_internlm3_8b_instruct)\nturbomind_internlm3_8b_instruct_kvint4 = deepcopy(*lmdeploy_internlm3_8b_instruct)\nturbomind_internlm3_8b_instruct_kvint8 = deepcopy(*lmdeploy_internlm3_8b_instruct)\npytorch_internlm3_8b_instruct = deepcopy(*lmdeploy_internlm3_8b_instruct)\npytorch_internlm3_8b_instruct_w8a8 = deepcopy(*lmdeploy_internlm3_8b_instruct)\n\n# ===== Configs for internlm/internlm2_5_20b_chat =====\nturbomind_internlm2_5_20b_chat = deepcopy(*lmdeploy_internlm2_5_20b_chat)\nturbomind_internlm2_5_20b_chat_4bits = deepcopy(*lmdeploy_internlm2_5_20b_chat)\nturbomind_internlm2_5_20b_chat_kvint4 = deepcopy(*lmdeploy_internlm2_5_20b_chat)\nturbomind_internlm2_5_20b_chat_kvint8 = deepcopy(*lmdeploy_internlm2_5_20b_chat)\npytorch_internlm2_5_20b_chat = deepcopy(*lmdeploy_internlm2_5_20b_chat)\n\n# ===== Configs for internlm/internlm2_chat_20b =====\nturbomind_internlm2_chat_20b = deepcopy(*lmdeploy_internlm2_chat_20b)\nturbomind_internlm2_chat_20b_4bits = deepcopy(*lmdeploy_internlm2_chat_20b)\nturbomind_internlm2_chat_20b_kvint4 = deepcopy(*lmdeploy_internlm2_chat_20b)\nturbomind_internlm2_chat_20b_kvint8 = deepcopy(*lmdeploy_internlm2_chat_20b)\npytorch_internlm2_chat_20b = deepcopy(*lmdeploy_internlm2_chat_20b)\n\n# ===== Configs for Qwen/Qwen1.5-7B-Chat =====\nturbomind_qwen1_5_7b_chat = deepcopy(*lmdeploy_qwen1_5_7b_chat)\nturbomind_qwen1_5_7b_chat_4bits = deepcopy(*lmdeploy_qwen1_5_7b_chat)\nturbomind_qwen1_5_7b_chat_kvint4 = deepcopy(*lmdeploy_qwen1_5_7b_chat)\nturbomind_qwen1_5_7b_chat_kvint8 = deepcopy(*lmdeploy_qwen1_5_7b_chat)\npytorch_qwen1_5_7b_chat = deepcopy(*lmdeploy_qwen1_5_7b_chat)\n\n# ===== Configs for Qwen/Qwen-7B-Chat =====\nturbomind_qwen_7b_chat = deepcopy(*lmdeploy_qwen_7b_chat)\nturbomind_qwen_7b_chat_4bits = deepcopy(*lmdeploy_qwen_7b_chat)\nturbomind_qwen_7b_chat_kvint4 = deepcopy(*lmdeploy_qwen_7b_chat)\nturbomind_qwen_7b_chat_kvint8 = deepcopy(*lmdeploy_qwen_7b_chat)\npytorch_qwen_7b_chat = deepcopy(*lmdeploy_qwen_7b_chat)\n\n# ===== Configs for meta-llama/Meta-Llama-3-8B-Instruct =====\nturbomind_llama3_8b_instruct = deepcopy(*lmdeploy_llama3_8b_instruct)\nturbomind_llama3_8b_instruct_4bits = deepcopy(*lmdeploy_llama3_8b_instruct)\nturbomind_llama3_8b_instruct_kvint4 = deepcopy(*lmdeploy_llama3_8b_instruct)\nturbomind_llama3_8b_instruct_kvint8 = deepcopy(*lmdeploy_llama3_8b_instruct)\npytorch_llama3_8b_instruct = deepcopy(*lmdeploy_llama3_8b_instruct)\n\n# ===== Configs for meta-llama/Meta-Llama-3.1-8B-Instruct =====\nturbomind_llama3_1_8b_instruct = deepcopy(*lmdeploy_llama3_1_8b_instruct)\nturbomind_llama3_1_8b_instruct['path'] = 'meta-llama/Meta-Llama-3-1-8B-Instruct'\nturbomind_llama3_1_8b_instruct_4bits = deepcopy(turbomind_llama3_1_8b_instruct)\nturbomind_llama3_1_8b_instruct_kvint4 = deepcopy(turbomind_llama3_1_8b_instruct)\nturbomind_llama3_1_8b_instruct_kvint8 = deepcopy(turbomind_llama3_1_8b_instruct)\npytorch_llama3_1_8b_instruct = deepcopy(turbomind_llama3_1_8b_instruct)\npytorch_llama3_1_8b_instruct_w8a8 = deepcopy(turbomind_llama3_1_8b_instruct)\n\n# ===== Configs for Qwen/Qwen2-7B-Instruct =====\nturbomind_qwen2_7b_instruct = deepcopy(*lmdeploy_qwen2_7b_instruct)\nturbomind_qwen2_7b_instruct_4bits = deepcopy(*lmdeploy_qwen2_7b_instruct)\nturbomind_qwen2_7b_instruct_kvint4 = deepcopy(*lmdeploy_qwen2_7b_instruct)\nturbomind_qwen2_7b_instruct_kvint8 = deepcopy(*lmdeploy_qwen2_7b_instruct)\npytorch_qwen2_7b_instruct = deepcopy(*lmdeploy_qwen2_7b_instruct)\npytorch_qwen2_7b_instruct_w8a8 = deepcopy(*lmdeploy_qwen2_7b_instruct)\n\n# ===== Configs for Qwen/Qwen25-7B-Instruct =====\nturbomind_qwen2_5_7b_instruct = deepcopy(*lmdeploy_qwen2_5_7b_instruct)\nturbomind_qwen2_5_7b_instruct_4bits = deepcopy(*lmdeploy_qwen2_5_7b_instruct)\nturbomind_qwen2_5_7b_instruct_kvint4 = deepcopy(*lmdeploy_qwen2_5_7b_instruct)\nturbomind_qwen2_5_7b_instruct_kvint8 = deepcopy(*lmdeploy_qwen2_5_7b_instruct)\npytorch_qwen2_5_7b_instruct = deepcopy(*lmdeploy_qwen2_5_7b_instruct)\npytorch_qwen2_5_7b_instruct_w8a8 = deepcopy(*lmdeploy_qwen2_5_7b_instruct)\n\n# ===== Configs for Qwen/Qwen25-32B-Instruct =====\nturbomind_qwen2_5_32b_instruct = deepcopy(*lmdeploy_qwen2_5_32b_instruct)\nturbomind_qwen2_5_32b_instruct_4bits = deepcopy(*lmdeploy_qwen2_5_32b_instruct)\nturbomind_qwen2_5_32b_instruct_kvint4 = deepcopy(*lmdeploy_qwen2_5_32b_instruct)\nturbomind_qwen2_5_32b_instruct_kvint8 = deepcopy(*lmdeploy_qwen2_5_32b_instruct)\npytorch_qwen2_5_32b_instruct = deepcopy(*lmdeploy_qwen2_5_32b_instruct)\npytorch_qwen2_5_32b_instruct_w8a8 = deepcopy(*lmdeploy_qwen2_5_32b_instruct)\n\n# ===== Configs for meta-llama/Llama-2-7b-chat-hf =====\nturbomind_llama2_7b_chat = deepcopy(*lmdeploy_llama2_7b_chat)\nturbomind_llama2_7b_chat_4bits = deepcopy(*lmdeploy_llama2_7b_chat)\nturbomind_llama2_7b_chat_kvint4 = deepcopy(*lmdeploy_llama2_7b_chat)\nturbomind_llama2_7b_chat_kvint8 = deepcopy(*lmdeploy_llama2_7b_chat)\n\nbase_model = dict(type=TurboMindModelwithChatTemplate,\n                  engine_config=dict(session_len=32768, max_batch_size=256),\n                  gen_config=dict(top_k=1, temperature=1e-6, top_p=0.9, max_new_tokens=32768),\n                  max_seq_len=32768,\n                  max_out_len=32768,\n                  batch_size=500,\n                  pred_postprocessor=dict(type=extract_non_reasoning_content),\n                  run_cfg=dict(num_gpus=1))\n\nturbomind_qwen3_32b = deepcopy(base_model)\npytorch_qwen3_32b = deepcopy(base_model)\nturbomind_qwen3_32b_4bits = deepcopy(base_model)\nturbomind_qwen3_32b_kvint8 = deepcopy(base_model)\n\nturbomind_qwen3_30b_a3b = deepcopy(base_model)\npytorch_qwen3_30b_a3b = deepcopy(base_model)\nturbomind_qwen3_30b_a3b_4bits = deepcopy(base_model)\nturbomind_qwen3_30b_a3b_kvint8 = deepcopy(base_model)\nturbomind_qwen3_30b_a3b_fp8 = deepcopy(base_model)\npytorch_qwen3_30b_a3b_fp8 = deepcopy(base_model)\nturbomind_qwen3_30b_a3b_fp8['engine_config']['cache_max_entry_count'] = 0.6\n\nturbomind_qwen3_235b_a22b = deepcopy(base_model)\npytorch_qwen3_235b_a22b = deepcopy(base_model)\nturbomind_qwen3_235b_a22b_4bits = deepcopy(base_model)\nturbomind_qwen3_235b_a22b_kvint8 = deepcopy(base_model)\nturbomind_qwen3_235b_a22b_fp8 = deepcopy(base_model)\npytorch_qwen3_235b_a22b_fp8 = deepcopy(base_model)\n\n# update config for Qwen3-32B, Qwen3-30B-A3B, Qwen3-235B-A22B\nfor model in [\n        v for k, v in locals().items() if k.startswith('turbomind_qwen3_32b') or k.startswith('pytorch_qwen3_32b')\n]:\n    model['abbr'] = 'qwen3_32b_turbomind'\n    model['path'] = 'Qwen/Qwen3-32B'\n\nfor model in [\n        v for k, v in locals().items()\n        if k.startswith('turbomind_qwen3_30b_a3b') or k.startswith('pytorch_qwen3_30b_a3b')\n]:\n    model['abbr'] = 'qwen3_30b_a3b_turbomind'\n    model['path'] = 'Qwen/Qwen3-30B-A3B'\n\nfor model in [\n        v for k, v in locals().items()\n        if k.startswith('turbomind_qwen3_30b_a3b_fp8') or k.startswith('pytorch_qwen3_30b_a3b_fp8')\n]:\n    model['abbr'] = 'qwen3_30b_a3b_fp8_turbomind'\n    model['path'] = 'Qwen/Qwen3-30B-A3B-FP8'\n\nfor model in [\n        v for k, v in locals().items()\n        if k.startswith('turbomind_qwen3_235b_a22b') or k.startswith('pytorch_qwen3_235b_a22b')\n]:\n    model['abbr'] = 'qwen3_235b_a22b_turbomind'\n    model['path'] = 'Qwen/Qwen3-235B-A22B'\n\nfor model in [\n        v for k, v in locals().items()\n        if k.startswith('turbomind_qwen3_235b_a22b_fp8') or k.startswith('pytorch_qwen3_235b_a22b_fp8')\n]:\n    model['abbr'] = 'qwen3_235b_a22b_fp8_turbomind'\n    model['path'] = 'Qwen/Qwen3-235B-A22B-FP8'\n\n# update config for turbomind, w4a4, w8a8, kvint4, kvint8, pytorch models\nfor model in [v for k, v in locals().items() if k.startswith('turbomind_')]:\n    model['engine_config']['max_batch_size'] = 512\n    model['gen_config']['do_sample'] = False\n    model['batch_size'] = 1000\n\nfor model in [v for k, v in locals().items() if k.endswith('_4bits')]:\n    model['engine_config']['model_format'] = 'awq'\n    model['abbr'] = model['abbr'] + '_4bits'\n    model['path'] = model['path'] + '-inner-4bits'\n\nfor model in [v for k, v in locals().items() if k.endswith('_w8a8')]:\n    model['abbr'] = model['abbr'] + '_w8a8'\n    model['path'] = model['path'] + '-inner-w8a8'\n\nfor model in [v for k, v in locals().items() if k.endswith('_kvint4')]:\n    model['engine_config']['quant_policy'] = 4\n    model['abbr'] = model['abbr'] + '_kvint4'\n\nfor model in [v for k, v in locals().items() if k.endswith('_kvint8')]:\n    model['engine_config']['quant_policy'] = 8\n    model['abbr'] = model['abbr'] + '_kvint8'\n\nfor model in [v for k, v in locals().items() if k.startswith('pytorch_')]:\n    model['abbr'] = model['abbr'].replace('turbomind', 'pytorch')\n    model['backend'] = 'pytorch'\n    model['engine_config']['max_batch_size'] = 512\n    model['gen_config']['do_sample'] = False\n    model['batch_size'] = 1000\n\nfor model in [v for k, v in locals().items() if '_batch1' in k]:\n    model['abbr'] = model['abbr'] + '_batch1'\n    model['engine_config']['max_batch_size'] = 1\n    model['batch_size'] = 1\n\n# update config for Qwen3-32B, Qwen3-30B-A3B, Qwen3-235B-A22B\nfor model in [\n        v for k, v in locals().items() if k.startswith('turbomind_qwen3_32b') or k.startswith('pytorch_qwen3_32b')\n]:\n    model['run_cfg']['num_gpus'] = 2\n    model['engine_config']['tp'] = 2\n    model['engine_config']['max_batch_size'] = 1024\n    model['batch_size'] = 2048\n\nfor model in [\n        v for k, v in locals().items()\n        if k.startswith('turbomind_qwen3_30b_a3b') or k.startswith('pytorch_qwen3_30b_a3b')\n]:\n    model['run_cfg']['num_gpus'] = 2\n    model['engine_config']['tp'] = 2\n    model['engine_config']['max_batch_size'] = 1024\n    model['batch_size'] = 2048\n\nfor model in [\n        v for k, v in locals().items()\n        if k.startswith('turbomind_qwen3_235b_a22b') or k.startswith('pytorch_qwen3_235b_a22b')\n]:\n    model['run_cfg']['num_gpus'] = 8\n    model['engine_config']['tp'] = 8\n    model['engine_config']['max_batch_size'] = 1024\n    model['batch_size'] = 2048\n\nturbomind_qwen3_235b_a22b_fp8['engine_config']['cache_max_entry_count'] = 0.6\nturbomind_qwen3_235b_a22b_fp8['engine_config']['tp'] = 4\nturbomind_qwen3_235b_a22b_fp8['run_cfg']['num_gpus'] = 4\npytorch_qwen3_235b_a22b_fp8['engine_config']['tp'] = 4\npytorch_qwen3_235b_a22b_fp8['run_cfg']['num_gpus'] = 4\n\nbasic_pytorch_chat_tp1 = dict(type=TurboMindModelwithChatTemplate,\n                              engine_config=dict(session_len=MAX_SESSION_LEN, max_batch_size=512, tp=1),\n                              gen_config=dict(do_sample=False, max_new_tokens=MAX_NEW_TOKENS),\n                              max_out_len=MAX_NEW_TOKENS,\n                              max_seq_len=MAX_SESSION_LEN,\n                              batch_size=1000,\n                              run_cfg=dict(num_gpus=1))\n\n# ===== Configs for Qwen/Qwen1.5-MoE-A2.7B-Chat =====\npytorch_qwen1_5_moe_2_7b_chat = deepcopy(basic_pytorch_chat_tp1)\npytorch_qwen1_5_moe_2_7b_chat['abbr'] = 'pytorch_qwen1_5_moe_2_7b_chat'\npytorch_qwen1_5_moe_2_7b_chat['path'] = 'Qwen/Qwen1.5-MoE-A2.7B-Chat'\n\n# ===== Configs for google/gemma2-7b-it =====\npytorch_gemma_2_9b_it = deepcopy(basic_pytorch_chat_tp1)\npytorch_gemma_2_9b_it['abbr'] = 'pytorch_gemma_2_9b_it'\npytorch_gemma_2_9b_it['path'] = 'google/gemma-2-9b-it'\n\n# ===== Configs for google/gemma2-27b-it =====\npytorch_gemma_2_27b_it = deepcopy(basic_pytorch_chat_tp1)\npytorch_gemma_2_27b_it['abbr'] = 'pytorch_gemma_2_27b_it'\npytorch_gemma_2_27b_it['path'] = 'google/gemma-2-27b-it'\npytorch_gemma_2_27b_it['run_cfg']['num_gpus'] = 2\npytorch_gemma_2_27b_it['engine_config']['tp'] = 2\n\nrace_datasets = [race_datasets[1]]\n\n# Summarizer\nsummarizer = dict(\n    dataset_abbrs=[\n        ['race-high', 'accuracy'],\n        ['ARC-c', 'accuracy'],\n        ['BoolQ', 'accuracy'],\n        ['mmlu_pro', 'naive_average'],\n        ['drop', 'accuracy'],\n        ['bbh', 'naive_average'],\n        ['GPQA_diamond', 'accuracy'],\n        ['math', 'accuracy'],\n        ['wikibench-wiki-single_choice_cncircular', 'perf_4'],\n        ['openai_humaneval', 'humaneval_pass@1'],\n        ['sanitized_mbpp', 'score'],\n        ['cmmlu', 'naive_average'],\n        ['mmlu', 'naive_average'],\n        ['teval', 'naive_average'],\n        ['SciCode', 'accuracy'],\n        ['SciCode', 'sub_accuracy'],\n        ['humanevalx', 'naive_average'],\n        ['ds1000', 'naive_average'],\n        ['IFEval', 'Prompt-level-strict-accuracy'],\n        ['gsm8k', 'accuracy'],\n        ['GaokaoBench', 'weighted_average'],\n        ['triviaqa_wiki_1shot', 'score'],\n        ['nq_open_1shot', 'score'],\n        ['hellaswag', 'accuracy'],\n        ['TheoremQA', 'score'],\n        '###### MathBench-A: Application Part ######',\n        'college',\n        'high',\n        'middle',\n        'primary',\n        'arithmetic',\n        'mathbench-a (average)',\n        '###### MathBench-T: Theory Part ######',\n        'college_knowledge',\n        'high_knowledge',\n        'middle_knowledge',\n        'primary_knowledge',\n        'mathbench-t (average)',\n        '###### Overall: Average between MathBench-A and MathBench-T ######',\n        'Overall',\n        '',\n        ''\n        'mmlu',\n        'mmlu-stem',\n        'mmlu-social-science',\n        'mmlu-humanities',\n        'mmlu-other',\n        '',\n        'cmmlu',\n        'cmmlu-stem',\n        'cmmlu-social-science',\n        'cmmlu-humanities',\n        'cmmlu-other',\n        'cmmlu-china-specific',\n        '',\n        'mmlu_pro',\n        'mmlu_pro_biology',\n        'mmlu_pro_business',\n        'mmlu_pro_chemistry',\n        'mmlu_pro_computer_science',\n        'mmlu_pro_economics',\n        'mmlu_pro_engineering',\n        'mmlu_pro_health',\n        'mmlu_pro_history',\n        'mmlu_pro_law',\n        'mmlu_pro_math',\n        'mmlu_pro_philosophy',\n        'mmlu_pro_physics',\n        'mmlu_pro_psychology',\n        'mmlu_pro_other',\n        '',\n        'humanevalx-python',\n        'humanevalx-cpp',\n        'humanevalx-go',\n        'humanevalx-java',\n        'humanevalx-js',\n        '',\n        'ds1000_Pandas',\n        'ds1000_Numpy',\n        'ds1000_Tensorflow',\n        'ds1000_Scipy',\n        'ds1000_Sklearn',\n        'ds1000_Pytorch',\n        'ds1000_Matplotlib',\n    ],\n    summary_groups=sum([v for k, v in locals().items() if k.endswith('_summary_groups')], []),\n)\n"
  },
  {
    "path": ".github/scripts/eval_regression_base_models.py",
    "content": "from copy import deepcopy\n\nfrom mmengine.config import read_base\n\nwith read_base():\n    # choose a list of datasets\n    from opencompass.configs.datasets.gpqa.gpqa_openai_simple_evals_gen_5aeece import gpqa_datasets  # noqa: F401, E501\n    from opencompass.configs.datasets.gsm8k.gsm8k_gen_17d0dc import gsm8k_datasets  # noqa: F401, E501\n    from opencompass.configs.datasets.race.race_ppl import race_datasets  # noqa: F401, E501\n    from opencompass.configs.datasets.winogrande.winogrande_5shot_ll_252f01 import \\\n        winogrande_datasets  # noqa: F401, E501\n    # read hf models - chat models\n    from opencompass.configs.models.chatglm.lmdeploy_glm4_9b import models as lmdeploy_glm4_9b_model  # noqa: F401, E501\n    from opencompass.configs.models.deepseek.lmdeploy_deepseek_7b_base import \\\n        models as lmdeploy_deepseek_7b_base_model  # noqa: F401, E501\n    from opencompass.configs.models.deepseek.lmdeploy_deepseek_67b_base import \\\n        models as lmdeploy_deepseek_67b_base_model  # noqa: F401, E501\n    from opencompass.configs.models.deepseek.lmdeploy_deepseek_v2 import lmdeploy_deepseek_v2_model  # noqa: F401, E501\n    from opencompass.configs.models.gemma.lmdeploy_gemma_9b import models as pytorch_gemma_9b_model  # noqa: F401, E501\n    from opencompass.configs.models.hf_internlm.lmdeploy_internlm2_1_8b import \\\n        models as lmdeploy_internlm2_1_8b_model  # noqa: F401, E501\n    from opencompass.configs.models.hf_internlm.lmdeploy_internlm2_5_7b import \\\n        models as lmdeploy_internlm2_5_7b_model  # noqa: F401, E501\n    from opencompass.configs.models.hf_internlm.lmdeploy_internlm2_20b import \\\n        models as lmdeploy_internlm2_20b_model  # noqa: F401, E501\n    from opencompass.configs.models.hf_internlm.lmdeploy_internlm2_base_7b import \\\n        models as lmdeploy_internlm2_base_7b_model  # noqa: F401, E501\n    from opencompass.configs.models.hf_llama.lmdeploy_llama3_1_8b import \\\n        models as lmdeploy_llama3_1_8b_model  # noqa: F401, E501\n    from opencompass.configs.models.hf_llama.lmdeploy_llama3_8b import \\\n        models as lmdeploy_llama3_8b_model  # noqa: F401, E501\n    from opencompass.configs.models.hf_llama.lmdeploy_llama3_70b import \\\n        models as lmdeploy_llama3_70b_model  # noqa: F401, E501\n    from opencompass.configs.models.qwen2_5.lmdeploy_qwen2_5_1_5b import \\\n        models as lmdeploy_qwen2_5_1_5b_model  # noqa: F401, E501\n    from opencompass.configs.models.qwen2_5.lmdeploy_qwen2_5_7b import \\\n        models as lmdeploy_qwen2_5_7b_model  # noqa: F401, E501\n    from opencompass.configs.models.qwen2_5.lmdeploy_qwen2_5_32b import \\\n        models as lmdeploy_qwen2_5_32b_model  # noqa: F401, E501\n    from opencompass.configs.models.qwen2_5.lmdeploy_qwen2_5_72b import \\\n        models as lmdeploy_qwen2_5_72b_model  # noqa: F401, E501\n    from opencompass.configs.models.qwen.lmdeploy_qwen2_1_5b import \\\n        models as lmdeploy_qwen2_1_5b_model  # noqa: F401, E501\n    from opencompass.configs.models.qwen.lmdeploy_qwen2_7b import models as lmdeploy_qwen2_7b_model  # noqa: F401, E501\n    from opencompass.configs.models.yi.lmdeploy_yi_1_5_9b import models as lmdeploy_yi_1_5_9b_model  # noqa: F401, E501\n\n    from .volc import infer as volc_infer  # noqa: F401, E501\n\nrace_datasets = [race_datasets[1]]\ndatasets = sum([v for k, v in locals().items() if k.endswith('_datasets')], [])\n\npytorch_glm4_9b_model = deepcopy(lmdeploy_glm4_9b_model)\npytorch_deepseek_7b_base_model = deepcopy(lmdeploy_deepseek_7b_base_model)\npytorch_deepseek_67b_base_model = deepcopy(lmdeploy_deepseek_67b_base_model)\npytorch_deepseek_v2_model = deepcopy(lmdeploy_deepseek_v2_model)\npytorch_internlm2_5_7b_model = deepcopy(lmdeploy_internlm2_5_7b_model)\npytorch_internlm2_20b_model = deepcopy(lmdeploy_internlm2_20b_model)\npytorch_internlm2_base_7b_model = deepcopy(lmdeploy_internlm2_base_7b_model)\npytorch_llama3_1_8b_model = deepcopy(lmdeploy_llama3_1_8b_model)\npytorch_llama3_70b_model = deepcopy(lmdeploy_llama3_70b_model)\npytorch_qwen2_5_1_5b_model = deepcopy(lmdeploy_qwen2_5_1_5b_model)\npytorch_qwen2_5_72b_model = deepcopy(lmdeploy_qwen2_5_72b_model)\npytorch_qwen2_7b_model = deepcopy(lmdeploy_qwen2_7b_model)\npytorch_yi_1_5_9b_model = deepcopy(lmdeploy_yi_1_5_9b_model)\npytorch_deepseek_v2_model['engine_config']['cache_max_entry_count'] = 0.6\n\nlmdeploy_glm4_9b_model_native = deepcopy(lmdeploy_glm4_9b_model)\nlmdeploy_deepseek_7b_base_model_native = deepcopy(lmdeploy_deepseek_7b_base_model)\nlmdeploy_deepseek_67b_base_model_native = deepcopy(lmdeploy_deepseek_67b_base_model)\nlmdeploy_deepseek_v2_model_native = deepcopy(lmdeploy_deepseek_v2_model)\nlmdeploy_internlm2_5_7b_model_native = deepcopy(lmdeploy_internlm2_5_7b_model)\nlmdeploy_internlm2_20b_model_native = deepcopy(lmdeploy_internlm2_20b_model)\nlmdeploy_internlm2_base_7b_model_native = deepcopy(lmdeploy_internlm2_base_7b_model)\nlmdeploy_llama3_1_8b_model_native = deepcopy(lmdeploy_llama3_1_8b_model)\nlmdeploy_llama3_70b_model_native = deepcopy(lmdeploy_llama3_70b_model)\nlmdeploy_qwen2_5_1_5b_model_native = deepcopy(lmdeploy_qwen2_5_1_5b_model)\nlmdeploy_qwen2_5_72b_model_native = deepcopy(lmdeploy_qwen2_5_72b_model)\nlmdeploy_qwen2_7b_model_native = deepcopy(lmdeploy_qwen2_7b_model)\nlmdeploy_yi_1_5_9b_model_native = deepcopy(lmdeploy_yi_1_5_9b_model)\n\nfor model in [v for k, v in locals().items() if k.startswith('lmdeploy_') or k.startswith('pytorch_')]:\n    for m in model:\n        m['engine_config']['max_batch_size'] = 512\n        m['gen_config']['do_sample'] = False\n        m['batch_size'] = 5000\n\nfor model in [v for k, v in locals().items() if k.startswith('lmdeploy_')]:\n    for m in model:\n        m['backend'] = 'turbomind'\n\nfor model in [v for k, v in locals().items() if k.startswith('pytorch_')]:\n    for m in model:\n        m['abbr'] = m['abbr'].replace('turbomind', 'pytorch').replace('lmdeploy', 'pytorch')\n        m['backend'] = 'pytorch'\n\nfor model in [v for k, v in locals().items() if k.endswith('_native')]:\n    for m in model:\n        m['abbr'] = m['abbr'] + '_native'\n        m['engine_config']['communicator'] = 'native'\n\n# models = sum([v for k, v in locals().items() if  k.startswith('lmdeploy_') or k.startswith('pytorch_')], [])\n# models = sorted(models, key=lambda x: x['run_cfg']['num_gpus'])\n\nsummarizer = dict(\n    dataset_abbrs=[\n        ['gsm8k', 'accuracy'],\n        ['GPQA_diamond', 'accuracy'],\n        ['race-high', 'accuracy'],\n        ['winogrande', 'accuracy'],\n    ],\n    summary_groups=sum([v for k, v in locals().items() if k.endswith('_summary_groups')], []),\n)\n"
  },
  {
    "path": ".github/scripts/eval_regression_chat_models.py",
    "content": "from copy import deepcopy\n\nfrom mmengine.config import read_base\n\nwith read_base():\n    # choose a list of datasets\n    from opencompass.configs.datasets.gpqa.gpqa_openai_simple_evals_gen_5aeece import gpqa_datasets  # noqa: F401, E501\n    from opencompass.configs.datasets.IFEval.IFEval_gen_353ae7 import ifeval_datasets  # noqa: F401, E501\n    from opencompass.configs.datasets.math.math_0shot_gen_11c4b5 import math_datasets  # noqa: F401, E501\n    # read hf models - chat models\n    from opencompass.configs.models.chatglm.lmdeploy_glm4_9b_chat import \\\n        models as lmdeploy_glm4_9b_chat_model  # noqa: F401, E501\n    from opencompass.configs.models.deepseek.lmdeploy_deepseek_r1_distill_qwen_32b import \\\n        models as lmdeploy_deepseek_r1_distill_qwen_32b_model  # noqa: F401, E501\n    from opencompass.configs.models.deepseek.lmdeploy_deepseek_v2_5_1210 import \\\n        models as lmdeploy_deepseek_v2_5_1210_model  # noqa: F401, E501\n    from opencompass.configs.models.deepseek.lmdeploy_deepseek_v2_lite import \\\n        models as lmdeploy_deepseek_v2_lite_model  # noqa: F401, E501\n    from opencompass.configs.models.gemma.lmdeploy_gemma_9b_it import \\\n        models as pytorch_gemma_9b_it_model  # noqa: F401, E501\n    from opencompass.configs.models.gemma.lmdeploy_gemma_27b_it import \\\n        models as pytorch_gemma_27b_it_model  # noqa: F401, E501\n    from opencompass.configs.models.hf_internlm.lmdeploy_internlm2_5_7b_chat import \\\n        models as lmdeploy_internlm2_5_7b_chat_model  # noqa: F401, E501\n    from opencompass.configs.models.hf_internlm.lmdeploy_internlm2_5_20b_chat import \\\n        models as lmdeploy_internlm2_5_20b_chat_model  # noqa: F401, E501\n    from opencompass.configs.models.hf_internlm.lmdeploy_internlm2_chat_1_8b import \\\n        models as lmdeploy_internlm2_chat_1_8b_model  # noqa: F401, E501\n    from opencompass.configs.models.hf_internlm.lmdeploy_internlm2_chat_1_8b_sft import \\\n        models as lmdeploy_internlm2_chat_1_8b_sft_model  # noqa: F401, E501\n    from opencompass.configs.models.hf_internlm.lmdeploy_internlm2_chat_7b import \\\n        models as lmdeploy_internlm2_chat_7b_model  # noqa: F401, E501\n    from opencompass.configs.models.hf_internlm.lmdeploy_internlm2_chat_7b_sft import \\\n        models as lmdeploy_internlm2_chat_7b_sft_model  # noqa: F401, E501\n    from opencompass.configs.models.hf_internlm.lmdeploy_internlm3_8b_instruct import \\\n        models as lmdeploy_internlm3_8b_instruct_model  # noqa: F401, E501\n    from opencompass.configs.models.hf_llama.lmdeploy_llama2_7b_chat import \\\n        models as lmdeploy_llama2_7b_chat_model  # noqa: F401, E501\n    from opencompass.configs.models.hf_llama.lmdeploy_llama3_1_8b_instruct import \\\n        models as lmdeploy_llama3_1_8b_instruct_model  # noqa: F401, E501\n    from opencompass.configs.models.hf_llama.lmdeploy_llama3_2_3b_instruct import \\\n        models as lmdeploy_llama3_2_3b_instruct_model  # noqa: F401, E501\n    from opencompass.configs.models.hf_llama.lmdeploy_llama3_3_70b_instruct import \\\n        models as lmdeploy_llama3_3_70b_instruct_model  # noqa: F401, E501\n    from opencompass.configs.models.hf_llama.lmdeploy_llama3_8b_instruct import \\\n        models as lmdeploy_llama3_8b_instruct_model  # noqa: F401, E501\n    from opencompass.configs.models.mistral.lmdeploy_mistral_large_instruct_2411 import \\\n        models as lmdeploy_mistral_large_instruct_2411_model  # noqa: F401, E501\n    from opencompass.configs.models.mistral.lmdeploy_mistral_nemo_instruct_2407 import \\\n        models as lmdeploy_mistral_nemo_instruct_2407_model  # noqa: F401, E501\n    from opencompass.configs.models.mistral.lmdeploy_mistral_small_instruct_2409 import \\\n        models as lmdeploy_mistral_small_instruct_2409_model  # noqa: F401, E501\n    from opencompass.configs.models.nvidia.lmdeploy_nemotron_70b_instruct_hf import \\\n        models as lmdeploy_nemotron_70b_instruct_hf_model  # noqa: F401, E501\n    from opencompass.configs.models.qwen2_5.lmdeploy_qwen2_5_0_5b_instruct import \\\n        models as lmdeploy_qwen2_5_0_5b_instruct_model  # noqa: F401, E501\n    from opencompass.configs.models.qwen2_5.lmdeploy_qwen2_5_3b_instruct import \\\n        models as lmdeploy_qwen2_5_3b_instruct_model  # noqa: F401, E501\n    from opencompass.configs.models.qwen2_5.lmdeploy_qwen2_5_14b_instruct import \\\n        models as lmdeploy_qwen2_5_14b_instruct_model  # noqa: F401, E501\n    from opencompass.configs.models.qwen2_5.lmdeploy_qwen2_5_32b_instruct import \\\n        models as lmdeploy_qwen2_5_32b_instruct_model  # noqa: F401, E501\n    from opencompass.configs.models.qwen2_5.lmdeploy_qwen2_5_72b_instruct import \\\n        models as lmdeploy_qwen2_5_72b_instruct_model  # noqa: F401, E501\n    from opencompass.configs.models.qwen.lmdeploy_qwen2_1_5b_instruct import \\\n        models as lmdeploy_qwen2_1_5b_instruct_model  # noqa: F401, E501\n    from opencompass.configs.models.qwen.lmdeploy_qwen2_7b_instruct import \\\n        models as lmdeploy_qwen2_7b_instruct_model  # noqa: F401, E501\n    from opencompass.configs.models.yi.lmdeploy_yi_1_5_6b_chat import \\\n        models as lmdeploy_yi_1_5_6b_chat_model  # noqa: F401, E501\n    from opencompass.configs.models.yi.lmdeploy_yi_1_5_9b_chat import \\\n        models as lmdeploy_yi_1_5_9b_chat_model  # noqa: F401, E501\n    from opencompass.configs.models.yi.lmdeploy_yi_1_5_34b_chat import \\\n        models as lmdeploy_yi_1_5_34b_chat_model  # noqa: F401, E501\n\n    from .volc import infer as volc_infer  # noqa: F401, E501\n\ndatasets = sum([v for k, v in locals().items() if k.endswith('_datasets')], [])\n\npytorch_glm4_9b_chat_model = deepcopy(lmdeploy_glm4_9b_chat_model)\npytorch_deepseek_v2_lite_model = deepcopy(lmdeploy_deepseek_v2_lite_model)\npytorch_deepseek_v2_5_1210_model = deepcopy(lmdeploy_deepseek_v2_5_1210_model)\npytorch_internlm3_8b_instruct_model = deepcopy(lmdeploy_internlm3_8b_instruct_model)\npytorch_internlm2_5_7b_chat_model = deepcopy(lmdeploy_internlm2_5_7b_chat_model)\npytorch_internlm2_5_20b_chat_model = deepcopy(lmdeploy_internlm2_5_20b_chat_model)\npytorch_llama3_2_3b_instruct_model = deepcopy(lmdeploy_llama3_2_3b_instruct_model)\npytorch_llama3_3_70b_instruct_model = deepcopy(lmdeploy_llama3_3_70b_instruct_model)\npytorch_mistral_nemo_instruct_2407_model = deepcopy(lmdeploy_mistral_nemo_instruct_2407_model)\npytorch_mistral_small_instruct_2409_model = deepcopy(lmdeploy_mistral_small_instruct_2409_model)\npytorch_qwen2_5_72b_instruct_model = deepcopy(lmdeploy_qwen2_5_72b_instruct_model)\npytorch_qwen2_5_32b_instruct_model = deepcopy(lmdeploy_qwen2_5_32b_instruct_model)\npytorch_qwen2_7b_instruct_model = deepcopy(lmdeploy_qwen2_7b_instruct_model)\npytorch_yi_1_5_34b_chat_model = deepcopy(lmdeploy_yi_1_5_34b_chat_model)\npytorch_deepseek_v2_5_1210_model['engine_config']['cache_max_entry_count'] = 0.6\n\nlmdeploy_glm4_9b_chat_model_native = deepcopy(lmdeploy_glm4_9b_chat_model)\nlmdeploy_deepseek_r1_distill_qwen_32b_model_native = deepcopy(lmdeploy_deepseek_r1_distill_qwen_32b_model)\nlmdeploy_deepseek_v2_lite_model_native = deepcopy(lmdeploy_deepseek_v2_lite_model)\nlmdeploy_deepseek_v2_5_1210_model_native = deepcopy(lmdeploy_deepseek_v2_5_1210_model)\nlmdeploy_internlm3_8b_instruct_model_native = deepcopy(lmdeploy_internlm3_8b_instruct_model)\nlmdeploy_internlm2_5_7b_chat_model_native = deepcopy(lmdeploy_internlm2_5_7b_chat_model)\nlmdeploy_internlm2_5_20b_chat_model_native = deepcopy(lmdeploy_internlm2_5_20b_chat_model)\nlmdeploy_llama3_1_8b_instruct_model_native = deepcopy(lmdeploy_llama3_1_8b_instruct_model)\nlmdeploy_llama3_2_3b_instruct_model_native = deepcopy(lmdeploy_llama3_2_3b_instruct_model)\nlmdeploy_llama3_8b_instruct_model_native = deepcopy(lmdeploy_llama3_8b_instruct_model)\nlmdeploy_llama3_3_70b_instruct_model_native = deepcopy(lmdeploy_llama3_3_70b_instruct_model)\nlmdeploy_mistral_large_instruct_2411_model_native = deepcopy(lmdeploy_mistral_large_instruct_2411_model)\nlmdeploy_mistral_nemo_instruct_2407_model_native = deepcopy(lmdeploy_mistral_nemo_instruct_2407_model)\nlmdeploy_mistral_small_instruct_2409_model_native = deepcopy(lmdeploy_mistral_small_instruct_2409_model)\nlmdeploy_nemotron_70b_instruct_hf_model_native = deepcopy(lmdeploy_nemotron_70b_instruct_hf_model)\nlmdeploy_qwen2_5_0_5b_instruct_model_native = deepcopy(lmdeploy_qwen2_5_0_5b_instruct_model)\nlmdeploy_qwen2_5_14b_instruct_model_native = deepcopy(lmdeploy_qwen2_5_14b_instruct_model)\nlmdeploy_qwen2_5_32b_instruct_model_native = deepcopy(lmdeploy_qwen2_5_32b_instruct_model)\nlmdeploy_qwen2_5_72b_instruct_model_native = deepcopy(lmdeploy_qwen2_5_72b_instruct_model)\nlmdeploy_qwen2_7b_instruct_model_native = deepcopy(lmdeploy_qwen2_7b_instruct_model)\nlmdeploy_yi_1_5_6b_chat_model_native = deepcopy(lmdeploy_yi_1_5_6b_chat_model)\nlmdeploy_yi_1_5_34b_chat_model_native = deepcopy(lmdeploy_yi_1_5_34b_chat_model)\n\nfor model in [v for k, v in locals().items() if k.startswith('lmdeploy_') or k.startswith('pytorch_')]:\n    for m in model:\n        m['engine_config']['max_batch_size'] = 512\n        m['gen_config']['do_sample'] = False\n        m['batch_size'] = 5000\n\nfor model in [v for k, v in locals().items() if k.startswith('lmdeploy_')]:\n    for m in model:\n        m['backend'] = 'turbomind'\n\nfor model in [v for k, v in locals().items() if k.startswith('pytorch_')]:\n    for m in model:\n        m['abbr'] = m['abbr'].replace('turbomind', 'pytorch').replace('lmdeploy', 'pytorch')\n        m['backend'] = 'pytorch'\n\nfor model in [v for k, v in locals().items() if k.endswith('_native')]:\n    for m in model:\n        m['abbr'] = m['abbr'] + '_native'\n        m['engine_config']['communicator'] = 'native'\n\n# models = sum([v for k, v in locals().items() if  k.startswith('lmdeploy_') or k.startswith('pytorch_')], [])\n# models = sorted(models, key=lambda x: x['run_cfg']['num_gpus'])\n\nsummarizer = dict(\n    dataset_abbrs=[\n        ['GPQA_diamond', 'accuracy'],\n        ['math', 'accuracy'],\n        ['IFEval', 'Prompt-level-strict-accuracy'],\n    ],\n    summary_groups=sum([v for k, v in locals().items() if k.endswith('_summary_groups')], []),\n)\n"
  },
  {
    "path": ".github/scripts/eval_stable_object_config.py",
    "content": "from mmengine.config import read_base\nfrom opencompass.models import OpenAISDK\n\nwith read_base():\n    # choose a list of datasets\n    from opencompass.configs.datasets.ARC_c.ARC_c_cot_gen_926652 import ARC_c_datasets  # noqa: F401, E501\n    from opencompass.configs.datasets.bbh.bbh_gen_5b92b0 import bbh_datasets  # noqa: F401, E501\n    from opencompass.configs.datasets.CHARM.charm_reason_cot_only_gen_f7b7d3 import \\\n        charm_reason_datasets  # noqa: F401, E501\n    from opencompass.configs.datasets.cmmlu.cmmlu_0shot_cot_gen_305931 import cmmlu_datasets  # noqa: F401, E501\n    from opencompass.configs.datasets.drop.drop_openai_simple_evals_gen_3857b0 import drop_datasets  # noqa: F401, E501\n    from opencompass.configs.datasets.ds1000.ds1000_service_eval_gen_cbc84f import ds1000_datasets  # noqa: F401, E501\n    from opencompass.configs.datasets.gpqa.gpqa_openai_simple_evals_gen_5aeece import gpqa_datasets  # noqa: F401, E501\n    from opencompass.configs.datasets.gsm8k.gsm8k_0shot_v2_gen_a58960 import gsm8k_datasets  # noqa: F401, E501\n    from opencompass.configs.datasets.hellaswag.hellaswag_10shot_gen_e42710 import \\\n        hellaswag_datasets  # noqa: F401, E501\n    from opencompass.configs.datasets.humaneval.humaneval_openai_sample_evals_gen_159614 import \\\n        humaneval_datasets  # noqa: F401, E501\n    from opencompass.configs.datasets.humanevalx.humanevalx_gen_620cfa import humanevalx_datasets  # noqa: F401, E501\n    from opencompass.configs.datasets.IFEval.IFEval_gen_3321a3 import ifeval_datasets  # noqa: F401, E501\n    from opencompass.configs.datasets.LCBench.lcbench_gen_5ff288 import LCBench_datasets  # noqa: F401, E501\n    from opencompass.configs.datasets.math.math_0shot_gen_393424 import math_datasets  # noqa: F401, E501\n    from opencompass.configs.datasets.MathBench.mathbench_2024_gen_50a320 import mathbench_datasets  # noqa: F401, E501\n    from opencompass.configs.datasets.mbpp.sanitized_mbpp_mdblock_gen_a447ff import \\\n        sanitized_mbpp_datasets  # noqa: F401, E501\n    from opencompass.configs.datasets.mmlu.mmlu_openai_simple_evals_gen_b618ea import mmlu_datasets  # noqa: F401, E501\n    from opencompass.configs.datasets.mmlu_pro.mmlu_pro_0shot_cot_gen_08c1de import \\\n        mmlu_pro_datasets  # noqa: F401, E501\n    from opencompass.configs.datasets.race.race_cot_gen_d95929 import race_datasets  # noqa: F401, E501\n    from opencompass.configs.datasets.scicode.scicode_gen_085b98 import SciCode_datasets  # noqa: F401, E501\n    from opencompass.configs.datasets.SuperGLUE_BoolQ.SuperGLUE_BoolQ_cot_gen_1d56df import \\\n        BoolQ_datasets  # noqa: F401, E501\n    from opencompass.configs.datasets.teval.teval_en_gen_1ac254 import \\\n        teval_datasets as teval_en_datasets  # noqa: F401, E501\n    from opencompass.configs.datasets.teval.teval_zh_gen_1ac254 import \\\n        teval_datasets as teval_zh_datasets  # noqa: F401, E501\n    from opencompass.configs.datasets.TheoremQA.TheoremQA_5shot_gen_6f0af8 import TheoremQA_datasets  # noqa: F401, E501\n    from opencompass.configs.datasets.wikibench.wikibench_gen_0978ad import wikibench_datasets  # noqa: F401, E501\n\ndatasets = sum(\n    (v for k, v in locals().items() if k.endswith('_datasets') and 'scicode' not in k.lower() and 'teval' not in k), [])\ndatasets += teval_en_datasets\ndatasets += teval_zh_datasets\ndatasets += SciCode_datasets\n\napi_meta_template = dict(\n    round=[\n        dict(role='HUMAN', api_role='HUMAN'),\n        dict(role='BOT', api_role='BOT', generate=True),\n    ],\n    reserved_roles=[dict(role='SYSTEM', api_role='SYSTEM')],\n)\n\nmodels = [\n    dict(\n        abbr='lmdeploy-api-test',\n        type=OpenAISDK,\n        key='EMPTY',\n        openai_api_base='http://localhost:23344/v1',\n        path='/nvme/qa_test_models/internlm/internlm2_5-20b-chat',\n        tokenizer_path='/nvme/qa_test_models/internlm/internlm2_5-20b-chat',\n        rpm_verbose=True,\n        meta_template=api_meta_template,\n        query_per_second=100,\n        max_out_len=1024,\n        max_seq_len=4096,\n        temperature=0.01,\n        batch_size=128,\n        retry=3,\n    )\n]\n"
  },
  {
    "path": ".github/scripts/eval_stable_subject_config.py",
    "content": "from mmengine.config import read_base\nfrom opencompass.models import OpenAISDK\nfrom opencompass.partitioners.sub_naive import SubjectiveNaivePartitioner\nfrom opencompass.runners import LocalRunner\nfrom opencompass.tasks.subjective_eval import SubjectiveEvalTask\n\nwith read_base():\n    # choose a list of datasets\n    from opencompass.configs.datasets.subjective.alignbench.alignbench_judgeby_critiquellm import \\\n        alignbench_datasets  # noqa: F401, E501\n    from opencompass.configs.datasets.subjective.alpaca_eval.alpacav2_judgeby_gpt4 import \\\n        alpacav2_datasets  # noqa: F401, E501\n    from opencompass.configs.datasets.subjective.arena_hard.arena_hard_compare import \\\n        arenahard_datasets  # noqa: F401, E501\n    from opencompass.configs.datasets.subjective.compassarena.compassarena_compare import \\\n        compassarena_datasets  # noqa: F401, E501\n    from opencompass.configs.datasets.subjective.fofo.fofo_bilingual_judge import fofo_datasets  # noqa: F401, E501\n    from opencompass.configs.datasets.subjective.multiround.mtbench101_judge import \\\n        mtbench101_datasets  # noqa: F401, E501\n    from opencompass.configs.datasets.subjective.wildbench.wildbench_pair_judge import \\\n        wildbench_datasets  # noqa: F401, E501\n\ndatasets = sum((v for k, v in locals().items() if k.endswith('_datasets') and 'wildbench' not in k), [])\ndatasets += wildbench_datasets\n\napi_meta_template = dict(\n    round=[\n        dict(role='HUMAN', api_role='HUMAN'),\n        dict(role='BOT', api_role='BOT', generate=True),\n    ],\n    reserved_roles=[dict(role='SYSTEM', api_role='SYSTEM')],\n)\n\nmodels = [\n    dict(\n        abbr='lmdeploy-api-test',\n        type=OpenAISDK,\n        key='EMPTY',\n        openai_api_base='http://localhost:23344/v1',\n        path='/nvme/qa_test_models/internlm/internlm2_5-20b-chat',\n        tokenizer_path='/nvme/qa_test_models/internlm/internlm2_5-20b-chat',\n        rpm_verbose=True,\n        meta_template=api_meta_template,\n        query_per_second=100,\n        max_out_len=1024,\n        max_seq_len=4096,\n        temperature=0.01,\n        batch_size=128,\n        retry=3,\n    )\n]\n\njudge_models = models\n\neval = dict(\n    partitioner=dict(\n        type=SubjectiveNaivePartitioner,\n        models=models,\n        judge_models=judge_models,\n    ),\n    runner=dict(type=LocalRunner, max_num_workers=16, task=dict(type=SubjectiveEvalTask)),\n)\n"
  },
  {
    "path": ".github/workflows/api_eval.yml",
    "content": "name: api_eval\n\non:\n  workflow_dispatch:\n    inputs:\n      repo_org:\n        required: false\n        description: 'Tested repository organization name. Default is InternLM/lmdeploy'\n        type: string\n        default: 'InternLM/lmdeploy'\n      repo_ref:\n        required: false\n        description: 'Set branch or tag or commit id. Default is \"main\"'\n        type: string\n        default: 'main'\n      backend:\n        required: true\n        description: 'Set backend filter. Default is \"[\"turbomind\", \"pytorch\"]\"'\n        type: string\n        default: \"['turbomind', 'pytorch']\"\n      execution_mode:\n        required: false\n        description: 'Select execution mode: infer, eval, or both. Default is \"both\"'\n        type: choice\n        options:\n          - both\n          - infer\n          - eval\n        default: 'both'\n      run_id:\n        required: false\n        description: 'Set custom run ID. If not provided, github.run_id will be used'\n        type: string\n        default: ''\n      offline_mode:\n        required: true\n        description: 'Whether start a offline mode, if true, you should prepare code and whl package by yourself'\n        type: boolean\n        default: false\n\nenv:\n  HOST_PIP_CACHE_DIR: /nvme/github-actions/pip-cache\n  HOST_LOCALTIME: /usr/share/zoneinfo/Asia/Shanghai\n  ACTIONS_ALLOW_USE_UNSECURE_NODE_VERSION: true\n  REPORT_DIR: /nvme/qa_test_models/evaluation_report/allure_report/${{ inputs.repo_ref }}_${{ github.run_id }}\n  COV_PARAM: --cov /opt/py3/lib/python3.10/site-packages/lmdeploy\n  TEST_CODE_PATH: /nvme/qa_test_models/test_pkg/lmdeploy/${{ inputs.repo_ref }}_${{ github.run_id }}\n  OFFLINE_CODE_PATH: /nvme/qa_test_models/offline_pkg/lmdeploy\n  COMPASS_DATA_CACHE: /nvme/qa_test_models/compass_data_cache\n  HF_DATASETS_OFFLINE: 1\n  HF_DATASETS_CACHE: /nvme/qa_test_models/hf_datasets\n  HF_HUB_OFFLINE: 1\n  HF_EVALUATE_OFFLINE: 1\n  RUN_ID: ${{ inputs.repo_ref }}_${{ github.run_id }}\n\njobs:\n  linux-build:\n    if: ${{github.event_name == 'schedule' || (!cancelled() && !inputs.offline_mode)}}\n    strategy:\n      matrix:\n        pyver: [py310]\n    runs-on: ubuntu-latest\n    env:\n      PYTHON_VERSION: ${{ matrix.pyver }}\n      PLAT_NAME: manylinux2014_x86_64\n      DOCKER_TAG: cuda12.8\n      OUTPUT_FOLDER: cuda12.8_dist_${{ github.run_id }}\n    steps:\n      - name: Free disk space\n        uses: jlumbroso/free-disk-space@main\n        with:\n          # This might remove tools that are actually needed, if set to \"true\" but frees about 6 GB\n          tool-cache: false\n          docker-images: false\n          # All of these default to true, but feel free to set to \"false\" if necessary for your workflow\n          android: true\n          dotnet: true\n          haskell: true\n          large-packages: true\n          swap-storage: false\n      - name: Checkout repository\n        uses: actions/checkout@v3\n        with:\n          repository: ${{ github.event.inputs.repo_org || 'InternLM/lmdeploy' }}\n          ref: ${{github.event.inputs.repo_ref || 'main'}}\n      - name: Build\n        run: |\n          echo ${PYTHON_VERSION}\n          echo ${PLAT_NAME}\n          echo ${DOCKER_TAG}\n          echo ${OUTPUT_FOLDER}\n          echo ${GITHUB_RUN_ID}\n          # remove -it\n          sed -i 's/docker run --rm -it/docker run --rm/g' builder/manywheel/build_wheel.sh\n          bash builder/manywheel/build_wheel.sh ${PYTHON_VERSION} ${PLAT_NAME} ${DOCKER_TAG} ${OUTPUT_FOLDER}\n      - name: Upload Artifacts\n        uses: actions/upload-artifact@v4\n        with:\n          if-no-files-found: error\n          path: builder/manywheel/${{ env.OUTPUT_FOLDER }}\n          retention-days: 1\n          name: my-artifact-${{ github.run_id }}-${{ matrix.pyver }}\n\n\n  download_pkgs:\n    needs: linux-build\n    if: ${{!cancelled()}}\n    runs-on: [self-hosted, linux-a100]\n    timeout-minutes: 50\n    container:\n      image: openmmlab/lmdeploy:latest-cu12.8\n      options: \"--gpus=all --ipc=host --user root -e PIP_CACHE_DIR=/root/.cache/pip -e NVIDIA_DISABLE_REQUIRE=1 --pull never\"\n      volumes:\n        - /nvme/qa_test_models:/nvme/qa_test_models\n        - /mnt/121:/mnt/121\n        - /mnt/104:/mnt/104\n        - /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro\n    steps:\n      - name: Clone repository\n        uses: actions/checkout@v2\n        if: ${{github.event_name == 'schedule' || !inputs.offline_mode}}\n        with:\n          repository: ${{ github.event.inputs.repo_org || 'InternLM/lmdeploy' }}\n          ref: ${{github.event.inputs.repo_ref || 'main'}}\n      - name: Copy repository\n        if: ${{github.event_name == 'schedule' || !inputs.offline_mode}}\n        run: rm -rf ${{env.TEST_CODE_PATH}} && mkdir ${{env.TEST_CODE_PATH}} && chmod 777 ${{env.TEST_CODE_PATH}} && cp -r . ${{env.TEST_CODE_PATH}}\n      - name: Copy repository - offline\n        if: ${{inputs.offline_mode}}\n        run: rm -rf ${{env.TEST_CODE_PATH}} && mkdir ${{env.TEST_CODE_PATH}} && chmod 777 ${{env.TEST_CODE_PATH}} && cp -r ${{env.OFFLINE_CODE_PATH}}/. ${{env.TEST_CODE_PATH}}\n      - name: Download Artifacts\n        if: ${{github.event_name == 'schedule' || !inputs.offline_mode}}\n        uses: actions/download-artifact@v4\n        with:\n          name: my-artifact-${{ github.run_id }}-py310\n      - name: Copy Artifacts\n        if: ${{github.event_name == 'schedule' || !inputs.offline_mode}}\n        run: rm ${{env.TEST_CODE_PATH}}/lmdeploy-*.whl -f && cp lmdeploy-*.whl ${{env.TEST_CODE_PATH}}\n      - name: Copy Artifacts - offline\n        if: ${{inputs.offline_mode}}\n        run: rm ${{env.TEST_CODE_PATH}}/lmdeploy-*.whl -f && cp ${{env.OFFLINE_CODE_PATH}}/lmdeploy-*.whl ${{env.TEST_CODE_PATH}}\n      - name: Mark as start\n        run: |\n          chmod -R 777 ${{env.TEST_CODE_PATH}}\n          mkdir ${{env.REPORT_DIR}} -p\n          echo \"starttime=$(date +%s)\" > ${{env.REPORT_DIR}}/status.txt\n\n  test_evaluation:\n    needs: download_pkgs\n    if: ${{ !cancelled() }}\n    runs-on: [self-hosted, linux-a100]\n    timeout-minutes: 7200\n    strategy:\n      fail-fast: false\n      matrix:\n        backend: ${{ fromJSON(inputs.backend || '[\"turbomind\", \"pytorch\"]')}}\n        gpu_num: ['gpu_num_1', 'gpu_num_2', 'gpu_num_4', 'gpu_num_8']\n        transformers: [\"\", \"legacy\"]\n    env:\n      TEST_ENV: ${{ matrix.transformers }}\n    container:\n      image: openmmlab/lmdeploy:latest-cu12.8\n      options: \"--gpus=all --ipc=host --user root -e PIP_CACHE_DIR=/root/.cache/pip -e NVIDIA_DISABLE_REQUIRE=1 --pull never\"\n      volumes:\n        - /nvme/github-actions/pip-cache:/root/.cache/pip\n        - /nvme/github-actions/packages:/root/packages\n        - /nvme/github-actions/resources:/root/resources\n        - /nvme/qa_test_models:/nvme/qa_test_models\n        - /nvme/huggingface_hub:/nvme/huggingface_hub\n        - /mnt/121:/mnt/121\n        - /mnt/104:/mnt/104\n        - /mnt/bigdisk:/mnt/bigdisk\n        - /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro\n    steps:\n      - name: Copy repository and Artifacts\n        run: |\n          cp -r ${{env.TEST_CODE_PATH}}/. .\n          mkdir ${{env.REPORT_DIR}} -p\n          echo \"starttime=$(date +%s)\" > ${{env.REPORT_DIR}}/status.txt\n      - name: Install lmdeploy - dependency\n        run: |\n          python3 -m pip install -r /nvme/qa_test_models/offline_pkg/requirements.txt\n      - name: Install lmdeploy\n        run: |\n          python3 -m pip uninstall lmdeploy -y && python3 -m pip install lmdeploy-*.whl --no-deps\n          python3 -m pip install -r requirements/test.txt\n      - name: Install opencompass\n        run: |\n          git clone https://github.com/open-compass/opencompass.git --depth 1\n          cd opencompass\n          python3 -m pip install .\n          python3 -m pip install langdetect\n      - name: Downgrade transformers\n        if: ${{matrix.transformers == 'legacy'}}\n        run: |\n          pip install transformers==4.57.6\n      - name: Check env\n        run: |\n          python3 -m pip list\n          lmdeploy check_env\n          mkdir ${{env.REPORT_DIR}} -p\n          echo \"starttime=$(date +%s)\" > ${{env.REPORT_DIR}}/status.txt\n      - name: Setup paths for evaluation\n        if: (matrix.backend == 'pytorch' || matrix.backend == 'turbomind')\n        run: |\n          overall_exit=0\n          ln -s /mnt/104/opencompass-data/data ./data\n          ln -s /nvme/qa_test_models/resource/nltk_data /usr/share/nltk_data\n          execution_mode=\"${{ github.event.inputs.execution_mode || 'both' }}\"\n          ulimit -n 65535\n          if [ \"$execution_mode\" = \"both\" ] || [ \"$execution_mode\" = \"infer\" ]; then\n            pytest autotest/evaluate/test_api_evaluate.py -m \"${{matrix.gpu_num}} and ${{matrix.backend}} and infer\" --alluredir=${{env.REPORT_DIR}} || overall_exit=$?\n          fi\n          if [ \"$execution_mode\" = \"both\" ] || [ \"$execution_mode\" = \"eval\" ]; then\n            pytest autotest/evaluate/test_api_evaluate.py -m \"${{matrix.gpu_num}} and ${{matrix.backend}} and eval\" -n 4 --alluredir=${{env.REPORT_DIR}} || overall_exit=$?\n          fi\n          exit $overall_exit\n      - name: Clear workspace\n        if: always()\n        run: |\n          echo \"status=done\" >> ${{env.REPORT_DIR}}/status.txt\n          chmod -R 777 ${{env.REPORT_DIR}}\n          export workdir=$(pwd)\n          rm -rf $workdir/*\n"
  },
  {
    "path": ".github/workflows/benchmark.yml",
    "content": "name: benchmark_test\n\non:\n  workflow_dispatch:\n    inputs:\n      repo_org:\n        required: false\n        description: 'Tested repository organization name. Default is InternLM'\n        type: string\n        default: 'InternLM/lmdeploy'\n      repo_ref:\n        required: false\n        description: 'Set branch or tag or commit id. Default is \"main\"'\n        type: string\n        default: 'main'\n      benchmark_type:\n        required: true\n        description: 'Set benchmark type. Default is \"[\"longtext\", \"throughput\", \"api_server\", \"prefixcache\"]\"'\n        type: string\n        default: \"['apiserver', 'mllm_apiserver', 'throughput', 'longtext', 'prefixcache']\"\n      backend:\n        required: true\n        description: 'Set backend filter. Default is \"[\"turbomind\", \"pytorch\"]\"'\n        type: string\n        default: \"['turbomind', 'pytorch']\"\n      offline_mode:\n        required: true\n        description: 'Whether start a offline mode, if true, you should prepare code and whl package by yourself'\n        type: boolean\n        default: false\n\nenv:\n  HOST_PIP_CACHE_DIR: /nvme/github-actions/pip-cache\n  HOST_LOCALTIME: /usr/share/zoneinfo/Asia/Shanghai\n  OUTPUT_FOLDER: cuda12.8_dist_${{ github.run_id }}\n  REPORT_DIR: /nvme/qa_test_models/benchmark_report/${{ inputs.repo_ref }}_${{ github.run_id }}\n  ALLURE_REPORT_DIR: /nvme/qa_test_models/benchmark_report/allure_report/${{ inputs.repo_ref }}_${{ github.run_id }}\n  TEST_CODE_PATH: /nvme/qa_test_models/test_pkg/lmdeploy/${{ inputs.repo_ref }}_${{ github.run_id }}\n  OFFLINE_CODE_PATH: /nvme/qa_test_models/offline_pkg/lmdeploy\n  ACTIONS_ALLOW_USE_UNSECURE_NODE_VERSION: true\n  RUN_ID: ${{ inputs.repo_ref }}_${{ github.run_id }}\n\njobs:\n  linux-build:\n    if: ${{github.event_name == 'schedule' || (!cancelled() && !inputs.offline_mode)}}\n    strategy:\n      matrix:\n        pyver: [py310]\n    runs-on: ubuntu-latest\n    env:\n      PYTHON_VERSION: ${{ matrix.pyver }}\n      PLAT_NAME: manylinux2014_x86_64\n      DOCKER_TAG: cuda12.8\n    steps:\n      - name: Free disk space\n        uses: jlumbroso/free-disk-space@main\n        with:\n          # This might remove tools that are actually needed, if set to \"true\" but frees about 6 GB\n          tool-cache: false\n          docker-images: false\n          # All of these default to true, but feel free to set to \"false\" if necessary for your workflow\n          android: true\n          dotnet: true\n          haskell: true\n          large-packages: true\n          swap-storage: false\n      - name: Checkout repository\n        uses: actions/checkout@v3\n        with:\n          repository: ${{ github.event.inputs.repo_org || 'InternLM/lmdeploy' }}\n          ref: ${{github.event.inputs.repo_ref || 'main'}}\n      - name: Build\n        run: |\n          echo ${PYTHON_VERSION}\n          echo ${PLAT_NAME}\n          echo ${DOCKER_TAG}\n          echo ${OUTPUT_FOLDER}\n          echo ${GITHUB_RUN_ID}\n          # remove -it\n          sed -i 's/docker run --rm -it/docker run --rm/g' builder/manywheel/build_wheel.sh\n          bash builder/manywheel/build_wheel.sh ${PYTHON_VERSION} ${PLAT_NAME} ${DOCKER_TAG} ${OUTPUT_FOLDER}\n      - name: Upload Artifacts\n        uses: actions/upload-artifact@v4\n        with:\n          if-no-files-found: error\n          path: builder/manywheel/${{ env.OUTPUT_FOLDER }}\n          retention-days: 1\n          name: my-artifact-${{ github.run_id }}-${{ matrix.pyver }}\n\n  download_pkgs:\n    needs: linux-build\n    if: ${{!cancelled()}}\n    runs-on: [self-hosted, linux-a100]\n    timeout-minutes: 50\n    container:\n      image: openmmlab/lmdeploy:latest-cu12.8\n      options: \"--gpus=all --ipc=host --user root -e PIP_CACHE_DIR=/root/.cache/pip -e NVIDIA_DISABLE_REQUIRE=1 --pull never\"\n      volumes:\n        - /nvme/qa_test_models:/nvme/qa_test_models\n        - /mnt/121:/mnt/121\n        - /mnt/104:/mnt/104\n        - /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro\n    steps:\n      - name: Clone repository\n        uses: actions/checkout@v2\n        if: ${{github.event_name == 'schedule' || !inputs.offline_mode}}\n        with:\n          repository: ${{ github.event.inputs.repo_org || 'InternLM/lmdeploy' }}\n          ref: ${{github.event.inputs.repo_ref || 'main'}}\n      - name: Copy repository\n        if: ${{github.event_name == 'schedule' || !inputs.offline_mode}}\n        run: rm -rf ${{env.TEST_CODE_PATH}} && mkdir ${{env.TEST_CODE_PATH}} && chmod 777 ${{env.TEST_CODE_PATH}} && cp -r . ${{env.TEST_CODE_PATH}}\n      - name: Copy repository - offline\n        if: ${{inputs.offline_mode}}\n        run: rm -rf ${{env.TEST_CODE_PATH}} && mkdir ${{env.TEST_CODE_PATH}} && chmod 777 ${{env.TEST_CODE_PATH}} && cp -r ${{env.OFFLINE_CODE_PATH}}/. ${{env.TEST_CODE_PATH}}\n      - name: Download Artifacts\n        if: ${{github.event_name == 'schedule' || !inputs.offline_mode}}\n        uses: actions/download-artifact@v4\n        with:\n          name: my-artifact-${{ github.run_id }}-py310\n      - name: Copy Artifacts\n        if: ${{github.event_name == 'schedule' || !inputs.offline_mode}}\n        run: rm ${{env.TEST_CODE_PATH}}/lmdeploy-*.whl -f && cp lmdeploy-*.whl ${{env.TEST_CODE_PATH}}\n      - name: Copy Artifacts - offline\n        if: ${{inputs.offline_mode}}\n        run: rm ${{env.TEST_CODE_PATH}}/lmdeploy-*.whl -f && cp ${{env.OFFLINE_CODE_PATH}}/lmdeploy-*.whl ${{env.TEST_CODE_PATH}}\n      - name: Mark as start\n        run: |\n          chmod -R 777 ${{env.TEST_CODE_PATH}}\n          mkdir ${{env.REPORT_DIR}} -p\n          echo \"starttime=$(date +%s)\" > ${{env.REPORT_DIR}}/status.txt\n\n  benchmark:\n    needs: download_pkgs\n    if: ${{github.event_name == 'schedule' || !cancelled()}}\n    runs-on: [self-hosted, linux-a100]\n    strategy:\n      fail-fast: false\n      matrix:\n        benchmark_type: ${{fromJSON(github.event.inputs.benchmark_type)}}\n        gpu_num: ['gpu_num_1', 'gpu_num_2', 'gpu_num_4', 'gpu_num_8']\n        transformers: [\"\", \"legacy\"]\n        include:\n          - n: 8\n            gpu_num: gpu_num_1\n          - n: 4\n            gpu_num: gpu_num_2\n          - n: 2\n            gpu_num: gpu_num_4\n          - n: 1\n            gpu_num: gpu_num_8\n    env:\n      TEST_ENV: ${{ matrix.transformers }}\n    timeout-minutes: 480\n    container:\n      image: openmmlab/lmdeploy:latest-cu12.8\n      options: \"--gpus=all --ipc=host --user root -e PIP_CACHE_DIR=/root/.cache/pip -e NVIDIA_DISABLE_REQUIRE=1 --pull never\"\n      volumes:\n        - /nvme/github-actions/pip-cache:/root/.cache/pip\n        - /nvme/github-actions/packages:/root/packages\n        - /nvme/qa_test_models:/nvme/qa_test_models\n        - /nvme/huggingface_hub:/nvme/huggingface_hub\n        - /mnt/121:/mnt/121\n        - /mnt/104:/mnt/104\n        - /mnt/bigdisk:/mnt/bigdisk\n        - /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro\n    steps:\n      - name: Copy repository and Artifacts\n        run: |\n          cp -r ${{env.TEST_CODE_PATH}}/. .\n          mkdir ${{env.REPORT_DIR}} -p\n          echo \"starttime=$(date +%s)\" > ${{env.REPORT_DIR}}/status.txt\n      - name: Install lmdeploy - dependency\n        run: |\n          python3 -m pip install -r /nvme/qa_test_models/offline_pkg/requirements.txt\n      - name: Install lmdeploy\n        run: |\n          python3 -m pip uninstall lmdeploy -y && python3 -m pip install lmdeploy-*.whl --no-deps\n          python3 -m pip install -r requirements/test.txt\n      - name: Downgrade transformers\n        if: ${{matrix.transformers == 'legacy'}}\n        run: |\n          pip install transformers==4.57.6\n      - name: Check env\n        run: |\n          python3 -m pip list\n          lmdeploy check_env\n      - name: Run other benchmark - all\n        if: contains(fromJson(github.event.inputs.backend), 'turbomind') && contains(fromJson(github.event.inputs.backend), 'pytorch')\n        run: |\n            pytest autotest/benchmark/test_${{matrix.benchmark_type}}_performance.py -n ${{matrix.n}} -m '${{matrix.gpu_num}} and not pr_test and not function' --alluredir=${{env.ALLURE_REPORT_DIR}}\n      - name: Run other benchmark - turbomind\n        if: contains(fromJson(github.event.inputs.backend), 'turbomind') && !contains(fromJson(github.event.inputs.backend), 'pytorch')\n        run: |\n            pytest autotest/benchmark/test_${{matrix.benchmark_type}}_performance.py -n ${{matrix.n}} -m '${{matrix.gpu_num}} and not pr_test and not function and turbomind' --alluredir=${{env.ALLURE_REPORT_DIR}}\n      - name: Run other benchmark - pytorch\n        if: contains(fromJson(github.event.inputs.backend), 'pytorch') && !contains(fromJson(github.event.inputs.backend), 'turbomind')\n        run: |\n            pytest autotest/benchmark/test_${{matrix.benchmark_type}}_performance.py -n ${{matrix.n}} -m '${{matrix.gpu_num}} and not pr_test and not function and pytorch' --alluredir=${{env.ALLURE_REPORT_DIR}}\n      - name: Clear workfile\n        if: always()\n        run: |\n          echo \"status=done\" >> ${{env.REPORT_DIR}}/status.txt\n          chmod -R 777 $REPORT_DIR\n          export workdir=$(pwd)\n          cd ..\n          rm -rf $workdir\n          mkdir $workdir\n          chmod -R 777 $workdir\n"
  },
  {
    "path": ".github/workflows/cuda12.8_whl_release.yml",
    "content": "name: cuda12.8-whl-release\n\non:\n  push:\n    tags:\n      - '*'\n  workflow_dispatch:\n\npermissions:\n  contents: write\n\njobs:\n  linux-build:\n    strategy:\n      matrix:\n        pyver: [py310, py311, py312, py313]\n    runs-on: ubuntu-latest\n    env:\n      PYTHON_VERSION: ${{ matrix.pyver }}\n      PLAT_NAME: manylinux2014_x86_64\n      DOCKER_TAG: cuda12.8\n      OUTPUT_FOLDER: cuda12.8_dist\n      CUDA_VER: 12.8\n    steps:\n      - name: Free disk space\n        uses: jlumbroso/free-disk-space@main\n        with:\n          # This might remove tools that are actually needed, if set to \"true\" but frees about 6 GB\n          tool-cache: false\n          docker-images: false\n          # All of these default to true, but feel free to set to \"false\" if necessary for your workflow\n          android: true\n          dotnet: true\n          haskell: true\n          large-packages: true\n          swap-storage: false\n      - name: Checkout repository\n        uses: actions/checkout@v3\n      - name: Build\n        run: |\n          echo ${PYTHON_VERSION}\n          echo ${PLAT_NAME}\n          echo ${DOCKER_TAG}\n          echo ${OUTPUT_FOLDER}\n          # remove -it\n          sed -i 's/docker run --rm -it/docker run --rm/g' builder/manywheel/build_wheel.sh\n          bash builder/manywheel/build_wheel.sh ${PYTHON_VERSION} ${PLAT_NAME} ${DOCKER_TAG} ${OUTPUT_FOLDER}\n      - name: Upload Artifacts\n        uses: actions/upload-artifact@v4\n        with:\n          if-no-files-found: error\n          path: builder/manywheel/${{ env.OUTPUT_FOLDER }}/*\n          retention-days: 1\n          name: linux-${{ matrix.pyver }}\n\n  windows-build:\n    strategy:\n      matrix:\n        pyver: ['3.10', '3.11', '3.12', '3.13']\n    runs-on: windows-latest\n    steps:\n      - name: Set git for windows\n        run: |\n          git config --global core.longpaths true\n      - name: Checkout repository\n        uses: actions/checkout@v3\n      - name: Set up python\n        uses: actions/setup-python@v4\n        with:\n          python-version: ${{ matrix.pyver }}\n      - name: Install python packages\n        run: |\n          pip install build change-wheel-version\n      - name: Setup CUDA Toolkit\n        id: cuda-toolkit\n        shell: pwsh\n        run: ./builder/windows/setup_cuda.ps1\n        env:\n            INPUT_CUDA_VERSION: '12.8.1'\n      - name: Build wheel\n        run: |\n          python -m build --wheel -o build/wheel\n          Get-ChildItem -Path \"build\" -Filter \"*.whl\" | ForEach-Object { change_wheel_version $_.FullName --local-version cu128 --delete-old-wheel }\n      - name: Upload Artifacts\n        uses: actions/upload-artifact@v4\n        with:\n          if-no-files-found: error\n          path: build/wheel/*\n          retention-days: 1\n          name: windows-${{ matrix.pyver }}\n\n  publish:\n    runs-on: ubuntu-latest\n    environment: 'prod'\n    needs:\n      - linux-build\n      - windows-build\n    steps:\n      - name: Checkout repository\n        uses: actions/checkout@v3\n      - name: Download artifacts\n        uses: actions/download-artifact@v4\n        with:\n          path: artifact\n          merge-multiple: true\n      - name: Add cuda version to package name\n        run: |\n          ver=$(cat lmdeploy/version.py | grep '__version__ =' | cut -d\\' -f2)\n          cuver=$ver+cu128\n          ls -lh\n          cd artifact\n          for file in *; do\n            mv \"$file\" \"`echo $file | sed \"s/$ver/$cuver/g\"`\";\n          done\n      - name: Display artifacts\n        run: ls artifact/ -lh\n      - name: Publish\n        uses: softprops/action-gh-release@v1\n        if: startsWith(github.ref, 'refs/tags/')\n        with:\n          files: artifact/*\n"
  },
  {
    "path": ".github/workflows/daily_ete_test.yml",
    "content": "name: daily_ete_test\n\non:\n  workflow_dispatch:\n    inputs:\n      repo_org:\n        required: false\n        description: 'Tested repository organization name. Default is InternLM'\n        type: string\n        default: 'InternLM/lmdeploy'\n      repo_ref:\n        required: false\n        description: 'Set branch or tag or commit id. Default is \"main\"'\n        type: string\n        default: 'main'\n      backend:\n        required: true\n        description: 'Set backend filter. Default is \"[\"turbomind\", \"pytorch\"]\"'\n        type: string\n        default: \"['turbomind', 'pytorch']\"\n      model:\n        required: true\n        description: 'Set testcase module filter: llm, mllm. Default contains all models'\n        type: string\n        default: \"['llm','mllm']\"\n      function:\n        required: true\n        description: 'Set testcase function filter: chat, restful, pipeline. Default contains all functions'\n        type: string\n        default: '[\"pipeline\", \"restful\", \"chat\"]'\n      offline_mode:\n        required: true\n        description: 'Whether start a offline mode, if true, you should prepare code and whl package by yourself'\n        type: boolean\n        default: false\n      regression_func:\n        required: true\n        description: 'regression functions'\n        type: string\n        default: \"['quant', 'tools','restful','pipeline','benchmark','evaluation']\"\n  schedule:\n    - cron:  '00 14 * * 0-4'\n\nenv:\n  HOST_PIP_CACHE_DIR: /nvme/github-actions/pip-cache\n  HOST_LOCALTIME: /usr/share/zoneinfo/Asia/Shanghai\n  OUTPUT_FOLDER: cuda12.8_dist_${{ github.run_id }}\n  ACTIONS_ALLOW_USE_UNSECURE_NODE_VERSION: true\n  ROOT_DIR: /nvme/qa_test_models\n  REPORT_DIR: /nvme/qa_test_models/test-reports/${{ inputs.repo_ref || 'main' }}_${{ github.run_id }}\n  COV_PARAM: --cov /opt/py3/lib/python3.10/site-packages/lmdeploy\n  TEST_CODE_PATH: /nvme/qa_test_models/test_pkg/lmdeploy/${{ inputs.repo_ref || 'main' }}_${{ github.run_id }}\n  OFFLINE_CODE_PATH: /nvme/qa_test_models/offline_pkg/lmdeploy\n  OFFLINE_REQUIREMENTS: /nvme/qa_test_models/offline_pkg/requirements.txt\n  DEEPSEEK_VL: /nvme/qa_test_models/offline_pkg/DeepSeek-VL\n  RUN_ID: ${{ inputs.repo_ref || 'main' }}_${{ github.run_id }}\n\njobs:\n  linux-build:\n    if: ${{!cancelled() && (github.event_name == 'schedule' || !inputs.offline_mode)}}\n    strategy:\n      matrix:\n        pyver: [py310]\n    runs-on: ubuntu-latest\n    env:\n      PYTHON_VERSION: ${{ matrix.pyver }}\n      PLAT_NAME: manylinux2014_x86_64\n      DOCKER_TAG: cuda12.8\n    steps:\n      - name: Free disk space\n        uses: jlumbroso/free-disk-space@main\n        with:\n          # This might remove tools that are actually needed, if set to \"true\" but frees about 6 GB\n          tool-cache: false\n          docker-images: false\n          # All of these default to true, but feel free to set to \"false\" if necessary for your workflow\n          android: true\n          dotnet: true\n          haskell: true\n          large-packages: true\n          swap-storage: false\n      - name: Checkout repository\n        uses: actions/checkout@v3\n        with:\n          repository: ${{ github.event.inputs.repo_org || 'InternLM/lmdeploy' }}\n          ref: ${{github.event.inputs.repo_ref || 'main'}}\n      - name: Build\n        run: |\n          echo ${PYTHON_VERSION}\n          echo ${PLAT_NAME}\n          echo ${DOCKER_TAG}\n          echo ${OUTPUT_FOLDER}\n          echo ${GITHUB_RUN_ID}\n          # remove -it\n          sed -i 's/docker run --rm -it/docker run --rm/g' builder/manywheel/build_wheel.sh\n          bash builder/manywheel/build_wheel.sh ${PYTHON_VERSION} ${PLAT_NAME} ${DOCKER_TAG} ${OUTPUT_FOLDER}\n      - name: Upload Artifacts\n        uses: actions/upload-artifact@v4\n        with:\n          if-no-files-found: error\n          path: builder/manywheel/${{ env.OUTPUT_FOLDER }}\n          retention-days: 1\n          name: my-artifact-${{ github.run_id }}-${{ matrix.pyver }}\n\n\n  download_pkgs:\n    needs: linux-build\n    if: ${{!cancelled()}}\n    runs-on: [self-hosted, linux-a100]\n    timeout-minutes: 50\n    container:\n      image: openmmlab/lmdeploy:latest-cu12.8\n      options: \"--gpus=all --ipc=host --user root -e PIP_CACHE_DIR=/root/.cache/pip -e NVIDIA_DISABLE_REQUIRE=1 --pull never\"\n      volumes:\n        - /nvme/qa_test_models:/nvme/qa_test_models\n        - /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro\n    steps:\n      - name: Clone repository\n        uses: actions/checkout@v2\n        if: ${{github.event_name == 'schedule' || !inputs.offline_mode}}\n        with:\n          repository: ${{ github.event.inputs.repo_org || 'InternLM/lmdeploy' }}\n          ref: ${{github.event.inputs.repo_ref || 'main'}}\n      - name: Copy repository\n        if: ${{github.event_name == 'schedule' || !inputs.offline_mode}}\n        run: rm -rf ${{env.TEST_CODE_PATH}} && mkdir ${{env.TEST_CODE_PATH}} && chmod 777 ${{env.TEST_CODE_PATH}} && cp -r . ${{env.TEST_CODE_PATH}}\n      - name: Copy repository - offline\n        if: ${{inputs.offline_mode}}\n        run: rm -rf ${{env.TEST_CODE_PATH}} && mkdir ${{env.TEST_CODE_PATH}} && chmod 777 ${{env.TEST_CODE_PATH}} && cp -r ${{env.OFFLINE_CODE_PATH}}/. ${{env.TEST_CODE_PATH}}\n      - name: Download Artifacts\n        if: ${{github.event_name == 'schedule' || !inputs.offline_mode}}\n        uses: actions/download-artifact@v4\n        with:\n          name: my-artifact-${{ github.run_id }}-py310\n      - name: Copy Artifacts\n        if: ${{github.event_name == 'schedule' || !inputs.offline_mode}}\n        run: rm ${{env.TEST_CODE_PATH}}/lmdeploy-*.whl -f && cp lmdeploy-*.whl ${{env.TEST_CODE_PATH}}\n      - name: Copy Artifacts - offline\n        if: ${{inputs.offline_mode}}\n        run: rm ${{env.TEST_CODE_PATH}}/lmdeploy-*.whl -f && cp ${{env.OFFLINE_CODE_PATH}}/lmdeploy-*.whl ${{env.TEST_CODE_PATH}}\n      - name: Mark as start\n        run: |\n          chmod -R 777 ${{env.TEST_CODE_PATH}}\n          mkdir ${{env.REPORT_DIR}} -p\n          echo \"starttime=$(date +%s)\" > ${{env.REPORT_DIR}}/status.txt\n\n  test_quantization:\n    needs: download_pkgs\n    if: ${{!cancelled() && (github.event_name == 'schedule' || contains(fromJSON(github.event.inputs.regression_func), 'quant') )}}\n    runs-on: [self-hosted, linux-a100]\n    timeout-minutes: 150\n    strategy:\n      matrix:\n        transformers: [\"\", \"legacy\"]\n    env:\n      PYTHONPATH: /nvme/qa_test_models/offline_pkg/LLaVA\n      MODELSCOPE_CACHE: /nvme/qa_test_models/modelscope_hub\n      MODELSCOPE_MODULES_CACHE: /nvme/qa_test_models/modelscope_modules\n      TEST_ENV: ${{ matrix.transformers }}\n    container:\n      image: openmmlab/lmdeploy:latest-cu12.8\n      options: \"--gpus=all --ipc=host --user root -e PIP_CACHE_DIR=/root/.cache/pip -e NVIDIA_DISABLE_REQUIRE=1 --pull never\"\n      volumes:\n        - /nvme/github-actions/pip-cache:/root/.cache/pip\n        - /nvme/github-actions/packages:/root/packages\n        - /nvme/qa_test_models:/nvme/qa_test_models\n        - /nvme/huggingface_hub:/nvme/huggingface_hub\n        - /mnt/121:/mnt/121\n        - /mnt/104:/mnt/104\n        - /mnt/bigdisk:/mnt/bigdisk\n        - /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro\n    steps:\n      - name: Copy repository and Artifacts\n        run: |\n          cp -r ${{env.TEST_CODE_PATH}}/. .\n          mkdir ${{env.REPORT_DIR}} -p\n          echo \"starttime=$(date +%s)\" > ${{env.REPORT_DIR}}/status.txt\n      - name: Install lmdeploy - dependency\n        run: |\n          python3 -m pip install auto_gptq matplotlib attrdict\n          python3 -m pip install -r requirements/lite.txt\n      - name: Install lmdeploy\n        run: |\n          python3 -m pip uninstall lmdeploy -y && python3 -m pip install lmdeploy-*.whl --no-deps\n          python3 -m pip install -r requirements/test.txt\n          pip install ${{env.DEEPSEEK_VL}} --no-deps\n          rm -rf ${{env.DEEPSEEK_VL}}/build\n      - name: Check env\n        run: |\n          pip install transformers==4.57.6\n          python3 -m pip list\n          lmdeploy check_env\n          rm -rf allure-results\n          # remove tmp log in testcase\n          mkdir ${{env.REPORT_DIR}}/.pytest_cache -p && rm autotest/.pytest_cache -f\n          ln -s ${{env.REPORT_DIR}}/.pytest_cache autotest\n      - name: Test lmdeploy - quantization w4a16\n        continue-on-error: true\n        if: github.event_name == 'schedule' || contains(fromJSON(github.event.inputs.backend), 'turbomind')\n        run: |\n          pytest autotest/tools/quantization/test_quantization_awq.py -m 'not pr_test' -n 8 --alluredir=${{env.REPORT_DIR}} --clean-alluredir ${{env.COV_PARAM}} || true\n          mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S')\n      - name: Test lmdeploy - quantization w8a8\n        continue-on-error: true\n        if: github.event_name == 'schedule' || contains(fromJSON(github.event.inputs.backend), 'pytorch')\n        run: |\n          pytest autotest/tools/quantization/test_quantization_w8a8.py -n 8 --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true\n          mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S')\n      - name: Clear workfile\n        if: always()\n        run: |\n          echo \"status=done\" >> ${{env.REPORT_DIR}}/status.txt\n          chmod -R 777 ${{env.ROOT_DIR}}\n          export workdir=$(pwd)\n          cd ..\n          rm -rf $workdir\n          mkdir $workdir\n          chmod -R 777 $workdir\n\n  test_tools:\n    if: ${{!cancelled() && (github.event_name == 'schedule' || contains(fromJSON(github.event.inputs.regression_func), 'tools'))}}\n    runs-on: [self-hosted, linux-a100]\n    needs: test_quantization\n    timeout-minutes: 300\n    strategy:\n      fail-fast: false\n      matrix:\n        backend: ${{ fromJSON(inputs.backend || '[\"turbomind\", \"pytorch\"]')}}\n        model: ${{ fromJSON(inputs.model || '[\"llm\", \"mllm\"]')}}\n        transformers: [\"\", \"legacy\"]\n        function: ${{ fromJSON(inputs.function || '[\"pipeline\",\"restful\",\"chat\"]')}}\n        exclude:\n          - backend: turbomind\n            model: mllm\n            function: chat\n          - backend: pytorch\n            model: mllm\n            function: chat\n        include:\n          - backend: turbomind\n            model: llm\n            function: other\n    env:\n      PYTHONPATH: /nvme/qa_test_models/offline_pkg/LLaVA\n      MODELSCOPE_CACHE: /nvme/qa_test_models/modelscope_hub\n      MODELSCOPE_MODULES_CACHE: /nvme/qa_test_models/modelscope_modules\n      TEST_ENV: ${{ matrix.transformers }}\n    container:\n      image: openmmlab/lmdeploy:latest-cu12.8\n      options: \"--gpus=all --ipc=host --user root -e PIP_CACHE_DIR=/root/.cache/pip -e NVIDIA_DISABLE_REQUIRE=1 --pull never\"\n      volumes:\n        - /nvme/github-actions/pip-cache:/root/.cache/pip\n        - /nvme/github-actions/packages:/root/packages\n        - /nvme/qa_test_models:/nvme/qa_test_models\n        - /nvme/huggingface_hub:/nvme/huggingface_hub\n        - /mnt/121:/mnt/121\n        - /mnt/104:/mnt/104\n        - /mnt/bigdisk:/mnt/bigdisk\n        - /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro\n    steps:\n      - name: Copy repository and Artifacts\n        run: |\n          cp -r ${{env.TEST_CODE_PATH}}/. .\n          mkdir ${{env.REPORT_DIR}} -p\n          echo \"starttime=$(date +%s)\" > ${{env.REPORT_DIR}}/status.txt\n      - name: Install lmdeploy - dependency\n        run: |\n          python3 -m pip install -r ${{env.OFFLINE_REQUIREMENTS}}\n      - name: Install lmdeploy\n        run: |\n          python3 -m pip uninstall lmdeploy -y && python3 -m pip install lmdeploy-*.whl --no-deps\n          python3 -m pip install -r requirements/test.txt\n          pip install ${{env.DEEPSEEK_VL}} --no-deps\n          rm -rf ${{env.DEEPSEEK_VL}}/build\n      - name: Downgrade transformers\n        if: ${{matrix.transformers == 'legacy'}}\n        run: |\n          pip install transformers==4.57.6\n      - name: Check env\n        run: |\n          python3 -m pip list\n          lmdeploy check_env\n          cp -r /nvme/qa_test_models/offline_pkg/lora .\n          rm -rf allure-results\n          # remove tmp log in testcase\n          mkdir ${{env.REPORT_DIR}}/.pytest_cache -p && rm autotest/.pytest_cache -f\n          ln -s ${{env.REPORT_DIR}}/.pytest_cache autotest\n      - name: Test lmdeploy - chat\n        continue-on-error: true\n        if: (matrix.backend == 'pytorch' || matrix.backend == 'turbomind') && matrix.model == 'llm' && matrix.function == 'chat'\n        run: |\n          pytest autotest/tools/chat/test_command_chat_hf_${{matrix.backend}}.py -m 'gpu_num_1 and not pr_test' -n 8 --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true\n          mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') || true\n          pytest autotest/tools/chat/test_command_chat_hf_${{matrix.backend}}.py -m 'gpu_num_2 and not pr_test' -n 4 --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true\n          mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S')\n          pytest autotest/tools/chat/test_command_chat_hf_${{matrix.backend}}.py -m 'gpu_num_4 and not pr_test' -n 2 --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true\n          mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S')\n          pytest autotest/tools/chat/test_command_chat_hf_${{matrix.backend}}.py -m 'gpu_num_8 and not pr_test' --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true\n          mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S')\n      - name: Test lmdeploy - pipeline\n        continue-on-error: true\n        if: matrix.function == 'pipeline'\n        run: |\n          pytest autotest/tools/pipeline/test_pipeline_chat_${{matrix.backend}}_${{matrix.model}}.py -m 'gpu_num_1 and not pr_test' -n 8 --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true\n          mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') || true\n          pytest autotest/tools/pipeline/test_pipeline_chat_${{matrix.backend}}_${{matrix.model}}.py -m 'gpu_num_2 and not pr_test' -n 4 --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true\n          mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S')\n          pytest autotest/tools/pipeline/test_pipeline_chat_${{matrix.backend}}_${{matrix.model}}.py -m 'gpu_num_4 and not pr_test' -n 2 --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true\n          mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S')\n          pytest autotest/tools/pipeline/test_pipeline_chat_${{matrix.backend}}_${{matrix.model}}.py -m 'gpu_num_8 and not pr_test' --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true\n          mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S')\n      - name: Test lmdeploy - restful\n        continue-on-error: true\n        if: matrix.function == 'restful'\n        run: |\n          pytest autotest/tools/restful/test_restful_chat_hf_${{matrix.backend}}_${{matrix.model}}.py -m 'gpu_num_1 and not pr_test' -n 8 --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true\n          mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') || true\n          pytest autotest/tools/restful/test_restful_chat_hf_${{matrix.backend}}_${{matrix.model}}.py -m 'gpu_num_2 and not pr_test' -n 4 --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true\n          mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S')\n          pytest autotest/tools/restful/test_restful_chat_hf_${{matrix.backend}}_${{matrix.model}}.py -m 'gpu_num_4 and not pr_test' -n 2 --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true\n          mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S')\n          pytest autotest/tools/restful/test_restful_chat_hf_${{matrix.backend}}_${{matrix.model}}.py -m 'gpu_num_8 and not pr_test' --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true\n          mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S')\n      - name: Test lmdeploy - local testcase\n        if: matrix.backend == 'turbomind' && matrix.model == 'llm' && matrix.function == 'other'\n        run: |\n          pytest autotest/toolchain --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true\n          mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S')\n      - name: Clear workfile\n        if: always()\n        run: |\n          echo \"status=done\" >> ${{env.REPORT_DIR}}/status.txt\n          chmod -R 777 ${{env.ROOT_DIR}}\n          export workdir=$(pwd)\n          cd ..\n          rm -rf $workdir\n          mkdir $workdir\n          chmod -R 777 $workdir\n\n  test_restful:\n    if: ${{!cancelled() && (github.event_name == 'schedule' || contains(fromJSON(github.event.inputs.regression_func), 'restful'))}}\n    runs-on: [self-hosted, linux-a100]\n    needs: test_quantization\n    strategy:\n      fail-fast: false\n      matrix:\n        backend: ${{ fromJSON(inputs.backend || '[\"turbomind\", \"pytorch\"]')}}\n        model_path: ['Qwen/Qwen3-8B-Base', 'Qwen/Qwen3-30B-A3B', 'Qwen/Qwen3-32B', 'OpenGVLab/InternVL3_5-30B-A3B', 'OpenGVLab/InternVL3-38B', 'Qwen/Qwen3-VL-8B-Instruct', 'Qwen/Qwen3-VL-30B-A3B-Instruct']\n        include:\n          - tp: 2\n            model: Qwen3-8B-Base\n            model_path: Qwen/Qwen3-8B-Base\n            case_info: ['completions_v1']\n            generate_type: base\n          - tp: 2\n            model: Qwen3-30B-A3B\n            model_path: Qwen/Qwen3-30B-A3B\n            case_info: ['chat_completions_v1', 'generate']\n            generate_type: all\n            extra: '--logprobs-mode raw_logprobs --enable-return-routed-experts'\n            backend: pytorch\n          - tp: 2\n            model: Qwen3-30B-A3B\n            model_path: Qwen/Qwen3-30B-A3B\n            case_info: ['chat_completions_v1', 'generate']\n            generate_type: logprob\n            extra: '--logprobs-mode raw_logprobs'\n            backend: turbomind\n          - tp: 2\n            model: InternVL3_5-30B-A3B\n            model_path: OpenGVLab/InternVL3_5-30B-A3B\n            case_info: ['chat_completions_v1', 'generate']\n            generate_type: logprob\n            extra: '--logprobs-mode raw_logprobs --enable-return-routed-experts'\n            backend: pytorch\n          - tp: 2\n            model: InternVL3_5-30B-A3B\n            model_path: OpenGVLab/InternVL3_5-30B-A3B\n            case_info: ['chat_completions_v1', 'generate']\n            generate_type: logprob\n            extra: '--logprobs-mode raw_logprobs'\n            backend: turbomind\n          - tp: 2\n            model: Qwen3-VL-30B-A3B-Instruct\n            model_path: Qwen/Qwen3-VL-30B-A3B-Instruct\n            case_info: ['chat_completions_v1', 'generate']\n            generate_type: logprob\n            extra: '--logprobs-mode raw_logprobs --enable-return-routed-experts'\n            backend: pytorch\n          - tp: 2\n            model: Qwen3-VL-30B-A3B-Instruct\n            model_path: Qwen/Qwen3-VL-30B-A3B-Instruct\n            case_info: ['chat_completions_v1', 'generate']\n            generate_type: logprob\n            extra: '--logprobs-mode raw_logprobs'\n            backend: turbomind\n          - tp: 2\n            model: Qwen3-32B\n            model_path: Qwen/Qwen3-32B\n            case_info: ['chat_completions_v1', 'generate']\n            generate_type: logprob\n            extra: '--logprobs-mode raw_logprobs'\n          - tp: 1\n            model: Qwen3-VL-8B-Instruct\n            model_path: Qwen/Qwen3-VL-8B-Instruct\n            case_info: ['chat_completions_v1', 'generate']\n            generate_type: logprob\n            extra: '--logprobs-mode raw_logprobs --enable-return-routed-experts'\n            backend: pytorch\n          - tp: 1\n            model: Qwen3-VL-8B-Instruct\n            model_path: Qwen/Qwen3-VL-8B-Instruct\n            case_info: ['chat_completions_v1', 'generate']\n            generate_type: logprob\n            extra: '--logprobs-mode raw_logprobs'\n            backend: turbomind\n          - tp: 2\n            model: InternVL3-38B\n            model_path: OpenGVLab/InternVL3-38B\n            case_info: ['chat_completions_v1', 'generate']\n            generate_type: logprob\n            extra: '--logprobs-mode raw_logprobs'\n    timeout-minutes: 60\n    container:\n      image: openmmlab/lmdeploy:latest-cu12.8\n      options: \"--gpus=all --ipc=host --user root -e PIP_CACHE_DIR=/root/.cache/pip -e NVIDIA_DISABLE_REQUIRE=1 --pull never\"\n      volumes:\n        - /nvme/github-actions/pip-cache:/root/.cache/pip\n        - /nvme/github-actions/packages:/root/packages\n        - /nvme/qa_test_models:/nvme/qa_test_models\n        - /nvme/huggingface_hub:/nvme/huggingface_hub\n        - /mnt/121:/mnt/121\n        - /mnt/104:/mnt/104\n        - /mnt/bigdisk:/mnt/bigdisk\n        - /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro\n    steps:\n      - name: Copy repository and Artifacts\n        run: |\n          cp -r ${{env.TEST_CODE_PATH}}/. .\n          mkdir ${{env.REPORT_DIR}} -p\n          echo \"starttime=$(date +%s)\" > ${{env.REPORT_DIR}}/status.txt\n      - name: Install lmdeploy - dependency\n        run: |\n          python3 -m pip install -r ${{env.OFFLINE_REQUIREMENTS}}\n      - name: Install lmdeploy\n        run: |\n          python3 -m pip uninstall lmdeploy -y && python3 -m pip install lmdeploy-*.whl --no-deps\n          python3 -m pip install -r requirements/test.txt\n      - name: Check env\n        run: |\n          python3 -m pip list\n          lmdeploy check_env\n          rm -rf allure-results\n          # remove tmp log in testcase\n          mkdir ${{env.REPORT_DIR}}/.pytest_cache -p && rm autotest/.pytest_cache -f\n          ln -s ${{env.REPORT_DIR}}/.pytest_cache autotest\n      - name: Start restful api\n        run: |\n          lmdeploy serve api_server /nvme/qa_test_models/${{matrix.model_path}} --tp ${{matrix.tp}} --backend ${{matrix.backend}} ${{matrix.extra}} --allow-terminate-by-client > ${{env.REPORT_DIR}}/${{matrix.backend}}_${{matrix.model}}_${{matrix.generate_type}}_start_restful.log 2>&1 &\n          echo \"restful_pid=$!\"\n          for i in $(seq 1 240)\n          do\n            sleep 5\n            echo \"health check try $i\"\n            if curl -f -s http://127.0.0.1:23333/health > /dev/null 2>&1; then\n              echo \"health check success\"\n              exit 0\n            fi\n          done\n\n          echo \"health check fail\"\n          curl -f -s http://127.0.0.1:23333/terminate > /dev/null 2>&1\n          exit 1\n      - name: Test lmdeploy - chat_completions_v1\n        if:  matrix.model != 'internlm2_5-20b-chat' && matrix.model != 'Intern-S1' && contains(matrix.case_info, 'chat_completions_v1')\n        timeout-minutes: 60\n        run: |\n          pytest autotest/interface/restful/test_restful_chat_completions_v1.py -n 20 -k '${{matrix.model_path}} and ${{matrix.backend}}' -m 'not not_${{matrix.backend}} and not internlm2_5 and not interns1' --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true\n          mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S')\n      - name: Test lmdeploy - chat_completions_v1\n        if: matrix.model == 'Intern-S1' && contains(matrix.case_info, 'chat_completions_v1')\n        timeout-minutes: 60\n        run: |\n          pytest autotest/interface/restful/test_restful_chat_completions_v1.py -n 20 -k '${{matrix.model_path}} and ${{matrix.backend}}' -m 'not not_${{matrix.backend}} and not internlm2_5' --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true\n          mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S')\n      - name: Test lmdeploy - chat_completions_v1 - internlm2_5-20b-chat\n        if:  matrix.model == 'internlm2_5-20b-chat' && contains(matrix.case_info, 'chat_completions_v1')\n        timeout-minutes: 60\n        run: |\n          pytest autotest/interface/restful/test_restful_chat_completions_v1.py -n 20 -k '${{matrix.model_path}} and ${{matrix.backend}}' -m 'not not_${{matrix.backend}} and not interns1' --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true\n          mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S')\n      - name: Test lmdeploy - completions_v1 - internlm2_5-20b\n        if: matrix.model == 'internlm2_5-20b' && contains(matrix.case_info, 'completions_v1')\n        timeout-minutes: 60\n        run: |\n          pytest autotest/interface/restful/test_restful_completions_v1.py -n 20 -k '${{matrix.model_path}} and ${{matrix.backend}}' --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true\n          mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S')\n      - name: Test lmdeploy - completions_v1 - other\n        if: matrix.model != 'internlm2_5-20b' && contains(matrix.case_info, 'completions_v1')\n        timeout-minutes: 60\n        run: |\n          pytest autotest/interface/restful/test_restful_completions_v1.py -n 20 -k '${{matrix.model_path}} and ${{matrix.backend}} and not internlm2_5' --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true\n          mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S')\n      - name: Test generate - base\n        if:  matrix.generate_type == 'base' && contains(matrix.case_info, 'generate')\n        timeout-minutes: 60\n        run: |\n          pytest autotest/interface/restful/test_restful_generate.py -n 20 -k '${{matrix.model_path}} and ${{matrix.backend}}' -m 'not not_${{matrix.backend}} and not logprob and not experts' --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true\n          mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S')\n      - name: Test generate - logprob\n        if:  matrix.generate_type == 'logprob' && contains(matrix.case_info, 'generate')\n        timeout-minutes: 60\n        run: |\n          pytest autotest/interface/restful/test_restful_generate.py -n 20 -k '${{matrix.model_path}} and ${{matrix.backend}}' -m 'not not_${{matrix.backend}} and not experts' --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true\n          mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S')\n      - name: Test generate - all\n        if:  matrix.generate_type == 'all' && contains(matrix.case_info, 'generate')\n        timeout-minutes: 60\n        run: |\n          pytest autotest/interface/restful/test_restful_generate.py -n 20 -k '${{matrix.model_path}} and ${{matrix.backend}}' -m 'not not_${{matrix.backend}}' --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true\n          mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S')\n      - name: Kill api server\n        if: always()\n        run: |\n          curl -f -s http://127.0.0.1:23333/terminate > /dev/null 2>&1\n      - name: Clear workfile\n        if: always()\n        run: |\n          echo \"status=done\" >> ${{env.REPORT_DIR}}/status.txt\n          chmod -R 777 ${{env.ROOT_DIR}}\n          export workdir=$(pwd)\n          cd ..\n          rm -rf $workdir\n          mkdir $workdir\n          chmod -R 777 $workdir\n\n  test_pipeline:\n    if: ${{!cancelled() && (github.event_name == 'schedule' || contains(fromJSON(github.event.inputs.regression_func), 'pipeline'))}}\n    runs-on: [self-hosted, linux-a100]\n    needs: test_quantization\n    timeout-minutes: 240\n    container:\n      image: openmmlab/lmdeploy:latest-cu12.8\n      options: \"--gpus=all --ipc=host --user root -e PIP_CACHE_DIR=/root/.cache/pip -e NVIDIA_DISABLE_REQUIRE=1 --pull never\"\n      volumes:\n        - /nvme/github-actions/pip-cache:/root/.cache/pip\n        - /nvme/github-actions/packages:/root/packages\n        - /nvme/qa_test_models:/nvme/qa_test_models\n        - /nvme/huggingface_hub:/nvme/huggingface_hub\n        - /mnt/121:/mnt/121\n        - /mnt/104:/mnt/104\n        - /mnt/bigdisk:/mnt/bigdisk\n        - /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro\n    steps:\n      - name: Copy repository and Artifacts\n        run: |\n          cp -r ${{env.TEST_CODE_PATH}}/. .\n          mkdir ${{env.REPORT_DIR}} -p\n          echo \"starttime=$(date +%s)\" > ${{env.REPORT_DIR}}/status.txt\n      - name: Install lmdeploy - dependency\n        run: |\n          python3 -m pip install -r ${{env.OFFLINE_REQUIREMENTS}}\n      - name: Install lmdeploy\n        run: |\n          python3 -m pip uninstall lmdeploy -y && python3 -m pip install lmdeploy-*.whl --no-deps\n          python3 -m pip install -r requirements/test.txt\n          pip install ${{env.DEEPSEEK_VL}} --no-deps\n          rm -rf ${{env.DEEPSEEK_VL}}/build\n      - name: Check env\n        run: |\n          python3 -m pip list\n          lmdeploy check_env\n          rm -rf allure-results\n          # remove tmp log in testcase\n          mkdir ${{env.REPORT_DIR}}/.pytest_cache -p && rm autotest/.pytest_cache -f\n          ln -s ${{env.REPORT_DIR}}/.pytest_cache autotest\n      - name: Test lmdeploy - interface pipeline case\n        run: |\n          pytest autotest/interface/pipeline/test_pipeline_func.py -m 'not pr_test' -n 4 --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true\n          mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') || true\n          pytest autotest/interface/pipeline/test_pipeline_longtext_func.py -m 'gpu_num_1 and not pr_test' -n 8 --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true\n          mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') || true\n          pytest autotest/interface/pipeline/test_pipeline_longtext_func.py -m 'gpu_num_2 and not pr_test' -n 4 --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true\n          mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') || true\n          pytest autotest/interface/pipeline/test_pipeline_longtext_func.py -m 'gpu_num_4 and not pr_test' -n 2 --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true\n          mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S')\n          pytest autotest/interface/pipeline/test_pipeline_longtext_func.py -m 'gpu_num_8 and not pr_test' -n 1 --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true\n          mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S')\n      - name: Clear workfile\n        if: always()\n        run: |\n          echo \"status=done\" >> ${{env.REPORT_DIR}}/status.txt\n          chmod -R 777 ${{env.ROOT_DIR}}\n          export workdir=$(pwd)\n          cd ..\n          rm -rf $workdir\n          mkdir $workdir\n          chmod -R 777 $workdir\n\n\n  test_benchmark:\n    if: ${{!cancelled() && (github.event_name == 'schedule' || contains(fromJSON(github.event.inputs.regression_func), 'benchmark'))}}\n    runs-on: [self-hosted, linux-a100]\n    needs: test_quantization\n    timeout-minutes: 120\n    container:\n      image: openmmlab/lmdeploy:latest-cu12.8\n      options: \"--gpus=all --ipc=host --user root -e PIP_CACHE_DIR=/root/.cache/pip -e NVIDIA_DISABLE_REQUIRE=1 --pull never\"\n      volumes:\n        - /nvme/github-actions/pip-cache:/root/.cache/pip\n        - /nvme/github-actions/packages:/root/packages\n        - /nvme/qa_test_models:/nvme/qa_test_models\n        - /nvme/huggingface_hub:/nvme/huggingface_hub\n        - /mnt/121:/mnt/121\n        - /mnt/104:/mnt/104\n        - /mnt/bigdisk:/mnt/bigdisk\n        - /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro\n    steps:\n      - name: Copy repository and Artifacts\n        run: |\n          cp -r ${{env.TEST_CODE_PATH}}/. .\n          mkdir ${{env.REPORT_DIR}} -p\n          echo \"starttime=$(date +%s)\" > ${{env.REPORT_DIR}}/status.txt\n      - name: Install lmdeploy - dependency\n        run: |\n          python3 -m pip install -r ${{env.OFFLINE_REQUIREMENTS}}\n      - name: Install lmdeploy\n        run: |\n          python3 -m pip uninstall lmdeploy -y && python3 -m pip install lmdeploy-*.whl --no-deps\n          python3 -m pip install -r requirements/test.txt\n          pip install ${{env.DEEPSEEK_VL}} --no-deps\n          rm -rf ${{env.DEEPSEEK_VL}}/build\n      - name: Check env\n        run: |\n          python3 -m pip list\n          lmdeploy check_env\n          rm -rf allure-results\n          # remove tmp log in testcase\n          mkdir ${{env.REPORT_DIR}}/.pytest_cache -p && rm autotest/.pytest_cache -f\n          ln -s ${{env.REPORT_DIR}}/.pytest_cache autotest\n      - name: Test benchmark script\n        run: |\n          pytest autotest/benchmark -n 4 -m function --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true\n          mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S')\n      - name: Clear workfile\n        if: always()\n        run: |\n          echo \"status=done\" >> ${{env.REPORT_DIR}}/status.txt\n          chmod -R 777 ${{env.ROOT_DIR}}\n          export workdir=$(pwd)\n          cd ..\n          rm -rf $workdir\n          mkdir $workdir\n          chmod -R 777 $workdir\n\n\n  test_restful_legacy:\n    if: ${{!cancelled() && (github.event_name == 'schedule' || contains(fromJSON(github.event.inputs.regression_func), 'restful'))}}\n    runs-on: [self-hosted, linux-a100]\n    needs: test_quantization\n    strategy:\n      fail-fast: false\n      matrix:\n        backend: ${{ fromJSON(inputs.backend || '[\"turbomind\", \"pytorch\"]')}}\n        model_path: ['internlm/Intern-S1']\n        include:\n          - tp: 8\n            model: Intern-S1\n            model_path: internlm/Intern-S1\n            case_info: ['chat_completions_v1', 'generate']\n            generate_type: base\n    timeout-minutes: 60\n    container:\n      image: openmmlab/lmdeploy:latest-cu12.8\n      options: \"--gpus=all --ipc=host --user root -e PIP_CACHE_DIR=/root/.cache/pip -e NVIDIA_DISABLE_REQUIRE=1 --pull never\"\n      volumes:\n        - /nvme/github-actions/pip-cache:/root/.cache/pip\n        - /nvme/github-actions/packages:/root/packages\n        - /nvme/qa_test_models:/nvme/qa_test_models\n        - /nvme/huggingface_hub:/nvme/huggingface_hub\n        - /mnt/121:/mnt/121\n        - /mnt/104:/mnt/104\n        - /mnt/bigdisk:/mnt/bigdisk\n        - /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro\n    steps:\n      - name: Copy repository and Artifacts\n        run: |\n          cp -r ${{env.TEST_CODE_PATH}}/. .\n          mkdir ${{env.REPORT_DIR}} -p\n          echo \"starttime=$(date +%s)\" > ${{env.REPORT_DIR}}/status.txt\n      - name: Install lmdeploy - dependency\n        run: |\n          python3 -m pip install -r ${{env.OFFLINE_REQUIREMENTS}}\n      - name: Install lmdeploy\n        run: |\n          python3 -m pip uninstall lmdeploy -y && python3 -m pip install lmdeploy-*.whl --no-deps\n          python3 -m pip install -r requirements/test.txt\n      - name: Check env\n        run: |\n          pip install transformers==4.57.6\n          python3 -m pip list\n          lmdeploy check_env\n          rm -rf allure-results\n          # remove tmp log in testcase\n          mkdir ${{env.REPORT_DIR}}/.pytest_cache -p && rm autotest/.pytest_cache -f\n          ln -s ${{env.REPORT_DIR}}/.pytest_cache autotest\n      - name: Start restful api\n        run: |\n          lmdeploy serve api_server /nvme/qa_test_models/${{matrix.model_path}} --tp ${{matrix.tp}} --backend ${{matrix.backend}} ${{matrix.extra}} --allow-terminate-by-client > ${{env.REPORT_DIR}}/${{matrix.backend}}_${{matrix.model}}_${{matrix.generate_type}}_start_restful.log 2>&1 &\n          echo \"restful_pid=$!\"\n          for i in $(seq 1 240)\n          do\n            sleep 5\n            echo \"health check try $i\"\n            if curl -f -s http://127.0.0.1:23333/health > /dev/null 2>&1; then\n              echo \"health check success\"\n              exit 0\n            fi\n          done\n\n          echo \"health check fail\"\n          curl -f -s http://127.0.0.1:23333/terminate > /dev/null 2>&1\n          exit 1\n      - name: Test lmdeploy - chat_completions_v1\n        if:  matrix.model != 'internlm2_5-20b-chat' && matrix.model != 'Intern-S1' && contains(matrix.case_info, 'chat_completions_v1')\n        timeout-minutes: 60\n        run: |\n          pytest autotest/interface/restful/test_restful_chat_completions_v1.py -n 20 -k '${{matrix.model_path}} and ${{matrix.backend}}' -m 'not not_${{matrix.backend}} and not internlm2_5 and not interns1' --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true\n          mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S')\n      - name: Test lmdeploy - chat_completions_v1\n        if: matrix.model == 'Intern-S1' && contains(matrix.case_info, 'chat_completions_v1')\n        timeout-minutes: 60\n        run: |\n          pytest autotest/interface/restful/test_restful_chat_completions_v1.py -n 20 -k '${{matrix.model_path}} and ${{matrix.backend}}' -m 'not not_${{matrix.backend}} and not internlm2_5' --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true\n          mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S')\n      - name: Test lmdeploy - chat_completions_v1 - internlm2_5-20b-chat\n        if:  matrix.model == 'internlm2_5-20b-chat' && contains(matrix.case_info, 'chat_completions_v1')\n        timeout-minutes: 60\n        run: |\n          pytest autotest/interface/restful/test_restful_chat_completions_v1.py -n 20 -k '${{matrix.model_path}} and ${{matrix.backend}}' -m 'not not_${{matrix.backend}} and not interns1' --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true\n          mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S')\n      - name: Test lmdeploy - completions_v1 - internlm2_5-20b\n        if: matrix.model == 'internlm2_5-20b' && contains(matrix.case_info, 'completions_v1')\n        timeout-minutes: 60\n        run: |\n          pytest autotest/interface/restful/test_restful_completions_v1.py -n 20 -k '${{matrix.model_path}} and ${{matrix.backend}}' --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true\n          mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S')\n      - name: Test lmdeploy - completions_v1 - other\n        if: matrix.model != 'internlm2_5-20b' && contains(matrix.case_info, 'completions_v1')\n        timeout-minutes: 60\n        run: |\n          pytest autotest/interface/restful/test_restful_completions_v1.py -n 20 -k '${{matrix.model_path}} and ${{matrix.backend}} and not internlm2_5' --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true\n          mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S')\n      - name: Test generate - base\n        if:  matrix.generate_type == 'base' && contains(matrix.case_info, 'generate')\n        timeout-minutes: 60\n        run: |\n          pytest autotest/interface/restful/test_restful_generate.py -n 20 -k '${{matrix.model_path}} and ${{matrix.backend}}' -m 'not not_${{matrix.backend}} and not logprob and not experts' --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true\n          mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S')\n      - name: Test generate - logprob\n        if:  matrix.generate_type == 'logprob' && contains(matrix.case_info, 'generate')\n        timeout-minutes: 60\n        run: |\n          pytest autotest/interface/restful/test_restful_generate.py -n 20 -k '${{matrix.model_path}} and ${{matrix.backend}}' -m 'not not_${{matrix.backend}} and not experts' --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true\n          mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S')\n      - name: Test generate - all\n        if:  matrix.generate_type == 'all' && contains(matrix.case_info, 'generate')\n        timeout-minutes: 60\n        run: |\n          pytest autotest/interface/restful/test_restful_generate.py -n 20 -k '${{matrix.model_path}} and ${{matrix.backend}}' -m 'not not_${{matrix.backend}}' --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true\n          mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S')\n      - name: Kill api server\n        if: always()\n        run: |\n          curl -f -s http://127.0.0.1:23333/terminate > /dev/null 2>&1\n      - name: Clear workfile\n        if: always()\n        run: |\n          echo \"status=done\" >> ${{env.REPORT_DIR}}/status.txt\n          chmod -R 777 ${{env.ROOT_DIR}}\n          export workdir=$(pwd)\n          cd ..\n          rm -rf $workdir\n          mkdir $workdir\n          chmod -R 777 $workdir\n\n  test_pipeline_legacy:\n    if: ${{!cancelled() && (github.event_name == 'schedule' || contains(fromJSON(github.event.inputs.regression_func), 'pipeline'))}}\n    runs-on: [self-hosted, linux-a100]\n    needs: test_quantization\n    timeout-minutes: 240\n    container:\n      image: openmmlab/lmdeploy:latest-cu12.8\n      options: \"--gpus=all --ipc=host --user root -e PIP_CACHE_DIR=/root/.cache/pip -e NVIDIA_DISABLE_REQUIRE=1 --pull never\"\n      volumes:\n        - /nvme/github-actions/pip-cache:/root/.cache/pip\n        - /nvme/github-actions/packages:/root/packages\n        - /nvme/qa_test_models:/nvme/qa_test_models\n        - /nvme/huggingface_hub:/nvme/huggingface_hub\n        - /mnt/121:/mnt/121\n        - /mnt/104:/mnt/104\n        - /mnt/bigdisk:/mnt/bigdisk\n        - /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro\n    steps:\n      - name: Copy repository and Artifacts\n        run: |\n          cp -r ${{env.TEST_CODE_PATH}}/. .\n          mkdir ${{env.REPORT_DIR}} -p\n          echo \"starttime=$(date +%s)\" > ${{env.REPORT_DIR}}/status.txt\n      - name: Install lmdeploy - dependency\n        run: |\n          python3 -m pip install -r ${{env.OFFLINE_REQUIREMENTS}}\n      - name: Install lmdeploy\n        run: |\n          python3 -m pip uninstall lmdeploy -y && python3 -m pip install lmdeploy-*.whl --no-deps\n          python3 -m pip install -r requirements/test.txt\n          pip install ${{env.DEEPSEEK_VL}} --no-deps\n          rm -rf ${{env.DEEPSEEK_VL}}/build\n      - name: Check env\n        run: |\n          pip install transformers==4.57.6\n          python3 -m pip list\n          lmdeploy check_env\n          rm -rf allure-results\n          # remove tmp log in testcase\n          mkdir ${{env.REPORT_DIR}}/.pytest_cache -p && rm autotest/.pytest_cache -f\n          ln -s ${{env.REPORT_DIR}}/.pytest_cache autotest\n      - name: Test lmdeploy - interface pipeline case\n        run: |\n          pytest autotest/interface/pipeline/test_pipeline_func.py -m 'not pr_test' -n 4 --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true\n          mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') || true\n          pytest autotest/interface/pipeline/test_pipeline_longtext_func.py -m 'gpu_num_1 and not pr_test' -n 8 --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true\n          mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') || true\n          pytest autotest/interface/pipeline/test_pipeline_longtext_func.py -m 'gpu_num_2 and not pr_test' -n 4 --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true\n          mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') || true\n          pytest autotest/interface/pipeline/test_pipeline_longtext_func.py -m 'gpu_num_4 and not pr_test' -n 2 --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true\n          mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S')\n          pytest autotest/interface/pipeline/test_pipeline_longtext_func.py -m 'gpu_num_8 and not pr_test' -n 1 --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true\n          mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S')\n      - name: Clear workfile\n        if: always()\n        run: |\n          echo \"status=done\" >> ${{env.REPORT_DIR}}/status.txt\n          chmod -R 777 ${{env.ROOT_DIR}}\n          export workdir=$(pwd)\n          cd ..\n          rm -rf $workdir\n          mkdir $workdir\n          chmod -R 777 $workdir\n\n  get_coverage_report:\n    if: ${{!cancelled()}}\n    runs-on: [self-hosted, linux-a100]\n    needs: [test_tools, test_restful, test_pipeline, test_benchmark]\n    timeout-minutes: 5\n    container:\n      image: openmmlab/lmdeploy:latest-cu12.8\n      options: \"--gpus=all --ipc=host --user root -e PIP_CACHE_DIR=/root/.cache/pip -e NVIDIA_DISABLE_REQUIRE=1 --pull never\"\n      volumes:\n        - /nvme/github-actions/pip-cache:/root/.cache/pip\n        - /nvme/github-actions/packages:/root/packages\n        - /nvme/qa_test_models:/nvme/qa_test_models\n        - /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro\n    steps:\n      - name: Copy repository and Artifacts\n        run: cp -r ${{env.TEST_CODE_PATH}}/. .\n      - name: Install lmdeploy\n        run: |\n          echo \"status=done\" >> ${{env.REPORT_DIR}}/status.txt\n          python3 -m pip uninstall lmdeploy -y && python3 -m pip install lmdeploy-*.whl --no-deps\n          python3 -m pip install -r requirements/test.txt\n      - name: Get coverage report\n        run: |\n          pip install coverage\n          coverage combine ${{env.REPORT_DIR}}\n          coverage xml -o ${{env.REPORT_DIR}}/coverage.xml\n          coverage report -m\n          mv .coverage ${{env.REPORT_DIR}}/.coverage\n      - name: Clear workfile\n        if: always()\n        run: |\n          chmod -R 777 ${{env.ROOT_DIR}}\n          export workdir=$(pwd)\n          cd ..\n          rm -rf $workdir\n          mkdir $workdir\n          chmod -R 777 $workdir\n"
  },
  {
    "path": ".github/workflows/daily_ete_test_3090.yml",
    "content": "name: daily_ete_test_3090\n\non:\n  workflow_dispatch:\n    inputs:\n      repo_org:\n        required: false\n        description: 'Tested repository organization name. Default is InternLM'\n        type: string\n        default: 'InternLM/lmdeploy'\n      repo_ref:\n        required: false\n        description: 'Set branch or tag or commit id. Default is \"main\"'\n        type: string\n        default: 'main'\n      backend:\n        required: true\n        description: 'Set backend filter. Default is \"[\"turbomind\", \"pytorch\"]\"'\n        type: string\n        default: \"['turbomind', 'pytorch']\"\n      model:\n        required: true\n        description: 'Set testcase module filter: llm, mllm. Default contains all models'\n        type: string\n        default: \"['llm','mllm']\"\n      function:\n        required: true\n        description: 'Set testcase function filter: chat, restful, pipeline. Default contains all functions'\n        type: string\n        default: '[\"pipeline\", \"restful\", \"chat\"]'\n      offline_mode:\n        required: true\n        description: 'Whether start a offline mode, if true, you should prepare code and whl package by yourself'\n        type: boolean\n        default: false\n      regression_func:\n        required: true\n        description: 'regression functions'\n        type: string\n        default: \"['quant', 'tools', 'restful']\"\n  schedule:\n    - cron:  '00 14 * * 0-4'\n\nenv:\n  HOST_PIP_CACHE_DIR: /nvme/github-actions/pip-cache\n  HOST_LOCALTIME: /usr/share/zoneinfo/Asia/Shanghai\n  OUTPUT_FOLDER: cuda12.4_dist_${{ github.run_id }}\n  ACTIONS_ALLOW_USE_UNSECURE_NODE_VERSION: true\n  REPORT_DIR: /nvme/qa_test_models/test-reports/${{ inputs.repo_ref || 'main' }}_${{ github.run_id }}\n  COV_PARAM: --cov /opt/py3/lib/python3.10/site-packages/lmdeploy\n  FAIL_CONFIG: ${{ github.event_name == 'schedule' && github.run_attempt != 1 && '--lf --lfnf none' || '--lf'}}\n  TEST_CODE_PATH: /nvme/qa_test_models/test_pkg/lmdeploy/${{ inputs.repo_ref || 'main' }}_${{ github.run_id }}\n  OFFLINE_CODE_PATH: /nvme/qa_test_models/offline_pkg/lmdeploy\n  OFFLINE_REQUIREMENTS: /nvme/qa_test_models/offline_pkg/requirements.txt\n  RUN_ID: ${{ inputs.repo_ref || 'main' }}_${{ github.run_id }}\n\njobs:\n  linux-build:\n    if: ${{!cancelled() && (github.event_name == 'schedule' || !inputs.offline_mode)}}\n    strategy:\n      matrix:\n        pyver: [py310]\n    runs-on: ubuntu-latest\n    env:\n      PYTHON_VERSION: ${{ matrix.pyver }}\n      PLAT_NAME: manylinux2014_x86_64\n      DOCKER_TAG: cuda12.4\n    steps:\n      - name: Free disk space\n        uses: jlumbroso/free-disk-space@main\n        with:\n          # This might remove tools that are actually needed, if set to \"true\" but frees about 6 GB\n          tool-cache: false\n          docker-images: false\n          # All of these default to true, but feel free to set to \"false\" if necessary for your workflow\n          android: true\n          dotnet: true\n          haskell: true\n          large-packages: true\n          swap-storage: false\n      - name: Checkout repository\n        uses: actions/checkout@v3\n        with:\n          repository: ${{ github.event.inputs.repo_org || 'InternLM/lmdeploy' }}\n          ref: ${{github.event.inputs.repo_ref || 'main'}}\n      - name: Build\n        run: |\n          echo ${PYTHON_VERSION}\n          echo ${PLAT_NAME}\n          echo ${DOCKER_TAG}\n          echo ${OUTPUT_FOLDER}\n          echo ${GITHUB_RUN_ID}\n          # remove -it\n          sed -i 's/docker run --rm -it/docker run --rm/g' builder/manywheel/build_wheel.sh\n          bash builder/manywheel/build_wheel.sh ${PYTHON_VERSION} ${PLAT_NAME} ${DOCKER_TAG} ${OUTPUT_FOLDER}\n      - name: Upload Artifacts\n        uses: actions/upload-artifact@v4\n        with:\n          if-no-files-found: error\n          path: builder/manywheel/${{ env.OUTPUT_FOLDER }}\n          retention-days: 1\n          name: my-artifact-${{ github.run_id }}-${{ matrix.pyver }}\n\n\n  download_pkgs:\n    needs: linux-build\n    if: ${{!cancelled()}}\n    runs-on: [self-hosted, 3090-r1]\n    timeout-minutes: 50\n    container:\n      image: openmmlab/lmdeploy:latest-cu12\n      options: \"--gpus=all --ipc=host --user root -e PIP_CACHE_DIR=/root/.cache/pip -e NVIDIA_DISABLE_REQUIRE=1 --pull never\"\n      volumes:\n        - /nvme/qa_test_models:/nvme/qa_test_models\n        - /data1:/data1\n        - /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro\n    steps:\n      - name: Clone repository\n        uses: actions/checkout@v2\n        if: ${{github.event_name == 'schedule' || !inputs.offline_mode}}\n        with:\n          repository: ${{ github.event.inputs.repo_org || 'InternLM/lmdeploy' }}\n          ref: ${{github.event.inputs.repo_ref || 'main'}}\n      - name: Copy repository\n        if: ${{github.event_name == 'schedule' || !inputs.offline_mode}}\n        run: rm -rf ${{env.TEST_CODE_PATH}} && mkdir ${{env.TEST_CODE_PATH}} && cp -r . ${{env.TEST_CODE_PATH}}\n      - name: Copy repository - offline\n        if: ${{inputs.offline_mode}}\n        run: rm -rf ${{env.TEST_CODE_PATH}} && mkdir ${{env.TEST_CODE_PATH}} && cp -r ${{env.OFFLINE_CODE_PATH}}/. ${{env.TEST_CODE_PATH}}\n      - name: Download Artifacts\n        if: ${{github.event_name == 'schedule' || !inputs.offline_mode}}\n        uses: actions/download-artifact@v4\n        with:\n          name: my-artifact-${{ github.run_id }}-py310\n      - name: Copy Artifacts\n        if: ${{github.event_name == 'schedule' || !inputs.offline_mode}}\n        run: rm ${{env.TEST_CODE_PATH}}/lmdeploy-*.whl -f && cp lmdeploy-*.whl ${{env.TEST_CODE_PATH}}\n      - name: Copy Artifacts - offline\n        if: ${{inputs.offline_mode}}\n        run: rm ${{env.TEST_CODE_PATH}}/lmdeploy-*.whl -f && cp ${{env.OFFLINE_CODE_PATH}}/lmdeploy-*.whl ${{env.TEST_CODE_PATH}}\n      - name: Mark as start\n        run: |\n          mkdir ${{env.REPORT_DIR}} -p\n          echo \"starttime=$(date +%s)\" > ${{env.REPORT_DIR}}/status.txt\n\n  test_quantization:\n    needs: download_pkgs\n    if: ${{!cancelled() && contains(needs.download_pkgs.result, 'success') && (github.event_name == 'schedule' || contains(fromJSON(github.event.inputs.regression_func), 'quant') )}}\n    runs-on: [self-hosted, 3090-r1]\n    timeout-minutes: 150\n    env:\n      PYTHONPATH: /nvme/qa_test_models/offline_pkg/LLaVA\n      MODELSCOPE_CACHE: /nvme/qa_test_models/modelscope_hub\n      MODELSCOPE_MODULES_CACHE: /nvme/qa_test_models/modelscope_modules\n      TEST_ENV: 3090_legacy\n    container:\n      image: openmmlab/lmdeploy:latest-cu12\n      options: \"--gpus=all --ipc=host --user root -e PIP_CACHE_DIR=/root/.cache/pip -e NVIDIA_DISABLE_REQUIRE=1 --pull never\"\n      volumes:\n        - /nvme/github-actions/pip-cache:/root/.cache/pip\n        - /nvme/qa_test_models:/nvme/qa_test_models\n        - /data1:/data1\n        - /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro\n    steps:\n      - name: Copy repository and Artifacts\n        run: |\n          cp -r ${{env.TEST_CODE_PATH}}/. .\n          mkdir ${{env.REPORT_DIR}} -p\n          echo \"starttime=$(date +%s)\" > ${{env.REPORT_DIR}}/status.txt\n      - name: Install lmdeploy - dependency\n        run: |\n          python3 -m pip install auto_gptq matplotlib\n          python3 -m pip install -r requirements/lite.txt\n      - name: Install lmdeploy\n        run: |\n          python3 -m pip uninstall lmdeploy -y && python3 -m pip install lmdeploy-*.whl --no-deps\n          python3 -m pip install -r requirements/test.txt\n      - name: Check env\n        run: |\n          python3 -m pip list\n          pip install transformers==4.57.6\n          lmdeploy check_env\n          rm -rf allure-results\n          # remove tmp log in testcase\n          mkdir ${{env.REPORT_DIR}}/.pytest_cache -p && rm autotest/.pytest_cache -f\n          ln -s ${{env.REPORT_DIR}}/.pytest_cache autotest\n      - name: Test lmdeploy - quantization w4a16\n        continue-on-error: true\n        if: github.event_name == 'schedule' || contains(fromJSON(github.event.inputs.backend), 'turbomind')\n        run: |\n          pytest autotest/tools/quantization/test_quantization_awq.py -m 'not pr_test and test_3090' --alluredir=${{env.REPORT_DIR}} --clean-alluredir ${{env.COV_PARAM}} || true\n          mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S')\n      - name: Test lmdeploy - quantization w8a8\n        continue-on-error: true\n        if: github.event_name == 'schedule' || contains(fromJSON(github.event.inputs.backend), 'pytorch')\n        run: |\n          pytest autotest/tools/quantization/test_quantization_w8a8.py --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true\n          mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S')\n      - name: Clear workfile\n        if: always()\n        run: |\n          echo \"status=done\" >> ${{env.REPORT_DIR}}/status.txt\n          chmod -R 777 $REPORT_DIR\n          export workdir=$(pwd)\n          cd ..\n          rm -rf $workdir\n          mkdir $workdir\n          chmod -R 777 $workdir\n\n  test_tools:\n    if: ${{!cancelled() && !contains(needs.test_quantization.result, 'fail') && (github.event_name == 'schedule' || contains(fromJSON(github.event.inputs.regression_func), 'tools'))}}\n    runs-on: [self-hosted, 3090-r1]\n    needs: test_quantization\n    timeout-minutes: 300\n    strategy:\n      fail-fast: false\n      matrix:\n        backend: ${{ fromJSON(inputs.backend || '[\"turbomind\", \"pytorch\"]')}}\n        transformers: [\"3090\", \"3090_legacy\"]\n        model: ${{ fromJSON(inputs.model || '[\"llm\", \"mllm\"]')}}\n        function: ${{ fromJSON(inputs.function || '[\"pipeline\",\"restful\",\"chat\"]')}}\n        exclude:\n          - backend: turbomind\n            model: mllm\n            function: chat\n          - backend: pytorch\n            model: mllm\n            function: chat\n    env:\n      PYTHONPATH: /nvme/qa_test_models/offline_pkg/LLaVA\n      MODELSCOPE_CACHE: /nvme/qa_test_models/modelscope_hub\n      MODELSCOPE_MODULES_CACHE: /nvme/qa_test_models/modelscope_modules\n      TEST_ENV: ${{matrix.transformers}}\n    container:\n      image: openmmlab/lmdeploy:latest-cu12\n      options: \"--gpus=all --ipc=host --user root -e PIP_CACHE_DIR=/root/.cache/pip -e NVIDIA_DISABLE_REQUIRE=1 --pull never\"\n      volumes:\n        - /nvme/github-actions/pip-cache:/root/.cache/pip\n        - /nvme/qa_test_models:/nvme/qa_test_models\n        - /data1:/data1\n        - /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro\n    steps:\n      - name: Copy repository and Artifacts\n        run: |\n          cp -r ${{env.TEST_CODE_PATH}}/. .\n          mkdir ${{env.REPORT_DIR}} -p\n          echo \"starttime=$(date +%s)\" > ${{env.REPORT_DIR}}/status.txt\n      - name: Install lmdeploy - dependency\n        run: |\n          python3 -m pip install -r ${{env.OFFLINE_REQUIREMENTS}}\n      - name: Install lmdeploy\n        run: |\n          python3 -m pip uninstall lmdeploy -y && python3 -m pip install lmdeploy-*.whl --no-deps\n          python3 -m pip install -r requirements/test.txt\n      - name: Downgrade transformers\n        if: ${{matrix.transformers == '3090_legacy'}}\n        run: |\n          pip install transformers==4.57.6\n      - name: Check env\n        run: |\n          python3 -m pip list\n          lmdeploy check_env\n          rm -rf allure-results\n          # remove tmp log in testcase\n          mkdir ${{env.REPORT_DIR}}/.pytest_cache -p && rm autotest/.pytest_cache -f\n          ln -s ${{env.REPORT_DIR}}/.pytest_cache autotest\n      - name: Test lmdeploy - chat\n        continue-on-error: true\n        if: (matrix.backend == 'pytorch' || matrix.backend == 'turbomind') && matrix.model == 'llm' && matrix.function == 'chat'\n        run: |\n          pytest autotest/tools/chat/test_command_chat_hf_${{matrix.backend}}.py -m 'gpu_num_1 and not pr_test and test_3090' --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true\n          mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') || true\n      - name: Test lmdeploy - pipeline\n        continue-on-error: true\n        if: matrix.function == 'pipeline'\n        run: |\n          pytest autotest/tools/pipeline/test_pipeline_chat_${{matrix.backend}}_${{matrix.model}}.py -m 'gpu_num_1 and not pr_test and test_3090' --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true\n          mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') || true\n      - name: Test lmdeploy - restful\n        continue-on-error: true\n        if: matrix.function == 'restful'\n        run: |\n          pytest autotest/tools/restful/test_restful_chat_hf_${{matrix.backend}}_${{matrix.model}}.py -m 'gpu_num_1 and not pr_test and test_3090' --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true\n          mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') || true\n      - name: Clear workfile\n        if: always()\n        run: |\n          echo \"status=done\" >> ${{env.REPORT_DIR}}/status.txt\n          chmod -R 777 $REPORT_DIR\n          export workdir=$(pwd)\n          cd ..\n          rm -rf $workdir\n          mkdir $workdir\n          chmod -R 777 $workdir\n\n  test_restful:\n    if: ${{!cancelled() && !contains(needs.test_quantization.result, 'fail') && (github.event_name == 'schedule' || contains(fromJSON(github.event.inputs.regression_func), 'restful'))}}\n    runs-on: [self-hosted, 3090-r1]\n    needs: test_quantization\n    strategy:\n      fail-fast: false\n      matrix:\n        backend: ${{ fromJSON(inputs.backend || '[\"turbomind\", \"pytorch\"]')}}\n        transformers: [\"3090\", \"3090_legacy\"]\n        model_path: ['internlm/internlm3-8b-instruct', 'Qwen/Qwen3-8B']\n        include:\n          - tp: 1\n            model: internlm3-8b-instruct\n            model_path: internlm/internlm3-8b-instruct\n            case_info: ['chat_completions_v1', 'generate']\n            generate_type: logprob\n            extra: '--logprobs-mode raw_logprobs'\n          - tp: 1\n            model: Qwen3-8B\n            model_path: Qwen/Qwen3-8B\n            case_info: ['completions_v1']\n            generate_type: base\n    timeout-minutes: 60\n    container:\n      image: openmmlab/lmdeploy:latest-cu12\n      options: \"--gpus=all --ipc=host --user root -e PIP_CACHE_DIR=/root/.cache/pip -e NVIDIA_DISABLE_REQUIRE=1 --pull never\"\n      volumes:\n        - /nvme/github-actions/pip-cache:/root/.cache/pip\n        - /nvme/qa_test_models:/nvme/qa_test_models\n        - /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro\n    env:\n      TEST_ENV: ${{matrix.transformers}}\n    steps:\n      - name: Copy repository and Artifacts\n        run: |\n          cp -r ${{env.TEST_CODE_PATH}}/. .\n          mkdir ${{env.REPORT_DIR}} -p\n          echo \"starttime=$(date +%s)\" > ${{env.REPORT_DIR}}/status.txt\n      - name: Install lmdeploy - dependency\n        run: |\n          python3 -m pip install -r ${{env.OFFLINE_REQUIREMENTS}}\n      - name: Install lmdeploy\n        run: |\n          python3 -m pip uninstall lmdeploy -y && python3 -m pip install lmdeploy-*.whl --no-deps\n          python3 -m pip install -r requirements/test.txt\n      - name: Downgrade transformers\n        if: ${{matrix.transformers == '3090_legacy'}}\n        run: |\n          pip install transformers==4.57.6\n      - name: Check env\n        run: |\n          python3 -m pip list\n          lmdeploy check_env\n          rm -rf allure-results\n          # remove tmp log in testcase\n          mkdir ${{env.REPORT_DIR}}/.pytest_cache -p && rm autotest/.pytest_cache -f\n          ln -s ${{env.REPORT_DIR}}/.pytest_cache autotest\n      - name: Start restful api\n        run: |\n          lmdeploy serve api_server /nvme/qa_test_models/${{matrix.model_path}} --tp ${{matrix.tp}} --backend ${{matrix.backend}} ${{matrix.extra}} > ${{env.REPORT_DIR}}/${{matrix.backend}}_${{matrix.model}}_${{matrix.generate_type}}_start_restful.log 2>&1 &\n          echo \"restful_pid=$!\" >> \"$GITHUB_ENV\"\n          for i in $(seq 1 180)\n          do\n            sleep 5\n            echo \"health check try $i\"\n            if curl -f -s http://127.0.0.1:23333/health > /dev/null 2>&1; then\n              echo \"health check success\"\n              exit 0\n            fi\n          done\n\n          echo \"health check fail\"\n          kill -15 $restful_pid 2>/dev/null || true\n          exit 1\n      - name: Test lmdeploy - chat_completions_v1\n        if:  contains(matrix.case_info, 'chat_completions_v1')\n        timeout-minutes: 60\n        run: |\n          pytest autotest/interface/restful/test_restful_chat_completions_v1.py -n 20 -k '${{matrix.model_path}} and ${{matrix.backend}}' -m 'not not_${{matrix.backend}} and not internlm2_5 and not interns1' --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true\n          mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S')\n      - name: Test lmdeploy - completions_v1 - other\n        if: contains(matrix.case_info, 'completions_v1')\n        timeout-minutes: 60\n        run: |\n          pytest autotest/interface/restful/test_restful_completions_v1.py -n 20 -k '${{matrix.model_path}} and ${{matrix.backend}} and not internlm2_5' --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true\n          mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S')\n      - name: Test generate - logprob\n        if:  matrix.generate_type == 'logprob' && contains(matrix.case_info, 'generate')\n        timeout-minutes: 60\n        run: |\n          pytest autotest/interface/restful/test_restful_generate.py -n 20 -k '${{matrix.model_path}} and ${{matrix.backend}}' -m 'not not_${{matrix.backend}} and not experts' --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true\n          mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S')\n      - name: Kill api server\n        if: always()\n        run: |\n          kill -15 \"$restful_pid\"\n      - name: Clear workfile\n        if: always()\n        run: |\n          echo \"status=done\" >> ${{env.REPORT_DIR}}/status.txt\n          chmod -R 777 $REPORT_DIR\n          export workdir=$(pwd)\n          cd ..\n          rm -rf $workdir\n          mkdir $workdir\n          chmod -R 777 $workdir\n\n  get_coverage_report:\n    if: ${{!cancelled()}}\n    runs-on: [self-hosted, 3090-r1]\n    needs: [test_tools, test_restful]\n    timeout-minutes: 5\n    container:\n      image: openmmlab/lmdeploy:latest-cu12\n      options: \"--gpus=all --ipc=host --user root -e PIP_CACHE_DIR=/root/.cache/pip -e NVIDIA_DISABLE_REQUIRE=1 --pull never\"\n      volumes:\n        - /nvme/github-actions/pip-cache:/root/.cache/pip\n        - /nvme/qa_test_models:/nvme/qa_test_models\n        - /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro\n    steps:\n      - name: Copy repository and Artifacts\n        run: cp -r ${{env.TEST_CODE_PATH}}/. .\n      - name: Install lmdeploy\n        run: |\n          echo \"status=done\" >> ${{env.REPORT_DIR}}/status.txt\n          python3 -m pip uninstall lmdeploy -y && python3 -m pip install lmdeploy-*.whl --no-deps\n          python3 -m pip install -r requirements/test.txt\n      - name: Get coverage report\n        run: |\n          pip install coverage\n          coverage combine ${{env.REPORT_DIR}}\n          coverage xml -o ${{env.REPORT_DIR}}/coverage.xml\n          coverage report -m\n          mv .coverage ${{env.REPORT_DIR}}/.coverage\n      - name: Clear workfile\n        if: always()\n        run: |\n          chmod -R 777 $REPORT_DIR\n          export workdir=$(pwd)\n          cd ..\n          rm -rf $workdir\n          mkdir $workdir\n          chmod -R 777 $workdir\n"
  },
  {
    "path": ".github/workflows/daily_ete_test_5080.yml",
    "content": "name: daily_ete_test_5080\n\non:\n  workflow_dispatch:\n    inputs:\n      repo_org:\n        required: false\n        description: 'Tested repository organization name. Default is InternLM'\n        type: string\n        default: 'InternLM/lmdeploy'\n      repo_ref:\n        required: false\n        description: 'Set branch or tag or commit id. Default is \"main\"'\n        type: string\n        default: 'main'\n      backend:\n        required: true\n        description: 'Set backend filter. Default is \"[\"turbomind\", \"pytorch\"]\"'\n        type: string\n        default: \"['turbomind', 'pytorch']\"\n      model:\n        required: true\n        description: 'Set testcase module filter: llm, mllm. Default contains all models'\n        type: string\n        default: \"['llm','mllm']\"\n      function:\n        required: true\n        description: 'Set testcase function filter: chat, restful, pipeline. Default contains all functions'\n        type: string\n        default: '[\"pipeline\", \"restful\", \"chat\"]'\n      offline_mode:\n        required: true\n        description: 'Whether start a offline mode, if true, you should prepare code and whl package by yourself'\n        type: boolean\n        default: false\n      regression_func:\n        required: true\n        description: 'regression functions'\n        type: string\n        default: \"['quant', 'tools', 'restful']\"\n  schedule:\n    - cron:  '00 14 * * 0-4'\n\nenv:\n  HOST_PIP_CACHE_DIR: /nvme/github-actions/pip-cache\n  HOST_LOCALTIME: /usr/share/zoneinfo/Asia/Shanghai\n  OUTPUT_FOLDER: cuda12.8_dist_${{ github.run_id }}\n  ACTIONS_ALLOW_USE_UNSECURE_NODE_VERSION: true\n  REPORT_DIR: /nvme/qa_test_models/test-reports/${{ inputs.repo_ref || 'main' }}_${{ github.run_id }}\n  COV_PARAM: --cov /opt/py3/lib/python3.10/site-packages/lmdeploy\n  FAIL_CONFIG: ${{ github.event_name == 'schedule' && github.run_attempt != 1 && '--lf --lfnf none' || '--lf'}}\n  TEST_CODE_PATH: /nvme/qa_test_models/test_pkg/lmdeploy/${{ inputs.repo_ref || 'main' }}_${{ github.run_id }}\n  OFFLINE_CODE_PATH: /nvme/qa_test_models/offline_pkg/lmdeploy\n  OFFLINE_REQUIREMENTS: /nvme/qa_test_models/offline_pkg/requirements.txt\n  RUN_ID: ${{ inputs.repo_ref || 'main' }}_${{ github.run_id }}\n\njobs:\n  linux-build:\n    if: ${{!cancelled() && (github.event_name == 'schedule' || !inputs.offline_mode)}}\n    strategy:\n      matrix:\n        pyver: [py310]\n    runs-on: ubuntu-latest\n    env:\n      PYTHON_VERSION: ${{ matrix.pyver }}\n      PLAT_NAME: manylinux2014_x86_64\n      DOCKER_TAG: cuda12.8\n    steps:\n      - name: Free disk space\n        uses: jlumbroso/free-disk-space@main\n        with:\n          # This might remove tools that are actually needed, if set to \"true\" but frees about 6 GB\n          tool-cache: false\n          docker-images: false\n          # All of these default to true, but feel free to set to \"false\" if necessary for your workflow\n          android: true\n          dotnet: true\n          haskell: true\n          large-packages: true\n          swap-storage: false\n      - name: Checkout repository\n        uses: actions/checkout@v3\n        with:\n          repository: ${{ github.event.inputs.repo_org || 'InternLM/lmdeploy' }}\n          ref: ${{github.event.inputs.repo_ref || 'main'}}\n      - name: Build\n        run: |\n          echo ${PYTHON_VERSION}\n          echo ${PLAT_NAME}\n          echo ${DOCKER_TAG}\n          echo ${OUTPUT_FOLDER}\n          echo ${GITHUB_RUN_ID}\n          # remove -it\n          sed -i 's/docker run --rm -it/docker run --rm/g' builder/manywheel/build_wheel.sh\n          bash builder/manywheel/build_wheel.sh ${PYTHON_VERSION} ${PLAT_NAME} ${DOCKER_TAG} ${OUTPUT_FOLDER}\n      - name: Upload Artifacts\n        uses: actions/upload-artifact@v4\n        with:\n          if-no-files-found: error\n          path: builder/manywheel/${{ env.OUTPUT_FOLDER }}\n          retention-days: 1\n          name: my-artifact-${{ github.run_id }}-${{ matrix.pyver }}\n\n\n  download_pkgs:\n    needs: linux-build\n    if: ${{!cancelled()}}\n    runs-on: [self-hosted, 5080-r1]\n    timeout-minutes: 50\n    container:\n      image: openmmlab/lmdeploy:latest-cu12.8\n      options: \"--gpus=all --ipc=host --user root -e PIP_CACHE_DIR=/root/.cache/pip -e NVIDIA_DISABLE_REQUIRE=1 --pull never\"\n      volumes:\n        - /nvme/qa_test_models:/nvme/qa_test_models\n        - /mnt/3090:/mnt/3090\n        - /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro\n    steps:\n      - name: Clone repository\n        uses: actions/checkout@v2\n        if: ${{github.event_name == 'schedule' || !inputs.offline_mode}}\n        with:\n          repository: ${{ github.event.inputs.repo_org || 'InternLM/lmdeploy' }}\n          ref: ${{github.event.inputs.repo_ref || 'main'}}\n      - name: Copy repository\n        if: ${{github.event_name == 'schedule' || !inputs.offline_mode}}\n        run: rm -rf ${{env.TEST_CODE_PATH}} && mkdir ${{env.TEST_CODE_PATH}} && cp -r . ${{env.TEST_CODE_PATH}}\n      - name: Copy repository - offline\n        if: ${{inputs.offline_mode}}\n        run: rm -rf ${{env.TEST_CODE_PATH}} && mkdir ${{env.TEST_CODE_PATH}} && cp -r ${{env.OFFLINE_CODE_PATH}}/. ${{env.TEST_CODE_PATH}}\n      - name: Download Artifacts\n        if: ${{github.event_name == 'schedule' || !inputs.offline_mode}}\n        uses: actions/download-artifact@v4\n        with:\n          name: my-artifact-${{ github.run_id }}-py310\n      - name: Copy Artifacts\n        if: ${{github.event_name == 'schedule' || !inputs.offline_mode}}\n        run: rm ${{env.TEST_CODE_PATH}}/lmdeploy-*.whl -f && cp lmdeploy-*.whl ${{env.TEST_CODE_PATH}}\n      - name: Copy Artifacts - offline\n        if: ${{inputs.offline_mode}}\n        run: rm ${{env.TEST_CODE_PATH}}/lmdeploy-*.whl -f && cp ${{env.OFFLINE_CODE_PATH}}/lmdeploy-*.whl ${{env.TEST_CODE_PATH}}\n      - name: Mark as start\n        run: |\n          mkdir ${{env.REPORT_DIR}} -p\n          echo \"starttime=$(date +%s)\" > ${{env.REPORT_DIR}}/status.txt\n\n  test_quantization:\n    needs: download_pkgs\n    if: ${{!cancelled() && contains(needs.download_pkgs.result, 'success') && (github.event_name == 'schedule' || contains(fromJSON(github.event.inputs.regression_func), 'quant') )}}\n    runs-on: [self-hosted, 5080-r1]\n    timeout-minutes: 150\n    env:\n      PYTHONPATH: /nvme/qa_test_models/offline_pkg/LLaVA\n      MODELSCOPE_CACHE: /nvme/qa_test_models/modelscope_hub\n      MODELSCOPE_MODULES_CACHE: /nvme/qa_test_models/modelscope_modules\n      TEST_ENV: 5080\n    container:\n      image: openmmlab/lmdeploy:latest-cu12.8\n      options: \"--gpus=all --ipc=host --user root -e PIP_CACHE_DIR=/root/.cache/pip -e NVIDIA_DISABLE_REQUIRE=1 --pull never\"\n      volumes:\n        - /nvme/github-actions/pip-cache:/root/.cache/pip\n        - /nvme/qa_test_models:/nvme/qa_test_models\n        - /mnt/3090:/mnt/3090\n        - /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro\n    steps:\n      - name: Copy repository and Artifacts\n        run: |\n          cp -r ${{env.TEST_CODE_PATH}}/. .\n          mkdir ${{env.REPORT_DIR}} -p\n          echo \"starttime=$(date +%s)\" > ${{env.REPORT_DIR}}/status.txt\n      - name: Install lmdeploy - dependency\n        run: |\n          python3 -m pip install auto_gptq matplotlib\n          python3 -m pip install -r requirements/lite.txt\n      - name: Install lmdeploy\n        run: |\n          python3 -m pip uninstall lmdeploy -y && python3 -m pip install lmdeploy-*.whl --no-deps\n          python3 -m pip install -r requirements/test.txt\n      - name: Check env\n        run: |\n          for i in $(seq 1 10); do\n            output=$(lmdeploy check_env 2>&1)\n            if echo \"$output\" | grep -q \"CUDA available: False\"; then\n              echo \"CUDA not available (attempt $i/10), retrying in 5 seconds...\"\n              sleep 5\n            else\n              echo \"CUDA check passed\"\n              break\n            fi\n          done\n          python3 -m pip list\n          lmdeploy check_env\n          rm -rf allure-results\n          # remove tmp log in testcase\n          mkdir ${{env.REPORT_DIR}}/.pytest_cache -p && rm autotest/.pytest_cache -f\n          ln -s ${{env.REPORT_DIR}}/.pytest_cache autotest\n      - name: Test lmdeploy - quantization w4a16\n        continue-on-error: true\n        if: github.event_name == 'schedule' || contains(fromJSON(github.event.inputs.backend), 'turbomind')\n        run: |\n          pytest autotest/tools/quantization/test_quantization_awq.py -m 'not pr_test and test_3090' --alluredir=${{env.REPORT_DIR}} --clean-alluredir ${{env.COV_PARAM}} || true\n          mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S')\n      - name: Test lmdeploy - quantization w8a8\n        continue-on-error: true\n        if: github.event_name == 'schedule' || contains(fromJSON(github.event.inputs.backend), 'pytorch')\n        run: |\n          pytest autotest/tools/quantization/test_quantization_w8a8.py --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true\n          mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S')\n      - name: Clear workfile\n        if: always()\n        run: |\n          chmod -R 777 $REPORT_DIR\n          export workdir=$(pwd)\n          cd ..\n          rm -rf $workdir\n          mkdir $workdir\n          chmod -R 777 $workdir\n\n  test_tools:\n    if: ${{!cancelled() && !contains(needs.test_quantization.result, 'fail') && (github.event_name == 'schedule' || contains(fromJSON(github.event.inputs.regression_func), 'tools'))}}\n    runs-on: [self-hosted, 5080-r1]\n    needs: test_quantization\n    timeout-minutes: 300\n    strategy:\n      fail-fast: false\n      matrix:\n        backend: ${{ fromJSON(inputs.backend || '[\"turbomind\", \"pytorch\"]')}}\n        model: ${{ fromJSON(inputs.model || '[\"llm\", \"mllm\"]')}}\n        transformers: [\"5080\", \"5080_legacy\"]\n        function: ${{ fromJSON(inputs.function || '[\"pipeline\",\"restful\",\"chat\"]')}}\n        exclude:\n          - backend: turbomind\n            model: mllm\n            function: chat\n          - backend: pytorch\n            model: mllm\n            function: chat\n    env:\n      PYTHONPATH: /nvme/qa_test_models/offline_pkg/LLaVA\n      MODELSCOPE_CACHE: /nvme/qa_test_models/modelscope_hub\n      MODELSCOPE_MODULES_CACHE: /nvme/qa_test_models/modelscope_modules\n      TEST_ENV: ${{ matrix.transformers }}\n    container:\n      image: openmmlab/lmdeploy:latest-cu12.8\n      options: \"--gpus=all --ipc=host --user root -e PIP_CACHE_DIR=/root/.cache/pip -e NVIDIA_DISABLE_REQUIRE=1 --pull never\"\n      volumes:\n        - /nvme/github-actions/pip-cache:/root/.cache/pip\n        - /nvme/qa_test_models:/nvme/qa_test_models\n        - /mnt/3090:/mnt/3090\n        - /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro\n    steps:\n      - name: Copy repository and Artifacts\n        run: |\n          cp -r ${{env.TEST_CODE_PATH}}/. .\n          mkdir ${{env.REPORT_DIR}} -p\n          echo \"starttime=$(date +%s)\" > ${{env.REPORT_DIR}}/status.txt\n      - name: Install lmdeploy - dependency\n        run: |\n          python3 -m pip install -r ${{env.OFFLINE_REQUIREMENTS}}\n      - name: Install lmdeploy\n        run: |\n          python3 -m pip uninstall lmdeploy -y && python3 -m pip install lmdeploy-*.whl --no-deps\n          python3 -m pip install -r requirements/test.txt\n      - name: Downgrade transformers\n        if: ${{matrix.transformers == '5080_legacy'}}\n        run: |\n          pip install transformers==4.57.6\n      - name: Check env\n        run: |\n          for i in $(seq 1 10); do\n            output=$(lmdeploy check_env 2>&1)\n            if echo \"$output\" | grep -q \"CUDA available: False\"; then\n              echo \"CUDA not available (attempt $i/10), retrying in 5 seconds...\"\n              sleep 5\n            else\n              echo \"CUDA check passed\"\n              break\n            fi\n          done\n          python3 -m pip list\n          lmdeploy check_env\n          rm -rf allure-results\n          # remove tmp log in testcase\n          mkdir ${{env.REPORT_DIR}}/.pytest_cache -p && rm autotest/.pytest_cache -f\n          ln -s ${{env.REPORT_DIR}}/.pytest_cache autotest\n      - name: Test lmdeploy - chat\n        continue-on-error: true\n        if: (matrix.backend == 'pytorch' || matrix.backend == 'turbomind') && matrix.model == 'llm' && matrix.function == 'chat'\n        run: |\n          pytest autotest/tools/chat/test_command_chat_hf_${{matrix.backend}}.py -m 'gpu_num_1 and not pr_test and test_3090' --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true\n          mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') || true\n      - name: Test lmdeploy - pipeline\n        continue-on-error: true\n        if: matrix.function == 'pipeline'\n        run: |\n          pytest autotest/tools/pipeline/test_pipeline_chat_${{matrix.backend}}_${{matrix.model}}.py -m 'gpu_num_1 and not pr_test and test_3090' --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true\n          mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') || true\n      - name: Test lmdeploy - restful\n        continue-on-error: true\n        if: matrix.function == 'restful'\n        run: |\n          pytest autotest/tools/restful/test_restful_chat_hf_${{matrix.backend}}_${{matrix.model}}.py -m 'gpu_num_1 and not pr_test and test_3090' --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true\n          mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') || true\n      - name: Clear workfile\n        if: always()\n        run: |\n          chmod -R 777 $REPORT_DIR\n          export workdir=$(pwd)\n          cd ..\n          rm -rf $workdir\n          mkdir $workdir\n          chmod -R 777 $workdir\n\n  test_restful:\n    if: ${{!cancelled() && !contains(needs.test_quantization.result, 'fail') && (github.event_name == 'schedule' || contains(fromJSON(github.event.inputs.regression_func), 'restful'))}}\n    runs-on: [self-hosted, 5080-r1]\n    needs: test_quantization\n    strategy:\n      fail-fast: false\n      matrix:\n        backend: ${{ fromJSON(inputs.backend || '[\"turbomind\", \"pytorch\"]')}}\n        model_path: ['meta-llama/Llama-3.2-3B-Instruct', 'Qwen/Qwen3-4B']\n        transformers: [\"5080\", \"5080_legacy\"]\n        include:\n          - tp: 1\n            model: Llama-3.2-3B-Instruct\n            model_path: meta-llama/Llama-3.2-3B-Instruct\n            case_info: ['chat_completions_v1', 'generate']\n            generate_type: logprob\n            extra: '--logprobs-mode raw_logprobs'\n          - tp: 1\n            model: Qwen3-4B\n            model_path: Qwen/Qwen3-4B\n            case_info: ['completions_v1']\n            generate_type: base\n    timeout-minutes: 60\n    container:\n      image: openmmlab/lmdeploy:latest-cu12.8\n      options: \"--gpus=all --ipc=host --user root -e PIP_CACHE_DIR=/root/.cache/pip -e NVIDIA_DISABLE_REQUIRE=1 --pull never\"\n      volumes:\n        - /nvme/github-actions/pip-cache:/root/.cache/pip\n        - /nvme/qa_test_models:/nvme/qa_test_models\n        - /mnt/3090:/mnt/3090\n        - /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro\n    env:\n      TEST_ENV: ${{ matrix.transformers }}\n    steps:\n      - name: Copy repository and Artifacts\n        run: |\n          cp -r ${{env.TEST_CODE_PATH}}/. .\n          mkdir ${{env.REPORT_DIR}} -p\n          echo \"starttime=$(date +%s)\" > ${{env.REPORT_DIR}}/status.txt\n      - name: Install lmdeploy - dependency\n        run: |\n          python3 -m pip install -r ${{env.OFFLINE_REQUIREMENTS}}\n      - name: Install lmdeploy\n        run: |\n          python3 -m pip uninstall lmdeploy -y && python3 -m pip install lmdeploy-*.whl --no-deps\n          python3 -m pip install -r requirements/test.txt\n      - name: Downgrade transformers\n        if: ${{matrix.transformers == '5080_legacy'}}\n        run: |\n          pip install transformers==4.57.6\n      - name: Check env\n        run: |\n          for i in $(seq 1 10); do\n            output=$(lmdeploy check_env 2>&1)\n            if echo \"$output\" | grep -q \"CUDA available: False\"; then\n              echo \"CUDA not available (attempt $i/10), retrying in 5 seconds...\"\n              sleep 5\n            else\n              echo \"CUDA check passed\"\n              break\n            fi\n          done\n          python3 -m pip list\n          lmdeploy check_env\n          rm -rf allure-results\n          # remove tmp log in testcase\n          mkdir ${{env.REPORT_DIR}}/.pytest_cache -p && rm autotest/.pytest_cache -f\n          ln -s ${{env.REPORT_DIR}}/.pytest_cache autotest\n      - name: Start restful api\n        run: |\n          lmdeploy serve api_server /nvme/qa_test_models/${{matrix.model_path}} --tp ${{matrix.tp}} --backend ${{matrix.backend}} ${{matrix.extra}} > ${{env.REPORT_DIR}}/${{matrix.backend}}_${{matrix.model}}_${{matrix.generate_type}}_start_restful.log 2>&1 &\n          echo \"restful_pid=$!\" >> \"$GITHUB_ENV\"\n          for i in $(seq 1 50)\n          do\n            sleep 5\n            echo \"health check try $i\"\n            if curl -f -s http://127.0.0.1:23333/health > /dev/null 2>&1; then\n              echo \"health check success\"\n              exit 0\n            fi\n          done\n\n          echo \"health check fail\"\n          kill -15 $restful_pid 2>/dev/null || true\n          exit 1\n      - name: Test lmdeploy - chat_completions_v1\n        if:  contains(matrix.case_info, 'chat_completions_v1')\n        timeout-minutes: 60\n        run: |\n          pytest autotest/interface/restful/test_restful_chat_completions_v1.py -n 20 -k '${{matrix.model_path}} and ${{matrix.backend}}' -m 'not not_${{matrix.backend}} and not internlm2_5 and not interns1' --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true\n          mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S')\n      - name: Test lmdeploy - completions_v1 - other\n        if: contains(matrix.case_info, 'completions_v1')\n        timeout-minutes: 60\n        run: |\n          pytest autotest/interface/restful/test_restful_completions_v1.py -n 20 -k '${{matrix.model_path}} and ${{matrix.backend}} and not internlm2_5' --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true\n          mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S')\n      - name: Test generate - logprob\n        if:  matrix.generate_type == 'logprob' && contains(matrix.case_info, 'generate')\n        timeout-minutes: 60\n        run: |\n          pytest autotest/interface/restful/test_restful_generate.py -n 20 -k '${{matrix.model_path}} and ${{matrix.backend}}' -m 'not not_${{matrix.backend}} and not experts' --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true\n          mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S')\n      - name: Kill api server\n        if: always()\n        run: |\n          kill -15 \"$restful_pid\"\n      - name: Clear workfile\n        if: always()\n        run: |\n          echo \"status=done\" >> ${{env.REPORT_DIR}}/status.txt\n          chmod -R 777 $REPORT_DIR\n          export workdir=$(pwd)\n          cd ..\n          rm -rf $workdir\n          mkdir $workdir\n          chmod -R 777 $workdir\n\n  get_coverage_report:\n    if: ${{!cancelled()}}\n    runs-on: [self-hosted, 5080-r1]\n    needs: [test_tools, test_restful]\n    timeout-minutes: 5\n    container:\n      image: openmmlab/lmdeploy:latest-cu12.8\n      options: \"--gpus=all --ipc=host --user root -e PIP_CACHE_DIR=/root/.cache/pip -e NVIDIA_DISABLE_REQUIRE=1 --pull never\"\n      volumes:\n        - /nvme/github-actions/pip-cache:/root/.cache/pip\n        - /nvme/qa_test_models:/nvme/qa_test_models\n        - /mnt/3090:/mnt/3090\n        - /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro\n    steps:\n      - name: Copy repository and Artifacts\n        run: cp -r ${{env.TEST_CODE_PATH}}/. .\n      - name: Install lmdeploy\n        run: |\n          echo \"status=done\" >> ${{env.REPORT_DIR}}/status.txt\n          python3 -m pip uninstall lmdeploy -y && python3 -m pip install lmdeploy-*.whl --no-deps\n          python3 -m pip install -r requirements/test.txt\n      - name: Get coverage report\n        run: |\n          pip install coverage\n          coverage combine ${{env.REPORT_DIR}}\n          coverage xml -o ${{env.REPORT_DIR}}/coverage.xml\n          coverage report -m\n          mv .coverage ${{env.REPORT_DIR}}/.coverage\n      - name: Clear workfile\n        if: always()\n        run: |\n          chmod -R 777 $REPORT_DIR\n          export workdir=$(pwd)\n          cd ..\n          rm -rf $workdir\n          mkdir $workdir\n          chmod -R 777 $workdir\n"
  },
  {
    "path": ".github/workflows/docker.yml",
    "content": "name: publish-docker\n\non:\n  push:\n    paths-ignore:\n      - \"!.github/workflows/docker.yml\"\n      - \".github/**\"\n      - \"docs/**\"\n      - \"resources/**\"\n      - \"benchmark/**\"\n      - \"tests/**\"\n      - \"**/*.md\"\n      - \"autotest/**\"\n      - \"builder/**\"\n      - \"k8s/**\"\n\n    branches:\n      - main\n    tags:\n      - \"v*.*.*\"\n  workflow_dispatch:\n    inputs:\n      repo_ref:\n        required: false\n        description: 'Set branch or tag or commit id. Default is \"\"'\n        type: string\n        default: 'main'\n      image_tag:\n        required: true\n        description: 'Set docker image tag. Default is \"latest\"'\n        type: string\n        default: latest\n\njobs:\n  publish_docker_image:\n    runs-on: ubuntu-latest\n    environment: 'prod'\n    strategy:\n      fail-fast: false\n      matrix:\n        cuda_version: ['cu12.8', 'cu12']\n    env:\n      CUDA_VERSION: ${{ matrix.cuda_version }}\n      TAG_PREFIX: \"openmmlab/lmdeploy\"\n      TAG: \"openmmlab/lmdeploy:latest-${{matrix.cuda_version}}\"\n    steps:\n      - name: Checkout repository\n        uses: actions/checkout@v4\n        with:\n          ref: ${{github.event.inputs.repo_ref}}\n      - name: Free disk space\n        uses: jlumbroso/free-disk-space@main\n        with:\n          # This might remove tools that are actually needed, if set to \"true\" but frees about 6 GB\n          tool-cache: false\n          docker-images: false\n          # All of these default to true, but feel free to set to \"false\" if necessary for your workflow\n          android: true\n          dotnet: true\n          haskell: true\n          large-packages: true\n          swap-storage: false\n      - name: Get docker info\n        run: |\n          docker info\n          # remove http extraheader\n          git config --local --unset \"http.https://github.com/.extraheader\"\n      - name: Login to Docker Hub\n        uses: docker/login-action@v2\n        with:\n          username: ${{ secrets.DOCKERHUB_USERNAME }}\n          password: ${{ secrets.DOCKERHUB_TOKEN }}\n      - name: Update docker TAG from workflow input\n        if: github.event_name == 'workflow_dispatch'\n        run: |\n          export TAG=$TAG_PREFIX:${{github.event.inputs.image_tag}}-${CUDA_VERSION}\n          echo $TAG\n          echo \"TAG=${TAG}\" >> $GITHUB_ENV\n      - name: Build and push Docker image\n        run: |\n          echo $TAG\n          docker build . -f docker/Dockerfile -t ${TAG} --build-arg CUDA_VERSION=${CUDA_VERSION}\n          docker push $TAG\n      - name: Push Docker image as latest\n        if: endsWith(env.TAG, 'latest-cu12') == true\n        run: |\n          export latest_TAG=${TAG_PREFIX}:latest\n          echo $latest_TAG\n          docker tag $TAG $latest_TAG\n          docker push $latest_TAG\n      - name: Push docker image with released tag\n        if: startsWith(github.ref, 'refs/tags/') == true\n        run: |\n          export RELEASE_TAG=${TAG_PREFIX}:${{github.ref_name}}-${CUDA_VERSION}\n          echo $RELEASE_TAG\n          docker tag $TAG $RELEASE_TAG\n          docker push $RELEASE_TAG\n\n  publish_ascend_docker_image:\n    runs-on: ubuntu-latest\n    environment: 'prod'\n    env:\n      TAG_PREFIX: \"openmmlab/lmdeploy\"\n      TAG: \"openmmlab/lmdeploy:ascend\"\n    steps:\n      - name: Checkout repository\n        uses: actions/checkout@v4\n        with:\n          ref: ${{github.event.inputs.repo_ref}}\n      - name: Free disk space\n        uses: jlumbroso/free-disk-space@main\n        with:\n          # This might remove tools that are actually needed, if set to \"true\" but frees about 6 GB\n          tool-cache: false\n          docker-images: false\n          # All of these default to true, but feel free to set to \"false\" if necessary for your workflow\n          android: true\n          dotnet: true\n          haskell: true\n          large-packages: true\n          swap-storage: false\n      - name: Set up QEMU\n        uses: docker/setup-qemu-action@v3\n      - name: Set up Docker Buildx\n        uses: docker/setup-buildx-action@v3\n      - name: Get docker info\n        run: |\n          docker info\n          # remove http extraheader\n          git config --local --unset \"http.https://github.com/.extraheader\"\n      - name: Login to Docker Hub\n        uses: docker/login-action@v2\n        with:\n          username: ${{ secrets.DOCKERHUB_USERNAME }}\n          password: ${{ secrets.DOCKERHUB_TOKEN }}\n      - name: Update docker TAG from workflow input\n        if: github.event_name == 'workflow_dispatch'\n        run: |\n          export TAG=$TAG_PREFIX:${{github.event.inputs.image_tag}}-ascend\n          echo $TAG\n          echo \"TAG=${TAG}\" >> $GITHUB_ENV\n      - name: Build and push Docker image\n        run: |\n          echo $TAG\n          docker build . -t ${TAG} -f docker/Dockerfile_ascend_a3 --platform linux/arm64\n          docker push $TAG\n"
  },
  {
    "path": ".github/workflows/docker_dev.yml",
    "content": "name: publish-dev-docker\n\non:\n  workflow_dispatch:\n    inputs:\n      repo_ref:\n        required: false\n        description: 'Set branch or tag or commit id. Default is \"main\"'\n        type: string\n        default: 'main'\n\njobs:\n  publish_dev_docker_image:\n    runs-on: ubuntu-latest\n    environment: 'prod'\n    env:\n      TAG: \"openmmlab/lmdeploy:dev-cu12.8\"\n    steps:\n      - name: Checkout repository\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ github.event.inputs.repo_ref }}\n\n      - name: Free disk space\n        uses: jlumbroso/free-disk-space@v1.3.1\n        with:\n          # This might remove tools that are actually needed, if set to \"true\" but frees about 6 GB\n          tool-cache: false\n          docker-images: false\n          # All of these default to true, but feel free to set to \"false\" if necessary for your workflow\n          android: true\n          dotnet: true\n          haskell: true\n          large-packages: true\n          swap-storage: false\n\n      - name: Get docker info\n        run: |\n          docker info\n          # remove http extraheader\n          git config --local --unset \"http.https://github.com/.extraheader\"\n\n      - name: Login to Docker Hub\n        uses: docker/login-action@v2\n        with:\n          username: ${{ secrets.DOCKERHUB_USERNAME }}\n          password: ${{ secrets.DOCKERHUB_TOKEN }}\n\n      - name: Build and push Docker image\n        run: |\n          echo $TAG\n          docker build . -f docker/Dockerfile_dev -t ${TAG}\n          docker push $TAG\n"
  },
  {
    "path": ".github/workflows/evaluate.yml",
    "content": "name: evaluate\n\non:\n  workflow_dispatch:\n    inputs:\n      repo_org:\n        required: false\n        description: 'Tested repository organization name. Default is InternLM/lmdeploy'\n        type: string\n        default: 'InternLM/lmdeploy'\n      repo_ref:\n        required: false\n        description: 'Set branch or tag or commit id. Default is \"main\"'\n        type: string\n        default: 'main'\n      base_models:\n        required: true\n        description: 'Tested TurboMind models list. eg. [turbomind_qwen2_5_1_5b, turbomind_qwen2_5_7b, turbomind_qwen2_5_32b, turbomind_glm_4_9b, turbomind_llama_3_1_8b, turbomind_llama_3_70b, turbomind_qwen3_0_6b_base, turbomind_qwen3_8b_base, turbomind_qwen3_30b_A3B_base, pytorch_qwen2_5_1_5b, pytorch_qwen2_5_7b, pytorch_qwen2_5_32b, pytorch_gemma_2_9b, pytorch_llama_3_70b, pytorch_llama_3_1_8b, pytorch_qwen3_0_6b_base, pytorch_qwen3_8b_base, pytorch_qwen3_30b_A3B_base]'\n        type: string\n        default: '[turbomind_qwen2_5_1_5b, turbomind_qwen2_5_7b, turbomind_qwen2_5_32b, turbomind_glm_4_9b, turbomind_llama_3_1_8b, turbomind_llama_3_70b, turbomind_qwen3_0_6b_base, turbomind_qwen3_8b_base, turbomind_qwen3_30b_A3B_base, pytorch_qwen2_5_1_5b, pytorch_qwen2_5_7b, pytorch_qwen2_5_32b, pytorch_gemma_2_9b, pytorch_llama_3_70b, pytorch_llama_3_1_8b, pytorch_qwen3_0_6b_base, pytorch_qwen3_8b_base, pytorch_qwen3_30b_A3B_base]'\n      baes_datasets:\n        required: true\n        description: 'Tested datasets list. eg. [*mmlu_datasets, *gsm8k_datasets]'\n        type: string\n        default: '[*mmlu_datasets, *gsm8k_datasets, *gpqa_datasets, *winogrande_datasets]'\n      oc_repo_org:\n        required: false\n        description: 'Tested repository organization name. Default is open-compass/opencompass'\n        type: string\n        default: 'open-compass/opencompass'\n      oc_repo_ref:\n        required: false\n        description: 'Set branch or tag or commit id. Default is \"main\"'\n        type: string\n        default: 'main'\n      offline_mode:\n        required: true\n        description: 'Whether start a offline mode, if true, you should prepare code and whl package by yourself'\n        type: boolean\n        default: false\n\nenv:\n  ACTIONS_ALLOW_USE_UNSECURE_NODE_VERSION: true\n  COMPASS_DATA_CACHE: /nvme/qa_test_models/compass_data_cache\n\njobs:\n  linux-build:\n    if: ${{github.event_name == 'schedule' || (!cancelled() && !inputs.offline_mode)}}\n    strategy:\n      matrix:\n        pyver: [py310]\n    runs-on: ubuntu-latest\n    env:\n      PYTHON_VERSION: ${{ matrix.pyver }}\n      PLAT_NAME: manylinux2014_x86_64\n      DOCKER_TAG: cuda12.8\n      OUTPUT_FOLDER: cuda12.8_dist_${{ github.run_id }}\n    steps:\n      - name: Free disk space\n        uses: jlumbroso/free-disk-space@main\n        with:\n          # This might remove tools that are actually needed, if set to \"true\" but frees about 6 GB\n          tool-cache: false\n          docker-images: false\n          # All of these default to true, but feel free to set to \"false\" if necessary for your workflow\n          android: true\n          dotnet: true\n          haskell: true\n          large-packages: true\n          swap-storage: false\n      - name: Checkout repository\n        uses: actions/checkout@v6\n        with:\n          repository: ${{ github.event.inputs.repo_org || 'InternLM/lmdeploy' }}\n          ref: ${{github.event.inputs.repo_ref || 'main'}}\n      - name: Build\n        run: |\n          echo ${PYTHON_VERSION}\n          echo ${PLAT_NAME}\n          echo ${DOCKER_TAG}\n          echo ${OUTPUT_FOLDER}\n          echo ${GITHUB_RUN_ID}\n          # remove -it\n          sed -i 's/docker run --rm -it/docker run --rm/g' builder/manywheel/build_wheel.sh\n          bash builder/manywheel/build_wheel.sh ${PYTHON_VERSION} ${PLAT_NAME} ${DOCKER_TAG} ${OUTPUT_FOLDER}\n      - name: Upload Artifacts\n        uses: actions/upload-artifact@v4\n        with:\n          if-no-files-found: error\n          path: builder/manywheel/${{ env.OUTPUT_FOLDER }}\n          retention-days: 1\n          name: my-artifact-${{ github.run_id }}-${{ matrix.pyver }}\n\n  evaluate:\n    needs: linux-build\n    if: ${{github.event_name == 'schedule' || !cancelled()}}\n    runs-on: [self-hosted, linux-a100]\n    timeout-minutes: 4320 # 72hours\n    strategy:\n      fail-fast: false\n      matrix:\n        evaluate_type: ['base']\n    container:\n      image: openmmlab/lmdeploy:latest-cu12.8\n      options: \"--gpus=all --ipc=host --user root -e PIP_CACHE_DIR=/root/.cache/pip -e NVIDIA_DISABLE_REQUIRE=1 --pull never\"\n      volumes:\n        - /nvme/github-actions/pip-cache:/root/.cache/pip\n        - /nvme/github-actions/packages:/root/packages\n        - /nvme/github-actions/resources:/root/resources\n        - /nvme/qa_test_models/evaluation_report:/root/evaluation_report\n        - /nvme/qa_test_models:/root/models\n        - /mnt/121:/mnt/121\n        - /mnt/104:/mnt/104\n        - /mnt/bigdisk:/mnt/bigdisk\n        - /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro\n    steps:\n      - name: Setup systems\n        run: |\n          export TIME_STAMP=\"$(date +'%Y%m%d-%H%M%S')\"\n          echo \"TIME_STAMP=$TIME_STAMP\" >> $GITHUB_ENV\n      - name: Clone repository\n        uses: actions/checkout@v2\n        if: ${{github.event_name == 'schedule' || !inputs.offline_mode}}\n        with:\n          repository: ${{ github.event.inputs.repo_org || 'InternLM/lmdeploy' }}\n          ref: ${{github.event.inputs.repo_ref || 'main'}}\n      - name: Copy repository - offline\n        if: ${{inputs.offline_mode}}\n        run: cp -r /root/models/offline_pkg/lmdeploy/. .\n      - name: Download Artifacts\n        if: ${{github.event_name == 'schedule' || !inputs.offline_mode}}\n        uses: actions/download-artifact@v4\n        with:\n          name: my-artifact-${{ github.run_id }}-py310\n      - name: Install lmdeploy - dependency\n        run: |\n          python3 -m pip install -r /root/models/offline_pkg/requirements.txt\n      - name: Install lmdeploy\n        if: ${{github.event_name == 'schedule' || !inputs.offline_mode}}\n        run: |\n          python3 -m pip uninstall lmdeploy -y && python3 -m pip install lmdeploy-*.whl --no-deps\n          python3 -m pip install -r requirements/test.txt\n      - name: Install lmdeploy - offline\n        if: ${{inputs.offline_mode}}\n        run: |\n          python3 -m pip install /root/models/offline_pkg/py310/lmdeploy-*.whl --no-deps\n          python3 -m pip install -r requirements/test.txt\n      - name: Install opencompass\n        run: |\n          git clone https://github.com/${{ github.event.inputs.oc_repo_org}}.git\n          cd opencompass\n          git checkout ${{ github.event.inputs.oc_repo_ref}}\n          python3 -m pip install .\n          echo \"OPENCOMPASS_DIR=$(pwd)\" >> $GITHUB_ENV\n      - name: Check env\n        run: |\n          python3 -m pip list\n          lmdeploy check_env\n      - name: Setup paths for evaluation\n        run: |\n          ln -s /root/opencompass-data ./data\n          python3 .github/scripts/action_tools.py create_model_links /root/models .\n      - name: Evaluate base models\n        if: matrix.evaluate_type == 'base'\n        run: |\n          echo ${{github.event.inputs.base_models}}\n          echo ${{github.event.inputs.baes_datasets}}\n          export LMDEPLOY_DIR=$(pwd)\n          python3 .github/scripts/action_tools.py evaluate \"${{github.event.inputs.base_models}}\" \"${{github.event.inputs.baes_datasets}}\" /root/evaluation_report/${{ github.run_id }} base\n      - name: Clear workspace\n        if: always()\n        run: |\n          export workdir=$(pwd)\n          cd ..\n          rm -rf $workdir\n          mkdir $workdir\n          chmod -R 777 $workdir\n"
  },
  {
    "path": ".github/workflows/lint.yml",
    "content": "name: lint\n\non: [push, pull_request]\n\njobs:\n  lint:\n    runs-on: ubuntu-latest\n    steps:\n      - uses: actions/checkout@v2\n      - name: Set up Python 3.10\n        uses: actions/setup-python@v4\n        with:\n          python-version: '3.10'\n      - name: Install pre-commit hook\n        run: |\n          python -m pip install pre-commit\n          pre-commit install\n      - name: Linting\n        run: pre-commit run --all-files\n      - name: Check markdown link\n        uses: gaurav-nelson/github-action-markdown-link-check@v1\n        with:\n          use-quiet-mode: 'yes'\n          use-verbose-mode: 'yes'\n#          check-modified-files-only: 'yes'\n          config-file: '.github/md-link-config.json'\n          file-path: './README.md, ./LICENSE, ./README_zh-CN.md'\n      - name: Check module init files\n        run: |\n          python -m pip install fire\n          python .github/scripts/check_lmdeploy.py check_module_init lmdeploy\n      - name: Check doc link\n        run: |\n          python .github/scripts/doc_link_checker.py --target README_zh-CN.md\n          python .github/scripts/doc_link_checker.py --target README.md\n      - name: Check docstring coverage\n        run: |\n          python -m pip install interrogate\n          interrogate -v --exclude ./lmdeploy/pytorch_poc/modeling/ --ignore-init-method --ignore-magic --ignore-module --ignore-private --ignore-nested-functions --ignore-nested-classes --fail-under 70 lmdeploy\n      - name: Check pylint score\n        run: |\n          python -m pip install pylint\n          pylint lmdeploy\n"
  },
  {
    "path": ".github/workflows/linux_x64_gpu.yml",
    "content": "name: linux-x64-gpu\non:\n  push:\n    paths:\n      - '.github/workflows/linux_x64_gpu.yml'\n      - 'src/**'\n      - 'CMakeLists.txt'\n      - 'cmake/**'\n      - 'examples/**'\n      - '3rdparty/**'\n      - 'tests/csrc/**'\n  pull_request:\n    paths:\n      - '.github/workflows/linux_x64_gpu.yml'\n      - 'src/**'\n      - 'CMakeLists.txt'\n      - 'cmake/**'\n      - 'examples/**'\n      - '3rdparty/**'\n      - 'tests/csrc/**'\nconcurrency:\n  group: linux-x64-gpu-${{ github.ref }}\n  cancel-in-progress: true\npermissions:\n  contents: read\n\njobs:\n  build:\n    strategy:\n      fail-fast: false\n      matrix:\n        cudaver: [12.4, 12.8]\n    name: cuda-${{ matrix.cudaver }}\n    runs-on: ubuntu-latest\n    steps:\n      - name: Free disk space\n        uses: jlumbroso/free-disk-space@main\n        with:\n          # This might remove tools that are actually needed, if set to \"true\" but frees about 6 GB\n          tool-cache: false\n          docker-images: false\n          # All of these default to true, but feel free to set to \"false\" if necessary for your workflow\n          android: true\n          dotnet: true\n          haskell: true\n          large-packages: true\n          swap-storage: false\n      - name: Checkout repository\n        uses: actions/checkout@v3\n      - name: Build\n        run: |\n          docker run --rm \\\n            -v ${{ github.workspace }}:/work \\\n            -w /work \\\n            openmmlab/lmdeploy-builder:cuda${{ matrix.cudaver }} \\\n            bash -c \"\n              source /opt/conda/bin/activate && \\\n              conda activate py310 && \\\n              pip install build && \\\n              python -m build --wheel\n            \"\n"
  },
  {
    "path": ".github/workflows/mllm_api_eval.yml",
    "content": "name: mllm_api_eval\n\non:\n  workflow_dispatch:\n    inputs:\n      repo_org:\n        required: false\n        description: 'Tested repository organization name. Default is InternLM/lmdeploy'\n        type: string\n        default: 'InternLM/lmdeploy'\n      repo_ref:\n        required: false\n        description: 'Set branch or tag or commit id. Default is \"main\"'\n        type: string\n        default: 'main'\n      backend:\n        required: true\n        description: 'Set backend filter. Default is \"[\"turbomind\", \"pytorch\"]\"'\n        type: string\n        default: \"['turbomind', 'pytorch']\"\n      execution_mode:\n        required: false\n        description: 'Select execution mode: infer, eval, or both. Default is \"both\"'\n        type: choice\n        options:\n          - both\n          - infer\n          - eval\n        default: 'both'\n      run_id:\n        required: false\n        description: 'Set custom run ID. If not provided, github.run_id will be used'\n        type: string\n        default: ''\n\n\nenv:\n  HOST_PIP_CACHE_DIR: /nvme/github-actions/pip-cache\n  HOST_LOCALTIME: /usr/share/zoneinfo/Asia/Shanghai\n  ACTIONS_ALLOW_USE_UNSECURE_NODE_VERSION: true\n  REPORT_DIR: /nvme/qa_test_models/mllm_evaluation_report/allure_report/${{ inputs.repo_ref }}_${{ github.run_id }}\n  COV_PARAM: --cov /opt/py3/lib/python3.10/site-packages/lmdeploy\n  TEST_CODE_PATH: /nvme/qa_test_models/test_pkg/lmdeploy/${{ inputs.repo_ref }}_${{ github.run_id }}\n  OFFLINE_CODE_PATH: /nvme/qa_test_models/offline_pkg/lmdeploy\n  OFFLINE_REQUIREMENTS: /nvme/qa_test_models/offline_pkg/requirements.txt\n  DEEPSEEK_VL: /nvme/qa_test_models/offline_pkg/DeepSeek-VL\n  LMUData: /nvme/qa_test_models/LMUData\n  LOCAL_LLM: turbomind_Qwen2.5-32B-Instruct_nccl_tp2_0\n  OPENAI_API_KEY: sk-empty\n  HF_DATASETS_OFFLINE: 1\n  HF_DATASETS_CACHE: /nvme/qa_test_models/hf_datasets\n  HF_HUB_OFFLINE: 1\n  HF_EVALUATE_OFFLINE: 1\n  RUN_ID: ${{ inputs.repo_ref }}_${{ github.run_id }}\n\njobs:\n  linux-build:\n    if: ${{ !cancelled() }}\n    strategy:\n      matrix:\n        pyver: [py310]\n    runs-on: ubuntu-latest\n    env:\n      PYTHON_VERSION: ${{ matrix.pyver }}\n      PLAT_NAME: manylinux2014_x86_64\n      DOCKER_TAG: cuda12.8\n      OUTPUT_FOLDER: cuda12.8_dist_${{ github.run_id }}\n    steps:\n      - name: Free disk space\n        uses: jlumbroso/free-disk-space@main\n        with:\n          # This might remove tools that are actually needed, if set to \"true\" but frees about 6 GB\n          tool-cache: false\n          docker-images: false\n          # All of these default to true, but feel free to set to \"false\" if necessary for your workflow\n          android: true\n          dotnet: true\n          haskell: true\n          large-packages: true\n          swap-storage: false\n      - name: Checkout repository\n        uses: actions/checkout@v3\n        with:\n          repository: ${{ github.event.inputs.repo_org || 'InternLM/lmdeploy' }}\n          ref: ${{github.event.inputs.repo_ref || 'main'}}\n      - name: Build\n        run: |\n          echo ${PYTHON_VERSION}\n          echo ${PLAT_NAME}\n          echo ${DOCKER_TAG}\n          echo ${OUTPUT_FOLDER}\n          echo ${GITHUB_RUN_ID}\n          # remove -it\n          sed -i 's/docker run --rm -it/docker run --rm/g' builder/manywheel/build_wheel.sh\n          bash builder/manywheel/build_wheel.sh ${PYTHON_VERSION} ${PLAT_NAME} ${DOCKER_TAG} ${OUTPUT_FOLDER}\n      - name: Upload Artifacts\n        uses: actions/upload-artifact@v4\n        with:\n          if-no-files-found: error\n          path: builder/manywheel/${{ env.OUTPUT_FOLDER }}\n          retention-days: 1\n          name: my-artifact-${{ github.run_id }}-${{ matrix.pyver }}\n\n  download_pkgs:\n    needs: linux-build\n    if: ${{!cancelled()}}\n    runs-on: [self-hosted, linux-a100]\n    timeout-minutes: 50\n    container:\n      image: openmmlab/lmdeploy:latest-cu12.8\n      options: \"--gpus=all --ipc=host --user root -e PIP_CACHE_DIR=/root/.cache/pip -e NVIDIA_DISABLE_REQUIRE=1 --pull never\"\n      volumes:\n        - /nvme/qa_test_models:/nvme/qa_test_models\n        - /mnt/121:/mnt/121\n        - /mnt/104:/mnt/104\n        - /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro\n    steps:\n      - name: Clone repository\n        uses: actions/checkout@v2\n        if: ${{github.event_name == 'schedule' || !inputs.offline_mode}}\n        with:\n          repository: ${{ github.event.inputs.repo_org || 'InternLM/lmdeploy' }}\n          ref: ${{github.event.inputs.repo_ref || 'main'}}\n      - name: Copy repository\n        if: ${{github.event_name == 'schedule' || !inputs.offline_mode}}\n        run: rm -rf ${{env.TEST_CODE_PATH}} && mkdir ${{env.TEST_CODE_PATH}} && chmod 777 ${{env.TEST_CODE_PATH}} && cp -r . ${{env.TEST_CODE_PATH}}\n      - name: Copy repository - offline\n        if: ${{inputs.offline_mode}}\n        run: rm -rf ${{env.TEST_CODE_PATH}} && mkdir ${{env.TEST_CODE_PATH}} && chmod 777 ${{env.TEST_CODE_PATH}} && cp -r ${{env.OFFLINE_CODE_PATH}}/. ${{env.TEST_CODE_PATH}}\n      - name: Download Artifacts\n        if: ${{github.event_name == 'schedule' || !inputs.offline_mode}}\n        uses: actions/download-artifact@v4\n        with:\n          name: my-artifact-${{ github.run_id }}-py310\n      - name: Copy Artifacts\n        if: ${{github.event_name == 'schedule' || !inputs.offline_mode}}\n        run: rm ${{env.TEST_CODE_PATH}}/lmdeploy-*.whl -f && cp lmdeploy-*.whl ${{env.TEST_CODE_PATH}}\n      - name: Copy Artifacts - offline\n        if: ${{inputs.offline_mode}}\n        run: rm ${{env.TEST_CODE_PATH}}/lmdeploy-*.whl -f && cp ${{env.OFFLINE_CODE_PATH}}/lmdeploy-*.whl ${{env.TEST_CODE_PATH}}\n      - name: Mark as start\n        run: |\n          chmod -R 777 ${{env.TEST_CODE_PATH}}\n          mkdir ${{env.REPORT_DIR}} -p\n          echo \"starttime=$(date +%s)\" > ${{env.REPORT_DIR}}/status.txt\n\n  test_evaluation:\n    needs: download_pkgs\n    if: ${{ !cancelled() }}\n    runs-on: [self-hosted, linux-a100]\n    timeout-minutes: 2400\n    strategy:\n      fail-fast: false\n      matrix:\n        backend: ${{ fromJSON(inputs.backend || '[\"turbomind\", \"pytorch\"]')}}\n        gpu_num: ['gpu_num_1', 'gpu_num_2', 'gpu_num_4', 'gpu_num_8']\n        transformers: [\"\", \"legacy\"]\n    env:\n      TEST_ENV: ${{ matrix.transformers }}\n    container:\n      image: openmmlab/lmdeploy:latest-cu12.8\n      options: \"--gpus=all --ipc=host --user root -e PIP_CACHE_DIR=/root/.cache/pip -e NVIDIA_DISABLE_REQUIRE=1 --pull never\"\n      volumes:\n        - /nvme/github-actions/pip-cache:/root/.cache/pip\n        - /nvme/github-actions/packages:/root/packages\n        - /nvme/github-actions/resources:/root/resources\n        - /nvme/qa_test_models:/nvme/qa_test_models\n        - /nvme/huggingface_hub:/nvme/huggingface_hub\n        - /mnt/121:/mnt/121\n        - /mnt/104:/mnt/104\n        - /mnt/bigdisk:/mnt/bigdisk\n        - /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro\n    steps:\n      - name: Copy repository and Artifacts\n        run: |\n          cp -r ${{env.TEST_CODE_PATH}}/. .\n          mkdir ${{env.REPORT_DIR}} -p\n          echo \"starttime=$(date +%s)\" > ${{env.REPORT_DIR}}/status.txt\n      - name: Install lmdeploy - dependency\n        run: |\n          python3 -m pip install -r /nvme/qa_test_models/offline_pkg/requirements.txt\n      - name: Install lmdeploy\n        run: |\n          python3 -m pip uninstall lmdeploy -y && python3 -m pip install lmdeploy-*.whl --no-deps\n          python3 -m pip install -r requirements/test.txt\n      - name: Install vlmeval\n        run: |\n          python3 -m pip install pandas datasets scikit-learn pylatexenc math_verify\n          apt update && apt install -y libgl1 libglib2.0-0\n          cp -r /nvme/qa_test_models/offline_pkg/VLMEvalKit .\n          cd VLMEvalKit && pip install .\n      - name: Downgrade transformers\n        if: ${{matrix.transformers == 'legacy'}}\n        run: |\n          pip install transformers==4.57.6\n      - name: Check env\n        run: |\n          python3 -m pip list\n          lmdeploy check_env\n          mkdir ${{env.REPORT_DIR}} -p\n          echo \"starttime=$(date +%s)\" > ${{env.REPORT_DIR}}/status.txt\n      - name: Setup paths for evaluation\n        if: (matrix.backend == 'pytorch' || matrix.backend == 'turbomind')\n        run: |\n          unset HTTP_PROXY;unset HTTPS_PROXY;unset http_proxy;unset https_proxy;\n          cd VLMEvalKit && cp -r ../autotest .\n          execution_mode=\"${{ github.event.inputs.execution_mode || 'both' }}\"\n          ulimit -n 65535\n          if [ \"$execution_mode\" = \"both\" ] || [ \"$execution_mode\" = \"infer\" ]; then\n            pytest autotest/evaluate/test_mllm_api_evaluate.py -m \"${{matrix.gpu_num}} and ${{matrix.backend}} and infer\" --alluredir=${{env.REPORT_DIR}} || overall_exit=$?\n          fi\n          if [ \"$execution_mode\" = \"both\" ] || [ \"$execution_mode\" = \"eval\" ]; then\n            pytest autotest/evaluate/test_mllm_api_evaluate.py -m \"${{matrix.gpu_num}} and ${{matrix.backend}} and eval\" -n 4 --alluredir=${{env.REPORT_DIR}} || overall_exit=$?\n          fi\n          exit $overall_exit\n      - name: Clear workspace\n        if: always()\n        run: |\n          echo \"status=done\" >> ${{env.REPORT_DIR}}/status.txt\n          chmod -R 777 ${{env.REPORT_DIR}}\n          export workdir=$(pwd)\n          rm -rf $workdir/*\n"
  },
  {
    "path": ".github/workflows/pr_ete_test.yml",
    "content": "name: pr_ete_test\n\non:\n  pull_request:\n    paths:\n      - \".github/workflows/pr_ete_test.yml\"\n      - \"cmake/**\"\n      - \"src/**\"\n      - \"autotest/**\"\n      - \"3rdparty/**\"\n      - \"lmdeploy/**\"\n      - \"requirements/**\"\n      - \"requirements_cuda.txt\"\n      - \"CMakeLists.txt\"\n      - \"setup.py\"\n  workflow_dispatch:\n\nconcurrency:\n  group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}\n  cancel-in-progress: true\n\n\nenv:\n  HOST_PIP_CACHE_DIR: /nvme/github-actions/pip-cache\n  HOST_LOCALTIME: /usr/share/zoneinfo/Asia/Shanghai\n  ACTIONS_ALLOW_USE_UNSECURE_NODE_VERSION: true\n  PYTHONPATH: /nvme/qa_test_models/offline_pkg/LLaVA\n\n\njobs:\n  pr_functions_test:\n    runs-on: [self-hosted, linux-a100-pr]\n    timeout-minutes: 120\n    env:\n      REPORT_DIR: /nvme/qa_test_models/test-reports/${{ github.head_ref }}_${{ github.run_id }}\n      SERVER_LOG: /nvme/qa_test_models/server_log/${{ github.head_ref }}_${{ github.run_id }}\n    container:\n      image: openmmlab/lmdeploy:dev-cu12.8\n      options: --gpus all --ipc=host --user root -e PIP_CACHE_DIR=/root/.cache/pip --pull never\n      volumes:\n        - /nvme/share_data/github-actions/pip-cache:/root/.cache/pip\n        - /nvme/share_data/github-actions/packages:/root/packages\n        - /nvme/qa_test_models:/nvme/qa_test_models\n        - /mnt/121:/mnt/121\n        - /mnt/104:/mnt/104\n        - /mnt/bigdisk:/mnt/bigdisk\n        - /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro\n    steps:\n      - name: Clone repository\n        uses: actions/checkout@v2\n      - name: Install lmdeploy\n        run: |\n          python3 -m pip install -r requirements/lite.txt\n          python3 -m pip install -r requirements/test.txt\n          python3 -m pip install -e .\n      - name: Check env\n        run: |\n          python3 -m pip list\n          lmdeploy check_env\n          mkdir ${{env.REPORT_DIR}} -p\n          mkdir ${{env.SERVER_LOG}} -p\n          echo \"starttime=$(date +%s)\" > ${{env.REPORT_DIR}}/status.txt\n      - name: Test lmdeploy - func\n        run: |\n          pytest autotest -m 'pr_test and gpu_num_2' -x --alluredir=${{env.REPORT_DIR}} --clean-alluredir\n          pytest autotest -m 'pr_test and gpu_num_1' -n 2 -x --alluredir=${{env.REPORT_DIR}}\n      - name: Update transformers\n        run: |\n          pip install transformers==4.57.3\n      - name: Test restful server - turbomind Qwen3-32B\n        run: |\n          CUDA_VISIBLE_DEVICES=6,7 lmdeploy serve api_server /nvme/qa_test_models/Qwen/Qwen3-32B --tp 2 --backend turbomind --logprobs-mode raw_logprobs --allow-terminate-by-client > ${{env.SERVER_LOG}}/turbomind_Qwen3-32B_start_restful.log 2>&1 &\n          echo \"restful_pid=$!\"\n          for i in $(seq 1 180)\n          do\n            sleep 5\n            echo \"health check try $i\"\n            if curl -f -s http://127.0.0.1:23333/health > /dev/null 2>&1; then\n              pytest autotest/interface/restful/test_restful_chat_completions_v1.py -n 20 -k 'Qwen/Qwen3-32B and turbomind' -m 'not not_turbomind and not internlm2_5 and not interns1 and pr_test' --alluredir=${{env.REPORT_DIR}}\n              pytest autotest/interface/restful/test_restful_generate.py -n 20 -k 'Qwen/Qwen3-32B and turbomind' -m 'not not_turbomind and not experts' --alluredir=${{env.REPORT_DIR}}\n              curl -f -s http://127.0.0.1:23333/terminate > /dev/null 2>&1\n              exit 0\n            fi\n          done\n\n          echo \"health check fail\"\n          curl -f -s http://127.0.0.1:23333/terminate > /dev/null 2>&1\n          cat ${{env.SERVER_LOG}}/turbomind_Qwen3-32B_start_restful.log\n          exit 1\n      - name: Test restful server - turbomind InternVL3-38B\n        run: |\n          CUDA_VISIBLE_DEVICES=6,7 lmdeploy serve api_server /nvme/qa_test_models/OpenGVLab/InternVL3-38B --tp 2 --backend turbomind --logprobs-mode raw_logprobs --allow-terminate-by-client > ${{env.SERVER_LOG}}/turbomind_InternVL3-38B_start_restful.log 2>&1 &\n          echo \"restful_pid=$!\"\n          for i in $(seq 1 180)\n          do\n            sleep 5\n            echo \"health check try $i\"\n            if curl -f -s http://127.0.0.1:23333/health > /dev/null 2>&1; then\n              pytest autotest/interface/restful/test_restful_chat_completions_v1.py -n 20 -k 'OpenGVLab/InternVL3-38B and turbomind' -m 'not not_turbomind and not internlm2_5 and not interns1 and pr_test' --alluredir=${{env.REPORT_DIR}}\n              pytest autotest/interface/restful/test_restful_generate.py -n 20 -k 'OpenGVLab/InternVL3-38B and turbomind' -m 'not not_turbomind and not experts' --alluredir=${{env.REPORT_DIR}}\n              curl -f -s http://127.0.0.1:23333/terminate > /dev/null 2>&1\n              exit 0\n            fi\n          done\n\n          echo \"health check fail\"\n          curl -f -s http://127.0.0.1:23333/terminate > /dev/null 2>&1\n          cat ${{env.SERVER_LOG}}/turbomind_InternVL3-38B_start_restful.log\n          exit 1\n      - name: Test restful server - turbomind Qwen3-30B-A3B\n        run: |\n          CUDA_VISIBLE_DEVICES=6,7 lmdeploy serve api_server /nvme/qa_test_models/Qwen/Qwen3-30B-A3B --tp 2 --backend turbomind --logprobs-mode raw_logprobs  --allow-terminate-by-client> ${{env.SERVER_LOG}}/turbomind_Qwen3-30B-A3B_start_restful.log 2>&1 &\n          echo \"restful_pid=$!\"\n          for i in $(seq 1 180)\n          do\n            sleep 5\n            echo \"health check try $i\"\n            if curl -f -s http://127.0.0.1:23333/health > /dev/null 2>&1; then\n              pytest autotest/interface/restful/test_restful_chat_completions_v1.py -n 20 -k 'Qwen/Qwen3-30B-A3B and turbomind' -m 'not not_turbomind and not internlm2_5 and not interns1 and pr_test' --alluredir=${{env.REPORT_DIR}}\n              pytest autotest/interface/restful/test_restful_generate.py -n 20 -k 'Qwen/Qwen3-30B-A3B and turbomind' -m 'not not_turbomind and not experts' --alluredir=${{env.REPORT_DIR}}\n              curl -f -s http://127.0.0.1:23333/terminate > /dev/null 2>&1\n              exit 0\n            fi\n          done\n\n          echo \"health check fail\"\n          curl -f -s http://127.0.0.1:23333/terminate > /dev/null 2>&1\n          cat ${{env.SERVER_LOG}}/turbomind_Qwen3-30B-A3B_start_restful.log\n          exit 1\n      - name: Test restful server - pytorch Qwen3-30B-A3B\n        run: |\n          CUDA_VISIBLE_DEVICES=6,7 lmdeploy serve api_server /nvme/qa_test_models/Qwen/Qwen3-30B-A3B --tp 2 --backend pytorch --logprobs-mode raw_logprobs --enable-return-routed-experts --allow-terminate-by-client > ${{env.SERVER_LOG}}/pytorch_Qwen3-30B-A3B_start_restful.log 2>&1 &\n          echo \"restful_pid=$!\"\n          for i in $(seq 1 180)\n          do\n            sleep 5\n            echo \"health check try $i\"\n            if curl -f -s http://127.0.0.1:23333/health > /dev/null 2>&1; then\n              pytest autotest/interface/restful/test_restful_chat_completions_v1.py -n 20 -k 'Qwen/Qwen3-30B-A3B and pytorch' -m 'not not_pytorch and not internlm2_5 and not interns1 and pr_test' --alluredir=${{env.REPORT_DIR}}\n              pytest autotest/interface/restful/test_restful_generate.py -n 20 -k 'Qwen/Qwen3-30B-A3B and pytorch' -m 'not not_pytorch' --alluredir=${{env.REPORT_DIR}}\n              curl -f -s http://127.0.0.1:23333/terminate > /dev/null 2>&1\n              exit 0\n            fi\n          done\n\n          echo \"health check fail\"\n          curl -f -s http://127.0.0.1:23333/terminate > /dev/null 2>&1\n          cat ${{env.SERVER_LOG}}/pytorch_Qwen3-30B-A3B_start_restful.log\n          exit 1\n      - name: Test restful server - pytorch Qwen3-VL-30B-A3B-Instruct\n        run: |\n          CUDA_VISIBLE_DEVICES=6,7 lmdeploy serve api_server /nvme/qa_test_models/Qwen/Qwen3-VL-30B-A3B-Instruct --tp 2 --backend pytorch --logprobs-mode raw_logprobs --allow-terminate-by-client > ${{env.SERVER_LOG}}/pytorch_Qwen3-VL-30B-A3B-Instruct_start_restful.log 2>&1 &\n          echo \"restful_pid=$!\"\n          for i in $(seq 1 180)\n          do\n            sleep 5\n            echo \"health check try $i\"\n            if curl -f -s http://127.0.0.1:23333/health > /dev/null 2>&1; then\n              pytest autotest/interface/restful/test_restful_chat_completions_v1.py -n 20 -k 'Qwen/Qwen3-VL-30B-A3B-Instruct and pytorch' -m 'not not_pytorch and not internlm2_5 and not interns1 and pr_test' --alluredir=${{env.REPORT_DIR}}\n              pytest autotest/interface/restful/test_restful_generate.py -n 20 -k 'Qwen/Qwen3-VL-30B-A3B-Instruct and pytorch' -m 'not not_pytorch and not experts' --alluredir=${{env.REPORT_DIR}}\n              curl -f -s http://127.0.0.1:23333/terminate > /dev/null 2>&1\n              exit 0\n            fi\n          done\n\n          echo \"health check fail\"\n          curl -f -s http://127.0.0.1:23333/terminate > /dev/null 2>&1\n          cat ${{env.SERVER_LOG}}/pytorch_Qwen3-VL-30B-A3B-Instruct_start_restful.log\n          exit 1\n      - name: Test restful server - pytorch InternVL3_5-30B-A3B\n        run: |\n          CUDA_VISIBLE_DEVICES=6,7 lmdeploy serve api_server /nvme/qa_test_models/OpenGVLab/InternVL3_5-30B-A3B --tp 2 --backend pytorch --logprobs-mode raw_logprobs  --allow-terminate-by-client > ${{env.SERVER_LOG}}/pytorch_InternVL3_5-30B-A3B_start_restful.log 2>&1 &\n          echo \"restful_pid=$!\"\n          for i in $(seq 1 180)\n          do\n            sleep 5\n            echo \"health check try $i\"\n            if curl -f -s http://127.0.0.1:23333/health > /dev/null 2>&1; then\n              pytest autotest/interface/restful/test_restful_chat_completions_v1.py -n 20 -k 'OpenGVLab/InternVL3_5-30B-A3B and pytorch' -m 'not not_pytorch and not internlm2_5 and not interns1 and pr_test' --alluredir=${{env.REPORT_DIR}}\n              pytest autotest/interface/restful/test_restful_generate.py -n 20 -k 'OpenGVLab/InternVL3_5-30B-A3B and pytorch' -m 'not not_pytorch and not experts' --alluredir=${{env.REPORT_DIR}}\n              curl -f -s http://127.0.0.1:23333/terminate > /dev/null 2>&1\n              exit 0\n            fi\n          done\n\n          echo \"health check fail\"\n          curl -f -s http://127.0.0.1:23333/terminate > /dev/null 2>&1\n          cat ${{env.SERVER_LOG}}/pytorch_InternVL3_5-30B-A3B_start_restful.log\n          exit 1\n      - name: Clear workfile\n        if: always()\n        run: |\n          echo \"status=done\" >> ${{env.REPORT_DIR}}/status.txt\n          export workdir=$(pwd)\n          cd ..\n          rm -rf $workdir\n          mkdir $workdir\n          chmod -R 777 $workdir\n"
  },
  {
    "path": ".github/workflows/pypi.yml",
    "content": "name: publish to pypi\n\non:\n  push:\n    branches:\n      - main\n    paths:\n      - \"lmdeploy/version.py\"\n  workflow_dispatch:\n\n\njobs:\n  linux-build:\n    strategy:\n      matrix:\n        pyver: [py310, py311, py312, py313]\n    runs-on: ubuntu-latest\n    env:\n      PYTHON_VERSION: ${{ matrix.pyver }}\n      PLAT_NAME: manylinux2014_x86_64\n      DOCKER_TAG: cuda12.4\n      OUTPUT_FOLDER: cuda12_dist\n    steps:\n      - name: Free disk space\n        uses: jlumbroso/free-disk-space@main\n        with:\n          # This might remove tools that are actually needed, if set to \"true\" but frees about 6 GB\n          tool-cache: false\n          docker-images: false\n          # All of these default to true, but feel free to set to \"false\" if necessary for your workflow\n          android: true\n          dotnet: true\n          haskell: true\n          large-packages: true\n          swap-storage: false\n      - name: Checkout repository\n        uses: actions/checkout@v3\n      - name: Build\n        run: |\n          echo ${PYTHON_VERSION}\n          echo ${PLAT_NAME}\n          echo ${DOCKER_TAG}\n          echo ${OUTPUT_FOLDER}\n          # remove -it\n          sed -i 's/docker run --rm -it/docker run --rm/g' builder/manywheel/build_wheel.sh\n          bash builder/manywheel/build_wheel.sh ${PYTHON_VERSION} ${PLAT_NAME} ${DOCKER_TAG} ${OUTPUT_FOLDER}\n      - name: Upload Artifacts\n        uses: actions/upload-artifact@v4\n        with:\n          if-no-files-found: error\n          path: builder/manywheel/${{ env.OUTPUT_FOLDER }}/*\n          retention-days: 1\n          name: linux-${{ matrix.pyver }}\n\n  windows-build:\n    strategy:\n      matrix:\n        pyver: ['3.10', '3.11', '3.12', '3.13']\n    runs-on: windows-latest\n    steps:\n      - name: Set git for windows\n        run: |\n          git config --global core.longpaths true\n      - name: Checkout repository\n        uses: actions/checkout@v3\n      - name: Set up python\n        uses: actions/setup-python@v4\n        with:\n          python-version: ${{ matrix.pyver }}\n      - name: Install python packages\n        run: |\n          pip install build change-wheel-version\n      - name: Setup CUDA Toolkit\n        id: cuda-toolkit\n        shell: pwsh\n        run: ./builder/windows/setup_cuda.ps1\n        env:\n            INPUT_CUDA_VERSION: '12.6.2'\n      - name: Build wheel\n        run: |\n          python -m build --wheel -o build/wheel\n          Get-ChildItem -Path \"build\" -Filter \"*.whl\" | ForEach-Object { change_wheel_version $_.FullName --local-version cu121 --delete-old-wheel }\n      - name: Upload Artifacts\n        uses: actions/upload-artifact@v4\n        with:\n          if-no-files-found: error\n          path: build/wheel/*\n          retention-days: 1\n          name: windows-${{ matrix.pyver }}\n\n  publish:\n    runs-on: ubuntu-latest\n    environment: 'prod'\n    needs:\n      - linux-build\n      - windows-build\n    steps:\n      - name: Download artifacts\n        uses: actions/download-artifact@v4\n        with:\n          path: artifact\n          merge-multiple: true\n      - name: Display artifacts\n        run: ls artifact/ -lh\n      - name: Set up python 3.10\n        uses: actions/setup-python@v4\n        with:\n          python-version: '3.10'\n      - name: Upload to pypi\n        run: |\n          pip install twine\n          twine upload artifact/* -u __token__ -p ${{ secrets.pypi_password }}\n"
  },
  {
    "path": ".github/workflows/stable.yml",
    "content": "name: stable_test\n\non:\n  workflow_dispatch:\n    inputs:\n      repo_org:\n        required: false\n        description: 'Tested repository organization name. Default is InternLM'\n        type: string\n        default: 'InternLM/lmdeploy'\n      repo_ref:\n        required: false\n        description: 'Set branch or tag or commit id. Default is \"main\"'\n        type: string\n        default: 'main'\n      offline_mode:\n        required: true\n        description: 'Whether start a offline mode, if true, you should prepare code and whl package by yourself'\n        type: boolean\n        default: false\n  schedule:\n    - cron:  '00 8 * * 1'\n\nenv:\n  HOST_PIP_CACHE_DIR: /nvme/github-actions/pip-cache\n  HOST_LOCALTIME: /usr/share/zoneinfo/Asia/Shanghai\n  OUTPUT_FOLDER: cuda11.8_dist_${{ github.run_id }}\n  REPORT_DIR: /nvme/qa_test_models/stable_reports/${{ github.run_id }}\n  ACTIONS_ALLOW_USE_UNSECURE_NODE_VERSION: true\n  COMPASS_DATA_CACHE: /nvme/qa_test_models/dataset\n\njobs:\n  linux-build:\n    if: ${{github.event_name == 'schedule' || (!cancelled() && !inputs.offline_mode)}}\n    strategy:\n      matrix:\n        pyver: [py310]\n    runs-on: ubuntu-latest\n    env:\n      PYTHON_VERSION: ${{ matrix.pyver }}\n      PLAT_NAME: manylinux2014_x86_64\n      DOCKER_TAG: cuda11.8\n    steps:\n      - name: Checkout repository\n        uses: actions/checkout@v3\n        with:\n          repository: ${{ github.event.inputs.repo_org || 'InternLM/lmdeploy' }}\n          ref: ${{github.event.inputs.repo_ref || 'main'}}\n      - name: Build\n        run: |\n          echo ${PYTHON_VERSION}\n          echo ${PLAT_NAME}\n          echo ${DOCKER_TAG}\n          echo ${OUTPUT_FOLDER}\n          echo ${GITHUB_RUN_ID}\n          # remove -it\n          sed -i 's/docker run --rm -it/docker run --rm/g' builder/manywheel/build_wheel.sh\n          bash builder/manywheel/build_wheel.sh ${PYTHON_VERSION} ${PLAT_NAME} ${DOCKER_TAG} ${OUTPUT_FOLDER}\n      - name: Upload Artifacts\n        uses: actions/upload-artifact@v4\n        with:\n          if-no-files-found: error\n          path: builder/manywheel/${{ env.OUTPUT_FOLDER }}\n          retention-days: 1\n          name: my-artifact-${{ github.run_id }}-${{ matrix.pyver }}\n\n\n  benchmark:\n    needs: linux-build\n    if: ${{github.event_name == 'schedule' || !cancelled()}}\n    runs-on: [self-hosted, lmdeploy-stable]\n    timeout-minutes: 10080\n    strategy:\n      fail-fast: false\n      matrix:\n        model: ['internlm/internlm2_5-20b-chat']\n    container:\n      image: openmmlab/lmdeploy:latest-cu11\n      options: \"--gpus=all --ipc=host --user root -e PIP_CACHE_DIR=/root/.cache/pip -e NVIDIA_DISABLE_REQUIRE=1 -e NO_PROXY=localhost,127.0.0.1 -e no_proxy=localhost,127.0.0.1 --pull never\"\n      volumes:\n        - /nvme/github-actions/pip-cache:/root/.cache/pip\n        - /nvme/github-actions/packages:/root/packages\n        - /nvme/qa_test_models:/nvme/qa_test_models\n        - /mnt/187:/mnt/187\n        - /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro\n    steps:\n      - name: Clone repository\n        uses: actions/checkout@v3\n        if: ${{github.event_name == 'schedule' || !inputs.offline_mode}}\n        with:\n          repository: ${{ github.event.inputs.repo_org || 'InternLM/lmdeploy' }}\n          ref: ${{github.event.inputs.repo_ref || 'main'}}\n      - name: Copy repository - offline\n        if: ${{inputs.offline_mode}}\n        run: cp -r /nvme/qa_test_models/offline_pkg/lmdeploy/. .\n      - name: Download Artifacts\n        if: ${{github.event_name == 'schedule' || !inputs.offline_mode}}\n        uses: actions/download-artifact@v4\n        with:\n          name: my-artifact-${{ github.run_id }}-py310\n      - name: Install lmdeploy - dependency\n        run: |\n          # manually install flash attn\n          # the install packeage from. https://github.com/Dao-AILab/flash-attention/releases\n          python3 -m pip install /root/packages/flash_attn-2.6.3+cu118torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl\n          python3 -m pip install /root/packages/xformers-0.0.27+cu118-cp310-cp310-manylinux2014_x86_64.whl --no-deps\n          python3 -m pip install -r /nvme/qa_test_models/offline_pkg/requirements.txt\n      - name: Install lmdeploy\n        if: ${{github.event_name == 'schedule' || !inputs.offline_mode}}\n        run: |\n          python3 -m pip install lmdeploy-*.whl --no-deps\n          python3 -m pip install -r requirements/test.txt\n      - name: Install lmdeploy - offline\n        if: ${{inputs.offline_mode}}\n        run: |\n          python3 -m pip install /nvme/qa_test_models/offline_pkg/py310/lmdeploy-*.whl --no-deps\n          python3 -m pip install -r requirements/test.txt\n      - name: Install opencompass\n        run: |\n          git clone --depth=1 https://github.com/open-compass/opencompass.git\n          cd opencompass\n          python3 -m pip install -e .\n          cd ..\n      - name: Check env\n        run: |\n          python3 -m pip list\n          lmdeploy check_env\n      - name: Start restful api turbomind\n        run: |\n          mkdir ${{env.REPORT_DIR}} -p\n          CUDA_VISIBLE_DEVICES=6,7 lmdeploy serve api_server /nvme/qa_test_models/${{matrix.model}} --tp 2 --max-batch-size 256 --cache-max-entry-count 0.9 --server-port 23344 > ${{env.REPORT_DIR}}/restful.log 2>&1  &\n          echo \"restful_pid=$!\" >> \"$GITHUB_ENV\"\n          sleep 120s\n      - name: Run OC result\n        continue-on-error: true\n        run: |\n          ln -s /nvme/qa_test_models/dataset/data .\n          opencompass .github/scripts/eval_stable_object_config.py --reuse --dump-eval-details --work-dir ${{env.REPORT_DIR}}-object-1\n          opencompass .github/scripts/eval_stable_subject_config.py --reuse --dump-eval-details --work-dir ${{env.REPORT_DIR}}-subject-1\n          opencompass .github/scripts/eval_stable_object_config.py --reuse --dump-eval-details --work-dir ${{env.REPORT_DIR}}-object-2\n          opencompass .github/scripts/eval_stable_subject_config.py --reuse --dump-eval-details --work-dir ${{env.REPORT_DIR}}-subject-2\n          opencompass .github/scripts/eval_stable_object_config.py --reuse --dump-eval-details --work-dir ${{env.REPORT_DIR}}-object-3\n          opencompass .github/scripts/eval_stable_subject_config.py --reuse --dump-eval-details --work-dir ${{env.REPORT_DIR}}-subject-3\n      - name: Test lmdeploy - restful api\n        run: |\n          python3 benchmark/profile_restful_api.py --backend lmdeploy --base-url http://localhost:23344 --dataset-path /nvme/qa_test_models/datasets/ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 10000 --output-file ${{env.REPORT_DIR}}/stable.jsonl > ${{env.REPORT_DIR}}/stable.log\n          python3 /nvme/qa_test_models/offline_pkg/profile_restful_api_internal.py localhost:23344 /nvme/qa_test_models/${{matrix.model}} /nvme/qa_test_models/datasets/Mixed.json --stream-output True --num-prompts 100000 --csv ${{env.REPORT_DIR}}/stable-internal-1.csv > ${{env.REPORT_DIR}}/stable-internal-1.log\n          python3 /nvme/qa_test_models/offline_pkg/profile_restful_api_internal.py localhost:23344 /nvme/qa_test_models/${{matrix.model}} /nvme/qa_test_models/datasets/Mixed.json --stream-output True --num-prompts 100000 --csv ${{env.REPORT_DIR}}/stable-internal-2.csv > ${{env.REPORT_DIR}}/stable-internal-2.log\n          python3 /nvme/qa_test_models/offline_pkg/profile_restful_api_internal.py localhost:23344 /nvme/qa_test_models/${{matrix.model}} /nvme/qa_test_models/datasets/Mixed.json --stream-output True --num-prompts 100000 --csv ${{env.REPORT_DIR}}/stable-internal-3.csv > ${{env.REPORT_DIR}}/stable-internal-3.log\n          python3 /nvme/qa_test_models/offline_pkg/profile_restful_api_internal.py localhost:23344 /nvme/qa_test_models/${{matrix.model}} /nvme/qa_test_models/datasets/Mixed.json --stream-output True --num-prompts 100000 --csv ${{env.REPORT_DIR}}/stable-internal-2.csv > ${{env.REPORT_DIR}}/stable-internal-4.log\n          python3 /nvme/qa_test_models/offline_pkg/profile_restful_api_internal.py localhost:23344 /nvme/qa_test_models/${{matrix.model}} /nvme/qa_test_models/datasets/Mixed.json --stream-output True --num-prompts 100000 --csv ${{env.REPORT_DIR}}/stable-internal-3.csv > ${{env.REPORT_DIR}}/stable-internal-5.log\n      - name: Attach result\n        if: always()\n        run: |\n          python3 .github/scripts/action_tools.py generate_csv_from_profile_result ${{env.REPORT_DIR}}/stable.jsonl ${{env.REPORT_DIR}}/stable.csv\n          python3 .github/scripts/action_tools.py add_summary ${{env.REPORT_DIR}}/stable.csv\n          python3 .github/scripts/action_tools.py add_summary ${{env.REPORT_DIR}}/stable-internal-1.csv\n          python3 .github/scripts/action_tools.py add_summary ${{env.REPORT_DIR}}/stable-internal-2.csv\n          python3 .github/scripts/action_tools.py add_summary ${{env.REPORT_DIR}}/stable-internal-3.csv\n          python3 .github/scripts/action_tools.py add_summary ${{env.REPORT_DIR}}/stable-internal-4.csv\n          python3 .github/scripts/action_tools.py add_summary ${{env.REPORT_DIR}}/stable-internal-5.csv\n      - name: Kill api server\n        if: always()\n        run: |\n          kill -15 \"$restful_pid\"\n      - name: Clear workfile\n        if: always()\n        run: |\n          chmod -R 777 $REPORT_DIR\n          export workdir=$(pwd)\n          cd ..\n          rm -rf $workdir\n          mkdir $workdir\n          chmod -R 777 $workdir\n"
  },
  {
    "path": ".github/workflows/stale.yml",
    "content": "name: 'Close stale issues and PRs'\n\non:\n  schedule:\n    # check issue and pull request once at 01:30 a.m. every day\n    - cron: '30 1 * * *'\n\npermissions:\n  contents: read\n\njobs:\n  stale:\n    permissions:\n      issues: write\n      pull-requests: write\n    runs-on: ubuntu-latest\n    steps:\n      - uses: actions/stale@v7\n        with:\n          stale-issue-message: 'This issue is marked as stale because it has been marked as invalid or awaiting response for 7 days without any further response. It will be closed in 5 days if the stale label is not removed or if there is no further response.'\n          stale-pr-message: 'This PR is marked as stale because there has been no activity in the past 45 days. It will be closed in 10 days if the stale label is not removed or if there is no further updates.'\n          close-issue-message: 'This issue is closed because it has been stale for 5 days. Please open a new issue if you have similar issues or you have any new updates now.'\n          close-pr-message: 'This PR is closed because it has been stale for 10 days. Please reopen this PR if you have any updates and want to keep contributing the code.'\n          # only issues/PRS with following labels are checked\n          any-of-labels: 'invalid, awaiting response, duplicate'\n          days-before-issue-stale: 7\n          days-before-pr-stale: 45\n          days-before-issue-close: 5\n          days-before-pr-close: 10\n          # automatically remove the stale label when the issues or the pull requests are updated or commented\n          remove-stale-when-updated: true\n          operations-per-run: 50\n"
  },
  {
    "path": ".github/workflows/test_docker.yml",
    "content": "name: test-docker\n\non:\n  push:\n    paths:\n      - 'docker/**'\n      - '.github/workflows/*docker.yml'\n  pull_request:\n    paths:\n      - 'docker/**'\n      - '.github/workflows/*docker.yml'\n\nconcurrency:\n  group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}\n  cancel-in-progress: true\n\njobs:\n  test_docker_image:\n    permissions:\n      pull-requests: write\n    runs-on: ubuntu-latest\n    strategy:\n      matrix:\n        cuda_version: [cu13, cu12]\n        python_version: ['3.10', '3.11', '3.12', '3.13']\n    env:\n      CUDA_VERSION: ${{ matrix.cuda_version }}\n      PYTHON_VERSION: ${{ matrix.python_version }}\n    steps:\n      - name: Checkout repository\n        uses: actions/checkout@v3\n        with:\n          ref: ${{github.event.inputs.repo_ref}}\n      - name: Free disk space\n        uses: jlumbroso/free-disk-space@main\n        with:\n          # This might remove tools that are actually needed, if set to \"true\" but frees about 6 GB\n          tool-cache: false\n          docker-images: false\n          # All of these default to true, but feel free to set to \"false\" if necessary for your workflow\n          android: true\n          dotnet: true\n          haskell: true\n          large-packages: true\n          swap-storage: false\n      - name: Get docker info\n        run: |\n          docker info\n          # remove http extraheader\n          git config --local --unset \"http.https://github.com/.extraheader\"\n      - name: Build Docker image\n        run: |\n          docker build . -t lmdeploy:latest -f docker/Dockerfile --build-arg CUDA_VERSION=${CUDA_VERSION} --build-arg PYTHON_VERSION=${PYTHON_VERSION}\n      - name: Test image with lmdeploy check_env\n        run: |\n          docker images\n          docker run --rm lmdeploy:latest lmdeploy check_env\n      - name: Dive\n        if: ${{ matrix.cuda_version == 'cu12' }}\n        uses: MaxymVlasov/dive-action@v1.5.0\n        with:\n          image: lmdeploy:latest\n          github-token: ${{ secrets.GITHUB_TOKEN }}\n\n  test_ascend_docker_image:\n    permissions:\n      pull-requests: write\n    runs-on: ubuntu-22.04-arm\n    steps:\n      - name: Checkout repository\n        uses: actions/checkout@v3\n        with:\n          ref: ${{github.event.inputs.repo_ref}}\n      - name: Free disk space\n        uses: jlumbroso/free-disk-space@main\n        with:\n          # This might remove tools that are actually needed, if set to \"true\" but frees about 6 GB\n          tool-cache: false\n          docker-images: false\n          # All of these default to true, but feel free to set to \"false\" if necessary for your workflow\n          android: true\n          dotnet: true\n          haskell: true\n          large-packages: true\n          swap-storage: false\n      - name: Set up Docker Buildx\n        uses: docker/setup-buildx-action@v3\n      - name: Get docker info\n        run: |\n          docker info\n          # remove http extraheader\n          git config --local --unset \"http.https://github.com/.extraheader\"\n      - name: Build Docker image\n        run: |\n          docker build . -t lmdeploy:ascend -f docker/Dockerfile_ascend_a3\n#      - name: Test image with lmdeploy check_env\n#        run: |\n#          docker images\n#          docker run --rm lmdeploy:ascend lmdeploy check_env\n      - name: Dive\n        uses: MaxymVlasov/dive-action@v1.5.0\n        with:\n          image: lmdeploy:ascend\n          github-token: ${{ secrets.GITHUB_TOKEN }}\n\n  test_jetson_docker_image:\n    permissions:\n      pull-requests: write\n    runs-on: ubuntu-22.04-arm\n    steps:\n      - name: Checkout repository\n        uses: actions/checkout@v3\n        with:\n          ref: ${{github.event.inputs.repo_ref}}\n      - name: Free disk space\n        uses: jlumbroso/free-disk-space@main\n        with:\n          # This might remove tools that are actually needed, if set to \"true\" but frees about 6 GB\n          tool-cache: false\n          docker-images: false\n          # All of these default to true, but feel free to set to \"false\" if necessary for your workflow\n          android: true\n          dotnet: true\n          haskell: true\n          large-packages: true\n          swap-storage: false\n      - name: Set up Docker Buildx\n        uses: docker/setup-buildx-action@v3\n      - name: Get docker info\n        run: |\n          docker info\n          # remove http extraheader\n          git config --local --unset \"http.https://github.com/.extraheader\"\n      - name: Build Docker image\n        run: |\n          docker build . -t lmdeploy:jetson -f docker/Dockerfile.jetson\n      - name: Test image with lmdeploy check_env\n        run: |\n          docker images\n          docker run --rm lmdeploy:jetson lmdeploy check_env\n      - name: Dive\n        uses: MaxymVlasov/dive-action@v1.5.0\n        with:\n          image: lmdeploy:jetson\n          github-token: ${{ secrets.GITHUB_TOKEN }}\n"
  },
  {
    "path": ".github/workflows/unit_test.yml",
    "content": "name: unit-test\n\non:\n  pull_request:\n    paths:\n      - \".github/workflows/unit_test.yml\"\n      - \"cmake/**\"\n      - \"src/**\"\n      - \"tests/**\"\n      - \"3rdparty/**\"\n      - \"lmdeploy/**\"\n      - \"requirements/**\"\n      - \"requirements_cuda.txt\"\n      - \"CMakeLists.txt\"\n      - \"setup.py\"\n  push:\n    branches:\n      - main\n    paths:\n      - \".github/workflows/unit_test.yml\"\n      - \"cmake/**\"\n      - \"src/**\"\n      - \"tests/**\"\n      - \"3rdparty/**\"\n      - \"lmdeploy/**\"\n      - \"requirements/**\"\n      - \"requirements_cuda.txt\"\n      - \"CMakeLists.txt\"\n      - \"setup.py\"\n    tags:\n      - \"v*.*.*\"\n\njobs:\n  unit_test:\n    runs-on: [self-hosted, linux-a100-s2]\n    timeout-minutes: 4320 # 72hours\n    container:\n      image: openmmlab/lmdeploy:dev-cu12.8\n      options: \"--gpus=all --ipc=host --user root -e PIP_CACHE_DIR=/root/.cache/pip -e CUDA_VISIBLE_DEVICES=2,3 -e HF_HOME=/root/.cache/huggingface --pull never\"\n      volumes:\n        - /nvme/share_data/github-actions/pip-cache:/root/.cache/pip\n        - /nvme/share_data/github-actions/hf_home:/root/.cache/huggingface\n        - /nvme/share_data/github-actions/packages:/root/packages\n        - /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro\n    steps:\n      - name: Clone repository\n        uses: actions/checkout@v5\n      - name: Install lmdeploy\n        run: |\n          python3 -m pip install -r requirements/test.txt\n          python3 -m pip install -e .\n      - name: Check env\n        run: |\n          python3 -m pip list\n          lmdeploy check_env\n      - name: Test lmdeploy python UT\n        run: |\n          coverage run --branch --source lmdeploy -m pytest -rsE tests\n          coverage xml\n          coverage report -m\n      - name: Clear workfile\n        if: always()\n        run: |\n          export workdir=$(pwd)\n          cd ..\n          rm -rf $workdir\n          mkdir $workdir\n          chmod -R 777 $workdir\n"
  },
  {
    "path": ".github/workflows/windows_x64_gpu.yml",
    "content": "name: windows-x64-gpu\non:\n  push:\n    paths:\n      - '.github/workflows/windows_x64_gpu.yml'\n      - 'src/**'\n      - 'CMakeLists.txt'\n      - 'cmake/**'\n      - 'examples/**'\n      - '3rdparty/**'\n      - 'tests/csrc/**'\n  pull_request:\n    paths:\n      - '.github/workflows/windows_x64_gpu.yml'\n      - 'src/**'\n      - 'CMakeLists.txt'\n      - 'cmake/**'\n      - 'examples/**'\n      - '3rdparty/**'\n      - 'tests/csrc/**'\nconcurrency:\n  group: windows-x64-gpu-${{ github.ref }}\n  cancel-in-progress: true\npermissions:\n  contents: read\n\njobs:\n  build:\n    strategy:\n      fail-fast: false\n      matrix:\n        cudaver: [12.6.2, 12.8.1]\n    name: cuda-${{ matrix.cudaver }}\n    runs-on: windows-latest\n    steps:\n      - name: Set git for windows\n        run: |\n          git config --global core.longpaths true\n      - name: Checkout repository\n        uses: actions/checkout@v3\n      - name: Set up python\n        uses: actions/setup-python@v4\n        with:\n          python-version: '3.10'\n      - name: Install python packages\n        run: |\n          pip install build\n      - name: Setup CUDA Toolkit\n        id: cuda-toolkit\n        shell: pwsh\n        run: ./builder/windows/setup_cuda.ps1\n        env:\n            INPUT_CUDA_VERSION: ${{ matrix.cudaver }}\n      - name: Build wheel\n        run: |\n          python -m build --wheel\n"
  },
  {
    "path": ".gitignore",
    "content": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n.vscode/\n.idea/\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\ntriton-rerope/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\n.venv/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\ntmp/\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.coverage\n.coverage.*\n.cache\n*build*/\n!builder/\nlmdeploy/lib/\nlmdeploy/bin/\ndist/\nexamples/cpp/llama/*.csv\n*.npy\n*.weight\ninstall/\n/docs/*/_static/*.yaml\n\n# LMDeploy\nworkspace/\nwork_dir*/\n\n# Huggingface\n*.bin\n*config.json\n*generate_config.json\n!lmdeploy/turbomind/hf_repo/config.json\n\n# Pytorch\n*.pt\n*.pth\n*.py~\n*.sh~\n*.pyc\n**/src/pytorch-sphinx-theme/\n\n# Outputs and logs\n*.txt\n*.log\n*.out\n*.csv\n!start_ids.csv\n*.pkl\n\n!CMakeLists.txt\nproxy_config.yml\n"
  },
  {
    "path": ".pre-commit-config.yaml",
    "content": "repos:\n  - repo: https://github.com/PyCQA/flake8\n    rev: 5.0.4\n    hooks:\n      - id: flake8\n        args: ['--extend-ignore=E231', \"--max-line-length=120\"]\n  - repo: https://github.com/PyCQA/isort\n    rev: 5.11.5\n    hooks:\n      - id: isort\n        args: [\"--line-length=120\"]\n  - repo: https://github.com/google/yapf\n    rev: v0.43.0\n    hooks:\n      - id: yapf\n        args: ['-i', '--style={based_on_style: pep8, column_limit: 120}']\n  - repo: https://github.com/pre-commit/pre-commit-hooks\n    rev: v4.3.0\n    hooks:\n      - id: trailing-whitespace\n      - id: check-yaml\n      - id: end-of-file-fixer\n      - id: requirements-txt-fixer\n      - id: double-quote-string-fixer\n      - id: check-merge-conflict\n      - id: fix-encoding-pragma\n        args: [\"--remove\"]\n      - id: mixed-line-ending\n        args: [\"--fix=lf\"]\n\n  - repo: https://github.com/executablebooks/mdformat\n    rev: 0.7.9\n    hooks:\n      - id: mdformat\n        args: [\"--number\"]\n        additional_dependencies:\n          - mdformat-openmmlab\n          - mdformat_frontmatter\n          - linkify-it-py\n  - repo: https://github.com/codespell-project/codespell\n    rev: v2.1.0\n    hooks:\n      - id: codespell\n        args: [\"--skip=third_party/*,*.ipynb,*.proto,src/turbomind/*,docker/Dockerfile_ascend*,docs/en/get_started/ascend/get_started.md,docs/zh_cn/get_started/ascend/get_started.md\"]\n\n\n  - repo: https://github.com/myint/docformatter\n    rev: v1.7.7\n    hooks:\n      - id: docformatter\n        language_version: python3.10\n        args: [\"--in-place\", \"--wrap-descriptions\", \"120\"]\n\n  - repo: https://github.com/open-mmlab/pre-commit-hooks\n    rev: v0.2.0\n    hooks:\n    -   id: check-copyright\n        args: [\"lmdeploy\"]\n\n  - repo: https://github.com/pre-commit/mirrors-clang-format\n    rev: v11.1.0\n    hooks:\n      - id: clang-format\n        files: ^src/\n        types_or: [c, c++, cuda]\n\nexclude: |\n  (?x)(\n    ^cmake/.*\\.patch$\n  )\n"
  },
  {
    "path": ".pylintrc",
    "content": "[MASTER]\n\n# A comma-separated list of package or module names from where C extensions may\n# be loaded. Extensions are loading into the active Python interpreter and may\n# run arbitrary code.\nextension-pkg-whitelist=\n\n# Specify a score threshold to be exceeded before program exits with error.\nfail-under=8.5\n\n# Add files or directories to the blacklist. They should be base names, not\n# paths.\nignore=CVS,configs\n\n# Add files or directories matching the regex patterns to the blacklist. The\n# regex matches against base names, not paths.\nignore-patterns=\n\n# Python code to execute, usually for sys.path manipulation such as\n# pygtk.require().\n#init-hook=\n\n# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the\n# number of processors available to use.\njobs=1\n\n# Control the amount of potential inferred values when inferring a single\n# object. This can help the performance when dealing with large functions or\n# complex, nested conditions.\nlimit-inference-results=100\n\n# List of plugins (as comma separated values of python module names) to load,\n# usually to register additional checkers.\nload-plugins=\n\n# Pickle collected data for later comparisons.\npersistent=yes\n\n# When enabled, pylint would attempt to guess common misconfiguration and emit\n# user-friendly hints instead of false-positive error messages.\nsuggestion-mode=yes\n\n# Allow loading of arbitrary C extensions. Extensions are imported into the\n# active Python interpreter and may run arbitrary code.\nunsafe-load-any-extension=no\n\n\n[MESSAGES CONTROL]\n\n# Only show warnings with the listed confidence levels. Leave empty to show\n# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED.\nconfidence=\n\n# Disable the message, report, category or checker with the given id(s). You\n# can either give multiple identifiers separated by comma (,) or put this\n# option multiple times (only on the command line, not in the configuration\n# file where it should appear only once). You can also use \"--disable=all\" to\n# disable everything first and then reenable specific checks. For example, if\n# you want to run only the similarities checker, you can use \"--disable=all\n# --enable=similarities\". If you want to run only the classes checker, but have\n# no Warning level messages displayed, use \"--disable=all --enable=classes\n# --disable=W\".\ndisable=print-statement,\n        parameter-unpacking,\n        unpacking-in-except,\n        old-raise-syntax,\n        backtick,\n        long-suffix,\n        old-ne-operator,\n        old-octal-literal,\n        import-star-module-level,\n        non-ascii-bytes-literal,\n        raw-checker-failed,\n        bad-inline-option,\n        locally-disabled,\n        file-ignored,\n        suppressed-message,\n        useless-suppression,\n        deprecated-pragma,\n        use-symbolic-message-instead,\n        apply-builtin,\n        basestring-builtin,\n        buffer-builtin,\n        cmp-builtin,\n        coerce-builtin,\n        execfile-builtin,\n        file-builtin,\n        long-builtin,\n        raw_input-builtin,\n        reduce-builtin,\n        standarderror-builtin,\n        unicode-builtin,\n        xrange-builtin,\n        coerce-method,\n        delslice-method,\n        getslice-method,\n        setslice-method,\n        no-absolute-import,\n        old-division,\n        dict-iter-method,\n        dict-view-method,\n        next-method-called,\n        metaclass-assignment,\n        indexing-exception,\n        raising-string,\n        reload-builtin,\n        oct-method,\n        hex-method,\n        nonzero-method,\n        cmp-method,\n        input-builtin,\n        round-builtin,\n        intern-builtin,\n        unichr-builtin,\n        map-builtin-not-iterating,\n        zip-builtin-not-iterating,\n        range-builtin-not-iterating,\n        filter-builtin-not-iterating,\n        using-cmp-argument,\n        eq-without-hash,\n        div-method,\n        idiv-method,\n        rdiv-method,\n        exception-message-attribute,\n        invalid-str-codec,\n        sys-max-int,\n        bad-python3-import,\n        deprecated-string-function,\n        deprecated-str-translate-call,\n        deprecated-itertools-function,\n        deprecated-types-field,\n        next-method-defined,\n        dict-items-not-iterating,\n        dict-keys-not-iterating,\n        dict-values-not-iterating,\n        deprecated-operator-function,\n        deprecated-urllib-function,\n        xreadlines-attribute,\n        deprecated-sys-function,\n        exception-escape,\n        comprehension-escape,\n        no-member,\n        invalid-name,\n        too-many-branches,\n        wrong-import-order,\n        too-many-arguments,\n        missing-function-docstring,\n        missing-module-docstring,\n        too-many-locals,\n        too-few-public-methods,\n        abstract-method,\n        broad-except,\n        too-many-nested-blocks,\n        too-many-instance-attributes,\n        missing-class-docstring,\n        duplicate-code,\n        not-callable,\n        protected-access,\n        dangerous-default-value,\n        no-name-in-module,\n        logging-fstring-interpolation,\n        super-init-not-called,\n        redefined-builtin,\n        attribute-defined-outside-init,\n        arguments-differ,\n        cyclic-import,\n        bad-super-call,\n        too-many-statements,\n        unused-argument,\n        import-outside-toplevel,\n        import-error,\n        super-with-arguments\n\n# Enable the message, report, category or checker with the given id(s). You can\n# either give multiple identifier separated by comma (,) or put this option\n# multiple time (only on the command line, not in the configuration file where\n# it should appear only once). See also the \"--disable\" option for examples.\nenable=c-extension-no-member\n\n\n[REPORTS]\n\n# Python expression which should return a score less than or equal to 10. You\n# have access to the variables 'error', 'warning', 'refactor', and 'convention'\n# which contain the number of messages in each category, as well as 'statement'\n# which is the total number of statements analyzed. This score is used by the\n# global evaluation report (RP0004).\nevaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)\n\n# Template used to display messages. This is a python new-style format string\n# used to format the message information. See doc for all details.\n#msg-template=\n\n# Set the output format. Available formats are text, parseable, colorized, json\n# and msvs (visual studio). You can also give a reporter class, e.g.\n# mypackage.mymodule.MyReporterClass.\noutput-format=text\n\n# Tells whether to display a full report or only the messages.\nreports=yes\n\n# Activate the evaluation score.\nscore=yes\n\n\n[REFACTORING]\n\n# Maximum number of nested blocks for function / method body\nmax-nested-blocks=5\n\n# Complete name of functions that never returns. When checking for\n# inconsistent-return-statements if a never returning function is called then\n# it will be considered as an explicit return statement and no message will be\n# printed.\nnever-returning-functions=sys.exit\n\n\n[TYPECHECK]\n\n# List of decorators that produce context managers, such as\n# contextlib.contextmanager. Add to this list to register other decorators that\n# produce valid context managers.\ncontextmanager-decorators=contextlib.contextmanager\n\n# List of members which are set dynamically and missed by pylint inference\n# system, and so shouldn't trigger E1101 when accessed. Python regular\n# expressions are accepted.\ngenerated-members=\n\n# Tells whether missing members accessed in mixin class should be ignored. A\n# mixin class is detected if its name ends with \"mixin\" (case insensitive).\nignore-mixin-members=yes\n\n# Tells whether to warn about missing members when the owner of the attribute\n# is inferred to be None.\nignore-none=yes\n\n# This flag controls whether pylint should warn about no-member and similar\n# checks whenever an opaque object is returned when inferring. The inference\n# can return multiple potential results while evaluating a Python object, but\n# some branches might not be evaluated, which results in partial inference. In\n# that case, it might be useful to still emit no-member and other checks for\n# the rest of the inferred objects.\nignore-on-opaque-inference=yes\n\n# List of class names for which member attributes should not be checked (useful\n# for classes with dynamically set attributes). This supports the use of\n# qualified names.\nignored-classes=optparse.Values,thread._local,_thread._local\n\n# List of module names for which member attributes should not be checked\n# (useful for modules/projects where namespaces are manipulated during runtime\n# and thus existing member attributes cannot be deduced by static analysis). It\n# supports qualified module names, as well as Unix pattern matching.\nignored-modules=\n\n# Show a hint with possible names when a member name was not found. The aspect\n# of finding the hint is based on edit distance.\nmissing-member-hint=yes\n\n# The minimum edit distance a name should have in order to be considered a\n# similar match for a missing member name.\nmissing-member-hint-distance=1\n\n# The total number of similar names that should be taken in consideration when\n# showing a hint for a missing member.\nmissing-member-max-choices=1\n\n# List of decorators that change the signature of a decorated function.\nsignature-mutators=\n\n\n[SPELLING]\n\n# Limits count of emitted suggestions for spelling mistakes.\nmax-spelling-suggestions=4\n\n# Spelling dictionary name. Available dictionaries: none. To make it work,\n# install the python-enchant package.\nspelling-dict=\n\n# List of comma separated words that should not be checked.\nspelling-ignore-words=\n\n# A path to a file that contains the private dictionary; one word per line.\nspelling-private-dict-file=\n\n# Tells whether to store unknown words to the private dictionary (see the\n# --spelling-private-dict-file option) instead of raising a message.\nspelling-store-unknown-words=no\n\n\n[LOGGING]\n\n# The type of string formatting that logging methods do. `old` means using %\n# formatting, `new` is for `{}` formatting.\nlogging-format-style=old\n\n# Logging modules to check that the string format arguments are in logging\n# function parameter format.\nlogging-modules=logging\n\n\n[VARIABLES]\n\n# List of additional names supposed to be defined in builtins. Remember that\n# you should avoid defining new builtins when possible.\nadditional-builtins=\n\n# Tells whether unused global variables should be treated as a violation.\nallow-global-unused-variables=yes\n\n# List of strings which can identify a callback function by name. A callback\n# name must start or end with one of those strings.\ncallbacks=cb_,\n          _cb\n\n# A regular expression matching the name of dummy variables (i.e. expected to\n# not be used).\ndummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_\n\n# Argument names that match this expression will be ignored. Default to name\n# with leading underscore.\nignored-argument-names=_.*|^ignored_|^unused_\n\n# Tells whether we should check for unused import in __init__ files.\ninit-import=no\n\n# List of qualified module names which can have objects that can redefine\n# builtins.\nredefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io\n\n\n[FORMAT]\n\n# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.\nexpected-line-ending-format=\n\n# Regexp for a line that is allowed to be longer than the limit.\nignore-long-lines=^\\s*(# )?<?https?://\\S+>?$\n\n# Number of spaces of indent required inside a hanging or continued line.\nindent-after-paren=4\n\n# String used as indentation unit. This is usually \"    \" (4 spaces) or \"\\t\" (1\n# tab).\nindent-string='    '\n\n# Maximum number of characters on a single line.\nmax-line-length=100\n\n# Maximum number of lines in a module.\nmax-module-lines=1000\n\n# Allow the body of a class to be on the same line as the declaration if body\n# contains single statement.\nsingle-line-class-stmt=no\n\n# Allow the body of an if to be on the same line as the test if there is no\n# else.\nsingle-line-if-stmt=no\n\n\n[STRING]\n\n# This flag controls whether inconsistent-quotes generates a warning when the\n# character used as a quote delimiter is used inconsistently within a module.\ncheck-quote-consistency=no\n\n# This flag controls whether the implicit-str-concat should generate a warning\n# on implicit string concatenation in sequences defined over several lines.\ncheck-str-concat-over-line-jumps=no\n\n\n[SIMILARITIES]\n\n# Ignore comments when computing similarities.\nignore-comments=yes\n\n# Ignore docstrings when computing similarities.\nignore-docstrings=yes\n\n# Ignore imports when computing similarities.\nignore-imports=no\n\n# Minimum lines number of a similarity.\nmin-similarity-lines=4\n\n\n[MISCELLANEOUS]\n\n# List of note tags to take in consideration, separated by a comma.\nnotes=FIXME,\n      XXX,\n      TODO\n\n# Regular expression of note tags to take in consideration.\n#notes-rgx=\n\n\n[BASIC]\n\n# Naming style matching correct argument names.\nargument-naming-style=snake_case\n\n# Regular expression matching correct argument names. Overrides argument-\n# naming-style.\n#argument-rgx=\n\n# Naming style matching correct attribute names.\nattr-naming-style=snake_case\n\n# Regular expression matching correct attribute names. Overrides attr-naming-\n# style.\n#attr-rgx=\n\n# Bad variable names which should always be refused, separated by a comma.\nbad-names=foo,\n          bar,\n          baz,\n          toto,\n          tutu,\n          tata\n\n# Bad variable names regexes, separated by a comma. If names match any regex,\n# they will always be refused\nbad-names-rgxs=\n\n# Naming style matching correct class attribute names.\nclass-attribute-naming-style=any\n\n# Regular expression matching correct class attribute names. Overrides class-\n# attribute-naming-style.\n#class-attribute-rgx=\n\n# Naming style matching correct class names.\nclass-naming-style=PascalCase\n\n# Regular expression matching correct class names. Overrides class-naming-\n# style.\n#class-rgx=\n\n# Naming style matching correct constant names.\nconst-naming-style=UPPER_CASE\n\n# Regular expression matching correct constant names. Overrides const-naming-\n# style.\n#const-rgx=\n\n# Minimum line length for functions/classes that require docstrings, shorter\n# ones are exempt.\ndocstring-min-length=-1\n\n# Naming style matching correct function names.\nfunction-naming-style=snake_case\n\n# Regular expression matching correct function names. Overrides function-\n# naming-style.\n#function-rgx=\n\n# Good variable names which should always be accepted, separated by a comma.\ngood-names=i,\n           j,\n           k,\n           ex,\n           Run,\n           _,\n           x,\n           y,\n           w,\n           h,\n           a,\n           b\n\n# Good variable names regexes, separated by a comma. If names match any regex,\n# they will always be accepted\ngood-names-rgxs=\n\n# Include a hint for the correct naming format with invalid-name.\ninclude-naming-hint=no\n\n# Naming style matching correct inline iteration names.\ninlinevar-naming-style=any\n\n# Regular expression matching correct inline iteration names. Overrides\n# inlinevar-naming-style.\n#inlinevar-rgx=\n\n# Naming style matching correct method names.\nmethod-naming-style=snake_case\n\n# Regular expression matching correct method names. Overrides method-naming-\n# style.\n#method-rgx=\n\n# Naming style matching correct module names.\nmodule-naming-style=snake_case\n\n# Regular expression matching correct module names. Overrides module-naming-\n# style.\n#module-rgx=\n\n# Colon-delimited sets of names that determine each other's naming style when\n# the name regexes allow several styles.\nname-group=\n\n# Regular expression which should only match function or class names that do\n# not require a docstring.\nno-docstring-rgx=^_\n\n# List of decorators that produce properties, such as abc.abstractproperty. Add\n# to this list to register other decorators that produce valid properties.\n# These decorators are taken in consideration only for invalid-name.\nproperty-classes=abc.abstractproperty\n\n# Naming style matching correct variable names.\nvariable-naming-style=snake_case\n\n# Regular expression matching correct variable names. Overrides variable-\n# naming-style.\n#variable-rgx=\n\n\n[DESIGN]\n\n# Maximum number of arguments for function / method.\nmax-args=5\n\n# Maximum number of attributes for a class (see R0902).\nmax-attributes=7\n\n# Maximum number of boolean expressions in an if statement (see R0916).\nmax-bool-expr=5\n\n# Maximum number of branch for function / method body.\nmax-branches=12\n\n# Maximum number of locals for function / method body.\nmax-locals=15\n\n# Maximum number of parents for a class (see R0901).\nmax-parents=7\n\n# Maximum number of public methods for a class (see R0904).\nmax-public-methods=20\n\n# Maximum number of return / yield for function / method body.\nmax-returns=6\n\n# Maximum number of statements in function / method body.\nmax-statements=50\n\n# Minimum number of public methods for a class (see R0903).\nmin-public-methods=2\n\n\n[IMPORTS]\n\n# List of modules that can be imported at any level, not just the top level\n# one.\nallow-any-import-level=\n\n# Allow wildcard imports from modules that define __all__.\nallow-wildcard-with-all=no\n\n# Analyse import fallback blocks. This can be used to support both Python 2 and\n# 3 compatible code, which means that the block might have code that exists\n# only in one or another interpreter, leading to false positives when analysed.\nanalyse-fallback-blocks=no\n\n# Deprecated modules which should not be used, separated by a comma.\ndeprecated-modules=optparse,tkinter.tix\n\n# Create a graph of external dependencies in the given file (report RP0402 must\n# not be disabled).\next-import-graph=\n\n# Create a graph of every (i.e. internal and external) dependencies in the\n# given file (report RP0402 must not be disabled).\nimport-graph=\n\n# Create a graph of internal dependencies in the given file (report RP0402 must\n# not be disabled).\nint-import-graph=\n\n# Force import order to recognize a module as part of the standard\n# compatibility libraries.\nknown-standard-library=\n\n# Force import order to recognize a module as part of a third party library.\nknown-third-party=enchant\n\n# Couples of modules and preferred modules, separated by a comma.\npreferred-modules=\n\n\n[CLASSES]\n\n# List of method names used to declare (i.e. assign) instance attributes.\ndefining-attr-methods=__init__,\n                      __new__,\n                      setUp,\n                      __post_init__\n\n# List of member names, which should be excluded from the protected access\n# warning.\nexclude-protected=_asdict,\n                  _fields,\n                  _replace,\n                  _source,\n                  _make\n\n# List of valid names for the first argument in a class method.\nvalid-classmethod-first-arg=cls\n\n# List of valid names for the first argument in a metaclass class method.\nvalid-metaclass-classmethod-first-arg=cls\n\n\n[EXCEPTIONS]\n\n# Exceptions that will emit a warning when being caught. Defaults to\n# \"BaseException, Exception\".\novergeneral-exceptions=BaseException,\n                       Exception\n"
  },
  {
    "path": "CLAUDE.md",
    "content": "# CLAUDE.md\n\nThis file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.\n\n## Commands\n\n**Linting:**\n\n```bash\npre-commit run --all-files\n```\n\nStyle: PEP8, max line length 120, double quotes, LF endings. C++ source under `src/` uses clang-format.\n\n**Tests:**\n\n```bash\npytest tests/test_lmdeploy                          # all unit tests\npytest tests/test_lmdeploy/test_model.py            # specific file\npytest tests/test_lmdeploy/test_lite/               # quantization tests\npytest tests/test_lmdeploy/test_vl/                 # vision-language tests\n```\n\n**Debug logging:**\n\n```bash\nLMDEPLOY_LOG_LEVEL=DEBUG python ...\n```\n\n**Build (TurboMind C++ extension):**\n\n- Controlled via `setup.py` + CMake. Relevant env vars: `LMDEPLOY_TARGET_DEVICE` (default `cuda`), `DISABLE_TURBOMIND`, `CMAKE_BUILD_TYPE`, `CUDACXX`.\n- Requirements split by device: `requirements/runtime_cuda.txt`, `runtime_ascend.txt`, etc.\n\n## Architecture\n\n### Two Backends, One Pipeline\n\n`lmdeploy/pipeline.py` is the main user-facing entry point (`pipeline()` in `api.py`). It instantiates either the **PyTorch engine** (`lmdeploy/pytorch/`) or the **TurboMind engine** (`lmdeploy/turbomind/`) based on config.\n\n### PyTorch Backend\n\n**Model patching** is the core mechanism: HuggingFace models are loaded normally, then their layers are dynamically replaced with optimized LMDeploy implementations.\n\n- `lmdeploy/pytorch/models/module_map.py` — registry mapping HF class names → LMDeploy replacement classes. Device-specific overrides in `DEVICE_SPECIAL_MODULE_MAP`.\n- `lmdeploy/pytorch/models/patch.py` — applies the substitutions at runtime via `_get_rewrite_qualname()` / `_class_from_qualname()`.\n- `lmdeploy/pytorch/models/` — 40+ per-model files (e.g., `llama.py`, `qwen.py`, `deepseek_v2.py`). Each reimplements attention, MLP, and embeddings using custom kernels.\n- `lmdeploy/pytorch/nn/` — reusable optimized modules: `linear/` (AWQ, W8A8, blocked-FP8, LoRA variants), `attention.py`, `norm.py`, `rotary_embedding.py`, `moe/`.\n- `lmdeploy/pytorch/kernels/` — Triton/CUDA kernels (e.g., `w8a8_triton_kernels.py`).\n- `lmdeploy/pytorch/backends/` — kernel/operator dispatchers per quantization type (FP8, AWQ, CUDA).\n\n**Engine execution flow (key files):**\n\n- `engine.py` — main PyTorch engine.\n- `paging/scheduler.py` — sequences → batches; prefill/decode, block eviction, prefix caching (`BlockTrie`).\n- `engine/engine_loop.py` — async inference loop.\n- (See `pytorch/engine/` and `pytorch/paging/` for full execution detail.)\n\n**Configuration dataclasses** (`lmdeploy/pytorch/config.py`): `ModelConfig`, `CacheConfig`, `SchedulerConfig`, `BackendConfig`, `DistConfig`, `MiscConfig`.\n\n### TurboMind Backend\n\n- Python wrapper: `lmdeploy/turbomind/turbomind.py` (~800 lines). Bridges into `lmdeploy/lib/_turbomind` (pybind11 extension built from `src/turbomind/`).\n- Tensor interop via `torch.from_dlpack()` / `_tm.from_dlpack()`.\n- Config and model conversion: `lmdeploy/turbomind/deploy/config.py`, `supported_models.py`.\n- Parallel config helpers: `update_parallel_config()`, `complete_parallel_config()` in `messages.py`.\n\n### Lite / Quantization\n\nEntrypoints in `lmdeploy/lite/apis/`: `calibrate.py` (main), `auto_awq.py`, `gptq.py`, `smooth_quant.py`.\n\n**Flow:** load HF model → `CalibrationContext` collects activation statistics → scale computation (`lmdeploy/lite/quantization/`) → write quantized weights.\n\n- `lite/quantization/awq.py` — AWQ (NORM_FCS_MAP, FC_FCS_MAP define per-model layer structure).\n- `lite/quantization/weight/quantizer.py` — weight quantizer.\n- `lite/quantization/activation/observer.py` — activation statistics.\n- `lite/modeling/` — model-specific GPTQ implementations (e.g., `internlm2_gptq.py`).\n- `lite/utils/cal_qparams.py` — quantization parameter calculation utilities.\n\nLayer/norm/head mappings per model family are defined directly in `calibrate.py` and `awq.py`.\n\n### Vision-Language Models\n\n- `lmdeploy/vl/model/` — VLM preprocessing (InternVL, Qwen-VL, LLaVA, CogVLM, etc.).\n- `lmdeploy/vl/media/` — image/video loaders and base classes.\n- `lmdeploy/pytorch/multimodal/` — multimodal input handling for the PyTorch engine.\n- Reference VLM implementation: `lmdeploy/vl/model/qwen3.py`.\n\n### Other Key Files\n\n- `lmdeploy/messages.py` — core types: `GenerationConfig`, `EngineConfig`, `TurbomindEngineConfig`, `SchedulerSequence`, `MessageStatus`.\n- `lmdeploy/model.py` — chat templates; critical for correct conversation formatting.\n- `lmdeploy/archs.py` — architecture registry mapping model arch names to runtime patches.\n- `lmdeploy/tokenizer.py` — HuggingFace/SentencePiece tokenizer wrapper.\n- `lmdeploy/serve/openai/` — OpenAI-compatible API server.\n\n## Adding a New PyTorch Model\n\nUse the `/support-new-model` skill for a complete step-by-step guide.\n"
  },
  {
    "path": "CMakeLists.txt",
    "content": "# Copyright (c) 2019-2023, NVIDIA CORPORATION.  All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\ncmake_minimum_required(VERSION 3.11 FATAL_ERROR) # for PyTorch extensions, version should be greater than 3.13\ncmake_policy(SET CMP0074 NEW)\nproject(TurboMind LANGUAGES CXX CUDA)\n\nif (MSVC)\n    # use standard conformant preprocessor\n    add_compile_options($<$<COMPILE_LANGUAGE:CXX>:/Zc:preprocessor>)\n    add_compile_options($<$<COMPILE_LANGUAGE:CXX>:/Zc:__cplusplus>)\n    set(CMAKE_CUDA_FLAGS \"${CMAKE_CUDA_FLAGS} -Xcompiler=/Zc:preprocessor -Xcompiler=/Zc:__cplusplus\")\nendif ()\n\nfind_package(CUDAToolkit REQUIRED)\n\nif(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL \"11\")\n  add_definitions(\"-DENABLE_BF16\")\nendif()\n\nset(CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake/Modules)\n\noption(BUILD_MULTI_GPU \"Build multi-gpu support\" ON)\noption(BUILD_PY_FFI \"Build python ffi\" ON)\noption(BUILD_TEST \"Build tests\" OFF)\noption(SPARSITY_SUPPORT \"Build project with Ampere sparsity feature support\" OFF)\noption(BUILD_FAST_MATH \"Build in fast math mode\" ON)\n\ninclude(FetchContent)\n\nif (BUILD_TEST)\n  FetchContent_Declare(\n    Catch2\n    GIT_REPOSITORY https://github.com/catchorg/Catch2.git\n    GIT_TAG        v3.8.0\n    GIT_SHALLOW ON\n    GIT_PROGRESS            TRUE\n    USES_TERMINAL_DOWNLOAD  TRUE\n    EXCLUDE_FROM_ALL\n  )\n  FetchContent_MakeAvailable(Catch2)\nendif()\n\n\nFetchContent_Declare(\n  repo-cutlass\n  GIT_REPOSITORY https://github.com/NVIDIA/cutlass.git\n  GIT_TAG                 v3.9.2\n  GIT_SHALLOW             ON\n  GIT_PROGRESS            TRUE\n  USES_TERMINAL_DOWNLOAD  TRUE\n  EXCLUDE_FROM_ALL\n)\n\nset(CUTLASS_ENABLE_SM90_EXTENDED_MMA_SHAPES ON CACHE BOOL \"Enable extended GMMA shapes\")\nset(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL \"Enable only the header library\")\n\nFetchContent_MakeAvailable(repo-cutlass)\n\nFetchContent_Declare(\n  yaml-cpp\n  GIT_REPOSITORY https://github.com/jbeder/yaml-cpp.git\n  GIT_TAG                 65c1c270dbe7eec37b2df2531d7497c4eea79aee\n  GIT_PROGRESS            TRUE\n  USES_TERMINAL_DOWNLOAD  TRUE\n)\nset(YAML_BUILD_SHARED_LIBS OFF CACHE BOOL \"Build static library of yaml-cpp\")\nFetchContent_MakeAvailable(yaml-cpp)\n\nFetchContent_Declare(\n  xgrammar\n  GIT_REPOSITORY          https://github.com/mlc-ai/xgrammar.git\n  GIT_TAG                 v0.1.27\n  GIT_SUBMODULES          \"3rdparty/dlpack\"\n  GIT_PROGRESS            TRUE\n  USES_TERMINAL_DOWNLOAD  TRUE\n)\n\nFetchContent_GetProperties(xgrammar)\nif(NOT xgrammar_POPULATED)\n  # Fetch the content using previously declared details\n  FetchContent_Populate(xgrammar)\n\n  file(WRITE ${xgrammar_SOURCE_DIR}/config.cmake \"set(XGRAMMAR_BUILD_PYTHON_BINDINGS OFF)\\n\")\n  if(NOT MSVC)\n    file(APPEND ${xgrammar_SOURCE_DIR}/config.cmake \"set(CMAKE_CXX_FLAGS \\\"-Wno-error\\\")\\n\")\n  endif()\n\n  # Bring the populated content into the build\n  add_subdirectory(${xgrammar_SOURCE_DIR} ${xgrammar_BINARY_DIR})\n  if(TARGET xgrammar)\n    target_compile_options(xgrammar PRIVATE $<$<CXX_COMPILER_ID:MSVC>:/utf-8>)\n    target_compile_options(xgrammar PRIVATE $<$<C_COMPILER_ID:MSVC>:/utf-8>)\n  endif()\nendif()\n\n# the environment variable\n#   ASAN_OPTIONS=protect_shadow_gap=0,intercept_tls_get_addr=0\n#   LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libasan.so.6:/usr/lib/x86_64-linux-gnu/libstdc++.so.6\n# must be set at runtime\n# https://github.com/google/sanitizers/issues/1322\nif (LMDEPLOY_ASAN_ENABLE)\n    add_compile_options($<$<COMPILE_LANGUAGE:CXX>:-fsanitize=address>)\n    add_link_options(-fsanitize=address)\nendif ()\n\n# notice that ubsan has linker issues for ubuntu < 18.04, see\n# https://stackoverflow.com/questions/50024731/ld-unrecognized-option-push-state-no-as-needed\nif (LMDEPLOY_UBSAN_ENABLE)\n    add_compile_options($<$<COMPILE_LANGUAGE:CXX>:-fsanitize=undefined>)\n    add_link_options(-fsanitize=undefined)\nendif ()\n\nif(BUILD_MULTI_GPU)\n    execute_process(\n      COMMAND python -c \"import importlib.util; print(importlib.util.find_spec('nvidia.nccl').submodule_search_locations[0])\"\n      RESULT_VARIABLE result\n      OUTPUT_VARIABLE nccl_path\n      ERROR_QUIET\n      OUTPUT_STRIP_TRAILING_WHITESPACE\n    )\n\n    if(result EQUAL 0 AND NOT nccl_path STREQUAL \"\")\n      set(NCCL_ROOT ${nccl_path})\n      message(STATUS \"Found NCCL at: ${nccl_path}\")\n\n      if(result EQUAL 0 AND NOT nccl_path STREQUAL \"\")\n        file(GLOB nccl_lib_files \"${nccl_path}/lib/libnccl.so.*\")\n        if(nccl_lib_files)\n          list(GET nccl_lib_files -1 latest_lib)\n          string(REGEX MATCH \"\\\\.([0-9]+)$\" version_match ${latest_lib})\n          if(version_match)\n            set(NCCL_ROOT ${nccl_path})\n            set(ENV{NCCL_VERSION} ${CMAKE_MATCH_1})\n          endif()\n        endif()\n      endif()\n    endif()\n\n    add_definitions(\"-DBUILD_MULTI_GPU=1\")\n    set(CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake/Modules)\n    find_package(NCCL)\n    if (NCCL_FOUND)\n        set(USE_NCCL ON)\n        add_definitions(\"-DUSE_NCCL=1\")\n    endif ()\nendif()\n\n\nset(CXX_STD \"17\" CACHE STRING \"C++ standard\")\n# enable gold linker for binary and .so\nif(NOT MSVC)\n  find_program(GOLD_PATH ld.gold REQUIRED)\n  if(NOT GOLD_PATH)\n    message(FATAL_ERROR \"GNU gold linker is required but not found. \"\n                         \"Please install binutils-gold package.\")\n  endif()\n  set(CMAKE_EXE_LINKER_FLAGS \"${CMAKE_EXE_LINKER_FLAGS} -fuse-ld=gold\")\n  set(CMAKE_SHARED_LINKER_FLAGS \"${CMAKE_SHARED_LINKER_FLAGS} -fuse-ld=gold\")\nendif()\nset(CUDA_PATH ${CUDA_TOOLKIT_ROOT_DIR})\n\nset(CUSPARSELT_PATH \"\" CACHE STRING \"cuSPARSELt path\")\n\nlist(APPEND CMAKE_MODULE_PATH ${CUDA_PATH}/lib64)\n\n# profiling\noption(USE_NVTX \"Whether or not to use nvtx\" ON)\nif(USE_NVTX)\n  message(STATUS \"NVTX is enabled.\")\n  add_definitions(\"-DUSE_NVTX\")\nendif()\n\n# setting compiler flags\nset(CMAKE_C_FLAGS    \"${CMAKE_C_FLAGS}\")\nset(CMAKE_CXX_STANDARD 17)\nset(CMAKE_CUDA_FLAGS \"${CMAKE_CUDA_FLAGS} -Xcompiler -Wall -ldl\") # -Xptxas -v\n\nif(CMAKE_SYSTEM_PROCESSOR STREQUAL \"x86_64\")\n  set(ARCH \"x86_64\")\nelseif(CMAKE_SYSTEM_PROCESSOR STREQUAL \"amd64\")\n  set(ARCH \"x86_64\")\nelseif(CMAKE_SYSTEM_PROCESSOR STREQUAL \"AMD64\")\n  # cmake reports AMD64 on Windows, but we might be building for 32-bit.\n  if(CMAKE_SIZEOF_VOID_P EQUAL 8)\n    set(ARCH \"x86_64\")\n  else()\n    set(ARCH \"x86\")\n  endif()\nelseif(CMAKE_SYSTEM_PROCESSOR STREQUAL \"x86\")\n  set(ARCH \"x86\")\nelseif(CMAKE_SYSTEM_PROCESSOR STREQUAL \"i386\")\n  set(ARCH \"x86\")\nelseif(CMAKE_SYSTEM_PROCESSOR STREQUAL \"i686\")\n  set(ARCH \"x86\")\nelseif(CMAKE_SYSTEM_PROCESSOR STREQUAL \"aarch64\")\n  set(ARCH \"aarch64\")\nelseif(CMAKE_SYSTEM_PROCESSOR STREQUAL \"arm64\")\n  set(ARCH \"aarch64\")\n# Apple A12 Bionic chipset which is added in iPhone XS/XS Max/XR uses arm64e architecture.\nelseif(CMAKE_SYSTEM_PROCESSOR STREQUAL \"arm64e\")\n  set(ARCH \"aarch64\")\nelseif(CMAKE_SYSTEM_PROCESSOR MATCHES \"^arm*\")\n  set(ARCH \"arm\")\nelseif(CMAKE_SYSTEM_PROCESSOR STREQUAL \"mips\")\n  # Just to avoid the “unknown processor” error.\n  set(ARCH \"generic\")\nelseif(CMAKE_SYSTEM_PROCESSOR STREQUAL \"ppc64le\")\n  set(ARCH \"ppc64le\")\nelse()\n  message(FATAL_ERROR \"Unknown processor:\" ${CMAKE_SYSTEM_PROCESSOR})\nendif()\n\n\nif(ARCH STREQUAL \"x86_64\")\n  if (NOT CMAKE_CUDA_ARCHITECTURES)\n    set(CMAKE_CUDA_ARCHITECTURES \"\")\n    if (${CMAKE_CUDA_COMPILER_VERSION} VERSION_LESS \"13.0\")\n      list(APPEND CMAKE_CUDA_ARCHITECTURES 70-real 75-real)  # V100, 2080\n    endif()\n    if (${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL \"11\")\n      list(APPEND CMAKE_CUDA_ARCHITECTURES 80-real) # A100\n    endif ()\n    if (${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL \"11.1\")\n      list(APPEND CMAKE_CUDA_ARCHITECTURES 86-real) # 3090\n    endif ()\n    if (${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL \"11.8\")\n      list(APPEND CMAKE_CUDA_ARCHITECTURES 89-real) # 4090\n    endif ()\n    if (${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL \"12.0\")\n      list(APPEND CMAKE_CUDA_ARCHITECTURES 90a-real) # H100\n    endif ()\n    if (${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL \"12.8\")\n      list(APPEND CMAKE_CUDA_ARCHITECTURES 120a-real) # 5090\n    endif ()\n    if (MSVC)\n      list(REMOVE_ITEM CMAKE_CUDA_ARCHITECTURES 80-real 90a-real)\n    endif ()\n  endif ()\nelseif(ARCH STREQUAL \"aarch64\")\n  if (NOT CMAKE_CUDA_ARCHITECTURES)\n    set(CMAKE_CUDA_ARCHITECTURES 72-real 87-real)  # Jetson\n  endif()\nelse()\n  message(FATAL_ERROR \"Unsupported Architecture:\" ${ARCH})\nendif()\n\nmessage(STATUS \"Building with CUDA archs: ${CMAKE_CUDA_ARCHITECTURES}\")\n\nset(CMAKE_CUDA_RUNTIME_LIBRARY Shared)\nset(CMAKE_C_FLAGS_DEBUG    \"${CMAKE_C_FLAGS_DEBUG}    -Wall -O0\")\nset(CMAKE_CXX_FLAGS_DEBUG  \"${CMAKE_CXX_FLAGS_DEBUG}  -Wall -O0\")\n# set(CMAKE_CUDA_FLAGS_DEBUG \"${CMAKE_CUDA_FLAGS_DEBUG} -O0 -G -Xcompiler -Wall  --ptxas-options=-v --resource-usage\")\nset(CMAKE_CUDA_FLAGS_DEBUG \"${CMAKE_CUDA_FLAGS_DEBUG} -O0 -G -Xcompiler -Wall\")\n\nset(CMAKE_CXX_STANDARD \"${CXX_STD}\")\nset(CMAKE_CXX_STANDARD_REQUIRED ON)\nset(CMAKE_CUDA_FLAGS \"${CMAKE_CUDA_FLAGS} --expt-extended-lambda\")\nset(CMAKE_CUDA_FLAGS \"${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr\")\nset(CMAKE_CUDA_FLAGS \"${CMAKE_CUDA_FLAGS} --std=c++${CXX_STD}\")\n\nstring(REPLACE \"-O2\" \"\" CMAKE_CXX_FLAGS_RELEASE         \"${CMAKE_CXX_FLAGS_RELEASE}\")\nstring(REPLACE \"-O2\" \"\" CMAKE_CUDA_FLAGS_RELEASE        \"${CMAKE_CUDA_FLAGS_RELEASE}\")\nstring(REPLACE \"-O2\" \"\" CMAKE_CXX_FLAGS_RELWITHDEBINFO  \"${CMAKE_CXX_FLAGS_RELWITHDEBINFO}\")\nstring(REPLACE \"-O2\" \"\" CMAKE_CUDA_FLAGS_RELWITHDEBINFO \"${CMAKE_CUDA_FLAGS_RELWITHDEBINFO}\")\n\nset(CMAKE_CXX_FLAGS_RELEASE         \"${CMAKE_CXX_FLAGS_RELEASE}         -O3\")\nset(CMAKE_CXX_FLAGS_RELWITHDEBINFO  \"${CMAKE_CXX_FLAGS_RELWITHDEBINFO}  -O3\")\nset(CMAKE_CUDA_FLAGS_RELEASE        \"${CMAKE_CUDA_FLAGS_RELEASE}        -O3\")\nset(CMAKE_CUDA_FLAGS_RELWITHDEBINFO \"${CMAKE_CUDA_FLAGS_RELWITHDEBINFO} -O3\")\n\nif(BUILD_FAST_MATH)\n    set(CMAKE_CUDA_FLAGS_RELEASE        \"${CMAKE_CUDA_FLAGS_RELEASE}        --use_fast_math\")\n    set(CMAKE_CUDA_FLAGS_RELWITHDEBINFO \"${CMAKE_CUDA_FLAGS_RELWITHDEBINFO} --use_fast_math\")\n    message(\"Release build CUDA flags: ${CMAKE_CUDA_FLAGS_RELEASE}\")\nendif()\n\nset(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib)\nset(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib)\nset(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)\n\nset(COMMON_HEADER_DIRS\n  ${PROJECT_SOURCE_DIR}\n  ${CUDA_PATH}/include\n  ${CUTLASS_HEADER_DIR}\n)\nmessage(\"-- COMMON_HEADER_DIRS: ${COMMON_HEADER_DIRS}\")\n\nset(COMMON_LIB_DIRS\n  ${CUDA_PATH}/lib64\n)\n\nif (SPARSITY_SUPPORT)\n  list(APPEND COMMON_HEADER_DIRS ${CUSPARSELT_PATH}/include)\n  list(APPEND COMMON_LIB_DIRS ${CUSPARSELT_PATH}/lib64)\n  add_definitions(-DSPARSITY_ENABLED=1)\nendif()\n\n\nset(PYTHON_PATH \"python\" CACHE STRING \"Python path\")\n\n# turn off warnings on windows\nif (MSVC)\n  foreach(\n    flag_var\n    CMAKE_CXX_FLAGS\n    CMAKE_CXX_FLAGS_DEBUG\n    CMAKE_CXX_FLAGS_RELEASE\n    CMAKE_CXX_FLAGS_MINSIZEREL\n    CMAKE_CXX_FLAGS_RELWITHDEBINFO\n    CMAKE_C_FLAGS\n    CMAKE_C_FLAGS_DEBUG\n    CMAKE_C_FLAGS_RELEASE\n    CMAKE_C_FLAGS_MINSIZEREL\n    CMAKE_C_FLAGS_RELWITHDEBINFO\n    CMAKE_CUDA_FLAGS\n    CMAKE_CUDA_FLAGS_DEBUG\n    CMAKE_CUDA_FLAGS_RELEASE\n    CMAKE_CUDA_FLAGS_MINSIZEREL\n    CMAKE_CUDA_FLAGS_RELWITHDEBINFO)\n    string(REGEX REPLACE \"-Wall\" \" /W0 \" ${flag_var} \"${${flag_var}}\")\n  endforeach()\n  # avoid min/max macro in \"windows.h\" conflict with std::min/std::max\n  add_definitions(-DNOMINMAX=1)\nendif()\n\ninclude_directories(\n  ${COMMON_HEADER_DIRS}\n)\n\nlink_directories(\n  ${COMMON_LIB_DIRS}\n)\n\nadd_subdirectory(src)\n\n# if(BUILD_TEST)\n#     add_subdirectory(tests/csrc)\n# endif()\n\n# install python api\nif (BUILD_PY_FFI)\n  if (CALL_FROM_SETUP_PY)\n    install(TARGETS _turbomind DESTINATION ${CMAKE_INSTALL_PREFIX})\n    install(TARGETS _xgrammar DESTINATION ${CMAKE_INSTALL_PREFIX})\n  else()\n    install(TARGETS _turbomind DESTINATION ${CMAKE_SOURCE_DIR}/lmdeploy/lib)\n    install(TARGETS _xgrammar DESTINATION ${CMAKE_SOURCE_DIR}/lmdeploy/lib)\n  endif()\nendif ()\n"
  },
  {
    "path": "LICENSE",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright 2023-2024 Shanghai AI Laboratory. All rights reserved.\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "MANIFEST.in",
    "content": "\ninclude lmdeploy/lib/*.so\ninclude lmdeploy/lib/*.so*\ninclude lmdeploy/lib/*.dll\ninclude lmdeploy/lib/*.pyd\ninclude lmdeploy/bin/*\n"
  },
  {
    "path": "README.md",
    "content": "<div align=\"center\">\n  <img src=\"docs/en/_static/image/lmdeploy-logo.svg\" width=\"450\"/>\n\n[![PyPI](https://img.shields.io/pypi/v/lmdeploy)](https://pypi.org/project/lmdeploy)\n![PyPI - Downloads](https://img.shields.io/pypi/dm/lmdeploy)\n[![license](https://img.shields.io/github/license/InternLM/lmdeploy.svg)](https://github.com/InternLM/lmdeploy/tree/main/LICENSE)\n[![issue resolution](https://img.shields.io/github/issues-closed-raw/InternLM/lmdeploy)](https://github.com/InternLM/lmdeploy/issues)\n[![open issues](https://img.shields.io/github/issues-raw/InternLM/lmdeploy)](https://github.com/InternLM/lmdeploy/issues)\n\n[📘Documentation](https://lmdeploy.readthedocs.io/en/latest/) |\n[🛠️Quick Start](https://lmdeploy.readthedocs.io/en/latest/get_started/get_started.html) |\n[🤔Reporting Issues](https://github.com/InternLM/lmdeploy/issues/new/choose)\n\nEnglish | [简体中文](README_zh-CN.md) | [日本語](README_ja.md)\n\n👋 join us on [![Static Badge](https://img.shields.io/badge/-grey?style=social&logo=wechat&label=WeChat)](https://cdn.vansin.top/internlm/lmdeploy.jpg)\n[![Static Badge](https://img.shields.io/badge/-grey?style=social&logo=twitter&label=Twitter)](https://twitter.com/intern_lm)\n[![Static Badge](https://img.shields.io/badge/-grey?style=social&logo=discord&label=Discord)](https://discord.gg/xa29JuW87d)\n\n</div>\n\n______________________________________________________________________\n\n## Latest News 🎉\n\n<details open>\n<summary><b>2026</b></summary>\n\n- \\[2026/02\\] Support [Qwen3.5](https://huggingface.co/collections/Qwen/qwen35)\n- \\[2026/02\\] Support [vllm-project/llm-compressor](https://github.com/vllm-project/llm-compressor) 4bit symmetric/asymmetric quantization. Refer [here](./docs/en/quantization/llm_compressor.md) for detailed guide\n\n</details>\n\n<details close>\n<summary><b>2025</b></summary>\n\n- \\[2025/09\\] TurboMind supports MXFP4 on NVIDIA GPUs starting from V100, achieving 1.5x the performmance of vLLM on H800 for openai gpt-oss models!\n- \\[2025/06\\] Comprehensive inference optimization for FP8 MoE Models\n- \\[2025/06\\] DeepSeek PD Disaggregation deployment is now supported through integration with [DLSlime](https://github.com/DeepLink-org/DLSlime) and [Mooncake](https://github.com/kvcache-ai/Mooncake). Huge thanks to both teams!\n- \\[2025/04\\] Enhance DeepSeek inference performance by integration deepseek-ai techniques: FlashMLA, DeepGemm, DeepEP, MicroBatch and eplb\n- \\[2025/01\\] Support DeepSeek V3 and R1\n\n</details>\n\n<details close>\n<summary><b>2024</b></summary>\n\n- \\[2024/11\\] Support Mono-InternVL with PyTorch engine\n- \\[2024/10\\] PyTorchEngine supports graph mode on ascend platform, doubling the inference speed\n- \\[2024/09\\] LMDeploy PyTorchEngine adds support for [Huawei Ascend](./docs/en/get_started/ascend/get_started.md). See supported models [here](docs/en/supported_models/supported_models.md)\n- \\[2024/09\\] LMDeploy PyTorchEngine achieves 1.3x faster on Llama3-8B inference by introducing CUDA graph\n- \\[2024/08\\] LMDeploy is integrated into [modelscope/swift](https://github.com/modelscope/swift) as the default accelerator for VLMs inference\n- \\[2024/07\\] Support Llama3.1 8B, 70B and its TOOLS CALLING\n- \\[2024/07\\] Support [InternVL2](docs/en/multi_modal/internvl.md) full-series models, [InternLM-XComposer2.5](docs/en/multi_modal/xcomposer2d5.md) and [function call](docs/en/llm/api_server_tools.md) of InternLM2.5\n- \\[2024/06\\] PyTorch engine support DeepSeek-V2 and several VLMs, such as CogVLM2, Mini-InternVL, LlaVA-Next\n- \\[2024/05\\] Balance vision model when deploying VLMs with multiple GPUs\n- \\[2024/05\\] Support 4-bits weight-only quantization and inference on VLMs, such as InternVL v1.5, LLaVa, InternLMXComposer2\n- \\[2024/04\\] Support Llama3 and more VLMs, such as InternVL v1.1, v1.2, MiniGemini, InternLMXComposer2.\n- \\[2024/04\\] TurboMind adds online int8/int4 KV cache quantization and inference for all supported devices. Refer [here](docs/en/quantization/kv_quant.md) for detailed guide\n- \\[2024/04\\] TurboMind latest upgrade boosts GQA, rocketing the [internlm2-20b](https://huggingface.co/internlm/internlm2-20b) model inference to 16+ RPS, about 1.8x faster than vLLM.\n- \\[2024/04\\] Support Qwen1.5-MOE and dbrx.\n- \\[2024/03\\] Support DeepSeek-VL offline inference pipeline and serving.\n- \\[2024/03\\] Support VLM offline inference pipeline and serving.\n- \\[2024/02\\] Support Qwen 1.5, Gemma, Mistral, Mixtral, Deepseek-MOE and so on.\n- \\[2024/01\\] [OpenAOE](https://github.com/InternLM/OpenAOE) seamless integration with [LMDeploy Serving Service](docs/en/llm/api_server.md).\n- \\[2024/01\\] Support for multi-model, multi-machine, multi-card inference services. For usage instructions, please refer to [here](docs/en/llm/proxy_server.md)\n- \\[2024/01\\] Support [PyTorch inference engine](./docs/en/inference/pytorch.md), developed entirely in Python, helping to lower the barriers for developers and enable  rapid experimentation with new features and technologies.\n\n</details>\n\n<details close>\n<summary><b>2023</b></summary>\n\n- \\[2023/12\\] Turbomind supports multimodal input.\n- \\[2023/11\\] Turbomind supports loading hf model directly. Click [here](docs/en/inference/load_hf.md) for details.\n- \\[2023/11\\] TurboMind major upgrades, including: Paged Attention, faster attention kernels without sequence length limitation, 2x faster KV8 kernels, Split-K decoding (Flash Decoding), and W4A16 inference for sm_75\n- \\[2023/09\\] TurboMind supports Qwen-14B\n- \\[2023/09\\] TurboMind supports InternLM-20B\n- \\[2023/09\\] TurboMind supports all features of Code Llama: code completion, infilling, chat / instruct, and python specialist. Click [here](./docs/en/llm/codellama.md) for deployment guide\n- \\[2023/09\\] TurboMind supports Baichuan2-7B\n- \\[2023/08\\] TurboMind supports flash-attention2.\n- \\[2023/08\\] TurboMind supports Qwen-7B, dynamic NTK-RoPE scaling and dynamic logN scaling\n- \\[2023/08\\] TurboMind supports Windows (tp=1)\n- \\[2023/08\\] TurboMind supports 4-bit inference, 2.4x faster than FP16, the fastest open-source implementation. Check [this](docs/en/quantization/w4a16.md) guide for detailed info\n- \\[2023/08\\] LMDeploy has launched on the [HuggingFace Hub](https://huggingface.co/lmdeploy), providing ready-to-use 4-bit models.\n- \\[2023/08\\] LMDeploy supports 4-bit quantization using the [AWQ](https://arxiv.org/abs/2306.00978) algorithm.\n- \\[2023/07\\] TurboMind supports Llama-2 70B with GQA.\n- \\[2023/07\\] TurboMind supports Llama-2 7B/13B.\n- \\[2023/07\\] TurboMind supports tensor-parallel inference of InternLM.\n\n</details>\n\n______________________________________________________________________\n\n# Introduction\n\nLMDeploy is a toolkit for compressing, deploying, and serving LLM, developed by the [MMRazor](https://github.com/open-mmlab/mmrazor) and [MMDeploy](https://github.com/open-mmlab/mmdeploy) teams. It has the following core features:\n\n- **Efficient Inference**: LMDeploy delivers up to 1.8x higher request throughput than vLLM, by introducing key features like persistent batch(a.k.a. continuous batching), blocked KV cache, dynamic split&fuse, tensor parallelism, high-performance CUDA kernels and so on.\n\n- **Effective Quantization**: LMDeploy supports weight-only and k/v quantization, and the 4-bit inference performance is 2.4x higher than FP16. The quantization quality has been confirmed via OpenCompass evaluation.\n\n- **Effortless Distribution Server**: Leveraging the request distribution service, LMDeploy facilitates an easy and efficient deployment of multi-model services across multiple machines and cards.\n\n- **Excellent Compatibility**: LMDeploy supports [KV Cache Quant](docs/en/quantization/kv_quant.md), [AWQ](docs/en/quantization/w4a16.md) and [Automatic Prefix Caching](docs/en/inference/turbomind_config.md) to be used simultaneously.\n\n# Performance\n\n![v0 1 0-benchmark](https://github.com/InternLM/lmdeploy/assets/4560679/8e455cf1-a792-4fa8-91a2-75df96a2a5ba)\n\n# Supported Models\n\n<table>\n<tbody>\n<tr align=\"center\" valign=\"middle\">\n<td>\n  <b>LLMs</b>\n</td>\n<td>\n  <b>VLMs</b>\n</td>\n<tr valign=\"top\">\n<td align=\"left\" valign=\"top\">\n<ul>\n  <li>Llama (7B - 65B)</li>\n  <li>Llama2 (7B - 70B)</li>\n  <li>Llama3 (8B, 70B)</li>\n  <li>Llama3.1 (8B, 70B)</li>\n  <li>Llama3.2 (1B, 3B)</li>\n  <li>InternLM (7B - 20B)</li>\n  <li>InternLM2 (7B - 20B)</li>\n  <li>InternLM3 (8B)</li>\n  <li>InternLM2.5 (7B)</li>\n  <li>Qwen (1.8B - 72B)</li>\n  <li>Qwen1.5 (0.5B - 110B)</li>\n  <li>Qwen1.5 - MoE (0.5B - 72B)</li>\n  <li>Qwen2 (0.5B - 72B)</li>\n  <li>Qwen2-MoE (57BA14B)</li>\n  <li>Qwen2.5 (0.5B - 32B)</li>\n  <li>Qwen3, Qwen3-MoE</li>\n  <li>Qwen3-Next(80B)</li>\n  <li>Baichuan (7B)</li>\n  <li>Baichuan2 (7B-13B)</li>\n  <li>Code Llama (7B - 34B)</li>\n  <li>ChatGLM2 (6B)</li>\n  <li>GLM-4 (9B)</li>\n  <li>GLM-4-0414 (9B, 32B)</li>\n  <li>CodeGeeX4 (9B)</li>\n  <li>YI (6B-34B)</li>\n  <li>Mistral (7B)</li>\n  <li>DeepSeek-MoE (16B)</li>\n  <li>DeepSeek-V2 (16B, 236B)</li>\n  <li>DeepSeek-V2.5 (236B)</li>\n  <li>DeepSeek-V3 (685B)</li>\n  <li>DeepSeek-V3.2 (685B)</li>\n  <li>Mixtral (8x7B, 8x22B)</li>\n  <li>Gemma (2B - 7B)</li>\n  <li>StarCoder2 (3B - 15B)</li>\n  <li>Phi-3-mini (3.8B)</li>\n  <li>Phi-3.5-mini (3.8B)</li>\n  <li>Phi-3.5-MoE (16x3.8B)</li>\n  <li>Phi-4-mini (3.8B)</li>\n  <li>MiniCPM3 (4B)</li>\n  <li>SDAR (1.7B-30B)</li>\n  <li>gpt-oss (20B, 120B)</li>\n  <li>GLM-4.7-Flash (30B)</li>\n  <li>GLM-5 (754B)</li>\n</ul>\n</td>\n<td>\n<ul>\n  <li>LLaVA(1.5,1.6) (7B-34B)</li>\n  <li>InternLM-XComposer2 (7B, 4khd-7B)</li>\n  <li>InternLM-XComposer2.5 (7B)</li>\n  <li>Qwen-VL (7B)</li>\n  <li>Qwen2-VL (2B, 7B, 72B)</li>\n  <li>Qwen2.5-VL (3B, 7B, 72B)</li>\n  <li>Qwen3-VL (2B - 235B)</li>\n  <li>Qwen3.5 (0.8B - 397B)</li>\n  <li>DeepSeek-VL (7B)</li>\n  <li>DeepSeek-VL2 (3B, 16B, 27B)</li>\n  <li>InternVL-Chat (v1.1-v1.5)</li>\n  <li>InternVL2 (1B-76B)</li>\n  <li>InternVL2.5(MPO) (1B-78B)</li>\n  <li>InternVL3 (1B-78B)</li>\n  <li>InternVL3.5 (1B-241BA28B)</li>\n  <li>Intern-S1 (241B)</li>\n  <li>Intern-S1-mini (8.3B)</li>\n  <li>Intern-S1-Pro (1TB)</li>\n  <li>Mono-InternVL (2B)</li>\n  <li>ChemVLM (8B-26B)</li>\n  <li>CogVLM-Chat (17B)</li>\n  <li>CogVLM2-Chat (19B)</li>\n  <li>MiniCPM-Llama3-V-2_5</li>\n  <li>MiniCPM-V-2_6</li>\n  <li>Phi-3-vision (4.2B)</li>\n  <li>Phi-3.5-vision (4.2B)</li>\n  <li>GLM-4V (9B)</li>\n  <li>GLM-4.1V-Thinking (9B)</li>\n  <li>Llama3.2-vision (11B, 90B)</li>\n  <li>Molmo (7B-D,72B)</li>\n  <li>Gemma3 (1B - 27B)</li>\n  <li>Llama4 (Scout, Maverick)</li>\n</ul>\n</td>\n</tr>\n</tbody>\n</table>\n\nLMDeploy has developed two inference engines - [TurboMind](./docs/en/inference/turbomind.md) and [PyTorch](./docs/en/inference/pytorch.md), each with a different focus. The former strives for ultimate optimization of inference performance, while the latter, developed purely in Python, aims to decrease the barriers for developers.\n\nThey differ in the types of supported models and the inference data type. Please refer to [this table](./docs/en/supported_models/supported_models.md) for each engine's capability and choose the proper one that best fits your actual needs.\n\n# Quick Start [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1Dh-YlSwg78ZO3AlleO441NF_QP2shs95#scrollTo=YALmXnwCG1pQ)\n\n## Installation\n\nIt is recommended installing lmdeploy using pip in a conda environment (python 3.10 - 3.13):\n\n```shell\nconda create -n lmdeploy python=3.10 -y\nconda activate lmdeploy\npip install lmdeploy\n```\n\nThe default prebuilt package is compiled on **CUDA 12** since v0.3.0.\n\nFor the GeForce RTX 50 series, please install the LMDeploy prebuilt package complied with **CUDA 12.8**\n\n```shell\nexport LMDEPLOY_VERSION=0.12.2\nexport PYTHON_VERSION=310\npip install https://github.com/InternLM/lmdeploy/releases/download/v${LMDEPLOY_VERSION}/lmdeploy-${LMDEPLOY_VERSION}+cu128-cp${PYTHON_VERSION}-cp${PYTHON_VERSION}-manylinux2014_x86_64.whl --extra-index-url https://download.pytorch.org/whl/cu128\n```\n\nFor more information on installing on CUDA 11+ platform, or for instructions on building from source, please refer to the [installation guide](docs/en/get_started/installation.md).\n\n## Offline Batch Inference\n\n```python\nimport lmdeploy\nwith lmdeploy.pipeline(\"internlm/internlm3-8b-instruct\") as pipe:\n    response = pipe([\"Hi, pls intro yourself\", \"Shanghai is\"])\n    print(response)\n```\n\n> \\[!NOTE\\]\n> By default, LMDeploy downloads model from HuggingFace. If you would like to use models from ModelScope, please install ModelScope by `pip install modelscope` and set the environment variable:\n>\n> `export LMDEPLOY_USE_MODELSCOPE=True`\n>\n> If you would like to use models from openMind Hub, please install openMind Hub by `pip install openmind_hub` and set the environment variable:\n>\n> `export LMDEPLOY_USE_OPENMIND_HUB=True`\n\nFor more information about inference pipeline, please refer to [here](docs/en/llm/pipeline.md).\n\n# Tutorials\n\nPlease review [getting_started](docs/en/get_started/get_started.md) section for the basic usage of LMDeploy.\n\nFor detailed user guides and advanced guides, please refer to our [tutorials](https://lmdeploy.readthedocs.io/en/latest/):\n\n- User Guide\n  - [LLM Inference pipeline](docs/en/llm/pipeline.md) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1Dh-YlSwg78ZO3AlleO441NF_QP2shs95#scrollTo=YALmXnwCG1pQ)\n  - [VLM Inference pipeline](docs/en/multi_modal/vl_pipeline.md) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1nKLfnPeDA3p-FMNw2NhI-KOpk7-nlNjF?usp=sharing)\n  - [LLM Serving](docs/en/llm/api_server.md)\n  - [VLM Serving](docs/en/multi_modal/api_server_vl.md)\n  - [Quantization](docs/en/quantization)\n- Advance Guide\n  - [Inference Engine - TurboMind](docs/en/inference/turbomind.md)\n  - [Inference Engine - PyTorch](docs/en/inference/pytorch.md)\n  - [Customize chat templates](docs/en/advance/chat_template.md)\n  - [Add a new model](docs/en/advance/pytorch_new_model.md)\n  - gemm tuning\n  - [Long context inference](docs/en/advance/long_context.md)\n  - [Multi-model inference service](docs/en/llm/proxy_server.md)\n\n# Third-party projects\n\n- Deploying LLMs offline on the NVIDIA Jetson platform by LMDeploy: [LMDeploy-Jetson](https://github.com/BestAnHongjun/LMDeploy-Jetson)\n\n- Example project for deploying LLMs using LMDeploy and BentoML: [BentoLMDeploy](https://github.com/bentoml/BentoLMDeploy)\n\n# Contributing\n\nWe appreciate all contributions to LMDeploy. Please refer to [CONTRIBUTING.md](.github/CONTRIBUTING.md) for the contributing guideline.\n\n# Acknowledgement\n\n- [FasterTransformer](https://github.com/NVIDIA/FasterTransformer)\n- [llm-awq](https://github.com/mit-han-lab/llm-awq)\n- [vLLM](https://github.com/vllm-project/vllm)\n- [DeepSpeed-MII](https://github.com/microsoft/DeepSpeed-MII)\n\n# Citation\n\n```bibtex\n@misc{2023lmdeploy,\n    title={LMDeploy: A Toolkit for Compressing, Deploying, and Serving LLM},\n    author={LMDeploy Contributors},\n    howpublished = {\\url{https://github.com/InternLM/lmdeploy}},\n    year={2023}\n}\n```\n\n```bibtex\n@article{zhang2025efficient,\n  title={Efficient Mixed-Precision Large Language Model Inference with TurboMind},\n  author={Zhang, Li and Jiang, Youhe and He, Guoliang and Chen, Xin and Lv, Han and Yao, Qian and Fu, Fangcheng and Chen, Kai},\n  journal={arXiv preprint arXiv:2508.15601},\n  year={2025}\n}\n```\n\n# License\n\nThis project is released under the [Apache 2.0 license](LICENSE).\n"
  },
  {
    "path": "README_ja.md",
    "content": "<div align=\"center\">\n  <img src=\"docs/en/_static/image/lmdeploy-logo.svg\" width=\"450\"/>\n\n[![PyPI](https://img.shields.io/pypi/v/lmdeploy)](https://pypi.org/project/lmdeploy)\n![PyPI - Downloads](https://img.shields.io/pypi/dm/lmdeploy)\n[![license](https://img.shields.io/github/license/InternLM/lmdeploy.svg)](https://github.com/InternLM/lmdeploy/tree/main/LICENSE)\n[![issue resolution](https://img.shields.io/github/issues-closed-raw/InternLM/lmdeploy)](https://github.com/InternLM/lmdeploy/issues)\n[![open issues](https://img.shields.io/github/issues-raw/InternLM/lmdeploy)](https://github.com/InternLM/lmdeploy/issues)\n\n[📘Documentation](https://lmdeploy.readthedocs.io/en/latest/) |\n[🛠️Quick Start](https://lmdeploy.readthedocs.io/en/latest/get_started/get_started.html) |\n[🤔Reporting Issues](https://github.com/InternLM/lmdeploy/issues/new/choose)\n\n[English](README.md) | [简体中文](README_zh-CN.md) | 日本語\n\n👋 join us on [![Static Badge](https://img.shields.io/badge/-grey?style=social&logo=wechat&label=WeChat)](https://cdn.vansin.top/internlm/lmdeploy.jpg)\n[![Static Badge](https://img.shields.io/badge/-grey?style=social&logo=twitter&label=Twitter)](https://twitter.com/intern_lm)\n[![Static Badge](https://img.shields.io/badge/-grey?style=social&logo=discord&label=Discord)](https://discord.gg/xa29JuW87d)\n\n</div>\n\n______________________________________________________________________\n\n## 最新ニュース 🎉\n\n<details close>\n<summary><b>2024</b></summary>\n\n- \\[2024/08\\] 🔥🔥 LMDeployは[modelscope/swift](https://github.com/modelscope/swift)に統合され、VLMs推論のデフォルトアクセラレータとなりました\n- \\[2024/07\\] 🎉🎉 Llama3.1 8B、70Bおよびそのツールコールをサポート\n- \\[2024/07\\] [InternVL2](https://huggingface.co/collections/OpenGVLab/internvl-20-667d3961ab5eb12c7ed1463e)全シリーズモデル、[InternLM-XComposer2.5](docs/en/multi_modal/xcomposer2d5.md)およびInternLM2.5の[ファンクションコール](docs/en/llm/api_server_tools.md)をサポート\n- \\[2024/06\\] PyTorchエンジンはDeepSeek-V2およびいくつかのVLMs、例えばCogVLM2、Mini-InternVL、LlaVA-Nextをサポート\n- \\[2024/05\\] 複数のGPUでVLMsをデプロイする際にビジョンモデルをバランスさせる\n- \\[2024/05\\] InternVL v1.5、LLaVa、InternLMXComposer2などのVLMsで4ビットの重みのみの量子化と推論をサポート\n- \\[2024/04\\] Llama3およびInternVL v1.1、v1.2、MiniGemini、InternLMXComposer2などのVLMモデルをサポート\n- \\[2024/04\\] TurboMindはすべてのサポートされているデバイスでのオンラインint8/int4 KVキャッシュ量子化と推論を追加しました。詳細なガイドは[こちら](docs/en/quantization/kv_quant.md)を参照してください\n- \\[2024/04\\] TurboMindの最新アップグレードによりGQAが強化され、[internlm2-20b](https://huggingface.co/internlm/internlm2-20b)モデルの推論が16+ RPSに達し、vLLMの約1.8倍の速さになりました\n- \\[2024/04\\] Qwen1.5-MOEおよびdbrxをサポート\n- \\[2024/03\\] DeepSeek-VLのオフライン推論パイプラインとサービングをサポート\n- \\[2024/03\\] VLMのオフライン推論パイプラインとサービングをサポート\n- \\[2024/02\\] Qwen 1.5、Gemma、Mistral、Mixtral、Deepseek-MOEなどをサポート\n- \\[2024/01\\] [OpenAOE](https://github.com/InternLM/OpenAOE)が[LMDeployサービングサービス](./docs/en/llm/api_server.md)とシームレスに統合されました\n- \\[2024/01\\] 複数モデル、複数マシン、複数カードの推論サービスをサポート。使用方法は[こちら](./docs/en/llm/proxy_server.md)を参照してください\n- \\[2024/01\\] [PyTorch推論エンジン](./docs/en/inference/pytorch.md)をサポートし、完全にPythonで開発されており、開発者の障壁を下げ、新機能や技術の迅速な実験を可能にします\n\n</details>\n\n<details close>\n<summary><b>2023</b></summary>\n\n- \\[2023/12\\] Turbomindはマルチモーダル入力をサポート\n- \\[2023/11\\] Turbomindはhfモデルの直接読み込みをサポート。詳細は[こちら](docs/en/inference/load_hf.md)をクリックしてください\n- \\[2023/11\\] TurboMindの主要なアップグレード、包括的なPaged Attention、シーケンス長制限のない高速なアテンションカーネル、2倍速いKV8カーネル、Split-Kデコーディング（Flash Decoding）、およびsm_75のW4A16推論\n- \\[2023/09\\] TurboMindはQwen-14Bをサポート\n- \\[2023/09\\] TurboMindはInternLM-20Bをサポート\n- \\[2023/09\\] TurboMindはCode Llamaのすべての機能をサポート：コード補完、インフィリング、チャット/インストラクト、Pythonスペシャリスト。デプロイメントガイドは[こちら](./docs/en/llm/codellama.md)をクリックしてください\n- \\[2023/09\\] TurboMindはBaichuan2-7Bをサポート\n- \\[2023/08\\] TurboMindはflash-attention2をサポート\n- \\[2023/08\\] TurboMindはQwen-7B、動的NTK-RoPEスケーリング、動的logNスケーリングをサポート\n- \\[2023/08\\] TurboMindはWindowsをサポート（tp=1）\n- \\[2023/08\\] TurboMindは4ビット推論をサポートし、FP16の2.4倍の速さで、最速のオープンソース実装です。詳細な情報は[こちら](docs/en/quantization/w4a16.md)のガイドを確認してください\n- \\[2023/08\\] LMDeployは[HuggingFace Hub](https://huggingface.co/lmdeploy)で提供され、すぐに使用できる4ビットモデルを提供します\n- \\[2023/08\\] LMDeployは[AWQ](https://arxiv.org/abs/2306.00978)アルゴリズムを使用した4ビット量子化をサポート\n- \\[2023/07\\] TurboMindはGQAを使用したLlama-2 70Bをサポート\n- \\[2023/07\\] TurboMindはLlama-2 7B/13Bをサポート\n- \\[2023/07\\] TurboMindはInternLMのテンソル並列推論をサポート\n\n</details>\n\n______________________________________________________________________\n\n# 紹介\n\nLMDeployは、[MMRazor](https://github.com/open-mmlab/mmrazor)および[MMDeploy](https://github.com/open-mmlab/mmdeploy)チームによって開発された、LLMの圧縮、デプロイ、およびサービングのためのツールキットです。以下の主要な機能を備えています：\n\n- **効率的な推論**：LMDeployは、persistent batch（連続バッチ）、ブロック化されたKVキャッシュ、動的分割と融合、テンソル並列、高性能なCUDAカーネルなどの主要な機能を導入し、vLLMよりも最大1.8倍のリクエストスループットを提供します。\n\n- **効果的な量子化**：LMDeployは、重みのみおよびk/vの量子化をサポートし、4ビットの推論性能はFP16の2.4倍です。量子化の品質はOpenCompassの評価を通じて確認されています。\n\n- **簡単な分散サーバー**：リクエスト分散サービスを活用することで、LMDeployは複数のマシンおよびカードにわたるマルチモデルサービスのデプロイを容易にします。\n\n- **優れた互換性**：LMDeployは、[KV Cache Quant](docs/en/quantization/kv_quant.md)、[AWQ](docs/en/quantization/w4a16.md)、および[Automatic Prefix Caching](docs/en/inference/turbomind_config.md)を同時に使用することをサポートします。\n\n# パフォーマンス\n\nLMDeploy TurboMindエンジンは卓越した推論能力を持ち、さまざまな規模のモデルで、vLLMの1.36〜1.85倍のリクエストを毎秒処理します。静的推論能力の面では、TurboMind 4ビットモデルの推論速度（out token/s）はFP16/BF16推論をはるかに上回ります。小さなバッチでは、2.4倍に向上します。\n\n![v0 1 0-benchmark](https://github.com/InternLM/lmdeploy/assets/4560679/8e455cf1-a792-4fa8-91a2-75df96a2a5ba)\n\n# サポートされているモデル\n\n<table>\n<tbody>\n<tr align=\"center\" valign=\"middle\">\n<td>\n  <b>LLMs</b>\n</td>\n<td>\n  <b>VLMs</b>\n</td>\n<tr valign=\"top\">\n<td align=\"left\" valign=\"top\">\n<ul>\n  <li>Llama (7B - 65B)</li>\n  <li>Llama2 (7B - 70B)</li>\n  <li>Llama3 (8B, 70B)</li>\n  <li>Llama3.1 (8B, 70B)</li>\n  <li>Llama3.2 (1B, 3B)</li>\n  <li>InternLM (7B - 20B)</li>\n  <li>InternLM2 (7B - 20B)</li>\n  <li>InternLM3 (8B)</li>\n  <li>InternLM2.5 (7B)</li>\n  <li>Qwen (1.8B - 72B)</li>\n  <li>Qwen1.5 (0.5B - 110B)</li>\n  <li>Qwen1.5 - MoE (0.5B - 72B)</li>\n  <li>Qwen2 (0.5B - 72B)</li>\n  <li>Qwen2-MoE (57BA14B)</li>\n  <li>Qwen2.5 (0.5B - 32B)</li>\n  <li>Qwen3, Qwen3-MoE</li>\n  <li>Qwen3-Next(80B)</li>\n  <li>Baichuan (7B)</li>\n  <li>Baichuan2 (7B-13B)</li>\n  <li>Code Llama (7B - 34B)</li>\n  <li>ChatGLM2 (6B)</li>\n  <li>GLM-4 (9B)</li>\n  <li>GLM-4-0414 (9B, 32B)</li>\n  <li>CodeGeeX4 (9B)</li>\n  <li>YI (6B-34B)</li>\n  <li>Mistral (7B)</li>\n  <li>DeepSeek-MoE (16B)</li>\n  <li>DeepSeek-V2 (16B, 236B)</li>\n  <li>DeepSeek-V2.5 (236B)</li>\n  <li>DeepSeek-V3 (685B)</li>\n  <li>DeepSeek-V3.2 (685B)</li>\n  <li>Mixtral (8x7B, 8x22B)</li>\n  <li>Gemma (2B - 7B)</li>\n  <li>StarCoder2 (3B - 15B)</li>\n  <li>Phi-3-mini (3.8B)</li>\n  <li>Phi-3.5-mini (3.8B)</li>\n  <li>Phi-3.5-MoE (16x3.8B)</li>\n  <li>Phi-4-mini (3.8B)</li>\n  <li>MiniCPM3 (4B)</li>\n  <li>SDAR (1.7B-30B)</li>\n  <li>gpt-oss (20B, 120B)</li>\n  <li>GLM-4.7-Flash (30B)</li>\n  <li>GLM-5 (754B)</li>\n</ul>\n</td>\n<td>\n<ul>\n  <li>LLaVA(1.5,1.6) (7B-34B)</li>\n  <li>InternLM-XComposer2 (7B, 4khd-7B)</li>\n  <li>InternLM-XComposer2.5 (7B)</li>\n  <li>Qwen-VL (7B)</li>\n  <li>Qwen2-VL (2B, 7B, 72B)</li>\n  <li>Qwen2.5-VL (3B, 7B, 72B)</li>\n  <li>Qwen3-VL (2B - 235B)</li>\n  <li>Qwen3.5 (0.8B - 397B)</li>\n  <li>DeepSeek-VL (7B)</li>\n  <li>DeepSeek-VL2 (3B, 16B, 27B)</li>\n  <li>InternVL-Chat (v1.1-v1.5)</li>\n  <li>InternVL2 (1B-76B)</li>\n  <li>InternVL2.5(MPO) (1B-78B)</li>\n  <li>InternVL3 (1B-78B)</li>\n  <li>InternVL3.5 (1B-241BA28B)</li>\n  <li>Intern-S1 (241B)</li>\n  <li>Intern-S1-mini (8.3B)</li>\n  <li>Mono-InternVL (2B)</li>\n  <li>ChemVLM (8B-26B)</li>\n  <li>CogVLM-Chat (17B)</li>\n  <li>CogVLM2-Chat (19B)</li>\n  <li>MiniCPM-Llama3-V-2_5</li>\n  <li>MiniCPM-V-2_6</li>\n  <li>Phi-3-vision (4.2B)</li>\n  <li>Phi-3.5-vision (4.2B)</li>\n  <li>GLM-4V (9B)</li>\n  <li>GLM-4.1V-Thinking (9B)</li>\n  <li>Llama3.2-vision (11B, 90B)</li>\n  <li>Molmo (7B-D,72B)</li>\n  <li>Gemma3 (1B - 27B)</li>\n  <li>Llama4 (Scout, Maverick)</li>\n</ul>\n</td>\n</tr>\n</tbody>\n</table>\n\nLMDeployは、[TurboMind](./docs/en/inference/turbomind.md)および[PyTorch](./docs/en/inference/pytorch.md)の2つの推論エンジンを開発しました。それぞれ異なる焦点を持っています。前者は推論性能の究極の最適化を目指し、後者は完全にPythonで開発されており、開発者の障壁を下げることを目指しています。\n\nサポートされているモデルの種類や推論データタイプに違いがあります。各エンジンの能力については[この表](./docs/en/supported_models/supported_models.md)を参照し、実際のニーズに最適なものを選択してください。\n\n# クイックスタート [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1Dh-YlSwg78ZO3AlleO441NF_QP2shs95#scrollTo=YALmXnwCG1pQ)\n\n## インストール\n\nクリーンなconda環境（Python 3.10 - 3.13）でlmdeployをインストールすることをお勧めします。\n\n```shell\nconda create -n lmdeploy python=3.10 -y\nconda activate lmdeploy\npip install lmdeploy\n```\n\nv0.3.0から、デフォルトの事前構築済みパッケージはCUDA 12でコンパイルされています。\nCUDA 11+プラットフォームでのインストールに関する情報、またはソースからのビルド手順については、[インストールガイドを](docs/en/get_started/installation.md)参照してください。\n\n## オフラインバッチ推論\n\n```python\nimport lmdeploy\nwith lmdeploy.pipeline(\"internlm/internlm3-8b-instruct\") as pipe:\n    response = pipe([\"Hi, pls intro yourself\", \"Shanghai is\"])\n    print(response)\n```\n\n> \\[!NOTE\\]\n> デフォルトでは、LMDeployはHuggingFaceからモデルをダウンロードします。ModelScopeからモデルを使用する場合は、`pip install modelscope`コマンドでModelScopeをインストールし、環境変数を設定してください：\n>\n> `export LMDEPLOY_USE_MODELSCOPE=True`\n>\n> openMind Hubからモデルを使用する場合は、`pip install openmind_hub`コマンドでopenMind Hubをインストールし、環境変数を設定してください：\n>\n> `export LMDEPLOY_USE_OPENMIND_HUB=True`\n\n推論パイプラインに関する詳細情報は[こちら](./docs/en/llm/pipeline.md)を参照してください。\n\n# チュートリアル\n\nLMDeployの基本的な使用方法については、[getting_started](docs/en/get_started/get_started.md)セクションを参照してください。\n\n詳細なユーザーガイドと高度なガイドについては、[チュートリアル](https://lmdeploy.readthedocs.io/en/latest/)を参照してください：\n\n- ユーザーガイド\n  - [LLM推論パイプライン](./docs/en/llm/pipeline.md) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1Dh-YlSwg78ZO3AlleO441NF_QP2shs95#scrollTo=YALmXnwCG1pQ)\n  - [VLM推論パイプライン](./docs/en/multi_modal/vl_pipeline.md) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1nKLfnPeDA3p-FMNw2NhI-KOpk7-nlNjF?usp=sharing)\n  - [LLMサービング](docs/en/llm/api_server.md)\n  - [VLMサービング](docs/en/multi_modal/api_server_vl.md)\n  - [量子化](docs/en/quantization)\n- 高度なガイド\n  - [推論エンジン - TurboMind](docs/en/inference/turbomind.md)\n  - [推論エンジン - PyTorch](docs/en/inference/pytorch.md)\n  - [カスタムチャットテンプレート](docs/en/advance/chat_template.md)\n  - [新しいモデルの追加](docs/en/advance/pytorch_new_model.md)\n  - gemmチューニング\n  - [長文推論](docs/en/advance/long_context.md)\n  - [マルチモデル推論サービス](docs/en/llm/proxy_server.md)\n\n# サードパーティプロジェクト\n\n- LMDeployを使用してNVIDIA JetsonプラットフォームでLLMをオフラインでデプロイ：[LMDeploy-Jetson](https://github.com/BestAnHongjun/LMDeploy-Jetson)\n- LMDeployとBentoMLを使用してLLMをデプロイするためのサンプルプロジェクト：[BentoLMDeploy](https://github.com/bentoml/BentoLMDeploy)\n\n# 貢献\n\nLMDeployへのすべての貢献に感謝します。貢献ガイドラインについては、[CONTRIBUTING.md](.github/CONTRIBUTING.md)を参照してください。\n\n# 謝辞\n\n- [FasterTransformer](https://github.com/NVIDIA/FasterTransformer)\n- [llm-awq](https://github.com/mit-han-lab/llm-awq)\n- [vLLM](https://github.com/vllm-project/vllm)\n- [DeepSpeed-MII](https://github.com/microsoft/DeepSpeed-MII)\n\n# 引用\n\n```bibtex\n@misc{2023lmdeploy,\n    title={LMDeploy: A Toolkit for Compressing, Deploying, and Serving LLM},\n    author={LMDeploy Contributors},\n    howpublished = {\\url{https://github.com/InternLM/lmdeploy}},\n    year={2023}\n}\n```\n\n# ライセンス\n\nこのプロジェクトは[Apache 2.0ライセンス](LICENSE)の下でリリースされています。\n"
  },
  {
    "path": "README_zh-CN.md",
    "content": "<div align=\"center\">\n  <img src=\"docs/en/_static/image/lmdeploy-logo.svg\" width=\"450\"/>\n\n[![PyPI](https://img.shields.io/pypi/v/lmdeploy)](https://pypi.org/project/lmdeploy)\n![PyPI - Downloads](https://img.shields.io/pypi/dm/lmdeploy)\n[![license](https://img.shields.io/github/license/InternLM/lmdeploy.svg)](https://github.com/InternLM/lmdeploy/tree/main/LICENSE)\n[![issue resolution](https://img.shields.io/github/issues-closed-raw/InternLM/lmdeploy)](https://github.com/InternLM/lmdeploy/issues)\n[![open issues](https://img.shields.io/github/issues-raw/InternLM/lmdeploy)](https://github.com/InternLM/lmdeploy/issues)\n\n[📘Documentation](https://lmdeploy.readthedocs.io/zh-cn/latest/) |\n[🛠️Quick Start](https://lmdeploy.readthedocs.io/zh-cn/latest/get_started/get_started.html) |\n[🤔Reporting Issues](https://github.com/InternLM/lmdeploy/issues/new/choose)\n\n[English](README.md) | 简体中文 | [日本語](README_ja.md)\n\n👋 join us on [![Static Badge](https://img.shields.io/badge/-grey?style=social&logo=wechat&label=WeChat)](https://cdn.vansin.top/internlm/lmdeploy.jpg)\n[![Static Badge](https://img.shields.io/badge/-grey?style=social&logo=twitter&label=Twitter)](https://twitter.com/intern_lm)\n[![Static Badge](https://img.shields.io/badge/-grey?style=social&logo=discord&label=Discord)](https://discord.gg/xa29JuW87d)\n\n</div>\n\n______________________________________________________________________\n\n## 最新进展 🎉\n\n<details open>\n<summary><b>2026</b></summary>\n\n- \\[2026/02\\] 支持 [Qwen3.5](https://huggingface.co/collections/Qwen/qwen35)\n- \\[2026/02\\] 支持 [vllm-project/llm-compressor](https://github.com/vllm-project/llm-compressor) 4bit 对称和非对称量化。 具体操作指南详见[此处](./docs/zh_cn/quantization/llm_compressor.md)\n\n</details>\n\n<details close>\n<summary><b>2025</b></summary>\n\n- 【2025年9月】TurboMind 引擎支持 MXFP4，适用于 NVIDIA V100 及以上 GPU。在 H800 上推理 openai gpt-oss 模型，性能可达 vLLM 的 1.5倍！\n- 【2025年6月】深度优化 FP8 MoE 模型推理\n- 【2025年6月】集成[DLSlime](https://github.com/DeepLink-org/DLSlime)和[Mooncake](https://github.com/kvcache-ai/Mooncake)，实现DeepSeek PD分离部署，向两个团队表示诚挚的感谢！\n- 【2025年4月】集成deepseek-ai组件FlashMLA、DeepGemm、DeepEP、MicroBatch、eplb等，提升DeepSeek推理性能\n- 【2025年1月】新增对DeepSeek V3及R1的支持\n\n</details>\n\n<details close>\n<summary><b>2024</b></summary>\n\n- \\[2024/11\\] PyTorch engine 支持 Mono-InternVL 模型\n- \\[2024/10\\] PyTorchEngine 在 ascend 平台上支持了图模式，推理性能提高了 1 倍\n- \\[2024/09\\] LMDeploy PyTorchEngine 增加了对 [华为 Ascend](docs/zh_cn/get_started/ascend/get_started.md) 的支持。支持的模型请见[这里](docs/zh_cn/supported_models/supported_models.md)\n- \\[2024/09\\] 通过引入 CUDA Graph，LMDeploy PyTorchEngine 在 Llama3-8B 推理上实现了 1.3 倍的加速\n- \\[2024/08\\] LMDeploy现已集成至 [modelscope/swift](https://github.com/modelscope/swift)，成为 VLMs 推理的默认加速引擎\n- \\[2024/07\\] 支持 Llama3.1 8B 和 70B 模型，以及工具调用功能\n- \\[2024/07\\] 支持 [InternVL2](docs/zh_cn/multi_modal/internvl.md) 全系列模型，[InternLM-XComposer2.5](docs/zh_cn/multi_modal/xcomposer2d5.md) 模型和 InternLM2.5 的 [function call 功能](docs/zh_cn/llm/api_server_tools.md)\n- \\[2024/06\\] PyTorch engine 支持了 DeepSeek-V2 和若干 VLM 模型推理, 比如 CogVLM2，Mini-InternVL，LlaVA-Next\n- \\[2024/05\\] 在多 GPU 上部署 VLM 模型时，支持把视觉部分的模型均分到多卡上\n- \\[2024/05\\] 支持InternVL v1.5, LLaVa, InternLMXComposer2 等 VLMs 模型的 4bit 权重量化和推理\n- \\[2024/04\\] 支持 Llama3 和 InternVL v1.1, v1.2，MiniGemini，InternLM-XComposer2 等 VLM 模型\n- \\[2024/04\\] TurboMind 支持 kv cache int4/int8 在线量化和推理，适用已支持的所有型号显卡。详情请参考[这里](docs/zh_cn/quantization/kv_quant.md)\n- \\[2024/04\\] TurboMind 引擎升级，优化 GQA 推理。[internlm2-20b](https://huggingface.co/internlm/internlm2-20b) 推理速度达 16+ RPS，约是 vLLM 的 1.8 倍\n- \\[2024/04\\] 支持 Qwen1.5-MOE 和 dbrx.\n- \\[2024/03\\] 支持 DeepSeek-VL 的离线推理 pipeline 和推理服务\n- \\[2024/03\\] 支持视觉-语言模型（VLM）的离线推理 pipeline 和推理服务\n- \\[2024/02\\] 支持 Qwen 1.5、Gemma、Mistral、Mixtral、Deepseek-MOE 等模型\n- \\[2024/01\\] [OpenAOE](https://github.com/InternLM/OpenAOE) 发布，支持无缝接入[LMDeploy Serving Service](docs/zh_cn/llm/api_server.md)\n- \\[2024/01\\] 支持多模型、多机、多卡推理服务。使用方法请参考[此处](docs/zh_cn/llm/proxy_server.md)\n- \\[2024/01\\] 增加 [PyTorch 推理引擎](./docs/zh_cn/inference/pytorch.md)，作为 TurboMind 引擎的补充。帮助降低开发门槛，和快速实验新特性、新技术\n\n</details>\n\n<details close>\n<summary><b>2023</b></summary>\n\n- \\[2023/12\\] Turbomind 支持多模态输入\n- \\[2023/11\\] Turbomind 支持直接读取 Huggingface 模型。点击[这里](docs/zh_cn/inference/load_hf.md)查看使用方法\n- \\[2023/11\\] TurboMind 重磅升级。包括：Paged Attention、更快的且不受序列最大长度限制的 attention kernel、2+倍快的 KV8 kernels、Split-K decoding (Flash Decoding) 和 支持 sm_75 架构的 W4A16\n- \\[2023/09\\] TurboMind 支持 Qwen-14B\n- \\[2023/09\\] TurboMind 支持 InternLM-20B 模型\n- \\[2023/09\\] TurboMind 支持 Code Llama 所有功能：代码续写、填空、对话、Python专项。点击[这里](./docs/zh_cn/llm/codellama.md)阅读部署方法\n- \\[2023/09\\] TurboMind 支持 Baichuan2-7B\n- \\[2023/08\\] TurboMind 支持 flash-attention2\n- \\[2023/08\\] TurboMind 支持 Qwen-7B，动态NTK-RoPE缩放，动态logN缩放\n- \\[2023/08\\] TurboMind 支持 Windows (tp=1)\n- \\[2023/08\\] TurboMind 支持 4-bit 推理，速度是 FP16 的 2.4 倍，是目前最快的开源实现。部署方式请看[这里](docs/zh_cn/quantization/w4a16.md)\n- \\[2023/08\\] LMDeploy 开通了 [HuggingFace Hub](https://huggingface.co/lmdeploy) ，提供开箱即用的 4-bit 模型\n- \\[2023/08\\] LMDeploy 支持使用 [AWQ](https://arxiv.org/abs/2306.00978) 算法进行 4-bit 量化\n- \\[2023/07\\] TurboMind 支持使用 GQA 的 Llama-2 70B 模型\n- \\[2023/07\\] TurboMind 支持 Llama-2 7B/13B 模型\n- \\[2023/07\\] TurboMind 支持 InternLM 的 Tensor Parallel 推理\n\n</details>\n______________________________________________________________________\n\n# 简介\n\nLMDeploy 由 [MMDeploy](https://github.com/open-mmlab/mmdeploy) 和 [MMRazor](https://github.com/open-mmlab/mmrazor) 团队联合开发，是涵盖了 LLM 任务的全套轻量化、部署和服务解决方案。\n这个强大的工具箱提供以下核心功能：\n\n- **高效的推理**：LMDeploy 开发了 Persistent Batch(即 Continuous Batch)，Blocked K/V Cache，动态拆分和融合，张量并行，高效的计算 kernel等重要特性。推理性能是 vLLM 的 1.8 倍\n\n- **可靠的量化**：LMDeploy 支持权重量化和 k/v 量化。4bit 模型推理效率是 FP16 下的 2.4 倍。量化模型的可靠性已通过 OpenCompass 评测得到充分验证。\n\n- **便捷的服务**：通过请求分发服务，LMDeploy 支持多模型在多机、多卡上的推理服务。\n\n- **卓越的兼容性**: LMDeploy 支持 [KV Cache 量化](docs/zh_cn/quantization/kv_quant.md), [AWQ](docs/zh_cn/quantization/w4a16.md) 和 [Automatic Prefix Caching](docs/zh_cn/inference/turbomind_config.md) 同时使用。\n\n# 性能\n\nLMDeploy TurboMind 引擎拥有卓越的推理能力，在各种规模的模型上，每秒处理的请求数是 vLLM 的 1.36 ~ 1.85 倍。在静态推理能力方面，TurboMind 4bit 模型推理速度（out token/s）远高于 FP16/BF16 推理。在小 batch 时，提高到 2.4 倍。\n\n![v0 1 0-benchmark](https://github.com/InternLM/lmdeploy/assets/4560679/8e455cf1-a792-4fa8-91a2-75df96a2a5ba)\n\n# 支持的模型\n\n<table>\n<tbody>\n<tr align=\"center\" valign=\"middle\">\n<td>\n  <b>LLMs</b>\n</td>\n<td>\n  <b>VLMs</b>\n</td>\n<tr valign=\"top\">\n<td align=\"left\" valign=\"top\">\n<ul>\n  <li>Llama (7B - 65B)</li>\n  <li>Llama2 (7B - 70B)</li>\n  <li>Llama3 (8B, 70B)</li>\n  <li>Llama3.1 (8B, 70B)</li>\n  <li>Llama3.2 (1B, 3B)</li>\n  <li>InternLM (7B - 20B)</li>\n  <li>InternLM2 (7B - 20B)</li>\n  <li>InternLM3 (8B)</li>\n  <li>InternLM2.5 (7B)</li>\n  <li>Qwen (1.8B - 72B)</li>\n  <li>Qwen1.5 (0.5B - 110B)</li>\n  <li>Qwen1.5 - MoE (0.5B - 72B)</li>\n  <li>Qwen2 (0.5B - 72B)</li>\n  <li>Qwen2-MoE (57BA14B)</li>\n  <li>Qwen2.5 (0.5B - 32B)</li>\n  <li>Qwen3, Qwen3-MoE</li>\n  <li>Qwen3-Next(80B)</li>\n  <li>Baichuan (7B)</li>\n  <li>Baichuan2 (7B-13B)</li>\n  <li>Code Llama (7B - 34B)</li>\n  <li>ChatGLM2 (6B)</li>\n  <li>GLM-4 (9B)</li>\n  <li>GLM-4-0414 (9B, 32B)</li>\n  <li>CodeGeeX4 (9B)</li>\n  <li>YI (6B-34B)</li>\n  <li>Mistral (7B)</li>\n  <li>DeepSeek-MoE (16B)</li>\n  <li>DeepSeek-V2 (16B, 236B)</li>\n  <li>DeepSeek-V2.5 (236B)</li>\n  <li>DeepSeek-V3 (685B)</li>\n  <li>DeepSeek-V3.2 (685B)</li>\n  <li>Mixtral (8x7B, 8x22B)</li>\n  <li>Gemma (2B - 7B)</li>\n  <li>StarCoder2 (3B - 15B)</li>\n  <li>Phi-3-mini (3.8B)</li>\n  <li>Phi-3.5-mini (3.8B)</li>\n  <li>Phi-3.5-MoE (16x3.8B)</li>\n  <li>Phi-4-mini (3.8B)</li>\n  <li>MiniCPM3 (4B)</li>\n  <li>SDAR (1.7B-30B)</li>\n  <li>gpt-oss (20B, 120B)</li>\n  <li>GLM-4.7-Flash (30B)</li>\n  <li>GLM-5 (754B)</li>\n</ul>\n</td>\n<td>\n<ul>\n  <li>LLaVA(1.5,1.6) (7B-34B)</li>\n  <li>InternLM-XComposer2 (7B, 4khd-7B)</li>\n  <li>InternLM-XComposer2.5 (7B)</li>\n  <li>Qwen-VL (7B)</li>\n  <li>Qwen2-VL (2B, 7B, 72B)</li>\n  <li>Qwen2.5-VL (3B, 7B, 72B)</li>\n  <li>Qwen3-VL (2B - 235B)</li>\n  <li>Qwen3.5 (0.8B - 397B)</li>\n  <li>DeepSeek-VL (7B)</li>\n  <li>DeepSeek-VL2 (3B, 16B, 27B)</li>\n  <li>InternVL-Chat (v1.1-v1.5)</li>\n  <li>InternVL2 (1B-76B)</li>\n  <li>InternVL2.5(MPO) (1B-78B)</li>\n  <li>InternVL3 (1B-78B)</li>\n  <li>InternVL3.5 (1B-241BA28B)</li>\n  <li>Intern-S1 (241B)</li>\n  <li>Intern-S1-mini (8.3B)</li>\n  <li>Intern-S1-Pro (1TB)</li>\n  <li>Mono-InternVL (2B)</li>\n  <li>ChemVLM (8B-26B)</li>\n  <li>CogVLM-Chat (17B)</li>\n  <li>CogVLM2-Chat (19B)</li>\n  <li>MiniCPM-Llama3-V-2_5</li>\n  <li>MiniCPM-V-2_6</li>\n  <li>Phi-3-vision (4.2B)</li>\n  <li>Phi-3.5-vision (4.2B)</li>\n  <li>GLM-4V (9B)</li>\n  <li>GLM-4.1V-Thinking (9B)</li>\n  <li>Llama3.2-vision (11B, 90B)</li>\n  <li>Molmo (7B-D,72B)</li>\n  <li>Gemma3 (1B - 27B)</li>\n  <li>Llama4 (Scout, Maverick)</li>\n</ul>\n</td>\n</tr>\n</tbody>\n</table>\n\nLMDeploy 支持 2 种推理引擎： [TurboMind](./docs/zh_cn/inference/turbomind.md) 和 [PyTorch](./docs/zh_cn/inference/pytorch.md)，它们侧重不同。前者追求推理性能的极致优化，后者纯用python开发，着重降低开发者的门槛。\n\n它们在支持的模型类别、计算精度方面有所差别。用户可参考[这里](./docs/zh_cn/supported_models/supported_models.md), 查阅每个推理引擎的能力，并根据实际需求选择合适的。\n\n# 快速开始 [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1Dh-YlSwg78ZO3AlleO441NF_QP2shs95#scrollTo=YALmXnwCG1pQ)\n\n## 安装\n\n我们推荐在一个干净的conda环境下（python3.9 - 3.12），安装 lmdeploy：\n\n```shell\nconda create -n lmdeploy python=3.10 -y\nconda activate lmdeploy\npip install lmdeploy\n```\n\n自 v0.3.0 版本起，默认预编译包基于 **CUDA 12** 编译。\n\n若使用 GeForce RTX 50 系列显卡，请安装基于 **CUDA 12.8** 编译的 LMDeploy 预编译包。\n\n```shell\nexport LMDEPLOY_VERSION=0.12.2\nexport PYTHON_VERSION=310\npip install https://github.com/InternLM/lmdeploy/releases/download/v${LMDEPLOY_VERSION}/lmdeploy-${LMDEPLOY_VERSION}+cu128-cp${PYTHON_VERSION}-cp${PYTHON_VERSION}-manylinux2014_x86_64.whl --extra-index-url https://download.pytorch.org/whl/cu128\n```\n\n如果需要在 CUDA 11+ 下安装 LMDeploy，或者源码安装 LMDeploy，请参考[安装文档](docs/zh_cn/get_started/installation.md)\n\n## 离线批处理\n\n```python\nimport lmdeploy\nwith lmdeploy.pipeline(\"internlm/internlm3-8b-instruct\") as pipe:\n    response = pipe([\"Hi, pls intro yourself\", \"Shanghai is\"])\n    print(response)\n```\n\n> \\[!NOTE\\]\n> LMDeploy 默认从 HuggingFace 上面下载模型，如果要从 ModelScope 上面下载模型，请通过命令 `pip install modelscope` 安装ModelScope，并设置环境变量：\n>\n> `export LMDEPLOY_USE_MODELSCOPE=True`\n>\n> 如果要从 openMind Hub 上面下载模型，请通过命令 `pip install openmind_hub` 安装openMind Hub，并设置环境变量：\n>\n> `export LMDEPLOY_USE_OPENMIND_HUB=True`\n\n关于 pipeline 的更多推理参数说明，请参考[这里](docs/zh_cn/llm/pipeline.md)\n\n# 用户教程\n\n请阅读[快速上手](docs/zh_cn/get_started/get_started.md)章节，了解 LMDeploy 的基本用法。\n\n为了帮助用户更进一步了解 LMDeploy，我们准备了用户指南和进阶指南，请阅读我们的[文档](https://lmdeploy.readthedocs.io/zh-cn/latest/)：\n\n- 用户指南\n  - [LLM 推理 pipeline](docs/zh_cn/llm/pipeline.md) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1Dh-YlSwg78ZO3AlleO441NF_QP2shs95#scrollTo=YALmXnwCG1pQ)\n  - [VLM 推理 pipeline](docs/zh_cn/multi_modal/vl_pipeline.md) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1nKLfnPeDA3p-FMNw2NhI-KOpk7-nlNjF?usp=sharing)\n  - [LLM 推理服务](docs/zh_cn/llm/api_server.md)\n  - [VLM 推理服务](docs/zh_cn/multi_modal/api_server_vl.md)\n  - [模型量化](./docs/zh_cn/quantization)\n- 进阶指南\n  - [推理引擎 - TurboMind](./docs/zh_cn/inference/turbomind.md)\n  - [推理引擎 - PyTorch](./docs/zh_cn/inference/pytorch.md)\n  - [自定义对话模板](./docs/zh_cn/advance/chat_template.md)\n  - [支持新模型](./docs/zh_cn/advance/pytorch_new_model.md)\n  - gemm tuning\n  - [长文本推理](./docs/zh_cn/advance/long_context.md)\n  - [多模型推理服务](docs/zh_cn/llm/proxy_server.md)\n\n# 社区项目\n\n- 使用LMDeploy在英伟达Jetson系列板卡部署大模型：[LMDeploy-Jetson](https://github.com/BestAnHongjun/LMDeploy-Jetson)\n- 使用 LMDeploy 和 BentoML 部署大模型的示例项目：[BentoLMDeploy](https://github.com/bentoml/BentoLMDeploy)\n\n# 贡献指南\n\n我们感谢所有的贡献者为改进和提升 LMDeploy 所作出的努力。请参考[贡献指南](.github/CONTRIBUTING.md)来了解参与项目贡献的相关指引。\n\n# 致谢\n\n- [FasterTransformer](https://github.com/NVIDIA/FasterTransformer)\n- [llm-awq](https://github.com/mit-han-lab/llm-awq)\n- [vLLM](https://github.com/vllm-project/vllm)\n- [DeepSpeed-MII](https://github.com/microsoft/DeepSpeed-MII)\n\n# 引用\n\n```bibtex\n@misc{2023lmdeploy,\n    title={LMDeploy: A Toolkit for Compressing, Deploying, and Serving LLM},\n    author={LMDeploy Contributors},\n    howpublished = {\\url{https://github.com/InternLM/lmdeploy}},\n    year={2023}\n}\n```\n\n```bibtex\n@article{zhang2025efficient,\n  title={Efficient Mixed-Precision Large Language Model Inference with TurboMind},\n  author={Zhang, Li and Jiang, Youhe and He, Guoliang and Chen, Xin and Lv, Han and Yao, Qian and Fu, Fangcheng and Chen, Kai},\n  journal={arXiv preprint arXiv:2508.15601},\n  year={2025}\n}\n```\n\n# 开源许可证\n\n该项目采用 [Apache 2.0 开源许可证](LICENSE)。\n"
  },
  {
    "path": "autotest/benchmark/test_apiserver_performance.py",
    "content": "import pytest\nfrom utils.benchmark_utils import restful_test\nfrom utils.config_utils import get_func_config_list\n\n\ndef get_models(backend, parallel_config):\n    return get_func_config_list(backend, parallel_config, func_type='benchmark')\n\n\n@pytest.mark.turbomind\n@pytest.mark.gpu_num_1\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models(backend='turbomind', parallel_config={'tp': 1}))\ndef test_turbomind_apiserver_tp1(config, run_config, worker_id):\n    result, msg = restful_test(config, run_config, worker_id=worker_id)\n    assert result, msg\n\n\n@pytest.mark.turbomind\n@pytest.mark.gpu_num_2\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models(backend='turbomind', parallel_config={'tp': 2}))\ndef test_turbomind_apiserver_tp2(config, run_config, worker_id):\n    result, msg = restful_test(config, run_config, worker_id=worker_id)\n    assert result, msg\n\n\n@pytest.mark.turbomind\n@pytest.mark.gpu_num_4\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models(backend='turbomind', parallel_config={'tp': 4}))\ndef test_turbomind_apiserver_tp4(config, run_config, worker_id):\n    result, msg = restful_test(config, run_config, worker_id=worker_id)\n    assert result, msg\n\n\n@pytest.mark.turbomind\n@pytest.mark.gpu_num_8\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models(backend='turbomind', parallel_config={'tp': 8}))\ndef test_turbomind_apiserver_tp8(config, run_config, worker_id):\n    result, msg = restful_test(config, run_config, worker_id=worker_id)\n    assert result, msg\n\n\n@pytest.mark.pytorch\n@pytest.mark.gpu_num_1\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models(backend='pytorch', parallel_config={'tp': 1}))\ndef test_pytorch_apiserver_tp1(config, run_config, worker_id):\n    result, msg = restful_test(config, run_config, worker_id=worker_id)\n    assert result, msg\n\n\n@pytest.mark.pytorch\n@pytest.mark.gpu_num_2\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models(backend='pytorch', parallel_config={'tp': 2}))\ndef test_pytorch_apiserver_tp2(config, run_config, worker_id):\n    result, msg = restful_test(config, run_config, worker_id=worker_id)\n    assert result, msg\n\n\n@pytest.mark.pytorch\n@pytest.mark.gpu_num_4\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models(backend='pytorch', parallel_config={'tp': 4}))\ndef test_pytorch_apiserver_tp4(config, run_config, worker_id):\n    result, msg = restful_test(config, run_config, worker_id=worker_id)\n    assert result, msg\n\n\n@pytest.mark.pytorch\n@pytest.mark.gpu_num_8\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models(backend='pytorch', parallel_config={'tp': 8}))\ndef test_pytorch_apiserver_tp8(config, run_config, worker_id):\n    result, msg = restful_test(config, run_config, worker_id=worker_id)\n    assert result, msg\n\n\n@pytest.mark.pytorch\n@pytest.mark.gpu_num_16\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models(backend='pytorch', parallel_config={'tp': 16}))\ndef test_pytorch_apiserver_tp16(config, run_config, worker_id):\n    result, msg = restful_test(config, run_config, worker_id=worker_id)\n    assert result, msg\n\n\n@pytest.mark.function\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.gpu_num_2\n@pytest.mark.parametrize('run_config', [{\n    'model': 'Qwen/Qwen3-30B-A3B',\n    'backend': 'pytorch',\n    'communicator': 'nccl',\n    'quant_policy': 0,\n    'parallel_config': {\n        'tp': 2\n    },\n    'extra_params': {}\n}, {\n    'model': 'Qwen/Qwen3-30B-A3B',\n    'backend': 'turbomind',\n    'communicator': 'nccl',\n    'quant_policy': 4,\n    'parallel_config': {\n        'tp': 2\n    },\n    'extra_params': {}\n}, {\n    'model': 'Qwen/Qwen3-30B-A3B',\n    'backend': 'turbomind',\n    'communicator': 'cuda-ipc',\n    'quant_policy': 8,\n    'parallel_config': {\n        'tp': 2\n    },\n    'extra_params': {}\n}, {\n    'model': 'Qwen/Qwen3-VL-30B-A3B-Instruct',\n    'backend': 'pytorch',\n    'communicator': 'nccl',\n    'quant_policy': 8,\n    'parallel_config': {\n        'tp': 2\n    },\n    'extra_params': {}\n}])\ndef test_restful_func_tp2(config, run_config, worker_id):\n    result, msg = restful_test(config, run_config, worker_id=worker_id, is_smoke=True)\n\n    assert result, msg\n"
  },
  {
    "path": "autotest/benchmark/test_longtext_performance.py",
    "content": "import pytest\nfrom utils.benchmark_utils import longtext_throughput_test\nfrom utils.config_utils import get_func_config_list\n\n\ndef get_models(backend, parallel_config):\n    return get_func_config_list(backend, parallel_config, func_type='longtext_benchmark')\n\n\n@pytest.mark.turbomind\n@pytest.mark.gpu_num_1\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models(backend='turbomind', parallel_config={'tp': 1}))\ndef test_turbomind_longtext_throughput_tp1(config, run_config, worker_id):\n    result, msg = longtext_throughput_test(config, run_config, worker_id=worker_id)\n    assert result, msg\n\n\n@pytest.mark.turbomind\n@pytest.mark.gpu_num_2\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models(backend='turbomind', parallel_config={'tp': 2}))\ndef test_turbomind_longtext_throughput_tp2(config, run_config, worker_id):\n    result, msg = longtext_throughput_test(config, run_config, worker_id=worker_id)\n    assert result, msg\n\n\n@pytest.mark.turbomind\n@pytest.mark.gpu_num_4\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models(backend='turbomind', parallel_config={'tp': 4}))\ndef test_turbomind_longtext_throughput_tp4(config, run_config, worker_id):\n    result, msg = longtext_throughput_test(config, run_config, worker_id=worker_id)\n    assert result, msg\n\n\n@pytest.mark.turbomind\n@pytest.mark.gpu_num_8\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models(backend='turbomind', parallel_config={'tp': 8}))\ndef test_turbomind_longtext_throughput_tp8(config, run_config, worker_id):\n    result, msg = longtext_throughput_test(config, run_config, worker_id=worker_id)\n    assert result, msg\n\n\n@pytest.mark.pytorch\n@pytest.mark.gpu_num_1\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models(backend='pytorch', parallel_config={'tp': 1}))\ndef test_pytorch_longtext_throughput_tp1(config, run_config, worker_id):\n    result, msg = longtext_throughput_test(config, run_config, worker_id=worker_id)\n    assert result, msg\n\n\n@pytest.mark.pytorch\n@pytest.mark.gpu_num_2\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models(backend='pytorch', parallel_config={'tp': 2}))\ndef test_pytorch_longtext_throughput_tp2(config, run_config, worker_id):\n    result, msg = longtext_throughput_test(config, run_config, worker_id=worker_id)\n    assert result, msg\n\n\n@pytest.mark.pytorch\n@pytest.mark.gpu_num_4\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models(backend='pytorch', parallel_config={'tp': 4}))\ndef test_pytorch_longtext_throughput_tp4(config, run_config, worker_id):\n    result, msg = longtext_throughput_test(config, run_config, worker_id=worker_id)\n    assert result, msg\n\n\n@pytest.mark.pytorch\n@pytest.mark.gpu_num_8\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models(backend='pytorch', parallel_config={'tp': 8}))\ndef test_pytorch_longtext_throughput_tp8(config, run_config, worker_id):\n    result, msg = longtext_throughput_test(config, run_config, worker_id=worker_id)\n    assert result, msg\n\n\n@pytest.mark.pytorch\n@pytest.mark.gpu_num_16\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models(backend='pytorch', parallel_config={'tp': 16}))\ndef test_pytorch_longtext_throughput_tp16(config, run_config, worker_id):\n    result, msg = longtext_throughput_test(config, run_config, worker_id=worker_id)\n    assert result, msg\n"
  },
  {
    "path": "autotest/benchmark/test_mllm_apiserver_performance.py",
    "content": "import pytest\nfrom utils.benchmark_utils import restful_test\nfrom utils.config_utils import get_func_config_list\n\n\ndef get_models(backend, parallel_config):\n    return get_func_config_list(backend, parallel_config, model_type='vl_model', func_type='mllm_evaluate')\n\n\n@pytest.mark.turbomind\n@pytest.mark.gpu_num_1\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models(backend='turbomind', parallel_config={'tp': 1}))\ndef test_turbomind_mllm_apiserver_tp1(config, run_config, worker_id):\n    result, msg = restful_test(config, run_config, worker_id=worker_id, is_mllm=True)\n    assert result, msg\n\n\n@pytest.mark.turbomind\n@pytest.mark.gpu_num_2\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models(backend='turbomind', parallel_config={'tp': 2}))\ndef test_turbomind_mllm_apiserver_tp2(config, run_config, worker_id):\n    result, msg = restful_test(config, run_config, worker_id=worker_id, is_mllm=True)\n    assert result, msg\n\n\n@pytest.mark.turbomind\n@pytest.mark.gpu_num_4\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models(backend='turbomind', parallel_config={'tp': 4}))\ndef test_turbomind_mllm_apiserver_tp4(config, run_config, worker_id):\n    result, msg = restful_test(config, run_config, worker_id=worker_id, is_mllm=True)\n    assert result, msg\n\n\n@pytest.mark.turbomind\n@pytest.mark.gpu_num_8\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models(backend='turbomind', parallel_config={'tp': 8}))\ndef test_turbomind_mllm_apiserver_tp8(config, run_config, worker_id):\n    result, msg = restful_test(config, run_config, worker_id=worker_id, is_mllm=True)\n    assert result, msg\n\n\n@pytest.mark.pytorch\n@pytest.mark.gpu_num_1\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models(backend='pytorch', parallel_config={'tp': 1}))\ndef test_pytorch_mllm_apiserver_tp1(config, run_config, worker_id):\n    result, msg = restful_test(config, run_config, worker_id=worker_id, is_mllm=True)\n    assert result, msg\n\n\n@pytest.mark.pytorch\n@pytest.mark.gpu_num_2\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models(backend='pytorch', parallel_config={'tp': 2}))\ndef test_pytorch_mllm_apiserver_tp2(config, run_config, worker_id):\n    result, msg = restful_test(config, run_config, worker_id=worker_id, is_mllm=True)\n    assert result, msg\n\n\n@pytest.mark.pytorch\n@pytest.mark.gpu_num_4\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models(backend='pytorch', parallel_config={'tp': 4}))\ndef test_pytorch_mllm_apiserver_tp4(config, run_config, worker_id):\n    result, msg = restful_test(config, run_config, worker_id=worker_id, is_mllm=True)\n    assert result, msg\n\n\n@pytest.mark.pytorch\n@pytest.mark.gpu_num_8\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models(backend='pytorch', parallel_config={'tp': 8}))\ndef test_pytorch_mllm_apiserver_tp8(config, run_config, worker_id):\n    result, msg = restful_test(config, run_config, worker_id=worker_id, is_mllm=True)\n    assert result, msg\n\n\n@pytest.mark.pytorch\n@pytest.mark.gpu_num_16\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models(backend='pytorch', parallel_config={'tp': 16}))\ndef test_pytorch_mllm_apiserver_tp16(config, run_config, worker_id):\n    result, msg = restful_test(config, run_config, worker_id=worker_id, is_mllm=True)\n    assert result, msg\n"
  },
  {
    "path": "autotest/benchmark/test_prefixcache_performance.py",
    "content": "import pytest\nfrom utils.benchmark_utils import prefixcache_throughput_test\nfrom utils.config_utils import get_func_config_list\n\n\ndef get_models(backend, parallel_config):\n    return get_func_config_list(backend, parallel_config, func_type='benchmark')\n\n\n@pytest.mark.turbomind\n@pytest.mark.gpu_num_1\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models(backend='turbomind', parallel_config={'tp': 1}))\ndef test_turbomind_prefix_tp1(config, run_config, worker_id):\n    result, msg = prefixcache_throughput_test(config, run_config, worker_id=worker_id)\n    assert result, msg\n\n\n@pytest.mark.turbomind\n@pytest.mark.gpu_num_2\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models(backend='turbomind', parallel_config={'tp': 2}))\ndef test_turbomind_prefix_tp2(config, run_config, worker_id):\n    result, msg = prefixcache_throughput_test(config, run_config, worker_id=worker_id)\n    assert result, msg\n\n\n@pytest.mark.turbomind\n@pytest.mark.gpu_num_4\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models(backend='turbomind', parallel_config={'tp': 4}))\ndef test_turbomind_prefix_tp4(config, run_config, worker_id):\n    result, msg = prefixcache_throughput_test(config, run_config, worker_id=worker_id)\n    assert result, msg\n\n\n@pytest.mark.turbomind\n@pytest.mark.gpu_num_8\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models(backend='turbomind', parallel_config={'tp': 8}))\ndef test_turbomind_prefix_tp8(config, run_config, worker_id):\n    result, msg = prefixcache_throughput_test(config, run_config, worker_id=worker_id)\n    assert result, msg\n\n\n@pytest.mark.pytorch\n@pytest.mark.gpu_num_1\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models(backend='pytorch', parallel_config={'tp': 1}))\ndef test_pytorch_prefix_tp1(config, run_config, worker_id):\n    result, msg = prefixcache_throughput_test(config, run_config, worker_id=worker_id)\n    assert result, msg\n\n\n@pytest.mark.pytorch\n@pytest.mark.gpu_num_2\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models(backend='pytorch', parallel_config={'tp': 2}))\ndef test_pytorch_prefix_tp2(config, run_config, worker_id):\n    result, msg = prefixcache_throughput_test(config, run_config, worker_id=worker_id)\n    assert result, msg\n\n\n@pytest.mark.pytorch\n@pytest.mark.gpu_num_4\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models(backend='pytorch', parallel_config={'tp': 4}))\ndef test_pytorch_prefix_tp4(config, run_config, worker_id):\n    result, msg = prefixcache_throughput_test(config, run_config, worker_id=worker_id)\n    assert result, msg\n\n\n@pytest.mark.pytorch\n@pytest.mark.gpu_num_8\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models(backend='pytorch', parallel_config={'tp': 8}))\ndef test_pytorch_prefix_tp8(config, run_config, worker_id):\n    result, msg = prefixcache_throughput_test(config, run_config, worker_id=worker_id)\n    assert result, msg\n\n\n@pytest.mark.pytorch\n@pytest.mark.gpu_num_16\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models(backend='pytorch', parallel_config={'tp': 16}))\ndef test_pytorch_prefix_tp16(config, run_config, worker_id):\n    result, msg = prefixcache_throughput_test(config, run_config, worker_id=worker_id)\n    assert result, msg\n\n\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.gpu_num_2\n@pytest.mark.function\n@pytest.mark.parametrize('run_config', [{\n    'model': 'Qwen/Qwen3-30B-A3B',\n    'backend': 'turbomind',\n    'communicator': 'cuda-ipc',\n    'quant_policy': 0,\n    'parallel_config': {\n        'tp': 2\n    },\n    'extra_params': {}\n}, {\n    'model': 'Qwen/Qwen3-VL-30B-A3B-Instruct',\n    'backend': 'pytorch',\n    'communicator': 'nccl',\n    'quant_policy': 8,\n    'parallel_config': {\n        'tp': 2\n    },\n    'extra_params': {}\n}])\ndef test_pytorch_prefix_pr_test_tp1(config, run_config, worker_id):\n    result, msg = prefixcache_throughput_test(config, run_config, worker_id=worker_id, is_smoke=True)\n    assert result, msg\n"
  },
  {
    "path": "autotest/benchmark/test_throughput_performance.py",
    "content": "import pytest\nfrom utils.benchmark_utils import throughput_test\nfrom utils.config_utils import get_func_config_list, get_workerid\n\n\ndef get_models(backend, parallel_config):\n    run_configs = get_func_config_list(backend, parallel_config, func_type='benchmark')\n    return [item for item in run_configs\n            if 'gpt' not in item['model']]  # gpt models are excluded because of openai_harmony is not supported yet\n\n\n@pytest.mark.turbomind\n@pytest.mark.gpu_num_1\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models(backend='turbomind', parallel_config={'tp': 1}))\ndef test_turbomind_throughput_tp1(config, run_config, worker_id):\n    result, msg = throughput_test(config, run_config, worker_id=worker_id)\n    assert result, msg\n\n\n@pytest.mark.turbomind\n@pytest.mark.gpu_num_2\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models(backend='turbomind', parallel_config={'tp': 2}))\ndef test_turbomind_throughput_tp2(config, run_config, worker_id):\n    result, msg = throughput_test(config, run_config, worker_id=worker_id)\n    assert result, msg\n\n\n@pytest.mark.turbomind\n@pytest.mark.gpu_num_4\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models(backend='turbomind', parallel_config={'tp': 4}))\ndef test_turbomind_throughput_tp4(config, run_config, worker_id):\n    result, msg = throughput_test(config, run_config, worker_id=worker_id)\n    assert result, msg\n\n\n@pytest.mark.turbomind\n@pytest.mark.gpu_num_8\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models(backend='turbomind', parallel_config={'tp': 8}))\ndef test_turbomind_throughput_tp8(config, run_config, worker_id):\n    result, msg = throughput_test(config, run_config, worker_id=worker_id)\n    assert result, msg\n\n\n@pytest.mark.pytorch\n@pytest.mark.gpu_num_1\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models(backend='pytorch', parallel_config={'tp': 1}))\ndef test_pytorch_throughput_tp1(config, run_config, worker_id):\n    result, msg = throughput_test(config, run_config, worker_id=worker_id)\n    assert result, msg\n\n\n@pytest.mark.pytorch\n@pytest.mark.gpu_num_2\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models(backend='pytorch', parallel_config={'tp': 2}))\ndef test_pytorch_throughput_tp2(config, run_config, worker_id):\n    result, msg = throughput_test(config, run_config, worker_id=worker_id)\n    assert result, msg\n\n\n@pytest.mark.pytorch\n@pytest.mark.gpu_num_4\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models(backend='pytorch', parallel_config={'tp': 4}))\ndef test_pytorch_throughput_tp4(config, run_config, worker_id):\n    result, msg = throughput_test(config, run_config, worker_id=worker_id)\n    assert result, msg\n\n\n@pytest.mark.pytorch\n@pytest.mark.gpu_num_8\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models(backend='pytorch', parallel_config={'tp': 8}))\ndef test_pytorch_throughput_tp8(config, run_config, worker_id):\n    result, msg = throughput_test(config, run_config, worker_id=worker_id)\n    assert result, msg\n\n\n@pytest.mark.pytorch\n@pytest.mark.gpu_num_16\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models(backend='pytorch', parallel_config={'tp': 16}))\ndef test_pytorch_throughput_tp16(config, run_config, worker_id):\n    result, msg = throughput_test(config, run_config, worker_id=worker_id)\n    assert result, msg\n\n\n@pytest.mark.function\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', [{\n    'model': 'Qwen/Qwen3-30B-A3B',\n    'backend': 'turbomind',\n    'communicator': 'cuda-ipc',\n    'quant_policy': 0,\n    'parallel_config': {\n        'tp': 2\n    },\n    'extra_params': {}\n}, {\n    'model': 'Qwen/Qwen3-VL-30B-A3B-Instruct',\n    'backend': 'pytorch',\n    'communicator': 'nccl',\n    'quant_policy': 8,\n    'parallel_config': {\n        'tp': 2\n    },\n    'extra_params': {}\n}])\ndef test_throughput_func_tp2(config, run_config, worker_id):\n    result, msg = throughput_test(config, run_config, worker_id=worker_id, is_smoke=True)\n    assert result, msg\n\n\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.gpu_num_1\n@pytest.mark.pr_test\n@pytest.mark.parametrize('run_config', [{\n    'model': 'meta-llama/Meta-Llama-3-1-8B-Instruct',\n    'backend': 'turbomind',\n    'communicator': 'nccl',\n    'quant_policy': 0,\n    'parallel_config': {\n        'tp': 1\n    },\n    'extra_params': {}\n}, {\n    'model': 'Qwen/Qwen3-VL-8B-Instruct',\n    'backend': 'pytorch',\n    'communicator': 'nccl',\n    'quant_policy': 8,\n    'parallel_config': {\n        'tp': 1\n    },\n    'extra_params': {}\n}])\ndef test_throughput_prtest_tp1(config, run_config, worker_id):\n    worker_id = 'gw' + str(6 + get_workerid(worker_id))\n    result, msg = throughput_test(config, run_config, worker_id=worker_id, is_smoke=True)\n    assert result, msg\n"
  },
  {
    "path": "autotest/chat_prompt_case.yml",
    "content": "base_testcase:\n    - 乌鲁木齐的景点A brief introduction to Urumqi’s attractions:\n        - contain:\n            - urumqi\n            - 乌鲁木齐\n            - 乌市\n            - xinjiang\n            - 新疆\n            - introduce\n            - 水磨沟\n            - 天池\n        - len_g:\n            10\n    - end:\n    - 介绍它的相应美食#please introduce some delicious foods:\n        - not contain:\n            - urumqi\n            - 乌鲁木齐\n            - 乌市\n            - xinjiang\n            - 新疆\n            - introduce\n            - 羊肉\n        - len_g:\n            10\nchat_testcase:\n    - 你好，你叫什么名字#hi, what's your name:\n    - end:\n    - 简要介绍乌鲁木齐的景点#A brief introduction to Urumqi’s attractions:\n        - contain:\n            - urumqi\n            - 乌鲁木齐\n            - 乌市\n            - xinjiang\n            - 新疆\n    - 介绍它的相应美食#please introduce some delicious foods:\n        - contain:\n            - urumqi\n            - 乌鲁木齐\n            - 乌市\n            - xinjiang\n            - 新疆\n            - 羊肉\n    - end:\n    - 介绍相应美食#please introduce some delicious foods:\n        - not contain:\n            - urumqi\n            - 乌鲁木齐\n            - 乌市\n            - xinjiang\n            - 新疆\ncode_testcase:\n    - 使用python编写一个int数组的冒泡排序代码:\n        - contain:\n            - def\n            - bubble\n            - 冒泡\n    - 快速排序呢:\n        - contain:\n            - def\n            - quick\n"
  },
  {
    "path": "autotest/config.yml",
    "content": "model_path: /nvme/qa_test_models\nresource_path: /nvme/qa_test_models/resource\nlog_path: /nvme/qa_test_models/autotest_log\nserver_log_path: /nvme/qa_test_models/server_log\neval_path: /nvme/qa_test_models/evaluation_report\nmllm_eval_path: /nvme/qa_test_models/mllm_evaluation_report\nbenchmark_path: /nvme/qa_test_models/benchmark_report\ndataset_path: /nvme/qa_test_models/datasets/ShareGPT_V3_unfiltered_cleaned_split.json\nprefix_dataset_path: /nvme/qa_test_models/datasets/prefix_cache_test.json\nenv_tag: a100\ndevice: cuda\n\nconfig:\n    tp:\n        meta-llama/Llama-4-Scout-17B-16E-Instruct: 4\n        meta-llama/Meta-Llama-3-1-70B-Instruct: 4\n        OpenGVLab/InternVL3-38B: 2\n        Qwen/Qwen3-235B-A22B: 8\n        Qwen/Qwen3-30B-A3B: 2\n        Qwen/Qwen3-32B: 2\n        Qwen/Qwen3-VL-30B-A3B-Instruct: 2\n        Qwen/Qwen3-30B-A3B-Base: 2\n        Qwen/Qwen2.5-VL-32B-Instruct: 2\n        mistralai/Mixtral-8x7B-Instruct-v0.1: 2\n        OpenGVLab/InternVL3_5-30B-A3B: 2\n        zai-org/GLM-4.7-Flash: 2\n\nturbomind_chat_model:\n    tp:\n        - meta-llama/Llama-3.2-1B-Instruct\n        - meta-llama/Llama-3.2-3B-Instruct\n        - meta-llama/Meta-Llama-3-1-8B-Instruct\n        - meta-llama/Meta-Llama-3-1-8B-Instruct-AWQ\n        - meta-llama/Meta-Llama-3-1-70B-Instruct\n        - meta-llama/Meta-Llama-3-8B-Instruct\n        - internlm/internlm3-8b-instruct\n        - internlm/internlm3-8b-instruct-awq\n        - OpenGVLab/InternVL3-8B\n        - OpenGVLab/InternVL3-38B\n        - OpenGVLab/InternVL3_5-30B-A3B\n        - Qwen/Qwen3-0.6B\n        - Qwen/Qwen3-4B\n        - Qwen/Qwen3-8B\n        - Qwen/Qwen3-32B\n        - Qwen/Qwen3-30B-A3B\n        - Qwen/Qwen3-30B-A3B-GPTQ-Int4\n        - Qwen/Qwen3-235B-A22B\n        - Qwen/Qwen2.5-VL-7B-Instruct\n        - Qwen/Qwen2.5-VL-32B-Instruct\n        - Qwen/Qwen1.5-MoE-A2.7B-Chat\n        - mistralai/Mixtral-8x7B-Instruct-v0.1\n        - THUDM/glm-4-9b-chat\n        - zai-org/GLM-4.7-Flash\n\npytorch_chat_model:\n    tp:\n        - meta-llama/Llama-4-Scout-17B-16E-Instruct\n        - meta-llama/Llama-3.2-1B-Instruct\n        - meta-llama/Llama-3.2-3B-Instruct\n        - meta-llama/Meta-Llama-3-1-8B-Instruct\n        - meta-llama/Meta-Llama-3-1-70B-Instruct\n        - meta-llama/Meta-Llama-3-8B-Instruct\n        - internlm/internlm3-8b-instruct\n        - OpenGVLab/InternVL3-8B\n        - OpenGVLab/InternVL3-38B\n        - OpenGVLab/InternVL3_5-30B-A3B\n        - Qwen/Qwen3-0.6B\n        - Qwen/Qwen3-4B\n        - Qwen/Qwen3-8B\n        - Qwen/Qwen3-32B\n        - Qwen/Qwen3-30B-A3B\n        - Qwen/Qwen3-235B-A22B\n        - Qwen/Qwen3-VL-8B-Instruct\n        - Qwen/Qwen3-VL-30B-A3B-Instruct\n        - THUDM/cogvlm2-llama3-chinese-chat-19B\n        - THUDM/glm-4v-9b\n        - THUDM/glm-4-9b-chat\n        - google/gemma-2-9b-it\n        - google/gemma-2-27b-it\n        - zai-org/GLM-4.7-Flash\n        - microsoft/Phi-3.5-vision-instruct\n        - microsoft/Phi-3-vision-128k-instruct\n\nturbomind_vl_model:\n    tp:\n        - OpenGVLab/InternVL3-8B\n        - OpenGVLab/InternVL3-38B\n        - OpenGVLab/InternVL3_5-30B-A3B\n        - Qwen/Qwen2.5-VL-7B-Instruct\n        - Qwen/Qwen2.5-VL-32B-Instruct\n\npytorch_vl_model:\n    tp:\n        - OpenGVLab/InternVL3-8B\n        - OpenGVLab/InternVL3_5-30B-A3B\n        - Qwen/Qwen3-VL-8B-Instruct\n        - Qwen/Qwen3-VL-30B-A3B-Instruct\n        - THUDM/cogvlm-chat-hf\n        - THUDM/cogvlm2-llama3-chinese-chat-19B\n        - THUDM/glm-4v-9b\n        - microsoft/Phi-3-vision-128k-instruct\n        - microsoft/Phi-3.5-vision-instruct\n\nturbomind_base_model:\n    tp:\n        - Qwen/Qwen3-8B-Base\n        - Qwen/Qwen3-30B-A3B-Base\n\npytorch_base_model:\n    tp:\n        - Qwen/Qwen3-8B-Base\n        - Qwen/Qwen3-30B-A3B-Base\n\nturbomind_quantization:\n    no_awq:\n        - meta-llama/Meta-Llama-3-1-70B-Instruct\n        - internlm/internlm3-8b-instruct # ImportError: cannot import name 'LossKwargs' from 'transformers.utils' (/opt/py3/lib/python3.10/site-packages/transformers/utils/__init__.py)\n        - OpenGVLab/InternVL3-8B\n        - Qwen/Qwen3-30B-A3B\n        - Qwen/Qwen3-235B-A22B\n        - Qwen/Qwen3-30B-A3B-Base\n        - Qwen/Qwen1.5-MoE-A2.7B-Chat\n        - Qwen/Qwen2.5-VL-7B-Instruct\n        - Qwen/Qwen2.5-VL-32B-Instruct\n        - OpenGVLab/InternVL3_5-30B-A3B\n        - zai-org/GLM-4.7-Flash\n    gptq:\n        - empty\n    no_kvint4:\n        - meta-llama/Llama-3.2-1B-Instruct\n        - OpenGVLab/InternVL3-2B\n        - OpenGVLab/InternVL3-8B\n        - Qwen/Qwen3-0.6B\n        - Qwen/Qwen3-4B\n        - Qwen/Qwen3-8B\n        - Qwen/Qwen3-32B\n        - Qwen/Qwen3-30B-A3B\n        - Qwen/Qwen3-30B-A3B-GPTQ-Int4\n        - Qwen/Qwen3-235B-A22B\n        - Qwen/Qwen2.5-VL-7B-Instruct\n        - Qwen/Qwen2.5-VL-32B-Instruct\n        - Qwen/Qwen1.5-MoE-A2.7B-Chat\n        - Qwen/Qwen3-8B-Base\n        - Qwen/Qwen3-30B-A3B-Base\n        - zai-org/GLM-4.7-Flash\n    no_kvint8:\n        - deepseek-ai/DeepSeek-V2-Chat\n        - zai-org/GLM-4.7-Flash\n\npytorch_quantization:\n    awq:\n        - meta-llama/Llama-3.2-3B-Instruct\n        - meta-llama/Meta-Llama-3-8B-Instruct\n        - meta-llama/Meta-Llama-3-1-8B-Instruct\n        # - internlm/internlm3-8b-instruct ImportError: cannot import name 'LossKwargs' from 'transformers.utils' (/opt/py3/lib/python3.10/site-packages/transformers/utils/__init__.py)\n        - Qwen/Qwen3-0.6B\n        - Qwen/Qwen3-4B\n        - Qwen/Qwen3-8B\n        - microsoft/Phi-3-mini-4k-instruct\n        - THUDM/glm-4v-9b\n    w8a8:\n        - meta-llama/Llama-3.2-1B-Instruct\n        - meta-llama/Meta-Llama-3-8B-Instruct\n        - meta-llama/Meta-Llama-3-1-8B-Instruct\n        # - internlm/internlm3-8b-instruct ImportError: cannot import name 'LossKwargs' from 'transformers.utils' (/opt/py3/lib/python3.10/site-packages/transformers/utils/__init__.py)\n        - microsoft/Phi-3-mini-4k-instruct\n    no_kvint4:\n        - meta-llama/Llama-3.2-1B-Instruct\n        - OpenGVLab/InternVL3-2B\n        - OpenGVLab/InternVL3-8B\n        - Qwen/Qwen3-8B-Base\n        - Qwen/Qwen3-30B-A3B-Base\n        - Qwen/Qwen3-0.6B\n        - Qwen/Qwen3-4B\n        - Qwen/Qwen3-8B\n        - Qwen/Qwen3-32B\n        - Qwen/Qwen3-30B-A3B\n        - Qwen/Qwen3-235B-A22B\n        - Qwen/Qwen3-VL-8B-Instruct\n        - Qwen/Qwen3-VL-30B-A3B-Instruct\n        - microsoft/Phi-3-vision-128k-instruct\n        - microsoft/Phi-3.5-vision-instruct\n        - zai-org/GLM-4.7-Flash\n    no_kvint8:\n        - zai-org/GLM-4.7-Flash\n\nlongtext_benchmark_model:\n    - Qwen/Qwen3-8B\n    - Qwen/Qwen3-30B-A3B\n\nevaluate_model:\n  - google/gemma-2-9b-it\n  - google/gemma-2-27b-it\n  - meta-llama/Meta-Llama-3-1-8B-Instruct\n  - Qwen/Qwen1.5-MoE-A2.7B-Chat\n  - Qwen/Qwen3-30B-A3B\n\nbenchmark_model:\n    - meta-llama/Meta-Llama-3-1-8B-Instruct\n    - meta-llama/Meta-Llama-3-1-70B-Instruct\n    - internlm/internlm3-8b-instruct\n    - THUDM/glm-4-9b-chat\n    - Qwen/Qwen3-30B-A3B\n\nmllm_evaluate_model:\n  - OpenGVLab/InternVL3-8B\n  - Qwen/Qwen3-VL-8B-Instruct\n  - Qwen/Qwen3-VL-30B-A3B-Instruct\n  - OpenGVLab/InternVL3_5-30B-A3B\n"
  },
  {
    "path": "autotest/config_3090.yml",
    "content": "model_path: /nvme/qa_test_models\nresource_path: /nvme/qa_test_models/resource\nlog_path: /nvme/qa_test_models/autotest_log\nserver_log_path: /nvme/qa_test_models/server_log\neval_path: /nvme/qa_test_models/evaluation_report\nmllm_eval_path: /nvme/qa_test_models/mllm_evaluation_report\nbenchmark_path: /nvme/qa_test_models/benchmark_report\ndataset_path: /nvme/qa_test_models/datasets/ShareGPT_V3_unfiltered_cleaned_split.json\nenv_tag: 3090\ndevice: cuda\n\nturbomind_chat_model:\n    tp:\n        - meta-llama/Llama-3.2-3B-Instruct\n        - meta-llama/Llama-3.2-1B-Instruct\n        - internlm/internlm3-8b-instruct\n        - OpenGVLab/InternVL3-8B\n        - OpenGVLab/InternVL3-2B-Instruct\n        - OpenGVLab/InternVL3-1B-Instruct\n        - OpenGVLab/InternVL2_5-1B\n        - Qwen/Qwen3-8B\n        - Qwen/Qwen3-4B\n        - Qwen/Qwen3-1.7B\n        - Qwen/Qwen3-0.6B\n        - Qwen/Qwen2.5-7B-Instruct\n\npytorch_chat_model:\n    tp:\n        - meta-llama/Llama-3.2-3B-Instruct\n        - meta-llama/Llama-3.2-1B-Instruct\n        - internlm/internlm3-8b-instruct\n        - OpenGVLab/InternVL3-8B\n        - OpenGVLab/InternVL3-2B-Instruct\n        - OpenGVLab/InternVL3-1B-Instruct\n        - OpenGVLab/InternVL2_5-1B\n        - Qwen/Qwen3-8B\n        - Qwen/Qwen3-4B\n        - Qwen/Qwen3-1.7B\n        - Qwen/Qwen3-0.6B\n        - Qwen/Qwen2.5-7B-Instruct\n\nturbomind_vl_model:\n    tp:\n        - OpenGVLab/InternVL3-8B\n        - OpenGVLab/InternVL3-2B-Instruct\n        - OpenGVLab/InternVL3-1B-Instruct\n        - OpenGVLab/InternVL2_5-1B\n\npytorch_vl_model:\n    tp:\n        - OpenGVLab/InternVL3-8B\n        - OpenGVLab/InternVL3-2B-Instruct\n        - OpenGVLab/InternVL3-1B-Instruct\n        - OpenGVLab/InternVL2_5-1B\n\nturbomind_base_model:\n    tp:\n        - internlm/internlm3-8b-instruct\n        - Qwen/Qwen3-8B\n\npytorch_base_model:\n    tp:\n        - internlm/internlm3-8b-instruct\n        - Qwen/Qwen3-8B\n\nturbomind_quantization:\n    no_awq:\n        - internlm/internlm3-8b-instruct\n        - OpenGVLab/InternVL3-8B\n        - OpenGVLab/InternVL3-2B-Instruct\n    gptq:\n        - empty\n    no_kvint4:\n        - meta-llama/Llama-3.2-1B-Instruct\n        - OpenGVLab/InternVL3-8B\n        - OpenGVLab/InternVL3-2B-Instruct\n        - OpenGVLab/InternVL3-1B-Instruct\n        - OpenGVLab/InternVL2_5-1B\n        - Qwen/Qwen3-8B\n        - Qwen/Qwen3-4B\n        - Qwen/Qwen3-1.7B\n        - Qwen/Qwen3-0.6B\n        - Qwen/Qwen2.5-7B-Instruct\n        - Qwen/Qwen2.5-VL-3B-Instruct\n        - Qwen/Qwen2.5-VL-7B-Instruct\n    no_kvint8:\n        - deepseek-ai/DeepSeek-V2-Chat\n\npytorch_quantization:\n    awq:\n        - OpenGVLab/InternVL2_5-1B\n        - Qwen/Qwen3-8B\n        - Qwen/Qwen3-4B\n        - Qwen/Qwen3-1.7B\n        - Qwen/Qwen3-0.6B\n        - Qwen/Qwen2.5-7B-Instruct\n    w8a8:\n        - meta-llama/Llama-3.2-3B-Instruct\n    no_kvint4:\n        - OpenGVLab/InternVL3-8B\n        - meta-llama/Llama-3.2-1B-Instruct\n        - OpenGVLab/InternVL3-2B-Instruct\n        - OpenGVLab/InternVL3-1B-Instruct\n        - OpenGVLab/InternVL2_5-1B\n        - Qwen/Qwen3-8B\n        - Qwen/Qwen3-4B\n        - Qwen/Qwen3-1.7B\n        - Qwen/Qwen3-0.6B\n        - Qwen/Qwen2.5-7B-Instruct\n    no_kvint8:\n        - deepseek-ai/DeepSeek-V2-Lite-Chat\n"
  },
  {
    "path": "autotest/config_3090_legacy.yml",
    "content": "model_path: /nvme/qa_test_models\nresource_path: /nvme/qa_test_models/resource\nlog_path: /nvme/qa_test_models/autotest_log\nserver_log_path: /nvme/qa_test_models/server_log\neval_path: /nvme/qa_test_models/evaluation_report\nmllm_eval_path: /nvme/qa_test_models/mllm_evaluation_report\nbenchmark_path: /nvme/qa_test_models/benchmark_report\ndataset_path: /nvme/qa_test_models/datasets/ShareGPT_V3_unfiltered_cleaned_split.json\nenv_tag: 3090\ndevice: cuda\n\nturbomind_chat_model:\n    tp:\n        - meta-llama/Llama-3.2-3B-Instruct\n        - meta-llama/Llama-3.2-1B-Instruct\n        - internlm/internlm3-8b-instruct\n        - OpenGVLab/InternVL3-8B\n        - OpenGVLab/InternVL3-2B-Instruct\n        - OpenGVLab/InternVL3-1B-Instruct\n        - OpenGVLab/InternVL2_5-1B\n        - Qwen/Qwen3-8B\n        - Qwen/Qwen3-4B\n        - Qwen/Qwen3-1.7B\n        - Qwen/Qwen3-0.6B\n        - Qwen/Qwen2.5-7B-Instruct\n\npytorch_chat_model:\n    tp:\n        - meta-llama/Llama-3.2-3B-Instruct\n        - meta-llama/Llama-3.2-1B-Instruct\n        - internlm/internlm3-8b-instruct\n        - OpenGVLab/InternVL3-8B\n        - OpenGVLab/InternVL3-2B-Instruct\n        - OpenGVLab/InternVL3-1B-Instruct\n        - OpenGVLab/InternVL2_5-1B\n        - Qwen/Qwen3-8B\n        - Qwen/Qwen3-4B\n        - Qwen/Qwen3-1.7B\n        - Qwen/Qwen3-0.6B\n        - Qwen/Qwen2.5-7B-Instruct\n        - Qwen/Qwen2.5-VL-3B-Instruct\n        - Qwen/Qwen2.5-VL-7B-Instruct\n\nturbomind_vl_model:\n    tp:\n        - OpenGVLab/InternVL3-8B\n        - OpenGVLab/InternVL3-2B-Instruct\n        - OpenGVLab/InternVL3-1B-Instruct\n        - OpenGVLab/InternVL2_5-1B\n\npytorch_vl_model:\n    tp:\n        - OpenGVLab/InternVL3-8B\n        - OpenGVLab/InternVL3-2B-Instruct\n        - OpenGVLab/InternVL3-1B-Instruct\n        - OpenGVLab/InternVL2_5-1B\n        - Qwen/Qwen2.5-VL-3B-Instruct\n        - Qwen/Qwen2.5-VL-7B-Instruct\n\nturbomind_base_model:\n    tp:\n        - internlm/internlm3-8b-instruct\n        - Qwen/Qwen3-8B\n\npytorch_base_model:\n    tp:\n        - internlm/internlm3-8b-instruct\n        - Qwen/Qwen3-8B\n\nturbomind_quantization:\n    no_awq:\n        - internlm/internlm3-8b-instruct\n        - OpenGVLab/InternVL3-8B\n        - OpenGVLab/InternVL3-2B-Instruct\n    gptq:\n        - empty\n    no_kvint4:\n        - meta-llama/Llama-3.2-1B-Instruct\n        - OpenGVLab/InternVL3-8B\n        - OpenGVLab/InternVL3-2B-Instruct\n        - OpenGVLab/InternVL3-1B-Instruct\n        - OpenGVLab/InternVL2_5-1B\n        - Qwen/Qwen3-8B\n        - Qwen/Qwen3-4B\n        - Qwen/Qwen3-1.7B\n        - Qwen/Qwen3-0.6B\n        - Qwen/Qwen2.5-7B-Instruct\n        - Qwen/Qwen2.5-VL-3B-Instruct\n        - Qwen/Qwen2.5-VL-7B-Instruct\n    no_kvint8:\n        - deepseek-ai/DeepSeek-V2-Chat\n\npytorch_quantization:\n    awq:\n        - OpenGVLab/InternVL2_5-1B\n        - Qwen/Qwen3-8B\n        - Qwen/Qwen3-4B\n        - Qwen/Qwen3-1.7B\n        - Qwen/Qwen3-0.6B\n        - Qwen/Qwen2.5-7B-Instruct\n    w8a8:\n        - meta-llama/Llama-3.2-3B-Instruct\n    no_kvint4:\n        - OpenGVLab/InternVL3-8B\n        - meta-llama/Llama-3.2-1B-Instruct\n        - OpenGVLab/InternVL3-2B-Instruct\n        - OpenGVLab/InternVL3-1B-Instruct\n        - OpenGVLab/InternVL2_5-1B\n        - Qwen/Qwen3-8B\n        - Qwen/Qwen3-4B\n        - Qwen/Qwen3-1.7B\n        - Qwen/Qwen3-0.6B\n        - Qwen/Qwen2.5-7B-Instruct\n        - Qwen/Qwen2.5-VL-3B-Instruct\n        - Qwen/Qwen2.5-VL-7B-Instruct\n    no_kvint8:\n        - deepseek-ai/DeepSeek-V2-Lite-Chat\n"
  },
  {
    "path": "autotest/config_5080.yml",
    "content": "model_path: /nvme/qa_test_models\nresource_path: /nvme/qa_test_models/resource\nlog_path: /nvme/qa_test_models/autotest_log\nserver_log_path: /nvme/qa_test_models/server_log\neval_path: /nvme/qa_test_models/evaluation_report\nmllm_eval_path: /nvme/qa_test_models/mllm_evaluation_report\nbenchmark_path: /nvme/qa_test_models/benchmark_report\ndataset_path: /nvme/qa_test_models/datasets/ShareGPT_V3_unfiltered_cleaned_split.json\nenv_tag: 5080\ndevice: cuda\n\nturbomind_chat_model:\n    tp:\n        - meta-llama/Llama-3.2-3B-Instruct\n        - meta-llama/Llama-3.2-1B-Instruct\n        - OpenGVLab/InternVL3-2B-Instruct\n        - OpenGVLab/InternVL3-1B-Instruct\n        - OpenGVLab/InternVL2_5-1B\n        - Qwen/Qwen3-4B\n        - Qwen/Qwen3-1.7B\n        - Qwen/Qwen3-0.6B\n\npytorch_chat_model:\n    tp:\n        - meta-llama/Llama-3.2-3B-Instruct\n        - meta-llama/Llama-3.2-1B-Instruct\n        - OpenGVLab/InternVL3-2B-Instruct\n        - OpenGVLab/InternVL3-1B-Instruct\n        - OpenGVLab/InternVL2_5-1B\n        - Qwen/Qwen3-4B\n        - Qwen/Qwen3-1.7B\n        - Qwen/Qwen3-0.6B\n\nturbomind_vl_model:\n    tp:\n        - OpenGVLab/InternVL3-2B-Instruct\n        - OpenGVLab/InternVL3-1B-Instruct\n        - OpenGVLab/InternVL2_5-1B\n\npytorch_vl_model:\n    tp:\n        - OpenGVLab/InternVL3-2B-Instruct\n        - OpenGVLab/InternVL3-1B-Instruct\n        - OpenGVLab/InternVL2_5-1B\n\nturbomind_base_model:\n    tp:\n        - Qwen/Qwen3-4B\n\npytorch_base_model:\n    tp:\n        - Qwen/Qwen3-4B\n\nturbomind_quantization:\n    no_awq:\n        - OpenGVLab/InternVL3-2B-Instruct\n    gptq:\n        - empty\n    no_kvint4:\n        - meta-llama/Llama-3.2-1B-Instruct\n        - OpenGVLab/InternVL3-2B-Instruct\n        - OpenGVLab/InternVL3-1B-Instruct\n        - OpenGVLab/InternVL2_5-1B\n        - Qwen/Qwen3-4B\n        - Qwen/Qwen3-1.7B\n        - Qwen/Qwen3-0.6B\n        - Qwen/Qwen2.5-VL-3B-Instruct\n    no_kvint8:\n        - deepseek-ai/DeepSeek-V2-Chat\n\npytorch_quantization:\n    awq:\n        - meta-llama/Llama-3.2-3B-Instruct\n        - OpenGVLab/InternVL2_5-1B\n        - Qwen/Qwen3-4B\n        - Qwen/Qwen3-1.7B\n        - Qwen/Qwen3-0.6B\n    w8a8:\n        - meta-llama/Llama-3.2-3B-Instruct\n    no_kvint4:\n        - meta-llama/Llama-3.2-1B-Instruct\n        - OpenGVLab/InternVL3-2B-Instruct\n        - OpenGVLab/InternVL3-1B-Instruct\n        - OpenGVLab/InternVL2_5-1B\n        - Qwen/Qwen3-4B\n        - Qwen/Qwen3-1.7B\n        - Qwen/Qwen3-0.6B\n    no_kvint8:\n        - deepseek-ai/DeepSeek-V2-Lite-Chat\n"
  },
  {
    "path": "autotest/config_5080_legacy.yml",
    "content": "model_path: /nvme/qa_test_models\nresource_path: /nvme/qa_test_models/resource\nlog_path: /nvme/qa_test_models/autotest_log\nserver_log_path: /nvme/qa_test_models/server_log\neval_path: /nvme/qa_test_models/evaluation_report\nmllm_eval_path: /nvme/qa_test_models/mllm_evaluation_report\nbenchmark_path: /nvme/qa_test_models/benchmark_report\ndataset_path: /nvme/qa_test_models/datasets/ShareGPT_V3_unfiltered_cleaned_split.json\nenv_tag: 5080\ndevice: cuda\n\nturbomind_chat_model:\n    tp:\n        - meta-llama/Llama-3.2-3B-Instruct\n        - meta-llama/Llama-3.2-1B-Instruct\n        - OpenGVLab/InternVL3-2B-Instruct\n        - OpenGVLab/InternVL3-1B-Instruct\n        - OpenGVLab/InternVL2_5-1B\n        - Qwen/Qwen3-4B\n        - Qwen/Qwen3-1.7B\n        - Qwen/Qwen3-0.6B\n\npytorch_chat_model:\n    tp:\n        - meta-llama/Llama-3.2-3B-Instruct\n        - meta-llama/Llama-3.2-1B-Instruct\n        - OpenGVLab/InternVL3-2B-Instruct\n        - OpenGVLab/InternVL3-1B-Instruct\n        - OpenGVLab/InternVL2_5-1B\n        - Qwen/Qwen3-4B\n        - Qwen/Qwen3-1.7B\n        - Qwen/Qwen3-0.6B\n\nturbomind_vl_model:\n    tp:\n        - OpenGVLab/InternVL3-2B-Instruct\n        - OpenGVLab/InternVL3-1B-Instruct\n        - OpenGVLab/InternVL2_5-1B\n\npytorch_vl_model:\n    tp:\n        - OpenGVLab/InternVL3-2B-Instruct\n        - OpenGVLab/InternVL3-1B-Instruct\n        - OpenGVLab/InternVL2_5-1B\n        - Qwen/Qwen2.5-VL-3B-Instruct\n\nturbomind_base_model:\n    tp:\n        - Qwen/Qwen3-4B\n\npytorch_base_model:\n    tp:\n        - Qwen/Qwen3-4B\n\nturbomind_quantization:\n    no_awq:\n        - OpenGVLab/InternVL3-2B-Instruct\n    gptq:\n        - empty\n    no_kvint4:\n        - meta-llama/Llama-3.2-1B-Instruct\n        - OpenGVLab/InternVL3-2B-Instruct\n        - OpenGVLab/InternVL3-1B-Instruct\n        - OpenGVLab/InternVL2_5-1B\n        - Qwen/Qwen3-4B\n        - Qwen/Qwen3-1.7B\n        - Qwen/Qwen3-0.6B\n        - Qwen/Qwen2.5-VL-3B-Instruct\n    no_kvint8:\n        - deepseek-ai/DeepSeek-V2-Chat\n\npytorch_quantization:\n    awq:\n        - meta-llama/Llama-3.2-3B-Instruct\n        - OpenGVLab/InternVL2_5-1B\n        - Qwen/Qwen3-4B\n        - Qwen/Qwen3-1.7B\n        - Qwen/Qwen3-0.6B\n    w8a8:\n        - meta-llama/Llama-3.2-3B-Instruct\n    no_kvint4:\n        - meta-llama/Llama-3.2-1B-Instruct\n        - OpenGVLab/InternVL3-2B-Instruct\n        - OpenGVLab/InternVL3-1B-Instruct\n        - OpenGVLab/InternVL2_5-1B\n        - Qwen/Qwen3-4B\n        - Qwen/Qwen3-1.7B\n        - Qwen/Qwen3-0.6B\n        - Qwen/Qwen2.5-VL-3B-Instruct\n    no_kvint8:\n        - deepseek-ai/DeepSeek-V2-Lite-Chat\n"
  },
  {
    "path": "autotest/config_ascend.yml",
    "content": "model_path: /mnt/vc-intern-delivery/qa-llm-cicd/qa_test_models\nresource_path: /mnt/vc-intern-delivery/qa-llm-cicd/resource\nlog_path: /mnt/vc-intern-delivery/qa-llm-cicd/log\nserver_log_path: /mnt/vc-intern-delivery/qa-llm-cicd/server_log\neval_path: /mnt/vc-intern-delivery/qa-llm-cicd/evaluation_report\nmllm_eval_path: /mnt/vc-intern-delivery/qa-llm-cicd/mllm_evaluation_report\nbenchmark_path: /mnt/vc-intern-delivery/qa-llm-cicd/benchmark_report\ndataset_path: /mnt/vc-intern-delivery/qa-llm-cicd/datasets/ShareGPT_V3_unfiltered_cleaned_split.json\nprefix_dataset_path: /mnt/vc-intern-delivery/qa-llm-cicd/datasets/prefix_cache_test.json\nenv_tag: ascend\ndevice: ascend\n\nconfig:\n    tp:\n        Qwen/Qwen3-30B-A3B: 4\n        Qwen/Qwen3-235B-A22B: 16\n        Qwen/Qwen3-32B: 4\n        Qwen/Qwen3-8B: 2\n        internlm/Intern-S1: 16\n        internlm/Intern-S1-mini: 2\n        OpenGVLab/InternVL3_5-8B: 2\n        OpenGVLab/InternVL3_5-38B: 4\n        Qwen/Qwen3-VL-30B-A3B-Instruct: 4\n        Qwen/Qwen3-VL-8B-Instruct: 2\n        Qwen/Qwen3-VL-32B-Instruct: 4\n\npytorch_chat_model:\n    tp:\n        - Qwen/Qwen3-30B-A3B\n        - Qwen/Qwen3-235B-A22B\n        - Qwen/Qwen3-32B\n        - Qwen/Qwen3-8B\n        - Qwen/Qwen3-0.6B\n\npytorch_vl_model:\n    tp:\n        - internlm/Intern-S1\n        - internlm/Intern-S1-mini\n        - OpenGVLab/InternVL3_5-2B\n        - OpenGVLab/InternVL3_5-8B\n        - OpenGVLab/InternVL3_5-38B\n        - Qwen/Qwen3-VL-30B-A3B-Instruct\n        - Qwen/Qwen3-VL-8B-Instruct\n        - Qwen/Qwen3-VL-32B-Instruct\n\n\npytorch_base_model:\n    tp:\n        - Qwen/Qwen3-0.6B\n\npytorch_quantization:\n    awq:\n        - Empty\n    w8a8:\n        - Empty\n    no_kvint4:\n        - Qwen/Qwen3-30B-A3B\n        - Qwen/Qwen3-235B-A22B\n        - Qwen/Qwen3-32B\n        - Qwen/Qwen3-8B\n        - Qwen/Qwen3-0.6B\n        - internlm/Intern-S1\n        - internlm/Intern-S1-mini\n        - OpenGVLab/InternVL3_5-2B\n        - OpenGVLab/InternVL3_5-8B\n        - OpenGVLab/InternVL3_5-38B\n        - Qwen/Qwen3-VL-30B-A3B-Instruct\n        - Qwen/Qwen3-VL-8B-Instruct\n        - Qwen/Qwen3-VL-32B-Instruct\n    no_kvint8:\n        - Qwen/Qwen3-30B-A3B\n        - Qwen/Qwen3-235B-A22B\n        - Qwen/Qwen3-32B\n        - Qwen/Qwen3-8B\n        - Qwen/Qwen3-0.6B\n        - internlm/Intern-S1\n        - internlm/Intern-S1-mini\n        - OpenGVLab/InternVL3_5-2B\n        - OpenGVLab/InternVL3_5-8B\n        - OpenGVLab/InternVL3_5-38B\n        - Qwen/Qwen3-VL-30B-A3B-Instruct\n        - Qwen/Qwen3-VL-8B-Instruct\n        - Qwen/Qwen3-VL-32B-Instruct\n\nlongtext_model:\n    - Qwen/Qwen3-30B-A3B\n\nbenchmark_model:\n    - Qwen/Qwen3-30B-A3B\n    - Qwen/Qwen3-235B-A22B\n    - Qwen/Qwen3-32B\n    - Qwen/Qwen3-8B\n    - Qwen/Qwen3-0.6B\n    - internlm/Intern-S1\n    - internlm/Intern-S1-mini\n\n\nevaluate_model:\n    - Qwen/Qwen3-30B-A3B\n    - Qwen/Qwen3-235B-A22B\n\nmllm_evaluate_model:\n    - Qwen/Qwen3-VL-30B-A3B-Instruct\n    - Qwen/Qwen3-VL-8B-Instruct\n    - Qwen/Qwen3-VL-32B-Instruct\n"
  },
  {
    "path": "autotest/config_h.yml",
    "content": "model_path: /mnt/shared-storage-user/llmrazor-share/qa-llm-cicd/cicd-autotest/eval_resource/model\nresource_path: /mnt/shared-storage-user/llmrazor-share/qa-llm-cicd/cicd-autotest/eval_resource/resource\nlog_path: /mnt/shared-storage-user/llmrazor-share/qa-llm-cicd/cicd-autotest/eval_resource/log\nserver_log_path: /mnt/shared-storage-user/llmrazor-share/qa-llm-cicd/cicd-autotest/eval_resource/server_log\neval_path: /mnt/shared-storage-user/llmrazor-share/qa-llm-cicd/cicd-autotest/eval_resource/evaluation_report\nmllm_eval_path: /mnt/shared-storage-user/llmrazor-share/qa-llm-cicd/cicd-autotest/eval_resource/mllm_evaluation_report\nbenchmark_path: /mnt/shared-storage-user/llmrazor-share/qa-llm-cicd/cicd-autotest/eval_resource/benchmark_report\ndataset_path: /mnt/shared-storage-user/llmrazor-share/qa-llm-cicd/cicd-autotest/eval_resource/datasets/ShareGPT_V3_unfiltered_cleaned_split.json\nprefix_dataset_path: /mnt/shared-storage-user/llmrazor-share/qa-llm-cicd/cicd-autotest/eval_resource/datasets/prefix_cache_test.json\nenv_tag: h\ndevice: cuda\n\nconfig:\n    tp:\n        Qwen/Qwen3-235B-A22B-FP8: 4\n        internlm/Intern-S1: 4\n        Qwen/Qwen3-235B-A22B-Thinking-2507-FP8: 4\n        Qwen/Qwen3-30B-A3B: 2\n        Qwen/Qwen3-32B: 2\n        openai/gpt-oss-120b: 2\n        openai/gpt-oss-120b-BF16: 4\n        openai/gpt-oss-20b-BF16: 2\n        deepseek/DeepSeek-V3.1: 8\n        Qwen/Qwen3-30B-A3B-Base: 2\n        JetLM/SDAR-30B-A3B-Sci: 2\n        moonshotai/Kimi-K2-Instruct-0905: 16\n        Qwen/Qwen3-235B-A22B-Thinking-2507: 8\n        OpenGVLab/InternVL3_5-38B: 2\n        Qwen/Qwen3-VL-30B-A3B-Instruct: 2\n        internlm/Intern-S1-Pro-FP8: 16\n\n    dp_ep:\n        moonshotai/Kimi-K2-Instruct-0905:\n            dp: 16\n            ep: 16\n        Qwen/Qwen3-235B-A22B-Thinking-2507:\n            dp: 8\n            ep: 8\n        internlm/Intern-S1-Pro-FP8:\n            dp: 16\n            ep: 16\n\n    cp_tp:\n        Qwen/Qwen3-235B-A22B-Thinking-2507:\n            cp: 2\n            tp: 8\n\n\nturbomind_chat_model:\n    tp:\n        - Qwen/Qwen3-0.6B-FP8\n        - Qwen/Qwen3-1.7B-FP8\n        - Qwen/Qwen3-4B-FP8\n        - Qwen/Qwen3-8B-FP8\n        - Qwen/Qwen3-14B-FP8\n        - Qwen/Qwen3-235B-A22B-Thinking-2507-FP8\n        - Qwen/Qwen3-235B-A22B-Thinking-2507\n        - Qwen/Qwen3-30B-A3B\n        - Qwen/Qwen3-30B-A3B-FP8\n        - Qwen/Qwen3-32B\n        - Qwen/Qwen3-32B-FP8\n        - OpenGVLab/InternVL3_5-38B\n        - openai/gpt-oss-120b\n        - openai/gpt-oss-20b\n\n    cp_tp:\n        - Qwen/Qwen3-235B-A22B-Thinking-2507\n\npytorch_chat_model:\n    tp:\n        - Qwen/Qwen3-0.6B-FP8\n        - Qwen/Qwen3-1.7B-FP8\n        - Qwen/Qwen3-4B-FP8\n        - Qwen/Qwen3-8B-FP8\n        - Qwen/Qwen3-14B-FP8\n        - Qwen/Qwen3-235B-A22B-Thinking-2507\n        - Qwen/Qwen3-235B-A22B-Thinking-2507-FP8\n        - Qwen/Qwen3-30B-A3B\n        - Qwen/Qwen3-30B-A3B-FP8\n        - Qwen/Qwen3-32B\n        - Qwen/Qwen3-32B-FP8\n        - Qwen/Qwen3-VL-30B-A3B-Instruct\n        - OpenGVLab/InternVL3_5-38B\n        - unsloth/gpt-oss-120b-BF16\n        - unsloth/gpt-oss-20b-BF16\n        - deepseek/DeepSeek-V3.1\n        - moonshotai/Kimi-K2-Instruct-0905\n        - internlm/Intern-S1-Pro-FP8\n        - JetLM/SDAR-30B-A3B-Sci\n    dp_ep:\n        - moonshotai/Kimi-K2-Instruct-0905\n        - Qwen/Qwen3-235B-A22B-Thinking-2507\n        - internlm/Intern-S1-Pro-FP8\n\nturbomind_vl_model:\n    tp:\n        - OpenGVLab/InternVL3_5-38B\n\n\npytorch_vl_model:\n    tp:\n        - OpenGVLab/InternVL3_5-38B\n        - Qwen/Qwen3-VL-30B-A3B-Instruct\n\nturbomind_base_model:\n    tp:\n        - Qwen/Qwen3-4B-FP8\n        - openai/gpt-oss-20b\n\npytorch_base_model:\n    tp:\n        - Qwen/Qwen3-8B-Base\n        - Qwen/Qwen3-30B-A3B-Base\n\nturbomind_quantization:\n    no_awq:\n        - Qwen/Qwen3-0.6B-FP8\n        - Qwen/Qwen3-1.7B-FP8\n        - Qwen/Qwen3-4B-FP8\n        - Qwen/Qwen3-8B-FP8\n        - Qwen/Qwen3-14B-FP8\n        - Qwen/Qwen3-235B-A22B-Thinking-2507-FP8\n        - Qwen/Qwen3-30B-A3B\n        - Qwen/Qwen3-30B-A3B-FP8\n        - Qwen/Qwen3-32B\n        - Qwen/Qwen3-32B-FP8\n        - openai/gpt-oss-120b\n        - openai/gpt-oss-20b\n        - Qwen/Qwen3-235B-A22B-Thinking-2507\n    gptq:\n        - empty\n    no_kvint4:\n        - Qwen/Qwen3-0.6B-FP8\n        - Qwen/Qwen3-1.7B-FP8\n        - Qwen/Qwen3-4B-FP8\n        - Qwen/Qwen3-8B-FP8\n        - Qwen/Qwen3-14B-FP8\n        - Qwen/Qwen3-235B-A22B-Thinking-2507-FP8\n        - Qwen/Qwen3-30B-A3B\n        - Qwen/Qwen3-30B-A3B-FP8\n        - Qwen/Qwen3-32B\n        - Qwen/Qwen3-32B-FP8\n        - openai/gpt-oss-120b\n        - openai/gpt-oss-20b\n        - Qwen/Qwen3-235B-A22B-Thinking-2507\n    no_kvint8:\n        - Qwen/Qwen3-235B-A22B-Thinking-2507\n\npytorch_quantization:\n    awq:\n        - empty\n    w8a8:\n        - empty\n    no_kvint4:\n        - Qwen/Qwen3-8B-Base\n        - Qwen/Qwen3-0.6B-FP8\n        - Qwen/Qwen3-1.7B-FP8\n        - Qwen/Qwen3-4B-FP8\n        - Qwen/Qwen3-8B-FP8\n        - Qwen/Qwen3-14B-FP8\n        - Qwen/Qwen3-235B-A22B-Thinking-2507-FP8\n        - Qwen/Qwen3-30B-A3B\n        - Qwen/Qwen3-30B-A3B-FP8\n        - Qwen/Qwen3-32B\n        - Qwen/Qwen3-32B-FP8\n        - moonshotai/Kimi-K2-Instruct-0905\n        - Qwen/Qwen3-235B-A22B-Thinking-2507\n        - internlm/Intern-S1-Pro-FP8\n        - JetLM/SDAR-30B-A3B-Sci\n        - deepseek/DeepSeek-V3.1\n    no_kvint8:\n        - Qwen/Qwen3-235B-A22B-Thinking-2507\n        - internlm/Intern-S1-Pro-FP8\n        - deepseek/DeepSeek-V3.1\n\nlongtext_model:\n    - Qwen/Qwen3-30B-A3B\n    - Qwen/Qwen3-235B-A22B-Thinking-2507\n\nbenchmark_model:\n    - meta-llama/Meta-Llama-3-1-8B-Instruct\n    - meta-llama/Meta-Llama-3-1-70B-Instruct\n    - Qwen/Qwen3-32B\n    - Qwen/Qwen3-30B-A3B\n    - Qwen/Qwen3-235B-A22B-Thinking-2507\n    - Qwen/Qwen2.5-72B-Instruct\n    - openai/gpt-oss-120b\n    - openai/gpt-oss-20b\n    - unsloth/gpt-oss-20b-BF16\n    - unsloth/gpt-oss-120b-BF16\n\nevaluate_model:\n    - Qwen/Qwen3-32B\n    - Qwen/Qwen3-32B-FP8\n    - Qwen/Qwen3-30B-A3B\n    - Qwen/Qwen3-30B-A3B-FP8\n    - Qwen/Qwen3-235B-A22B-Thinking-2507\n    - Qwen/Qwen3-235B-A22B-Thinking-2507-FP8\n    - openai/gpt-oss-120b\n    - unsloth/gpt-oss-120b-BF16\n    - deepseek/DeepSeek-V3.1\n    - moonshotai/Kimi-K2-Instruct-0905\n    - internlm/Intern-S1-Pro-FP8\n    - JetLM/SDAR-30B-A3B-Sci\n\nmllm_evaluate_model:\n    - OpenGVLab/InternVL3_5-38B\n    - Qwen/Qwen3-VL-30B-A3B-Instruct\n"
  },
  {
    "path": "autotest/config_h800.yml",
    "content": "model_path: /nvme/qa_test_models\nresource_path: /nvme/qa_test_models/resource\nlog_path: /nvme/qa_test_models/autotest_model/log\neval_path: /nvme/qa_test_models/evaluation_report\nbenchmark_path: /nvme/qa_test_models/benchmark_report\ndataset_path: /nvme/qa_test_models/datasets/ShareGPT_V3_unfiltered_cleaned_split.json\nprefix_dataset_path: /nvme/qa_test_models/datasets/prefix_cache_test.json\nenv_tag: h800\ndevice: cuda\n\ntp_config:\n    Intern-S1: 8\n    Qwen3-235B-A22B: 8\n    Qwen3-235B-A22B-FP8: 4\n    Qwen3-30B-A3B: 2\n    Qwen3-32B: 2\n    gpt-oss-120b: 2\n    gpt-oss-120b-BF16: 4\n    gpt-oss-20b-BF16: 2\n    Qwen2.5-32B-Instruct: 2\n\nturbomind_chat_model:\n    tp:\n        - internlm/Intern-S1\n        - internlm/Intern-S1-mini\n        - Qwen/Qwen3-0.6B-FP8\n        - Qwen/Qwen3-1.7B-FP8\n        - Qwen/Qwen3-4B-FP8\n        - Qwen/Qwen3-8B-FP8\n        - Qwen/Qwen3-14B-FP8\n        - Qwen/Qwen3-235B-A22B\n        - Qwen/Qwen3-235B-A22B-FP8\n        - Qwen/Qwen3-30B-A3B\n        - Qwen/Qwen3-30B-A3B-FP8\n        - Qwen/Qwen3-32B\n        - Qwen/Qwen3-32B-FP8\n        - openai/gpt-oss-120b\n        - openai/gpt-oss-20b\n\npytorch_chat_model:\n    tp:\n        - internlm/Intern-S1\n        - internlm/Intern-S1-mini\n        - Qwen/Qwen3-0.6B-FP8\n        - Qwen/Qwen3-1.7B-FP8\n        - Qwen/Qwen3-4B-FP8\n        - Qwen/Qwen3-8B-FP8\n        - Qwen/Qwen3-14B-FP8\n        - Qwen/Qwen3-235B-A22B\n        - Qwen/Qwen3-235B-A22B-FP8\n        - Qwen/Qwen3-30B-A3B\n        - Qwen/Qwen3-30B-A3B-FP8\n        - Qwen/Qwen3-32B\n        - Qwen/Qwen3-32B-FP8\n        - unsloth/gpt-oss-120b-BF16\n        - unsloth/gpt-oss-20b-BF16\n\nturbomind_vl_model:\n   tp:\n        - internlm/Intern-S1\n        - internlm/Intern-S1-mini\n\npytorch_vl_model:\n   tp:\n        - internlm/Intern-S1\n        - internlm/Intern-S1-mini\n\nturbomind_base_model:\n    tp:\n        - internlm/Intern-S1-mini\n        - Qwen/Qwen3-4B-FP8\n        - openai/gpt-oss-20b\n\npytorch_base_model:\n    tp:\n        - internlm/Intern-S1-mini\n        - Qwen/Qwen3-4B-FP8\n        - unsloth/gpt-oss-20b-BF16\n\nturbomind_quantization:\n    no_awq:\n        - internlm/Intern-S1\n        - internlm/Intern-S1-mini\n        - Qwen/Qwen3-0.6B-FP8\n        - Qwen/Qwen3-1.7B-FP8\n        - Qwen/Qwen3-4B-FP8\n        - Qwen/Qwen3-8B-FP8\n        - Qwen/Qwen3-14B-FP8\n        - Qwen/Qwen3-235B-A22B\n        - Qwen/Qwen3-235B-A22B-FP8\n        - Qwen/Qwen3-30B-A3B\n        - Qwen/Qwen3-30B-A3B-FP8\n        - Qwen/Qwen3-32B\n        - Qwen/Qwen3-32B-FP8\n        - openai/gpt-oss-120b\n        - openai/gpt-oss-20b\n    gptq:\n        - empty\n    no_kvint4:\n        - internlm/Intern-S1\n        - internlm/Intern-S1-mini\n        - Qwen/Qwen3-0.6B-FP8\n        - Qwen/Qwen3-1.7B-FP8\n        - Qwen/Qwen3-4B-FP8\n        - Qwen/Qwen3-8B-FP8\n        - Qwen/Qwen3-14B-FP8\n        - Qwen/Qwen3-235B-A22B\n        - Qwen/Qwen3-235B-A22B-FP8\n        - Qwen/Qwen3-30B-A3B\n        - Qwen/Qwen3-30B-A3B-FP8\n        - Qwen/Qwen3-32B\n        - Qwen/Qwen3-32B-FP8\n        - openai/gpt-oss-120b\n        - openai/gpt-oss-20b\n    no_kvint8:\n        - empty\n\npytorch_quantization:\n    awq:\n        - empty\n    w8a8:\n        - empty\n    no_kvint4:\n        - internlm/Intern-S1\n        - internlm/Intern-S1-mini\n        - Qwen/Qwen3-0.6B-FP8\n        - Qwen/Qwen3-1.7B-FP8\n        - Qwen/Qwen3-4B-FP8\n        - Qwen/Qwen3-8B-FP8\n        - Qwen/Qwen3-14B-FP8\n        - Qwen/Qwen3-235B-A22B\n        - Qwen/Qwen3-235B-A22B-FP8\n        - Qwen/Qwen3-30B-A3B\n        - Qwen/Qwen3-30B-A3B-FP8\n        - Qwen/Qwen3-32B\n        - Qwen/Qwen3-32B-FP8\n    no_kvint8:\n        - empty\n\n\nevaluate_model:\n    - internlm/Intern-S1-mini\n    - Qwen/Qwen3-0.6B-FP8\n    - Qwen/Qwen3-1.7B-FP8\n    - Qwen/Qwen3-4B-FP8\n    - Qwen/Qwen3-8B-FP8\n    - Qwen/Qwen3-14B-FP8\n    - Qwen/Qwen3-32B\n    - Qwen/Qwen3-32B-FP8\n    - Qwen/Qwen3-30B-A3B\n    - Qwen/Qwen3-30B-A3B-FP8\n    - Qwen/Qwen3-235B-A22B\n    - Qwen/Qwen3-235B-A22B-FP8\n    - openai/gpt-oss-120b\n    - openai/gpt-oss-20b\n    - unsloth/gpt-oss-120b-BF16\n    - unsloth/gpt-oss-20b-BF16\n"
  },
  {
    "path": "autotest/config_h_legacy.yml",
    "content": "model_path: /mnt/shared-storage-user/llmrazor-share/qa-llm-cicd/cicd-autotest/eval_resource/model\nresource_path: /mnt/shared-storage-user/llmrazor-share/qa-llm-cicd/cicd-autotest/eval_resource/resource\nlog_path: /mnt/shared-storage-user/llmrazor-share/qa-llm-cicd/cicd-autotest/eval_resource/log\nserver_log_path: /mnt/shared-storage-user/llmrazor-share/qa-llm-cicd/cicd-autotest/eval_resource/server_log\neval_path: /mnt/shared-storage-user/llmrazor-share/qa-llm-cicd/cicd-autotest/eval_resource/evaluation_report\nmllm_eval_path: /mnt/shared-storage-user/llmrazor-share/qa-llm-cicd/cicd-autotest/eval_resource/mllm_evaluation_report\nbenchmark_path: /mnt/shared-storage-user/llmrazor-share/qa-llm-cicd/cicd-autotest/eval_resource/benchmark_report\ndataset_path: /mnt/shared-storage-user/auto-eval-pipeline/datasets/ShareGPT_V3_unfiltered_cleaned_split.json\nprefix_dataset_path: /mnt/shared-storage-user/auto-eval-pipeline/datasets/prefix_cache_test.json\nenv_tag: h\ndevice: cuda\n\nconfig:\n    tp:\n        internlm/Intern-S1: 4\n\nturbomind_chat_model:\n    tp:\n        - internlm/Intern-S1\n        - internlm/Intern-S1-mini\n\npytorch_chat_model:\n    tp:\n        - internlm/Intern-S1\n        - internlm/Intern-S1-mini\n\nturbomind_vl_model:\n    tp:\n        - internlm/Intern-S1\n        - internlm/Intern-S1-mini\n\npytorch_vl_model:\n    tp:\n        - internlm/Intern-S1\n        - internlm/Intern-S1-mini\n\nturbomind_base_model:\n    tp:\n\npytorch_base_model:\n    tp:\n\nturbomind_quantization:\n    no_awq:\n        - internlm/Intern-S1\n        - internlm/Intern-S1-mini\n    gptq:\n        - empty\n    no_kvint4:\n        - internlm/Intern-S1\n        - internlm/Intern-S1-mini\n    no_kvint8:\n        - empty\n\npytorch_quantization:\n    awq:\n        - empty\n    w8a8:\n        - empty\n    no_kvint4:\n        - internlm/Intern-S1\n        - internlm/Intern-S1-mini\n    no_kvint8:\n        - empty\n\nbenchmark_model:\n    - internlm/Intern-S1\n    - internlm/Intern-S1-mini\n\nmllm_evaluate_model:\n    - internlm/Intern-S1\n    - internlm/Intern-S1-mini\n"
  },
  {
    "path": "autotest/config_legacy.yml",
    "content": "model_path: /nvme/qa_test_models\nresource_path: /nvme/qa_test_models/resource\nlog_path: /nvme/qa_test_models/autotest_log\nserver_log_path: /nvme/qa_test_models/server_log\neval_path: /nvme/qa_test_models/evaluation_report\nmllm_eval_path: /nvme/qa_test_models/mllm_evaluation_report\nbenchmark_path: /nvme/qa_test_models/benchmark_report\ndataset_path: /nvme/qa_test_models/datasets/ShareGPT_V3_unfiltered_cleaned_split.json\nprefix_dataset_path: /nvme/qa_test_models/datasets/prefix_cache_test.json\nenv_tag: a100\ndevice: cuda\n\nconfig:\n    tp:\n        meta-llama/Llama-4-Scout-17B-16E-Instruct: 4\n        meta-llama/Meta-Llama-3-1-70B-Instruct: 4\n        internlm/Intern-S1: 8\n        OpenGVLab/InternVL3-38B: 2\n        OpenGVLab/InternVL2_5-26B: 2\n        OpenGVLab/InternVL2_5-26B-MPO: 2\n        OpenGVLab/InternVL2_5-38B: 4\n        OpenGVLab/InternVL2-40B: 4\n        Qwen/Qwen2.5-72B-Instruct: 4\n        deepseek-ai/deepseek-vl-1.3b-chat: 2\n        baichuan-inc/Baichuan2-13B-Chat: 2\n        mistralai/Mixtral-8x7B-Instruct-v0.1: 2\n        google/gemma-2-27b-it: 2\n        OpenGVLab/InternVL2-Llama3-76B-AWQ: 4\n        unsloth/gpt-oss-20b-BF16: 2\n        unsloth/gpt-oss-120b-BF16: 4\n        OpenGVLab/InternVL3_5-30B-A3B: 2\n\nturbomind_chat_model:\n    tp:\n        - meta-llama/Llama-2-7b-chat-hf\n        - internlm/Intern-S1\n        - internlm/Intern-S1-mini\n        - OpenGVLab/InternVL3-8B\n        - OpenGVLab/InternVL2_5-8B\n        - OpenGVLab/Mini-InternVL-Chat-2B-V1-5\n        - OpenGVLab/InternVL2-Llama3-76B-AWQ\n        - Qwen/Qwen2.5-7B-Instruct\n        - Qwen/Qwen2-57B-A14B-Instruct-GPTQ-Int4\n        - Qwen/Qwen2.5-VL-7B-Instruct\n        - Qwen/Qwen2-VL-7B-Instruct\n        - baichuan-inc/Baichuan2-7B-Chat\n        - liuhaotian/llava-v1.6-vicuna-7b\n        - codellama/CodeLlama-7b-Instruct-hf\n        # - allenai/Molmo-7B-D-0924  This modeling file requires the following packages that were not found in your environment: tensorflow. Run `pip install tensorflow`\n\npytorch_chat_model:\n    tp:\n        - meta-llama/Llama-2-7b-chat-hf\n        - internlm/Intern-S1\n        - internlm/Intern-S1-mini\n        - OpenGVLab/InternVL3-8B\n        - OpenGVLab/InternVL2_5-8B\n        # - OpenGVLab/Mono-InternVL-2B 'dict' object has no attribute 'image_size'\n        - Qwen/Qwen2.5-7B-Instruct\n        - Qwen/Qwen2.5-VL-7B-Instruct\n        - Qwen/Qwen2-VL-7B-Instruct\n        - unsloth/gpt-oss-20b-BF16\n        - mistralai/Mixtral-8x7B-Instruct-v0.1\n        - google/gemma-3-12b-it\n        - google/gemma-2-9b-it\n        - google/gemma-2-27b-it\n        - google/gemma-7b-it\n        - baichuan-inc/Baichuan2-13B-Chat\n        - deepseek-ai/deepseek-moe-16b-chat\n        - THUDM/chatglm2-6b\n        - microsoft/Phi-4-mini-instruct\n\nturbomind_vl_model:\n    tp:\n        - internlm/Intern-S1\n        - internlm/Intern-S1-mini\n        - OpenGVLab/InternVL3-8B\n        - OpenGVLab/InternVL2_5-8B\n        - OpenGVLab/Mini-InternVL-Chat-2B-V1-5\n        - OpenGVLab/InternVL2-Llama3-76B-AWQ\n        - Qwen/Qwen2.5-VL-7B-Instruct\n        - Qwen/Qwen2-VL-7B-Instruct\n        - liuhaotian/llava-v1.6-vicuna-7b\n\npytorch_vl_model:\n    tp:\n        - internlm/Intern-S1\n        - internlm/Intern-S1-mini\n        - OpenGVLab/InternVL3-8B\n        - OpenGVLab/InternVL2_5-8B\n        # - OpenGVLab/Mono-InternVL-2B 'dict' object has no attribute 'image_size'\n        - Qwen/Qwen2-VL-7B-Instruct\n        - Qwen/Qwen2.5-VL-7B-Instruct\n\nturbomind_base_model:\n    tp:\n        - codellama/CodeLlama-7b-hf\n\npytorch_base_model:\n    tp:\n        - bigcode/starcoder2-7b\n\nturbomind_quantization:\n    no_awq:\n        - internlm/Intern-S1\n        - internlm/Intern-S1-mini\n        - OpenGVLab/InternVL3-8B\n        - Qwen/Qwen2.5-7B-Instruct\n        - Qwen/Qwen2-57B-A14B-Instruct-GPTQ-Int4\n        - Qwen/Qwen2.5-VL-7B-Instruct\n        - Qwen/Qwen2-VL-7B-Instruct\n        - OpenGVLab/InternVL3_5-30B-A3B\n        - codellama/CodeLlama-7b-Instruct-hf\n        # - allenai/Molmo-7B-D-0924  This modeling file requires the following packages that were not found in your environment: tensorflow. Run `pip install tensorflow`\n    gptq:\n        - empty\n    no_kvint4:\n        - OpenGVLab/InternVL3-8B\n        - Qwen/Qwen2.5-7B-Instruct\n        - Qwen/Qwen2.5-VL-7B-Instruct\n        - Qwen/Qwen2-VL-7B-Instruct\n        # - allenai/Molmo-7B-D-0924  This modeling file requires the following packages that were not found in your environment: tensorflow. Run `pip install tensorflow`\n    no_kvint8:\n        - Qwen/Qwen2.5-7B-Instruct\n\npytorch_quantization:\n    awq:\n        - meta-llama/Llama-2-7b-chat-hf\n        # - internlm/internlm3-8b-instruct ImportError: cannot import name 'LossKwargs' from 'transformers.utils' (/opt/py3/lib/python3.10/site-packages/transformers/utils/__init__.py)\n        - Qwen/Qwen2.5-7B-Instruct\n        # - microsoft/Phi-4-mini-instruct The size of tensor a (5120) must match the size of tensor b (3072) at non-singleton dimension 0\n    w8a8:\n        - meta-llama/Llama-2-7b-chat-hf\n        # - internlm/internlm3-8b-instruct ImportError: cannot import name 'LossKwargs' from 'transformers.utils' (/opt/py3/lib/python3.10/site-packages/transformers/utils/__init__.py)\n        - Qwen/Qwen2.5-7B-Instruct\n        # - microsoft/Phi-4-mini-instruct The size of tensor a (5120) must match the size of tensor b (3072) at non-singleton dimension 0\n    no_kvint4:\n        - OpenGVLab/InternVL3-8B\n        - Qwen/Qwen2.5-7B-Instruct\n        - Qwen/Qwen2.5-VL-7B-Instruct\n        - Qwen/Qwen2-VL-7B-Instruct\n        - microsoft/Phi-3-vision-128k-instruct\n        - microsoft/Phi-3.5-vision-instruct\n        - unsloth/gpt-oss-20b-BF16\n    no_kvint8:\n        - empty\n\nlongtext_benchmark_model:\n    - internlm/Intern-S1-mini\n\nbenchmark_model:\n    - internlm/Intern-S1\n    - internlm/Intern-S1-mini\n    - meta-llama/Llama-2-7b-chat-hf\n    - unsloth/gpt-oss-20b-BF16\n\nevaluate_model:\n  - Qwen/Qwen2.5-7B-Instruct\n\nmllm_evaluate_model:\n  - internlm/Intern-S1-mini\n  - internlm/Intern-S1\n"
  },
  {
    "path": "autotest/config_test.yml",
    "content": "model_path: /nvme/qa_test_models\nresource_path: /nvme/qa_test_models/resource\nlog_path: /nvme/qa_test_models/autotest_model/log\nserver_log_path: /nvme/qa_test_models/server_log\neval_path: /nvme/qa_test_models/evaluation_report\nmllm_eval_path: /nvme/qa_test_models/mllm_evaluation_report\nbenchmark_path: /nvme/qa_test_models/benchmark_report\ndataset_path: /nvme/qa_test_models/datasets/ShareGPT_V3_unfiltered_cleaned_split.json\nprefix_dataset_path: /nvme/qa_test_models/datasets/prefix_cache_test.json\nenv_tag: test\ndevice: cuda\n\nconfig:\n    tp:\n        test/test_tp2: 2\n        test/test_tp2_gpqa: 2\n        test/test_tp2_int4: 2\n        test/test_tp8: 8\n        test/test_vl_tp2: 2\n        test/test_vl_tp2_gpqa: 2\n        test/test_vl_tp2_int4: 2\n        test/test_vl_tp8: 8\n        test/test_allkind: 8\n\n    dp_ep:\n        test/test_dpep16:\n            dp: 16\n            ep: 16\n        test/test_dpep8:\n            dp: 8\n            ep: 8\n        test/test_vl_dpep16:\n            dp: 16\n            ep: 16\n        test/test_vl_dpep8:\n            dp: 8\n            ep: 8\n        test/test_allkind:\n            dp: 8\n            ep: 8\n\n    cp_tp:\n        test/test_cp2tp8:\n            cp: 2\n            tp: 8\n        test/test_vl_cp2tp8:\n            cp: 2\n            tp: 8\n        test/test_allkind:\n            cp: 2\n            tp: 8\n\n\nturbomind_chat_model:\n    tp:\n        - test/test_tp1\n        - test/test_tp2\n        - test/test_tp2_gpqa\n        - test/test_tp2_int4\n        - test/test_tp8\n        - test/test_vl_tp1\n        - test/test_vl_tp2\n        - test/test_vl_tp2_gpqa\n        - test/test_vl_tp2_int4\n        - test/test_vl_tp8\n        - test/test_allkind\n    cp_tp:\n        - test/test_cp2tp8\n        - test/test_vl_cp2tp8\n        - test/test_allkind\n\npytorch_chat_model:\n    tp:\n        - test/test_tp1\n        - test/test_tp1_pytorch\n        - test/test_tp2\n        - test/test_tp2_gpqa\n        - test/test_tp2_int4\n        - test/test_tp8\n        - test/test_vl_tp1\n        - test/test_vl_tp1_pytorch\n        - test/test_vl_tp2\n        - test/test_vl_tp2_gpqa\n        - test/test_vl_tp2_int4\n        - test/test_vl_tp8\n        - test/test_allkind\n    dp_ep:\n        - test/test_dpep8\n        - test/test_dpep16\n        - test/test_vl_dpep8\n        - test/test_vl_dpep16\n        - test/test_allkind\n    cp_tp:\n        - test/test_cp2tp8\n        - test/test_vl_cp2tp8\n        - test/test_allkind\n\nturbomind_vl_model:\n    tp:\n        - test/test_vl_tp1\n        - test/test_vl_tp2\n        - test/test_vl_tp2_gpqa\n        - test/test_vl_tp2_int4\n        - test/test_vl_tp8\n        - test/test_allkind\n\npytorch_vl_model:\n    tp:\n        - test/test_vl_tp1\n        - test/test_vl_tp1_pytorch\n        - test/test_vl_tp2\n        - test/test_vl_tp2_gpqa\n        - test/test_vl_tp2_int4\n        - test/test_vl_tp8\n        - test/test_allkind\n    dp_ep:\n        - test/test_vl_dpep8\n        - test/test_vl_dpep16\n        - test/test_allkind\n\nturbomind_base_model:\n    tp:\n        - test/test_tp1\n        - test/test_tp2\n\npytorch_base_model:\n    tp:\n        - test/test_tp1\n        - test/test_tp1_pytorch\n        - test/test_tp2\n\nturbomind_quantization:\n    no_awq:\n        - test/test_tp2\n        - test/test_vl_tp2\n        - test/test_tp2_gpqa\n        - test/test_vl_tp2_gpqa\n        - test/test_cp2tp8\n        - test/test_dpep8\n    gptq:\n        - test/test_tp1\n        - test/test_vl_tp1\n        - test/test_cp2tp8\n        - test/test_dpep8\n    no_kvint4:\n        - test/test_tp2\n        - test/test_vl_tp2\n        - test/test_cp2tp8\n        - test/test_vl_dpep8\n    no_kvint8:\n        - test/test_tp1\n        - test/test_vl_tp1\n        - test/test_dpep8\n        - test/test_vl_cp2tp8\n\npytorch_quantization:\n    awq:\n        - test/test_tp1\n    w8a8:\n        - test/test_tp2\n    no_kvint4:\n        - test/test_tp2\n        - test/test_cp2tp8\n        - test/test_vl_cp2tp8\n    no_kvint8:\n        - test/test_tp1\n        - test/test_vl_tp1\n        - test/test_vl_dpep8\n\nlongtext_model:\n    - test/test_tp1\n    - test/test_tp1_pytorch\n    - test/test_vl_tp2\n    - test/test_vl_tp8\n    - test/test_cp2tp8\n    - test/test_vl_dpep8\n\nbenchmark_model:\n    - test/test_tp1\n    - test/test_tp1_pytorch\n    - test/test_vl_tp2\n    - test/test_vl_tp8\n    - test/test_cp2tp8\n    - test/test_vl_dpep8\n\nmllm_benchmark_model:\n    - test/test_vl_tp1\n    - test/test_vl_tp1_pytorch\n    - test/test_vl_tp2\n    - test/test_vl_tp8\n    - test/test_vl_dpep16\n    - test/test_vl_cp2tp8\n\nevaluate_model:\n    - test/test_tp1\n    - test/test_tp1_pytorch\n    - test/test_vl_tp2\n    - test/test_vl_tp8\n    - test/test_cp2tp8\n    - test/test_dpep16\n    - test/test_vl_dpep8\n\nmllm_evaluate_model:\n    - test/test_vl_tp1\n    - test/test_vl_tp1_pytorch\n    - test/test_vl_tp2\n    - test/test_vl_tp8\n    - test/test_vl_dpep16\n    - test/test_vl_cp2tp8\n"
  },
  {
    "path": "autotest/config_testascend.yml",
    "content": "model_path: /nvme/qa_test_models\nresource_path: /nvme/qa_test_models/resource\nlog_path: /nvme/qa_test_models/autotest_model/log\nserver_log_path: /nvme/qa_test_models/server_log\neval_path: /nvme/qa_test_models/evaluation_report\nmllm_eval_path: /nvme/qa_test_models/mllm_evaluation_report\nbenchmark_path: /nvme/qa_test_models/benchmark_report\ndataset_path: /nvme/qa_test_models/datasets/ShareGPT_V3_unfiltered_cleaned_split.json\nprefix_dataset_path: /nvme/qa_test_models/datasets/prefix_cache_test.json\nenv_tag: ascend\ndevice: ascend\n\nconfig:\n    tp:\n    dp_ep:\n    cp_tp:\n\n\npytorch_chat_model:\n    tp:\n        - test/test_tp1\n\n\npytorch_quantization:\n    awq:\n        - test/test_tp1\n    w8a8:\n        - test/test_tp1\n    no_kvint4:\n        - test/test_tp1\n    no_kvint8:\n        - test/test_tp1\n"
  },
  {
    "path": "autotest/conftest.py",
    "content": "import os\n\nimport pytest\nimport yaml\nfrom utils.config_utils import get_config\nfrom utils.constant import DEFAULT_SERVER\nfrom utils.proxy_distributed_utils import ProxyDistributedManager\nfrom utils.ray_distributed_utils import RayLMDeployManager\n\ncli_prompt_case_file = 'autotest/chat_prompt_case.yml'\ncommon_prompt_case_file = 'autotest/prompt_case.yml'\nconfig_file = 'autotest/config.yml'\n\nPROXY_PORT = 8000\n\n\n@pytest.fixture(scope='session')\ndef config():\n    # Use device-specific config file if DEVICE environment variable is set\n    return get_config()\n\n\n@pytest.fixture(scope='session')\ndef cli_case_config():\n    case_path = os.path.join(cli_prompt_case_file)\n    with open(case_path) as f:\n        case_config = yaml.load(f.read(), Loader=yaml.SafeLoader)\n    return case_config\n\n\n@pytest.fixture(scope='class', autouse=True)\ndef common_case_config():\n    case_path = os.path.join(common_prompt_case_file)\n    with open(case_path) as f:\n        case_config = yaml.load(f.read(), Loader=yaml.SafeLoader)\n    return case_config\n\n\n@pytest.fixture(scope='session')\ndef shared_ray_manager():\n    master_addr = DEFAULT_SERVER\n    env_tag = os.environ.get('TEST_ENV')\n    if env_tag:\n        device_config_path = f'autotest/config_{env_tag}.yml'\n        if os.path.exists(device_config_path):\n            config_path = device_config_path\n        else:\n            config_path = config_file\n    else:\n        config_path = config_file\n\n    with open(config_path) as f:\n        env_config = yaml.load(f.read(), Loader=yaml.SafeLoader)\n    run_id = os.environ.get('RUN_ID', 'local_run')\n    log_dir = os.path.join(env_config.get('server_log_path', '/tmp/lmdeploy_test'), str(run_id).replace('/', '_'))\n\n    manager = RayLMDeployManager(master_addr=master_addr, api_port=PROXY_PORT, log_dir=log_dir, health_check=True)\n\n    manager.start_ray_cluster()\n\n    if manager.is_master:\n        print('🎯 Master node: Ray cluster started, waiting for worker nodes to join...')\n\n    yield manager\n\n    print(f'\\n[Final Cleanup] Node {manager.node_rank} performing final resource cleanup...')\n    manager.cleanup(force=True)\n\n\n@pytest.fixture(scope='session')\ndef shared_proxy_manager():\n    master_addr = DEFAULT_SERVER\n\n    manager = ProxyDistributedManager()\n\n    if manager.is_master:\n        manager.start()\n        print(f'🎯 Master node: LMDeploy Proxy started on {master_addr}:{manager.proxy_port}')\n        print('⏳ Waiting for worker nodes to connect...')\n\n    yield manager\n\n    print(f'\\n[Final Cleanup] Node {manager.node_rank} performing final resource cleanup...')\n    manager.cleanup()\n"
  },
  {
    "path": "autotest/evaluate/eval_config_chat.py",
    "content": "# flake8: noqa\n\nfrom mmengine.config import read_base\nfrom opencompass.models import OpenAISDK\nfrom opencompass.partitioners import NaivePartitioner, NumWorkerPartitioner\nfrom opencompass.runners import LocalRunner\nfrom opencompass.tasks import OpenICLEvalTask, OpenICLInferTask\nfrom opencompass.utils.text_postprocessors import extract_non_reasoning_content\n\n#######################################################################\n#                          PART 0  Essential Configs                  #\n#######################################################################\nwith read_base():\n    # Datasets\n    from opencompass.configs.datasets.aime2025.aime2025_llmjudge_academic import aime2025_datasets\n    from opencompass.configs.datasets.gpqa.gpqa_cascade_eval_academic import gpqa_datasets\n    from opencompass.configs.datasets.HLE.hle_llmverify_academic import hle_datasets\n    from opencompass.configs.datasets.IFEval.IFEval_gen_353ae7 import ifeval_datasets\n    from opencompass.configs.datasets.livecodebench.livecodebench_v6_academic import LCBCodeGeneration_dataset\n    from opencompass.configs.datasets.mmlu_pro.mmlu_pro_0shot_cot_gen_08c1de import mmlu_pro_datasets\n    # Summary Groups\n    from opencompass.configs.summarizers.groups.mmlu_pro import mmlu_pro_summary_groups\n\n#######################################################################\n#                         Model Configuration                         #\n#######################################################################\n\nMODEL_NAME = ''\nMODEL_PATH = ''\nAPI_BASE = ''\nJUDGE_MODEL_NAME = ''\nJUDGE_MODEL_PATH = ''\nJUDGE_API_BASE = ''\n\napi_meta_template = dict(round=[\n    dict(role='HUMAN', api_role='HUMAN'),\n    dict(role='BOT', api_role='BOT', generate=True),\n])\n\n# Use OpenAISDK to configure LMDeploy OpenAI interface\nmodels = [\n    dict(type=OpenAISDK,\n         abbr=f'{MODEL_NAME}',\n         path=MODEL_PATH,\n         key='EMPTY',\n         openai_api_base=API_BASE,\n         retry=3,\n         run_cfg=dict(num_gpus=0),\n         meta_template=api_meta_template,\n         timeout=10800,\n         pred_postprocessor=dict(type=extract_non_reasoning_content))\n]\n\n#######################################################################\n#                          PART 1  Datasets List                      #\n#######################################################################\n# datasets list for evaluation\nmmlu_pro_datasets = [x for x in mmlu_pro_datasets if 'math' in x['abbr'] or 'other' in x['abbr']]\n\n# Modify datasets list to exclude hle_datasets and LCBCodeGeneration_dataset\ndatasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), []) + [LCBCodeGeneration_dataset]\n\n# LLM judge config: using LLM to evaluate predictions\njudge_cfg = dict(\n    type=OpenAISDK,\n    abbr=f'{JUDGE_MODEL_NAME}',\n    path=JUDGE_MODEL_NAME,\n    key='EMPTY',\n    openai_api_base=JUDGE_API_BASE,\n    meta_template=dict(round=[\n        dict(role='HUMAN', api_role='HUMAN'),\n        dict(role='BOT', api_role='BOT', generate=True),\n    ]),\n    query_per_second=16,\n    batch_size=1024,\n    temperature=0.001,\n    tokenizer_path=JUDGE_MODEL_PATH,\n    verbose=True,\n    max_out_len=8192,\n    max_seq_len=32768,\n    mode='mid',\n)\n\nfor item in datasets:\n    if 'judge_cfg' in item['eval_cfg']['evaluator']:\n        item['eval_cfg']['evaluator']['judge_cfg'] = judge_cfg\n    if 'llm_evaluator' in item['eval_cfg']['evaluator'].keys(\n    ) and 'judge_cfg' in item['eval_cfg']['evaluator']['llm_evaluator']:\n        item['eval_cfg']['evaluator']['llm_evaluator']['judge_cfg'] = judge_cfg\n\n#######################################################################\n#                       PART 2  Dataset Summarizer                    #\n#######################################################################\n\ncore_summary_groups = [\n    {\n        'name':\n        'core_average',\n        'subsets': [\n            ['IFEval', 'Prompt-level-strict-accuracy'],\n            ['hle_llmjudge', 'accuracy'],\n            ['aime2025_repeat_32', 'accuracy (32 runs average)'],\n            ['GPQA_diamond_repeat_4', 'accuracy (4 runs average)'],\n            ['mmlu_pro', 'naive_average'],\n            'mmlu_pro_math',\n            'mmlu_pro_other',\n            ['lcb_code_generation_repeat_6', 'pass@1 (6 runs average)'],\n        ],\n    },\n]\n\nsummarizer = dict(\n    dataset_abbrs=[\n        ['core_average', 'naive_average'],\n        ['IFEval', 'Prompt-level-strict-accuracy'],\n        ['hle_llmjudge', 'accuracy'],\n        ['GPQA_diamond_repeat_4', 'accuracy (4 runs average)'],\n        ['aime2025_repeat_32', 'accuracy (32 runs average)'],\n        ['mmlu_pro', 'naive_average'],\n        'mmlu_pro_math',\n        'mmlu_pro_other',\n        ['lcb_code_generation_repeat_6', 'pass@1 (6 runs average)'],\n    ],\n    summary_groups=sum([v for k, v in locals().items() if k.endswith('_summary_groups')], []) + core_summary_groups,\n)\n\nfor item in datasets:\n    if 'max_out_len' in item['infer_cfg']['inferencer']:\n        del item['infer_cfg']['inferencer']['max_out_len']\n\nNUM_WORKERS = 8\n\ninfer = dict(\n    partitioner=dict(type=NumWorkerPartitioner, num_worker=NUM_WORKERS),\n    runner=dict(\n        type=LocalRunner,\n        max_num_workers=64,\n        retry=0,\n        task=dict(type=OpenICLInferTask),\n    ),\n)\n\n# eval with local runner\neval = dict(\n    partitioner=dict(type=NaivePartitioner, n=10),\n    runner=dict(type=LocalRunner, max_num_workers=64, task=dict(type=OpenICLEvalTask)),\n)\n\ninfer['partitioner']['num_worker'] = 64\n"
  },
  {
    "path": "autotest/evaluate/test_api_evaluate.py",
    "content": "import os\nimport time\n\nimport pytest\nimport utils.constant as constant\nfrom utils.config_utils import get_case_str_by_config, get_func_config_list, get_workerid\nfrom utils.evaluate_utils import eval_test\nfrom utils.proxy_distributed_utils import ApiServerPerTest, proxy_worker_node_wait\nfrom utils.ray_distributed_utils import ray_worker_node_wait\nfrom utils.run_restful_chat import start_openai_service, start_proxy_server, stop_restful_api, terminate_restful_api\n\n\ndef _run_ray_distributed_test(\n        config,\n        run_config,\n        worker_id,\n        test_type='infer',\n        manager=None,  # ← New parameter: pass in shared manager\n        eval_config_name='default'):\n    \"\"\"Universal distributed test executor (using shared Ray cluster)\"\"\"\n    assert manager is not None, 'Manager instance must be provided'\n    if 'gpt' in run_config.get('model', '').lower():\n        eval_config_name = 'gpt'\n    elif 'intern-s1-pro' in run_config.get('model', '').lower():\n        eval_config_name = 'intern-s1-pro'\n    if str(config.get('env_tag')) == 'ascend':\n        eval_config_name = f'{eval_config_name}-2batch'\n\n    preset_config = constant.EVAL_CONFIGS.get(eval_config_name, {})\n\n    if manager.is_master:\n        model_path = os.path.join(config['model_path'], run_config['model'])\n        eval_path = config.get('eval_path')\n\n        # Start API Server for current model (master node starts/stops, worker nodes verify)\n        manager.start_lmdeploy_api_server(config=config, run_config=run_config)\n\n        try:\n            print(f'🧪 Master node executing {test_type} test ({eval_config_name})...')\n            case_name = get_case_str_by_config(run_config)\n\n            result, msg = eval_test(model_path,\n                                    eval_path,\n                                    case_name,\n                                    port=constant.PROXY_PORT,\n                                    test_type=test_type,\n                                    **preset_config)\n            assert result, f'❌ {test_type} test failed: {msg}'\n            print(f'✅ {test_type} test passed')\n\n        finally:\n            # Clean up API Server for current model (worker nodes skip)\n            manager.cleanup(force=False)\n    else:\n        time.sleep(10)\n        ray_worker_node_wait(manager, timeout_minutes=4880)\n\n\ndef _run_proxy_distributed_test(config,\n                                run_config,\n                                worker_id,\n                                test_type='infer',\n                                manager=None,\n                                eval_config_name='default'):\n    assert manager is not None, 'Manager instance must be provided'\n\n    if 'gpt' in run_config.get('model', '').lower():\n        eval_config_name = 'gpt'\n    elif 'intern-s1-pro' in run_config.get('model', '').lower():\n        eval_config_name = 'intern-s1-pro'\n\n    if str(config.get('env_tag')) == 'ascend':\n        eval_config_name = f'{eval_config_name}-2batch'\n\n    preset_config = constant.EVAL_CONFIGS.get(eval_config_name, {})\n    model_name = run_config['model']\n    model_path = os.path.join(config['model_path'], model_name)\n\n    api_server = ApiServerPerTest(proxy_manager=manager, config=config, run_config=run_config)\n    api_server.start()\n\n    try:\n        if manager.is_master:\n            api_server.wait_until_ready()\n            print(f'🧪 Master node executing {test_type} test ({eval_config_name})...')\n            eval_path = config.get('eval_path')\n            case_name = get_case_str_by_config(run_config)\n\n            extra_config = {'max-num-workers': 16}\n\n            result, msg = eval_test(model_path,\n                                    eval_path,\n                                    case_name,\n                                    port=constant.PROXY_PORT,\n                                    test_type=test_type,\n                                    extra_config=extra_config,\n                                    **preset_config)\n            assert result, f'❌ {test_type} test failed: {msg}'\n            print(f'✅ {test_type} test passed')\n\n        else:\n            print(f'⏸️ Worker node {manager.node_rank} waiting for master to complete test...')\n            proxy_worker_node_wait(manager, timeout_minutes=4880)\n\n    finally:\n        api_server.cleanup()\n        if manager.is_master:\n            time.sleep(1)\n\n\ndef run_eval_test(config, run_config, worker_id, test_type='infer', eval_config_name='default'):\n    \"\"\"Run test with specified evaluation configuration.\"\"\"\n    if 'gpt' in run_config.get('model', '').lower():\n        eval_config_name = 'gpt'\n    elif 'sdar' in run_config.get('model', '').lower():\n        eval_config_name = 'sdar'\n    elif 'intern-s1-pro' in run_config.get('model', '').lower():\n        eval_config_name = 'intern-s1-pro'\n    if str(config.get('env_tag')) == 'a100':\n        eval_config_name = f'{eval_config_name}-32k'\n    elif str(config.get('env_tag')) == 'ascend':\n        eval_config_name = f'{eval_config_name}-2batch'\n    preset_config = constant.EVAL_CONFIGS.get(eval_config_name, {})\n    eval_path = config.get('eval_path')\n\n    total_gpus = int(os.environ.get('TOTAL_GPU_COUNT', '8'))\n    work_num = int(total_gpus / run_config.get('parallel_config', {}).get('tp', 1))\n    extra_config = {'max-num-workers': min(work_num * 16, 64)}\n    case_name = get_case_str_by_config(run_config)\n\n    if test_type == 'infer':\n        proxy_pid, proxy_process = start_proxy_server(config.get('server_log_path'), constant.PROXY_PORT,\n                                                      f'{case_name}_infer')\n        run_config_new = run_config.copy()\n        if 'extra_params' not in run_config_new:\n            run_config_new['extra_params'] = {}\n        run_config_new['extra_params']['proxy-url'] = f'http://{constant.DEFAULT_SERVER}:{constant.PROXY_PORT}'\n        run_config_new['extra_params']['server-name'] = constant.DEFAULT_SERVER\n\n        from concurrent.futures import ThreadPoolExecutor\n\n        def run_openai_service_start(i):\n            return start_openai_service(config, run_config_new, f'gw{i}')\n\n        with ThreadPoolExecutor(max_workers=work_num) as executor:\n            futures = [executor.submit(run_openai_service_start, i) for i in range(int(work_num))]\n        results = []\n        for future in futures:\n            pid, content = future.result()\n            results.append((pid, content))\n\n        try:\n            model_path = os.path.join(config.get('model_path'), run_config.get('model'))\n            eval_test(model_path,\n                      eval_path,\n                      case_name,\n                      port=constant.PROXY_PORT,\n                      test_type=test_type,\n                      extra_config=extra_config,\n                      **preset_config)\n        finally:\n            for i in range(work_num):\n                terminate_restful_api(f'gw{i}')\n            stop_restful_api(proxy_pid, proxy_process)\n    else:  # eval\n        port = constant.PROXY_PORT + get_workerid(worker_id)\n        proxy_pid, proxy_process = start_proxy_server(config.get('server_log_path'), port, f'{case_name}_eval')\n        eval_run_config = constant.EVAL_RUN_CONFIG.copy()\n        if 'extra_params' not in eval_run_config:\n            eval_run_config['extra_params'] = {}\n        eval_run_config['extra_params']['proxy-url'] = f'http://{constant.DEFAULT_SERVER}:{port}'\n\n        pid, content = start_openai_service(config, eval_run_config, worker_id)\n        try:\n            if pid > 0:\n                model_path = os.path.join(config.get('model_path'), eval_run_config.get('model'))\n                eval_test(model_path,\n                          eval_path,\n                          case_name,\n                          port=port,\n                          test_type=test_type,\n                          extra_config=extra_config,\n                          **preset_config)\n            else:\n                assert False, f'Failed to start RESTful API server: {content}'\n        finally:\n            if pid > 0:\n                terminate_restful_api(worker_id)\n            stop_restful_api(proxy_pid, proxy_process)\n\n\ndef get_models(backend, parallel_config):\n    return get_func_config_list(backend, parallel_config, func_type='evaluate', extra={'session_len': 65536})\n\n\n@pytest.mark.infer\n@pytest.mark.turbomind\n@pytest.mark.gpu_num_1\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models('turbomind', {'tp': 1}))\ndef test_turbomind_infer_tp1(config, run_config, worker_id):\n    run_eval_test(config, run_config, worker_id, 'infer')\n\n\n@pytest.mark.infer\n@pytest.mark.turbomind\n@pytest.mark.gpu_num_2\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models('turbomind', {'tp': 2}))\ndef test_turbomind_infer_tp2(config, run_config, worker_id):\n    run_eval_test(config, run_config, worker_id, 'infer')\n\n\n@pytest.mark.infer\n@pytest.mark.turbomind\n@pytest.mark.gpu_num_4\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models('turbomind', {'tp': 4}))\ndef test_turbomind_infer_tp4(config, run_config, worker_id):\n    run_eval_test(config, run_config, worker_id, 'infer')\n\n\n@pytest.mark.infer\n@pytest.mark.turbomind\n@pytest.mark.gpu_num_8\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models('turbomind', {'tp': 8}))\ndef test_turbomind_infer_tp8(config, run_config, worker_id):\n    run_eval_test(config, run_config, worker_id, 'infer')\n\n\n@pytest.mark.infer\n@pytest.mark.turbomind\n@pytest.mark.gpu_num_cp2tp8\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models('turbomind', {'cp': 2, 'tp': 8}))\ndef test_turbomind_infer_cp2tp8(config, run_config, worker_id):\n    run_eval_test(config, run_config, worker_id, 'infer')\n\n\n@pytest.mark.infer\n@pytest.mark.pytorch\n@pytest.mark.gpu_num_1\n@pytest.mark.test_ascend\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models('pytorch', {'tp': 1}))\ndef test_pytorch_restful_tp1(config, run_config, worker_id):\n    run_eval_test(config, run_config, worker_id, 'infer')\n\n\n@pytest.mark.infer\n@pytest.mark.pytorch\n@pytest.mark.gpu_num_2\n@pytest.mark.test_ascend\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models('pytorch', {'tp': 2}))\ndef test_pytorch_restful_tp2(config, run_config, worker_id):\n    run_eval_test(config, run_config, worker_id, 'infer')\n\n\n@pytest.mark.infer\n@pytest.mark.pytorch\n@pytest.mark.gpu_num_4\n@pytest.mark.test_ascend\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models('pytorch', {'tp': 4}))\ndef test_pytorch_restful_tp4(config, run_config, worker_id):\n    run_eval_test(config, run_config, worker_id, 'infer')\n\n\n@pytest.mark.infer\n@pytest.mark.pytorch\n@pytest.mark.gpu_num_8\n@pytest.mark.test_ascend\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models('pytorch', {'tp': 8}))\ndef test_pytorch_restful_tp8(config, run_config, worker_id):\n    run_eval_test(config, run_config, worker_id, 'infer')\n\n\n@pytest.mark.infer\n@pytest.mark.pytorch\n@pytest.mark.gpu_num_16\n@pytest.mark.test_ascend\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models('pytorch', {'tp': 16}))\ndef test_pytorch_restful_tp16(config, run_config, worker_id):\n    run_eval_test(config, run_config, worker_id, 'infer')\n\n\n@pytest.mark.infer\n@pytest.mark.pytorch\n@pytest.mark.gpu_num_distributed_tp16\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models('pytorch', {'tp': 16}))\ndef test_pytorch_restful_distributed_tp16(shared_ray_manager, config, run_config, worker_id):\n    _run_ray_distributed_test(config=config,\n                              run_config=run_config,\n                              worker_id=worker_id,\n                              test_type='infer',\n                              manager=shared_ray_manager)\n\n\n@pytest.mark.infer\n@pytest.mark.pytorch\n@pytest.mark.gpu_num_distributed_dpep8\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models('pytorch', {'dp': 8, 'ep': 8}))\ndef test_pytorch_restful_distributed_dpep8(shared_proxy_manager, config, run_config, worker_id):\n    _run_proxy_distributed_test(config=config,\n                                run_config=run_config,\n                                worker_id=worker_id,\n                                test_type='infer',\n                                manager=shared_proxy_manager)\n\n\n@pytest.mark.infer\n@pytest.mark.pytorch\n@pytest.mark.gpu_num_distributed_dpep16\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models('pytorch', {'dp': 16, 'ep': 16}))\ndef test_pytorch_restful_distributed_dpep16(shared_proxy_manager, config, run_config, worker_id):\n    _run_proxy_distributed_test(config=config,\n                                run_config=run_config,\n                                worker_id=worker_id,\n                                test_type='infer',\n                                manager=shared_proxy_manager)\n\n\n@pytest.mark.eval\n@pytest.mark.turbomind\n@pytest.mark.gpu_num_1\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models('turbomind', {'tp': 1}))\ndef test_turbomind_eval_tp1(config, run_config, worker_id):\n    run_eval_test(config, run_config, worker_id, 'eval')\n\n\n@pytest.mark.eval\n@pytest.mark.turbomind\n@pytest.mark.gpu_num_2\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models('turbomind', {'tp': 2}))\ndef test_turbomind_eval_tp2(config, run_config, worker_id):\n    run_eval_test(config, run_config, worker_id, 'eval')\n\n\n@pytest.mark.eval\n@pytest.mark.turbomind\n@pytest.mark.gpu_num_4\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models('turbomind', {'tp': 4}))\ndef test_turbomind_eval_tp4(config, run_config, worker_id):\n    run_eval_test(config, run_config, worker_id, 'eval')\n\n\n@pytest.mark.eval\n@pytest.mark.turbomind\n@pytest.mark.gpu_num_8\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models('turbomind', {'tp': 8}))\ndef test_turbomind_eval_tp8(config, run_config, worker_id):\n    run_eval_test(config, run_config, worker_id, 'eval')\n\n\n@pytest.mark.eval\n@pytest.mark.pytorch\n@pytest.mark.gpu_num_1\n@pytest.mark.test_ascend\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models('pytorch', {'tp': 1}))\ndef test_pytorch_eval_tp1(config, run_config, worker_id):\n    run_eval_test(config, run_config, worker_id, 'eval')\n\n\n@pytest.mark.eval\n@pytest.mark.pytorch\n@pytest.mark.gpu_num_2\n@pytest.mark.test_ascend\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models('pytorch', {'tp': 2}))\ndef test_pytorch_eval_tp2(config, run_config, worker_id):\n    run_eval_test(config, run_config, worker_id, 'eval')\n\n\n@pytest.mark.eval\n@pytest.mark.pytorch\n@pytest.mark.gpu_num_4\n@pytest.mark.test_ascend\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models('pytorch', {'tp': 4}))\ndef test_pytorch_eval_tp4(config, run_config, worker_id):\n    run_eval_test(config, run_config, worker_id, 'eval')\n\n\n@pytest.mark.eval\n@pytest.mark.pytorch\n@pytest.mark.gpu_num_8\n@pytest.mark.test_ascend\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models('pytorch', {'tp': 8}))\ndef test_pytorch_eval_tp8(config, run_config, worker_id):\n    run_eval_test(config, run_config, worker_id, 'eval')\n\n\n@pytest.mark.eval\n@pytest.mark.pytorch\n@pytest.mark.gpu_num_16\n@pytest.mark.test_ascend\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models('pytorch', {'tp': 16}))\ndef test_pytorch_eval_tp16(config, run_config, worker_id):\n    run_eval_test(config, run_config, worker_id, 'eval')\n\n\n@pytest.mark.eval\n@pytest.mark.pytorch\n@pytest.mark.gpu_num_distributed_tp16\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models('pytorch', {'tp': 16}))\ndef test_pytorch_eval_distributed_tp16(config, run_config, worker_id):\n    run_eval_test(config, run_config, worker_id, 'eval')\n\n\n@pytest.mark.eval\n@pytest.mark.pytorch\n@pytest.mark.gpu_num_distributed_dpep8\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models('pytorch', {'dp': 8, 'ep': 8}))\ndef test_pytorch_eval_distributed_dpep8(config, run_config, worker_id):\n    run_eval_test(config, run_config, worker_id, 'eval')\n\n\n@pytest.mark.eval\n@pytest.mark.pytorch\n@pytest.mark.gpu_num_distributed_dpep16\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models('pytorch', {'dp': 16, 'ep': 16}))\ndef test_pytorch_eval_distributed_dpep16(config, run_config, worker_id):\n    run_eval_test(config, run_config, worker_id, 'eval')\n\n\n@pytest.mark.eval\n@pytest.mark.turbomind\n@pytest.mark.gpu_num_cp2tp8\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models('turbomind', {'cp': 2, 'tp': 8}))\ndef test_turbomind_eval_cp2tp8(config, run_config, worker_id):\n    run_eval_test(config, run_config, worker_id, 'eval')\n"
  },
  {
    "path": "autotest/evaluate/test_mllm_api_evaluate.py",
    "content": "import os\n\nimport pytest\nimport utils.constant as constant\nfrom utils.config_utils import get_case_str_by_config, get_func_config_list, get_workerid\nfrom utils.evaluate_utils import mllm_eval_test\nfrom utils.run_restful_chat import start_openai_service, start_proxy_server, stop_restful_api, terminate_restful_api\n\n\ndef run_eval_test(config, run_config, worker_id, test_type='infer', eval_config_name='default'):\n    extra_config = constant.MLLM_EVAL_CONFIGS.get(eval_config_name, {})\n    eval_path = config.get('mllm_eval_path')\n    case_name = get_case_str_by_config(run_config)\n    if test_type == 'infer':\n        proxy_pid, proxy_process = start_proxy_server(config.get('server_log_path'), constant.PROXY_PORT,\n                                                      f'{case_name}_infer')\n        total_gpus = int(os.environ.get('TOTAL_GPU_COUNT', '8'))\n        work_num = int(total_gpus / run_config.get('parallel_config', {}).get('tp', 1))\n        run_config_new = run_config.copy()\n        if 'extra_params' not in run_config_new:\n            run_config_new['extra_params'] = {}\n        run_config_new['extra_params']['proxy-url'] = f'http://{constant.DEFAULT_SERVER}:{constant.PROXY_PORT}'\n\n        from concurrent.futures import ThreadPoolExecutor\n\n        def run_openai_service_start(i):\n            return start_openai_service(config, run_config_new, f'gw{i}')\n\n        with ThreadPoolExecutor(max_workers=work_num) as executor:\n            futures = [executor.submit(run_openai_service_start, i) for i in range(int(work_num))]\n        results = []\n        for future in futures:\n            pid, content = future.result()\n            results.append((pid, content))\n\n        try:\n            model_path = os.path.join(config.get('model_path'), run_config.get('model'))\n            extra_config['api-nproc'] = work_num * 16\n            mllm_eval_test(model_path,\n                           eval_path,\n                           case_name,\n                           port=constant.PROXY_PORT,\n                           test_type=test_type,\n                           extra_config=extra_config)\n        finally:\n            for i in range(work_num):\n                terminate_restful_api(f'gw{i}')\n            stop_restful_api(proxy_pid, proxy_process)\n    else:  # eval\n        port = constant.PROXY_PORT + get_workerid(worker_id)\n        proxy_pid, proxy_process = start_proxy_server(config.get('server_log_path'), port, f'{case_name}_eval')\n        eval_run_config = constant.EVAL_RUN_CONFIG.copy()\n        if 'extra_params' not in eval_run_config:\n            eval_run_config['extra_params'] = {}\n        eval_run_config['extra_params']['proxy-url'] = f'http://{constant.DEFAULT_SERVER}:{port}'\n        pid, content = start_openai_service(config, eval_run_config, worker_id)\n        try:\n            if pid > 0:\n                model_path = os.path.join(config.get('model_path'), eval_run_config.get('model'))\n                mllm_eval_test(model_path, eval_path, case_name, port=port, test_type=test_type)\n            else:\n                assert False, f'Failed to start RESTful API server: {content}'\n        finally:\n            if pid > 0:\n                terminate_restful_api(worker_id)\n            stop_restful_api(proxy_pid, proxy_process)\n\n\ndef get_models(backend, parallel_config):\n    return get_func_config_list(backend,\n                                parallel_config,\n                                model_type='vl_model',\n                                func_type='mllm_evaluate',\n                                extra={\n                                    'session-len': 65536,\n                                    'cache-max-entry-count': 0.6\n                                })\n\n\n@pytest.mark.infer\n@pytest.mark.turbomind\n@pytest.mark.gpu_num_1\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models('turbomind', {'tp': 1}))\ndef test_turbomind_vl_eval_tp1(config, run_config, worker_id):\n    run_eval_test(config, run_config, worker_id, 'infer')\n\n\n@pytest.mark.infer\n@pytest.mark.turbomind\n@pytest.mark.gpu_num_2\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models('turbomind', {'tp': 2}))\ndef test_turbomind_vl_eval_tp2(config, run_config, worker_id):\n    run_eval_test(config, run_config, worker_id, 'infer')\n\n\n@pytest.mark.infer\n@pytest.mark.turbomind\n@pytest.mark.gpu_num_4\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models('turbomind', {'tp': 4}))\ndef test_turbomind_vl_eval_tp4(config, run_config, worker_id):\n    run_eval_test(config, run_config, worker_id, 'infer')\n\n\n@pytest.mark.infer\n@pytest.mark.turbomind\n@pytest.mark.gpu_num_8\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models('turbomind', {'tp': 8}))\ndef test_turbomind_vl_eval_tp8(config, run_config, worker_id):\n    run_eval_test(config, run_config, worker_id, 'infer')\n\n\n@pytest.mark.infer\n@pytest.mark.pytorch\n@pytest.mark.gpu_num_1\n@pytest.mark.test_ascend\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models('pytorch', {'tp': 1}))\ndef test_pytorch_vl_eval_tp1(config, run_config, worker_id):\n    run_eval_test(config, run_config, worker_id, 'infer')\n\n\n@pytest.mark.infer\n@pytest.mark.pytorch\n@pytest.mark.gpu_num_2\n@pytest.mark.test_ascend\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models('pytorch', {'tp': 2}))\ndef test_pytorch_vl_eval_tp2(config, run_config, worker_id):\n    run_eval_test(config, run_config, worker_id, 'infer')\n\n\n@pytest.mark.infer\n@pytest.mark.pytorch\n@pytest.mark.gpu_num_4\n@pytest.mark.test_ascend\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models('pytorch', {'tp': 4}))\ndef test_pytorch_vl_eval_tp4(config, run_config, worker_id):\n    run_eval_test(config, run_config, worker_id, 'infer')\n\n\n@pytest.mark.infer\n@pytest.mark.pytorch\n@pytest.mark.gpu_num_8\n@pytest.mark.test_ascend\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models('pytorch', {'tp': 8}))\ndef test_pytorch_vl_eval_tp8(config, run_config, worker_id):\n    run_eval_test(config, run_config, worker_id, 'infer')\n\n\n@pytest.mark.infer\n@pytest.mark.pytorch\n@pytest.mark.gpu_num_16\n@pytest.mark.test_ascend\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models('pytorch', {'tp': 16}))\ndef test_pytorch_vl_eval_tp16(config, run_config, worker_id):\n    run_eval_test(config, run_config, worker_id, 'infer')\n\n\n@pytest.mark.eval\n@pytest.mark.turbomind\n@pytest.mark.gpu_num_1\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models('turbomind', {'tp': 1}))\ndef test_turbomind_eval_tp1(config, run_config, worker_id):\n    run_eval_test(config, run_config, worker_id, 'eval')\n\n\n@pytest.mark.eval\n@pytest.mark.turbomind\n@pytest.mark.gpu_num_2\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models('turbomind', {'tp': 2}))\ndef test_turbomind_eval_tp2(config, run_config, worker_id):\n    run_eval_test(config, run_config, worker_id, 'eval')\n\n\n@pytest.mark.eval\n@pytest.mark.turbomind\n@pytest.mark.gpu_num_4\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models('turbomind', {'tp': 4}))\ndef test_turbomind_eval_tp4(config, run_config, worker_id):\n    run_eval_test(config, run_config, worker_id, 'eval')\n\n\n@pytest.mark.eval\n@pytest.mark.turbomind\n@pytest.mark.gpu_num_8\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models('turbomind', {'tp': 8}))\ndef test_turbomind_eval_tp8(config, run_config, worker_id):\n    run_eval_test(config, run_config, worker_id, 'eval')\n\n\n@pytest.mark.eval\n@pytest.mark.pytorch\n@pytest.mark.gpu_num_1\n@pytest.mark.test_ascend\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models('pytorch', {'tp': 1}))\ndef test_pytorch_eval_tp1(config, run_config, worker_id):\n    run_eval_test(config, run_config, worker_id, 'eval')\n\n\n@pytest.mark.eval\n@pytest.mark.pytorch\n@pytest.mark.gpu_num_2\n@pytest.mark.test_ascend\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models('pytorch', {'tp': 2}))\ndef test_pytorch_eval_tp2(config, run_config, worker_id):\n    run_eval_test(config, run_config, worker_id, 'eval')\n\n\n@pytest.mark.eval\n@pytest.mark.pytorch\n@pytest.mark.gpu_num_4\n@pytest.mark.test_ascend\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models('pytorch', {'tp': 4}))\ndef test_pytorch_eval_tp4(config, run_config, worker_id):\n    run_eval_test(config, run_config, worker_id, 'eval')\n\n\n@pytest.mark.eval\n@pytest.mark.pytorch\n@pytest.mark.gpu_num_8\n@pytest.mark.test_ascend\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models('pytorch', {'tp': 8}))\ndef test_pytorch_eval_tp8(config, run_config, worker_id):\n    run_eval_test(config, run_config, worker_id, 'eval')\n\n\n@pytest.mark.eval\n@pytest.mark.pytorch\n@pytest.mark.gpu_num_16\n@pytest.mark.test_ascend\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.parametrize('run_config', get_models('pytorch', {'tp': 16}))\ndef test_pytorch_eval_tp16(config, run_config, worker_id):\n    run_eval_test(config, run_config, worker_id, 'eval')\n"
  },
  {
    "path": "autotest/interface/pipeline/test_pipeline_func.py",
    "content": "import multiprocessing as mp\n\nimport pydantic\nimport pytest\nfrom utils.config_utils import set_device_env_variable, unset_device_env_variable\nfrom utils.pipeline_chat import (assert_pipeline_batch_return, assert_pipeline_batch_stream_return,\n                                 assert_pipeline_common_log, assert_pipeline_single_return,\n                                 assert_pipeline_single_stream_return, save_pipeline_common_log)\nfrom utils.restful_return_check import has_repeated_fragment\n\nfrom lmdeploy import GenerationConfig, PytorchEngineConfig, TurbomindEngineConfig, pipeline\nfrom lmdeploy.utils import is_bf16_supported\n\n\ndef init_pipeline(model_path, backend_config):\n    if not is_bf16_supported() and isinstance(backend_config, PytorchEngineConfig):\n        backend_config.dtype = 'float16'\n    return pipeline(model_path, backend_config=backend_config)\n\n\ndef run_case_in_spawn(worker_id, target, args):\n    needs_device_env = 'gw' in worker_id\n    if needs_device_env:\n        set_device_env_variable(worker_id, parallel_config=2)\n    ctx = mp.get_context('spawn')\n    process = ctx.Process(target=target, args=args)\n    process.start()\n    process.join()\n    if needs_device_env:\n        unset_device_env_variable()\n\n\ndef run_pipeline_testcase_prompt(config, model, backend, file_name):\n    model_path = '/'.join([config.get('model_path'), model])\n    backend_config = backend(tp=2)\n    pipe = init_pipeline(model_path, backend_config=backend_config)\n    response = pipe('Hi, pls intro yourself')\n    result, msg = assert_pipeline_single_return(response)\n    save_pipeline_common_log(config, file_name, result, response, msg)\n    pipe.close()\n\n\ndef run_pipeline_testcase_prompt_stream(config, model, backend, file_name):\n    model_path = '/'.join([config.get('model_path'), model])\n    backend_config = backend(tp=2)\n    pipe = init_pipeline(model_path, backend_config=backend_config)\n    response = []\n    for item in pipe.stream_infer('Hi, pls intro yourself'):\n        response.append(item)\n    result, msg = assert_pipeline_single_stream_return(response)\n    save_pipeline_common_log(config, file_name, result, response, msg)\n    pipe.close()\n\n\ndef run_pipeline_testcase_multi_prompt(config, model, backend, file_name):\n    model_path = '/'.join([config.get('model_path'), model])\n    backend_config = backend(tp=2)\n    pipe = init_pipeline(model_path, backend_config=backend_config)\n    response = pipe(['Hi, pls intro yourself', 'Shanghai is'])\n    result, msg = assert_pipeline_batch_return(response, 2)\n    save_pipeline_common_log(config, file_name, result, response, msg)\n    pipe.close()\n\n\ndef run_pipeline_testcase_multi_prompt_stream(config, model, backend, file_name):\n    model_path = '/'.join([config.get('model_path'), model])\n    backend_config = backend(tp=2)\n    pipe = init_pipeline(model_path, backend_config=backend_config)\n    response = []\n    for item in pipe.stream_infer(['Pls intro yourself', 'Shanghai is']):\n        response.append(item)\n    result, msg = assert_pipeline_batch_stream_return(response, 2)\n    save_pipeline_common_log(config, file_name, result, response, msg)\n    pipe.close()\n\n\ndef run_pipeline_testcase_message(config, model, backend, file_name):\n    model_path = '/'.join([config.get('model_path'), model])\n    backend_config = backend(tp=2)\n    pipe = init_pipeline(model_path, backend_config=backend_config)\n    prompts = [[{'role': 'user', 'content': 'Hi, pls intro yourself'}]]\n    response = pipe(prompts)\n    result, msg = assert_pipeline_batch_return(response)\n    save_pipeline_common_log(config, file_name, result, response, msg)\n    pipe.close()\n\n\ndef run_pipeline_testcase_message_stream(config, model, backend, file_name):\n    model_path = '/'.join([config.get('model_path'), model])\n    backend_config = backend(tp=2)\n    pipe = init_pipeline(model_path, backend_config=backend_config)\n    prompts = [[{'role': 'user', 'content': 'Hi, pls intro yourself'}]]\n    response = []\n    for item in pipe.stream_infer(prompts):\n        response.append(item)\n    result, msg = assert_pipeline_single_stream_return(response)\n    save_pipeline_common_log(config, file_name, result, response, msg)\n    pipe.close()\n\n\ndef run_pipeline_testcase_message_batch(config, model, backend, file_name):\n    model_path = '/'.join([config.get('model_path'), model])\n    backend_config = backend(tp=2)\n    pipe = init_pipeline(model_path, backend_config=backend_config)\n    prompts = [[{'role': 'user', 'content': 'Hi, pls intro yourself'}], [{'role': 'user', 'content': 'Shanghai is'}]]\n    response = pipe(prompts)\n    result, msg = assert_pipeline_batch_return(response, 2)\n    save_pipeline_common_log(config, file_name, result, response, msg)\n    pipe.close()\n\n\ndef run_pipeline_testcase_message_batch_stream(config, model, backend, file_name):\n    model_path = '/'.join([config.get('model_path'), model])\n    backend_config = backend(tp=2)\n    pipe = init_pipeline(model_path, backend_config=backend_config)\n    prompts = [[{'role': 'user', 'content': 'Hi, pls intro yourself'}], [{'role': 'user', 'content': 'Shanghai is'}]]\n    response = []\n    for item in pipe.stream_infer(prompts):\n        response.append(item)\n    result, msg = assert_pipeline_batch_stream_return(response, 2)\n    save_pipeline_common_log(config, file_name, result, response, msg)\n    pipe.close()\n\n\ndef run_pipeline_testcase_logprobs(config, model, backend, file_name):\n    model_path = '/'.join([config.get('model_path'), model])\n    backend_config = backend(tp=2)\n    pipe = init_pipeline(model_path, backend_config=backend_config)\n    gen_config = GenerationConfig(logprobs=10, max_new_tokens=5, top_k=40, do_sample=True)\n    response = pipe('Hi, pls intro yourself', gen_config=gen_config)\n    result, msg = assert_pipeline_single_return(response, logprobs_num=10)\n    save_pipeline_common_log(config, file_name, result, response, msg)\n    pipe.close()\n\n\ndef run_pipeline_testcase_logprobs_stream(config, model, backend, file_name):\n    model_path = '/'.join([config.get('model_path'), model])\n    backend_config = backend(tp=2)\n    pipe = init_pipeline(model_path, backend_config=backend_config)\n    gen_config = GenerationConfig(logprobs=10, max_new_tokens=5, top_k=40, do_sample=True)\n    response = []\n    for item in pipe.stream_infer('Hi, pls intro yourself', gen_config=gen_config):\n        response.append(item)\n    result, msg = assert_pipeline_single_stream_return(response, logprobs_num=10)\n    save_pipeline_common_log(config, file_name, result, response, msg)\n    pipe.close()\n\n\ndef run_pipeline_testcase_session_len(config, model, backend, file_name):\n    model_path = '/'.join([config.get('model_path'), model])\n    backend_config = backend(session_len=10, tp=2)\n    pipe = init_pipeline(model_path, backend_config=backend_config)\n    response = pipe(['Hi, pls intro yourself', 'Shanghai is'])\n    result = True\n    for i in range(2):\n        result &= response[i].finish_reason == 'error'\n        result &= response[i].generate_token_len == 0\n        result &= response[i].text == 'internal error happened, status code ResponseType.INPUT_LENGTH_ERROR'\n    save_pipeline_common_log(config, file_name, result, response)\n    pipe.close()\n\n\ndef run_pipeline_testcase_min_new_tokens(config, model, backend, file_name):\n    model_path = '/'.join([config.get('model_path'), model])\n    backend_config = backend(tp=2)\n    pipe = init_pipeline(model_path, backend_config=backend_config)\n    gen_config = GenerationConfig(min_new_tokens=200, ignore_eos=True)\n    response = pipe(['Hi, pls intro yourself', 'Shanghai is'], gen_config=gen_config)\n    result = True\n    for i in range(2):\n        result &= response[i].finish_reason == 'length'\n        result &= response[i].index == i\n    save_pipeline_common_log(config, file_name, result, response)\n    pipe.close()\n\n\ndef run_pipeline_testcase_stop_words(config, model, backend, file_name):\n    model_path = '/'.join([config.get('model_path'), model])\n    backend_config = backend(tp=2)\n    pipe = init_pipeline(model_path, backend_config=backend_config)\n    gen_config = GenerationConfig(stop_words=[' and', '浦', ' to'])\n    response = pipe(['Hi, pls intro yourself', 'Shanghai is'], gen_config=gen_config)\n    result = True\n    for i in range(2):\n        result &= '浦' not in response[i].text\n        result &= ' and' not in response[i].text and ' to ' not in response[i].text\n        result &= response[i].finish_reason == 'stop' and response[i].generate_token_len < 50\n    save_pipeline_common_log(config, file_name, result, response)\n    pipe.close()\n\n\ndef run_pipeline_testcase_bad_words(config, model, backend, file_name):\n    model_path = '/'.join([config.get('model_path'), model])\n    backend_config = backend(tp=2)\n    pipe = init_pipeline(model_path, backend_config=backend_config)\n    gen_config = GenerationConfig(bad_words=[' and', '浦', ' to'])\n    response = pipe(['Hi, pls intro yourself', 'Shanghai is'], gen_config=gen_config)\n    result = True\n    for i in range(2):\n        result &= '浦' not in response[i].text and ' and' not in response[i].text and ' to ' not in response[i].text\n    save_pipeline_common_log(config, file_name, result, response)\n    pipe.close()\n\n\ndef run_pipeline_testcase_special_words_false(config, model, backend, file_name):\n    model_path = '/'.join([config.get('model_path'), model])\n    backend_config = backend(tp=2)\n    pipe = init_pipeline(model_path, backend_config=backend_config)\n    prompt = '<|im_start|>system\\n当开启工具以及代码时，根据需求选择合适的工具进行调用\\n' + \\\n        '<|im_end|><|im_start|>system name=<|interpreter|>\\n你现在已经' + \\\n        '能够在一个有状态的 Jupyter 笔记本环境中运行 Python 代码。当你向 python ' + \\\n        '发送含有 Python >代码的消息时，它将在该环境中执行。这个工具适用于多种场景，' + \\\n        '如数据分析或处理（包括数据操作、统计分析、图表绘制），复杂的计算问题（解决数学和物理' + \\\n        '难题），编程示例（理解编程概念或特性），文本处理和分析（比如文本解析和自然语言处理），机器学习和数据科学（用于' + \\\n        '展示模型训练和数据可视化），以及文件操作和数据导入（处理CSV、JSON等格式的文件）。<|im_end|>\\n' + \\\n        '<|im_start|>user\\n设 $L$ 为圆周$x^2+y^2=2x$，计算曲线积分：$I=\\\\int_L' + \\\n        '{x\\\\mathrm{d}s}=$<|im_end|>\\n<|im_start|>assistant'\n    gen_config = GenerationConfig(skip_special_tokens=False)\n    response = pipe(prompt, gen_config=gen_config)\n    result = '<|action_start|><|interpreter|>' in response.text\n    save_pipeline_common_log(config, file_name, result, response)\n    pipe.close()\n\n\ndef run_pipeline_testcase_special_words_true(config, model, backend, file_name):\n    model_path = '/'.join([config.get('model_path'), model])\n    backend_config = backend(tp=2)\n    pipe = init_pipeline(model_path, backend_config=backend_config)\n    prompt = '<|im_start|>system\\n当开启工具以及代码时，根据需求选择合适的工具进行调用\\n' + \\\n        '<|im_end|><|im_start|>system name=<|interpreter|>\\n你现在已经' + \\\n        '能够在一个有状态的 Jupyter 笔记本环境中运行 Python 代码。当你向 python ' + \\\n        '发送含有 Python >代码的消息时，它将在该环境中执行。这个工具适用于多种场景，' + \\\n        '如数据分析或处理（包括数据操作、统计分析、图表绘制），复杂的计算问题（解决数学和物理' + \\\n        '难题），编程示例（理解编程概念或特性），文本处理和分析（比如文本解析和自然语言处理），机器学习和数据科学（用于' + \\\n        '展示模型训练和数据可视化），以及文件操作和数据导入（处理CSV、JSON等格式的文件）。<|im_end|>\\n' + \\\n        '<|im_start|>user\\n设 $L$ 为圆周$x^2+y^2=2x$，计算曲线积分：$I=\\\\int_L' + \\\n        '{x\\\\mathrm{d}s}=$<|im_end|>\\n<|im_start|>assistant'\n    gen_config = GenerationConfig(skip_special_tokens=True)\n    response = pipe(prompt, gen_config=gen_config)\n    result = '<|action_start|><|interpreter|>' not in response.text\n    save_pipeline_common_log(config, file_name, result, response)\n    pipe.close()\n\n\ndef run_pipeline_testcase_repetition_penalty(config, model, backend, file_name):\n    model_path = '/'.join([config.get('model_path'), model])\n    backend_config = backend(tp=2)\n    pipe = init_pipeline(model_path, backend_config=backend_config)\n    gen_config = GenerationConfig(repetition_penalty=0.01, random_seed=1, min_new_tokens=50, do_sample=True)\n    response = pipe('Shanghai is', gen_config=gen_config)\n    result, msg = has_repeated_fragment(response.text)\n    save_pipeline_common_log(config, file_name, result, response, msg=msg)\n    pipe.close()\n\n\ndef run_pipeline_testcase_repetition_penalty_bigger(config, model, backend, file_name):\n    model_path = '/'.join([config.get('model_path'), model])\n    backend_config = backend(tp=2)\n    pipe = init_pipeline(model_path, backend_config=backend_config)\n    gen_config = GenerationConfig(repetition_penalty=1.2, random_seed=1)\n    response = pipe('Shanghai is', gen_config=gen_config)\n    result, msg = assert_pipeline_single_return(response)\n    save_pipeline_common_log(config, file_name, result, response, msg)\n    pipe.close()\n\n\ndef run_pipeline_testcase_min_top_p(config, model, backend, file_name):\n    model_path = '/'.join([config.get('model_path'), model])\n    backend_config = backend(tp=2)\n    pipe = init_pipeline(model_path, backend_config=backend_config)\n    gen_config = GenerationConfig(top_p=0, random_seed=1)\n    response = pipe('Shanghai is', gen_config=gen_config)\n    result, msg = assert_pipeline_single_return(response)\n    save_pipeline_common_log(config, file_name, result, response, msg)\n    pipe.close()\n\n\ndef run_pipeline_testcase_min_top_k(config, model, backend, file_name):\n    model_path = '/'.join([config.get('model_path'), model])\n    backend_config = backend(tp=2)\n    pipe = init_pipeline(model_path, backend_config=backend_config)\n    gen_config = GenerationConfig(top_k=1, max_new_tokens=20, do_sample=True)\n    response_list = []\n    for _ in range(3):\n        response_list.append(pipe('Shanghai is', gen_config=gen_config))\n    result = response_list[0].text == response_list[1].text and response_list[1].text == response_list[2].text\n    save_pipeline_common_log(config, file_name, result, response_list)\n    pipe.close()\n\n\ndef run_pipeline_testcase_diff_random_seed(config, model, backend, file_name):\n    model_path = '/'.join([config.get('model_path'), model])\n    backend_config = backend(tp=2)\n    pipe = init_pipeline(model_path, backend_config=backend_config)\n    response_list = []\n    for i in range(3):\n        gen_config = GenerationConfig(random_seed=i, temperature=1.0, top_k=40, do_sample=True)\n        response_list.append(pipe('Shanghai is', gen_config=gen_config))\n    result = response_list[0].text != response_list[1].text and response_list[1].text != response_list[2].text\n    save_pipeline_common_log(config, file_name, result, response_list)\n    pipe.close()\n\n\ndef run_pipeline_testcase_same_random_seed(config, model, backend, file_name):\n    model_path = '/'.join([config.get('model_path'), model])\n    backend_config = backend(tp=2)\n    pipe = init_pipeline(model_path, backend_config=backend_config)\n    gen_config = GenerationConfig(random_seed=1, top_k=40, do_sample=True)\n    response_list = []\n    for _ in range(3):\n        response_list.append(pipe('Shanghai is', gen_config=gen_config))\n    result = response_list[0].text == response_list[1].text and response_list[1].text == response_list[2].text\n    save_pipeline_common_log(config, file_name, result, response_list)\n    pipe.close()\n\n\ndef run_pipeline_testcase_do_sample_batch(config, model, backend, file_name):\n    model_path = '/'.join([config.get('model_path'), model])\n    backend_config = backend(tp=2)\n    pipe = init_pipeline(model_path, backend_config=backend_config)\n    gen_config = GenerationConfig(temperature=1.0, top_k=40, do_sample=True)\n    response = pipe(['Shanghai is'] * 3, gen_config=gen_config)\n    result = response[0].text != response[1].text and response[1].text != response[2].text\n    save_pipeline_common_log(config, file_name, result, response)\n    pipe.close()\n\n\ndef run_pipeline_testcase_max_new_tokens(config, model, backend, file_name):\n    model_path = '/'.join([config.get('model_path'), model])\n    backend_config = backend(tp=2)\n    pipe = init_pipeline(model_path, backend_config=backend_config)\n    gen_config = GenerationConfig(max_new_tokens=5)\n    response = pipe(['Hi, pls intro yourself', 'Shanghai is'], gen_config=gen_config)\n    result = True\n    for i in range(2):\n        result &= response[i].finish_reason == 'length'\n        result &= response[i].generate_token_len == 6 or response[i].generate_token_len == 5\n    save_pipeline_common_log(config, file_name, result, response)\n    pipe.close()\n\n\ndef run_pipeline_testcase_ignore_eos(config, model, backend, file_name):\n    model_path = '/'.join([config.get('model_path'), model])\n    backend_config = backend(tp=2)\n    pipe = init_pipeline(model_path, backend_config=backend_config)\n    gen_config = GenerationConfig(ignore_eos=True, max_new_tokens=256)\n    response = pipe(['Hi, pls intro yourself', 'Shanghai is'], gen_config=gen_config)\n    result = True\n    for i in range(2):\n        result &= response[i].finish_reason == 'length'\n        result &= response[i].generate_token_len == 257 or response[i].generate_token_len == 256\n    save_pipeline_common_log(config, file_name, result, response)\n    pipe.close()\n\n\n@pytest.mark.parametrize('model', ['OpenGVLab/InternVL3_5-30B-A3B', 'Qwen/Qwen3-30B-A3B'])\n@pytest.mark.parametrize('backend', [TurbomindEngineConfig, PytorchEngineConfig])\ndef test_return_with_prompt(config, model, backend, worker_id):\n\n    file_name = f'pipeline_log_{worker_id}.txt'\n    run_case_in_spawn(worker_id, run_pipeline_testcase_prompt, (config, model, backend, file_name))\n    assert_pipeline_common_log(config, file_name)\n\n\n@pytest.mark.parametrize('model', ['OpenGVLab/InternVL3_5-30B-A3B', 'Qwen/Qwen3-30B-A3B'])\n@pytest.mark.parametrize('backend', [TurbomindEngineConfig, PytorchEngineConfig])\ndef test_return_with_prompt_stream(config, model, backend, worker_id):\n\n    file_name = f'pipeline_log_{worker_id}.txt'\n    run_case_in_spawn(worker_id, run_pipeline_testcase_prompt_stream, (config, model, backend, file_name))\n    assert_pipeline_common_log(config, file_name)\n\n\n@pytest.mark.parametrize('model', ['OpenGVLab/InternVL3_5-30B-A3B', 'Qwen/Qwen3-30B-A3B'])\n@pytest.mark.parametrize('backend', [TurbomindEngineConfig, PytorchEngineConfig])\ndef test_return_with_multi_prompt(config, model, backend, worker_id):\n\n    file_name = f'pipeline_log_{worker_id}.txt'\n    run_case_in_spawn(worker_id, run_pipeline_testcase_multi_prompt, (config, model, backend, file_name))\n    assert_pipeline_common_log(config, file_name)\n\n\n@pytest.mark.parametrize('model', ['OpenGVLab/InternVL3_5-30B-A3B', 'Qwen/Qwen3-30B-A3B'])\n@pytest.mark.parametrize('backend', [TurbomindEngineConfig, PytorchEngineConfig])\ndef test_return_with_multi_prompt_stream(config, model, backend, worker_id):\n\n    file_name = f'pipeline_log_{worker_id}.txt'\n    run_case_in_spawn(worker_id, run_pipeline_testcase_multi_prompt_stream, (config, model, backend, file_name))\n    assert_pipeline_common_log(config, file_name)\n\n\n@pytest.mark.parametrize('model', ['OpenGVLab/InternVL3_5-30B-A3B', 'Qwen/Qwen3-30B-A3B'])\n@pytest.mark.parametrize('backend', [TurbomindEngineConfig, PytorchEngineConfig])\ndef test_return_with_message(config, model, backend, worker_id):\n    file_name = f'pipeline_log_{worker_id}.txt'\n    run_case_in_spawn(worker_id, run_pipeline_testcase_message, (config, model, backend, file_name))\n    assert_pipeline_common_log(config, file_name)\n\n\n@pytest.mark.parametrize('model', ['OpenGVLab/InternVL3_5-30B-A3B', 'Qwen/Qwen3-30B-A3B'])\n@pytest.mark.parametrize('backend', [TurbomindEngineConfig, PytorchEngineConfig])\ndef test_return_with_message_stream(config, model, backend, worker_id):\n    file_name = f'pipeline_log_{worker_id}.txt'\n    run_case_in_spawn(worker_id, run_pipeline_testcase_message_stream, (config, model, backend, file_name))\n    assert_pipeline_common_log(config, file_name)\n\n\n@pytest.mark.parametrize('model', ['OpenGVLab/InternVL3_5-30B-A3B', 'Qwen/Qwen3-30B-A3B'])\n@pytest.mark.parametrize('backend', [TurbomindEngineConfig, PytorchEngineConfig])\ndef test_return_with_message_batch(config, model, backend, worker_id):\n    file_name = f'pipeline_log_{worker_id}.txt'\n    run_case_in_spawn(worker_id, run_pipeline_testcase_message_batch, (config, model, backend, file_name))\n    assert_pipeline_common_log(config, file_name)\n\n\n@pytest.mark.parametrize('model', ['OpenGVLab/InternVL3_5-30B-A3B', 'Qwen/Qwen3-30B-A3B'])\n@pytest.mark.parametrize('backend', [TurbomindEngineConfig, PytorchEngineConfig])\ndef test_return_with_message_batch_stream(config, model, backend, worker_id):\n    file_name = f'pipeline_log_{worker_id}.txt'\n    run_case_in_spawn(worker_id, run_pipeline_testcase_message_batch_stream, (config, model, backend, file_name))\n    assert_pipeline_common_log(config, file_name)\n\n\n@pytest.mark.parametrize('model', ['OpenGVLab/InternVL3_5-30B-A3B', 'Qwen/Qwen3-30B-A3B'])\n@pytest.mark.parametrize('backend', [TurbomindEngineConfig])\ndef test_return_check_logprobs(config, model, backend, worker_id):\n    file_name = f'pipeline_log_{worker_id}.txt'\n    run_case_in_spawn(worker_id, run_pipeline_testcase_logprobs, (config, model, backend, file_name))\n    assert_pipeline_common_log(config, file_name)\n\n\n@pytest.mark.parametrize('model', ['OpenGVLab/InternVL3_5-30B-A3B', 'Qwen/Qwen3-30B-A3B'])\n@pytest.mark.parametrize('backend', [TurbomindEngineConfig])\ndef test_return_check_logprobs_stream(config, model, backend, worker_id):\n    file_name = f'pipeline_log_{worker_id}.txt'\n    run_case_in_spawn(worker_id, run_pipeline_testcase_logprobs_stream, (config, model, backend, file_name))\n    assert_pipeline_common_log(config, file_name)\n\n\n@pytest.mark.parametrize('model', ['OpenGVLab/InternVL3_5-30B-A3B', 'Qwen/Qwen3-30B-A3B'])\n@pytest.mark.parametrize('backend', [TurbomindEngineConfig, PytorchEngineConfig])\ndef test_backend_config_session_len(config, model, backend, worker_id):\n    file_name = f'pipeline_log_{worker_id}.txt'\n    run_case_in_spawn(worker_id, run_pipeline_testcase_session_len, (config, model, backend, file_name))\n    assert_pipeline_common_log(config, file_name)\n\n\n@pytest.mark.parametrize('model', ['OpenGVLab/InternVL3_5-30B-A3B', 'Qwen/Qwen3-30B-A3B'])\n@pytest.mark.parametrize('backend', [TurbomindEngineConfig, PytorchEngineConfig])\ndef test_gen_config_min_new_tokens(config, model, backend, worker_id):\n    file_name = f'pipeline_log_min_new_tokens_{worker_id}.txt'\n    run_case_in_spawn(worker_id, run_pipeline_testcase_min_new_tokens, (config, model, backend, file_name))\n    assert_pipeline_common_log(config, file_name)\n\n\n@pytest.mark.parametrize('model', ['OpenGVLab/InternVL3_5-30B-A3B', 'Qwen/Qwen3-30B-A3B'])\n@pytest.mark.parametrize('backend', [TurbomindEngineConfig, PytorchEngineConfig])\ndef test_gen_config_stop_words(config, model, backend, worker_id):\n    file_name = f'pipeline_log_stop_words_{worker_id}.txt'\n    run_case_in_spawn(worker_id, run_pipeline_testcase_stop_words, (config, model, backend, file_name))\n    assert_pipeline_common_log(config, file_name)\n\n\n@pytest.mark.parametrize('model', ['OpenGVLab/InternVL3_5-30B-A3B', 'Qwen/Qwen3-30B-A3B'])\n@pytest.mark.parametrize('backend', [TurbomindEngineConfig, PytorchEngineConfig])\ndef test_gen_config_bad_words(config, model, backend, worker_id):\n    file_name = f'pipeline_log_bad_words_{worker_id}.txt'\n    run_case_in_spawn(worker_id, run_pipeline_testcase_bad_words, (config, model, backend, file_name))\n    assert_pipeline_common_log(config, file_name)\n\n\n@pytest.mark.parametrize('model', ['OpenGVLab/InternVL3_5-30B-A3B'])\n@pytest.mark.parametrize('backend', [TurbomindEngineConfig, PytorchEngineConfig])\ndef test_gen_config_special_words_false(config, model, backend, worker_id):\n    file_name = f'pipeline_log_special_words_{worker_id}.txt'\n    run_case_in_spawn(worker_id, run_pipeline_testcase_special_words_false, (config, model, backend, file_name))\n    assert_pipeline_common_log(config, file_name)\n\n\n@pytest.mark.parametrize('model', ['OpenGVLab/InternVL3_5-30B-A3B', 'Qwen/Qwen3-30B-A3B'])\n@pytest.mark.parametrize('backend', [TurbomindEngineConfig, PytorchEngineConfig])\ndef test_gen_config_special_words_true(config, model, backend, worker_id):\n    file_name = f'pipeline_log_special_words_{worker_id}.txt'\n    run_case_in_spawn(worker_id, run_pipeline_testcase_special_words_true, (config, model, backend, file_name))\n    assert_pipeline_common_log(config, file_name)\n\n\n@pytest.mark.parametrize('model', ['OpenGVLab/InternVL3_5-30B-A3B'])\n@pytest.mark.parametrize('backend', [TurbomindEngineConfig, PytorchEngineConfig])\ndef test_gen_config_minimum_repetition_penalty(config, model, backend, worker_id):\n    file_name = f'pipeline_log_repetition_penalty_{worker_id}.txt'\n    run_case_in_spawn(worker_id, run_pipeline_testcase_repetition_penalty, (config, model, backend, file_name))\n    assert_pipeline_common_log(config, file_name)\n\n\n@pytest.mark.parametrize('model', ['OpenGVLab/InternVL3_5-30B-A3B', 'Qwen/Qwen3-30B-A3B'])\n@pytest.mark.parametrize('backend', [TurbomindEngineConfig, PytorchEngineConfig])\ndef test_gen_config_repetition_penalty_bigger_than_1(config, model, backend, worker_id):\n    file_name = f'pipeline_log_repetition_penalty_{worker_id}.txt'\n    run_case_in_spawn(worker_id, run_pipeline_testcase_repetition_penalty_bigger, (config, model, backend, file_name))\n    assert_pipeline_common_log(config, file_name)\n\n\n@pytest.mark.parametrize('model', ['OpenGVLab/InternVL3_5-30B-A3B', 'Qwen/Qwen3-30B-A3B'])\n@pytest.mark.parametrize('backend', [TurbomindEngineConfig, PytorchEngineConfig])\ndef test_gen_config_minimun_topp(config, model, backend, worker_id):\n    file_name = f'pipeline_log_{worker_id}.txt'\n    run_case_in_spawn(worker_id, run_pipeline_testcase_min_top_p, (config, model, backend, file_name))\n    assert_pipeline_common_log(config, file_name)\n\n\n@pytest.mark.parametrize('model', ['OpenGVLab/InternVL3_5-30B-A3B', 'Qwen/Qwen3-30B-A3B'])\n@pytest.mark.parametrize('backend', [TurbomindEngineConfig, PytorchEngineConfig])\ndef test_gen_config_minimun_topk(config, model, backend, worker_id):\n    file_name = f'pipeline_log_{worker_id}.txt'\n    run_case_in_spawn(worker_id, run_pipeline_testcase_min_top_k, (config, model, backend, file_name))\n    assert_pipeline_common_log(config, file_name)\n\n\n@pytest.mark.parametrize('model', ['OpenGVLab/InternVL3_5-30B-A3B', 'Qwen/Qwen3-30B-A3B'])\n@pytest.mark.parametrize('backend', [TurbomindEngineConfig, PytorchEngineConfig])\ndef test_gen_config_diff_random_seed(config, model, backend, worker_id):\n    file_name = f'pipeline_log_{worker_id}.txt'\n    run_case_in_spawn(worker_id, run_pipeline_testcase_diff_random_seed, (config, model, backend, file_name))\n    assert_pipeline_common_log(config, file_name)\n\n\n@pytest.mark.parametrize('model', ['OpenGVLab/InternVL3_5-30B-A3B', 'Qwen/Qwen3-30B-A3B'])\n@pytest.mark.parametrize('backend', [TurbomindEngineConfig, PytorchEngineConfig])\ndef test_gen_config_same_random_seed(config, model, backend, worker_id):\n    file_name = f'pipeline_log_{worker_id}.txt'\n    run_case_in_spawn(worker_id, run_pipeline_testcase_same_random_seed, (config, model, backend, file_name))\n    assert_pipeline_common_log(config, file_name)\n\n\n@pytest.mark.parametrize('model', ['OpenGVLab/InternVL3_5-30B-A3B', 'Qwen/Qwen3-30B-A3B'])\n@pytest.mark.parametrize('backend', [TurbomindEngineConfig, PytorchEngineConfig])\ndef test_gen_config_do_sample_batch(config, model, backend, worker_id):\n    file_name = f'pipeline_log_{worker_id}.txt'\n    run_case_in_spawn(worker_id, run_pipeline_testcase_do_sample_batch, (config, model, backend, file_name))\n    assert_pipeline_common_log(config, file_name)\n\n\n@pytest.mark.parametrize('model', ['OpenGVLab/InternVL3_5-30B-A3B', 'Qwen/Qwen3-30B-A3B'])\n@pytest.mark.parametrize('backend', [TurbomindEngineConfig, PytorchEngineConfig])\ndef test_gen_config_max_new_tokens(config, model, backend, worker_id):\n    file_name = f'pipeline_log_max_new_tokens_{worker_id}.txt'\n    run_case_in_spawn(worker_id, run_pipeline_testcase_max_new_tokens, (config, model, backend, file_name))\n    assert_pipeline_common_log(config, file_name)\n\n\n@pytest.mark.parametrize('model', ['OpenGVLab/InternVL3_5-30B-A3B', 'Qwen/Qwen3-30B-A3B'])\n@pytest.mark.parametrize('backend', [TurbomindEngineConfig, PytorchEngineConfig])\ndef test_gen_config_ignore_eos(config, model, backend, worker_id):\n    file_name = f'pipeline_log_ignore_eos_{worker_id}.txt'\n    run_case_in_spawn(worker_id, run_pipeline_testcase_ignore_eos, (config, model, backend, file_name))\n    assert_pipeline_common_log(config, file_name)\n\n\n@pytest.mark.parametrize('model', ['OpenGVLab/InternVL3_5-30B-A3B', 'Qwen/Qwen3-30B-A3B'])\n@pytest.mark.parametrize('backend', [TurbomindEngineConfig, PytorchEngineConfig])\ndef test_backend_config_input_validation(config, model, backend, worker_id):\n    if 'gw' in worker_id:\n        set_device_env_variable(worker_id, parallel_config=2)\n    model_path = '/'.join([config.get('model_path'), model])\n    backend_config = backend(tp=2)\n    pipe = init_pipeline(model_path, backend_config=backend_config)\n    with pytest.raises(AssertionError):\n        gen_config = GenerationConfig(top_p=-0.01)\n        pipe('Shanghai is', gen_config=gen_config)\n\n    with pytest.raises(AssertionError):\n        gen_config = GenerationConfig(top_p=1.01)\n        pipe('Shanghai is', gen_config=gen_config)\n\n    with pytest.raises(AssertionError):\n        gen_config = GenerationConfig(temperature=-1)\n        pipe('Shanghai is', gen_config=gen_config)\n\n    with pytest.raises(AssertionError):\n        gen_config = GenerationConfig(temperature=2.01)\n        pipe('Shanghai is', gen_config=gen_config)\n\n    with pytest.raises(AssertionError):\n        gen_config = GenerationConfig(top_k=-1)\n        pipe('Shanghai is', gen_config=gen_config)\n\n    with pytest.raises(AssertionError):\n        gen_config = GenerationConfig(n=-1)\n        pipe('Shanghai is', gen_config=gen_config)\n\n    pipe.close()\n    if 'gw' in worker_id:\n        unset_device_env_variable()\n\n\n@pytest.mark.parametrize('model', ['OpenGVLab/InternVL3_5-30B-A3B', 'Qwen/Qwen3-30B-A3B'])\n@pytest.mark.parametrize('backend', [TurbomindEngineConfig])\ndef test_backend_config_validate_turbomind(config, model, backend, worker_id):\n    if 'gw' in worker_id:\n        set_device_env_variable(worker_id, parallel_config=2)\n    model_path = '/'.join([config.get('model_path'), model])\n    with pytest.raises(pydantic.ValidationError, match='tp must be a positive integer'):\n        backend_config = backend(tp=0)\n        pipeline(model_path, backend_config=backend_config)\n\n    with pytest.raises(AssertionError, match='max_batch_size should be greater than 0, but got 0'):\n        backend_config = backend(max_batch_size=0)\n        pipeline(model_path, backend_config=backend_config)\n\n    with pytest.raises(pydantic.ValidationError):\n        backend_config = backend(cache_max_entry_count=0)\n        pipeline(model_path, backend_config=backend_config)\n\n    with pytest.raises(pydantic.ValidationError):\n        backend_config = backend(quant_policy=1)\n        pipeline(model_path, backend_config=backend_config)\n\n    with pytest.raises(pydantic.ValidationError):\n        backend_config = backend(rope_scaling_factor=-1)\n        pipeline(model_path, backend_config=backend_config)\n\n    with pytest.raises(pydantic.ValidationError):\n        backend_config = backend(max_prefill_token_num=-1)\n        pipeline(model_path, backend_config=backend_config)\n\n    with pytest.raises(pydantic.ValidationError):\n        backend_config = backend(num_tokens_per_iter=-1)\n        pipeline(model_path, backend_config=backend_config)\n\n    if 'gw' in worker_id:\n        unset_device_env_variable()\n\n\n@pytest.mark.parametrize('model', ['OpenGVLab/InternVL3_5-30B-A3B', 'Qwen/Qwen3-30B-A3B'])\n@pytest.mark.parametrize('backend', [PytorchEngineConfig])\ndef test_backend_config_validate_pytorch(config, model, backend, worker_id):\n    if 'gw' in worker_id:\n        set_device_env_variable(worker_id, parallel_config=2)\n    model_path = '/'.join([config.get('model_path'), model])\n    with pytest.raises(AssertionError):\n        backend_config = backend(tp=0)\n        init_pipeline(model_path, backend_config=backend_config)\n\n    with pytest.raises(SystemExit):\n        backend_config = backend(max_batch_size=0)\n        init_pipeline(model_path, backend_config=backend_config)\n\n    with pytest.raises(AssertionError):\n        backend_config = backend(cache_max_entry_count=0)\n        init_pipeline(model_path, backend_config=backend_config)\n\n    with pytest.raises(AssertionError):\n        backend_config = backend(num_cpu_blocks=-1)\n        init_pipeline(model_path, backend_config=backend_config)\n\n    with pytest.raises(AssertionError):\n        backend_config = backend(num_gpu_blocks=-1)\n        init_pipeline(model_path, backend_config=backend_config)\n\n    if 'gw' in worker_id:\n        unset_device_env_variable()\n\n\n@pytest.mark.parametrize('model', ['OpenGVLab/InternVL3_5-30B-A3B'])\n@pytest.mark.parametrize('backend', [TurbomindEngineConfig])\ndef test_backend_config_tp(config, model, backend, worker_id):\n    with pytest.raises(AssertionError):\n        if 'gw' in worker_id:\n            set_device_env_variable(worker_id, parallel_config=2)\n        model_path = '/'.join([config.get('model_path'), model])\n        backend_config = backend(tp=100)\n        pipe = init_pipeline(model_path, backend_config=backend_config)\n        pipe.close()\n        if 'gw' in worker_id:\n            unset_device_env_variable()\n"
  },
  {
    "path": "autotest/interface/pipeline/test_pipeline_longtext_func.py",
    "content": "import multiprocessing as mp\nimport os\n\nimport numpy as np\nimport pytest\nfrom utils.config_utils import set_device_env_variable, unset_device_env_variable\n\nfrom lmdeploy import GenerationConfig, PytorchEngineConfig, TurbomindEngineConfig, pipeline\n\nSESSION_LEN = 198000\nSESSION_LEN_128K = 128000\nSESSION_LEN_32K = 32000\n\nSESSION_LEN_CONFIG = {\n    'Qwen/Qwen2.5-7B-Instruct': SESSION_LEN_32K,\n    'Qwen/Qwen3-235B-A22B': SESSION_LEN_128K,\n    'Qwen/Qwen3-30B-A3B': SESSION_LEN_128K,\n    'Qwen/Qwen3-32B': SESSION_LEN_128K,\n    'meta-llama/Meta-Llama-3-1-8B-Instruct': SESSION_LEN_128K,\n    'meta-llama/Meta-Llama-3-1-70B-Instruct': SESSION_LEN_128K,\n}\n\n\ndef run_case_in_spawn(target, args):\n    ctx = mp.get_context('spawn')\n    process = ctx.Process(target=target, args=args)\n    process.start()\n    process.join()\n\n\n@pytest.mark.gpu_num_1\n@pytest.mark.parametrize('model', ['Qwen/Qwen3-8B'])\ndef test_history_issue_tp1(config, model, worker_id):\n    if 'gw' in worker_id:\n        set_device_env_variable(worker_id)\n    run_case_in_spawn(stream_infer_worker, (config, model, 1))\n    if 'gw' in worker_id:\n        unset_device_env_variable()\n\n\n@pytest.mark.gpu_num_2\n@pytest.mark.parametrize('model', ['Qwen/Qwen3-32B', 'Qwen/Qwen3-32B-inner-4bits', 'Qwen/Qwen3-30B-A3B'])\ndef test_history_issue_tp2(config, model, worker_id):\n    if 'gw' in worker_id:\n        set_device_env_variable(worker_id, parallel_config=2)\n        os.environ['MASTER_PORT'] = str(int(worker_id.replace('gw', '')) + 29500)\n    run_case_in_spawn(stream_infer_worker, (config, model, 2))\n    if 'gw' in worker_id:\n        unset_device_env_variable()\n\n\ndef stream_infer_worker(config, model, tp_num):\n    model_path = os.path.join(config.get('model_path'), model)\n\n    backend_config = TurbomindEngineConfig(session_len=SESSION_LEN, tp=tp_num)\n    pipe = pipeline(model_path, backend_config=backend_config)\n    prompt = '今 天 心 ' * int(SESSION_LEN / 6)\n\n    gen_config = GenerationConfig(top_k=40)\n    # stream infer\n    for outputs in pipe.stream_infer(prompt, gen_config=gen_config):\n        continue\n    print(outputs)\n\n    prompts = ['今 天 心 ' * int(SESSION_LEN / 6)] * 2\n    # stream infer\n    for outputs in pipe.stream_infer(prompts, gen_config=gen_config):\n        continue\n    print(outputs)\n\n    pipe.close()\n\n\n@pytest.mark.gpu_num_1\n@pytest.mark.parametrize('model', ['Qwen/Qwen2.5-7B-Instruct', 'meta-llama/Meta-Llama-3-1-8B-Instruct'])\n@pytest.mark.parametrize('backend', ['turbomind', 'pytorch'])\ndef test_long_test_passkey_tp1(config, model, backend, worker_id):\n    log_name = ''.join(['pipeline_longtext_passkey_', worker_id, '.log'])\n    if 'gw' in worker_id:\n        set_device_env_variable(worker_id)\n    run_case_in_spawn(passkey_retrival_worker,\n                      (config, model, backend, log_name, 1, SESSION_LEN_CONFIG.get(model, SESSION_LEN_128K)))\n    if 'gw' in worker_id:\n        unset_device_env_variable()\n\n\n@pytest.mark.gpu_num_2\n@pytest.mark.parametrize('model', ['Qwen/Qwen3-30B-A3B', 'Qwen/Qwen3-32B'])\n@pytest.mark.parametrize('backend', ['turbomind', 'pytorch'])\ndef test_long_test_passkey_tp2(config, model, backend, worker_id):\n    log_name = ''.join(['pipeline_longtext_passkey_', worker_id, '.log'])\n    if 'gw' in worker_id:\n        set_device_env_variable(worker_id, parallel_config=2)\n        os.environ['MASTER_PORT'] = str(int(worker_id.replace('gw', '')) + 29500)\n    run_case_in_spawn(passkey_retrival_worker,\n                      (config, model, backend, log_name, 2, SESSION_LEN_CONFIG.get(model, SESSION_LEN_128K)))\n    if 'gw' in worker_id:\n        unset_device_env_variable()\n\n\n@pytest.mark.gpu_num_8\n@pytest.mark.parametrize('model', ['Qwen/Qwen3-235B-A22B', 'meta-llama/Meta-Llama-3-1-70B-Instruct'])\n@pytest.mark.parametrize('backend', ['turbomind', 'pytorch'])\ndef test_long_test_passkey_tp8(config, model, backend, worker_id):\n    log_name = ''.join(['pipeline_longtext_passkey_', worker_id, '.log'])\n    if 'gw' in worker_id:\n        set_device_env_variable(worker_id, parallel_config=8)\n        os.environ['MASTER_PORT'] = str(int(worker_id.replace('gw', '')) + 29500)\n    run_case_in_spawn(passkey_retrival_worker,\n                      (config, model, backend, log_name, 8, SESSION_LEN_CONFIG.get(model, SESSION_LEN_128K)))\n    if 'gw' in worker_id:\n        unset_device_env_variable()\n\n\nYARN_CONFIG = {'rope_scaling': {'rope_type': 'yarn', 'factor': 4.0, 'original_max_position_embeddings': 32768}}\n\nNTK_CONFIG = {\n    'rope_scaling': {\n        'type': 'dynamic',\n        'factor': 2.0\n    },\n}\n\n\ndef passkey_retrival_worker(config, model, backend, log_name, tp_num, session_len: int = SESSION_LEN_128K):\n    model_path = '/'.join([config.get('model_path'), model])\n    if backend == 'turbomind':\n        if 'qwen' in model.lower():\n            backend_config = TurbomindEngineConfig(session_len=session_len,\n                                                   max_batch_size=1,\n                                                   cache_max_entry_count=0.7,\n                                                   tp=tp_num,\n                                                   hf_overrides=YARN_CONFIG)\n        elif 'intern-s1' in model.lower():\n            backend_config = TurbomindEngineConfig(session_len=session_len,\n                                                   max_batch_size=1,\n                                                   cache_max_entry_count=0.7,\n                                                   tp=tp_num,\n                                                   hf_overrides={'text_config': NTK_CONFIG})\n        else:\n            backend_config = TurbomindEngineConfig(session_len=session_len,\n                                                   max_batch_size=1,\n                                                   cache_max_entry_count=0.7,\n                                                   tp=tp_num)\n    else:\n        if 'qwen' in model.lower():\n            backend_config = PytorchEngineConfig(session_len=session_len,\n                                                 tp=tp_num,\n                                                 max_batch_size=1,\n                                                 hf_overrides=YARN_CONFIG)\n        elif 'intern-s1' in model.lower():\n            backend_config = TurbomindEngineConfig(session_len=session_len,\n                                                   max_batch_size=1,\n                                                   cache_max_entry_count=0.7,\n                                                   tp=tp_num,\n                                                   hf_overrides={'text_config': NTK_CONFIG})\n        else:\n            backend_config = PytorchEngineConfig(session_len=session_len, tp=tp_num, max_batch_size=1)\n\n    pipe = pipeline(model_path, backend_config=backend_config)\n\n    gen_config = GenerationConfig(top_k=40)\n    # inference\n    pass_key1, prompt = get_passkey_prompt(pipe, session_len)\n    response1 = pipe(prompt, gen_config=gen_config)\n\n    # inference\n    pass_key2, prompt = get_passkey_prompt(pipe, session_len)\n    response2 = pipe([prompt] * 2, gen_config=gen_config)\n\n    pipe.close()\n\n    assert str(pass_key1) in response1.text, str(response1)\n    assert str(pass_key2) in response2[0].text and str(pass_key2) in response2[1].text, str(response2)\n\n\ndef get_passkey_prompt(pipe, session_len):\n    # create long context input\n    tok = pipe.tokenizer\n    task_description = 'There is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there.'  # noqa: E501\n    garbage = 'The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again.'  # noqa: E501\n\n    n_times = (session_len - 1000) // len(tok.encode(garbage))\n    n_garbage_prefix = np.random.randint(0, n_times)\n    n_garbage_suffix = n_times - n_garbage_prefix\n    garbage_prefix = ' '.join([garbage] * n_garbage_prefix)\n    garbage_suffix = ' '.join([garbage] * n_garbage_suffix)\n    pass_key = np.random.randint(1, 50000)\n    information_line = f'The pass key is {pass_key}. Remember it. {pass_key} is the pass key.'  # noqa: E501\n    final_question = 'What is the pass key? The pass key is'\n    lines = [\n        task_description,\n        garbage_prefix,\n        information_line,\n        garbage_suffix,\n        final_question,\n    ]\n\n    # inference\n    prompt = ' '.join(lines)\n    return pass_key, prompt\n"
  },
  {
    "path": "autotest/interface/restful/test_restful_chat_completions_v1.py",
    "content": "from typing import Literal\n\nimport pytest\nfrom openai import OpenAI\nfrom utils.constant import BACKEND_LIST, RESTFUL_MODEL_LIST\nfrom utils.restful_return_check import (assert_chat_completions_batch_return, assert_chat_completions_stream_return,\n                                        has_repeated_fragment)\n\nfrom lmdeploy.serve.openai.api_client import APIClient, get_model_list\n\nBASE_HTTP_URL = 'http://localhost'\nDEFAULT_PORT = 23333\nMODEL = 'internlm/Intern-S1'\nBASE_URL = ':'.join([BASE_HTTP_URL, str(DEFAULT_PORT)])\n\n\n@pytest.mark.order(8)\n@pytest.mark.chat\n@pytest.mark.flaky(reruns=2)\n@pytest.mark.parametrize('backend', BACKEND_LIST)\n@pytest.mark.parametrize('model_case', RESTFUL_MODEL_LIST)\nclass TestRestfulInterfaceBase:\n\n    @pytest.mark.interns1\n    def test_get_model(self, config, backend, model_case):\n        api_client = APIClient(BASE_URL)\n        model_name = api_client.available_models[0]\n        assert model_name == '/'.join([config.get('model_path'), MODEL]), api_client.available_models\n\n        model_list = get_model_list(BASE_URL + '/v1/models')\n        assert model_name in model_list, model_list\n\n    @pytest.mark.interns1\n    def test_encode_s1(self, backend, model_case):\n        api_client = APIClient(BASE_URL)\n        input_ids1, length1 = api_client.encode('Hi, pls intro yourself')\n        input_ids2, length2 = api_client.encode('Hi, pls intro yourself', add_bos=False)\n        input_ids3, length3 = api_client.encode('Hi, pls intro yourself', do_preprocess=True)\n        input_ids4, length4 = api_client.encode('Hi, pls intro yourself', do_preprocess=True, add_bos=False)\n        input_ids5, length5 = api_client.encode('Hi, pls intro yourself' * 100, add_bos=False)\n\n        assert len(input_ids1) == length1 and length1 > 0\n        assert len(input_ids2) == length2 and length2 > 0\n        assert len(input_ids3) == length3 and length3 > 0\n        assert len(input_ids4) == length4 and length4 > 0\n        assert len(input_ids5) == length5 and length5 > 0\n        assert length1 == length2\n        assert input_ids2 == input_ids1\n        assert input_ids1[0] == 13048 and input_ids3[0] == 151644\n        assert length5 == length2 * 100\n        assert input_ids5 == input_ids2 * 100\n\n    @pytest.mark.internlm2_5\n    def test_encode(self, backend, model_case):\n        api_client = APIClient(BASE_URL)\n        input_ids1, length1 = api_client.encode('Hi, pls intro yourself')\n        input_ids2, length2 = api_client.encode('Hi, pls intro yourself', add_bos=False)\n        input_ids3, length3 = api_client.encode('Hi, pls intro yourself', do_preprocess=True)\n        input_ids4, length4 = api_client.encode('Hi, pls intro yourself', do_preprocess=True, add_bos=False)\n        input_ids5, length5 = api_client.encode('Hi, pls intro yourself' * 100, add_bos=False)\n\n        assert len(input_ids1) == length1 and length1 > 0\n        assert len(input_ids2) == length2 and length2 > 0\n        assert len(input_ids3) == length3 and length3 > 0\n        assert len(input_ids4) == length4 and length4 > 0\n        assert len(input_ids5) == length5 and length5 > 0\n        assert length1 == length2 + 1\n        assert input_ids2 == input_ids1[1:]\n        assert input_ids1[0] == 1 and input_ids3[0] == 1\n        assert length5 == length2 * 100\n        assert input_ids5 == input_ids2 * 100\n\n\n@pytest.mark.order(8)\n@pytest.mark.flaky(reruns=2)\n@pytest.mark.parametrize('backend', BACKEND_LIST)\n@pytest.mark.parametrize('model_case', RESTFUL_MODEL_LIST)\nclass TestRestfulInterfaceChatCompletions:\n\n    def test_return_info_with_prompt(self, backend, model_case):\n        api_client = APIClient(BASE_URL)\n        model_name = api_client.available_models[0]\n        for output in api_client.chat_completions_v1(model=model_name,\n                                                     messages=[\n                                                         {\n                                                             'role': 'user',\n                                                             'content': 'Hi, pls intro yourself'\n                                                         },\n                                                     ],\n                                                     temperature=0.01):\n            continue\n        assert_chat_completions_batch_return(output, model_name)\n\n    def test_return_info_with_messegae(self, backend, model_case):\n        api_client = APIClient(BASE_URL)\n        model_name = api_client.available_models[0]\n        for output in api_client.chat_completions_v1(model=model_name,\n                                                     messages=[{\n                                                         'role': 'user',\n                                                         'content': 'Hi, pls intro yourself'\n                                                     }],\n                                                     temperature=0.01):\n            continue\n        assert_chat_completions_batch_return(output, model_name)\n\n    def test_return_info_with_prompt_streaming(self, backend, model_case):\n        api_client = APIClient(BASE_URL)\n        model_name = api_client.available_models[0]\n        outputList = []\n        for output in api_client.chat_completions_v1(model=model_name,\n                                                     messages=[\n                                                         {\n                                                             'role': 'user',\n                                                             'content': 'Hi, pls intro yourself'\n                                                         },\n                                                     ],\n                                                     stream=True,\n                                                     temperature=0.01):\n            outputList.append(output)\n\n        assert_chat_completions_stream_return(outputList[-1], model_name, True)\n        for index in range(0, len(outputList) - 1):\n            assert_chat_completions_stream_return(outputList[index], model_name)\n\n    def test_return_info_with_messegae_streaming(self, backend, model_case):\n        api_client = APIClient(BASE_URL)\n        model_name = api_client.available_models[0]\n        outputList = []\n        for output in api_client.chat_completions_v1(model=model_name,\n                                                     messages=[{\n                                                         'role': 'user',\n                                                         'content': 'Hi, pls intro yourself'\n                                                     }],\n                                                     stream=True,\n                                                     temperature=0.01):\n            outputList.append(output)\n\n        assert_chat_completions_stream_return(outputList[-1], model_name, True)\n        for index in range(0, len(outputList) - 1):\n            assert_chat_completions_stream_return(outputList[index], model_name)\n\n    def test_single_stopword(self, backend, model_case):\n        api_client = APIClient(BASE_URL)\n        model_name = api_client.available_models[0]\n        for output in api_client.chat_completions_v1(model=model_name,\n                                                     messages=[\n                                                         {\n                                                             'role': 'user',\n                                                             'content': 'Shanghai is'\n                                                         },\n                                                     ],\n                                                     stop=' is',\n                                                     temperature=0.01):\n            continue\n        assert_chat_completions_batch_return(output, model_name)\n        assert ' is' not in output.get('choices')[0].get('message').get('content')\n        assert output.get('choices')[0].get('finish_reason') == 'stop'\n\n    def test_single_stopword_streaming(self, backend, model_case):\n        api_client = APIClient(BASE_URL)\n        model_name = api_client.available_models[0]\n        outputList = []\n        for output in api_client.chat_completions_v1(model=model_name,\n                                                     messages=[\n                                                         {\n                                                             'role': 'user',\n                                                             'content': 'Shanghai is'\n                                                         },\n                                                     ],\n                                                     stop=' is',\n                                                     stream=True,\n                                                     temperature=0.01):\n            outputList.append(output)\n\n        assert_chat_completions_stream_return(outputList[-1], model_name, True)\n        for index in range(0, len(outputList) - 1):\n            assert_chat_completions_stream_return(outputList[index], model_name)\n            assert ' to' not in outputList[index].get('choices')[0].get('delta').get('content')\n        assert outputList[-1].get('choices')[0].get('finish_reason') == 'stop'\n\n    def test_array_stopwords(self, backend, model_case):\n        api_client = APIClient(BASE_URL)\n        model_name = api_client.available_models[0]\n        for output in api_client.chat_completions_v1(model=model_name,\n                                                     messages=[\n                                                         {\n                                                             'role': 'user',\n                                                             'content': 'Shanghai is'\n                                                         },\n                                                     ],\n                                                     stop=[' is', '上海', ' to'],\n                                                     temperature=0.01):\n            continue\n        assert_chat_completions_batch_return(output, model_name)\n        assert ' is' not in output.get('choices')[0].get('message').get('content')\n        assert ' 上海' not in output.get('choices')[0].get('message').get('content')\n        assert ' to ' not in output.get('choices')[0].get('message').get('content')\n        assert output.get('choices')[0].get('finish_reason') == 'stop'\n\n    def test_array_stopwords_streaming(self, backend, model_case):\n        api_client = APIClient(BASE_URL)\n        model_name = api_client.available_models[0]\n        outputList = []\n        for output in api_client.chat_completions_v1(model=model_name,\n                                                     messages=[\n                                                         {\n                                                             'role': 'user',\n                                                             'content': 'Shanghai is'\n                                                         },\n                                                     ],\n                                                     stop=[' is', '上海', ' to'],\n                                                     stream=True,\n                                                     temperature=0.01):\n            outputList.append(output)\n\n        assert_chat_completions_stream_return(outputList[-1], model_name, True)\n        for index in range(0, len(outputList) - 1):\n            assert_chat_completions_stream_return(outputList[index], model_name)\n            assert ' is' not in outputList[index].get('choices')[0].get('delta').get('content')\n            assert '上海' not in outputList[index].get('choices')[0].get('delta').get('content')\n            assert ' to ' not in outputList[index].get('choices')[0].get('delta').get('content')\n        assert outputList[-1].get('choices')[0].get('finish_reason') == 'stop'\n\n    @pytest.mark.internlm2_5\n    def test_special_words(self, backend, model_case):\n        message = '<|im_start|>system\\n当开启工具以及代码时，根据需求选择合适的工具进行调用\\n' + \\\n                '<|im_end|><|im_start|>system name=<|interpreter|>\\n你现在已经' + \\\n                '能够在一个有状态的 Jupyter 笔记本环境中运行 Python 代码。当你向 python ' + \\\n                '发送含有 Python >代码的消息时，它将在该环境中执行。这个工具适用于多种场景，' + \\\n                '如数据分析或处理（包括数据操作、统计分析、图表绘制），复杂的计算问题（解决数学和物理' + \\\n                '难题），编程示例（理解编程概念或特性），文本处理和分析（比如文本解析和自然语言处理），机器学习和数据科学（用于' + \\\n                '展示模型训练和数据可视化），以及文件操作和数据导入（处理CSV、JSON等格式的文件）。<|im_end|>\\n' + \\\n                '<|im_start|>user\\n设 $L$ 为圆周$x^2+y^2=2x$，计算曲线积分：$I=\\\\int_L' + \\\n                '{x\\\\mathrm{d}s}=$<|im_end|>\\n<|im_start|>assistant'\n        api_client = APIClient(BASE_URL)\n        model_name = api_client.available_models[0]\n        for output in api_client.chat_completions_v1(model=model_name,\n                                                     messages=message,\n                                                     skip_special_tokens=False,\n                                                     temperature=0.01):\n            continue\n        assert_chat_completions_batch_return(output, model_name)\n        assert '<|action_start|><|interpreter|>' in output.get('choices')[0].get('message').get('content')\n\n        for output in api_client.chat_completions_v1(model=model_name,\n                                                     messages=message,\n                                                     skip_special_tokens=True,\n                                                     temperature=0.01):\n            continue\n        assert_chat_completions_batch_return(output, model_name)\n        assert '<|action_start|><|interpreter|>' not in output.get('choices')[0].get('message').get('content')\n\n    def test_minimum_repetition_penalty(self, backend, model_case):\n        api_client = APIClient(BASE_URL)\n        model_name = api_client.available_models[0]\n        for output in api_client.chat_completions_v1(model=model_name,\n                                                     messages=[\n                                                         {\n                                                             'role': 'user',\n                                                             'content': 'Shanghai is'\n                                                         },\n                                                     ],\n                                                     repetition_penalty=0.0000001,\n                                                     temperature=0.01,\n                                                     max_tokens=200,\n                                                     min_new_tokens=100):\n            continue\n        assert_chat_completions_batch_return(output, model_name)\n        result, msg = has_repeated_fragment(output.get('choices')[0].get('message').get('content'))\n        assert result, msg\n\n    def test_minimum_repetition_penalty_streaming(self, backend, model_case):\n        api_client = APIClient(BASE_URL)\n        model_name = api_client.available_models[0]\n        outputList = []\n        for output in api_client.chat_completions_v1(model=model_name,\n                                                     messages=[\n                                                         {\n                                                             'role': 'user',\n                                                             'content': 'Hi, pls intro yourself'\n                                                         },\n                                                     ],\n                                                     stream=True,\n                                                     repetition_penalty=0.0000001,\n                                                     temperature=0.01,\n                                                     max_tokens=200,\n                                                     min_new_tokens=100):\n            outputList.append(output)\n        assert_chat_completions_stream_return(outputList[-1], model_name, True)\n        response = ''\n        for index in range(0, len(outputList) - 1):\n            assert_chat_completions_stream_return(outputList[index], model_name)\n            response += outputList[index].get('choices')[0].get('delta').get('content')\n        result, msg = has_repeated_fragment(response)\n        assert result, msg\n\n    def test_repetition_penalty_bigger_than_1(self, backend, model_case):\n        api_client = APIClient(BASE_URL)\n        model_name = api_client.available_models[0]\n        for output in api_client.chat_completions_v1(model=model_name,\n                                                     messages=[\n                                                         {\n                                                             'role': 'user',\n                                                             'content': 'Shanghai is'\n                                                         },\n                                                     ],\n                                                     repetition_penalty=1.2,\n                                                     temperature=0.01,\n                                                     max_tokens=200):\n            continue\n        assert_chat_completions_batch_return(output, model_name)\n\n    def test_repetition_penalty_bigger_than_1_streaming(self, backend, model_case):\n        api_client = APIClient(BASE_URL)\n        model_name = api_client.available_models[0]\n        outputList = []\n        for output in api_client.chat_completions_v1(model=model_name,\n                                                     messages=[\n                                                         {\n                                                             'role': 'user',\n                                                             'content': 'Hi, pls intro yourself'\n                                                         },\n                                                     ],\n                                                     stream=True,\n                                                     repetition_penalty=1.2,\n                                                     temperature=0.01,\n                                                     max_tokens=200):\n            outputList.append(output)\n        assert_chat_completions_stream_return(outputList[-1], model_name, True)\n        for index in range(0, len(outputList) - 1):\n            assert_chat_completions_stream_return(outputList[index], model_name)\n            continue\n\n    def test_minimum_topp(self, backend, model_case):\n        api_client = APIClient(BASE_URL)\n        model_name = api_client.available_models[0]\n        outputList = []\n        for i in range(3):\n            for output in api_client.chat_completions_v1(model=model_name,\n                                                         messages=[\n                                                             {\n                                                                 'role': 'user',\n                                                                 'content': 'Shanghai is'\n                                                             },\n                                                         ],\n                                                         top_p=0.0000000001,\n                                                         max_tokens=10):\n                outputList.append(output)\n            assert_chat_completions_batch_return(output, model_name)\n        assert outputList[0].get('choices')[0].get('message').get('content') == outputList[1].get('choices')[0].get(\n            'message').get('content')\n        assert outputList[1].get('choices')[0].get('message').get('content') == outputList[2].get('choices')[0].get(\n            'message').get('content')\n\n    def test_minimum_topp_streaming(self, backend, model_case):\n        api_client = APIClient(BASE_URL)\n        model_name = api_client.available_models[0]\n        responseList = []\n        for i in range(3):\n            outputList = []\n            response = ''\n            for output in api_client.chat_completions_v1(model=model_name,\n                                                         messages=[\n                                                             {\n                                                                 'role': 'user',\n                                                                 'content': 'Hi, pls intro yourself'\n                                                             },\n                                                         ],\n                                                         stream=True,\n                                                         top_p=0.0000000001,\n                                                         max_tokens=10):\n                outputList.append(output)\n            assert_chat_completions_stream_return(outputList[-1], model_name, True)\n            response = ''\n            for index in range(0, len(outputList) - 1):\n                assert_chat_completions_stream_return(outputList[index], model_name)\n                response += outputList[index].get('choices')[0].get('delta').get('content')\n            responseList.append(response)\n        assert responseList[0] == responseList[1] or responseList[1] == responseList[2]\n\n    def test_mistake_modelname_return(self, backend, model_case):\n        api_client = APIClient(BASE_URL)\n        for output in api_client.chat_completions_v1(model='error',\n                                                     messages=[\n                                                         {\n                                                             'role': 'user',\n                                                             'content': 'Hi, pls intro yourself'\n                                                         },\n                                                     ],\n                                                     temperature=0.01):\n            continue\n        assert output.get('code') == 404\n        assert output.get('message') == 'The model \\'error\\' does not exist.'\n        assert output.get('object') == 'error'\n\n    def test_mistake_modelname_return_streaming(self, backend, model_case):\n        api_client = APIClient(BASE_URL)\n        outputList = []\n        for output in api_client.chat_completions_v1(model='error',\n                                                     messages=[\n                                                         {\n                                                             'role': 'user',\n                                                             'content': 'Hi, pls intro yourself'\n                                                         },\n                                                     ],\n                                                     stream=True,\n                                                     max_tokens=5,\n                                                     temperature=0.01):\n            outputList.append(output)\n        assert output.get('code') == 404\n        assert output.get('message') == 'The model \\'error\\' does not exist.'\n        assert output.get('object') == 'error'\n        assert len(outputList) == 1\n\n    def test_mutilple_times_response_should_not_same(self, backend, model_case):\n        api_client = APIClient(BASE_URL)\n        model_name = api_client.available_models[0]\n        outputList = []\n        for i in range(3):\n            for output in api_client.chat_completions_v1(model=model_name,\n                                                         messages=[\n                                                             {\n                                                                 'role': 'user',\n                                                                 'content': 'Shanghai is',\n                                                             },\n                                                         ],\n                                                         max_tokens=100):\n                outputList.append(output)\n            assert_chat_completions_batch_return(output, model_name)\n        assert outputList[0].get('choices')[0].get('message').get('content') != outputList[1].get('choices')[0].get(\n            'message').get('content') or outputList[1].get('choices')[0].get('message').get(\n                'content') != outputList[2].get('choices')[0].get('message').get('content')\n\n    def test_mutilple_times_response_should_not_same_streaming(self, backend, model_case):\n        api_client = APIClient(BASE_URL)\n        model_name = api_client.available_models[0]\n        responseList = []\n        for i in range(3):\n            outputList = []\n            for output in api_client.chat_completions_v1(model=model_name,\n                                                         messages=[\n                                                             {\n                                                                 'role': 'user',\n                                                                 'content': 'Shanghai is',\n                                                             },\n                                                         ],\n                                                         stream=True,\n                                                         max_tokens=100):\n                outputList.append(output)\n            assert_chat_completions_stream_return(outputList[-1], model_name, True)\n            response = ''\n            for index in range(0, len(outputList) - 1):\n                assert_chat_completions_stream_return(outputList[index], model_name)\n                response += outputList[index].get('choices')[0].get('delta').get('content')\n            responseList.append(response)\n        assert responseList[0] != responseList[1] or responseList[1] == responseList[2]\n\n    def test_longtext_input(self, backend, model_case):\n        api_client = APIClient(BASE_URL)\n        model_name = api_client.available_models[0]\n        for output in api_client.chat_completions_v1(model=model_name,\n                                                     messages=[\n                                                         {\n                                                             'role': 'user',\n                                                             'content': 'Hi, pls intro yourself' * 100000,\n                                                         },\n                                                     ],\n                                                     temperature=0.01):\n            continue\n        assert output.get('choices')[0].get('finish_reason') == 'length'\n        assert output.get('choices')[0].get('message').get('content') == ''\n\n    def test_longtext_input_streaming(self, backend, model_case):\n        api_client = APIClient(BASE_URL)\n        model_name = api_client.available_models[0]\n        outputList = []\n        for output in api_client.chat_completions_v1(model=model_name,\n                                                     messages=[\n                                                         {\n                                                             'role': 'user',\n                                                             'content': 'Hi, pls intro yourself' * 100000,\n                                                         },\n                                                     ],\n                                                     stream=True,\n                                                     temperature=0.01):\n            outputList.append(output)\n        assert_chat_completions_stream_return(outputList[0], model_name, is_last=True)\n        assert outputList[0].get('choices')[0].get('finish_reason') == 'length'\n        assert outputList[0].get('choices')[0].get('delta').get('content') == ''\n        assert len(outputList) == 1\n\n    def test_ignore_eos(self, backend, model_case):\n        api_client = APIClient(BASE_URL)\n        model_name = api_client.available_models[0]\n        for output in api_client.chat_completions_v1(model=model_name,\n                                                     messages=[\n                                                         {\n                                                             'role': 'user',\n                                                             'content': 'Hi, what is your name?'\n                                                         },\n                                                     ],\n                                                     ignore_eos=True,\n                                                     max_tokens=100,\n                                                     temperature=0.01):\n            continue\n        assert_chat_completions_batch_return(output, model_name)\n        assert output.get('usage').get('completion_tokens') == 101 or output.get('usage').get(\n            'completion_tokens') == 100\n        assert output.get('choices')[0].get('finish_reason') == 'length'\n\n    def test_ignore_eos_streaming(self, backend, model_case):\n        api_client = APIClient(BASE_URL)\n        model_name = api_client.available_models[0]\n        outputList = []\n        for output in api_client.chat_completions_v1(model=model_name,\n                                                     messages=[\n                                                         {\n                                                             'role': 'user',\n                                                             'content': 'Hi, what is your name?'\n                                                         },\n                                                     ],\n                                                     ignore_eos=True,\n                                                     stream=True,\n                                                     max_tokens=100,\n                                                     temperature=0.01):\n            outputList.append(output)\n        assert_chat_completions_stream_return(outputList[-1], model_name, True)\n        response = ''\n        for index in range(0, len(outputList) - 1):\n            assert_chat_completions_stream_return(outputList[index], model_name)\n            response += outputList[index].get('choices')[0].get('delta').get('content')\n        length = api_client.encode(response, add_bos=False)[1]\n        assert outputList[-1].get('choices')[0].get('finish_reason') == 'length'\n        assert length >= 99 and length <= 101\n\n    def __test_max_tokens_or_max_completion_tokens(\n        self,\n        max_tokens_or_max_completion_tokens: Literal['max_tokens', 'max_completion_tokens'],\n    ):\n        api_client = APIClient(BASE_URL)\n        model_name = api_client.available_models[0]\n        if max_tokens_or_max_completion_tokens == 'max_tokens':\n            for output in api_client.chat_completions_v1(\n                    model=model_name,\n                    messages=[\n                        {\n                            'role': 'user',\n                            'content': 'Hi, pls intro yourself'\n                        },\n                    ],\n                    max_tokens=5,\n                    temperature=0.01,\n            ):\n                continue\n        else:\n            for output in api_client.chat_completions_v1(\n                    model=model_name,\n                    messages=[\n                        {\n                            'role': 'user',\n                            'content': 'Hi, pls intro yourself'\n                        },\n                    ],\n                    max_completion_tokens=5,\n                    temperature=0.01,\n            ):\n                continue\n        assert_chat_completions_batch_return(output, model_name)\n        assert output.get('choices')[0].get('finish_reason') == 'length'\n        assert output.get('usage').get('completion_tokens') == 6 or output.get('usage').get('completion_tokens') == 5\n\n    def test_max_tokens(self, backend, model_case):\n        self.__test_max_tokens_or_max_completion_tokens('max_tokens')\n\n    def test_max_completion_tokens(self, backend, model_case):\n        self.__test_max_tokens_or_max_completion_tokens('max_completion_tokens')\n\n    def __test_max_tokens_streaming_or_max_completion_tokens_streaming(\n        self,\n        max_tokens_or_max_completion_tokens: Literal['max_tokens', 'max_completion_tokens'],\n    ):\n        api_client = APIClient(BASE_URL)\n        model_name = api_client.available_models[0]\n        outputList = []\n        if max_tokens_or_max_completion_tokens == 'max_tokens':\n            for output in api_client.chat_completions_v1(\n                    model=model_name,\n                    messages=[\n                        {\n                            'role': 'user',\n                            'content': 'Hi, pls intro yourself'\n                        },\n                    ],\n                    stream=True,\n                    max_tokens=5,\n                    temperature=0.01,\n            ):\n                outputList.append(output)\n        else:\n            for output in api_client.chat_completions_v1(\n                    model=model_name,\n                    messages=[\n                        {\n                            'role': 'user',\n                            'content': 'Hi, pls intro yourself'\n                        },\n                    ],\n                    stream=True,\n                    max_completion_tokens=5,\n                    temperature=0.01,\n            ):\n                outputList.append(output)\n        assert_chat_completions_stream_return(outputList[-1], model_name, True)\n        response = ''\n        for index in range(0, len(outputList) - 1):\n            assert_chat_completions_stream_return(outputList[index], model_name)\n            response += outputList[index].get('choices')[0].get('delta').get('content')\n        length = api_client.encode(response, add_bos=False)[1]\n        assert outputList[-1].get('choices')[0].get('finish_reason') == 'length'\n        assert length == 5 or length == 6\n\n    def test_max_tokens_streaming(self, backend, model_case):\n        self.__test_max_tokens_streaming_or_max_completion_tokens_streaming('max_tokens')\n\n    def test_max_completion_tokens_streaming(self, backend, model_case):\n        self.__test_max_tokens_streaming_or_max_completion_tokens_streaming('max_completion_tokens')\n\n    @pytest.mark.not_pytorch\n    def test_logprobs(self, backend, model_case):\n        api_client = APIClient(BASE_URL)\n        model_name = api_client.available_models[0]\n        for output in api_client.chat_completions_v1(model=model_name,\n                                                     messages=[\n                                                         {\n                                                             'role': 'user',\n                                                             'content': 'Hi, pls intro yourself'\n                                                         },\n                                                     ],\n                                                     max_tokens=5,\n                                                     temperature=0.01,\n                                                     logprobs=True,\n                                                     top_logprobs=10):\n            continue\n        assert_chat_completions_batch_return(output, model_name, check_logprobs=True, logprobs_num=10)\n        assert output.get('choices')[0].get('finish_reason') == 'length'\n        assert output.get('usage').get('completion_tokens') == 6 or output.get('usage').get('completion_tokens') == 5\n\n    @pytest.mark.not_pytorch\n    def test_logprobs_streaming(self, backend, model_case):\n        api_client = APIClient(BASE_URL)\n        model_name = api_client.available_models[0]\n        outputList = []\n        for output in api_client.chat_completions_v1(model=model_name,\n                                                     messages=[\n                                                         {\n                                                             'role': 'user',\n                                                             'content': 'Hi, pls intro yourself'\n                                                         },\n                                                     ],\n                                                     stream=True,\n                                                     max_tokens=5,\n                                                     temperature=0.01,\n                                                     logprobs=True,\n                                                     top_logprobs=10):\n            outputList.append(output)\n        assert_chat_completions_stream_return(outputList[-1], model_name, True, check_logprobs=True, logprobs_num=10)\n        response = ''\n        for index in range(0, len(outputList) - 1):\n            assert_chat_completions_stream_return(outputList[index], model_name, check_logprobs=True, logprobs_num=10)\n            response += outputList[index].get('choices')[0].get('delta').get('content')\n        length = api_client.encode(response, add_bos=False)[1]\n        assert outputList[-1].get('choices')[0].get('finish_reason') == 'length'\n        assert length == 5 or length == 6\n\n\n@pytest.mark.order(8)\n@pytest.mark.flaky(reruns=2)\n@pytest.mark.parametrize('backend', BACKEND_LIST)\n@pytest.mark.parametrize('model_case', RESTFUL_MODEL_LIST)\nclass TestRestfulOpenAI:\n\n    @pytest.mark.pr_test\n    def test_return_info(self, backend, model_case):\n        client = OpenAI(api_key='YOUR_API_KEY', base_url=f'{BASE_URL}/v1')\n        model_name = client.models.list().data[0].id\n        outputs = client.chat.completions.create(model=model_name,\n                                                 messages=[\n                                                     {\n                                                         'role': 'user',\n                                                         'content': 'Hi, pls intro yourself'\n                                                     },\n                                                 ],\n                                                 temperature=0.01)\n\n        output = outputs.model_dump()\n        assert_chat_completions_batch_return(output, model_name)\n\n    @pytest.mark.pr_test\n    def test_return_info_streaming(self, backend, model_case):\n        client = OpenAI(api_key='YOUR_API_KEY', base_url=f'{BASE_URL}/v1')\n        model_name = client.models.list().data[0].id\n        outputs = client.chat.completions.create(model=model_name,\n                                                 messages=[\n                                                     {\n                                                         'role': 'user',\n                                                         'content': 'Hi, pls intro yourself'\n                                                     },\n                                                 ],\n                                                 temperature=0.01,\n                                                 stream=True)\n\n        outputList = []\n        for output in outputs:\n            outputList.append(output.model_dump())\n\n        assert_chat_completions_stream_return(outputList[-1], model_name, True)\n        for index in range(0, len(outputList) - 1):\n            assert_chat_completions_stream_return(outputList[index], model_name)\n\n    def test_single_stopword(self, backend, model_case):\n        client = OpenAI(api_key='YOUR_API_KEY', base_url=f'{BASE_URL}/v1')\n        model_name = client.models.list().data[0].id\n        outputs = client.chat.completions.create(model=model_name,\n                                                 messages=[\n                                                     {\n                                                         'role': 'user',\n                                                         'content': 'Shanghai is'\n                                                     },\n                                                 ],\n                                                 temperature=0.01,\n                                                 stop=' is')\n\n        output = outputs.model_dump()\n        assert_chat_completions_batch_return(output, model_name)\n        assert ' is' not in output.get('choices')[0].get('message').get('content')\n        assert output.get('choices')[0].get('finish_reason') == 'stop'\n\n    @pytest.mark.pr_test\n    def test_single_stopword_streaming(self, backend, model_case):\n        client = OpenAI(api_key='YOUR_API_KEY', base_url=f'{BASE_URL}/v1')\n        model_name = client.models.list().data[0].id\n        outputs = client.chat.completions.create(model=model_name,\n                                                 messages=[\n                                                     {\n                                                         'role': 'user',\n                                                         'content': 'Shanghai is'\n                                                     },\n                                                 ],\n                                                 stop=' is',\n                                                 temperature=0.01,\n                                                 stream=True)\n\n        outputList = []\n        for output in outputs:\n            outputList.append(output.model_dump())\n\n        assert_chat_completions_stream_return(outputList[-1], model_name, True)\n        for index in range(0, len(outputList) - 1):\n            assert_chat_completions_stream_return(outputList[index], model_name)\n            assert ' is ' not in outputList[index].get('choices')[0].get('delta').get('content')\n        assert outputList[-1].get('choices')[0].get('finish_reason') == 'stop'\n\n    def test_array_stopwords(self, backend, model_case):\n        client = OpenAI(api_key='YOUR_API_KEY', base_url=f'{BASE_URL}/v1')\n        model_name = client.models.list().data[0].id\n        outputs = client.chat.completions.create(\n            model=model_name,\n            messages=[\n                {\n                    'role': 'user',\n                    'content': 'Shanghai is'\n                },\n            ],\n            temperature=0.01,\n            stop=[' is', '上海', ' to'],\n        )\n\n        output = outputs.model_dump()\n        assert_chat_completions_batch_return(output, model_name)\n        assert ' is' not in output.get('choices')[0].get('message').get('content')\n        assert ' 上海' not in output.get('choices')[0].get('message').get('content')\n        assert ' to' not in output.get('choices')[0].get('message').get('content')\n        assert output.get('choices')[0].get('finish_reason') == 'stop'\n\n    def test_array_stopwords_streaming(self, backend, model_case):\n        client = OpenAI(api_key='YOUR_API_KEY', base_url=f'{BASE_URL}/v1')\n        model_name = client.models.list().data[0].id\n        outputs = client.chat.completions.create(model=model_name,\n                                                 messages=[\n                                                     {\n                                                         'role': 'user',\n                                                         'content': 'Shanghai is'\n                                                     },\n                                                 ],\n                                                 stop=[' is', '上海', ' to'],\n                                                 temperature=0.01,\n                                                 stream=True)\n\n        outputList = []\n        for output in outputs:\n            outputList.append(output.model_dump())\n\n        assert_chat_completions_stream_return(outputList[-1], model_name, True)\n        for index in range(0, len(outputList) - 1):\n            assert_chat_completions_stream_return(outputList[index], model_name)\n            assert ' is' not in outputList[index].get('choices')[0].get('delta').get('content')\n            assert '上海' not in outputList[index].get('choices')[0].get('delta').get('content')\n            assert ' to ' not in outputList[index].get('choices')[0].get('delta').get('content')\n        assert outputList[-1].get('choices')[0].get('finish_reason') == 'stop'\n\n    @pytest.mark.pr_test\n    def test_minimum_topp(self, backend, model_case):\n        client = OpenAI(api_key='YOUR_API_KEY', base_url=f'{BASE_URL}/v1')\n        model_name = client.models.list().data[0].id\n        outputList = []\n        for i in range(3):\n            outputs = client.chat.completions.create(model=model_name,\n                                                     messages=[\n                                                         {\n                                                             'role': 'user',\n                                                             'content': 'Shanghai is'\n                                                         },\n                                                     ],\n                                                     temperature=0.01,\n                                                     top_p=0.0000000001,\n                                                     max_tokens=10)\n            output = outputs.model_dump()\n            outputList.append(output)\n            assert_chat_completions_batch_return(output, model_name)\n        assert outputList[0].get('choices')[0].get('message').get('content') == outputList[1].get('choices')[0].get(\n            'message').get('content')\n        assert outputList[1].get('choices')[0].get('message').get('content') == outputList[2].get('choices')[0].get(\n            'message').get('content')\n\n    def test_minimum_topp_streaming(self, backend, model_case):\n        client = OpenAI(api_key='YOUR_API_KEY', base_url=f'{BASE_URL}/v1')\n        model_name = client.models.list().data[0].id\n        responseList = []\n        for i in range(3):\n            outputs = client.chat.completions.create(model=model_name,\n                                                     messages=[\n                                                         {\n                                                             'role': 'user',\n                                                             'content': 'Hi, pls intro yourself'\n                                                         },\n                                                     ],\n                                                     top_p=0.0000000001,\n                                                     max_tokens=10,\n                                                     stream=True)\n\n            outputList = []\n            for output in outputs:\n                outputList.append(output.model_dump())\n            assert_chat_completions_stream_return(outputList[-1], model_name, True)\n            response = ''\n            for index in range(0, len(outputList) - 1):\n                assert_chat_completions_stream_return(outputList[index], model_name)\n                response += outputList[index].get('choices')[0].get('delta').get('content')\n            responseList.append(response)\n        assert responseList[0] == responseList[1] or responseList[1] == responseList[2]\n\n    @pytest.mark.pr_test\n    def test_mistake_modelname_return(self, backend, model_case):\n        client = OpenAI(api_key='YOUR_API_KEY', base_url=f'{BASE_URL}/v1')\n        with pytest.raises(Exception, match='The model \\'error\\' does not exist.'):\n            client.chat.completions.create(\n                model='error',\n                messages=[\n                    {\n                        'role': 'user',\n                        'content': 'Shanghai is'\n                    },\n                ],\n                temperature=0.01,\n                stop=[' is', '上海', ' to'],\n            )\n\n    def test_mistake_modelname_return_streaming(self, backend, model_case):\n        client = OpenAI(api_key='YOUR_API_KEY', base_url=f'{BASE_URL}/v1')\n\n        with pytest.raises(Exception, match='The model \\'error\\' does not exist.'):\n            client.chat.completions.create(model='error',\n                                           messages=[\n                                               {\n                                                   'role': 'user',\n                                                   'content': 'Hi, pls intro yourself'\n                                               },\n                                           ],\n                                           max_tokens=5,\n                                           temperature=0.01,\n                                           stream=True)\n\n    @pytest.mark.pr_test\n    def test_mutilple_times_response_should_not_same(self, backend, model_case):\n        client = OpenAI(api_key='YOUR_API_KEY', base_url=f'{BASE_URL}/v1')\n        model_name = client.models.list().data[0].id\n        outputList = []\n        for i in range(3):\n            outputs = client.chat.completions.create(model=model_name,\n                                                     messages=[\n                                                         {\n                                                             'role': 'user',\n                                                             'content': 'Shanghai is'\n                                                         },\n                                                     ],\n                                                     max_tokens=100)\n            output = outputs.model_dump()\n            outputList.append(output)\n            assert_chat_completions_batch_return(output, model_name)\n        assert outputList[0].get('choices')[0].get('message').get('content') != outputList[1].get('choices')[0].get(\n            'message').get('content') or outputList[1].get('choices')[0].get('message').get(\n                'content') != outputList[2].get('choices')[0].get('message').get('content')\n\n    def test_mutilple_times_response_should_not_same_streaming(self, backend, model_case):\n        client = OpenAI(api_key='YOUR_API_KEY', base_url=f'{BASE_URL}/v1')\n        model_name = client.models.list().data[0].id\n        responseList = []\n        for i in range(3):\n            outputs = client.chat.completions.create(model=model_name,\n                                                     messages=[\n                                                         {\n                                                             'role': 'user',\n                                                             'content': 'Hi, pls intro yourself'\n                                                         },\n                                                     ],\n                                                     max_tokens=100,\n                                                     stream=True)\n\n            outputList = []\n            for output in outputs:\n                outputList.append(output.model_dump())\n            assert_chat_completions_stream_return(outputList[-1], model_name, True)\n            response = ''\n            for index in range(0, len(outputList) - 1):\n                assert_chat_completions_stream_return(outputList[index], model_name)\n                response += outputList[index].get('choices')[0].get('delta').get('content')\n            responseList.append(response)\n        assert responseList[0] != responseList[1] or responseList[1] == responseList[2]\n\n    def test_longtext_input(self, backend, model_case):\n        client = OpenAI(api_key='YOUR_API_KEY', base_url=f'{BASE_URL}/v1')\n        model_name = client.models.list().data[0].id\n        outputs = client.chat.completions.create(model=model_name,\n                                                 messages=[\n                                                     {\n                                                         'role': 'user',\n                                                         'content': 'Hi, pls intro yourself' * 100000\n                                                     },\n                                                 ],\n                                                 max_tokens=100)\n        output = outputs.model_dump()\n        print(output)\n        assert output.get('choices')[0].get('finish_reason') == 'error'\n        assert output.get('choices')[0].get('message').get(\n            'content') == 'internal error happened, status code ResponseType.INPUT_LENGTH_ERROR'\n\n    @pytest.mark.pr_test\n    def test_longtext_input_streaming(self, backend, model_case):\n        client = OpenAI(api_key='YOUR_API_KEY', base_url=f'{BASE_URL}/v1')\n        model_name = client.models.list().data[0].id\n\n        outputs = client.chat.completions.create(model=model_name,\n                                                 messages=[\n                                                     {\n                                                         'role': 'user',\n                                                         'content': 'Hi, pls intro yourself' * 100000\n                                                     },\n                                                 ],\n                                                 max_tokens=100,\n                                                 stream=True)\n\n        outputList = []\n        for output in outputs:\n            outputList.append(output.model_dump())\n\n        assert_chat_completions_stream_return(outputList[0], model_name, is_last=True)\n        assert outputList[0].get('choices')[0].get('finish_reason') == 'error'\n        assert outputList[0].get('choices')[0].get('delta').get(\n            'content') == 'internal error happened, status code ResponseType.INPUT_LENGTH_ERROR'\n        assert len(outputList) == 1\n\n    @pytest.mark.pr_test\n    def test_max_tokens(self, backend, model_case):\n        client = OpenAI(api_key='YOUR_API_KEY', base_url=f'{BASE_URL}/v1')\n        model_name = client.models.list().data[0].id\n        outputs = client.chat.completions.create(model=model_name,\n                                                 messages=[\n                                                     {\n                                                         'role': 'user',\n                                                         'content': 'Hi, pls intro yourself'\n                                                     },\n                                                 ],\n                                                 max_tokens=5,\n                                                 temperature=0.01)\n        output = outputs.model_dump()\n        assert_chat_completions_batch_return(output, model_name)\n        assert output.get('choices')[0].get('finish_reason') == 'length'\n        assert output.get('usage').get('completion_tokens') == 6 or output.get('usage').get('completion_tokens') == 5\n\n    def test_max_tokens_streaming(self, backend, model_case):\n        client = OpenAI(api_key='YOUR_API_KEY', base_url=f'{BASE_URL}/v1')\n        model_name = client.models.list().data[0].id\n\n        outputs = client.chat.completions.create(model=model_name,\n                                                 messages=[\n                                                     {\n                                                         'role': 'user',\n                                                         'content': 'Hi, pls intro yourself'\n                                                     },\n                                                 ],\n                                                 max_tokens=5,\n                                                 temperature=0.01,\n                                                 stream=True)\n\n        outputList = []\n        for output in outputs:\n            outputList.append(output.model_dump())\n\n        assert_chat_completions_stream_return(outputList[-1], model_name, True)\n        response = ''\n        for index in range(0, len(outputList) - 1):\n            assert_chat_completions_stream_return(outputList[index], model_name)\n            response += outputList[index].get('choices')[0].get('delta').get('content')\n        api_client = APIClient(BASE_URL)\n        length = api_client.encode(response, add_bos=False)[1]\n        assert outputList[-1].get('choices')[0].get('finish_reason') == 'length'\n        assert length == 5 or length == 6\n\n    @pytest.mark.not_pytorch\n    @pytest.mark.pr_test\n    def test_logprobs(self, backend, model_case):\n        client = OpenAI(api_key='YOUR_API_KEY', base_url=f'{BASE_URL}/v1')\n        model_name = client.models.list().data[0].id\n        outputs = client.chat.completions.create(model=model_name,\n                                                 messages=[\n                                                     {\n                                                         'role': 'user',\n                                                         'content': 'Hi, pls intro yourself'\n                                                     },\n                                                 ],\n                                                 max_tokens=5,\n                                                 temperature=0.01,\n                                                 logprobs=True,\n                                                 top_logprobs=10)\n        output = outputs.model_dump()\n        assert_chat_completions_batch_return(output, model_name, check_logprobs=True, logprobs_num=10)\n        assert output.get('choices')[0].get('finish_reason') == 'length'\n        assert output.get('usage').get('completion_tokens') == 6 or output.get('usage').get('completion_tokens') == 5\n\n    @pytest.mark.not_pytorch\n    @pytest.mark.pr_test\n    def test_logprobs_streaming(self, backend, model_case):\n        client = OpenAI(api_key='YOUR_API_KEY', base_url=f'{BASE_URL}/v1')\n        model_name = client.models.list().data[0].id\n\n        outputs = client.chat.completions.create(model=model_name,\n                                                 messages=[\n                                                     {\n                                                         'role': 'user',\n                                                         'content': 'Hi, pls intro yourself'\n                                                     },\n                                                 ],\n                                                 max_tokens=5,\n                                                 temperature=0.01,\n                                                 logprobs=True,\n                                                 top_logprobs=10,\n                                                 stream=True)\n\n        outputList = []\n        for output in outputs:\n            outputList.append(output.model_dump())\n\n        assert_chat_completions_stream_return(outputList[-1], model_name, True, check_logprobs=True, logprobs_num=10)\n        response = ''\n        for index in range(0, len(outputList) - 1):\n            assert_chat_completions_stream_return(outputList[index], model_name, check_logprobs=True, logprobs_num=10)\n            response += outputList[index].get('choices')[0].get('delta').get('content')\n        api_client = APIClient(BASE_URL)\n        length = api_client.encode(response, add_bos=False)[1]\n        assert outputList[-1].get('choices')[0].get('finish_reason') == 'length'\n        assert length == 5 or length == 6\n\n    def test_input_validation(self, backend, model_case):\n        client = OpenAI(api_key='YOUR_API_KEY', base_url=f'{BASE_URL}/v1')\n        model_name = client.models.list().data[0].id\n        messages = [\n            {\n                'role': 'user',\n                'content': 'Hi, pls intro yourself'\n            },\n        ],\n        with pytest.raises(Exception):\n            client.chat.completions.create(model=model_name, messages=messages, top_p=0)\n\n        with pytest.raises(Exception):\n            client.chat.completions.create(model=model_name, messages=messages, top_p=1.01)\n\n        with pytest.raises(Exception):\n            client.chat.completions.create(model=model_name, messages=messages, top_p='test')\n\n        with pytest.raises(Exception):\n            client.chat.completions.create(model=model_name, messages=messages, n=0)\n\n        with pytest.raises(Exception):\n            client.chat.completions.create(model=model_name, messages=messages, n='test')\n\n        with pytest.raises(Exception):\n            client.chat.completions.create(model=model_name, messages=messages, temperature=-0.01)\n\n        with pytest.raises(Exception):\n            client.chat.completions.create(model=model_name, messages=messages, temperature=2.01)\n\n        with pytest.raises(Exception):\n            client.chat.completions.create(model=model_name, messages=messages, temperature='test')\n\n    def test_input_validation_streaming(self, backend, model_case):\n        client = OpenAI(api_key='YOUR_API_KEY', base_url=f'{BASE_URL}/v1')\n        model_name = client.models.list().data[0].id\n        messages = [\n            {\n                'role': 'user',\n                'content': 'Hi, pls intro yourself'\n            },\n        ],\n        with pytest.raises(Exception):\n            client.chat.completions.create(model=model_name, messages=messages, top_p=0, stream=True)\n\n        with pytest.raises(Exception):\n            client.chat.completions.create(model=model_name, messages=messages, top_p=1.01, stream=True)\n\n        with pytest.raises(Exception):\n            client.chat.completions.create(model=model_name, messages=messages, top_p='test', stream=True)\n\n        with pytest.raises(Exception):\n            client.chat.completions.create(model=model_name, messages=messages, n=0, stream=True)\n\n        with pytest.raises(Exception):\n            client.chat.completions.create(model=model_name, messages=messages, n='test', stream=True)\n\n        with pytest.raises(Exception):\n            client.chat.completions.create(model=model_name, messages=messages, temperature=-0.01, stream=True)\n\n        with pytest.raises(Exception):\n            client.chat.completions.create(model=model_name, messages=messages, temperature=2.01, stream=True)\n\n        with pytest.raises(Exception):\n            client.chat.completions.create(model=model_name, messages=messages, temperature='test', stream=True)\n\n    @pytest.mark.interns1\n    def test_disable_think(self, backend, model_case):\n        client = OpenAI(api_key='YOUR_API_KEY', base_url=f'{BASE_URL}/v1')\n        model_name = client.models.list().data[0].id\n        output = client.chat.completions.create(model=model_name,\n                                                messages=[\n                                                    {\n                                                        'role': 'user',\n                                                        'content': 'Hi, pls intro yourself'\n                                                    },\n                                                ],\n                                                temperature=0.8,\n                                                top_p=0.8)\n        print(output)\n        assert '</think>' in str(output.model_dump())\n\n        output = client.chat.completions.create(model=model_name,\n                                                messages=[\n                                                    {\n                                                        'role': 'user',\n                                                        'content': 'Hi, pls intro yourself'\n                                                    },\n                                                ],\n                                                temperature=0.8,\n                                                top_p=0.8,\n                                                extra_body={\n                                                    'enable_thinking': False,\n                                                })\n\n        response = output.model_dump()\n        assert '</think>' not in response\n        assert_chat_completions_batch_return(response, model_name)\n\n    @pytest.mark.interns1\n    def test_disable_think_with_image(self, backend, model_case):\n        client = OpenAI(api_key='YOUR_API_KEY', base_url=f'{BASE_URL}/v1')\n        model_name = client.models.list().data[0].id\n        output = client.chat.completions.create(\n            model=model_name,\n            messages=[\n                {\n                    'role':\n                    'user',\n                    'content': [{\n                        'type': 'text',\n                        'text': 'Describe the image please',\n                    }, {\n                        'type': 'image_url',\n                        'image_url': {\n                            'url': 'https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg',\n                        },\n                    }],\n                },\n            ],\n            temperature=0.8,\n            top_p=0.8)\n        print(output)\n        assert '</think>' in str(output.model_dump())\n\n        output = client.chat.completions.create(\n            model=model_name,\n            messages=[\n                {\n                    'role':\n                    'user',\n                    'content': [{\n                        'type': 'text',\n                        'text': 'Describe the image please',\n                    }, {\n                        'type': 'image_url',\n                        'image_url': {\n                            'url': 'https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg',\n                        },\n                    }],\n                },\n            ],\n            temperature=0.8,\n            top_p=0.8,\n            extra_body={\n                'enable_thinking': False,\n            })\n\n        response = output.model_dump()\n        assert '</think>' not in response\n        assert_chat_completions_batch_return(response, model_name)\n"
  },
  {
    "path": "autotest/interface/restful/test_restful_completions_v1.py",
    "content": "import pytest\nfrom utils.constant import BACKEND_LIST, RESTFUL_BASE_MODEL_LIST\nfrom utils.restful_return_check import assert_completions_batch_return, assert_completions_stream_return\n\nfrom lmdeploy.serve.openai.api_client import APIClient\n\nBASE_HTTP_URL = 'http://localhost'\nDEFAULT_PORT = 23333\nMODEL = 'internlm/internlm2_5-20b'\nBASE_URL = ':'.join([BASE_HTTP_URL, str(DEFAULT_PORT)])\n\n\n@pytest.mark.parametrize('backend', BACKEND_LIST)\n@pytest.mark.parametrize('model_case', RESTFUL_BASE_MODEL_LIST)\nclass TestRestfulInterfaceBase:\n\n    @pytest.mark.internlm2_5\n    def test_get_model(self, config, backend, model_case):\n        api_client = APIClient(BASE_URL)\n        model_name = api_client.available_models[0]\n        assert model_name == '/'.join([config.get('model_path'), MODEL]), api_client.available_models\n\n    @pytest.mark.internlm2_5\n    def test_encode(self, backend, model_case):\n        api_client = APIClient(BASE_URL)\n        input_ids1, length1 = api_client.encode('Hi, pls intro yourself')\n        input_ids2, length2 = api_client.encode('Hi, pls intro yourself', add_bos=False)\n        input_ids3, length3 = api_client.encode('Hi, pls intro yourself', do_preprocess=True)\n        input_ids4, length4 = api_client.encode('Hi, pls intro yourself', do_preprocess=True, add_bos=False)\n        input_ids5, length5 = api_client.encode('Hi, pls intro yourself' * 100, add_bos=False)\n        assert len(input_ids1) == length1 and length1 > 0\n        assert len(input_ids2) == length2 and length2 > 0\n        assert len(input_ids3) == length3 and length3 > 0\n        assert len(input_ids4) == length4 and length4 > 0\n        assert len(input_ids5) == length5 and length5 > 0\n        assert length1 == length2 + 1\n        assert input_ids2 == input_ids1[1:]\n        assert input_ids1[0] == 1 and input_ids3[0] == 1\n        assert length5 == length2 * 100\n        assert input_ids5 == input_ids2 * 100\n\n    def test_return(self, backend, model_case):\n        api_client = APIClient(BASE_URL)\n        model_name = api_client.available_models[0]\n        for item in api_client.completions_v1(\n                model=model_name,\n                prompt='Hi, pls intro yourself',\n                max_tokens=16,\n                temperature=0.01,\n        ):\n            completion_tokens = item['usage']['completion_tokens']\n            assert completion_tokens > 0\n            assert completion_tokens <= 17\n            assert completion_tokens >= 16\n            assert item.get('choices')[0].get('finish_reason') in ['length']\n        assert_completions_batch_return(item, model_name)\n\n    def test_return_streaming(self, backend, model_case):\n        api_client = APIClient(BASE_URL)\n        model_name = api_client.available_models[0]\n        outputList = []\n        for item in api_client.completions_v1(model=model_name,\n                                              prompt='Hi, pls intro yourself',\n                                              max_tokens=16,\n                                              stream=True,\n                                              temperature=0.01):\n            outputList.append(item)\n        assert_completions_stream_return(outputList[-1], model_name, True)\n        for index in range(0, len(outputList) - 1):\n            assert_completions_stream_return(outputList[index], model_name)\n\n    def test_max_tokens(self, backend, model_case):\n        api_client = APIClient(BASE_URL)\n        model_name = api_client.available_models[0]\n        for item in api_client.completions_v1(model=model_name,\n                                              prompt='Hi, pls intro yourself',\n                                              max_tokens=16,\n                                              temperature=0.01):\n            completion_tokens = item['usage']['completion_tokens']\n            assert completion_tokens > 0\n            assert completion_tokens <= 17\n            assert completion_tokens >= 16\n            assert item.get('choices')[0].get('finish_reason') in ['length']\n\n    def test_single_stopword(self, backend, model_case):\n        api_client = APIClient(BASE_URL)\n        model_name = api_client.available_models[0]\n        for item in api_client.completions_v1(model=model_name,\n                                              prompt='Shanghai is',\n                                              max_tokens=200,\n                                              stop=' Shanghai',\n                                              temperature=0.01):\n            assert ' Shanghai' not in item.get('choices')[0].get('text')\n            assert item.get('choices')[0].get('finish_reason') in ['stop', 'length']\n\n    def test_array_stopwords(self, backend, model_case):\n        api_client = APIClient(BASE_URL)\n        model_name = api_client.available_models[0]\n        for item in api_client.completions_v1(model=model_name,\n                                              prompt='Shanghai is',\n                                              max_tokens=200,\n                                              stop=[' Shanghai', ' city', ' China'],\n                                              temperature=0.01):\n            assert ' Shanghai' not in item.get('choices')[0].get('text')\n            assert ' city' not in item.get('choices')[0].get('text')\n            assert ' China' not in item.get('choices')[0].get('text')\n            assert item.get('choices')[0].get('finish_reason') in ['stop', 'length']\n\n    def test_completions_stream(self, backend, model_case):\n        api_client = APIClient(BASE_URL)\n        model_name = api_client.available_models[0]\n        outputList = []\n        for output in api_client.completions_v1(model=model_name, prompt='Shanghai is', stream='true',\n                                                temperature=0.01):\n            outputList.append(output)\n\n        for index in range(1, len(outputList) - 1):\n            output = outputList[index]\n            assert (output.get('model') == model_name)\n            for message in output.get('choices'):\n                assert message.get('index') == 0\n                assert len(message.get('text')) > 0\n\n        output_last = outputList[len(outputList) - 1]\n        assert output_last.get('choices')[0].get('finish_reason') in ['stop', 'length']\n\n    def test_completions_stream_stopword(self, backend, model_case):\n        api_client = APIClient(BASE_URL)\n        model_name = api_client.available_models[0]\n        outputList = []\n        for output in api_client.completions_v1(model=model_name,\n                                                prompt='Beijing is',\n                                                stream='true',\n                                                stop=' is',\n                                                temperature=0.01):\n            outputList.append(output)\n\n        for index in range(1, len(outputList) - 2):\n            output = outputList[index]\n            assert (output.get('model') == model_name)\n            assert (output.get('object') == 'text_completion')\n            for message in output.get('choices'):\n                assert ' is' not in message.get('text')\n                assert message.get('index') == 0\n                assert len(message.get('text')) > 0\n\n        output_last = outputList[len(outputList) - 1]\n        assert output_last.get('choices')[0].get('text') == ''\n        assert output_last.get('choices')[0].get('finish_reason') in ['stop', 'length']\n\n    def test_completions_stream_stopwords(self, backend, model_case):\n        api_client = APIClient(BASE_URL)\n        model_name = api_client.available_models[0]\n        outputList = []\n        for output in api_client.completions_v1(model=model_name,\n                                                prompt='Beijing is',\n                                                stream='true',\n                                                stop=[' Beijing', ' city', ' China'],\n                                                temperature=0.01):\n            outputList.append(output)\n\n        for index in range(1, len(outputList) - 2):\n            output = outputList[index]\n            assert (output.get('model') == model_name)\n            assert (output.get('object') == 'text_completion')\n            for message in output.get('choices'):\n                assert ' Beijing' not in message.get('text')\n                assert ' city' not in message.get('text')\n                assert ' China' not in message.get('text')\n                assert message.get('index') == 0\n                assert len(message.get('text')) > 0\n\n        output_last = outputList[len(outputList) - 1]\n        assert output_last.get('choices')[0].get('text') == ''\n        assert output_last.get('choices')[0].get('finish_reason') in ['stop', 'length']\n\n    def test_batch_prompt_order(self, backend, model_case):\n        api_client = APIClient(BASE_URL)\n        model_name = api_client.available_models[0]\n        for item in api_client.completions_v1(model=model_name,\n                                              prompt=['你好', '今天天气怎么样', '你是谁', '帮我写一首以梅花为主题的五言律诗', '5+2等于多少'],\n                                              max_tokens=400,\n                                              min_tokens=50):\n            print(str(item))\n            assert '天' in item.get('choices')[1].get('text'), item.get('choices')[1].get('text')\n            assert '梅' in item.get('choices')[3].get('text') or '对仗' in item.get('choices')[3].get('text'), item.get(\n                'choices')[3].get('text')\n            assert '7' in item.get('choices')[4].get('text'), item.get('choices')[4].get('text')\n"
  },
  {
    "path": "autotest/interface/restful/test_restful_generate.py",
    "content": "import json\nimport os\nimport re\nimport time\nfrom concurrent.futures import ThreadPoolExecutor, as_completed\nfrom datetime import datetime\nfrom typing import Any\n\nimport pytest\nimport requests\nfrom transformers import AutoTokenizer\nfrom utils.constant import BACKEND_LIST, DEFAULT_SERVER, RESTFUL_MODEL_LIST\nfrom utils.toolkit import encode_text, parse_sse_stream\n\nBASE_HTTP_URL = f'http://{DEFAULT_SERVER}'\nDEFAULT_PORT = 23333\nBASE_URL = ':'.join([BASE_HTTP_URL, str(DEFAULT_PORT)])\n\n\n@pytest.mark.parametrize('backend', BACKEND_LIST)\n@pytest.mark.parametrize('model_name', RESTFUL_MODEL_LIST)\nclass TestGenerateComprehensive:\n\n    @pytest.fixture(autouse=True)\n    def setup_api(self, request, config, model_name, backend):\n        self.api_url = f'{BASE_URL}/generate'\n        self.headers = {'Content-Type': 'application/json'}\n        self.model_name = model_name\n\n        test_name = request.node.name\n        safe_test_name = re.sub(r'[^\\w\\.-]', '_', test_name)\n        safe_model_name = self.model_name.replace('/', '_')\n        log_base = config.get('log_path', './logs')\n        self.log_dir = os.path.join(log_base, safe_model_name)\n        os.makedirs(self.log_dir, exist_ok=True)\n        self.log_file = os.path.join(self.log_dir, f'{backend}_{safe_test_name}.log')\n\n    def _log_request_response(self, payload, response_data, stream_raw=None):\n        log_entry = {\n            'timestamp': datetime.now().isoformat(),\n            'model': self.model_name,\n            'request': payload,\n            'response': response_data,\n        }\n        if stream_raw is not None:\n            log_entry['stream_raw'] = stream_raw\n\n        try:\n            with open(self.log_file, 'a', encoding='utf-8') as f:\n                json.dump(log_entry, f, indent=2, ensure_ascii=False)\n                f.write('\\n')\n        except Exception as e:\n            print(f'[LOG WARN] Failed to write {self.log_file}: {e}')\n\n    def _post(self, payload, stream=False):\n        if 'model' not in payload:\n            payload['model'] = self.model_name\n\n        resp = requests.post(self.api_url, json=payload, headers=self.headers, stream=stream, timeout=60)\n        resp.raise_for_status()\n\n        if stream:\n            raw_content = ''\n            for chunk in resp.iter_content(chunk_size=None):\n                if chunk:\n                    raw_content += chunk.decode('utf-8')\n\n            events = parse_sse_stream(raw_content)\n            accumulated_text = ''\n            output_ids = []\n            stream_events_count = 0\n\n            for event in events:\n                if event == '[DONE]':\n                    break\n                try:\n                    data_str = event.replace('data: ', '').strip()\n                    if not data_str:\n                        continue\n                    data = json.loads(data_str)\n                    delta = data.get('text', '')\n                    if isinstance(delta, str):\n                        accumulated_text += delta\n                    ids = data.get('output_ids')\n                    if isinstance(ids, list):\n                        output_ids.extend(ids)\n                    stream_events_count += 1\n                except Exception as e:\n                    print(f'Error parsing stream event: {e}')\n                    continue\n\n            fake_resp = {\n                'text': accumulated_text,\n                'output_ids': output_ids,\n                'meta_info': {\n                    'stream_events': stream_events_count\n                }\n            }\n            self._log_request_response(payload, fake_resp, raw_content)\n\n            class MockResp:\n\n                def json(self):\n                    return fake_resp\n\n                @property\n                def status_code(self):\n                    return 200\n\n            return MockResp()\n\n        else:\n            data = resp.json()\n            self._log_request_response(payload, data)\n            return resp\n\n    def _validate_generation_response(self,\n                                      data: dict[str, Any],\n                                      expected_fields: list[str] | None = None,\n                                      validate_tokens: bool = True,\n                                      expect_logprobs: bool = False,\n                                      validate_experts: bool = False) -> None:\n        assert isinstance(data, dict), f'Response should be a dict, got {type(data)}'\n\n        required_fields = ['text']\n        for field in required_fields:\n            assert field in data, f'Missing required field: {field}'\n            assert data[field] is not None, f'Field {field} should not be None'\n\n        assert isinstance(data['text'], str), \\\n            f\"text should be string, got {type(data['text'])}\"\n\n        if validate_experts:\n            assert 'routed_experts' in data[\n                'meta_info'], \"Response should contain 'routed_experts' when validate_experts=True\"\n\n            experts_data = data['meta_info']['routed_experts']\n\n            assert isinstance(experts_data, list)\n            assert len(experts_data) > 0\n\n            total_steps = len(experts_data)\n\n            for step_idx in range(total_steps):\n                token_experts = experts_data[step_idx]\n\n                assert isinstance(token_experts, list)\n                assert len(token_experts) > 0\n\n                for layer_idx in range(len(token_experts)):\n                    layer_experts = token_experts[layer_idx]\n\n                    assert isinstance(layer_experts, list)\n                    assert len(layer_experts) == 8\n\n                    for expert_idx, expert_id in enumerate(layer_experts):\n                        assert isinstance(expert_id, int)\n                        assert 0 <= expert_id < 256, f'Invalid expert_id: {expert_id}. Must be in [0, 256)'\n\n        if validate_tokens:\n            assert 'output_ids' in data, \"Response should contain 'output_ids'\"\n            output_ids = data['output_ids']\n\n            assert isinstance(output_ids, list), \\\n                f'output_ids should be list, got {type(output_ids)}'\n            assert len(output_ids) >= 0, 'output_ids should not be empty'\n\n            for i, token_id in enumerate(output_ids):\n                assert isinstance(token_id, int), \\\n                    f'output_ids[{i}] should be int, got {type(token_id)}'\n\n            if 'meta_info' in data:\n                meta = data['meta_info']\n                assert isinstance(meta, dict), 'meta_info should be dict'\n\n                if 'completion_tokens' in meta:\n                    assert meta['completion_tokens'] == len(output_ids), \\\n                        f\"meta.completion_tokens ({meta['completion_tokens']}) \" \\\n                        f'should equal len(output_ids) ({len(output_ids)})'\n\n        if expect_logprobs:\n            assert 'meta_info' in data, \\\n                \"Response should contain 'meta_info' when expecting logprobs\"\n            meta = data['meta_info']\n            assert isinstance(meta, dict)\n\n            assert 'output_token_logprobs' in meta, \\\n                \"meta_info missing 'output_token_logprobs'\"\n            logprobs_data = meta['output_token_logprobs']\n\n            assert isinstance(logprobs_data, list), \\\n                'output_token_logprobs should be a list'\n            assert len(logprobs_data) > 0, \\\n                'output_token_logprobs should not be empty'\n\n            if 'output_ids' in data:\n                assert len(logprobs_data) == len(data['output_ids']), \\\n                    f'Logprobs outer list length ({len(logprobs_data)}) != ' \\\n                    f\"Output IDs length ({len(data['output_ids'])})\"\n\n            for idx, item in enumerate(logprobs_data):\n                assert isinstance(item, list), \\\n                    f'Logprobs item at index {idx} should be a list, got {type(item)}'\n                assert len(item) == 2, \\\n                    f'Logprobs item at index {idx} should have 2 elements ' \\\n                    f'[logprob, token_id], got {len(item)}'\n\n                logprob_val = item[0]\n                assert isinstance(logprob_val, (float, int)), \\\n                    f'Logprob value at [{idx}][0] should be number, ' \\\n                    f'got {type(logprob_val)}'\n                assert logprob_val <= 0, \\\n                    f'Logprob value should be <= 0, got {logprob_val}'\n\n                token_id_in_logprob = item[1]\n                assert isinstance(token_id_in_logprob, int), \\\n                    f'Token ID in logprobs at [{idx}][1] should be int, ' \\\n                    f'got {type(token_id_in_logprob)}'\n\n                if 'output_ids' in data and idx < len(data['output_ids']):\n                    assert token_id_in_logprob == data['output_ids'][idx], \\\n                        f'Token ID mismatch at index {idx}: output_ids has ' \\\n                        f\"{data['output_ids'][idx]}, but logprobs has \" \\\n                        f'{token_id_in_logprob}'\n\n        if expected_fields:\n            for field in expected_fields:\n                assert field in data, f'Missing expected field: {field}'\n\n        if 'error' in data:\n            assert not data['error'], f\"Response contains error: {data['error']}\"\n        if 'code' in data and data['code'] != 0:\n            assert False, f\"Response contains error code: {data['code']}\"\n\n    def test_basic_generation(self):\n        print(f'\\n[Model: {self.model_name}] Running basic generation test')\n        test_cases = [{\n            'name': 'simple prompt',\n            'payload': {\n                'prompt': 'The sky is',\n                'max_tokens': 5\n            },\n        }, {\n            'name': 'prompt with spaces',\n            'payload': {\n                'prompt': '  Hello world  ',\n                'max_tokens': 3\n            },\n        }, {\n            'name': 'unicode prompt',\n            'payload': {\n                'prompt': 'Hello, world',\n                'max_tokens': 3\n            },\n        }, {\n            'name': 'longer generation',\n            'payload': {\n                'prompt': 'Once upon a time',\n                'max_tokens': 10\n            },\n        }]\n\n        for test_case in test_cases:\n            test_name = test_case['name']\n            print(f'\\n[Test: {test_name}]')\n\n            resp = self._post(test_case['payload'])\n            data = resp.json()\n\n            self._validate_generation_response(data=data, validate_tokens=True)\n\n            prompt = test_case['payload']['prompt']\n            generated_text = data['text']\n            assert generated_text != prompt.strip(), \\\n                f\"Generated text should be different from prompt: '{generated_text}'\"\n\n            if 'output_ids' in data:\n                output_ids = data['output_ids']\n                max_tokens = test_case['payload']['max_tokens']\n                max_allowed = max_tokens + 1\n\n                assert len(output_ids) <= max_allowed, \\\n                    f'Too many tokens generated: {len(output_ids)} > {max_allowed}'\n\n                meta = data.get('meta_info', {})\n                finish_type = meta.get('finish_reason', {}).get('type')\n                if len(output_ids) >= max_tokens and finish_type != 'length':\n                    print(f'[WARN] Generated {len(output_ids)} tokens but '\n                          f\"finish_reason is not 'length': {finish_type}\")\n\n            print(f\"  Generated text: '{generated_text[:50]}...'\")\n            print(f\"  Generated tokens: {len(data.get('output_ids', []))}\")\n\n    def test_input_ids_mode(self, config):\n        print(f'\\n[Model: {self.model_name}] Running input_ids mode test')\n        model_path = os.path.join(config.get('model_path'), self.model_name)\n\n        test_cases = [{\n            'name': 'simple text',\n            'text': 'Hello world',\n            'max_tokens': 5,\n            'expected_min_text': 3\n        }, {\n            'name': 'question',\n            'text': 'What is the meaning of life?',\n            'max_tokens': 8,\n            'expected_min_text': 5\n        }, {\n            'name': 'short input',\n            'text': 'Yes',\n            'max_tokens': 3,\n            'expected_min_text': 1\n        }]\n\n        for test_case in test_cases:\n            test_name = test_case['name']\n            print(f'\\n[Test: input_ids - {test_name}]')\n\n            try:\n                input_ids = encode_text(model_path, test_case['text'])\n            except Exception as e:\n                pytest.skip(f'Tokenizer failed for {test_name}: {e}')\n\n            assert isinstance(input_ids, list), \\\n                f'input_ids should be list, got {type(input_ids)}'\n            assert len(input_ids) > 0, 'input_ids should not be empty'\n            for i, token_id in enumerate(input_ids):\n                assert isinstance(token_id, int), \\\n                    f'input_ids[{i}] should be int, got {type(token_id)}'\n                assert token_id >= 0, \\\n                    f'input_ids[{i}] should be >= 0, got {token_id}'\n\n            resp = self._post({'input_ids': input_ids, 'max_tokens': test_case['max_tokens']})\n            data = resp.json()\n\n            self._validate_generation_response(data=data, validate_tokens=True)\n\n            generated_text = data['text']\n            try:\n                generated_text.encode('utf-8')\n            except UnicodeEncodeError:\n                pytest.fail(f'Generated text contains invalid UTF-8 characters: '\n                            f'{generated_text[:100]}')\n\n            print(f'  Input tokens: {len(input_ids)}')\n            print(f\"  Output tokens: {len(data.get('output_ids', []))}\")\n            print(f\"  Generated text: '{generated_text[:50]}...'\")\n\n    def test_conflict_prompt_and_input_ids(self):\n        print(f'\\n[Model: {self.model_name}] Running conflict test')\n        test_cases = [{\n            'name':\n            'both provided',\n            'payload': {\n                'prompt': 'Hello world',\n                'input_ids': [1, 2, 3, 4, 5],\n                'max_tokens': 5\n            },\n            'expected_status':\n            400,\n            'expected_error_keywords': [\n                'conflict', 'both', 'either', 'cannot', 'mutually exclusive', 'specify exactly one', 'prompt',\n                'input_ids'\n            ]\n        }, {\n            'name':\n            'prompt with empty input_ids',\n            'payload': {\n                'prompt': 'Test',\n                'input_ids': [],\n                'max_tokens': 3\n            },\n            'expected_status':\n            400,\n            'expected_error_keywords': ['conflict', 'invalid', 'empty', 'specify exactly one', 'prompt', 'input_ids']\n        }, {\n            'name':\n            'empty prompt with input_ids',\n            'payload': {\n                'prompt': '',\n                'input_ids': [100, 200, 300],\n                'max_tokens': 3\n            },\n            'expected_status':\n            400,\n            'expected_error_keywords': ['conflict', 'empty', 'invalid', 'specify exactly one', 'prompt', 'input_ids']\n        }]\n\n        for test_case in test_cases:\n            test_name = test_case['name']\n            print(f'\\n[Test: conflict - {test_name}]')\n\n            try:\n                resp = requests.post(self.api_url, json=test_case['payload'], headers=self.headers, timeout=30)\n\n                assert resp.status_code == test_case['expected_status'], \\\n                    f\"Expected status {test_case['expected_status']}, \" \\\n                    f'got {resp.status_code}'\n\n                error_data = resp.json()\n                assert 'error' in error_data or 'message' in error_data, \\\n                    \"Error response should contain 'error' or 'message' field\"\n\n                error_msg = ''\n                if 'error' in error_data:\n                    error_msg = str(error_data['error']).lower()\n                elif 'message' in error_data:\n                    error_msg = str(error_data['message']).lower()\n\n                keywords_found = any(keyword in error_msg for keyword in test_case['expected_error_keywords'])\n\n                if not keywords_found:\n                    has_both_fields = ('prompt' in error_msg and 'input_ids' in error_msg)\n                    has_exclusivity = any(phrase in error_msg for phrase in [\n                        'only one', 'specify exactly', 'cannot both', 'mutually exclusive', 'exactly one',\n                        'must specify'\n                    ])\n                    if has_both_fields and has_exclusivity:\n                        keywords_found = True\n\n                assert keywords_found, \\\n                    f'Error message should indicate conflict between prompt and ' \\\n                    f'input_ids, got: {error_msg}'\n\n                assert 'text' not in error_data, \\\n                    \"Error response should not contain 'text' field\"\n                assert 'output_ids' not in error_data, \\\n                    \"Error response should not contain 'output_ids' field\"\n\n                print(f'  Got expected error: {error_msg[:100]}...')\n\n            except Exception as e:\n                print(f'  Unexpected error: {e}')\n                raise\n\n    @pytest.mark.logprob\n    def test_input_ids_with_logprob(self, config):\n        print(f'\\n[Model: {self.model_name}] Running input_ids with logprob test')\n        model_path = os.path.join(config.get('model_path'), self.model_name)\n\n        test_cases = [{\n            'name': 'basic logprob',\n            'text': 'The weather is',\n            'max_tokens': 3,\n            'expected_min_text': 3\n        }, {\n            'name': 'single token generation',\n            'text': 'Hello',\n            'max_tokens': 1,\n            'expected_min_text': 1\n        }, {\n            'name': 'multiple tokens with logprob',\n            'text': 'Artificial intelligence is',\n            'max_tokens': 5,\n            'expected_min_text': 5\n        }]\n\n        for test_case in test_cases:\n            test_name = test_case['name']\n            print(f'\\n[Test: logprob - {test_name}]')\n\n            try:\n                input_ids = encode_text(model_path, test_case['text'])\n            except Exception as e:\n                pytest.skip(f'Tokenizer failed for {test_name}: {e}')\n\n            request_payload = {'input_ids': input_ids, 'max_tokens': test_case['max_tokens'], 'return_logprob': True}\n\n            resp = self._post(request_payload)\n            data = resp.json()\n\n            self._validate_generation_response(data=data, validate_tokens=True, expect_logprobs=True)\n\n            assert 'meta_info' in data, \\\n                \"Response should contain 'meta_info' when return_logprob=True\"\n            meta = data['meta_info']\n\n            assert 'output_token_logprobs' in meta, \\\n                \"meta_info should contain 'output_token_logprobs'\"\n            logprobs = meta['output_token_logprobs']\n\n            logprob_values = []\n\n            for i, item in enumerate(logprobs):\n                logprob_values.append(item[0])\n\n            avg_logprob = sum(logprob_values) / len(logprob_values)\n            if avg_logprob < -15.0:\n                pytest.fail(f'Generation confidence critically low '\n                            f'(Avg: {avg_logprob:.2f})')\n\n            generated_text = data.get('text', '')\n            print(f'  Generated tokens: {len(logprob_values)}')\n            print(f'  Avg Logprob: {avg_logprob:.3f}')\n            print(f\"  Generated text: '{generated_text[:50]}...'\")\n\n    def test_stop_str_with_include_flag(self):\n        print(f'\\n[Model: {self.model_name}] Running stop_str with include flag test')\n        test_cases = [{\n            'name': 'simple stop word',\n            'prompt': 'Count to 10: 1, 2, 3, ',\n            'stop_word': '6',\n            'max_tokens': 20,\n        }]\n\n        for test_case in test_cases:\n            test_name = test_case['name']\n            print(f'\\n[Test: stop_str - {test_name}]')\n\n            prompt = test_case['prompt']\n            stop_word = test_case['stop_word']\n            max_tokens = test_case['max_tokens']\n\n            print('  Testing EXCLUDE mode (include_stop=False)...')\n            resp1 = self._post({\n                'prompt': prompt,\n                'max_tokens': max_tokens,\n                'stop': [stop_word],\n                'include_stop_str_in_output': False,\n                'return_logprob': True\n            })\n\n            self._validate_generation_response(resp1.json())\n            text_exclude = resp1.json()['text']\n            assert stop_word not in text_exclude, \\\n                f\"Stop word '{stop_word}' should NOT be in output when include_stop=False\"\n\n            print('  Testing INCLUDE mode (include_stop=True)...')\n            resp2 = self._post({\n                'prompt': prompt,\n                'max_tokens': max_tokens,\n                'stop': [stop_word],\n                'include_stop_str_in_output': True,\n                'return_logprob': True\n            })\n\n            self._validate_generation_response(resp2.json())\n            text_include = resp2.json()['text']\n            assert stop_word in text_include, \\\n                f\"Stop word '{stop_word}' should be in output when include_stop=True\"\n\n    def test_streaming_mode(self):\n        print(f'\\n[Model: {self.model_name}] Running streaming mode test')\n        prompt = 'Count to 10: 1, 2,'\n\n        resp = self._post({'prompt': prompt, 'max_tokens': 8, 'stream': True}, stream=True)\n        assert resp.status_code == 200\n        data = resp.json()\n\n        text = data['text']\n        output_ids = data['output_ids']\n        meta = data['meta_info']\n\n        assert isinstance(text, str) and len(text.strip()) > 0, \\\n            'Generated text cannot be empty'\n        assert len(output_ids) >= 3, 'Output token count should be reasonable'\n\n        import re\n        count_matches = len(re.findall(r'\\b[3-9]\\b', text))\n        assert count_matches >= 2, \\\n            f'Expected continuation of counting, but not enough numbers found ' \\\n            f'(found {count_matches})'\n\n        stream_events = meta.get('stream_events', [])\n        assert stream_events <= len(output_ids) + 2, \\\n            'Streaming event count should be less than output token count'\n\n        print(f\"  Generated text: '{text}'\")\n        print(f'  Output tokens: {len(output_ids)}, '\n              f'Stream events: {stream_events}')\n\n    def test_streaming_incremental_correctness(self):\n        print(f'\\n[Model: {self.model_name}] Running streaming incremental correctness test')\n        prompt = 'The sky is '\n\n        raw_resp = requests.post(self.api_url,\n                                 json={\n                                     'prompt': prompt,\n                                     'max_tokens': 10,\n                                     'stream': True\n                                 },\n                                 headers=self.headers,\n                                 stream=True,\n                                 timeout=30)\n        raw_resp.raise_for_status()\n\n        full_text_from_delta = ''\n        tokens_from_delta = []\n        event_count = 0\n\n        print('  Streaming chunks:')\n        for line in raw_resp.iter_lines():\n            if line:\n                line_str = line.decode('utf-8').strip()\n                if line_str.startswith('data: ') and '[DONE]' not in line_str:\n                    try:\n                        json_str = line_str[6:]\n                        payload = json.loads(json_str)\n\n                        delta_text = payload.get('text', '')\n                        token_id = payload.get('token_id')\n\n                        full_text_from_delta += delta_text\n                        if token_id is not None:\n                            tokens_from_delta.append(token_id)\n\n                        event_count += 1\n                        if delta_text.strip():\n                            print(f\"+'{delta_text}'\")\n\n                    except Exception as e:\n                        print(f'[Parse warning]: {e}')\n                        continue\n\n        assert len(full_text_from_delta.strip()) > 0, \\\n            'Assembled text from streaming deltas is empty'\n        assert event_count >= 3, \\\n            f'Too few streaming events received ({event_count}), ' \\\n            f'connection might be interrupted'\n\n        print(f\"  Final assembled text: '{full_text_from_delta}'\")\n        print(f'  Total events received: {event_count}')\n\n    @pytest.mark.logprob\n    def test_return_logprob(self):\n        print(f'\\n[Model: {self.model_name}] Running return_logprob test')\n\n        resp = self._post({'prompt': 'Paris is the capital of', 'max_tokens': 2, 'return_logprob': True})\n        data = resp.json()\n\n        self._validate_generation_response(data, validate_tokens=True, expect_logprobs=True)\n\n        print(f\"  Generated text: '{data['text']}'\")\n\n    def test_same_session_id_allowed(self):\n        print(f'\\n[Model: {self.model_name}] Running same session_id test')\n        sid = int(time.time_ns()) % 100000\n\n        resp1 = self._post({'prompt': 'First message:', 'session_id': sid, 'max_tokens': 2})\n        resp2 = self._post({'prompt': 'Second message:', 'session_id': sid, 'max_tokens': 2})\n\n        assert resp1.status_code == 200\n        assert resp2.status_code == 200\n\n        data1 = resp1.json()\n        data2 = resp2.json()\n\n        self._validate_generation_response(data1)\n        self._validate_generation_response(data2)\n\n        text1 = data1['text'].strip()\n        text2 = data2['text'].strip()\n        assert text1 != text2\n\n        print(f\"  First response: '{data1['text']}'\")\n        print(f\"  Second response: '{data2['text']}'\")\n\n    def test_empty_prompt_rejected(self):\n        print(f'\\n[Model: {self.model_name}] Running empty prompt test')\n\n        with pytest.raises(requests.HTTPError) as exc:\n            self._post({'prompt': '', 'max_tokens': 5})\n\n        assert exc.value.response.status_code == 400\n\n        try:\n            error_response = exc.value.response.json()\n            print(f'  Error response: {error_response}')\n            assert 'error' in error_response or 'message' in error_response\n        except json.JSONDecodeError:\n            print(f'  Non-JSON error: {exc.value.response.text[:100]}')\n\n    def test_input_ids_rejected(self):\n        print(f'\\n[Model: {self.model_name}] Running input_ids invalid cases test')\n\n        invalid_cases = [{\n            'case': {\n                'input_ids': [],\n                'max_tokens': 5\n            },\n            'desc': 'Empty input_ids list'\n        }, {\n            'case': {\n                'input_ids': 'not_a_list',\n                'max_tokens': 5\n            },\n            'desc': 'input_ids is a string, not list'\n        }, {\n            'case': {\n                'max_tokens': 5\n            },\n            'desc': 'Missing input_ids field'\n        }]\n\n        for invalid_case in invalid_cases:\n            test_desc = invalid_case['desc']\n            payload = invalid_case['case']\n\n            with pytest.raises(requests.HTTPError) as exc_info:\n                self._post(payload)\n\n            response = exc_info.value.response\n            assert response.status_code in [400, 422], (f\"Bad Request for case '{test_desc}', \"\n                                                        f'but got {response.status_code}')\n\n    def test_stress_concurrent_requests(self):\n        print(f'\\n[Model: {self.model_name}] Running stress concurrent requests test')\n\n        def single_request(idx):\n            start_time = time.time()\n            try:\n                resp = requests.post(self.api_url,\n                                     json={\n                                         'prompt': f'Hello, task {idx}',\n                                         'max_tokens': 5,\n                                         'stream': False\n                                     },\n                                     headers=self.headers,\n                                     timeout=10)\n                resp.raise_for_status()\n                data = resp.json()\n\n                if 'text' in data and len(data['text'].strip()) > 0:\n                    latency = time.time() - start_time\n                    return {'success': True, 'latency': latency}\n                else:\n                    return {'success': False, 'error': 'Empty response'}\n\n            except Exception as e:\n                return {'success': False, 'error': str(e)}\n\n        success_count = 0\n        total_latency = 0\n        failures = []\n\n        with ThreadPoolExecutor(max_workers=10) as executor:\n            futures = [executor.submit(single_request, i) for i in range(20)]\n\n            for i, future in enumerate(as_completed(futures)):\n                result = future.result()\n                if result['success']:\n                    success_count += 1\n                    total_latency += result['latency']\n                    print(f\"  Req {i}: ✓ (Latency: {result['latency']:.2f}s)\")\n                else:\n                    failures.append(result['error'])\n                    print(f'  Req {i}: ✗')\n\n        success_rate = success_count / 20\n        assert success_rate == 1.0, \\\n            f'Stress test failed: success rate {success_rate*100}% < 80%'\n\n        if success_count > 0:\n            avg_latency = total_latency / success_count\n            assert avg_latency < 5.0, \\\n                f'Average latency too high: {avg_latency:.2f}s'\n            print(f'  Performance: Avg Latency={avg_latency:.2f}s')\n\n        print(f'  Summary: {success_count}/20 succeeded')\n\n    def test_stress_long_prompt_and_generation(self):\n        print(f'\\n[Model: {self.model_name}] Running stress long prompt test')\n\n        long_prompt = 'Summarize: The quick brown fox jumps over the lazy dog. ' * 100\n\n        resp = self._post({'prompt': long_prompt, 'max_tokens': 512, 'temperature': 0.7})\n\n        data = resp.json()\n        self._validate_generation_response(data=data, validate_tokens=True)\n\n    def test_stress_streaming_under_load(self):\n        print(f'\\n[Model: {self.model_name}] Running stress streaming under load test')\n\n        def stream_request(idx):\n            try:\n                resp = requests.post(self.api_url,\n                                     json={\n                                         'prompt': f'Stream load test {idx}',\n                                         'max_tokens': 10,\n                                         'stream': True\n                                     },\n                                     headers=self.headers,\n                                     stream=True,\n                                     timeout=30)\n\n                assert resp.status_code == 200\n                content_type = resp.headers.get('Content-Type', '')\n                assert 'text/event-stream' in content_type or \\\n                    'application/x-ndjson' in content_type\n\n                full_text = ''\n                event_count = 0\n                for line in resp.iter_lines():\n                    if line and line.startswith(b'data:'):\n                        event_count += 1\n                        if b'[DONE]' in line:\n                            break\n                        try:\n                            payload = json.loads(line.decode().replace('data: ', '', 1))\n                            full_text += payload.get('text', '')\n                        except Exception:\n                            pass\n\n                assert len(full_text) > 0\n                assert event_count >= 3\n\n                return True\n\n            except Exception as e:\n                print(f'  Stream {idx} error: {e}')\n                return False\n\n        with ThreadPoolExecutor(max_workers=5) as executor:\n            futures = [executor.submit(stream_request, i) for i in range(10)]\n            results = [f.result() for f in futures]\n\n        success_count = sum(results)\n\n        assert success_count == 10, \\\n            f'Concurrent streaming test failure rate too high: {success_count}/10'\n\n        print(f'  Streaming under load: {success_count}/10 succeeded')\n\n    def test_temperature_parameter(self):\n        print(f'\\n[Model: {self.model_name}] Running temperature parameter test')\n        prompt = 'The capital of France is'\n\n        resp_low = self._post({'prompt': prompt, 'max_tokens': 10, 'temperature': 0.1, 'stream': False})\n        resp_high = self._post({'prompt': prompt, 'max_tokens': 10, 'temperature': 0.9, 'stream': False})\n\n        data_low = resp_low.json()\n        data_high = resp_high.json()\n\n        self._validate_generation_response(data=data_low, validate_tokens=True)\n        self._validate_generation_response(data=data_high, validate_tokens=True)\n\n        assert 'Paris' in data_low['text'] or \\\n            'paris' in data_low['text'].lower(), \\\n            \"Low temperature didn't answer correct capital\"\n        assert data_low['text'] != data_high['text'], \\\n            'High and low temperature outputs identical, ' \\\n            'temperature may not be effective'\n\n    def test_top_p_parameter(self):\n        print(f'\\n[Model: {self.model_name}] Running top_p parameter test')\n        prompt = 'The weather today is'\n\n        resp_strict = self._post({'prompt': prompt, 'max_tokens': 20, 'top_p': 0.01, 'stream': False})\n        resp_loose = self._post({'prompt': prompt, 'max_tokens': 20, 'top_p': 0.99, 'stream': False})\n\n        text_strict = resp_strict.json()\n        text_loose = resp_loose.json()\n\n        self._validate_generation_response(data=text_strict, validate_tokens=True)\n        self._validate_generation_response(data=text_loose, validate_tokens=True)\n\n    def test_top_k_parameter(self):\n        print(f'\\n[Model: {self.model_name}] Running top_k parameter test')\n        prompt = 'Artificial intelligence'\n\n        resp_k10 = self._post({'prompt': prompt, 'max_tokens': 10, 'top_k': 10, 'stream': False})\n        resp_k50 = self._post({'prompt': prompt, 'max_tokens': 10, 'top_k': 50, 'stream': False})\n\n        text_k10 = resp_k10.json()\n        text_k50 = resp_k50.json()\n\n        self._validate_generation_response(data=text_k10, validate_tokens=True)\n        self._validate_generation_response(data=text_k50, validate_tokens=True)\n\n    def test_min_p_parameter(self):\n        print(f'\\n[Model: {self.model_name}] Running min_p parameter test')\n        prompt = 'Machine learning is'\n\n        resp = self._post({'prompt': prompt, 'max_tokens': 10, 'min_p': 0.05, 'stream': False})\n        data = resp.json()\n        self._validate_generation_response(data)\n\n    def test_repetition_penalty(self):\n        print(f'\\n[Model: {self.model_name}] Running repetition penalty test')\n        prompt = 'Repeat repeat repeat repeat'\n\n        resp_no_penalty = self._post({'prompt': prompt, 'max_tokens': 10, 'repetition_penalty': 1.0, 'stream': False})\n        resp_penalty = self._post({'prompt': prompt, 'max_tokens': 10, 'repetition_penalty': 1.5, 'stream': False})\n\n        text_no_penalty = resp_no_penalty.json()['text']\n        text_penalty = resp_penalty.json()['text']\n\n        def count_repeats(text):\n            words = text.lower().split()\n            return sum(1 for i in range(1, len(words)) if words[i] == words[i - 1])\n\n        repeats_no_penalty = count_repeats(text_no_penalty)\n        repeats_penalty = count_repeats(text_penalty)\n\n        assert repeats_penalty <= repeats_no_penalty, (\n            f'High penalty coefficient ({1.5}) repetition count ({repeats_penalty}) '\n            f'not less than low penalty ({1.0}) count ({repeats_no_penalty}), '\n            f'repetition_penalty ineffective')\n\n    def test_ignore_eos_parameter(self):\n        print(f'\\n[Model: {self.model_name}] Running ignore_eos parameter test')\n        prompt = 'The sky is blue.'\n\n        resp_normal = self._post({'prompt': prompt, 'ignore_eos': False, 'stream': False})\n        data_normal = resp_normal.json()\n        self._validate_generation_response(data_normal)\n\n        resp_ignore = self._post({'prompt': prompt, 'ignore_eos': True, 'stream': False})\n        data_ignore = resp_ignore.json()\n        self._validate_generation_response(data_ignore)\n\n        reason_ignore = data_ignore.get('meta_info', {}).get('finish_reason', {}).get('type', 'unknown')\n\n        assert reason_ignore == 'length', \\\n            f'ignore_eos=True must end due to length, actual: {reason_ignore}'\n\n    def test_skip_special_tokens(self, config):\n        print(f'[Model: {self.model_name}] Running skip_special_tokens test')\n        model_path = os.path.join(config.get('model_path'), self.model_name)\n        user_content = 'Hello [world]! This is a [test].'\n\n        tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)\n        special_tokens_map = tokenizer.special_tokens_map\n\n        special_patterns = list(special_tokens_map.values())\n        special_patterns = [\n            item for sublist in special_patterns for item in (sublist if isinstance(sublist, list) else [sublist])\n        ]\n\n        print('Special patterns:', special_patterns)\n\n        print(' Executing skip_special_tokens=True')\n        payload_true = {'prompt': user_content, 'max_tokens': 100, 'skip_special_tokens': True, 'stream': False}\n        resp_true = self._post(payload_true)\n        data_true = resp_true.json()\n        self._validate_generation_response(data=data_true, validate_tokens=True)\n        generated_text = data_true['text']\n        assert not any(pattern in generated_text for pattern in special_patterns), \\\n            'Expected no special pattern in the generated text but found one.'\n\n    def test_stop_token_ids(self):\n        print(f'\\n[Model: {self.model_name}] Running stop_token_ids test')\n        payload = {'prompt': 'Once upon a time', 'max_tokens': 500, 'stop_token_ids': [11, 281], 'stream': False}\n\n        resp = self._post(payload)\n        assert resp.status_code == 200, \\\n            f'HTTP request failed, status code: {resp.status_code}'\n\n        try:\n            data = resp.json()\n        except Exception as e:\n            pytest.fail(f'Response JSON parsing failed: {e}')\n\n        self._validate_generation_response(data)\n\n        generated_text = data.get('text', '')\n        finish_reason = data.get('meta_info', {}).get('finish_reason', {}).get('type', 'unknown')\n        actual_length = len(generated_text)\n\n        print(f'\\n stop_token_ids=[11, 281] generation result: length={actual_length}, '\n              f\"end reason='{finish_reason}', text='{generated_text[:20]}...'\")\n\n        assert finish_reason in ['stop'], \\\n            f'Expected generation to end due to stop token, ' \\\n            f'actual reason: {finish_reason}. This may mean stop_token_ids [11, 281] ' \\\n            f\"didn't take effect, or generation was truncated.\"\n\n    def test_combined_parameters(self):\n        print(f'\\n[Model: {self.model_name}] Running combined parameters test')\n        resp = self._post({\n            'prompt': 'The future of AI',\n            'max_tokens': 15,\n            'temperature': 0.7,\n            'top_p': 0.9,\n            'top_k': 40,\n            'repetition_penalty': 1.1,\n            'stream': False\n        })\n\n        assert resp.status_code == 200\n        data = resp.json()\n        self._validate_generation_response(data)\n\n    def test_streaming_with_all_parameters(self):\n        print(f'\\n[Model: {self.model_name}] Running streaming with all parameters test')\n        resp = self._post(\n            {\n                'prompt': 'Streaming test with parameters',\n                'max_tokens': 10,\n                'temperature': 0.8,\n                'top_p': 0.85,\n                'top_k': 30,\n                'repetition_penalty': 1.2,\n                'stop': ['test'],\n                'stream': True\n            },\n            stream=True)\n\n        assert resp.status_code == 200\n        data = resp.json()\n        self._validate_generation_response(data)\n\n        stream_events = data['meta_info'].get('stream_events', [])\n\n        assert stream_events <= len(data['output_ids']) + 2, \\\n            'Streaming event count should be less than generated token count'\n\n    def test_invalid_temperature_values(self):\n        print(f'\\n[Model: {self.model_name}] Running invalid temperature values test')\n        resp1 = self._post({'prompt': 'Test', 'max_tokens': 3, 'temperature': 0.0, 'stream': False})\n        assert resp1.status_code == 200, 'temperature=0.0 should be valid'\n\n        with pytest.raises(requests.HTTPError) as exc_info:\n            self._post({'prompt': 'Test', 'max_tokens': 3, 'temperature': -0.5, 'stream': False})\n        assert exc_info.value.response.status_code in [400, 422]\n\n        print('  Invalid temperature values test passed')\n\n    def test_invalid_top_p_values(self):\n        print(f'\\n[Model: {self.model_name}] Running invalid top_p values test')\n        with pytest.raises(requests.HTTPError) as exc_info:\n            self._post({'prompt': 'Test', 'max_tokens': 3, 'top_p': 1.5, 'stream': False})\n        assert exc_info.value.response.status_code in [400, 422]\n\n        print('  Invalid top_p values test passed')\n\n    def test_invalid_top_k_values(self):\n        print(f'\\n[Model: {self.model_name}] Running invalid top_k values test')\n        with pytest.raises(requests.HTTPError) as exc_info:\n            self._post({'prompt': 'Test', 'max_tokens': 3, 'top_k': -5, 'stream': False})\n        assert exc_info.value.response.status_code in [400, 422]\n\n        print('  Invalid top_k values test passed')\n\n    def test_boundary_max_tokens(self):\n        print(f'\\n[Model: {self.model_name}] Running boundary max_tokens test')\n        resp1 = self._post({'prompt': 'Min tokens', 'max_tokens': 1, 'stream': False})\n        assert resp1.status_code == 200\n        data1 = resp1.json()\n        assert data1['meta_info']['completion_tokens'] >= 1\n\n        resp2 = self._post({'prompt': 'Max tokens test', 'max_tokens': 2048, 'stream': False})\n        assert resp2.status_code == 200\n\n        with pytest.raises(requests.HTTPError) as exc:\n            self._post({'prompt': 'Test', 'max_tokens': -2, 'stream': False})\n\n        assert exc.value.response.status_code == 400\n\n        with pytest.raises(requests.HTTPError) as exc:\n            self._post({'prompt': 'Test', 'max_tokens': 0, 'stream': False})\n\n        assert exc.value.response.status_code == 400\n\n        print('  Max tokens boundary test passed')\n\n    def test_parameter_interactions(self):\n        print(f'\\n[Model: {self.model_name}] Running parameter interactions test')\n        resp1 = self._post({\n            'prompt': 'Deterministic generation',\n            'max_tokens': 10,\n            'temperature': 0.0,\n            'top_p': 0.5,\n            'top_k': 10,\n            'stream': False\n        })\n        assert resp1.status_code == 200\n        data1 = resp1.json()\n\n        self._validate_generation_response(data1)\n\n        print('  Parameter interaction (temp=0 with top_p/k) passed')\n\n    def test_session_id_with_all_parameters(self):\n        print(f'\\n[Model: {self.model_name}] Running session_id with all parameters test')\n        session_id = int(time.time_ns()) % 100000\n\n        resp1 = self._post({\n            'session_id': session_id,\n            'prompt': 'Hello, introduce yourself briefly.',\n            'max_tokens': 20,\n            'temperature': 0.7,\n            'stream': False\n        })\n        assert resp1.status_code == 200\n        data1 = resp1.json()\n        self._validate_generation_response(data1)\n\n        resp2 = self._post({\n            'session_id': session_id,\n            'prompt': 'What was I just talking about?',\n            'max_tokens': 20,\n            'temperature': 0.7,\n            'stream': False\n        })\n        assert resp2.status_code == 200\n        data2 = resp2.json()\n        self._validate_generation_response(data2)\n\n        assert 'What' in data2['text'] or 'hello' in data2['text'].lower() or \\\n            len(data2['text']) > 0\n\n        print(f'  Session {session_id} test passed')\n\n    def test_edge_cases_stop_conditions(self):\n        print(f'\\n[Model: {self.model_name}] Running edge cases stop conditions test')\n        resp1 = self._post({'prompt': 'Test with empty stop list', 'max_tokens': 10, 'stop': [], 'stream': False})\n        assert resp1.status_code == 200\n        data1 = resp1.json()\n        assert len(data1['text']) > 0\n\n        resp2 = self._post({\n            'prompt': 'Write a sentence ending with a period. Stop here test.',\n            'max_tokens': 200,\n            'stop': ['.'],\n            'stream': False\n        })\n        assert resp2.status_code == 200\n        data2 = resp2.json()\n\n        text2 = data2['text']\n        finish_reason = data2['meta_info']['finish_reason']['type']\n\n        assert '. ' not in text2 and not text2.strip().endswith(\n            '.'), \"Stop token '.' should cause generation to end at period\"\n\n        assert not text2.strip().endswith('.'), \"Stop token '.' should cause generation to end at period\"\n\n        assert finish_reason in ['stop', 'eos'], \\\n            f'Expected to end due to stop token, actual: {finish_reason}, content is {text2}'\n\n        print(f\"  Stop at '.': generated '{text2}' (Reason: {finish_reason})\")\n\n    def test_spaces_between_special_tokens(self, config):\n        print(f'[Model: {self.model_name}] Running spaces_between_special_tokens test')\n        model_path = os.path.join(config.get('model_path'), self.model_name)\n        user_content = 'Hello [world]! This is a [test].'\n\n        tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)\n        special_tokens_map = tokenizer.special_tokens_map\n\n        special_patterns = list(special_tokens_map.values())\n        special_patterns = [\n            item for sublist in special_patterns for item in (sublist if isinstance(sublist, list) else [sublist])\n        ]\n\n        print(' Executing skip_special_tokens=False and checking spaces between special tokens')\n        payload_false = {'prompt': user_content, 'max_tokens': 100, 'skip_special_tokens': False, 'stream': False}\n        resp_false = self._post(payload_false)\n        data_false = resp_false.json()\n        self._validate_generation_response(data=data_false, validate_tokens=True)\n        generated_text = data_false['text']\n\n        for i in range(len(generated_text) - 1):\n            if generated_text[i] in special_patterns and generated_text[i + 1] not in [' ', '\\n']:\n                assert False, f'Expected space after special token {generated_text[i]} but found none.'\n\n    @pytest.mark.experts\n    @pytest.mark.not_turbomind\n    def test_request_returns_experts(self):\n        print(f'\\n[Model: {self.model_name}] Running request with experts test')\n        resp1 = self._post({\n            'prompt': 'Deterministic generation',\n            'max_tokens': 50,\n            'temperature': 0.8,\n            'return_routed_experts': True\n        })\n        assert resp1.status_code == 200\n        data1 = resp1.json()\n\n        self._validate_generation_response(data1, validate_experts=True)\n"
  },
  {
    "path": "autotest/prompt_case.yml",
    "content": "identity:\n    - 你好，你叫什么名字#hi, what's your name:\nmemory_test:\n    - 简要介绍乌鲁木齐的景点#A brief introduction to Urumqi’s attractions:\n        - contain:\n            - urumqi\n            - 乌鲁木齐\n            - 乌市\n            - xinjiang\n            - 新疆\n            - uwumqi\n            - Ürümqi\n    - 介绍它的相应美食#please introduce some delicious foods:\n        - contain:\n            - urumqi\n            - 乌鲁木齐\n            - 乌市\n            - xinjiang\n            - 新疆\n            - uwumqi\n            - Ürümqi\n\nchinese_poem_case:\n    - 给我一首中文诗，需要添加标点符号，请用中文回答Give me a Chinese poem in Chinese:\n        - contain:\n            - \"，\"\n            - \"。\"\n            - poem\n            - poetry\n            - \\n\n        - len_g:\n            5\nenglish_poem_case:\n    - write a romantic English poem in English:\n        - contain:\n            - \" \"\n        - contain:\n            - \".\"\n            - \",\"\n        - len_g:\n            30\nemoji_case:\n    - 请输出👍赞的emoji#print output the emoji of good👍:\n        - contain:\n            - 👍\n            - 😊\n            - 😀\n            - 🎉\n            - 👏\n            - 👌\n            - good\n            - like\n            - 赞\n            - 好\n            - '!'\n            - u1f44d\n            - 🌟\ntraditional_chinese_case:\n    - 介紹澳門景點，使用繁體:\n        - contain:\n            - 澳門\n            - 景點\n            - 澳门\n            - macau\ncode_testcase:\n    - 使用python编写一个int数组的冒泡排序代码:\n        - contain:\n            - def\n            - bubble\n            - 冒泡\n            - code\n            - python\n        - llama2:\n            - contain:\n                - def\n                - bubble\n                - 冒泡\n                - code\n                - python\n                - assist\n                - however\n"
  },
  {
    "path": "autotest/pytest.ini",
    "content": "[pytest]\npython_files = test*_*.py  # test file\npython_classes = Test*     # test class\npython_functions = test_*  # test function\npytest_runtest_call.tryfirst = True\nfilterwarnings = ignore::UserWarning\nreruns = 2\nreruns_delay = 1\n"
  },
  {
    "path": "autotest/template.json",
    "content": "{\n    \"model_name\": \"base\",\n    \"capability\": \"completion\"\n}\n"
  },
  {
    "path": "autotest/toolchain/test_lagent.py",
    "content": "import pytest\n\n\n@pytest.mark.order(10)\n@pytest.mark.lagent\n@pytest.mark.flaky(reruns=2)\n@pytest.mark.parametrize('model', ['internlm/internlm2_5-7b-chat'])\ndef test_repeat(config, model):\n    from lagent.llms import INTERNLM2_META, LMDeployPipeline\n\n    model = LMDeployPipeline(\n        path='/'.join([config.get('model_path'), model]),\n        meta_template=INTERNLM2_META,\n        tp=1,\n        top_k=40,\n        top_p=0.8,\n        temperature=1.2,\n        stop_words=['<|im_end|>'],\n        max_new_tokens=4096,\n    )\n    response_list = []\n    for i in range(3):\n        print(f'run_{i}：')\n        response = model.chat([{\n            'role':\n            'user',\n            'content':\n            '已知$$z_{1}=1$$,$$z_{2}=\\\\text{i}$$,$$z_{3}=-1$$,$$z_{4}=-\\\\text{i}$$,顺次连结它们所表示的点,则所得图形围成的面积为（ ）\\nA. $$\\\\dfrac{1}{4}$$\\n B. $$\\\\dfrac{1}{2}$$\\n C. $$1$$\\n D. $$2$$\\n\\n'  # noqa: F401, E501\n        }])\n        print(response)\n        response_list.append(response)\n        assert len(response) > 10\n    assert response_list[0] != response_list[1] and response_list[1] != response_list[2]\n"
  },
  {
    "path": "autotest/tools/chat/test_command_chat_hf_pytorch.py",
    "content": "import pytest\nfrom tools.common_case_config import (MODELSCOPE_CONFIG, PYTORCH_LORA_TEST_LLM_GPU1, PYTORCH_LORA_TEST_LLM_GPU2,\n                                      PYTORCH_PR_TEST_LLM_GPU1, PYTORCH_PR_TEST_LLM_GPU2)\nfrom utils.config_utils import get_func_config_list, get_workerid\nfrom utils.run_client_chat import run_tests\n\nBACKEND = 'pytorch'\n\n\n@pytest.mark.usefixtures('cli_case_config')\n@pytest.mark.gpu_num_1\n@pytest.mark.test_3090\n@pytest.mark.test_ascend\n@pytest.mark.parametrize('run_config', get_func_config_list('pytorch', {'tp': 1}))\ndef test_hf_pytorch_chat_tp1(config, run_config, cli_case_config, worker_id):\n    run_tests(config, 'chat_testcase', cli_case_config, run_config, worker_id)\n\n\n@pytest.mark.usefixtures('cli_case_config')\n@pytest.mark.gpu_num_2\n@pytest.mark.test_ascend\n@pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 2}))\ndef test_hf_pytorch_chat_tp2(config, run_config, cli_case_config, worker_id):\n    run_tests(config, 'chat_testcase', cli_case_config, run_config, worker_id)\n\n\n@pytest.mark.usefixtures('cli_case_config')\n@pytest.mark.gpu_num_4\n@pytest.mark.test_ascend\n@pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 4}))\ndef test_hf_pytorch_chat_tp4(config, run_config, cli_case_config, worker_id):\n    run_tests(config, 'chat_testcase', cli_case_config, run_config, worker_id)\n\n\n@pytest.mark.usefixtures('cli_case_config')\n@pytest.mark.gpu_num_8\n@pytest.mark.test_ascend\n@pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 8}))\ndef test_hf_pytorch_chat_tp8(config, run_config, cli_case_config, worker_id):\n    run_tests(config, 'chat_testcase', cli_case_config, run_config, worker_id)\n\n\n@pytest.mark.usefixtures('cli_case_config')\n@pytest.mark.gpu_num_16\n@pytest.mark.test_ascend\n@pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 16}))\ndef test_hf_pytorch_chat_tp16(config, run_config, cli_case_config, worker_id):\n    run_tests(config, 'chat_testcase', cli_case_config, run_config, worker_id)\n\n\n@pytest.mark.usefixtures('cli_case_config')\n@pytest.mark.gpu_num_1\n@pytest.mark.test_ascend\n@pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 1}, 'base_model'))\ndef test_hf_pytorch_base_tp1(config, run_config, cli_case_config, worker_id):\n    run_tests(config, 'base_testcase', cli_case_config, run_config, worker_id)\n\n\n@pytest.mark.usefixtures('cli_case_config')\n@pytest.mark.gpu_num_2\n@pytest.mark.test_ascend\n@pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 2}, 'base_model'))\ndef test_hf_pytorch_base_tp2(config, run_config, cli_case_config, worker_id):\n    run_tests(config, 'base_testcase', cli_case_config, run_config, worker_id)\n\n\n@pytest.mark.usefixtures('cli_case_config')\n@pytest.mark.gpu_num_2\n@pytest.mark.pr_test\n@pytest.mark.parametrize('run_config', PYTORCH_PR_TEST_LLM_GPU2)\ndef test_hf_pytorch_chat_pr_tp2(config, run_config, cli_case_config, worker_id):\n    worker_id = 'gw' + str(3 + get_workerid(worker_id))\n    run_tests(config, 'chat_testcase', cli_case_config, run_config, worker_id)\n\n\n@pytest.mark.usefixtures('cli_case_config')\n@pytest.mark.gpu_num_1\n@pytest.mark.pr_test\n@pytest.mark.parametrize('run_config', PYTORCH_PR_TEST_LLM_GPU1)\ndef test_hf_pytorch_chat_pr_tp1(config, run_config, cli_case_config, worker_id):\n    worker_id = 'gw' + str(6 + get_workerid(worker_id))\n    run_tests(config, 'chat_testcase', cli_case_config, run_config, worker_id)\n\n\n@pytest.mark.usefixtures('cli_case_config')\n@pytest.mark.gpu_num_1\n@pytest.mark.parametrize('run_config', [item for item in MODELSCOPE_CONFIG if item['backend'] == BACKEND])\ndef test_modelscope_pytorch_chat_tp1(config, run_config, cli_case_config, worker_id):\n    run_config['env'] = {'LMDEPLOY_USE_MODELSCOPE': 'True'}\n    run_tests(config, 'chat_testcase', cli_case_config, run_config, worker_id)\n\n\n@pytest.mark.order(10)\n@pytest.mark.usefixtures('cli_case_config')\n@pytest.mark.hf_pytorch_chat\n@pytest.mark.gpu_num_1\n@pytest.mark.other\n@pytest.mark.parametrize('run_config', PYTORCH_LORA_TEST_LLM_GPU1)\ndef test_pytorch_chat_with_lora_tp1(config, run_config, cli_case_config, worker_id):\n    run_tests(config, 'chat_testcase', cli_case_config, run_config, worker_id)\n\n\n@pytest.mark.order(10)\n@pytest.mark.usefixtures('cli_case_config')\n@pytest.mark.hf_pytorch_chat\n@pytest.mark.gpu_num_1\n@pytest.mark.other\n@pytest.mark.parametrize('run_config', PYTORCH_LORA_TEST_LLM_GPU2)\ndef test_pytorch_chat_with_lora_tp2(config, run_config, cli_case_config, worker_id):\n    run_tests(config, 'chat_testcase', cli_case_config, run_config, worker_id)\n"
  },
  {
    "path": "autotest/tools/chat/test_command_chat_hf_turbomind.py",
    "content": "import pytest\nfrom tools.common_case_config import (MODELSCOPE_CONFIG, TURBOMIND_FALLBACK_TEST_LLM_GPU1,\n                                      TURBOMIND_FALLBACK_TEST_LLM_GPU2, TURBOMIND_PR_TEST_LLM_GPU1,\n                                      TURBOMIND_PR_TEST_LLM_GPU2)\nfrom utils.config_utils import get_func_config_list, get_workerid\nfrom utils.run_client_chat import run_tests\n\nBACKEND = 'turbomind'\n\n\n@pytest.mark.usefixtures('cli_case_config')\n@pytest.mark.gpu_num_1\n@pytest.mark.test_3090\n@pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 1}))\ndef test_hf_turbomind_chat_tp1(config, run_config, cli_case_config, worker_id):\n    run_tests(config, 'chat_testcase', cli_case_config, run_config, worker_id)\n\n\n@pytest.mark.usefixtures('cli_case_config')\n@pytest.mark.gpu_num_2\n@pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 2}))\ndef test_hf_turbomind_chat_tp2(config, run_config, cli_case_config, worker_id):\n    run_tests(config, 'chat_testcase', cli_case_config, run_config, worker_id)\n\n\n@pytest.mark.usefixtures('cli_case_config')\n@pytest.mark.gpu_num_4\n@pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 4}))\ndef test_hf_turbomind_chat_tp4(config, run_config, cli_case_config, worker_id):\n    run_tests(config, 'chat_testcase', cli_case_config, run_config, worker_id)\n\n\n@pytest.mark.usefixtures('cli_case_config')\n@pytest.mark.gpu_num_8\n@pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 8}))\ndef test_hf_turbomind_chat_tp8(config, run_config, cli_case_config, worker_id):\n    run_tests(config, 'chat_testcase', cli_case_config, run_config, worker_id)\n\n\n@pytest.mark.usefixtures('cli_case_config')\n@pytest.mark.gpu_num_1\n@pytest.mark.parametrize('run_config', TURBOMIND_FALLBACK_TEST_LLM_GPU1)\ndef test_hf_turbomind_chat_fallback_backend_tp1(config, run_config, cli_case_config, worker_id):\n    run_tests(config, 'chat_testcase', cli_case_config, run_config, worker_id)\n\n\n@pytest.mark.usefixtures('cli_case_config')\n@pytest.mark.gpu_num_2\n@pytest.mark.parametrize('run_config', TURBOMIND_FALLBACK_TEST_LLM_GPU2)\ndef test_hf_turbomind_chat_fallback_backend_tp2(config, run_config, cli_case_config, worker_id):\n    run_tests(config, 'chat_testcase', cli_case_config, run_config, worker_id)\n\n\n@pytest.mark.usefixtures('cli_case_config')\n@pytest.mark.gpu_num_1\n@pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 1}, 'base_model'))\ndef test_hf_turbomind_base_tp1(config, run_config, cli_case_config, worker_id):\n    run_tests(config, 'base_testcase', cli_case_config, run_config, worker_id)\n\n\n@pytest.mark.usefixtures('cli_case_config')\n@pytest.mark.gpu_num_2\n@pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 2}, 'base_model'))\ndef test_hf_turbomind_base_tp2(config, run_config, cli_case_config, worker_id):\n    run_tests(config, 'base_testcase', cli_case_config, run_config, worker_id)\n\n\n@pytest.mark.usefixtures('cli_case_config')\n@pytest.mark.gpu_num_2\n@pytest.mark.pr_test\n@pytest.mark.parametrize('run_config', TURBOMIND_PR_TEST_LLM_GPU2)\ndef test_hf_turbomind_chat_pr_tp2(config, run_config, cli_case_config, worker_id):\n    worker_id = 'gw' + str(3 + get_workerid(worker_id))\n    run_tests(config, 'chat_testcase', cli_case_config, run_config, worker_id)\n\n\n@pytest.mark.usefixtures('cli_case_config')\n@pytest.mark.gpu_num_1\n@pytest.mark.pr_test\n@pytest.mark.parametrize('run_config', TURBOMIND_PR_TEST_LLM_GPU1)\ndef test_hf_turbomind_chat_pr_tp1(config, run_config, cli_case_config, worker_id):\n    worker_id = 'gw' + str(6 + get_workerid(worker_id))\n    run_tests(config, 'chat_testcase', cli_case_config, run_config, worker_id)\n\n\n@pytest.mark.usefixtures('cli_case_config')\n@pytest.mark.gpu_num_1\n@pytest.mark.parametrize('run_config', [item for item in MODELSCOPE_CONFIG if item['backend'] == BACKEND])\ndef test_modelscope_turbomind_chat_tp1(config, run_config, cli_case_config, worker_id):\n    run_tests(config, 'chat_testcase', cli_case_config, run_config, worker_id)\n"
  },
  {
    "path": "autotest/tools/common_case_config.py",
    "content": "TURBOMIND_PR_TEST_LLM_GPU2 = [{\n    'model': 'Qwen/Qwen3-30B-A3B',\n    'backend': 'turbomind',\n    'communicator': 'cuda-ipc',\n    'quant_policy': 0,\n    'parallel_config': {\n        'tp': 2\n    },\n    'extra_params': {}\n}, {\n    'model': 'mistralai/Mixtral-8x7B-Instruct-v0.1',\n    'backend': 'turbomind',\n    'communicator': 'nccl',\n    'quant_policy': 0,\n    'parallel_config': {\n        'tp': 2\n    },\n    'extra_params': {}\n}]\n\nTURBOMIND_PR_TEST_LLM_GPU1 = [{\n    'model': 'Qwen/Qwen3-0.6B',\n    'backend': 'turbomind',\n    'communicator': 'cuda-ipc',\n    'quant_policy': 0,\n    'parallel_config': {\n        'tp': 1\n    },\n    'extra_params': {}\n}, {\n    'model': 'Qwen/Qwen3-0.6B-inner-4bits',\n    'backend': 'turbomind',\n    'communicator': 'nccl',\n    'quant_policy': 0,\n    'parallel_config': {\n        'tp': 1\n    },\n    'extra_params': {}\n}, {\n    'model': 'Qwen/Qwen3-8B',\n    'backend': 'turbomind',\n    'communicator': 'nccl',\n    'quant_policy': 8,\n    'parallel_config': {\n        'tp': 1\n    },\n    'extra_params': {}\n}]\n\nTURBOMIND_PR_TEST_MLLM_GPU1 = [{\n    'model': 'OpenGVLab/InternVL3-8B',\n    'backend': 'turbomind',\n    'communicator': 'cuda-ipc',\n    'quant_policy': 0,\n    'parallel_config': {\n        'tp': 1\n    },\n    'extra_params': {}\n}, {\n    'model': 'OpenGVLab/InternVL3-8B',\n    'backend': 'turbomind',\n    'communicator': 'nccl',\n    'quant_policy': 8,\n    'parallel_config': {\n        'tp': 1\n    },\n    'extra_params': {}\n}]\n\nTURBOMIND_PR_TEST_MLLM_GPU2 = [{\n    'model': 'OpenGVLab/InternVL3_5-30B-A3B',\n    'backend': 'turbomind',\n    'communicator': 'cuda-ipc',\n    'quant_policy': 0,\n    'parallel_config': {\n        'tp': 2\n    },\n    'extra_params': {}\n}, {\n    'model': 'OpenGVLab/InternVL3_5-30B-A3B',\n    'backend': 'turbomind',\n    'communicator': 'nccl',\n    'quant_policy': 8,\n    'parallel_config': {\n        'tp': 2\n    },\n    'extra_params': {}\n}]\n\nTURBOMIND_FALLBACK_TEST_LLM_GPU1 = [{\n    'model': 'THUDM/cogvlm-chat-hf',\n    'backend': 'turbomind',\n    'communicator': 'cuda-ipc',\n    'quant_policy': 8,\n    'parallel_config': {\n        'tp': 1\n    },\n    'extra_params': {}\n}, {\n    'model': 'microsoft/Phi-3.5-vision-instruct',\n    'backend': 'turbomind',\n    'communicator': 'nccl',\n    'quant_policy': 0,\n    'parallel_config': {\n        'tp': 1\n    },\n    'extra_params': {}\n}]\n\nTURBOMIND_FALLBACK_TEST_LLM_GPU2 = [{\n    'model': 'Qwen/Qwen3-VL-30B-A3B-Instruct',\n    'backend': 'turbomind',\n    'communicator': 'cuda-ipc',\n    'quant_policy': 0,\n    'parallel_config': {\n        'tp': 2\n    },\n    'extra_params': {}\n}, {\n    'model': 'Qwen/Qwen3-VL-30B-A3B-Instruct',\n    'backend': 'turbomind',\n    'communicator': 'nccl',\n    'quant_policy': 8,\n    'parallel_config': {\n        'tp': 1\n    },\n    'extra_params': {}\n}]\n\nTURBOMIND_FALLBACK_TEST_MLLM_GPU1 = [{\n    'model': 'THUDM/glm-4v-9b',\n    'backend': 'turbomind',\n    'communicator': 'cuda-ipc',\n    'quant_policy': 4,\n    'parallel_config': {\n        'tp': 1\n    },\n    'extra_params': {}\n}, {\n    'model': 'THUDM/glm-4v-9b',\n    'backend': 'turbomind',\n    'communicator': 'nccl',\n    'quant_policy': 0,\n    'parallel_config': {\n        'tp': 1\n    },\n    'extra_params': {}\n}]\n\nTURBOMIND_LOGPROBS_TEST_LLM_GPU2 = [{\n    'model': 'Qwen/Qwen3-30B-A3B',\n    'backend': 'turbomind',\n    'communicator': 'nccl',\n    'quant_policy': 0,\n    'parallel_config': {\n        'tp': 2\n    },\n    'extra_params': {}\n}, {\n    'model': 'OpenGVLab/InternVL3-38B',\n    'backend': 'turbomind',\n    'communicator': 'nccl',\n    'quant_policy': 0,\n    'parallel_config': {\n        'tp': 2\n    },\n    'extra_params': {}\n}]\n\nBASE_MODELSCOPE_CONFIG = [{\n    'model': 'Qwen/Qwen2.5-7B-Instruct',\n    'communicator': 'cuda-ipc',\n    'quant_policy': 0,\n    'parallel_config': {\n        'tp': 1\n    },\n    'extra_params': {},\n    'env': {\n        'LMDEPLOY_USE_MODELSCOPE': 'True'\n    }\n}, {\n    'model': 'Qwen/Qwen2.5-7B-Instruct',\n    'communicator': 'nccl',\n    'quant_policy': 8,\n    'parallel_config': {\n        'tp': 1\n    },\n    'extra_params': {},\n    'env': {\n        'LMDEPLOY_USE_MODELSCOPE': 'True'\n    }\n}]\n\nMODELSCOPE_CONFIG = [{\n    **item, 'backend': 'turbomind'\n} for item in BASE_MODELSCOPE_CONFIG] + [{\n    **item, 'backend': 'pytorch'\n} for item in BASE_MODELSCOPE_CONFIG]\n\nPYTORCH_LORA_TEST_LLM_GPU1 = [{\n    'model': 'meta-llama/Llama-2-7b-chat-hf',\n    'backend': 'pytorch',\n    'communicator': 'nccl',\n    'quant_policy': 0,\n    'parallel_config': {\n        'tp': 1\n    },\n    'extra_params': {\n        'adapters': {\n            'default': 'lora/Llama2-Chinese-7b-Chat-LoRA'\n        }\n    }\n}]\n\nPYTORCH_LORA_TEST_LLM_GPU2 = [{\n    'model': 'baichuan-inc/Baichuan2-13B-Chat',\n    'backend': 'pytorch',\n    'communicator': 'nccl',\n    'quant_policy': 0,\n    'parallel_config': {\n        'tp': 2\n    },\n    'extra_params': {\n        'adapters': {\n            'a': 'lora/2024-01-25_self_dup',\n            'b': 'lora/2024-01-25_self'\n        }\n    }\n}]\n\nPYTORCH_PR_TEST_LLM_GPU2 = [{\n    'model': 'Qwen/Qwen3-30B-A3B',\n    'backend': 'pytorch',\n    'communicator': 'nccl',\n    'quant_policy': 8,\n    'parallel_config': {\n        'tp': 2\n    },\n    'extra_params': {}\n}, {\n    'model': 'mistralai/Mixtral-8x7B-Instruct-v0.1',\n    'backend': 'pytorch',\n    'communicator': 'nccl',\n    'quant_policy': 0,\n    'parallel_config': {\n        'tp': 2\n    },\n    'extra_params': {}\n}]\n\nPYTORCH_PR_TEST_LLM_GPU1 = [{\n    'model': 'meta-llama/Meta-Llama-3-1-8B-Instruct',\n    'backend': 'pytorch',\n    'communicator': 'nccl',\n    'quant_policy': 0,\n    'parallel_config': {\n        'tp': 1\n    },\n    'extra_params': {}\n}, {\n    'model': 'Qwen/Qwen3-0.6B',\n    'backend': 'pytorch',\n    'communicator': 'nccl',\n    'quant_policy': 8,\n    'parallel_config': {\n        'tp': 1\n    },\n    'extra_params': {}\n}]\n\nBASE_TOOLCALL_TEST_LLM = [{\n    'model': 'Qwen/Qwen3-8B',\n    'communicator': 'nccl',\n    'quant_policy': 0,\n    'parallel_config': {\n        'tp': 1\n    },\n    'extra_params': {\n        'tool-call-parser': 'qwen'\n    }\n}, {\n    'model': 'meta-llama/Meta-Llama-3-1-70B-Instruct',\n    'communicator': 'nccl',\n    'quant_policy': 0,\n    'parallel_config': {\n        'tp': 4\n    },\n    'extra_params': {\n        'tool-call-parser': 'llama3'\n    }\n}, {\n    'model': 'Qwen/Qwen3-30B-A3B',\n    'communicator': 'nccl',\n    'quant_policy': 0,\n    'parallel_config': {\n        'tp': 2\n    },\n    'extra_params': {\n        'tool-call-parser': 'qwen'\n    }\n}]\n\nBASE_REASONING_TEST_LLM = [{\n    'model': 'Qwen/Qwen3-VL-30B-A3B-Instruct',\n    'communicator': 'nccl',\n    'quant_policy': 0,\n    'parallel_config': {\n        'tp': 1\n    },\n    'extra_params': {\n        'reasoning-parser': 'qwen-qwq'\n    }\n}, {\n    'model': 'Qwen/Qwen3-30B-A3B',\n    'communicator': 'nccl',\n    'quant_policy': 0,\n    'parallel_config': {\n        'tp': 2\n    },\n    'extra_params': {\n        'reasoning-parser': 'qwen-qwq'\n    }\n}]\n\nTOOLCALL_TEST_LLM = [{\n    **item, 'backend': 'turbomind'\n} for item in BASE_TOOLCALL_TEST_LLM] + [{\n    **item, 'backend': 'pytorch'\n} for item in BASE_TOOLCALL_TEST_LLM]\n\nREASONING_TEST_LLM = [{\n    **item, 'backend': 'turbomind'\n} for item in BASE_REASONING_TEST_LLM] + [{\n    **item, 'backend': 'pytorch'\n} for item in BASE_REASONING_TEST_LLM]\n\nBASE_SPECULATIVE_DECODING_PIPELINE_TEST_LLM = [{\n    'model': 'meta-llama/Meta-Llama-3-1-8B-Instruct',\n    'communicator': 'nccl',\n    'quant_policy': 0,\n    'parallel_config': {\n        'tp': 1\n    },\n    'extra_params': {\n        'max_batch_size': 128,\n        'speculative_config': {\n            'method': 'eagle3',\n            'num_speculative_tokens': 3,\n            'model': 'yuhuili/EAGLE3-LLaMA3.1-Instruct-8B'\n        }\n    }\n}]\n\nSPECULATIVE_DECODING_PIPELINE_TEST_LLM = [{\n    **item, 'backend': 'pytorch'\n} for item in BASE_SPECULATIVE_DECODING_PIPELINE_TEST_LLM]\n\nBASE_SPECULATIVE_DECODING_RESTFUL_TEST_LLM = [{\n    'model': 'meta-llama/Meta-Llama-3-1-8B-Instruct',\n    'communicator': 'nccl',\n    'quant_policy': 0,\n    'parallel_config': {\n        'tp': 1\n    },\n    'extra_params': {\n        'speculative-draft-model': 'yuhuili/EAGLE3-LLaMA3.1-Instruct-8B',\n        'speculative-algorithm': 'eagle3',\n        'speculative-num-draft-tokens': 3,\n        'max-batch-size': 128\n    }\n}, {\n    'model': 'deepseek/DeepSeek-V3',\n    'communicator': 'nccl',\n    'quant_policy': 0,\n    'parallel_config': {\n        'tp': 16\n    },\n    'extra_params': {\n        'speculative-algorithm': 'deepseek_mtp',\n        'speculative-num-draft-tokens': 3,\n        'max-batch-size': 128\n    }\n}]\n\nSPECULATIVE_DECODING_RESTFUL_TEST_LLM = [{\n    **item, 'backend': 'pytorch'\n} for item in BASE_SPECULATIVE_DECODING_RESTFUL_TEST_LLM]\n"
  },
  {
    "path": "autotest/tools/pipeline/llm_case.py",
    "content": "import json\nimport os\n\nimport fire\nimport yaml\n\nfrom lmdeploy import GenerationConfig, PytorchEngineConfig, TurbomindEngineConfig, pipeline\nfrom lmdeploy.messages import SpeculativeConfig\n\ngen_config = GenerationConfig(max_new_tokens=500, min_new_tokens=10)\n\n\ndef run_pipeline_chat_test(model_path, run_config, cases_path, is_pr_test: bool = False):\n    backend = run_config.get('backend')\n    communicator = run_config.get('communicator')\n    quant_policy = run_config.get('quant_policy')\n    extra_params = run_config.get('extra_params', {})\n    parallel_config = run_config.get('parallel_config', {})\n\n    if backend == 'pytorch':\n        backend_config = PytorchEngineConfig(quant_policy=quant_policy)\n    else:\n        backend_config = TurbomindEngineConfig(communicator=communicator, quant_policy=quant_policy)\n\n    # quant format\n    model_lower = model_path.lower()\n    if 'w4' in model_lower or '4bits' in model_lower or 'awq' in model_lower:\n        backend_config.model_format = 'awq'\n    elif 'gptq' in model_lower:\n        backend_config.model_format = 'gptq'\n\n    # Parallel config\n    for para_key in ('dp', 'ep', 'cp'):\n        if para_key in parallel_config:\n            setattr(backend_config, para_key, parallel_config[para_key])\n    if 'tp' in parallel_config and parallel_config['tp'] > 1:\n        backend_config.tp = parallel_config['tp']\n\n    # Extract speculative_config from extra_params if present\n    speculative_config = None\n    spec_cfg = extra_params.pop('speculative_config', None)\n    if isinstance(spec_cfg, dict):\n        speculative_config = SpeculativeConfig(**spec_cfg)\n\n    # Extra params\n    # Map CLI param names to PytorchEngineConfig attribute names\n    param_name_map = {'device': 'device_type'}\n    for key, value in extra_params.items():\n        attr_name = param_name_map.get(key, key)\n        try:\n            setattr(backend_config, attr_name, value)\n        except AttributeError:\n            print(f\"Warning: Cannot set attribute '{attr_name}' on backend_config. Skipping.\")\n\n    print('backend_config config: ' + str(backend_config))\n    print('speculative_config config: ' + str(speculative_config))\n    pipe = pipeline(model_path, backend_config=backend_config, speculative_config=speculative_config)\n\n    cases_path = os.path.join(cases_path)\n    with open(cases_path) as f:\n        cases_info = yaml.load(f.read(), Loader=yaml.SafeLoader)\n\n    for case in cases_info.keys():\n        if is_pr_test and case != 'memory_test':\n            continue\n        if case != 'code_testcase' and 'code' in model_path.lower():\n            continue\n        case_info = cases_info.get(case)\n\n        prompts = []\n        response_list = []\n        for prompt_detail in case_info:\n            prompt = list(prompt_detail.keys())[0]\n            prompts.append({'role': 'user', 'content': prompt})\n            response = pipe([prompts], gen_config=gen_config, log_level='INFO', max_log_len=10)[0].text\n            response_list.append({'prompt': prompt, 'response': response})\n            prompts.append({'role': 'assistant', 'content': response})\n\n        print(f'[caseresult {case} start]' + json.dumps(response_list, ensure_ascii=False) +\n              f'[caseresult {case} end]\\n')\n\n    pipe.close()\n\n\nif __name__ == '__main__':\n    fire.Fire()\n"
  },
  {
    "path": "autotest/tools/pipeline/mllm_case.py",
    "content": "import json\n\nimport fire\nimport numpy as np\nfrom PIL import Image\n\nfrom lmdeploy import GenerationConfig, PytorchEngineConfig, TurbomindEngineConfig, pipeline\nfrom lmdeploy.vl import encode_image_base64, load_image\nfrom lmdeploy.vl.constants import IMAGE_TOKEN\n\ngen_config = GenerationConfig(max_new_tokens=500, min_new_tokens=10)\n\nPIC1 = 'tiger.jpeg'\nPIC2 = 'human-pose.jpg'\nPIC_BEIJING = 'Beijing_Small.jpeg'\nPIC_CHONGQING = 'Chongqing_Small.jpeg'\nPIC_REDPANDA = 'redpanda.jpg'\nPIC_PANDA = 'panda.jpg'\nDESC = 'What are the similarities and differences between these two images.'\nDESC_ZH = '两张图有什么相同和不同的地方.'\n\n\ndef run_pipeline_mllm_test(model_path, run_config, resource_path, is_pr_test: bool = False):\n    backend = run_config.get('backend')\n    communicator = run_config.get('communicator')\n    quant_policy = run_config.get('quant_policy')\n    extra_params = run_config.get('extra_params', {})\n    parallel_config = run_config.get('parallel_config', {})\n\n    if 'pytorch' == backend:\n        backend_config = PytorchEngineConfig(session_len=65152, quant_policy=quant_policy, cache_max_entry_count=0.6)\n    else:\n        backend_config = TurbomindEngineConfig(session_len=65152,\n                                               communicator=communicator,\n                                               quant_policy=quant_policy,\n                                               cache_max_entry_count=0.6)\n\n    # quant format\n    model_lower = model_path.lower()\n    if 'w4' in model_lower or '4bits' in model_lower or 'awq' in model_lower:\n        backend_config.model_format = 'awq'\n    elif 'gptq' in model_lower:\n        backend_config.model_format = 'gptq'\n\n    # Parallel config\n    for para_key in ('dp', 'ep', 'cp'):\n        if para_key in parallel_config:\n            setattr(backend_config, para_key, parallel_config[para_key])\n    if 'tp' in parallel_config and parallel_config['tp'] > 1:\n        backend_config.tp = parallel_config['tp']\n\n    # Extra params\n    # Map CLI param names to PytorchEngineConfig attribute names\n    param_name_map = {'device': 'device_type'}\n    for key, value in extra_params.items():\n        attr_name = param_name_map.get(key, key)\n        try:\n            setattr(backend_config, attr_name, value)\n        except AttributeError:\n            print(f\"Warning: Cannot set attribute '{attr_name}' on backend_config. Skipping.\")\n\n    print('backend_config config: ' + str(backend_config))\n    pipe = pipeline(model_path, backend_config=backend_config)\n\n    image = load_image(f'{resource_path}/{PIC1}')\n\n    if 'deepseek' in model_lower:\n        prompt = f'describe this image{IMAGE_TOKEN}'\n    else:\n        prompt = 'describe this image'\n\n    response = pipe((prompt, image)).text\n    print('[caseresult single1 start]' + json.dumps(response, ensure_ascii=False) + '[caseresult single1 end]\\n')\n\n    prompts = [{\n        'role':\n        'user',\n        'content': [{\n            'type': 'text',\n            'text': prompt\n        }, {\n            'type': 'image_url',\n            'image_url': {\n                'url': f'{resource_path}/{PIC1}'\n            }\n        }]\n    }]\n    response = pipe(prompts, gen_config=gen_config, log_level='INFO', max_log_len=10)\n    print('[caseresult single2 start]' + json.dumps(response.text, ensure_ascii=False) + '[caseresult single2 end]\\n')\n\n    image_urls = [f'{resource_path}/{PIC2}', f'{resource_path}/{PIC1}']\n    images = [load_image(img_url) for img_url in image_urls]\n    response = pipe((prompt, images))\n    print('[caseresult multi-imagese start]' + json.dumps(response.text, ensure_ascii=False) +\n          '[caseresult multi-imagese end]\\n')\n\n    image_urls = [f'{resource_path}/{PIC2}', f'{resource_path}/{PIC1}']\n    prompts = [(prompt, load_image(img_url)) for img_url in image_urls]\n    response = pipe(prompts, gen_config=gen_config, log_level='INFO', max_log_len=10)\n    print('[caseresult batch-example1 start]' + json.dumps(response[0].text, ensure_ascii=False) +\n          '[caseresult batch-example1 end]\\n')\n    print('[caseresult batch-example2 start]' + json.dumps(response[1].text, ensure_ascii=False) +\n          '[caseresult batch-example2 end]\\n')\n\n    image = load_image(f'{resource_path}/{PIC2}')\n    sess = pipe.chat((prompt, image))\n    print('[caseresult multi-turn1 start]' + json.dumps(sess.response.text, ensure_ascii=False) +\n          '[caseresult multi-turn1 end]\\n')\n    sess = pipe.chat('What is the woman doing?', session=sess)\n    print('[caseresult multi-turn2 start]' + json.dumps(sess.response.text, ensure_ascii=False) +\n          '[caseresult multi-turn2 end]\\n')\n\n    if not is_pr_test:\n        if 'internvl' in model_path.lower() and 'internvl2-4b' not in model_path.lower():\n            internvl_vl_testcase(pipe, resource_path)\n            internvl_vl_testcase(pipe, resource_path, lang='cn')\n        if 'minicpm' in model_path.lower():\n            MiniCPM_vl_testcase(pipe, resource_path)\n        if 'qwen' in model_path.lower():\n            Qwen_vl_testcase(pipe, resource_path)\n\n    pipe.close()\n\n\ndef internvl_vl_testcase(pipe, resource_path, lang='en'):\n    if lang == 'cn':\n        description = DESC_ZH\n    else:\n        description = DESC\n    # multi-image multi-round conversation, combined images\n    messages = [\n        dict(role='user',\n             content=[\n                 dict(type='text', text=f'{IMAGE_TOKEN}{IMAGE_TOKEN}\\n{description}'),\n                 dict(type='image_url', image_url=dict(max_dynamic_patch=12, url=f'{resource_path}/{PIC_REDPANDA}')),\n                 dict(type='image_url', image_url=dict(max_dynamic_patch=12, url=f'{resource_path}/{PIC_PANDA}'))\n             ])\n    ]\n    response = pipe(messages, gen_config=gen_config, log_level='INFO', max_log_len=10)\n    print(f'[caseresult internvl-combined-images-{lang} start]' + json.dumps(response.text, ensure_ascii=False) +\n          f'[caseresult internvl-combined-images-{lang} end]\\n')\n\n    messages.append(dict(role='assistant', content=response.text))\n    messages.append(dict(role='user', content=description))\n    response = pipe(messages, gen_config=gen_config, log_level='INFO', max_log_len=10)\n    print(f'[caseresult internvl-combined-images2-{lang} start]' + json.dumps(response.text, ensure_ascii=False) +\n          f'[caseresult internvl-combined-images2-{lang} end]\\n')\n\n    # multi-image multi-round conversation, separate images\n    messages = [\n        dict(\n            role='user',\n            content=[\n                dict(\n                    type='text',\n                    text=f'Image-1: {IMAGE_TOKEN}\\nImage-2: {IMAGE_TOKEN}\\n' +  # noqa E251,E501\n                    description),\n                dict(type='image_url', image_url=dict(max_dynamic_patch=12, url=f'{resource_path}/{PIC_REDPANDA}')),\n                dict(type='image_url', image_url=dict(max_dynamic_patch=12, url=f'{resource_path}/{PIC_PANDA}'))\n            ])\n    ]\n    response = pipe(messages, gen_config=gen_config, log_level='INFO', max_log_len=10)\n    print(f'[caseresult internvl-separate-images-{lang} start]' + json.dumps(response.text, ensure_ascii=False) +\n          f'[caseresult internvl-separate-images-{lang} end]\\n')\n\n    messages.append(dict(role='assistant', content=response.text))\n    messages.append(dict(role='user', content=description))\n    response = pipe(messages, gen_config=gen_config, log_level='INFO', max_log_len=10)\n    print(f'[caseresult internvl-separate-images2-{lang} start]' + json.dumps(response.text, ensure_ascii=False) +\n          f'[caseresult internvl-separate-images2-{lang} end]\\n')\n\n    # video multi-round conversation\n    def get_index(bound, fps, max_frame, first_idx=0, num_segments=32):\n        if bound:\n            start, end = bound[0], bound[1]\n        else:\n            start, end = -100000, 100000\n        start_idx = max(first_idx, round(start * fps))\n        end_idx = min(round(end * fps), max_frame)\n        seg_size = float(end_idx - start_idx) / num_segments\n        frame_indices = np.array(\n            [int(start_idx + (seg_size / 2) + np.round(seg_size * idx)) for idx in range(num_segments)])\n        return frame_indices\n\n    def load_video(video_path, bound=None, num_segments=32):\n        import cv2\n        cap = cv2.VideoCapture(video_path)\n        if not cap.isOpened():\n            raise ValueError(f'Cannot open video file: {video_path}')\n\n        max_frame = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) - 1\n        fps = cap.get(cv2.CAP_PROP_FPS)\n\n        frame_indices = get_index(bound, fps, max_frame, first_idx=0, num_segments=num_segments)\n        imgs = []\n\n        for frame_index in frame_indices:\n            cap.set(cv2.CAP_PROP_POS_FRAMES, frame_index)\n            ret, frame = cap.read()\n            if ret:\n                rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)\n                img = Image.fromarray(rgb_frame).convert('RGB')\n                imgs.append(img)\n\n        cap.release()\n        return imgs\n\n    video_path = resource_path + '/red-panda.mp4'\n    imgs = load_video(video_path, num_segments=8)\n\n    question = ''\n    for i in range(len(imgs)):\n        question = question + f'Frame{i+1}: {IMAGE_TOKEN}\\n'\n\n    if lang == 'cn':\n        question += '视频里有什么动物，它在做什么？'\n    else:\n        question += 'What animals are in the video, and what are they doing?'\n\n    content = [{'type': 'text', 'text': question}]\n    for img in imgs:\n        content.append({\n            'type': 'image_url',\n            'image_url': {\n                'max_dynamic_patch': 1,\n                'url': f'data:image/jpeg;base64,{encode_image_base64(img)}'  # noqa E231\n            }\n        })\n\n    messages = [dict(role='user', content=content)]\n    response = pipe(messages, gen_config=gen_config, log_level='INFO', max_log_len=10)\n    print(f'[caseresult internvl-video-{lang} start]' + json.dumps(response.text, ensure_ascii=False) +\n          f'[caseresult internvl-video-{lang} end]\\n')\n\n    messages.append(dict(role='assistant', content=response.text))\n    if lang == 'cn':\n        messages.append(dict(role='user', content='描述视频详情，不要重复'))\n    else:\n        messages.append(dict(role='user', content='Describe this video in detail. Don\\'t repeat.'))\n    response = pipe(messages, gen_config=gen_config, log_level='INFO', max_log_len=10)\n    print(f'[caseresult internvl-video2-{lang} start]' + json.dumps(response.text, ensure_ascii=False) +\n          f'[caseresult internvl-video2-{lang} end]\\n')\n\n\ndef MiniCPM_vl_testcase(pipe, resource_path):\n    # Chat with multiple images\n    messages = [\n        dict(role='user',\n             content=[\n                 dict(type='text', text='Describe the two images in detail.'),\n                 dict(type='image_url', image_url=dict(max_slice_nums=9, url=f'{resource_path}/{PIC_REDPANDA}')),\n                 dict(type='image_url', image_url=dict(max_slice_nums=9, url=f'{resource_path}/{PIC_PANDA}'))\n             ])\n    ]\n    response = pipe(messages, gen_config=gen_config, log_level='INFO', max_log_len=10)\n    print('[caseresult minicpm-combined-images start]' + json.dumps(response.text, ensure_ascii=False) +\n          '[caseresult minicpm-combined-images end]\\n')\n\n    messages.append(dict(role='assistant', content=response.text))\n    messages.append(dict(role='user', content=DESC))\n    response = pipe(messages, gen_config=gen_config, log_level='INFO', max_log_len=10)\n    print('[caseresult minicpm-combined-images2 start]' + json.dumps(response.text, ensure_ascii=False) +\n          '[caseresult minicpm-combined-images2 end]\\n')\n\n    # In-context few-shot learning\n    question = 'production date'\n    messages = [\n        dict(role='user',\n             content=[\n                 dict(type='text', text=question),\n                 dict(type='image_url', image_url=dict(url=f'{resource_path}/data1.jpeg')),\n             ]),\n        dict(role='assistant', content='2021.08.29'),\n        dict(role='user',\n             content=[\n                 dict(type='text', text=question),\n                 dict(type='image_url', image_url=dict(url=f'{resource_path}/data2.jpeg')),\n             ]),\n        dict(role='assistant', content='1999.05.15'),\n        dict(role='user',\n             content=[\n                 dict(type='text', text=question),\n                 dict(type='image_url', image_url=dict(url=f'{resource_path}/data3.jpeg')),\n             ])\n    ]\n    response = pipe(messages, gen_config=gen_config, log_level='INFO', max_log_len=10)\n    print('[caseresult minicpm-fewshot start]' + json.dumps(response.text, ensure_ascii=False) +\n          '[caseresult minicpm-fewshot end]\\n')\n\n    # Chat with video\n    MAX_NUM_FRAMES = 64  # if cuda OOM set a smaller number\n\n    def encode_video(video_path):\n\n        def uniform_sample(length, n):\n            gap = len(length) / n\n            idxs = [int(i * gap + gap / 2) for i in range(n)]\n            return [length[i] for i in idxs]\n\n        import cv2\n        cap = cv2.VideoCapture(video_path)\n        if not cap.isOpened():\n            raise ValueError(f'Cannot open video file: {video_path}')\n\n        fps = cap.get(cv2.CAP_PROP_FPS)\n        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))\n\n        sample_fps = round(fps / 1)  # FPS\n        frame_idx = [i for i in range(0, total_frames, sample_fps)]\n        if len(frame_idx) > MAX_NUM_FRAMES:\n            frame_idx = uniform_sample(frame_idx, MAX_NUM_FRAMES)\n\n        frames = []\n        for idx in frame_idx:\n            cap.set(cv2.CAP_PROP_POS_FRAMES, idx)\n            ret, frame = cap.read()\n            if ret:\n                rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)\n                frames.append(Image.fromarray(rgb_frame.astype('uint8')).convert('RGB'))\n\n        cap.release()\n        print('num frames:', len(frames))\n        return frames\n\n    video_path = resource_path + '/red-panda.mp4'\n    frames = encode_video(video_path)\n    question = 'What animals are in the video, and what are they doing?'\n\n    content = [dict(type='text', text=question)]\n    for frame in frames:\n        content.append(\n            dict(type='image_url',\n                 image_url=dict(use_image_id=False,\n                                max_slice_nums=2,\n                                url=f'data:image/jpeg;base64,{encode_image_base64(frame)}')))  # noqa E231\n\n    messages = [dict(role='user', content=content)]\n    response = pipe(messages, gen_config=gen_config, log_level='INFO', max_log_len=10)\n    print('[caseresult minicpm-video start]' + json.dumps(response.text, ensure_ascii=False) +\n          '[caseresult minicpm-video end]\\n')\n\n\ndef Qwen_vl_testcase(pipe, resource_path):\n    # multi-image multi-round conversation, combined images\n    messages = [\n        dict(role='user',\n             content=[\n                 dict(type='text', text='Describe the two images in detail.'),\n                 dict(type='image_url', image_url=dict(url=f'{resource_path}/{PIC_BEIJING}')),\n                 dict(type='image_url', image_url=dict(url=f'{resource_path}/{PIC_CHONGQING}'))\n             ])\n    ]\n    response = pipe(messages, gen_config=gen_config, log_level='INFO', max_log_len=10)\n    print('[caseresult qwen-combined-images start]' + json.dumps(response.text, ensure_ascii=False) +\n          '[caseresult qwen-combined-images end]\\n')\n\n    messages.append(dict(role='assistant', content=response.text))\n    messages.append(dict(role='user', content=DESC))\n    response = pipe(messages, gen_config=gen_config, log_level='INFO', max_log_len=10)\n    print('[caseresult qwen-combined-images2 start]' + json.dumps(response.text, ensure_ascii=False) +\n          '[caseresult qwen-combined-images2 end]\\n')\n\n    # image resolution for performance boost\n    min_pixels = 64 * 28 * 28\n    max_pixels = 64 * 28 * 28\n    messages = [\n        dict(role='user',\n             content=[\n                 dict(type='text', text='Describe the two images in detail.'),\n                 dict(type='image_url',\n                      image_url=dict(min_pixels=min_pixels, max_pixels=max_pixels,\n                                     url=f'{resource_path}/{PIC_BEIJING}')),\n                 dict(type='image_url',\n                      image_url=dict(min_pixels=min_pixels,\n                                     max_pixels=max_pixels,\n                                     url=f'{resource_path}/{PIC_CHONGQING}'))\n             ])\n    ]\n    response = pipe(messages, gen_config=gen_config, log_level='INFO', max_log_len=10)\n    print('[caseresult qwen-performance-images start]' + json.dumps(response.text, ensure_ascii=False) +\n          '[caseresult qwen-performance-images end]\\n')\n\n    messages.append(dict(role='assistant', content=response.text))\n    messages.append(dict(role='user', content=DESC))\n    response = pipe(messages, gen_config=gen_config, log_level='INFO', max_log_len=10)\n    print('[caseresult qwen-performance-images2 start]' + json.dumps(response.text, ensure_ascii=False) +\n          '[caseresult qwen-performance-images2 end]\\n')\n\n\nif __name__ == '__main__':\n    fire.Fire()\n"
  },
  {
    "path": "autotest/tools/pipeline/test_pipeline_chat_pytorch_llm.py",
    "content": "import pytest\nfrom tools.common_case_config import (MODELSCOPE_CONFIG, PYTORCH_LORA_TEST_LLM_GPU1, PYTORCH_LORA_TEST_LLM_GPU2,\n                                      PYTORCH_PR_TEST_LLM_GPU1, PYTORCH_PR_TEST_LLM_GPU2,\n                                      SPECULATIVE_DECODING_PIPELINE_TEST_LLM)\nfrom utils.config_utils import get_func_config_list, get_workerid\nfrom utils.pipeline_chat import run_pipeline_llm_test\n\nBACKEND = 'pytorch'\n\n\n@pytest.mark.usefixtures('common_case_config')\n@pytest.mark.gpu_num_1\n@pytest.mark.test_3090\n@pytest.mark.test_ascend\n@pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 1}))\ndef test_pipeline_chat_tp1(config, run_config, common_case_config, worker_id):\n    run_pipeline_llm_test(config, run_config, common_case_config, worker_id)\n\n\n@pytest.mark.usefixtures('common_case_config')\n@pytest.mark.gpu_num_2\n@pytest.mark.test_ascend\n@pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 2}))\ndef test_pipeline_chat_tp2(config, run_config, common_case_config, worker_id):\n    run_pipeline_llm_test(config, run_config, common_case_config, worker_id)\n\n\n@pytest.mark.usefixtures('common_case_config')\n@pytest.mark.gpu_num_4\n@pytest.mark.test_ascend\n@pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 4}))\ndef test_pipeline_chat_tp4(config, run_config, common_case_config, worker_id):\n    run_pipeline_llm_test(config, run_config, common_case_config, worker_id)\n\n\n@pytest.mark.usefixtures('common_case_config')\n@pytest.mark.gpu_num_8\n@pytest.mark.test_ascend\n@pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 8}))\ndef test_pipeline_chat_tp8(config, run_config, common_case_config, worker_id):\n    run_pipeline_llm_test(config, run_config, common_case_config, worker_id)\n\n\n@pytest.mark.usefixtures('common_case_config')\n@pytest.mark.gpu_num_16\n@pytest.mark.test_ascend\n@pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 16}))\ndef test_pipeline_chat_tp16(config, run_config, common_case_config, worker_id):\n    run_pipeline_llm_test(config, run_config, common_case_config, worker_id)\n\n\n@pytest.mark.usefixtures('common_case_config')\n@pytest.mark.gpu_num_2\n@pytest.mark.test_ascend\n@pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 2}, extra={'enable-prefix-caching': None}))\ndef test_pipeline_chat_pytorch_prefix_cache_tp2(config, run_config, common_case_config, worker_id):\n    run_pipeline_llm_test(config, run_config, common_case_config, worker_id)\n\n\n@pytest.mark.usefixtures('common_case_config')\n@pytest.mark.gpu_num_2\n@pytest.mark.pr_test\n@pytest.mark.parametrize('run_config', PYTORCH_PR_TEST_LLM_GPU2)\ndef test_hf_pytorch_chat_pr_tp2(config, run_config, common_case_config, worker_id):\n    worker_id = 'gw' + str(3 + get_workerid(worker_id))\n    run_pipeline_llm_test(config, run_config, common_case_config, worker_id)\n\n\n@pytest.mark.usefixtures('common_case_config')\n@pytest.mark.gpu_num_1\n@pytest.mark.pr_test\n@pytest.mark.parametrize('run_config', PYTORCH_PR_TEST_LLM_GPU1)\ndef test_hf_pytorch_chat_pr_tp1(config, run_config, common_case_config, worker_id):\n    worker_id = 'gw' + str(6 + get_workerid(worker_id))\n    run_pipeline_llm_test(config, run_config, common_case_config, worker_id)\n\n\n@pytest.mark.usefixtures('common_case_config')\n@pytest.mark.gpu_num_1\n@pytest.mark.parametrize('run_config', [item for item in MODELSCOPE_CONFIG if item['backend'] == BACKEND])\ndef test_modelscope_pipeline_chat_tp1(config, run_config, common_case_config, worker_id):\n    case_config = {k: v for k, v in common_case_config.items() if k == 'memory_test'}\n    run_pipeline_llm_test(config, run_config, case_config, worker_id)\n\n\n@pytest.mark.usefixtures('common_case_config')\n@pytest.mark.gpu_num_2\n@pytest.mark.parametrize('run_config', PYTORCH_LORA_TEST_LLM_GPU1)\ndef test_pytorch_chat_with_lora_tp1(config, run_config, common_case_config, worker_id):\n    run_pipeline_llm_test(config, run_config, common_case_config, worker_id)\n\n\n@pytest.mark.usefixtures('common_case_config')\n@pytest.mark.gpu_num_2\n@pytest.mark.parametrize('run_config', PYTORCH_LORA_TEST_LLM_GPU2)\ndef test_pytorch_chat_with_lora_tp2(config, run_config, common_case_config, worker_id):\n    run_pipeline_llm_test(config, run_config, common_case_config, worker_id)\n\n\n@pytest.mark.usefixtures('common_case_config')\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.gpu_num_1\n@pytest.mark.parametrize(\n    'run_config', [item for item in SPECULATIVE_DECODING_PIPELINE_TEST_LLM if item['parallel_config'].get('tp') == 1])\ndef test_pipeline_chat_speculative_decoding_tp1(config, run_config, common_case_config, worker_id):\n    case_config = {k: v for k, v in common_case_config.items() if k == 'memory_test'}\n    run_pipeline_llm_test(config, run_config, case_config, worker_id)\n"
  },
  {
    "path": "autotest/tools/pipeline/test_pipeline_chat_pytorch_mllm.py",
    "content": "import pytest\nfrom utils.config_utils import get_func_config_list\nfrom utils.pipeline_chat import run_pipeline_mllm_test\n\nBACKEND = 'pytorch'\n\n\ndef get_models(parallel_config):\n    return get_func_config_list(BACKEND, parallel_config, model_type='vl_model', extra={'session_len': 8192})\n\n\n@pytest.mark.gpu_num_1\n@pytest.mark.test_3090\n@pytest.mark.parametrize('run_config', get_models({'tp': 1}))\ndef test_restful_chat_tp1(config, run_config, worker_id):\n    run_pipeline_mllm_test(config, run_config, worker_id)\n\n\n@pytest.mark.gpu_num_2\n@pytest.mark.parametrize('run_config', get_models({'tp': 2}))\ndef test_restful_chat_tp2(config, run_config, worker_id):\n    run_pipeline_mllm_test(config, run_config, worker_id)\n\n\n@pytest.mark.gpu_num_4\n@pytest.mark.parametrize('run_config', get_models({'tp': 4}))\ndef test_restful_chat_tp4(config, run_config, worker_id):\n    run_pipeline_mllm_test(config, run_config, worker_id)\n\n\n@pytest.mark.gpu_num_8\n@pytest.mark.parametrize('run_config', get_models({'tp': 8}))\ndef test_restful_chat_tp8(config, run_config, worker_id):\n    run_pipeline_mllm_test(config, run_config, worker_id)\n\n\n@pytest.mark.gpu_num_16\n@pytest.mark.parametrize('run_config', get_models({'tp': 16}))\ndef test_restful_chat_tp16(config, run_config, worker_id):\n    run_pipeline_mllm_test(config, run_config, worker_id)\n"
  },
  {
    "path": "autotest/tools/pipeline/test_pipeline_chat_turbomind_llm.py",
    "content": "import pytest\nfrom tools.common_case_config import (MODELSCOPE_CONFIG, TURBOMIND_FALLBACK_TEST_LLM_GPU1,\n                                      TURBOMIND_FALLBACK_TEST_LLM_GPU2, TURBOMIND_PR_TEST_LLM_GPU1,\n                                      TURBOMIND_PR_TEST_LLM_GPU2)\nfrom utils.config_utils import get_func_config_list, get_workerid\nfrom utils.pipeline_chat import run_pipeline_llm_test\n\nBACKEND = 'turbomind'\n\n\n@pytest.mark.usefixtures('common_case_config')\n@pytest.mark.gpu_num_1\n@pytest.mark.test_3090\n@pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 1}))\ndef test_pipeline_chat_tp1(config, run_config, common_case_config, worker_id):\n    run_pipeline_llm_test(config, run_config, common_case_config, worker_id)\n\n\n@pytest.mark.usefixtures('common_case_config')\n@pytest.mark.gpu_num_2\n@pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 2}))\ndef test_pipeline_chat_tp2(config, run_config, common_case_config, worker_id):\n    run_pipeline_llm_test(config, run_config, common_case_config, worker_id)\n\n\n@pytest.mark.usefixtures('common_case_config')\n@pytest.mark.gpu_num_4\n@pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 4}))\ndef test_pipeline_chat_tp4(config, run_config, common_case_config, worker_id):\n    run_pipeline_llm_test(config, run_config, common_case_config, worker_id)\n\n\n@pytest.mark.usefixtures('common_case_config')\n@pytest.mark.gpu_num_8\n@pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 8}))\ndef test_pipeline_chat_tp8(config, run_config, common_case_config, worker_id):\n    run_pipeline_llm_test(config, run_config, common_case_config, worker_id)\n\n\n@pytest.mark.usefixtures('common_case_config')\n@pytest.mark.gpu_num_2\n@pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 2}, extra={'enable-prefix-caching': None}))\ndef test_pipeline_chat_prefix_cache_tp2(config, run_config, common_case_config, worker_id):\n    run_pipeline_llm_test(config, run_config, common_case_config, worker_id)\n\n\n@pytest.mark.usefixtures('common_case_config')\n@pytest.mark.gpu_num_1\n@pytest.mark.parametrize('run_config', TURBOMIND_FALLBACK_TEST_LLM_GPU1)\ndef test_pipeline_chat_fallback_backend_tp1(config, run_config, common_case_config, worker_id):\n    case_config = {k: v for k, v in common_case_config.items() if k == 'memory_test'}\n    run_pipeline_llm_test(config, run_config, case_config, worker_id)\n\n\n@pytest.mark.usefixtures('common_case_config')\n@pytest.mark.gpu_num_2\n@pytest.mark.parametrize('run_config', TURBOMIND_FALLBACK_TEST_LLM_GPU2)\ndef test_pipeline_chat_fallback_backend_tp2(config, run_config, common_case_config, worker_id):\n    case_config = {k: v for k, v in common_case_config.items() if k == 'memory_test'}\n    run_pipeline_llm_test(config, run_config, case_config, worker_id)\n\n\n@pytest.mark.usefixtures('common_case_config')\n@pytest.mark.gpu_num_2\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.pr_test\n@pytest.mark.parametrize('run_config', TURBOMIND_PR_TEST_LLM_GPU2)\ndef test_pipeline_chat_pr_tp2(config, run_config, common_case_config, worker_id):\n    worker_id = 'gw' + str(3 + get_workerid(worker_id))\n    case_config = {k: v for k, v in common_case_config.items() if k == 'memory_test'}\n    run_pipeline_llm_test(config, run_config, case_config, worker_id)\n\n\n@pytest.mark.usefixtures('common_case_config')\n@pytest.mark.gpu_num_1\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.pr_test\n@pytest.mark.parametrize('run_config', TURBOMIND_PR_TEST_LLM_GPU1)\ndef test_pipeline_chat_pr_tp1(config, run_config, common_case_config, worker_id):\n    worker_id = 'gw' + str(6 + get_workerid(worker_id))\n    case_config = {k: v for k, v in common_case_config.items() if k == 'memory_test'}\n    run_pipeline_llm_test(config, run_config, case_config, worker_id)\n\n\n@pytest.mark.usefixtures('common_case_config')\n@pytest.mark.gpu_num_1\n@pytest.mark.parametrize('run_config', [item for item in MODELSCOPE_CONFIG if item['backend'] == BACKEND])\ndef test_modelscope_restful_chat_tp1(config, run_config, common_case_config, worker_id):\n    case_config = {k: v for k, v in common_case_config.items() if k == 'memory_test'}\n    run_pipeline_llm_test(config, run_config, case_config, worker_id)\n"
  },
  {
    "path": "autotest/tools/pipeline/test_pipeline_chat_turbomind_mllm.py",
    "content": "import pytest\nfrom tools.common_case_config import (TURBOMIND_FALLBACK_TEST_MLLM_GPU1, TURBOMIND_PR_TEST_MLLM_GPU1,\n                                      TURBOMIND_PR_TEST_MLLM_GPU2)\nfrom utils.config_utils import get_func_config_list, get_workerid\nfrom utils.pipeline_chat import run_pipeline_mllm_test\n\nBACKEND = 'turbomind'\n\n\ndef get_models(parallel_config):\n    return get_func_config_list(BACKEND, parallel_config, model_type='vl_model', extra={'session_len': 8192})\n\n\n@pytest.mark.gpu_num_1\n@pytest.mark.test_3090\n@pytest.mark.parametrize('run_config', get_models({'tp': 1}))\ndef test_restful_chat_tp1(config, run_config, worker_id):\n    run_pipeline_mllm_test(config, run_config, worker_id)\n\n\n@pytest.mark.gpu_num_2\n@pytest.mark.parametrize('run_config', get_models({'tp': 2}))\ndef test_restful_chat_tp2(config, run_config, worker_id):\n    run_pipeline_mllm_test(config, run_config, worker_id)\n\n\n@pytest.mark.gpu_num_4\n@pytest.mark.parametrize('run_config', get_models({'tp': 4}))\ndef test_restful_chat_tp4(config, run_config, worker_id):\n    run_pipeline_mllm_test(config, run_config, worker_id)\n\n\n@pytest.mark.gpu_num_8\n@pytest.mark.parametrize('run_config', get_models({'tp': 8}))\ndef test_restful_chat_tp8(config, run_config, worker_id):\n    run_pipeline_mllm_test(config, run_config, worker_id)\n\n\n@pytest.mark.gpu_num_16\n@pytest.mark.parametrize('run_config', get_models({'tp': 16}))\ndef test_restful_chat_tp16(config, run_config, worker_id):\n    run_pipeline_mllm_test(config, run_config, worker_id)\n\n\n@pytest.mark.gpu_num_1\n@pytest.mark.other\n@pytest.mark.parametrize('run_config', TURBOMIND_FALLBACK_TEST_MLLM_GPU1)\ndef test_restful_chat_fallback_backend_tp1(config, run_config, worker_id):\n    run_pipeline_mllm_test(config, run_config, worker_id)\n\n\n@pytest.mark.gpu_num_1\n@pytest.mark.other\n@pytest.mark.pr_test\n@pytest.mark.parametrize('run_config', TURBOMIND_PR_TEST_MLLM_GPU1)\ndef test_pipeline_pr_test(config, run_config, worker_id):\n    worker_id = 'gw' + str(6 + get_workerid(worker_id))\n    run_pipeline_mllm_test(config, run_config, worker_id, is_smoke=True)\n\n\n@pytest.mark.gpu_num_2\n@pytest.mark.other\n@pytest.mark.pr_test\n@pytest.mark.parametrize('run_config', TURBOMIND_PR_TEST_MLLM_GPU2)\ndef test_pipeline_pr_tp2_test(config, run_config, worker_id):\n    worker_id = 'gw' + str(3 + get_workerid(worker_id))\n    run_pipeline_mllm_test(config, run_config, worker_id, is_smoke=True)\n"
  },
  {
    "path": "autotest/tools/quantization/test_quantization_awq.py",
    "content": "import os\n\nimport allure\nimport pytest\nfrom utils.config_utils import get_cuda_prefix_by_workerid, get_quantization_model_list\nfrom utils.quantization_utils import quantization\n\n\n@pytest.mark.order(3)\n@pytest.mark.test_3090\n@pytest.mark.timeout(900)\n@pytest.mark.parametrize('model', get_quantization_model_list('awq'))\ndef test_quantization_awq(config, model, worker_id):\n    quantization_type = 'awq'\n    quantization_all(config, model + '-inner-4bits', model, quantization_type,\n                     get_cuda_prefix_by_workerid(worker_id, {'tp': 1}))\n\n\n@pytest.mark.order(3)\n@pytest.mark.timeout(900)\n@pytest.mark.parametrize('model', get_quantization_model_list('gptq'))\ndef test_quantization_gptq(config, model, worker_id):\n    quantization_type = 'gptq'\n    quantization_all(config, model + '-inner-gptq', model, quantization_type,\n                     get_cuda_prefix_by_workerid(worker_id, {'tp': 1}))\n\n\n@pytest.mark.order(3)\n@pytest.mark.pr_test\n@pytest.mark.gpu_num_2\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.timeout(900)\n@pytest.mark.parametrize('model', ['Qwen/Qwen3-0.6B'])\ndef test_quantization_awq_pr(config, model):\n    quantization_type = 'awq'\n    quantization_all(config, model + '-inner-4bits', model, quantization_type, cuda_prefix='CUDA_VISIBLE_DEVICES=6')\n\n\ndef quantization_all(config, quantization_model_name, origin_model_name, quantization_type, cuda_prefix: str = ''):\n    result, msg = quantization(config, quantization_model_name, origin_model_name, quantization_type, cuda_prefix)\n    log_path = config.get('log_path')\n    quantization_log = os.path.join(\n        log_path, '_'.join(['quantization', quantization_type,\n                            quantization_model_name.split('/')[1]]) + '.log')\n\n    allure.attach.file(quantization_log, name=quantization_log, attachment_type=allure.attachment_type.TEXT)\n    assert result, msg\n"
  },
  {
    "path": "autotest/tools/quantization/test_quantization_w8a8.py",
    "content": "import os\n\nimport allure\nimport pytest\nfrom utils.config_utils import get_cuda_prefix_by_workerid, get_quantization_model_list\nfrom utils.quantization_utils import quantization\n\n\n@pytest.mark.order(2)\n@pytest.mark.quantization_w8a8\n@pytest.mark.timeout(900)\n@pytest.mark.parametrize('model', get_quantization_model_list('w8a8'))\ndef test_quantization_w8a8(config, model, worker_id):\n    quantization_w8a8(config, model + '-inner-w8a8', model, get_cuda_prefix_by_workerid(worker_id, {'tp': 1}))\n\n\ndef quantization_w8a8(config, quantization_model_name, origin_model_name, cuda_prefix):\n    quantization_type = 'w8a8'\n    result, msg = quantization(config, quantization_model_name, origin_model_name, quantization_type, cuda_prefix)\n    log_path = config.get('log_path')\n    quantization_log = os.path.join(\n        log_path, '_'.join(['quantization', quantization_type,\n                            quantization_model_name.split('/')[1]]) + '.log')\n\n    allure.attach.file(quantization_log, name=quantization_log, attachment_type=allure.attachment_type.TEXT)\n    assert result, msg\n"
  },
  {
    "path": "autotest/tools/restful/test_restful_chat_hf_pytorch_llm.py",
    "content": "import time\n\nimport pytest\nfrom tools.common_case_config import (MODELSCOPE_CONFIG, PYTORCH_LORA_TEST_LLM_GPU1, PYTORCH_LORA_TEST_LLM_GPU2,\n                                      PYTORCH_PR_TEST_LLM_GPU1, PYTORCH_PR_TEST_LLM_GPU2, REASONING_TEST_LLM,\n                                      SPECULATIVE_DECODING_RESTFUL_TEST_LLM, TOOLCALL_TEST_LLM)\nfrom utils.config_utils import get_case_str_by_config, get_func_config_list, get_workerid\nfrom utils.constant import PROXY_PORT\nfrom utils.proxy_distributed_utils import ApiServerPerTest, proxy_worker_node_wait\nfrom utils.ray_distributed_utils import ray_worker_node_wait\nfrom utils.run_restful_chat import run_all_step, run_llm_test, run_reasoning_case, run_tools_case\n\nBACKEND = 'pytorch'\n\n\ndef _run_ray_distributed_test(\n        config,\n        run_config,\n        common_case_config,\n        manager=None,  # ← New parameter: pass in shared manager\n):\n    \"\"\"Universal distributed test executor (using shared Ray cluster)\"\"\"\n    assert manager is not None, 'Manager instance must be provided'\n\n    if manager.is_master:\n        # Start API Server for current model (master node starts/stops, worker nodes verify)\n        manager.start_lmdeploy_api_server(config=config, run_config=run_config)\n\n        try:\n            case_name = get_case_str_by_config(run_config)\n            run_all_step(config.get('log_path'), case_name, common_case_config, port=PROXY_PORT)\n\n        finally:\n            # Clean up API Server for current model (worker nodes skip)\n            manager.cleanup(force=False)\n    else:\n        time.sleep(10)\n        ray_worker_node_wait(manager, timeout_minutes=4880)\n\n\ndef _run_proxy_distributed_test(\n        config,\n        run_config,\n        common_case_config,\n        manager=None,  # ← New parameter: pass in shared manager\n):\n    \"\"\"Universal distributed test executor (using shared Ray cluster)\"\"\"\n    assert manager is not None, 'Manager instance must be provided'\n\n    api_server = ApiServerPerTest(proxy_manager=manager, config=config, run_config=run_config)\n    api_server.start()\n\n    try:\n\n        if manager.is_master:\n            api_server.wait_until_ready()\n            case_name = get_case_str_by_config(run_config)\n            run_all_step(config.get('log_path'), case_name, common_case_config, port=PROXY_PORT)\n\n        else:\n            print(f'⏸️ Worker node {manager.node_rank} waiting for master to complete test...')\n            proxy_worker_node_wait(manager, timeout_minutes=4880)\n    finally:\n        api_server.cleanup()\n        if manager.is_master:\n            time.sleep(1)\n\n\n@pytest.mark.usefixtures('common_case_config')\n@pytest.mark.gpu_num_1\n@pytest.mark.test_3090\n@pytest.mark.test_ascend\n@pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 1}))\ndef test_restful_chat_tp1(config, run_config, common_case_config, worker_id):\n    run_llm_test(config, run_config, common_case_config, worker_id)\n\n\n@pytest.mark.usefixtures('common_case_config')\n@pytest.mark.gpu_num_2\n@pytest.mark.test_ascend\n@pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 2}))\ndef test_restful_chat_tp2(config, run_config, common_case_config, worker_id):\n    run_llm_test(config, run_config, common_case_config, worker_id)\n\n\n@pytest.mark.usefixtures('common_case_config')\n@pytest.mark.gpu_num_4\n@pytest.mark.test_ascend\n@pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 4}))\ndef test_restful_chat_tp4(config, run_config, common_case_config, worker_id):\n    run_llm_test(config, run_config, common_case_config, worker_id)\n\n\n@pytest.mark.usefixtures('common_case_config')\n@pytest.mark.gpu_num_8\n@pytest.mark.test_ascend\n@pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 8}))\ndef test_restful_chat_tp8(config, run_config, common_case_config, worker_id):\n    run_llm_test(config, run_config, common_case_config, worker_id)\n\n\n@pytest.mark.usefixtures('common_case_config')\n@pytest.mark.gpu_num_16\n@pytest.mark.test_ascend\n@pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 16}))\ndef test_restful_chat_tp16(config, run_config, common_case_config, worker_id):\n    run_llm_test(config, run_config, common_case_config, worker_id)\n\n\n@pytest.mark.usefixtures('common_case_config')\n@pytest.mark.restful_api_pytorch\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.gpu_num_distributed_tp16\n@pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 16}))\ndef test_restful_chat_distributed_tp16(shared_ray_manager, config, run_config, common_case_config, worker_id):\n    _run_ray_distributed_test(config=config,\n                              run_config=run_config,\n                              common_case_config=common_case_config,\n                              manager=shared_ray_manager)\n\n\n@pytest.mark.usefixtures('common_case_config')\n@pytest.mark.restful_api_pytorch\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.gpu_num_distributed_dpep16\n@pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'dp': 16, 'ep': 16}))\ndef test_restful_chat_distributed_dpep16(shared_proxy_manager, config, run_config, common_case_config, worker_id):\n    _run_proxy_distributed_test(config=config,\n                                run_config=run_config,\n                                common_case_config=common_case_config,\n                                manager=shared_proxy_manager)\n\n\n@pytest.mark.usefixtures('common_case_config')\n@pytest.mark.gpu_num_2\n@pytest.mark.test_ascend\n@pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 2}, extra={'enable-prefix-caching': None}))\ndef test_restful_chat_pytorch_prefix_cache_tp2(config, run_config, common_case_config, worker_id):\n    run_llm_test(config, run_config, common_case_config, worker_id)\n\n\n@pytest.mark.usefixtures('common_case_config')\n@pytest.mark.gpu_num_2\n@pytest.mark.pr_test\n@pytest.mark.parametrize('run_config', PYTORCH_PR_TEST_LLM_GPU2)\ndef test_hf_pytorch_chat_pr_tp2(config, run_config, common_case_config, worker_id):\n    worker_id = 'gw' + str(3 + get_workerid(worker_id))\n    run_llm_test(config, run_config, common_case_config, worker_id)\n\n\n@pytest.mark.usefixtures('common_case_config')\n@pytest.mark.gpu_num_1\n@pytest.mark.pr_test\n@pytest.mark.parametrize('run_config', PYTORCH_PR_TEST_LLM_GPU1)\ndef test_hf_pytorch_chat_pr_tp1(config, run_config, common_case_config, worker_id):\n    worker_id = 'gw' + str(6 + get_workerid(worker_id))\n    run_llm_test(config, run_config, common_case_config, worker_id)\n\n\n@pytest.mark.usefixtures('common_case_config')\n@pytest.mark.gpu_num_1\n@pytest.mark.parametrize('run_config', [item for item in MODELSCOPE_CONFIG if item['backend'] == BACKEND])\ndef test_modelscope_restful_chat_tp1(config, run_config, common_case_config, worker_id):\n    case_config = {k: v for k, v in common_case_config.items() if k == 'memory_test'}\n    run_llm_test(config, run_config, case_config, worker_id)\n\n\n@pytest.mark.usefixtures('common_case_config')\n@pytest.mark.gpu_num_2\n@pytest.mark.parametrize('run_config', PYTORCH_LORA_TEST_LLM_GPU1)\ndef test_pytorch_chat_with_lora_tp1(config, run_config, common_case_config, worker_id):\n    run_llm_test(config, run_config, common_case_config, worker_id)\n\n\n@pytest.mark.usefixtures('common_case_config')\n@pytest.mark.gpu_num_2\n@pytest.mark.parametrize('run_config', PYTORCH_LORA_TEST_LLM_GPU2)\ndef test_pytorch_chat_with_lora_tp2(config, run_config, common_case_config, worker_id):\n    run_llm_test(config, run_config, common_case_config, worker_id)\n\n\n@pytest.mark.usefixtures('common_case_config')\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.gpu_num_1\n@pytest.mark.parametrize(\n    'run_config',\n    [item for item in REASONING_TEST_LLM if item['backend'] == BACKEND and item['parallel_config'].get('tp') == 1])\ndef test_restful_chat_reasoning_tp1(config, run_config, worker_id):\n    run_reasoning_case(config, run_config, worker_id)\n\n\n@pytest.mark.usefixtures('common_case_config')\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.gpu_num_2\n@pytest.mark.parametrize(\n    'run_config',\n    [item for item in REASONING_TEST_LLM if item['backend'] == BACKEND and item['parallel_config'].get('tp') == 2])\ndef test_restful_chat_reasoning_tp2(config, run_config, worker_id):\n    run_reasoning_case(config, run_config, worker_id)\n\n\n@pytest.mark.usefixtures('common_case_config')\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.gpu_num_1\n@pytest.mark.parametrize(\n    'run_config',\n    [item for item in TOOLCALL_TEST_LLM if item['backend'] == BACKEND and item['parallel_config'].get('tp') == 1])\ndef test_restful_chat_tools_tp1(config, run_config, worker_id):\n    run_tools_case(config, run_config, worker_id)\n\n\n@pytest.mark.usefixtures('common_case_config')\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.gpu_num_2\n@pytest.mark.parametrize(\n    'run_config',\n    [item for item in TOOLCALL_TEST_LLM if item['backend'] == BACKEND and item['parallel_config'].get('tp') == 2])\ndef test_restful_chat_tools_tp2(config, run_config, worker_id):\n    run_tools_case(config, run_config, worker_id)\n\n\n@pytest.mark.usefixtures('common_case_config')\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.gpu_num_4\n@pytest.mark.parametrize(\n    'run_config',\n    [item for item in TOOLCALL_TEST_LLM if item['backend'] == BACKEND and item['parallel_config'].get('tp') == 4])\ndef test_restful_chat_tools_tp4(config, run_config, worker_id):\n    run_tools_case(config, run_config, worker_id)\n\n\n@pytest.mark.usefixtures('common_case_config')\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.gpu_num_1\n@pytest.mark.parametrize(\n    'run_config', [item for item in SPECULATIVE_DECODING_RESTFUL_TEST_LLM if item['parallel_config'].get('tp') == 1])\ndef test_restful_chat_speculative_decoding_tp1(config, run_config, common_case_config, worker_id):\n    case_config = {k: v for k, v in common_case_config.items() if k == 'memory_test'}\n    run_llm_test(config, run_config, case_config, worker_id)\n\n\n@pytest.mark.usefixtures('common_case_config')\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.gpu_num_distributed_tp16\n@pytest.mark.parametrize(\n    'run_config', [item for item in SPECULATIVE_DECODING_RESTFUL_TEST_LLM if item['parallel_config'].get('tp') == 16])\ndef test_restful_chat_speculative_decoding_tp16(shared_ray_manager, config, run_config, common_case_config, worker_id):\n    case_config = {k: v for k, v in common_case_config.items() if k == 'memory_test'}\n    _run_ray_distributed_test(config=config,\n                              run_config=run_config,\n                              common_case_config=case_config,\n                              manager=shared_ray_manager)\n"
  },
  {
    "path": "autotest/tools/restful/test_restful_chat_hf_pytorch_mllm.py",
    "content": "import pytest\nfrom utils.config_utils import get_func_config_list\nfrom utils.run_restful_chat import run_mllm_test\n\nBACKEND = 'pytorch'\n\n\n@pytest.mark.gpu_num_1\n@pytest.mark.test_3090\n@pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 1}, model_type='vl_model'))\ndef test_restful_chat_tp1(config, run_config, worker_id):\n    run_mllm_test(config, run_config, worker_id)\n\n\n@pytest.mark.gpu_num_2\n@pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 2}, model_type='vl_model'))\ndef test_restful_chat_tp2(config, run_config, worker_id):\n    run_mllm_test(config, run_config, worker_id)\n\n\n@pytest.mark.gpu_num_4\n@pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 4}, model_type='vl_model'))\ndef test_restful_chat_tp4(config, run_config, worker_id):\n    run_mllm_test(config, run_config, worker_id)\n\n\n@pytest.mark.gpu_num_8\n@pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 8}, model_type='vl_model'))\ndef test_restful_chat_tp8(config, run_config, worker_id):\n    run_mllm_test(config, run_config, worker_id)\n\n\n@pytest.mark.gpu_num_16\n@pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 16}, model_type='vl_model'))\ndef test_restful_chat_tp16(config, run_config, worker_id):\n    run_mllm_test(config, run_config, worker_id)\n"
  },
  {
    "path": "autotest/tools/restful/test_restful_chat_hf_turbomind_llm.py",
    "content": "import pytest\nfrom tools.common_case_config import (MODELSCOPE_CONFIG, REASONING_TEST_LLM, TOOLCALL_TEST_LLM,\n                                      TURBOMIND_FALLBACK_TEST_LLM_GPU1, TURBOMIND_FALLBACK_TEST_LLM_GPU2,\n                                      TURBOMIND_LOGPROBS_TEST_LLM_GPU2, TURBOMIND_PR_TEST_LLM_GPU1,\n                                      TURBOMIND_PR_TEST_LLM_GPU2)\nfrom utils.config_utils import get_func_config_list, get_workerid\nfrom utils.run_restful_chat import run_llm_test, run_logprob_test, run_reasoning_case, run_tools_case\n\nBACKEND = 'turbomind'\n\n\n@pytest.mark.usefixtures('common_case_config')\n@pytest.mark.gpu_num_1\n@pytest.mark.test_3090\n@pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 1}))\ndef test_restful_chat_tp1(config, run_config, common_case_config, worker_id):\n    run_llm_test(config, run_config, common_case_config, worker_id)\n\n\n@pytest.mark.usefixtures('common_case_config')\n@pytest.mark.gpu_num_2\n@pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 2}))\ndef test_restful_chat_tp2(config, run_config, common_case_config, worker_id):\n    run_llm_test(config, run_config, common_case_config, worker_id)\n\n\n@pytest.mark.usefixtures('common_case_config')\n@pytest.mark.gpu_num_4\n@pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 4}))\ndef test_restful_chat_tp4(config, run_config, common_case_config, worker_id):\n    run_llm_test(config, run_config, common_case_config, worker_id)\n\n\n@pytest.mark.usefixtures('common_case_config')\n@pytest.mark.gpu_num_8\n@pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 8}))\ndef test_restful_chat_tp8(config, run_config, common_case_config, worker_id):\n    run_llm_test(config, run_config, common_case_config, worker_id)\n\n\n@pytest.mark.usefixtures('common_case_config')\n@pytest.mark.gpu_num_2\n@pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 2}, extra={'enable-prefix-caching': None}))\ndef test_restful_chat_prefix_cache_tp2(config, run_config, common_case_config, worker_id):\n    run_llm_test(config, run_config, common_case_config, worker_id)\n\n\n@pytest.mark.usefixtures('common_case_config')\n@pytest.mark.gpu_num_1\n@pytest.mark.parametrize('run_config', TURBOMIND_FALLBACK_TEST_LLM_GPU1)\ndef test_restful_chat_fallback_backend_tp1(config, run_config, common_case_config, worker_id):\n    case_config = {k: v for k, v in common_case_config.items() if k == 'memory_test'}\n    run_llm_test(config, run_config, case_config, worker_id)\n\n\n@pytest.mark.usefixtures('common_case_config')\n@pytest.mark.gpu_num_2\n@pytest.mark.parametrize('run_config', TURBOMIND_FALLBACK_TEST_LLM_GPU2)\ndef test_restful_chat_fallback_backend_tp2(config, run_config, common_case_config, worker_id):\n    case_config = {k: v for k, v in common_case_config.items() if k == 'memory_test'}\n    run_llm_test(config, run_config, case_config, worker_id)\n\n\n@pytest.mark.usefixtures('common_case_config')\n@pytest.mark.gpu_num_2\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.pr_test\n@pytest.mark.parametrize('run_config', TURBOMIND_PR_TEST_LLM_GPU2)\ndef test_restful_chat_pr_tp2(config, run_config, common_case_config, worker_id):\n    worker_id = 'gw' + str(3 + get_workerid(worker_id))\n    case_config = {k: v for k, v in common_case_config.items() if k == 'memory_test'}\n    run_llm_test(config, run_config, case_config, worker_id)\n\n\n@pytest.mark.usefixtures('common_case_config')\n@pytest.mark.gpu_num_1\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.pr_test\n@pytest.mark.parametrize('run_config', TURBOMIND_PR_TEST_LLM_GPU1)\ndef test_restful_chat_pr_tp1(config, run_config, common_case_config, worker_id):\n    worker_id = 'gw' + str(6 + get_workerid(worker_id))\n    case_config = {k: v for k, v in common_case_config.items() if k == 'memory_test'}\n    run_llm_test(config, run_config, case_config, worker_id)\n\n\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.gpu_num_2\n@pytest.mark.pr_test\n@pytest.mark.parametrize('run_config', TURBOMIND_LOGPROBS_TEST_LLM_GPU2)\ndef test_restful_logprobs(config, run_config, worker_id):\n    worker_id = 'gw' + str(3 + get_workerid(worker_id))\n    run_logprob_test(config, run_config, worker_id)\n\n\n@pytest.mark.usefixtures('common_case_config')\n@pytest.mark.gpu_num_1\n@pytest.mark.parametrize('run_config', [item for item in MODELSCOPE_CONFIG if item['backend'] == BACKEND])\ndef test_modelscope_restful_chat_tp1(config, run_config, common_case_config, worker_id):\n    case_config = {k: v for k, v in common_case_config.items() if k == 'memory_test'}\n    run_llm_test(config, run_config, case_config, worker_id)\n\n\n@pytest.mark.usefixtures('common_case_config')\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.gpu_num_1\n@pytest.mark.parametrize(\n    'run_config',\n    [item for item in REASONING_TEST_LLM if item['backend'] == BACKEND and item['parallel_config'].get('tp') == 1])\ndef test_restful_chat_reasoning_tp1(config, run_config, worker_id):\n    run_reasoning_case(config, run_config, worker_id)\n\n\n@pytest.mark.usefixtures('common_case_config')\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.gpu_num_2\n@pytest.mark.parametrize(\n    'run_config',\n    [item for item in REASONING_TEST_LLM if item['backend'] == BACKEND and item['parallel_config'].get('tp') == 2])\ndef test_restful_chat_reasoning_tp2(config, run_config, worker_id):\n    run_reasoning_case(config, run_config, worker_id)\n\n\n@pytest.mark.usefixtures('common_case_config')\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.gpu_num_1\n@pytest.mark.parametrize(\n    'run_config',\n    [item for item in TOOLCALL_TEST_LLM if item['backend'] == BACKEND and item['parallel_config'].get('tp') == 1])\ndef test_restful_chat_tools_tp1(config, run_config, worker_id):\n    run_tools_case(config, run_config, worker_id)\n\n\n@pytest.mark.usefixtures('common_case_config')\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.gpu_num_2\n@pytest.mark.parametrize(\n    'run_config',\n    [item for item in TOOLCALL_TEST_LLM if item['backend'] == BACKEND and item['parallel_config'].get('tp') == 2])\ndef test_restful_chat_tools_tp2(config, run_config, worker_id):\n    run_tools_case(config, run_config, worker_id)\n\n\n@pytest.mark.usefixtures('common_case_config')\n@pytest.mark.flaky(reruns=0)\n@pytest.mark.gpu_num_4\n@pytest.mark.parametrize(\n    'run_config',\n    [item for item in TOOLCALL_TEST_LLM if item['backend'] == BACKEND and item['parallel_config'].get('tp') == 4])\ndef test_restful_chat_tools_tp4(config, run_config, worker_id):\n    run_tools_case(config, run_config, worker_id)\n"
  },
  {
    "path": "autotest/tools/restful/test_restful_chat_hf_turbomind_mllm.py",
    "content": "import pytest\nfrom tools.common_case_config import TURBOMIND_FALLBACK_TEST_MLLM_GPU1\nfrom utils.config_utils import get_func_config_list\nfrom utils.run_restful_chat import run_mllm_test\n\nBACKEND = 'turbomind'\n\n\n@pytest.mark.gpu_num_1\n@pytest.mark.test_3090\n@pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 1}, model_type='vl_model'))\ndef test_restful_chat_tp1(config, run_config, worker_id):\n    run_mllm_test(config, run_config, worker_id)\n\n\n@pytest.mark.gpu_num_2\n@pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 2}, model_type='vl_model'))\ndef test_restful_chat_tp2(config, run_config, worker_id):\n    run_mllm_test(config, run_config, worker_id)\n\n\n@pytest.mark.gpu_num_4\n@pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 4}, model_type='vl_model'))\ndef test_restful_chat_tp4(config, run_config, worker_id):\n    run_mllm_test(config, run_config, worker_id)\n\n\n@pytest.mark.gpu_num_8\n@pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 8}, model_type='vl_model'))\ndef test_restful_chat_tp8(config, run_config, worker_id):\n    run_mllm_test(config, run_config, worker_id)\n\n\n@pytest.mark.gpu_num_16\n@pytest.mark.parametrize('run_config', get_func_config_list(BACKEND, {'tp': 16}, model_type='vl_model'))\ndef test_restful_chat_tp16(config, run_config, worker_id):\n    run_mllm_test(config, run_config, worker_id)\n\n\n@pytest.mark.gpu_num_1\n@pytest.mark.other\n@pytest.mark.parametrize('run_config', TURBOMIND_FALLBACK_TEST_MLLM_GPU1)\ndef test_restful_chat_fallback_backend_tp1(config, run_config, worker_id):\n    run_mllm_test(config, run_config, worker_id)\n"
  },
  {
    "path": "autotest/utils/benchmark_utils.py",
    "content": "import os\nimport time\n\nimport allure\nimport utils.constant as constant\nfrom utils.common_utils import execute_command_with_logging\nfrom utils.config_utils import get_case_str_by_config, get_cli_common_param, get_cuda_prefix_by_workerid, get_workerid\nfrom utils.run_restful_chat import health_check, start_openai_service, terminate_restful_api\n\n\ndef throughput_test(config, run_config, worker_id: str = '', is_smoke: bool = False):\n    model = run_config.get('model')\n    model_path = os.path.join(config.get('model_path'), model)\n    dataset_path = config.get('dataset_path')\n\n    case_name = get_case_str_by_config(run_config)\n    benchmark_path = os.path.join(config.get('benchmark_path'), 'throughput')\n    work_dir = os.path.join(benchmark_path, f'wk_{case_name}')\n    os.makedirs(work_dir, exist_ok=True)\n\n    max_cache_entry = get_max_cache_entry(model, run_config.get('backend'))\n    if max_cache_entry is not None:\n        if 'extra_params' not in run_config:\n            run_config['extra_params'] = {}\n        run_config['extra_params']['cache-max-entry-count'] = max_cache_entry\n\n    cuda_prefix = get_cuda_prefix_by_workerid(worker_id, run_config.get('parallel_config'))\n\n    command = f'{cuda_prefix} python3 benchmark/profile_throughput.py {dataset_path} {model_path} {get_cli_common_param(run_config)}'  # noqa\n\n    if is_smoke:\n        num_prompts = '--num-prompts 100'\n    else:\n        num_prompts = '--num-prompts 5000'\n\n    env = os.environ.copy()\n    env.update(run_config.get('env', {}))\n\n    for batch in [128, 256]:\n        csv_path = os.path.join(work_dir, f'{batch}.csv')\n        timestamp = time.strftime('%Y%m%d_%H%M%S')\n        benchmark_log = os.path.join(benchmark_path, f'log_{case_name}_{batch}_{timestamp}.log')\n        cmd = ' '.join([command, '--concurrency', str(batch), num_prompts, '--csv ', csv_path]).strip()\n\n        result, stderr = execute_command_with_logging(cmd, benchmark_log, env=env)\n        allure.attach.file(benchmark_log, name=benchmark_log, attachment_type=allure.attachment_type.TEXT)\n\n        if result and not os.path.isfile(csv_path):\n            return False, 'result is empty'\n        if not result:\n            return False, stderr\n\n    return True, 'success'\n\n\ndef longtext_throughput_test(config, run_config, worker_id: str = ''):\n    model = run_config.get('model')\n    model_path = os.path.join(config.get('model_path'), model)\n    dataset_path = config.get('dataset_path')\n\n    case_name = get_case_str_by_config(run_config)\n    benchmark_path = os.path.join(config.get('benchmark_path'), 'longtext-throughput')\n    work_dir = os.path.join(benchmark_path, f'wk_{case_name}')\n    os.makedirs(work_dir, exist_ok=True)\n\n    max_cache_entry = get_max_cache_entry(model, run_config.get('backend'))\n    if max_cache_entry is not None:\n        if 'extra_params' not in run_config:\n            run_config['extra_params'] = {}\n        run_config['extra_params']['cache-max-entry-count'] = max_cache_entry\n        run_config['extra_params'].pop('session-len', None)\n\n    cuda_prefix = get_cuda_prefix_by_workerid(worker_id, run_config.get('parallel_config'))\n\n    command = f'{cuda_prefix} python3 benchmark/profile_pipeline_api.py {dataset_path} {model_path} {get_cli_common_param(run_config)}'  # noqa\n\n    env = os.environ.copy()\n    env.update(run_config.get('env', {}))\n\n    for input_len, out_len, num_prompts, session_info, concurrency in [(1, 32768, 3, '32k', 3), (1, 65536, 1, '64k', 1),\n                                                                       (65536, 1024, 5, '64k-1k', 5),\n                                                                       (198000, 1024, 1, '198k-1k', 1)]:\n        session_len = input_len + out_len + 1\n        csv_path = os.path.join(work_dir, f'{case_name}_{session_info}.csv')\n        timestamp = time.strftime('%Y%m%d_%H%M%S')\n        benchmark_log = os.path.join(benchmark_path, f'log_{case_name}_{session_info}_{timestamp}.log')\n        cmd = ' '.join([\n            command, '--dataset-name random', f'--random-input-len {input_len}', f'--random-output-len {out_len}',\n            f'--num-prompts {num_prompts}', f'--concurrency {concurrency}', '--stream-output',\n            f'--session-len {session_len}', '--random-range-ratio 1', f'--csv {csv_path}'\n        ]).strip()\n\n        result, stderr = execute_command_with_logging(cmd, benchmark_log, timeout=7200, env=env)\n        allure.attach.file(benchmark_log, name=benchmark_log, attachment_type=allure.attachment_type.TEXT)\n\n        if result and not os.path.isfile(csv_path):\n            return False, 'result is empty'\n        if not result:\n            return False, stderr\n    return True, 'success'\n\n\ndef restful_test(config, run_config, worker_id: str = '', is_smoke: bool = False, is_mllm: bool = False):\n    max_cache_entry = get_max_cache_entry(run_config.get('model'), run_config.get('backend'))\n    if max_cache_entry is not None:\n        if 'extra_params' not in run_config:\n            run_config['extra_params'] = {}\n        run_config['extra_params']['cache-max-entry-count'] = max_cache_entry\n\n    pid, content = start_openai_service(config, run_config, worker_id)\n    try:\n        if pid > 0:\n            if is_mllm:\n                return mllm_restful_profile(config,\n                                            run_config,\n                                            port=constant.DEFAULT_PORT + get_workerid(worker_id),\n                                            is_smoke=is_smoke)\n            else:\n                return restful_profile(config,\n                                       run_config,\n                                       port=constant.DEFAULT_PORT + get_workerid(worker_id),\n                                       is_smoke=is_smoke)\n        else:\n            assert False, f'Failed to start RESTful API server: {content}'\n    finally:\n        if pid > 0:\n            terminate_restful_api(worker_id)\n\n\nBASE_HTTP_URL = f'http://{constant.DEFAULT_SERVER}'\n\n\ndef restful_profile(config, run_config, port, is_smoke: bool = False):\n    model_path = os.path.join(config.get('model_path'), run_config.get('model'))\n    case_name = get_case_str_by_config(run_config)\n    dataset_path = config.get('dataset_path')\n    benchmark_path = os.path.join(config.get('benchmark_path'), 'restful')\n    work_dir = os.path.join(benchmark_path, f'wk_{case_name}')\n    timestamp = time.strftime('%Y%m%d_%H%M%S')\n    benchmark_log = os.path.join(benchmark_path, f'log_{case_name}_{timestamp}.log')\n    os.makedirs(work_dir, exist_ok=True)\n\n    http_url = f'{BASE_HTTP_URL}:{port}'  # noqa: E231\n    if not health_check(http_url, case_name):\n        return False, 'server not start'\n\n    csv_path = f'{work_dir}/restful.csv'\n\n    command = f'python benchmark/profile_restful_api.py --backend lmdeploy --dataset-name sharegpt --dataset-path {dataset_path} --tokenizer {model_path} --base-url {http_url} --output-file {csv_path}'  # noqa\n    if is_smoke:\n        command += ' --num-prompts 100'\n    else:\n        command += ' --num-prompts 5000'\n\n    result, stderr = execute_command_with_logging(command, benchmark_log)\n    allure.attach.file(benchmark_log, name=benchmark_log, attachment_type=allure.attachment_type.TEXT)\n\n    if result and not os.path.isfile(csv_path):\n        return False, 'result is empty'\n    if not result:\n        return False, stderr\n    return True, 'success'\n\n\ndef mllm_restful_profile(config, run_config, port, is_smoke: bool = False):\n    model_path = os.path.join(config.get('model_path'), run_config.get('model'))\n    case_name = get_case_str_by_config(run_config)\n    benchmark_path = os.path.join(config.get('benchmark_path'), 'mllm_restful')\n    work_dir = os.path.join(benchmark_path, f'wk_{case_name}')\n    timestamp = time.strftime('%Y%m%d_%H%M%S')\n    benchmark_log = os.path.join(benchmark_path, f'log_{case_name}_{timestamp}.log')\n    os.makedirs(work_dir, exist_ok=True)\n\n    http_url = f'{BASE_HTTP_URL}:{port}'  # noqa: E231\n    if not health_check(http_url, case_name):\n        return False, 'server not start'\n\n    csv_path = f'{work_dir}/mllm_restful.csv'\n\n    command = f'python benchmark/profile_restful_api.py --backend lmdeploy-chat --dataset-name image --tokenizer {model_path} --model {case_name} --model-path {model_path} --random-input-len 100 --random-output-len 100 --random-range-ratio 1 --image-format jpeg --image-count 1 --image-content random --image-resolution 1024x1024 --base-url {http_url} --output-file {csv_path}'  # noqa\n    if is_smoke:\n        command += ' --num-prompts 100'\n    else:\n        command += ' --num-prompts 1000'\n\n    result, stderr = execute_command_with_logging(command, benchmark_log)\n    allure.attach.file(benchmark_log, name=benchmark_log, attachment_type=allure.attachment_type.TEXT)\n\n    if result and not os.path.isfile(csv_path):\n        return False, 'result is empty'\n    if not result:\n        return False, stderr\n    return True, 'success'\n\n\ndef prefixcache_throughput_test(config, run_config, worker_id: str = '', is_smoke: bool = False):\n    model = run_config.get('model')\n    model_path = os.path.join(config.get('model_path'), model)\n    dataset_path = config.get('prefix_dataset_path')\n\n    case_name = get_case_str_by_config(run_config)\n    benchmark_path = os.path.join(config.get('benchmark_path'), 'prefix-throughtput')\n    work_dir = os.path.join(benchmark_path, f'wk_{case_name}')\n    os.makedirs(work_dir, exist_ok=True)\n    max_cache_entry = get_max_cache_entry(model, run_config.get('backend'))\n    if max_cache_entry is not None:\n        if 'extra_params' not in run_config:\n            run_config['extra_params'] = {}\n        run_config['extra_params']['cache-max-entry-count'] = max_cache_entry\n\n    cuda_prefix = get_cuda_prefix_by_workerid(worker_id, run_config.get('parallel_config'))\n\n    run_config_new = run_config.copy()\n    if 'extra_params' not in run_config_new:\n        run_config_new['extra_params'] = {}\n    run_config_new['extra_params'].pop('enable-prefix-caching', None)\n    run_config_new['extra_params']['session-len'] = 32768\n    command = f'{cuda_prefix} python3 benchmark/profile_pipeline_api.py {dataset_path} {model_path} {get_cli_common_param(run_config_new)}'  # noqa\n\n    env = os.environ.copy()\n    env.update(run_config.get('env', {}))\n\n    if is_smoke:\n        test_configs = [(4096, 256, 10, '4k', None)]\n    else:\n        test_configs = [(4096, 256, 100, '4k', None)]\n\n    for enable_prefix_caching in [False, True]:\n        suffix = 'cache' if enable_prefix_caching else 'no_cache'\n\n        for input_len, out_len, num_prompts, session_info, concurrency in test_configs:\n            timestamp = time.strftime('%Y%m%d_%H%M%S')\n            benchmark_log = os.path.join(benchmark_path, f'log_{case_name}_{session_info}_{suffix}_{timestamp}.log')\n            csv_path = os.path.join(work_dir, f'{session_info}_{suffix}.csv')\n\n            command = ' '.join([\n                command, '--dataset-name random', f'--random-input-len {input_len}', f'--random-output-len {out_len}',\n                '--random-range-ratio 1.0', f'--num-prompts {num_prompts}', '--stream-output', f'--csv {csv_path}'\n            ]).strip()\n\n            if enable_prefix_caching:\n                command += ' --enable-prefix-caching'\n\n            if concurrency:\n                command += f' --concurrency {concurrency}'\n\n            result, stderr = execute_command_with_logging(command, benchmark_log, env=env)\n            allure.attach.file(benchmark_log, name=benchmark_log, attachment_type=allure.attachment_type.TEXT)\n\n            if result and not os.path.isfile(csv_path):\n                return False, 'result is empty'\n            if not result:\n                return False, stderr\n    return True, 'success'\n\n\ndef get_max_cache_entry(model, backend):\n    if backend == 'pytorch':\n        return 0.8\n    if 'Llama-2' in model:\n        return 0.95\n    elif 'internlm2' in model:\n        return 0.9\n    elif 'Qwen/Qwen3-235B-A22B' == model or 'internlm/Intern-S1' == model:\n        return 0.7\n    else:\n        return None\n"
  },
  {
    "path": "autotest/utils/common_utils.py",
    "content": "import os\nimport subprocess\nimport sys\n\n\ndef execute_command_with_logging(cmd,\n                                 log_file_path: str,\n                                 timeout: int = 3600,\n                                 env=None,\n                                 should_print=True) -> tuple[bool, str]:\n    if env is None:\n        env = os.environ.copy()\n\n    if os.path.isfile(log_file_path):\n        write_type = 'a'\n    else:\n        write_type = 'w'\n    try:\n        result = True\n        with open(log_file_path, write_type, encoding='utf-8') as log_file:\n            start_msg = f'execute command: {cmd}\\n'\n            print(start_msg, end='')\n            log_file.write(start_msg)\n            log_file.flush()\n\n            process = subprocess.run(cmd,\n                                     shell=True,\n                                     text=True,\n                                     encoding='utf-8',\n                                     errors='replace',\n                                     stdout=subprocess.PIPE,\n                                     stderr=subprocess.STDOUT,\n                                     env=env,\n                                     bufsize=1,\n                                     timeout=timeout,\n                                     start_new_session=True)\n\n            if process.stdout:\n                if should_print:\n                    print(process.stdout, end='')\n                log_file.write(process.stdout)\n\n            if process.returncode == 0:\n                result_msg = 'execute command success!\\n'\n            else:\n                result = False\n                result_msg = f'execute command fail: {process.returncode}\\n'\n\n            log_file.write(result_msg)\n\n        return result, result_msg.strip()\n\n    except Exception as e:\n        error_msg = f'execute command fail exception: {str(e)}\\n'\n        print(error_msg, file=sys.stderr, end='')\n\n        with open(log_file_path, 'a', encoding='utf-8') as log_file:\n            log_file.write(error_msg)\n\n        return False, error_msg.strip()\n"
  },
  {
    "path": "autotest/utils/config_utils.py",
    "content": "import copy\nimport os\nfrom collections import OrderedDict\nfrom typing import Any\n\nimport yaml\n\nfrom lmdeploy.utils import is_bf16_supported\n\nSUFFIX_INNER_AWQ = '-inner-4bits'\nSUFFIX_INNER_GPTQ = '-inner-gptq'\nSUFFIX_INNER_W8A8 = '-inner-w8a8'\n\n\ndef resolve_extra_params(extra_params: dict[str, Any], model_base_path: str) -> None:\n    \"\"\"Resolve relative model paths in extra_params to absolute paths.\n\n    Centralised helper so that every call-site does not need its own\n    ``if key in extra_params …`` guard – adding a new key here is enough.\n    \"\"\"\n    # Keys in extra_params whose string values are relative model paths\n    model_path_keys = ['speculative-draft-model']\n\n    # Flat string-valued keys\n    for key in model_path_keys:\n        if key in extra_params:\n            value = extra_params[key]\n            if value and isinstance(value, str) and not os.path.isabs(value):\n                extra_params[key] = os.path.join(model_base_path, value)\n\n    # Nested speculative_config (pipeline usage)\n    spec_cfg = extra_params.get('speculative_config')\n    if isinstance(spec_cfg, dict) and 'model' in spec_cfg:\n        model = spec_cfg['model']\n        if model and isinstance(model, str) and not os.path.isabs(model):\n            spec_cfg['model'] = os.path.join(model_base_path, model)\n\n\ndef get_func_config_list(backend: str,\n                         parallel_config: dict[str, int],\n                         model_type: str = 'chat_model',\n                         func_type: str = 'func',\n                         extra: dict[str, Any] | None = None) -> list[dict[str, Any]]:\n    \"\"\"Generate all valid running config combinations (communicator + quant\n    policy + model).\n\n    Args:\n        backend: Backend type (turbomind/pytorch)\n        parallel_config: Parallel config for tensor parallel\n        model_type: Model type, default: chat_model\n        func_type: Test func type filter, default: func\n        extra: extra config to update in each run config dict\n    Returns:\n        list[dict]: All valid run config dicts\n    \"\"\"\n    config = get_config()\n    device = config.get('device', 'cuda')\n    base_case_list = get_model_list(config, backend, parallel_config, model_type, func_type)\n\n    if extra is None:\n        extra = {}\n\n    run_configs = []\n    dtype = 'float16' if not is_bf16_supported(device) else None\n\n    for communicator in _get_communicator_list(config, backend, parallel_config):\n        for model in base_case_list:\n            for quant_policy in [0, 4, 8]:\n                # temp remove testcase because of issue 3434\n                if 'turbomind' == backend and communicator == 'cuda-ipc' and parallel_config.get(\n                        'tp', 1) > 1 and ('InternVL3' in model or 'InternVL2_5' in model or 'MiniCPM-V-2_6' in model\n                                          or 'InternVL2-Llama3' in model):  # noqa\n                    continue\n                if 'turbomind' == backend and parallel_config.get(\n                        'tp', 1\n                ) > 1 and model_type == 'vl_model' and func_type == 'mllm_evaluate':  # mllm eval with bug when tp > 2\n                    continue\n                # [TM][FATAL] models/llama/LlamaBatch.cc(362): Check failed: r->session.start_flag Mrope doesn't support interactive chat # noqa\n                if ('Qwen2.5-VL' in model or 'Qwen2-VL' in model) and 'turbomind' == backend:\n                    continue\n                # AssertionError: prompts should be a list\n                if 'phi' in model.lower() and model_type == 'vl_model':\n                    continue\n                if not _is_kvint_model(config, backend, model, quant_policy):\n                    continue\n                run_config = {\n                    'model': model,\n                    'backend': backend,\n                    'communicator': communicator,\n                    'quant_policy': quant_policy,\n                    'parallel_config': parallel_config,\n                    'extra_params': copy.copy(extra)\n                }\n                if dtype and backend == 'pytorch':\n                    run_config['extra_params']['dtype'] = dtype\n                if device != 'cuda':\n                    run_config['extra_params']['device'] = device\n                run_configs.append(run_config)\n\n    for run_config in run_configs:\n        if 'Qwen3-235B-A22B-Thinking-2507' in run_config['model']:\n            run_config['extra_params']['cache-max-entry-count'] = 0.9\n            run_config['extra_params']['max-batch-size'] = 1024\n\n        if config.get('env_tag', '') in ['3090', '5080']:\n            run_config['extra_params']['cache-max-entry-count'] = 0.5\n\n        if config.get('env_tag', '') in ['a100'] and ('Qwen3-235B-A22B' in run_config['model']\n                                                      or run_config['model'] == 'internlm/Intern-S1'):\n            run_config['extra_params']['cache-max-entry-count'] = 0.6\n\n        if 'sdar' in run_config['model'].lower():\n            run_config['extra_params']['dllm-block-length'] = 4\n            run_config['extra_params']['dllm-denoising-steps'] = 4\n            run_config['extra_params']['dllm-confidence-threshold'] = 0.9\n\n        if 'kimi' in run_config['model'].lower():\n            para_conf = run_config.get('parallel_config', {})\n            if para_conf.get('dp', 0) == 16 and para_conf.get('ep', 0) == 16:\n                run_config['extra_params']['max-batch-size'] = 256\n\n        if 'Intern-S1-Pro-FP8' in run_config['model'] or 'Intern-S1-Pro-BF16' in run_config['model']:\n            if 'Intern-S1-Pro-FP8' in run_config['model']:\n                run_config['extra_params']['model-format'] = 'fp8'\n            para_conf = run_config.get('parallel_config', {})\n            # For dpep16 configuration, add max-prefill-token-num\n            if para_conf.get('dp', 0) == 16 and para_conf.get('ep', 0) == 16:\n                run_config['extra_params']['max-prefill-token-num'] = 1024\n                run_config['extra_params']['max-batch-size'] = 128\n\n    return run_configs\n\n\ndef get_cli_common_param(run_config: dict[str, Any]) -> str:\n    \"\"\"Generate cli common params string by run config dict.\"\"\"\n    backend = run_config.get('backend')\n    model = run_config.get('model')\n    communicator = run_config.get('communicator')\n    quant_policy = run_config.get('quant_policy')\n    extra_params = run_config.get('extra_params', {})\n    parallel_config = run_config.get('parallel_config', {})\n\n    cli_params = [f'--backend {backend}', f'--communicator {communicator}']\n    # Optional params\n    if quant_policy != 0:\n        cli_params.append(f'--quant-policy {quant_policy}')\n\n    # quant format\n    model_lower = model.lower()\n    if 'w4' in model_lower or '4bits' in model_lower or 'awq' in model_lower:\n        cli_params.append('--model-format awq')\n    if 'gptq' in model_lower:\n        cli_params.append('--model-format gptq')\n\n    # Parallel config\n    for para_key in ('dp', 'ep', 'cp'):\n        if para_key in parallel_config and parallel_config[para_key] > 1:\n            cli_params.append(f'--{para_key} {parallel_config[para_key]}')\n    if 'tp' in parallel_config and parallel_config['tp'] > 1:\n        tp_num = parallel_config['tp']\n        cli_params.append(f'--tp {tp_num}')  # noqa\n\n    # Extra params\n    cli_params.append(get_cli_str(extra_params))\n\n    return ' '.join(cli_params).strip()\n\n\ndef get_cli_str(config: dict[str, Any]) -> str:\n    cli_str = []\n    # Extra params\n    for key, value in config.items():\n        key = key.replace('_', '-')\n        if value is None:\n            cli_str.append(f'--{key}')\n        elif isinstance(value, list):\n            tmp_cli = ' '.join(map(str, value))\n            cli_str.append(f'--{key} {tmp_cli}')\n        elif isinstance(value, dict):\n            tmp_cli = ' '.join([f'{k}={v}' for k, v in value.items()])\n            cli_str.append(f'--{key} {tmp_cli}')\n        else:\n            cli_str.append(f'--{key} {value}' if value else f'--{key}')\n\n    return ' '.join(cli_str)\n\n\ndef get_parallel_config(config: dict[str, Any], model_name: str) -> list[dict[str, int]]:\n    \"\"\"Get matched parallel config dict by model name, default tp:1 if no\n    match.\"\"\"\n    result = []\n    base_model = _base_model_name(model_name)\n    parallel_configs = config.get('config', {})\n\n    for conf_key, model_map in parallel_configs.items():\n        if model_map is None:\n            continue\n        if base_model in model_map:\n            conf_value = model_map[base_model]\n            if isinstance(conf_value, dict):\n                result.append(conf_value.copy())\n            elif isinstance(conf_value, int):\n                result.append({conf_key: conf_value})\n\n    return result if result else [{'tp': 1}]\n\n\ndef _extract_models_from_config(config_value: Any) -> list[str]:\n    \"\"\"Extract flat model name list from config value (dict/list supported)\"\"\"\n    models = []\n    if isinstance(config_value, dict):\n        for model_list in config_value.values():\n            if isinstance(model_list, list):\n                models.extend([m for m in model_list if isinstance(m, str)])\n    elif isinstance(config_value, list):\n        models.extend([m for m in config_value if isinstance(m, str)])\n    return models\n\n\ndef get_model_list(config: dict[str, Any],\n                   backend: str,\n                   parallel_config: dict[str, int] | None = None,\n                   model_type: str = 'chat_model',\n                   func_type: str = 'func') -> list[str]:\n    \"\"\"Get filtered model list with quantization extended models by\n    backend/parallel config/model type/func type.\n\n    Args:\n        config: Global system config dict\n        backend: Backend type (turbomind/pytorch)\n        parallel_config: Parallel filter config\n        model_type: Model type, default: chat_model\n        func_type: Test func type filter, default: func\n    Returns:\n        list[str]: Base models + quantization extended models\n    \"\"\"\n    model_config_key = f'{backend}_{model_type}'\n    all_models = []\n\n    if model_config_key in config:\n        all_models = _extract_models_from_config(config[model_config_key])\n\n    all_models = _filter_by_test_func_type(config, all_models, func_type)\n    all_models = list(OrderedDict.fromkeys(all_models))  # Deduplicate, keep order\n    all_models = [model for model in all_models if is_model_in_list(config, parallel_config, model)]\n\n    extended_models = list(all_models)\n    quantization_config = config.get(f'{backend}_quantization', {})\n\n    # Append quantization models by backend\n    if backend == 'turbomind':\n        _extend_turbomind_quant_models(quantization_config, all_models, extended_models)\n    elif backend == 'pytorch':\n        _extend_pytorch_quant_models(quantization_config, all_models, extended_models)\n\n    return extended_models\n\n\ndef _filter_by_test_func_type(config: dict[str, Any], model_list: list[str], func_type: str) -> list[str]:\n    \"\"\"Filter model list by test function type, return intersection of two\n    model sets.\"\"\"\n    if func_type == 'func':\n        return model_list\n\n    filtered_models = []\n    model_config_key = f'{func_type}_model'\n    if model_config_key in config:\n        filtered_models = _extract_models_from_config(config[model_config_key])\n\n    return list(set(filtered_models) & set(model_list))\n\n\ndef _extend_turbomind_quant_models(quant_config: dict[str, Any], base_models: list[str],\n                                   target_list: list[str]) -> None:\n    \"\"\"Append turbomind quantization models to target list (AWQ 4bits +\n    GPTQ)\"\"\"\n    no_awq_models = quant_config.get('no_awq', [])\n    # Append AWQ 4bits quantization models\n    for model_name in base_models:\n        if model_name in target_list and model_name not in no_awq_models and not is_quantization_model(model_name):\n            target_list.append(model_name + SUFFIX_INNER_AWQ)\n    # Append GPTQ quantization models\n    for model_name in quant_config.get('gptq', []):\n        if model_name in target_list:\n            target_list.append(model_name + SUFFIX_INNER_GPTQ)\n\n\ndef _extend_pytorch_quant_models(quant_config: dict[str, Any], base_models: list[str], target_list: list[str]) -> None:\n    \"\"\"Append pytorch quantization models to target list (AWQ 4bits + W8A8)\"\"\"\n    # Append AWQ quantization models\n    for model_name in quant_config.get('awq', []):\n        if model_name in target_list:\n            target_list.append(model_name + SUFFIX_INNER_AWQ)\n    # Append W8A8 quantization models\n    for model_name in quant_config.get('w8a8', []):\n        if model_name in target_list:\n            target_list.append(model_name + SUFFIX_INNER_W8A8)\n\n\ndef _is_kvint_model(config: dict[str, Any], backend: str, model: str, quant_policy: int) -> bool:\n    \"\"\"Check if model supports the kv quantization policy, quant_policy=0\n    always return True.\"\"\"\n    if quant_policy == 0:\n        return True\n    no_kvint_black_list = config.get(f'{backend}_quantization', {}).get(f'no_kvint{quant_policy}', [])\n\n    return _base_model_name(model) not in no_kvint_black_list\n\n\ndef _base_model_name(model: str) -> str:\n    \"\"\"Simplify model name by removing quantization suffix for config\n    matching.\"\"\"\n    return model.replace('-inner-4bits', '').replace('-inner-w8a8', '').replace('-inner-gptq', '')\n\n\ndef get_quantization_model_list(type: str) -> list[str]:\n    \"\"\"Get quantization model list by specified quant type(awq/gptq/w8a8)\"\"\"\n    config = get_config()\n    quant_model_list = []\n\n    if type == 'awq':\n        # Get all turbomind chat/base models & deduplicate\n        turbo_chat = _extract_models_from_config(\n            config['turbomind_chat_model']) if 'turbomind_chat_model' in config else []\n        turbo_base = _extract_models_from_config(\n            config['turbomind_base_model']) if 'turbomind_base_model' in config else []\n        all_turbo_models = list(OrderedDict.fromkeys(turbo_chat + turbo_base))\n\n        # Filter turbomind valid awq models\n        no_awq = config.get('turbomind_quantization', {}).get('no_awq', [])\n        quant_model_list = [m for m in all_turbo_models if m not in no_awq and not is_quantization_model(m)]\n\n        # Append pytorch awq models\n        torch_awq = config.get('pytorch_quantization', {}).get('awq', [])\n        for model in torch_awq:\n            if model not in quant_model_list:\n                quant_model_list.append(model)\n\n    elif type == 'gptq':\n        quant_model_list = config.get('turbomind_quantization', {}).get(type, [])\n\n    elif type == 'w8a8':\n        quant_model_list = config.get('pytorch_quantization', {}).get(type, [])\n\n    return quant_model_list\n\n\ndef get_config() -> dict[str, Any]:\n    \"\"\"Load & get yaml config file, auto adapt device env & update log path.\"\"\"\n    # Get device env & match config file path\n    env_tag = os.environ.get('TEST_ENV')\n    config_path = f'autotest/config_{env_tag}.yml' if env_tag else 'autotest/config.yml'\n\n    # Fallback to default config if device-specific config not exist\n    if env_tag and not os.path.exists(config_path):\n        config_path = 'autotest/config.yml'\n    # Load yaml config file safely\n    with open(config_path, 'r', encoding='utf-8') as f:\n        config = yaml.load(f.read(), Loader=yaml.SafeLoader)\n\n    # Deep copy config to avoid modify raw data, update log path with github run id\n    config_copy = copy.deepcopy(config)\n    run_id = os.environ.get('RUN_ID', 'local_run')\n    config_copy['log_path'] = os.path.join(config_copy['log_path'], str(run_id).replace('/', '_'))\n    config_copy['eval_path'] = os.path.join(config_copy['eval_path'], str(run_id).replace('/', '_'))\n    config_copy['mllm_eval_path'] = os.path.join(config_copy['mllm_eval_path'], str(run_id).replace('/', '_'))\n    config_copy['benchmark_path'] = os.path.join(config_copy['benchmark_path'], str(run_id).replace('/', '_'))\n    config_copy['server_log_path'] = os.path.join(config_copy['server_log_path'], str(run_id).replace('/', '_'))\n    os.makedirs(config_copy['log_path'], exist_ok=True)\n    os.makedirs(config_copy['eval_path'], exist_ok=True)\n    os.makedirs(config_copy['mllm_eval_path'], exist_ok=True)\n    os.makedirs(config_copy['benchmark_path'], exist_ok=True)\n    os.makedirs(config_copy['server_log_path'], exist_ok=True)\n\n    return config_copy\n\n\ndef get_cuda_prefix_by_workerid(worker_id: str | None, parallel_config: dict[str, int] | None = None) -> str | None:\n    \"\"\"Get cuda/ascend visible devices env prefix by worker id & parallel\n    config.\"\"\"\n    para_conf = parallel_config or {}\n    device_type = os.environ.get('DEVICE', 'cuda')\n\n    tp_num = para_conf.get('tp')\n    if not tp_num:\n        return ''\n\n    cuda_id = get_cuda_id_by_workerid(worker_id, tp_num)\n    if not cuda_id:\n        return ''\n\n    return f'ASCEND_RT_VISIBLE_DEVICES={cuda_id}' if device_type == 'ascend' else f'CUDA_VISIBLE_DEVICES={cuda_id}'\n\n\ndef get_cuda_id_by_workerid(worker_id: str | None, tp_num: int = 1) -> str | None:\n    \"\"\"Get cuda id str by worker id and tp num, return None if invalid worker\n    id.\"\"\"\n    if worker_id is None or 'gw' not in worker_id:\n        return None\n\n    base_id = int(worker_id.replace('gw', ''))\n    cuda_num = base_id * tp_num\n    return ','.join([str(cuda_num + i) for i in range(tp_num)])\n\n\ndef get_workerid(worker_id: str | None) -> int:\n    \"\"\"Parse numeric worker id from worker id str, return 0 if invalid worker\n    id.\"\"\"\n    if worker_id is None or 'gw' not in worker_id:\n        return 0\n\n    return int(worker_id.replace('gw', ''))\n\n\ndef is_quantization_model(model: str) -> bool:\n    \"\"\"Check if model name contains quantization related keywords.\"\"\"\n    lower_name = model.lower()\n    return any(key in lower_name for key in ('awq', '4bits', 'w4', 'int4'))\n\n\ndef _get_communicator_list(config: dict[str, Any],\n                           backend: str,\n                           parallel_config: dict[str, int] | None = None) -> list[str]:\n    \"\"\"Get available communicator list by device and parallel config.\"\"\"\n    device = config.get('device', None)\n\n    if device == 'ascend':\n        return ['nccl']\n    if backend == 'pytorch':\n        return ['nccl']\n    if ('cp' in parallel_config or 'dp' in parallel_config or 'ep' in parallel_config):\n        return ['nccl']\n    if 'tp' in parallel_config and parallel_config['tp'] == 1:\n        return ['nccl']\n\n    return ['nccl', 'cuda-ipc']\n\n\ndef set_device_env_variable(worker_id: str | None, parallel_config: dict[str, int] | None = None) -> None:\n    \"\"\"Set device environment variable based on the device type.\"\"\"\n    device = os.environ.get('DEVICE', 'cuda')\n\n    tp_num = 1\n    if parallel_config is not None:\n        if isinstance(parallel_config, int):\n            tp_num = parallel_config\n        elif isinstance(parallel_config, dict):\n            tp_num = parallel_config.get('tp', 1)\n\n    if device == 'ascend':\n        device_id = get_cuda_id_by_workerid(worker_id, tp_num)\n        if device_id is not None:\n            os.environ['ASCEND_RT_VISIBLE_DEVICES'] = device_id\n    else:\n        cuda_id = get_cuda_id_by_workerid(worker_id, tp_num)\n        if cuda_id is not None:\n            os.environ['CUDA_VISIBLE_DEVICES'] = cuda_id\n\n\ndef unset_device_env_variable():\n    device_type = os.environ.get('DEVICE', 'cuda')\n    if device_type == 'ascend':\n        if 'ASCEND_RT_VISIBLE_DEVICES' in os.environ:\n            del os.environ['ASCEND_RT_VISIBLE_DEVICES']\n    else:\n        if 'CUDA_VISIBLE_DEVICES' in os.environ:\n            del os.environ['CUDA_VISIBLE_DEVICES']\n\n\ndef is_model_in_list(config: dict[str, Any], parallel_config: dict[str, int], model: str) -> bool:\n    \"\"\"Check if model matches the target parallel config.\"\"\"\n    model_config = get_parallel_config(config, model)\n    return parallel_config in model_config\n\n\ndef get_case_str_by_config(run_config: dict[str, Any], is_simple: bool = True) -> str:\n    \"\"\"Generate case name string by run config dict.\"\"\"\n    model_name = run_config['model']\n    backend_type = run_config['backend']\n    communicator = run_config.get('communicator', 'nccl')\n    quant_policy = run_config.get('quant_policy', 0)\n    parallel_config = run_config.get('parallel_config', {'tp': 1})\n    extra_params = run_config.get('extra_params', {})\n\n    # Sorted parallel config to fixed string format\n    sorted_items = sorted(parallel_config.items())\n    parallel_str = '_'.join(f'{k}{v}' for k, v in sorted_items)\n    # Get last section of model name, compatible with model name contains '/'\n    pure_model_name = model_name.split('/')[-1].replace('_', '-')\n    extra_params_case = ''\n    if not is_simple:\n        for k, v in extra_params.items():\n            if len(v) > 10:\n                extra_params_case += f'_{k}'.replace('_', '-').replace('/', '-').replace('.', '-')\n            else:\n                extra_params_case += f'_{k}{v}'.replace('_', '-').replace('/', '-').replace('.', '-')\n\n    return f'{backend_type}_{pure_model_name}_{communicator}_{parallel_str}_{quant_policy}{extra_params_case}'\n\n\ndef parse_config_by_case(case_str: str) -> dict[str, Any]:\n    \"\"\"Parse run config dict from case name string (fix split & type convert\n    bug)\"\"\"\n    case_parts = case_str.split('_')\n    # Parse fixed field & reassemble dynamic parallel config\n    backend = case_parts[0]\n    model = case_parts[1]\n    communicator = case_parts[2]\n    quant_policy = int(case_parts[-1])\n    parallel_parts = case_parts[3:-1]\n\n    # Convert parallel str to dict, e.g: ['tp1','pp2'] -> {'tp':1, 'pp':2}\n    parallel_config = {}\n    for part in parallel_parts:\n        for idx, char in enumerate(part):\n            if char.isdigit():\n                k = part[:idx]\n                v = int(part[idx:])\n                parallel_config[k] = v\n                break\n\n    return {\n        'backend': backend,\n        'model': model,\n        'communicator': communicator,\n        'parallel_config': parallel_config,\n        'quant_policy': quant_policy\n    }\n\n\ndef test_config():\n    os.environ['DEVICE'] = 'test'\n    config = get_config()\n    assert 'model_path' in config.keys()\n    assert 'resource_path' in config.keys()\n    assert 'log_path' in config.keys()\n    assert 'server_log_path' in config.keys()\n    assert 'eval_path' in config.keys()\n    assert 'mllm_eval_path' in config.keys()\n    assert 'benchmark_path' in config.keys()\n    assert 'dataset_path' in config.keys()\n    assert 'prefix_dataset_path' in config.keys()\n    assert 'env_tag' in config.keys()\n    assert 'config' in config.keys()\n    assert 'tp' in config.get('config')\n\n    assert is_model_in_list(config, parallel_config={'tp': 1}, model='test/test_tp1')\n    assert is_model_in_list(config, parallel_config={'tp': 2}, model='test/test_tp1') is False\n    assert is_model_in_list(config, parallel_config={'ep': 1},\n                            model='test/test_tp1') is False, is_model_in_list(config,\n                                                                              parallel_config={'ep': 1},\n                                                                              model='test/test_tp1')\n    assert is_model_in_list(config, parallel_config={'tp': 2}, model='test/test_tp2-inner-4bits')\n    assert is_model_in_list(config, parallel_config={'tp': 2}, model='test/test_tp2-inner-w8a8')\n    assert is_model_in_list(config, parallel_config={'tp': 8}, model='test/test_tp8-inner-gptq')\n    assert is_model_in_list(config, parallel_config={'tp': 8}, model='test/test_cp2tp8') is False\n    assert is_model_in_list(config, parallel_config={'tp': 8, 'cp': 2}, model='test/test_cp2tp8')\n    assert is_model_in_list(config, parallel_config={'cp': 2, 'tp': 8}, model='test/test_cp2tp8')\n    assert is_model_in_list(config, parallel_config={'cp': 4, 'tp': 8}, model='test/test_cp2tp8') is False\n    assert is_model_in_list(config, parallel_config={'dp': 8, 'ep': 8}, model='test/test_dpep8')\n    assert is_model_in_list(config, parallel_config={'dp': 4, 'ep': 8}, model='test/test_dpep8') is False\n    assert is_model_in_list(config, parallel_config={'ep': 4, 'dp': 8}, model='test/test_dpep8') is False\n\n    assert _is_kvint_model(config, 'turbomind', 'test/test_tp1-inner-4bits', 8) is False\n    assert _is_kvint_model(config, 'turbomind', 'test/test_tp1-inner-4bits', 4)\n    assert _is_kvint_model(config, 'turbomind', 'any', 0)\n    assert _is_kvint_model(config, 'pytorch', 'test/test_tp1-inner-gptq', 8) is False\n    assert _is_kvint_model(config, 'pytorch', 'test/test_tp1-inner-gptq', 4)\n    assert _is_kvint_model(config, 'pytorch', 'test/test_vl_tp1-inner-gptq', 8) is False\n    assert _is_kvint_model(config, 'pytorch', 'test/test_cp2tp8-inner-w8a8', 4) is False\n    os.unsetenv('DEVICE')\n\n\ndef test_get_case_str_by_config():\n    run_config = {\n        'model': 'test/test_dpep16',\n        'backend': 'turbomind',\n        'communicator': 'nccl',\n        'quant_policy': 8,\n        'parallel_config': {\n            'dp': 16,\n            'ep': 16\n        }\n    }\n    case_str = get_case_str_by_config(run_config)\n    assert case_str == 'turbomind_test-dpep16_nccl_dp16_ep16_8', case_str\n    run_config_parsed = parse_config_by_case(case_str)\n    assert run_config_parsed['model'] == 'test-dpep16'\n    assert run_config_parsed['backend'] == 'turbomind'\n    assert run_config_parsed['communicator'] == 'nccl'\n    assert run_config_parsed['quant_policy'] == 8\n    assert run_config_parsed['parallel_config']['dp'] == 16\n    assert run_config_parsed['parallel_config']['ep'] == 16\n\n\ndef test_cli_common_param():\n    run_config = {\n        'model': 'test/test_dpep16-inner-4bits',\n        'backend': 'turbomind',\n        'communicator': 'nccl',\n        'quant_policy': 8,\n        'parallel_config': {\n            'dp': 16,\n            'ep': 16\n        },\n        'extra_params': {\n            'dtype': 'bfloat16',\n            'device': 'ascend',\n            'enable_prefix_caching': None,\n            'max_batch_size': 2048,\n            'session_len': 8192,\n            'cache_max_entry_count': 0.75,\n            'adapters': {\n                'a': 'lora/2024-01-25_self_dup',\n                'b': 'lora/2024-01-25_self'\n            }\n        }\n    }\n\n    cli_params = get_cli_common_param(run_config)\n    assert cli_params == '--backend turbomind --communicator nccl --quant-policy 8 --model-format awq --dp 16 --ep 16 --dtype bfloat16 --device ascend --enable-prefix-caching --max-batch-size 2048 --session-len 8192 --cache-max-entry-count 0.75 --adapters a=lora/2024-01-25_self_dup b=lora/2024-01-25_self', cli_params  # noqa\n    run_config = {\n        'model': 'test/test_dpep16-inner-4bits',\n        'backend': 'pytorch',\n        'communicator': 'nccl',\n        'quant_policy': 0,\n        'parallel_config': {\n            'tp': 8\n        }\n    }\n\n    cli_params = get_cli_common_param(run_config)\n    assert cli_params == '--backend pytorch --communicator nccl --model-format awq --tp 8', cli_params\n    os.unsetenv('TEST_ENV')\n\n\ndef test_return_info_turbomind():\n    os.environ['TEST_ENV'] = 'test'\n    backend = 'turbomind'\n    func_chat_tp1 = get_func_config_list(backend, parallel_config={'tp': 1}, model_type='chat_model', func_type='func')\n    assert len(func_chat_tp1) == 12, len(func_chat_tp1)\n    func_chat_tp2 = get_func_config_list(backend, parallel_config={'tp': 2}, model_type='chat_model', func_type='func')\n    assert len(func_chat_tp2) == 32, len(func_chat_tp2)\n    func_chat_tp8 = get_func_config_list(backend, parallel_config={'tp': 8}, model_type='chat_model', func_type='func')\n    assert len(func_chat_tp8) == 36, len(func_chat_tp8)\n    func_chat_cptp = get_func_config_list(backend,\n                                          parallel_config={\n                                              'cp': 2,\n                                              'tp': 8\n                                          },\n                                          model_type='chat_model',\n                                          func_type='func')\n    assert len(func_chat_cptp) == 14, len(func_chat_cptp)\n    func_chat_dpep8 = get_func_config_list(backend,\n                                           parallel_config={\n                                               'dp': 8,\n                                               'ep': 8\n                                           },\n                                           model_type='chat_model',\n                                           func_type='func')\n    assert len(func_chat_dpep8) == 6, len(func_chat_dpep8)\n    func_chat_dpep16 = get_func_config_list(backend,\n                                            parallel_config={\n                                                'dp': 16,\n                                                'ep': 16\n                                            },\n                                            model_type='chat_model',\n                                            func_type='func')\n    assert len(func_chat_dpep16) == 0, len(func_chat_dpep16)\n    func_base_tp1 = get_func_config_list(backend, parallel_config={'tp': 1}, model_type='base_model', func_type='func')\n    assert len(func_base_tp1) == 6, len(func_base_tp1)\n    func_base_tp2 = get_func_config_list(backend, parallel_config={'tp': 2}, model_type='base_model', func_type='func')\n    assert len(func_base_tp2) == 4, len(func_base_tp2)\n\n    evaluate_tp1 = get_func_config_list(backend,\n                                        parallel_config={'tp': 1},\n                                        model_type='chat_model',\n                                        func_type='evaluate')\n    assert len(evaluate_tp1) == 6, len(evaluate_tp1)\n    benchmark_tp2 = get_func_config_list(backend,\n                                         parallel_config={'tp': 2},\n                                         model_type='chat_model',\n                                         func_type='benchmark')\n    assert len(benchmark_tp2) == 4, len(benchmark_tp2)\n    longtext_tp8 = get_func_config_list(backend,\n                                        parallel_config={'tp': 8},\n                                        model_type='chat_model',\n                                        func_type='longtext')\n    assert len(longtext_tp8) == 12, len(longtext_tp8)\n    evaluate_cptp = get_func_config_list(backend,\n                                         parallel_config={\n                                             'cp': 2,\n                                             'tp': 8\n                                         },\n                                         model_type='chat_model',\n                                         func_type='evaluate')\n    assert len(evaluate_cptp) == 4, len(evaluate_cptp)\n    benchmark_dpep8 = get_func_config_list(backend,\n                                           parallel_config={\n                                               'dp': 8,\n                                               'ep': 8\n                                           },\n                                           model_type='chat_model',\n                                           func_type='benchmark')\n    assert len(benchmark_dpep8) == 0, len(benchmark_dpep8)\n\n    mllm_benchmark_tp1 = get_func_config_list(backend,\n                                              parallel_config={'tp': 1},\n                                              model_type='chat_model',\n                                              func_type='mllm_benchmark')\n    assert len(mllm_benchmark_tp1) == 6, len(mllm_benchmark_tp1)\n    mllm_longtext_tp2 = get_func_config_list(backend,\n                                             parallel_config={'tp': 2},\n                                             model_type='chat_model',\n                                             func_type='mllm_longtext')\n    assert len(mllm_longtext_tp2) == 0, len(mllm_longtext_tp2)\n    mllm_evaluate_tp8 = get_func_config_list(backend,\n                                             parallel_config={'tp': 8},\n                                             model_type='chat_model',\n                                             func_type='mllm_evaluate')\n    assert len(mllm_evaluate_tp8) == 12, len(mllm_evaluate_tp8)\n    mllm_evaluate_dpep16 = get_func_config_list(backend,\n                                                parallel_config={\n                                                    'dp': 16,\n                                                    'ep': 16\n                                                },\n                                                model_type='chat_model',\n                                                func_type='evaluate')\n    assert len(mllm_evaluate_dpep16) == 0, len(mllm_evaluate_dpep16)\n    mllm_benchmark_cptp = get_func_config_list(backend,\n                                               parallel_config={\n                                                   'cp': 2,\n                                                   'tp': 8\n                                               },\n                                               model_type='chat_model',\n                                               func_type='benchmark')\n    assert len(mllm_benchmark_cptp) == 4, len(mllm_benchmark_cptp)\n    os.unsetenv('TEST_ENV')\n\n\ndef test_return_info_pytorch():\n    os.environ['TEST_ENV'] = 'test'\n    backend = 'pytorch'\n    func_chat_tp1 = get_func_config_list(backend, parallel_config={'tp': 1}, model_type='chat_model', func_type='func')\n    assert len(func_chat_tp1) == 12, len(func_chat_tp1)\n    func_chat_tp2 = get_func_config_list(backend, parallel_config={'tp': 2}, model_type='chat_model', func_type='func')\n    assert len(func_chat_tp2) == 19, len(func_chat_tp2)\n    func_chat_tp8 = get_func_config_list(backend, parallel_config={'tp': 8}, model_type='chat_model', func_type='func')\n    assert len(func_chat_tp8) == 9, len(func_chat_tp8)\n    func_chat_cptp = get_func_config_list(backend,\n                                          parallel_config={\n                                              'cp': 2,\n                                              'tp': 8\n                                          },\n                                          model_type='chat_model',\n                                          func_type='func')\n    assert len(func_chat_cptp) == 7, len(func_chat_cptp)\n    func_chat_dpep8 = get_func_config_list(backend,\n                                           parallel_config={\n                                               'dp': 8,\n                                               'ep': 8\n                                           },\n                                           model_type='chat_model',\n                                           func_type='func')\n    assert len(func_chat_dpep8) == 8, len(func_chat_dpep8)\n    func_chat_dpep16 = get_func_config_list(backend,\n                                            parallel_config={\n                                                'dp': 16,\n                                                'ep': 16\n                                            },\n                                            model_type='chat_model',\n                                            func_type='func')\n    assert len(func_chat_dpep16) == 6, len(func_chat_dpep16)\n    func_base_tp1 = get_func_config_list(backend, parallel_config={'tp': 1}, model_type='base_model', func_type='func')\n    assert len(func_base_tp1) == 7, len(func_base_tp1)\n    func_base_tp2 = get_func_config_list(backend, parallel_config={'tp': 2}, model_type='base_model', func_type='func')\n    assert len(func_base_tp2) == 4, len(func_base_tp2)\n\n    evaluate_tp1 = get_func_config_list(backend,\n                                        parallel_config={'tp': 1},\n                                        model_type='chat_model',\n                                        func_type='evaluate')\n    assert len(evaluate_tp1) == 7, len(evaluate_tp1)\n    benchmark_tp2 = get_func_config_list(backend,\n                                         parallel_config={'tp': 2},\n                                         model_type='chat_model',\n                                         func_type='benchmark')\n    assert len(benchmark_tp2) == 3, len(benchmark_tp2)\n    longtext_tp8 = get_func_config_list(backend,\n                                        parallel_config={'tp': 8},\n                                        model_type='chat_model',\n                                        func_type='longtext')\n    assert len(longtext_tp8) == 3, len(longtext_tp8)\n    evaluate_cptp = get_func_config_list(backend,\n                                         parallel_config={\n                                             'cp': 2,\n                                             'tp': 8\n                                         },\n                                         model_type='chat_model',\n                                         func_type='evaluate')\n    assert len(evaluate_cptp) == 2, len(evaluate_cptp)\n    benchmark_dpep8 = get_func_config_list(backend,\n                                           parallel_config={\n                                               'dp': 8,\n                                               'ep': 8\n                                           },\n                                           model_type='chat_model',\n                                           func_type='benchmark')\n    assert len(benchmark_dpep8) == 2, len(benchmark_dpep8)\n\n    mllm_benchmark_tp1 = get_func_config_list(backend,\n                                              parallel_config={'tp': 1},\n                                              model_type='chat_model',\n                                              func_type='mllm_benchmark')\n    assert len(mllm_benchmark_tp1) == 5, len(mllm_benchmark_tp1)\n    mllm_longtext_tp2 = get_func_config_list(backend,\n                                             parallel_config={'tp': 2},\n                                             model_type='chat_model',\n                                             func_type='mllm_longtext')\n    assert len(mllm_longtext_tp2) == 0, len(mllm_longtext_tp2)\n    mllm_evaluate_tp8 = get_func_config_list(backend,\n                                             parallel_config={'tp': 8},\n                                             model_type='chat_model',\n                                             func_type='mllm_evaluate')\n    assert len(mllm_evaluate_tp8) == 3, len(mllm_evaluate_tp8)\n    mllm_evaluate_dpep16 = get_func_config_list(backend,\n                                                parallel_config={\n                                                    'dp': 16,\n                                                    'ep': 16\n                                                },\n                                                model_type='chat_model',\n                                                func_type='evaluate')\n    assert len(mllm_evaluate_dpep16) == 3, len(mllm_evaluate_dpep16)\n    mllm_benchmark_cptp = get_func_config_list(backend,\n                                               parallel_config={\n                                                   'cp': 2,\n                                                   'tp': 8\n                                               },\n                                               model_type='chat_model',\n                                               func_type='benchmark')\n    assert len(mllm_benchmark_cptp) == 2, len(mllm_benchmark_cptp)\n    os.unsetenv('TEST_ENV')\n\n\ndef test_run_config():\n    os.environ['TEST_ENV'] = 'test'\n    backend = 'turbomind'\n    run_config1 = get_func_config_list(backend, parallel_config={'tp': 1}, model_type='chat_model', func_type='func')[0]\n    assert run_config1['model'] == 'test/test_tp1'\n    assert run_config1['backend'] == 'turbomind'\n    assert run_config1['communicator'] == 'nccl'\n    assert run_config1['quant_policy'] == 0\n    assert run_config1['parallel_config'] == {'tp': 1}\n    os.environ['TEST_ENV'] = 'testascend'\n    backend = 'pytorch'\n    run_config2 = get_func_config_list(backend, parallel_config={'tp': 1}, model_type='chat_model', func_type='func')[0]\n    assert run_config2['model'] == 'test/test_tp1'\n    assert run_config2['backend'] == 'pytorch'\n    assert run_config2['communicator'] == 'nccl'\n    assert run_config2['quant_policy'] == 0\n    assert run_config2['parallel_config'] == {'tp': 1}\n    run_config3 = get_func_config_list(backend,\n                                       parallel_config={'tp': 1},\n                                       model_type='chat_model',\n                                       func_type='func',\n                                       extra={\n                                           'speculative_algorithm': 'eagle',\n                                           'session_len': 1024\n                                       })[0]\n    assert run_config3['model'] == 'test/test_tp1'\n    assert run_config3['backend'] == 'pytorch'\n    assert run_config3['communicator'] == 'nccl'\n    assert run_config3['quant_policy'] == 0\n    assert run_config3['parallel_config'] == {'tp': 1}\n    assert run_config3['extra_params']['speculative_algorithm'] == 'eagle'\n    assert run_config3['extra_params']['session_len'] == 1024\n    os.unsetenv('TEST_ENV')\n\n\ndef test_get_parallel_config():\n    test = get_parallel_config({}, 'empty')\n    assert test == [{'tp': 1}]\n    test = get_parallel_config(\n        {\n            'config': {\n                'tp': {\n                    'empty': 1\n                },\n                'dp_ep': {\n                    'empty': {\n                        'dp': 1,\n                        'ep': 8\n                    }\n                },\n                'cp_tp': {\n                    'empty': {\n                        'cp': 8,\n                        'tp': 8\n                    }\n                }\n            }\n        }, 'empty')\n    assert test == [{'tp': 1}, {'dp': 1, 'ep': 8}, {'cp': 8, 'tp': 8}]\n\n\nif __name__ == '__main__':\n    test_get_parallel_config()\n    test_cli_common_param()\n    test_run_config()\n    test_get_case_str_by_config()\n    test_return_info_pytorch()\n    test_config()\n    test_return_info_turbomind()\n"
  },
  {
    "path": "autotest/utils/constant.py",
    "content": "import os\n\nDEFAULT_PORT = 23333\nDEFAULT_SERVER = os.getenv('MASTER_ADDR', '127.0.0.1')\nPROXY_PORT = 8000\n\nEVAL_CONFIGS = {\n    'default': {\n        'query_per_second': 4,\n        'max_out_len': 64000,\n        'max_seq_len': 65536,\n        'batch_size': 500,\n        'temperature': 0.6,\n    },\n    'default-32k': {\n        'query_per_second': 4,\n        'max_out_len': 32768,\n        'max_seq_len': 65536,\n        'batch_size': 500,\n        'temperature': 0.6,\n    },\n    'default-2batch': {\n        'query_per_second': 4,\n        'max_out_len': 64000,\n        'max_seq_len': 65536,\n        'batch_size': 2,\n        'temperature': 0.6,\n    },\n    'gpt': {\n        'query_per_second': 4,\n        'max_out_len': 64000,\n        'max_seq_len': 65536,\n        'batch_size': 500,\n        'temperature': 0.6,\n        'openai_extra_kwargs': {\n            'reasoning_effort': 'high',\n        }\n    },\n    'gpt-32k': {\n        'query_per_second': 4,\n        'max_out_len': 32768,\n        'max_seq_len': 65536,\n        'batch_size': 500,\n        'temperature': 0.6,\n        'openai_extra_kwargs': {\n            'reasoning_effort': 'high',\n        }\n    },\n    'gpt-2batch': {\n        'query_per_second': 4,\n        'max_out_len': 64000,\n        'max_seq_len': 65536,\n        'batch_size': 2,\n        'temperature': 0.6,\n        'openai_extra_kwargs': {\n            'reasoning_effort': 'high',\n        }\n    },\n    'sdar': {\n        'query_per_second': 4,\n        'max_out_len': 64000,\n        'max_seq_len': 65536,\n        'batch_size': 500,\n        'temperature': 1.0,\n        'openai_extra_kwargs': {\n            'top_p': 1.0,\n        },\n        'extra_body': {\n            'top_k': 0,\n        }\n    },\n    'sdar-32k': {\n        'query_per_second': 4,\n        'max_out_len': 32768,\n        'max_seq_len': 65536,\n        'batch_size': 500,\n        'temperature': 1.0,\n        'openai_extra_kwargs': {\n            'top_p': 1.0,\n        },\n        'extra_body': {\n            'top_k': 0,\n        }\n    },\n    'sdar-2batch': {\n        'query_per_second': 4,\n        'max_out_len': 64000,\n        'max_seq_len': 65536,\n        'batch_size': 2,\n        'temperature': 1.0,\n        'openai_extra_kwargs': {\n            'top_p': 1.0,\n        },\n        'extra_body': {\n            'top_k': 0,\n        }\n    },\n    'intern-s1-pro': {\n        'query_per_second': 4,\n        'max_out_len': 64000,\n        'max_seq_len': 65536,\n        'batch_size': 500,\n        'temperature': 0.8,\n        'openai_extra_kwargs': {\n            'top_p': 0.95,\n        },\n        'extra_body': {\n            'top_k': 50,\n            'min_p': 0.0,\n        }\n    },\n    'intern-s1-pro-32k': {\n        'query_per_second': 4,\n        'max_out_len': 32768,\n        'max_seq_len': 65536,\n        'batch_size': 500,\n        'temperature': 0.8,\n        'openai_extra_kwargs': {\n            'top_p': 0.95,\n        },\n        'extra_body': {\n            'top_k': 50,\n            'min_p': 0.0,\n        }\n    },\n    'intern-s1-pro-2batch': {\n        'query_per_second': 4,\n        'max_out_len': 64000,\n        'max_seq_len': 65536,\n        'batch_size': 2,\n        'temperature': 0.8,\n        'openai_extra_kwargs': {\n            'top_p': 0.95,\n        },\n        'extra_body': {\n            'top_k': 50,\n            'min_p': 0.0,\n        }\n    }\n}\n\nMLLM_EVAL_CONFIGS = {\n    'default': {},\n    'internvl': {\n        'repetition-penalty': 1.0,\n        'top-p': 0.8,\n        'top-k': 20,\n        'temperature': 0.7,\n    }\n}\n\nBACKEND_LIST = ['turbomind', 'pytorch']\n\nRESTFUL_MODEL_LIST = [\n    'Qwen/Qwen3-0.6B', 'Qwen/Qwen3-VL-2B-Instruct', 'Qwen/Qwen3-30B-A3B', 'internlm/Intern-S1',\n    'internlm/internlm2_5-20b', 'Qwen/Qwen3-32B', 'OpenGVLab/InternVL3_5-30B-A3B', 'OpenGVLab/InternVL3-38B',\n    'Qwen/Qwen3-VL-8B-Instruct', 'internlm/internlm3-8b-instruct', 'meta-llama/Llama-3.2-3B-Instruct',\n    'Qwen/Qwen3-VL-30B-A3B-Instruct'\n]\n\nRESTFUL_BASE_MODEL_LIST = [\n    'Qwen/Qwen3-8B-Base', 'internlm/internlm2_5-20b', 'Qwen/Qwen3-4B', 'internlm/internlm3-8b-instruct'\n]\n\nSUFFIX_INNER_AWQ = '-inner-4bits'\nSUFFIX_INNER_GPTQ = '-inner-gptq'\nSUFFIX_INNER_W8A8 = '-inner-w8a8'\n\nEVAL_RUN_CONFIG = {\n    'model': 'Qwen/Qwen2.5-32B-Instruct',\n    'backend': 'turbomind',\n    'communicator': 'nccl',\n    'quant_policy': 0,\n    'parallel_config': {\n        'tp': 2\n    },\n    'extra_params': {\n        'server-name': DEFAULT_SERVER,\n        'session-len': 76000,\n        'cache-max-entry-count': 0.7\n    }\n}\n"
  },
  {
    "path": "autotest/utils/evaluate_utils.py",
    "content": "import csv\nimport glob\nimport json\nimport os\nimport subprocess\nimport time\n\nimport allure\nimport pandas as pd\nfrom mmengine.config import Config\nfrom utils.common_utils import execute_command_with_logging\nfrom utils.config_utils import get_case_str_by_config, get_cli_str, parse_config_by_case\nfrom utils.constant import DEFAULT_PORT, DEFAULT_SERVER, EVAL_RUN_CONFIG\n\n\ndef write_to_summary(case_name, result, msg, metrics, result_dir):\n    status = '✅ PASS' if result else f'❌ FAIL {msg}'\n\n    config = parse_config_by_case(case_name)\n\n    backend = config['backend']\n    model = config['model']\n    communicator = config['communicator']\n    parallel_config_str = config['parallel_config']\n    quant_policy = config['quant_policy']\n\n    dataset_name = []\n    dataset_metrics = []\n    for key in sorted(metrics.keys()):\n        dataset_name.append(key)\n        dataset_metrics.append(metrics.get(key, ''))\n\n    summary_dataset_name = ' | '.join(dataset_name)\n    summary_dataset_metrics = ' | '.join(dataset_metrics)\n\n    summary_file = os.environ.get('GITHUB_STEP_SUMMARY', '')\n    md_summary_file = f'{result_dir}/summary_{case_name}.md'\n    summary_line = f'| {model} | {quant_policy} | {backend} | {communicator} | {parallel_config_str} | {status} | {summary_dataset_metrics} |\\n'  # noqa: E501\n\n    write_header = not os.path.exists(md_summary_file) or os.path.getsize(md_summary_file) == 0\n    with open(md_summary_file, 'a') as f:\n        if write_header:\n            dash_line = '-----|' * (len(metrics.keys()))\n            f.write('## Model Evaluation Results\\n')\n            f.write(\n                f'| Model | QuantPolicy | Backend | Communicator | Parallel config | Status | {summary_dataset_name} |\\n'  # noqa\n            )\n            f.write(f'|-------|-------------|---------|--------------|----|--------|{dash_line}\\n')\n        f.write(summary_line)\n    if summary_file:\n        write_header = not os.path.exists(summary_file) or os.path.getsize(summary_file) == 0\n        with open(summary_file, 'a') as f:\n            if write_header:\n                dash_line = '-----|' * (len(metrics.keys()))\n                f.write('## Model Evaluation Results\\n')\n                f.write(\n                    f'| Model | QuantPolicy | Backend | Communicator | Parallel config | Status | {summary_dataset_name} |\\n'  # noqa\n                )\n                f.write(f'|-------|-------------|---------|--------------|----|--------|{dash_line}\\n')\n            f.write(summary_line)\n    else:\n        print(\n            f'Summary: {model} | {backend} | {communicator} | {parallel_config_str} | {status} | {summary_dataset_metrics}'  # noqa: E501\n        )\n\n\ndef llm_summary(case_name, result, msg, work_dir, result_dir=None):\n    metrics = {}\n\n    if work_dir and os.path.exists(work_dir):\n        try:\n            summary_dirs = glob.glob(os.path.join(work_dir, '*', 'summary'))\n            if not summary_dirs:\n                raise FileNotFoundError('No summary directory found')\n\n            summary_dir = summary_dirs[0]\n\n            csv_files = glob.glob(os.path.join(summary_dir, 'summary_*.csv'))\n            if not csv_files:\n                raise FileNotFoundError('No CSV files found')\n\n            csv_file = sorted(csv_files)[-1]\n            if not os.path.exists(csv_file):\n                raise FileNotFoundError('CSV file does not exist')\n\n            with open(csv_file, 'r') as f:\n                reader = csv.reader(f)\n                next(reader)\n                for row in reader:\n                    if len(row) < 5 or not row[4]:\n                        continue\n\n                    dataset = row[0]\n                    metric_value = row[4]\n                    try:\n                        metrics[dataset] = f'{float(metric_value):.2f}'  # noqa: E231\n                    except ValueError:\n                        metrics[dataset] = metric_value\n\n        except Exception as e:\n            print(f'Error reading metrics: {str(e)}')\n    if not result_dir:\n        result_dir = work_dir\n    write_to_summary(case_name, result, msg, metrics, result_dir)\n\n\ndef mllm_summary(case_name,\n                 result,\n                 msg,\n                 work_dir,\n                 result_dir=None,\n                 dataset_list=['MMBench_V11_MINI', 'MMStar_MINI', 'AI2D_MINI', 'OCRBench_MINI']):\n\n    metrics = {}\n    pattern = os.path.join(work_dir, case_name, 'T*')\n    t_dirs = [d for d in glob.glob(pattern) if os.path.isdir(d)]\n\n    if not t_dirs:\n        return\n\n    # 按修改时间排序\n    t_dirs.sort(key=os.path.getmtime, reverse=True)\n    latest_dir = t_dirs[0]\n\n    for dataset in dataset_list:\n        if dataset == 'OCRBench_MINI':\n            score_file = f'{latest_dir}/{case_name}_{dataset}_score.json'\n            cur_score = 0\n            with open(score_file, 'r') as f:\n                total_score = json.load(f)\n                cur_score = total_score['Final Score Norm']\n            metrics[dataset] = f'{cur_score:.2f}'  # noqa: E231\n        else:\n            score_file = f'{latest_dir}/{case_name}_{dataset}_acc.csv'\n            df = pd.read_csv(score_file)\n            cur_score = df['Overall'].iloc[0]\n            if dataset == 'MMBench_V11_MINI':\n                cur_score = df.loc[df['split'] == 'dev', 'Overall'].values\n            cur_score = cur_score * 100\n            metrics[dataset] = f'{cur_score.item():.2f}'  # noqa: E231\n        if result_dir is None:\n            result_dir = work_dir\n    write_to_summary(case_name, result, msg, metrics, result_dir)\n\n\ndef eval_test(model_path, eval_path, case_name, port=DEFAULT_PORT, test_type='infer', extra_config={}, **kwargs):\n    work_dir = None\n    try:\n\n        work_dir = os.path.join(eval_path, f'wk_{case_name}')\n        timestamp = time.strftime('%Y%m%d_%H%M%S')\n        eval_log = os.path.join(eval_path, f'log_{case_name}_{test_type}_{timestamp}.log')\n        temp_config_path = os.path.join(eval_path, f'temp_{case_name}.py')\n\n        current_dir = os.path.dirname(os.path.abspath(__file__))\n        parent_dir = os.path.dirname(current_dir)\n        config_file = os.path.join(parent_dir, 'evaluate/eval_config_chat.py')\n\n        print(f'Starting OpenCompass evaluation for model: {model_path}')\n        print(f'Model path: {model_path}')\n        print(f'Case: {case_name}')\n        print(f'Config file: {config_file}')\n\n        original_cwd = os.getcwd()\n        os.makedirs(work_dir, exist_ok=True)\n\n        test_url = f'http://{DEFAULT_SERVER}:{port}/v1'\n\n        try:\n            if test_type == 'infer':\n                if not os.path.exists(config_file):\n                    return False, f'Config file {config_file} not found'\n\n                cfg = Config.fromfile(config_file)\n\n                cfg.MODEL_NAME = case_name\n                cfg.MODEL_PATH = model_path\n                cfg.API_BASE = test_url  # noqa: E231\n\n                if cfg.models and len(cfg.models) > 0:\n                    model_cfg = cfg.models[0]\n                    model_cfg['abbr'] = case_name\n                    model_cfg['path'] = case_name\n                    model_cfg['openai_api_base'] = test_url\n                    model_cfg['tokenizer_path'] = model_path\n\n                    for key, value in kwargs.items():\n                        model_cfg[key] = value\n\n                cfg.NUM_WORKERS = extra_config.get('max-num-workers', 8)\n                cfg.infer['partitioner']['num_worker'] = extra_config.get('max-num-workers', 8)\n\n                cfg.dump(temp_config_path)\n                print(f'Modified config saved to: {temp_config_path}')\n            elif test_type == 'eval':\n                if not os.path.exists(temp_config_path):\n                    error_msg = f'Temp config file {temp_config_path} not found for eval stage'\n                    llm_summary(case_name, False, error_msg, work_dir, eval_path)\n                    return False, error_msg\n\n                cfg = Config.fromfile(temp_config_path)\n                print(f'Using existing temp config file: {temp_config_path}')\n                eval_run_config = EVAL_RUN_CONFIG\n                eval_case_name = get_case_str_by_config(eval_run_config)\n                cfg.JUDGE_API_BASE = test_url\n                cfg.JUDGE_MODEL_PATH = model_path\n                cfg.JUDGE_MODEL_NAME = eval_case_name\n\n                if hasattr(cfg, 'judge_cfg'):\n                    cfg.judge_cfg['path'] = eval_case_name\n                    cfg.judge_cfg['abbr'] = eval_case_name\n                    cfg.judge_cfg['openai_api_base'] = test_url\n                    cfg.judge_cfg['tokenizer_path'] = model_path\n\n                if hasattr(cfg, 'datasets') and cfg.datasets:\n                    for dataset in cfg.datasets:\n                        if 'eval_cfg' in dataset and 'evaluator' in dataset['eval_cfg']:\n                            evaluator = dataset['eval_cfg']['evaluator']\n\n                            if 'judge_cfg' in evaluator:\n                                evaluator['judge_cfg']['abbr'] = cfg.JUDGE_MODEL_NAME\n                                evaluator['judge_cfg']['path'] = cfg.JUDGE_MODEL_NAME\n                                evaluator['judge_cfg']['openai_api_base'] = cfg.JUDGE_API_BASE\n                                evaluator['judge_cfg']['tokenizer_path'] = cfg.JUDGE_MODEL_PATH\n\n                            if 'llm_evaluator' in evaluator and 'judge_cfg' in evaluator['llm_evaluator']:\n                                evaluator['llm_evaluator']['judge_cfg']['abbr'] = cfg.JUDGE_MODEL_NAME\n                                evaluator['llm_evaluator']['judge_cfg']['path'] = cfg.JUDGE_MODEL_NAME\n                                evaluator['llm_evaluator']['judge_cfg']['openai_api_base'] = cfg.JUDGE_API_BASE\n                                evaluator['llm_evaluator']['judge_cfg']['tokenizer_path'] = cfg.JUDGE_MODEL_PATH\n\n                cfg.dump(temp_config_path)\n                print(f'Modified config for eval stage saved to: {temp_config_path}')\n\n            extra_config_str = get_cli_str(extra_config)\n            cmd = f'opencompass {temp_config_path} --reuse -w {work_dir} -m {test_type} --dump-res-length {extra_config_str}'  # noqa\n            print(f'Running command: {cmd}')\n            print(f'Work directory: {work_dir}')\n\n            result, stderr = execute_command_with_logging(cmd, eval_log, timeout=259200)\n\n            allure.attach.file(eval_log, name=eval_log, attachment_type=allure.attachment_type.TEXT)\n\n            if test_type == 'eval':\n                llm_summary(case_name, result, stderr, work_dir, eval_path)\n\n            return result, stderr\n        except Exception as e:\n            print(f'Error occurred: {e}')\n            return False, f'Error occurred: {e}'\n        finally:\n            os.chdir(original_cwd)\n            print(f'Returned to directory: {original_cwd}')\n\n    except subprocess.TimeoutExpired:\n        timeout_msg = (f'Evaluation timed out for {model_path} '\n                       f'after 259200 seconds')\n        if work_dir and test_type == 'eval':\n            llm_summary(case_name, False, timeout_msg, work_dir, eval_path)\n        return False, timeout_msg\n    except Exception as e:\n        error_msg = f'Error during evaluation for {model_path}: {str(e)}'\n        if work_dir and test_type == 'eval':\n            llm_summary(case_name, False, error_msg, work_dir, eval_path)\n        return False, error_msg\n\n\ndef mllm_eval_test(model_path, eval_path, case_name, port=DEFAULT_PORT, test_type='infer', extra_config={}):\n    work_dir = os.path.join(eval_path, f'wk_{case_name}')\n    timestamp = time.strftime('%Y%m%d_%H%M%S')\n    eval_log = os.path.join(eval_path, f'log_{case_name}_{timestamp}.log')\n\n    print(f'Starting VLMEvalKit evaluation for model: {model_path}')\n    print(f'Model path: {model_path}')\n    print(f'Case: {case_name}')\n    print(f'Work directory: {work_dir}')\n\n    os.makedirs(work_dir, exist_ok=True)\n\n    extra_config_str = get_cli_str(extra_config)\n\n    if test_type == 'infer':\n        cmd = f'python run.py --data MMBench_V11_MINI MMStar_MINI AI2D_MINI OCRBench_MINI --model {case_name} --base-url http://{DEFAULT_SERVER}:{port}/v1 --reuse --work-dir {work_dir} --mode infer {extra_config_str}'  # noqa\n    elif test_type == 'eval':\n        cmd = f'python run.py --data MMBench_V11_MINI MMStar_MINI AI2D_MINI OCRBench_MINI --model {case_name} --base-url http://{DEFAULT_SERVER}:empty/v1 --reuse --work-dir {work_dir} --api-nproc 32 --mode eval --judge turbomind_Qwen2.5-32B-Instruct_nccl_tp2_0 --judge-base-url http://{DEFAULT_SERVER}:{port}/v1'  # noqa\n\n    result, msg = execute_command_with_logging(cmd, eval_log)\n\n    allure.attach.file(eval_log, name=eval_log, attachment_type=allure.attachment_type.TEXT)\n\n    if test_type == 'eval':\n        mllm_summary(case_name,\n                     result,\n                     msg,\n                     work_dir,\n                     eval_path,\n                     dataset_list=['MMBench_V11_MINI', 'MMStar_MINI', 'AI2D_MINI', 'OCRBench_MINI'])\n    return result, msg\n"
  },
  {
    "path": "autotest/utils/get_run_config.py",
    "content": "from lmdeploy.model import MODELS\n\n\n# Deprecated function\ndef get_model_name(model):\n    model_names = ['llama', 'llama2', 'llama3', 'internlm', 'internlm2', 'baichuan2', 'chatglm2', 'yi', 'qwen']\n    model_names += list(MODELS.module_dict.keys())\n    model_names.sort()\n    model_name = _simple_model_name(model)\n    model_name = model_name.lower()\n\n    if model_name in model_names:\n        return model_name\n    if model_name in model_names:\n        return model_name\n    if ('llama-2' in model_name):\n        return 'llama2'\n    if ('llama-3-1' in model_name):\n        return 'llama3_1'\n    if ('llama-3' in model_name):\n        return 'llama3'\n    if 'vicuna' in model_name and 'llava' not in model_name:\n        return 'vicuna'\n    if 'llava' in model_name and 'v1' in model_name and 'v1.6-34b' not in model_name and 'mistral' not in model_name:\n        return 'llava-v1'\n    if 'llava' in model_name and 'v1.6-34b' in model_name:\n        return 'llava-chatml'\n    if 'internvl-chat' in model_name and 'v1-2' in model_name:\n        return 'internvl-zh-hermes2'\n    elif 'llava-1.5' in model_name:\n        return 'llava-v1'\n    if ('yi-vl' in model_name):\n        return 'yi-vl'\n    if ('qwen' in model_name):\n        return 'qwen'\n    if ('internvl') in model_name:\n        return 'internvl-internlm2'\n    if ('internlm2') in model_name:\n        return 'internlm2'\n    if ('internlm-xcomposer2d5') in model_name:\n        return 'internlm-xcomposer2d5'\n    if ('internlm-xcomposer2') in model_name:\n        return 'internlm-xcomposer2'\n    if ('glm-4') in model_name:\n        return 'glm4'\n    if len(model_name.split('-')) > 2 and '-'.join(model_name.split('-')[0:2]) in model_names:\n        return '-'.join(model_name.split('-')[0:2])\n    return model_name.split('-')[0]\n\n\ndef _simple_model_name(model):\n    if '/' in model:\n        model_name = model.split('/')[1]\n    else:\n        model_name = model\n    model_name = model_name.replace('-inner-4bits', '')\n    model_name = model_name.replace('-inner-w8a8', '')\n    model_name = model_name.replace('-4bits', '')\n    return model_name\n"
  },
  {
    "path": "autotest/utils/mp_log_utils.py",
    "content": "import os\n\nimport allure\nfrom pytest_assume.plugin import assume\n\n\ndef write_log(config, result, msg, is_new: bool = True, case_path_tag: str = 'default'):\n    try:\n        log_path = os.path.join(config.get('log_path'), case_path_tag)\n\n        if is_new:\n            file = open(log_path, 'w')\n        else:\n            file = open(log_path, 'a')\n\n        file.writelines('result:' + result + ', reason:' + msg + '\\n')\n        file.close()\n    except Exception as e:\n        return False, None, f'Unknown error: {e}'\n\n\ndef assert_log(config, case_path_tag: str = 'default'):\n    log_path = os.path.join(config.get('log_path'), case_path_tag)\n\n    with open(log_path, 'r') as f:\n        lines = f.readlines()\n\n        for line in lines:\n            if 'result:False, reason:' in line:\n                result = False\n                msg = line\n                break\n            if 'result:True, reason:' in line and not result:\n                result = True\n\n    allure.attach.file(log_path, name=log_path, attachment_type=allure.attachment_type.TEXT)\n    with assume:\n        assert result, msg\n"
  },
  {
    "path": "autotest/utils/pipeline_chat.py",
    "content": "import json\nimport os\nimport shutil\nimport time\n\nimport allure\nfrom pytest_assume.plugin import assume\nfrom utils.common_utils import execute_command_with_logging\nfrom utils.config_utils import get_case_str_by_config, get_cuda_prefix_by_workerid, get_workerid, resolve_extra_params\nfrom utils.rule_condition_assert import assert_result\n\n\ndef run_pipeline_llm_test(config, run_config, common_case_config, worker_id: str = '', is_smoke: bool = False):\n    model = run_config.get('model')\n    if run_config.get('env', {}).get('LMDEPLOY_USE_MODELSCOPE', 'False') == 'True':\n        model_path = model\n    else:\n        model_path = os.path.join(config.get('model_path'), model)\n\n    log_path = config.get('log_path')\n    case_name = get_case_str_by_config(run_config)\n    timestamp = time.strftime('%Y%m%d_%H%M%S')\n    pipeline_log = os.path.join(log_path, f'pipeline_llm_{case_name}_{timestamp}.log')\n\n    env = os.environ.copy()\n    env['MASTER_PORT'] = str(get_workerid(worker_id) + 29500)\n    env.update(run_config.get('env', {}))\n\n    run_config_bk = run_config.copy()\n    run_config_bk.pop('env', None)\n    run_config_bk.pop('model', None)\n\n    resolve_extra_params(run_config_bk.get('extra_params', {}), config.get('model_path'))\n\n    run_config_string = json.dumps(run_config_bk, ensure_ascii=False, indent=None)\n    run_config_string = run_config_string.replace(' ', '').replace('\"', '\\\\\"').replace(',', '\\\\,')\n\n    cuda_prefix = get_cuda_prefix_by_workerid(worker_id, run_config.get('parallel_config'))\n    cmd = f'{cuda_prefix} python3 autotest/tools/pipeline/llm_case.py run_pipeline_chat_test {model_path} {run_config_string} autotest/prompt_case.yml {is_smoke}'  # noqa E501\n\n    result, stderr = execute_command_with_logging(cmd, pipeline_log, timeout=1800, env=env)\n\n    with assume:\n        assert result, stderr\n\n    with open(pipeline_log, 'r', encoding='utf-8') as file:\n        output_text = file.read()\n\n    with open(pipeline_log, 'a') as file:\n        for case in common_case_config.keys():\n            if is_smoke and case != 'memory_test':\n                continue\n            if case != 'code_testcase' and 'code' in model_path.lower():\n                continue\n\n            with allure.step(case):\n                case_info = common_case_config.get(case)\n                case_result = True\n                reason = ''\n\n                for prompt_detail in case_info:\n                    prompt = list(prompt_detail.keys())[0]\n                    case_result, reason = assert_result(get_response_from_output_by_prompt(output_text, case, prompt),\n                                                        prompt_detail.values(), model_path)\n                    if not case_result:\n                        print(f'{case} result: {case_result}, reason: {reason} \\n')\n                    file.writelines(f'{case} result: {case_result}, reason: {reason} \\n')\n                with assume:\n                    assert case_result, reason\n    allure.attach.file(pipeline_log, name=pipeline_log, attachment_type=allure.attachment_type.TEXT)\n\n\ndef run_pipeline_mllm_test(config, run_config, worker_id: str = '', is_smoke: bool = False):\n    model = run_config.get('model')\n    if run_config.get('env', {}).get('LMDEPLOY_USE_MODELSCOPE', 'False') == 'True':\n        model_path = model\n    else:\n        model_path = os.path.join(config.get('model_path'), model)\n\n    log_path = config.get('log_path')\n    case_name = get_case_str_by_config(run_config)\n    timestamp = time.strftime('%Y%m%d_%H%M%S')\n    pipeline_log = os.path.join(log_path, f'pipeline_mllm_{case_name}_{timestamp}.log')\n\n    env = os.environ.copy()\n    env['MASTER_PORT'] = str(get_workerid(worker_id) + 29500)\n    env.update(run_config.get('env', {}))\n\n    run_config_bk = run_config.copy()\n    run_config_bk.pop('env', None)\n    run_config_bk.pop('model', None)\n    run_config_string = json.dumps(run_config_bk, ensure_ascii=False, indent=None)\n    run_config_string = run_config_string.replace(' ', '').replace('\"', '\\\\\"').replace(',', '\\\\,')\n\n    cuda_prefix = get_cuda_prefix_by_workerid(worker_id, run_config.get('parallel_config'))\n    resource_path = config.get('resource_path')\n    cmd = f'{cuda_prefix} python3 autotest/tools/pipeline/mllm_case.py run_pipeline_mllm_test {model_path} {run_config_string} {resource_path} {is_smoke}'  # noqa E501\n\n    result, stderr = execute_command_with_logging(cmd, pipeline_log, timeout=1800, env=env, should_print=False)\n\n    with assume:\n        assert result, stderr\n\n    with open(pipeline_log, 'r', encoding='utf-8') as file:\n        output_text = file.read()\n\n    with open(pipeline_log, 'a') as file:\n        with allure.step('single1 pic'):\n            response = get_response_from_output(output_text, 'single1')\n            case_result = any(word in response.lower() for word in ['tiger', '虎'])\n            file.writelines(f'single1 pic result: {case_result} reason: simple example tiger should in {response} \\n')\n            with assume:\n                assert case_result, f'reason: simple example tiger should in {response}'\n        with allure.step('single2 pic'):\n            response = get_response_from_output(output_text, 'single2')\n            case_result = any(word in response.lower() for word in ['tiger', '虎'])\n            file.writelines(f'single2 pic result: {case_result} reason: simple example tiger should in {response} \\n')\n            with assume:\n                assert case_result, f'reason: simple example tiger should in {response}'\n        with allure.step('multi-imagese'):\n            response = get_response_from_output(output_text, 'multi-imagese')\n            case_result = any(word in response.lower() for word in ['tiger', '虎', '滑雪', 'ski'])\n            file.writelines(f'multi-imagese pic result: {case_result} reason: tiger or ski should in {response} \\n')\n            with assume:\n                assert case_result, f'reason: Multi-images example: tiger or ski should in {response}'\n        with allure.step('batch-example1'):\n            response = get_response_from_output(output_text, 'batch-example1')\n            case_result = any(word in response.lower() for word in ['滑雪', 'ski'])\n            file.writelines(f'batch-example1 pic result: {case_result} reason: ski should in {response} \\n')\n            with assume:\n                assert case_result, f'reason: batch-example1: ski should in {response}'\n        with allure.step('batch-example2'):\n            response = get_response_from_output(output_text, 'batch-example2')\n            case_result = any(word in response.lower() for word in ['tiger', '虎'])\n            file.writelines(f'batch-example2 pic result: {case_result} reason: tiger should in {response} \\n')\n            with assume:\n                assert case_result, f'reason: batch-example1: tiger should in {response}'\n        with allure.step('multi-turn1'):\n            response = get_response_from_output(output_text, 'multi-turn1')\n            case_result = any(word in response.lower() for word in ['滑雪', 'ski'])\n            file.writelines(f'multi-turn1 pic result: {case_result} reason:  ski should in {response} \\n')\n            with assume:\n                assert case_result, f'reason: batch-example1: ski should in {response}'\n        with allure.step('multi-turn2'):\n            response = get_response_from_output(output_text, 'multi-turn2')\n            case_result = any(word in response.lower() for word in ['滑雪', 'ski'])\n            file.writelines(f'multi-turn2 pic result: {case_result} reason: ski should in {response} \\n')\n            with assume:\n                assert case_result, f'reason: batch-example1: ski should in {response}'\n        if not is_smoke:\n            if 'internvl' in model.lower() and 'internvl2-4b' not in model.lower():\n                internvl_vl_testcase(output_text, file)\n                internvl_vl_testcase(output_text, file, 'cn')\n            if 'minicpm' in model.lower():\n                MiniCPM_vl_testcase(output_text, file)\n            if 'qwen' in model.lower():\n                Qwen_vl_testcase(output_text, file)\n\n    with open(pipeline_log, 'r', encoding='utf-8') as file:\n        output_text = file.read()\n    print(output_text)\n    allure.attach.file(pipeline_log, name=pipeline_log, attachment_type=allure.attachment_type.TEXT)\n\n\ndef get_response_from_output(output_text, case):\n    return output_text.split(f'[caseresult {case} start]')[1].split(f'[caseresult {case} end]')[0]\n\n\ndef get_response_from_output_by_prompt(output_text, case, prompt):\n    output_list = output_text.split(f'[caseresult {case} start]')[1].split(f'[caseresult {case} end]')[0]\n    output_dict = json.loads(output_list.rstrip())\n    for output in output_dict:\n        if output.get('prompt') == prompt:\n            return output.get('response')\n    return None\n\n\ndef assert_pipeline_single_return(output, logprobs_num: int = 0):\n    result = assert_pipeline_single_element(output, is_last=True, logprobs_num=logprobs_num)\n    if not result:\n        return result, 'single_stream_element is wrong'\n    return result & (len(output.token_ids) == output.generate_token_len\n                     or len(output.token_ids) == output.generate_token_len - 1), 'token_is len is not correct'\n\n\ndef assert_pipeline_batch_return(output, size: int = 1):\n    if len(output) != size:\n        return False, 'length is not correct'\n    for single_output in output:\n        result, msg = assert_pipeline_single_return(single_output)\n        if not result:\n            return result, msg\n    return True, ''\n\n\ndef assert_pipeline_single_stream_return(output, logprobs_num: int = 0):\n    for i in range(0, len(output) - 2):\n        if not assert_pipeline_single_element(output[i], is_stream=True, logprobs_num=logprobs_num):\n            return False, f'single_stream_element is false, index is {i}'\n    if assert_pipeline_single_element(output[-1], is_stream=True, is_last=True, logprobs_num=logprobs_num) is False:\n        return False, 'last single_stream_element is false'\n    return True, ''\n\n\ndef assert_pipeline_batch_stream_return(output, size: int = 1):\n    for i in range(size):\n        output_list = [item for item in output if item.index == i]\n        result, msg = assert_pipeline_single_stream_return(output_list)\n        if not result:\n            return result, msg\n    return True, ''\n\n\ndef assert_pipeline_single_element(output, is_stream: bool = False, is_last: bool = False, logprobs_num: int = 0):\n    result = True\n    result &= output.generate_token_len > 0\n    result &= output.input_token_len > 0\n    result &= output.index >= 0\n    if is_last:\n        result &= output.text is not None\n        result &= output.finish_reason in ['stop', 'length']\n        if is_stream:\n            result &= output.token_ids is None or output.token_ids == []\n        else:\n            result &= len(output.token_ids) > 0\n    else:\n        result &= len(output.text) > 0\n        result &= output.finish_reason is None\n        result &= len(output.token_ids) > 0\n    if logprobs_num == 0 or (is_last and is_stream):\n        result &= output.logprobs is None\n    else:\n        if is_stream:\n            result &= len(output.logprobs) >= 1\n        else:\n            result &= len(output.logprobs) == output.generate_token_len or len(\n                output.logprobs) == output.generate_token_len + 1\n        if result:\n            for content in output.logprobs:\n                result &= len(content.keys()) <= logprobs_num\n                for key in content.keys():\n                    result &= isinstance(content.get(key), float)\n    return result\n\n\ndef internvl_vl_testcase(output_text, file, lang: str = 'en'):\n    with allure.step(f'internvl-combined-images-{lang}'):\n        response = get_response_from_output(output_text, f'internvl-combined-images-{lang}')\n        case_result = any(word in response.lower() for word in ['panda', '熊猫'])\n        file.writelines(f'internvl-combined-images-{lang} result: {case_result}, reason: panda should in {response} \\n')\n        with assume:\n            assert case_result, f'reason: combined images: panda should in {response}'\n    with allure.step(f'internvl-combined-images2-{lang}'):\n        response = get_response_from_output(output_text, f'internvl-combined-images2-{lang}')\n        case_result = any(word in response.lower() for word in ['panda', '熊猫'])\n        file.writelines(\n            f'internvl-combined-images2-{lang} result: {case_result}, reason: panda should in {response} \\n')\n        with assume:\n            assert case_result, f'reason: combined images2: panda should in {response}'\n    with allure.step(f'internvl-separate-images-{lang}'):\n        response = get_response_from_output(output_text, f'internvl-separate-images-{lang}')\n        case_result = any(word in response.lower() for word in ['panda', '熊猫', 'same', 'different', 'eat', 'cute'])\n        file.writelines(f'internvl-separate-images-{lang} result: {case_result}, reason: panda should in {response} \\n')\n        with assume:\n            assert case_result, f'reason: separate images: panda should in {response}'\n    with allure.step(f'internvl-separate-images2-{lang}'):\n        response = get_response_from_output(output_text, f'internvl-separate-images2-{lang}')\n        case_result = any(word in response.lower()\n                          for word in ['panda', '熊猫', 'same', 'different', 'difference', 'identical'])\n        file.writelines(\n            f'internvl-separate-images2-{lang} result: {case_result}, reason: panda should in {response} \\n')\n        with assume:\n            assert case_result, f'reason: separate images2: panda should in {response}'\n    with allure.step(f'internvl-video-{lang}'):\n        response = get_response_from_output(output_text, f'internvl-video-{lang}')\n        case_result = any(word in response.lower() for word in ['red panda', 'eat', '熊猫', '竹子', 'food', 'hold'])\n        file.writelines(f'internvl-video-{lang} result: {case_result}, reason: panda should in {response} \\n')\n        with assume:\n            assert case_result, f'reason: video: panda should in {response}'\n    with allure.step(f'internvl-video2-{lang}'):\n        response = get_response_from_output(output_text, f'internvl-video2-{lang}')\n        case_result = any(word in response.lower() for word in ['red panda', 'eat', '熊猫', '竹子'])\n        file.writelines(f'internvl-video2-{lang} result: {case_result}, reason: panda should in {response} \\n')\n        with assume:\n            assert case_result, f'reason: video2: panda should in {response}'\n\n\ndef MiniCPM_vl_testcase(output_text, file):\n    with allure.step('minicpm-combined-images'):\n        response = get_response_from_output(output_text, 'minicpm-combined-images')\n        case_result = any(word in response.lower() for word in ['panda', '熊猫'])\n        file.writelines(f'minicpm-combined-images result: {case_result}, reason:  panda should in {response} \\n')\n        with assume:\n            assert case_result, f'reason: combined images: panda should in {response}'\n    with allure.step('minicpm-combined-images2'):\n        response = get_response_from_output(output_text, 'minicpm-combined-images2')\n        case_result = any(word in response.lower() for word in ['panda', '熊猫'])\n        file.writelines(f'minicpm-combined-images2 result: {case_result}, reason: panda should in {response} \\n')\n        with assume:\n            assert case_result, f'reason: combined images2: panda should in {response}'\n    with allure.step('minicpm-fewshot'):\n        response = get_response_from_output(output_text, 'minicpm-fewshot')\n        case_result = any(word in response.lower() for word in ['2021', '14'])\n        file.writelines(f'minicpm-fewshot result: {case_result} reason: 2021 or 14 should in {response} \\n')\n        with assume:\n            assert case_result, f'reason: fewshot: 2021 or 14 should in {response}'\n    with allure.step('minicpm-video'):\n        response = get_response_from_output(output_text, 'minicpm-video')\n        case_result = any(word in response.lower() for word in ['red panda', '熊猫'])\n        file.writelines(f'minicpm-video result: {case_result} reason: video: panda should in {response} \\n')\n        with assume:\n            assert case_result, f'reason: video: panda should in {response}'\n\n\ndef Qwen_vl_testcase(output_text, file):\n    with allure.step('qwen-combined-images'):\n        response = get_response_from_output(output_text, 'qwen-combined-images')\n        case_result = any(word in response.lower() for word in ['buildings', '楼', 'skyline', 'city'])\n        file.writelines(f'qwen-combined-images result: {case_result}, reason: buildings should in {response} \\n')\n        with assume:\n            assert case_result, f'reason: combined images: buildings should in {response}'\n    with allure.step('qwen-combined-images2'):\n        response = get_response_from_output(output_text, 'qwen-combined-images2')\n        case_result = any(word in response.lower() for word in ['buildings', '楼', 'skyline', 'city'])\n        file.writelines(f'qwen-combined-images2 result: {case_result}, reason: buildings should in {response} \\n')\n        with assume:\n            assert case_result, f'reason: combined images2: buildings should in {response}'\n    with allure.step('qwen-performance-images'):\n        response = get_response_from_output(output_text, 'qwen-performance-images')\n        case_result = any(word in response.lower() for word in ['buildings', '楼', 'skyline', 'city'])\n        file.writelines(f'qwen-performance-images result: {case_result}, reason: buildings should in {response} \\n')\n        with assume:\n            assert case_result, f'reason: performance images: buildings should in {response}'\n    with allure.step('qwen-performance-images2'):\n        response = get_response_from_output(output_text, 'qwen-performance-images2')\n        case_result = any(word in response.lower() for word in ['buildings', '楼', 'skyline', 'city'])\n        file.writelines(f'qwen-performance-images2 result: {case_result}, reason: buildings should in {response} \\n')\n        with assume:\n            assert case_result, f'reason: performance images2: buildings should in {response}'\n\n\ndef save_pipeline_common_log(config, log_name, result, content, msg: str = '', write_type: str = 'w'):\n    log_path = config.get('log_path')\n\n    config_log = os.path.join(log_path, log_name)\n    file = open(config_log, write_type)\n    file.writelines(f'result:{result}, reason: {msg}, content: {content}')  # noqa E231\n    file.close()\n\n\ndef assert_pipeline_common_log(config, log_name):\n    log_path = config.get('log_path')\n\n    config_log = os.path.join(log_path, log_name)\n    allure.attach.file(config_log, name=config_log, attachment_type=allure.attachment_type.TEXT)\n\n    msg = 'result is empty, please check again'\n    result = False\n    with open(config_log, 'r') as f:\n        lines = f.readlines()\n\n        for line in lines:\n            if 'result:False, reason:' in line:\n                result = False\n                msg = line\n                break\n            if 'result:True, reason:' in line and not result:\n                result = True\n                msg = ''\n    try:\n        if os.path.isfile(config_log):\n            os.remove(config_log)\n        elif os.path.isdir(config_log):\n            shutil.rmtree(config_log)\n    except OSError:\n        pass  # Ignore errors when removing log file\n\n    assert result, msg\n"
  },
  {
    "path": "autotest/utils/proxy_distributed_utils.py",
    "content": "import os\nimport random\nimport socket\nimport subprocess\nimport time\nfrom typing import Any\n\nimport requests\nfrom utils.config_utils import get_case_str_by_config, get_cli_common_param, resolve_extra_params\nfrom utils.ray_distributed_utils import verify_service_functionality\n\ntime_time = time.time\n\nDEFAULT_PROXY_PORT = 8000\nWORKER_WAIT_INTERVAL = 15  # seconds\n\n\ndef is_port_open(host: str, port: int, timeout: float = 1.0) -> bool:\n    \"\"\"Check if a port is open.\"\"\"\n    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:\n        s.settimeout(timeout)\n        try:\n            s.connect((host, port))\n            return True\n        except (socket.timeout, ConnectionRefusedError, OSError):\n            return False\n\n\ndef check_nodes_status(host: str, proxy_port: int, model_name: str, expected_instances: int, check_count: int,\n                       current_time: float, last_progress_print: float,\n                       progress_print_interval: int) -> tuple[bool, int]:\n    try:\n        nodes_url = f'http://{host}:{proxy_port}/nodes/status'\n        resp = requests.get(nodes_url, timeout=10)\n\n        if resp.status_code != 200:\n            if current_time - last_progress_print >= progress_print_interval:\n                print(f'🔧 Check {check_count}: Failed to get node status, status code: {resp.status_code}')\n            return False, 0\n\n        nodes_data = resp.json()\n        ready_instances = 0\n        total_instances = len(nodes_data)\n\n        for node_info in nodes_data.values():\n            models = node_info.get('models', [])\n            if model_name in models:\n                ready_instances += 1\n\n        should_print = current_time - last_progress_print >= progress_print_interval\n\n        if should_print:\n            basename = os.path.basename(model_name)\n            print(f'📊 Check {check_count}: Model registration progress: '\n                  f'{ready_instances}/{expected_instances} instances ready '\n                  f'(Total reported: {total_instances})')\n            for node_url, node_info in nodes_data.items():\n                models = node_info.get('models', [])\n                if model_name in models:\n                    print(f'   ✅ Instance {node_url} registered model {basename}')\n                else:\n                    print(f'   ⏳ Instance {node_url} has not registered target model')\n\n        if ready_instances >= expected_instances:\n            if should_print:\n                print(f'🎯 All {expected_instances} API server instances have registered the target model')\n            return True, ready_instances\n        else:\n            if should_print:\n                print(f'⏳ Waiting for more instances to register... ({ready_instances}/{expected_instances})')\n            return False, ready_instances\n\n    except Exception as e:\n        if current_time - last_progress_print >= progress_print_interval:\n            print(f'🔧 Check {check_count}: Exception getting node status - {e}')\n        return False, 0\n\n\ndef wait_for_model_service_ready(host: str,\n                                 proxy_port: int,\n                                 model_name: str,\n                                 timeout_seconds: int = 2000,\n                                 expected_instances: int = None) -> bool:\n    if expected_instances:\n        print(f'⏳ Waiting for model service to be fully ready (Model: {model_name}), '\n              f'expected instances: {expected_instances}, timeout: {timeout_seconds}s')\n    else:\n        print(f'⏳ Waiting for model service to be fully ready (Model: {model_name}), '\n              f'timeout: {timeout_seconds}s')\n\n    start_time = time_time()\n    check_count = 0\n    last_progress_print = 0\n    progress_print_interval = 30\n\n    initial_delay = random.uniform(1, 5)\n    time.sleep(initial_delay)\n\n    while time_time() - start_time < timeout_seconds:\n        check_count += 1\n        current_time = time_time()\n\n        try:\n            with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:\n                sock.settimeout(5)\n                if sock.connect_ex((host, proxy_port)) != 0:\n                    if current_time - last_progress_print >= progress_print_interval:\n                        print(f'🔌 Check {check_count}: proxy port not ready')\n                        last_progress_print = current_time\n                    time.sleep(10)\n                    continue\n\n            if expected_instances:\n                instances_ready, ready_count = check_nodes_status(host, proxy_port, model_name, expected_instances,\n                                                                  check_count, current_time, last_progress_print,\n                                                                  progress_print_interval)\n                if not instances_ready:\n                    if ready_count is not None and current_time - last_progress_print >= progress_print_interval:\n                        last_progress_print = current_time\n                    time.sleep(10)\n                    continue\n\n            service_ready = verify_service_functionality(host, proxy_port, model_name, check_count)\n            if service_ready:\n                if expected_instances:\n                    print(f'✅ All {expected_instances} API server instances are ready and service is functional!')\n                else:\n                    print('✅ Model service is fully ready!')\n                return True\n\n        except requests.exceptions.RequestException as e:\n            if current_time - last_progress_print >= progress_print_interval:\n                print(f'🔧 Check {check_count}: Request exception - {e}')\n                last_progress_print = current_time\n        except Exception as e:\n            if current_time - last_progress_print >= progress_print_interval:\n                print(f'🔧 Check {check_count}: Unknown exception - {e}')\n                last_progress_print = current_time\n\n        sleep_time = 10 + random.uniform(-2, 2)\n        time.sleep(sleep_time)\n\n    print(f'❌ Model service startup timed out ({timeout_seconds} seconds)')\n    return False\n\n\ndef proxy_worker_node_wait(manager, timeout_minutes: int = 120):\n    \"\"\"Worker node waits by periodically checking if the master's proxy service\n    is still alive. If the proxy becomes unreachable for several consecutive\n    checks, assume master has finished.\n\n    Args:\n        manager: ProxyDistributedManager instance\n        timeout_minutes: Maximum time to wait before giving up (default: 120 minutes)\n    \"\"\"\n    print(f'⏸️ Worker node {manager.node_rank} entering monitoring mode...')\n\n    max_checks = (timeout_minutes * 60) // WORKER_WAIT_INTERVAL\n    consecutive_failures = 0\n    max_consecutive_failures = 3\n\n    for i in range(max_checks):\n        if not is_port_open(manager.master_addr, manager.proxy_port, timeout=2.0):\n            consecutive_failures += 1\n            print(f'⚠️ Proxy connection to master failed ({consecutive_failures}/{max_consecutive_failures})')\n            if consecutive_failures >= max_consecutive_failures:\n                print('📡 Master proxy service stopped, worker node exiting')\n                break\n        else:\n            consecutive_failures = 0\n\n        if i % 4 == 0:\n            elapsed = (i * WORKER_WAIT_INTERVAL) // 60\n            print(f'⏳ Worker node {manager.node_rank} monitoring... Running for {elapsed} minutes')\n\n        time.sleep(WORKER_WAIT_INTERVAL)\n    else:\n        print(f'⏰ Worker node {manager.node_rank} monitoring timed out ({timeout_minutes} minutes)')\n\n    print(f'✅ Worker node {manager.node_rank} completed waiting')\n\n\nclass ProxyDistributedManager:\n\n    def __init__(self):\n        self.master_addr = os.getenv('MASTER_ADDR', '127.0.0.1')\n        self.node_rank = int(os.getenv('NODE_RANK', '0'))\n        self.proxy_port = int(os.getenv('PROXY_PORT', str(DEFAULT_PROXY_PORT)))\n\n        self.is_master = (self.node_rank == 0)\n        self.proxy_process = None\n\n    def start(self):\n        if not self.is_master:\n            return\n\n        cmd = [\n            'lmdeploy', 'serve', 'proxy', '--server-name', self.master_addr, '--server-port',\n            str(self.proxy_port), '--routing-strategy', 'min_expected_latency', '--serving-strategy', 'Hybrid'\n        ]\n        print(f\"[Proxy] Starting: {' '.join(cmd)}\")\n        self.proxy_process = subprocess.Popen(cmd)\n\n        time.sleep(5)\n\n    def cleanup(self):\n        if self.proxy_process and self.proxy_process.poll() is None:\n            print('[Proxy] Terminating proxy process...')\n            self.proxy_process.terminate()\n            try:\n                self.proxy_process.wait(timeout=10)\n            except subprocess.TimeoutExpired:\n                self.proxy_process.kill()\n\n\nclass ApiServerPerTest:\n\n    def __init__(self, proxy_manager: ProxyDistributedManager, config: dict[str, Any], run_config: dict[str, Any]):\n        self.proxy_manager = proxy_manager\n        self.config = config\n        self.run_config = run_config\n\n        model_name = run_config['model']\n        self.model_path = os.path.join(config['model_path'], model_name)\n\n        self.master_addr = proxy_manager.master_addr\n        self.proxy_port = proxy_manager.proxy_port\n        self.node_rank = int(os.getenv('NODE_RANK', '0'))\n        self.node_count = int(os.getenv('NODE_COUNT', '1'))\n        self.proc_per_node = int(os.getenv('PROC_PER_NODE', '1'))\n\n        self.expected_instances = self.node_count * self.proc_per_node\n        self.is_master = (self.node_rank == 0)\n        self.api_process = None\n\n    def start(self):\n        proxy_url = f'http://{self.master_addr}:{self.proxy_port}'\n\n        extra_params = self.run_config.get('extra_params', {})\n        resolve_extra_params(extra_params, self.config['model_path'])\n\n        # Get model-name: use extra_params['model-name'] if specified, otherwise use case_name\n        case_name = get_case_str_by_config(self.run_config)\n        self.model_name = case_name if extra_params.get('model-name', None) is None else extra_params.get('model-name')\n\n        cmd = [\n            'lmdeploy',\n            'serve',\n            'api_server',\n            self.model_path,\n            '--model-name',\n            self.model_name,\n        ] + get_cli_common_param(self.run_config).split() + [\n            '--proxy-url',\n            proxy_url,\n        ]\n        if self.node_count > 1:\n            cmd += ['--nnodes', str(self.node_count), '--node-rank', str(self.node_rank)]\n\n        print(f\"[API Server] Starting: {' '.join(cmd)}\")\n        timestamp = time.strftime('%Y%m%d_%H%M%S')\n        log_dir = self.config.get('server_log_path', '/tmp/lmdeploy_test')\n        os.makedirs(log_dir, exist_ok=True)\n        log_path = os.path.join(log_dir, f'log_{case_name}_{timestamp}.log')\n        self._log_file = open(log_path, 'w')\n        self.api_process = subprocess.Popen(cmd, stdout=self._log_file, stderr=self._log_file)\n        print(f'📝 API Server log: {log_path}')\n\n    def wait_until_ready(self):\n        if not self.is_master:\n            return\n        success = wait_for_model_service_ready(host=self.master_addr,\n                                               proxy_port=self.proxy_port,\n                                               model_name=self.model_name,\n                                               timeout_seconds=2000,\n                                               expected_instances=self.expected_instances)\n        if not success:\n            raise RuntimeError(f'API Server failed to register model: {self.model_name}')\n\n    def cleanup(self):\n        if self.api_process and self.api_process.poll() is None:\n            print(f'[API Server] Terminating for model: {self.model_path}')\n            self.api_process.terminate()\n            try:\n                self.api_process.wait(timeout=15)\n            except subprocess.TimeoutExpired:\n                self.api_process.kill()\n        if hasattr(self, '_log_file') and self._log_file and not self._log_file.closed:\n            self._log_file.close()\n"
  },
  {
    "path": "autotest/utils/quantization_utils.py",
    "content": "import os\nimport subprocess\nfrom subprocess import PIPE\n\n\ndef quantization(config,\n                 quantization_model_name,\n                 origin_model_name,\n                 quantization_type: str = 'awq',\n                 cuda_prefix: str = 'CUDA_VISIBLE_DEVICES=0'):\n    model_path = config.get('model_path')\n    log_path = config.get('log_path')\n    origin_model_path = os.path.join(config.get('model_path'), origin_model_name)\n    quantization_model_path = os.path.join(model_path, quantization_model_name)\n    quantization_log = os.path.join(\n        log_path, '_'.join(['quantization', quantization_type,\n                            quantization_model_name.split('/')[1]]) + '.log')\n\n    if quantization_type == 'awq':\n        quantization_cmd = ' '.join(\n            ['lmdeploy lite auto_awq', origin_model_path, '--work-dir', quantization_model_path])\n    elif quantization_type == 'gptq':\n        quantization_cmd = ' '.join(\n            ['lmdeploy lite auto_gptq', origin_model_path, '--work-dir', quantization_model_path])\n    elif quantization_type == 'w8a8':\n        quantization_cmd = ' '.join(\n            ['lmdeploy lite smooth_quant', origin_model_path, '--work-dir', quantization_model_path])\n    else:\n        return False, 'quantization type should in [awq, gptq, w8a8], \\\n            now the type is ' + quantization_type\n\n    # Add device option if specified in environment\n    device = os.environ.get('DEVICE', '')\n    if device == 'ascend':\n        quantization_cmd += ' --device npu '\n\n    if cuda_prefix is not None:\n        quantization_cmd = ' '.join([cuda_prefix, quantization_cmd])\n\n    if 'llama-3' in origin_model_name.lower():\n        quantization_cmd += ' --search-scale'\n\n    if quantization_type == 'gptq' or str(config.get('env_tag')) == '3090' or str(config.get('env_tag')) == '5080':\n        quantization_cmd += ' --batch-size 8'\n    else:\n        quantization_cmd += ' --batch-size 32'\n\n    with open(quantization_log, 'w') as f:\n        # remove existing folder\n        subprocess.run([' '.join(['rm -rf', quantization_model_path])],\n                       stdout=f,\n                       stderr=f,\n                       shell=True,\n                       text=True,\n                       encoding='utf-8')\n\n        f.writelines('reproduce command quantization_cmd: ' + quantization_cmd + '\\n')\n        print('reproduce command quantization_cmd: ' + quantization_cmd)\n        # quantization\n        quantizationRes = subprocess.run([quantization_cmd],\n                                         stdout=f,\n                                         stderr=PIPE,\n                                         shell=True,\n                                         text=True,\n                                         encoding='utf-8',\n                                         errors='replace')\n        f.writelines(quantizationRes.stderr)\n        result = quantizationRes.returncode == 0\n\n    return result, quantizationRes.stderr\n"
  },
  {
    "path": "autotest/utils/ray_distributed_utils.py",
    "content": "import os\nimport random\nimport socket\nimport subprocess\nimport time\nfrom time import time as time_time\nfrom typing import Any\n\nimport requests\nfrom utils.config_utils import get_case_str_by_config, get_cli_common_param, resolve_extra_params\n\n# Default constants\nLM_DEPLOY_API_PORT = 8000\nRAY_PORT = 6379\nHEALTH_CHECK_TIMEOUT = 30\nCONNECTION_CHECK_TIMEOUT = 5\nWORKER_WAIT_INTERVAL = 30\n\n\ndef wait_for_model_service_ready(\n    host: str,\n    api_port: int,\n    model_name: str,\n    timeout_seconds: int = 1000,\n) -> bool:\n    \"\"\"Wait for LMDeploy API Server to be ready and verify basic functionality.\n\n    No longer checks multi-node registration (API Server is a single-point service).\n    \"\"\"\n    print(f'⏳ Waiting for LMDeploy API Server to be ready (Model: {model_name}), Timeout: {timeout_seconds}s')\n\n    start_time = time_time()\n    check_count = 0\n    last_progress_print = 0\n    progress_print_interval = 30\n\n    # Random initial delay to avoid multiple clients requesting simultaneously\n    time.sleep(random.uniform(1, 5))\n\n    while time_time() - start_time < timeout_seconds:\n        check_count += 1\n        current_time = time_time()\n\n        try:\n            # Check if port is open\n            with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:\n                sock.settimeout(5)\n                if sock.connect_ex((host, api_port)) != 0:\n                    if current_time - last_progress_print >= progress_print_interval:\n                        print(f'🔌 Check {check_count}: API port {api_port} not ready')\n                        last_progress_print = current_time\n                    time.sleep(10)\n                    continue\n\n            # Verify service functionality\n            if verify_service_functionality(host, api_port, model_name, check_count):\n                print('✅ LMDeploy API Server is fully ready!')\n                return True\n\n        except Exception as e:\n            if current_time - last_progress_print >= progress_print_interval:\n                print(f'🔧 Check {check_count}: Exception - {e}')\n                last_progress_print = current_time\n\n        sleep_time = 10 + random.uniform(-2, 2)\n        time.sleep(sleep_time)\n\n    print(f'❌ LMDeploy API Server startup timed out ({timeout_seconds} seconds)')\n    return False\n\n\ndef verify_service_functionality(host: str, api_port: int, model_name: str, check_count: int) -> bool:\n    \"\"\"Verify that the API Server can respond to basic requests.\"\"\"\n    try:\n        test_data = {\n            'model': model_name,\n            'messages': [{\n                'role': 'user',\n                'content': 'hi'\n            }],\n            'max_tokens': 5,\n            'stream': False\n        }\n\n        resp = requests.post(f'http://{host}:{api_port}/v1/chat/completions', json=test_data, timeout=15)\n\n        if resp.status_code == 200:\n            print(f'✅ Check {check_count}: Service functionality normal (received valid response)')\n            return True\n        elif resp.status_code == 400:\n            print(f'✅ Check {check_count}: Service framework activated (received 400)')\n            return True\n        else:\n            print(f'🔧 Check {check_count}: Service test failed, status code: {resp.status_code}')\n            return False\n\n    except requests.exceptions.RequestException as e:\n        print(f'🔧 Check {check_count}: Service test exception - {e}')\n        return False\n\n\nclass RayLMDeployManager:\n\n    def __init__(\n        self,\n        master_addr: str,\n        ray_port: int = RAY_PORT,\n        api_port: int = LM_DEPLOY_API_PORT,\n        log_dir: str = '.',\n        health_check: bool = True,\n    ):\n        self.master_addr = master_addr\n        self.ray_port = ray_port\n        self.api_port = api_port\n        self.log_dir = log_dir\n        self.health_check = health_check\n        self._cleaned = False\n\n        # Determine if this is the master node (via environment variable NODE_RANK)\n        self.node_rank = int(os.getenv('NODE_RANK', '0'))\n        self.is_master = (self.node_rank == 0)\n\n        os.makedirs(self.log_dir, exist_ok=True)\n        print(f'📝 Node {self.node_rank} log directory: {self.log_dir}')\n\n        # Print cluster information\n        self.node_count = int(os.getenv('NODE_COUNT', '1'))\n        self.job_id = os.getenv('JOB_ID', 'unknown')\n        print(f'🎯 Node {self.node_rank} cluster information:')\n        print(f'- Total nodes: {self.node_count}')\n        print(f\"- Role: {'Master node' if self.is_master else 'Worker node'}\")\n        print(f'- Master address: {self.master_addr}')\n        print(f'- Ray port: {self.ray_port}')\n        print(f'- API port: {self.api_port}')\n        print(f'- Job ID: {self.job_id}')\n\n    def start_ray_cluster(self):\n        \"\"\"Start or join Ray cluster.\"\"\"\n        if self.is_master:\n            cmd = ['ray', 'start', '--head', '--port', str(self.ray_port)]\n            print(f'🚀 Master node starting Ray cluster (Port: {self.ray_port})')\n        else:\n            cmd = ['ray', 'start', '--address', f'{self.master_addr}:{self.ray_port}']\n            print(f'🔌 Worker node {self.node_rank} joining Ray cluster: {self.master_addr}:{self.ray_port}')\n\n        try:\n            subprocess.run(cmd, capture_output=True, text=True, check=True)\n            print('✅ Ray started successfully')\n        except subprocess.CalledProcessError as e:\n            print(f'💥 Ray startup failed: {e.stderr}')\n            raise\n\n    def start_lmdeploy_api_server(self, config: dict[str, Any], run_config: dict[str, Any]) -> None:\n        \"\"\"\n        Master node: Start LMDeploy API Server and wait for it to be ready.\n        Worker nodes: Do not start the service, only verify that the master node's API Server is ready.\n        \"\"\"\n        # Derive model_path from config and run_config\n        model_path = os.path.join(config['model_path'], run_config['model'])\n\n        extra_params = run_config.get('extra_params', {})\n        resolve_extra_params(extra_params, config['model_path'])\n\n        # Get model-name: use extra_params['model-name'] if specified, otherwise use case_name\n        case_name = get_case_str_by_config(run_config)\n        extra_params = run_config.get('extra_params', {})\n        model_name = case_name if extra_params.get('model-name', None) is None else extra_params.get('model-name')\n\n        if self.is_master:\n            # === Master node logic: Start service ===\n            timestamp = time.strftime('%Y%m%d_%H%M%S')\n            log_path = os.path.join(self.log_dir, f'log_{model_name}_{timestamp}.log')\n\n            cmd = [\n                'lmdeploy',\n                'serve',\n                'api_server',\n                model_path,\n                '--server-port',\n                str(self.api_port),\n                '--model-name',\n                model_name,\n            ] + get_cli_common_param(run_config).split()\n\n            print(f\"🚀 Master node starting LMDeploy API Server: {' '.join(cmd)}\")\n            self._log_file = open(log_path, 'w')\n            self._api_process = subprocess.Popen(cmd, stdout=self._log_file, stderr=self._log_file)\n            print(f'📝 API Server log: {log_path}')\n\n            # Wait for service to be ready\n            if self.health_check:\n                ready = wait_for_model_service_ready(host=self.master_addr,\n                                                     api_port=self.api_port,\n                                                     model_name=model_name,\n                                                     timeout_seconds=1000)\n                if not ready:\n                    print('❌ API Server failed to be ready, terminating process')\n                    self._api_process.terminate()\n                    try:\n                        self._api_process.wait(timeout=10)\n                    except subprocess.TimeoutExpired:\n                        self._api_process.kill()\n                    raise RuntimeError('LMDeploy API Server failed to start')\n        else:\n            # === Worker node logic: Only verify that the master node service is ready ===\n            print(f'🔍 Worker node {self.node_rank} is verifying that the master node '\n                  f'({self.master_addr}:{self.api_port}) API Server is ready...')\n            if self.health_check:\n                ready = wait_for_model_service_ready(host=self.master_addr,\n                                                     api_port=self.api_port,\n                                                     model_name=model_name,\n                                                     timeout_seconds=1000)\n                if not ready:\n                    raise RuntimeError(f'Worker node {self.node_rank}: Master node API Server not ready '\n                                       f'within 1000 seconds, cannot continue')\n            else:\n                print('⚠️ health_check=False, skipping API Server readiness check (not recommended)')\n\n    def cleanup(self, force: bool = True):\n        \"\"\"Clean up resources.\n\n        Args:\n            force (bool):\n                - False: Only stop LMDeploy API Server (used after individual test completion)\n                - True: Stop API Server + Ray cluster (used for final cleanup at session end)\n        \"\"\"\n        if self._cleaned and force:\n            # Note: If this is just an intermediate cleanup with force=False, we shouldn't skip due to _cleaned\n            # So only skip when force=True and already cleaned\n            return\n\n        print(f'🧹 Node {self.node_rank} cleaning resources... (force={force})')\n\n        # Stop API Server (master node only)\n        if hasattr(self, '_api_process') and self._api_process.poll() is None:\n            self._api_process.terminate()\n            try:\n                self._api_process.wait(timeout=10)\n            except subprocess.TimeoutExpired:\n                self._api_process.kill()\n            print('✅ LMDeploy API Server stopped')\n            # Note: We don't clear the _api_process attribute here so it can be checked later\n        if hasattr(self, '_log_file') and self._log_file and not self._log_file.closed:\n            self._log_file.close()\n\n        # Stop Ray (only when force=True)\n        if force:\n            try:\n                subprocess.run(['ray', 'stop', '--force'], check=False, capture_output=True)\n                print('✅ Ray cluster stopped')\n            except Exception as e:\n                print(f'⚠️ Ray stop exception: {e}')\n            self._cleaned = True  # Only mark as \"fully cleaned\" when force=True\n\n    def get_cluster_info(self) -> dict[str, Any]:\n        return {\n            'node_rank': self.node_rank,\n            'node_count': self.node_count,\n            'master_addr': self.master_addr,\n            'ray_port': self.ray_port,\n            'api_port': self.api_port,\n            'is_master': self.is_master,\n            'job_id': self.job_id,\n        }\n\n    def __enter__(self):\n        return self\n\n    def __exit__(self, exc_type, exc_val, exc_tb):\n        self.cleanup()\n\n\ndef ray_worker_node_wait(manager: RayLMDeployManager, timeout_minutes: int = 60):\n    \"\"\"Worker node waits for Ray master node (Head Node) to be alive (by\n    detecting GCS service port)\"\"\"\n    if manager.is_master:\n        return\n\n    print(f'⏸️ Worker node {manager.node_rank} entering wait mode...')\n    max_checks = (timeout_minutes * 60) // WORKER_WAIT_INTERVAL\n    consecutive_failures = 0\n    max_consecutive_failures = 3\n\n    for i in range(max_checks):\n        try:\n            with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:\n                sock.settimeout(CONNECTION_CHECK_TIMEOUT)\n                if sock.connect_ex((manager.master_addr, RAY_PORT)) == 0:\n                    consecutive_failures = 0\n                else:\n                    consecutive_failures += 1\n        except Exception:\n            consecutive_failures += 1\n\n        if consecutive_failures >= max_consecutive_failures:\n            print('📡 Ray master node GCS service unreachable, worker node exiting')\n            break\n\n        if i % 4 == 0:\n            elapsed = (i * WORKER_WAIT_INTERVAL) // 60\n            print(f'⏳ Worker node {manager.node_rank} waiting... Running for {elapsed} minutes')\n\n        time.sleep(WORKER_WAIT_INTERVAL)\n    else:\n        print(f'⏰ Worker node {manager.node_rank} wait timeout ({timeout_minutes} minutes)')\n\n    manager.cleanup()\n"
  },
  {
    "path": "autotest/utils/restful_return_check.py",
    "content": "import re\n\n\ndef assert_chat_completions_batch_return(output, model_name, check_logprobs: bool = False, logprobs_num: int = 5):\n    assert_usage(output.get('usage'))\n    assert output.get('id') is not None\n    assert output.get('object') == 'chat.completion'\n    assert output.get('model') == model_name\n    output_message = output.get('choices')\n    assert len(output_message) == 1\n    for message in output_message:\n        assert message.get('finish_reason') in ['stop', 'length']\n        assert message.get('index') == 0\n        assert len(message.get('message').get('content')) > 0\n        assert message.get('message').get('role') == 'assistant'\n        if check_logprobs:\n            len(message.get('logprobs').get('content')) == output.get('usage').get('completion_tokens')\n            for logprob in message.get('logprobs').get('content'):\n                assert_logprobs(logprob, logprobs_num)\n\n\ndef assert_completions_batch_return(output, model_name, check_logprobs: bool = False, logprobs_num: int = 5):\n    assert_usage(output.get('usage'))\n    assert output.get('id') is not None\n    assert output.get('object') == 'text_completion'\n    assert output.get('model') == model_name\n    output_message = output.get('choices')\n    assert len(output_message) == 1\n    for message in output_message:\n        assert message.get('finish_reason') in ['stop', 'length']\n        assert message.get('index') == 0\n        assert len(message.get('text')) > 0\n        if check_logprobs:\n            len(message.get('logprobs').get('content')) == output.get('usage').get('completion_tokens')\n            for logprob in message.get('logprobs').get('content'):\n                assert_logprobs(logprob, logprobs_num)\n\n\ndef assert_usage(usage):\n    assert usage.get('prompt_tokens') > 0\n    assert usage.get('total_tokens') > 0\n    assert usage.get('completion_tokens') > 0\n    assert usage.get('completion_tokens') + usage.get('prompt_tokens') == usage.get('total_tokens')\n\n\ndef assert_logprobs(logprobs, logprobs_num):\n    assert_logprob_element(logprobs)\n    assert len(logprobs.get('top_logprobs')) >= 0\n    assert type(logprobs.get('top_logprobs')) == list\n    assert len(logprobs.get('top_logprobs')) <= logprobs_num\n    for logprob_element in logprobs.get('top_logprobs'):\n        assert_logprob_element(logprob_element)\n\n\ndef assert_logprob_element(logprob):\n    assert len(logprob.get('token')) > 0 and type(logprob.get('token')) == str\n    assert len(logprob.get('bytes')) > 0 and type(logprob.get('bytes')) == list\n    assert type(logprob.get('logprob')) == float\n\n\ndef assert_chat_completions_stream_return(output,\n                                          model_name,\n                                          is_last: bool = False,\n                                          check_logprobs: bool = False,\n                                          logprobs_num: int = 5):\n    print(output)\n    assert output.get('id') is not None\n    assert output.get('object') == 'chat.completion.chunk'\n    assert output.get('model') == model_name\n    output_message = output.get('choices')\n    assert len(output_message) == 1\n    for message in output_message:\n        assert message.get('delta').get('role') == 'assistant'\n        assert message.get('index') == 0\n        assert len(message.get('delta').get('content')) >= 0\n        if not is_last:\n            assert message.get('finish_reason') is None\n            if check_logprobs:\n                assert (len(message.get('logprobs').get('content')) >= 1)\n                for content in message.get('logprobs').get('content'):\n                    assert_logprobs(content, logprobs_num)\n        if is_last is True:\n            assert len(message.get('delta').get('content')) == 0 or 'error' in message.get('delta').get('content')\n            assert message.get('finish_reason') in ['stop', 'length', 'error']\n            if check_logprobs is True:\n                assert message.get('logprobs') is None\n\n\ndef assert_completions_stream_return(output,\n                                     model_name,\n                                     is_last: bool = False,\n                                     check_logprobs: bool = False,\n                                     logprobs_num: int = 5):\n    print(output)\n    assert output.get('id') is not None\n    assert output.get('object') == 'text_completion'\n    assert output.get('model') == model_name\n    output_message = output.get('choices')\n    assert len(output_message) == 1\n    for message in output_message:\n        assert message.get('index') == 0\n        assert len(message.get('text')) >= 0\n        if is_last is False:\n            assert message.get('finish_reason') is None\n            if check_logprobs:\n                assert (len(message.get('logprobs').get('content')) >= 1)\n                for content in message.get('logprobs').get('content'):\n                    assert_logprobs(content, logprobs_num)\n\n        if is_last is True:\n            assert len(message.get('text')) == 0\n            assert message.get('finish_reason') in ['stop', 'length']\n            if check_logprobs is True:\n                assert message.get('logprobs') is None\n\n\ndef has_repeated_fragment(text, repeat_count=5):\n    pattern = r'(.+?)\\1{' + str(repeat_count - 1) + ',}'\n    match = re.search(pattern, text.replace('\\n', ''))\n    if match:\n        repeated_fragment = match.group(1)\n        start_pos = match.start()\n        return True, {'repeated_fragment': repeated_fragment, 'position': start_pos}\n    return False, f'{text} does not contain repeated fragments'\n"
  },
  {
    "path": "autotest/utils/rule_condition_assert.py",
    "content": "def assert_result(input, rule_condition, model_name: str = None):\n    input = input.replace('\\n', '\\\\n')\n    input_lower = input.lower()\n    for dict in rule_condition:\n        if dict is None:\n            return True, ''\n\n        for rule in dict:\n            operator = list(rule.keys())[0]\n            value = list(rule.values())[0]\n            if model_name is not None and model_name == operator:\n                dict = value\n\n        for rule in dict:\n            operator = list(rule.keys())[0]\n            value = list(rule.values())[0]\n            if input is None or len(input) == 0:\n                return False, 'response is empty'\n            if operator == 'contain':\n                if isinstance(value, list):\n                    tmpResult = False\n                    for word in value:\n                        if word.lower() in input_lower:\n                            tmpResult = True\n                    if not tmpResult:\n                        return False, ','.join(value) + \" doesn't exist in \" + input\n                else:\n                    if value.lower() not in input_lower:\n                        msg = value + \" doesn't exist in:\" + input\n                        return False, msg\n            if operator == 'not_contain':\n                if isinstance(value, list):\n                    for word in value:\n                        if word.lower() in input_lower:\n                            msg = word + \" shouldn't exist in:\" + input\n                            return False, msg\n                else:\n                    if value.lower() in input_lower:\n                        msg = value + \" shouldn't exist in \" + input\n                        return False, msg\n            if operator == 'len_g':\n                if len(input) < int(value):\n                    return False, input + ' length: ' + str(len(input)) + ', should greater than ' + str(value)\n        return True, ''\n\n\nif __name__ == '__main__':\n    input = '成都的景点hot potdddd'\n    condition = ([[{'contain': ['hot pot']}, {'contain': ['。']}, {'len_g': [10]}]])\n    print(assert_result(input, condition))\n"
  },
  {
    "path": "autotest/utils/run_client_chat.py",
    "content": "import os\nimport time\nfrom subprocess import PIPE, Popen\n\nimport allure\nfrom utils.config_utils import get_case_str_by_config, get_cli_common_param, get_cuda_prefix_by_workerid\nfrom utils.rule_condition_assert import assert_result\n\nTEMPLATE = 'autotest/template.json'\n\n\ndef run_tests(config, usercase, cli_case_config, run_config, worker_id):\n    if 'coder' in run_config['model'].lower() and usercase == 'chat_testcase':\n        usercase = 'code_testcase'\n\n    hf_command_line_test(config,\n                         usercase,\n                         cli_case_config.get(usercase),\n                         run_config,\n                         cuda_prefix=get_cuda_prefix_by_workerid(worker_id, run_config.get('parallel_config')))\n\n\ndef hf_command_line_test(config, case, case_info, run_config, cuda_prefix: str = ''):\n    model = run_config.get('model')\n    if run_config.get('env', {}).get('LMDEPLOY_USE_MODELSCOPE', 'False') == 'True':\n        model_path = model\n\n    else:\n        model_path = os.path.join(config.get('model_path'), model)\n\n    run_config['extra_params']['session_len'] = 4096\n    if case == 'base_testcase':\n        run_config['extra_params']['chat_template'] = TEMPLATE\n        run_config['extra_params']['session_len'] = 512\n\n    print(run_config)\n\n    cmd = ' '.join([cuda_prefix, ' '.join(['lmdeploy chat', model_path, get_cli_common_param(run_config)])]).strip()\n\n    result, chat_log, msg = command_test(config, cmd, run_config, case_info, True)\n    if chat_log:\n        allure.attach.file(chat_log, name=chat_log, attachment_type=allure.attachment_type.TEXT)\n    assert result, msg\n\n\ndef command_test(config, cmd, run_config, case_info, need_extract_output):\n    try:\n        log_path = config.get('log_path')\n        case_name = get_case_str_by_config(run_config)\n        timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())\n        chat_log = os.path.join(log_path, f'chat_{case_name}_{timestamp}.log')\n\n        returncode = -1\n        result = True\n\n        spliter = '\\n\\n'\n        # join prompt together\n        prompt = ''\n        for item in case_info:\n            prompt += list(item.keys())[0] + spliter\n        prompt += 'exit' + spliter\n\n        msg = ''\n\n        env = os.environ.copy()\n        env.update(run_config.get('env', {}))\n\n        with Popen([cmd],\n                   stdin=PIPE,\n                   stdout=PIPE,\n                   stderr=PIPE,\n                   shell=True,\n                   text=True,\n                   encoding='utf-8',\n                   errors='replace',\n                   env=env,\n                   start_new_session=True) as proc, open(chat_log, 'a') as file:\n            print(f'reproduce command chat: {cmd} \\n')\n            file.writelines(f'reproduce command chat: {cmd} \\n')\n\n            file.writelines('prompt:' + prompt + '\\n')\n\n            outputs, errors = proc.communicate(input=prompt)\n            returncode = proc.returncode\n            if returncode != 0:\n                file.writelines('error:' + errors + '\\n')\n                result = False\n                return result, chat_log, errors\n\n            outputDialogs = parse_dialogue(outputs)\n            file.writelines('answersize:' + str(len(outputDialogs)) + '\\n')\n\n            index = 0\n            for prompt_detail in case_info:\n                if need_extract_output:\n                    output = extract_output(outputDialogs[index], run_config.get('model'))\n                else:\n                    output = outputDialogs[index]\n                case_result, reason = assert_result(output, prompt_detail.values(), run_config.get('model'))\n                file.writelines(f'prompt: {list(prompt_detail.keys())[0]}\\n')\n                file.writelines(f'output: {output}\\n')\n                file.writelines(f'result: {case_result}, reason: {reason}\\n')\n                index += 1\n                if not case_result:\n                    print(f'prompt: {list(prompt_detail.keys())[0]}\\n')\n                    print(f'output: {output}\\n')\n                    print(f'result: {case_result}, reason: {reason}\\n')\n                    msg += reason\n                result = result and case_result\n            file.writelines('\\n\\n\\n' + 'full log:' + outputs + '\\n')\n\n        return result, chat_log, msg\n    except Exception as e:\n        return False, None, f'Unknown error: {e}'\n\n\ndef parse_dialogue(inputs: str):\n    dialogues = inputs.strip()\n    sep = 'double enter to end input >>>'\n    dialogues = dialogues.strip()\n    dialogues = dialogues.split(sep)\n    dialogues = [d.strip() for d in dialogues]\n    return dialogues[1:-1]\n\n\ndef extract_output(output: str, model: str):\n    if 'Qwen' in model or 'internlm2' in model:\n        if len(output.split('<|im_start|>assistant')) >= 2:\n            return output.split('<|im_start|>assistant')[1]\n    if 'Baichuan2' in model:\n        if len(output.split('<reserved_107>')) >= 2:\n            return output.split('<reserved_107>')[1]\n    if 'internlm' in model:\n        if len(output.split('<|Bot|>: ')) >= 2:\n            return output.split('<|Bot|>: ')[1]\n    if 'llama' in model or 'Llama' in model:\n        if len(output.split('[/INST]')) >= 2:\n            return output.split('[/INST]')[1]\n\n    return output\n"
  },
  {
    "path": "autotest/utils/run_restful_chat.py",
    "content": "import json\nimport os\nimport subprocess\nimport time\n\nimport allure\nimport psutil\nimport requests\nfrom openai import OpenAI\nfrom pytest_assume.plugin import assume\nfrom utils.config_utils import (get_case_str_by_config, get_cli_common_param, get_cuda_prefix_by_workerid, get_workerid,\n                                resolve_extra_params)\nfrom utils.constant import DEFAULT_PORT, DEFAULT_SERVER\nfrom utils.restful_return_check import assert_chat_completions_batch_return\nfrom utils.rule_condition_assert import assert_result\n\nfrom lmdeploy.serve.openai.api_client import APIClient\n\nBASE_HTTP_URL = f'http://{DEFAULT_SERVER}'\n\n\ndef start_openai_service(config, run_config, worker_id, timeout: int = 1200):\n    port = DEFAULT_PORT + get_workerid(worker_id)\n    case_name = get_case_str_by_config(run_config)\n    timestamp = time.strftime('%Y%m%d_%H%M%S')\n    server_log = os.path.join(config.get('server_log_path'), f'log_{case_name}_{port}_{timestamp}.log')\n\n    model = run_config.get('model')\n    if run_config.get('env', {}).get('LMDEPLOY_USE_MODELSCOPE', 'False') == 'True':\n        model_path = model\n    else:\n        model_path = os.path.join(config.get('model_path'), model)\n\n    cuda_prefix = get_cuda_prefix_by_workerid(worker_id, run_config.get('parallel_config'))\n\n    # Ensure extra_params exists before modifying\n    if 'extra_params' not in run_config:\n        run_config['extra_params'] = {}\n\n    resolve_extra_params(run_config['extra_params'], config.get('model_path'))\n\n    run_config['extra_params']['server-port'] = str(port)\n    run_config['extra_params']['allow-terminate-by-client'] = None\n    model_name = case_name if run_config['extra_params'].get(\n        'model-name', None) is None else run_config['extra_params'].pop('model-name')\n    cmd = ' '.join([\n        cuda_prefix, 'lmdeploy serve api_server', model_path,\n        get_cli_common_param(run_config), f'--model-name {model_name}'\n    ]).strip()\n\n    env = os.environ.copy()\n    env['MASTER_PORT'] = str(get_workerid(worker_id) + 29500)\n    env.update(run_config.get('env', {}))\n\n    file = open(server_log, 'w')\n    print('reproduce command restful: ' + cmd)\n    file.write('reproduce command restful: ' + cmd + '\\n')\n    startRes = subprocess.Popen(cmd,\n                                stdout=file,\n                                stderr=file,\n                                shell=True,\n                                text=True,\n                                env=env,\n                                encoding='utf-8',\n                                errors='replace',\n                                start_new_session=True)\n    pid = startRes.pid\n\n    http_url = ':'.join([BASE_HTTP_URL, str(port)])\n    start_time = int(time.time())\n    start_timeout = timeout\n\n    time.sleep(5)\n    for i in range(start_timeout):\n        time.sleep(1)\n        end_time = int(time.time())\n        total_time = end_time - start_time\n        result = health_check(http_url, case_name)\n        if result or total_time >= start_timeout:\n            break\n        try:\n            # Check if process is still running\n            return_code = startRes.wait(timeout=1)  # Small timeout to check status\n            if return_code != 0:\n                with open(server_log, 'r') as f:\n                    content = f.read()\n                    print(content)\n                return 0, content\n        except subprocess.TimeoutExpired:\n            continue\n    file.close()\n    allure.attach.file(server_log, name=server_log, attachment_type=allure.attachment_type.TEXT)\n    return pid, ''\n\n\ndef stop_restful_api(pid, startRes):\n    if pid > 0:\n        parent = psutil.Process(pid)\n        for child in parent.children(recursive=True):\n            child.terminate()\n        parent.terminate()\n\n\ndef terminate_restful_api(worker_id):\n    port = DEFAULT_PORT + get_workerid(worker_id)\n    http_url = ':'.join([BASE_HTTP_URL, str(port)])\n\n    response = None\n    request_error = None\n    try:\n        response = requests.get(f'{http_url}/terminate')\n    except requests.exceptions.RequestException as exc:\n        request_error = exc\n    if request_error is not None:\n        assert False, f'terminate request failed: {request_error}'\n    assert response is not None and response.status_code == 200, f'terminate with {response}'\n\n\ndef run_all_step(log_path, case_name, cases_info, port: int = DEFAULT_PORT):\n    http_url = ':'.join([BASE_HTTP_URL, str(port)])\n    model = get_model(http_url)\n\n    if model is None:\n        assert False, 'server not start correctly'\n    for case in cases_info.keys():\n        if case != 'code_testcase' and 'code' in model.lower():\n            continue\n        case_info = cases_info.get(case)\n\n        with allure.step(case + ' restful_test - openai chat'):\n            restful_result, restful_log, msg = open_chat_test(log_path, case_name, case_info, http_url)\n            allure.attach.file(restful_log, name=restful_log, attachment_type=allure.attachment_type.TEXT)\n        with assume:\n            assert restful_result, msg\n\n\ndef open_chat_test(log_path, case_name, case_info, url):\n    timestamp = time.strftime('%Y%m%d_%H%M%S')\n\n    restful_log = os.path.join(log_path, f'log_restful_{case_name}_{timestamp}.log')\n\n    file = open(restful_log, 'w')\n\n    result = True\n\n    client = OpenAI(api_key='YOUR_API_KEY', base_url=f'{url}/v1')\n    model_name = client.models.list().data[0].id\n\n    messages = []\n    msg = ''\n    for prompt_detail in case_info:\n        if not result:\n            break\n        prompt = list(prompt_detail.keys())[0]\n        messages.append({'role': 'user', 'content': prompt})\n        file.writelines('prompt:' + prompt + '\\n')\n\n        outputs = client.chat.completions.create(model=model_name,\n                                                 messages=messages,\n                                                 temperature=0.01,\n                                                 top_p=0.8,\n                                                 max_completion_tokens=1024,\n                                                 stream=True)\n\n        content_chunks = []\n        reasoning_content_chunks = []\n        for output in outputs:\n            # Safely handle streaming chunks: choices may be empty and content may be None\n            if not getattr(output, 'choices', None):\n                continue\n            choice = output.choices[0]\n            delta = getattr(choice, 'delta', None)\n            reasoning_content = getattr(delta, 'reasoning_content', None) if delta is not None else None\n            content = getattr(delta, 'content', None) if delta is not None else None\n            if reasoning_content:\n                reasoning_content_chunks.append(reasoning_content)\n            if content:\n                content_chunks.append(content)\n        reasoning_content = ''.join(reasoning_content_chunks)\n        output_content = ''.join(content_chunks)\n\n        file.writelines(f'reasoning_content :{reasoning_content}, content: {output_content}\\n')\n        messages.append({'role': 'assistant', 'content': output_content})\n\n        case_result, reason = assert_result(reasoning_content + output_content, prompt_detail.values(), model_name)\n        file.writelines('result:' + str(case_result) + ',reason:' + reason + '\\n')\n        if not case_result:\n            msg += reason\n        result = result and case_result\n    file.close()\n    return result, restful_log, msg\n\n\ndef health_check(url, model_name):\n    try:\n        api_client = APIClient(url)\n        model_name_current = api_client.available_models[0]\n        messages = []\n        messages.append({'role': 'user', 'content': '你好'})\n        for output in api_client.chat_completions_v1(model=model_name, messages=messages, top_k=1):\n            if output.get('code') is not None and output.get('code') != 0:\n                return False\n            # Return True on first successful response\n            return model_name == model_name_current\n        return False  # No output received\n    except Exception:\n        return False\n\n\ndef get_model(url):\n    print(url)\n    try:\n        api_client = APIClient(url)\n        model_name = api_client.available_models[0]\n        return model_name.split('/')[-1]\n    except Exception:\n        return None\n\n\ndef _run_logprobs_test(port: int = DEFAULT_PORT):\n    http_url = ':'.join([BASE_HTTP_URL, str(port)])\n    api_client = APIClient(http_url)\n    model_name = api_client.available_models[0]\n    output = None\n    for output in api_client.chat_completions_v1(model=model_name,\n                                                 messages='Hi, pls intro yourself',\n                                                 max_tokens=5,\n                                                 temperature=0.01,\n                                                 logprobs=True,\n                                                 top_logprobs=10):\n        continue\n    if output is None:\n        assert False, 'No output received from logprobs test'\n    print(output)\n    assert_chat_completions_batch_return(output, model_name, check_logprobs=True, logprobs_num=10)\n    assert output.get('choices')[0].get('finish_reason') == 'length'\n    assert output.get('usage').get('completion_tokens') == 6 or output.get('usage').get('completion_tokens') == 5\n\n\nPIC = 'tiger.jpeg'  # noqa E501\nPIC2 = 'human-pose.jpg'  # noqa E501\n\n\ndef run_vl_testcase(log_path, resource_path, port: int = DEFAULT_PORT):\n    http_url = ':'.join([BASE_HTTP_URL, str(port)])\n\n    model = get_model(http_url)\n    if model is None:\n        assert False, 'server not start correctly'\n\n    client = OpenAI(api_key='YOUR_API_KEY', base_url=http_url + '/v1')\n    model_name = client.models.list().data[0].id\n\n    timestamp = time.strftime('%Y%m%d_%H%M%S')\n\n    simple_model_name = model_name.split('/')[-1]\n    restful_log = os.path.join(log_path, f'restful_vl_{simple_model_name}_{str(port)}_{timestamp}.log')  # noqa\n    file = open(restful_log, 'w')\n\n    prompt_messages = [{\n        'role':\n        'user',\n        'content': [{\n            'type': 'text',\n            'text': 'Describe the image please',\n        }, {\n            'type': 'image_url',\n            'image_url': {\n                'url': f'{resource_path}/{PIC}',\n            },\n        }, {\n            'type': 'image_url',\n            'image_url': {\n                'url': f'{resource_path}/{PIC2}',\n            },\n        }],\n    }]\n\n    response = client.chat.completions.create(model=model_name, messages=prompt_messages, temperature=0.8, top_p=0.8)\n    file.writelines(str(response).lower() + '\\n')\n\n    api_client = APIClient(http_url)\n    model_name = api_client.available_models[0]\n    for item in api_client.chat_completions_v1(model=model_name, messages=prompt_messages):\n        continue\n    file.writelines(str(item) + '\\n')\n    file.close()\n\n    allure.attach.file(restful_log, name=restful_log, attachment_type=allure.attachment_type.TEXT)\n\n    assert 'tiger' in str(response).lower() or '虎' in str(response).lower() or 'ski' in str(\n        response).lower() or '滑雪' in str(response).lower(), response\n    assert 'tiger' in str(item).lower() or '虎' in str(item).lower() or 'ski' in str(item).lower() or '滑雪' in str(\n        item).lower(), item\n\n\ndef _run_reasoning_case(log_path, port: int = DEFAULT_PORT):\n    http_url = ':'.join([BASE_HTTP_URL, str(port)])\n\n    model = get_model(http_url)\n\n    if model is None:\n        assert False, 'server not start correctly'\n\n    timestamp = time.strftime('%Y%m%d_%H%M%S')\n    restful_log = os.path.join(log_path, f'restful_reasoning_{model}_{str(port)}_{timestamp}.log')\n    file = open(restful_log, 'w')\n\n    client = OpenAI(api_key='YOUR_API_KEY', base_url=http_url + '/v1')\n    model_name = client.models.list().data[0].id\n\n    with allure.step('step1 - stream'):\n        messages = [{'role': 'user', 'content': '9.11 and 9.8, which is greater?'}]\n        response = client.chat.completions.create(model=model_name, messages=messages, temperature=0.01, stream=True)\n        outputList = []\n        final_content = ''\n        final_reasoning_content = ''\n        for stream_response in response:\n            if stream_response.choices[0].delta.content is not None:\n                final_content += stream_response.choices[0].delta.content\n            if stream_response.choices[0].delta.reasoning_content is not None:\n                final_reasoning_content += stream_response.choices[0].delta.reasoning_content\n            outputList.append(stream_response)\n        file.writelines(str(outputList) + '\\n')\n        with assume:\n            assert '9.11' in final_reasoning_content and '9.11' in final_content and len(outputList) > 1, str(\n                outputList)\n\n    with allure.step('step2 - batch'):\n        response = client.chat.completions.create(model=model_name, messages=messages, temperature=0.01, stream=False)\n        print(response)\n        reasoning_content = response.choices[0].message.reasoning_content\n        content = response.choices[0].message.content\n        file.writelines(str(outputList) + '\\n')\n        with assume:\n            assert '9.11' in reasoning_content and '9.11' in content and len(outputList) > 1, str(outputList)\n\n    file.close()\n    allure.attach.file(restful_log, name=restful_log, attachment_type=allure.attachment_type.TEXT)\n\n\ndef test_internlm_multiple_round_prompt(client, model):\n\n    def add(a: int, b: int):\n        return a + b\n\n    def mul(a: int, b: int):\n        return a * b\n\n    tools = [{\n        'type': 'function',\n        'function': {\n            'name': 'add',\n            'description': 'Compute the sum of two numbers',\n            'parameters': {\n                'type': 'object',\n                'properties': {\n                    'a': {\n                        'type': 'int',\n                        'description': 'A number',\n                    },\n                    'b': {\n                        'type': 'int',\n                        'description': 'A number',\n                    },\n                },\n                'required': ['a', 'b'],\n            },\n        }\n    }, {\n        'type': 'function',\n        'function': {\n            'name': 'mul',\n            'description': 'Calculate the product of two numbers',\n            'parameters': {\n                'type': 'object',\n                'properties': {\n                    'a': {\n                        'type': 'int',\n                        'description': 'A number',\n                    },\n                    'b': {\n                        'type': 'int',\n                        'description': 'A number',\n                    },\n                },\n                'required': ['a', 'b'],\n            },\n        }\n    }]\n    messages = [{'role': 'user', 'content': 'Compute (3+5)*2'}]\n\n    response = client.chat.completions.create(model=model,\n                                              messages=messages,\n                                              temperature=0.01,\n                                              stream=False,\n                                              tools=tools)\n    print(response)\n    response_list = [response]\n    func1_name = response.choices[0].message.tool_calls[0].function.name\n    func1_args = response.choices[0].message.tool_calls[0].function.arguments\n    func1_args_dict = json.loads(func1_args)\n    func1_out = add(**func1_args_dict) if func1_name == 'add' else mul(**func1_args_dict)\n    with assume:\n        assert response.choices[0].finish_reason == 'tool_calls'\n    with assume:\n        assert func1_name == 'add'\n    with assume:\n        assert func1_args == '{\"a\": 3, \"b\": 5}'\n    with assume:\n        assert func1_out == 8\n    with assume:\n        assert response.choices[0].message.tool_calls[0].type == 'function'\n\n    messages.append({'role': 'assistant', 'content': response.choices[0].message.content})\n    messages.append({'role': 'environment', 'content': f'3+5={func1_out}', 'name': 'plugin'})\n    response = client.chat.completions.create(model=model,\n                                              messages=messages,\n                                              temperature=0.8,\n                                              top_p=0.8,\n                                              stream=False,\n                                              tools=tools)\n    print(response)\n    response_list.append(response)\n    func2_name = response.choices[0].message.tool_calls[0].function.name\n    func2_args = response.choices[0].message.tool_calls[0].function.arguments\n    func2_args_dict = json.loads(func2_args)\n    func2_out = add(**func2_args_dict) if func2_name == 'add' else mul(**func2_args_dict)\n    with assume:\n        assert response.choices[0].finish_reason == 'tool_calls'\n    with assume:\n        assert func2_name == 'mul'\n    with assume:\n        assert func2_args == '{\"a\": 8, \"b\": 2}'\n    with assume:\n        assert func2_out == 16\n    with assume:\n        assert response.choices[0].message.tool_calls[0].type == 'function'\n\n    return response_list\n\n\ndef test_qwen_multiple_round_prompt(client, model):\n\n    def get_current_temperature(location: str, unit: str = 'celsius'):\n        \"\"\"Get current temperature at a location.\n\n        Args:\n            location: The location to get the temperature for, in the format \"City, State, Country\".\n            unit: The unit to return the temperature in. Defaults to \"celsius\". (choices: [\"celsius\", \"fahrenheit\"])\n\n        Returns:\n            the temperature, the location, and the unit in a dict\n        \"\"\"\n        return {\n            'temperature': 26.1,\n            'location': location,\n            'unit': unit,\n        }\n\n    def get_temperature_date(location: str, date: str, unit: str = 'celsius'):\n        \"\"\"Get temperature at a location and date.\n\n        Args:\n            location: The location to get the temperature for, in the format 'City, State, Country'.\n            date: The date to get the temperature for, in the format 'Year-Month-Day'.\n            unit: The unit to return the temperature in. Defaults to 'celsius'. (choices: ['celsius', 'fahrenheit'])\n\n        Returns:\n            the temperature, the location, the date and the unit in a dict\n        \"\"\"\n        return {\n            'temperature': 25.9,\n            'location': location,\n            'date': date,\n            'unit': unit,\n        }\n\n    def get_function_by_name(name):\n        if name == 'get_current_temperature':\n            return get_current_temperature\n        if name == 'get_temperature_date':\n            return get_temperature_date\n\n    tools = [{\n        'type': 'function',\n        'function': {\n            'name': 'get_current_temperature',\n            'description': 'Get current temperature at a location.',\n            'parameters': {\n                'type': 'object',\n                'properties': {\n                    'location': {\n                        'type': 'string',\n                        'description':\n                        'The location to get the temperature for, in the format \\'City, State, Country\\'.'\n                    },\n                    'unit': {\n                        'type': 'string',\n                        'enum': ['celsius', 'fahrenheit'],\n                        'description': 'The unit to return the temperature in. Defaults to \\'celsius\\'.'\n                    }\n                },\n                'required': ['location']\n            }\n        }\n    }, {\n        'type': 'function',\n        'function': {\n            'name': 'get_temperature_date',\n            'description': 'Get temperature at a location and date.',\n            'parameters': {\n                'type': 'object',\n                'properties': {\n                    'location': {\n                        'type': 'string',\n                        'description':\n                        'The location to get the temperature for, in the format \\'City, State, Country\\'.'\n                    },\n                    'date': {\n                        'type': 'string',\n                        'description': 'The date to get the temperature for, in the format \\'Year-Month-Day\\'.'\n                    },\n                    'unit': {\n                        'type': 'string',\n                        'enum': ['celsius', 'fahrenheit'],\n                        'description': 'The unit to return the temperature in. Defaults to \\'celsius\\'.'\n                    }\n                },\n                'required': ['location', 'date']\n            }\n        }\n    }]\n    messages = [{\n        'role': 'user',\n        'content': 'Today is 2024-11-14, What\\'s the temperature in San Francisco now? How about tomorrow?'\n    }]\n\n    response = client.chat.completions.create(model=model,\n                                              messages=messages,\n                                              temperature=0.8,\n                                              top_p=0.8,\n                                              stream=False,\n                                              tools=tools)\n    print(response)\n    response_list = [response]\n    func1_name = response.choices[0].message.tool_calls[0].function.name\n    func1_args = response.choices[0].message.tool_calls[0].function.arguments\n    func2_name = response.choices[0].message.tool_calls[1].function.name\n    func2_args = response.choices[0].message.tool_calls[1].function.arguments\n    with assume:\n        assert response.choices[0].finish_reason == 'tool_calls'\n        assert func1_name == 'get_current_temperature'\n        assert func1_args == '{\"location\": \"San Francisco, CA, USA\"}' \\\n            or func1_args == '{\"location\": \"San Francisco, California, USA\", \"unit\": \"celsius\"}'\n        assert func2_name == 'get_temperature_date'\n        assert func2_args == '{\"location\": \"San Francisco, CA, USA\", \"date\": \"2024-11-15\"}' \\\n            or func2_args == '{\"location\": \"San Francisco, California, USA\", \"date\": \"2024-11-15\", \"unit\": \"celsius\"}'\n        assert response.choices[0].message.tool_calls[0].type == 'function'\n\n    messages.append(response.choices[0].message)\n\n    for tool_call in response.choices[0].message.tool_calls:\n        tool_call_args = json.loads(tool_call.function.arguments)\n        tool_call_result = get_function_by_name(tool_call.function.name)(**tool_call_args)\n        messages.append({\n            'role': 'tool',\n            'name': tool_call.function.name,\n            'content': tool_call_result,\n            'tool_call_id': tool_call.id\n        })\n\n    response = client.chat.completions.create(model=model,\n                                              messages=messages,\n                                              temperature=0.8,\n                                              top_p=0.8,\n                                              stream=False,\n                                              tools=tools)\n    print(response)\n    response_list.append(response)\n    with assume:\n        assert response.choices[0].finish_reason == 'stop'\n        assert '26.1' in response.choices[0].message.content\n\n    return response_list\n\n\ndef _run_tools_case(log_path, port: int = DEFAULT_PORT):\n    http_url = ':'.join([BASE_HTTP_URL, str(port)])\n\n    model = get_model(http_url)\n\n    if model is None:\n        assert False, 'server not start correctly'\n\n    timestamp = time.strftime('%Y%m%d_%H%M%S')\n    restful_log = os.path.join(log_path, f'restful_toolcall_{model}_{str(port)}_{timestamp}.log')\n    file = open(restful_log, 'w')\n\n    client = OpenAI(api_key='YOUR_API_KEY', base_url=http_url + '/v1')\n    model_name = client.models.list().data[0].id\n\n    with open(restful_log, 'a') as file:\n        with allure.step('step1 - one_round_prompt'):\n            tools = [{\n                'type': 'function',\n                'function': {\n                    'name': 'get_current_weather',\n                    'description': 'Get the current weather in a given location',\n                    'parameters': {\n                        'type': 'object',\n                        'properties': {\n                            'location': {\n                                'type': 'string',\n                                'description': 'The city and state, e.g. San Francisco, CA',\n                            },\n                            'unit': {\n                                'type': 'string',\n                                'enum': ['celsius', 'fahrenheit']\n                            },\n                        },\n                        'required': ['location'],\n                    },\n                }\n            }]\n            messages = [{'role': 'user', 'content': 'What\\'s the weather like in Boston today?'}]\n            response = client.chat.completions.create(model=model_name,\n                                                      messages=messages,\n                                                      temperature=0.01,\n                                                      stream=False,\n                                                      tools=tools)\n            print(response)\n            with assume:\n                assert response.choices[0].finish_reason == 'tool_calls'\n            with assume:\n                assert response.choices[0].message.tool_calls[0].function.name == 'get_current_weather'\n            with assume:\n                assert 'Boston' in response.choices[0].message.tool_calls[0].function.arguments\n            with assume:\n                assert response.choices[0].message.tool_calls[0].type == 'function'\n            file.writelines(str(response) + '\\n')\n\n        with allure.step('step2 - search prompt'):\n            tools = [{\n                'type': 'function',\n                'function': {\n                    'name': 'search',\n                    'description': 'BING search API',\n                    'parameters': {\n                        'type': 'object',\n                        'properties': {\n                            'query': {\n                                'type': 'string',\n                                'description': 'list of search query strings'\n                            }\n                        },\n                        'required': ['location']\n                    }\n                }\n            }]\n            messages = [{'role': 'user', 'content': '搜索最近的人工智能发展趋势'}]\n            response = client.chat.completions.create(model=model_name,\n                                                      messages=messages,\n                                                      temperature=0.01,\n                                                      stream=False,\n                                                      tools=tools)\n            print(response)\n            with assume:\n                assert response.choices[0].finish_reason == 'tool_calls'\n            with assume:\n                assert response.choices[0].message.tool_calls[0].function.name == 'search'\n            with assume:\n                assert '人工智能' in response.choices[0].message.tool_calls[0].function.arguments\n            with assume:\n                assert response.choices[0].message.tool_calls[0].type == 'function'\n            file.writelines(str(response) + '\\n')\n\n        with allure.step('step3 - multiple_round_prompt'):\n            response_list = None\n            if 'intern' in model.lower():\n                response_list = test_internlm_multiple_round_prompt(client, model_name)\n            elif 'qwen' in model.lower():\n                response_list = test_qwen_multiple_round_prompt(client, model_name)\n\n            if response_list is not None:\n                file.writelines(str(response_list) + '\\n')\n\n    allure.attach.file(restful_log, name=restful_log, attachment_type=allure.attachment_type.TEXT)\n\n\ndef proxy_health_check(url):\n    \"\"\"Check if proxy server is healthy.\"\"\"\n    try:\n        # For proxy server, we check if it responds to the /v1/models endpoint\n        import requests\n        response = requests.get(f'{url}/v1/models', timeout=5)\n        if response.status_code == 200:\n            return True\n        return False\n    except Exception:\n        return False\n\n\ndef start_proxy_server(log_path, port, case_name: str = 'default'):\n    \"\"\"Start the proxy server for testing with enhanced error handling and\n    logging.\"\"\"\n    if log_path is None:\n        log_path = '/nvme/qa_test_models/evaluation_report'\n\n    timestamp = time.strftime('%Y%m%d_%H%M%S')\n    proxy_log = os.path.join(log_path, f'proxy_server_{case_name}_{str(port)}_{timestamp}.log')\n\n    proxy_url = f'http://{DEFAULT_SERVER}:{port}'  # noqa: E231, E261\n    try:\n        response = requests.get(f'{proxy_url}/nodes/status', timeout=5)\n        if response.status_code == 200:\n            print(f'Terminating existing nodes on proxy {proxy_url}')\n            requests.get(f'{proxy_url}/nodes/terminate_all', timeout=10)\n            time.sleep(5)\n    except requests.exceptions.RequestException:\n        pass\n\n    cmd = (f'lmdeploy serve proxy --server-name {DEFAULT_SERVER} --server-port {port} '\n           f'--routing-strategy min_expected_latency --serving-strategy Hybrid')\n\n    print(f'Starting proxy server with command: {cmd}')\n    print(f'Proxy log will be saved to: {proxy_log}')\n\n    proxy_file = open(proxy_log, 'w')\n    proxy_process = subprocess.Popen([cmd],\n                                     stdout=proxy_file,\n                                     stderr=proxy_file,\n                                     shell=True,\n                                     text=True,\n                                     encoding='utf-8')\n    pid = proxy_process.pid\n\n    start_time = int(time.time())\n    timeout = 300\n\n    time.sleep(5)\n    for i in range(timeout):\n        time.sleep(1)\n        if proxy_health_check(f'http://{DEFAULT_SERVER}:{port}'):  # noqa: E231, E261\n            break\n\n        try:\n            # Check if process is still running\n            return_code = proxy_process.wait(timeout=1)  # Small timeout to check status\n            if return_code != 0:\n                with open(proxy_log, 'r') as f:\n                    content = f.read()\n                    print(content)\n                return 0, proxy_process\n        except subprocess.TimeoutExpired:\n            continue\n\n        end_time = int(time.time())\n        total_time = end_time - start_time\n        if total_time >= timeout:\n            break\n\n    proxy_file.close()\n    allure.attach.file(proxy_log, name=proxy_log, attachment_type=allure.attachment_type.TEXT)\n\n    print(f'Proxy server started successfully with PID: {pid}')\n    return pid, proxy_process\n\n\ndef run_llm_test(config, run_config, common_case_config, worker_id):\n    pid, content = start_openai_service(config, run_config, worker_id)\n    try:\n        if pid > 0:\n            case_name = get_case_str_by_config(run_config)\n            run_all_step(config.get('log_path'),\n                         case_name,\n                         common_case_config,\n                         port=DEFAULT_PORT + get_workerid(worker_id))\n        else:\n            assert False, f'Failed to start RESTful API server: {content}'\n    finally:\n        if pid > 0:\n            terminate_restful_api(worker_id)\n\n\ndef run_mllm_test(config, run_config, worker_id):\n    pid, content = start_openai_service(config, run_config, worker_id)\n    try:\n        if pid > 0:\n            run_vl_testcase(config.get('log_path'),\n                            config.get('resource_path'),\n                            port=DEFAULT_PORT + get_workerid(worker_id))\n        else:\n            assert False, f'Failed to start RESTful API server: {content}'\n    finally:\n        if pid > 0:\n            terminate_restful_api(worker_id)\n\n\ndef run_reasoning_case(config, run_config, worker_id):\n    pid, content = start_openai_service(config, run_config, worker_id)\n    try:\n        if pid > 0:\n            _run_reasoning_case(config.get('log_path'), port=DEFAULT_PORT + get_workerid(worker_id))\n        else:\n            assert False, f'Failed to start RESTful API server: {content}'\n    finally:\n        if pid > 0:\n            terminate_restful_api(worker_id)\n\n\ndef run_tools_case(config, run_config, worker_id):\n    pid, content = start_openai_service(config, run_config, worker_id)\n    try:\n        if pid > 0:\n            _run_tools_case(config.get('log_path'), port=DEFAULT_PORT + get_workerid(worker_id))\n        else:\n            assert False, f'Failed to start RESTful API server: {content}'\n    finally:\n        if pid > 0:\n            terminate_restful_api(worker_id)\n\n\ndef run_logprob_test(config, run_config, worker_id):\n    pid, content = start_openai_service(config, run_config, worker_id)\n    try:\n        if pid > 0:\n            _run_logprobs_test(port=DEFAULT_PORT + get_workerid(worker_id))\n        else:\n            assert False, f'Failed to start RESTful API server: {content}'\n    finally:\n        if pid > 0:\n            terminate_restful_api(worker_id)\n"
  },
  {
    "path": "autotest/utils/toolkit.py",
    "content": "from functools import lru_cache\n\nfrom transformers import AutoTokenizer\n\n\ndef parse_sse_stream(content: str) -> list[str]:\n    \"\"\"Parse SSE (Server-Sent Events) stream content into a list of events.\n\n    Each event is either a JSON string or \"[DONE]\".\n    \"\"\"\n    lines = content.strip().split('\\n')\n    events = []\n    for line in lines:\n        line = line.strip()\n        if line.startswith('data: '):\n            data = line[6:]  # remove \"data: \"\n            if data.strip() == '[DONE]':\n                events.append('[DONE]')\n            else:\n                events.append(data)\n    return events\n\n\n@lru_cache(maxsize=4)\ndef _load_tokenizer_cached(model_path: str):\n    try:\n        tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)\n        return tokenizer\n    except Exception as e:\n        raise RuntimeError(f\"Failed to load tokenizer from '{model_path}': {e}\")\n\n\ndef encode_text(model_path: str, text: str) -> list[int]:\n    tokenizer = _load_tokenizer_cached(model_path)\n\n    encoded = tokenizer.encode(text)\n\n    return encoded\n"
  },
  {
    "path": "benchmark/README.md",
    "content": "# Benchmark\n\nWe provide several profiling tools to benchmark our models.\n\n## profile with dataset\n\nDownload the dataset below or create your own dataset.\n\n```bash\nwget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json\n```\n\nProfiling your model with `profile_throughput.py`\n\n```bash\npython profile_throughput.py \\\n ShareGPT_V3_unfiltered_cleaned_split.json \\\n /path/to/your/model \\\n --concurrency 64\n```\n\n## profile restful api\n\n`profile_restful_api.py` is used to do benchmark on api server.\n\n```bash\nwget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json\n\npython3 profile_restful_api.py --backend lmdeploy --dataset-path ./ShareGPT_V3_unfiltered_cleaned_split.json\n```\n"
  },
  {
    "path": "benchmark/benchmark_decode.py",
    "content": "import json\nimport pickle\nimport time\nfrom pathlib import Path\n\nimport fire\nimport numpy as np\nfrom transformers import AutoTokenizer\n\nfrom lmdeploy.pytorch.decode import Engine\n\n\ndef benchmark(model_path, share_gpt_path, downsample=100, accel=None, save_to='decode_result'):\n    \"\"\"Benchmark using ShareGPT data.\n\n    Please download `ShareGPT_V3_unfiltered_cleaned_split.json` as data for this benchmark.\n    \"\"\"\n\n    start = time.monotonic()\n    content = json.load(open(share_gpt_path, 'r'))\n\n    texts = []\n    for c in content:\n        for cc in c['conversations']:\n            texts.append(cc['value'])\n\n    print(f'Parse json in {time.monotonic() - start} seconds.')\n\n    tokenizer = AutoTokenizer.from_pretrained(model_path)\n    tokenizer.pad_token_id = tokenizer.eos_token_id\n    tokenizer.padding_side = 'right'\n\n    texts = texts[::downsample]\n    input_ids = tokenizer(texts, padding=False).input_ids\n\n    print(F'Number of prompts: {len(input_ids)}')\n    print(F'Maximum length: {max(map(len, input_ids))}')\n    print(F'Total length: {sum(map(len, input_ids))}')\n\n    start = time.monotonic()\n    # Init an engine\n    engine = Engine(model_path, tokenizer=tokenizer, accel=accel)\n    # decode prompts\n    probs = engine.decode(input_ids)\n    total_tokens = sum(map(len, input_ids))\n\n    elapsed = time.monotonic() - start\n    print(f'Decoded {total_tokens} tokens in {elapsed:.1f} seconds, '\n          f'{total_tokens / elapsed:.1f} tokens/s.')\n    print(f'Decoded {len(probs)} prompts in {elapsed:.1f} seconds, '\n          f'{len(probs) / elapsed:.1f} requests/s.')\n\n    pkl_path = Path(save_to).with_suffix('.pkl')\n\n    with pkl_path.open('wb') as f:\n        pickle.dump(probs, f)\n\n    txt_path = Path(save_to).with_suffix('.txt')\n    np.savetxt(txt_path.as_posix(), probs, fmt='%.4e')\n\n\nif __name__ == '__main__':\n    fire.Fire(benchmark)\n\n    # llama-2 on 1 A100:\n    # data = ShareGPT, downsample = 100\n    # Decoded 1579536 tokens in 175.3 seconds, 9012.821089984884 tokens/s.\n    # Decoded 7022 prompts in 175.3 seconds, 40.067481648961376 requests/s.\n\n    # llama-2 on 3 A100:\n    # data = ShareGPT, downsample = 100\n    # Decoded 1579536 tokens in 77.9 seconds, 20268.736076299527 tokens/s.\n    # Decoded 7022 prompts in 77.9 seconds, 90.10688248180179 requests/s.\n\n    # llama-2 on 8 A100:\n    # data = ShareGPT, downsample = 100\n    # Decoded 1579536 tokens in 55.2 seconds, 28630.35872677815 tokens/s.\n    # Decoded 7022 prompts in 55.2 seconds, 127.27939026361929 requests/s.\n\n    # llama-2 on 8 A100:\n    # data = ShareGPT, downsample = 10\n    # Decoded 15991314 tokens in 242.7 seconds, 65893.38488718234 tokens/s.\n    # Decoded 70216 prompts in 242.7 seconds, 289.33018970413536 requests/s.\n\n    # Above time all includes time for workers to load model.\n"
  },
  {
    "path": "benchmark/benchmark_pipeline.py",
    "content": "import os\nimport subprocess\nfrom typing import Dict, List\n\nimport fire\nimport yaml\n\n\ndef get_cmd(model_path, backend, engine_config, data_config):\n    assert backend in ['turbomind', 'pytorch']\n\n    current_dir = os.path.dirname(os.path.abspath(__file__))\n    dataset_path = data_config.pop('dataset_path')\n    data_config.pop('dataset_name')\n\n    cmd = ['python3', f'{current_dir}/profile_pipeline_api.py', dataset_path, model_path, '--backend', backend]\n    for key, value in engine_config.items():\n        # profile_pipeline_api.py uses \"--concurrency\" to pass the \"max_batch_size\" value\n        if key == 'max_batch_size':\n            key = 'concurrency'\n        # change the key like 'cache_max_entry_count' to 'cache-max-entry-count' to suit the optional\n        # arguments in \"python3 benchmark/profile_pipeline_api.py\"\n        key = key.replace('_', '-')\n        cmd.append(f'--{key}')\n        cmd.append(str(value))\n\n    for key, value in data_config.items():\n        # change the key like 'sharegpt_output_len' to 'sharegpt-output-len' to suit the optional\n        # arguments in \"python3 benchmark/profile_pipeline_api.py\"\n        key = key.replace('_', '-')\n        cmd.append(f'--{key}')\n        cmd.append(str(value))\n    return cmd\n\n\ndef benchmark(model_path, backend, engine_config, data_config):\n    \"\"\"Benchmark the performance with the given configuration.\n\n    Args:\n        model_path: Path to the model.\n    :param backend: Backend to use.\n    :param engine_config: Configuration for the inference engine.\n    :param data_config: Configuration for the data.\n    \"\"\"\n    model_name = os.path.basename(model_path)\n    bs = engine_config['max_batch_size']\n    cach_ratio = engine_config.get('cache_max_entry_count', 0.8)\n    tp = engine_config.get('tp', 1)\n    output_file = f'benchmark_pipeline_{model_name}_{backend}_bs{bs}_tp{tp}_cache{cach_ratio}.csv'\n    try:\n        if isinstance(data_config, Dict):\n            data_config = [data_config]\n        assert isinstance(data_config, List) and all(isinstance(d, Dict) for d in data_config)\n        for _data_config in data_config:\n            _data_config['csv'] = output_file\n            cmd = get_cmd(model_path, backend, engine_config, _data_config)\n            print(f\"Running command: {' '.join(cmd)}\")\n            subprocess.run(cmd, check=True)\n    except Exception as e:\n        print(f'exception happened, {e}')\n\n\ndef main(model_path=None, backend=None, config_path=None):\n    with open(config_path, 'r') as f:\n        config = yaml.safe_load(f)\n        engine_configs = config['engine']\n        data_config = config['data']\n        if isinstance(engine_configs, Dict):\n            engine_configs = [engine_configs]\n        assert isinstance(engine_configs, List) and all(isinstance(s, Dict) for s in engine_configs)\n        for engine_config in engine_configs:\n            # The model_path provided by the user will override the model_path in the config file.\n            model_path = model_path or engine_config.pop('model_path')\n            engine_config.pop('model_path', '')\n            benchmark(model_path, backend, engine_config, data_config)\n\n\nif __name__ == '__main__':\n    fire.Fire(main)\n"
  },
  {
    "path": "benchmark/benchmark_serving.py",
    "content": "import os\nimport subprocess\nimport time\nfrom typing import Dict, List, Optional, Tuple\n\nimport fire\nimport yaml\n\n\ndef get_launching_server_cmd(model_path, backend, server_config):\n    if backend in ['turbomind', 'pytorch']:\n        cmd = ['lmdeploy', 'serve', 'api_server', model_path, '--backend', backend]\n    elif backend == 'sglang':\n        cmd = ['python3', '-m', 'sglang.launch_server', '--model-path', model_path]\n    elif backend == 'vllm':\n        cmd = ['vllm', 'serve', model_path]\n    else:\n        raise ValueError(f'unknown backend: {backend}')\n    for key, value in server_config.items():\n        # Convert snake_case to kebab-case for command line args\n        key = key.replace('_', '-')\n        cmd.append(f'--{key}')\n        if str(value):\n            cmd.append(str(value))\n    # Special handling for proxy server case\n    if server_config.get('proxy_url') and server_config.get('dp'):\n        cmd.append('--allow-terminate-by-client')\n    return cmd\n\n\ndef get_output_file(model_path, backend, server_config):\n    \"\"\"Generate the benchmark output filename.\"\"\"\n    model_name = server_config.get('model_name', None) or os.path.basename(model_path)\n\n    if backend not in ['turbomind', 'pytorch', 'sglang', 'vllm']:\n        raise ValueError(f'Unknown backend: {backend}')\n\n    if backend in ['sglang', 'vllm']:\n        return f'benchmark_{model_name}_{backend}.csv'\n\n    # For turbomind/pytorch backends\n    params = [\n        ('bs', server_config['max_batch_size']),\n        ('tp', server_config.get('tp', 1)),\n        ('dp', server_config.get('dp', '')),\n        ('ep', server_config.get('ep', '')),\n        ('cache', server_config.get('cache_max_entry_count', 0.8)),\n        ('mptk', server_config.get('max_prefill_token_num', '')),\n    ]\n    params_str = '_'.join(f'{k}{v}' for k, v in params if v != '')\n    # Turbomind-specific additions\n    if backend == 'turbomind' and (comm := server_config.get('communicator')):\n        params_str += f'_{comm}'\n\n    return f'benchmark_{model_name}_{backend}_{params_str}.csv'\n\n\ndef get_server_ip_port(backend: str, server_config: Dict) -> Tuple[str, int]:\n    if backend in ['turbomind', 'pytorch']:\n        if server_config.get('proxy_url'):\n            # If proxy_url is set, we use the proxy server's IP and port\n            parts = server_config['proxy_url'].split(':')\n            server_ip = parts[1].lstrip('//')\n            server_port = int(parts[2])\n        else:\n            # Default to the server IP and port specified in the config\n            server_ip = server_config.get('server_ip', '0.0.0.0')\n            server_port = server_config.get('server_port', 23333)\n    elif backend == 'sglang':\n        return (server_config.get('server_ip', '0.0.0.0'), server_config.get('port', 30000))\n    elif backend == 'vllm':\n        return (server_config.get('server_ip', '0.0.0.0'), server_config.get('port', 8000))\n    else:\n        raise ValueError(f'unknown backend: {backend}')\n    return server_ip, server_port\n\n\ndef wait_server_ready(server_ip: str, server_port: int) -> bool:\n    \"\"\"Wait for the API server to become ready.\"\"\"\n    from openai import OpenAI\n    while True:\n        try:\n            client = OpenAI(api_key='DUMMPY', base_url=f'http://{server_ip}:{server_port}/v1')\n            model_name = client.models.list().data[0].id\n            if model_name:\n                print('Server is ready.')\n                return True\n        except Exception as e:\n            print(f'connect to server http://{server_ip}:{server_port} failed {e}')\n            time.sleep(5)\n\n\ndef get_client_cmd(backend: str, server_ip: str, server_port: int, client_config: Dict) -> List[str]:\n    \"\"\"Generate the client benchmark command.\"\"\"\n    current_dir = os.path.dirname(os.path.abspath(__file__))\n    if backend in ['turbomind', 'pytorch']:\n        backend = 'lmdeploy'\n    cmd = [\n        'python3', f'{current_dir}/profile_restful_api.py', '--backend', backend, '--host', server_ip, '--port',\n        str(server_port)\n    ]\n    for key, value in client_config.items():\n        # change the key like 'dataset_path' to 'dataset-path' to suit the optional when performing\n        # \"python3 benchmark/profile_restful_api.py\"\n        key = key.replace('_', '-')\n        if key == 'disable-warmup':\n            if str(value).lower() == 'true':\n                cmd.append(f'--{key}')\n            continue\n        cmd.append(f'--{key}')\n        cmd.append(str(value))\n    return cmd\n\n\ndef benchmark(model_path: str, backend: str, server_config: Dict, data_config: Dict | List[Dict]):\n    \"\"\"Benchmark the server with the given configuration.\n\n    Args:\n        model_path: Path to the model.\n        backend: Backend to use.\n        server_config: Configuration for the server and the inference engine.\n        data_config: Configuration for the data.\n    \"\"\"\n    if isinstance(data_config, Dict):\n        data_config = [data_config]\n    if not (isinstance(data_config, List) and all(isinstance(d, Dict) for d in data_config)):\n        raise ValueError('data_config must be a dict or list of dicts')\n\n    server_cmd = get_launching_server_cmd(model_path, backend, server_config)\n    server_ip, server_port = get_server_ip_port(backend, server_config)\n    proc = None\n\n    try:\n\n        print(f\"Starting api_server: {' '.join(server_cmd)}\", flush=True)\n        proc = subprocess.Popen(server_cmd)\n        # Wait for the server to be ready\n        wait_server_ready(server_ip, server_port)\n        # Run benchmarks\n        output_file = get_output_file(model_path, backend, server_config)\n        for data in data_config:\n            data = data.copy()\n            data['output_file'] = output_file\n            client_cmd = get_client_cmd(backend, server_ip, server_port, data)\n            print(f\"Running benchmark: {' '.join(client_cmd)}\")\n            subprocess.run(client_cmd, check=True)\n    except Exception as e:\n        print(f'Unexpected error: {e}')\n        raise\n    finally:\n        # Clean up server process\n        if proc and proc.poll() is None:\n            if server_config.get('proxy_url') and server_config.get('dp'):\n                # Sending termination request to proxy_server. The request will be broadcasted to\n                # api_server on each dp_rank by proxy server\n                # Note that api_server is supposed to be launched with --allow-terminate-by-client\n                print('Sending termination request to proxy server')\n                subprocess.run(['curl', '-X', 'POST', f'{server_config[\"proxy_url\"]}/nodes/terminate_all'],\n                               check=True,\n                               timeout=10)\n            proc.terminate()\n            try:\n                proc.wait(timeout=30)\n            except subprocess.TimeoutExpired:\n                print('Server did not terminate gracefully - killing')\n                proc.kill()\n\n\ndef validate_config(config: Dict) -> None:\n    \"\"\"Validate the configuration structure.\n\n    Args:\n        config: Loaded configuration dictionary\n\n    Raises:\n        BenchmarkConfigError: If configuration is invalid\n    \"\"\"\n    required_sections = ['api_server', 'engine', 'data']\n    for section in required_sections:\n        if section not in config:\n            raise ValueError(f'Missing required config section: {section}')\n\n    if not isinstance(config['engine'], (Dict, List)):\n        raise ValueError('engine config must be a dict or list of dicts')\n\n    if not isinstance(config['data'], (Dict, List)):\n        raise ValueError('data config must be a dict or list of dicts')\n\n\ndef main(backend: str, config_path: str, model_path: Optional[str] = None):\n    \"\"\"Main entry point for the benchmark script.\n\n    Args:\n        backend: Backend to use\n        config_path: Path to config file\n        model_path: Optional override for model path\n    Raises:\n        BenchmarkConfigError: If required parameters are missing or config is invalid\n    \"\"\"\n    with open(config_path, 'r') as f:\n        config = yaml.safe_load(f)\n        server_config = config['server']\n        engine_configs = config['engine']\n        data_config = config['data']\n        if isinstance(engine_configs, Dict):\n            engine_configs = [engine_configs]\n        assert isinstance(engine_configs, List) and all(isinstance(s, Dict) for s in engine_configs)\n        for engine_config in engine_configs:\n            server_config = server_config.copy()\n            server_config.update(engine_config)  # Merge engine config with server config\n            # The model_path provided by the user will override the model_path in the config file.\n            model_path = model_path or server_config.pop('model_path')\n            # Remove model_path from server_config to avoid passing it to the server command\n            server_config.pop('model_path', None)\n            benchmark(model_path, backend, server_config, data_config)\n\n\nif __name__ == '__main__':\n    fire.Fire(main)\n"
  },
  {
    "path": "benchmark/benchmark_throughput.py",
    "content": "import os\nimport subprocess\nfrom typing import Dict, List\n\nimport fire\nimport yaml\n\n\ndef get_cmd(model_path, backend, engine_config, data_config):\n    assert backend in ['turbomind', 'pytorch']\n\n    current_dir = os.path.dirname(os.path.abspath(__file__))\n\n    dataset_path = data_config.pop('dataset_path')\n\n    cmd = ['python3', f'{current_dir}/profile_throughput.py', dataset_path, model_path, '--backend', backend]\n    for key, value in engine_config.items():\n        # profile_throughput.py uses \"--concurrency\" to pass the \"max_batch_size\" value\n        if key == 'max_batch_size':\n            key = 'concurrency'\n        # change the key like 'cache_max_entry_count' to 'cache-max-entry-count' to suit the optional\n        # arguments in \"python3 benchmark/profile_throughput.py\"\n        key = key.replace('_', '-')\n        cmd.append(f'--{key}')\n        cmd.append(str(value))\n\n    for key, value in data_config.items():\n        # change the key like 'sharegpt_output_len' to 'sharegpt-output-len' to suit the optional\n        # arguments in \"python3 benchmark/profile_throughput.py\"\n        key = key.replace('_', '-')\n        cmd.append(f'--{key}')\n        cmd.append(str(value))\n    return cmd\n\n\ndef benchmark(model_path, backend, engine_config, data_config):\n    \"\"\"Benchmark the performance with the given configuration.\n\n    Args:\n        model_path: Path to the model.\n    :param backend: Backend to use.\n    :param engine_config: Configuration for the inference engine.\n    :param data_config: Configuration for the data.\n    \"\"\"\n    model_name = os.path.basename(model_path)\n    bs = engine_config['max_batch_size']\n    cach_ratio = engine_config.get('cache_max_entry_count', 0.8)\n    tp = engine_config.get('tp', 1)\n    output_file = f'benchmark_throughput_{model_name}_{backend}_bs{bs}_tp{tp}_cache{cach_ratio}.csv'\n    try:\n        if isinstance(data_config, Dict):\n            data_config = [data_config]\n        assert isinstance(data_config, List) and all(isinstance(d, Dict) for d in data_config)\n        for _data_config in data_config:\n            _data_config['csv'] = output_file\n            cmd = get_cmd(model_path, backend, engine_config, _data_config)\n            print(f\"Running command: {' '.join(cmd)}\")\n            subprocess.run(cmd, check=True)\n    except Exception as e:\n        print(f'exception happened, {e}')\n\n\ndef main(model_path=None, backend=None, config_path=None):\n    with open(config_path, 'r') as f:\n        config = yaml.safe_load(f)\n        engine_configs = config['engine']\n        data_config = config['data']\n        if isinstance(engine_configs, Dict):\n            engine_configs = [engine_configs]\n        assert isinstance(engine_configs, List) and all(isinstance(s, Dict) for s in engine_configs)\n        for engine_config in engine_configs:\n            # The model_path provided by the user will override the model_path in the config file.\n            model_path = model_path or engine_config.pop('model_path')\n            engine_config.pop('model_path', '')\n            benchmark(model_path, backend, engine_config, data_config)\n\n\nif __name__ == '__main__':\n    fire.Fire(main)\n"
  },
  {
    "path": "benchmark/lmdeploy.yml",
    "content": "num_promts: &num_prompts 10000\ndataset_path: &dataset_path \"/nvme1/shared/ShareGPT_V3_unfiltered_cleaned_split.json\"\ndataset_name: &dataset_name \"sharegpt\"\nmodel_path: &model_path \"Qwen/Qwen3-30B-A3B-FP8\"\nserver:\n  server_port: 23333\n# Inference engine configuration\nengine:\n  - model_path: *model_path\n    max_batch_size: 1280\n    cache_max_entry_count: 0.9\n    tp: 1\n  - model_path: *model_path\n    max_batch_size: 1280\n    cache_max_entry_count: 0.9\n    max_prefill_token_num: 4096\n    tp: 1\n  - model_path: \"Qwen/Qwen3-235B-A22B-FP8\"\n    max_batch_size: 64\n    cache_max_entry_count: 0.7\n    max_prefill_token_num: 4096\n    dp: 8\n    ep: 8\n    proxy_url: \"http://localhost:8000\"\n# Benchmark test configuration for profile_restful_api.py\n# Defines multiple test cases with different output lengths to evaluate API performance\ndata:\n  - dataset_name: *dataset_name\n    dataset_path: *dataset_path\n    num_prompts: *num_prompts\n  - dataset_name: *dataset_name\n    dataset_path: *dataset_path\n    sharegpt_output_len: 2048\n    num_prompts: *num_prompts\n  - dataset_name: *dataset_name\n    dataset_path: *dataset_path\n    sharegpt_output_len: 4096\n    num_prompts: *num_prompts\n  - dataset_name: *dataset_name\n    dataset_path: *dataset_path\n    sharegpt_output_len: 8192\n    num_prompts: *num_prompts\n  - dataset_name: *dataset_name\n    dataset_path: *dataset_path\n    sharegpt_output_len: 16384\n    num_prompts: *num_prompts\n  - dataset_name: *dataset_name\n    dataset_path: *dataset_path\n    sharegpt_output_len: 32768\n    num_prompts: *num_prompts\n"
  },
  {
    "path": "benchmark/profile_pipeline_api.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport argparse\nimport json\nimport os\nimport random\nfrom typing import List, Optional, Tuple\n\nimport numpy as np\nfrom tqdm import tqdm\nfrom transformers import AutoTokenizer, PreTrainedTokenizerBase\n\nfrom lmdeploy import GenerationConfig, PytorchEngineConfig, TurbomindEngineConfig, pipeline\nfrom lmdeploy.cli.utils import ArgumentHelper, DefaultsAndTypesHelpFormatter\nfrom lmdeploy.profiler import Profiler, Session\nfrom lmdeploy.utils import get_logger\n\nlogger = get_logger('lmdeploy')\n\n\ndef sample_sharegpt_requests(\n    dataset_path: str,\n    num_requests: int,\n    tokenizer: PreTrainedTokenizerBase,\n    fixed_output_len: Optional[int] = None,\n) -> List[Tuple[str, int, int]]:\n    if fixed_output_len is not None and fixed_output_len < 4:\n        raise ValueError('output_len too small')\n\n    # Load the dataset.\n    with open(dataset_path) as f:\n        dataset = json.load(f)\n    # Filter out the conversations with less than 2 turns.\n    dataset = [data for data in dataset if len(data['conversations']) >= 2]\n    # Only keep the first two turns of each conversation.\n    dataset = [(data['conversations'][0]['value'], data['conversations'][1]['value']) for data in dataset]\n\n    # Shuffle the dataset.\n    random.shuffle(dataset)\n\n    # Filter out sequences that are too long or too short\n    filtered_dataset: List[Tuple[str, int, int]] = []\n    for i in range(len(dataset)):\n        if len(filtered_dataset) == num_requests:\n            break\n\n        # Tokenize the prompts and completions.\n        prompt = dataset[i][0]\n        prompt_token_ids = tokenizer.encode(prompt)\n        completion = dataset[i][1]\n        completion_token_ids = tokenizer.encode(completion)\n        prompt_len = len(prompt_token_ids)\n        output_len = (len(completion_token_ids) if fixed_output_len is None else fixed_output_len)\n        if prompt_len < 4 or output_len < 4:\n            # Prune too short sequences.\n            continue\n        if prompt_len > 1024 or (prompt_len + output_len > 2048 and fixed_output_len is None):\n            # Prune too long sequences.\n            continue\n        filtered_dataset.append((prompt, prompt_len, output_len))\n\n    print(f'#Input tokens: {np.sum([x[1] for x in filtered_dataset])}')\n    print(f'#Output tokens: {np.sum([x[2] for x in filtered_dataset])}')\n    return filtered_dataset\n\n\ndef sample_random_requests(\n    input_len: int,\n    output_len: int,\n    num_prompts: int,\n    range_ratio: float,\n    tokenizer: PreTrainedTokenizerBase,\n    dataset_path: str,\n) -> List[Tuple[str, int, int]]:\n\n    input_lens = np.random.randint(\n        max(int(input_len * range_ratio), 1),\n        input_len + 1,\n        size=num_prompts,\n    )\n    output_lens = np.random.randint(\n        int(output_len * range_ratio),\n        output_len + 1,\n        size=num_prompts,\n    )\n\n    if True:\n        # Sample token ids from ShareGPT and repeat/truncate them to\n        # satisfy the input_lens\n\n        # Load the dataset.\n        with open(dataset_path) as f:\n            dataset = json.load(f)\n        # Filter out the conversations with less than 2 turns.\n        dataset = [data for data in dataset if len(data['conversations']) >= 2]\n        # Only keep the first two turns of each conversation.\n        dataset = [(data['conversations'][0]['value'], data['conversations'][1]['value']) for data in dataset]\n        # remove the empty prompt\n        dataset = [(query, answer) for query, answer in dataset if len(query) > 0]\n\n        # Shuffle the dataset.\n        random.shuffle(dataset)\n\n        # Filter out sequences that are too long or too short\n        input_requests: List[Tuple[str, int, int]] = []\n        for i in range(num_prompts):\n            # Tokenize the prompts and completions.\n            prompt = dataset[i][0]\n            prompt_token_ids = tokenizer.encode(prompt)\n            prompt_len = len(prompt_token_ids)\n\n            if prompt_len > input_lens[i]:\n                input_ids = prompt_token_ids[:input_lens[i]]\n            else:\n                ratio = (input_lens[i] + prompt_len - 1) // prompt_len\n                input_ids = (prompt_token_ids * ratio)[:input_lens[i]]\n            prompt = tokenizer.decode(input_ids)\n            input_requests.append((prompt, int(input_lens[i]), int(output_lens[i])))\n    else:\n        # Sample token ids from random integers.\n        # This can cause some NaN issues.\n        offsets = np.random.randint(0, tokenizer.vocab_size, size=num_prompts)\n        input_requests = []\n        for i in range(num_prompts):\n            prompt = tokenizer.decode([(offsets[i] + i + j) % tokenizer.vocab_size for j in range(input_lens[i])])\n            input_requests.append((prompt, int(input_lens[i]), int(output_lens[i])))\n\n    print(f'#Input tokens: {np.sum(input_lens)}')\n    print(f'#Output tokens: {np.sum(output_lens)}')\n    return input_requests\n\n\nclass Engine:\n\n    def __init__(self, model_path: str, engine_config, csv: str):\n        self.pipe = pipeline(model_path, backend_config=engine_config, log_level='ERROR')\n        self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)\n        self.return_routed_experts = getattr(self.pipe.backend_config, 'enable_return_routed_experts', False)\n        self.csv = csv\n\n    def process_request(self, requests, profiler: Profiler, temperature, top_p, top_k, stream_output):\n\n        prompts = [prompt for prompt, _, _ in requests]\n        gen_configs = [\n            GenerationConfig(temperature=temperature,\n                             top_p=top_p,\n                             top_k=top_k,\n                             ignore_eos=True,\n                             do_sample=False,\n                             return_routed_experts=self.return_routed_experts,\n                             max_new_tokens=output_len) for _, _, output_len in requests\n        ]\n\n        sess: List[Session] = []\n        for _, input_len, output_len in requests:\n            sess.append(profiler.new_session(input_len, output_len))\n\n        def _to_status(finish_reason):\n            if finish_reason == 'length':\n                return Session.SUCCESS\n            else:\n                return Session.FAIL\n\n        profiler.start()\n\n        for s in sess:\n            s.tick(0)\n\n        if stream_output:\n            pbar = tqdm(total=len(requests))\n            for output in self.pipe.stream_infer(prompts, gen_config=gen_configs, do_preprocess=False):\n                index = output.index\n                n_token = output.generate_token_len\n                finish_reason = output.finish_reason\n                sess[index].tick(n_token)\n                if finish_reason is not None:\n                    sess[index].finish(_to_status(finish_reason))\n                    pbar.update(1)\n            pbar.close()\n        else:\n            for output in self.pipe(prompts, gen_configs, do_preprocess=False, use_tqdm=True):\n                index = output.index\n                n_token = output.generate_token_len\n                finish_reason = output.finish_reason\n                sess[index].tick(n_token)\n                sess[index].finish(_to_status(finish_reason))\n\n        profiler.finish()\n\n        # report first failure\n        for i, s in enumerate(sess):\n            if s.status != Session.SUCCESS or s.ns[-1] < s.req_output_len:\n                logger.error(f'Request {i} failed with {s.ns[-1]}/{s.req_output_len} tokens generated'  # noqa: E501\n                             )\n                logger.error(f'Prompt: {prompts[i]}')\n                logger.warning('Got failed requests, metrics may be invalid')\n                break\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description='Benchmark the request throughput of lmdeploy '\n                                     'in localhost',\n                                     formatter_class=DefaultsAndTypesHelpFormatter)\n    parser.add_argument('dataset', type=str, help='the path dataset')\n    parser.add_argument('model_path',\n                        type=str,\n                        help='the path of the model in localhost or '\n                        'the repo_id of the model in huggingface.co')\n    parser.add_argument('-c',\n                        '--concurrency',\n                        type=int,\n                        help='Number of working threads to process the sampled prompts',\n                        default=256)\n    parser.add_argument('-n', '--num-prompts', type=int, help='Number of prompts to process', default=5000)\n    parser.add_argument('--csv', type=str, help='Where to save the result.', default='./profile_pipeline_api.csv')\n    parser.add_argument('--seed', type=int, default=0, help='Seed used in sampling prompts from dataset')\n    parser.add_argument('--stream-output', action='store_true', help='Trust remote code for loading hf models')\n    parser.add_argument('--dataset-name',\n                        type=str,\n                        default='sharegpt',\n                        choices=['sharegpt', 'random'],\n                        help='Name of the dataset to benchmark on.')\n    parser.add_argument(\n        '--sharegpt-output-len',\n        type=int,\n        default=None,\n        help='Output length for each request. Overrides the output length '\n        'from the ShareGPT dataset.',\n    )\n    parser.add_argument(\n        '--random-input-len',\n        type=int,\n        help='Number of input tokens per request, used only for random '\n        'dataset.',\n    )\n    parser.add_argument(\n        '--random-output-len',\n        type=int,\n        help='Number of output tokens per request, used only for random '\n        'dataset.',\n    )\n    parser.add_argument(\n        '--random-range-ratio',\n        type=float,\n        default=0.0,\n        help='Range of sampled ratio of input/output length, '\n        'used only for random dataset.',\n    )\n    # other args\n    ArgumentHelper.top_p(parser)\n    ArgumentHelper.temperature(parser)\n    ArgumentHelper.top_k(parser)\n    ArgumentHelper.log_level(parser)\n    ArgumentHelper.backend(parser)\n\n    # pytorch engine args\n    pt_group = parser.add_argument_group('PyTorch engine arguments')\n    ArgumentHelper.eager_mode(pt_group)\n    ArgumentHelper.enable_return_routed_experts(pt_group)\n\n    tp_act = ArgumentHelper.tp(pt_group)\n    cache_count_act = ArgumentHelper.cache_max_entry_count(pt_group)\n    session_len_act = ArgumentHelper.session_len(pt_group)\n    cache_block_seq_len_act = ArgumentHelper.cache_block_seq_len(pt_group)\n    prefix_caching_act = ArgumentHelper.enable_prefix_caching(pt_group)\n\n    # turbomind engine args\n    tb_group = parser.add_argument_group('TurboMind engine argument')\n    tb_group._group_actions.append(tp_act)\n    tb_group._group_actions.append(cache_count_act)\n    tb_group._group_actions.append(session_len_act)\n    tb_group._group_actions.append(cache_block_seq_len_act)\n    tb_group._group_actions.append(prefix_caching_act)\n    ArgumentHelper.model_format(tb_group, default='hf')\n    ArgumentHelper.quant_policy(tb_group, default=0)\n    ArgumentHelper.num_tokens_per_iter(tb_group)\n    ArgumentHelper.max_prefill_iters(tb_group)\n    ArgumentHelper.communicator(tb_group)\n    ArgumentHelper.async_(tb_group)\n\n    args = parser.parse_args()\n    return args\n\n\ndef main():\n    args = parse_args()\n    random.seed(args.seed)\n    os.environ['TM_LOG_LEVEL'] = args.log_level\n    if args.backend == 'turbomind':\n        engine_config = TurbomindEngineConfig(max_batch_size=args.concurrency,\n                                              tp=args.tp,\n                                              cache_max_entry_count=args.cache_max_entry_count,\n                                              session_len=args.session_len,\n                                              cache_block_seq_len=args.cache_block_seq_len,\n                                              model_format=args.model_format,\n                                              quant_policy=args.quant_policy,\n                                              num_tokens_per_iter=args.num_tokens_per_iter,\n                                              max_prefill_iters=args.max_prefill_iters,\n                                              enable_prefix_caching=args.enable_prefix_caching,\n                                              communicator=args.communicator,\n                                              enable_metrics=False,\n                                              async_=args.async_)\n    elif args.backend == 'pytorch':\n        engine_config = PytorchEngineConfig(\n            cache_max_entry_count=args.cache_max_entry_count,\n            session_len=args.session_len,\n            block_size=args.cache_block_seq_len,\n            max_batch_size=args.concurrency,\n            tp=args.tp,\n            thread_safe=False,\n            eager_mode=args.eager_mode,\n            enable_prefix_caching=args.enable_prefix_caching,\n            enable_return_routed_experts=args.enable_return_routed_experts,\n        )\n\n    engine = Engine(args.model_path, engine_config, csv=args.csv)\n\n    profiler = Profiler(args.stream_output, [50, 75, 95, 99])\n\n    if args.dataset_name == 'sharegpt':\n        assert args.random_input_len is None and args.random_output_len is None\n        requests = sample_sharegpt_requests(\n            dataset_path=args.dataset,\n            num_requests=args.num_prompts,\n            tokenizer=engine.tokenizer,\n            fixed_output_len=args.sharegpt_output_len,\n        )\n    elif args.dataset_name == 'random':\n        assert args.random_input_len is not None and \\\n            args.random_output_len is not None\n        requests = sample_random_requests(\n            input_len=args.random_input_len,\n            output_len=args.random_output_len,\n            num_prompts=args.num_prompts,\n            range_ratio=args.random_range_ratio,\n            tokenizer=engine.tokenizer,\n            dataset_path=args.dataset,\n        )\n    else:\n        raise ValueError(f'Unknown dataset: {args.dataset_name}')\n\n    engine.process_request(requests,\n                           profiler,\n                           temperature=args.temperature,\n                           top_p=args.top_p,\n                           top_k=args.top_k,\n                           stream_output=args.stream_output)\n\n    hyperparams = [('Concurrency', args.concurrency), ('Stream output', str(args.stream_output).lower())]\n\n    profiler.compute_metrics()\n    profiler.summarize(title='Profile Pipeline API', hyperparams=hyperparams)\n\n    if args.csv:\n        # profiler.save_csv(args.csv, (('batch', args.concurrency), ('num_prompts', args.num_prompts)))\n        profiler.save_csv(args.csv, (\n            ('backend', args.backend),\n            ('bs', args.concurrency),\n            ('dataset_name', args.dataset_name),\n            ('sharegpt_output_len', args.sharegpt_output_len),\n            ('random_input_len', args.random_input_len),\n            ('random_output_len', args.random_output_len),\n            ('random_range_ratio', args.random_range_ratio),\n            ('num_prompts', args.num_prompts),\n        ))\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "benchmark/profile_restful_api.py",
    "content": "# Modify from https://github.com/sgl-project/sglang/blob/main/python/sglang/bench_serving.py  # noqa\n# Adapted from https://github.com/vllm-project/vllm/blob/6366efc67b0aedd2c1721c14385370e50b297fb3/benchmarks/backend_request_func.py  # noqa\n# Adapted from https://github.com/vllm-project/vllm/blob/6366efc67b0aedd2c1721c14385370e50b297fb3/benchmarks/benchmark_serving.py  # noqa\n\"\"\"Benchmark online serving with dynamic requests.\n\nUsage:\npython3 -m sglang.bench_serving --backend sglang --num-prompt 10\n\npython3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 3000 --random-input 1024 --random-output 1024 --random-range-ratio 0.5\npython3 -m sglang.bench_serving --backend sglang --dataset-name random --request-rate-range 1,2,4,8,16,32 --random-input 4096 --random-output 1024 --random-range-ratio 0.125 --multi\n\"\"\"  # noqa\nimport argparse\nimport asyncio\nimport csv\nimport io\nimport json\nimport os\nimport random\nimport resource\nimport sys\nimport time\nimport traceback\nimport warnings\nfrom argparse import ArgumentParser\nfrom dataclasses import dataclass, field\nfrom datetime import datetime\nfrom typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union\n\nimport aiohttp\nimport numpy as np\nimport pybase64\nimport requests\nfrom PIL import Image\nfrom tqdm.asyncio import tqdm\nfrom transformers import (AutoProcessor, AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerBase,\n                          PreTrainedTokenizerFast)\n\nAIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=None)\n\n_timeout_value = os.getenv('AIOHTTP_TIMEOUT', None)\nif _timeout_value is not None:\n    try:\n        _timeout_value = int(_timeout_value)\n        if _timeout_value < 0:\n            raise ValueError('AIOHTTP_TIMEOUT cannot be negative.')\n        AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=_timeout_value * 60 * 60)\n    except ValueError as e:\n        print(f'Invalid AIOHTTP_TIMEOUT: {e}.')\n        AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=None)\n\nglobal args\n\n\n@dataclass\nclass RequestFuncInput:\n    prompt: str\n    api_url: str\n    prompt_len: int\n    output_len: int\n    model: str\n    image_data: Optional[List[str]]\n    extra_request_body: Dict[str, Any]\n\n\n@dataclass\nclass RequestFuncOutput:\n    generated_text: str = ''\n    success: bool = False\n    latency: float = 0.0\n    ttft: float = 0.0  # Time to first token\n    itl: List[float] = field(default_factory=list)  # List of inter-token latencies\n    prompt_len: int = 0\n    output_len: int = 0\n    error: str = ''\n\n\ndef remove_prefix(text: str, prefix: str) -> str:\n    return text[len(prefix):] if text.startswith(prefix) else text\n\n\n# trt llm not support ignore_eos\n# https://github.com/triton-inference-server/tensorrtllm_backend/issues/505\nasync def async_request_trt_llm(\n    request_func_input: RequestFuncInput,\n    pbar: Optional[tqdm] = None,\n) -> RequestFuncOutput:\n    api_url = request_func_input.api_url\n    assert api_url.endswith('generate_stream')\n\n    async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:\n        payload = {\n            'accumulate_tokens': True,\n            'text_input': request_func_input.prompt,\n            'temperature': 0.000001,\n            'top_p': 1.0,\n            'max_tokens': request_func_input.output_len,\n            'stream': True,\n            'min_length': request_func_input.output_len,\n            'end_id': 1048576,\n            **request_func_input.extra_request_body,\n        }\n        if args.disable_ignore_eos:\n            del payload['min_length']\n            del payload['end_id']\n        output = RequestFuncOutput()\n        output.prompt_len = request_func_input.prompt_len\n\n        ttft = 0.0\n        st = time.perf_counter()\n        most_recent_timestamp = st\n        try:\n            async with session.post(url=api_url, json=payload) as response:\n                if response.status == 200:\n                    async for chunk_bytes in response.content:\n                        chunk_bytes = chunk_bytes.strip()\n                        if not chunk_bytes:\n                            continue\n\n                        chunk = remove_prefix(chunk_bytes.decode('utf-8'), 'data:')\n\n                        data = json.loads(chunk)\n                        output.generated_text += data['text_output']\n                        timestamp = time.perf_counter()\n                        # First token\n                        if ttft == 0.0:\n                            ttft = time.perf_counter() - st\n                            output.ttft = ttft\n\n                        # Decoding phase\n                        else:\n                            output.itl.append(timestamp - most_recent_timestamp)\n\n                        most_recent_timestamp = timestamp\n\n                    output.latency = most_recent_timestamp - st\n                    output.success = True\n                    output.output_len = request_func_input.output_len\n\n                else:\n                    output.error = response.reason or ''\n                    output.success = False\n        except Exception:\n            output.success = False\n            exc_info = sys.exc_info()\n            output.error = ''.join(traceback.format_exception(*exc_info))\n\n        if pbar:\n            pbar.update(1)\n        return output\n\n\n# set ignore_eos True by default\nasync def async_request_openai_completions(\n    request_func_input: RequestFuncInput,\n    pbar: Optional[tqdm] = None,\n) -> RequestFuncOutput:\n    api_url = request_func_input.api_url\n    assert api_url.endswith('completions'), \"OpenAI Completions API URL must end with 'completions'.\"\n\n    prompt = request_func_input.prompt\n\n    async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:\n        payload = {\n            'model': request_func_input.model,\n            'prompt': prompt,\n            'temperature': 0.0,\n            'best_of': 1,\n            'max_tokens': request_func_input.output_len,\n            'stream': not args.disable_stream,\n            'ignore_eos': not args.disable_ignore_eos,\n            **request_func_input.extra_request_body,\n        }\n        headers = {'Authorization': f\"Bearer {os.environ.get('OPENAI_API_KEY')}\"}\n\n        output = RequestFuncOutput()\n        output.prompt_len = request_func_input.prompt_len\n\n        generated_text = ''\n        ttft = 0.0\n        st = time.perf_counter()\n        most_recent_timestamp = st\n        try:\n            async with session.post(url=api_url, json=payload, headers=headers) as response:\n                if response.status == 200:\n                    async for chunk_bytes in response.content:\n                        chunk_bytes = chunk_bytes.strip()\n                        if not chunk_bytes:\n                            continue\n\n                        chunk = remove_prefix(chunk_bytes.decode('utf-8'), 'data: ')\n                        latency = time.perf_counter() - st\n                        if chunk == '[DONE]':\n                            pass\n                        else:\n                            data = json.loads(chunk)\n\n                            # NOTE: Some completion API might have a last\n                            # usage summary response without a token so we\n                            # want to check a token was generated\n                            if data['choices'][0]['text']:\n                                timestamp = time.perf_counter()\n                                # First token\n                                if ttft == 0.0:\n                                    ttft = time.perf_counter() - st\n                                    output.ttft = ttft\n\n                                # Decoding phase\n                                else:\n                                    output.itl.append(timestamp - most_recent_timestamp)\n\n                                most_recent_timestamp = timestamp\n                                generated_text += data['choices'][0]['text']\n\n                    output.generated_text = generated_text\n                    output.success = True\n                    output.latency = latency\n                    output.output_len = request_func_input.output_len\n                else:\n                    output.error = response.reason or ''\n                    output.success = False\n        except Exception:\n            output.success = False\n            exc_info = sys.exc_info()\n            output.error = ''.join(traceback.format_exception(*exc_info))\n\n    if pbar:\n        pbar.update(1)\n    return output\n\n\nasync def async_request_openai_chat_completions(\n    request_func_input: RequestFuncInput,\n    pbar: Optional[tqdm] = None,\n) -> RequestFuncOutput:\n    api_url = request_func_input.api_url\n    assert api_url.endswith('chat/completions'), \"OpenAI Chat Completions API URL must end with 'chat/completions'.\"\n\n    if request_func_input.image_data:\n        # Build multi-image content: a list of image_url entries followed by the text\n        content_items = [{\n            'type': 'image_url',\n            'image_url': {\n                'url': img_url\n            },\n        } for img_url in request_func_input.image_data]\n        content_items.append({'type': 'text', 'text': request_func_input.prompt})\n        messages = [\n            {\n                'role': 'user',\n                'content': content_items,\n            },\n        ]\n    else:\n        messages = [{'role': 'user', 'content': request_func_input.prompt}]\n\n    async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:\n        payload = {\n            'model': request_func_input.model,\n            'messages': messages,\n            'temperature': 0.0,\n            'max_completion_tokens': request_func_input.output_len,\n            'stream': not args.disable_stream,\n            'ignore_eos': not args.disable_ignore_eos,\n            **request_func_input.extra_request_body,\n        }\n        headers = {'Authorization': f\"Bearer {os.environ.get('OPENAI_API_KEY')}\"}\n\n        output = RequestFuncOutput()\n        output.prompt_len = request_func_input.prompt_len\n\n        generated_text = ''\n        output_len = request_func_input.output_len\n        ttft = 0.0\n        st = time.perf_counter()\n        most_recent_timestamp = st\n        try:\n            async with session.post(url=api_url, json=payload, headers=headers) as response:\n                if response.status == 200:\n                    if args.disable_stream:\n                        # Non-streaming response\n                        response_json = await response.json()\n                        output.generated_text = response_json['choices'][0]['message']['content']\n                        output.success = True\n                        output.latency = time.perf_counter() - st\n                        output.ttft = (output.latency)  # For non-streaming, TTFT = total latency\n                        output.output_len = response_json.get('usage', {}).get('completion_tokens', output_len)\n                    else:\n                        # Streaming response\n                        async for chunk_bytes in response.content:\n                            chunk_bytes = chunk_bytes.strip()\n                            if not chunk_bytes:\n                                continue\n\n                            chunk = remove_prefix(chunk_bytes.decode('utf-8'), 'data: ')\n                            latency = time.perf_counter() - st\n                            if chunk == '[DONE]':\n                                pass\n                            else:\n                                data = json.loads(chunk)\n\n                                # Check if this chunk contains content\n                                delta = data.get('choices', [{}])[0].get('delta', {})\n                                content = delta.get('content', '')\n\n                                if content:\n                                    timestamp = time.perf_counter()\n                                    # First token\n                                    if ttft == 0.0:\n                                        ttft = timestamp - st\n                                        output.ttft = ttft\n\n                                    # Decoding phase\n                                    else:\n                                        output.itl.append(timestamp - most_recent_timestamp)\n\n                                    most_recent_timestamp = timestamp\n                                    generated_text += content\n\n                                # Check for usage info in final chunk\n                                output_len = (data.get('usage') or {}).get('completion_tokens', output_len)\n\n                        output.generated_text = generated_text\n                        output.success = True\n                        output.latency = latency\n                        output.output_len = output_len\n                else:\n                    output.error = ((response.reason or '') + ': ' + (await response.text()))\n                    output.success = False\n        except Exception:\n            output.success = False\n            exc_info = sys.exc_info()\n            output.error = ''.join(traceback.format_exception(*exc_info))\n\n    if pbar:\n        pbar.update(1)\n    return output\n\n\nasync def async_request_sglang_generate(\n    request_func_input: RequestFuncInput,\n    pbar: Optional[tqdm] = None,\n) -> RequestFuncOutput:\n    api_url = request_func_input.api_url\n    prompt = request_func_input.prompt\n\n    async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:\n        payload = {\n            'text': prompt,\n            'sampling_params': {\n                'temperature': 0.0,\n                'max_new_tokens': request_func_input.output_len,\n                'ignore_eos': not args.disable_ignore_eos,\n            },\n            'stream': not args.disable_stream,\n            **request_func_input.extra_request_body,\n        }\n        headers = {}\n\n        output = RequestFuncOutput()\n        output.prompt_len = request_func_input.prompt_len\n\n        generated_text = ''\n        ttft = 0.0\n        st = time.perf_counter()\n        most_recent_timestamp = st\n        try:\n            async with session.post(url=api_url, json=payload, headers=headers) as response:\n                if response.status == 200:\n                    async for chunk_bytes in response.content:\n                        chunk_bytes = chunk_bytes.strip()\n                        if not chunk_bytes:\n                            continue\n                        # print(chunk_bytes)\n\n                        chunk = remove_prefix(chunk_bytes.decode('utf-8'), 'data: ')\n                        latency = time.perf_counter() - st\n                        if chunk == '[DONE]':\n                            pass\n                        else:\n                            data = json.loads(chunk)\n\n                            # NOTE: Some completion API might have a last\n                            # usage summary response without a token so we\n                            # want to check a token was generated\n                            if data['text']:\n                                timestamp = time.perf_counter()\n                                # First token\n                                if ttft == 0.0:\n                                    ttft = time.perf_counter() - st\n                                    output.ttft = ttft\n\n                                # Decoding phase\n                                else:\n                                    output.itl.append(timestamp - most_recent_timestamp)\n\n                                most_recent_timestamp = timestamp\n                                generated_text = data['text']\n\n                    output.generated_text = generated_text\n                    output.success = True\n                    output.latency = latency\n                    output.output_len = request_func_input.output_len\n                else:\n                    output.error = response.reason or ''\n                    output.success = False\n        except Exception:\n            output.success = False\n            exc_info = sys.exc_info()\n            output.error = ''.join(traceback.format_exception(*exc_info))\n\n    if pbar:\n        pbar.update(1)\n    return output\n\n\nasync def async_request_gserver(\n    request_func_input: RequestFuncInput,\n    pbar: Optional[tqdm] = None,\n) -> RequestFuncOutput:\n    raise NotImplementedError()\n\n\ndef get_model(pretrained_model_name_or_path: str) -> str:\n    if os.getenv('SGLANG_USE_MODELSCOPE', 'False').lower() == 'true':\n        import huggingface_hub.constants\n        from modelscope import snapshot_download\n\n        model_path = snapshot_download(\n            model_id=pretrained_model_name_or_path,\n            local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,\n            ignore_file_pattern=['.*.pt', '.*.safetensors', '.*.bin'],\n        )\n\n        return model_path\n    return pretrained_model_name_or_path\n\n\ndef get_tokenizer(pretrained_model_name_or_path: str, ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:\n    if pretrained_model_name_or_path.endswith('.json') or pretrained_model_name_or_path.endswith('.model'):\n        from sglang.srt.hf_transformers_utils import get_tokenizer\n\n        return get_tokenizer(pretrained_model_name_or_path)\n\n    if pretrained_model_name_or_path is not None and not os.path.exists(pretrained_model_name_or_path):\n        pretrained_model_name_or_path = get_model(pretrained_model_name_or_path)\n    return AutoTokenizer.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True)\n\n\ndef get_processor(pretrained_model_name_or_path: str, ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:\n    assert (pretrained_model_name_or_path is not None and pretrained_model_name_or_path != '')\n    if pretrained_model_name_or_path.endswith('.json') or pretrained_model_name_or_path.endswith('.model'):\n        from sglang.srt.utils.hf_transformers_utils import get_processor\n\n        return get_processor(pretrained_model_name_or_path)\n\n    if pretrained_model_name_or_path is not None and not os.path.exists(pretrained_model_name_or_path):\n        pretrained_model_name_or_path = get_model(pretrained_model_name_or_path)\n    return AutoProcessor.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True)\n\n\nASYNC_REQUEST_FUNCS = {\n    'sglang': async_request_sglang_generate,\n    'sglang-native': async_request_sglang_generate,\n    'sglang-oai': async_request_openai_completions,\n    'sglang-oai-chat': async_request_openai_chat_completions,\n    'vllm': async_request_openai_completions,\n    'vllm-chat': async_request_openai_chat_completions,\n    'lmdeploy': async_request_openai_completions,\n    'lmdeploy-chat': async_request_openai_chat_completions,\n    'trt': async_request_trt_llm,\n    'gserver': async_request_gserver,\n}\n\n\n@dataclass\nclass BenchmarkMetrics:\n    completed: int\n    total_input: int\n    total_input_text: int\n    total_input_vision: int\n    total_output: int\n    total_output_retokenized: int\n    request_throughput: float\n    input_throughput: float\n    output_throughput: float\n    output_throughput_retokenized: float\n    mean_ttft_ms: float\n    median_ttft_ms: float\n    std_ttft_ms: float\n    p99_ttft_ms: float\n    mean_tpot_ms: float\n    median_tpot_ms: float\n    std_tpot_ms: float\n    p99_tpot_ms: float\n    mean_itl_ms: float\n    median_itl_ms: float\n    std_itl_ms: float\n    p99_itl_ms: float\n    mean_e2e_latency_ms: float\n    median_e2e_latency_ms: float\n\n\nSHAREGPT_URL = 'https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json'  # noqa\n\n\ndef download_and_cache_file(url: str, filename: Optional[str] = None):\n    \"\"\"Read and cache a file from a url.\"\"\"\n    if filename is None:\n        filename = os.path.join('/tmp', url.split('/')[-1])\n\n    # Check if the cache file already exists\n    if os.path.exists(filename):\n        return filename\n\n    print(f'Downloading from {url} to {filename}')\n\n    # Stream the response to show the progress bar\n    response = requests.get(url, stream=True)\n    response.raise_for_status()  # Check for request errors\n\n    # Total size of the file in bytes\n    total_size = int(response.headers.get('content-length', 0))\n    chunk_size = 1024  # Download in chunks of 1KB\n\n    # Use tqdm to display the progress bar\n    with open(filename, 'wb') as f, tqdm(\n            desc=filename,\n            total=total_size,\n            unit='B',\n            unit_scale=True,\n            unit_divisor=1024,\n    ) as bar:\n        for chunk in response.iter_content(chunk_size=chunk_size):\n            f.write(chunk)\n            bar.update(len(chunk))\n\n    return filename\n\n\n@dataclass\nclass DatasetRow:\n    prompt: str\n    prompt_len: int\n    output_len: int\n    text_prompt_len: Optional[int] = None\n    vision_prompt_len: Optional[int] = None\n    image_data: Optional[List[str]] = None\n\n    def __post_init__(self):\n        if self.text_prompt_len is None:\n            self.text_prompt_len = self.prompt_len\n        if self.vision_prompt_len is None:\n            self.vision_prompt_len = 0\n\n\ndef sample_sharegpt_requests(dataset_path: str,\n                             num_requests: int,\n                             tokenizer: PreTrainedTokenizerBase,\n                             fixed_output_len: Optional[int] = None) -> List[DatasetRow]:\n    if fixed_output_len is not None and fixed_output_len < 4:\n        raise ValueError('output_len too small')\n\n    # Download sharegpt if necessary\n    if not os.path.isfile(dataset_path):\n        dataset_path = download_and_cache_file(SHAREGPT_URL)\n\n    # Load the dataset.\n    with open(dataset_path) as f:\n        dataset = json.load(f)\n    # Filter out the conversations with less than 2 turns.\n    dataset = [data for data in dataset if len(data['conversations']) >= 2]\n    # Only keep the first two turns of each conversation.\n    dataset = [(data['conversations'][0]['value'], data['conversations'][1]['value']) for data in dataset]\n\n    # Shuffle the dataset.\n    random.shuffle(dataset)\n\n    # Filter out sequences that are too long or too short\n    filtered_dataset: List[DatasetRow] = []\n    for i in range(len(dataset)):\n        if len(filtered_dataset) == num_requests:\n            break\n\n        # Tokenize the prompts and completions.\n        prompt = dataset[i][0]\n        prompt_token_ids = tokenizer.encode(prompt)\n        completion = dataset[i][1]\n        completion_token_ids = tokenizer.encode(completion)\n        prompt_len = len(prompt_token_ids)\n        output_len = (len(completion_token_ids) if fixed_output_len is None else fixed_output_len)\n        if prompt_len < 4 or output_len < 4:\n            # Prune too short sequences.\n            continue\n        if prompt_len > 1024 or (prompt_len + output_len > 2048 and fixed_output_len is None):\n            # Prune too long sequences.\n            continue\n\n        filtered_dataset.append(DatasetRow(\n            prompt=prompt,\n            prompt_len=prompt_len,\n            output_len=output_len,\n        ))\n\n    print(f'#Input tokens: {sum(x.prompt_len for x in filtered_dataset)}')\n    print(f'#Output tokens: {sum(x.output_len for x in filtered_dataset)}')\n    return filtered_dataset\n\n\ndef compute_random_lens(full_len: int, range_ratio: float, num: int):\n    return np.random.randint(\n        max(int(full_len * range_ratio), 1),\n        full_len + 1,\n        size=num,\n    )\n\n\ndef sample_random_requests(\n    input_len: int,\n    output_len: int,\n    num_prompts: int,\n    range_ratio: float,\n    tokenizer: PreTrainedTokenizerBase,\n    dataset_path: str,\n) -> List[DatasetRow]:\n\n    input_lens = compute_random_lens(\n        full_len=input_len,\n        range_ratio=range_ratio,\n        num=num_prompts,\n    )\n    output_lens = compute_random_lens(\n        full_len=output_len,\n        range_ratio=range_ratio,\n        num=num_prompts,\n    )\n\n    # Sample token ids from ShareGPT and repeat/truncate them to\n    # satisfy the input_lens\n\n    # Download sharegpt if necessary\n    if not os.path.isfile(dataset_path):\n        dataset_path = download_and_cache_file(SHAREGPT_URL)\n\n    # Load the dataset.\n    with open(dataset_path) as f:\n        dataset = json.load(f)\n    # Filter out the conversations with less than 2 turns.\n    dataset = [data for data in dataset if len(data['conversations']) >= 2]\n    # Only keep the first two turns of each conversation.\n    dataset = [(data['conversations'][0]['value'], data['conversations'][1]['value']) for data in dataset]\n    # remove the empty prompt\n    dataset = [(query, answer) for query, answer in dataset if len(query) > 0]\n\n    # Shuffle the dataset.\n    random.shuffle(dataset)\n\n    # Filter out sequences that are too long or too short\n    input_requests: List[DatasetRow] = []\n    origin_output_lens: List[int] = []\n    for i in range(num_prompts):\n        # Tokenize the prompts and completions.\n        prompt = dataset[i][0]\n        prompt_token_ids = tokenizer.encode(prompt)\n        prompt_len = len(prompt_token_ids)\n        completion = dataset[i][1]\n        completion_token_ids = tokenizer.encode(completion)\n        origin_output_lens.append(len(completion_token_ids))\n\n        if prompt_len > input_lens[i]:\n            input_ids = prompt_token_ids[:input_lens[i]]\n        else:\n            ratio = (input_lens[i] + prompt_len - 1) // prompt_len\n            input_ids = (prompt_token_ids * ratio)[:input_lens[i]]\n        prompt = tokenizer.decode(input_ids)\n        input_requests.append(DatasetRow(\n            prompt=prompt,\n            prompt_len=int(input_lens[i]),\n            output_len=int(output_lens[i]),\n        ))\n\n    print(f'#Input tokens: {sum(x.prompt_len for x in input_requests)}')\n    print(f'#Output tokens: {sum(x.output_len for x in input_requests)}')\n    return input_requests\n\n\ndef parse_image_resolution(image_resolution: str) -> Tuple[int, int]:\n    \"\"\"Parse image resolution into (width, height).\n\n    Supports presets '1080p', '720p', '360p'. And custom 'heightxwidth' format (e.g., '1080x1920' means height=1080,\n    width=1920) will be parsed into (width, height).\n    \"\"\"\n    resolution_to_size = {\n        '4k': (3840, 2160),\n        '1080p': (1920, 1080),\n        '720p': (1280, 720),\n        '360p': (640, 360),\n    }\n    if image_resolution in resolution_to_size:\n        return resolution_to_size[image_resolution]\n\n    res = image_resolution.strip().lower()\n    if 'x' in res:\n        parts = res.split('x')\n        if len(parts) == 2 and parts[0].isdigit() and parts[1].isdigit():\n            height = int(parts[0])\n            width = int(parts[1])\n            if height > 0 and width > 0:\n                return (width, height)\n\n    raise ValueError(f'Unsupported image resolution: {image_resolution}. '\n                     \"Choose from 4k, 1080p, 720p, 360p, or provide custom 'heightxwidth' (e.g., 1080x1920).\")\n\n\ndef gen_mm_prompt(tokenizer, image_pad_id, token_num):\n    \"\"\"Generate a random prompt of specified token length using tokenizer\n    vocabulary.\"\"\"\n    all_available_tokens = list(tokenizer.get_vocab().values())\n    if image_pad_id:\n        all_available_tokens.remove(image_pad_id)\n    selected_tokens = random.choices(all_available_tokens, k=token_num)\n    return tokenizer.decode(selected_tokens)\n\n\ndef create_mm_data_row(text_prompt, images: list, images_base64, output_len, processor, backend):\n    try:\n        content_items = [{'type': 'image', 'image': {'url': image_base64}} for image_base64 in images_base64]\n        content_items.append({'type': 'text', 'text': text_prompt})\n        prompt_str = processor.apply_chat_template(\n            [{\n                'role': 'user',\n                'content': content_items\n            }],\n            add_generation_prompt=True,\n            tokenize=False,\n        )\n    except Exception as e:\n        # Note (Xinyuan): This is a workaround for an issue where some tokenizers\n        # do not support content as a list. (e.g. InternVL)\n        print(f'Error applying chat template: {e}, fallback to <image> tag')\n        # Some tokenizers do not support list content; fall back to a placeholder in the text\n        prompt_str = f'<image>{text_prompt}'\n\n    # Calculate total tokens (text + vision)\n    prompt_len = processor(\n        text=[prompt_str],\n        images=images,\n        padding=False,\n        return_tensors='pt',\n    )['input_ids'].numel()\n\n    # Calculate text-only tokens\n    try:\n        # Create text-only version of the prompt\n        text_only_prompt = processor.apply_chat_template(\n            [{\n                'role': 'user',\n                'content': text_prompt\n            }],\n            add_generation_prompt=True,\n            tokenize=False,\n        )\n        text_prompt_len = processor(\n            text=[text_only_prompt],\n            padding=False,\n            return_tensors='pt',\n        )['input_ids'].numel()\n    except Exception:\n        # Fallback: just tokenize the text prompt directly\n        tokenizer_to_use = (processor.tokenizer if hasattr(processor, 'tokenizer') else processor)\n        text_prompt_len = len(tokenizer_to_use.encode(text_prompt))\n\n    # Vision tokens = total tokens - text tokens\n    vision_prompt_len = prompt_len - text_prompt_len\n\n    use_raw_prompt = backend in [\n        'sglang',\n        'sglang-oai',\n        'sglang-oai-chat',\n        'vllm',\n        'vllm-chat',\n        'lmdeploy',\n        'lmdeploy-chat',\n    ]\n    return DatasetRow(\n        prompt=text_prompt if use_raw_prompt else prompt_str,\n        prompt_len=prompt_len,\n        output_len=output_len,\n        text_prompt_len=text_prompt_len,\n        vision_prompt_len=vision_prompt_len,\n        image_data=images_base64,\n    )\n\n\ndef sample_image_requests(\n    num_requests: int,\n    image_count: int,\n    input_len: int,\n    output_len: int,\n    range_ratio: float,\n    processor: AutoProcessor,\n    image_content: str,\n    image_format: str,\n    image_resolution: str,\n    backend: str,\n) -> List[DatasetRow]:\n    \"\"\"Generate requests with images.\n\n    - Each request includes ``image_count`` images.\n    - Supported resolutions: 4k (3840x2160), 1080p (1920x1080), 720p (1280x720), 360p (640x360),\n      or custom 'heightxwidth' (e.g., 1080x1920).\n    - Text lengths follow the 'random' dataset sampling rule. ``prompt_len``\n      only counts text tokens and excludes image data.\n    \"\"\"\n\n    # Parse resolution (supports presets and 'heightxwidth')\n    width, height = parse_image_resolution(image_resolution)\n\n    # Check for potentially problematic combinations and warn user\n    if width * height >= 1920 * 1080 and image_count * num_requests >= 100:\n        warnings.warn(\n            f'High resolution ({width}x{height}) with {image_count * num_requests} total images '\n            f'may take a long time. Consider reducing resolution or image count.',\n            UserWarning,\n            stacklevel=2,\n        )\n\n    # Sample text lengths\n    input_lens = compute_random_lens(\n        full_len=input_len,\n        range_ratio=range_ratio,\n        num=num_requests,\n    )\n    output_lens = compute_random_lens(\n        full_len=output_len,\n        range_ratio=range_ratio,\n        num=num_requests,\n    )\n\n    def _gen_random_image_data_uri(width: int = width, height: int = height) -> Tuple[Image.Image, str, int]:\n        if image_content == 'blank':\n            # Generate blank white image\n            arr = np.full((height, width, 3), 255, dtype=np.uint8)\n        else:\n            # Generate random colored image\n            arr = (np.random.rand(height, width, 3) * 255).astype(np.uint8)\n        img = Image.fromarray(arr)\n        buf = io.BytesIO()\n        img.save(buf, format=image_format, quality=85)\n        encoded = pybase64.b64encode(buf.getvalue()).decode('utf-8')\n        image_data = f'data:image/{image_format};base64,{encoded}'  # noqa\n        image_bytes = len(image_data.encode('utf-8'))\n        return img, image_data, image_bytes\n\n    dataset: List[DatasetRow] = []\n    total_image_bytes = 0\n    for i in range(num_requests):\n        # Generate text prompt\n        text_prompt = gen_mm_prompt(\n            processor.tokenizer,\n            processor.image_token_id if hasattr(processor, 'image_token_id') else None,\n            int(input_lens[i]),\n        )\n\n        # Generate image list\n        images, images_base64, images_bytes = zip(*[_gen_random_image_data_uri() for _ in range(image_count)])\n        total_image_bytes += sum(list(images_bytes))\n\n        data_row = create_mm_data_row(\n            text_prompt,\n            list(images),\n            list(images_base64),\n            int(output_lens[i]),\n            processor,\n            backend,\n        )\n\n        dataset.append(data_row)\n    avg_image_bytes = total_image_bytes // num_requests if num_requests > 0 else 0\n\n    print(f'#Input tokens: {np.sum([x.prompt_len for x in dataset])}')\n    print(f'#Output tokens: {np.sum([x.output_len for x in dataset])}')\n    print(f'\\nCreated {len(dataset)} {image_content} {image_format} images \\\n            with average {avg_image_bytes} bytes per request')  # noqa\n    return dataset\n\n\nasync def get_request(\n    input_requests: List[DatasetRow],\n    request_rate: float,\n) -> AsyncGenerator[DatasetRow, None]:\n    input_requests = iter(input_requests)\n    for request in input_requests:\n        yield request\n\n        if request_rate == float('inf'):\n            # If the request rate is infinity, then we don't need to wait.\n            continue\n\n        # Sample the request interval from the exponential distribution.\n        interval = np.random.exponential(1.0 / request_rate)\n        # The next request will be sent after the interval.\n        await asyncio.sleep(interval)\n\n\ndef calculate_metrics(\n    input_requests: List[DatasetRow],\n    outputs: List[RequestFuncOutput],\n    dur_s: float,\n    tokenizer: PreTrainedTokenizerBase,\n    backend: str,\n) -> Tuple[BenchmarkMetrics, List[int]]:\n    output_lens: List[int] = []\n    retokenized_output_lens: List[int] = []\n    total_input = 0\n    total_input_text = 0\n    total_input_vision = 0\n    completed = 0\n    itls: List[float] = []\n    tpots: List[float] = []\n    ttfts: List[float] = []\n    e2e_latencies: List[float] = []\n\n    for i in range(len(outputs)):\n        if outputs[i].success:\n            output_len = outputs[i].output_len\n            output_lens.append(output_len)\n            retokenized_output_len = len(tokenizer.encode(outputs[i].generated_text, add_special_tokens=False))\n            retokenized_output_lens.append(retokenized_output_len)\n            total_input += input_requests[i].prompt_len\n            total_input_text += input_requests[i].text_prompt_len\n            total_input_vision += input_requests[i].vision_prompt_len\n            if output_len > 1:\n                tpots.append((outputs[i].latency - outputs[i].ttft) / (output_len - 1))\n            itls += outputs[i].itl\n            ttfts.append(outputs[i].ttft)\n\n            e2e_latencies.append(outputs[i].latency)\n\n            completed += 1\n        else:\n            output_lens.append(0)\n            retokenized_output_lens.append(0)\n\n    if completed == 0:\n        warnings.warn(\n            'All requests failed. This is likely due to a misconfiguration '\n            'on the benchmark arguments.',\n            stacklevel=2,\n        )\n    metrics = BenchmarkMetrics(\n        completed=completed,\n        total_input=total_input,\n        total_input_text=total_input_text,\n        total_input_vision=total_input_vision,\n        total_output=sum(output_lens),\n        total_output_retokenized=sum(retokenized_output_lens),\n        request_throughput=completed / dur_s,\n        input_throughput=total_input / dur_s,\n        output_throughput=sum(output_lens) / dur_s,\n        output_throughput_retokenized=sum(retokenized_output_lens) / dur_s,\n        mean_ttft_ms=np.mean(ttfts or 0) * 1000,  # ttfts is empty if streaming is not supported by backend\n        median_ttft_ms=np.median(ttfts or 0) * 1000,\n        std_ttft_ms=np.std(ttfts or 0) * 1000,\n        p99_ttft_ms=np.percentile(ttfts or 0, 99) * 1000,\n        mean_tpot_ms=np.mean(tpots or 0) * 1000,\n        median_tpot_ms=np.median(tpots or 0) * 1000,\n        std_tpot_ms=np.std(tpots or 0) * 1000,\n        p99_tpot_ms=np.percentile(tpots or 0, 99) * 1000,\n        mean_itl_ms=np.mean(itls or 0) * 1000,\n        median_itl_ms=np.median(itls or 0) * 1000,\n        std_itl_ms=np.std(itls or 0) * 1000,\n        p99_itl_ms=np.percentile(itls or 0, 99) * 1000,\n        mean_e2e_latency_ms=np.mean(e2e_latencies) * 1000,\n        median_e2e_latency_ms=np.median(e2e_latencies) * 1000,\n    )\n\n    return metrics, output_lens\n\n\nasync def benchmark(\n    backend: str,\n    api_url: str,\n    model_id: str,\n    tokenizer: PreTrainedTokenizerBase,\n    input_requests: List[DatasetRow],\n    request_rate: float,\n    disable_tqdm: bool,\n    extra_request_body: Dict[str, Any],\n):\n    if backend in ASYNC_REQUEST_FUNCS:\n        request_func = ASYNC_REQUEST_FUNCS[backend]\n    else:\n        raise ValueError(f'Unknown backend: {backend}')\n\n    if not args.disable_warmup:\n        print('Starting initial single prompt test run...')\n        start_warmup = time.perf_counter()\n        test_request = input_requests[0]\n        test_input = RequestFuncInput(\n            model=model_id,\n            prompt=test_request.prompt,\n            api_url=api_url,\n            prompt_len=test_request.prompt_len,\n            output_len=test_request.output_len,\n            extra_request_body=extra_request_body,\n            image_data=test_request.image_data,\n        )\n        test_output = await request_func(request_func_input=test_input)\n        if not test_output.success:\n            raise ValueError('Initial test run failed - Please make sure benchmark arguments '\n                             f'are correctly specified. Error: {test_output.error}')\n        else:\n            print('Initial test run completed. Starting main benchmark run...')\n        end_warmup = time.perf_counter()\n        print(f'warmup time: {end_warmup - start_warmup:.2f}s')\n        time.sleep(1.5)\n\n    pbar = None if disable_tqdm else tqdm(total=len(input_requests))\n\n    benchmark_start_time = time.perf_counter()\n    tasks: List[asyncio.Task] = []\n    async for request in get_request(input_requests, request_rate):\n        request_func_input = RequestFuncInput(\n            model=model_id,\n            prompt=request.prompt,\n            api_url=api_url,\n            prompt_len=request.prompt_len,\n            output_len=request.output_len,\n            image_data=request.image_data,\n            extra_request_body=extra_request_body,\n        )\n        tasks.append(asyncio.create_task(request_func(request_func_input=request_func_input, pbar=pbar)))\n    outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks)\n\n    if pbar is not None:\n        pbar.close()\n\n    benchmark_duration = time.perf_counter() - benchmark_start_time\n\n    metrics, output_lens = calculate_metrics(\n        input_requests=input_requests,\n        outputs=outputs,\n        dur_s=benchmark_duration,\n        tokenizer=tokenizer,\n        backend=backend,\n    )\n\n    print('\\n{s:{c}^{n}}'.format(s=' Serving Benchmark Result ', n=50, c='='))\n    print('{:<40} {:<10}'.format('Backend:', backend))\n    print('{:<40} {:<10}'.format('Traffic request rate:', request_rate))\n    print('{:<40} {:<10}'.format('Successful requests:', metrics.completed))\n    print('{:<40} {:<10.2f}'.format('Benchmark duration (s):', benchmark_duration))\n    print('{:<40} {:<10}'.format('Total input tokens:', metrics.total_input))\n    print('{:<40} {:<10}'.format('Total input text tokens:', metrics.total_input_text))\n    print('{:<40} {:<10}'.format('Total input vision tokens:', metrics.total_input_vision))\n    print('{:<40} {:<10}'.format('Total generated tokens:', metrics.total_output))\n    print('{:<40} {:<10}'.format('Total generated tokens (retokenized):', metrics.total_output_retokenized))\n    print('{:<40} {:<10.2f}'.format('Request throughput (req/s):', metrics.request_throughput))\n    print('{:<40} {:<10.2f}'.format('Input token throughput (tok/s):', metrics.input_throughput))\n    print('{:<40} {:<10.2f}'.format('Output token throughput (tok/s):', metrics.output_throughput))\n    print('{s:{c}^{n}}'.format(s='End-to-End Latency', n=50, c='-'))\n    print('{:<40} {:<10.2f}'.format('Mean E2E Latency (ms):', metrics.mean_e2e_latency_ms))\n    print('{:<40} {:<10.2f}'.format('Median E2E Latency (ms):', metrics.median_e2e_latency_ms))\n    print('{s:{c}^{n}}'.format(s='Time to First Token', n=50, c='-'))\n    print('{:<40} {:<10.2f}'.format('Mean TTFT (ms):', metrics.mean_ttft_ms))\n    print('{:<40} {:<10.2f}'.format('Median TTFT (ms):', metrics.median_ttft_ms))\n    print('{:<40} {:<10.2f}'.format('P99 TTFT (ms):', metrics.p99_ttft_ms))\n    print('{s:{c}^{n}}'.format(s='Time per Output Token (excl. 1st token)', n=50, c='-'))\n    print('{:<40} {:<10.2f}'.format('Mean TPOT (ms):', metrics.mean_tpot_ms))\n    print('{:<40} {:<10.2f}'.format('Median TPOT (ms):', metrics.median_tpot_ms))\n    print('{:<40} {:<10.2f}'.format('P99 TPOT (ms):', metrics.p99_tpot_ms))\n    print('{s:{c}^{n}}'.format(s='Inter-token Latency', n=50, c='-'))\n    print('{:<40} {:<10.2f}'.format('Mean ITL (ms):', metrics.mean_itl_ms))\n    print('{:<40} {:<10.2f}'.format('Median ITL (ms):', metrics.median_itl_ms))\n    print('{:<40} {:<10.2f}'.format('P99 ITL (ms):', metrics.p99_itl_ms))\n    print('=' * 50)\n\n    if (metrics.median_ttft_ms is not None and metrics.mean_itl_ms is not None\n            and metrics.output_throughput is not None):\n        FIELD_NAMES = [\n            'backend', 'dataset_name', 'sharegpt_output_len', 'random_input_len', 'random_output_len',\n            'random_range_ratio', 'request_rate', 'completed', 'total_input_tokens', 'total_output_tokens', 'duration',\n            'request_throughput', 'input_throughput', 'output_throughput', 'mean_e2e_latency_ms', 'mean_ttft_ms',\n            'mean_tpot_ms', 'mean_itl_ms'\n        ]\n        result = {\n            'backend': args.backend,\n            'dataset_name': args.dataset_name,\n            'request_rate': request_rate,\n            'total_input_tokens': metrics.total_input,\n            'total_output_tokens': metrics.total_output,\n            'mean_e2e_latency_ms': metrics.mean_e2e_latency_ms,\n            'output_throughput': metrics.output_throughput,\n            'sharegpt_output_len': args.sharegpt_output_len,\n            'random_input_len': args.random_input_len,\n            'random_output_len': args.random_output_len,\n            'random_range_ratio': args.random_range_ratio,\n            'duration': benchmark_duration,\n            'completed': metrics.completed,\n            'request_throughput': metrics.request_throughput,\n            'input_throughput': metrics.input_throughput,\n            'mean_ttft_ms': metrics.mean_ttft_ms,\n            'mean_tpot_ms': metrics.mean_tpot_ms,\n            'mean_itl_ms': metrics.mean_itl_ms,\n        }\n    else:\n        print(f'Error running benchmark for request rate: {request_rate}')\n        print('-' * 30)\n\n    # Determine output file name\n    if args.output_file:\n        output_file_name = args.output_file\n    else:\n        now = datetime.now().strftime('%m%d')\n        if args.dataset_name == 'random':\n            output_file_name = f'{args.backend}_{now}_{args.num_prompts}_{args.random_input_len}_{args.random_output_len}.jsonl'  # noqa\n        else:\n            output_file_name = f'{args.backend}_{now}_{args.num_prompts}_sharegpt.csv'  # noqa\n\n    # Append results to a CSV file\n    file_exists = os.path.isfile(output_file_name)\n    with open(output_file_name, mode='a', newline='') as f:\n        writer = csv.DictWriter(f, fieldnames=FIELD_NAMES)\n        if not file_exists:\n            writer.writeheader()\n        writer.writerow(result)\n\n    result = {\n        'duration': benchmark_duration,\n        'completed': metrics.completed,\n        'total_input_tokens': metrics.total_input,\n        'total_output_tokens': metrics.total_output,\n        'total_output_tokens_retokenized': metrics.total_output_retokenized,\n        'request_throughput': metrics.request_throughput,\n        'input_throughput': metrics.input_throughput,\n        'output_throughput': metrics.output_throughput,\n        'mean_ttft_ms': metrics.mean_ttft_ms,\n        'median_ttft_ms': metrics.median_ttft_ms,\n        'std_ttft_ms': metrics.std_ttft_ms,\n        'p99_ttft_ms': metrics.p99_ttft_ms,\n        'mean_tpot_ms': metrics.mean_tpot_ms,\n        'median_tpot_ms': metrics.median_tpot_ms,\n        'std_tpot_ms': metrics.std_tpot_ms,\n        'p99_tpot_ms': metrics.p99_tpot_ms,\n        'mean_itl_ms': metrics.mean_itl_ms,\n        'median_itl_ms': metrics.median_itl_ms,\n        'std_itl_ms': metrics.std_itl_ms,\n        'p99_itl_ms': metrics.p99_itl_ms,\n        'input_lens': [output.prompt_len for output in outputs],\n        'output_lens': output_lens,\n        'ttfts': [output.ttft for output in outputs],\n        'itls': [output.itl for output in outputs],\n        'generated_texts': [output.generated_text for output in outputs],\n        'errors': [output.error for output in outputs],\n        'mean_e2e_latency_ms': metrics.mean_e2e_latency_ms,\n        'median_e2e_latency_ms': metrics.median_e2e_latency_ms,\n    }\n    return result\n\n\ndef parse_request_rate_range(request_rate_range):\n    if len(request_rate_range.split(',')) == 3:\n        start, stop, step = map(int, request_rate_range.split(','))\n        return list(range(start, stop, step))\n    else:\n        return list(map(int, request_rate_range.split(',')))\n\n\ndef check_chat_template(model_path):\n    try:\n        tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)\n        return 'chat_template' in tokenizer.init_kwargs\n    except Exception as e:\n        print(f'Fail to load tokenizer config with error={e}')\n        return False\n\n\ndef run_benchmark(args_: argparse.Namespace):\n    global args\n    args = args_\n\n    # Set global environments\n    set_ulimit()\n    random.seed(args.seed)\n    np.random.seed(args.seed)\n\n    extra_request_body = {}\n    if args.extra_request_body:\n        extra_request_body = json.loads(args.extra_request_body)\n\n    # Set url\n    if args.port is None:\n        args.port = {\n            'sglang': 30000,\n            'sglang-native': 30000,\n            'sglang-oai': 30000,\n            'sglang-oai-chat': 30000,\n            'lmdeploy': 23333,\n            'lmdeploy-chat': 23333,\n            'vllm': 8000,\n            'vllm-chat': 8000,\n            'trt': 8000,\n            'gserver': 9988,\n        }.get(args.backend, 30000)\n\n    model_url = (f'{args.base_url}/v1/models' if args.base_url else f'http://{args.host}:{args.port}/v1/models')\n\n    if args.backend in ['sglang', 'sglang-native']:\n        api_url = (f'{args.base_url}/generate' if args.base_url else f'http://{args.host}:{args.port}/generate')\n    elif args.backend in ['sglang-oai', 'vllm', 'lmdeploy']:\n        api_url = (f'{args.base_url}/v1/completions'\n                   if args.base_url else f'http://{args.host}:{args.port}/v1/completions')\n    elif args.backend in ['lmdeploy-chat', 'vllm-chat', 'sglang-oai-chat']:\n        api_url = (f'{args.base_url}/v1/chat/completions'\n                   if args.base_url else f'http://{args.host}:{args.port}/v1/chat/completions')\n    elif args.backend == 'trt':\n        api_url = (\n            f'{args.base_url}/v2/models/ensemble/generate_stream'\n            if args.base_url else f'http://{args.host}:{args.port}/v2/models/ensemble/generate_stream'  # noqa\n        )\n        if args.model is None:\n            print('Please provide a model using `--model` when using '\n                  '`trt` backend.')\n            sys.exit(1)\n    elif args.backend == 'gserver':\n        api_url = args.base_url if args.base_url else \\\n            f'{args.host}:{args.port}'\n        args.model = args.model or 'default'\n\n    # Get model name\n    if args.model is None:\n        try:\n            response = requests.get(model_url)\n            model_list = response.json().get('data', [])\n            args.model = model_list[0]['id'] if model_list else None\n        except Exception as e:\n            print(f'Failed to fetch model from {model_url}. Error: {e}')\n            print('Please specify the correct host and port using '\n                  '`--host` and `--port`.')\n            sys.exit(1)\n\n    # Read dataset\n    backend = args.backend\n    model_id = args.model\n    model_path = args.model_path if args.model_path is not None else args.model\n    tokenizer_id = args.tokenizer if args.tokenizer is not None else model_path\n\n    if args.model is None:\n        print('No model specified or found. Please provide a model '\n              'using `--model`.')\n        sys.exit(1)\n\n    if not check_chat_template(model_path):\n        print('\\nWARNING It is recommended to use the `Chat` or `Instruct` '\n              'model for benchmarking.\\n'\n              'Because when the tokenizer counts the output tokens, if '\n              'there is gibberish, it might count incorrectly.\\n')\n\n    print(f'{args}\\n')\n\n    tokenizer = get_tokenizer(tokenizer_id)\n\n    if args.dataset_name == 'sharegpt':\n        assert args.random_input_len is None and args.random_output_len is None\n        input_requests = sample_sharegpt_requests(\n            dataset_path=args.dataset_path,\n            num_requests=args.num_prompts,\n            tokenizer=tokenizer,\n            fixed_output_len=args.sharegpt_output_len,\n        )\n    elif args.dataset_name == 'random':\n        assert args.random_input_len is not None and \\\n            args.random_output_len is not None\n        input_requests = sample_random_requests(\n            input_len=args.random_input_len,\n            output_len=args.random_output_len,\n            num_prompts=args.num_prompts,\n            range_ratio=args.random_range_ratio,\n            tokenizer=tokenizer,\n            dataset_path=args.dataset_path,\n        )\n    elif args.dataset_name == 'image':\n        processor = get_processor(model_path)\n        input_requests = sample_image_requests(\n            num_requests=args.num_prompts,\n            image_count=args.image_count,\n            input_len=args.random_input_len,\n            output_len=args.random_output_len,\n            range_ratio=args.random_range_ratio,\n            processor=processor,\n            image_content=args.image_content,\n            image_format=args.image_format,\n            image_resolution=args.image_resolution,\n            backend=args.backend,\n        )\n    else:\n        raise ValueError(f'Unknown dataset: {args.dataset_name}')\n\n    if not args.multi:\n        return asyncio.run(\n            benchmark(\n                backend=backend,\n                api_url=api_url,\n                model_id=model_id,\n                tokenizer=tokenizer,\n                input_requests=input_requests,\n                request_rate=args.request_rate,\n                disable_tqdm=args.disable_tqdm,\n                extra_request_body=extra_request_body,\n            ))\n    else:\n        # Benchmark multiple rps.\n        # TODO: use a fixed duration to compute num_prompts\n        request_rates = parse_request_rate_range(args.request_rate_range)\n\n        for rate in request_rates:\n            asyncio.run(\n                benchmark(\n                    backend=backend,\n                    api_url=api_url,\n                    model_id=model_id,\n                    tokenizer=tokenizer,\n                    input_requests=input_requests,\n                    request_rate=rate,\n                    disable_tqdm=args.disable_tqdm,\n                    extra_request_body=extra_request_body,\n                ))\n\n\ndef set_ulimit(target_soft_limit=65535):\n    resource_type = resource.RLIMIT_NOFILE\n    current_soft, current_hard = resource.getrlimit(resource_type)\n\n    if current_soft < target_soft_limit:\n        try:\n            resource.setrlimit(resource_type, (target_soft_limit, current_hard))\n        except ValueError as e:\n            print(f'Fail to set RLIMIT_NOFILE: {e}')\n\n\nif __name__ == '__main__':\n    parser = ArgumentParser(description='Benchmark the online serving throughput.')\n    parser.add_argument(\n        '--backend',\n        type=str,\n        choices=list(ASYNC_REQUEST_FUNCS.keys()),\n        default='sglang',\n        help='Must specify a backend, depending on the LLM Inference Engine.',\n    )\n    parser.add_argument(\n        '--base-url',\n        type=str,\n        default=None,\n        help='Server or API base url if not using http host and port.',\n    )\n    parser.add_argument('--host', type=str, default='0.0.0.0', help='Default host is 0.0.0.0.')\n    parser.add_argument(\n        '--port',\n        type=int,\n        help='If not set, the default port is configured according to its '\n        'default value for different LLM Inference Engines.',\n    )\n    parser.add_argument(\n        '--dataset-name',\n        type=str,\n        default='sharegpt',\n        choices=['sharegpt', 'random', 'image'],\n        help='Name of the dataset to benchmark on.',\n    )\n    parser.add_argument('--dataset-path', type=str, default='', help='Path to the dataset.')\n    parser.add_argument(\n        '--model',\n        type=str,\n        help='Name or path of the model. If not set, the default model will '\n        'request /v1/models for conf.',\n    )\n    parser.add_argument(\n        '--model-path',\n        type=str,\n        help='Path to the model. If not set, the default model will be model',\n    )\n    parser.add_argument(\n        '--tokenizer',\n        type=str,\n        help='Name or path of the tokenizer. If not set, using the model '\n        'conf.',\n    )\n    parser.add_argument(\n        '--num-prompts',\n        type=int,\n        default=1000,\n        help='Number of prompts to process. Default is 1000.',\n    )\n    parser.add_argument(\n        '--sharegpt-output-len',\n        type=int,\n        default=None,\n        help='Output length for each request. Overrides the output length '\n        'from the ShareGPT dataset.',\n    )\n    parser.add_argument(\n        '--random-input-len',\n        type=int,\n        help='Number of input tokens per request, used only for random '\n        'dataset.',\n    )\n    parser.add_argument(\n        '--random-output-len',\n        type=int,\n        help='Number of output tokens per request, used only for random '\n        'dataset.',\n    )\n    parser.add_argument(\n        '--random-range-ratio',\n        type=float,\n        default=0.0,\n        help='Range of sampled ratio of input/output length, '\n        'used only for random dataset.',\n    )\n    # image dataset args\n    parser.add_argument(\n        '--image-count',\n        type=int,\n        default=1,\n        help='Number of images per request (only available with the image dataset)',\n    )\n    parser.add_argument(\n        '--image-resolution',\n        type=str,\n        default='1080p',\n        help=('Resolution of images for image dataset. '\n              \"Supports presets 4k/1080p/720p/360p or custom 'heightxwidth' (e.g., 1080x1920).\"),\n    )\n    parser.add_argument(\n        '--image-format',\n        type=str,\n        default='jpeg',\n        help=('Format of images for image dataset. '\n              'Supports jpeg and png.'),\n    )\n    parser.add_argument(\n        '--image-content',\n        type=str,\n        default='random',\n        help=('Content for images for image dataset. '\n              'Supports random and blank.'),\n    )\n    parser.add_argument(\n        '--request-rate',\n        type=float,\n        default=float('inf'),\n        help='Number of requests per second. If this is inf, then all the '\n        'requests are sent at time 0. Otherwise, we use Poisson process to '\n        'synthesize the request arrival times. Default is inf.',\n    )\n    parser.add_argument('--seed', type=int, default=1, help='The random seed.')\n    parser.add_argument(\n        '--multi',\n        action='store_true',\n        help='Use request rate range rather than single value.',\n    )\n    parser.add_argument(\n        '--request-rate-range',\n        type=str,\n        default='2,34,2',\n        help='Range of request rates in the format start,stop,step. Default '\n        'is 2,34,2. It also supports a list of request rates, requiring '\n        'the parameters to not equal three.',\n    )\n    parser.add_argument('--output-file', type=str, help='Output JSONL file name.')\n    parser.add_argument(\n        '--disable-tqdm',\n        action='store_true',\n        help='Specify to disable tqdm progress bar.',\n    )\n    parser.add_argument(\n        '--disable-stream',\n        action='store_true',\n        help='Disable streaming mode.',\n    )\n    parser.add_argument(\n        '--disable-ignore-eos',\n        action='store_true',\n        help='Disable ignoring EOS.',\n    )\n    parser.add_argument(\n        '--extra-request-body',\n        metavar='{\"key1\": \"value1\", \"key2\": \"value2\"}',\n        type=str,\n        help='Append given JSON object to the request payload. You can use '\n        'this to specify additional generate params like sampling params.',\n    )\n    parser.add_argument(\n        '--disable-warmup',\n        action='store_true',\n        default=None,\n        help='Disable a warmup request before the benchmark. ',\n    )\n    args = parser.parse_args()\n    run_benchmark(args)\n"
  },
  {
    "path": "benchmark/profile_throughput.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport argparse\nimport asyncio\nimport json\nimport os\nimport random\nfrom queue import Queue\nfrom typing import List, Optional, Tuple, Union\n\nimport numpy as np\nfrom tqdm import tqdm\nfrom transformers import PreTrainedTokenizerBase\n\nfrom lmdeploy.cli.utils import ArgumentHelper, DefaultsAndTypesHelpFormatter\nfrom lmdeploy.messages import GenerationConfig, PytorchEngineConfig, TurbomindEngineConfig\nfrom lmdeploy.profiler import Profiler, Session\nfrom lmdeploy.tokenizer import DetokenizeState, Tokenizer\nfrom lmdeploy.utils import get_logger\n\nget_logger('lmdeploy').setLevel('ERROR')\nos.environ['TM_LOG_LEVEL'] = 'ERROR'\n\n\ndef sample_sharegpt_requests(\n    dataset_path: str,\n    num_requests: int,\n    tokenizer: PreTrainedTokenizerBase,\n    fixed_output_len: Optional[int] = None,\n) -> List[Tuple[str, int, int]]:\n    if fixed_output_len is not None and fixed_output_len < 4:\n        raise ValueError('output_len too small')\n    # Load the dataset.\n    with open(dataset_path) as f:\n        dataset = json.load(f)\n    # Filter out the conversations with less than 2 turns.\n    dataset = [data for data in dataset if len(data['conversations']) >= 2]\n    # Only keep the first two turns of each conversation.\n    dataset = [(data['conversations'][0]['value'], data['conversations'][1]['value']) for data in dataset]\n\n    # Shuffle the dataset.\n    random.shuffle(dataset)\n\n    # Filter out sequences that are too long or too short\n    filtered_dataset: List[Tuple[str, int, int]] = []\n    for i in range(len(dataset)):\n        if len(filtered_dataset) == num_requests:\n            break\n\n        # Tokenize the prompts and completions.\n        prompt = dataset[i][0]\n        prompt_token_ids = tokenizer.encode(prompt)\n        completion = dataset[i][1]\n        completion_token_ids = tokenizer.encode(completion)\n        prompt_len = len(prompt_token_ids)\n        output_len = (len(completion_token_ids) if fixed_output_len is None else fixed_output_len)\n        if prompt_len < 4 or output_len < 4:\n            # Prune too short sequences.\n            continue\n        if prompt_len > 1024 or (prompt_len + output_len > 2048 and fixed_output_len is None):\n            # Prune too long sequences.\n            continue\n        filtered_dataset.append((prompt, prompt_len, output_len))\n\n    print(f'#Input tokens: {np.sum([x[1] for x in filtered_dataset])}')\n    print(f'#Output tokens: {np.sum([x[2] for x in filtered_dataset])}')\n    return filtered_dataset\n\n\ndef sample_random_requests(\n    input_len: int,\n    output_len: int,\n    num_prompts: int,\n    range_ratio: float,\n    tokenizer: PreTrainedTokenizerBase,\n    dataset_path: str,\n) -> List[Tuple[str, int, int]]:\n\n    input_lens = np.random.randint(\n        max(int(input_len * range_ratio), 1),\n        input_len + 1,\n        size=num_prompts,\n    )\n    output_lens = np.random.randint(\n        int(output_len * range_ratio),\n        output_len + 1,\n        size=num_prompts,\n    )\n\n    if True:\n        # Sample token ids from ShareGPT and repeat/truncate them to\n        # satisfy the input_lens\n\n        # Load the dataset.\n        with open(dataset_path) as f:\n            dataset = json.load(f)\n        # Filter out the conversations with less than 2 turns.\n        dataset = [data for data in dataset if len(data['conversations']) >= 2]\n        # Only keep the first two turns of each conversation.\n        dataset = [(data['conversations'][0]['value'], data['conversations'][1]['value']) for data in dataset]\n        # remove the empty prompt\n        dataset = [(query, answer) for query, answer in dataset if len(query) > 0]\n\n        # Shuffle the dataset.\n        random.shuffle(dataset)\n\n        # Filter out sequences that are too long or too short\n        input_requests: List[Tuple[str, int, int]] = []\n        for i in range(num_prompts):\n            # Tokenize the prompts and completions.\n            prompt = dataset[i][0]\n            prompt_token_ids = tokenizer.encode(prompt)\n            prompt_len = len(prompt_token_ids)\n\n            if prompt_len > input_lens[i]:\n                input_ids = prompt_token_ids[:input_lens[i]]\n            else:\n                ratio = (input_lens[i] + prompt_len - 1) // prompt_len\n                input_ids = (prompt_token_ids * ratio)[:input_lens[i]]\n            prompt = tokenizer.decode(input_ids)\n            input_requests.append((prompt, int(input_lens[i]), int(output_lens[i])))\n    else:\n        # Sample token ids from random integers.\n        # This can cause some NaN issues.\n        offsets = np.random.randint(0, tokenizer.vocab_size, size=num_prompts)\n        input_requests = []\n        for i in range(num_prompts):\n            prompt = tokenizer.decode([(offsets[i] + i + j) % tokenizer.vocab_size for j in range(input_lens[i])])\n            input_requests.append((prompt, int(input_lens[i]), int(output_lens[i])))\n\n    print(f'#Input tokens: {np.sum(input_lens)}')\n    print(f'#Output tokens: {np.sum(output_lens)}')\n    return input_requests\n\n\nclass Engine:\n\n    def __init__(self, model_path: str, engine_config: Union[PytorchEngineConfig, TurbomindEngineConfig]):\n        self.tokenizer = Tokenizer(model_path)\n        if isinstance(engine_config, TurbomindEngineConfig):\n            from lmdeploy.turbomind import TurboMind\n            tm_model = TurboMind.from_pretrained(model_path, engine_config=engine_config)\n            self.backend = 'turbomind'\n        elif isinstance(engine_config, PytorchEngineConfig):\n            from lmdeploy.pytorch.engine import Engine as PytorchEngine\n            tm_model = PytorchEngine.from_pretrained(model_path, engine_config=engine_config)\n            self.backend = 'pytorch'\n\n        self.tm_model = tm_model\n        self.pbar = None\n\n    async def _inference(self, req_queue: Queue, session_id: int, temperature: float, top_p: float, top_k: int,\n                         stream_output: bool, skip_tokenize: bool, skip_detokenize: bool, concurrency: int):\n        model_inst = self.tm_model.create_instance()\n        sess: Session = None\n        for prompt, _, output_seqlen, cancel_after, sess in iter(req_queue.get_nowait, None):\n\n            sess.tick(0)\n\n            if skip_tokenize:\n                input_ids = prompt\n            else:\n                input_ids = self.tokenizer(prompt).input_ids\n\n            state = DetokenizeState(len(input_ids))\n\n            n_token = 0\n            token_ids = input_ids.copy()\n\n            generator = model_inst.async_stream_infer(session_id,\n                                                      input_ids=input_ids,\n                                                      gen_config=GenerationConfig(max_new_tokens=output_seqlen,\n                                                                                  temperature=temperature,\n                                                                                  top_p=top_p,\n                                                                                  top_k=top_k,\n                                                                                  ignore_eos=True),\n                                                      sequence_start=True,\n                                                      sequence_end=True,\n                                                      stream_output=stream_output)\n            try:\n                async for outputs in generator:\n                    n_token += len(outputs.token_ids)\n                    token_ids += outputs.token_ids\n                    if not skip_detokenize:\n                        _, state = self.tokenizer.detokenize_incrementally(token_ids, state)\n                    sess.tick(n_token)\n                    if n_token > cancel_after:\n                        break\n                sess.finish(Session.SUCCESS)\n            finally:\n                await generator.aclose()\n\n            # for pytorch engine to restart a session\n            if self.backend == 'pytorch':\n                await model_inst.async_end(session_id)\n\n            self.pbar.update(1)\n            session_id += concurrency\n\n    def process_request(self, requests, profiler: Profiler, concurrency, temperature, top_p, top_k, stream_output,\n                        skip_tokenize, skip_detokenize, cancel_rate):\n        req_queue = Queue()\n\n        # feed request to q\n        for prompt, input_len, output_len in requests:\n            cancel_after = output_len + 1\n            if cancel_rate > 0:\n                if random.random() < cancel_rate:\n                    cancel_after = random.randint(0, cancel_after)\n            sess = profiler.new_session(input_len, output_len)\n            req = [prompt, input_len, output_len, cancel_after, sess]\n            if skip_tokenize:\n                req[0] = self.tokenizer.encode(prompt)\n            req_queue.put(req)\n        for i in range(concurrency):\n            req_queue.put(None)\n\n        # start threads\n        tasks = []\n        for i in range(concurrency):\n            task = self._inference(req_queue, i, temperature, top_p, top_k, stream_output, skip_tokenize,\n                                   skip_detokenize, concurrency)\n            tasks.append(task)\n\n        async def _gather_tasks(tasks):\n            profiler.start()\n            ret = await asyncio.gather(*tasks)\n            profiler.finish()\n            return ret\n\n        self.pbar = tqdm(total=len(requests))\n\n        asyncio.run(_gather_tasks(tasks))\n\n        self.pbar.close()\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description='Benchmark the request throughput of lmdeploy '\n                                     'in localhost',\n                                     formatter_class=DefaultsAndTypesHelpFormatter)\n    parser.add_argument('dataset', type=str, help='the path dataset')\n    parser.add_argument('model_path',\n                        type=str,\n                        help='the path of the model in localhost or '\n                        'the repo_id of the model in huggingface.co')\n    parser.add_argument('-c',\n                        '--concurrency',\n                        type=int,\n                        help='Number of working threads to process the sampled prompts',\n                        default=256)\n    parser.add_argument('-n', '--num-prompts', type=int, help='Number of prompts to process', default=5000)\n    parser.add_argument('--no-stream-output', action='store_true', help='Use stream output')\n    parser.add_argument('--skip-tokenize', action='store_true', help='Pre-tokenize input prompts before starting')\n    parser.add_argument('--skip-detokenize', action='store_true', help='Skip detokenizing output tokens')\n    parser.add_argument('--cancel-rate', type=float, help='Possibility of a request being canceled', default=0)\n    parser.add_argument('--use-uvloop', action='store_true')\n    parser.add_argument('--csv', type=str, help='Where to save the result.', default='./profile_throughput.csv')\n    parser.add_argument('--seed', type=int, default=0, help='Seed used in sampling prompts from dataset')\n    parser.add_argument('--distributed-executor-backend',\n                        type=str,\n                        default=None,\n                        choices=['uni', 'mp', 'ray'],\n                        help='backend of executor backend')\n    parser.add_argument('--dataset-name',\n                        type=str,\n                        default='sharegpt',\n                        choices=['sharegpt', 'random'],\n                        help='Name of the dataset to benchmark on.')\n    parser.add_argument(\n        '--sharegpt-output-len',\n        type=int,\n        default=None,\n        help='Output length for each request. Overrides the output length '\n        'from the ShareGPT dataset.',\n    )\n    parser.add_argument(\n        '--random-input-len',\n        type=int,\n        help='Number of input tokens per request, used only for random '\n        'dataset.',\n    )\n    parser.add_argument(\n        '--random-output-len',\n        type=int,\n        help='Number of output tokens per request, used only for random '\n        'dataset.',\n    )\n    parser.add_argument(\n        '--random-range-ratio',\n        type=float,\n        default=0.0,\n        help='Range of sampled ratio of input/output length, '\n        'used only for random dataset.',\n    )\n    # other args\n    ArgumentHelper.top_p(parser)\n    ArgumentHelper.temperature(parser)\n    ArgumentHelper.top_k(parser)\n    ArgumentHelper.backend(parser)\n\n    # pytorch engine args\n    pt_group = parser.add_argument_group('PyTorch engine arguments')\n    ArgumentHelper.eager_mode(pt_group)\n    ArgumentHelper.dllm_block_length(pt_group)\n    ArgumentHelper.dllm_unmasking_strategy(pt_group)\n    ArgumentHelper.dllm_denoising_steps(pt_group)\n    ArgumentHelper.dllm_confidence_threshold(pt_group)\n\n    tp_act = ArgumentHelper.tp(pt_group)\n    cache_count_act = ArgumentHelper.cache_max_entry_count(pt_group)\n    cache_block_seq_len_act = ArgumentHelper.cache_block_seq_len(pt_group)\n    prefix_caching_act = ArgumentHelper.enable_prefix_caching(pt_group)\n    quant_policy_act = ArgumentHelper.quant_policy(pt_group, default=0)\n    dtype_act = ArgumentHelper.dtype(pt_group)\n\n    # turbomind engine args\n    tb_group = parser.add_argument_group('TurboMind engine argument')\n    tb_group._group_actions.append(tp_act)\n    tb_group._group_actions.append(cache_count_act)\n    tb_group._group_actions.append(cache_block_seq_len_act)\n    tb_group._group_actions.append(prefix_caching_act)\n    tb_group._group_actions.append(quant_policy_act)\n    tb_group._group_actions.append(dtype_act)\n\n    ArgumentHelper.dp(tb_group)\n    ArgumentHelper.cp(tb_group)\n    ArgumentHelper.model_format(tb_group, default='hf')\n    ArgumentHelper.num_tokens_per_iter(tb_group)\n    ArgumentHelper.max_prefill_iters(tb_group)\n    ArgumentHelper.async_(tb_group)\n    ArgumentHelper.communicator(tb_group)\n\n    args = parser.parse_args()\n    return args\n\n\ndef main():\n    args = parse_args()\n    random.seed(args.seed)\n    if args.backend == 'turbomind':\n        engine_config = TurbomindEngineConfig(\n            max_batch_size=args.concurrency // args.dp,\n            tp=args.tp,\n            dp=args.dp,\n            cp=args.cp,\n            cache_max_entry_count=args.cache_max_entry_count,\n            cache_block_seq_len=args.cache_block_seq_len,\n            model_format=args.model_format,\n            quant_policy=args.quant_policy,\n            num_tokens_per_iter=args.num_tokens_per_iter,\n            max_prefill_iters=args.max_prefill_iters,\n            async_=args.async_,\n            enable_prefix_caching=args.enable_prefix_caching,\n            dtype=args.dtype,\n            communicator=args.communicator,\n        )\n    elif args.backend == 'pytorch':\n        engine_config = PytorchEngineConfig(\n            cache_max_entry_count=args.cache_max_entry_count,\n            block_size=args.cache_block_seq_len,\n            max_batch_size=args.concurrency,\n            tp=args.tp,\n            eager_mode=args.eager_mode,\n            enable_prefix_caching=args.enable_prefix_caching,\n            quant_policy=args.quant_policy,\n            dtype=args.dtype,\n            distributed_executor_backend=args.distributed_executor_backend,\n            dllm_block_length=args.dllm_block_length,\n            dllm_unmasking_strategy=args.dllm_unmasking_strategy,\n            dllm_denoising_steps=args.dllm_denoising_steps,\n            dllm_confidence_threshold=args.dllm_confidence_threshold,\n        )\n\n    if args.use_uvloop:\n        import uvloop\n        asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())\n\n    engine = Engine(args.model_path, engine_config)\n\n    if args.dataset_name == 'sharegpt':\n        assert args.random_input_len is None and args.random_output_len is None\n        requests = sample_sharegpt_requests(\n            dataset_path=args.dataset,\n            num_requests=args.num_prompts,\n            tokenizer=engine.tokenizer.model.model,\n            fixed_output_len=args.sharegpt_output_len,\n        )\n    elif args.dataset_name == 'random':\n        assert args.random_input_len is not None and \\\n            args.random_output_len is not None\n        requests = sample_random_requests(\n            input_len=args.random_input_len,\n            output_len=args.random_output_len,\n            num_prompts=args.num_prompts,\n            range_ratio=args.random_range_ratio,\n            tokenizer=engine.tokenizer.model.model,\n            dataset_path=args.dataset,\n        )\n    else:\n        raise ValueError(f'Unknown dataset: {args.dataset_name}')\n\n    stream_output = not args.no_stream_output\n\n    profiler = Profiler(stream_output, [50, 75, 95, 99])\n\n    engine.process_request(requests,\n                           profiler,\n                           temperature=args.temperature,\n                           top_p=args.top_p,\n                           top_k=args.top_k,\n                           concurrency=args.concurrency if args.concurrency < args.num_prompts else args.num_prompts,\n                           stream_output=not args.no_stream_output,\n                           skip_tokenize=args.skip_tokenize,\n                           skip_detokenize=args.skip_detokenize,\n                           cancel_rate=args.cancel_rate)\n\n    hyperparams = [('Concurrency', args.concurrency), ('Cancel rate', args.cancel_rate),\n                   ('Stream output', str(stream_output).lower()), ('Skip tokenize', str(args.skip_tokenize).lower()),\n                   ('Skip detokenize', str(args.skip_detokenize).lower())]\n    profiler.compute_metrics()\n    profiler.summarize(title='Profile Throughput', hyperparams=hyperparams)\n    if args.csv:\n        profiler.save_csv(args.csv, (\n            ('backend', args.backend),\n            ('bs', args.concurrency),\n            ('dataset_name', args.dataset_name),\n            ('sharegpt_output_len', args.sharegpt_output_len),\n            ('random_input_len', args.random_input_len),\n            ('random_output_len', args.random_output_len),\n            ('random_range_ratio', args.random_range_ratio),\n            ('num_prompts', args.num_prompts),\n        ))\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "builder/manywheel/Dockerfile_2014",
    "content": "# WARNING: CentOS 7 is out of date since 6/30/2024, we should use the following one in the future\n# FROM quay.io/pypa/manylinux_2_28_x86_64 as base\nFROM quay.io/pypa/manylinux2014_x86_64 as base\nARG BASE_CUDA_VERSION=11.8\n\nENV LC_ALL en_US.UTF-8\nENV LANG en_US.UTF-8\nENV LANGUAGE en_US.UTF-8\n\nRUN sed -i 's|^mirrorlist=|#mirrorlist=|g' /etc/yum.repos.d/CentOS-*.repo && \\\n    sed -i 's|^#baseurl=http://mirror.centos.org|baseurl=https://vault.centos.org|g' /etc/yum.repos.d/CentOS-*.repo && \\\n    yum install -y \\\n    wget \\\n    rapidjson-devel \\\n    glog-devel && \\\n    yum clean all\n\nENV LD_LIBRARY_PATH=/opt/rh/devtoolset-10/root/usr/lib64:/opt/rh/devtoolset-10/root/usr/lib:$LD_LIBRARY_PATH\n\nFROM base as cuda\nCOPY manywheel/scripts/install_cuda.sh /tmp/install_cuda.sh\nRUN bash /tmp/install_cuda.sh ${BASE_CUDA_VERSION} && rm /tmp/install_cuda.sh\n\nFROM base as conda\nCOPY manywheel/scripts/install_conda.sh /tmp/install_conda.sh\nRUN bash /tmp/install_conda.sh && rm /tmp/install_conda.sh\n\n# Accept Anaconda's Terms of Service to avoid `CondaToSNonInteractiveError`\nRUN /opt/conda/bin/conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/main && \\\n    /opt/conda/bin/conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/r\n\nRUN PY_VERSIONS=(3.10 3.11 3.12 3.13) && \\\n    for pyver in \"${PY_VERSIONS[@]}\"; do \\\n        /opt/conda/bin/conda create -n py${pyver//./} python=${pyver} -yq && \\\n        /opt/conda/envs/py${pyver//./}/bin/pip install -i 'https://mirrors.aliyun.com/pypi/simple/' --no-cache-dir pybind11; \\\n    done && \\\n    /opt/conda/bin/conda clean -ya\n\nFROM base as cuda_final\nCOPY --from=cuda /usr/local/cuda-${BASE_CUDA_VERSION} /usr/local/cuda-${BASE_CUDA_VERSION}\nRUN ln -sf /usr/local/cuda-${BASE_CUDA_VERSION} /usr/local/cuda\nENV PATH=/usr/local/cuda/bin:$PATH\nCOPY --from=conda /opt/conda /opt/conda\nRUN /opt/conda/bin/conda init bash\n"
  },
  {
    "path": "builder/manywheel/README.md",
    "content": "# LMDeploy Build System\n\n## Building lmdeploy builder images\n\nTo build all lmdeploy builder images, such as \"lmdeploy-builder:cuda11.8\", \"\"lmdeploy-builder:cuda12.4\", execute:\n\n```bash\n./build_all_lmdeploy_builders.sh\n\n# Build and push images (for CI/CD)\nWITH_PUSH=true ./build_all_lmdeploy_builders.sh\n```\n\nFor custom builds with specific versions:\n\n```bash\nMANY_LINUX_VERSION=2014 GPU_ARCH_VERSION=12.4 ./build_lmdeploy_builder.sh\n```\n\n## Build lmdeploy wheels\n\nCompile all wheel packages:\n\n```bash\n./build_all_wheel.sh\n```\n"
  },
  {
    "path": "builder/manywheel/build_all_lmdeploy_builders.sh",
    "content": "#!/usr/bin/env bash\n\nset -eou pipefail\n\nTOPDIR=$(git rev-parse --show-toplevel)/builder\n\nfor cuda_version in 12.4 12.6 12.8; do\n    MANY_LINUX_VERSION=2014 GPU_ARCH_VERSION=\"${cuda_version}\" \"${TOPDIR}/manywheel/build_lmdeploy_builder.sh\"\ndone\n"
  },
  {
    "path": "builder/manywheel/build_all_wheel.sh",
    "content": "#!/usr/bin/env bash\n\nset -eou pipefail\n\nTOPDIR=$(git rev-parse --show-toplevel)/builder\n\nCUDA_VER=${CUDA_VER:-12.8}\n\nPLAT_NAME=manylinux2014_x86_64\nfor cuver in ${CUDA_VER}; do\n    DOCKER_TAG=cuda${cuver}\n    OUTPUT_FOLDER=cuda${cuver}_dist\n    for pyver in py310 py311 py312 py313; do\n        bash ${TOPDIR}/manywheel/build_wheel.sh ${pyver} ${PLAT_NAME} ${DOCKER_TAG} ${OUTPUT_FOLDER} \\\n            |& tee ${PLAT_NAME}.${pyver}.cuda${cuver}.log.txt\n    done\ndone\n"
  },
  {
    "path": "builder/manywheel/build_lmdeploy_builder.sh",
    "content": "#!/usr/bin/env bash\n\nset -eou pipefail\n\nTOPDIR=$(git rev-parse --show-toplevel)/builder\nGPU_ARCH_VERSION=${GPU_ARCH_VERSION}\nWITH_PUSH=${WITH_PUSH:-}\n\nTARGET=cuda_final\nDOCKER_TAG=cuda${GPU_ARCH_VERSION}\n\nDOCKER_IMAGE=openmmlab/lmdeploy-builder:${DOCKER_TAG}\nDOCKERFILE_SUFFIX=$([[ -n ${MANY_LINUX_VERSION} ]] && echo \"_${MANY_LINUX_VERSION}\" || echo \"\")\n\n# List of all build arguments (format: KEY=VALUE)\n# Empty values will be automatically filtered out later\nBUILD_ARGS=(\n    \"BASE_CUDA_VERSION=${GPU_ARCH_VERSION}\"\n    \"DEVTOOLSET_VERSION=9\"\n    \"HTTPS_PROXY=${HTTPS_PROXY:-}\"\n    \"HTTP_PROXY=${HTTP_PROXY:-}\"\n    # Add more parameters here if needed\n)\n\n# Base Docker build command arguments\ndocker_build_args=(\n    -t \"${DOCKER_IMAGE}\"\n    --target \"${TARGET}\"\n    -f \"${TOPDIR}/manywheel/Dockerfile${DOCKERFILE_SUFFIX}\"\n)\n\n# Process build arguments: filter empty values and format as --build-arg\nfor arg in \"${BUILD_ARGS[@]}\"; do\n    IFS='=' read -r key value <<< \"$arg\"  # Split KEY=VALUE\n    if [[ -n \"$value\" ]]; then  # Only add non-empty values\n        docker_build_args+=(--build-arg \"$arg\")\n    fi\ndone\n\n(\n    set -x\n    DOCKER_BUILDKIT=1 docker build \"${docker_build_args[@]}\" \"${TOPDIR}\"\n)\n\nif [[ \"${WITH_PUSH}\" == true ]]; then\n    (\n        set -x\n        docker push \"${DOCKER_IMAGE}\"\n    )\nfi\n"
  },
  {
    "path": "builder/manywheel/build_wheel.sh",
    "content": "#!/usr/bin/env bash\nset -eux\n\nPYTHON_VERSION=\"$1\"\nPLAT_NAME=\"$2\"\nDOCKER_TAG=\"$3\"\nOUTPUT_DIR=\"$4\"\n\nDOCKER_IMAGE=\"openmmlab/lmdeploy-builder:${DOCKER_TAG}\"\nexport USERID=$(id -u)\nexport GROUPID=$(id -g)\n\ncd \"$(dirname \"$0\")\"  # move inside the script directory\nmkdir -p \"${OUTPUT_DIR}\"\ndocker pull ${DOCKER_IMAGE}\ndocker run --rm -it \\\n    --env PYTHON_VERSION=\"${PYTHON_VERSION}\" \\\n    --env PLAT_NAME=\"${PLAT_NAME}\" \\\n    --env USERID=\"${USERID}\" \\\n    --env GROUPID=\"${GROUPID}\" \\\n    --volume \"$(pwd)/../../:/lmdeploy\" \\\n    --volume \"$(pwd)/${OUTPUT_DIR}:/lmdeploy_build\" \\\n    --volume \"$(pwd)/entrypoint_build.sh:/entrypoint_build.sh\" \\\n    --entrypoint /entrypoint_build.sh \\\n    ${DOCKER_IMAGE}\n"
  },
  {
    "path": "builder/manywheel/entrypoint_build.sh",
    "content": "#!/usr/bin/env bash\nset -eux\n\nexport PYTHON_VERSION=$PYTHON_VERSION\nexport PLAT_NAME=$PLAT_NAME\nexport USERID=${USERID}\nexport GROUPID=${GROUPID}\nexport NCCL_INCLUDE_DIR=/usr/local/cuda/include\nexport NCCL_LIB_DIR=/usr/local/cuda/lib64\n\nsource /opt/conda/bin/activate\nconda activate $PYTHON_VERSION\n\ncd lmdeploy\npip install build change-wheel-version\npython -m build --wheel -o /tmpbuild/\nfor file in $(find /tmpbuild/ -name \"*.whl\")\ndo\n    platform_tag=\"$(basename $file | cut -d- -f3-4)-${PLAT_NAME}\"\n    change_wheel_version /tmpbuild/*.whl --delete-old-wheel --platform-tag ${platform_tag}\ndone\nchown ${USERID}:${GROUPID} /tmpbuild/*\nmv /tmpbuild/* /lmdeploy_build/\n"
  },
  {
    "path": "builder/manywheel/scripts/install_conda.sh",
    "content": "#!/bin/bash\n\nset -ex\n\nwget -q https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh\nchmod +x  Miniconda3-latest-Linux-x86_64.sh\nbash ./Miniconda3-latest-Linux-x86_64.sh -b -p /opt/conda\nrm Miniconda3-latest-Linux-x86_64.sh\n"
  },
  {
    "path": "builder/manywheel/scripts/install_cuda.sh",
    "content": "#!/bin/bash\n\nset -ex\n\nfunction install_118 {\n    echo \"Installing CUDA 11.8 and NCCL 2.15\"\n    rm -rf /usr/local/cuda-11.8 /usr/local/cuda\n    # install CUDA 11.8.0 in the same container\n    wget -q https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run\n    chmod +x cuda_11.8.0_520.61.05_linux.run\n    ./cuda_11.8.0_520.61.05_linux.run --toolkit --silent\n    rm -f cuda_11.8.0_520.61.05_linux.run\n    rm -f /usr/local/cuda && ln -s /usr/local/cuda-11.8 /usr/local/cuda\n\n    # NCCL license: https://docs.nvidia.com/deeplearning/nccl/#licenses\n    mkdir tmp_nccl && cd tmp_nccl\n    wget -q https://developer.download.nvidia.com/compute/redist/nccl/v2.15.5/nccl_2.15.5-1+cuda11.8_x86_64.txz\n    tar xf nccl_2.15.5-1+cuda11.8_x86_64.txz\n    cp -a nccl_2.15.5-1+cuda11.8_x86_64/include/* /usr/local/cuda/include/\n    cp -a nccl_2.15.5-1+cuda11.8_x86_64/lib/* /usr/local/cuda/lib64/\n    cd ..\n    rm -rf tmp_nccl\n    ldconfig\n}\n\nfunction install_121 {\n    echo \"Installing CUDA 12.1 and NCCL 2.18.1\"\n    rm -rf /usr/local/cuda-12.1 /usr/local/cuda\n    # install CUDA 12.1.0 in the same container\n    wget -q https://developer.download.nvidia.com/compute/cuda/12.1.0/local_installers/cuda_12.1.0_530.30.02_linux.run\n    chmod +x cuda_12.1.0_530.30.02_linux.run\n    ./cuda_12.1.0_530.30.02_linux.run --toolkit --silent\n    rm -f cuda_12.1.0_530.30.02_linux.run\n    rm -f /usr/local/cuda && ln -s /usr/local/cuda-12.1 /usr/local/cuda\n\n    # NCCL license: https://docs.nvidia.com/deeplearning/nccl/#licenses\n    mkdir tmp_nccl && cd tmp_nccl\n    wget -q https://developer.download.nvidia.com/compute/redist/nccl/v2.18.1/nccl_2.18.1-1+cuda12.1_x86_64.txz\n    tar xf nccl_2.18.1-1+cuda12.1_x86_64.txz\n    cp -a nccl_2.18.1-1+cuda12.1_x86_64/include/* /usr/local/cuda/include/\n    cp -a nccl_2.18.1-1+cuda12.1_x86_64/lib/* /usr/local/cuda/lib64/\n    cd ..\n    rm -rf tmp_nccl\n    ldconfig\n}\n\nfunction install_124 {\n    echo \"Installing CUDA 12.4 and NCCL 2.25.1\"\n    rm -rf /usr/local/cuda-12.4 /usr/local/cuda\n    # install CUDA 12.4.1 in the same container\n    wget -q https://developer.download.nvidia.com/compute/cuda/12.4.1/local_installers/cuda_12.4.1_550.54.15_linux.run\n    chmod +x cuda_12.4.1_550.54.15_linux.run\n    ./cuda_12.4.1_550.54.15_linux.run --toolkit --silent\n    rm -f cuda_12.4.1_550.54.15_linux.run\n    rm -f /usr/local/cuda && ln -s /usr/local/cuda-12.4 /usr/local/cuda\n\n    # NCCL license: https://docs.nvidia.com/deeplearning/nccl/#licenses\n    mkdir tmp_nccl && cd tmp_nccl\n    wget -q https://developer.download.nvidia.com/compute/redist/nccl/v2.25.1/nccl_2.25.1-1+cuda12.4_x86_64.txz\n    tar xf nccl_2.25.1-1+cuda12.4_x86_64.txz\n    cp -a nccl_2.25.1-1+cuda12.4_x86_64/include/* /usr/local/cuda/include/\n    cp -a nccl_2.25.1-1+cuda12.4_x86_64/lib/* /usr/local/cuda/lib64/\n    cd ..\n    rm -rf tmp_nccl\n    ldconfig\n}\n\nfunction install_126 {\n    echo \"Installing CUDA 12.6 and NCCL 2.24.3\"\n    rm -rf /usr/local/cuda-12.6 /usr/local/cuda\n    # install CUDA 12.6.3 in the same container\n    wget -q https://developer.download.nvidia.com/compute/cuda/12.6.3/local_installers/cuda_12.6.3_560.35.05_linux.run\n    chmod +x cuda_12.6.3_560.35.05_linux.run\n    ./cuda_12.6.3_560.35.05_linux.run --toolkit --silent\n    rm -f cuda_12.6.3_560.35.05_linux.run\n    rm -f /usr/local/cuda && ln -s /usr/local/cuda-12.6 /usr/local/cuda\n\n    # NCCL license: https://docs.nvidia.com/deeplearning/nccl/#licenses\n    mkdir tmp_nccl && cd tmp_nccl\n    wget -q https://developer.download.nvidia.com/compute/redist/nccl/v2.24.3/nccl_2.24.3-1+cuda12.6_x86_64.txz\n    tar xf nccl_2.24.3-1+cuda12.6_x86_64.txz\n    cp -a nccl_2.24.3-1+cuda12.6_x86_64/include/* /usr/local/cuda/include/\n    cp -a nccl_2.24.3-1+cuda12.6_x86_64/lib/* /usr/local/cuda/lib64/\n    cd ..\n    rm -rf tmp_nccl\n    ldconfig\n}\n\nfunction install_128 {\n    echo \"Installing CUDA 12.8 and NCCL 2.25.1\"\n    rm -rf /usr/local/cuda-12.8 /usr/local/cuda\n    # install CUDA 12.8.1 in the same container\n    wget -q https://developer.download.nvidia.com/compute/cuda/12.8.1/local_installers/cuda_12.8.1_570.124.06_linux.run\n    chmod +x cuda_12.8.1_570.124.06_linux.run\n    ./cuda_12.8.1_570.124.06_linux.run --toolkit --silent\n    rm -f cuda_12.8.1_570.124.06_linux.run\n    rm -f /usr/local/cuda && ln -s /usr/local/cuda-12.8 /usr/local/cuda\n\n    # NCCL license: https://docs.nvidia.com/deeplearning/nccl/#licenses\n    mkdir tmp_nccl && cd tmp_nccl\n    wget -q https://developer.download.nvidia.com/compute/redist/nccl/v2.25.1/nccl_2.25.1-1+cuda12.8_x86_64.txz\n    tar xf nccl_2.25.1-1+cuda12.8_x86_64.txz\n    cp -a nccl_2.25.1-1+cuda12.8_x86_64/include/* /usr/local/cuda/include/\n    cp -a nccl_2.25.1-1+cuda12.8_x86_64/lib/* /usr/local/cuda/lib64/\n    cd ..\n    rm -rf tmp_nccl\n    ldconfig\n}\n\nif test $# -eq 0\nthen\n    echo \"doesn't provide cuda version\"; exit 1;\nfi\n\n# idiomatic parameter and option handling in sh\nwhile test $# -gt 0\ndo\n    case \"$1\" in\n    11.8) install_118\n\t        ;;\n    12.1) install_121\n            ;;\n    12.4) install_124\n            ;;\n    12.6) install_126\n            ;;\n    12.8) install_128\n            ;;\n\t*) echo \"bad argument $1\"; exit 1\n\t   ;;\n    esac\n    shift\ndone\n"
  },
  {
    "path": "builder/manywheel/scripts/install_openmpi.sh",
    "content": "#!/bin/bash\n\nset -ex\n\nwget -q https://download.open-mpi.org/release/open-mpi/v4.1/openmpi-4.1.5.tar.gz\ntar xf openmpi-4.1.5.tar.gz\ncd openmpi-4.1.5\n./configure --prefix=/usr/local/mpi\nmake -j$(nproc)\nmake install\n"
  },
  {
    "path": "builder/windows/README.md",
    "content": "# Build lmdeploy on windows\n\n## Requirements\n\n- [CMake 3.17+](https://github.com/Kitware/CMake/releases)\n- [Visual Studio 2019+](https://visualstudio.microsoft.com/downloads/)\n- [CUDA Toolkit 11.8+](https://developer.nvidia.com/cuda-toolkit-archive)\n\n## Build lmdeploy wheel\n\n```powershell\npip install build\npython -m build --wheel\n```\n"
  },
  {
    "path": "builder/windows/generate.ps1",
    "content": "cmake .. -A x64 -T \"v143,cuda=$env:CUDA_PATH\" `\n    -DCMAKE_BUILD_TYPE=Release `\n    -DCMAKE_INSTALL_PREFIX=install `\n    -DBUILD_PY_FFI=ON `\n    -DBUILD_MULTI_GPU=OFF `\n    -DUSE_NVTX=OFF `\n    -DBUILD_TEST=\"$env:BUILD_TEST\"\n"
  },
  {
    "path": "builder/windows/setup_cuda.ps1",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n# Adapted from https://github.com/thewh1teagle/vibe/blob/5d7b75568ca65ab635bdf0ce912bbc975a043066/scripts/setup_cuda.ps1\n\n$CUDA_VERSION_FULL = $env:INPUT_CUDA_VERSION # v12.1.0 or v11.8.0\n\n# Make sure CUDA_VERSION_FULL is set and valid, otherwise error.\n# Validate CUDA version, extracting components via regex\n$cuda_ver_matched = $CUDA_VERSION_FULL -match \"^(?<major>[1-9][0-9]*)\\.(?<minor>[0-9]+)\\.(?<patch>[0-9]+)$\"\nif(-not $cuda_ver_matched){\n    Write-Output \"Invalid CUDA version specified, <major>.<minor>.<patch> required. '$CUDA_VERSION_FULL'.\"\n    exit 1\n}\n$CUDA_MAJOR=$Matches.major\n$CUDA_MINOR=$Matches.minor\n$CUDA_PATCH=$Matches.patch\n\nWrite-Output \"Selected CUDA version: $CUDA_VERSION_FULL\"\n\n$src = \"cuda\"\n$dst = \"C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v$($CUDA_MAJOR).$($CUDA_MINOR)\"\n$installer = \"cuda.exe\"\n\nif ($CUDA_VERSION_FULL -eq \"12.1.0\") {\n    $downloadUrl = \"https://developer.download.nvidia.com/compute/cuda/12.1.0/local_installers/cuda_12.1.0_531.14_windows.exe\"\n} elseif ($CUDA_VERSION_FULL -eq \"11.8.0\") {\n    $downloadUrl = \"https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_522.06_windows.exe\"\n} elseif ($CUDA_VERSION_FULL -eq \"12.5.0\") {\n    $downloadUrl = \"https://developer.download.nvidia.com/compute/cuda/12.5.0/local_installers/cuda_12.5.0_555.85_windows.exe\"\n} elseif ($CUDA_VERSION_FULL -eq \"12.6.2\") {\n    $downloadUrl = \"https://developer.download.nvidia.com/compute/cuda/12.6.2/local_installers/cuda_12.6.2_560.94_windows.exe\"\n} elseif ($CUDA_VERSION_FULL -eq \"12.8.1\") {\n    $downloadUrl = \"https://developer.download.nvidia.com/compute/cuda/12.8.1/local_installers/cuda_12.8.1_572.61_windows.exe\"\n} else {\n    Write-Output \"Unsupported CUDA version specified\"\n    exit 1\n}\n\n# Download cuda\nWrite-Output \"Downloading CUDA from: $downloadUrl\"\nif (-not (Test-Path -Path $installer)) {\n    Write-Output \"Downloading CUDA installer...\"\n    # If the file does not exist, download it\n    & \"C:\\msys64\\usr\\bin\\wget\" $downloadUrl -O $installer -q\n}\n\n# Extract cuda\nif (-not (Test-Path -Path $src -Type Container)) {\n    # Extract CUDA using 7-Zip\n    Write-Output \"Extracting CUDA using 7-Zip...\"\n    mkdir \"$src\"\n    & 'C:\\Program Files\\7-Zip\\7z' x $installer -o\"$src\"\n}\n\n# Create destination directory if it doesn't exist\nif (-Not (Test-Path -Path $dst)) {\n    Write-Output \"Creating destination directory: $dst\"\n    New-Item -Path $dst -ItemType Directory\n}\n\n# Get directories to process from the source path\n$directories = Get-ChildItem -Directory -Path $src\n$whitelist = @(\"CUDA_Toolkit_Release_Notes.txt\", \"DOCS\", \"EULA.txt\", \"LICENSE\", \"README\", \"version.json\")\n\nforeach ($dir in $directories) {\n    # Get all subdirectories and files in the current directory\n    $items = Get-ChildItem -Path (Join-Path $src $dir.Name)\n\n    foreach ($item in $items) {\n        if ($item.PSIsContainer) {\n            # If the item is a directory, copy its contents\n            Write-Output \"Copying contents of directory $($item.FullName) to $dst\"\n            Copy-Item -Path \"$($item.FullName)\\*\" -Destination $dst -Recurse -Force\n        } else {\n            if ($whitelist -contains $item.Name) {\n                Write-Output \"Copying file $($item.FullName) to $dst\"\n                Copy-Item -Path $item.FullName -Destination $dst -Force\n            }\n        }\n    }\n}\n\n# Add msbuild cuda extensions\n$msBuildExtensions = (Get-ChildItem  \"$src\\visual_studio_integration\\CUDAVisualStudioIntegration\\extras\\visual_studio_integration\\MSBuildExtensions\").fullname\n(Get-ChildItem 'C:\\Program Files\\Microsoft Visual Studio\\2022\\*\\MSBuild\\Microsoft\\VC\\*\\BuildCustomizations').FullName | ForEach-Object {\n    $destination = $_\n    $msBuildExtensions | ForEach-Object {\n        $extension = $_\n        Copy-Item $extension -Destination $destination -Force\n        Write-Output \"Copied $extension to $destination\"\n    }\n}\n\n$CUDA_FLAGS=\"-allow-unsupported-compiler -D_ALLOW_COMPILER_AND_STL_VERSION_MISMATCH=1\"\n\n# Add to Github env\nWrite-Output \"Setting environment variables for GitHub Actions...\"\n\nWrite-Output \"CUDA_PATH=$dst\"\nWrite-Output \"CUDA_PATH_V$($CUDA_MAJOR)_$($CUDA_MINOR)=$dst\"\nWrite-Output \"CUDA_PATH_VX_Y=CUDA_PATH_V$($CUDA_MAJOR)_$($CUDA_MINOR)\"\nWrite-Output \"CUDA_VERSION=$CUDA_VERSION_FULL\"\n\nWrite-Output \"CUDA_PATH=$dst\" >> $env:GITHUB_ENV\nWrite-Output \"CUDA_PATH_V$($CUDA_MAJOR)_$($CUDA_MINOR)=$dst\" >> $env:GITHUB_ENV\nWrite-Output \"CUDA_PATH_VX_Y=CUDA_PATH_V$($CUDA_MAJOR)_$($CUDA_MINOR)\" >> $env:GITHUB_ENV\nWrite-Output \"CudaToolkitDir=$dst\" >> $env:GITHUB_ENV\nWrite-Output \"CMAKE_CUDA_COMPILER=$dst\\bin\\nvcc.exe\" >> $env:GITHUB_ENV\nWrite-Output \"NVCC_APPEND_FLAGS=$CUDA_FLAGS\" >> $env:GITHUB_ENV\n\nWrite-Output \"CUDA_VERSION=$CUDA_VERSION_FULL\" >> $env:GITHUB_ENV\nWrite-Output \"Setup completed.\"\n"
  },
  {
    "path": "cmake/Modules/FindNCCL.cmake",
    "content": "# Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved.\n#\n# From PyTorch:\n#\n# Copyright (c) 2016-     Facebook, Inc            (Adam Paszke)\n# Copyright (c) 2014-     Facebook, Inc            (Soumith Chintala)\n# Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)\n# Copyright (c) 2012-2014 Deepmind Technologies    (Koray Kavukcuoglu)\n# Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)\n# Copyright (c) 2011-2013 NYU                      (Clement Farabet)\n# Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)\n# Copyright (c) 2006      Idiap Research Institute (Samy Bengio)\n# Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)\n#\n# From Caffe2:\n#\n# Copyright (c) 2016-present, Facebook Inc. All rights reserved.\n#\n# All contributions by Facebook:\n# Copyright (c) 2016 Facebook Inc.\n#\n# All contributions by Google:\n# Copyright (c) 2015 Google Inc.\n# All rights reserved.\n#\n# All contributions by Yangqing Jia:\n# Copyright (c) 2015 Yangqing Jia\n# All rights reserved.\n#\n# All contributions by Kakao Brain:\n# Copyright 2019-2020 Kakao Brain\n#\n# All contributions from Caffe:\n# Copyright(c) 2013, 2014, 2015, the respective contributors\n# All rights reserved.\n#\n# All other contributions:\n# Copyright(c) 2015, 2016 the respective contributors\n# All rights reserved.\n#\n# Caffe2 uses a copyright model similar to Caffe: each contributor holds\n# copyright over their contributions to Caffe2. The project versioning records\n# all such contribution and copyright details. If a contributor wants to further\n# mark their specific copyright on a particular contribution, they should\n# indicate their copyright solely in the commit message of the change when it is\n# committed.\n#\n# All rights reserved.\n#\n# Redistribution and use in source and binary forms, with or without\n# modification, are permitted provided that the following conditions are met:\n#\n# 1. Redistributions of source code must retain the above copyright\n#    notice, this list of conditions and the following disclaimer.\n#\n# 2. Redistributions in binary form must reproduce the above copyright\n#    notice, this list of conditions and the following disclaimer in the\n#    documentation and/or other materials provided with the distribution.\n#\n# 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America\n#    and IDIAP Research Institute nor the names of its contributors may be\n#    used to endorse or promote products derived from this software without\n#    specific prior written permission.\n#\n# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE\n# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE\n# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR\n# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF\n# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS\n# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN\n# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)\n# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE\n# POSSIBILITY OF SUCH DAMAGE.\n#\n# Find the nccl libraries\n#\n# The following variables are optionally searched for defaults\n#  NCCL_ROOT: Base directory where all NCCL components are foundHong Xu, 1 year ago: • Let CMake handle NCCL detection instead of ou…\n#  NCCL_INCLUDE_DIR: Directory where NCCL header is foundPieter Noordhuis, 3 years ago: • Bump gloo\n#  NCCL_LIB_DIR: Directory where NCCL library is found\n#\n# The following are set after configuration is done:\n#  NCCL_FOUND\n#  NCCL_INCLUDE_DIRS\n#  NCCL_LIBRARIES\n#\n# The path hints include CUDA_TOOLKIT_ROOT_DIR seeing as some folks\n# install NCCL in the same location as the CUDA toolkit.\n# See https://github.com/caffe2/caffe2/issues/1601\n\nset(NCCL_INCLUDE_DIR $ENV{NCCL_INCLUDE_DIR} CACHE PATH \"Folder contains NVIDIA NCCL headers\")\nset(NCCL_LIB_DIR $ENV{NCCL_LIB_DIR} CACHE PATH \"Folder contains NVIDIA NCCL libraries\")\nset(NCCL_VERSION $ENV{NCCL_VERSION} CACHE STRING \"Version of NCCL to build with\")\n\nif ($ENV{NCCL_ROOT_DIR})\n  message(WARNING \"NCCL_ROOT_DIR is deprecated. Please set NCCL_ROOT instead.\")\nendif()\nlist(APPEND NCCL_ROOT $ENV{NCCL_ROOT_DIR} ${CUDA_TOOLKIT_ROOT_DIR})\n# Compatible layer for CMake <3.12. NCCL_ROOT will be accounted in for searching paths and libraries for CMake >=3.12.\nlist(APPEND CMAKE_PREFIX_PATH ${NCCL_ROOT})\n\nfind_path(NCCL_INCLUDE_DIRS\n  NAMES nccl.h\n  HINTS ${NCCL_INCLUDE_DIR})\n\nif (USE_STATIC_NCCL)\n  MESSAGE(STATUS \"USE_STATIC_NCCL is set. Linking with static NCCL library.\")\n  SET(NCCL_LIBNAME \"nccl_static\")\n  if (NCCL_VERSION)  # Prefer the versioned library if a specific NCCL version is specified\n    set(CMAKE_FIND_LIBRARY_SUFFIXES \".a.${NCCL_VERSION}\" ${CMAKE_FIND_LIBRARY_SUFFIXES})\n  endif()\nelse()\n  SET(NCCL_LIBNAME \"nccl\")\n  if (NCCL_VERSION)  # Prefer the versioned library if a specific NCCL version is specified\n    set(CMAKE_FIND_LIBRARY_SUFFIXES \".so.${NCCL_VERSION}\" ${CMAKE_FIND_LIBRARY_SUFFIXES})\n  endif()\nendif()\n\nfind_library(NCCL_LIBRARIES\n  NAMES ${NCCL_LIBNAME}\n  HINTS ${NCCL_LIB_DIR})\n\ninclude(FindPackageHandleStandardArgs)\nfind_package_handle_standard_args(NCCL DEFAULT_MSG NCCL_INCLUDE_DIRS NCCL_LIBRARIES)\n\nif(NCCL_FOUND)  # obtaining NCCL version and some sanity checks\n  set (NCCL_HEADER_FILE \"${NCCL_INCLUDE_DIRS}/nccl.h\")\n  message (STATUS \"Determining NCCL version from ${NCCL_HEADER_FILE}...\")\n  set (OLD_CMAKE_REQUIRED_INCLUDES ${CMAKE_REQUIRED_INCLUDES})\n  list (APPEND CMAKE_REQUIRED_INCLUDES ${NCCL_INCLUDE_DIRS})\n  include(CheckCXXSymbolExists)\n  check_cxx_symbol_exists(NCCL_VERSION_CODE nccl.h NCCL_VERSION_DEFINED)\n\n  if (NCCL_VERSION_DEFINED)\n    set(file \"${PROJECT_BINARY_DIR}/detect_nccl_version.cc\")\n    file(WRITE ${file} \"\n      #include <iostream>\n      #include <nccl.h>\n      int main()\n      {\n        std::cout << NCCL_MAJOR << '.' << NCCL_MINOR << '.' << NCCL_PATCH << std::endl;\n        int x;\n        ncclGetVersion(&x);\n        return x == NCCL_VERSION_CODE;\n      }\n\")\n    try_run(NCCL_VERSION_MATCHED compile_result ${PROJECT_BINARY_DIR} ${file}\n          RUN_OUTPUT_VARIABLE NCCL_VERSION_FROM_HEADER\n          CMAKE_FLAGS  \"-DINCLUDE_DIRECTORIES=${NCCL_INCLUDE_DIRS}\"\n          LINK_LIBRARIES ${NCCL_LIBRARIES})\n    if (NOT NCCL_VERSION_MATCHED)\n      message(FATAL_ERROR \"Found NCCL header version and library version do not match! \\\n(include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES}) Please set NCCL_INCLUDE_DIR and NCCL_LIB_DIR manually.\")\n    endif()\n    message(STATUS \"NCCL version: ${NCCL_VERSION_FROM_HEADER}\")\n  else()\n    # message(STATUS \"NCCL version < 2.3.5-5\")\n  endif ()\n  set (CMAKE_REQUIRED_INCLUDES ${OLD_CMAKE_REQUIRED_INCLUDES})\n\n  message(STATUS \"Found NCCL (include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES})\")\n  mark_as_advanced(NCCL_ROOT_DIR NCCL_INCLUDE_DIRS NCCL_LIBRARIES)\nendif()\n"
  },
  {
    "path": "cmake/TritonTurboMindBackendConfig.cmake.in",
    "content": "# Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved.\n#\n# Redistribution and use in source and binary forms, with or without\n# modification, are permitted provided that the following conditions\n# are met:\n#  * Redistributions of source code must retain the above copyright\n#    notice, this list of conditions and the following disclaimer.\n#  * Redistributions in binary form must reproduce the above copyright\n#    notice, this list of conditions and the following disclaimer in the\n#    documentation and/or other materials provided with the distribution.\n#  * Neither the name of NVIDIA CORPORATION nor the names of its\n#    contributors may be used to endorse or promote products derived\n#    from this software without specific prior written permission.\n#\n# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY\n# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR\n# PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT OWNER OR\n# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,\n# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,\n# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR\n# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY\n# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n\ninclude(CMakeFindDependencyMacro)\n\nget_filename_component(\n  TRITONPYTORCHBACKEND_CMAKE_DIR \"${CMAKE_CURRENT_LIST_FILE}\" PATH\n)\n\nlist(APPEND CMAKE_MODULE_PATH ${TRITONPYTORCHBACKEND_CMAKE_DIR})\n\nif(NOT TARGET TritonPyTorchBackend::triton-pytorch-backend)\n  include(\"${TRITONPYTORCHBACKEND_CMAKE_DIR}/TritonPyTorchBackendTargets.cmake\")\nendif()\n\nset(TRITONPYTORCHBACKEND_LIBRARIES TritonPyTorchBackend::triton-pytorch-backend)\n"
  },
  {
    "path": "cmake/TurboMindConfig.cmake.in",
    "content": "# Copyright (c) 2021-2023, NVIDIA CORPORATION. All rights reserved.\n#\n# Redistribution and use in source and binary forms, with or without\n# modification, are permitted provided that the following conditions\n# are met:\n#  * Redistributions of source code must retain the above copyright\n#    notice, this list of conditions and the following disclaimer.\n#  * Redistributions in binary form must reproduce the above copyright\n#    notice, this list of conditions and the following disclaimer in the\n#    documentation and/or other materials provided with the distribution.\n#  * Neither the name of NVIDIA CORPORATION nor the names of its\n#    contributors may be used to endorse or promote products derived\n#    from this software without specific prior written permission.\n#\n# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY\n# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR\n# PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT OWNER OR\n# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,\n# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,\n# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR\n# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY\n# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n\ninclude(CMakeFindDependencyMacro)\n\nget_filename_component(\n  TURBOMIND_CMAKE_DIR \"${CMAKE_CURRENT_LIST_FILE}\" PATH\n)\n\nlist(APPEND CMAKE_MODULE_PATH ${TURBOMIND_CMAKE_DIR})\n\nif(NOT TARGET transformer-shared)\n  include(\"${TURBOMIND_CMAKE_DIR}/TurboMindTargets.cmake\")\nendif()\n\nset(TURBOMIND_LIBRARIES transformer-shared)\n"
  },
  {
    "path": "cmake/yaml-cpp_cmake_policy.patch",
    "content": "diff --git a/CMakeLists.txt b/CMakeLists.txt\nindex 46dc180..b746ac1 100644\n--- a/CMakeLists.txt\n+++ b/CMakeLists.txt\n@@ -1,5 +1,5 @@\n # 3.5 is actually available almost everywhere, but this a good minimum\n-cmake_minimum_required(VERSION 3.4)\n+cmake_minimum_required(VERSION 3.5)\n \n # enable MSVC_RUNTIME_LIBRARY target property\n # see https://cmake.org/cmake/help/latest/policy/CMP0091.html\n"
  },
  {
    "path": "debug.sh",
    "content": "#!/bin/bash -e\n\nbuilder=\"-G Ninja\"\n\nif [ \"$1\" == \"make\" ]; then\n    builder=\"\"\nfi\n\ncmake ${builder} .. \\\n    -DCMAKE_BUILD_TYPE=RelWithDebInfo \\\n    -DCMAKE_EXPORT_COMPILE_COMMANDS=1 \\\n    -DCMAKE_INSTALL_PREFIX=./install \\\n    -DBUILD_PY_FFI=ON \\\n    -DBUILD_MULTI_GPU=ON \\\n    -DCMAKE_CUDA_FLAGS=\"-lineinfo\" \\\n    -DUSE_NVTX=ON \\\n    -DPYTHON_EXECUTABLE=$(which python3) \\\n    -DFETCHCONTENT_QUIET=OFF \\\n    -DBUILD_TEST=ON\n"
  },
  {
    "path": "docker/Dockerfile",
    "content": "# Base images\nARG IMAGE_TYPE=final\nARG CUDA_VERSION=cu12\n\nFROM nvidia/cuda:13.0.2-devel-ubuntu22.04 AS cu13\nENV CUDA_VERSION_SHORT=cu130\n\nFROM nvidia/cuda:12.8.1-devel-ubuntu22.04 AS cu12.8\nENV CUDA_VERSION_SHORT=cu128\n\nFROM nvidia/cuda:12.6.3-devel-ubuntu22.04 AS cu12\nENV CUDA_VERSION_SHORT=cu126\n\n# Builder image\nFROM ${CUDA_VERSION} AS dev\nARG PYTHON_VERSION=3.10\n\nENV PATH=/opt/py3/bin:/root/.local/bin:${PATH}\nENV DEBIAN_FRONTEND=noninteractive\nENV TZ=Etc/UTC\n\nRUN --mount=type=cache,target=/root/.cache \\\n    sed -i 's|http://archive.ubuntu.com|http://azure.archive.ubuntu.com|g' /etc/apt/sources.list && \\\n    apt-get update -y && \\\n    apt-get install -y --no-install-recommends \\\n        tzdata wget curl openssh-server ssh sudo git-core \\\n        libibverbs1 ibverbs-providers ibverbs-utils librdmacm1 libibverbs-dev rdma-core libmlx5-1 \\\n        libssl-dev pkg-config vim rapidjson-dev libgoogle-glog-dev gdb && \\\n    apt-get clean -y && \\\n    rm -rf /var/lib/apt/lists/* && \\\n    wget -qO- https://astral.sh/uv/install.sh | sh && \\\n    uv venv -p python${PYTHON_VERSION} --seed /opt/py3 && \\\n    pip install --upgrade pip build\n\nFROM dev AS builder\n\n# Should be in the lmdeploy root directory when building docker image\nCOPY . /opt/lmdeploy\nWORKDIR /opt/lmdeploy\n\nRUN --mount=type=cache,target=/root/.cache \\\n    pip install -r requirements/runtime_cuda.txt --extra-index-url https://download.pytorch.org/whl/${CUDA_VERSION_SHORT}\n\nRUN --mount=type=cache,target=/root/.cache \\\n    docker/build.sh\n\nRUN --mount=type=cache,target=/root/.cache \\\n    docker/prepare_wheel.sh\n\n# Runtime image\nFROM nvidia/cuda:13.0.2-base-ubuntu22.04 AS cu13-base\nENV CUDA_VERSION_SHORT=cu130\n\nFROM nvidia/cuda:12.8.1-base-ubuntu22.04 AS cu12.8-base\nENV CUDA_VERSION_SHORT=cu128\n\nFROM nvidia/cuda:12.6.3-base-ubuntu22.04 AS cu12-base\nENV CUDA_VERSION_SHORT=cu126\n\nFROM ${CUDA_VERSION}-base AS final\nARG PYTHON_VERSION=3.10\n\n# Some dependencies such as timm(required by InternVL models) are missed in the docker image\n# We need to install them via pip. Since these dependencies are listed in requirements/serve.txt,\n# we copy the requirements directory here.\nCOPY requirements /tmp/requirements\nCOPY docker/install.sh /tmp/install.sh\nRUN --mount=type=cache,target=/root/.cache \\\n    --mount=type=cache,target=/wheels,from=builder,source=/wheels \\\n    /tmp/install.sh\n\n# explicitly set ptxas path for triton\nENV PATH=/opt/py3/bin:$PATH\nENV TRITON_PTXAS_PATH=/usr/local/cuda/bin/ptxas\nFROM ${IMAGE_TYPE}\n"
  },
  {
    "path": "docker/Dockerfile.jetson",
    "content": "# Base images\nFROM nvcr.io/nvidia/l4t-base:r36.2.0\nENV CUDA_VER=12.6 \\\n    PYTHON_VERSION=3.10 \\\n    PATH=/opt/py3/bin:/root/.local/bin:/usr/local/cuda/bin:${PATH}\n\nRUN --mount=type=cache,target=/root/.cache \\\n    --mount=type=cache,target=/tmp/download \\\n    export CUDA_SUFFIX=$(echo $CUDA_VER | sed 's/\\./-/g') && \\\n    cd /tmp/download && \\\n    mkdir -p /opt/nvidia/l4t-packages/ && \\\n    touch /opt/nvidia/l4t-packages/.nv-l4t-disable-boot-fw-update-in-preinstall && \\\n    wget -q \"https://repo.download.nvidia.com/jetson/t234/pool/main/n/nvidia-l4t-core/nvidia-l4t-core_36.2.0-20231218214829_arm64.deb\" && \\\n    wget -q \"https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/arm64/cuda-keyring_1.1-1_all.deb\" && \\\n    yes | dpkg -i nvidia-l4t-core_*.deb cuda-keyring_*.deb && \\\n    rm -rf *.deb *.deb.* && \\\n    apt update -y && \\\n    apt-get install -y --no-install-recommends \\\n        cuda-toolkit-${CUDA_SUFFIX} cuda-compat-${CUDA_SUFFIX} libcudnn9-cuda-12 libcusparselt0 cudss \\\n        git libopenblas-dev python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv && \\\n    apt-get clean -y && \\\n    rm -rf /var/lib/apt/lists/* && \\\n    python${PYTHON_VERSION} -m venv /opt/py3 && \\\n    mkdir -p /wheels\n\n# Should be in the lmdeploy root directory when building docker image\nCOPY . /opt/lmdeploy\nWORKDIR /opt/lmdeploy\n\nRUN --mount=type=cache,target=/root/.cache \\\n    --mount=type=cache,target=/opt/pytorch \\\n    pip install build change-wheel-version && \\\n    python -m build -w -o /wheels -v . && \\\n    change_wheel_version --local-version cu126 --delete-old-wheel /wheels/lmdeploy*.whl && \\\n    pip install -v /wheels/lmdeploy*.whl --index-url https://pypi.jetson-ai-lab.io/jp6/cu126/+simple/\n"
  },
  {
    "path": "docker/Dockerfile_ascend_a2_300i",
    "content": "# DOCKER_BUILDKIT=1 docker build --build-arg ASCEND_DEVICE_TYPE=ascend_a2 \\\n#     --build-arg DLINFER_TAG=main --build-arg LMDEPLOY_TAG=main --network=host \\\n#     -t lmdeploy_dlinfer:a2 -f  Dockerfile_ascend_a2_300i .\nARG ASCEND_DEVICE_TYPE=ascend_a2\nARG ASCEND_HUB=swr.cn-south-1.myhuaweicloud.com/ascendhub\n\nFROM ${ASCEND_HUB}/cann:8.3.rc1-910b-ubuntu22.04-py3.11 AS ascend_a2_base\nFROM ${ASCEND_HUB}/cann:8.3.rc1-310p-ubuntu22.04-py3.11 AS ascend_300i_base\n\nFROM ${ASCEND_DEVICE_TYPE}_base AS builder\nENV DEBIAN_FRONTEND=noninteractive\nRUN apt update -y && \\\n    apt install -y libjemalloc-dev git && \\\n    apt clean && rm -rf /var/lib/apt/lists/*\n\nENV HCCL_CONNECT_TIMEOUT=7200 \\\n    PYTORCH_NPU_ALLOC_CONF=\"expandable_segments:True\" \\\n    HCCL_OP_EXPANSION_MODE=\"AIV\" \\\n    LD_PRELOAD=/usr/lib/aarch64-linux-gnu/libjemalloc.so:$LD_PRELOAD\n\nARG DLINFER_TAG=main\nARG LMDEPLOY_TAG=main\nRUN --mount=type=cache,target=/root/.cache \\\n    pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple && \\\n    pip config set global.trusted-host pypi.tuna.tsinghua.edu.cn && \\\n    pip install --no-cache-dir torch==2.8.0 torch-npu==2.8.0 torchvision==0.23.0 && \\\n    TORCH_DEVICE_BACKEND_AUTOLOAD=0 DEVICE=ascend pip install git+https://github.com/DeepLink-org/dlinfer.git@${DLINFER_TAG} && \\\n    LMDEPLOY_TARGET_DEVICE=ascend pip install git+https://github.com/InternLM/lmdeploy.git@${LMDEPLOY_TAG}\n"
  },
  {
    "path": "docker/Dockerfile_ascend_a3",
    "content": "# DOCKER_BUILDKIT=1 docker build --build-arg ASCEND_DEVICE=ascend_a3 \\\n#     --build-arg DLINFER_TAG=main --build-arg LMDEPLOY_TAG=main --network=host \\\n#     -t lmdeploy_dlinfer:a3 -f  Dockerfile_ascend_a3 .\nARG ASCEND_DEVICE_TYPE=ascend_a3\nARG ASCEND_HUB=swr.cn-south-1.myhuaweicloud.com/ascendhub\n\nFROM ${ASCEND_HUB}/cann:8.5.0-a3-openeuler24.03-py3.11 AS ascend_a3_base\n\nFROM ${ASCEND_DEVICE_TYPE}_base AS builder\nENV DEBIAN_FRONTEND=noninteractive\nRUN dnf update -y && \\\n    dnf install -y jemalloc jemalloc-devel && \\\n    dnf clean all && rm -rf /var/cache/dnf\n\nENV HCCL_CONNECT_TIMEOUT=7200 \\\n    PYTORCH_NPU_ALLOC_CONF=\"expandable_segments:True\" \\\n    HCCL_OP_EXPANSION_MODE=\"AIV\" \\\n    LD_PRELOAD=/usr/lib64/libjemalloc.so.2:$LD_PRELOAD\n\nARG DLINFER_TAG=main\nARG LMDEPLOY_TAG=main\nRUN --mount=type=cache,target=/root/.cache \\\n    pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple && \\\n    pip config set global.trusted-host pypi.tuna.tsinghua.edu.cn && \\\n    pip install --no-cache-dir torch==2.9.0 torch-npu==2.9.0 torchvision==0.24.0 && \\\n    TORCH_DEVICE_BACKEND_AUTOLOAD=0 DEVICE=ascend pip install git+https://github.com/DeepLink-org/dlinfer.git@${DLINFER_TAG} && \\\n    LMDEPLOY_TARGET_DEVICE=ascend pip install git+https://github.com/InternLM/lmdeploy.git@${LMDEPLOY_TAG}\n"
  },
  {
    "path": "docker/Dockerfile_dev",
    "content": "FROM nvidia/cuda:12.8.1-devel-ubuntu22.04 AS cu12.8\n\n# environment variables\nENV DEBIAN_FRONTEND=noninteractive \\\n    TZ=Etc/UTC \\\n    PATH=/opt/py3/bin:/root/.local/bin:${PATH} \\\n    TRITON_PTXAS_PATH=/usr/local/cuda/bin/ptxas \\\n    CUDA_VERSION_SHORT=cu128\n\n# Install dependencies and create python virtual environment\nRUN --mount=type=cache,target=/var/cache/apt \\\n    --mount=type=cache,target=/root/.cache \\\n    sed -i 's|http://archive.ubuntu.com|http://azure.archive.ubuntu.com|g' /etc/apt/sources.list && \\\n    apt-get update -y && \\\n    apt-get install -y --no-install-recommends \\\n        tzdata wget curl openssh-server ssh sudo git-core \\\n        libibverbs1 ibverbs-providers ibverbs-utils librdmacm1 \\\n        libibverbs-dev rdma-core libmlx5-1 libssl-dev pkg-config \\\n        vim rapidjson-dev libgoogle-glog-dev gdb cmake build-essential \\\n        python3-dev ninja-build htop tree jq unzip && \\\n    apt-get clean -y && \\\n    rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* && \\\n    # install UV\n    wget -qO- https://astral.sh/uv/install.sh | sh && \\\n    # create Python virtual environment\n    uv venv -p python3.12 --seed /opt/py3\n\n# Should be in the lmdeploy root directory when building docker image\nCOPY . /opt/lmdeploy\nWORKDIR /opt/lmdeploy\n\n# install lmdeploy and its dependencies\nRUN --mount=type=cache,target=/root/.cache/uv \\\n    uv pip install -r requirements_cuda.txt --extra-index-url https://download.pytorch.org/whl/cu128 && \\\n    uv pip install -e .\n\nRUN --mount=type=cache,target=/root/.cache/uv \\\n    docker/prepare_wheel.sh\n\nRUN --mount=type=cache,target=/root/.cache/uv \\\n    cp -r requirements /tmp/requirements && \\\n    docker/install.sh\n\n# Clean up to reduce image size\nRUN uv cache clean && \\\n    rm -rf /wheels /tmp/* /var/tmp/* /root/.cache/uv/* && \\\n    find /opt/lmdeploy -type d -name __pycache__ -exec rm -rf {} + 2>/dev/null || true && \\\n    find /opt/lmdeploy -type f -name \"*.pyc\" -delete 2>/dev/null || true\n"
  },
  {
    "path": "docker/InternVL_Dockerfile",
    "content": "ARG CUDA_VERSION=cu12\n\nFROM openmmlab/lmdeploy:latest-cu12 AS cu12\nENV CUDA_VERSION_SHORT=cu123\n\nFROM openmmlab/lmdeploy:latest-cu11 AS cu11\nENV CUDA_VERSION_SHORT=cu118\n\nFROM ${CUDA_VERSION} AS final\n\nRUN python3 -m pip install timm!=1.0.23\n\nRUN python3 -m pip install https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+${CUDA_VERSION_SHORT}torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl\n"
  },
  {
    "path": "docker/Qwen2VL_Dockerfile",
    "content": "ARG CUDA_VERSION=cu12\n\nFROM openmmlab/lmdeploy:latest-cu12 AS cu12\nENV CUDA_VERSION_SHORT=cu123\n\nFROM openmmlab/lmdeploy:latest-cu11 AS cu11\nENV CUDA_VERSION_SHORT=cu118\n\nFROM ${CUDA_VERSION} AS final\n\n# we use transformers to load vision part of qwen2_vl and it needs transformers > v4.44.2\nRUN python3 -m pip install git+https://github.com/huggingface/transformers.git\n\nRUN python3 -m pip install qwen_vl_utils\n"
  },
  {
    "path": "docker/build.sh",
    "content": "#!/bin/bash -ex\n\nmkdir -p /wheels\n\nif [[ \"${CUDA_VERSION_SHORT}\" = \"cu130\" ]]; then\n    pip install nvidia-nccl-cu13\nelse\n    pip install nvidia-nccl-cu12\nfi\n\npython3 -m build -w -o /wheels -v .\n"
  },
  {
    "path": "docker/install.sh",
    "content": "#!/bin/bash -ex\n\n# Skip system setup if virtual env already exists (e.g., in dev image)\nif [ ! -f \"/opt/py3/bin/python\" ]; then\n    # install system packages\n    export DEBIAN_FRONTEND=noninteractive TZ=Etc/UTC\n    sed -i 's|http://archive.ubuntu.com|http://azure.archive.ubuntu.com|g' /etc/apt/sources.list\n    apt-get update -y\n    apt-get install -y --no-install-recommends \\\n        tzdata wget curl ssh sudo git-core vim libibverbs1 ibverbs-providers ibverbs-utils librdmacm1 libibverbs-dev rdma-core libmlx5-1\n\n    if [[ ${PYTHON_VERSION} != \"3.10\" ]]; then\n        apt-get install -y --no-install-recommends software-properties-common\n        add-apt-repository -y ppa:deadsnakes/ppa\n        apt-get update -y\n    fi\n\n    # install python, create virtual env\n    apt-get install -y --no-install-recommends \\\n        python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv\n\n    pushd /opt >/dev/null\n        python${PYTHON_VERSION} -m venv py3\n    popd >/dev/null\n\n    # install CUDA build tools\n    if [[ \"${CUDA_VERSION_SHORT}\" = \"cu126\" ]]; then\n        apt-get install -y --no-install-recommends cuda-minimal-build-12-6 numactl dkms\n    elif [[ \"${CUDA_VERSION_SHORT}\" = \"cu128\" ]]; then\n        apt-get install -y --no-install-recommends cuda-minimal-build-12-8 numactl dkms\n    elif [[ \"${CUDA_VERSION_SHORT}\" = \"cu130\" ]]; then\n        apt-get install -y --no-install-recommends cuda-minimal-build-13-0 numactl dkms\n    fi\n\n    apt-get clean -y\n    rm -rf /var/lib/apt/lists/*\nfi\n\n# install GDRCopy debs\nif [ \"$(ls -A /wheels/*.deb 2>/dev/null)\" ]; then\n    dpkg -i /wheels/*.deb\nfi\n\n# install python packages\nexport PATH=/opt/py3/bin:$PATH\n\npip install -U pip wheel setuptools\n\nif [[ \"${CUDA_VERSION_SHORT}\" = \"cu130\" ]]; then\n    pip install nvidia-nvshmem-cu13==3.4.5\nelse\n    pip install nvidia-nvshmem-cu12==3.4.5\nfi\n\npip install /wheels/*.whl\npip install dlblas==0.0.7 dlslime==0.0.2.post1\n\n# install pre-built flash attention 3 wheel\nTORCH_VER=$(python3 -c \"import torch; print(''.join(torch.__version__.split('+')[0].split('.')))\")\n\npip install ninja einops packaging\nFA3_WHEELS_URL=\"https://windreamer.github.io/flash-attention3-wheels/${CUDA_VERSION_SHORT}_torch${TORCH_VER}\"\npip install --no-index flash_attn_3 --find-links ${FA3_WHEELS_URL}\n\n# install requirements/serve.txt dependencies such as timm\nif [ -f /tmp/requirements/serve.txt ]; then\n    pip install -r /tmp/requirements/serve.txt\nfi\n\nif [[ \"${CUDA_VERSION_SHORT}\" = \"cu128\" ]]; then\n    # As described in https://github.com/InternLM/lmdeploy/pull/4313,\n    # window registration may cause memory leaks in NCCL 2.27, NCCL 2.28+ resolves the issue,\n    # but turbomind engine will use nccl GIN for EP in future, which is brought in since 2.29\n    pip install \"nvidia-nccl-cu12>2.29\"\nfi\n"
  },
  {
    "path": "docker/prepare_wheel.sh",
    "content": "#!/bin/bash -ex\n\nexport PATH=/opt/py3/bin:$PATH\n\npip install \"cmake<4.0\" wheel ninja setuptools packaging\n\nif [[ ${PYTHON_VERSION} = \"3.13\" ]]; then\n    curl https://sh.rustup.rs -sSf | sh -s -- -y\n    . \"$HOME/.cargo/env\"\n\n    pip install setuptools_rust\n    pip wheel -v --no-build-isolation --no-deps -w /wheels \"git+https://github.com/google/sentencepiece.git@v0.2.0#subdirectory=python\"\nfi\n\nGDRCOPY_VERSION=2.5.1\nDEEP_EP_VERSION=9af0e0d  # v1.2.1\nDEEP_GEMM_VERSION=c9f8b34  # v2.1.1.post3\nFLASH_MLA_VERSION=1408756  # no release, pick the latest commit\n\n# DeepEP\nif [[ \"${CUDA_VERSION_SHORT}\" = \"cu130\" ]]; then\n    export CPLUS_INCLUDE_PATH=\"/usr/local/cuda/include/cccl\":${CPLUS_INCLUDE_PATH}\n    pip install nvidia-nvshmem-cu13==3.4.5\nelse\n    pip install nvidia-nvshmem-cu12==3.4.5\nfi\npip wheel -v --no-build-isolation --no-deps -w /wheels \"git+https://github.com/deepseek-ai/DeepEP.git@${DEEP_EP_VERSION}\"\n\n# DeepGEMM\npip wheel -v --no-build-isolation --no-deps -w /wheels \"git+https://github.com/deepseek-ai/DeepGEMM.git@${DEEP_GEMM_VERSION}\"\n\n# FlashMLA\n# sm100 compilation for Flash MLA requires NVCC 12.9 or higher\nFLASH_MLA_DISABLE_SM100=1 pip wheel -v --no-build-isolation --no-deps -w /wheels \"git+https://github.com/deepseek-ai/FlashMLA.git@${FLASH_MLA_VERSION}\"\n\n# GDRCopy debs\napt-get update -y \\\n&& apt-get install -y --no-install-recommends build-essential devscripts debhelper fakeroot pkg-config dkms\n\nwget -q https://github.com/NVIDIA/gdrcopy/archive/refs/tags/v${GDRCOPY_VERSION}.tar.gz \\\n&& tar -xzf v${GDRCOPY_VERSION}.tar.gz && rm v${GDRCOPY_VERSION}.tar.gz \\\n&& cd gdrcopy-${GDRCOPY_VERSION}/packages \\\n&& CUDA=/usr/local/cuda ./build-deb-packages.sh \\\n&& mv ./*.deb /wheels\n\n# Clean up build artifacts\ncd / && rm -rf gdrcopy-${GDRCOPY_VERSION}\napt-get clean -y && rm -rf /var/lib/apt/lists/*\n"
  },
  {
    "path": "docs/en/.readthedocs.yaml",
    "content": "version: 2\n\nformats: all\n\nbuild:\n  os: \"ubuntu-22.04\"\n  tools:\n    python: \"3.10\"\n\n\nsphinx:\n  configuration: docs/en/conf.py\n\n\npython:\n  install:\n    - requirements: requirements/docs.txt\n    - requirements: requirements/readthedocs.txt\n"
  },
  {
    "path": "docs/en/Makefile",
    "content": "# Minimal makefile for Sphinx documentation\n#\n\n# You can set these variables from the command line.\nSPHINXOPTS    =\nSPHINXBUILD   = sphinx-build\nSOURCEDIR     = .\nBUILDDIR      = _build\n\n# Put it first so that \"make\" without argument is like \"make help\".\nhelp:\n\t@$(SPHINXBUILD) -M help \"$(SOURCEDIR)\" \"$(BUILDDIR)\" $(SPHINXOPTS) $(O)\n\n.PHONY: help Makefile\n\n# Catch-all target: route all unknown targets to Sphinx using the new\n# \"make mode\" option.  $(O) is meant as a shortcut for $(SPHINXOPTS).\n%: Makefile\n\t@$(SPHINXBUILD) -M $@ \"$(SOURCEDIR)\" \"$(BUILDDIR)\" $(SPHINXOPTS) $(O)\n"
  },
  {
    "path": "docs/en/_static/css/readthedocs.css",
    "content": "table.autosummary td {\n  width: 50%\n}\n\nimg.align-center {\n  display: block;\n  margin-left: auto;\n  margin-right: auto;\n}\n"
  },
  {
    "path": "docs/en/advance/chat_template.md",
    "content": "# Customized chat template\n\nThe effect of the applied chat template can be observed by **setting log level** `INFO`.\n\nLMDeploy supports two methods of adding chat templates:\n\n- One approach is to utilize an existing conversation template by directly configuring a JSON file like the following.\n\n  ```json\n  {\n      \"model_name\": \"your awesome chat template name\",\n      \"system\": \"<|im_start|>system\\n\",\n      \"meta_instruction\": \"You are a robot developed by LMDeploy.\",\n      \"eosys\": \"<|im_end|>\\n\",\n      \"user\": \"<|im_start|>user\\n\",\n      \"eoh\": \"<|im_end|>\\n\",\n      \"assistant\": \"<|im_start|>assistant\\n\",\n      \"eoa\": \"<|im_end|>\",\n      \"separator\": \"\\n\",\n      \"capability\": \"chat\",\n      \"stop_words\": [\"<|im_end|>\"]\n  }\n  ```\n\n  The new chat template would be applied like this:\n\n  ```\n  {system}{meta_instruction}{eosys}{user}{user_content}{eoh}{assistant}{assistant_content}{eoa}{separator}{user}...\n  ```\n\n  When using the CLI tool, you can pass in a custom chat template with `--chat-template`, for example.\n\n  ```shell\n  lmdeploy serve api_server internlm/internlm2_5-7b-chat --chat-template ${JSON_FILE}\n  ```\n\n  You can also pass it in through the interface function, for example.\n\n  ```python\n  from lmdeploy import ChatTemplateConfig, serve\n  serve('internlm/internlm2_5-7b-chat',\n        chat_template_config=ChatTemplateConfig.from_json('${JSON_FILE}'))\n  ```\n\n- Another approach is to customize a Python chat template class like the existing LMDeploy chat templates. It can be used directly after successful registration. The advantages are a high degree of customization and strong controllability. Below is an example of registering an LMDeploy chat template.\n\n  ```python\n  from lmdeploy.model import MODELS, BaseChatTemplate\n\n\n  @MODELS.register_module(name='customized_model')\n  class CustomizedModel(BaseChatTemplate):\n      \"\"\"A customized chat template.\"\"\"\n\n      def __init__(self,\n                   system='<|im_start|>system\\n',\n                   meta_instruction='You are a robot developed by LMDeploy.',\n                   user='<|im_start|>user\\n',\n                   assistant='<|im_start|>assistant\\n',\n                   eosys='<|im_end|>\\n',\n                   eoh='<|im_end|>\\n',\n                   eoa='<|im_end|>',\n                   separator='\\n',\n                   stop_words=['<|im_end|>', '<|action_end|>']):\n          super().__init__(system=system,\n                           meta_instruction=meta_instruction,\n                           eosys=eosys,\n                           user=user,\n                           eoh=eoh,\n                           assistant=assistant,\n                           eoa=eoa,\n                           separator=separator,\n                           stop_words=stop_words)\n\n\n  from lmdeploy import ChatTemplateConfig, pipeline\n\n  messages = [{'role': 'user', 'content': 'who are you?'}]\n  pipe = pipeline('internlm/internlm2_5-7b-chat',\n                  chat_template_config=ChatTemplateConfig('customized_model'))\n  for response in pipe.stream_infer(messages):\n      print(response.text, end='')\n  ```\n\n  In this example, we register a LMDeploy chat template that sets the model to be created by LMDeploy, so when the user asks who the model is, the model will answer that it was created by LMDeploy.\n"
  },
  {
    "path": "docs/en/advance/context_parallel.md",
    "content": "# Context Parallel\n\nWhen the memory on a single GPU is insufficient to deploy a model, it is often deployed using tensor parallelism (TP), which generally requires `num_key_value_heads` to be divisible by `TP`. If you want to deploy with `TP > num_key_value_heads`, the kv-heads should be duplicated to meet the divisibility requirement. However, this has two disadvantages:\n\n1. The amount of available kv_cache is halved, which reducing the maximum supported session length.\n2. The maximum inference batch size is reduced, leading to lower throughput.\n\nTo address this issue, the TurboMind inference backend supports setting `attn_dp_size`, which avoids creating copies of kv-heads, but this introduces data imbalance. To eliminate data imbalance, TurboMind supports sequence parallelism, which allowing kv_cache to be stored interleaved on different cp_ranks. See the example below:\n\n```\ncp_rank=2, prompt_len=5, generation_len=4\nkv_cache stored on cp_rank0: 0, 2, 4, 6, 8\nkv_cache stored on cp_rank1: 1, 3, 5, 7\n```\n\n## Usage\n\nTaking Intern-S1 / Qwen3-235B-A22B as an example, their `num_key_value_heads` is 4. If you want to deploy with `TP=8` and avoid duplication of kv_cache, you can deploy in the following way:\n\n```\nlmdeploy serve api_server internlm/Intern-S1 --tp 8 --cp 2\n\nlmdeploy serve api_server Qwen/Qwen3-235B-A22B --tp 8 --cp 2\n```\n"
  },
  {
    "path": "docs/en/advance/debug_turbomind.md",
    "content": "# How to debug Turbomind\n\nTurbomind is implemented in C++, which is not as easy to debug as Python. This document provides basic methods for debugging Turbomind.\n\n## Prerequisite\n\nFirst, complete the local compilation according to the commands in [Install from source](../get_started/installation.md).\n\n## Configure Python debug environment\n\nSince many large companies currently use Centos 7 for online production environments, we will use Centos 7 as an example to illustrate the process.\n\n### Obtain `glibc` and `python3` versions\n\n```bash\nrpm -qa | grep glibc\nrpm -qa | grep python3\n```\n\nThe result should be similar to this:\n\n```\n[username@hostname workdir]# rpm -qa | grep glibc\nglibc-2.17-325.el7_9.x86_64\nglibc-common-2.17-325.el7_9.x86_64\nglibc-headers-2.17-325.el7_9.x86_64\nglibc-devel-2.17-325.el7_9.x86_64\n\n[username@hostname workdir]# rpm -qa | grep python3\npython3-pip-9.0.3-8.el7.noarch\npython3-rpm-macros-3-34.el7.noarch\npython3-rpm-generators-6-2.el7.noarch\npython3-setuptools-39.2.0-10.el7.noarch\npython3-3.6.8-21.el7_9.x86_64\npython3-devel-3.6.8-21.el7_9.x86_64\npython3.6.4-sre-1.el6.x86_64\n```\n\nBased on the information above, we can see that the version of `glibc` is `2.17-325.el7_9.x86_64` and the version of `python3` is `3.6.8-21.el7_9.x86_64`.\n\n### Download and install `debuginfo` library\n\nDownload `glibc-debuginfo-common-2.17-325.el7.x86_64.rpm`, `glibc-debuginfo-2.17-325.el7.x86_64.rpm`, and `python3-debuginfo-3.6.8-21.el7.x86_64.rpm` from http://debuginfo.centos.org/7/x86_64.\n\n```bash\nrpm -ivh glibc-debuginfo-common-2.17-325.el7.x86_64.rpm\nrpm -ivh glibc-debuginfo-2.17-325.el7.x86_64.rpm\nrpm -ivh python3-debuginfo-3.6.8-21.el7.x86_64.rpm\n```\n\n### Upgrade GDB\n\n```bash\nsudo yum install devtoolset-10 -y\necho \"source scl_source enable devtoolset-10\" >> ~/.bashrc\nsource ~/.bashrc\n```\n\n### Verification\n\n```bash\ngdb python3\n```\n\nThe output should be similar to this:\n\n```\n[username@hostname workdir]# gdb python3\nGNU gdb (GDB) Red Hat Enterprise Linux 9.2-10.el7\nCopyright (C) 2020 Free Software Foundation, Inc.\nLicense GPLv3+: GNU GPL version 3 or later <http://gnu.org/licenses/gpl.html>\nThis is free software: you are free to change and redistribute it.\nThere is NO WARRANTY, to the extent permitted by law.\nType \"show copying\" and \"show warranty\" for details.\nThis GDB was configured as \"x86_64-redhat-linux-gnu\".\nType \"show configuration\" for configuration details.\nFor bug reporting instructions, please see:\n<http://www.gnu.org/software/gdb/bugs/>.\nFind the GDB manual and other documentation resources online at:\n   <http://www.gnu.org/software/gdb/documentation/>.\n\nFor help, type \"help\".\nType \"apropos word\" to search for commands related to \"word\"...\nReading symbols from python3...\n(gdb)\n```\n\nIf it shows `Reading symbols from python3`, the configuration has been successful.\n\nFor other operating systems, please refer to [DebuggingWithGdb](https://wiki.python.org/moin/DebuggingWithGdb).\n\n## Set up symbolic links\n\nAfter setting up symbolic links, there is no need to install it locally with `pip` every time.\n\n```bash\n# Change directory to lmdeploy, e.g.\ncd /workdir/lmdeploy\n\n# Since it has been built in the build directory\n# Link the lib directory\ncd lmdeploy && ln -s ../build/lib . && cd ..\n# (Optional) Link compile_commands.json for clangd index\nln -s build/compile_commands.json .\n```\n\n## Start debugging\n\n````bash\n# Use gdb to start the API server with Llama-2-13b-chat-hf, e.g.\ngdb --args python3 -m lmdeploy serve api_server /workdir/Llama-2-13b-chat-hf\n\n# Set directories in gdb\nReading symbols from python3...\n(gdb) set directories /workdir/lmdeploy\n\n# Set a breakpoint using the relative path, e.g.\n(gdb) b src/turbomind/models/llama/BlockManager.cc:104\n\n# When it shows\n# ```\n# No source file named src/turbomind/models/llama/BlockManager.cc.\n# Make breakpoint pending on future shared library load? (y or [n])\n# ```\n# Just type `y` and press enter\n\n# Run\n(gdb) r\n\n# (Optional) Use https://github.com/InternLM/lmdeploy/blob/main/benchmark/profile_restful_api.py to send a request\n\npython3 profile_restful_api.py --backend lmdeploy --dataset-path /workdir/ShareGPT_V3_unfiltered_cleaned_split.json --num_prompts 1\n````\n\n## Using GDB\n\nRefer to [GDB Execution Commands](https://lldb.llvm.org/use/map.html) and happy debugging.\n"
  },
  {
    "path": "docs/en/advance/long_context.md",
    "content": "# Context length extrapolation\n\nLong text extrapolation refers to the ability of LLM to handle data longer than the training text during inference. TurboMind engine now support [LlamaDynamicNTKScalingRotaryEmbedding](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L178) and the implementation is consistent with huggingface.\n\n## Usage\n\nYou can enable the context length extrapolation abality by modifying the TurbomindEngineConfig. Edit the `session_len` to the expected length and change `rope_scaling_factor` to a number no less than 1.0.\n\nTake `internlm2_5-7b-chat-1m` as an example, which supports a context length of up to **1 million tokens**:\n\n```python\nfrom lmdeploy import pipeline, GenerationConfig, TurbomindEngineConfig\n\nbackend_config = TurbomindEngineConfig(\n        rope_scaling_factor=2.5,\n        session_len=1000000,\n        max_batch_size=1,\n        cache_max_entry_count=0.7,\n        tp=4)\npipe = pipeline('internlm/internlm2_5-7b-chat-1m', backend_config=backend_config)\nprompt = 'Use a long prompt to replace this sentence'\ngen_config = GenerationConfig(top_p=0.8,\n                              top_k=40,\n                              temperature=0.8,\n                              max_new_tokens=1024)\nresponse = pipe(prompt, gen_config=gen_config)\nprint(response)\n```\n\n## Evaluation\n\nWe use several methods to evaluate the long-context-length inference ability of LMDeploy, including [passkey retrieval](#passkey-retrieval), [needle in a haystack](#needle-in-a-haystack) and computing [perplexity](#perplexity)\n\n### Passkey Retrieval\n\nYou can try the following code to test how many times LMDeploy can retrieval the special key.\n\n```python\nimport numpy as np\nfrom lmdeploy import pipeline\nfrom lmdeploy import TurbomindEngineConfig\nimport time\n\nsession_len = 1000000\nbackend_config = TurbomindEngineConfig(\n        rope_scaling_factor=2.5,\n        session_len=session_len,\n        max_batch_size=1,\n        cache_max_entry_count=0.7,\n        tp=4)\npipe = pipeline('internlm/internlm2_5-7b-chat-1m', backend_config=backend_config)\n\n\ndef passkey_retrieval(session_len, n_round=5):\n    # create long context input\n    tok = pipe.tokenizer\n    task_description = 'There is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there.'\n    garbage = 'The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again.'\n\n    for _ in range(n_round):\n        start = time.perf_counter()\n        n_times = (session_len - 1000) // len(tok.encode(garbage))\n        n_garbage_prefix = np.random.randint(0, n_times)\n        n_garbage_suffix = n_times - n_garbage_prefix\n        garbage_prefix = ' '.join([garbage] * n_garbage_prefix)\n        garbage_suffix = ' '.join([garbage] * n_garbage_suffix)\n        pass_key = np.random.randint(1, 50000)\n        information_line = f'The pass key is {pass_key}. Remember it. {pass_key} is the pass key.'  # noqa: E501\n        final_question = 'What is the pass key? The pass key is'\n        lines = [\n            task_description,\n            garbage_prefix,\n            information_line,\n            garbage_suffix,\n            final_question,\n        ]\n\n        # inference\n        prompt = ' '.join(lines)\n        response = pipe([prompt])\n        print(pass_key, response)\n        end = time.perf_counter()\n        print(f'duration: {end - start} s')\n\npasskey_retrieval(session_len, 5)\n```\n\nThis test takes approximately 364 seconds per round when conducted on A100-80G GPUs\n\n### Needle In A Haystack\n\n[OpenCompass](https://github.com/open-compass/opencompass) offers very useful tools to perform needle-in-a-haystack evaluation. For specific instructions, please refer to the [guide](https://github.com/open-compass/opencompass/blob/main/docs/en/advanced_guides/needleinahaystack_eval.md).\n\n### Perplexity\n\nThe following codes demonstrate how to use LMDeploy to calculate perplexity.\n\n```python\nfrom transformers import AutoTokenizer\nfrom lmdeploy import TurbomindEngineConfig, pipeline\nimport numpy as np\n\n# load model and tokenizer\nmodel_repoid_or_path = 'internlm/internlm2_5-7b-chat-1m'\nbackend_config = TurbomindEngineConfig(\n        rope_scaling_factor=2.5,\n        session_len=1000000,\n        max_batch_size=1,\n        cache_max_entry_count=0.7,\n        tp=4)\npipe = pipeline(model_repoid_or_path, backend_config=backend_config)\ntokenizer = AutoTokenizer.from_pretrained(model_repoid_or_path, trust_remote_code=True)\n\n# get perplexity\ntext = 'Use a long prompt to replace this sentence'\ninput_ids = tokenizer.encode(text)\nppl = pipe.get_ppl(input_ids)[0]\nprint(ppl)\n```\n"
  },
  {
    "path": "docs/en/advance/metrics.md",
    "content": "# Production Metrics\n\nLMDeploy exposes a set of metrics via Prometheus, and provides visualization via Grafana.\n\n## Setup Guide\n\nThis section describes how to set up the monitoring stack (Prometheus + Grafana) provided in the `lmdeploy/monitoring` directory.\n\n## Prerequisites\n\n- [Docker](https://docs.docker.com/engine/install/) and [Docker Compose](https://docs.docker.com/compose/install/) installed\n\n- LMDeploy server running with metrics system enabled\n\n## Usage (DP = 1)\n\n1. **Start your LMDeploy server with metrics enabled**\n\n```\nlmdeploy serve api_server Qwen/Qwen2.5-7B-Instruct --enable-metrics\n```\n\nReplace the model path according to your needs.\nBy default, the metrics endpoint will be available at `http://<lmdeploy_server_host>:23333/metrics`.\n\n2. **Navigate to the monitoring directory**\n\n```\ncd lmdeploy/monitoring\n```\n\n3. **Start the monitoring stack**\n\n```\ndocker compose up\n```\n\nThis command will start Prometheus and Grafana in the background.\n\n4. **Access the monitoring interfaces**\n\n- Prometheus: Open your web browser and go to http://localhost:9090.\n\n- Grafana: Open your web browser and go to http://localhost:3000.\n\n5. **Log in to Grafana**\n\n- Default Username: `admin`\n\n- Default Password: `admin` You will be prompted to change the password upon your first login.\n\n6. **View the Dashboard**\n\nThe LMDeploy dashboard is pre-configured and should be available automatically.\n\n## Usage (DP > 1)\n\n1. **Start your LMDeploy server with metrics enabled**\n\nAs an example, we use the model `Qwen/Qwen2.5-7B-Instruct` with `DP=2, TP=2`. Start the service as follows:\n\n```bash\n# Proxy server\nlmdeploy serve proxy --server-port 8000 --routing-strategy 'min_expected_latency' --serving-strategy Hybrid --log-level INFO\n\n# API server\nLMDEPLOY_DP_MASTER_ADDR=127.0.0.1 \\\nLMDEPLOY_DP_MASTER_PORT=29555 \\\nlmdeploy serve api_server \\\n    Qwen/Qwen2.5-7B-Instruct \\\n    --backend pytorch \\\n    --tp 2 \\\n    --dp 2 \\\n    --proxy-url http://0.0.0.0:8000 \\\n    --nnodes 1 \\\n    --node-rank 0 \\\n    --enable-metrics\n```\n\nYou should be able to see multiple API servers added to the proxy server list. Details can be found in `lmdeploy/serve/proxy/proxy_config.json`.\n\nFor example, you may have the following API servers:\n\n```\nhttp://$host_ip:$api_server_port1\n\nhttp://$host_ip:$api_server_port2\n```\n\n2. **Modify the Prometheus configuration**\n\nWhen `DP > 1`, LMDeploy will launch one API server for each DP rank. If you want to monitor a specific API server, e.g. `http://$host_ip:$api_server_port1`, modify the configuration file `lmdeploy/monitoring/prometheus.yaml` as follows.\n\n> Note that you should use the actual host machine IP instead of `127.0.0.1` here, since LMDeploy starts the API server using the actual host IP when `DP > 1`\n\n```\nglobal:\n  scrape_interval: 5s\n  evaluation_interval: 30s\n\nscrape_configs:\n  - job_name: lmdeploy\n    static_configs:\n      - targets:\n          - '$host_ip:$api_server_port1' # <= Modify this\n```\n\n3. **Navigate to the monitoring folder and perform the same steps as described above**\n\n## Troubleshooting\n\n1. **Port conflicts**\n\nCheck if any services are occupying ports `23333` (LMDeploy server port), `9090` (Prometheus port), or `3000` (Grafana port). You can either stop the conflicting running ports or modify the config files as follows:\n\n- Modify LMDeploy server port for Prometheus scrape\n\nIn `lmdeploy/monitoring/prometheus.yaml`\n\n```\nglobal:\n  scrape_interval: 5s\n  evaluation_interval: 30s\n\nscrape_configs:\n  - job_name: lmdeploy\n    static_configs:\n      - targets:\n          - '127.0.0.1:23333' # <= Modify this LMDeploy server port 23333, need to match the running server port\n```\n\n- Modify Prometheus port\n\nIn `lmdeploy/monitoring/grafana/datasources/datasource.yaml`\n\n```\napiVersion: 1\ndatasources:\n  - name: Prometheus\n    type: prometheus\n    access: proxy\n    url: http://localhost:9090 # <= Modify this Prometheus interface port 9090\n    isDefault: true\n    editable: false\n```\n\n- Modify Grafana port:\n\nIn `lmdeploy/monitoring/docker-compose.yaml`, for example, change the port to `3090`\n\nOption 1: Add `GF_SERVER_HTTP_PORT` to the environment section.\n\n```\n  environment:\n- GF_AUTH_ANONYMOUS_ENABLED=true\n- GF_SERVER_HTTP_PORT=3090  # <= Add this line\n```\n\nOption 2: Use port mapping.\n\n```\ngrafana:\n  image: grafana/grafana:latest\n  container_name: grafana\n  ports:\n  - \"3090:3000\"  # <= Host:Container port mapping\n```\n\n2. **No data on the dashboard**\n\n- Create traffic\n\nTry to send some requests to the LMDeploy server to create certain traffic\n\n```\npython3 benchmark/profile_restful_api.py --backend lmdeploy --num-prompts 5000 --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json\n```\n\nAfter refreshing, you should be able to see data on the dashboard.\n"
  },
  {
    "path": "docs/en/advance/pytorch_multinodes.md",
    "content": "# PyTorchEngine Multi-Node Deployment Guide\n\nTo support larger-scale model deployment requirements, PyTorchEngine provides multi-node deployment support. Below are the detailed steps for deploying a `tp=16` model across two 8-GPU nodes.\n\n## 1. Create Docker Containers (Optional)\n\nTo ensure consistency across the cluster environment, it is recommended to use Docker to set up the cluster. Create containers on each node as follows:\n\n```bash\ndocker run -it \\\n    --network host \\\n    -v $MODEL_PATH:$CONTAINER_MODEL_PATH \\\n    openmmlab/lmdeploy:latest\n```\n\n> \\[!IMPORTANT\\]\n> Ensure that the model is placed in the same directory on all node containers.\n\n## 2. Set Up the Cluster Using Ray\n\n### 2.1 Start the Head Node\n\nSelect one node as the **head node** and run the following command in its container:\n\n```bash\nray start --head --port=$DRIVER_PORT\n```\n\n### 2.2 Join the Cluster\n\nOn the other nodes, use the following command in their containers to join the cluster created by the head node:\n\n```bash\nray start --address=$DRIVER_NODE_ADDR:$DRIVER_PORT\n```\n\nrun `ray status` on head node to check the cluster.\n\n> \\[!IMPORTANT\\]\n> Ensure that `DRIVER_NODE_ADDR` is the address of the head node and `DRIVER_PORT` matches the port number used during the head node initialization.\n\n## 3. Use LMDeploy Interfaces\n\nIn the head node's container, you can use all functionalities of PyTorchEngine as usual.\n\n### 3.1 Start the Server\n\n```bash\nlmdeploy serve api_server \\\n    $CONTAINER_MODEL_PATH \\\n    --backend pytorch \\\n    --tp 16\n```\n\n### 3.2 Use the Pipeline\n\n```python\nfrom lmdeploy import pipeline, PytorchEngineConfig\n\nif __name__ == '__main__':\n    model_path = '/path/to/model'\n    backend_config = PytorchEngineConfig(tp=16)\n    with pipeline(model_path, backend_config=backend_config) as pipe:\n        outputs = pipe('Hakuna Matata')\n```\n\n> \\[!NOTE\\]\n> PyTorchEngine will automatically choose the appropriate launch method (single-node/multi-node) based on the `tp` parameter and the number of devices available in the cluster. If you want to enforce the use of the Ray cluster, you can configure `distributed_executor_backend='ray'` in `PytorchEngineConfig` or use the environment variable `LMDEPLOY_EXECUTOR_BACKEND=ray`.\n\n______________________________________________________________________\n\nBy following the steps above, you can successfully deploy PyTorchEngine in a multi-node environment and leverage the Ray cluster for distributed computing.\n\n> \\[!WARNING\\]\n> To achieve better performance, we recommend users to configure a higher-quality network environment (such as using [InfiniBand](https://en.wikipedia.org/wiki/InfiniBand)) to improve engine efficiency.\n"
  },
  {
    "path": "docs/en/advance/pytorch_multithread.md",
    "content": "# PyTorchEngine Multithread\n\nWe have removed `thread_safe` mode from PytorchEngine since [PR2907](https://github.com/InternLM/lmdeploy/pull/2907). We encourage users to achieve high concurrency by using **service API** or **coroutines** whenever possible, for example:\n\n```python\nimport asyncio\nfrom lmdeploy import pipeline, PytorchEngineConfig\n\nevent_loop = asyncio.new_event_loop()\nasyncio.set_event_loop(event_loop)\n\nmodel_path = 'Llama-3.2-1B-Instruct'\npipe = pipeline(model_path, backend_config=PytorchEngineConfig())\n\nasync def _gather_output():\n    tasks = [\n        pipe.async_batch_infer('Hakuna Matata'),\n        pipe.async_batch_infer('giraffes are heartless creatures'),\n    ]\n    return await asyncio.gather(*tasks)\n\noutput = asyncio.run(_gather_output())\nprint(output[0].text)\nprint(output[1].text)\n```\n\nIf you do need multithreading, it would be easy to warp it like below:\n\n```python\nimport threading\nfrom queue import Queue\nimport asyncio\nfrom lmdeploy import pipeline, PytorchEngineConfig\n\nmodel_path = 'Llama-3.2-1B-Instruct'\n\n\nasync def _batch_infer(inque: Queue, outque: Queue, pipe):\n    while True:\n        if inque.empty():\n            await asyncio.sleep(0)\n            continue\n\n        input = inque.get_nowait()\n        output = await pipe.async_batch_infer(input)\n        outque.put(output)\n\n\ndef server(inques, outques):\n    event_loop = asyncio.new_event_loop()\n    asyncio.set_event_loop(event_loop)\n    pipe = pipeline(model_path, backend_config=PytorchEngineConfig())\n    for inque, outque in zip(inques, outques):\n        event_loop.create_task(_batch_infer(inque, outque, pipe))\n    event_loop.run_forever()\n\ndef client(inque, outque, message):\n    inque.put(message)\n    print(outque.get().text)\n\n\ninques = [Queue(), Queue()]\noutques = [Queue(), Queue()]\n\nt_server = threading.Thread(target=server, args=(inques, outques))\nt_client0 = threading.Thread(target=client, args=(inques[0], outques[0], 'Hakuna Matata'))\nt_client1 = threading.Thread(target=client, args=(inques[1], outques[1], 'giraffes are heartless creatures'))\n\nt_server.start()\nt_client0.start()\nt_client1.start()\n\nt_client0.join()\nt_client1.join()\n```\n\n> \\[!WARNING\\]\n> This is NOT recommended, as multithreading introduces additional overhead, leading to unstable inference performance.\n"
  },
  {
    "path": "docs/en/advance/pytorch_new_model.md",
    "content": "# lmdeploy.pytorch New Model Support\n\nlmdeploy.pytorch is designed to simplify the support for new models and the development of prototypes. Users can adapt new models according to their own needs.\n\n## Model Support\n\n### Configuration Loading (Optional)\n\nlmdeploy.pytorch initializes the engine based on the model's config file. If the parameter naming of the model to be integrated differs from common models in transformers, parsing errors may occur. A custom ConfigBuilder can be added to parse the configuration.\n\n```python\n# lmdeploy/pytorch/configurations/gemma.py\n\nfrom lmdeploy.pytorch.config import ModelConfig\n\nfrom .builder import AutoModelConfigBuilder\n\n\nclass GemmaModelConfigBuilder(AutoModelConfigBuilder):\n\n    @classmethod\n    def condition(cls, hf_config):\n        # Check if hf_config is suitable for this builder\n        return hf_config.model_type in ['gemma', 'gemma2']\n\n    @classmethod\n    def build(cls, hf_config, model_path: str = None):\n        # Use the hf_config loaded by transformers\n        # Construct the ModelConfig for the pytorch engine\n        return ModelConfig(hidden_size=hf_config.hidden_size,\n                           num_layers=hf_config.num_hidden_layers,\n                           num_attention_heads=hf_config.num_attention_heads,\n                           num_key_value_heads=hf_config.num_key_value_heads,\n                           bos_token_id=hf_config.bos_token_id,\n                           eos_token_id=hf_config.eos_token_id,\n                           head_dim=hf_config.head_dim,\n                           vocab_size=hf_config.vocab_size)\n```\n\nThe `lmdeploy.pytorch.check_env.check_model` function can be used to verify if the configuration can be parsed correctly.\n\n### Implementing the Model\n\nAfter ensuring that the configuration can be parsed correctly, you can start implementing the model logic. Taking the implementation of llama as an example, we need to create the model using the configuration file from transformers.\n\n```python\nclass LlamaForCausalLM(nn.Module):\n\n    # Constructor, builds the model with the given config\n    # ctx_mgr is the context manager, which can be used to pass engine configurations or additional parameters\n    def __init__(self,\n                 config: LlamaConfig,\n                 ctx_mgr: StepContextManager,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__()\n        self.config = config\n        self.ctx_mgr = ctx_mgr\n        # build LLamaModel\n        self.model = LlamaModel(config, dtype=dtype, device=device)\n        # build lm_head\n        self.lm_head = build_rowwise_linear(config.hidden_size,\n                                            config.vocab_size,\n                                            bias=False,\n                                            dtype=dtype,\n                                            device=device)\n\n    # Model inference function\n    # It is recommended to use the same parameters as below\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        past_key_values: List[List[torch.Tensor]],\n        attn_metadata: Any = None,\n        inputs_embeds: torch.Tensor = None,\n        **kwargs,\n    ):\n        hidden_states = self.model(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            inputs_embeds=inputs_embeds,\n        )\n\n        logits = self.lm_head(hidden_states)\n        logits = logits.float()\n        return logits\n```\n\nIn addition to these, the following content needs to be added:\n\n```python\nclass LlamaForCausalLM(nn.Module):\n\n    ...\n\n    # Indicates whether the model supports cudagraph\n    # Can be a callable object, receiving forward inputs\n    # Dynamically determines if cudagraph is supported\n    support_cuda_graph = True\n\n    # Builds model inputs\n    # Returns a dictionary, the keys of which must be inputs to forward\n    def prepare_inputs_for_generation(\n        self,\n        past_key_values: List[List[torch.Tensor]],\n        inputs_embeds: Optional[torch.Tensor] = None,\n        context: StepContext = None,\n    ):\n        ...\n\n    # Loads weights\n    # The model's inputs are key-value pairs of the state dict\n    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):\n        ...\n```\n\nWe have encapsulated many fused operators to simplify the model construction. These operators better support various functions such as tensor parallelism and quantization. We encourage developers to use these ops as much as possible.\n\n```python\n# Using predefined build_merged_colwise_linear, SiluAndMul, build_rowwise_linear\n# Helps us build the model faster and without worrying about tensor concurrency, quantization, etc.\nclass LlamaMLP(nn.Module):\n\n    def __init__(self,\n                 config: LlamaConfig,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__()\n        quantization_config = getattr(config, 'quantization_config', None)\n        # gate up\n        self.gate_up_proj = build_merged_colwise_linear(\n            config.hidden_size,\n            [config.intermediate_size, config.intermediate_size],\n            bias=config.mlp_bias,\n            dtype=dtype,\n            device=device,\n            quant_config=quantization_config,\n            is_tp=True,\n        )\n\n        # silu and mul\n        self.act_fn = SiluAndMul(inplace=True)\n\n        # down\n        self.down_proj = build_rowwise_linear(config.intermediate_size,\n                                              config.hidden_size,\n                                              bias=config.mlp_bias,\n                                              quant_config=quantization_config,\n                                              dtype=dtype,\n                                              device=device,\n                                              is_tp=True)\n\n    def forward(self, x):\n        \"\"\"forward.\"\"\"\n        gate_up = self.gate_up_proj(x)\n        act = self.act_fn(gate_up)\n        return self.down_proj(act)\n```\n\n### Model Registration\n\nTo ensure that the developed model implementation can be used normally, we also need to register the model in `lmdeploy/pytorch/models/module_map.py`\n\n```python\nMODULE_MAP.update({\n    'LlamaForCausalLM':\n    f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaForCausalLM',\n})\n```\n\nIf you do not wish to modify the model source code, you can also pass a custom module map from the outside, making it easier to integrate into other projects.\n\n```\nfrom lmdeploy import PytorchEngineConfig, pipeline\n\nbackend_config = PytorchEngineConfig(custom_module_map='/path/to/custom/module_map.py')\ngenerator = pipeline(model_path, backend_config=backend_config)\n```\n"
  },
  {
    "path": "docs/en/advance/pytorch_profiling.md",
    "content": "# PyTorchEngine Profiling\n\nWe provide multiple profiler to analysis the performance of PyTorchEngine.\n\n## PyTorch Profiler\n\nWe have integrated the PyTorch Profiler. You can enable it by setting environment variables when launching the pipeline or API server:\n\n```bash\n# enable profile cpu\nexport LMDEPLOY_PROFILE_CPU=1\n# enable profile cuda\nexport LMDEPLOY_PROFILE_CUDA=1\n# profile would start after 3 seconds\nexport LMDEPLOY_PROFILE_DELAY=3\n# profile 10 seconds\nexport LMDEPLOY_PROFILE_DURATION=10\n# prefix path to save profile files\nexport LMDEPLOY_PROFILE_OUT_PREFIX=\"/path/to/save/profile_\"\n```\n\nAfter the program exits, the profiling data will be saved to the path specified by `LMDEPLOY_PROFILE_OUT_PREFIX` for performance analysis.\n\n## Nsight System\n\nWe also support using Nsight System to profile NVIDIA devices.\n\n### Single GPU\n\nFor single-GPU scenarios, simply use `nsys profile`:\n\n```bash\nnsys profile python your_script.py\n```\n\n### Multi-GPU\n\nWhen using multi-GPU solutions like DP/TP/EP, set the following environment variables:\n\n```bash\n# enable nsight system\nexport LMDEPLOY_RAY_NSYS_ENABLE=1\n# prefix path to save profile files\nexport LMDEPLOY_RAY_NSYS_OUT_PREFIX=\"/path/to/save/profile_\"\n```\n\nThen launch the script or API server as usual (Do **NOT** use nsys profile here).\n\nThe profiling results will be saved under `LMDEPLOY_RAY_NSYS_OUT_PREFIX`. If `LMDEPLOY_RAY_NSYS_OUT_PREFIX` is not configured, you can find the results in `/tmp/ray/session_xxx/nsight`.\n\n## Ray timeline\n\nWe use `ray` to support multi-device deployment. You can get the ray timeline with the environments below.\n\n```bash\nexport LMDEPLOY_RAY_TIMELINE_ENABLE=1\nexport LMDEPLOY_RAY_TIMELINE_OUT_PATH=\"/path/to/save/timeline.json\"\n```\n"
  },
  {
    "path": "docs/en/advance/spec_decoding.md",
    "content": "# Speculative Decoding\n\nSpeculative decoding is an optimization technique that introcude a lightweight draft model to propose multiple next tokens and then, the main model verify and choose the longest matched tokens in a forward pass. Compared with standard auto-regressive decoding, this methold lets the system generate multiple tokens at once.\n\n> \\[!NOTE\\]\n> This is an experimental feature in lmdeploy.\n\n## Examples\n\nHere are some examples.\n\n### Eagle 3\n\n#### Prepare\n\nInstall [flash-atten3 ](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#flashattention-3-beta-release)\n\n```shell\ngit clone --depth=1 https://github.com/Dao-AILab/flash-attention.git\ncd flash-attention/hopper\npython setup.py install\n```\n\n#### pipeline\n\n```python\nfrom lmdeploy import PytorchEngineConfig, pipeline\nfrom lmdeploy.messages import SpeculativeConfig\n\n\nif __name__ == '__main__':\n\n    model_path = 'meta-llama/Llama-3.1-8B-Instruct'\n    spec_cfg = SpeculativeConfig(\n        method='eagle3',\n        num_speculative_tokens=3,\n        model='yuhuili/EAGLE3-LLaMA3.1-Instruct-8B',\n    )\n    pipe = pipeline(model_path, backend_config=PytorchEngineConfig(max_batch_size=128), speculative_config=spec_cfg)\n    response = pipe(['Hi, pls intro yourself', 'Shanghai is'])\n    print(response)\n\n```\n\n#### serving\n\n```shell\nlmdeploy serve api_server \\\nmeta-llama/Llama-3.1-8B-Instruct \\\n--backend pytorch \\\n--server-port 24545 \\\n--speculative-draft-model yuhuili/EAGLE3-LLaMA3.1-Instruct-8B \\\n--speculative-algorithm eagle3 \\\n--speculative-num-draft-tokens 3 \\\n--max-batch-size 128 \\\n--enable-metrics\n```\n\n### Deepseek MTP\n\n#### Prepare\n\nInstall [FlashMLA](https://github.com/deepseek-ai/FlashMLA?tab=readme-ov-file#installation)\n\n```shell\ngit clone https://github.com/deepseek-ai/FlashMLA.git flash-mla\ncd flash-mla\ngit submodule update --init --recursive\npip install -v .\n```\n\n#### pipeline\n\n```python\nfrom lmdeploy import PytorchEngineConfig, pipeline\nfrom lmdeploy.messages import SpeculativeConfig\n\n\nif __name__ == '__main__':\n\n    model_path = 'deepseek-ai/DeepSeek-V3'\n    spec_cfg = SpeculativeConfig(\n        method='deepseek_mtp',\n        num_speculative_tokens=3,\n    )\n    pipe = pipeline(model_path,\n                    backend_config=PytorchEngineConfig(tp=16, max_batch_size=128),\n                    speculative_config=spec_cfg)\n    response = pipe(['Hi, pls intro yourself', 'Shanghai is'])\n    print(response)\n\n```\n\n#### serving\n\n```shell\nlmdeploy serve api_server \\\ndeepseek-ai/DeepSeek-V3 \\\n--backend pytorch \\\n--server-port 24545 \\\n--tp 16 \\\n--speculative-algorithm deepseek_mtp \\\n--speculative-num-draft-tokens 3 \\\n--max-batch-size 128 \\\n--enable-metrics\n```\n"
  },
  {
    "path": "docs/en/advance/structed_output.md",
    "content": "# Structured output\n\nStructured output, also known as guided decoding, forces the model to generate text that exactly matches a user-supplied JSON schema, grammar, or regex.\nBoth the PyTorch and Turbomind backends now support structured (schema-constrained) generation.\nBelow are examples for the pipeline API and the API server.\n\n## pipeline\n\n```python\nfrom lmdeploy import pipeline\nfrom lmdeploy.messages import GenerationConfig, PytorchEngineConfig\n\nmodel = 'internlm/internlm2-chat-1_8b'\nguide = {\n    'type': 'object',\n    'properties': {\n        'name': {\n            'type': 'string'\n        },\n        'skills': {\n            'type': 'array',\n            'items': {\n                'type': 'string',\n                'maxLength': 10\n            },\n            'minItems': 3\n        },\n        'work history': {\n            'type': 'array',\n            'items': {\n                'type': 'object',\n                'properties': {\n                    'company': {\n                        'type': 'string'\n                    },\n                    'duration': {\n                        'type': 'string'\n                    }\n                },\n                'required': ['company']\n            }\n        }\n    },\n    'required': ['name', 'skills', 'work history']\n}\npipe = pipeline(model, backend_config=PytorchEngineConfig(), log_level='INFO')\ngen_config = GenerationConfig(\n    response_format=dict(type='json_schema', json_schema=dict(name='test', schema=guide)))\nresponse = pipe(['Make a self introduction please.'], gen_config=gen_config)\nprint(response)\n```\n\n## api_server\n\nFirstly, start the api_server service for the InternLM2 model.\n\n```shell\nlmdeploy serve api_server internlm/internlm2-chat-1_8b --backend pytorch\n```\n\nThe client can test using OpenAI’s python package: The output result is a response in JSON format.\n\n```python\nfrom openai import OpenAI\nguide = {\n    'type': 'object',\n    'properties': {\n        'name': {\n            'type': 'string'\n        },\n        'skills': {\n            'type': 'array',\n            'items': {\n                'type': 'string',\n                'maxLength': 10\n            },\n            'minItems': 3\n        },\n        'work history': {\n            'type': 'array',\n            'items': {\n                'type': 'object',\n                'properties': {\n                    'company': {\n                        'type': 'string'\n                    },\n                    'duration': {\n                        'type': 'string'\n                    }\n                },\n                'required': ['company']\n            }\n        }\n    },\n    'required': ['name', 'skills', 'work history']\n}\nresponse_format=dict(type='json_schema',  json_schema=dict(name='test',schema=guide))\nmessages = [{'role': 'user', 'content': 'Make a self-introduction please.'}]\nclient = OpenAI(api_key='YOUR_API_KEY', base_url='http://0.0.0.0:23333/v1')\nmodel_name = client.models.list().data[0].id\nresponse = client.chat.completions.create(\n    model=model_name,\n    messages=messages,\n    temperature=0.8,\n    response_format=response_format,\n    top_p=0.8)\nprint(response)\n```\n"
  },
  {
    "path": "docs/en/advance/update_weights.md",
    "content": "# Update Weights\n\nLMDeploy supports update model weights online for scenes such as RL training. Here are the steps to do so.\n\n## Step 1: Launch server\n\nFor pytorch backend you have to add `--distributed-executor-backend ray`.\n\n```shell\nlmdeploy serve api_server internlm/internlm2_5-7b-chat --server-port 23333 --distributed-executor-backend ray # for pytorch backend\n```\n\n## Step 2: Offloads weights & kv cache\n\nBefore update model weights, the server should offloads weights and kv cache.\n\n```python\nfrom lmdeploy.utils import serialize_state_dict\nimport requests\n\nBASE_URL = 'http://0.0.0.0:23333'\napi_key = 'sk-xxx'\n\nheaders = {\n                \"Content-Type\": \"application/json\",\n                \"Authorization\": f\"Bearer {api_key}\",\n            }\n\n# offloads weights and kv cache with level=2\nresponse = requests.post(f\"{BASE_URL}/sleep\", headers=headers, params=dict(tags=['weights', 'kv_cache'], level=2))\nassert response.status_code == 200, response.status_code\n\n# wake up weights, the server is ready for update weights\nresponse = requests.post(f\"{BASE_URL}/wakeup\", headers=headers, params=dict(tags=['weights']))\nassert response.status_code == 200, response.status_code\n```\n\n## Step 3: Update weights\n\nSplit model weights into multi segments and update through `update_weights` endpoint.\n\n```python\nsegmented_state_dict: List[Dict[str, torch.Tensor]] = ...\nnum_segment = len(segmented_state_dict)\nfor seg_idx in range(num_segment):\n    serialized_data = serialize_state_dict(segmented_state_dict[seg_idx])\n    data = dict(serialized_named_tensors=serialized_data, finished=seg_idx == num_segment-1)\n    response = requests.post(f\"{BASE_URL}/update_weights\", headers=headers, json=data)\n    assert response.status_code == 200, f\"response.status_code = {response.status_code}\"\n\n```\n\n**Note**: For pytorch backend, lmdeploy also supports flattened bucket tensors:\n\n```python\nfrom lmdeploy.utils import serialize_state_dict, FlattenedTensorBucket, FlattenedTensorMetadata\n\nsegmented_state_dict: List[Dict[str, torch.Tensor]] = ...\nnum_segment = len(segmented_state_dict)\nfor seg_idx in range(num_segment):\n    named_tensors = list(segmented_state_dict[seg_idx].items())\n    bucket = FlattenedTensorBucket(named_tensors=named_tensors)\n    metadata = bucket.get_metadata()\n    flattened_tensor_data = dict(flattened_tensor=bucket.get_flattened_tensor(), metadata=metadata)\n    serialized_data = serialize_state_dict(flattened_tensor_data)\n    data = dict(serialized_named_tensors=serialized_data, finished=seg_idx == num_segment-1, load_format='flattened_bucket')\n    response = requests.post(f\"{BASE_URL}/update_weights\", headers=headers, json=data)\n    assert response.status_code == 200, f\"response.status_code = {response.status_code}\"\n```\n\n## Step 4: Wakeup server\n\nAfter update model weights, the server should onloads kv cache and provide serving again with the new updated weights.\n\n```python\nresponse = requests.post(f\"{BASE_URL}/wakeup\", headers=headers, params=dict(tags=['kv_cache']))\nassert response.status_code == 200, response.status_code\n```\n"
  },
  {
    "path": "docs/en/api/cli.rst",
    "content": "Command-line Tools\n===================\n\n.. sphinx_argparse_cli::\n   :module: lmdeploy.cli\n   :func: run\n   :hook:\n   :prog: lmdeploy\n"
  },
  {
    "path": "docs/en/api/openapi.rst",
    "content": "OpenAPI Endpoints\n==================\n.. currentmodule:: lmdeploy\n\nOpenAI Compatible API Endpoints\n-------------------------------\n\n.. openapi:: ../_static/openai.yaml\n    :request:\n    :examples:\n\n\n\nProxy Server API\n----------------\n\n.. openapi:: ../_static/proxy.yaml\n    :request:\n    :examples:\n"
  },
  {
    "path": "docs/en/api/pipeline.rst",
    "content": "Inference pipeline\n==================\n.. currentmodule:: lmdeploy\n\nPipeline\n--------\n.. autofunction:: pipeline\n.. autoclass:: Pipeline\n   :undoc-members:\n   :show-inheritance:\n   :members: __init__, infer, stream_infer, chat, get_ppl\n   :member-order: bysource\n\nConfig\n-------------------\n.. autoclass:: PytorchEngineConfig\n.. autoclass:: TurbomindEngineConfig\n.. autoclass:: GenerationConfig\n.. autoclass:: ChatTemplateConfig\n"
  },
  {
    "path": "docs/en/benchmark/a100_fp16.md",
    "content": "# TurboMind Benchmark on A100\n\nAll the following results are tested on A100-80G(x8) CUDA 11.8.\n\nThe tested lmdeploy version is `v0.2.0`\n\n## Request Throughput Benchmark\n\n- `batch`: the max batch size during inference\n- `tp`: the number of GPU cards for tensor parallelism\n- `num_prompts`: the number of prompts, i.e. the number of requests\n- `PRS`: **R**equest **P**er **S**econd\n- `FTL`: **F**irst **T**oken **L**atency\n\n### FP16\n\n| model        | batch | tp  | num_promts | RPS    | FTL(ave)(s) | FTL(min)(s) | FTL(max)(s) | 50%(s) | 75%(s) | 95%(s) | 99%(s) | throughput(out tok/s) | throughput(total tok/s) |\n| ------------ | ----- | --- | ---------- | ------ | ----------- | ----------- | ----------- | ------ | ------ | ------ | ------ | --------------------- | ----------------------- |\n| llama2-7b    | 256   | 1   | 3000       | 14.556 | 0.526       | 0.092       | 4.652       | 0.066  | 0.101  | 0.155  | 0.220  | 3387.419              | 6981.159                |\n| llama2-13b   | 128   | 1   | 3000       | 7.950  | 0.352       | 0.075       | 4.193       | 0.051  | 0.067  | 0.138  | 0.202  | 1850.145              | 3812.978                |\n| internlm-20b | 128   | 2   | 3000       | 10.291 | 0.287       | 0.073       | 3.845       | 0.053  | 0.072  | 0.113  | 0.161  | 2053.266              | 4345.057                |\n| llama2-70b   | 256   | 4   | 3000       | 7.231  | 1.075       | 0.139       | 14.524      | 0.102  | 0.153  | 0.292  | 0.482  | 1682.738              | 3467.969                |\n\n## Static Inference Benchmark\n\n- `batch`: the max batch size during inference\n- `tp`: the number of GPU cards for tensor parallelism\n- `prompt_tokens`: the number of input tokens\n- `output_tokens`: the number of generated tokens\n- `throughput`: the number of generated tokens per second\n- `FTL`: **F**irst **T**oken **L**atency\n\n### FP16 llama2-7b\n\n| batch | tp  | prompt_tokens | output_tokens | throughput(out tok/s) | mem(GB) | FTL(ave)(s) | FTL(min)(s) | FTL(max)(s) | 50%(s) | 75%(s) | 95%(s) | 99%(s) |\n| ----- | --- | ------------- | ------------- | --------------------- | ------- | ----------- | ----------- | ----------- | ------ | ------ | ------ | ------ |\n| 1     | 1   | 1             | 128           | 100.02                | 76.55   | 0.011       | 0.01        | 0.011       | 0.009  | 0.009  | 0.01   | 0.011  |\n| 1     | 1   | 128           | 128           | 102.21                | 76.59   | 0.022       | 0.022       | 0.022       | 0.01   | 0.01   | 0.01   | 0.01   |\n| 1     | 1   | 128           | 2048          | 98.92                 | 76.59   | 0.022       | 0.022       | 0.022       | 0.01   | 0.01   | 0.01   | 0.01   |\n| 1     | 1   | 2048          | 128           | 86.1                  | 76.77   | 0.139       | 0.139       | 0.14        | 0.01   | 0.01   | 0.01   | 0.011  |\n| 1     | 1   | 2048          | 2048          | 93.78                 | 76.77   | 0.14        | 0.139       | 0.141       | 0.011  | 0.011  | 0.011  | 0.011  |\n| 16    | 1   | 1             | 128           | 1504.72               | 76.59   | 0.021       | 0.011       | 0.031       | 0.01   | 0.011  | 0.011  | 0.013  |\n| 16    | 1   | 128           | 128           | 1272.47               | 76.77   | 0.129       | 0.023       | 0.149       | 0.011  | 0.011  | 0.012  | 0.014  |\n| 16    | 1   | 128           | 2048          | 1010.62               | 76.77   | 0.13        | 0.023       | 0.144       | 0.015  | 0.018  | 0.02   | 0.021  |\n| 16    | 1   | 2048          | 128           | 348.87                | 78.3    | 2.897       | 0.143       | 3.576       | 0.02   | 0.021  | 0.022  | 0.025  |\n| 16    | 1   | 2048          | 2048          | 601.63                | 78.3    | 2.678       | 0.142       | 3.084       | 0.025  | 0.028  | 0.03   | 0.031  |\n| 32    | 1   | 1             | 128           | 2136.73               | 76.62   | 0.079       | 0.014       | 0.725       | 0.011  | 0.012  | 0.013  | 0.021  |\n| 32    | 1   | 128           | 128           | 2125.47               | 76.99   | 0.214       | 0.022       | 0.359       | 0.012  | 0.013  | 0.014  | 0.035  |\n| 32    | 1   | 128           | 2048          | 1462.12               | 76.99   | 0.2         | 0.026       | 0.269       | 0.021  | 0.026  | 0.031  | 0.033  |\n| 32    | 1   | 2048          | 128           | 450.43                | 78.3    | 4.288       | 0.143       | 5.267       | 0.031  | 0.032  | 0.034  | 0.161  |\n| 32    | 1   | 2048          | 2048          | 733.34                | 78.34   | 4.118       | 0.19        | 5.429       | 0.04   | 0.045  | 0.05   | 0.053  |\n| 64    | 1   | 1             | 128           | 4154.81               | 76.71   | 0.042       | 0.013       | 0.21        | 0.012  | 0.018  | 0.028  | 0.041  |\n| 64    | 1   | 128           | 128           | 3024.07               | 77.43   | 0.44        | 0.026       | 1.061       | 0.014  | 0.018  | 0.026  | 0.158  |\n| 64    | 1   | 128           | 2048          | 1852.06               | 77.96   | 0.535       | 0.027       | 1.231       | 0.03   | 0.041  | 0.048  | 0.053  |\n| 64    | 1   | 2048          | 128           | 493.46                | 78.4    | 6.59        | 0.142       | 16.235      | 0.046  | 0.049  | 0.055  | 0.767  |\n| 64    | 1   | 2048          | 2048          | 755.65                | 78.4    | 39.105      | 0.142       | 116.285     | 0.047  | 0.049  | 0.051  | 0.207  |\n"
  },
  {
    "path": "docs/en/benchmark/benchmark.md",
    "content": "# Benchmark\n\nPlease install the lmdeploy precompiled package and download the script and the test dataset:\n\n```shell\npip install lmdeploy\n# clone the repo to get the benchmark script\ngit clone --depth=1 https://github.com/InternLM/lmdeploy\ncd lmdeploy\n# switch to the tag corresponding to the installed version:\ngit fetch --tags\n# Check the installed lmdeploy version:\npip show lmdeploy | grep Version\n# Then, check out the corresponding tag (replace <version> with the version string):\ngit checkout <version>\n# download the test dataset\nwget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json\n```\n\n## Benchmark offline pipeline API\n\n```shell\npython3 benchmark/profile_pipeline_api.py ShareGPT_V3_unfiltered_cleaned_split.json meta-llama/Meta-Llama-3-8B-Instruct\n```\n\nFor a comprehensive list of available arguments, please execute `python3 benchmark/profile_pipeline_api.py -h`\n\n## Benchmark offline engine API\n\n```shell\npython3 benchmark/profile_throughput.py ShareGPT_V3_unfiltered_cleaned_split.json meta-llama/Meta-Llama-3-8B-Instruct\n```\n\nDetailed argument specification can be retrieved by running `python3 benchmark/profile_throughput.py -h`\n\n## Benchmark online serving\n\nLaunch the server first (you may refer [here](../llm/api_server.md) for guide) and run the following command:\n\n```shell\npython3 benchmark/profile_restful_api.py --backend lmdeploy --num-prompts 5000 --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json\n```\n\nFor detailed argument specification of `profile_restful_api.py`, please run the help command `python3 benchmark/profile_restful_api.py -h`.\n"
  },
  {
    "path": "docs/en/benchmark/evaluate_with_opencompass.md",
    "content": "# Model Evaluation Guide\n\nThis document describes how to evaluate a model's capabilities on academic datasets using OpenCompass and LMDeploy. The complete evaluation process consists of two main stages: inference stage and evaluation stage.\n\nDuring the inference stage, the target model is first deployed as an inference service using LMDeploy. OpenCompass then sends dataset content as requests to this service and collects the generated responses.\n\nIn the evaluation stage, the OpenCompass evaluation model `opencompass/CompassVerifier-32B` is deployed as a service via LMDeploy. OpenCompass subsequently submits the inference results to this service to obtain final evaluation scores.\n\nIf sufficient computational resources are available, please refer to the [End-to-End Evaluation](#end-to-end-evaluation) section for complete workflow execution. Otherwise, we recommend following the [Step-by-Step Evaluation](#step-by-step-evaluation) section to execute both stages sequentially.\n\n## Environment Setup\n\n```shell\npip install lmdeploy\npip install \"opencompass[full]\"\n\n# Download the lmdeploy source code, which will be used in subsequent steps to access eval script and configuration\ngit clone --depth=1 https://github.com/InternLM/lmdeploy.git\n```\n\nIt is recommended to install LMDeploy and OpenCompass in separate Python virtual environments to avoid potential dependency conflicts.\n\n## End-to-End Evaluation\n\n1. **Deploy Target Model**\n\n```shell\nlmdeploy serve api_server <model_path> --server-port 10000 <--other-options>\n```\n\n2. **Deploy Evaluation Model (Judger)**\n\n```shell\nlmdeploy serve api_server opencompass/CompassVerifier-32B --server-port 20000 --tp 2\n```\n\n3. **Generate Evaluation Configuration and Execute**\n\n```shell\n\ncd {the/root/path/of/lmdeploy/repo}\n\n## Specify the dataset path. OC will download the datasets automatically if they are\n## not found in the path\nexport HF_DATASETS_CACHE=/nvme4/huggingface_hub/datasets\nexport COMPASS_DATA_CACHE=/nvme1/shared/opencompass/.cache\npython eval/eval.py {task_name} \\\n    --mode all \\\n    --api-server http://{api-server-ip}:10000 \\\n    --judger-server http://{judger-server-ip}:20000 \\\n    -w {oc_output_dir}\n```\n\nFor detailed usage instructions about `eval.py`, such as specifying evaluation datasets, please run `python eval/eval.py --help`.\n\nAfter evaluation completion, results are saved in `{oc_output_dir}/{yyyymmdd_hhmmss}`, where `{yyyymmdd_hhmmss}` represents the task timestamp.\n\n## Step-by-Step Evaluation\n\n### Inference Stage\n\nThis stage generates model responses for the dataset.\n\n1. **Deploy Target Model**\n\n```shell\nlmdeploy serve api_server <model_path> --server-port 10000 <--other-options>\n```\n\n2. **Generate Inference Configuration and Execute**\n\n```shell\ncd {the/root/path/of/lmdeploy/repo}\n\n## Specify the dataset path. OC will download the datasets automatically if they are\n## not found in the path\nexport COMPASS_DATA_CACHE=/nvme1/shared/opencompass/.cache\nexport HF_DATASETS_CACHE=/nvme4/huggingface_hub/datasets\n# Run inference task\npython eval/eval.py {task_name} \\\n    --mode infer \\\n    --api-server http://{api-server-ip}:10000 \\\n    -w {oc_output_dir}\n```\n\nFor detailed usage instructions about `eval.py`, such as specifying evaluation datasets, please run `python eval/eval.py --help`.\n\n### Evaluation Stage\n\nThis stage uses the evaluation model (Judger) to assess the quality of inference results.\n\n1. **Deploy Evaluation Model (Judger)**\n\n```shell\nlmdeploy serve api_server opencompass/CompassVerifier-32B --server-port 20000 --tp 2 --session-len 65536\n```\n\n2. **Generate Evaluation Configuration and Execute**\n\n```shell\ncd {the/root/path/of/lmdeploy/repo}\n\n## Specify the dataset path. OC will download the datasets automatically if they are\n## not found in the path\nexport COMPASS_DATA_CACHE=/nvme1/shared/opencompass/.cache\nexport HF_DATASETS_CACHE=/nvme4/huggingface_hub/datasets\n# Run evaluation task\nopencompass /path/to/judger_config.py -m eval -w {oc_output_dir} -r {yyyymmdd_hhmmss}\n```\n\nImportant Notes:\n\n- `task_name` must be identical to the one used in the inference stage\n- The `oc_output_dir` specified with `-w` must match the directory used in the inference stage\n- The `-r` parameter indicates \"previous outputs & results\" and should specify the timestamp directory generated during the inference stage (the subdirectory under `{oc_output_dir}`)\n\nFor detailed usage instructions about `eval.py`, such as specifying evaluation datasets, please run `python eval/eval.py --help`.\n"
  },
  {
    "path": "docs/en/benchmark/evaluate_with_vlmevalkit.md",
    "content": "# Multi-Modal Model Evaluation Guide\n\nThis document describes how to evaluate multi-modal models' capabilities using VLMEvalKit and LMDeploy.\n\n## Environment Setup\n\n```shell\npip install lmdeploy\n\ngit clone https://github.com/open-compass/VLMEvalKit.git\ncd VLMEvalKit && pip install -e .\n```\n\nIt is recommended to install LMDeploy and VLMEvalKit in separate Python virtual environments to avoid potential dependency conflicts.\n\n## Evaluations\n\n1. **Deploy Large Multi-Modality Models (LMMs)**\n\n```shell\nlmdeploy serve api_server <model_path> --server-port 23333 <--other-options>\n```\n\n2. **Config the Evaluation Settings**\n\nModify `VLMEvalKit/vlmeval/config.py`, add following LMDeploy API configurations in the `api_models` dictionary.\n\nThe `<task_name>` is a custom name for your evaluation task (e.g., `lmdeploy_qwen3vl-4b`). The `model` parameter should match the `<model_path>` used in the `lmdeploy serve` command.\n\n```python\n// filepath: VLMEvalKit/vlmeval/config.py\n// ...existing code...\napi_models = {\n    # lmdeploy api\n    ...,\n    \"<task_name>\": partial(\n        LMDeployAPI,\n        api_base=\"http://0.0.0.0:23333/v1/chat/completions\",\n        model=\"<model_path>\",\n        retry=4,\n        timeout=1200,\n        temperature=0.7, # modify if needed\n        max_new_tokens=16384, # modify if needed\n    ),\n    ...\n}\n// ...existing code...\n```\n\n3. **Start Evaluations**\n\n```shell\ncd VLMEvalKit\npython run.py --data OCRBench --model <task_name> --api-nproc 16 --reuse --verbose --api 123\n```\n\nThe `<task_name>` should match the one used in the above config file.\n\nParameter explanations:\n\n- `--data`: Specify the dataset for evaluation (e.g., `OCRBench`).\n- `--model`: Specify the model name, which must match the `<task_name>` in your `config.py`.\n- `--api-nproc`: Specify the number of parallel API calls.\n- `--reuse`: Reuse previous inference results to avoid re-running completed evaluations.\n- `--verbose`: Enable verbose logging.\n"
  },
  {
    "path": "docs/en/conf.py",
    "content": "#\n# Configuration file for the Sphinx documentation builder.\n#\n# This file does only contain a selection of the most common options. For a\n# full list see the documentation:\n# http://www.sphinx-doc.org/en/master/config\n\n# -- Path setup --------------------------------------------------------------\n\n# If extensions (or modules to document with autodoc) are in another directory,\n# add these directories to sys.path here. If the directory is relative to the\n# documentation root, use os.path.abspath to make it absolute, like shown here.\n#\nimport os\nimport sys\nfrom pathlib import Path\n\nfrom fastapi import FastAPI\nfrom fastapi.responses import Response\nfrom yaml import safe_dump\n\nsys.path.insert(0, os.path.abspath('../..'))\n\nfrom lmdeploy.serve.openai.api_server import router  # noqa: E402\nfrom lmdeploy.serve.proxy.proxy import app as proxy_server  # noqa: E402\n\nversion_file = '../../lmdeploy/version.py'\nwith open(version_file, 'r') as f:\n    exec(compile(f.read(), version_file, 'exec'))\n__version__ = locals()['__version__']\n\n# -- Project information -----------------------------------------------------\n\nproject = 'lmdeploy'\ncopyright = '2021-2024, OpenMMLab'\nauthor = 'LMDeploy Authors'\n\n# The short X.Y version\nversion = __version__\n# The full version, including alpha/beta/rc tags\nrelease = __version__\n\n# -- Generate OpenAPI Spec -----------------------------------------------------\n\nopenai_server = FastAPI()\nopenai_server.include_router(router)\n\n\n@openai_server.get('/metrics',\n                   response_class=Response,\n                   responses={\n                       200: {\n                           'content': {\n                               'text/plain': {}\n                           },\n                           'description': 'Prometheus metrics data'\n                       },\n                       404: {\n                           'description': 'Metrics Endpoint not enabled'\n                       }\n                   })\ndef metrics():\n    \"\"\"**[Optional]** Prometheus metrics endpoint.\"\"\"\n    pass\n\n\nspec_dir = Path('_static')\nspec_dir.mkdir(exist_ok=True)\n\nwith open(spec_dir / 'openai.yaml', 'w', encoding='utf-8') as f:\n    f.write(safe_dump(openai_server.openapi()))\n\nwith open(spec_dir / 'proxy.yaml', 'w', encoding='utf-8') as f:\n    f.write(safe_dump(proxy_server.openapi()))\n\n# -- General configuration ---------------------------------------------------\n\n# If your documentation needs a minimal Sphinx version, state it here.\n#\n# needs_sphinx = '1.0'\n\n# Add any Sphinx extension module names here, as strings. They can be\n# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom\n# ones.\n\nextensions = [\n    'myst_parser',\n    'sphinx_argparse_cli',\n    'sphinx.ext.autodoc',\n    'sphinx.ext.autosectionlabel',\n    'sphinx.ext.autosummary',\n    'sphinx.ext.intersphinx',\n    'sphinx.ext.napoleon',\n    'sphinx.ext.viewcode',\n    'sphinx_autodoc_typehints',\n    'sphinx_copybutton',\n    'sphinx_tabs.tabs',\n    'sphinxcontrib.mermaid',\n    'sphinxcontrib.openapi',\n]  # yapf: disable\n\n\nautosectionlabel_prefix_document = True\n\n# Add any paths that contain templates here, relative to this directory.\ntemplates_path = ['_templates']\n\n# The suffix(es) of source filenames.\n# You can specify multiple suffix as a list of string:\n#\nsource_suffix = {\n    '.rst': 'restructuredtext',\n    '.md': 'markdown',\n}\n\n# The master toctree document.\nmaster_doc = 'index'\n\n# The language for content autogenerated by Sphinx. Refer to documentation\n# for a list of supported languages.\n#\n# This is also used if you do content translation via gettext catalogs.\n# Usually you set \"language\" from the command line for these cases.\nlanguage = 'en'\n\n# List of patterns, relative to source directory, that match files and\n# directories to ignore when looking for source files.\n# This pattern also affects html_static_path and html_extra_path.\nexclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']\n\n# The name of the Pygments (syntax highlighting) style to use.\npygments_style = 'sphinx'\n\n# -- Options for HTML output -------------------------------------------------\n\n# The theme to use for HTML and HTML Help pages.  See the documentation for\n# a list of builtin themes.\n#\n# html_theme = 'sphinx_rtd_theme'\nhtml_theme = 'sphinx_book_theme'\nhtml_logo = '_static/image/lmdeploy-logo.svg'\nhtml_title = project\nhtml_copy_source = True\nhtml_last_updated_fmt = ''\n# Theme options are theme-specific and customize the look and feel of a theme\n# further.  For a list of options available for each theme, see the\n# documentation.\n#\nhtml_theme_options = {\n    'path_to_docs': 'docs/en',\n    'repository_url': 'https://github.com/InternLM/lmdeploy',\n    'repository_branch': 'main',\n    # 'show_navbar_depth': 3,\n    # 'navigation_depth': 4,\n    # 'collapse_navigation': False,\n    'use_edit_page_button': True,\n    'use_source_button': True,\n    'use_issues_button': True,\n    'use_repository_button': True,\n    'use_download_button': True,\n    'use_sidenotes': True,\n    # 'show_toc_level': 2,\n    # \"icon_links\": [\n    #     {\n    #         \"name\": \"切换至简体中文\",\n    #         \"url\": \"https://lmdeploy.readthedocs.io/en/latest\",\n    #         \"icon\": \"https://img.shields.io/badge/Doc-%E7%AE%80%E4%BD%93%E4%B8%AD%E6%96%87-blue\", # noqa: #501\n    #         \"type\": \"url\",\n    #     },\n    # ],\n}\n\n# Add any paths that contain custom static files (such as style sheets) here,\n# relative to this directory. They are copied after the builtin static files,\n# so a file named \"default.css\" will overwrite the builtin \"default.css\".\nhtml_static_path = ['_static']\nhtml_css_files = ['css/readthedocs.css']\n\n# Enable ::: for my_st\nmyst_enable_extensions = [\n    'dollarmath',\n    'amsmath',\n    'deflist',\n    # \"html_admonition\",\n    # \"html_image\",\n    'colon_fence',\n    # \"smartquotes\",\n    # \"replacements\",\n    # \"linkify\",\n    # \"substitution\",\n]\nmyst_heading_anchors = 5\n\n# Custom sidebar templates, must be a dictionary that maps document names\n# to template names.\n#\n# The default sidebars (for documents that don't match any pattern) are\n# defined by theme itself.  Builtin themes are using these templates by\n# default: ``['localtoc.html', 'relations.html', 'sourcelink.html',\n# 'searchbox.html']``.\n#\n# html_sidebars = {}\n\n# -- Options for HTMLHelp output ---------------------------------------------\n\n# Output file base name for HTML help builder.\nhtmlhelp_basename = 'lmdeploydoc'\n\n# -- Options for LaTeX output ------------------------------------------------\n\nlatex_elements = {\n    # The paper size ('letterpaper' or 'a4paper').\n    #\n    # 'papersize': 'letterpaper',\n\n    # The font size ('10pt', '11pt' or '12pt').\n    #\n    # 'pointsize': '10pt',\n\n    # Additional stuff for the LaTeX preamble.\n    #\n    # 'preamble': '',\n\n    # Latex figure (float) alignment\n    #\n    # 'figure_align': 'htbp',\n}\n\n# Grouping the document tree into LaTeX files. List of tuples\n# (source start file, target name, title,\n#  author, documentclass [howto, manual, or own class]).\nlatex_documents = [\n    (master_doc, 'lmdeploy.tex', 'lmdeploy Documentation', 'LMDeploy Contributors', 'manual'),\n]\n\n# -- Options for manual page output ------------------------------------------\n\n# One entry per manual page. List of tuples\n# (source start file, name, description, authors, manual section).\nman_pages = [(master_doc, 'lmdeploy', 'lmdeploy Documentation', [author], 1)]\n\n# -- Options for Texinfo output ----------------------------------------------\n\n# Grouping the document tree into Texinfo files. List of tuples\n# (source start file, target name, title, author,\n#  dir menu entry, description, category)\ntexinfo_documents = [\n    (master_doc, 'lmdeploy', 'lmdeploy Documentation', author, 'lmdeploy', 'One line description of project.',\n     'Miscellaneous'),\n]\n\n# -- Options for Epub output -------------------------------------------------\n\n# Bibliographic Dublin Core info.\nepub_title = project\n\n# The unique identifier of the text. This can be a ISBN number\n# or the project homepage.\n#\n# epub_identifier = ''\n\n# A unique identification for the text.\n#\n# epub_uid = ''\n\n# A list of files that should not be packed into the epub file.\nepub_exclude_files = ['search.html']\n\n# -- Extension configuration -------------------------------------------------\n# Ignore >>> when copying code\ncopybutton_prompt_text = r'>>> |\\.\\.\\. '\ncopybutton_prompt_is_regexp = True\n\nautodoc_preserve_defaults = True\nnavigation_with_keys = False\n\n# Mock out external dependencies here,\n# otherwise the autodoc pages may be blank.\nautodoc_mock_imports = [\n    'torch',\n    'torchvision',\n    'transformers',\n    '_turbomind',\n    'triton',\n]\n\nautodoc_type_aliases = {'PydanticDataclass': 'pydantic.dataclasses.PydanticDataclass'}\n\nintersphinx_mapping = {\n    'python': ('https://docs.python.org/3.10', None),\n    'typing_extensions': ('https://typing-extensions.readthedocs.io/en/latest', None),\n    'pillow': ('https://pillow.readthedocs.io/en/stable', None),\n    'numpy': ('https://numpy.org/doc/stable', None),\n    'torch': ('https://pytorch.org/docs/stable', None),\n    'torchvision': ('https://pytorch.org/vision/stable', None),\n}\n"
  },
  {
    "path": "docs/en/faq.md",
    "content": "# FAQ\n\n## ModuleNotFoundError\n\n### No module named 'mmengine.config.lazy'\n\nThere is probably a cached mmengine in your local host. Try to install its latest version.\n\n```shell\npip install --upgrade mmengine\n```\n\n### No module named '\\_turbomind'\n\nIt may have been caused by the following reasons.\n\n1. You haven't installed lmdeploy's precompiled package. `_turbomind` is the pybind package of c++ turbomind, which involves compilation. It is recommended that you install the precompiled one.\n\n```shell\npip install lmdeploy[all]\n```\n\n2. If you have installed it and still encounter this issue, it is probably because you are executing turbomind-related command in the root directory of lmdeploy source code. Switching to another directory will fix it.\n\nBut if you are a developer, you often need to develop and compile locally. The efficiency of installing whl every time is too low. You can specify the path of lib after compilation through symbolic links.\n\n```shell\n# mkdir and build locally\nmkdir bld && cd bld && bash ../generate.sh && ninja -j$(nproc)\n\n# go to the lmdeploy subdirectory from bld and set symbolic links\ncd ../lmdeploy && ln -s ../bld/lib .\n\n# go to the lmdeploy root directory\ncd ..\n\n# use the python command such as check_env\npython3 -m lmdeploy check_env\n```\n\nIf you still encounter problems finding turbomind so, it means that maybe there are multiple Python environments on your local machine, and the version of Python does not match during compilation and execution. In this case, you need to set `PYTHON_EXECUTABLE` in `lmdeploy/generate.sh` according to the actual situation, such as `-DPYTHON_EXECUTABLE=/usr/local/bin/python3`. And it needs to be recompiled.\n\n## Libs\n\n### libnccl.so.2 not found\n\nMake sure you have install lmdeploy (>=v0.0.5) through `pip install lmdeploy[all]`.\n\nIf the issue still exists after lmdeploy installation, add the path of `libnccl.so.2` to environment variable LD_LIBRARY_PATH.\n\n```shell\n# Get the location of nvidia-nccl-cu11 package\npip show nvidia-nccl-cu11|grep Location\n# insert the path of \"libnccl.so.2\" to LD_LIBRARY_PATH\nexport LD_LIBRARY_PATH={Location}/nvidia/nccl/lib:$LD_LIBRARY_PATH\n```\n\n### symbol cudaFreeAsync version libcudart.so.11.0 not defined in file libcudart.so.11.0 with link time reference\n\nIt's probably due to a low-version cuda toolkit. LMDeploy runtime requires a minimum CUDA version of 11.2\n\n## Inference\n\n### RuntimeError: \\[TM\\]\\[ERROR\\] CUDA runtime error: out of memory /workspace/lmdeploy/src/turbomind/utils/allocator.h\n\nThis is usually due to a disproportionately large memory ratio for the k/v cache, which is dictated by `TurbomindEngineConfig.cache_max_entry_count`.\nThe implications of this parameter have slight variations in different versions of lmdeploy. For specifics, please refer to the source code for the \\[detailed notes\\] (https://github.com/InternLM/lmdeploy/blob/52419bd5b6fb419a5e3aaf3c3b4dea874b17e094/lmdeploy/messages.py#L107)\n\nIf you encounter this issue while using the pipeline interface, please reduce the `cache_max_entry_count` in `TurbomindEngineConfig` like following:\n\n```python\nfrom lmdeploy import pipeline, TurbomindEngineConfig\n\nbackend_config = TurbomindEngineConfig(cache_max_entry_count=0.2)\n\npipe = pipeline('internlm/internlm2_5-7b-chat',\n                backend_config=backend_config)\nresponse = pipe(['Hi, pls intro yourself', 'Shanghai is'])\nprint(response)\n```\n\nIf OOM occurs when you run CLI tools, please pass `--cache-max-entry-count` to decrease k/v cache memory ratio. For example:\n\n```shell\n# chat command\nlmdeploy chat internlm/internlm2_5-7b-chat --cache-max-entry-count 0.2\n\n# server command\nlmdeploy serve api_server internlm/internlm2_5-7b-chat --cache-max-entry-count 0.2\n```\n\n## Serve\n\n### Api Server Fetch Timeout\n\nThe image URL fetch timeout for the API server can be configured via the environment variable `LMDEPLOY_FETCH_TIMEOUT`.\nBy default, requests may take up to 10 seconds before timing out. See [lmdeploy/vl/utils.py](https://github.com/InternLM/lmdeploy/blob/7b6876eafcb842633e0efe8baabe5906d7beeeea/lmdeploy/vl/utils.py#L31) for usage.\n\n## Quantization\n\n### RuntimeError: \\[enforce fail at inline_container.cc:337\\] . unexpected pos 4566829760 vs 4566829656\n\nPlease check your disk space. This error is due to insufficient disk space when saving weights, which might be encountered when quantizing the 70B model\n\n### ModuleNotFoundError: No module named 'flash_attn'\n\nQuantizing `qwen` requires the installation of `flash-attn`. But based on feedback from community users, `flash-attn` can be challenging to install. Therefore, we have removed it from lmdeploy dependencies and now recommend that users install it it manually as needed.\n"
  },
  {
    "path": "docs/en/get_started/ascend/get_started.md",
    "content": "# Get Started with Huawei Ascend\n\nWe currently support running lmdeploy on **Atlas 800T A3, Atlas 800T A2 and Atlas 300I Duo**.\nThe usage of lmdeploy on a Huawei Ascend device is almost the same as its usage on CUDA with PytorchEngine in lmdeploy.\nPlease read the original [Get Started](../get_started.md) guide before reading this tutorial.\n\nHere is the [supported model list](../../supported_models/supported_models.md#PyTorchEngine-on-Other-Platforms).\n\n> \\[!IMPORTANT\\]\n> We have uploaded a docker image with KUNPENG CPU to aliyun.\n> Please try to pull the image by following command:\n>\n> Atlas 800T A3:\n>\n> `docker pull crpi-4crprmm5baj1v8iv.cn-hangzhou.personal.cr.aliyuncs.com/lmdeploy_dlinfer/ascend:a3-latest`\n>\n> (Atlas 800T A3 currently supports only the Qwen-series with eager mode.)\n>\n> Atlas 800T A2:\n>\n> `docker pull crpi-4crprmm5baj1v8iv.cn-hangzhou.personal.cr.aliyuncs.com/lmdeploy_dlinfer/ascend:a2-latest`\n>\n> 300I Duo:\n>\n> `docker pull crpi-4crprmm5baj1v8iv.cn-hangzhou.personal.cr.aliyuncs.com/lmdeploy_dlinfer/ascend:300i-duo-latest`\n>\n> (Atlas 300I Duo currently works only with graph mode.)\n>\n> To build the environment yourself, refer to the Dockerfiles [here](../../../../docker).\n\n## Offline batch inference\n\n### LLM inference\n\nSet `device_type=\"ascend\"` in the `PytorchEngineConfig`:\n\n```python\nfrom lmdeploy import pipeline\nfrom lmdeploy import PytorchEngineConfig\npipe = pipeline(\"internlm/internlm2_5-7b-chat\",\n        backend_config=PytorchEngineConfig(tp=1, device_type=\"ascend\"))\nquestion = [\"Shanghai is\", \"Please introduce China\", \"How are you?\"]\nresponse = pipe(question)\nprint(response)\n```\n\n### VLM inference\n\nSet `device_type=\"ascend\"` in the `PytorchEngineConfig`:\n\n```python\nfrom lmdeploy import pipeline, PytorchEngineConfig\nfrom lmdeploy.vl import load_image\npipe = pipeline('OpenGVLab/InternVL2-2B',\n        backend_config=PytorchEngineConfig(tp=1, device_type='ascend'))\nimage = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg')\nresponse = pipe(('describe this image', image))\nprint(response)\n```\n\n## Online serving\n\n### Serve a LLM model\n\nAdd `--device ascend` in the serve command.\n\n```bash\nlmdeploy serve api_server --backend pytorch --device ascend internlm/internlm2_5-7b-chat\n```\n\nRun the following commands to launch docker container for lmdeploy LLM serving:\n\n```bash\ndocker run -it --net=host crpi-4crprmm5baj1v8iv.cn-hangzhou.personal.cr.aliyuncs.com/lmdeploy_dlinfer/ascend:a2-latest \\\n    bash -i -c \"lmdeploy serve api_server --backend pytorch --device ascend internlm/internlm2_5-7b-chat\"\n```\n\n### Serve a VLM model\n\nAdd `--device ascend` in the serve command\n\n```bash\nlmdeploy serve api_server --backend pytorch --device ascend OpenGVLab/InternVL2-2B\n```\n\nRun the following commands to launch docker container for lmdeploy VLM serving:\n\n```bash\ndocker run -it --net=host crpi-4crprmm5baj1v8iv.cn-hangzhou.personal.cr.aliyuncs.com/lmdeploy_dlinfer/ascend:a2-latest \\\n    bash -i -c \"lmdeploy serve api_server --backend pytorch --device ascend OpenGVLab/InternVL2-2B\"\n```\n\n## Inference with Command line Interface\n\nAdd `--device ascend` in the serve command.\n\n```bash\nlmdeploy chat internlm/internlm2_5-7b-chat --backend pytorch --device ascend\n```\n\nRun the following commands to launch lmdeploy chatting after starting container:\n\n```bash\ndocker run -it crpi-4crprmm5baj1v8iv.cn-hangzhou.personal.cr.aliyuncs.com/lmdeploy_dlinfer/ascend:a2-latest \\\n    bash -i -c \"lmdeploy chat --backend pytorch --device ascend internlm/internlm2_5-7b-chat\"\n```\n\n## Quantization\n\n### w4a16 AWQ\n\nRun the following commands to quantize weights on Atlas 800T A2.\n\n```bash\nlmdeploy lite auto_awq $HF_MODEL --work-dir $WORK_DIR --device npu\n```\n\nPlease check [supported_models](../../supported_models/supported_models.md) before use this feature.\n\n### w8a8 SMOOTH_QUANT\n\nRun the following commands to quantize weights on Atlas 800T A2.\n\n```bash\nlmdeploy lite smooth_quant $HF_MODEL --work-dir $WORK_DIR --device npu\n```\n\nPlease check [supported_models](../../supported_models/supported_models.md) before use this feature.\n\n### int8 KV-cache Quantization\n\nAscend backend has supported offline int8 KV-cache Quantization on eager mode.\n\nPlease refer this [doc](https://github.com/DeepLink-org/dlinfer/blob/main/docs/quant/ascend_kv_quant.md) for details.\n\n## Limitations on 300I Duo\n\n1. only support dtype=float16.\n2. only support graph mode, please do not add --eager-mode.\n"
  },
  {
    "path": "docs/en/get_started/camb/get_started.md",
    "content": "# Cambricon\n\nThe usage of lmdeploy on a Cambricon device is almost the same as its usage on CUDA with PytorchEngine in lmdeploy.\nPlease read the original [Get Started](../get_started.md) guide before reading this tutorial.\n\nHere is the [supported model list](../../supported_models/supported_models.md#PyTorchEngine-on-Other-Platforms).\n\n> \\[!IMPORTANT\\]\n> We have uploaded a docker image to aliyun.\n> Please try to pull the image by following command:\n>\n> `docker pull crpi-4crprmm5baj1v8iv.cn-hangzhou.personal.cr.aliyuncs.com/lmdeploy_dlinfer/camb:latest`\n\n> \\[!IMPORTANT\\]\n> Currently, launching multi-device inference on Cambricon accelerators requires manually starting Ray.\n>\n> Below is an example for a 2-devices setup：\n>\n> ```shell\n>  export MLU_VISIBLE_DEVICES=0,1\n>  ray start --head --resources='{\"MLU\": 2}'\n> ```\n\n## Offline batch inference\n\n### LLM inference\n\nSet `device_type=\"camb\"` in the `PytorchEngineConfig`:\n\n```python\nfrom lmdeploy import pipeline\nfrom lmdeploy import PytorchEngineConfig\npipe = pipeline(\"internlm/internlm2_5-7b-chat\",\n        backend_config=PytorchEngineConfig(tp=1, device_type=\"camb\"))\nquestion = [\"Shanghai is\", \"Please introduce China\", \"How are you?\"]\nresponse = pipe(question)\nprint(response)\n```\n\n### VLM inference\n\nSet `device_type=\"camb\"` in the `PytorchEngineConfig`:\n\n```python\nfrom lmdeploy import pipeline, PytorchEngineConfig\nfrom lmdeploy.vl import load_image\npipe = pipeline('OpenGVLab/InternVL2-2B',\n        backend_config=PytorchEngineConfig(tp=1, device_type='camb'))\nimage = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg')\nresponse = pipe(('describe this image', image))\nprint(response)\n```\n\n## Online serving\n\n### Serve a LLM model\n\nAdd `--device camb` in the serve command.\n\n```bash\nlmdeploy serve api_server --backend pytorch --device camb internlm/internlm2_5-7b-chat\n```\n\nRun the following commands to launch docker container for lmdeploy LLM serving:\n\n```bash\ndocker run -it --net=host crpi-4crprmm5baj1v8iv.cn-hangzhou.personal.cr.aliyuncs.com/lmdeploy_dlinfer/camb:latest \\\n    bash -i -c \"lmdeploy serve api_server --backend pytorch --device camb internlm/internlm2_5-7b-chat\"\n```\n\n### Serve a VLM model\n\nAdd `--device camb` in the serve command\n\n```bash\nlmdeploy serve api_server --backend pytorch --device camb OpenGVLab/InternVL2-2B\n```\n\nRun the following commands to launch docker container for lmdeploy VLM serving:\n\n```bash\ndocker run -it --net=host crpi-4crprmm5baj1v8iv.cn-hangzhou.personal.cr.aliyuncs.com/lmdeploy_dlinfer/camb:latest \\\n    bash -i -c \"lmdeploy serve api_server --backend pytorch --device camb OpenGVLab/InternVL2-2B\"\n```\n\n## Inference with Command line Interface\n\nAdd `--device camb` in the serve command.\n\n```bash\nlmdeploy chat internlm/internlm2_5-7b-chat --backend pytorch --device camb\n```\n\nRun the following commands to launch lmdeploy chatting after starting container:\n\n```bash\ndocker run -it crpi-4crprmm5baj1v8iv.cn-hangzhou.personal.cr.aliyuncs.com/lmdeploy_dlinfer/camb:latest \\\n    bash -i -c \"lmdeploy chat --backend pytorch --device camb internlm/internlm2_5-7b-chat\"\n```\n"
  },
  {
    "path": "docs/en/get_started/get_started.md",
    "content": "# Quick Start\n\nThis tutorial shows the usage of LMDeploy on CUDA platform:\n\n- Offline inference of LLM model and VLM model\n- Serve a LLM or VLM model by the OpenAI compatible server\n- Console CLI to interactively chat with LLM model\n\nBefore reading further, please ensure that you have installed lmdeploy as outlined in the [installation guide](installation.md)\n\n## Offline batch inference\n\n### LLM inference\n\n```python\nfrom lmdeploy import pipeline\npipe = pipeline('internlm/internlm2_5-7b-chat')\nresponse = pipe(['Hi, pls intro yourself', 'Shanghai is'])\nprint(response)\n```\n\nWhen constructing the `pipeline`, if an inference engine is not designated between the TurboMind Engine and the PyTorch Engine, LMDeploy will automatically assign one based on [their respective capabilities](../supported_models/supported_models.md), with the TurboMind Engine taking precedence by default.\n\nHowever, you have the option to manually select an engine. For instance,\n\n```python\nfrom lmdeploy import pipeline, TurbomindEngineConfig\npipe = pipeline('internlm/internlm2_5-7b-chat',\n                backend_config=TurbomindEngineConfig(\n                    max_batch_size=32,\n                    enable_prefix_caching=True,\n                    cache_max_entry_count=0.8,\n                    session_len=8192,\n                ))\n```\n\nor,\n\n```python\nfrom lmdeploy import pipeline, PytorchEngineConfig\npipe = pipeline('internlm/internlm2_5-7b-chat',\n                backend_config=PytorchEngineConfig(\n                    max_batch_size=32,\n                    enable_prefix_caching=True,\n                    cache_max_entry_count=0.8,\n                    session_len=8192,\n                ))\n```\n\n```{note}\nThe parameter \"cache_max_entry_count\" significantly influences the GPU memory usage.\nIt means the proportion of FREE GPU memory occupied by the K/V cache after the model weights are loaded.\n\nThe default value is 0.8. The K/V cache memory is allocated once and reused repeatedly, which is why it is observed that the built pipeline and the \"api_server\" mentioned later in the next consumes a substantial amount of GPU memory.\n\nIf you encounter an Out-of-Memory(OOM) error, you may need to consider lowering the value of \"cache_max_entry_count\".\n```\n\nWhen use the callable `pipe()` to perform token generation with given prompts, you can set the sampling parameters via `GenerationConfig` as below:\n\n```python\nfrom lmdeploy import GenerationConfig, pipeline\n\npipe = pipeline('internlm/internlm2_5-7b-chat')\nprompts = ['Hi, pls intro yourself', 'Shanghai is']\nresponse = pipe(prompts,\n                gen_config=GenerationConfig(\n                    max_new_tokens=1024,\n                    top_p=0.8,\n                    top_k=40,\n                    temperature=0.6\n                ))\n```\n\nIn the `GenerationConfig`, `top_k=1` or `temperature=0.0` indicates greedy search.\n\nFor more information about pipeline, please read the [detailed tutorial](../llm/pipeline.md)\n\n### VLM inference\n\nThe usage of VLM inference pipeline is akin to that of LLMs, with the additional capability of processing image data with the pipeline.\nFor example, you can utilize the following code snippet to perform the inference with an InternVL model:\n\n```python\nfrom lmdeploy import pipeline\nfrom lmdeploy.vl import load_image\n\npipe = pipeline('OpenGVLab/InternVL2-8B')\n\nimage = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg')\nresponse = pipe(('describe this image', image))\nprint(response)\n```\n\nIn VLM pipeline, the default image processing batch size is 1. This can be adjusted by `VisionConfig`. For instance, you might set it like this:\n\n```python\nfrom lmdeploy import pipeline, VisionConfig\nfrom lmdeploy.vl import load_image\n\npipe = pipeline('OpenGVLab/InternVL2-8B',\n                vision_config=VisionConfig(\n                    max_batch_size=8\n                ))\n\nimage = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg')\nresponse = pipe(('describe this image', image))\nprint(response)\n```\n\nHowever, the larger the image batch size, the greater risk of an OOM error, because the LLM component within the VLM model pre-allocates a massive amount of memory in advance.\n\nWe encourage you to manually choose between the TurboMind Engine and the PyTorch Engine based on their respective capabilities, as detailed in [the supported-models matrix](../supported_models/supported_models.md).\nAdditionally, follow the instructions in [LLM Inference](#llm-inference) section to reduce the values of memory-related parameters\n\n## Serving\n\nAs demonstrated in the previous [offline batch inference](#offline-batch-inference) section, this part presents the respective serving methods for LLMs and VLMs.\n\n### Serve a LLM model\n\n```shell\nlmdeploy serve api_server internlm/internlm2_5-7b-chat\n```\n\nThis command will launch an OpenAI-compatible server on the localhost at port `23333`. You can specify a different server port by using the `--server-port` option.\nFor more options, consult the help documentation by running `lmdeploy serve api_server --help`. Most of these options align with the engine configuration.\n\nTo access the service, you can utilize the official OpenAI Python package `pip install openai`. Below is an example demonstrating how to use the entrypoint `v1/chat/completions`\n\n```python\nfrom openai import OpenAI\nclient = OpenAI(\n    api_key='YOUR_API_KEY',\n    base_url=\"http://0.0.0.0:23333/v1\"\n)\nmodel_name = client.models.list().data[0].id\nresponse = client.chat.completions.create(\n  model=model_name,\n  messages=[\n    {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n    {\"role\": \"user\", \"content\": \" provide three suggestions about time management\"},\n  ],\n    temperature=0.8,\n    top_p=0.8\n)\nprint(response)\n```\n\nWe encourage you to refer to the detailed guide for more comprehensive information about [serving with Docker](../llm/api_server.md), [function calls](../llm/api_server_tools.md) and other topics\n\n### Serve a VLM model\n\n```shell\nlmdeploy serve api_server OpenGVLab/InternVL2-8B\n```\n\n```{note}\nLMDeploy reuses the vision component from upstream VLM repositories. Each upstream VLM model may have different dependencies.\nConsequently, LMDeploy has decided not to include the dependencies of the upstream VLM repositories in its own dependency list.\nIf you encounter an \"ImportError\" when using LMDeploy for inference with VLM models, please install the relevant dependencies yourself.\n```\n\nAfter the service is launched successfully, you can access the VLM service in a manner similar to how you would access the `gptv4` service by modifying the `api_key` and `base_url` parameters:\n\n```python\nfrom openai import OpenAI\n\nclient = OpenAI(api_key='YOUR_API_KEY', base_url='http://0.0.0.0:23333/v1')\nmodel_name = client.models.list().data[0].id\nresponse = client.chat.completions.create(\n    model=model_name,\n    messages=[{\n        'role':\n        'user',\n        'content': [{\n            'type': 'text',\n            'text': 'Describe the image please',\n        }, {\n            'type': 'image_url',\n            'image_url': {\n                'url':\n                'https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg',\n            },\n        }],\n    }],\n    temperature=0.8,\n    top_p=0.8)\nprint(response)\n```\n\n## Inference with Command line Interface\n\nLMDeploy offers a very convenient CLI tool for users to chat with the LLM model locally. For example:\n\n```shell\nlmdeploy chat internlm/internlm2_5-7b-chat --backend turbomind\n```\n\nIt is designed to assist users in checking and verifying whether LMDeploy supports their model, whether the chat template is applied correctly, and whether the inference results are delivered smoothly.\n\nAnother tool, `lmdeploy check_env`, aims to gather the essential environment information. It is crucial when reporting an issue to us, as it helps us diagnose and resolve the problem more effectively.\n\nIf you have any doubt about their usage, you can try using the `--help` option to obtain detailed information.\n"
  },
  {
    "path": "docs/en/get_started/index.rst",
    "content": "On Other Platforms\n=================================\n\n.. toctree::\n   :maxdepth: 1\n   :caption: OtherPF\n\n   ascend/get_started.md\n   maca/get_started.md\n   camb/get_started.md\n"
  },
  {
    "path": "docs/en/get_started/installation.md",
    "content": "# Installation\n\nLMDeploy is a python library for compressing, deploying, and serving Large Language Models(LLMs) and Vision-Language Models(VLMs).\nIts core inference engines include TurboMind Engine and PyTorch Engine. The former is developed by C++ and CUDA, striving for ultimate optimization of inference performance, while the latter, developed purely in Python, aims to decrease the barriers for developers.\n\nIt supports LLMs and VLMs deployment on both Linux and Windows platform, with minimum requirement of CUDA version 11.3. Furthermore, it is compatible with the following NVIDIA GPUs:\n\n- Volta(sm70): V100\n- Turing(sm75): 20 series, T4\n- Ampere(sm80,sm86): 30 series, A10, A16, A30, A100\n- Ada Lovelace(sm89): 40 series\n\n## Install with pip (Recommend)\n\nIt is recommended installing lmdeploy using pip in a conda environment (python 3.10 - 3.13):\n\n```shell\nconda create -n lmdeploy python=3.10 -y\nconda activate lmdeploy\npip install lmdeploy\n```\n\nThe default prebuilt package is compiled on **CUDA 12**. If CUDA 11+ (>=11.3) is required, you can install lmdeploy by:\n\n```shell\nexport LMDEPLOY_VERSION=0.12.2\nexport PYTHON_VERSION=310\npip install https://github.com/InternLM/lmdeploy/releases/download/v${LMDEPLOY_VERSION}/lmdeploy-${LMDEPLOY_VERSION}+cu118-cp${PYTHON_VERSION}-cp${PYTHON_VERSION}-manylinux2014_x86_64.whl --extra-index-url https://download.pytorch.org/whl/cu118\n```\n\n## Install from source\n\nBy default, LMDeploy will build with NVIDIA CUDA support, utilizing both the Turbomind and PyTorch backends. Before installing LMDeploy, ensure you have successfully installed the CUDA Toolkit.\n\nOnce the CUDA toolkit is successfully set up, you can build and install LMDeploy with a single command:\n\n```shell\npip install git+https://github.com/InternLM/lmdeploy.git\n```\n\nYou can also explicitly disable the Turbomind backend to avoid CUDA compilation by setting the `DISABLE_TURBOMIND` environment variable:\n\n```shell\nDISABLE_TURBOMIND=1 pip install git+https://github.com/InternLM/lmdeploy.git\n```\n\nIf you prefer a specific version instead of the `main` branch of LMDeploy, you can specify it in your command:\n\n```shell\npip install https://github.com/InternLM/lmdeploy/archive/refs/tags/v0.11.0.zip\n```\n\nIf you want to build LMDeploy with support for Ascend, Cambricon, or MACA, install LMDeploy with the corresponding `LMDEPLOY_TARGET_DEVICE` environment variable.\n\nLMDeploy also supports installation on AMD GPUs with ROCm.\n\n```shell\n#The recommended way is to use the official ROCm PyTorch Docker image with pre-installed dependencies:\ndocker run -it \\\n    --cap-add=SYS_PTRACE \\\n    --security-opt seccomp=unconfined \\\n    --device=/dev/kfd \\\n    --device=/dev/dri \\\n    --group-add video \\\n    --ipc=host \\\n    --network=host \\\n    --shm-size 32G \\\n    -v /root:/workspace \\\n    rocm/pytorch:latest\n\n\n#Once inside the container, install LMDeploy with ROCm support:\nLMDEPLOY_TARGET_DEVICE=rocm pip install  git+https://github.com/InternLM/lmdeploy.git\n```\n"
  },
  {
    "path": "docs/en/get_started/maca/get_started.md",
    "content": "# MetaX-tech\n\nThe usage of lmdeploy on a MetaX-tech device is almost the same as its usage on CUDA with PytorchEngine in lmdeploy.\nPlease read the original [Get Started](../get_started.md) guide before reading this tutorial.\n\nHere is the [supported model list](../../supported_models/supported_models.md#PyTorchEngine-on-Other-Platforms).\n\n> \\[!IMPORTANT\\]\n> We have uploaded a docker image to aliyun.\n> Please try to pull the image by following command:\n>\n> `docker pull crpi-4crprmm5baj1v8iv.cn-hangzhou.personal.cr.aliyuncs.com/lmdeploy_dlinfer/maca:latest`\n\n## Offline batch inference\n\n### LLM inference\n\nSet `device_type=\"maca\"` in the `PytorchEngineConfig`:\n\n```python\nfrom lmdeploy import pipeline\nfrom lmdeploy import PytorchEngineConfig\npipe = pipeline(\"internlm/internlm2_5-7b-chat\",\n        backend_config=PytorchEngineConfig(tp=1, device_type=\"maca\"))\nquestion = [\"Shanghai is\", \"Please introduce China\", \"How are you?\"]\nresponse = pipe(question)\nprint(response)\n```\n\n### VLM inference\n\nSet `device_type=\"maca\"` in the `PytorchEngineConfig`:\n\n```python\nfrom lmdeploy import pipeline, PytorchEngineConfig\nfrom lmdeploy.vl import load_image\npipe = pipeline('OpenGVLab/InternVL2-2B',\n        backend_config=PytorchEngineConfig(tp=1, device_type='maca'))\nimage = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg')\nresponse = pipe(('describe this image', image))\nprint(response)\n```\n\n## Online serving\n\n### Serve a LLM model\n\nAdd `--device maca` in the serve command.\n\n```bash\nlmdeploy serve api_server --backend pytorch --device maca internlm/internlm2_5-7b-chat\n```\n\nRun the following commands to launch docker container for lmdeploy LLM serving:\n\n```bash\ndocker run -it --net=host crpi-4crprmm5baj1v8iv.cn-hangzhou.personal.cr.aliyuncs.com/lmdeploy_dlinfer/maca:latest \\\n    bash -i -c \"lmdeploy serve api_server --backend pytorch --device maca internlm/internlm2_5-7b-chat\"\n```\n\n### Serve a VLM model\n\nAdd `--device maca` in the serve command\n\n```bash\nlmdeploy serve api_server --backend pytorch --device maca OpenGVLab/InternVL2-2B\n```\n\nRun the following commands to launch docker container for lmdeploy VLM serving:\n\n```bash\ndocker run -it --net=host crpi-4crprmm5baj1v8iv.cn-hangzhou.personal.cr.aliyuncs.com/lmdeploy_dlinfer/maca:latest \\\n    bash -i -c \"lmdeploy serve api_server --backend pytorch --device maca OpenGVLab/InternVL2-2B\"\n```\n\n## Inference with Command line Interface\n\nAdd `--device maca` in the serve command.\n\n```bash\nlmdeploy chat internlm/internlm2_5-7b-chat --backend pytorch --device maca\n```\n\nRun the following commands to launch lmdeploy chatting after starting container:\n\n```bash\ndocker run -it crpi-4crprmm5baj1v8iv.cn-hangzhou.personal.cr.aliyuncs.com/lmdeploy_dlinfer/maca:latest \\\n    bash -i -c \"lmdeploy chat --backend pytorch --device maca internlm/internlm2_5-7b-chat\"\n```\n"
  },
  {
    "path": "docs/en/index.rst",
    "content": "Welcome to LMDeploy's tutorials!\n====================================\n\n.. figure:: ./_static/image/lmdeploy-logo.svg\n  :width: 50%\n  :align: center\n  :alt: LMDeploy\n  :class: no-scaled-link\n\n.. raw:: html\n\n   <p style=\"text-align:center\">\n   <strong>LMDeploy is a toolkit for compressing, deploying, and serving LLM.\n   </strong>\n   </p>\n\n   <p style=\"text-align:center\">\n   <script async defer src=\"https://buttons.github.io/buttons.js\"></script>\n   <a class=\"github-button\" href=\"https://github.com/InternLM/lmdeploy\" data-show-count=\"true\" data-size=\"large\" aria-label=\"Star\">Star</a>\n   <a class=\"github-button\" href=\"https://github.com/InternLM/lmdeploy/subscription\" data-icon=\"octicon-eye\" data-size=\"large\" aria-label=\"Watch\">Watch</a>\n   <a class=\"github-button\" href=\"https://github.com/InternLM/lmdeploy/fork\" data-icon=\"octicon-repo-forked\" data-size=\"large\" aria-label=\"Fork\">Fork</a>\n   </p>\n\nLMDeploy has the following core features:\n\n* **Efficient Inference**: LMDeploy delivers up to 1.8x higher request throughput than vLLM, by introducing key features like persistent batch(a.k.a. continuous batching), blocked KV cache, dynamic split&fuse, tensor parallelism, high-performance CUDA kernels and so on.\n\n* **Effective Quantization**: LMDeploy supports weight-only and k/v quantization, and the 4-bit inference performance is 2.4x higher than FP16. The quantization quality has been confirmed via OpenCompass evaluation.\n\n* **Effortless Distribution Server**: Leveraging the request distribution service, LMDeploy facilitates an easy and efficient deployment of multi-model services across multiple machines and cards.\n\n* **Excellent Compatibility**: LMDeploy supports `KV Cache Quant <https://lmdeploy.readthedocs.io/en/latest/quantization/kv_quant.html>`_, `AWQ <https://lmdeploy.readthedocs.io/en/latest/quantization/w4a16.html>`_ and `Automatic Prefix Caching <https://lmdeploy.readthedocs.io/en/latest/inference/turbomind_config.html>`_ to be used simultaneously.\n\nDocumentation\n-------------\n\n.. _get_started:\n.. toctree::\n   :maxdepth: 1\n   :caption: Get Started\n\n   get_started/installation.md\n   get_started/get_started.md\n   get_started/index.rst\n\n.. _supported_models:\n.. toctree::\n   :maxdepth: 1\n   :caption: Models\n\n   supported_models/supported_models.md\n   supported_models/reward_models.md\n\n.. _llm_deployment:\n.. toctree::\n   :maxdepth: 1\n   :caption: Large Language Models(LLMs) Deployment\n\n   llm/pipeline.md\n   llm/api_server.md\n   llm/api_server_tools.md\n   llm/api_server_reasoning.md\n   llm/api_server_lora.md\n   llm/proxy_server.md\n\n.. _vlm_deployment:\n.. toctree::\n   :maxdepth: 1\n   :caption: Vision-Language Models(VLMs) Deployment\n\n   multi_modal/vl_pipeline.md\n   multi_modal/api_server_vl.md\n   multi_modal/index.rst\n\n.. _quantization:\n.. toctree::\n   :maxdepth: 1\n   :caption: Quantization\n\n   quantization/w4a16.md\n   quantization/w8a8.md\n   quantization/kv_quant.md\n   quantization/llm_compressor.md\n\n.. _benchmark:\n.. toctree::\n   :maxdepth: 1\n   :caption: Benchmark\n\n   benchmark/benchmark.md\n   benchmark/evaluate_with_opencompass.md\n   benchmark/evaluate_with_vlmevalkit.md\n\n.. toctree::\n   :maxdepth: 1\n   :caption: Advanced Guide\n\n   inference/turbomind.md\n   inference/pytorch.md\n   advance/pytorch_new_model.md\n   advance/long_context.md\n   advance/chat_template.md\n   advance/debug_turbomind.md\n   advance/structed_output.md\n   advance/pytorch_multinodes.md\n   advance/pytorch_profiling.md\n   advance/metrics.md\n   advance/context_parallel.md\n   advance/spec_decoding.md\n   advance/update_weights.md\n\n.. toctree::\n   :maxdepth: 1\n   :caption: API Reference\n\n   api/pipeline.rst\n   api/openapi.rst\n   api/cli.rst\n\nIndices and tables\n==================\n\n* :ref:`genindex`\n* :ref:`search`\n* :ref:`routingtable`\n"
  },
  {
    "path": "docs/en/inference/load_hf.md",
    "content": "# Load huggingface model directly\n\nStarting from v0.1.0, Turbomind adds the ability to pre-process the model parameters on-the-fly while loading them from huggingface style models.\n\n## Supported model type\n\nCurrently, Turbomind support loading three types of model:\n\n1. A lmdeploy-quantized model hosted on huggingface.co, such as [llama2-70b-4bit](https://huggingface.co/lmdeploy/llama2-chat-70b-4bit), [internlm-chat-20b-4bit](https://huggingface.co/internlm/internlm-chat-20b-4bit), etc.\n2. Other LM models on huggingface.co like Qwen/Qwen-7B-Chat\n\n## Usage\n\n### 1) A lmdeploy-quantized model\n\nFor models quantized by `lmdeploy.lite` such as [llama2-70b-4bit](https://huggingface.co/lmdeploy/llama2-chat-70b-4bit), [internlm-chat-20b-4bit](https://huggingface.co/internlm/internlm-chat-20b-4bit), etc.\n\n```\nrepo_id=internlm/internlm-chat-20b-4bit\nmodel_name=internlm-chat-20b\n# or\n# repo_id=/path/to/downloaded_model\n\n# Inference by TurboMind\nlmdeploy chat $repo_id --model-name $model_name\n\n# Serving with Restful API\nlmdeploy serve api_server $repo_id --model-name $model_name --tp 1\n```\n\n### 2) Other LM models\n\nFor other LM models such as Qwen/Qwen-7B-Chat or baichuan-inc/Baichuan2-7B-Chat. LMDeploy supported models can be viewed through `lmdeploy list`.\n\n```\nrepo_id=Qwen/Qwen-7B-Chat\nmodel_name=qwen-7b\n# or\n# repo_id=/path/to/Qwen-7B-Chat/local_path\n\n# Inference by TurboMind\nlmdeploy chat $repo_id --model-name $model_name\n\n# Serving with Restful API\nlmdeploy serve api_server $repo_id --model-name $model_name --tp 1\n```\n"
  },
  {
    "path": "docs/en/inference/pytorch.md",
    "content": "# Architecture of lmdeploy.pytorch\n\n`lmdeploy.pytorch` is an inference engine in LMDeploy that offers a developer-friendly framework to users interested in deploying their own models and developing new features.\n\n## Design\n\n![pytorch arch](https://github.com/grimoire/lmdeploy/blob/media/lmdeploy_pytorch_arch.png?raw=true)\n\n## API\n\n`lmdeploy.pytorch` shares service interfaces with `Turbomind`, and the inference service is implemented by `Engine` and `EngineInstance`.\n\n`EngineInstance` acts as the sender of inference requests, encapsulating and sending requests to the `Engine` to achieve streaming inference. The inference interface of `EngineInstance` is thread-safe, allowing instances in different threads to initiate requests simultaneously. The `Engine` will automatically perform batch processing based on the current system resources.\n\nEngine is the request receiver and executor. It contain modules:\n\n- `ModelAgent` serves as a wrapper for the model, handling tasks such as loading model/adapters, managing the cache, and implementing tensor parallelism.\n- The `Scheduler` functions as the sequence manager, determining the sequences and adapters to participate in the current step, and subsequently allocating resources for them.\n- `RequestManager` is tasked with sending and receiving requests. acting as the bridge between the `Engine` and `EngineInstance`.\n\n## Engine\n\nThe Engine responses to requests in a sub-thread, following this looping sequence:\n\n1. Get new requests through `RequestManager`. These requests are cached for now.\n2. The `Scheduler` performs scheduling, deciding which cached requests should be processed and allocating resources for them.\n3. `ModelAgent` swaps the caches according to the information provided by the Scheduler, then performs inference with the patched model.\n4. The `Scheduler` updates the status of requests based to the inference results from `ModelAgent`.\n5. `RequestManager` responds to the sender (`EngineInstance`), and the process return to step 1.\n\nNow, Let's delve deeper into the modules that participate in these steps.\n\n### Scheduler\n\nIn LLM inference, caching history key and value states is a common practice to prevent redundant computation. However, as history lengths vary in a batch of sequences, we need to pad the caches to enable batching inference. Unfortunately, this padding can lead to significant memory wastage, limiting the transformer's performance.\n\n[vLLM](https://docs.vllm.ai) employs a paging-based strategy, allocating caches in page blocks to minimize extra memory usage. Our Scheduler module in the Engine shares a similar design, allocating resources based on sequence length in blocks and evicting unused blocks to support larger batching and longer session lengths.\n\nAdditionally, we support [S-LoRA](https://github.com/S-LoRA/S-LoRA), which enables the use of multiple LoRA adapters on limited memory.\n\n### ModelAgent\n\n`lmdeploy.pytorch` supports Tensor Parallelism, which leads to complex model initialization, cache allocation, and weight partitioning. ModelAgent is designed to abstract these complexities, allowing the Engine to focus solely on maintaining the pipeline.\n\nModelAgent consists of two components:\n\n1. \\`**patched_model**: : This is the transformer model after patching. In comparison to the original model, the patched model incorporates additional features such as Tensor Parallelism, quantization, and high-performance kernels.\n2. **cache_engine**: This component manages the caches. It receives commands from the Scheduler and performs host-device page swaps. Only GPU blocks are utilized for caching key/value pairs and adapters.\n\n## Features\n\n`lmdeploy.pytorch` supports new features including:\n\n- **Continuous Batching**: As the sequence length in a batch may vary, padding is often necessary for batching inference. However, large padding can lead to additional memory usage and unnecessary computation. To address this, we employ continuous batching, where all sequences are concatenated into a single long sequence to avoid padding.\n\n- **Tensor Parallelism**: The GPU memory usage of LLM might exceed the capacity of a single GPU. Tensor parallelism is utilized to accommodate such models on multiple devices. Each device handles parts of the model simultaneously, and the results are gathered to ensure correctness.\n\n- **S-LoRA**: LoRA adapters can be used to train LLM on devices with limited memory. While it's common practice to merge adapters into the model weights before deployment, loading multiple adapters in this way can consume a significant amount of memory. We support S-LoRA, where adapters are paged and swapped in when necessary. Special kernels are developed to support inference with unmerged adapters, enabling the loading of various adapters efficiently.\n\n- **Quantization**: Model quantization involves performing computations with low precision. `lmdeploy.pytorch` supports w8a8 quantization. For more details, refer to [w8a8](../quantization/w8a8.md).\n"
  },
  {
    "path": "docs/en/inference/turbomind.md",
    "content": "# Architecture of TurboMind\n\nTurboMind is an inference engine that supports high throughput inference for conversational LLMs. It's based on NVIDIA's [FasterTransformer](https://github.com/NVIDIA/FasterTransformer). Major features of TurboMind include an efficient LLaMa implementation, the persistent batch inference model and an extendable KV cache manager.\n\n## High level overview of TurboMind\n\n```\n  +--------------------+\n  |        API         |\n  +--------------------+\n          |    ^\n  request |    | stream callback\n          v    |\n  +--------------------+   fetch   +-------------------+\n  |  Persistent Batch  | <-------> |  KV Cache Manager |\n  +--------------------+   update  +-------------------+\n             ^\n             |\n             v\n+------------------------+\n|  LLaMA implementation  |\n+------------------------+\n| FT kernels & utilities |\n+------------------------+\n```\n\n## Persistent Batch\n\nYou may recognize this feature as \"continuous batching\" in other repos. But during the concurrent development of the feature, we modeled the inference of a conversational LLM as a persistently running batch whose lifetime spans the entire serving process, hence the name \"persistent batch\". To put it simply\n\n- The persistent batch as N pre-configured batch slots.\n- Requests join the batch when there are free slots available. A batch slot is released and can be reused once the generation of the requested tokens is finished.\n- __On cache-hits (see below), history tokens don't need to be decoded in every round of a conversation; generation of response tokens will start instantly.__\n- The batch grows or shrinks automatically to minimize unnecessary computations.\n\n## KV Cache Manager\n\nThe [KV cache manager](https://github.com/InternLM/lmdeploy/blob/main/src/turbomind/models/llama/SequenceManager.h) of TurboMind is a memory-pool-liked object that also implements LRU policy so that it can be viewed as a form of __cache of KV caches__. It works in the following way\n\n- All device memory required for KV cache is allocated by the manager. A fixed number of slots is pre-configured to match the memory size of the system. Each slot corresponds to the memory required by the KV cache of a single sequence. Allocation chunk-size can be configure to implement pre-allocate/on-demand style allocation policy (or something in-between).\n- When space for the KV cache of a new sequence is requested but no free slots left in the pool, the least recently used sequence is evicted from the cache and its device memory is directly reused by the new sequence. However, this is not the end of the story.\n- Fetching sequence currently resides in one of the slots resembles a _cache-hit_, the history KV cache is returned directly and no context decoding is needed.\n- Victim (evicted) sequences are not erased entirely but converted to its most compact form, i.e. token IDs. When the same sequence id is fetched later (_cache-miss_) the token IDs will be decoded by FMHA backed context decoder and converted back to KV cache.\n- The eviction and conversion are handled automatically inside TurboMind and thus transparent to the users. __From the user's aspect, system that use TurboMind has access to infinite device memory.__\n\n## LLaMa implementation\n\nOur implementation of the LLaMa family models is modified from Gpt-NeoX model in FasterTransformer. In addition to basic refactoring and modifications to support the LLaMa family, we made some improvements to enable high performance inference of conversational models, most importantly:\n\n- To support fast context decoding in multi-round conversations. We replaced the attention implementation in context decoder with a [cutlass](https://github.com/NVIDIA/cutlass)-based FMHA implementation that supports mismatched Q/K lengths.\n- We introduced indirect buffer pointers in both context FMHA and generation FMHA to support the discontinuity in KV cache within the batch.\n- To support concurrent inference with persistent batch, new synchronization mechanism was designed to orchestrate the worker threads running in tensor parallel mode.\n- To maximize the throughput, we implement INT8 KV cache support to increase the max batch size. It's effective because in real-world serving scenarios, KV cache costs more memory and consumes more memory bandwidth than weights or other activations.\n- We resolved an NCCL hang issue when running multiple model instances in TP mode within a single process, NCCL APIs are now guarded by host-side synchronization barriers.\n\n## API\n\nTurboMind supports a Python API that enables streaming output and tensor parallel mode.\n\n## Difference between FasterTransformer and TurboMind\n\nApart of the features described above, there are still many minor differences that we don't cover in this document. Notably, many capabilities of FT are dropped in TurboMind because of the difference in objectives (e.g. prefix prompt, beam search, context embedding, sparse GEMM, GPT/T5/other model families, etc)\n\n## FAQ\n\n### Supporting Huggingface models\n\nFor historical reasons, TurboMind's weight layout is based on [the original LLaMa implementation](https://github.com/facebookresearch/llama) (differ only by a transpose). The implementation in huggingface transformers uses a [different layout](https://github.com/huggingface/transformers/blob/45025d92f815675e483f32812caa28cce3a960e7/src/transformers/models/llama/convert_llama_weights_to_hf.py#L123C76-L123C76) for `W_q` and `W_k` which is handled in [deploy.py](https://github.com/InternLM/lmdeploy/blob/ff4648a1d09e5aec74cf70efef35bfaeeac552e0/lmdeploy/serve/turbomind/deploy.py#L398).\n"
  },
  {
    "path": "docs/en/inference/turbomind_config.md",
    "content": "# TurboMind Config\n\nTurboMind is one of the inference engines of LMDeploy. When using it to do model inference, you need to convert the input model into a TurboMind model. In the TurboMind model folder, besides model weight files, the TurboMind model also includes some other files, among which the most important is the configuration file `triton_models/weights/config.ini` that is closely related to inference performance.\n\nIf you are using LMDeploy version 0.0.x, please refer to the [turbomind 1.0 config](#turbomind-10-config) section to learn the relevant content in the configuration. Otherwise, please read [turbomind 2.0 config](#turbomind-2x-config) to familiarize yourself with the configuration details.\n\n## TurboMind 2.x config\n\nTake the `llama-2-7b-chat` model as an example. In TurboMind 2.x, its config.ini content is as follows:\n\n```toml\n[llama]\nmodel_name = \"llama2\"\ntensor_para_size = 1\nhead_num = 32\nkv_head_num = 32\nvocab_size = 32000\nnum_layer = 32\ninter_size = 11008\nnorm_eps = 1e-06\nattn_bias = 0\nstart_id = 1\nend_id = 2\nsession_len = 4104\nweight_type = \"fp16\"\nrotary_embedding = 128\nrope_theta = 10000.0\nsize_per_head = 128\ngroup_size = 0\nmax_batch_size = 64\nmax_context_token_num = 1\nstep_length = 1\ncache_max_entry_count = 0.5\ncache_block_seq_len = 128\ncache_chunk_size = 1\nenable_prefix_caching = false\nquant_policy = 0\nmax_position_embeddings = 2048\nrope_scaling_factor = 0.0\nuse_logn_attn = 0\n```\n\nThese parameters are composed of model attributes and inference parameters. Model attributes include the number of layers, the number of heads, dimensions, etc., and they are **not modifiable**.\n\n```toml\nmodel_name = \"llama2\"\nhead_num = 32\nkv_head_num = 32\nvocab_size = 32000\nnum_layer = 32\ninter_size = 11008\nnorm_eps = 1e-06\nattn_bias = 0\nstart_id = 1\nend_id = 2\nrotary_embedding = 128\nrope_theta = 10000.0\nsize_per_head = 128\n```\n\nComparing to TurboMind 1.0, the model attribute part in the config remains the same with TurboMind 1.0, while the inference parameters have changed\nIn the following sections, we will focus on introducing the inference parameters.\n\n### data type\n\n`weight_type` and `group_size` are the relevant parameters, **which cannot be modified**.\n\n`weight_type` represents the data type of weights. Currently, `fp16` and `int4` are supported. `int4` represents 4bit weights. When `weight_type` is `int4`, `group_size` means the group size used when quantizing weights with `awq`. In LMDeploy prebuilt package, kernels with `group size = 128` are included.\n\n### batch size\n\nThe maximum batch size is still set through `max_batch_size`. But its default value has been changed from 32 to 64, and `max_batch_size` is no longer related to `cache_max_entry_count`.\n\n### k/v cache size\n\nk/v cache memory is determined by `cache_block_seq_len` and `cache_max_entry_count`.\n\nTurboMind 2.x has implemented Paged Attention, managing the k/v cache in blocks.\n\n`cache_block_seq_len` represents the length of the token sequence in a k/v block with a default value 128. TurboMind calculates the memory size of the k/v block according to the following formula:\n\n```\ncache_block_seq_len * num_layer * kv_head_num * size_per_head * 2 * sizeof(kv_data_type)\n```\n\nFor the llama2-7b model, when storing k/v as the `half` type, the memory of a k/v block is: `128 * 32 * 32 * 128 * 2 * sizeof(half) = 64MB`\n\nThe meaning of `cache_max_entry_count` varies depending on its value:\n\n- When it's a decimal between (0, 1), `cache_max_entry_count` represents the percentage of memory used by k/v blocks. For example, if turbomind launches on a A100-80G GPU with `cache_max_entry_count` being `0.5`, the total memory used by the k/v blocks is `80 * 0.5 = 40G`.\n- When lmdeploy is greater than v0.2.1, `cache_max_entry_count` determines the percentage of **free memory** for k/v blocks, defaulting to `0.8`. For example, with Turbomind on an A100-80G GPU running a 13b model, the memory for k/v blocks would be `(80 - 26) * 0.8 = 43.2G`, utilizing 80% of the free 54G.\n- When it's an integer > 0, it represents the total number of k/v blocks\n\nThe `cache_chunk_size` indicates the size of the k/v cache chunk to be allocated each time new k/v cache blocks are needed. Different values represent different meanings:\n\n- When it is an integer > 0, `cache_chunk_size` number of k/v cache blocks are allocated.\n- When the value is -1, `cache_max_entry_count` number of k/v cache blocks are allocated.\n- When the value is 0, `sqrt(cache_max_entry_count)` number of k/v cache blocks are allocated.\n\n### prefix caching switch\n\nPrefix caching feature can be controlled by setting the `enable_prefix_caching` parameter. When set to `True`, it indicates that the feature is enabled, and when set to `False`, it indicates that the feature is disabled. The default value is `False`.\n\nPrefix caching feature is mainly applicable to scenarios where multiple requests have the same prompt prefix (such as system prompt). The k/v blocks of this identical prefix part will be cached and reused by multiple requests, thereby saving the overhead of redundant computations and improving inference performance. The longer the identical prompt prefix, the greater the performance improvement.\n\nSince k/v block is the smallest granularity for reuse in prefix caching, if the identical prompt prefix is less than one block (prefix length \\< cache_block_seq_len), there will be no improvement in inference performance.\n\n### kv quantization and inference switch\n\n- `quant_policy=4` means 4bit k/v quantization and inference\n- `quant_policy=8` indicates 8bit k/v quantization and inference\n\nPlease refer to [kv quant](../quantization/kv_quant.md) for detailed guide.\n\n### long context switch\n\nBy setting `rope_scaling_factor = 1.0`, you can enable the Dynamic NTK option of RoPE, which allows the model to use long-text input and output.\n\nRegarding the principle of Dynamic NTK, please refer to:\n\n1. https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases\n2. https://kexue.fm/archives/9675\n\nYou can also turn on [LogN attention scaling](https://kexue.fm/archives/8823) by setting `use_logn_attn = 1`.\n\n## TurboMind 1.0 config\n\nTaking the `llama-2-7b-chat` model as an example, in TurboMind 1.0, its `config.ini` content is as follows:\n\n```toml\n[llama]\nmodel_name = \"llama2\"\ntensor_para_size = 1\nhead_num = 32\nkv_head_num = 32\nvocab_size = 32000\nnum_layer = 32\ninter_size = 11008\nnorm_eps = 1e-06\nattn_bias = 0\nstart_id = 1\nend_id = 2\nsession_len = 4104\nweight_type = \"fp16\"\nrotary_embedding = 128\nrope_theta = 10000.0\nsize_per_head = 128\ngroup_size = 0\nmax_batch_size = 32\nmax_context_token_num = 4\nstep_length = 1\ncache_max_entry_count = 48\ncache_chunk_size = 1\nuse_context_fmha = 1\nquant_policy = 0\nmax_position_embeddings = 2048\nuse_dynamic_ntk = 0\nuse_logn_attn = 0\n```\n\nThese parameters are composed of model attributes and inference parameters. Model attributes include the number of layers, the number of heads, dimensions, etc., and they are **not modifiable**.\n\n```toml\nmodel_name = \"llama2\"\nhead_num = 32\nkv_head_num = 32\nvocab_size = 32000\nnum_layer = 32\ninter_size = 11008\nnorm_eps = 1e-06\nattn_bias = 0\nstart_id = 1\nend_id = 2\nrotary_embedding = 128\nrope_theta = 10000.0\nsize_per_head = 128\n```\n\nIn the following sections, we will focus on introducing the inference parameters.\n\n### data type\n\n`weight_type` and `group_size` are the relevant parameters, **which cannot be modified**.\n\n`weight_type` represents the data type of weights. Currently, `fp16` and `int4` are supported. `int4` represents 4bit weights. When `weight_type` is `int4`, `group_size` means the group size used when quantizing weights with `awq`. In LMDeploy prebuilt package, kernels with `group size = 128` are included.\n\n### batch size\n\n`max_batch_size` determines the max size of a batch during inference. In general, the larger the batch size is, the higher the throughput is. But make sure that `max_batch_size <= cache_max_entry_count`\n\n### k/v cache size\n\nTurboMind allocates k/v cache memory based on `session_len`, `cache_chunk_size`, and `cache_max_entry_count`.\n\n- `session_len` denotes the maximum length of a sequence, i.e., the size of the context window.\n- `cache_chunk_size` indicates the size of k/v sequences to be allocated when new sequences are added.\n- `cache_max_entry_count` signifies the maximum number of k/v sequences that can be cached.\n\n### kv int8 switch\n\nWhen initiating 8bit k/v inference, change `quant_policy = 4` and `use_context_fmha = 0`. Please refer to [kv int8](../quantization/kv_quant.md) for a guide.\n\n### long context switch\n\nBy setting `use_dynamic_ntk = 1`, you can enable the Dynamic NTK option of RoPE, which allows the model to use long-text input and output.\n\nRegarding the principle of Dynamic NTK, please refer to:\n\n1. https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases\n2. https://kexue.fm/archives/9675\n\nYou can also turn on [LogN attention scaling](https://kexue.fm/archives/8823) by setting `use_logn_attn = 1`.\n"
  },
  {
    "path": "docs/en/llm/api_server.md",
    "content": "# OpenAI Compatible Server\n\nThis article primarily discusses the deployment of a single LLM model across multiple GPUs on a single node, providing a service that is compatible with the OpenAI interface, as well as the usage of the service API.\nFor the sake of convenience, we refer to this service as `api_server`. Regarding parallel services with multiple models, please refer to the guide about [Request Distribution Server](proxy_server.md).\n\nIn the following sections, we will first introduce methods for starting the service, choosing the appropriate one based on your application scenario.\n\nNext, we focus on the definition of the service's RESTful API, explore the various ways to interact with the interface, and demonstrate how to try the service through the Swagger UI or LMDeploy CLI tools.\n\n## Launch Service\n\nTake the [internlm2_5-7b-chat](https://huggingface.co/internlm/internlm2_5-7b-chat) model hosted on huggingface hub as an example, you can choose one the following methods to start the service.\n\n### Option 1: Launching with lmdeploy CLI\n\n```shell\nlmdeploy serve api_server internlm/internlm2_5-7b-chat --server-port 23333\n```\n\nThe arguments of `api_server` can be viewed through the command `lmdeploy serve api_server -h`, for instance, `--tp` to set tensor parallelism, `--session-len` to specify the max length of the context window, `--cache-max-entry-count` to adjust the GPU mem ratio for k/v cache etc.\n\n### Option 2: Deploying with docker\n\nWith LMDeploy [official docker image](https://hub.docker.com/r/openmmlab/lmdeploy/tags), you can run OpenAI compatible server as follows:\n\n```shell\ndocker run --runtime nvidia --gpus all \\\n    -v ~/.cache/huggingface:/root/.cache/huggingface \\\n    --env \"HUGGING_FACE_HUB_TOKEN=<secret>\" \\\n    -p 23333:23333 \\\n    --ipc=host \\\n    openmmlab/lmdeploy:latest \\\n    lmdeploy serve api_server internlm/internlm2_5-7b-chat\n```\n\nThe parameters of `api_server` are the same with that mentioned in \"[option 1](#option-1-launching-with-lmdeploy-cli)\" section\n\n### Option 3: Deploying to Kubernetes cluster\n\nConnect to a running Kubernetes cluster and deploy the internlm2_5-7b-chat model service with [kubectl](https://kubernetes.io/docs/reference/kubectl/) command-line tool (replace `<your token>` with your huggingface hub token):\n\n```shell\nsed 's/{{HUGGING_FACE_HUB_TOKEN}}/<your token>/' k8s/deployment.yaml | kubectl create -f - \\\n    && kubectl create -f k8s/service.yaml\n```\n\nIn the example above the model data is placed on the local disk of the node (hostPath). Consider replacing it with high-availability shared storage if multiple replicas are desired, and the storage can be mounted into container using [PersistentVolume](https://kubernetes.io/docs/concepts/storage/persistent-volumes/).\n\n## RESTful API\n\nLMDeploy's RESTful API is compatible with the following three OpenAI interfaces:\n\n- /v1/chat/completions\n- /v1/models\n- /v1/completions\n\nYou can overview and try out the offered RESTful APIs by the website `http://0.0.0.0:23333` as shown in the below image after launching the service successfully.\n\n![swagger_ui](https://github.com/InternLM/lmdeploy/assets/4560679/b891dd90-3ffa-4333-92b2-fb29dffa1459)\n\nIf you need to integrate the service into your own projects or products, we recommend the following approach:\n\n### Integrate with `OpenAI`\n\nHere is an example of interaction with the endpoint `v1/chat/completions` service via the openai package.\nBefore running it, please install the openai package by `pip install openai`\n\n```python\nfrom openai import OpenAI\nclient = OpenAI(\n    api_key='YOUR_API_KEY',\n    base_url=\"http://0.0.0.0:23333/v1\"\n)\nmodel_name = client.models.list().data[0].id\nresponse = client.chat.completions.create(\n  model=model_name,\n  messages=[\n    {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n    {\"role\": \"user\", \"content\": \" provide three suggestions about time management\"},\n  ],\n    temperature=0.8,\n    top_p=0.8\n)\nprint(response)\n```\n\nIf you want to use async functions, may try the following example:\n\n```python\nimport asyncio\nfrom openai import AsyncOpenAI\n\nasync def main():\n    client = AsyncOpenAI(api_key='YOUR_API_KEY',\n                         base_url='http://0.0.0.0:23333/v1')\n    model_cards = await client.models.list()._get_page()\n    response = await client.chat.completions.create(\n        model=model_cards.data[0].id,\n        messages=[\n            {\n                'role': 'system',\n                'content': 'You are a helpful assistant.'\n            },\n            {\n                'role': 'user',\n                'content': ' provide three suggestions about time management'\n            },\n        ],\n        temperature=0.8,\n        top_p=0.8)\n    print(response)\n\nasyncio.run(main())\n```\n\nYou can invoke other OpenAI interfaces using similar methods. For more detailed information, please refer to the [OpenAI API guide](https://platform.openai.com/docs/guides/text-generation)\n\n### Integrate with lmdeploy `APIClient`\n\nBelow are some examples demonstrating how to visit the service through `APIClient`\n\nIf you want to use the `/v1/chat/completions` endpoint, you can try the following code:\n\n```python\nfrom lmdeploy.serve.openai.api_client import APIClient\napi_client = APIClient('http://{server_ip}:{server_port}')\nmodel_name = api_client.available_models[0]\nmessages = [{\"role\": \"user\", \"content\": \"Say this is a test!\"}]\nfor item in api_client.chat_completions_v1(model=model_name, messages=messages):\n    print(item)\n```\n\nFor the `/v1/completions` endpoint, you can try:\n\n```python\nfrom lmdeploy.serve.openai.api_client import APIClient\napi_client = APIClient('http://{server_ip}:{server_port}')\nmodel_name = api_client.available_models[0]\nfor item in api_client.completions_v1(model=model_name, prompt='hi'):\n    print(item)\n```\n\n### Tools\n\nMay refer to [api_server_tools](./api_server_tools.md).\n\n### Integrate with Java/Golang/Rust\n\nMay use [openapi-generator-cli](https://github.com/OpenAPITools/openapi-generator-cli) to convert `http://{server_ip}:{server_port}/openapi.json` to java/rust/golang client.\nHere is an example:\n\n```shell\n$ docker run -it --rm -v ${PWD}:/local openapitools/openapi-generator-cli generate -i /local/openapi.json -g rust -o /local/rust\n\n$ ls rust/*\nrust/Cargo.toml  rust/git_push.sh  rust/README.md\n\nrust/docs:\nChatCompletionRequest.md  EmbeddingsRequest.md  HttpValidationError.md  LocationInner.md  Prompt.md\nDefaultApi.md             GenerateRequest.md    Input.md                Messages.md       ValidationError.md\n\nrust/src:\napis  lib.rs  models\n```\n\n### Integrate with cURL\n\ncURL is a tool for observing the output of the RESTful APIs.\n\n- list served models `v1/models`\n\n```bash\ncurl http://{server_ip}:{server_port}/v1/models\n```\n\n- chat `v1/chat/completions`\n\n```bash\ncurl http://{server_ip}:{server_port}/v1/chat/completions \\\n  -H \"Content-Type: application/json\" \\\n  -d '{\n    \"model\": \"internlm-chat-7b\",\n    \"messages\": [{\"role\": \"user\", \"content\": \"Hello! How are you?\"}]\n  }'\n```\n\n- text completions `v1/completions`\n\n```shell\ncurl http://{server_ip}:{server_port}/v1/completions \\\n  -H 'Content-Type: application/json' \\\n  -d '{\n  \"model\": \"llama\",\n  \"prompt\": \"two steps to build a house:\"\n}'\n```\n\n## Launch multiple api servers\n\nFollowing are two steps to launch multiple api servers through torchrun. Just create a python script with the following codes.\n\n1. Launch the proxy server through `lmdeploy serve proxy`. Get the correct proxy server url.\n2. Launch the script through `torchrun --nproc_per_node 2 script.py InternLM/internlm2-chat-1_8b --proxy_url http://{proxy_node_name}:{proxy_node_port}`.**Note**: Please do not use `0.0.0.0:8000` here, instead, we input the real ip name, `11.25.34.55:8000` for example.\n\n```python\nimport os\nimport socket\nfrom typing import List, Literal\n\nimport fire\n\n\ndef get_host_ip():\n    try:\n        s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)\n        s.connect(('8.8.8.8', 80))\n        ip = s.getsockname()[0]\n    finally:\n        s.close()\n    return ip\n\n\ndef main(model_path: str,\n         tp: int = 1,\n         proxy_url: str = 'http://0.0.0.0:8000',\n         port: int = 23333,\n         backend: Literal['turbomind', 'pytorch'] = 'turbomind'):\n    local_rank = int(os.environ.get('LOCAL_RANK', -1))\n    world_size = int(os.environ.get('WORLD_SIZE', -1))\n    local_ip = get_host_ip()\n    if isinstance(port, List):\n        assert len(port) == world_size\n        port = port[local_rank]\n    else:\n        port += local_rank * 10\n    if (world_size - local_rank) % tp == 0:\n        rank_list = ','.join([str(local_rank + i) for i in range(tp)])\n        command = f'CUDA_VISIBLE_DEVICES={rank_list} lmdeploy serve api_server {model_path} '\\\n                  f'--server-name {local_ip} --server-port {port} --tp {tp} '\\\n                  f'--proxy-url {proxy_url} --backend {backend}'\n        print(f'running command: {command}')\n        os.system(command)\n\n\nif __name__ == '__main__':\n    fire.Fire(main)\n```\n\n## FAQ\n\n1. When user got `\"finish_reason\":\"length\"`, it means the session is too long to be continued. The session length can be\n   modified by passing `--session_len` to api_server.\n\n2. When OOM appeared at the server side, please reduce the `cache_max_entry_count` of `backend_config` when launching the service.\n\n3. Regarding the stop words, we only support characters that encode into a single index. Furthermore, there may be multiple indexes that decode into results containing the stop word. In such cases, if the number of these indexes is too large, we will only use the index encoded by the tokenizer. If you want use a stop symbol that encodes into multiple indexes, you may consider performing string matching on the streaming client side. Once a successful match is found, you can then break out of the streaming loop.\n\n4. To customize a chat template, please refer to [chat_template.md](../advance/chat_template.md).\n"
  },
  {
    "path": "docs/en/llm/api_server_lora.md",
    "content": "# Serving LoRA\n\n## Launch LoRA\n\nLoRA is currently only supported by the PyTorch backend. Its deployment process is similar to that of other models, and you can view the commands using lmdeploy `serve api_server -h`. Among the parameters supported by the PyTorch backend, there are configuration options for LoRA.\n\n```\nPyTorch engine arguments:\n  --adapters [ADAPTERS [ADAPTERS ...]]\n                        Used to set path(s) of lora adapter(s). One can input key-value pairs in xxx=yyy format for multiple lora adapters. If only have one adapter, one can only input the path of the adapter.. Default:\n                        None. Type: str\n```\n\nThe user only needs to pass the Hugging Face model path of the LoRA weights in the form of a dictionary to `--adapters`.\n\n```shell\nlmdeploy serve api_server THUDM/chatglm2-6b --adapters mylora=chenchi/lora-chatglm2-6b-guodegang\n```\n\nAfter the service starts, you can find two available model names in the Swagger UI: ‘THUDM/chatglm2-6b’ and ‘mylora’. The latter is the key in the `--adapters` dictionary.\n\n## Client usage\n\n### CLI\n\nWhen using the OpenAI endpoint, the `model` parameter can be used to select either the base model or a specific LoRA weight for inference. The following example chooses to use the provided `chenchi/lora-chatglm2-6b-guodegang` for inference.\n\n```shell\ncurl -X 'POST' \\\n  'http://localhost:23334/v1/chat/completions' \\\n  -H 'accept: application/json' \\\n  -H 'Content-Type: application/json' \\\n  -d '{\n  \"model\": \"mylora\",\n  \"messages\": [\n    {\n      \"content\": \"hi\",\n      \"role\": \"user\"\n    }\n  ]\n}'\n```\n\nAnd here is the output:\n\n```json\n{\n  \"id\": \"2\",\n  \"object\": \"chat.completion\",\n  \"created\": 1721377275,\n  \"model\": \"mylora\",\n  \"choices\": [\n    {\n      \"index\": 0,\n      \"message\": {\n        \"role\": \"assistant\",\n        \"content\": \" 很高兴哪有什么赶凳儿？（按东北语说的“起早哇”），哦，东北人都学会外语了？\",\n        \"tool_calls\": null\n      },\n      \"logprobs\": null,\n      \"finish_reason\": \"stop\"\n    }\n  ],\n  \"usage\": {\n    \"prompt_tokens\": 17,\n    \"total_tokens\": 43,\n    \"completion_tokens\": 26\n  }\n}\n```\n\n### python\n\n```python\nfrom openai import OpenAI\nclient = OpenAI(\n    api_key='YOUR_API_KEY',\n    base_url=\"http://0.0.0.0:23333/v1\"\n)\nmodel_name = 'mylora'\nresponse = client.chat.completions.create(\n  model=model_name,\n  messages=[\n    {\"role\": \"user\", \"content\": \"hi\"},\n  ],\n    temperature=0.8,\n    top_p=0.8\n)\nprint(response)\n```\n\nThe printed response content is:\n\n```\nChatCompletion(id='4', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content=' 很高兴能够见到你哪，我也在辐射区开了个愣儿，你呢，还活着。', role='assistant', function_call=None, tool_calls=None))], created=1721377497, model='mylora', object='chat.completion', service_tier=None, system_fingerprint=None, usage=CompletionUsage(completion_tokens=22, prompt_tokens=17, total_tokens=39))\n```\n"
  },
  {
    "path": "docs/en/llm/api_server_reasoning.md",
    "content": "# Reasoning Outputs\n\nFor models that support reasoning capabilities, such as [DeepSeek R1](https://huggingface.co/deepseek-ai/DeepSeek-R1), LMDeploy supports parsing the reasoning results in the service and separately records the reasoning content using `reasoning_content`.\n\n## Examples\n\n### DeepSeek R1\n\nWe can start the DeepSeek R1 model's api_server service just like launching other models. The difference is that we need to specify --reasoning-parser\\` parameter.\n\n```\nlmdeploy serve api_server deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B --reasoning-parser deepseek-r1\n```\n\nThen, we can call the service's functionality from the client:\n\n```python\nfrom openai import OpenAI\n\nopenai_api_key = \"Your API key\"\nopenai_api_base = \"http://0.0.0.0:23333/v1\"\n\nclient = OpenAI(\n    api_key=openai_api_key,\n    base_url=openai_api_base,\n)\n\nmodels = client.models.list()\nmodel = models.data[0].id\n\nmessages = [{\"role\": \"user\", \"content\": \"9.11 and 9.8, which is greater?\"}]\nresponse = client.chat.completions.create(model=model, messages=messages, stream=True)\nfor stream_response in response:\n    print('reasoning content: ',stream_response.choices[0].delta.reasoning_content)\n    print('content: ', stream_response.choices[0].delta.content)\n\nresponse = client.chat.completions.create(model=model, messages=messages, stream=False)\nreasoning_content = response.choices[0].message.reasoning_content\ncontent = response.choices[0].message.content\n\nprint(\"reasoning_content:\", reasoning_content)\nprint(\"content:\", content)\n```\n\n## Custom parser\n\nYou only need to add a similar parser class in `lmdeploy/serve/openai/reasoning_parser/reasoning_parser.py`.\n\n```python\n# import the required packages\nfrom typing import Sequence, Union, Tuple, Optional\n\nfrom lmdeploy.serve.openai.reasoning_parser import (\n    ReasoningParser, ReasoningParserManager)\nfrom lmdeploy.serve.openai.protocol import (ChatCompletionRequest,\n                                              DeltaMessage)\n\n# define a reasoning parser and register it to lmdeploy\n# the name list in register_module can be used\n# in --reasoning-parser.\n@ReasoningParserManager.register_module([\"example\"])\nclass ExampleParser(ReasoningParser):\n    def __init__(self, tokenizer: object):\n        super().__init__(tokenizer)\n\n    def extract_reasoning_content_streaming(\n        self,\n        previous_text: str,\n        current_text: str,\n        delta_text: str,\n        previous_token_ids: Sequence[int],\n        current_token_ids: Sequence[int],\n        delta_token_ids: Sequence[int],\n    ) -> Union[DeltaMessage, None]:\n        \"\"\"\n        Instance method that should be implemented for extracting reasoning\n        from an incomplete response; for use when handling reasoning calls and\n        streaming. Has to be an instance method because  it requires state -\n        the current tokens/diffs, but also the information about what has\n        previously been parsed and extracted (see constructor)\n        \"\"\"\n\n    def extract_reasoning_content(\n            self, model_output: str, request: ChatCompletionRequest\n    ) -> Tuple[Optional[str], Optional[str]]:\n        \"\"\"\n        Extract reasoning content from a complete model-generated string.\n\n        Used for non-streaming responses where we have the entire model response\n        available before sending to the client.\n\n        Args:\n            model_output (str): The model-generated string to extract reasoning content from.\n            request (ChatCompletionRequest): he request object that was used to generate the model_output.\n\n        Returns:\n            reasoning_content (str | None): The reasoning content.\n            final_output (str | None): The content.\n        \"\"\"\n```\n\nSimilarly, the command to start the service becomes:\n\n```\nlmdeploy serve api_server $model_path --reasoning-parser example\n```\n"
  },
  {
    "path": "docs/en/llm/api_server_tools.md",
    "content": "# Tools Calling\n\nLMDeploy supports tools for InternLM2, InternLM2.5, llama3.1 and Qwen2.5 models. Please use `--tool-call-parser` to specify\nwhich parser to use when launching the api_server. Supported names are:\n\n1. internlm\n2. qwen\n3. llama3\n\n## Single Round Invocation\n\nPlease start the service of models before running the following example.\n\n```python\nfrom openai import OpenAI\n\ntools = [\n  {\n    \"type\": \"function\",\n    \"function\": {\n      \"name\": \"get_current_weather\",\n      \"description\": \"Get the current weather in a given location\",\n      \"parameters\": {\n        \"type\": \"object\",\n        \"properties\": {\n          \"location\": {\n            \"type\": \"string\",\n            \"description\": \"The city and state, e.g. San Francisco, CA\",\n          },\n          \"unit\": {\"type\": \"string\", \"enum\": [\"celsius\", \"fahrenheit\"]},\n        },\n        \"required\": [\"location\"],\n      },\n    }\n  }\n]\nmessages = [{\"role\": \"user\", \"content\": \"What's the weather like in Boston today?\"}]\n\nclient = OpenAI(api_key='YOUR_API_KEY',base_url='http://0.0.0.0:23333/v1')\nmodel_name = client.models.list().data[0].id\nresponse = client.chat.completions.create(\n    model=model_name,\n    messages=messages,\n    temperature=0.8,\n    top_p=0.8,\n    stream=False,\n    tools=tools)\nprint(response)\n```\n\n## Multiple Round Invocation\n\n### InternLM\n\nA complete toolchain invocation process can be demonstrated through the following example.\n\n```python\nfrom openai import OpenAI\n\n\ndef add(a: int, b: int):\n    return a + b\n\n\ndef mul(a: int, b: int):\n    return a * b\n\n\ntools = [{\n    'type': 'function',\n    'function': {\n        'name': 'add',\n        'description': 'Compute the sum of two numbers',\n        'parameters': {\n            'type': 'object',\n            'properties': {\n                'a': {\n                    'type': 'int',\n                    'description': 'A number',\n                },\n                'b': {\n                    'type': 'int',\n                    'description': 'A number',\n                },\n            },\n            'required': ['a', 'b'],\n        },\n    }\n}, {\n    'type': 'function',\n    'function': {\n        'name': 'mul',\n        'description': 'Calculate the product of two numbers',\n        'parameters': {\n            'type': 'object',\n            'properties': {\n                'a': {\n                    'type': 'int',\n                    'description': 'A number',\n                },\n                'b': {\n                    'type': 'int',\n                    'description': 'A number',\n                },\n            },\n            'required': ['a', 'b'],\n        },\n    }\n}]\nmessages = [{'role': 'user', 'content': 'Compute (3+5)*2'}]\n\nclient = OpenAI(api_key='YOUR_API_KEY', base_url='http://0.0.0.0:23333/v1')\nmodel_name = client.models.list().data[0].id\nresponse = client.chat.completions.create(\n    model=model_name,\n    messages=messages,\n    temperature=0.8,\n    top_p=0.8,\n    stream=False,\n    tools=tools)\nprint(response)\nfunc1_name = response.choices[0].message.tool_calls[0].function.name\nfunc1_args = response.choices[0].message.tool_calls[0].function.arguments\nfunc1_out = eval(f'{func1_name}(**{func1_args})')\nprint(func1_out)\n\nmessages.append(response.choices[0].message)\nmessages.append({\n    'role': 'tool',\n    'content': f'3+5={func1_out}',\n    'tool_call_id': response.choices[0].message.tool_calls[0].id\n})\nresponse = client.chat.completions.create(\n    model=model_name,\n    messages=messages,\n    temperature=0.8,\n    top_p=0.8,\n    stream=False,\n    tools=tools)\nprint(response)\nfunc2_name = response.choices[0].message.tool_calls[0].function.name\nfunc2_args = response.choices[0].message.tool_calls[0].function.arguments\nfunc2_out = eval(f'{func2_name}(**{func2_args})')\nprint(func2_out)\n```\n\nUsing the InternLM2-Chat-7B model to execute the above example, the following results will be printed.\n\n```\nChatCompletion(id='1', choices=[Choice(finish_reason='tool_calls', index=0, logprobs=None, message=ChatCompletionMessage(content='', role='assistant', function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='0', function=Function(arguments='{\"a\": 3, \"b\": 5}', name='add'), type='function')]))], created=1722852901, model='/nvme/shared_data/InternLM/internlm2-chat-7b', object='chat.completion', system_fingerprint=None, usage=CompletionUsage(completion_tokens=25, prompt_tokens=263, total_tokens=288))\n8\nChatCompletion(id='2', choices=[Choice(finish_reason='tool_calls', index=0, logprobs=None, message=ChatCompletionMessage(content='', role='assistant', function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='1', function=Function(arguments='{\"a\": 8, \"b\": 2}', name='mul'), type='function')]))], created=1722852901, model='/nvme/shared_data/InternLM/internlm2-chat-7b', object='chat.completion', system_fingerprint=None, usage=CompletionUsage(completion_tokens=25, prompt_tokens=293, total_tokens=318))\n16\n```\n\n### Llama 3.1\n\nMeta announces in [Llama3's official user guide](https://llama.meta.com/docs/model-cards-and-prompt-formats/llama3_1) that,\n\n> There are three built-in tools (brave_search, wolfram_alpha, and code interpreter) can be turned on using the system prompt:\n>\n> 1. Brave Search: Tool call to perform web searches.\n> 2. Wolfram Alpha: Tool call to perform complex mathematical calculations.\n> 3. Code Interpreter: Enables the model to output python code.\n\nAdditionally, it cautions: \"**Note:** We recommend using Llama 70B-instruct or Llama 405B-instruct for applications that combine conversation and tool calling. Llama 8B-Instruct can not reliably maintain a conversation alongside tool calling definitions. It can be used for zero-shot tool calling, but tool instructions should be removed for regular conversations between the model and the user.\"\n\nTherefore, we utilize [Meta-Llama-3.1-70B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-70B-Instruct) to show how to invoke the tool calling by LMDeploy `api_server`.\n\nOn a A100-SXM-80G node, you can start the service as follows:\n\n```shell\nlmdeploy serve api_server /the/path/of/Meta-Llama-3.1-70B-Instruct/model --tp 4\n```\n\nFor an in-depth understanding of the api_server, please refer to the detailed documentation available [here](./api_server.md).\n\nThe following code snippet demonstrates how to utilize the 'Wolfram Alpha' tool. It is assumed that you have already registered on the [Wolfram Alpha](https://www.wolframalpha.com) website and obtained an API key. Please ensure that you have a valid API key to access the services provided by Wolfram Alpha\n\n```python\nfrom openai import OpenAI\nimport requests\n\n\ndef request_llama3_1_service(messages):\n    client = OpenAI(api_key='YOUR_API_KEY',\n                    base_url='http://0.0.0.0:23333/v1')\n    model_name = client.models.list().data[0].id\n    response = client.chat.completions.create(\n        model=model_name,\n        messages=messages,\n        temperature=0.8,\n        top_p=0.8,\n        stream=False)\n    return response.choices[0].message.content\n\n\n# The role of \"system\" MUST be specified, including the required tools\nmessages = [\n    {\n        \"role\": \"system\",\n        \"content\": \"Environment: ipython\\nTools: wolfram_alpha\\n\\n Cutting Knowledge Date: December 2023\\nToday Date: 23 Jul 2024\\n\\nYou are a helpful Assistant.\" # noqa\n    },\n    {\n        \"role\": \"user\",\n        \"content\": \"Can you help me solve this equation: x^3 - 4x^2 + 6x - 24 = 0\"  # noqa\n    }\n]\n\n# send request to the api_server of llama3.1-70b and get the response\n# the \"assistant_response\" is supposed to be:\n# <|python_tag|>wolfram_alpha.call(query=\"solve x^3 - 4x^2 + 6x - 24 = 0\")\nassistant_response = request_llama3_1_service(messages)\nprint(assistant_response)\n\n# Call the API of Wolfram Alpha with the query generated by the model\napp_id = 'YOUR-Wolfram-Alpha-API-KEY'\nparams = {\n    \"input\": assistant_response,\n    \"appid\": app_id,\n    \"format\": \"plaintext\",\n    \"output\": \"json\",\n}\n\nwolframalpha_response = requests.get(\n    \"https://api.wolframalpha.com/v2/query\",\n    params=params\n)\nwolframalpha_response = wolframalpha_response.json()\n\n# Append the contents obtained by the model and the wolframalpha's API\n# to \"messages\", and send it again to the api_server\nmessages += [\n    {\n        \"role\": \"assistant\",\n        \"content\": assistant_response\n    },\n    {\n        \"role\": \"ipython\",\n        \"content\": wolframalpha_response\n    }\n]\n\nassistant_response = request_llama3_1_service(messages)\nprint(assistant_response)\n```\n\n### Qwen2.5\n\nQwen2.5 supports multi tool calling, which means that multiple tool requests can be initiated in one request\n\n```python\nfrom openai import OpenAI\nimport json\n\ndef get_current_temperature(location: str, unit: str = \"celsius\"):\n    \"\"\"Get current temperature at a location.\n\n    Args:\n        location: The location to get the temperature for, in the format \"City, State, Country\".\n        unit: The unit to return the temperature in. Defaults to \"celsius\". (choices: [\"celsius\", \"fahrenheit\"])\n\n    Returns:\n        the temperature, the location, and the unit in a dict\n    \"\"\"\n    return {\n        \"temperature\": 26.1,\n        \"location\": location,\n        \"unit\": unit,\n    }\n\n\ndef get_temperature_date(location: str, date: str, unit: str = \"celsius\"):\n    \"\"\"Get temperature at a location and date.\n\n    Args:\n        location: The location to get the temperature for, in the format \"City, State, Country\".\n        date: The date to get the temperature for, in the format \"Year-Month-Day\".\n        unit: The unit to return the temperature in. Defaults to \"celsius\". (choices: [\"celsius\", \"fahrenheit\"])\n\n    Returns:\n        the temperature, the location, the date and the unit in a dict\n    \"\"\"\n    return {\n        \"temperature\": 25.9,\n        \"location\": location,\n        \"date\": date,\n        \"unit\": unit,\n    }\n\ndef get_function_by_name(name):\n    if name == \"get_current_temperature\":\n        return get_current_temperature\n    if name == \"get_temperature_date\":\n        return get_temperature_date\n\ntools = [{\n    'type': 'function',\n    'function': {\n        'name': 'get_current_temperature',\n        'description': 'Get current temperature at a location.',\n        'parameters': {\n            'type': 'object',\n            'properties': {\n                'location': {\n                    'type': 'string',\n                    'description': 'The location to get the temperature for, in the format \\'City, State, Country\\'.'\n                },\n                'unit': {\n                    'type': 'string',\n                    'enum': [\n                        'celsius',\n                        'fahrenheit'\n                    ],\n                    'description': 'The unit to return the temperature in. Defaults to \\'celsius\\'.'\n                }\n            },\n            'required': [\n                'location'\n            ]\n        }\n    }\n}, {\n    'type': 'function',\n    'function': {\n        'name': 'get_temperature_date',\n        'description': 'Get temperature at a location and date.',\n        'parameters': {\n            'type': 'object',\n            'properties': {\n                'location': {\n                    'type': 'string',\n                    'description': 'The location to get the temperature for, in the format \\'City, State, Country\\'.'\n                },\n                'date': {\n                    'type': 'string',\n                    'description': 'The date to get the temperature for, in the format \\'Year-Month-Day\\'.'\n                },\n                'unit': {\n                    'type': 'string',\n                    'enum': [\n                        'celsius',\n                        'fahrenheit'\n                    ],\n                    'description': 'The unit to return the temperature in. Defaults to \\'celsius\\'.'\n                }\n            },\n            'required': [\n                'location',\n                'date'\n            ]\n        }\n    }\n}]\nmessages = [{'role': 'user', 'content': 'Today is 2024-11-14, What\\'s the temperature in San Francisco now? How about tomorrow?'}]\n\nclient = OpenAI(api_key='YOUR_API_KEY', base_url='http://0.0.0.0:23333/v1')\nmodel_name = client.models.list().data[0].id\nresponse = client.chat.completions.create(\n    model=model_name,\n    messages=messages,\n    temperature=0.8,\n    top_p=0.8,\n    stream=False,\n    tools=tools)\nprint(response.choices[0].message.tool_calls)\nmessages.append(response.choices[0].message)\n\nfor tool_call in response.choices[0].message.tool_calls:\n    tool_call_args = json.loads(tool_call.function.arguments)\n    tool_call_result =  get_function_by_name(tool_call.function.name)(**tool_call_args)\n    messages.append({\n        'role': 'tool',\n        'name': tool_call.function.name,\n        'content': tool_call_result,\n        'tool_call_id': tool_call.id\n    })\n\nresponse = client.chat.completions.create(\n    model=model_name,\n    messages=messages,\n    temperature=0.8,\n    top_p=0.8,\n    stream=False,\n    tools=tools)\nprint(response.choices[0].message.content)\n\n```\n\nUsing the Qwen2.5-14B-Instruct, similar results can be obtained as follows\n\n```\n[ChatCompletionMessageToolCall(id='0', function=Function(arguments='{\"location\": \"San Francisco, California, USA\"}', name='get_current_temperature'), type='function'),\n ChatCompletionMessageToolCall(id='1', function=Function(arguments='{\"location\": \"San Francisco, California, USA\", \"date\": \"2024-11-15\"}', name='get_temperature_date'), type='function')]\n\nThe current temperature in San Francisco, California, USA is 26.1°C. For tomorrow, 2024-11-15, the temperature is expected to be 25.9°C.\n```\n\nIt is important to note that in scenarios involving multiple tool calls, the order of the tool call results can affect the response quality. The tool_call_id has not been correctly provided to the LLM.\n"
  },
  {
    "path": "docs/en/llm/codellama.md",
    "content": "# codellama\n\n## Introduction\n\n[codellama](https://github.com/facebookresearch/codellama) features enhanced coding capabilities. It can generate code and natural language about code, from both code and natural language prompts (e.g., “Write me a function that outputs the fibonacci sequence”). It can also be used for code completion and debugging. It supports many of the most popular programming languages used today, including Python, C++, Java, PHP, Typescript (Javascript), C#, Bash and more.\n\nThere are three sizes (7b, 13b, 34b) as well as three flavours (base model, Python fine-tuned, and instruction tuned) released on [HuggingFace](https://huggingface.co/codellama).\n\n| Base Model                                                                      | Python                                                                                        | Instruct                                                                                          |\n| ------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------- |\n| [codellama/CodeLlama-7b-hf](https://huggingface.co/codellama/CodeLlama-7b-hf)   | [codellama/CodeLlama-7b-Python-hf](https://huggingface.co/codellama/CodeLlama-7b-Python-hf)   | [codellama/CodeLlama-7b-Instruct-hf](https://huggingface.co/codellama/CodeLlama-7b-Instruct-hf)   |\n| [codellama/CodeLlama-13b-hf](https://huggingface.co/codellama/CodeLlama-13b-hf) | [codellama/CodeLlama-13b-Python-hf](https://huggingface.co/codellama/CodeLlama-13b-Python-hf) | [codellama/CodeLlama-13b-Instruct-hf](https://huggingface.co/codellama/CodeLlama-13b-Instruct-hf) |\n| [codellama/CodeLlama-34b-hf](https://huggingface.co/codellama/CodeLlama-34b-hf) | [codellama/CodeLlama-34b-Python-hf](https://huggingface.co/codellama/CodeLlama-34b-Python-hf) | [codellama/CodeLlama-34b-Instruct-hf](https://huggingface.co/codellama/CodeLlama-34b-Instruct-hf) |\n\nThe correspondence between the model and capabilities is:\n\n| models     | code completion | infilling         | instructions / chat | python specialist |\n| ---------- | --------------- | ----------------- | ------------------- | ----------------- |\n| Base Model | Y               | Y(7B,13B), N(34B) | N                   | N                 |\n| Python     | Y               | N                 | N                   | Y                 |\n| Instruct   | Y               | Y(7B,13B), N(34B) | Y                   | N                 |\n\n## Inference\n\nBased on the above table, this section shows how to utilize CodeLlama's capabilities by examples\n\n### Completion\n\n```python\nfrom lmdeploy import pipeline, GenerationConfig, ChatTemplateConfig\n\npipe = pipeline('meta-llama/CodeLlama-7b-hf',\n                chat_template_config=ChatTemplateConfig(\n                    model_name='codellama',\n                    capability='completion'\n                ))\n\nresponse = pipe(\n    'import socket\\n\\ndef ping_exponential_backoff(host: str):',\n    gen_config=GenerationConfig(\n        top_k=10,\n        temperature=0.1,\n        top_p=0.95\n    )\n)\nprint(response.text)\n```\n\n### Infilling\n\n```python\nfrom lmdeploy import pipeline, GenerationConfig, ChatTemplateConfig\n\npipe = pipeline('meta-llama/CodeLlama-7b-hf',\n                chat_template_config=ChatTemplateConfig(\n                    model_name='codellama',\n                    capability='infilling'\n                ))\n\nprompt = \"\"\"\ndef remove_non_ascii(s: str) -> str:\n    \\\"\\\"\\\"\n    <FILL>\n    \\\"\\\"\\\"\n    return result\n\"\"\"\nresponse = pipe(\n    prompt,\n    gen_config=GenerationConfig(\n        top_k=10,\n        temperature=0.1,\n        top_p=0.95,\n        max_new_tokens=500\n    )\n)\nprint(response.text)\n```\n\n### Chat\n\n```python\nfrom lmdeploy import pipeline, GenerationConfig, ChatTemplateConfig\n\npipe = pipeline('meta-llama/CodeLlama-7b-Instruct-hf',\n                chat_template_config=ChatTemplateConfig(\n                    model_name='codellama',\n                    capability='chat'\n                ))\n\nresponse = pipe(\n    'implement quick sort in C++',\n    gen_config=GenerationConfig(\n        top_k=10,\n        temperature=0.1,\n        top_p=0.95\n    )\n)\nprint(response.text)\n```\n\n### Python specialist\n\n```python\nfrom lmdeploy import pipeline, GenerationConfig, ChatTemplateConfig\n\npipe = pipeline('meta-llama/CodeLlama-7b-Python-hf',\n                chat_template_config=ChatTemplateConfig(\n                    model_name='codellama',\n                    capability='python'\n                ))\n\nresponse = pipe(\n    'implement quick sort',\n    gen_config=GenerationConfig(\n        top_k=10,\n        temperature=0.1,\n        top_p=0.95\n    )\n)\nprint(response.text)\n```\n\n## Quantization\n\nTBD\n\n## Serving\n\nPrepare a chat template json file, for instance \"codellama.json\", with the following content:\n\n```json\n{\n    \"model_name\": \"codellama\",\n    \"capability\": \"completion\"\n}\n```\n\nThen launch the service as follows:\n\n```shell\nlmdeploy serve api_server meta-llama/CodeLlama-7b-Instruct-hf --chat-template codellama.json\n```\n\nAfter the service is launched successfully, you can access the service with `openai` package:\n\n```python\nfrom openai import OpenAI\nclient = OpenAI(\n    api_key='YOUR_API_KEY',\n    base_url=\"http://0.0.0.0:23333/v1\"\n)\nmodel_name = client.models.list().data[0].id\nresponse = client.chat.completions.create(\n  model=model_name,\n  messages=[\n    {\"role\": \"user\", \"content\": \"import socket\\n\\ndef ping_exponential_backoff(host: str):\"},\n  ],\n    temperature=0.1,\n    top_p=0.95,\n    max_tokens=500\n)\nprint(response)\n```\n\nRegarding the detailed information of the api_server, you can refer to the [guide](../llm/api_server.md).\n"
  },
  {
    "path": "docs/en/llm/pipeline.md",
    "content": "# Offline Inference Pipeline\n\nIn this tutorial, We will present a list of examples to introduce the usage of `lmdeploy.pipeline`.\n\nYou can overview the detailed pipeline API in [this](https://lmdeploy.readthedocs.io/en/latest/api/pipeline.html) guide.\n\n## Usage\n\n### A 'Hello, world' example\n\n```python\nfrom lmdeploy import pipeline\n\npipe = pipeline('internlm/internlm2_5-7b-chat')\nresponse = pipe(['Hi, pls intro yourself', 'Shanghai is'])\nprint(response)\n```\n\nIn this example, the pipeline by default allocates a predetermined percentage of GPU memory for storing k/v cache. The ratio is dictated by the parameter `TurbomindEngineConfig.cache_max_entry_count`.\n\nThere have been alterations to the strategy for setting the k/v cache ratio throughout the evolution of LMDeploy. The following are the change histories:\n\n1. `v0.2.0 <= lmdeploy <= v0.2.1`\n\n   `TurbomindEngineConfig.cache_max_entry_count` defaults to 0.5, indicating 50% GPU **total memory** allocated for k/v cache. Out Of Memory (OOM) errors may occur if a 7B model is deployed on a GPU with memory less than 40G. If you encounter an OOM error, please decrease the ratio of the k/v cache occupation as follows:\n\n   ```python\n   from lmdeploy import pipeline, TurbomindEngineConfig\n\n   # decrease the ratio of the k/v cache occupation to 20%\n   backend_config = TurbomindEngineConfig(cache_max_entry_count=0.2)\n\n   pipe = pipeline('internlm/internlm2_5-7b-chat',\n                   backend_config=backend_config)\n   response = pipe(['Hi, pls intro yourself', 'Shanghai is'])\n   print(response)\n   ```\n\n2. `lmdeploy > v0.2.1`\n\n   The allocation strategy for k/v cache is changed to reserve space from the **GPU free memory** proportionally. The ratio `TurbomindEngineConfig.cache_max_entry_count` has been adjusted to 0.8 by default. If OOM error happens, similar to the method mentioned above, please consider reducing the ratio value to decrease the memory usage of the k/v cache.\n\n### Set tensor parallelism\n\n```python\nfrom lmdeploy import pipeline, TurbomindEngineConfig\n\nbackend_config = TurbomindEngineConfig(tp=2)\npipe = pipeline('internlm/internlm2_5-7b-chat',\n                backend_config=backend_config)\nresponse = pipe(['Hi, pls intro yourself', 'Shanghai is'])\nprint(response)\n```\n\n### Set sampling parameters\n\n```python\nfrom lmdeploy import pipeline, GenerationConfig, TurbomindEngineConfig\n\nbackend_config = TurbomindEngineConfig(tp=2)\ngen_config = GenerationConfig(top_p=0.8,\n                              top_k=40,\n                              temperature=0.8,\n                              max_new_tokens=1024)\npipe = pipeline('internlm/internlm2_5-7b-chat',\n                backend_config=backend_config)\nresponse = pipe(['Hi, pls intro yourself', 'Shanghai is'],\n                gen_config=gen_config)\nprint(response)\n```\n\n### Apply OpenAI format prompt\n\n```python\nfrom lmdeploy import pipeline, GenerationConfig, TurbomindEngineConfig\n\nbackend_config = TurbomindEngineConfig(tp=2)\ngen_config = GenerationConfig(top_p=0.8,\n                              top_k=40,\n                              temperature=0.8,\n                              max_new_tokens=1024)\npipe = pipeline('internlm/internlm2_5-7b-chat',\n                backend_config=backend_config)\nprompts = [[{\n    'role': 'user',\n    'content': 'Hi, pls intro yourself'\n}], [{\n    'role': 'user',\n    'content': 'Shanghai is'\n}]]\nresponse = pipe(prompts,\n                gen_config=gen_config)\nprint(response)\n```\n\n### Apply streaming output\n\n```python\nfrom lmdeploy import pipeline, GenerationConfig, TurbomindEngineConfig\n\nbackend_config = TurbomindEngineConfig(tp=2)\ngen_config = GenerationConfig(top_p=0.8,\n                              top_k=40,\n                              temperature=0.8,\n                              max_new_tokens=1024)\npipe = pipeline('internlm/internlm2_5-7b-chat',\n                backend_config=backend_config)\nprompts = [[{\n    'role': 'user',\n    'content': 'Hi, pls intro yourself'\n}], [{\n    'role': 'user',\n    'content': 'Shanghai is'\n}]]\nfor item in pipe.stream_infer(prompts, gen_config=gen_config):\n    print(item)\n```\n\n### Get logits for generated tokens\n\n```python\nfrom lmdeploy import pipeline, GenerationConfig\n\npipe = pipeline('internlm/internlm2_5-7b-chat')\n\ngen_config=GenerationConfig(output_logits='generation',\n                            max_new_tokens=10)\nresponse = pipe(['Hi, pls intro yourself', 'Shanghai is'],\n                gen_config=gen_config)\nlogits = [x.logits for x in response]\n```\n\n### Get last layer's hidden states for generated tokens\n\n```python\nfrom lmdeploy import pipeline, GenerationConfig\n\npipe = pipeline('internlm/internlm2_5-7b-chat')\n\ngen_config=GenerationConfig(output_last_hidden_state='generation',\n                            max_new_tokens=10)\nresponse = pipe(['Hi, pls intro yourself', 'Shanghai is'],\n                gen_config=gen_config)\nhidden_states = [x.last_hidden_state for x in response]\n```\n\n### Calculate ppl\n\n```python\nfrom transformers import AutoTokenizer\nfrom lmdeploy import pipeline\n\n\nmodel_repoid_or_path = 'internlm/internlm2_5-7b-chat'\npipe = pipeline(model_repoid_or_path)\ntokenizer = AutoTokenizer.from_pretrained(model_repoid_or_path, trust_remote_code=True)\nmessages = [\n   {\"role\": \"user\", \"content\": \"Hello, how are you?\"},\n]\ninput_ids = tokenizer.apply_chat_template(messages)\n\n# ppl is a list of float numbers\nppl = pipe.get_ppl(input_ids)\nprint(ppl)\n```\n\n```{note}\n- When input_ids is too long, an OOM (Out Of Memory) error may occur. Please apply it with caution\n- get_ppl returns the cross entropy loss without applying the exponential operation afterwards\n```\n\n### Use PyTorchEngine\n\n```shell\npip install triton>=2.1.0\n```\n\n```python\nfrom lmdeploy import pipeline, GenerationConfig, PytorchEngineConfig\n\nbackend_config = PytorchEngineConfig(session_len=2048)\ngen_config = GenerationConfig(top_p=0.8,\n                              top_k=40,\n                              temperature=0.8,\n                              max_new_tokens=1024)\npipe = pipeline('internlm/internlm2_5-7b-chat',\n                backend_config=backend_config)\nprompts = [[{\n    'role': 'user',\n    'content': 'Hi, pls intro yourself'\n}], [{\n    'role': 'user',\n    'content': 'Shanghai is'\n}]]\nresponse = pipe(prompts, gen_config=gen_config)\nprint(response)\n```\n\n### Inference with LoRA\n\n```python\nfrom lmdeploy import pipeline, GenerationConfig, PytorchEngineConfig\n\nbackend_config = PytorchEngineConfig(session_len=2048,\n                                     adapters=dict(lora_name_1='chenchi/lora-chatglm2-6b-guodegang'))\ngen_config = GenerationConfig(top_p=0.8,\n                              top_k=40,\n                              temperature=0.8,\n                              max_new_tokens=1024)\npipe = pipeline('THUDM/chatglm2-6b',\n                backend_config=backend_config)\nprompts = [[{\n    'role': 'user',\n    'content': '您猜怎么着'\n}]]\nresponse = pipe(prompts, gen_config=gen_config, adapter_name='lora_name_1')\nprint(response)\n```\n\n### Release pipeline\n\nYou can release the pipeline explicitly by calling its `close()` method, or alternatively, use the `with` statement as demonstrated below:\n\n```python\nfrom lmdeploy import pipeline\n\nwith pipeline('internlm/internlm2_5-7b-chat') as pipe:\n    response = pipe(['Hi, pls intro yourself', 'Shanghai is'])\n    print(response)\n```\n\n## FAQs\n\n- **RuntimeError: An attempt has been made to start a new process before the current process has finished its bootstrapping phase**.\n\n  If you got this for tp>1 in pytorch backend. Please make sure the python script has following\n\n  ```python\n  if __name__ == '__main__':\n  ```\n\n  Generally, in the context of multi-threading or multi-processing, it might be necessary to ensure that initialization code is executed only once. In this case, `if __name__ == '__main__':` can help to ensure that these initialization codes are run only in the main program, and not repeated in each newly created process or thread.\n\n- To customize a chat template, please refer to [chat_template.md](../advance/chat_template.md).\n\n- If the weight of lora has a corresponding chat template, you can first register the chat template to lmdeploy, and then use the chat template name as the adapter name.\n"
  },
  {
    "path": "docs/en/llm/proxy_server.md",
    "content": "# Request Distributor Server\n\nThe request distributor service can parallelize multiple api_server services. Users only need to access the proxy URL, and they can indirectly access different api_server services. The proxy service will automatically distribute requests internally, achieving load balancing.\n\n## Startup\n\nStart the proxy service:\n\n```shell\nlmdeploy serve proxy --server-name {server_name} --server-port {server_port} --routing-strategy \"min_expected_latency\" --serving-strategy Hybrid\n```\n\nAfter startup is successful, the URL of the proxy service will also be printed by the script. Access this URL in your browser to open the Swagger UI.\nSubsequently, users can add it directly to the proxy service when starting the `api_server` service by using the `--proxy-url` command. For example:\n`lmdeploy serve api_server InternLM/internlm2-chat-1_8b --proxy-url http://0.0.0.0:8000`。\nIn this way, users can access the services of the `api_server` through the proxy node, and the usage of the proxy node is exactly the same as that of the `api_server`, both of which are compatible with the OpenAI format.\n\n- /v1/models\n- /v1/chat/completions\n- /v1/completions\n\n## Node Management\n\nThrough Swagger UI, we can see multiple APIs. Those related to api_server node management include:\n\n- /nodes/status\n- /nodes/add\n- /nodes/remove\n\nThey respectively represent viewing all api_server service nodes, adding a certain node, and deleting a certain node.\n\n### Node Management through curl\n\n```shell\ncurl -X 'GET' \\\n  'http://localhost:8000/nodes/status' \\\n  -H 'accept: application/json'\n```\n\n```shell\ncurl -X 'POST' \\\n  'http://localhost:8000/nodes/add' \\\n  -H 'accept: application/json' \\\n  -H 'Content-Type: application/json' \\\n  -d '{\n  \"url\": \"http://0.0.0.0:23333\"\n}'\n```\n\n```shell\ncurl -X 'POST' \\\n  'http://localhost:8000/nodes/remove?node_url=http://0.0.0.0:23333' \\\n  -H 'accept: application/json' \\\n  -d ''\n```\n\n### Node Management through python\n\n```python\n# query all nodes\nimport requests\nurl = 'http://localhost:8000/nodes/status'\nheaders = {'accept': 'application/json'}\nresponse = requests.get(url, headers=headers)\nprint(response.text)\n```\n\n```python\n# add a new node\nimport requests\nurl = 'http://localhost:8000/nodes/add'\nheaders = {\n    'accept': 'application/json',\n    'Content-Type': 'application/json'\n}\ndata = {\"url\": \"http://0.0.0.0:23333\"}\nresponse = requests.post(url, headers=headers, json=data)\nprint(response.text)\n```\n\n```python\n# delete a node\nimport requests\nurl = 'http://localhost:8000/nodes/remove'\nheaders = {'accept': 'application/json',}\nparams = {'node_url': 'http://0.0.0.0:23333',}\nresponse = requests.post(url, headers=headers, data='', params=params)\nprint(response.text)\n```\n\n## Serving Strategy\n\nLMDeploy currently supports two serving strategies:\n\n- Hybrid: Does not distinguish between Prefill and Decoding instances, following the traditional inference deployment mode.\n- DistServe: Separates Prefill and Decoding instances, deploying them on different service nodes to achieve more flexible and efficient resource scheduling and scalability.\n\n## Dispatch Strategy\n\nThe current distribution strategies of the proxy service are as follows:\n\n- random： dispatches based on the ability of each api_server node provided by the user to process requests. The greater the request throughput, the more likely it is to be allocated. Nodes that do not provide throughput are treated according to the average throughput of other nodes.\n- min_expected_latency： allocates based on the number of requests currently waiting to be processed on each node, and the throughput capability of each node, calculating the expected time required to complete the response. The shortest one gets allocated. Nodes that do not provide throughput are treated similarly.\n- min_observed_latency： allocates based on the average time required to handle a certain number of past requests on each node. The one with the shortest time gets allocated.\n"
  },
  {
    "path": "docs/en/make.bat",
    "content": "@ECHO OFF\n\npushd %~dp0\n\nREM Command file for Sphinx documentation\n\nif \"%SPHINXBUILD%\" == \"\" (\n\tset SPHINXBUILD=sphinx-build\n)\nset SOURCEDIR=.\nset BUILDDIR=_build\n\nif \"%1\" == \"\" goto help\n\n%SPHINXBUILD% >NUL 2>NUL\nif errorlevel 9009 (\n\techo.\n\techo.The 'sphinx-build' command was not found. Make sure you have Sphinx\n\techo.installed, then set the SPHINXBUILD environment variable to point\n\techo.to the full path of the 'sphinx-build' executable. Alternatively you\n\techo.may add the Sphinx directory to PATH.\n\techo.\n\techo.If you don't have Sphinx installed, grab it from\n\techo.http://sphinx-doc.org/\n\texit /b 1\n)\n\n%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%\ngoto end\n\n:help\n%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%\n\n:end\npopd\n"
  },
  {
    "path": "docs/en/multi_modal/api_server_vl.md",
    "content": "# OpenAI Compatible Server\n\nThis article primarily discusses the deployment of a single large vision language model across multiple GPUs on a single node, providing a service that is compatible with the OpenAI interface, as well as the usage of the service API.\nFor the sake of convenience, we refer to this service as `api_server`. Regarding parallel services with multiple models, please refer to the guide about [Request Distribution Server](../llm/proxy_server.md).\n\nIn the following sections, we will first introduce two methods for starting the service, choosing the appropriate one based on your application scenario.\n\nNext, we focus on the definition of the service's RESTful API, explore the various ways to interact with the interface, and demonstrate how to try the service through the Swagger UI or LMDeploy CLI tools.\n\n## Launch Service\n\nTake the [llava-v1.6-vicuna-7b](https://huggingface.co/liuhaotian/llava-v1.6-vicuna-7b) model hosted on huggingface hub as an example, you can choose one the following methods to start the service.\n\n### Option 1: Launching with lmdeploy CLI\n\n```shell\nlmdeploy serve api_server liuhaotian/llava-v1.6-vicuna-7b --server-port 23333\n```\n\nThe arguments of `api_server` can be viewed through the command `lmdeploy serve api_server -h`, for instance, `--tp` to set tensor parallelism, `--session-len` to specify the max length of the context window, `--cache-max-entry-count` to adjust the GPU mem ratio for k/v cache etc.\n\n### Option 2: Deploying with docker\n\nWith LMDeploy [official docker image](https://hub.docker.com/r/openmmlab/lmdeploy/tags), you can run OpenAI compatible server as follows:\n\n```shell\ndocker run --runtime nvidia --gpus all \\\n    -v ~/.cache/huggingface:/root/.cache/huggingface \\\n    --env \"HUGGING_FACE_HUB_TOKEN=<secret>\" \\\n    -p 23333:23333 \\\n    --ipc=host \\\n    openmmlab/lmdeploy:latest \\\n    lmdeploy serve api_server liuhaotian/llava-v1.6-vicuna-7b\n```\n\nThe parameters of `api_server` are the same with that mentioned in \"[option 1](#option-1-launching-with-lmdeploy-cli)\" section\n\nEach model may require specific dependencies not included in the Docker image. If you run into issues, you may need to install those yourself\non a case-by-case basis. If in doubt, refer to the specific model's project for documentation.\n\nFor example, for Llava:\n\n```\nFROM openmmlab/lmdeploy:latest\n\nRUN apt-get update && apt-get install -y python3 python3-pip git\n\nWORKDIR /app\n\nRUN pip3 install --upgrade pip\nRUN pip3 install timm\nRUN pip3 install git+https://github.com/haotian-liu/LLaVA.git --no-deps\n\nCOPY . .\n\nCMD [\"lmdeploy\", \"serve\", \"api_server\", \"liuhaotian/llava-v1.6-34b\"]\n```\n\n## RESTful API\n\nLMDeploy's RESTful API is compatible with the following three OpenAI interfaces:\n\n- /v1/chat/completions\n- /v1/models\n- /v1/completions\n\nThe interface for image interaction is `/v1/chat/completions`, which is consistent with OpenAI.\n\nYou can overview and try out the offered RESTful APIs by the website `http://0.0.0.0:23333` as shown in the below image after launching the service successfully.\n\n![swagger_ui](https://github.com/InternLM/lmdeploy/assets/4560679/b891dd90-3ffa-4333-92b2-fb29dffa1459)\n\nIf you need to integrate the service into your own projects or products, we recommend the following approach:\n\n### Integrate with `OpenAI`\n\nHere is an example of interaction with the endpoint `v1/chat/completions` service via the openai package.\nBefore running it, please install the openai package by `pip install openai`\n\n```python\nfrom openai import OpenAI\n\nclient = OpenAI(api_key='YOUR_API_KEY', base_url='http://0.0.0.0:23333/v1')\nmodel_name = client.models.list().data[0].id\nresponse = client.chat.completions.create(\n    model=model_name,\n    messages=[{\n        'role':\n        'user',\n        'content': [{\n            'type': 'text',\n            'text': 'Describe the image please',\n        }, {\n            'type': 'image_url',\n            'image_url': {\n                'url':\n                'https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg',\n            },\n        }],\n    }],\n    temperature=0.8,\n    top_p=0.8)\nprint(response)\n```\n\n### Integrate with lmdeploy `APIClient`\n\nBelow are some examples demonstrating how to visit the service through `APIClient`\n\nIf you want to use the `/v1/chat/completions` endpoint, you can try the following code:\n\n```python\nfrom lmdeploy.serve.openai.api_client import APIClient\n\napi_client = APIClient(f'http://0.0.0.0:23333')\nmodel_name = api_client.available_models[0]\nmessages = [{\n    'role':\n    'user',\n    'content': [{\n        'type': 'text',\n        'text': 'Describe the image please',\n    }, {\n        'type': 'image_url',\n        'image_url': {\n            'url':\n            'https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg',\n        },\n    }]\n}]\nfor item in api_client.chat_completions_v1(model=model_name,\n                                           messages=messages):\n    print(item)\n```\n\n### Integrate with Java/Golang/Rust\n\nMay use [openapi-generator-cli](https://github.com/OpenAPITools/openapi-generator-cli) to convert `http://{server_ip}:{server_port}/openapi.json` to java/rust/golang client.\nHere is an example:\n\n```shell\n$ docker run -it --rm -v ${PWD}:/local openapitools/openapi-generator-cli generate -i /local/openapi.json -g rust -o /local/rust\n\n$ ls rust/*\nrust/Cargo.toml  rust/git_push.sh  rust/README.md\n\nrust/docs:\nChatCompletionRequest.md  EmbeddingsRequest.md  HttpValidationError.md  LocationInner.md  Prompt.md\nDefaultApi.md             GenerateRequest.md    Input.md                Messages.md       ValidationError.md\n\nrust/src:\napis  lib.rs  models\n```\n"
  },
  {
    "path": "docs/en/multi_modal/cogvlm.md",
    "content": "# CogVLM\n\n## Introduction\n\nCogVLM is a powerful open-source visual language model (VLM). LMDeploy supports CogVLM-17B models like [THUDM/cogvlm-chat-hf](https://huggingface.co/THUDM/cogvlm-chat-hf) and CogVLM2-19B models like [THUDM/cogvlm2-llama3-chat-19B](https://huggingface.co/THUDM/cogvlm2-llama3-chat-19B) in PyTorch engine.\n\n## Quick Start\n\nInstall LMDeploy by following the [installation guide](../get_started/installation.md)\n\n### Prepare\n\nWhen deploying the **CogVLM** model using LMDeploy, it is necessary to download the model first, as the **CogVLM** model repository does not include the tokenizer model.\nHowever, this step is not required for **CogVLM2**.\n\nTaking one **CogVLM** model `cogvlm-chat-hf` as an example, you can prepare it as follows:\n\n```shell\nhuggingface-cli download THUDM/cogvlm-chat-hf --local-dir ./cogvlm-chat-hf --local-dir-use-symlinks False\nhuggingface-cli download lmsys/vicuna-7b-v1.5 special_tokens_map.json tokenizer.model tokenizer_config.json --local-dir ./cogvlm-chat-hf --local-dir-use-symlinks False\n```\n\n### Offline inference pipeline\n\nThe following sample code shows the basic usage of VLM pipeline. For more examples, please refer to [VLM Offline Inference Pipeline](./vl_pipeline.md)\n\n```python\nfrom lmdeploy import pipeline\nfrom lmdeploy.vl import load_image\n\n\nif __name__ == \"__main__\":\n    pipe = pipeline('cogvlm-chat-hf')\n\n    image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg')\n    response = pipe(('describe this image', image))\n    print(response)\n```\n"
  },
  {
    "path": "docs/en/multi_modal/deepseek_vl2.md",
    "content": "# DeepSeek-VL2\n\n## Introduction\n\nDeepSeek-VL2, an advanced series of large Mixture-of-Experts (MoE) Vision-Language Models that significantly improves upon its predecessor, DeepSeek-VL.\nDeepSeek-VL2 demonstrates superior capabilities across various tasks, including but not limited to visual question answering, optical character recognition, document/table/chart understanding, and visual grounding.\n\nLMDeploy supports [deepseek-vl2-tiny](https://huggingface.co/deepseek-ai/deepseek-vl2-tiny), [deepseek-vl2-small](https://huggingface.co/deepseek-ai/deepseek-vl2-small) and [deepseek-vl2](https://huggingface.co/deepseek-ai/deepseek-vl2) in PyTorch engine.\n\n## Quick Start\n\nInstall LMDeploy by following the [installation guide](../get_started/installation.md).\n\n### Prepare\n\nWhen deploying the **DeepSeek-VL2** model using LMDeploy, you must install the official GitHub repository and related 3-rd party libs. This is because LMDeploy reuses the image processing functions provided in the official repository.\n\n```\npip install git+https://github.com/deepseek-ai/DeepSeek-VL2.git --no-deps\npip install attrdict timm 'transformers<4.48.0'\n```\n\nWorth noticing that it may fail with `transformers>=4.48.0`, as known in this [issue](https://github.com/deepseek-ai/DeepSeek-VL2/issues/45).\n\n### Offline inference pipeline\n\nThe following sample code shows the basic usage of VLM pipeline. For more examples, please refer to [VLM Offline Inference Pipeline](./vl_pipeline.md).\n\nTo construct valid DeepSeek-VL2 prompts with image inputs, users should insert `<IMAGE_TOKEN>` manually.\n\n```python\nfrom lmdeploy import pipeline\nfrom lmdeploy.vl import load_image\n\n\nif __name__ == \"__main__\":\n    pipe = pipeline('deepseek-ai/deepseek-vl2-tiny')\n\n    image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg')\n    response = pipe(('<IMAGE_TOKEN>describe this image', image))\n    print(response)\n```\n"
  },
  {
    "path": "docs/en/multi_modal/gemma3.md",
    "content": "# Gemma3\n\n## Introduction\n\nGemma is a family of lightweight, state-of-the-art open models from Google, built from the same research and technology used to create the Gemini models. Gemma 3 models are multimodal, handling text and image input and generating text output, with open weights for both pre-trained variants and instruction-tuned variants. Gemma 3 has a large, 128K context window, multilingual support in over 140 languages, and is available in more sizes than previous versions. Gemma 3 models are well-suited for a variety of text generation and image understanding tasks, including question answering, summarization, and reasoning. Their relatively small size makes it possible to deploy them in environments with limited resources such as laptops, desktops or your own cloud infrastructure, democratizing access to state of the art AI models and helping foster innovation for everyone.\n\n## Quick Start\n\nInstall LMDeploy by following the [installation guide](../get_started/installation.md).\n\n### Prepare\n\nWhen deploying the **Gemma3** model using LMDeploy, please install the latest transformers.\n\n### Offline inference pipeline\n\nThe following sample code shows the basic usage of VLM pipeline. For more examples, please refer to [VLM Offline Inference Pipeline](./vl_pipeline.md).\n\n```python\nfrom lmdeploy import pipeline\nfrom lmdeploy.vl import load_image\n\n\nif __name__ == \"__main__\":\n    pipe = pipeline('google/gemma-3-12b-it')\n\n    image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg')\n    response = pipe(('describe this image', image))\n    print(response)\n```\n"
  },
  {
    "path": "docs/en/multi_modal/index.rst",
    "content": "Vision-Language Models\n=================================\n\n.. toctree::\n   :maxdepth: 2\n   :caption: Examples\n\n   deepseek_vl2.md\n   llava.md\n   internvl.md\n   xcomposer2d5.md\n   cogvlm.md\n   minicpmv.md\n   phi3.md\n   qwen2_vl.md\n   qwen2_5_vl.md\n   molmo.md\n   gemma3.md\n"
  },
  {
    "path": "docs/en/multi_modal/internvl.md",
    "content": "# InternVL\n\nLMDeploy supports the following InternVL series of models, which are detailed in the table below:\n\n|         Model         |     Size      | Supported Inference Engine |\n| :-------------------: | :-----------: | :------------------------: |\n|       InternVL        |    13B-19B    |         TurboMind          |\n|      InternVL1.5      |    2B-26B     |     TurboMind, PyTorch     |\n|       InternVL2       |      4B       |          PyTorch           |\n|       InternVL2       | 1B-2B, 8B-76B |     TurboMind, PyTorch     |\n| InternVL2.5/2.5-MPO/3 |    1B-78B     |     TurboMind, PyTorch     |\n|     Mono-InternVL     |      2B       |          PyTorch           |\n\nThe next chapter demonstrates how to deploy an InternVL model using LMDeploy, with [InternVL2-8B](https://huggingface.co/OpenGVLab/InternVL2-8B) as an example.\n\n## Installation\n\nPlease install LMDeploy by following the [installation guide](../get_started/installation.md), and install other packages that InternVL2 needs\n\n```shell\npip install timm\n# It is recommended to find the whl package that matches the environment from the releases on https://github.com/Dao-AILab/flash-attention.\npip install flash-attn\n```\n\nOr, you can build a docker image to set up the inference environment. If the CUDA version on your host machine is `>=12.4`, you can run:\n\n```\ndocker build --build-arg CUDA_VERSION=cu12 -t openmmlab/lmdeploy:internvl . -f ./docker/InternVL_Dockerfile\n```\n\nOtherwise, you can go with:\n\n```shell\ngit clone https://github.com/InternLM/lmdeploy.git\ncd lmdeploy\ndocker build --build-arg CUDA_VERSION=cu11 -t openmmlab/lmdeploy:internvl . -f ./docker/InternVL_Dockerfile\n```\n\n## Offline inference\n\nThe following sample code shows the basic usage of VLM pipeline. For detailed information, please refer to [VLM Offline Inference Pipeline](./vl_pipeline.md)\n\n```python\nfrom lmdeploy import pipeline\nfrom lmdeploy.vl import load_image\n\npipe = pipeline('OpenGVLab/InternVL2-8B')\n\nimage = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg')\nresponse = pipe((f'describe this image', image))\nprint(response)\n```\n\nMore examples are listed below:\n\n<details>\n  <summary>\n    <b>multi-image multi-round conversation, combined images</b>\n  </summary>\n\n```python\nfrom lmdeploy import pipeline, GenerationConfig\nfrom lmdeploy.vl.constants import IMAGE_TOKEN\n\npipe = pipeline('OpenGVLab/InternVL2-8B', log_level='INFO')\nmessages = [\n    dict(role='user', content=[\n        dict(type='text', text=f'{IMAGE_TOKEN}{IMAGE_TOKEN}\\nDescribe the two images in detail.'),\n        dict(type='image_url', image_url=dict(max_dynamic_patch=12, url='https://raw.githubusercontent.com/OpenGVLab/InternVL/main/internvl_chat/examples/image1.jpg')),\n        dict(type='image_url', image_url=dict(max_dynamic_patch=12, url='https://raw.githubusercontent.com/OpenGVLab/InternVL/main/internvl_chat/examples/image2.jpg'))\n    ])\n]\nout = pipe(messages, gen_config=GenerationConfig(top_k=1))\n\nmessages.append(dict(role='assistant', content=out.text))\nmessages.append(dict(role='user', content='What are the similarities and differences between these two images.'))\nout = pipe(messages, gen_config=GenerationConfig(top_k=1))\n```\n\n</details>\n\n<details>\n  <summary>\n    <b>multi-image multi-round conversation, separate images</b>\n  </summary>\n\n```python\nfrom lmdeploy import pipeline, GenerationConfig\nfrom lmdeploy.vl.constants import IMAGE_TOKEN\n\npipe = pipeline('OpenGVLab/InternVL2-8B', log_level='INFO')\nmessages = [\n    dict(role='user', content=[\n        dict(type='text', text=f'Image-1: {IMAGE_TOKEN}\\nImage-2: {IMAGE_TOKEN}\\nDescribe the two images in detail.'),\n        dict(type='image_url', image_url=dict(max_dynamic_patch=12, url='https://raw.githubusercontent.com/OpenGVLab/InternVL/main/internvl_chat/examples/image1.jpg')),\n        dict(type='image_url', image_url=dict(max_dynamic_patch=12, url='https://raw.githubusercontent.com/OpenGVLab/InternVL/main/internvl_chat/examples/image2.jpg'))\n    ])\n]\nout = pipe(messages, gen_config=GenerationConfig(top_k=1))\n\nmessages.append(dict(role='assistant', content=out.text))\nmessages.append(dict(role='user', content='What are the similarities and differences between these two images.'))\nout = pipe(messages, gen_config=GenerationConfig(top_k=1))\n```\n\n</details>\n\n<details>\n  <summary>\n    <b>video multi-round conversation</b>\n  </summary>\n\n```python\nimport numpy as np\nfrom lmdeploy import pipeline, GenerationConfig\nfrom decord import VideoReader, cpu\nfrom lmdeploy.vl.constants import IMAGE_TOKEN\nfrom lmdeploy.vl import encode_image_base64\nfrom PIL import Image\npipe = pipeline('OpenGVLab/InternVL2-8B', log_level='INFO')\n\n\ndef get_index(bound, fps, max_frame, first_idx=0, num_segments=32):\n    if bound:\n        start, end = bound[0], bound[1]\n    else:\n        start, end = -100000, 100000\n    start_idx = max(first_idx, round(start * fps))\n    end_idx = min(round(end * fps), max_frame)\n    seg_size = float(end_idx - start_idx) / num_segments\n    frame_indices = np.array([\n        int(start_idx + (seg_size / 2) + np.round(seg_size * idx))\n        for idx in range(num_segments)\n    ])\n    return frame_indices\n\n\ndef load_video(video_path, bound=None, num_segments=32):\n    vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)\n    max_frame = len(vr) - 1\n    fps = float(vr.get_avg_fps())\n    pixel_values_list, num_patches_list = [], []\n    frame_indices = get_index(bound, fps, max_frame, first_idx=0, num_segments=num_segments)\n    imgs = []\n    for frame_index in frame_indices:\n        img = Image.fromarray(vr[frame_index].asnumpy()).convert('RGB')\n        imgs.append(img)\n    return imgs\n\n\nvideo_path = 'red-panda.mp4'\nimgs = load_video(video_path, num_segments=8)\n\nquestion = ''\nfor i in range(len(imgs)):\n    question = question + f'Frame{i+1}: {IMAGE_TOKEN}\\n'\n\nquestion += 'What is the red panda doing?'\n\ncontent = [{'type': 'text', 'text': question}]\nfor img in imgs:\n    content.append({'type': 'image_url', 'image_url': {'max_dynamic_patch': 1, 'url': f'data:image/jpeg;base64,{encode_image_base64(img)}'}})\n\nmessages = [dict(role='user', content=content)]\nout = pipe(messages, gen_config=GenerationConfig(top_k=1))\n\nmessages.append(dict(role='assistant', content=out.text))\nmessages.append(dict(role='user', content='Describe this video in detail. Don\\'t repeat.'))\nout = pipe(messages, gen_config=GenerationConfig(top_k=1))\n```\n\n</details>\n\n## Online serving\n\nYou can launch the server by the `lmdeploy serve api_server` CLI:\n\n```shell\nlmdeploy serve api_server OpenGVLab/InternVL2-8B\n```\n\nYou can also start the service using the aforementioned built docker image:\n\n```shell\ndocker run --runtime nvidia --gpus all \\\n    -v ~/.cache/huggingface:/root/.cache/huggingface \\\n    --env \"HUGGING_FACE_HUB_TOKEN=<secret>\" \\\n    -p 23333:23333 \\\n    --ipc=host \\\n    openmmlab/lmdeploy:internvl \\\n    lmdeploy serve api_server OpenGVLab/InternVL2-8B\n```\n\nThe docker compose is another option. Create a `docker-compose.yml` configuration file in the root directory of the lmdeploy project as follows:\n\n```yaml\nversion: '3.5'\n\nservices:\n  lmdeploy:\n    container_name: lmdeploy\n    image: openmmlab/lmdeploy:internvl\n    ports:\n      - \"23333:23333\"\n    environment:\n      HUGGING_FACE_HUB_TOKEN: <secret>\n    volumes:\n      - ~/.cache/huggingface:/root/.cache/huggingface\n    stdin_open: true\n    tty: true\n    ipc: host\n    command: lmdeploy serve api_server OpenGVLab/InternVL2-8B\n    deploy:\n      resources:\n        reservations:\n          devices:\n            - driver: nvidia\n              count: \"all\"\n              capabilities: [gpu]\n```\n\nThen, you can execute the startup command as below:\n\n```shell\ndocker-compose up -d\n```\n\nIf you find the following logs after running `docker logs -f lmdeploy`, it means the service launches successfully.\n\n```text\nHINT:    Please open  http://0.0.0.0:23333   in a browser for detailed api usage!!!\nHINT:    Please open  http://0.0.0.0:23333   in a browser for detailed api usage!!!\nHINT:    Please open  http://0.0.0.0:23333   in a browser for detailed api usage!!!\nINFO:     Started server process [2439]\nINFO:     Waiting for application startup.\nINFO:     Application startup complete.\nINFO:     Uvicorn running on  http://0.0.0.0:23333  (Press CTRL+C to quit)\n```\n\nThe arguments of `lmdeploy serve api_server` can be reviewed in detail by `lmdeploy serve api_server -h`.\n\nMore information about `api_server` as well as how to access the service can be found from [here](api_server_vl.md)\n"
  },
  {
    "path": "docs/en/multi_modal/llava.md",
    "content": "# LLaVA\n\nLMDeploy supports the following llava series of models, which are detailed in the table below:\n\n|                Model                 | Size | Supported Inference Engine |\n| :----------------------------------: | :--: | :------------------------: |\n| llava-hf/Llava-interleave-qwen-7b-hf |  7B  |     TurboMind, PyTorch     |\n|       llava-hf/llava-1.5-7b-hf       |  7B  |     TurboMind, PyTorch     |\n|  llava-hf/llava-v1.6-mistral-7b-hf   |  7B  |          PyTorch           |\n|   llava-hf/llava-v1.6-vicuna-7b-hf   |  7B  |          PyTorch           |\n|   liuhaotian/llava-v1.6-mistral-7b   |  7B  |         TurboMind          |\n|   liuhaotian/llava-v1.6-vicuna-7b    |  7B  |         TurboMind          |\n\nThe next chapter demonstrates how to deploy an Llava model using LMDeploy, with [llava-hf/llava-interleave](https://huggingface.co/llava-hf/llava-interleave-qwen-7b-hf) as an example.\n\n```{note}\nPyTorch engine removes the support of original llava models after v0.6.4. Please use their corresponding transformers models instead, which can be found in https://huggingface.co/llava-hf\n```\n\n## Installation\n\nPlease install LMDeploy by following the [installation guide](../get_started/installation.md).\n\nOr, you can go with office docker image:\n\n```shell\ndocker pull openmmlab/lmdeploy:latest\n```\n\n## Offline inference\n\nThe following sample code shows the basic usage of VLM pipeline. For detailed information, please refer to [VLM Offline Inference Pipeline](./vl_pipeline.md)\n\n```python\nfrom lmdeploy import GenerationConfig, TurbomindEngineConfig, pipeline\nfrom lmdeploy.vl import load_image\n\n\npipe = pipeline(\"llava-hf/llava-interleave-qwen-7b-hf\", backend_config=TurbomindEngineConfig(cache_max_entry_count=0.5),\n    gen_config=GenerationConfig(max_new_tokens=512))\n\nimage = load_image('https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg')\nprompt = 'Describe the image.'\nprint(f'prompt:{prompt}')\nresponse = pipe((prompt, image))\nprint(response)\n\n```\n\nMore examples are listed below:\n\n<details>\n  <summary>\n    <b>multi-image multi-round conversation, combined images</b>\n  </summary>\n\n```python\nfrom lmdeploy import pipeline, GenerationConfig\n\npipe = pipeline('llava-hf/llava-interleave-qwen-7b-hf', log_level='INFO')\nmessages = [\n    dict(role='user', content=[\n        dict(type='text', text='Describe the two images in detail.'),\n        dict(type='image_url', image_url=dict(url='https://raw.githubusercontent.com/QwenLM/Qwen-VL/master/assets/mm_tutorial/Beijing_Small.jpeg')),\n        dict(type='image_url', image_url=dict(url='https://raw.githubusercontent.com/QwenLM/Qwen-VL/master/assets/mm_tutorial/Chongqing_Small.jpeg'))\n    ])\n]\nout = pipe(messages, gen_config=GenerationConfig(top_k=1))\n\nmessages.append(dict(role='assistant', content=out.text))\nmessages.append(dict(role='user', content='What are the similarities and differences between these two images.'))\nout = pipe(messages, gen_config=GenerationConfig(top_k=1))\n```\n\n</details>\n\n## Online serving\n\nYou can launch the server by the `lmdeploy serve api_server` CLI:\n\n```shell\nlmdeploy serve api_server llava-hf/llava-interleave-qwen-7b-hf\n```\n\nYou can also start the service using the aforementioned built docker image:\n\n```shell\ndocker run --runtime nvidia --gpus all \\\n    -v ~/.cache/huggingface:/root/.cache/huggingface \\\n    --env \"HUGGING_FACE_HUB_TOKEN=<secret>\" \\\n    -p 23333:23333 \\\n    --ipc=host \\\n    openmmlab/lmdeploy:latest \\\n    lmdeploy serve api_server llava-hf/llava-interleave-qwen-7b-hf\n```\n\nThe docker compose is another option. Create a `docker-compose.yml` configuration file in the root directory of the lmdeploy project as follows:\n\n```yaml\nversion: '3.5'\n\nservices:\n  lmdeploy:\n    container_name: lmdeploy\n    image: openmmlab/lmdeploy:latest\n    ports:\n      - \"23333:23333\"\n    environment:\n      HUGGING_FACE_HUB_TOKEN: <secret>\n    volumes:\n      - ~/.cache/huggingface:/root/.cache/huggingface\n    stdin_open: true\n    tty: true\n    ipc: host\n    command: lmdeploy serve api_server llava-hf/llava-interleave-qwen-7b-hf\n    deploy:\n      resources:\n        reservations:\n          devices:\n            - driver: nvidia\n              count: \"all\"\n              capabilities: [gpu]\n```\n\nThen, you can execute the startup command as below:\n\n```shell\ndocker-compose up -d\n```\n\nIf you find the following logs after running `docker logs -f lmdeploy`, it means the service launches successfully.\n\n```text\nHINT:    Please open  http://0.0.0.0:23333   in a browser for detailed api usage!!!\nHINT:    Please open  http://0.0.0.0:23333   in a browser for detailed api usage!!!\nHINT:    Please open  http://0.0.0.0:23333   in a browser for detailed api usage!!!\nINFO:     Started server process [2439]\nINFO:     Waiting for application startup.\nINFO:     Application startup complete.\nINFO:     Uvicorn running on  http://0.0.0.0:23333  (Press CTRL+C to quit)\n```\n\nThe arguments of `lmdeploy serve api_server` can be reviewed in detail by `lmdeploy serve api_server -h`.\n\nMore information about `api_server` as well as how to access the service can be found from [here](api_server_vl.md)\n"
  },
  {
    "path": "docs/en/multi_modal/minicpmv.md",
    "content": "# MiniCPM-V\n\nLMDeploy supports the following MiniCPM-V series of models, which are detailed in the table below:\n\n|        Model         | Supported Inference Engine |\n| :------------------: | :------------------------: |\n| MiniCPM-Llama3-V-2_5 |         TurboMind          |\n|    MiniCPM-V-2_6     |         TurboMind          |\n\nThe next chapter demonstrates how to deploy an MiniCPM-V model using LMDeploy, with [MiniCPM-V-2_6](https://huggingface.co/openbmb/MiniCPM-V-2_6) as an example.\n\n## Installation\n\nPlease install LMDeploy by following the [installation guide](../get_started/installation.md).\n\n## Offline inference\n\nThe following sample code shows the basic usage of VLM pipeline. For detailed information, please refer to [VLM Offline Inference Pipeline](./vl_pipeline.md)\n\n```python\nfrom lmdeploy import pipeline\nfrom lmdeploy.vl import load_image\n\npipe = pipeline('openbmb/MiniCPM-V-2_6')\n\nimage = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg')\nresponse = pipe(('describe this image', image))\nprint(response)\n```\n\nMore examples are listed below:\n\n<details>\n  <summary>\n    <b>Chat with multiple images</b>\n  </summary>\n\n```python\nfrom lmdeploy import pipeline, GenerationConfig\n\npipe = pipeline('openbmb/MiniCPM-V-2_6', log_level='INFO')\nmessages = [\n    dict(role='user', content=[\n        dict(type='text', text='Describe the two images in detail.'),\n        dict(type='image_url', image_url=dict(max_slice_nums=9, url='https://raw.githubusercontent.com/OpenGVLab/InternVL/main/internvl_chat/examples/image1.jpg')),\n        dict(type='image_url', image_url=dict(max_slice_nums=9, url='https://raw.githubusercontent.com/OpenGVLab/InternVL/main/internvl_chat/examples/image2.jpg'))\n    ])\n]\nout = pipe(messages, gen_config=GenerationConfig(top_k=1))\nprint(out.text)\n\nmessages.append(dict(role='assistant', content=out.text))\nmessages.append(dict(role='user', content='What are the similarities and differences between these two images.'))\nout = pipe(messages, gen_config=GenerationConfig(top_k=1))\nprint(out.text)\n```\n\n</details>\n\n<details>\n  <summary>\n    <b>In-context few-shot learning</b>\n  </summary>\n\n```python\nfrom lmdeploy import pipeline, GenerationConfig\n\npipe = pipeline('openbmb/MiniCPM-V-2_6', log_level='INFO')\n\nquestion = \"production date\"\nmessages = [\n    dict(role='user', content=[\n        dict(type='text', text=question),\n        dict(type='image_url', image_url=dict(url='example1.jpg')),\n    ]),\n    dict(role='assistant', content='2023.08.04'),\n    dict(role='user', content=[\n        dict(type='text', text=question),\n        dict(type='image_url', image_url=dict(url='example2.jpg')),\n    ]),\n    dict(role='assistant', content='2007.04.24'),\n    dict(role='user', content=[\n        dict(type='text', text=question),\n        dict(type='image_url', image_url=dict(url='test.jpg')),\n    ])\n]\nout = pipe(messages, gen_config=GenerationConfig(top_k=1))\nprint(out.text)\n```\n\n</details>\n\n<details>\n  <summary>\n    <b>Chat with video</b>\n  </summary>\n\n```python\nfrom lmdeploy import pipeline, GenerationConfig\nfrom lmdeploy.vl import encode_image_base64\nimport torch\nfrom PIL import Image\nfrom transformers import AutoModel, AutoTokenizer\nfrom decord import VideoReader, cpu    # pip install decord\n\npipe = pipeline('openbmb/MiniCPM-V-2_6', log_level='INFO')\n\nMAX_NUM_FRAMES=64 # if cuda OOM set a smaller number\ndef encode_video(video_path):\n    def uniform_sample(l, n):\n        gap = len(l) / n\n        idxs = [int(i * gap + gap / 2) for i in range(n)]\n        return [l[i] for i in idxs]\n    vr = VideoReader(video_path, ctx=cpu(0))\n    sample_fps = round(vr.get_avg_fps() / 1)  # FPS\n    frame_idx = [i for i in range(0, len(vr), sample_fps)]\n    if len(frame_idx) > MAX_NUM_FRAMES:\n        frame_idx = uniform_sample(frame_idx, MAX_NUM_FRAMES)\n    frames = vr.get_batch(frame_idx).asnumpy()\n    frames = [Image.fromarray(v.astype('uint8')) for v in frames]\n    print('num frames:', len(frames))\n    return frames\n\nvideo_path=\"video_test.mp4\"\nframes = encode_video(video_path)\nquestion = \"Describe the video\"\n\ncontent=[dict(type='text', text=question)]\nfor frame in frames:\n    content.append(dict(type='image_url', image_url=dict(use_image_id=False, max_slice_nums=2,\n        url=f'data:image/jpeg;base64,{encode_image_base64(frame)}')))\n\nmessages = [dict(role='user', content=content)]\nout = pipe(messages, gen_config=GenerationConfig(top_k=1))\nprint(out.text)\n```\n\n</details>\n\n## Online serving\n\nYou can launch the server by the `lmdeploy serve api_server` CLI:\n\n```shell\nlmdeploy serve api_server openbmb/MiniCPM-V-2_6\n```\n\nYou can also start the service using the official lmdeploy docker image:\n\n```shell\ndocker run --runtime nvidia --gpus all \\\n    -v ~/.cache/huggingface:/root/.cache/huggingface \\\n    --env \"HUGGING_FACE_HUB_TOKEN=<secret>\" \\\n    -p 23333:23333 \\\n    --ipc=host \\\n    openmmlab/lmdeploy:latest \\\n    lmdeploy serve api_server openbmb/MiniCPM-V-2_6\n```\n\nThe docker compose is another option. Create a `docker-compose.yml` configuration file in the root directory of the lmdeploy project as follows:\n\n```yaml\nversion: '3.5'\n\nservices:\n  lmdeploy:\n    container_name: lmdeploy\n    image: openmmlab/lmdeploy:latest\n    ports:\n      - \"23333:23333\"\n    environment:\n      HUGGING_FACE_HUB_TOKEN: <secret>\n    volumes:\n      - ~/.cache/huggingface:/root/.cache/huggingface\n    stdin_open: true\n    tty: true\n    ipc: host\n    command: lmdeploy serve api_server openbmb/MiniCPM-V-2_6\n    deploy:\n      resources:\n        reservations:\n          devices:\n            - driver: nvidia\n              count: \"all\"\n              capabilities: [gpu]\n```\n\nThen, you can execute the startup command as below:\n\n```shell\ndocker-compose up -d\n```\n\nIf you find the following logs after running `docker logs -f lmdeploy`, it means the service launches successfully.\n\n```text\nHINT:    Please open  http://0.0.0.0:23333   in a browser for detailed api usage!!!\nHINT:    Please open  http://0.0.0.0:23333   in a browser for detailed api usage!!!\nHINT:    Please open  http://0.0.0.0:23333   in a browser for detailed api usage!!!\nINFO:     Started server process [2439]\nINFO:     Waiting for application startup.\nINFO:     Application startup complete.\nINFO:     Uvicorn running on  http://0.0.0.0:23333  (Press CTRL+C to quit)\n```\n\nThe arguments of `lmdeploy serve api_server` can be reviewed in detail by `lmdeploy serve api_server -h`.\n\nMore information about `api_server` as well as how to access the service can be found from [here](api_server_vl.md)\n"
  },
  {
    "path": "docs/en/multi_modal/molmo.md",
    "content": "# Molmo\n\nLMDeploy supports the following molmo series of models, which are detailed in the table below:\n\n|      Model      | Size | Supported Inference Engine |\n| :-------------: | :--: | :------------------------: |\n| Molmo-7B-D-0924 |  7B  |         TurboMind          |\n|  Molmo-72-0924  | 72B  |         TurboMind          |\n\nThe next chapter demonstrates how to deploy a molmo model using LMDeploy, with [Molmo-7B-D-0924](https://huggingface.co/allenai/Molmo-7B-D-0924) as an example.\n\n## Installation\n\nPlease install LMDeploy by following the [installation guide](../get_started/installation.md)\n\n## Offline inference\n\nThe following sample code shows the basic usage of VLM pipeline. For detailed information, please refer to [VLM Offline Inference Pipeline](./vl_pipeline.md)\n\n```python\nfrom lmdeploy import pipeline\nfrom lmdeploy.vl import load_image\n\npipe = pipeline('allenai/Molmo-7B-D-0924')\n\nimage = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg')\nresponse = pipe((f'describe this image', image))\nprint(response)\n```\n\nMore examples are listed below:\n\n<details>\n  <summary>\n    <b>multi-image multi-round conversation, combined images</b>\n  </summary>\n\n```python\nfrom lmdeploy import pipeline, GenerationConfig\n\npipe = pipeline('allenai/Molmo-7B-D-0924', log_level='INFO')\nmessages = [\n    dict(role='user', content=[\n        dict(type='text', text='Describe the two images in detail.'),\n        dict(type='image_url', image_url=dict(url='https://raw.githubusercontent.com/QwenLM/Qwen-VL/master/assets/mm_tutorial/Beijing_Small.jpeg')),\n        dict(type='image_url', image_url=dict(url='https://raw.githubusercontent.com/QwenLM/Qwen-VL/master/assets/mm_tutorial/Chongqing_Small.jpeg'))\n    ])\n]\nout = pipe(messages, gen_config=GenerationConfig(do_sample=False))\n\nmessages.append(dict(role='assistant', content=out.text))\nmessages.append(dict(role='user', content='What are the similarities and differences between these two images.'))\nout = pipe(messages, gen_config=GenerationConfig(do_sample=False))\n```\n\n</details>\n\n## Online serving\n\nYou can launch the server by the `lmdeploy serve api_server` CLI:\n\n```shell\nlmdeploy serve api_server allenai/Molmo-7B-D-0924\n```\n\nYou can also start the service using the docker image:\n\n```shell\ndocker run --runtime nvidia --gpus all \\\n    -v ~/.cache/huggingface:/root/.cache/huggingface \\\n    --env \"HUGGING_FACE_HUB_TOKEN=<secret>\" \\\n    -p 23333:23333 \\\n    --ipc=host \\\n    openmmlab/lmdeploy:latest \\\n    lmdeploy serve api_server allenai/Molmo-7B-D-0924\n```\n\nIf you find the following logs, it means the service launches successfully.\n\n```text\nHINT:    Please open  http://0.0.0.0:23333   in a browser for detailed api usage!!!\nHINT:    Please open  http://0.0.0.0:23333   in a browser for detailed api usage!!!\nHINT:    Please open  http://0.0.0.0:23333   in a browser for detailed api usage!!!\nINFO:     Started server process [2439]\nINFO:     Waiting for application startup.\nINFO:     Application startup complete.\nINFO:     Uvicorn running on  http://0.0.0.0:23333  (Press CTRL+C to quit)\n```\n\nThe arguments of `lmdeploy serve api_server` can be reviewed in detail by `lmdeploy serve api_server -h`.\n\nMore information about `api_server` as well as how to access the service can be found from [here](api_server_vl.md)\n"
  },
  {
    "path": "docs/en/multi_modal/phi3.md",
    "content": "# Phi-3 Vision\n\n## Introduction\n\n[Phi-3](https://huggingface.co/collections/microsoft/phi-3-6626e15e9585a200d2d761e3) is a family of small language and multi-modal models from MicroSoft. LMDeploy supports the multi-modal models as below.\n\n|                                                Model                                                | Size | Supported Inference Engine |\n| :-------------------------------------------------------------------------------------------------: | :--: | :------------------------: |\n| [microsoft/Phi-3-vision-128k-instruct](https://huggingface.co/microsoft/Phi-3-vision-128k-instruct) | 4.2B |          PyTorch           |\n|    [microsoft/Phi-3.5-vision-instruct](https://huggingface.co/microsoft/Phi-3.5-vision-instruct)    | 4.2B |          PyTorch           |\n\nThe next chapter demonstrates how to deploy an Phi-3 model using LMDeploy, with [microsoft/Phi-3.5-vision-instruct](https://huggingface.co/microsoft/Phi-3.5-vision-instruct) as an example.\n\n## Installation\n\nPlease install LMDeploy by following the [installation guide](../get_started/installation.md) and install the dependency [Flash-Attention](https://github.com/Dao-AILab/flash-attention)\n\n```shell\n# It is recommended to find the whl package that matches the environment from the releases on https://github.com/Dao-AILab/flash-attention.\npip install flash-attn\n```\n\n## Offline inference\n\nThe following sample code shows the basic usage of VLM pipeline. For more examples, please refer to [VLM Offline Inference Pipeline](./vl_pipeline.md)\n\n```python\nfrom lmdeploy import pipeline\nfrom lmdeploy.vl import load_image\n\npipe = pipeline('microsoft/Phi-3.5-vision-instruct')\n\nimage = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg')\nresponse = pipe(('describe this image', image))\nprint(response)\n```\n\n## Online serving\n\n### Launch Service\n\nYou can launch the server by the `lmdeploy serve api_server` CLI:\n\n```shell\nlmdeploy serve api_server microsoft/Phi-3.5-vision-instruct\n```\n\n### Integrate with `OpenAI`\n\nHere is an example of interaction with the endpoint `v1/chat/completions` service via the openai package.\nBefore running it, please install the openai package by `pip install openai`\n\n```python\nfrom openai import OpenAI\n\nclient = OpenAI(api_key='YOUR_API_KEY', base_url='http://0.0.0.0:23333/v1')\nmodel_name = client.models.list().data[0].id\nresponse = client.chat.completions.create(\n    model=model_name,\n    messages=[{\n        'role':\n        'user',\n        'content': [{\n            'type': 'text',\n            'text': 'Describe the image please',\n        }, {\n            'type': 'image_url',\n            'image_url': {\n                'url':\n                'https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg',\n            },\n        }],\n    }],\n    temperature=0.8,\n    top_p=0.8)\nprint(response)\n```\n"
  },
  {
    "path": "docs/en/multi_modal/qwen2_5_vl.md",
    "content": "# Qwen2.5-VL\n\nLMDeploy supports the following Qwen-VL series of models, which are detailed in the table below:\n\n|   Model    |       Size       | Supported Inference Engine |\n| :--------: | :--------------: | :------------------------: |\n| Qwen2.5-VL | 3B, 7B, 32B, 72B |          PyTorch           |\n\nThe next chapter demonstrates how to deploy a Qwen-VL model using LMDeploy, with [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) as an example.\n\n## Installation\n\nPlease install LMDeploy by following the [installation guide](../get_started/installation.md), and install other packages that Qwen2.5-VL needs\n\n```shell\n# Qwen2.5-VL requires the latest transformers (transformers >= 4.49.0)\npip install git+https://github.com/huggingface/transformers\n# It's highly recommended to use `[decord]` feature for faster video loading.\npip install qwen-vl-utils[decord]==0.0.8\n```\n\n## Offline inference\n\nThe following sample code shows the basic usage of the VLM pipeline. For detailed information, please refer to [VLM Offline Inference Pipeline](./vl_pipeline.md)\n\n```python\nfrom lmdeploy import pipeline\nfrom lmdeploy.vl import load_image\n\npipe = pipeline('Qwen/Qwen2.5-VL-7B-Instruct')\n\nimage = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg')\nresponse = pipe((f'describe this image', image))\nprint(response)\n```\n\nMore examples are listed below:\n\n<details>\n  <summary>\n    <b>multi-image multi-round conversation, combined images</b>\n  </summary>\n\n```python\nfrom lmdeploy import pipeline, GenerationConfig\n\npipe = pipeline('Qwen/Qwen2.5-VL-7B-Instruct', log_level='INFO')\nmessages = [\n    dict(role='user', content=[\n        dict(type='text', text='Describe the two images in detail.'),\n        dict(type='image_url', image_url=dict(url='https://raw.githubusercontent.com/QwenLM/Qwen-VL/master/assets/mm_tutorial/Beijing_Small.jpeg')),\n        dict(type='image_url', image_url=dict(url='https://raw.githubusercontent.com/QwenLM/Qwen-VL/master/assets/mm_tutorial/Chongqing_Small.jpeg'))\n    ])\n]\nout = pipe(messages, gen_config=GenerationConfig(top_k=1))\n\nmessages.append(dict(role='assistant', content=out.text))\nmessages.append(dict(role='user', content='What are the similarities and differences between these two images.'))\nout = pipe(messages, gen_config=GenerationConfig(top_k=1))\n```\n\n</details>\n\n<details>\n  <summary>\n    <b>image resolution for performance boost</b>\n  </summary>\n\n```python\nfrom lmdeploy import pipeline, GenerationConfig\n\npipe = pipeline('Qwen/Qwen2.5-VL-7B-Instruct', log_level='INFO')\n\nmin_pixels = 64 * 28 * 28\nmax_pixels = 64 * 28 * 28\nmessages = [\n    dict(role='user', content=[\n        dict(type='text', text='Describe the two images in detail.'),\n        dict(type='image_url', image_url=dict(min_pixels=min_pixels, max_pixels=max_pixels, url='https://raw.githubusercontent.com/QwenLM/Qwen-VL/master/assets/mm_tutorial/Beijing_Small.jpeg')),\n        dict(type='image_url', image_url=dict(min_pixels=min_pixels, max_pixels=max_pixels, url='https://raw.githubusercontent.com/QwenLM/Qwen-VL/master/assets/mm_tutorial/Chongqing_Small.jpeg'))\n    ])\n]\nout = pipe(messages, gen_config=GenerationConfig(top_k=1))\n\nmessages.append(dict(role='assistant', content=out.text))\nmessages.append(dict(role='user', content='What are the similarities and differences between these two images.'))\nout = pipe(messages, gen_config=GenerationConfig(top_k=1))\n```\n\n</details>\n\n<details>\n  <summary>\n    <b>video multi-round conversation</b>\n  </summary>\n\n```python\nimport numpy as np\nfrom lmdeploy import pipeline, GenerationConfig\nfrom decord import VideoReader, cpu\nfrom lmdeploy.vl.constants import IMAGE_TOKEN\nfrom lmdeploy.vl import encode_image_base64\nfrom PIL import Image\npipe = pipeline('Qwen/Qwen2.5-VL-7B-Instruct', log_level='INFO')\n\n\ndef get_index(bound, fps, max_frame, first_idx=0, num_segments=32):\n    if bound:\n        start, end = bound[0], bound[1]\n    else:\n        start, end = -100000, 100000\n    start_idx = max(first_idx, round(start * fps))\n    end_idx = min(round(end * fps), max_frame)\n    seg_size = float(end_idx - start_idx) / num_segments\n    frame_indices = np.array([\n        int(start_idx + (seg_size / 2) + np.round(seg_size * idx))\n        for idx in range(num_segments)\n    ])\n    return frame_indices\n\n\ndef load_video(video_path, bound=None, num_segments=32):\n    vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)\n    max_frame = len(vr) - 1\n    fps = float(vr.get_avg_fps())\n    pixel_values_list, num_patches_list = [], []\n    frame_indices = get_index(bound, fps, max_frame, first_idx=0, num_segments=num_segments)\n    imgs = []\n    for frame_index in frame_indices:\n        img = Image.fromarray(vr[frame_index].asnumpy()).convert('RGB')\n        imgs.append(img)\n    return imgs\n\n\nvideo_path = 'red-panda.mp4'\nimgs = load_video(video_path, num_segments=8)\n\nquestion = ''\nfor i in range(len(imgs)):\n    question = question + f'Frame{i+1}: {IMAGE_TOKEN}\\n'\n\nquestion += 'What is the red panda doing?'\n\ncontent = [{'type': 'text', 'text': question}]\nfor img in imgs:\n    content.append({'type': 'image_url', 'image_url': {'max_dynamic_patch': 1, 'url': f'data:image/jpeg;base64,{encode_image_base64(img)}'}})\n\nmessages = [dict(role='user', content=content)]\nout = pipe(messages, gen_config=GenerationConfig(top_k=1))\n\nmessages.append(dict(role='assistant', content=out.text))\nmessages.append(dict(role='user', content='Describe this video in detail. Don\\'t repeat.'))\nout = pipe(messages, gen_config=GenerationConfig(top_k=1))\n```\n\n</details>\n"
  },
  {
    "path": "docs/en/multi_modal/qwen2_vl.md",
    "content": "# Qwen2-VL\n\nLMDeploy supports the following Qwen-VL series of models, which are detailed in the table below:\n\n|    Model     |  Size  | Supported Inference Engine |\n| :----------: | :----: | :------------------------: |\n| Qwen-VL-Chat |   -    |         TurboMind          |\n|   Qwen2-VL   | 2B, 7B |          PyTorch           |\n\nThe next chapter demonstrates how to deploy an Qwen-VL model using LMDeploy, with [Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) as an example.\n\n## Installation\n\nPlease install LMDeploy by following the [installation guide](../get_started/installation.md), and install other packages that Qwen2-VL needs\n\n```shell\npip install qwen_vl_utils\n```\n\nOr, you can build a docker image to set up the inference environment. If the CUDA version on your host machine is `>=12.4`, you can run:\n\n```\ngit clone https://github.com/InternLM/lmdeploy.git\ncd lmdeploy\ndocker build --build-arg CUDA_VERSION=cu12 -t openmmlab/lmdeploy:qwen2vl . -f ./docker/Qwen2VL_Dockerfile\n```\n\nOtherwise, you can go with:\n\n```shell\ndocker build --build-arg CUDA_VERSION=cu11 -t openmmlab/lmdeploy:qwen2vl . -f ./docker/Qwen2VL_Dockerfile\n```\n\n## Offline inference\n\nThe following sample code shows the basic usage of VLM pipeline. For detailed information, please refer to [VLM Offline Inference Pipeline](./vl_pipeline.md)\n\n```python\nfrom lmdeploy import pipeline\nfrom lmdeploy.vl import load_image\n\npipe = pipeline('Qwen/Qwen2-VL-2B-Instruct')\n\nimage = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg')\nresponse = pipe((f'describe this image', image))\nprint(response)\n```\n\nMore examples are listed below:\n\n<details>\n  <summary>\n    <b>multi-image multi-round conversation, combined images</b>\n  </summary>\n\n```python\nfrom lmdeploy import pipeline, GenerationConfig\n\npipe = pipeline('Qwen/Qwen2-VL-2B-Instruct', log_level='INFO')\nmessages = [\n    dict(role='user', content=[\n        dict(type='text', text='Describe the two images in detail.'),\n        dict(type='image_url', image_url=dict(url='https://raw.githubusercontent.com/QwenLM/Qwen-VL/master/assets/mm_tutorial/Beijing_Small.jpeg')),\n        dict(type='image_url', image_url=dict(url='https://raw.githubusercontent.com/QwenLM/Qwen-VL/master/assets/mm_tutorial/Chongqing_Small.jpeg'))\n    ])\n]\nout = pipe(messages, gen_config=GenerationConfig(top_k=1))\n\nmessages.append(dict(role='assistant', content=out.text))\nmessages.append(dict(role='user', content='What are the similarities and differences between these two images.'))\nout = pipe(messages, gen_config=GenerationConfig(top_k=1))\n```\n\n</details>\n\n<details>\n  <summary>\n    <b>image resolution for performance boost</b>\n  </summary>\n\n```python\nfrom lmdeploy import pipeline, GenerationConfig\n\npipe = pipeline('Qwen/Qwen2-VL-2B-Instruct', log_level='INFO')\n\nmin_pixels = 64 * 28 * 28\nmax_pixels = 64 * 28 * 28\nmessages = [\n    dict(role='user', content=[\n        dict(type='text', text='Describe the two images in detail.'),\n        dict(type='image_url', image_url=dict(min_pixels=min_pixels, max_pixels=max_pixels, url='https://raw.githubusercontent.com/QwenLM/Qwen-VL/master/assets/mm_tutorial/Beijing_Small.jpeg')),\n        dict(type='image_url', image_url=dict(min_pixels=min_pixels, max_pixels=max_pixels, url='https://raw.githubusercontent.com/QwenLM/Qwen-VL/master/assets/mm_tutorial/Chongqing_Small.jpeg'))\n    ])\n]\nout = pipe(messages, gen_config=GenerationConfig(top_k=1))\n\nmessages.append(dict(role='assistant', content=out.text))\nmessages.append(dict(role='user', content='What are the similarities and differences between these two images.'))\nout = pipe(messages, gen_config=GenerationConfig(top_k=1))\n```\n\n</details>\n\n## Online serving\n\nYou can launch the server by the `lmdeploy serve api_server` CLI:\n\n```shell\nlmdeploy serve api_server Qwen/Qwen2-VL-2B-Instruct\n```\n\nYou can also start the service using the aforementioned built docker image:\n\n```shell\ndocker run --runtime nvidia --gpus all \\\n    -v ~/.cache/huggingface:/root/.cache/huggingface \\\n    --env \"HUGGING_FACE_HUB_TOKEN=<secret>\" \\\n    -p 23333:23333 \\\n    --ipc=host \\\n    openmmlab/lmdeploy:qwen2vl \\\n    lmdeploy serve api_server Qwen/Qwen2-VL-2B-Instruct\n```\n\nThe docker compose is another option. Create a `docker-compose.yml` configuration file in the root directory of the lmdeploy project as follows:\n\n```yaml\nversion: '3.5'\n\nservices:\n  lmdeploy:\n    container_name: lmdeploy\n    image: openmmlab/lmdeploy:qwen2vl\n    ports:\n      - \"23333:23333\"\n    environment:\n      HUGGING_FACE_HUB_TOKEN: <secret>\n    volumes:\n      - ~/.cache/huggingface:/root/.cache/huggingface\n    stdin_open: true\n    tty: true\n    ipc: host\n    command: lmdeploy serve api_server Qwen/Qwen2-VL-2B-Instruct\n    deploy:\n      resources:\n        reservations:\n          devices:\n            - driver: nvidia\n              count: \"all\"\n              capabilities: [gpu]\n```\n\nThen, you can execute the startup command as below:\n\n```shell\ndocker-compose up -d\n```\n\nIf you find the following logs after running `docker logs -f lmdeploy`, it means the service launches successfully.\n\n```text\nHINT:    Please open  http://0.0.0.0:23333   in a browser for detailed api usage!!!\nHINT:    Please open  http://0.0.0.0:23333   in a browser for detailed api usage!!!\nHINT:    Please open  http://0.0.0.0:23333   in a browser for detailed api usage!!!\nINFO:     Started server process [2439]\nINFO:     Waiting for application startup.\nINFO:     Application startup complete.\nINFO:     Uvicorn running on  http://0.0.0.0:23333  (Press CTRL+C to quit)\n```\n\nThe arguments of `lmdeploy serve api_server` can be reviewed in detail by `lmdeploy serve api_server -h`.\n\nMore information about `api_server` as well as how to access the service can be found from [here](api_server_vl.md)\n"
  },
  {
    "path": "docs/en/multi_modal/vl_pipeline.md",
    "content": "# Offline Inference Pipeline\n\nLMDeploy abstracts the complex inference process of multi-modal Vision-Language Models (VLM) into an easy-to-use pipeline, similar to the Large Language Model (LLM) inference [pipeline](../llm/pipeline.md).\n\nThe supported models are listed [here](../supported_models/supported_models.md). We genuinely invite the community to contribute new VLM support to LMDeploy. Your involvement is truly appreciated.\n\nThis article showcases the VLM pipeline using the [OpenGVLab/InternVL2_5-8B](https://huggingface.co/OpenGVLab/InternVL2_5-8B) model as a case study.\nYou'll learn about the simplest ways to leverage the pipeline and how to gradually unlock more advanced features by adjusting engine parameters and generation arguments, such as tensor parallelism, context window sizing, random sampling, and chat template customization.\nMoreover, we will provide practical inference examples tailored to scenarios with multiple images, batch prompts etc.\n\nUsing the pipeline interface to infer other VLM models is similar, with the main difference being the configuration and installation dependencies of the models. You can read [here](https://lmdeploy.readthedocs.io/en/latest/multi_modal/index.html) for environment installation and configuration methods for different models.\n\n## A 'Hello, world' example\n\n```python\nfrom lmdeploy import pipeline\nfrom lmdeploy.vl import load_image\n\npipe = pipeline('OpenGVLab/InternVL2_5-8B')\n\nimage = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg')\nresponse = pipe(('describe this image', image))\nprint(response)\n```\n\nIf `ImportError` occurs while executing this case, please install the required dependency packages as prompted.\n\nIn the above example, the inference prompt is a tuple structure consisting of (prompt, image). Besides this structure, the pipeline also supports prompts in the OpenAI format:\n\n```python\nfrom lmdeploy import pipeline\n\npipe = pipeline('OpenGVLab/InternVL2_5-8B')\n\nprompts = [\n    {\n        'role': 'user',\n        'content': [\n            {'type': 'text', 'text': 'describe this image'},\n            {'type': 'image_url', 'image_url': {'url': 'https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg'}}\n        ]\n    }\n]\nresponse = pipe(prompts)\nprint(response)\n```\n\n### Set tensor parallelism\n\nTensor paramllelism can be activated by setting the engine parameter `tp`\n\n```python\nfrom lmdeploy import pipeline, TurbomindEngineConfig\nfrom lmdeploy.vl import load_image\n\npipe = pipeline('OpenGVLab/InternVL2_5-8B',\n                backend_config=TurbomindEngineConfig(tp=2))\n\nimage = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg')\nresponse = pipe(('describe this image', image))\nprint(response)\n```\n\n### Set context window size\n\nWhen creating the pipeline, you can customize the size of the context window by setting the engine parameter `session_len`.\n\n```python\nfrom lmdeploy import pipeline, TurbomindEngineConfig\nfrom lmdeploy.vl import load_image\n\npipe = pipeline('OpenGVLab/InternVL2_5-8B',\n                backend_config=TurbomindEngineConfig(session_len=8192))\n\nimage = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg')\nresponse = pipe(('describe this image', image))\nprint(response)\n```\n\n### Set sampling parameters\n\nYou can change the default sampling parameters of pipeline by passing `GenerationConfig`\n\n```python\nfrom lmdeploy import pipeline, GenerationConfig, TurbomindEngineConfig\nfrom lmdeploy.vl import load_image\n\npipe = pipeline('OpenGVLab/InternVL2_5-8B',\n                backend_config=TurbomindEngineConfig(tp=2, session_len=8192))\ngen_config = GenerationConfig(top_k=40, top_p=0.8, temperature=0.6)\nimage = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg')\nresponse = pipe(('describe this image', image), gen_config=gen_config)\nprint(response)\n```\n\n### Customize image token position\n\nBy default, LMDeploy inserts the special image token into the user prompt following the chat template defined by the upstream algorithm repository. However, for certain models where the image token's position is unrestricted, such as deepseek-vl, or when users require a customized image token placement, manual insertion of the special image token into the prompt is necessary. LMDeploy use `<IMAGE_TOKEN>` as the special image token.\n\n```python\nfrom lmdeploy import pipeline\nfrom lmdeploy.vl import load_image\nfrom lmdeploy.vl.constants import IMAGE_TOKEN\n\npipe = pipeline('deepseek-ai/deepseek-vl-1.3b-chat')\n\nimage = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg')\nresponse = pipe((f'describe this image{IMAGE_TOKEN}', image))\nprint(response)\n```\n\n### Set chat template\n\nWhile performing inference, LMDeploy identifies an appropriate chat template from its builtin collection based on the model path and subsequently applies this template to the input prompts. However, when a chat template cannot be told from its model path, users have to specify it. For example, [liuhaotian/llava-v1.5-7b](https://huggingface.co/liuhaotian/llava-v1.5-7b) employs the ['llava-v1'](https://github.com/haotian-liu/LLaVA/blob/v1.2.2/llava/conversation.py#L325-L335) chat template, if user have a custom folder name instead of the official 'llava-v1.5-7b', the user needs to specify it by setting 'llava-v1' to `ChatTemplateConfig` as follows:\n\n```python\nfrom lmdeploy import pipeline, ChatTemplateConfig\nfrom lmdeploy.vl import load_image\npipe = pipeline('local_model_folder',\n                chat_template_config=ChatTemplateConfig(model_name='llava-v1'))\nimage = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg')\nresponse = pipe(('describe this image', image))\nprint(response)\n```\n\nFor more information about customizing a chat template, please refer to [this](../advance/chat_template.md) guide\n\n### Setting vision model parameters\n\nThe default parameters of the visual model can be modified by setting `VisionConfig`.\n\n```python\nfrom lmdeploy import pipeline, VisionConfig\nfrom lmdeploy.vl import load_image\nvision_config=VisionConfig(max_batch_size=16)\npipe = pipeline('liuhaotian/llava-v1.5-7b', vision_config=vision_config)\nimage = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg')\nresponse = pipe(('describe this image', image))\nprint(response)\n```\n\n### Output logits for generated tokens\n\n```python\nfrom lmdeploy import pipeline, GenerationConfig\nfrom lmdeploy.vl import load_image\npipe = pipeline('OpenGVLab/InternVL2_5-8B')\n\nimage = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg')\n\nresponse = pipe(('describe this image', image),\n                gen_config=GenerationConfig(output_logits='generation'))\nlogits = response.logits\nprint(logits)\n```\n\n## Multi-images inference\n\nWhen dealing with multiple images, you can put them all in one list. Keep in mind that multiple images will lead to a higher number of input tokens, and as a result, the size of the [context window](#set-context-window-size) typically needs to be increased.\n\n```python\nfrom lmdeploy import pipeline, TurbomindEngineConfig\nfrom lmdeploy.vl import load_image\n\npipe = pipeline('OpenGVLab/InternVL2_5-8B',\n                backend_config=TurbomindEngineConfig(session_len=8192))\n\nimage_urls=[\n    'https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/demo/resources/human-pose.jpg',\n    'https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/demo/resources/det.jpg'\n]\n\nimages = [load_image(img_url) for img_url in image_urls]\nresponse = pipe(('describe these images', images))\nprint(response)\n```\n\n## Batch prompts inference\n\nConducting inference with batch prompts is quite straightforward; just place them within a list structure:\n\n```python\nfrom lmdeploy import pipeline, TurbomindEngineConfig\nfrom lmdeploy.vl import load_image\n\npipe = pipeline('OpenGVLab/InternVL2_5-8B',\n                backend_config=TurbomindEngineConfig(session_len=8192))\n\nimage_urls=[\n    \"https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/demo/resources/human-pose.jpg\",\n    \"https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/demo/resources/det.jpg\"\n]\nprompts = [('describe this image', load_image(img_url)) for img_url in image_urls]\nresponse = pipe(prompts)\nprint(response)\n```\n\n## Multi-turn conversation\n\nThere are two ways to do the multi-turn conversations with the pipeline. One is to construct messages according to the format of OpenAI and use above introduced method, the other is to use the `pipeline.chat` interface.\n\n```python\nfrom lmdeploy import pipeline, TurbomindEngineConfig, GenerationConfig\nfrom lmdeploy.vl import load_image\n\npipe = pipeline('OpenGVLab/InternVL2_5-8B',\n                backend_config=TurbomindEngineConfig(session_len=8192))\n\nimage = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/demo/resources/human-pose.jpg')\ngen_config = GenerationConfig(top_k=40, top_p=0.8, temperature=0.8)\nsess = pipe.chat(('describe this image', image), gen_config=gen_config)\nprint(sess.response.text)\nsess = pipe.chat('What is the woman doing?', session=sess, gen_config=gen_config)\nprint(sess.response.text)\n```\n\n## Release pipeline\n\nYou can release the pipeline explicitly by calling its `close()` method, or alternatively, use the `with` statement as demonstrated below:\n\n```python\nfrom lmdeploy import pipeline\n\nfrom lmdeploy import pipeline\nfrom lmdeploy.vl import load_image\n\nwith pipeline('OpenGVLab/InternVL2_5-8B') as pipe:\n    image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg')\n    response = pipe(('describe this image', image))\n    print(response)\n\n# Clear the torch cache and perform garbage collection if needed\nimport torch\nimport gc\ntorch.cuda.empty_cache()\ngc.collect()\n```\n"
  },
  {
    "path": "docs/en/multi_modal/xcomposer2d5.md",
    "content": "# InternLM-XComposer-2.5\n\n## Introduction\n\n[InternLM-XComposer-2.5](https://github.com/InternLM/InternLM-XComposer) excels in various text-image comprehension and composition applications, achieving GPT-4V level capabilities with merely 7B LLM backend. IXC-2.5 is trained with 24K interleaved image-text contexts, it can seamlessly extend to 96K long contexts via RoPE extrapolation. This long-context capability allows IXC-2.5 to perform exceptionally well in tasks requiring extensive input and output contexts. LMDeploy supports model [internlm/internlm-xcomposer2d5-7b](https://huggingface.co/internlm/internlm-xcomposer2d5-7b)  in TurboMind engine.\n\n## Quick Start\n\n### Installation\n\nPlease install LMDeploy by following the [installation guide](../get_started/installation.md), and install other packages that InternLM-XComposer-2.5 needs\n\n```shell\npip install decord\n```\n\n### Offline inference pipeline\n\nThe following sample code shows the basic usage of VLM pipeline. For more examples, please refer to [VLM Offline Inference Pipeline](./vl_pipeline.md)\n\n```python\nfrom lmdeploy import pipeline\nfrom lmdeploy.vl import load_image\nfrom lmdeploy.vl.constants import IMAGE_TOKEN\n\npipe = pipeline('internlm/internlm-xcomposer2d5-7b')\n\nimage = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg')\nresponse = pipe((f'describe this image', image))\nprint(response)\n```\n\n## Lora Model\n\nInternLM-XComposer-2.5 trained the LoRA weights for webpage creation and article writing. As TurboMind backend doesn't support slora, only one LoRA model can be deployed at a time, and the LoRA weights need to be merged when deploying the model. LMDeploy provides the corresponding conversion script, which is used as follows:\n\n```\nexport HF_MODEL=internlm/internlm-xcomposer2d5-7b\nexport WORK_DIR=internlm/internlm-xcomposer2d5-7b-web\nexport TASK=web\npython -m lmdeploy.vl.tools.merge_xcomposer2d5_task $HF_MODEL $WORK_DIR --task $TASK\n```\n\n## Quantization\n\nThe following takes the base model as an example to show the quantization method. If you want to use the LoRA model, please merge the LoRA model according to the previous section.\n\n```shell\n\nexport HF_MODEL=internlm/internlm-xcomposer2d5-7b\nexport WORK_DIR=internlm/internlm-xcomposer2d5-7b-4bit\n\nlmdeploy lite auto_awq \\\n   $HF_MODEL \\\n  --work-dir $WORK_DIR\n```\n\n## More examples\n\n<details>\n  <summary>\n    <b>Video Understanding</b>\n  </summary>\n\nThe following uses the `pipeline.chat` interface api as an example to demonstrate its usage. Other interfaces apis also support inference but require manually splicing of conversation content.\n\n```python\nfrom lmdeploy import pipeline, GenerationConfig\nfrom transformers.dynamic_module_utils import get_class_from_dynamic_module\n\nHF_MODEL = 'internlm/internlm-xcomposer2d5-7b'\nload_video = get_class_from_dynamic_module('ixc_utils.load_video', HF_MODEL)\nframe2img = get_class_from_dynamic_module('ixc_utils.frame2img', HF_MODEL)\nVideo_transform = get_class_from_dynamic_module('ixc_utils.Video_transform', HF_MODEL)\nget_font = get_class_from_dynamic_module('ixc_utils.get_font', HF_MODEL)\n\nvideo = load_video('liuxiang.mp4') # https://github.com/InternLM/InternLM-XComposer/raw/main/examples/liuxiang.mp4\nimg = frame2img(video, get_font())\nimg = Video_transform(img)\n\npipe = pipeline(HF_MODEL)\ngen_config = GenerationConfig(top_k=50, top_p=0.8, temperature=1.0)\nquery = 'Here are some frames of a video. Describe this video in detail'\nsess = pipe.chat((query, img), gen_config=gen_config)\nprint(sess.response.text)\n\nquery = 'tell me the athlete code of Liu Xiang'\nsess = pipe.chat(query, session=sess, gen_config=gen_config)\nprint(sess.response.text)\n```\n\n</details>\n\n<details>\n  <summary>\n    <b>Multi-Image</b>\n  </summary>\n\n```python\nfrom lmdeploy import pipeline, GenerationConfig\nfrom lmdeploy.vl.constants import IMAGE_TOKEN\nfrom lmdeploy.vl import load_image\n\nquery = f'Image1 {IMAGE_TOKEN}; Image2 {IMAGE_TOKEN}; Image3 {IMAGE_TOKEN}; I want to buy a car from the three given cars, analyze their advantages and weaknesses one by one'\n\nurls = ['https://raw.githubusercontent.com/InternLM/InternLM-XComposer/main/examples/cars1.jpg',\n        'https://raw.githubusercontent.com/InternLM/InternLM-XComposer/main/examples/cars2.jpg',\n        'https://raw.githubusercontent.com/InternLM/InternLM-XComposer/main/examples/cars3.jpg']\nimages = [load_image(url) for url in urls]\n\npipe = pipeline('internlm/internlm-xcomposer2d5-7b', log_level='INFO')\noutput = pipe((query, images), gen_config=GenerationConfig(top_k=0, top_p=0.8, random_seed=89247526689433939))\n```\n\nSince LMDeploy does not support beam search, the generated results will be quite different from those using beam search with transformers. It is recommended to turn off top_k or use a larger top_k sampling to increase diversity.\n\n</details>\n\n<details>\n  <summary>\n    <b>Instruction to Webpage</b>\n  </summary>\n\nPlease first convert the web model using the instructions above.\n\n```python\nfrom lmdeploy import pipeline, GenerationConfig\n\npipe = pipeline('/nvme/shared/internlm-xcomposer2d5-7b-web', log_level='INFO')\npipe.chat_template.meta_instruction = None\n\nquery = 'A website for Research institutions. The name is Shanghai AI lab. Top Navigation Bar is blue.Below left, an image shows the logo of the lab. In the right, there is a passage of text below that describes the mission of the laboratory.There are several images to show the research projects of Shanghai AI lab.'\noutput = pipe(query, gen_config=GenerationConfig(max_new_tokens=2048))\n```\n\nWhen using transformers for testing, it is found that if repetition_penalty is set, there is a high probability that the decode phase will not stop if `num_beams` is set to 1. As LMDeploy does not support beam search, it is recommended to turn off repetition_penalty when using LMDeploy for inference.\n\n</details>\n\n<details>\n  <summary>\n    <b>Write Article</b>\n  </summary>\n\nPlease first convert the write model using the instructions above.\n\n```python\nfrom lmdeploy import pipeline, GenerationConfig\n\npipe = pipeline('/nvme/shared/internlm-xcomposer2d5-7b-write', log_level='INFO')\npipe.chat_template.meta_instruction = None\n\nquery = 'Please write a blog based on the title: French Pastries: A Sweet Indulgence'\noutput = pipe(query, gen_config=GenerationConfig(max_new_tokens=8192))\n```\n\n</details>\n"
  },
  {
    "path": "docs/en/quantization/kv_quant.md",
    "content": "# INT4/INT8 KV Cache\n\nSince v0.4.0, LMDeploy has supported **online** key-value (kv) cache quantization with int4 and int8 numerical precision, utilizing an asymmetric quantization method that is applied on a per-head, per-token basis. The original kv offline quantization method has been removed.\n\nIntuitively, quantization is beneficial for increasing the number of kv block. Compared to fp16, the number of kv block for int4/int8 kv can be increased by 4 times and 2 times respectively. This means that under the same memory conditions, the system can support a significantly increased number of concurrent operations after kv quantization, thereby ultimately enhancing throughput.\n\nHowever, quantization typically brings in some loss of model accuracy. We have used OpenCompass to evaluate the accuracy of several models after applying int4/int8 quantization. int8 kv keeps the accuracy while int4 kv has slight loss. The detailed results are presented in the [Evaluation](#evaluation) section. You can refer to the information and choose wisely based on your requirements.\n\nLMDeploy inference with quantized kv supports the following NVIDIA GPU models:\n\n- Volta architecture (sm70): V100\n- Turing architecture (sm75): 20 series, T4\n- Ampere architecture (sm80, sm86): 30 series, A10, A16, A30, A100\n- Ada Lovelace architecture (sm89): 40 series\n- Hopper architecture (sm90): H100, H200\n\nIn summary, LMDeploy kv quantization has the following advantages:\n\n1. data-free online quantization\n2. Supports all nvidia GPU models with Volta architecture (sm70) and above\n3. KV int8 quantization has almost lossless accuracy, and KV int4 quantization accuracy is within an acceptable range\n4. Efficient inference, with int8/int4 kv quantization applied to llama2-7b, RPS is improved by round 30% and 40% respectively compared to fp16\n\nIn the next section, we will take `internlm2-chat-7b` model as an example, introducing the usage of kv quantization and inference of lmdeploy. But before that, please ensure that lmdeploy is installed.\n\n```shell\npip install lmdeploy\n```\n\n## Usage\n\nApplying kv quantization and inference via LMDeploy is quite straightforward. Simply set the `quant_policy` parameter.\n\n**LMDeploy specifies that `quant_policy=4` stands for 4-bit kv, whereas `quant_policy=8` indicates 8-bit kv.**\n\n### Offline inference\n\n```python\nfrom lmdeploy import pipeline, TurbomindEngineConfig\nengine_config = TurbomindEngineConfig(quant_policy=8)\npipe = pipeline(\"internlm/internlm2_5-7b-chat\", backend_config=engine_config)\nresponse = pipe([\"Hi, pls intro yourself\", \"Shanghai is\"])\nprint(response)\n```\n\n### Serving\n\n```shell\nlmdeploy serve api_server internlm/internlm2_5-7b-chat --quant-policy 8\n```\n\n## Evaluation\n\nWe apply kv quantization of LMDeploy to several LLM models and utilize OpenCompass to evaluate the inference accuracy. The results are shown in the table below:\n\n| -           | -       | -             | llama2-7b-chat | -       | -       | internlm2-chat-7b | -       | -       | internlm2.5-chat-7b | -       | -       | qwen1.5-7b-chat | -       | -       |\n| ----------- | ------- | ------------- | -------------- | ------- | ------- | ----------------- | ------- | ------- | ------------------- | ------- | ------- | --------------- | ------- | ------- |\n| dataset     | version | metric        | kv fp16        | kv int8 | kv int4 | kv fp16           | kv int8 | kv int4 | kv fp16             | kv int8 | kv int4 | fp16            | kv int8 | kv int4 |\n| ceval       | -       | naive_average | 28.42          | 27.96   | 27.58   | 60.45             | 60.88   | 60.28   | 78.06               | 77.87   | 77.05   | 70.56           | 70.49   | 68.62   |\n| mmlu        | -       | naive_average | 35.64          | 35.58   | 34.79   | 63.91             | 64      | 62.36   | 72.30               | 72.27   | 71.17   | 61.48           | 61.56   | 60.65   |\n| triviaqa    | 2121ce  | score         | 56.09          | 56.13   | 53.71   | 58.73             | 58.7    | 58.18   | 65.09               | 64.87   | 63.28   | 44.62           | 44.77   | 44.04   |\n| gsm8k       | 1d7fe4  | accuracy      | 28.2           | 28.05   | 27.37   | 70.13             | 69.75   | 66.87   | 85.67               | 85.44   | 83.78   | 54.97           | 56.41   | 54.74   |\n| race-middle | 9a54b6  | accuracy      | 41.57          | 41.78   | 41.23   | 88.93             | 88.93   | 88.93   | 92.76               | 92.83   | 92.55   | 87.33           | 87.26   | 86.28   |\n| race-high   | 9a54b6  | accuracy      | 39.65          | 39.77   | 40.77   | 85.33             | 85.31   | 84.62   | 90.51               | 90.42   | 90.42   | 82.53           | 82.59   | 82.02   |\n\nFor detailed evaluation methods, please refer to [this](../benchmark/evaluate_with_opencompass.md) guide. Remember to pass `quant_policy` to the inference engine in the config file.\n\n## Performance\n\n| model             | kv type | test settings                            | RPS   | v.s. kv fp16 |\n| ----------------- | ------- | ---------------------------------------- | ----- | ------------ |\n| llama2-chat-7b    | fp16    | tp1 / ratio 0.8 / bs 256 / prompts 10000 | 14.98 | 1.0          |\n| -                 | int8    | tp1 / ratio 0.8 / bs 256 / prompts 10000 | 19.01 | 1.27         |\n| -                 | int4    | tp1 / ratio 0.8 / bs 256 / prompts 10000 | 20.81 | 1.39         |\n| llama2-chat-13b   | fp16    | tp1 / ratio 0.9 / bs 128 / prompts 10000 | 8.55  | 1.0          |\n| -                 | int8    | tp1 / ratio 0.9 / bs 256 / prompts 10000 | 10.96 | 1.28         |\n| -                 | int4    | tp1 / ratio 0.9 / bs 256 / prompts 10000 | 11.91 | 1.39         |\n| internlm2-chat-7b | fp16    | tp1 / ratio 0.8 / bs 256 / prompts 10000 | 24.13 | 1.0          |\n| -                 | int8    | tp1 / ratio 0.8 / bs 256 / prompts 10000 | 25.28 | 1.05         |\n| -                 | int4    | tp1 / ratio 0.8 / bs 256 / prompts 10000 | 25.80 | 1.07         |\n\nThe performance data is obtained by `benchmark/profile_throughput.py`\n"
  },
  {
    "path": "docs/en/quantization/llm_compressor.md",
    "content": "# llm-compressor Support\n\nThis guide aims to introduce how to use LMDeploy's TurboMind inference engine to run models quantized by the [vllm-project/llm-compressor](https://github.com/vllm-project/llm-compressor) tool.\n\nCurrently supported `llm-compressor` quantization types include:\n\n- int4 quantization (e.g., AWQ, GPTQ)\n\nThese quantized models can run via the TurboMind engine on the following NVIDIA GPU architectures:\n\n| Compute Capability | Micro-architecture | GPUs                            |\n| ------------------ | ------------------ | ------------------------------- |\n| 7.0                | Volta              | V100                            |\n| 7.2                | Volta              | Jetson Xavier                   |\n| 7.5                | Turing             | GeForce RTX 20 series, T4       |\n| 8.0                | Ampere             | A100, A800, A30                 |\n| 8.6                | Ampere             | GeForce RTX 30 series, A40, A10 |\n| 8.7                | Ampere             | Jetson Orin                     |\n| 8.9                | Ada Lovelace       | GeForce RTX 40 series, L40, L20 |\n| 9.0                | Hopper             | H20, H200, H100, GH200          |\n| 12.0               | Blackwell          | GeForce RTX 50 series           |\n\nLMDeploy will continue to follow up and expand support for the `llm-compressor` project.\n\nThe remainder of this document consists of the following sections:\n\n<!-- toc -->\n\n- [Model Quantization](#model-quantization)\n- [Model Deployment](#model-deployment)\n- [Accuracy Evaluation](#accuracy-evaluation)\n\n<!-- tocstop -->\n\n## Model Quantization\n\n`llm-compressor` provides a wealth of model quantization [examples](https://github.com/vllm-project/llm-compressor/tree/main/examples). Please refer to its tutorials to select a quantization algorithm supported by LMDeploy to complete your model quantization work.\n\nLMDeploy also provides a built-in [script](https://github.com/InternLM/lmdeploy/blob/main/examples/lite/qwen3_30b_a3b_awq.py) for AWQ quantization of **Qwen3-30B-A3B** using `llm-compressor` for your reference:\n\n```shell\n# Create conda environment\nconda create -n lmdeploy python=3.10 -y\nconda activate lmdeploy\n\n# Install llm-compressor\npip install llmcompressor\n\n# Clone lmdeploy source code and run the quantization example\ngit clone https://github.com/InternLM/lmdeploy\ncd lmdeploy\npython examples/lite/qwen3_30b_a3b_awq.py --work-dir ./qwen3_30b_a3b_awq\n```\n\nIn the following sections, we will use this quantized model as an example to introduce model deployment and accuracy evaluation methods.\n\n## Model Deployment\n\n### Offline Inference\n\nWith the quantized model, offline batch processing can be implemented with just a few lines of code:\n\n```python\nfrom lmdeploy import pipeline, TurbomindEngineConfig\n\nengine_config = TurbomindEngineConfig()\nwith pipeline(\"./qwen3_30b_a3b_4bit\", backend_config=engine_config) as pipe:\n    response = pipe([\"Hi, pls intro yourself\", \"Shanghai is\"])\n    print(response)\n```\n\nFor a detailed introduction to the pipeline, please refer to [here](https://lmdeploy.readthedocs.io/en/latest/llm/pipeline.html).\n\n### Online Serving\n\nLMDeploy api_server supports encapsulating the model as a service with a single command. The provided RESTful APIs are compatible with OpenAI interfaces. Below is an example of starting the service:\n\n```shell\nlmdeploy serve api_server ./qwen3_30b_a3b_4bit --backend turbomind\n```\n\nThe default service port is 23333. After the server starts, you can access the service via the OpenAI SDK. For command arguments and methods to access the service, please read [this](https://lmdeploy.readthedocs.io/en/latest/llm/api_server.html) document.\n\n## Accuracy Evaluation\n\nAftering deploying AWQ symmetric/asymmetric quantized models of Qwen3-8B (Dense) and Qwen3-30B-A3B (MoE) as services via LMDeploy, we evaluated their accuracy on several academic datasets using [opencompass](https://github.com/open-compass/opencompass). Results indicate that, for Qwen3-8B, asymmetric quantization generally outperforms symmetric quantization, while Qwen3-30B-A3B shows no substantial difference between symmetric and asymmetric quantization. Compared with BF16, Qwen3-8B shows a smaller accuracy gap under both symmetric and asymmetric quantization than Qwen3-30B-A3B. Compared with BF16, accuracy drops significantly on long-output datasets such as aime2025 (avg 17,635 tokens) and LCB (avg 14,157 tokens), while on medium/short-output datasets like ifeval (avg 1,885 tokens) and mmlu_pro (avg 2,826 tokens), the accuracy is as expected.\n\n| dataset           | Qwen3-8B |         |          | Qwen3-30B-A3B |         |          |\n| ----------------- | -------- | ------- | -------- | ------------- | ------- | -------- |\n|                   | bf16     | awq sym | awq asym | bf16          | awq sym | awq asym |\n| ifeval            | 85.58    | 83.73   | 85.77    | 86.32         | 84.10   | 84.29    |\n| hle               | 5.05     | 5.05    | 5.24     | 7.00          | 5.47    | 5.65     |\n| gpqa              | 59.97    | 56.57   | 59.47    | 61.74         | 57.95   | 57.07    |\n| aime2025          | 69.48    | 64.38   | 63.96    | 73.44         | 64.79   | 66.67    |\n| mmlu_pro          | 73.69    | 71.73   | 72.34    | 77.85         | 75.77   | 75.69    |\n| LCBCodeGeneration | 50.86    | 44.10   | 46.95    | 56.67         | 50.86   | 49.24    |\n\nFor reproduction methods, please refer to [this](https://lmdeploy.readthedocs.io/en/latest/benchmark/evaluate_with_opencompass.html) document.\n"
  },
  {
    "path": "docs/en/quantization/w4a16.md",
    "content": "# AWQ/GPTQ\n\nLMDeploy TurboMind engine supports the inference of 4bit quantized models that are quantized both by [AWQ](https://arxiv.org/abs/2306.00978) and [GPTQ](https://github.com/AutoGPTQ/AutoGPTQ), but its quantization module only supports the AWQ quantization algorithm.\n\nThe following NVIDIA GPUs are available for AWQ/GPTQ INT4 inference:\n\n- V100(sm70): V100\n- Turing(sm75): 20 series, T4\n- Ampere(sm80,sm86): 30 series, A10, A16, A30, A100\n- Ada Lovelace(sm89): 40 series\n\nBefore proceeding with the quantization and inference, please ensure that lmdeploy is installed by following the [installation guide](../get_started/installation.md)\n\nThe remainder of this article is structured into the following sections:\n\n<!-- toc -->\n\n- [Quantization](#quantization)\n- [Evaluation](#evaluation)\n- [Inference](#inference)\n- [Service](#service)\n- [Performance](#performance)\n\n<!-- tocstop -->\n\n## Quantization\n\nA single command execution is all it takes to quantize the model. The resulting quantized weights are then stored in the $WORK_DIR directory.\n\n```shell\nexport HF_MODEL=internlm/internlm2_5-7b-chat\nexport WORK_DIR=internlm/internlm2_5-7b-chat-4bit\n\nlmdeploy lite auto_awq \\\n   $HF_MODEL \\\n  --calib-dataset 'wikitext2' \\\n  --calib-samples 128 \\\n  --calib-seqlen 2048 \\\n  --w-bits 4 \\\n  --w-group-size 128 \\\n  --batch-size 1 \\\n  --work-dir $WORK_DIR\n```\n\nTypically, the above command doesn't require filling in optional parameters, as the defaults usually suffice. For instance, when quantizing the [internlm/internlm2_5-7b-chat](https://huggingface.co/internlm/internlm2_5-7b-chat) model, the command can be condensed as:\n\n```shell\nlmdeploy lite auto_awq internlm/internlm2_5-7b-chat --work-dir internlm2_5-7b-chat-4bit\n```\n\n**Note:**\n\n- We recommend that you specify the --work-dir parameter, including the model name as demonstrated in the example above. This facilitates LMDeploy in fuzzy matching the --work-dir with an appropriate built-in chat template. Otherwise, you will have to designate the chat template during inference.\n- If the quantized model’s accuracy is compromised, it is recommended to enable --search-scale for re-quantization and increase the --batch-size, for example, to 8. When search_scale is enabled, the quantization process will take more time. The --batch-size affects the amount of memory used, which can be adjusted according to actual conditions as needed.\n\nUpon completing quantization, you can engage with the model efficiently using a variety of handy tools.\nFor example, you can initiate a conversation with it via the command line:\n\n```shell\nlmdeploy chat ./internlm2_5-7b-chat-4bit --model-format awq\n```\n\n## Evaluation\n\nPlease refer to [OpenCompass](https://opencompass.readthedocs.io/en/latest/index.html) about model evaluation with LMDeploy. Here is the [guide](https://opencompass.readthedocs.io/en/latest/advanced_guides/evaluation_lmdeploy.html)\n\n## Inference\n\nTrying the following codes, you can perform the batched offline inference with the quantized model:\n\n```python\nfrom lmdeploy import pipeline, TurbomindEngineConfig\nengine_config = TurbomindEngineConfig(model_format='awq')\npipe = pipeline(\"./internlm2_5-7b-chat-4bit\", backend_config=engine_config)\nresponse = pipe([\"Hi, pls intro yourself\", \"Shanghai is\"])\nprint(response)\n```\n\nFor more information about the pipeline parameters, please refer to [here](../llm/pipeline.md).\n\nIn addition to performing inference with the quantized model on localhost, LMDeploy can also execute inference for the 4bit quantized model derived from AWQ algorithm available on Huggingface Hub, such as models from the [lmdeploy space](https://huggingface.co/lmdeploy) and [TheBloke space](https://huggingface.co/TheBloke)\n\n```python\n# inference with models from lmdeploy space\nfrom lmdeploy import pipeline, TurbomindEngineConfig\npipe = pipeline(\"lmdeploy/llama2-chat-70b-4bit\",\n                backend_config=TurbomindEngineConfig(model_format='awq', tp=4))\nresponse = pipe([\"Hi, pls intro yourself\", \"Shanghai is\"])\nprint(response)\n\n# inference with models from thebloke space\nfrom lmdeploy import pipeline, TurbomindEngineConfig, ChatTemplateConfig\npipe = pipeline(\"TheBloke/LLaMA2-13B-Tiefighter-AWQ\",\n                backend_config=TurbomindEngineConfig(model_format='awq'),\n                chat_template_config=ChatTemplateConfig(model_name='llama2')\n                )\nresponse = pipe([\"Hi, pls intro yourself\", \"Shanghai is\"])\nprint(response)\n```\n\n## Service\n\nLMDeploy's `api_server` enables models to be easily packed into services with a single command. The provided RESTful APIs are compatible with OpenAI's interfaces. Below are an example of service startup:\n\n```shell\nlmdeploy serve api_server ./internlm2_5-7b-chat-4bit --backend turbomind --model-format awq\n```\n\nThe default port of `api_server` is `23333`. After the server is launched, you can communicate with server on terminal through `api_client`:\n\n```shell\nlmdeploy serve api_client http://0.0.0.0:23333\n```\n\nYou can overview and try out `api_server` APIs online by swagger UI at `http://0.0.0.0:23333`, or you can also read the API specification from [here](../llm/api_server.md).\n\n## Performance\n\nWe benchmarked the Llama-2-7B-chat and Llama-2-13B-chat models with 4-bit quantization on NVIDIA GeForce RTX 4090. And we measure the token generation throughput (tokens/s) by setting a single prompt token and generating 512 tokens. All the results are measured for single batch inference.\n\n| model            | llm-awq | mlc-llm | turbomind |\n| ---------------- | ------- | ------- | --------- |\n| Llama-2-7B-chat  | 112.9   | 159.4   | 206.4     |\n| Llama-2-13B-chat | N/A     | 90.7    | 115.8     |\n\n## FAQs\n\n1. Out of Memory error during quantization due to insufficient GPU memory: This can be addressed by reducing the parameter `--calib-seqlen`, increasing the parameter `--calib-samples`, and set `--batch-size` to 1.\n"
  },
  {
    "path": "docs/en/quantization/w8a8.md",
    "content": "# SmoothQuant\n\nLMDeploy provides functions for quantization and inference of large language models using 8-bit integers(INT8). For GPUs such as Nvidia H100, lmdeploy also supports 8-bit floating point(FP8).\n\nAnd the following NVIDIA GPUs are available for INT8/FP8 inference respectively:\n\n- INT8\n  - V100(sm70): V100\n  - Turing(sm75): 20 series, T4\n  - Ampere(sm80,sm86): 30 series, A10, A16, A30, A100\n  - Ada Lovelace(sm89): 40 series\n  - Hopper(sm90): H100\n- FP8\n  - Ada Lovelace(sm89): 40 series\n  - Hopper(sm90): H100\n\nFirst of all, run the following command to install lmdeploy:\n\n```shell\npip install lmdeploy[all]\n```\n\n## 8-bit Weight Quantization\n\nPerforming 8-bit weight quantization involves three steps:\n\n1. **Smooth Weights**: Start by smoothing the weights of the Language Model (LLM). This process makes the weights more amenable to quantizing.\n2. **Replace Modules**: Locate DecoderLayers and replace the modules RSMNorm and nn.Linear with QRSMNorm and QLinear modules respectively. These 'Q' modules are available in the lmdeploy/pytorch/models/q_modules.py file.\n3. **Save the Quantized Model**: Once you've made the necessary replacements, save the new quantized model.\n\nlmdeploy provides `lmdeploy lite smooth_quant` command to accomplish all three tasks detailed above. Note that the argument `--quant-dtype` is used to determine if you are doing int8 or fp8 weight quantization. To get more info about usage of the cli, run `lmdeploy lite smooth_quant --help`\n\nHere are two examples:\n\n- int8\n\n  ```shell\n  lmdeploy lite smooth_quant internlm/internlm2_5-7b-chat --work-dir ./internlm2_5-7b-chat-int8 --quant-dtype int8\n  ```\n\n- fp8\n\n  ```shell\n  lmdeploy lite smooth_quant internlm/internlm2_5-7b-chat --work-dir ./internlm2_5-7b-chat-fp8 --quant-dtype fp8\n  ```\n\n## Inference\n\nTrying the following codes, you can perform the batched offline inference with the quantized model:\n\n```python\nfrom lmdeploy import pipeline, PytorchEngineConfig\n\nengine_config = PytorchEngineConfig(tp=1)\npipe = pipeline(\"internlm2_5-7b-chat-int8\", backend_config=engine_config)\nresponse = pipe([\"Hi, pls intro yourself\", \"Shanghai is\"])\nprint(response)\n```\n\n## Service\n\nLMDeploy's `api_server` enables models to be easily packed into services with a single command. The provided RESTful APIs are compatible with OpenAI's interfaces. Below are an example of service startup:\n\n```shell\nlmdeploy serve api_server ./internlm2_5-7b-chat-int8 --backend pytorch\n```\n\nThe default port of `api_server` is `23333`. After the server is launched, you can communicate with server on terminal through `api_client`:\n\n```shell\nlmdeploy serve api_client http://0.0.0.0:23333\n```\n\nYou can overview and try out `api_server` APIs online by swagger UI at `http://0.0.0.0:23333`, or you can also read the API specification from [here](../llm/api_server.md).\n"
  },
  {
    "path": "docs/en/supported_models/reward_models.md",
    "content": "# Reward Models\n\nLMDeploy supports  reward models, which are detailed in the table below:\n\n|      Model       |     Size      | Supported Inference Engine |\n| :--------------: | :-----------: | :------------------------: |\n| Qwen2.5-Math-RM  |      72B      |          PyTorch           |\n| InternLM2-Reward | 1.8B, 7B, 20B |          PyTorch           |\n|      POLAR       |   1.8B, 7B    |          PyTorch           |\n\n## Offline Inference\n\nWe take `internlm/internlm2-1_8b-reward` as an example:\n\n```python\nfrom transformers import AutoTokenizer\nfrom lmdeploy import pipeline, PytorchEngineConfig\n\nmodel_path = \"internlm/internlm2-1_8b-reward\"\nchat = [\n    {\"role\": \"system\", \"content\": \"Please reason step by step, and put your final answer within \\\\boxed{}.\"},\n    {\"role\": \"user\", \"content\": \"Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?\"},\n    {\"role\": \"assistant\", \"content\": \"To determine how much Janet makes from selling the duck eggs at the farmers' market, we need to follow these steps:\\n\\n1. Calculate the total number of eggs laid by the ducks each day.\\n2. Determine how many eggs Janet eats and bakes for herself each day.\\n3. Find out how many eggs are left to be sold.\\n4. Calculate the revenue from selling the remaining eggs at $2 per egg.\\n\\nLet's start with the first step:\\n\\n1. Janet's ducks lay 16 eggs per day.\\n\\nNext, we calculate how many eggs Janet eats and bakes for herself each day:\\n\\n2. Janet eats 3 eggs for breakfast every morning.\\n3. Janet bakes 4 eggs for her friends every day.\\n\\nSo, the total number of eggs Janet eats and bakes for herself each day is:\\n\\\\[ 3 + 4 = 7 \\\\text{ eggs} \\\\]\\n\\nNow, we find out how many eggs are left to be sold:\\n\\\\[ 16 - 7 = 9 \\\\text{ eggs} \\\\]\\n\\nFinally, we calculate the revenue from selling the remaining eggs at $2 per egg:\\n\\\\[ 9 \\\\times 2 = 18 \\\\text{ dollars} \\\\]\\n\\nTherefore, Janet makes 18 dollars every day at the farmers' market.\"}\n]\n\ntokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)\n\nconversation_str = tokenizer.apply_chat_template(\n    chat,\n    tokenize=False,\n    add_generation_prompt=False\n)\n\ninput_ids = tokenizer.encode(\n    conversation_str,\n    add_special_tokens=False\n)\n\n\nif __name__ == '__main__':\n    engine_config = PytorchEngineConfig(tp=tp)\n    with pipeline(model_path, backend_config=engine_config) as pipe:\n        score = pipe.get_reward_score(input_ids)\n        print(f'score: {score}')\n```\n\n## Online Inference\n\nStart the API server:\n\n```bash\nlmdeploy serve api_server internlm/internlm2-1_8b-reward --backend pytorch\n```\n\nGet the reward score from the `/pooling` API endpoint:\n\n```\ncurl http://0.0.0.0:23333/pooling \\\n  -H \"Content-Type: application/json\" \\\n  -d '{\n    \"model\": \"internlm/internlm2-1_8b-reward\",\n    \"input\": \"Who are you?\"\n  }'\n```\n"
  },
  {
    "path": "docs/en/supported_models/supported_models.md",
    "content": "# Supported Models\n\nThe following tables detail the models supported by LMDeploy's TurboMind engine and PyTorch engine across different platforms.\n\n## TurboMind on CUDA Platform\n\n|              Model               |       Size       | Type | FP16/BF16 | KV INT8 | KV INT4 | W4A16 |\n| :------------------------------: | :--------------: | :--: | :-------: | :-----: | :-----: | :---: |\n|              Llama               |     7B - 65B     | LLM  |    Yes    |   Yes   |   Yes   |  Yes  |\n|              Llama2              |     7B - 70B     | LLM  |    Yes    |   Yes   |   Yes   |  Yes  |\n|              Llama3              |     8B, 70B      | LLM  |    Yes    |   Yes   |   Yes   |  Yes  |\n|             Llama3.1             |     8B, 70B      | LLM  |    Yes    |   Yes   |   Yes   |  Yes  |\n|     Llama3.2<sup>\\[2\\]</sup>     |      1B, 3B      | LLM  |    Yes    |  Yes\\*  |  Yes\\*  |  Yes  |\n|             InternLM             |     7B - 20B     | LLM  |    Yes    |   Yes   |   Yes   |  Yes  |\n|            InternLM2             |     7B - 20B     | LLM  |    Yes    |   Yes   |   Yes   |  Yes  |\n|           InternLM2.5            |        7B        | LLM  |    Yes    |   Yes   |   Yes   |  Yes  |\n|            InternLM3             |        8B        | LLM  |    Yes    |   Yes   |   Yes   |  Yes  |\n|       InternLM-XComposer2        |   7B, 4khd-7B    | MLLM |    Yes    |   Yes   |   Yes   |  Yes  |\n|      InternLM-XComposer2.5       |        7B        | MLLM |    Yes    |   Yes   |   Yes   |  Yes  |\n|            Intern-S1             |       241B       | MLLM |    Yes    |   Yes   |   Yes   |  No   |\n|          Intern-S1-mini          |       8.3B       | MLLM |    Yes    |   Yes   |   Yes   |  No   |\n|               Qwen               |    1.8B - 72B    | LLM  |    Yes    |   Yes   |   Yes   |  Yes  |\n|     Qwen1.5<sup>\\[1\\]</sup>      |   1.8B - 110B    | LLM  |    Yes    |   Yes   |   Yes   |  Yes  |\n|      Qwen2<sup>\\[2\\]</sup>       |    0.5B - 72B    | LLM  |    Yes    |  Yes\\*  |  Yes\\*  |  Yes  |\n|            Qwen2-MoE             |     57BA14B      | LLM  |    Yes    |   Yes   |   Yes   |  Yes  |\n|     Qwen2.5<sup>\\[2\\]</sup>      |    0.5B - 72B    | LLM  |    Yes    |  Yes\\*  |  Yes\\*  |  Yes  |\n|              Qwen3               |    0.6B-235B     | LLM  |    Yes    |   Yes   |  Yes\\*  | Yes\\* |\n|     Qwen3.5<sup>\\[3\\]</sup>      |    0.8B-397B     | MLLM |    Yes    |   Yes   |   No    |  Yes  |\n|     Mistral<sup>\\[1\\]</sup>      |        7B        | LLM  |    Yes    |   Yes   |   Yes   |  No   |\n|             Mixtral              |   8x7B, 8x22B    | LLM  |    Yes    |   Yes   |   Yes   |  Yes  |\n|           DeepSeek-V2            |    16B, 236B     | LLM  |    Yes    |   Yes   |   Yes   |  No   |\n|          DeepSeek-V2.5           |       236B       | LLM  |    Yes    |   Yes   |   Yes   |  No   |\n|             Qwen-VL              |        7B        | MLLM |    Yes    |   Yes   |   Yes   |  Yes  |\n|           DeepSeek-VL            |        7B        | MLLM |    Yes    |   Yes   |   Yes   |  Yes  |\n|             Baichuan             |        7B        | LLM  |    Yes    |   Yes   |   Yes   |  Yes  |\n|            Baichuan2             |        7B        | LLM  |    Yes    |   Yes   |   Yes   |  Yes  |\n|            Code Llama            |     7B - 34B     | LLM  |    Yes    |   Yes   |   Yes   |  No   |\n|                YI                |     6B - 34B     | LLM  |    Yes    |   Yes   |   Yes   |  Yes  |\n|          LLaVA(1.5,1.6)          |     7B - 34B     | MLLM |    Yes    |   Yes   |   Yes   |  Yes  |\n|             InternVL             |   v1.1 - v1.5    | MLLM |    Yes    |   Yes   |   Yes   |  Yes  |\n|    InternVL2<sup>\\[2\\]</sup>     | 1 - 2B, 8B - 76B | MLLM |    Yes    |  Yes\\*  |  Yes\\*  |  Yes  |\n| InternVL2.5(MPO)<sup>\\[2\\]</sup> |     1 - 78B      | MLLM |    Yes    |  Yes\\*  |  Yes\\*  |  Yes  |\n|    InternVL3<sup>\\[2\\]</sup>     |     1 - 78B      | MLLM |    Yes    |  Yes\\*  |  Yes\\*  |  Yes  |\n|   InternVL3.5<sup>\\[3\\]</sup>    |   1 - 241BA28B   | MLLM |    Yes    |  Yes\\*  |  Yes\\*  |  No   |\n|             ChemVLM              |     8B - 26B     | MLLM |    Yes    |   Yes   |   Yes   |  Yes  |\n|       MiniCPM-Llama3-V-2_5       |        -         | MLLM |    Yes    |   Yes   |   Yes   |  Yes  |\n|          MiniCPM-V-2_6           |        -         | MLLM |    Yes    |   Yes   |   Yes   |  Yes  |\n|               GLM4               |        9B        | LLM  |    Yes    |   Yes   |   Yes   |  Yes  |\n|            CodeGeeX4             |        9B        | LLM  |    Yes    |   Yes   |   Yes   |   -   |\n|              Molmo               |     7B-D,72B     | MLLM |    Yes    |   Yes   |   Yes   |  No   |\n|             gpt-oss              |     20B,120B     | LLM  |    Yes    |   Yes   |   Yes   |  Yes  |\n\n\"-\" means not verified yet.\n\n```{note}\n* [1] The TurboMind engine doesn't support window attention. Therefore, for models that have applied window attention and have the corresponding switch \"use_sliding_window\" enabled, such as Mistral, Qwen1.5 and etc., please choose the PyTorch engine for inference.\n* [2] When the head_dim of a model is not 128, such as llama3.2-1B, qwen2-0.5B and internvl2-1B, turbomind doesn't support its kv cache 4/8 bit quantization and inference\n* [3] TurboMind does not currently support the vision encoder for the Qwen3.5 series.\n```\n\n## PyTorchEngine on CUDA Platform\n\n|             Model              |      Size       | Type | FP16/BF16 | KV INT8 | KV INT4 | W8A8 | W4A16 |\n| :----------------------------: | :-------------: | :--: | :-------: | :-----: | :-----: | :--: | :---: |\n|             Llama              |    7B - 65B     | LLM  |    Yes    |   Yes   |   Yes   | Yes  |  Yes  |\n|             Llama2             |    7B - 70B     | LLM  |    Yes    |   Yes   |   Yes   | Yes  |  Yes  |\n|             Llama3             |     8B, 70B     | LLM  |    Yes    |   Yes   |   Yes   | Yes  |  Yes  |\n|            Llama3.1            |     8B, 70B     | LLM  |    Yes    |   Yes   |   Yes   | Yes  |  Yes  |\n|            Llama3.2            |     1B, 3B      | LLM  |    Yes    |   Yes   |   Yes   | Yes  |  Yes  |\n|             Llama4             | Scout, Maverick | MLLM |    Yes    |   Yes   |   Yes   |  -   |   -   |\n|            InternLM            |    7B - 20B     | LLM  |    Yes    |   Yes   |   Yes   | Yes  |  Yes  |\n|           InternLM2            |    7B - 20B     | LLM  |    Yes    |   Yes   |   Yes   | Yes  |  Yes  |\n|          InternLM2.5           |       7B        | LLM  |    Yes    |   Yes   |   Yes   | Yes  |  Yes  |\n|           InternLM3            |       8B        | LLM  |    Yes    |   Yes   |   Yes   | Yes  |  Yes  |\n|           Intern-S1            |      241B       | MLLM |    Yes    |   Yes   |   Yes   | Yes  |   -   |\n|         Intern-S1-mini         |      8.3B       | MLLM |    Yes    |   Yes   |   Yes   | Yes  |   -   |\n|         Intern-S1-Pro          |       1TB       | MLLM |    Yes    |    -    |    -    |  -   |  No   |\n|           Baichuan2            |       7B        | LLM  |    Yes    |   Yes   |   Yes   | Yes  |  No   |\n|           Baichuan2            |       13B       | LLM  |    Yes    |   Yes   |   Yes   |  No  |  No   |\n|            ChatGLM2            |       6B        | LLM  |    Yes    |   Yes   |   Yes   |  No  |  No   |\n|               YI               |    6B - 34B     | LLM  |    Yes    |   Yes   |   Yes   | Yes  |  Yes  |\n|            Mistral             |       7B        | LLM  |    Yes    |   Yes   |   Yes   | Yes  |  Yes  |\n|            Mixtral             |   8x7B, 8x22B   | LLM  |    Yes    |   Yes   |   Yes   |  No  |  No   |\n|              QWen              |   1.8B - 72B    | LLM  |    Yes    |   Yes   |   Yes   | Yes  |  Yes  |\n|            QWen1.5             |   0.5B - 110B   | LLM  |    Yes    |   Yes   |   Yes   | Yes  |  Yes  |\n|          QWen1.5-MoE           |      A2.7B      | LLM  |    Yes    |   Yes   |   Yes   |  No  |  No   |\n|             QWen2              |   0.5B - 72B    | LLM  |    Yes    |   Yes   |   No    | Yes  |  Yes  |\n|            Qwen2.5             |   0.5B - 72B    | LLM  |    Yes    |   Yes   |   No    | Yes  |  Yes  |\n|             Qwen3              |   0.6B - 235B   | LLM  |    Yes    |   Yes   |  Yes\\*  |  -   | Yes\\* |\n|           QWen3-Next           |       80B       | LLM  |    Yes    |   No    |   No    |  No  |  No   |\n|            QWen2-VL            |     2B, 7B      | MLLM |    Yes    |   Yes   |   No    |  No  |  Yes  |\n|           QWen2.5-VL           |    3B - 72B     | MLLM |    Yes    |   No    |   No    |  No  |  No   |\n|            QWen3-VL            |    2B - 235B    | MLLM |    Yes    |   No    |   No    |  No  |  No   |\n|            QWen3.5             |    0.8B-397B    | MLLM |    Yes    |   No    |   No    |  No  |  No   |\n|          DeepSeek-MoE          |       16B       | LLM  |    Yes    |   No    |   No    |  No  |  No   |\n|          DeepSeek-V2           |    16B, 236B    | LLM  |    Yes    |   No    |   No    |  No  |  No   |\n|         DeepSeek-V2.5          |      236B       | LLM  |    Yes    |   No    |   No    |  No  |  No   |\n|          DeepSeek-V3           |      685B       | LLM  |    Yes    |   No    |   No    |  No  |  No   |\n|         DeepSeek-V3.2          |      685B       | LLM  |    Yes    |   No    |   No    |  No  |  No   |\n|          DeepSeek-VL2          |    3B - 27B     | MLLM |    Yes    |   No    |   No    |  No  |  No   |\n|            MiniCPM3            |       4B        | LLM  |    Yes    |   Yes   |   Yes   |  No  |  No   |\n|         MiniCPM-V-2_6          |       8B        | LLM  |    Yes    |   No    |   No    |  No  |  Yes  |\n|             Gemma              |      2B-7B      | LLM  |    Yes    |   Yes   |   Yes   |  No  |  No   |\n|           StarCoder2           |     3B-15B      | LLM  |    Yes    |   Yes   |   Yes   |  No  |  No   |\n|           Phi-3-mini           |      3.8B       | LLM  |    Yes    |   Yes   |   Yes   | Yes  |  Yes  |\n|          Phi-3-vision          |      4.2B       | MLLM |    Yes    |   Yes   |   Yes   |  -   |   -   |\n|           Phi-4-mini           |      3.8B       | LLM  |    Yes    |   Yes   |   Yes   | Yes  |  Yes  |\n|          CogVLM-Chat           |       17B       | MLLM |    Yes    |   Yes   |   Yes   |  -   |   -   |\n|          CogVLM2-Chat          |       19B       | MLLM |    Yes    |   Yes   |   Yes   |  -   |   -   |\n| LLaVA(1.5,1.6)<sup>\\[2\\]</sup> |     7B-34B      | MLLM |    No     |   No    |   No    |  No  |  No   |\n|         InternVL(v1.5)         |     2B-26B      | MLLM |    Yes    |   Yes   |   Yes   |  No  |  Yes  |\n|           InternVL2            |     1B-76B      | MLLM |    Yes    |   Yes   |   Yes   |  -   |   -   |\n|        InternVL2.5(MPO)        |     1B-78B      | MLLM |    Yes    |   Yes   |   Yes   |  -   |   -   |\n|           InternVL3            |     1B-78B      | MLLM |    Yes    |   Yes   |   Yes   |  -   |   -   |\n|          InternVL3.5           |   1B-241BA28B   | MLLM |    Yes    |   Yes   |   Yes   |  No  |  No   |\n| Mono-InternVL<sup>\\[1\\]</sup>  |       2B        | MLLM |    Yes    |   Yes   |   Yes   |  -   |   -   |\n|            ChemVLM             |     8B-26B      | MLLM |    Yes    |   Yes   |   No    |  -   |   -   |\n|             Gemma2             |     9B-27B      | LLM  |    Yes    |   Yes   |   Yes   |  -   |   -   |\n|             Gemma3             |     1B-27B      | MLLM |    Yes    |   Yes   |   Yes   |  -   |   -   |\n|             GLM-4              |       9B        | LLM  |    Yes    |   Yes   |   Yes   |  No  |  No   |\n|           GLM-4-0414           |       9B        | LLM  |    Yes    |   Yes   |   Yes   |  -   |   -   |\n|             GLM-4V             |       9B        | MLLM |    Yes    |   Yes   |   Yes   |  No  |  Yes  |\n|       GLM-4.1V-Thinking        |       9B        | MLLM |    Yes    |   Yes   |   Yes   |  -   |   -   |\n|            GLM-4.5             |      355B       | LLM  |    Yes    |   Yes   |   Yes   |  -   |   -   |\n|          GLM-4.5-Air           |      106B       | LLM  |    Yes    |   Yes   |   Yes   |  -   |   -   |\n|           CodeGeeX4            |       9B        | LLM  |    Yes    |   Yes   |   Yes   |  -   |   -   |\n|          Phi-3.5-mini          |      3.8B       | LLM  |    Yes    |   Yes   |   No    |  -   |   -   |\n|          Phi-3.5-MoE           |     16x3.8B     | LLM  |    Yes    |   Yes   |   No    |  -   |   -   |\n|         Phi-3.5-vision         |      4.2B       | MLLM |    Yes    |   Yes   |   No    |  -   |   -   |\n|              SDAR              |    1.7B-30B     | LLM  |    Yes    |   Yes   |   No    |  -   |   -   |\n|         GLM-4.7-Flash          |       30B       | LLM  |    Yes    |   No    |   No    |  No  |  No   |\n|             GLM-5              |      754B       | LLM  |    Yes    |   No    |   No    |  No  |  No   |\n\n```{note}\n* [1] Currently Mono-InternVL does not support FP16 due to numerical instability. Please use BF16 instead.\n* [2] PyTorch engine removes the support of original llava models after v0.6.4. Please use their corresponding transformers models instead, which can be found in https://huggingface.co/llava-hf\nStarting from version 0.11.1, PytorchEngine no longer provides support for mllama.\n```\n\n## PyTorchEngine on Other Platforms\n\n|                |           |      |  Atlas 800T A2   |  Atlas 800T A2   | Atlas 800T A2 | Atlas 800T A2 | Atlas 300I Duo |  Atlas 800T A3   | Maca C500 | Cambricon |\n| :------------: | :-------: | :--: | :--------------: | :--------------: | :-----------: | :-----------: | :------------: | :--------------: | :-------: | :-------: |\n|     Model      |   Size    | Type | FP16/BF16(eager) | FP16/BF16(graph) |  W8A8(graph)  | W4A16(eager)  |  FP16(graph)   | FP16/BF16(eager) |  BF/FP16  |  BF/FP16  |\n|     Llama2     | 7B - 70B  | LLM  |       Yes        |       Yes        |      Yes      |      Yes      |       -        |       Yes        |    Yes    |    Yes    |\n|     Llama3     |    8B     | LLM  |       Yes        |       Yes        |      Yes      |      Yes      |      Yes       |       Yes        |    Yes    |    Yes    |\n|    Llama3.1    |    8B     | LLM  |       Yes        |       Yes        |      Yes      |      Yes      |      Yes       |       Yes        |    Yes    |    Yes    |\n|   InternLM2    | 7B - 20B  | LLM  |       Yes        |       Yes        |      Yes      |      Yes      |      Yes       |       Yes        |    Yes    |    Yes    |\n|  InternLM2.5   | 7B - 20B  | LLM  |       Yes        |       Yes        |      Yes      |      Yes      |      Yes       |       Yes        |    Yes    |    Yes    |\n|   InternLM3    |    8B     | LLM  |       Yes        |       Yes        |      Yes      |      Yes      |      Yes       |       Yes        |    Yes    |    Yes    |\n|    Mixtral     |   8x7B    | LLM  |       Yes        |       Yes        |      No       |      No       |      Yes       |        -         |    Yes    |    Yes    |\n|  QWen1.5-MoE   |   A2.7B   | LLM  |       Yes        |        -         |      No       |      No       |       -        |        -         |    Yes    |     -     |\n|   QWen2(.5)    |    7B     | LLM  |       Yes        |       Yes        |      Yes      |      Yes      |      Yes       |        -         |    Yes    |    Yes    |\n|    QWen2-VL    |  2B, 7B   | MLLM |       Yes        |       Yes        |       -       |       -       |       -        |        -         |    Yes    |    No     |\n|   QWen2.5-VL   | 3B - 72B  | MLLM |       Yes        |       Yes        |       -       |       -       |      Yes       |        -         |    Yes    |    No     |\n|   QWen2-MoE    |  A14.57B  | LLM  |       Yes        |        -         |      No       |      No       |       -        |        -         |    Yes    |     -     |\n|     QWen3      | 0.6B-235B | LLM  |       Yes        |       Yes        |      No       |      No       |      Yes       |       Yes        |    Yes    |    Yes    |\n|  DeepSeek-V2   |    16B    | LLM  |        No        |       Yes        |      No       |      No       |       -        |        -         |     -     |     -     |\n| InternVL(v1.5) |  2B-26B   | MLLM |       Yes        |        -         |      Yes      |      Yes      |       -        |        -         |    Yes    |     -     |\n|   InternVL2    |  1B-40B   | MLLM |       Yes        |       Yes        |      Yes      |      Yes      |      Yes       |        -         |    Yes    |    Yes    |\n|  InternVL2.5   |  1B-78B   | MLLM |       Yes        |       Yes        |      Yes      |      Yes      |      Yes       |        -         |    Yes    |    Yes    |\n|   InternVL3    |  1B-78B   | MLLM |       Yes        |       Yes        |      Yes      |      Yes      |      Yes       |        -         |    Yes    |    Yes    |\n|  CogVLM2-chat  |    19B    | MLLM |       Yes        |        No        |       -       |       -       |       -        |        -         |    Yes    |     -     |\n|     GLM4V      |    9B     | MLLM |       Yes        |        No        |       -       |       -       |       -        |        -         |     -     |     -     |\n"
  },
  {
    "path": "docs/zh_cn/.readthedocs.yaml",
    "content": "version: 2\n\nformats: all\n\nbuild:\n  os: \"ubuntu-22.04\"\n  tools:\n    python: \"3.10\"\n\n\nsphinx:\n  configuration: docs/zh_cn/conf.py\n\n\npython:\n  install:\n    - requirements: requirements/docs.txt\n    - requirements: requirements/readthedocs.txt\n"
  },
  {
    "path": "docs/zh_cn/Makefile",
    "content": "# Minimal makefile for Sphinx documentation\n#\n\n# You can set these variables from the command line.\nSPHINXOPTS    =\nSPHINXBUILD   = sphinx-build\nSOURCEDIR     = .\nBUILDDIR      = _build\n\n# Put it first so that \"make\" without argument is like \"make help\".\nhelp:\n\t@$(SPHINXBUILD) -M help \"$(SOURCEDIR)\" \"$(BUILDDIR)\" $(SPHINXOPTS) $(O)\n\n.PHONY: help Makefile\n\n# Catch-all target: route all unknown targets to Sphinx using the new\n# \"make mode\" option.  $(O) is meant as a shortcut for $(SPHINXOPTS).\n%: Makefile\n\t@$(SPHINXBUILD) -M $@ \"$(SOURCEDIR)\" \"$(BUILDDIR)\" $(SPHINXOPTS) $(O)\n"
  },
  {
    "path": "docs/zh_cn/_static/css/readthedocs.css",
    "content": "table.autosummary td {\n  width: 50%\n}\n\nimg.align-center {\n  display: block;\n  margin-left: auto;\n  margin-right: auto;\n}\n"
  },
  {
    "path": "docs/zh_cn/advance/chat_template.md",
    "content": "# 自定义对话模板\n\n被应用的对话模板效果，可以通过设置日志等级为`INFO`进行观测。\n\nLMDeploy 支持两种添加对话模板的形式：\n\n- 一种是利用现有对话模板，直接配置一个如下的 json 文件使用。\n\n  ```json\n  {\n      \"model_name\": \"your awesome chat template name\",\n      \"system\": \"<|im_start|>system\\n\",\n      \"meta_instruction\": \"You are a robot developed by LMDeploy.\",\n      \"eosys\": \"<|im_end|>\\n\",\n      \"user\": \"<|im_start|>user\\n\",\n      \"eoh\": \"<|im_end|>\\n\",\n      \"assistant\": \"<|im_start|>assistant\\n\",\n      \"eoa\": \"<|im_end|>\",\n      \"separator\": \"\\n\",\n      \"capability\": \"chat\",\n      \"stop_words\": [\"<|im_end|>\"]\n  }\n  ```\n\n  这样一个模板将会以下面的形式进行拼接。\n\n  ```\n  {system}{meta_instruction}{eosys}{user}{user_content}{eoh}{assistant}{assistant_content}{eoa}{separator}{user}...\n  ```\n\n  在使用 CLI 工具时，可以通过 `--chat-template` 传入自定义对话模板，比如：\n\n  ```shell\n  lmdeploy serve api_server internlm/internlm2_5-7b-chat --chat-template ${JSON_FILE}\n  ```\n\n  也可以在通过接口函数传入，比如：\n\n  ```python\n  from lmdeploy import ChatTemplateConfig, serve\n\n  serve('internlm/internlm2_5-7b-chat',\n        chat_template_config=ChatTemplateConfig.from_json('${JSON_FILE}'))\n  ```\n\n- 一种是以 LMDeploy 现有对话模板，自定义一个python对话模板类，注册成功后直接用即可。优点是自定义程度高，可控性强。\n  下面是一个注册 LMDeploy 对话模板的例子：\n\n  ```python\n  from lmdeploy.model import MODELS, BaseChatTemplate\n\n\n  @MODELS.register_module(name='customized_model')\n  class CustomizedModel(BaseChatTemplate):\n      \"\"\"A customized chat template.\"\"\"\n\n      def __init__(self,\n                   system='<|im_start|>system\\n',\n                   meta_instruction='You are a robot developed by LMDeploy.',\n                   user='<|im_start|>user\\n',\n                   assistant='<|im_start|>assistant\\n',\n                   eosys='<|im_end|>\\n',\n                   eoh='<|im_end|>\\n',\n                   eoa='<|im_end|>',\n                   separator='\\n',\n                   stop_words=['<|im_end|>', '<|action_end|>']):\n          super().__init__(system=system,\n                           meta_instruction=meta_instruction,\n                           eosys=eosys,\n                           user=user,\n                           eoh=eoh,\n                           assistant=assistant,\n                           eoa=eoa,\n                           separator=separator,\n                           stop_words=stop_words)\n\n\n  from lmdeploy import ChatTemplateConfig, pipeline\n\n  messages = [{'role': 'user', 'content': 'who are you?'}]\n  pipe = pipeline('internlm/internlm2_5-7b-chat',\n                  chat_template_config=ChatTemplateConfig('customized_model'))\n  for response in pipe.stream_infer(messages):\n      print(response.text, end='')\n  ```\n\n  在这个例子中，我们注册了一个 LMDeploy 的对话模板，该模板将模型设置为由 LMDeploy 创造，所以当用户提问模型是谁的时候，模型就会回答由 LMDeploy 所创。\n"
  },
  {
    "path": "docs/zh_cn/advance/context_parallel.md",
    "content": "# 序列并行\n\n在单卡显存不足以部署模型的时候，通常会以 `TP` 的方式进行部署，而这一般要求 `num_key_value_heads` 被 `TP` 整除。如果要以 `TP > num_key_value_heads` 的方式进行部署，需要创建 kv-heads 的副本，以满足整除需求。但是这样会有两个缺点：\n\n1. 可用的 kvcache 数量减半，进而减少请求最大推理长度\n2. 降低推理的最大 batch 数量，减少吞吐量。\n\n为了解决这个问题，TurboMind 推理后端支持设置 `attn_dp_size`，避免了创建 kv-heads 的副本，但是这会引入数据的不均衡性。为了消除数据的不均衡，TurboMind 支持了序列并行，支持将 kv_cache 交错存储到不同的 cp_rank 上，例如\n\n```\ncp_rank=2, prompt_len=5, generation_len=4\nkv_cache stored on cp_rank0: 0, 2, 4, 6, 8\nkv_cache stored on cp_rank1: 1, 3, 5, 7\n```\n\n## 使用说明\n\n以 `Intern-S1` / `Qwen3-235B-A22B` 为例，他们的 `num_key_value_heads` 为 4，若要用 `TP=8` 的方式部署，并避免 kv_cache 的拷贝，可以用如下的方式部署\n\n```\nlmdeploy serve api_server internlm/Intern-S1 --tp 8 --cp 2\n\nlmdeploy serve api_server Qwen/Qwen3-235B-A22B --tp 8 --cp 2\n```\n"
  },
  {
    "path": "docs/zh_cn/advance/debug_turbomind.md",
    "content": "# 如何调试 Turbomind\n\nTurbomind 使用 C++ 实现，不像 Python 一样易于调试。该文档提供了调试 Turbomind 的基本方法。\n\n## 前置工作\n\n首先，根据构建[命令](../get_started/installation.md)完成源码编译和安装。\n\n## 配置 Python 调试环境\n\n由于目前许多大公司在线上生产环境中使用 Centos 7，我们将以 Centos 7 为例来说明配置过程。\n\n### 获取 `glibc` 和 `python3` 的版本\n\n```bash\nrpm -qa | grep glibc\nrpm -qa | grep python3\n```\n\n结果类似于这样：\n\n```\n[username@hostname workdir]# rpm -qa | grep glibc\nglibc-2.17-325.el7_9.x86_64\nglibc-common-2.17-325.el7_9.x86_64\nglibc-headers-2.17-325.el7_9.x86_64\nglibc-devel-2.17-325.el7_9.x86_64\n\n[username@hostname workdir]# rpm -qa | grep python3\npython3-pip-9.0.3-8.el7.noarch\npython3-rpm-macros-3-34.el7.noarch\npython3-rpm-generators-6-2.el7.noarch\npython3-setuptools-39.2.0-10.el7.noarch\npython3-3.6.8-21.el7_9.x86_64\npython3-devel-3.6.8-21.el7_9.x86_64\npython3.6.4-sre-1.el6.x86_64\n```\n\n根据上述信息，我们可以看到 `glibc` 的版本是 `2.17-325.el7_9.x86_64`，`python3` 的版本是 `3.6.8-21.el7_9.x86_64`。\n\n### 下载并安装 `debuginfo` 库\n\n从 http://debuginfo.centos.org/7/x86_64 下载 `glibc-debuginfo-common-2.17-325.el7.x86_64.rpm`、`glibc-debuginfo-2.17-325.el7.x86_64.rpm` 和 `python3-debuginfo-3.6.8-21.el7.x86_64.rpm`。\n\n```bash\nrpm -ivh glibc-debuginfo-common-2.17-325.el7.x86_64.rpm\nrpm -ivh glibc-debuginfo-2.17-325.el7.x86_64.rpm\nrpm -ivh python3-debuginfo-3.6.8-21.el7.x86_64.rpm\n```\n\n### 升级 GDB\n\n```bash\nsudo yum install devtoolset-10 -y\necho \"source scl_source enable devtoolset-10\" >> ~/.bashrc\nsource ~/.bashrc\n```\n\n### 验证\n\n```bash\ngdb python3\n```\n\n输出类似于这样：\n\n```\n[username@hostname workdir]# gdb python3\nGNU gdb (GDB) Red Hat Enterprise Linux 9.2-10.el7\nCopyright (C) 2020 Free Software Foundation, Inc.\nLicense GPLv3+: GNU GPL version 3 or later <http://gnu.org/licenses/gpl.html>\nThis is free software: you are free to change and redistribute it.\nThere is NO WARRANTY, to the extent permitted by law.\nType \"show copying\" and \"show warranty\" for details.\nThis GDB was configured as \"x86_64-redhat-linux-gnu\".\nType \"show configuration\" for configuration details.\nFor bug reporting instructions, please see:\n<http://www.gnu.org/software/gdb/bugs/>.\nFind the GDB manual and other documentation resources online at:\n   <http://www.gnu.org/software/gdb/documentation/>.\n\nFor help, type \"help\".\nType \"apropos word\" to search for commands related to \"word\"...\nReading symbols from python3...\n(gdb)\n```\n\n如果显示 `Reading symbols from python3`，说明配置成功。\n\n对于其他操作系统，请参考 [DebuggingWithGdb](https://wiki.python.org/moin/DebuggingWithGdb)。\n\n## 设置符号链接\n\n设置符号链接后，不需要每次都通过 `pip` 进行本地安装。\n\n```bash\n# 更改目录到 lmdeploy，例如\ncd /workdir/lmdeploy\n\n# 因为编译文件在 build 文件夹中\n# 设置 lib 的软链接\ncd lmdeploy && ln -s ../build/lib . && cd ..\n# （可选）创建 compile_commands.json 软链接，用于 clangd 构建 index\nln -s build/compile_commands.json .\n```\n\n## 开始调试\n\n````bash\n# 使用 gdb 启动 API Server，例如\ngdb --args python3 -m lmdeploy serve api_server /workdir/Llama-2-13b-chat-hf\n\n# 在 gdb 中设置 lmdeploy 文件夹路径\nReading symbols from python3...\n(gdb) set directories /workdir/lmdeploy\n\n# 使用相对路径设置断点，例如\n(gdb) b src/turbomind/models/llama/BlockManager.cc:104\n\n# 当出现\n# ```\n# No source file named src/turbomind/models/llama/BlockManager.cc.\n# Make breakpoint pending on future shared library load? (y or [n])\n# ```\n# 输入 y 并回车\n\n# 运行\n(gdb) r\n\n# (可选) 使用 https://github.com/InternLM/lmdeploy/blob/main/benchmark/profile_restful_api.py 发送请求\n\npython3 profile_restful_api.py --backend lmdeploy --dataset-path ./ShareGPT_V3_unfiltered_cleaned_split.json\n````\n\n## 使用 GDB\n\n参考 [GDB Execution Commands](https://lldb.llvm.org/use/map.html) 进行调试。\n"
  },
  {
    "path": "docs/zh_cn/advance/long_context.md",
    "content": "# 长文本外推\n\n长文本外推指 LLM 推理时处理比训练文本更长数据的能力。TurboMind 引擎目前支持 [LlamaDynamicNTKScalingRotaryEmbedding](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L178), 并与 HuggingFace 的实现对齐。\n\n## 如何使用\n\n如果要直接加载 HuggingFace 格式的模型，可以通过修改 TurbomindEngineConfig 参数的方式赋予模型外推能力。将 `session_len` 修改为外推的长度，并将 `rope_scaling_factor` 修改为不小于 1.0 的值。\n\n以具有 **1M 上下文长度**的`internlm2_5-7b-chat-1m`为例，可以使用如下方式，激活长文本推理能力：\n\n```python\nfrom lmdeploy import pipeline, GenerationConfig, TurbomindEngineConfig\n\nbackend_config = TurbomindEngineConfig(\n        rope_scaling_factor=2.5,\n        session_len=1000000,\n        max_batch_size=1,\n        cache_max_entry_count=0.7,\n        tp=4)\npipe = pipeline('internlm/internlm2_5-7b-chat-1m', backend_config=backend_config)\nprompt = 'Use a long prompt to replace this sentence'\ngen_config = GenerationConfig(top_p=0.8,\n                              top_k=40,\n                              temperature=0.8,\n                              max_new_tokens=1024)\nresponse = pipe(prompt, gen_config=gen_config)\nprint(response)\n```\n\n## 评测\n\n我们使用多种方式评测 LMDeploy 长文本推理能力，分别是 [passkey retrieval 实验](#passkey-retrieval)、[大海捞针实验](#大海捞针) 和[计算困惑度](#困惑度)\n\n### Passkey Retrieval\n\n执行如下代码，可以测试在长文本中找到特殊 key 成功和失败的次数\n\n```python\nimport numpy as np\nfrom lmdeploy import pipeline\nfrom lmdeploy import TurbomindEngineConfig\nimport time\n\nsession_len = 1000000\nbackend_config = TurbomindEngineConfig(\n        rope_scaling_factor=2.5,\n        session_len=session_len,\n        max_batch_size=1,\n        cache_max_entry_count=0.7,\n        tp=4)\npipe = pipeline('internlm/internlm2_5-7b-chat-1m', backend_config=backend_config)\n\n\ndef passkey_retrieval(session_len, n_round=5):\n    # create long context input\n    tok = pipe.tokenizer\n    task_description = 'There is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there.'\n    garbage = 'The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again.'\n\n    for _ in range(n_round):\n        start = time.perf_counter()\n        n_times = (session_len - 1000) // len(tok.encode(garbage))\n        n_garbage_prefix = np.random.randint(0, n_times)\n        n_garbage_suffix = n_times - n_garbage_prefix\n        garbage_prefix = ' '.join([garbage] * n_garbage_prefix)\n        garbage_suffix = ' '.join([garbage] * n_garbage_suffix)\n        pass_key = np.random.randint(1, 50000)\n        information_line = f'The pass key is {pass_key}. Remember it. {pass_key} is the pass key.'  # noqa: E501\n        final_question = 'What is the pass key? The pass key is'\n        lines = [\n            task_description,\n            garbage_prefix,\n            information_line,\n            garbage_suffix,\n            final_question,\n        ]\n\n        # inference\n        prompt = ' '.join(lines)\n        response = pipe([prompt])\n        print(pass_key, response)\n        end = time.perf_counter()\n        print(f'duration: {end - start} s')\n\npasskey_retrieval(session_len, 5)\n```\n\n在 A100-80G GPU上，执行上述实验，每轮测试大约需要 364 秒\n\n### 大海捞针\n\n可使用 OpenCompass 进行测评，具体使用方法，请参考[文档](https://github.com/open-compass/opencompass/blob/main/docs/zh_cn/advanced_guides/needleinahaystack_eval.md)\n\n### 困惑度\n\n下面展示使用 LMDeploy 计算困惑度的用法\n\n```python\nfrom transformers import AutoTokenizer\nfrom lmdeploy import TurbomindEngineConfig, pipeline\nimport numpy as np\n\n# load model and tokenizer\nmodel_repoid_or_path = 'internlm/internlm2_5-7b-chat-1m'\nbackend_config = TurbomindEngineConfig(\n        rope_scaling_factor=2.5,\n        session_len=1000000,\n        max_batch_size=1,\n        cache_max_entry_count=0.7,\n        tp=4)\npipe = pipeline(model_repoid_or_path, backend_config=backend_config)\ntokenizer = AutoTokenizer.from_pretrained(model_repoid_or_path, trust_remote_code=True)\n\n# get perplexity\ntext = 'Use a long prompt to replace this sentence'\ninput_ids = tokenizer.encode(text)\nloss = pipe.get_ppl(input_ids)[0]\nprint(ppl)\n```\n"
  },
  {
    "path": "docs/zh_cn/advance/metrics.md",
    "content": "# 生产环境指标监控\n\nLMDeploy 通过 Prometheus 暴露监控指标，并通过 Grafana 提供可视化界面。\n\n## 配置指南\n\n本节介绍如何设置 `lmdeploy/monitoring` 目录中提供的监控套件（Prometheus + Grafana）\n\n## 前提条件\n\n- 已安装 [Docker](https://docs.docker.com/engine/install/) 和 [Docker Compose](https://docs.docker.com/compose/install/)\n\n- 已启用指标系统的 LMDeploy 服务正在运行\n\n## 使用说明 (DP = 1)\n\n1. **启动已启用指标的 LMDeploy 服务**\n\n```\nlmdeploy serve api_server Qwen/Qwen2.5-7B-Instruct --enable-metrics\n```\n\n请根据需求替换模型路径。默认 metrics endpoint 位于 `http://<lmdeploy_server_host>:23333/metrics`。\n\n2. **进入监控目录**\n\n```\ncd lmdeploy/monitoring\n```\n\n3. **启动监控套件**\n\n```\ndocker compose up\n```\n\n此命令将在后台启动 Prometheus 和 Grafana。\n\n4. **访问监控界面**\n\n- Prometheus：浏览器访问 http://localhost:9090.\n\n- Grafana：浏览器访问 http://localhost:3000.\n\n5. **登录 Grafana**\n\n- 默认用户名：`admin`\n\n- 默认密码：`admin` （首次登录后会提示修改密码）\n\n6. **查看仪表盘**\n\n预配置的 LMDeploy 仪表盘将自动加载。\n\n## 使用说明 (DP > 1)\n\n1. **启动已启用指标的 LMDeploy 服务**\n\n以模型 `Qwen/Qwen2.5-7B-Instruct` 为例，使用 `DP=2，TP=2` 启动服务：\n\n```bash\n# Proxy server\nlmdeploy serve proxy --server-port 8000 --routing-strategy 'min_expected_latency' --serving-strategy Hybrid --log-level INFO\n\n# API server\nLMDEPLOY_DP_MASTER_ADDR=127.0.0.1 \\\nLMDEPLOY_DP_MASTER_PORT=29555 \\\nlmdeploy serve api_server \\\n    Qwen/Qwen2.5-7B-Instruct \\\n    --backend pytorch \\\n    --tp 2 \\\n    --dp 2 \\\n    --proxy-url http://0.0.0.0:8000 \\\n    --nnodes 1 \\\n    --node-rank 0 \\\n    --enable-metrics\n```\n\n您应该能在代理服务器列表中看到多个 API 服务实例。详细信息可以在 `lmdeploy/serve/proxy/proxy_config.json` 中找到。\n\n例如，您可能会看到如下 API 服务地址：\n\n```\nhttp://$host_ip:$api_server_port1\n\nhttp://$host_ip:$api_server_port2\n```\n\n2. **修改 Prometheus 配置**\n\n当 DP > 1 时，LMDeploy 会为每个 DP Rank 启动一个 API 服务。如果你想监控其中某个 API 服务，例如：`http://$host_ip:$api_server_port1`，请修改配置文件 `lmdeploy/monitoring/prometheus.yaml` 如下所示。\n\n> 注意：这里应使用实际主机的 IP 地址而非 127.0.0.1，因为当 DP > 1 时，LMDeploy 是通过实际主机 IP 启动 API 服务的。\n\n```\nglobal:\n  scrape_interval: 5s\n  evaluation_interval: 30s\n\nscrape_configs:\n  - job_name: lmdeploy\n    static_configs:\n      - targets:\n          - '$host_ip:$api_server_port1' # <= 修改此处\n```\n\n3. **进入监控目录并执行上述相同步骤**\n\n## 故障排除\n\n1. **端口冲突**\n\n检查端口 `23333` (LMDeploy 服务端口)、`9090` (Prometheus 端口) 或 `3000` (Grafana 端口) 是否被占用。解决方案，关闭冲突的端口或如下修改配置文件：\n\n- 修改 Prometheus 抓取的 LMDeploy 服务端口\n\n在 `lmdeploy/monitoring/prometheus.yaml` 中\n\n```\nglobal:\n  scrape_interval: 5s\n  evaluation_interval: 30s\n\nscrape_configs:\n  - job_name: lmdeploy\n    static_configs:\n      - targets:\n          - '127.0.0.1:23333' # <= 修改此处的 LMDeploy 服务端口 23333，需与实际运行端口一致\n```\n\n- 修改 Prometheus 端口\n\n在 `lmdeploy/monitoring/grafana/datasources/datasource.yaml` 中\n\n```\napiVersion: 1\ndatasources:\n  - name: Prometheus\n    type: prometheus\n    access: proxy\n    url: http://localhost:9090 # <= 修改此处的 Prometheus 接口端口 9090\n    isDefault: true\n    editable: false\n```\n\n- 修改 Grafana 端口\n\n在 `lmdeploy/monitoring/docker-compose.yaml` 中操作（例如改为 3090 端口）:\n\n方案一：在环境变量中添加 `GF_SERVER_HTTP_PORT`\n\n```\n  environment:\n- GF_AUTH_ANONYMOUS_ENABLED=true\n- GF_SERVER_HTTP_PORT=3090  # <= 添加此行\n```\n\n方案二：使用端口映射\n\n```\ngrafana:\n  image: grafana/grafana:latest\n  container_name: grafana\n  ports:\n  - \"3090:3000\"  # <= 主机端口:容器端口映射\n```\n\n- **仪表盘无数据**\n\n尝试向 LMDeploy 服务发送请求生成流量：\n\n```\npython3 benchmark/profile_restful_api.py --backend lmdeploy --num-prompts 5000 --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json\n```\n\n刷新后仪表盘应显示数据。\n"
  },
  {
    "path": "docs/zh_cn/advance/pytorch_multinodes.md",
    "content": "# PyTorchEngine 多节点部署指南\n\n为了支持更大规模的模型部署需求，PyTorchEngine 提供了多节点部署的支持。以下是如何在两个8卡节点上部署 tp=16 模型的详细步骤。\n\n## 1. 创建 Docker 容器（可选）\n\n为了确保集群环境的一致性，建议使用 Docker 搭建集群。在每个节点上创建容器：\n\n```bash\ndocker run -it \\\n    --network host \\\n    -v $MODEL_PATH:$CONTAINER_MODEL_PATH \\\n    openmmlab/lmdeploy:latest\n```\n\n> \\[!IMPORTANT\\]\n> 请确保将模型放置在各个节点容器的相同目录中。\n\n## 2. 使用 ray 搭建集群\n\n### 2.1 启动主节点\n\n选择其中一个节点做为`主节点`，并在该节点的容器中运行以下命令：\n\n```bash\nray start --head --port=$DRIVER_PORT\n```\n\n### 2.2 加入集群\n\n在其他节点的容器中，使用以下命令加入主节点所在的集群：\n\n```bash\nray start --address=$DRIVER_NODE_ADDR:$DRIVER_PORT\n```\n\n完成后可以在主节点使用 `ray status` 查看集群状态，确保所有节点都被成功加入集群。\n\n> \\[!IMPORTANT\\]\n> 请确保 `DRIVER_NODE_ADDR` 为主节点的地址，`DRIVER_PORT` 与主节点初始化时使用的端口号一致。\n\n## 3. 使用 LMDeploy 接口\n\n在主节点的容器中，您可以正常使用 PyTorchEngine 的所有功能。\n\n### 3.1 启动服务 API\n\n```bash\nlmdeploy serve api_server \\\n    $CONTAINER_MODEL_PATH \\\n    --backend pytorch \\\n    --tp 16\n```\n\n### 3.2 使用 pipeline 接口\n\n```python\nfrom lmdeploy import pipeline, PytorchEngineConfig\n\nif __name__ == '__main__':\n    model_path = '/path/to/model'\n    backend_config = PytorchEngineConfig(tp=16)\n    with pipeline(model_path, backend_config=backend_config) as pipe:\n        outputs = pipe('Hakuna Matata')\n```\n\n> \\[!NOTE\\]\n> PytorchEngine 会根据 tp 数以及集群上的设备数量自动选择合适的启动方式（单机/多机）。如果希望强制使用 ray 集群，可以配置 `PytorchEngineConfig` 中的 `distributed_executor_backend='ray'` 或使用环境变量 `LMDEPLOY_EXECUTOR_BACKEND=ray`。\n\n通过以上步骤，您可以成功在多节点环境中部署 PyTorchEngine，并利用 Ray 集群进行分布式计算。\n\n> \\[!WARNING\\]\n> 为了能够得到更好的性能，我们建议用户配置更好的网络环境（比如使用 [InfiniBand](https://en.wikipedia.org/wiki/InfiniBand)）以提高引擎运行效率\n"
  },
  {
    "path": "docs/zh_cn/advance/pytorch_multithread.md",
    "content": "# PyTorchEngine 多线程推理\n\n自 [PR2907](https://github.com/InternLM/lmdeploy/pull/2907) 起，我们废除了 PytorchEngine 的 thread_safe 模式以保证引擎能够更高效的运行。我们鼓励用户尽可能使用**服务接口**或**协程**来实现高并发,\n如果你确实有多线程推理的需求，那么可以进行简单的封装，来实现类似的效果。\n\n```python\nimport threading\nfrom queue import Queue\nimport asyncio\nfrom lmdeploy import pipeline, PytorchEngineConfig\n\nmodel_path = 'Llama-3.2-1B-Instruct'\n\n\nasync def _batch_infer(inque: Queue, outque: Queue, pipe):\n    while True:\n        if inque.empty():\n            await asyncio.sleep(0)\n            continue\n\n        input = inque.get_nowait()\n        output = await pipe.async_batch_infer(input)\n        outque.put(output)\n\n\ndef server(inques, outques):\n    event_loop = asyncio.new_event_loop()\n    asyncio.set_event_loop(event_loop)\n    pipe = pipeline(model_path, backend_config=PytorchEngineConfig())\n    for inque, outque in zip(inques, outques):\n        event_loop.create_task(_batch_infer(inque, outque, pipe))\n    event_loop.run_forever()\n\ndef client(inque, outque, message):\n    inque.put(message)\n    print(outque.get().text)\n\n\ninques = [Queue(), Queue()]\noutques = [Queue(), Queue()]\n\nt_server = threading.Thread(target=server, args=(inques, outques))\nt_client0 = threading.Thread(target=client, args=(inques[0], outques[0], 'Hakuna Matata'))\nt_client1 = threading.Thread(target=client, args=(inques[1], outques[1], 'giraffes are heartless creatures'))\n\nt_server.start()\nt_client0.start()\nt_client1.start()\n\nt_client0.join()\nt_client1.join()\n```\n\n> \\[!WARNING\\]\n> 我们不鼓励这样实现，多线程会带来额外的开销，使得推理性能不稳定\n"
  },
  {
    "path": "docs/zh_cn/advance/pytorch_new_model.md",
    "content": "# lmdeploy.pytorch 新模型支持\n\nlmdeploy.pytorch 被设计用来简化新模型的支持以及原型的开发，用户可以根据自己的需求适配新的模型。\n\n## 模型支持\n\n### 配置加载（可选）\n\nlmdeploy.pytorch 会根据模型的参数初始化引擎，如果需要接入的模型的参数命名与 transformers 中常见模型不同，可能存在解析错误的情况。可以添加自定义的 ConfigBuilder 来解析配置\n\n```python\n# lmdeploy/pytorch/configurations/gemma.py\n\nfrom lmdeploy.pytorch.config import ModelConfig\n\nfrom .builder import AutoModelConfigBuilder\n\n\nclass GemmaModelConfigBuilder(AutoModelConfigBuilder):\n\n    @classmethod\n    def condition(cls, hf_config):\n        # 判断 hf_config 是否适配该 builder\n        return hf_config.model_type in ['gemma', 'gemma2']\n\n    @classmethod\n    def build(cls, hf_config, model_path: str = None):\n        # 使用 transformers 加载的 hf_config\n        # 构造 pytorch engine 的 ModelConfig\n        return ModelConfig(hidden_size=hf_config.hidden_size,\n                           num_layers=hf_config.num_hidden_layers,\n                           num_attention_heads=hf_config.num_attention_heads,\n                           num_key_value_heads=hf_config.num_key_value_heads,\n                           bos_token_id=hf_config.bos_token_id,\n                           eos_token_id=hf_config.eos_token_id,\n                           head_dim=hf_config.head_dim,\n                           vocab_size=hf_config.vocab_size)\n```\n\n可以使用 `lmdeploy.pytorch.check_env.check_model` 函数验证配置是否能够正确解析\n\n### 实现模型\n\n在确保能够正确解析配置后，就可以开始实现模型逻辑。以 llama 的实现为例，我们需要通过 transformers 的配置文件创建模型\n\n```python\nclass LlamaForCausalLM(nn.Module):\n\n    # 构造函数，通过传入的 config 搭建模型\n    # ctx_mgr 是上下文管理器，可以通过它传入引擎配置或额外参数\n    def __init__(self,\n                 config: LlamaConfig,\n                 ctx_mgr: StepContextManager,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__()\n        self.config = config\n        self.ctx_mgr = ctx_mgr\n        # build LLamaModel\n        self.model = LlamaModel(config, dtype=dtype, device=device)\n        # build lm_head\n        self.lm_head = build_rowwise_linear(config.hidden_size,\n                                            config.vocab_size,\n                                            bias=False,\n                                            dtype=dtype,\n                                            device=device)\n\n    # 模型推理函数\n    # 推荐尽可能使用与下面相同的参数\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        past_key_values: List[List[torch.Tensor]],\n        attn_metadata: Any = None,\n        inputs_embeds: torch.Tensor = None,\n        **kwargs,\n    ):\n        hidden_states = self.model(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            inputs_embeds=inputs_embeds,\n        )\n\n        logits = self.lm_head(hidden_states)\n        logits = logits.float()\n        return logits\n```\n\n除了这些以外，还有如下内容需要添加\n\n```python\nclass LlamaForCausalLM(nn.Module):\n\n    ...\n\n    # 标注该模型是否支持 cudagraph\n    # 可以是一个 callable 对象，接收 forward 输入\n    # 动态判断是否支持 cudagraph\n    support_cuda_graph = True\n\n    # 构建模型输入\n    # 返回词典，词典的 key 必须是 forward 的输入\n    def prepare_inputs_for_generation(\n        self,\n        past_key_values: List[List[torch.Tensor]],\n        inputs_embeds: Optional[torch.Tensor] = None,\n        context: StepContext = None,\n    ):\n        ...\n\n    # 加载权重\n    # 模型的输入是 state dict 的 key value 对\n    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):\n        ...\n```\n\n我们封装了许多融合算子以简化模型的搭建。这些算子能够更好的支持 tensor 并行、量化等各种功能，我们鼓励开发者尽可能使用这些 op 进行开发。\n\n```python\n# 使用预定义的 build_merged_colwise_linear, SiluAndMul, build_rowwise_linear\n# 可以帮助我们更快搭建模型，并且不用关心 tensor 并发、量化等细节\nclass LlamaMLP(nn.Module):\n\n    def __init__(self,\n                 config: LlamaConfig,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__()\n        quantization_config = getattr(config, 'quantization_config', None)\n        # gate up\n        self.gate_up_proj = build_merged_colwise_linear(\n            config.hidden_size,\n            [config.intermediate_size, config.intermediate_size],\n            bias=config.mlp_bias,\n            dtype=dtype,\n            device=device,\n            quant_config=quantization_config,\n            is_tp=True,\n        )\n\n        # silu and mul\n        self.act_fn = SiluAndMul(inplace=True)\n\n        # down\n        self.down_proj = build_rowwise_linear(config.intermediate_size,\n                                              config.hidden_size,\n                                              bias=config.mlp_bias,\n                                              quant_config=quantization_config,\n                                              dtype=dtype,\n                                              device=device,\n                                              is_tp=True)\n\n    def forward(self, x):\n        \"\"\"forward.\"\"\"\n        gate_up = self.gate_up_proj(x)\n        act = self.act_fn(gate_up)\n        return self.down_proj(act)\n```\n\n### 模型注册\n\n为了能够让开发的模型实现可以正常使用，我们还需要在 `lmdeploy/pytorch/models/module_map.py` 中注册该模型\n\n```python\nMODULE_MAP.update({\n    'LlamaForCausalLM':\n    f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaForCausalLM',\n})\n```\n\n如果你不希望修改模型源码，也可以从外部传入自定义的 module map，方便整合进其他项目中\n\n```\nfrom lmdeploy import PytorchEngineConfig, pipeline\n\nbackend_config = PytorchEngineConfig(custom_module_map='/path/to/custom/module_map.py')\ngenerator = pipeline(model_path, backend_config=backend_config)\n```\n"
  },
  {
    "path": "docs/zh_cn/advance/pytorch_profiling.md",
    "content": "# PyTorchEngine 性能分析\n\n我们提供了数种分析 PytorchEngine 性能的方式\n\n## PyTorch Profiler\n\n我们集成了 PyTorch Profiler，可以在启动 pipeline 或 api server 时添加环境变量：\n\n```bash\n# enable profile cpu\nexport LMDEPLOY_PROFILE_CPU=1\n# enable profile cuda\nexport LMDEPLOY_PROFILE_CUDA=1\n# profile would start after 3 seconds\nexport LMDEPLOY_PROFILE_DELAY=3\n# profile 10 seconds\nexport LMDEPLOY_PROFILE_DURATION=10\n# prefix path to save profile files\nexport LMDEPLOY_PROFILE_OUT_PREFIX=\"/path/to/save/profile_\"\n```\n\n这样在退出程序后，统计信息会被存储在 `LMDEPLOY_PROFILE_OUT_PREFIX` 指定的地址，方便进行性能分析。\n\n## Nsight System\n\n我们也支持使用 Nsight System 分析 nVidia 设备的性能。\n\n### 单卡\n\n单卡情况下比较简单，可以直接使用 `nsys profile`：\n\n```bash\nnsys profile python your_script.py\n```\n\n### 多卡\n\n当启用了 DP/TP/EP 等多卡方案时，可以设置环境变量\n\n```bash\n# enable nsight system\nexport LMDEPLOY_RAY_NSYS_ENABLE=1\n# prefix path to save profile files\nexport LMDEPLOY_RAY_NSYS_OUT_PREFIX=\"/path/to/save/profile_\"\n```\n\n然后正常启动脚本或 api server 即可（注意**不要**添加 `nsys profile`）\n\n这样 profile 的结果就会被保存在 `LMDEPLOY_RAY_NSYS_OUT_PREFIX` 下，如果没有配置 `LMDEPLOY_RAY_NSYS_OUT_PREFIX`，可以在 `/tmp/ray/session_xxx/nsight` 目录下找到。\n\n## Ray timeline\n\n我们使用 ray 实现多卡支持，如果希望查看 ray timeline，可以配置如下环境变量：\n\n```bash\nexport LMDEPLOY_RAY_TIMELINE_ENABLE=1\nexport LMDEPLOY_RAY_TIMELINE_OUT_PATH=\"/path/to/save/timeline.json\"\n```\n"
  },
  {
    "path": "docs/zh_cn/advance/spec_decoding.md",
    "content": "# Speculative Decoding\n\n投机解码是一种优化技术，它通过引入轻量级草稿模型来预测多个后续token，再由主模型在前向推理过程中验证并选择匹配度最高的长token序列。与标准的自回归解码相比，这种方法可使系统一次性生成多个token。\n\n> \\[!NOTE\\]\n> 请注意，这是lmdeploy中的实验性功能。\n\n## 示例\n\n请参考如下使用示例。\n\n### Eagle 3\n\n#### 安装依赖\n\n安装 [flash-atten3 ](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#flashattention-3-beta-release)\n\n```shell\ngit clone --depth=1 https://github.com/Dao-AILab/flash-attention.git\ncd flash-attention/hopper\npython setup.py install\n```\n\n#### pipeline\n\n```python\nfrom lmdeploy import PytorchEngineConfig, pipeline\nfrom lmdeploy.messages import SpeculativeConfig\n\n\nif __name__ == '__main__':\n\n    model_path = 'meta-llama/Llama-3.1-8B-Instruct'\n    spec_cfg = SpeculativeConfig(\n        method='eagle3',\n        num_speculative_tokens=3,\n        model='yuhuili/EAGLE3-LLaMA3.1-Instruct-8B',\n    )\n    pipe = pipeline(model_path, backend_config=PytorchEngineConfig(max_batch_size=128), speculative_config=spec_cfg)\n    response = pipe(['Hi, pls intro yourself', 'Shanghai is'])\n    print(response)\n\n```\n\n#### serving\n\n```shell\nlmdeploy serve api_server \\\nmeta-llama/Llama-3.1-8B-Instruct \\\n--backend pytorch \\\n--server-port 24545 \\\n--speculative-draft-model yuhuili/EAGLE3-LLaMA3.1-Instruct-8B \\\n--speculative-algorithm eagle3 \\\n--speculative-num-draft-tokens 3 \\\n--max-batch-size 128 \\\n--enable-metrics\n```\n\n### Deepseek MTP\n\n#### 安装依赖\n\nInstall [FlashMLA](https://github.com/deepseek-ai/FlashMLA?tab=readme-ov-file#installation)\n\n```shell\ngit clone https://github.com/deepseek-ai/FlashMLA.git flash-mla\ncd flash-mla\ngit submodule update --init --recursive\npip install -v .\n```\n\n#### pipeline\n\n```python\nfrom lmdeploy import PytorchEngineConfig, pipeline\nfrom lmdeploy.messages import SpeculativeConfig\n\n\nif __name__ == '__main__':\n\n    model_path = 'deepseek-ai/DeepSeek-V3'\n    spec_cfg = SpeculativeConfig(\n        method='deepseek_mtp',\n        num_speculative_tokens=3,\n    )\n    pipe = pipeline(model_path,\n                    backend_config=PytorchEngineConfig(tp=16, max_batch_size=128),\n                    speculative_config=spec_cfg)\n    response = pipe(['Hi, pls intro yourself', 'Shanghai is'])\n    print(response)\n```\n\n#### serving\n\n```shell\nlmdeploy serve api_server \\\ndeepseek-ai/DeepSeek-V3 \\\n--backend pytorch \\\n--server-port 24545 \\\n--tp 16 \\\n--speculative-algorithm deepseek_mtp \\\n--speculative-num-draft-tokens 3 \\\n--max-batch-size 128 \\\n--enable-metrics\n```\n"
  },
  {
    "path": "docs/zh_cn/advance/structed_output.md",
    "content": "# 结构化输出\n\n结构化输出（也称为引导解码）会强制模型生成与用户提供的 JSON 模式、语法或正则表达式完全匹配的文本。\n当前，PyTorch 与 Turbomind 两个后端均已支持这种（受模式约束的）结构化生成。\n以下分别为 pipeline API 和 API 服务的使用示例。\n\n## pipeline\n\n```python\nfrom lmdeploy import pipeline\nfrom lmdeploy.messages import GenerationConfig, PytorchEngineConfig\n\nmodel = 'internlm/internlm2-chat-1_8b'\nguide = {\n    'type': 'object',\n    'properties': {\n        'name': {\n            'type': 'string'\n        },\n        'skills': {\n            'type': 'array',\n            'items': {\n                'type': 'string',\n                'maxLength': 10\n            },\n            'minItems': 3\n        },\n        'work history': {\n            'type': 'array',\n            'items': {\n                'type': 'object',\n                'properties': {\n                    'company': {\n                        'type': 'string'\n                    },\n                    'duration': {\n                        'type': 'string'\n                    }\n                },\n                'required': ['company']\n            }\n        }\n    },\n    'required': ['name', 'skills', 'work history']\n}\npipe = pipeline(model, backend_config=PytorchEngineConfig(), log_level='INFO')\ngen_config = GenerationConfig(\n    response_format=dict(type='json_schema', json_schema=dict(name='test', schema=guide)))\nresponse = pipe(['Make a self introduction please.'], gen_config=gen_config)\nprint(response)\n```\n\n## api_server\n\n首先，先启动 InternLM2 模型的 api_server 服务。\n\n```shell\nlmdeploy serve api_server internlm/internlm2-chat-1_8b --backend pytorch\n```\n\n客户端可以使用 OpenAI 的 python 包进行测试：\n\n```python\nfrom openai import OpenAI\nguide = {\n    'type': 'object',\n    'properties': {\n        'name': {\n            'type': 'string'\n        },\n        'skills': {\n            'type': 'array',\n            'items': {\n                'type': 'string',\n                'maxLength': 10\n            },\n            'minItems': 3\n        },\n        'work history': {\n            'type': 'array',\n            'items': {\n                'type': 'object',\n                'properties': {\n                    'company': {\n                        'type': 'string'\n                    },\n                    'duration': {\n                        'type': 'string'\n                    }\n                },\n                'required': ['company']\n            }\n        }\n    },\n    'required': ['name', 'skills', 'work history']\n}\nresponse_format=dict(type='json_schema',  json_schema=dict(name='test',schema=guide))\nmessages = [{'role': 'user', 'content': 'Make a self-introduction please.'}]\nclient = OpenAI(api_key='YOUR_API_KEY', base_url='http://0.0.0.0:23333/v1')\nmodel_name = client.models.list().data[0].id\nresponse = client.chat.completions.create(\n    model=model_name,\n    messages=messages,\n    temperature=0.8,\n    response_format=response_format,\n    top_p=0.8)\nprint(response)\n```\n\n输出结果是一个 json 格式的回答。\n"
  },
  {
    "path": "docs/zh_cn/advance/update_weights.md",
    "content": "# 权重更新\n\nLMDeploy支持在线权重更新，方便RL训练等场景下的使用。以下是权重更新的步骤：\n\n## 步骤 1: 启动服务\n\nFor pytorch backend you have to add `--distributed-executor-backend ray`.\n\n```shell\nlmdeploy serve api_server internlm/internlm2_5-7b-chat --server-port 23333 --distributed-executor-backend ray # for pytorch backend\n```\n\n## 步骤 2: 卸载权重和KV缓存\n\n在权重更新前，需要调用API卸载权重和KV缓存，使推理引擎处于可更新状态：\n\n```python\nfrom lmdeploy.utils import serialize_state_dict\nimport requests\n\nBASE_URL = 'http://0.0.0.0:23333'\napi_key = 'sk-xxx'\n\nheaders = {\n                \"Content-Type\": \"application/json\",\n                \"Authorization\": f\"Bearer {api_key}\",\n            }\n\n# offloads weights and kv cache with level=2\nresponse = requests.post(f\"{BASE_URL}/sleep\", headers=headers, params=dict(tags=['weights', 'kv_cache'], level=2))\nassert response.status_code == 200, response.status_code\n\n# wake up weights, the server is ready for update weights\nresponse = requests.post(f\"{BASE_URL}/wakeup\", headers=headers, params=dict(tags=['weights']))\nassert response.status_code == 200, response.status_code\n```\n\n## 步骤 3: 更新权重\n\n将模型权重切分后调用`update_weights`API进行更新。\n\n```python\nsegmented_state_dict: List[Dict[str, torch.Tensor]] = ...\nnum_segment = len(segmented_state_dict)\nfor seg_idx in range(num_segment):\n    serialized_data = serialize_state_dict(segmented_state_dict[seg_idx])\n    data = dict(serialized_named_tensors=serialized_data, finished=seg_idx == num_segment-1)\n    response = requests.post(f\"{BASE_URL}/update_weights\", headers=headers, json=data)\n    assert response.status_code == 200, f\"response.status_code = {response.status_code}\"\n\n```\n\n**注意**: 对于pytorch推理后端，lmdeploy还支持扁平化桶张量(flattened bucket tensor)传输方式:\n\n```python\nfrom lmdeploy.utils import serialize_state_dict, FlattenedTensorBucket, FlattenedTensorMetadata\n\nsegmented_state_dict: List[Dict[str, torch.Tensor]] = ...\nnum_segment = len(segmented_state_dict)\nfor seg_idx in range(num_segment):\n    named_tensors = list(segmented_state_dict[seg_idx].items())\n    bucket = FlattenedTensorBucket(named_tensors=named_tensors)\n    metadata = bucket.get_metadata()\n    flattened_tensor_data = dict(flattened_tensor=bucket.get_flattened_tensor(), metadata=metadata)\n    serialized_data = serialize_state_dict(flattened_tensor_data)\n    data = dict(serialized_named_tensors=serialized_data, finished=seg_idx == num_segment-1, load_format='flattened_bucket')\n    response = requests.post(f\"{BASE_URL}/update_weights\", headers=headers, json=data)\n    assert response.status_code == 200, f\"response.status_code = {response.status_code}\"\n```\n\n## 步骤 4: 唤醒引擎\n\n权重更新后，调用API构建KV缓存，唤醒引擎，重新提供推理服务。\n\n```python\nresponse = requests.post(f\"{BASE_URL}/wakeup\", headers=headers, params=dict(tags=['kv_cache']))\nassert response.status_code == 200, response.status_code\n```\n"
  },
  {
    "path": "docs/zh_cn/api/cli.rst",
    "content": "命令行工具\n===========\n\n.. sphinx_argparse_cli::\n   :module: lmdeploy.cli\n   :func: run\n   :hook:\n   :prog: lmdeploy\n"
  },
  {
    "path": "docs/zh_cn/api/openapi.rst",
    "content": "OpenAPI 接口\n============\n.. currentmodule:: lmdeploy\n\nOpenAI 兼容服务器接口\n----------------------\n\n.. openapi:: ../_static/openai.yaml\n    :request:\n    :examples:\n\n\nProxy 服务器接口\n-----------------\n\n.. openapi:: ../_static/proxy.yaml\n    :request:\n    :examples:\n"
  },
  {
    "path": "docs/zh_cn/api/pipeline.rst",
    "content": "推理 pipeline\n==================\n.. currentmodule:: lmdeploy\n\nPipeline\n--------\n.. autofunction:: pipeline\n.. autoclass:: Pipeline\n   :undoc-members:\n   :show-inheritance:\n   :members: __init__, infer, stream_infer, chat, get_ppl\n   :member-order: bysource\n\nConfig\n-------------------\n.. autoclass:: PytorchEngineConfig\n.. autoclass:: TurbomindEngineConfig\n.. autoclass:: GenerationConfig\n.. autoclass:: ChatTemplateConfig\n"
  },
  {
    "path": "docs/zh_cn/benchmark/benchmark.md",
    "content": "# 性能测试\n\n测试之前，请安装 lmdeploy 预编译包，并下载测试脚本和数据。\n\n```shell\npip install lmdeploy\n# 下载 lmdeploy 源码，获取其中的性能测试脚本\ngit clone --depth=1 https://github.com/InternLM/lmdeploy\ncd lmdeploy\n# 切换到与已安装 lmdeploy 版本对应的 tag：\ngit fetch --tags\n# 查看已安装 lmdeploy 的版本：\npip show lmdeploy | grep Version\n# 切换到对应的 tag（将 <version> 替换为实际的版本号）：\ngit checkout <version>\n# 下载测试数据\nwget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json\n```\n\n## 测试 pipeline 接口\n\n```shell\npython3 benchmark/profile_pipeline_api.py ShareGPT_V3_unfiltered_cleaned_split.json meta-llama/Meta-Llama-3-8B-Instruct\n```\n\n可通过 `python3 benchmark/profile_pipeline_api.py -h` 查看脚本中的参数详情\n\n## 测试推理引擎接口\n\n```shell\npython3 benchmark/profile_throughput.py ShareGPT_V3_unfiltered_cleaned_split.json meta-llama/Meta-Llama-3-8B-Instruct\n```\n\n可通过 `python3 benchmark/profile_throughput.py -h` 查看脚本中的参数详情\n\n## 测试 api_server 性能\n\n启动模型服务（可以参考[这里](../llm/api_server.md)）。接着，使用下面的命令:\n\n```shell\npython3 benchmark/profile_restful_api.py --backend lmdeploy  --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json\n```\n\n关于 `profile_restful_api.py`的帮助信息，可以通过`python3 benchmark/profile_restful_api.py -h`查阅\n"
  },
  {
    "path": "docs/zh_cn/benchmark/evaluate_with_opencompass.md",
    "content": "# 模型评测指南\n\n本文档介绍如何使用 OpenCompass 和 LMDeploy 对模型在学术数据集上的能力进行评测。完整的评测流程包含两个主要阶段：推理阶段和评判阶段。\n\n在推理阶段，首先通过 LMDeploy 将待评测模型部署为推理服务，随后使用 OpenCompass 将数据集内容作为请求发送至该服务，并获取模型生成的结果。\n\n在评判阶段，需将 OpenCompass 提供的评测模型 `opencompass/CompassVerifier-32B` 通过 LMDeploy 部署为服务，再使用 OpenCompass 将推理阶段生成的结果提交至该服务，从而获得最终的评测结果。\n\n若评测资源充足，建议参考[端到端评测](#端到端评测)章节执行完整流程；若资源有限，则建议按照[逐步评测](#逐步评测)章节依次执行两个阶段。\n\n## 环境准备\n\n```shell\npip install lmdeploy\npip install \"opencompass[full]\"\n\n# 下载 lmdeploy 源码，在后续步骤中会使用到 eval/* 中的评测脚本和配置文件\ngit clone --depth=1 https://github.com/InternLM/lmdeploy.git\n```\n\n建议将 LMDeploy 和 OpenCompass 安装在不同的 Python 虚拟环境中，以避免可能的依赖冲突。\n\n## 端到端评测\n\n1. **部署待评测模型**\n\n```shell\nlmdeploy serve api_server <model_path> --server-port 10000 <--other-options>\n```\n\n2. **部署评测模型（Judger）**\n\n```shell\nlmdeploy serve api_server opencompass/CompassVerifier-32B --server-port 20000 --tp 2 --session-len 65536\n```\n\n3. **生成评测配置并执行评测**\n\n```shell\ncd {the/root/path/of/lmdeploy/repo}\n\n## 指定数据集路径。如果在路径下没有找到评测数据集，OC会自动下载\nexport HF_DATASETS_CACHE=/nvme4/huggingface_hub/datasets\nexport COMPASS_DATA_CACHE=/nvme1/shared/opencompass/.cache\npython eval/eval.py {task_name} \\\n    --mode all \\\n    --api-server http://{api-server-ip}:10000 \\\n    --judger-server http://{judger-server-ip}:20000 \\\n    -w {oc_output_dir}\n```\n\n关于 `eval.py` 的详细使用方法，比如指定评测集，请通过 `python eval/eval.py --help` 查阅。\n\n评测任务完成后，结果将保存在 `{oc_output_dir}/{yyyymmdd_hhmmss}` 目录中，其中 `{yyyymmdd_hhmmss}` 为任务执行的时间戳。\n\n## 逐步评测\n\n### 推理阶段\n\n本阶段用于生成模型对数据集的回答结果。\n\n1. **部署待评测模型**\n\n```shell\nlmdeploy serve api_server <model_path> --server-port 10000 <--other-options>\n```\n\n2. **生成推理配置并执行推理**\n\n```shell\ncd {the/root/path/of/lmdeploy/repo}\n\n## 指定数据集路径。如果在路径下没有找到评测数据集，OC会自动下载\nexport HF_DATASETS_CACHE=/nvme4/huggingface_hub/datasets\nexport COMPASS_DATA_CACHE=/nvme1/shared/opencompass/.cache\n# 执行推理任务\npython eval/eval.py {task_name} \\\n    --mode infer \\\n    --api-server http://{api-server-ip}:10000 \\\n    -w {oc_output_dir}\n```\n\n关于 `eval.py` 的详细使用方法，比如指定评测集，请通过 `python eval/eval.py --help` 查阅。\n\n推理完成后，结果将保存在 `{oc_output_dir}/{yyyymmdd_hhmmss}` 目录中，其中 `{yyyymmdd_hhmmss}` 为任务执行的时间戳。\n\n### 评判阶段\n\n本阶段由评测模型（Judger）对推理阶段生成的结果进行判断。\n\n1. **部署评测模型（Judger）**\n\n```shell\nlmdeploy serve api_server opencompass/CompassVerifier-32B --server-port 20000 --tp 2\n```\n\n2. **生成评判配置并执行评判**\n\n```shell\ncd {the/root/path/of/lmdeploy/repo}\n\n## 指定数据集路径。如果在路径下没有找到评测数据集，OC会自动下载\nexport HF_DATASETS_CACHE=/nvme4/huggingface_hub/datasets\nexport COMPASS_DATA_CACHE=/nvme1/shared/opencompass/.cache\n# 执行评测任务\npython eval/eval.py {task_name} \\\n    --mode eval \\\n    --judger-server http://{judger-server-ip}:20000 \\\n    -w {oc_output_dir} -r {yyyymmdd_hhmmss}\n```\n\n注意事项：\n\n- `task_name` 必须与推理阶段的任务名称保持一致\n- `-w` 参数指定的输出目录 `oc_output_dir` 需与推理阶段一致\n- `-r` 参数用于指定“之前的输出与结果”，应填入推理阶段生成的时间戳目录名，即 `{oc_output_dir}` 下的子目录名称\n\n关于 `eval.py` 的详细使用方法，比如指定评测集，请通过 `python eval/eval.py --help` 查阅。\n"
  },
  {
    "path": "docs/zh_cn/benchmark/evaluate_with_vlmevalkit.md",
    "content": "# 多模态模型评测指南\n\n本文档介绍如何使用 VLMEvalKit 和 LMDeploy 评测多模态模型能力。\n\n## 环境准备\n\n```shell\npip install lmdeploy\n\ngit clone https://github.com/open-compass/VLMEvalKit.git\ncd VLMEvalKit && pip install -e .\n```\n\n建议在不同的 Python 虚拟环境中分别安装 LMDeploy 和 VLMEvalKit，以避免潜在的依赖冲突。\n\n## 评测\n\n1. **部署大语言多模态模型 (LMMs)**\n\n```shell\nlmdeploy serve api_server <model_path> --server-port 23333 <--other-options>\n```\n\n2. **配置评测设置**\n\n修改 `VLMEvalKit/vlmeval/config.py`，在 `api_models` 字典中添加以下 LMDeploy API 配置。\n\n`<task_name>` 是您评测任务的自定义名称（例如 `lmdeploy_qwen3vl-4b`）。`model` 参数应与 `lmdeploy serve` 命令中使用的 `<model_path>` 保持一致。\n\n```python\n// filepath: VLMEvalKit/vlmeval/config.py\n// ...existing code...\napi_models = {\n    # lmdeploy api\n    ...,\n    \"<task_name>\": partial(\n        LMDeployAPI,\n        api_base=\"http://0.0.0.0:23333/v1/chat/completions\",\n        model=\"<model_path>\",\n        retry=4,\n        timeout=1200,\n        temperature=0.7, # modify if needed\n        max_new_tokens=16384, # modify if needed\n    ),\n    ...\n}\n// ...existing code...\n```\n\n3. **开始评测**\n\n```shell\ncd VLMEvalKit\npython run.py --data OCRBench --model <task_name> --api-nproc 16 --reuse --verbose --api 123\n```\n\n`<task_name>` 应与上述配置文件中使用的名称保持一致。\n\n参数说明：\n\n- `--data`: 指定用于评测的数据集（例如 `OCRBench`）。\n- `--model`: 指定模型名称，必须与您在 `config.py` 中设置的 `<task_name>` 匹配。\n- `--api-nproc`: 指定并行的 API 调用数量。\n- `--reuse`: 复用先前的推理结果，以避免重新运行已完成的评测。\n- `--verbose`: 启用详细日志记录。\n"
  },
  {
    "path": "docs/zh_cn/conf.py",
    "content": "#\n# Configuration file for the Sphinx documentation builder.\n#\n# This file does only contain a selection of the most common options. For a\n# full list see the documentation:\n# http://www.sphinx-doc.org/en/master/config\n\n# -- Path setup --------------------------------------------------------------\n\n# If extensions (or modules to document with autodoc) are in another directory,\n# add these directories to sys.path here. If the directory is relative to the\n# documentation root, use os.path.abspath to make it absolute, like shown here.\n#\nimport os\nimport sys\nfrom pathlib import Path\n\nfrom fastapi import FastAPI\nfrom fastapi.responses import Response\nfrom yaml import safe_dump\n\nsys.path.insert(0, os.path.abspath('../..'))\n\nfrom lmdeploy.serve.openai.api_server import router  # noqa: E402\nfrom lmdeploy.serve.proxy.proxy import app as proxy_server  # noqa: E402\n\nversion_file = '../../lmdeploy/version.py'\nwith open(version_file, 'r') as f:\n    exec(compile(f.read(), version_file, 'exec'))\n__version__ = locals()['__version__']\n\n# -- Project information -----------------------------------------------------\n\nproject = 'lmdeploy'\ncopyright = '2021-2024, OpenMMLab'\nauthor = 'LMDeploy Authors'\n\n# The short X.Y version\nversion = __version__\n# The full version, including alpha/beta/rc tags\nrelease = __version__\n\n# -- Generate OpenAPI Spec -----------------------------------------------------\n\nopenai_server = FastAPI()\nopenai_server.include_router(router)\n\n\n@openai_server.get('/metrics',\n                   response_class=Response,\n                   responses={\n                       200: {\n                           'content': {\n                               'text/plain': {}\n                           },\n                           'description': 'Prometheus metrics data'\n                       },\n                       404: {\n                           'description': 'Metrics Endpoint not enabled'\n                       }\n                   })\ndef metrics():\n    \"\"\"**[Optional]** Prometheus metrics endpoint.\"\"\"\n    pass\n\n\nspec_dir = Path('_static')\nspec_dir.mkdir(exist_ok=True)\n\nwith open(spec_dir / 'openai.yaml', 'w', encoding='utf-8') as f:\n    f.write(safe_dump(openai_server.openapi()))\n\nwith open(spec_dir / 'proxy.yaml', 'w', encoding='utf-8') as f:\n    f.write(safe_dump(proxy_server.openapi()))\n\n# -- General configuration ---------------------------------------------------\n\n# If your documentation needs a minimal Sphinx version, state it here.\n#\n# needs_sphinx = '1.0'\n\n# Add any Sphinx extension module names here, as strings. They can be\n# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom\n# ones.\n\nextensions = [\n    'myst_parser',\n    'sphinx_argparse_cli',\n    'sphinx.ext.autodoc',\n    'sphinx.ext.autosectionlabel',\n    'sphinx.ext.autosummary',\n    'sphinx.ext.intersphinx',\n    'sphinx.ext.napoleon',\n    'sphinx.ext.viewcode',\n    'sphinx_autodoc_typehints',\n    'sphinx_copybutton',\n    'sphinx_tabs.tabs',\n    'sphinxcontrib.mermaid',\n    'sphinxcontrib.openapi',\n]  # yapf: disable\n\n\nautosectionlabel_prefix_document = True\n\n# Add any paths that contain templates here, relative to this directory.\ntemplates_path = ['_templates']\n\n# The suffix(es) of source filenames.\n# You can specify multiple suffix as a list of string:\n#\nsource_suffix = {\n    '.rst': 'restructuredtext',\n    '.md': 'markdown',\n}\n\n# The master toctree document.\nmaster_doc = 'index'\n\n# The language for content autogenerated by Sphinx. Refer to documentation\n# for a list of supported languages.\n#\n# This is also used if you do content translation via gettext catalogs.\n# Usually you set \"language\" from the command line for these cases.\nlanguage = 'zh_CN'\n\n# List of patterns, relative to source directory, that match files and\n# directories to ignore when looking for source files.\n# This pattern also affects html_static_path and html_extra_path.\nexclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']\n\n# The name of the Pygments (syntax highlighting) style to use.\npygments_style = 'sphinx'\n\n# -- Options for HTML output -------------------------------------------------\n\n# The theme to use for HTML and HTML Help pages.  See the documentation for\n# a list of builtin themes.\n#\n# html_theme = 'sphinx_rtd_theme'\nhtml_theme = 'sphinx_book_theme'\nhtml_logo = '_static/image/lmdeploy-logo.svg'\nhtml_title = project\nhtml_copy_source = True\nhtml_last_updated_fmt = ''\n\n# Theme options are theme-specific and customize the look and feel of a theme\n# further.  For a list of options available for each theme, see the\n# documentation.\n#\nhtml_theme_options = {\n    'path_to_docs': 'docs/zh_cn',\n    'repository_url': 'https://github.com/InternLM/lmdeploy',\n    'repository_branch': 'main',\n    # 'show_navbar_depth': 3,\n    # 'navigation_depth': 4,\n    # 'collapse_navigation': True,\n    'use_edit_page_button': True,\n    'use_source_button': True,\n    'use_issues_button': True,\n    'use_repository_button': True,\n    'use_download_button': True,\n    'use_sidenotes': True,\n    # 'show_toc_level': 2,\n    # \"icon_links\": [\n    #     {\n    #         \"name\": \"Switch to English\",\n    #         \"url\": \"https://lmdeploy.readthedocs.io/en/latest\",\n    #         \"icon\": \"https://img.shields.io/badge/Doc-English-blue\",\n    #         \"type\": \"url\",\n    #     },\n    # ],\n}\n\n# Add any paths that contain custom static files (such as style sheets) here,\n# relative to this directory. They are copied after the builtin static files,\n# so a file named \"default.css\" will overwrite the builtin \"default.css\".\nhtml_static_path = ['_static']\nhtml_css_files = ['css/readthedocs.css']\n\n# Enable ::: for my_st\nmyst_enable_extensions = [\n    'dollarmath',\n    'amsmath',\n    'deflist',\n    # \"html_admonition\",\n    # \"html_image\",\n    'colon_fence',\n    # \"smartquotes\",\n    # \"replacements\",\n    # \"linkify\",\n    # \"substitution\",\n]\nmyst_heading_anchors = 5\n\n# Custom sidebar templates, must be a dictionary that maps document names\n# to template names.\n#\n# The default sidebars (for documents that don't match any pattern) are\n# defined by theme itself.  Builtin themes are using these templates by\n# default: ``['localtoc.html', 'relations.html', 'sourcelink.html',\n# 'searchbox.html']``.\n#\n# html_sidebars = {}\n\n# -- Options for HTMLHelp output ---------------------------------------------\n\n# Output file base name for HTML help builder.\nhtmlhelp_basename = 'lmdeploydoc'\n\n# -- Options for LaTeX output ------------------------------------------------\n\nlatex_elements = {\n    # The paper size ('letterpaper' or 'a4paper').\n    #\n    # 'papersize': 'letterpaper',\n\n    # The font size ('10pt', '11pt' or '12pt').\n    #\n    # 'pointsize': '10pt',\n\n    # Additional stuff for the LaTeX preamble.\n    #\n    # 'preamble': '',\n\n    # Latex figure (float) alignment\n    #\n    # 'figure_align': 'htbp',\n}\n\n# Grouping the document tree into LaTeX files. List of tuples\n# (source start file, target name, title,\n#  author, documentclass [howto, manual, or own class]).\nlatex_documents = [\n    (master_doc, 'lmdeploy.tex', 'lmdeploy Documentation', 'LMDeploy Contributors', 'manual'),\n]\n\n# -- Options for manual page output ------------------------------------------\n\n# One entry per manual page. List of tuples\n# (source start file, name, description, authors, manual section).\nman_pages = [(master_doc, 'lmdeploy', 'lmdeploy Documentation', [author], 1)]\n\n# -- Options for Texinfo output ----------------------------------------------\n\n# Grouping the document tree into Texinfo files. List of tuples\n# (source start file, target name, title, author,\n#  dir menu entry, description, category)\ntexinfo_documents = [\n    (master_doc, 'lmdeploy', 'lmdeploy Documentation', author, 'lmdeploy', 'One line description of project.',\n     'Miscellaneous'),\n]\n\n# -- Options for Epub output -------------------------------------------------\n\n# Bibliographic Dublin Core info.\nepub_title = project\n\n# The unique identifier of the text. This can be a ISBN number\n# or the project homepage.\n#\n# epub_identifier = ''\n\n# A unique identification for the text.\n#\n# epub_uid = ''\n\n# A list of files that should not be packed into the epub file.\nepub_exclude_files = ['search.html']\n\n# -- Extension configuration -------------------------------------------------\n# Ignore >>> when copying code\ncopybutton_prompt_text = r'>>> |\\.\\.\\. '\ncopybutton_prompt_is_regexp = True\n\nautodoc_preserve_defaults = True\nnavigation_with_keys = False\n\n# Mock out external dependencies here,\n# otherwise the autodoc pages may be blank.\nautodoc_mock_imports = [\n    'torch',\n    'torchvision',\n    'transformers',\n    '_turbomind',\n    'triton',\n]\n\nautodoc_type_aliases = {'PydanticDataclass': 'pydantic.dataclasses.PydanticDataclass'}\n\nintersphinx_mapping = {\n    'python': ('https://docs.python.org/3.10', None),\n    'typing_extensions': ('https://typing-extensions.readthedocs.io/en/latest', None),\n    'pillow': ('https://pillow.readthedocs.io/en/stable', None),\n    'numpy': ('https://numpy.org/doc/stable', None),\n    'torch': ('https://pytorch.org/docs/stable', None),\n    'torchvision': ('https://pytorch.org/vision/stable', None),\n}\n"
  },
  {
    "path": "docs/zh_cn/faq.md",
    "content": "# 常见问题\n\n## ModuleNotFoundError\n\n### No module named 'mmengine.config.lazy'\n\n可能是因为已经有旧版本的mmengine缓存在了本机。更新到最新班应该可以解决这个问题。\n\n```shell\npip install --upgrade mmengine\n```\n\n### No module named '\\_turbomind'\n\n可能是因为：\n\n1. 您没有安装 lmdeploy 的预编译包。`_turbomind`是 turbomind c++ 的 pybind部分，涉及到编译。推荐您直接安装预编译包。\n\n```shell\npip install lmdeploy[all]\n```\n\n2. 如果已经安装了，还是出现这个问题，请检查下执行目录。不要在 lmdeploy 的源码根目录下执行 python -m lmdeploy.turbomind.\\*下的package，换到其他目录下执行。\n\n但是如果您是开发人员，通常需要在本地进行开发和编译。每次安装 whl 的效率太低了。您可以通过符号链接在编译后指定 lib 的路径。\n\n```shell\n# 创建 bld 和进行本地编译\nmkdir bld && cd bld && bash ../generate.sh && ninja -j$(nproc)\n\n# 从 bld 中切到 lmdeploy 子目录并设置软链接\ncd ../lmdeploy && ln -s ../bld/lib .\n\n# 切换到 lmdeploy 根目录\ncd ..\n\n# 使用 python command 比如 check_env\npython3 -m lmdeploy check_env\n```\n\n如果您仍然遇到在本地机器上找不到 turbomind so 的问题，这意味着您的本地机器上可能存在多个 Python 环境，并且在编译和执行过程中 Python 的版本不匹配。在这种情况下，您需要根据实际情况设置 `lmdeploy/generate.sh` 中的 `PYTHON_EXECUTABLE`，例如 `-DPYTHON_EXECUTABLE=/usr/local/bin/python3`，并且需要重新编译。\n\n## Libs\n\n### libnccl.so.2 not found\n\n确保通过 `pip install lmdeploy[all]` 安装了 lmdeploy (>=v0.0.5)。\n\n如果安装之后，问题还存在，那么就把`libnccl.so.2`的路径加入到环境变量 LD_LIBRARY_PATH 中。\n\n```shell\n# 获取nvidia-nccl-cu11 package的安装目录\npip show nvidia-nccl-cu11|grep Location\n# 把\"libnccl.so.2\"的路径加入到 LD_LIBRARY_PATH\nexport LD_LIBRARY_PATH={Location}/nvidia/nccl/lib:$LD_LIBRARY_PATH\n```\n\n### symbol cudaFreeAsync version libcudart.so.11.0 not defined in file libcudart.so.11.0 with link time reference\n\n很可能是机器上的 cuda 版本太低导致的。LMDeploy运行时要求 cuda 不低于 11.2\n\n## 推理\n\n### RuntimeError: \\[TM\\]\\[ERROR\\] CUDA runtime error: out of memory /workspace/lmdeploy/src/turbomind/utils/allocator.h\n\n通常这是因为 k/v cache内存比例过大导致的。比例的控制参数是 `TurbomindEngineConfig.cache_max_entry_count`。该参数在不同版本的 lmdeploy中，含义略有不同。具体请参考代码中的[演进说明](https://github.com/InternLM/lmdeploy/blob/52419bd5b6fb419a5e3aaf3c3b4dea874b17e094/lmdeploy/messages.py#L107)\n\n如果在使用 pipeline 接口遇到该问题，请调低比例，比如\n\n```python\nfrom lmdeploy import pipeline, TurbomindEngineConfig\n\nbackend_config = TurbomindEngineConfig(cache_max_entry_count=0.2)\n\npipe = pipeline('internlm/internlm2_5-7b-chat',\n                backend_config=backend_config)\nresponse = pipe(['Hi, pls intro yourself', 'Shanghai is'])\nprint(response)\n```\n\n如果在使用 CLI 工具时遇到此问题，请传入参数`--cache-max-entry-count`，调低 k/v cache缓存使用比例。比如，\n\n```shell\n# chat 命令\nlmdeploy chat internlm/internlm2_5-7b-chat --cache-max-entry-count 0.2\n\n# server 命令\nlmdeploy serve api_server internlm/internlm2_5-7b-chat --cache-max-entry-count 0.2\n```\n\n## 服务\n\n### Api 服务器获取超时\n\nAPI 服务器的图像 URL 获取超时可通过环境变量 `LMDEPLOY_FETCH_TIMEOUT` 进行配置。默认情况下，请求可能需要长达 10 秒才会超时。\n\n请参阅 [lmdeploy/vl/utils.py](https://github.com/InternLM/lmdeploy/blob/7b6876eafcb842633e0efe8baabe5906d7beeeea/lmdeploy/vl/utils.py#L31) 了解用法。\n\n## 量化\n\n### RuntimeError: \\[enforce fail at inline_container.cc:337\\] . unexpected pos 4566829760 vs 4566829656\n\n请检查你的硬盘空间。\n\n这个错误是因为保存权重时硬盘空间不足导致的，在量化 70B 模型时可能会遇到\n\n### ModuleNotFoundError: No module named 'flash_attn'\n\n量化 `qwen` 模型需要安装 `flash-attn`。但是，根据社区用户的反馈，`flash-attn` 比较难安装。所以，lmdeploy 从依赖列表中移除 `flash-attn`，用户在用到的时候，可以进行手动安装。\n"
  },
  {
    "path": "docs/zh_cn/get_started/ascend/get_started.md",
    "content": "# 华为昇腾\n\n我们基于 LMDeploy 的 PytorchEngine，增加了华为昇腾设备的支持，目前支持的型号是**Atlas 800T A3，Atlas 800T A2和Atlas 300I Duo**。在华为昇腾上使用 LMDeploy 的方法与在英伟达 GPU 上使用 PytorchEngine 后端的方法几乎相同。在阅读本教程之前，请先阅读原版的[快速开始](../get_started.md)。\n\n支持的模型列表在[这里](../../supported_models/supported_models.md#PyTorchEngine-其他平台).\n\n> \\[!IMPORTANT\\]\n> 我们已经在阿里云上提供了构建完成的鲲鹏CPU版本的镜像。\n> 请使用下面的命令来拉取镜像:\n>\n> Atlas 800T A3:\n> `docker pull crpi-4crprmm5baj1v8iv.cn-hangzhou.personal.cr.aliyuncs.com/lmdeploy_dlinfer/ascend:a3-latest`\n> （Atlas 800T A3目前只支持Qwen系列的算子模式下运行）\n>\n> Atlas 800T A2:\n> `docker pull crpi-4crprmm5baj1v8iv.cn-hangzhou.personal.cr.aliyuncs.com/lmdeploy_dlinfer/ascend:a2-latest`\n>\n> Atlas 300I Duo:\n> `docker pull crpi-4crprmm5baj1v8iv.cn-hangzhou.personal.cr.aliyuncs.com/lmdeploy_dlinfer/ascend:300i-duo-latest`\n> （Atlas 300I Duo目前只支持非eager模式）\n>\n> 如果您希望自己构建环境，请参考[这里](../../../../docker)的dockerfile来自己构建。\n\n## 离线批处理\n\n### LLM 推理\n\n将`device_type=\"ascend\"`加入`PytorchEngineConfig`的参数中。\n\n```python\nfrom lmdeploy import pipeline\nfrom lmdeploy import PytorchEngineConfig\npipe = pipeline(\"internlm/internlm2_5-7b-chat\",\n             backend_config=PytorchEngineConfig(tp=1, device_type=\"ascend\"))\nquestion = [\"Shanghai is\", \"Please introduce China\", \"How are you?\"]\nresponse = pipe(question)\nprint(response)\n```\n\n### VLM 推理\n\n将`device_type=\"ascend\"`加入`PytorchEngineConfig`的参数中。\n\n```python\nfrom lmdeploy import pipeline, PytorchEngineConfig\nfrom lmdeploy.vl import load_image\npipe = pipeline('OpenGVLab/InternVL2-2B',\n     backend_config=PytorchEngineConfig(tp=1, device_type='ascend'))\nimage = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg')\nresponse = pipe(('describe this image', image))\nprint(response)\n```\n\n## 在线服务\n\n### LLM 模型服务\n\n将`--device ascend`加入到服务启动命令中。\n\n```bash\nlmdeploy serve api_server --backend pytorch --device ascend internlm/internlm2_5-7b-chat\n```\n\n也可以运行以下命令启动容器运行LLM模型服务。\n\n```bash\ndocker run -it --net=host crpi-4crprmm5baj1v8iv.cn-hangzhou.personal.cr.aliyuncs.com/lmdeploy_dlinfer/ascend:a2-latest \\\n    bash -i -c \"lmdeploy serve api_server --backend pytorch --device ascend internlm/internlm2_5-7b-chat\"\n```\n\n### VLM 模型服务\n\n将`--device ascend`加入到服务启动命令中。\n\n```bash\nlmdeploy serve api_server --backend pytorch --device ascend OpenGVLab/InternVL2-2B\n```\n\n也可以运行以下命令启动容器运行VLM模型服务。\n\n```bash\ndocker run -it --net=host crpi-4crprmm5baj1v8iv.cn-hangzhou.personal.cr.aliyuncs.com/lmdeploy_dlinfer/ascend:a2-latest \\\n    bash -i -c \"lmdeploy serve api_server --backend pytorch --device ascend OpenGVLab/InternVL2-2B\"\n```\n\n## 使用命令行与LLM模型对话\n\n将`--device ascend`加入到服务启动命令中。\n\n```bash\nlmdeploy chat internlm/internlm2_5-7b-chat --backend pytorch --device ascend\n```\n\n也可以运行以下命令使启动容器后开启lmdeploy聊天\n\n```bash\ndocker run -it crpi-4crprmm5baj1v8iv.cn-hangzhou.personal.cr.aliyuncs.com/lmdeploy_dlinfer/ascend:a2-latest \\\n    bash -i -c \"lmdeploy chat --backend pytorch --device ascend internlm/internlm2_5-7b-chat\"\n```\n\n## 量化\n\n### w4a16 AWQ\n\n运行下面的代码可以在Atlas 800T A2上对权重进行W4A16量化。\n\n```bash\nlmdeploy lite auto_awq $HF_MODEL --work-dir $WORK_DIR --device npu\n```\n\n支持的模型列表请参考[支持的模型](../../supported_models/supported_models.md)。\n\n### w8a8 SMOOTH_QUANT\n\n运行下面的代码可以在Atlas 800T A2上对权重进行W8A8量化。\n\n```bash\nlmdeploy lite smooth_quant $HF_MODEL --work-dir $WORK_DIR --device npu\n```\n\n支持的模型列表请参考[支持的模型](../../supported_models/supported_models.md)。\n\n### int8 KV-cache 量化\n\n昇腾后端现在支持了在eager模式下的离线int8 KV-cache量化。\n\n详细使用方式请请参考这篇[文章](https://github.com/DeepLink-org/dlinfer/blob/main/docs/quant/ascend_kv_quant.md)。\n\n## Atlas 300I Duo上的限制\n\n1. 只支持dtype=float16。\n2. 只支持图模式，请不要加上--eager-mode。\n"
  },
  {
    "path": "docs/zh_cn/get_started/camb/get_started.md",
    "content": "# 寒武纪云端加速卡\n\n我们基于 LMDeploy 的 PytorchEngine，增加了寒武纪云端加速卡设备的支持。所以，在寒武纪云端加速卡上使用 LMDeploy 的方法与在英伟达 GPU 上使用 PytorchEngine 后端的方法几乎相同。在阅读本教程之前，请先阅读原版的[快速开始](../get_started.md)。\n\n支持的模型列表在[这里](../../supported_models/supported_models.md#PyTorchEngine-其他平台).\n\n> \\[!IMPORTANT\\]\n> 我们已经在阿里云上提供了构建完成的寒武纪云端加速卡镜像。\n> 请使用下面的命令来拉取镜像:\n>\n> `docker pull crpi-4crprmm5baj1v8iv.cn-hangzhou.personal.cr.aliyuncs.com/lmdeploy_dlinfer/camb:latest`\n\n> \\[!IMPORTANT\\]\n> 目前寒武纪加速卡上启动多卡推理需要手动启动ray。下面是一个2卡的例子：\n>\n> ```shell\n>  export MLU_VISIBLE_DEVICES=0,1\n>  ray start --head --resources='{\"MLU\": 2}'\n> ```\n\n## 离线批处理\n\n### LLM 推理\n\n将`device_type=\"camb\"`加入`PytorchEngineConfig`的参数中。\n\n```python\nfrom lmdeploy import pipeline\nfrom lmdeploy import PytorchEngineConfig\npipe = pipeline(\"internlm/internlm2_5-7b-chat\",\n             backend_config=PytorchEngineConfig(tp=1, device_type=\"camb\"))\nquestion = [\"Shanghai is\", \"Please introduce China\", \"How are you?\"]\nresponse = pipe(question)\nprint(response)\n```\n\n### VLM 推理\n\n将`device_type=\"camb\"`加入`PytorchEngineConfig`的参数中。\n\n```python\nfrom lmdeploy import pipeline, PytorchEngineConfig\nfrom lmdeploy.vl import load_image\npipe = pipeline('OpenGVLab/InternVL2-2B',\n     backend_config=PytorchEngineConfig(tp=1, device_type='camb'))\nimage = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg')\nresponse = pipe(('describe this image', image))\nprint(response)\n```\n\n## 在线服务\n\n### LLM 模型服务\n\n将`--device camb`加入到服务启动命令中。\n\n```bash\nlmdeploy serve api_server --backend pytorch --device camb internlm/internlm2_5-7b-chat\n```\n\n也可以运行以下命令启动容器运行LLM模型服务。\n\n```bash\ndocker run -it --net=host crpi-4crprmm5baj1v8iv.cn-hangzhou.personal.cr.aliyuncs.com/lmdeploy_dlinfer/camb:latest \\\n    bash -i -c \"lmdeploy serve api_server --backend pytorch --device camb internlm/internlm2_5-7b-chat\"\n```\n\n### VLM 模型服务\n\n将`--device camb`加入到服务启动命令中。\n\n```bash\nlmdeploy serve api_server --backend pytorch --device camb OpenGVLab/InternVL2-2B\n```\n\n也可以运行以下命令启动容器运行VLM模型服务。\n\n```bash\ndocker run -it --net=host crpi-4crprmm5baj1v8iv.cn-hangzhou.personal.cr.aliyuncs.com/lmdeploy_dlinfer/camb:latest \\\n    bash -i -c \"lmdeploy serve api_server --backend pytorch --device camb OpenGVLab/InternVL2-2B\"\n```\n\n## 使用命令行与LLM模型对话\n\n将`--device camb`加入到服务启动命令中。\n\n```bash\nlmdeploy chat internlm/internlm2_5-7b-chat --backend pytorch --device camb\n```\n\n也可以运行以下命令使启动容器后开启lmdeploy聊天\n\n```bash\ndocker run -it crpi-4crprmm5baj1v8iv.cn-hangzhou.personal.cr.aliyuncs.com/lmdeploy_dlinfer/camb:latest \\\n    bash -i -c \"lmdeploy chat --backend pytorch --device camb internlm/internlm2_5-7b-chat\"\n```\n"
  },
  {
    "path": "docs/zh_cn/get_started/get_started.md",
    "content": "# 快速开始\n\nLMDeploy提供了快速安装、模型量化、离线批处理、在线推理服务等功能。每个功能只需简单的几行代码或者命令就可以完成。\n\n本教程将展示 LMDeploy 在以下几方面的使用方法：\n\n- LLM 模型和 VLM 模型的离线推理\n- 搭建与 OpenAI 接口兼容的 LLM 或 VLM 模型服务\n- 通过控制台命令行与 LLM 模型进行交互式聊天\n\n在继续阅读之前，请确保你已经按照[安装指南](installation.md)安装了 lmdeploy。\n\n## 离线批处理\n\n### LLM 推理\n\n```python\nimport lmdeploy\npipe = lmdeploy.pipeline(\"internlm/internlm2_5-7b-chat\")\nresponse = pipe([\"Hi, pls intro yourself\", \"Shanghai is\"])\nprint(response)\n```\n\n在构造 `pipeline` 时，如果没有指定使用 TurboMind 引擎或 PyTorch 引擎进行推理，LMDeploy 将根据[它们各自的能力](../supported_models/supported_models.md)自动分配一个，默认优先使用 TurboMind 引擎。\n\n然而，你可以选择手动选择一个引擎。例如，\n\n```python\nfrom lmdeploy import pipeline, TurbomindEngineConfig\npipe = pipeline('internlm/internlm2_5-7b-chat',\n                backend_config=TurbomindEngineConfig(\n                    max_batch_size=32,\n                    enable_prefix_caching=True,\n                    cache_max_entry_count=0.8,\n                    session_len=8192,\n                ))\n```\n\n或者，\n\n```python\nfrom lmdeploy import pipeline, PytorchEngineConfig\npipe = pipeline('internlm/internlm2_5-7b-chat',\n                backend_config=PytorchEngineConfig(\n                    max_batch_size=32,\n                    enable_prefix_caching=True,\n                    cache_max_entry_count=0.8,\n                    session_len=8192,\n                ))\n```\n\n```{note}\n参数 \"cache_max_entry_count\" 显著影响 GPU 内存占用。它表示加载模型权重后 K/V 缓存占用的空闲 GPU 内存的比例。\n默认值是 0.8。K/V 缓存分配方式是一次性申请，重复性使用，这就是为什么 pipeline 以及下文中的 api_server 在启动后会消耗大量 GPU 内存。\n如果你遇到内存不足(OOM)错误的错误，可能需要考虑降低 cache_max_entry_count 的值。\n```\n\n当使用 `pipe()` 生成提示词的 token 时，你可以通过 `GenerationConfig` 设置采样参数，如下所示：\n\n```python\nfrom lmdeploy import GenerationConfig, pipeline\n\npipe = pipeline('internlm/internlm2_5-7b-chat')\nprompts = ['Hi, pls intro yourself', 'Shanghai is']\nresponse = pipe(prompts,\n                gen_config=GenerationConfig(\n                    max_new_tokens=1024,\n                    top_p=0.8,\n                    top_k=40,\n                    temperature=0.6\n                ))\n```\n\n在 `GenerationConfig` 中，`top_k=1` 或 `temperature=0.0` 表示贪心搜索。\n\n有关 pipeline 的更多信息，请参考[这里](../llm/pipeline.md)\n\n### VLM 推理\n\nVLM 推理 pipeline 与 LLM 类似，但增加了使用 pipeline 处理图像数据的能力。例如，你可以使用以下代码片段对 InternVL 模型进行推理：\n\n```python\nfrom lmdeploy import pipeline\nfrom lmdeploy.vl import load_image\n\npipe = pipeline('OpenGVLab/InternVL2-8B')\n\nimage = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg')\nresponse = pipe(('describe this image', image))\nprint(response)\n```\n\n在 VLM pipeline 中，默认的图像处理批量大小是 1。这可以通过 `VisionConfig` 调整。例如，你可以这样设置：\n\n```python\nfrom lmdeploy import pipeline, VisionConfig\nfrom lmdeploy.vl import load_image\n\npipe = pipeline('OpenGVLab/InternVL2-8B',\n                vision_config=VisionConfig(\n                    max_batch_size=8\n                ))\n\nimage = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg')\nresponse = pipe(('describe this image', image))\nprint(response)\n```\n\n然而，图像批量大小越大，OOM 错误的风险越大，因为 VLM 模型中的 LLM 部分会提前预分配大量的内存。\n\nVLM pipeline 对于推理引擎的选择方式与 LLM pipeline 类似。你可以参考 [LLM 推理](#llm-推理)并结合两个引擎支持的 VLM 模型列表，手动选择和配置推理引擎。\n\n## 模型服务\n\n类似前文[离线批量推理](#离线批处理)，我们在本章节介绍 LLM 和 VLM 各自构建服务方法。\n\n### LLM 模型服务\n\n```shell\nlmdeploy serve api_server internlm/internlm2_5-7b-chat\n```\n\n此命令将在本地主机上的端口 `23333` 启动一个与 OpenAI 接口兼容的模型推理服务。你可以使用 `--server-port` 选项指定不同的服务器端口。\n更多选项，请通过运行 `lmdeploy serve api_server --help` 查阅帮助文档。这些选项大多与引擎配置一致。\n\n要访问服务，你可以使用官方的 OpenAI Python 包 `pip install openai`。以下是演示如何使用入口点 v1/chat/completions 的示例：\n\n```python\nfrom openai import OpenAI\nclient = OpenAI(\n    api_key='YOUR_API_KEY',\n    base_url=\"http://0.0.0.0:23333/v1\"\n)\nmodel_name = client.models.list().data[0].id\nresponse = client.chat.completions.create(\n  model=model_name,\n  messages=[\n    {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n    {\"role\": \"user\", \"content\": \" provide three suggestions about time management\"},\n  ],\n    temperature=0.8,\n    top_p=0.8\n)\nprint(response)\n```\n\n我们鼓励你参考详细指南，了解关于[使用 Docker 部署服务](../llm/api_server.md)、[工具调用](../llm/api_server_tools.md)和其他更多功能的信息。\n\n### VLM 模型服务\n\n```shell\nlmdeploy serve api_server OpenGVLab/InternVL2-8B\n```\n\n```{note}\nLMDeploy 复用了上游 VLM 仓库的视觉组件。而每个上游的 VLM 模型，它们的视觉模型可能互不相同，依赖库也各有区别。\n因此，LMDeploy 决定不在自身的依赖列表中加入上游 VLM 库的依赖。如果你在使用 LMDeploy 推理 VLM 模型时出现 \"ImportError\" 的问题，请自行安装相关的依赖。\n```\n\n服务成功启动后，你可以以类似访问 `gptv4` 服务的方式访问 VLM 服务：\n\n```python\nfrom openai import OpenAI\n\nclient = OpenAI(api_key='YOUR_API_KEY', # A dummy api_key is required\n                base_url='http://0.0.0.0:23333/v1')\nmodel_name = client.models.list().data[0].id\nresponse = client.chat.completions.create(\n    model=model_name,\n    messages=[{\n        'role':\n        'user',\n        'content': [{\n            'type': 'text',\n            'text': 'Describe the image please',\n        }, {\n            'type': 'image_url',\n            'image_url': {\n                'url':\n                'https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg',\n            },\n        }],\n    }],\n    temperature=0.8,\n    top_p=0.8)\nprint(response)\n```\n\n## 使用命令行与 LLM 模型对话\n\nLMDeploy 提供了一个非常方便的 CLI 工具，供用户与 LLM 模型进行本地聊天。例如：\n\n```shell\nlmdeploy chat internlm/internlm2_5-7b-chat --backend turbomind\n```\n\n它的设计目的是帮助用户检查和验证 LMDeploy 是否支持提供的模型，聊天模板是否被正确应用，以及推理结果是否正确。\n\n另外，`lmdeploy check_env` 收集基本的环境信息。在给 LMDeploy 提交问题报告时，这非常重要，因为它有助于我们更有效地诊断和解决问题。\n\n如果你对它们的使用方法有任何疑问，你可以尝试使用 `--help` 选项获取详细信息。\n"
  },
  {
    "path": "docs/zh_cn/get_started/index.rst",
    "content": "其他软硬件平台\n=================================\n\n.. toctree::\n   :maxdepth: 1\n   :caption: OtherPF\n\n   ascend/get_started.md\n   maca/get_started.md\n   camb/get_started.md\n"
  },
  {
    "path": "docs/zh_cn/get_started/installation.md",
    "content": "# 安装\n\nLMDeploy 是一个用于大型语言模型（LLMs）和视觉-语言模型（VLMs）压缩、部署和服务的 Python 库。\n其核心推理引擎包括 TurboMind 引擎和 PyTorch 引擎。前者由 C++ 和 CUDA 开发，致力于推理性能的优化，而后者纯 Python 开发，旨在降低开发者的门槛。\n\nLMDeploy 支持在 Linux 和 Windows 平台上部署 LLMs 和 VLMs，最低要求 CUDA 版本为 11.3。此外，它还与以下 NVIDIA GPU 兼容：\n\nVolta(sm70): V100\nTuring(sm75): 20 系列，T4\nAmpere(sm80,sm86): 30 系列，A10, A16, A30, A100\nAda Lovelace(sm89): 40 系列\n\n## 使用 pip 安装（推荐）\n\n我们推荐在一个干净的conda环境下（python3.9 - 3.13），安装 lmdeploy：\n\n```shell\nconda create -n lmdeploy python=3.10 -y\nconda activate lmdeploy\npip install lmdeploy\n```\n\n默认的预构建包是在 **CUDA 12** 上编译的。如果需要 CUDA 11+ (>=11.3)，你可以使用以下命令安装 lmdeploy：\n\n```shell\nexport LMDEPLOY_VERSION=0.12.2\nexport PYTHON_VERSION=310\npip install https://github.com/InternLM/lmdeploy/releases/download/v${LMDEPLOY_VERSION}/lmdeploy-${LMDEPLOY_VERSION}+cu118-cp${PYTHON_VERSION}-cp${PYTHON_VERSION}-manylinux2014_x86_64.whl --extra-index-url https://download.pytorch.org/whl/cu118\n```\n\n## 从源码安装\n\n默认情况下，LMDeploy 将面向 NVIDIA CUDA 环境进行编译安装，并同时启用 Turbomind 和 PyTorch 两种后端引擎。在安装 LMDeploy 之前，请确保已成功安装 CUDA 工具包。\n\n成功安装 CUDA 工具包后，您可以使用以下单行命令构建并安装 LMDeploy：\n\n```shell\npip install git+https://github.com/InternLM/lmdeploy.git\n```\n\n您还可以通过设置 `DISABLE_TURBOMIND` 环境变量，显式禁用 Turbomind 后端，以避免 CUDA 编译：\n\n```shell\nDISABLE_TURBOMIND=1 pip install git+https://github.com/InternLM/lmdeploy.git\n```\n\n如果您希望使用特定版本，而不是 LMDeploy 的 `main` 分支，可以在命令行中指定：\n\n```shell\npip install https://github.com/InternLM/lmdeploy/archive/refs/tags/v0.11.0.zip\n```\n\n如果您希望构建支持昇腾、寒武纪或沐熙的 LMDeploy，请使用相应的 `LMDEPLOY_TARGET_DEVICE` 环境变量进行安装。\n\nLMDeploy 也支持在 AMD GPU 的 ROCm 环境中安装。\n\n```shell\n#The recommended way is to use the official ROCm PyTorch Docker image with pre-installed dependencies:\ndocker run -it \\\n    --cap-add=SYS_PTRACE \\\n    --security-opt seccomp=unconfined \\\n    --device=/dev/kfd \\\n    --device=/dev/dri \\\n    --group-add video \\\n    --ipc=host \\\n    --network=host \\\n    --shm-size 32G \\\n    -v /root:/workspace \\\n    rocm/pytorch:latest\n\n\n#Once inside the container, install LMDeploy with ROCm support:\nLMDEPLOY_TARGET_DEVICE=rocm pip install  git+https://github.com/InternLM/lmdeploy.git\n```\n"
  },
  {
    "path": "docs/zh_cn/get_started/maca/get_started.md",
    "content": "# 沐曦C500\n\n我们基于 LMDeploy 的 PytorchEngine，增加了沐曦C500设备的支持。所以，在沐曦上使用 LMDeploy 的方法与在英伟达 GPU 上使用 PytorchEngine 后端的方法几乎相同。在阅读本教程之前，请先阅读原版的[快速开始](../get_started.md)。\n\n支持的模型列表在[这里](../../supported_models/supported_models.md#PyTorchEngine-其他平台).\n\n> \\[!IMPORTANT\\]\n> 我们已经在阿里云上提供了构建完成的沐曦的镜像。\n> 请使用下面的命令来拉取镜像:\n> `docker pull crpi-4crprmm5baj1v8iv.cn-hangzhou.personal.cr.aliyuncs.com/lmdeploy_dlinfer/maca:latest`\n\n## 离线批处理\n\n### LLM 推理\n\n将`device_type=\"maca\"`加入`PytorchEngineConfig`的参数中。\n\n```python\nfrom lmdeploy import pipeline\nfrom lmdeploy import PytorchEngineConfig\npipe = pipeline(\"internlm/internlm2_5-7b-chat\",\n             backend_config=PytorchEngineConfig(tp=1, device_type=\"maca\"))\nquestion = [\"Shanghai is\", \"Please introduce China\", \"How are you?\"]\nresponse = pipe(question)\nprint(response)\n```\n\n### VLM 推理\n\n将`device_type=\"maca\"`加入`PytorchEngineConfig`的参数中。\n\n```python\nfrom lmdeploy import pipeline, PytorchEngineConfig\nfrom lmdeploy.vl import load_image\npipe = pipeline('OpenGVLab/InternVL2-2B',\n     backend_config=PytorchEngineConfig(tp=1, device_type='maca'))\nimage = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg')\nresponse = pipe(('describe this image', image))\nprint(response)\n```\n\n## 在线服务\n\n### LLM 模型服务\n\n将`--device maca`加入到服务启动命令中。\n\n```bash\nlmdeploy serve api_server --backend pytorch --device maca internlm/internlm2_5-7b-chat\n```\n\n也可以运行以下命令启动容器运行LLM模型服务。\n\n```bash\ndocker run -it --net=host crpi-4crprmm5baj1v8iv.cn-hangzhou.personal.cr.aliyuncs.com/lmdeploy_dlinfer/maca:latest \\\n    bash -i -c \"lmdeploy serve api_server --backend pytorch --device maca internlm/internlm2_5-7b-chat\"\n```\n\n### VLM 模型服务\n\n将`--device maca`加入到服务启动命令中。\n\n```bash\nlmdeploy serve api_server --backend pytorch --device maca OpenGVLab/InternVL2-2B\n```\n\n也可以运行以下命令启动容器运行VLM模型服务。\n\n```bash\ndocker run -it --net=host crpi-4crprmm5baj1v8iv.cn-hangzhou.personal.cr.aliyuncs.com/lmdeploy_dlinfer/maca:latest \\\n    bash -i -c \"lmdeploy serve api_server --backend pytorch --device maca OpenGVLab/InternVL2-2B\"\n```\n\n## 使用命令行与LLM模型对话\n\n将`--device maca`加入到服务启动命令中。\n\n```bash\nlmdeploy chat internlm/internlm2_5-7b-chat --backend pytorch --device maca\n```\n\n也可以运行以下命令使启动容器后开启lmdeploy聊天\n\n```bash\ndocker run -it crpi-4crprmm5baj1v8iv.cn-hangzhou.personal.cr.aliyuncs.com/lmdeploy_dlinfer/maca:latest \\\n    bash -i -c \"lmdeploy chat --backend pytorch --device maca internlm/internlm2_5-7b-chat\"\n```\n"
  },
  {
    "path": "docs/zh_cn/index.rst",
    "content": "欢迎来到 LMDeploy 的中文教程！\n====================================\n\n.. figure:: ./_static/image/lmdeploy-logo.svg\n  :width: 50%\n  :align: center\n  :alt: LMDeploy\n  :class: no-scaled-link\n\n.. raw:: html\n\n   <p style=\"text-align:center\">\n   <strong>LMDeploy 是一个高效且友好的 LLMs 模型部署工具箱，功能涵盖了量化、推理和服务\n   </strong>\n   </p>\n\n   <p style=\"text-align:center\">\n   <script async defer src=\"https://buttons.github.io/buttons.js\"></script>\n   <a class=\"github-button\" href=\"https://github.com/InternLM/lmdeploy\" data-show-count=\"true\" data-size=\"large\" aria-label=\"Star\">Star</a>\n   <a class=\"github-button\" href=\"https://github.com/InternLM/lmdeploy/subscription\" data-icon=\"octicon-eye\" data-size=\"large\" aria-label=\"Watch\">Watch</a>\n   <a class=\"github-button\" href=\"https://github.com/InternLM/lmdeploy/fork\" data-icon=\"octicon-repo-forked\" data-size=\"large\" aria-label=\"Fork\">Fork</a>\n   </p>\n\nLMDeploy 工具箱提供以下核心功能：\n\n- **高效的推理：** LMDeploy 开发了 Persistent Batch(即 Continuous Batch)，Blocked K/V Cache，动态拆分和融合，张量并行，高效的计算 kernel等重要特性。推理性能是 vLLM 的 1.8 倍\n\n- **可靠的量化：** LMDeploy 支持权重量化和 k/v 量化。4bit 模型推理效率是 FP16 下的 2.4 倍。量化模型的可靠性已通过 OpenCompass 评测得到充分验证。\n\n- **便捷的服务：** 通过请求分发服务，LMDeploy 支持多模型在多机、多卡上的推理服务。\n\n- **卓越的兼容性:**  LMDeploy 支持 `KV Cache 量化 <https://lmdeploy.readthedocs.io/zh-cn/latest/quantization/kv_quant.html>`_, `AWQ <https://lmdeploy.readthedocs.io/zh-cn/latest/quantization/w4a16.html>`_ 和 `Automatic Prefix Caching <https://lmdeploy.readthedocs.io/zh-cn/latest/inference/turbomind_config.html>`_ 同时使用。\n\n中文文档\n--------\n\n.. _快速上手:\n.. toctree::\n   :maxdepth: 2\n   :caption: 快速上手\n\n   get_started/installation.md\n   get_started/get_started.md\n   get_started/index.rst\n\n.. _支持的模型:\n.. toctree::\n   :maxdepth: 1\n   :caption: 模型列表\n\n   supported_models/supported_models.md\n   supported_models/reward_models.md\n\n.. _llm_部署:\n.. toctree::\n   :maxdepth: 1\n   :caption: 大语言模型(LLMs)部署\n\n   llm/pipeline.md\n   llm/api_server.md\n   llm/api_server_tools.md\n   llm/api_server_reasoning.md\n   llm/api_server_lora.md\n   llm/proxy_server.md\n\n.. _vlm_部署:\n.. toctree::\n   :maxdepth: 1\n   :caption: 视觉-语言模型(VLMs)部署\n\n   multi_modal/vl_pipeline.md\n   multi_modal/api_server_vl.md\n   multi_modal/index.rst\n\n\n.. _量化:\n.. toctree::\n   :maxdepth: 1\n   :caption: 量化\n\n   quantization/w4a16.md\n   quantization/w8a8.md\n   quantization/kv_quant.md\n   quantization/llm_compressor.md\n\n.. _测试基准:\n.. toctree::\n   :maxdepth: 1\n   :caption: 测试基准\n\n   benchmark/benchmark.md\n   benchmark/evaluate_with_opencompass.md\n   benchmark/evaluate_with_vlmevalkit.md\n\n.. toctree::\n   :maxdepth: 1\n   :caption: 进阶指南\n\n   inference/turbomind.md\n   inference/pytorch.md\n   advance/pytorch_new_model.md\n   advance/long_context.md\n   advance/chat_template.md\n   advance/debug_turbomind.md\n   advance/structed_output.md\n   advance/pytorch_multinodes.md\n   advance/pytorch_profiling.md\n   advance/metrics.md\n   advance/context_parallel.md\n   advance/spec_decoding.md\n   advance/update_weights.md\n\n.. toctree::\n   :maxdepth: 1\n   :caption: API 文档\n\n   api/pipeline.rst\n   api/openapi.rst\n   api/cli.rst\n\n索引与表格\n==================\n\n* :ref:`genindex`\n* :ref:`search`\n* :ref:`routingtable`\n"
  },
  {
    "path": "docs/zh_cn/inference/load_hf.md",
    "content": "# 直接读取 huggingface 模型\n\n从 v0.1.0 开始，Turbomid 添加了直接读取 Huggingface 格式权重的能力。\n\n## 支持的类型\n\n目前，TurboMind 支持加载三种类型的模型：\n\n1. 在 huggingface.co 上面通过 lmdeploy 量化的模型，如 [llama2-70b-4bit](https://huggingface.co/lmdeploy/llama2-chat-70b-4bit), [internlm-chat-20b-4bit](https://huggingface.co/internlm/internlm-chat-20b-4bit)\n2. huggingface.co 上面其他 LM 模型，如Qwen/Qwen-7B-Chat\n\n## 使用方式\n\n### 1) 通过 lmdeploy 量化的模型\n\n对于通过 `lmdeploy.lite` 量化的模型，TurboMind 可以直接加载，比如 [llama2-70b-4bit](https://huggingface.co/lmdeploy/llama2-chat-70b-4bit), [internlm-chat-20b-4bit](https://huggingface.co/internlm/internlm-chat-20b-4bit).\n\n```\nrepo_id=internlm/internlm-chat-20b-4bit\nmodel_name=internlm-chat-20b\n\n# or\n# repo_id=/path/to/downloaded_model\n\n# Inference by TurboMind\nlmdeploy chat $repo_id --model-name $model_name\n\n\n# Serving with Restful API\nlmdeploy serve api_server $repo_id --model-name $model_name --tp 1\n```\n\n### 2) 其他的 LM 模型\n\n其他 LM 模型比如 Qwen/Qwen-7B-Chat, baichuan-inc/Baichuan2-7B-Chat。LMDeploy 模型支持情况可通过 `lmdeploy list` 查看。\n\n```\nrepo_id=Qwen/Qwen-7B-Chat\nmodel_name=qwen-7b\n# or\n# repo_id=/path/to/Qwen-7B-Chat/local_path\n\n# Inference by TurboMind\nlmdeploy chat $repo_id --model-name $model_name\n\n# Serving with Restful API\nlmdeploy serve api_server $repo_id --model-name $model_name --tp 1\n```\n"
  },
  {
    "path": "docs/zh_cn/inference/pytorch.md",
    "content": "# lmdeploy.pytorch 架构\n\n`lmdeploy.pytorch` 是 LMDeploy 提供的推理后端之一。与着重于性能的 turbomind 相比，lmdeploy.pytorch 以较小的性能开销为代价，提供了一套更容易开发与扩展的大模型推理实现。\n\n## 设计\n\n![pytorch arch](https://github.com/grimoire/lmdeploy/blob/media/lmdeploy_pytorch_arch.png?raw=true)\n\n## API\n\nlmdeploy.pytorch 可以与 turbomind 共享同样的服务接口，这些服务接口通过 Engine 与 EngineInstance 与 lmdeploy.pytorch 进行交互。\n\nEngineInstance 是推理请求的发起者，它会将推理请求组织成特定格式发送给 Engine，以此实现流式推理。EngineInstance 的推理接口是线程安全的，服务发起者可以在不同线程中启动各自的 EngineInstance，Engine 回根据当前资源与推理请求自动进行 batch 化处理。\n\nEngine 是推理请求的接收与执行者。它包含如下的组件来完成这项任务：\n\n- ModelAgent 对象负责模型的加载、缓存管理以及 tensor parallelism 的管理。\n- Scheduler 对象负责 session 的管理，sequence 与 lora adapter 所需要的资源的分配。\n- RequestManager 负责请求的发送与接收，可以通过它与 EngineInstance 交互。\n\n## Engine\n\n为了应对异步推理请求，Engine 在启动后会维护一个线程，循环如下操作：\n\n1. 通过 RequestManager 读取请求，对各种请求进行分类处理。\n2. Scheduler 规划哪些请求可以被处理，以及它们所需的缓存和 adapters。\n3. ModelAgent 根据步骤 2. 得到的信息为输入分配资源，然后使用 patch 后的模型进行推理\n4. Scheduler 根据推理结果更新请求状态\n5. RequestManager 将输出返回给发送者（EngineInstance），回到步骤 1.\n\n下面我们将介绍上述步骤中用到的几个重要组件\n\n### Scheduler\n\n在进行大模型的推理时，通常会把 attention 的历史输入 key 和 value 缓存起来，以避免在未来的推理中进行重复计算。这种情况下如果要进行多 batch 的推理，由于不同数据的序列长度可能不同，kv 会进行大量的填充，浪费很多显存资源，也限制了模型的并发推理能力上限。\n\n[vLLM](https://docs.vllm.ai) 提了一种 paging 策略，以 page block 为单位为 key value 分配缓存，这样就可以避免由于 padding 导致的显存浪费。 lmdeploy.pytorch 中的 Scheduler 也遵循同样的设计，根据请求的长度合理分配所需的资源，并撤出暂时不使用的资源以保证存储资源的高效利用。\n\nlmdeploy.pytorch 还对 [S-LoRA](https://github.com/S-LoRA/S-LoRA) 的支持，S-LoRA 是一种对单模型多 adapter 的支持方案。LoRA 在推理时通常会把 adapter 融合进模型权重当中，同时使用复数个 adapter 会导致显存使用量的激增；S-LoRA 不对 adapter 进行融合，通过使用 unified paging，在推理时动态换入需要使用的 adapter，大幅降低了使用 adapter 的显存开销。Scheduler 中也实现了相关的功能，让用户可以更方便的使用自己的 adapter.\n\n### ModelAgent\n\nlmdeploy.pytorch 中对 Tensor Parallelism（TP）进行了支持，不同的 TP 参数对模型的构造、权重处理、分配 cache 都存在影响。ModelAgent 对这些内容进行了封装，让 Engine 不用再关心这部分细节。\n\nModelAgent 有两个重要组件：\n\n1. patched_model 是更新后的 transformer 模型，更新后的模型添加了各种功能的支持，包括更高性能的子模块实现、TP、量化等等\n2. cache_engine 是缓存的分配与交换模块。它接收来自 scheduler 的交换请求，执行 host-device 间显存交换，adapter 加载等工作\n\n## 特性\n\n- **Continuous Batching**: 由于输入序列的长度不一样，batching 通常需要对输入进行 padding，这种 padding 会导致后续运算的计算量增加、影响速度，也会使得显存的占用大幅增加。遵循许多其他成熟框架的方案，lmdeploy.pytorch 采用了 continuous batching 的方式对输入做了连续化处理，避免了多余的资源占用。\n\n- **Tensor Parallelism**: 大模型可能会占用远超一张显卡的显存量，为了支持这样的大模型的推理，我们实现了 Tensor 并发，模型的权重会被分布在不同的设备中，每张 GPU 设备负责一部分计算，减少了单卡显存占用，也充分利用了多显卡的计算优势。\n\n- **S-LoRA**: LoRA adapter 可以帮助我们使用有限的显存来调优大模型，S-LoRA 可以帮助我们在有限的显存中同时使用复数个 LoRA 权重，扩展模型的能力。\n\n- **Quantization**: 量化可以帮助我们进一步减少显存占用，提高推理性能。lmdeploy.pytorch 分支中添加了 w8a8 模型量化的支持，可以阅读 [w8a8](../quantization/w8a8.md) 了解更多细节。\n"
  },
  {
    "path": "docs/zh_cn/inference/turbomind.md",
    "content": "# TurboMind 框架\n\nTurboMind 是一款关于 LLM 推理的高效推理引擎，基于英伟达的 [FasterTransformer](https://github.com/NVIDIA/FasterTransformer) 研发而成。它的主要功能包括：LLaMa 结构模型的支持，persistent batch 推理模式和可扩展的 KV 缓存管理器。\n\n## TurboMind 结构\n\n```\n  +--------------------+\n  |        API         |\n  +--------------------+\n          |    ^\n    请 求  |    | 流式回调\n          v    |\n  +--------------------+    获取   +-------------------+\n  |  Persistent Batch  | <-------> |  KV Cache 管理器 |\n  +--------------------+    更新   +-------------------+\n             ^\n             |\n             v\n+------------------------+\n|      LLaMa推理实现      |\n+------------------------+\n| FT kernels & utilities |\n+------------------------+\n```\n\n## Persistent Batch\n\n你也许在别的项目中看到这项机制的另一个名字： `continuous batching` 。在开发这个功能时，我们将对话式 LLM 的推理建模为一个持续运行的 batch ，其生命周期跨越整个服务过程，故将其命名为 `persistent batch` 。简单来说是这样实现的：\n\n- 该功能会预先准备好 N 个 batch slots。\n- 当有空闲 slots 时， 请求就会加入到 batch 中。当请求对应的 tokens 都生成完毕后，对应的 batch slot 会立刻被释放，接收新的请求。\n- **当一个 sequence 命中缓存时（见下文），它的历史 token 不必在每轮中都进行解码，所以它的 token 生成过程会即刻开始**。\n- 整个 batch 会自动扩缩容来避免不必要的计算。\n\n## KV 缓存管理器\n\nTurboMind 的 [KV 缓存管理器](https://github.com/InternLM/lmdeploy/blob/main/src/turbomind/models/llama/SequenceManager.h) 是一个内存池类型的对象，并且在其中加入了 LRU 的实现，这样整个管理器可以被看作是一个 **KV 缓存的缓存**。大致工作方式如下：\n\n- KV 缓存由管理器分配。管理器会根据预先配置好的 slot 数量开辟空间。每个 slot 对应于一个 sequence 所需的 KV 缓存。分配的内存块大小可通过配置来实现预分配或者按需分配（或介于两者之间）。\n- 当有新的请求，但是缓存池中没有空闲 slot时，根据 LRU 机制，管理器会踢除最近使用最少的 sequence，把它占据的 slot 分给新的请求。不仅仅如此，\n- sequence获取到了slot，类似缓存命中。它在缓存中的历史KV会被直接返回，而不用再进行context decoding 。\n- 被踢除的 sequences 不会被完全的删除，而是会被转换成最简洁的形式，例如 token IDs 。当之后获取到相同的 sequence id 时 (即 _cache-miss_ 状态)，这些 token IDs 将被 FMHA 的 context decoder 解码并被转回 KV 缓存。\n- 踢除和转换均由 TurboMind 内部自动管理所以对用户来说是透明的。__从用户的使用角度来看，使用了 TurboMind 的系统就像是可以访问无限的设备内存__。\n\n## TurboMind 的 LLaMa 实现\n\n我们对 LLaMa 系列模型的实现是从 FasterTransformer 中的 Gpt-NeX 模型修改而来的。除了对 LLaMa 系列进行基本重构和修改外，我们还做了一些改进以实现会话模型的高性能推理，其中最重要的是：\n\n- 支持多轮对话中的快速文本解码。我们用基于 [cutlass](https://github.com/NVIDIA/cutlass) 的 FMHA 实现替代了 context decoder 中的注意力机制实现，从而支持了 Q/K 长度不匹配的情况。\n- 我们在 context FMHA 和 generation FMHA 中都加入了间接缓冲指针，支持 batch 中不连续的 KV 缓存。\n- 为了支持 persistent batch 的并发推理，我们设计了新的同步机制来协调在张量并型模式下的工作线程。\n- 我们实现了 INT8 KV cache，降低了内存开销，提高了批处理大小和系统吞吐量。这在实际场景中非常有用，因为相比权重和其他激活，KV cache 会消耗更多的内存和内存带宽。\n- 我们解决了单个进程内多个模型实例在 TP 模式下运行时 NCCL 卡住的问题。NCCL APIs 现由 host 端的同步 barriers 保护。\n\n## API\n\nTurboMind 的 Python API 支持流式结果返回和张量并行模式。\n\n## TurboMind 和 FasterTransformer 的区别\n\n除了上文中提到的功能外，TurboMind 相较于 FasterTransformer 还有不少差别。譬如不少 FasterTransformer 的功能在 TurboMind 中都被去掉了，这其中包括前缀提示词、 beam search 、上下文 embedding、稀疏化 GEMM 操作和对应 GPT 或 T5 等结构的模型的支持等等。\n\n## FAQ\n\n### 对 Huggingface 模型的支持\n\n因为历史因素， TurboMind 的权重设计是基于 [LLaMa 的官方实现](https://github.com/facebookresearch/llama) 完成的，两者只相差一个转置操作。但是 Huggingface 版本的实现却是[另一种形式](https://github.com/huggingface/transformers/blob/45025d92f815675e483f32812caa28cce3a960e7/src/transformers/models/llama/convert_llama_weights_to_hf.py#L123C76-L123C76)，两种权重实现方式在 `W_q` 和 `W_k` 上的区别我们在 [deploy.py](https://github.com/InternLM/lmdeploy/blob/ff4648a1d09e5aec74cf70efef35bfaeeac552e0/lmdeploy/serve/turbomind/deploy.py#L398) 进行了适配处理，用户可前往查看。\n"
  },
  {
    "path": "docs/zh_cn/inference/turbomind_config.md",
    "content": "# TurboMind 配置\n\nTurboMind 是 LMDeploy 的推理引擎，在用它推理 LLM 模型时，需要把输入模型转成 TurboMind 模型。在 TurboMind 的模型文件夹中，除模型权重外，TurboMind 模型还包括其他一些文件，其中最重要的是和推理性能息息相关的配置文件`triton_models/weights/config.ini`。\n\n如果你使用的是 LMDeploy 0.0.x 版本，请参考[turbomind 1.0 配置](#turbomind-10-配置)章节，了解配置中的相关内容。如果使用的是 LMDeploy 0.1.x 版本，请阅读[turbomind 2.x 配置](#turbomind-2x-配置)了解配置细节。\n\n## TurboMind 2.x 配置\n\n以 `llama-2-7b-chat` 模型为例，在 TurboMind 2.x 中，它的`config.ini`内容如下：\n\n```toml\n[llama]\nmodel_name = \"llama2\"\ntensor_para_size = 1\nhead_num = 32\nkv_head_num = 32\nvocab_size = 32000\nnum_layer = 32\ninter_size = 11008\nnorm_eps = 1e-06\nattn_bias = 0\nstart_id = 1\nend_id = 2\nsession_len = 4104\nweight_type = \"fp16\"\nrotary_embedding = 128\nrope_theta = 10000.0\nsize_per_head = 128\ngroup_size = 0\nmax_batch_size = 64\nmax_context_token_num = 1\nstep_length = 1\ncache_max_entry_count = 0.5\ncache_block_seq_len = 128\ncache_chunk_size = 1\nenable_prefix_caching = false\nquant_policy = 0\nmax_position_embeddings = 2048\nrope_scaling_factor = 0.0\nuse_logn_attn = 0\n```\n\n这些参数由模型属性和推理参数组成。模型属性包括层数、head个数、维度等等，它们**不可修改**\n\n```toml\nmodel_name = \"llama2\"\nhead_num = 32\nkv_head_num = 32\nvocab_size = 32000\nnum_layer = 32\ninter_size = 11008\nnorm_eps = 1e-06\nattn_bias = 0\nstart_id = 1\nend_id = 2\nrotary_embedding = 128\nrope_theta = 10000.0\nsize_per_head = 128\n```\n\n和 TurboMind 1.0 config 相比，TurboMind 2.x config 中的模型属性部分和 1.0 一致，但推理参数发生了变化。\n\n在接下来的章节中，我们重点介绍推理参数。\n\n### 数据类型\n\n和数据类型相关的参数是 `weight_type` 和 `group_size`。它们**不可被修改**。\n\n`weight_type` 表示权重的数据类型。目前支持 fp16 和 int4。int4 表示 4bit 权重。当 `weight_type`为 4bit 权重时，`group_size` 表示 `awq` 量化权重时使用的 group 大小。目前，在 LMDeploy 的预编译包中，使用的是 `group_size = 128`。\n\n### 批处理大小\n\n仍通过 `max_batch_size` 设置最大批处理量。默认值由原来的 32 改成 64。\n在 TurboMind 2.x 中，`max_batch_size` 和 `cache_max_entry_count`无关。\n\n### k/v 缓存大小\n\n`cache_block_seq_len` 和 `cache_max_entry_count` 用来调节 k/v cache 的内存大小。\n\nTurboMind 2.x 实现了 Paged Attention，按块管理 k/v cache。\n\n`cache_block_seq_len` 表示一块 k/v block 可以存放的 token 序列长度，默认 128。TurboMind 按照以下公式计算 k/v block 的内存大小：\n\n```\ncache_block_seq_len * num_layer * kv_head_num * size_per_head * 2 * sizeof(kv_data_type)\n```\n\n对于 llama2-7b 模型来说，以 half 类型存放 k/v 时，一块 k/v block 的内存为：`128 * 32 * 32 * 128 * 2 * sizeof(half) = 64MB`\n\n`cache_max_entry_count` 根据取值不同，表示不同的含义：\n\n- 当值为 (0, 1) 之间的小数时，`cache_max_entry_count` 表示 k/v block 使用的内存百分比。比如 A100-80G 显卡内存是80G，当`cache_max_entry_count`为0.5时，表示 k/v block 使用的内存总量为 80 * 0.5 = 40G\n- 当 lmdeploy 版本大于 0.2.1 时，`cache_max_entry_count` 将**空闲**内存的百分比用于 k/v blocks，默认值为 `0.8`。例如，在 A100-80G GPU 上运行 Turbomind 加载 13b 模型时，k/v blocks 使用的内存为 `(80 - 26) * 0.8 = 43.2G`，即利用剩余 54G 中的 80%\n- 当值为 > 1的整数时，表示 k/v block 数量\n\n`cache_chunk_size` 表示在每次需要新的 k/v cache 块时，开辟 k/v cache 块的大小。不同的取值，表示不同的含义：\n\n- 当为 > 0 的整数时，开辟 `cache_chunk_size` 个 k/v cache 块\n- 当值为 -1 时，开辟 `cache_max_entry_count` 个 k/v cache 块\n- 当值为 0 时，时，开辟 `sqrt(cache_max_entry_count)` 个 k/v cache 块\n\n### 前缀缓存开关\n\n`enable_prefix_caching`是前缀缓存（Prefix Caching）功能的开关。值为`True`时表示开启，`False`表示关闭，默认为`False`。\n\n前缀缓存功能主要适用于多个请求具有相同的prompt前缀（比如system prompt）的场景，该相同前缀部分的 k/v block 会被缓存起来，被多个请求重复利用，从而节省了重复计算的开销，提高推理性能。相同prompt前缀长度越长，性能提升越大。\n\n由于前缀缓存对 k/v 重复利用的最小粒度是block，如果相同prompt前缀不足一个block（前缀长度\\<`cache_block_seq_len`），则推理性能不会有提升。\n\n### kv 量化推理开关\n\n`quant_policy`是 kv 量化和推理开关。\n\n- `quant_policy=4` 代表 4bit k/v 量化和推理\n- `quant_policy=8` 代表 8bit k/v 量化和推理\n\n具体使用方法，请参考 [kv quant](../quantization/kv_quant.md) 部署文档\n\n### 外推能力开关\n\n默认 `rope_scaling_factor = 0` 不具备外推能力。设置为 1.0，可以开启 RoPE 的 Dynamic NTK 功能，支持长文本推理。\n\n关于 Dynamic NTK 的原理，详细请参考：\n\n1. https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases\n2. https://kexue.fm/archives/9675\n\n设置 `use_logn_attn = 1`，可以开启 [LogN attention scaling](https://kexue.fm/archives/8823)。\n\n## TurboMind 1.0 配置\n\n以 `llama-2-7b-chat` 模型为例，在 TurboMind 1.0 中，它的`config.ini`内容如下：\n\n```toml\n[llama]\nmodel_name = \"llama2\"\ntensor_para_size = 1\nhead_num = 32\nkv_head_num = 32\nvocab_size = 32000\nnum_layer = 32\ninter_size = 11008\nnorm_eps = 1e-06\nattn_bias = 0\nstart_id = 1\nend_id = 2\nsession_len = 4104\nweight_type = \"fp16\"\nrotary_embedding = 128\nrope_theta = 10000.0\nsize_per_head = 128\ngroup_size = 0\nmax_batch_size = 32\nmax_context_token_num = 4\nstep_length = 1\ncache_max_entry_count = 48\ncache_chunk_size = 1\nuse_context_fmha = 1\nquant_policy = 0\nmax_position_embeddings = 2048\nuse_dynamic_ntk = 0\nuse_logn_attn = 0\n```\n\n这些参数由模型属性和推理参数组成。模型属性包括层数、head个数、维度等等，它们**不可修改**\n\n```toml\nmodel_name = \"llama2\"\nhead_num = 32\nkv_head_num = 32\nvocab_size = 32000\nnum_layer = 32\ninter_size = 11008\nnorm_eps = 1e-06\nattn_bias = 0\nstart_id = 1\nend_id = 2\nrotary_embedding = 128\nrope_theta = 10000.0\nsize_per_head = 128\n```\n\n在接下来的章节中，我们重点介绍推理参数。\n\n### 数据类型\n\n和数据类型相关的参数是 `weight_type` 和 `group_size`。它们**不可被修改**。\n\n`weight_type` 表示权重的数据类型。目前支持 fp16 和 int4。int4 表示 4bit 权重。当 `weight_type`为 4bit 权重时，`group_size` 表示 `awq` 量化权重时使用的 group 大小。目前，在 LMDeploy 的预编译包中，使用的是 `group_size = 128`。\n\n### 批处理大小\n\n可通过`max_batch_size`调节推理时最大的 batch 数。一般，batch 越大吞吐量越高。但务必保证 `max_batch_size <= cache_max_entry_count`\n\n### k/v cache 大小\n\nTurboMind 根据 `session_len`、 `cache_chunk_size` 和 `cache_max_entry_count` 开辟 k/v cache 内存。\n\n- `session_len` 表示一个序列的最大长度，即 context window 的大小。\n- `cache_chunk_size` 表示当新增对话序列时，每次要开辟多少个序列的 k/v cache\n- `cache_max_entry_count` 表示最多缓存多少个对话序列\n\n### kv int8 开关\n\n当启动 8bit k/v 推理时，需要修改参数 `quant_policy` 和 `use_context_fmha`。详细内容请查阅 [kv int8](../quantization/kv_quant.md) 部署文档。\n\n### 外推能力开关\n\n设置 `use_dynamic_ntk = 1`，可以开启 RoPE 的 Dynamic NTK 选项，支持长文本推理。\n\n关于 Dynamic NTK 的原理，详细请参考：\n\n1. https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases\n2. https://kexue.fm/archives/9675\n\n设置 `use_logn_attn = 1`，可以开启 [LogN attention scaling](https://kexue.fm/archives/8823)。\n"
  },
  {
    "path": "docs/zh_cn/llm/api_server.md",
    "content": "# 部署 LLM 类 openai 服务\n\n本文主要介绍单个模型在单机多卡环境下，部署兼容 openai 接口服务的方式，以及服务接口的用法。为行文方便，我们把该服务名称为 `api_server`。对于多模型的并行服务，请阅读[请求分发服务器](./proxy_server.md)一文。\n\n在这篇文章中， 我们首先介绍服务启动的两种方法，你可以根据应用场景，选择合适的。\n\n其次，我们重点介绍服务的 RESTful API 定义，以及接口使用的方式，并展示如何通过 Swagger UI、LMDeploy CLI 工具体验服务功能\n\n## 启动服务\n\n以 huggingface hub 上的 [internlm2_5-7b-chat](https://huggingface.co/internlm/internlm2_5-7b-chat) 模型为例，你可以任选以下方式之一，启动推理服务。\n\n### 方式一：使用 lmdeploy cli 工具\n\n```shell\nlmdeploy serve api_server internlm/internlm2_5-7b-chat --server-port 23333\n```\n\napi_server 启动时的参数可以通过命令行`lmdeploy serve api_server -h`查看。\n比如，`--tp` 设置张量并行，`--session-len` 设置推理的最大上下文窗口长度，`--cache-max-entry-count` 调整 k/v cache 的内存使用比例等等。\n\n### 方式二：使用 docker\n\n使用 LMDeploy 官方[镜像](https://hub.docker.com/r/openmmlab/lmdeploy/tags)，可以运行兼容 OpenAI 的服务。下面是使用示例：\n\n```shell\ndocker run --runtime nvidia --gpus all \\\n    -v ~/.cache/huggingface:/root/.cache/huggingface \\\n    --env \"HUGGING_FACE_HUB_TOKEN=<secret>\" \\\n    -p 23333:23333 \\\n    --ipc=host \\\n    openmmlab/lmdeploy:latest \\\n    lmdeploy serve api_server internlm/internlm2_5-7b-chat\n```\n\n在这个例子中，`lmdeploy server api_server` 的命令参数与方式一一致。\n\n每个模型可能需要 Docker 映像中未包含的特定依赖项。如果遇到问题，您可能需要根据具体情况自行安装这些依赖项。如有疑问，请参阅特定模型的项目以获取文档。\n\n例如，对于 Llava\n\n```\nFROM openmmlab/lmdeploy:latest\n\nRUN apt-get update && apt-get install -y python3 python3-pip git\n\nWORKDIR /app\n\nRUN pip3 install --upgrade pip\nRUN pip3 install timm\nRUN pip3 install git+https://github.com/haotian-liu/LLaVA.git --no-deps\n\nCOPY . .\n\nCMD [\"lmdeploy\", \"serve\", \"api_server\", \"liuhaotian/llava-v1.6-34b\"]\n```\n\n### 方式三：部署到Kubernetes集群\n\n使用[kubectl](https://kubernetes.io/docs/reference/kubectl/)命令行工具，连接到一个运行中Kubernetes集群并部署internlm2_5-7b-chat模型服务。下面是使用示例（需要替换`<your token>`为你的huggingface hub token）：\n\n```shell\nsed 's/{{HUGGING_FACE_HUB_TOKEN}}/<your token>/' k8s/deployment.yaml | kubectl create -f - \\\n    && kubectl create -f k8s/service.yaml\n```\n\n示例中模型数据来源于node上的本地磁盘（hostPath），多副本部署时考虑替换为高可用共享存储，通过[PersistentVolume](https://kubernetes.io/docs/concepts/storage/persistent-volumes/)方式挂载到容器中。\n\n## RESTful API\n\nLMDeploy 的 RESTful API 兼容了 OpenAI 以下 3 个接口：\n\n- /v1/chat/completions\n- /v1/models\n- /v1/completions\n\n服务启动后，你可以在浏览器中打开网页 http://0.0.0.0:23333，通过 Swagger UI 查看接口的详细说明，并且也可以直接在网页上操作，体验每个接口的用法，如下图所示。\n\n![swagger_ui](https://github.com/InternLM/lmdeploy/assets/4560679/b891dd90-3ffa-4333-92b2-fb29dffa1459)\n\n若需要把服务集成到自己的项目或者产品中，我们推荐以下用法：\n\n### 使用 openai 接口\n\n以下代码是通过 openai 包使用 `v1/chat/completions` 服务的例子。运行之前，请先安装 openai 包: `pip install openai`。\n\n```python\nfrom openai import OpenAI\nclient = OpenAI(\n    api_key='YOUR_API_KEY',\n    base_url=\"http://0.0.0.0:23333/v1\"\n)\nmodel_name = client.models.list().data[0].id\nresponse = client.chat.completions.create(\n  model=model_name,\n  messages=[\n    {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n    {\"role\": \"user\", \"content\": \" provide three suggestions about time management\"},\n  ],\n    temperature=0.8,\n    top_p=0.8\n)\nprint(response)\n```\n\n如果你想使用异步的接口，可以尝试下面的例子：\n\n```python\nimport asyncio\nfrom openai import AsyncOpenAI\n\nasync def main():\n    client = AsyncOpenAI(api_key='YOUR_API_KEY',\n                         base_url='http://0.0.0.0:23333/v1')\n    model_cards = await client.models.list()._get_page()\n    response = await client.chat.completions.create(\n        model=model_cards.data[0].id,\n        messages=[\n            {\n                'role': 'system',\n                'content': 'You are a helpful assistant.'\n            },\n            {\n                'role': 'user',\n                'content': ' provide three suggestions about time management'\n            },\n        ],\n        temperature=0.8,\n        top_p=0.8)\n    print(response)\n\nasyncio.run(main())\n```\n\n关于其他 openai 接口的调用，也可以如法炮制。详情请参考 openai 官方[文档](https://platform.openai.com/docs/guides/text-generation)\n\n### 使用 lmdeploy `APIClient` 接口\n\n如果你想用 `/v1/chat/completions` 接口，你可以尝试下面代码：\n\n```python\nfrom lmdeploy.serve.openai.api_client import APIClient\napi_client = APIClient(f'http://{server_ip}:{server_port}')\nmodel_name = api_client.available_models[0]\nmessages = [{\"role\": \"user\", \"content\": \"Say this is a test!\"}]\nfor item in api_client.chat_completions_v1(model=model_name, messages=messages):\n    print(item)\n```\n\n如果你想用 `/v1/completions` 接口，你可以尝试：\n\n```python\nfrom lmdeploy.serve.openai.api_client import APIClient\napi_client = APIClient(f'http://{server_ip}:{server_port}')\nmodel_name = api_client.available_models[0]\nfor item in api_client.completions_v1(model=model_name, prompt='hi'):\n    print(item)\n```\n\n### 工具调用\n\n参考 [api_server_tools](./api_server_tools.md)。\n\n### 使用 Java/Golang/Rust\n\n可以使用代码生成工具 [openapi-generator-cli](https://github.com/OpenAPITools/openapi-generator-cli) 将 `http://{server_ip}:{server_port}/openapi.json` 转成 java/rust/golang 客户端。\n下面是一个使用示例：\n\n```shell\n$ docker run -it --rm -v ${PWD}:/local openapitools/openapi-generator-cli generate -i /local/openapi.json -g rust -o /local/rust\n\n$ ls rust/*\nrust/Cargo.toml  rust/git_push.sh  rust/README.md\n\nrust/docs:\nChatCompletionRequest.md  EmbeddingsRequest.md  HttpValidationError.md  LocationInner.md  Prompt.md\nDefaultApi.md             GenerateRequest.md    Input.md                Messages.md       ValidationError.md\n\nrust/src:\napis  lib.rs  models\n```\n\n### 使用 cURL\n\ncURL 也可以用于查看 API 的输出结果\n\n- 查看模型列表 `v1/models`\n\n```bash\ncurl http://{server_ip}:{server_port}/v1/models\n```\n\n- 对话 `v1/chat/completions`\n\n```bash\ncurl http://{server_ip}:{server_port}/v1/chat/completions \\\n  -H \"Content-Type: application/json\" \\\n  -d '{\n    \"model\": \"internlm-chat-7b\",\n    \"messages\": [{\"role\": \"user\", \"content\": \"Hello! How are you?\"}]\n  }'\n```\n\n- 文本补全 `v1/completions`\n\n```shell\ncurl http://{server_ip}:{server_port}/v1/completions \\\n  -H 'Content-Type: application/json' \\\n  -d '{\n  \"model\": \"llama\",\n  \"prompt\": \"two steps to build a house:\"\n}'\n```\n\n## 同时启动多个 api_server\n\n两步直接启动多机多卡服务。先用下面的代码创建一个启动脚本。然后：\n\n1. 启动代理服务 `lmdeploy serve proxy`。\n2. torchrun 启动脚本 `torchrun --nproc_per_node 2 script.py InternLM/internlm2-chat-1_8b --proxy_url http://{proxy_node_name}:{proxy_node_port}`. **注意**： 多机多卡不要用默认 url `0.0.0.0:8000`，我们需要输入真实ip对应的地址，如：`11.25.34.55:8000`。多机情况下，因为不需要子节点间的通信，所以并不需要用户指定 torchrun 的 `--nnodes` 等参数，只要能保证每个节点执行一次单节点的 torchrun 就行。\n\n```python\nimport os\nimport socket\nfrom typing import List, Literal\n\nimport fire\n\n\ndef get_host_ip():\n    try:\n        s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)\n        s.connect(('8.8.8.8', 80))\n        ip = s.getsockname()[0]\n    finally:\n        s.close()\n    return ip\n\n\ndef main(model_path: str,\n         tp: int = 1,\n         proxy_url: str = 'http://0.0.0.0:8000',\n         port: int = 23333,\n         backend: Literal['turbomind', 'pytorch'] = 'turbomind'):\n    local_rank = int(os.environ.get('LOCAL_RANK', -1))\n    world_size = int(os.environ.get('WORLD_SIZE', -1))\n    local_ip = get_host_ip()\n    if isinstance(port, List):\n        assert len(port) == world_size\n        port = port[local_rank]\n    else:\n        port += local_rank * 10\n    if (world_size - local_rank) % tp == 0:\n        rank_list = ','.join([str(local_rank + i) for i in range(tp)])\n        command = f'CUDA_VISIBLE_DEVICES={rank_list} lmdeploy serve api_server {model_path} '\\\n                  f'--server-name {local_ip} --server-port {port} --tp {tp} '\\\n                  f'--proxy-url {proxy_url} --backend {backend}'\n        print(f'running command: {command}')\n        os.system(command)\n\n\nif __name__ == '__main__':\n    fire.Fire(main)\n```\n\n### 示例\n\n为了进一步展示如何在集群环境中使用多机多卡服务。下面提供一个在火山云的用例：\n\n```shell\n#!/bin/bash\n# 激活 conda 环境\nsource /path/to/your/home/miniconda3/bin/activate /path/to/your/home/miniconda3/envs/your_env\nexport HOME=/path/to/your/home\n# 获取主节点IP地址（假设 MLP_WORKER_0_HOST 是主节点的IP）\nMASTER_IP=${MLP_WORKER_0_HOST}\n# 检查是否为主节点\nif [ \"${MLP_ROLE_INDEX}\" -eq 0 ]; then\n    # 启动 lmdeploy serve proxy 并放入后台\n    echo \"Starting lmdeploy serve proxy on master node...\"\n    PROXY_PORT=8000\n    lmdeploy serve proxy --server-name ${MASTER_IP} --server-port ${PROXY_PORT} &\nelse\n    # 这里我们默认调度平台同时启动了所有机器，否则要sleep一会，等待 proxy 启动成功\n    echo \"Not starting lmdeploy serve proxy on worker node ${MLP_ROLE_INDEX}.\"\nfi\n# 启动 torchrun 并放入后台\n# 再次强调多机环境下并不需要传--nnodes 或者 --master-addr 等参数，相当于每个机器上执行一次单节点的 torchrun 即可。\ntorchrun \\\n--nproc_per_node=${MLP_WORKER_GPU} \\\n/path/to/script.py \\\nInternLM/internlm2-chat-1_8b 8 http://${MASTER_IP}:${PROXY_PORT}\n# 打印主机的IP地址\necho \"Host IP addresses:\"\nhostname -I\n```\n\n## FAQ\n\n1. 当返回结果结束原因为 `\"finish_reason\":\"length\"`，这表示回话长度超过最大值。如需调整会话支持的最大长度，可以通过启动`api_server`时，设置`--session_len`参数大小。\n\n2. 当服务端显存 OOM 时，可以适当减小启动服务时的 `backend_config` 的 `cache_max_entry_count` 大小\n\n3. 关于停止符，我们只支持编码后为单个 index 的字符。此外，可能存在多种 index 都会解码出带有停止符的结果。对于这种情况，如果这些 index 数量太多，我们只会采用 tokenizer 编码出的 index。而如果你想要编码后为多个 index 的停止符，可以考虑在流式客户端做字符串匹配，匹配成功后跳出流式循环即可。\n\n4. 自定义对话模板，请参考[chat_template.md](../advance/chat_template.md)\n"
  },
  {
    "path": "docs/zh_cn/llm/api_server_lora.md",
    "content": "# LoRA 推理服务\n\n## 启动 LoRA 服务\n\nLoRA 目前只有 pytorch 后端支持。它的服务化，和其他模型服务化一样，命令都可以用 `lmdeploy serve api_server -h` 查看。其中 pytorch 后端支持的参数就有 LoRA 的配置内容。\n\n```\nPyTorch engine arguments:\n  --adapters [ADAPTERS [ADAPTERS ...]]\n                        Used to set path(s) of lora adapter(s). One can input key-value pairs in xxx=yyy format for multiple lora adapters. If only have one adapter, one can only input the path of the adapter.. Default:\n                        None. Type: str\n```\n\n用户只需要将 lora 权重的 huggingface 模型路径通过字典的形式传入 `--adapters` 即可。\n\n```shell\nlmdeploy serve api_server THUDM/chatglm2-6b --adapters mylora=chenchi/lora-chatglm2-6b-guodegang\n```\n\n服务启动后，可以在 Swagger UI 中查询到两个可用的模型名字：“THUDM/chatglm2-6b” 和 “mylora”。后者是 `--adapters` 字典的 key。\n\n## 客户端使用\n\n### CLI\n\n使用时，OpenAI 接口参数 `model` 可以用来选择使用基础模型还是某个 lora 权重用于推理。下面的例子就选择使用了传入的 `chenchi/lora-chatglm2-6b-guodegang` 用于推理。\n\n```shell\ncurl -X 'POST' \\\n  'http://localhost:23334/v1/chat/completions' \\\n  -H 'accept: application/json' \\\n  -H 'Content-Type: application/json' \\\n  -d '{\n  \"model\": \"mylora\",\n  \"messages\": [\n    {\n      \"content\": \"hi\",\n      \"role\": \"user\"\n    }\n  ]\n}'\n```\n\n可以得到一个这个 lora 权重特有的回复：\n\n```json\n{\n  \"id\": \"2\",\n  \"object\": \"chat.completion\",\n  \"created\": 1721377275,\n  \"model\": \"mylora\",\n  \"choices\": [\n    {\n      \"index\": 0,\n      \"message\": {\n        \"role\": \"assistant\",\n        \"content\": \" 很高兴哪有什么赶凳儿？（按东北语说的“起早哇”），哦，东北人都学会外语了？\",\n        \"tool_calls\": null\n      },\n      \"logprobs\": null,\n      \"finish_reason\": \"stop\"\n    }\n  ],\n  \"usage\": {\n    \"prompt_tokens\": 17,\n    \"total_tokens\": 43,\n    \"completion_tokens\": 26\n  }\n}\n```\n\n### python\n\n```python\nfrom openai import OpenAI\nclient = OpenAI(\n    api_key='YOUR_API_KEY',\n    base_url=\"http://0.0.0.0:23333/v1\"\n)\nmodel_name = 'mylora'\nresponse = client.chat.completions.create(\n  model=model_name,\n  messages=[\n    {\"role\": \"user\", \"content\": \"hi\"},\n  ],\n    temperature=0.8,\n    top_p=0.8\n)\nprint(response)\n```\n\n打印的响应内容为：\n\n```\nChatCompletion(id='4', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content=' 很高兴能够见到你哪，我也在辐射区开了个愣儿，你呢，还活着。', role='assistant', function_call=None, tool_calls=None))], created=1721377497, model='mylora', object='chat.completion', service_tier=None, system_fingerprint=None, usage=CompletionUsage(completion_tokens=22, prompt_tokens=17, total_tokens=39))\n```\n"
  },
  {
    "path": "docs/zh_cn/llm/api_server_reasoning.md",
    "content": "# Reasoning Outputs\n\n对于支持推理能力的模型，比如 [DeepSeek R1](https://huggingface.co/deepseek-ai/DeepSeek-R1)，LMDeploy 支持在服务中将推理的结果解析出来，并单独用\nreasoning_content 记录推理内容。\n\n## 使用示例\n\n### DeepSeek R1\n\n我们可以像启动其他模型的 api_server 服务一样启动 DeepSeek R1 的模型，只是不同的是，我们需要指定 `--reasoning-parser`。\n在 `--reasoning-parser` 传参里，我们需要指定具体的 parser。\n\n```\nlmdeploy serve api_server deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B --reasoning-parser deepseek-r1\n```\n\n然后，我们就可以在客户端调用这个服务的功能：\n\n```python\nfrom openai import OpenAI\n\nopenai_api_key = \"Your API key\"\nopenai_api_base = \"http://0.0.0.0:23333/v1\"\n\nclient = OpenAI(\n    api_key=openai_api_key,\n    base_url=openai_api_base,\n)\n\nmodels = client.models.list()\nmodel = models.data[0].id\n\nmessages = [{\"role\": \"user\", \"content\": \"9.11 and 9.8, which is greater?\"}]\nresponse = client.chat.completions.create(model=model, messages=messages, stream=True)\nfor stream_response in response:\n    print('reasoning content: ',stream_response.choices[0].delta.reasoning_content)\n    print('content: ', stream_response.choices[0].delta.content)\n\nresponse = client.chat.completions.create(model=model, messages=messages, stream=False)\nreasoning_content = response.choices[0].message.reasoning_content\ncontent = response.choices[0].message.content\n\nprint(\"reasoning_content:\", reasoning_content)\nprint(\"content:\", content)\n```\n\n## 自定义 parser\n\n只需要在 `lmdeploy/serve/openai/reasoning_parser/reasoning_parser.py` 中添加一个类似的 parser 类即可。\n\n```python\n# import the required packages\nfrom typing import Sequence, Union, Tuple, Optional\n\nfrom lmdeploy.serve.openai.reasoning_parser import (\n    ReasoningParser, ReasoningParserManager)\nfrom lmdeploy.serve.openai.protocol import (ChatCompletionRequest,\n                                              DeltaMessage)\n\n# define a reasoning parser and register it to lmdeploy\n# the name list in register_module can be used\n# in --reasoning-parser.\n@ReasoningParserManager.register_module([\"example\"])\nclass ExampleParser(ReasoningParser):\n    def __init__(self, tokenizer: object):\n        super().__init__(tokenizer)\n\n    def extract_reasoning_content_streaming(\n        self,\n        previous_text: str,\n        current_text: str,\n        delta_text: str,\n        previous_token_ids: Sequence[int],\n        current_token_ids: Sequence[int],\n        delta_token_ids: Sequence[int],\n    ) -> Union[DeltaMessage, None]:\n        \"\"\"\n        Instance method that should be implemented for extracting reasoning\n        from an incomplete response; for use when handling reasoning calls and\n        streaming. Has to be an instance method because  it requires state -\n        the current tokens/diffs, but also the information about what has\n        previously been parsed and extracted (see constructor)\n        \"\"\"\n\n    def extract_reasoning_content(\n            self, model_output: str, request: ChatCompletionRequest\n    ) -> Tuple[Optional[str], Optional[str]]:\n        \"\"\"\n        Extract reasoning content from a complete model-generated string.\n\n        Used for non-streaming responses where we have the entire model response\n        available before sending to the client.\n\n        Args:\n            model_output (str): The model-generated string to extract reasoning content from.\n            request (ChatCompletionRequest): he request object that was used to generate the model_output.\n\n        Returns:\n            reasoning_content (str | None): The reasoning content.\n            final_output (str | None): The content.\n        \"\"\"\n```\n\n类似的，启动服务的命令就变成了：\n\n```\nlmdeploy serve api_server $model_path --reasoning-parser example\n```\n"
  },
  {
    "path": "docs/zh_cn/llm/api_server_tools.md",
    "content": "# Tools\n\nLMDeploy 支持 InternLM2, InternLM2.5, Llama3.1 和 Qwen2.5模型的工具调用。请在启动 api_server 的时候使用 `--tool-call-parser` 指定\nparser 名字。以下是支持的名字:\n\n1. internlm\n2. qwen\n3. llama3\n\n## 单轮调用\n\n启动好模型的服务后，运行下面 demo 即可。\n\n```python\nfrom openai import OpenAI\n\ntools = [\n  {\n    \"type\": \"function\",\n    \"function\": {\n      \"name\": \"get_current_weather\",\n      \"description\": \"Get the current weather in a given location\",\n      \"parameters\": {\n        \"type\": \"object\",\n        \"properties\": {\n          \"location\": {\n            \"type\": \"string\",\n            \"description\": \"The city and state, e.g. San Francisco, CA\",\n          },\n          \"unit\": {\"type\": \"string\", \"enum\": [\"celsius\", \"fahrenheit\"]},\n        },\n        \"required\": [\"location\"],\n      },\n    }\n  }\n]\nmessages = [{\"role\": \"user\", \"content\": \"What's the weather like in Boston today?\"}]\n\nclient = OpenAI(api_key='YOUR_API_KEY',base_url='http://0.0.0.0:23333/v1')\nmodel_name = client.models.list().data[0].id\nresponse = client.chat.completions.create(\n    model=model_name,\n    messages=messages,\n    temperature=0.8,\n    top_p=0.8,\n    stream=False,\n    tools=tools)\nprint(response)\n```\n\n## 多轮调用\n\n### InternLM\n\n一个完整的工具链调用过程可以通过下面的例子展示。\n\n```python\nfrom openai import OpenAI\n\n\ndef add(a: int, b: int):\n    return a + b\n\n\ndef mul(a: int, b: int):\n    return a * b\n\n\ntools = [{\n    'type': 'function',\n    'function': {\n        'name': 'add',\n        'description': 'Compute the sum of two numbers',\n        'parameters': {\n            'type': 'object',\n            'properties': {\n                'a': {\n                    'type': 'int',\n                    'description': 'A number',\n                },\n                'b': {\n                    'type': 'int',\n                    'description': 'A number',\n                },\n            },\n            'required': ['a', 'b'],\n        },\n    }\n}, {\n    'type': 'function',\n    'function': {\n        'name': 'mul',\n        'description': 'Calculate the product of two numbers',\n        'parameters': {\n            'type': 'object',\n            'properties': {\n                'a': {\n                    'type': 'int',\n                    'description': 'A number',\n                },\n                'b': {\n                    'type': 'int',\n                    'description': 'A number',\n                },\n            },\n            'required': ['a', 'b'],\n        },\n    }\n}]\nmessages = [{'role': 'user', 'content': 'Compute (3+5)*2'}]\n\nclient = OpenAI(api_key='YOUR_API_KEY', base_url='http://0.0.0.0:23333/v1')\nmodel_name = client.models.list().data[0].id\nresponse = client.chat.completions.create(\n    model=model_name,\n    messages=messages,\n    temperature=0.8,\n    top_p=0.8,\n    stream=False,\n    tools=tools)\nprint(response)\nfunc1_name = response.choices[0].message.tool_calls[0].function.name\nfunc1_args = response.choices[0].message.tool_calls[0].function.arguments\nfunc1_out = eval(f'{func1_name}(**{func1_args})')\nprint(func1_out)\n\nmessages.append(response.choices[0].message)\nmessages.append({\n    'role': 'tool',\n    'content': f'3+5={func1_out}',\n    'tool_call_id': response.choices[0].message.tool_calls[0].id\n})\nresponse = client.chat.completions.create(\n    model=model_name,\n    messages=messages,\n    temperature=0.8,\n    top_p=0.8,\n    stream=False,\n    tools=tools)\nprint(response)\nfunc2_name = response.choices[0].message.tool_calls[0].function.name\nfunc2_args = response.choices[0].message.tool_calls[0].function.arguments\nfunc2_out = eval(f'{func2_name}(**{func2_args})')\nprint(func2_out)\n```\n\n实际使用 InternLM2-Chat-7B 模型执行上述例子，可以得到下面的结果：\n\n```\nChatCompletion(id='1', choices=[Choice(finish_reason='tool_calls', index=0, logprobs=None, message=ChatCompletionMessage(content='', role='assistant', function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='0', function=Function(arguments='{\"a\": 3, \"b\": 5}', name='add'), type='function')]))], created=1722852901, model='/nvme/shared_data/InternLM/internlm2-chat-7b', object='chat.completion', system_fingerprint=None, usage=CompletionUsage(completion_tokens=25, prompt_tokens=263, total_tokens=288))\n8\nChatCompletion(id='2', choices=[Choice(finish_reason='tool_calls', index=0, logprobs=None, message=ChatCompletionMessage(content='', role='assistant', function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='1', function=Function(arguments='{\"a\": 8, \"b\": 2}', name='mul'), type='function')]))], created=1722852901, model='/nvme/shared_data/InternLM/internlm2-chat-7b', object='chat.completion', system_fingerprint=None, usage=CompletionUsage(completion_tokens=25, prompt_tokens=293, total_tokens=318))\n16\n```\n\n### Llama3.1\n\nMeta 在 [Llama3 的官方用户指南](https://llama.meta.com/docs/model-cards-and-prompt-formats/llama3_1)中宣布（注：下文为原文的中文翻译）：\n\n> 有三个内置工具（brave_search、wolfram_alpha 和 code interpreter）可以使用系统提示词打开：\n>\n> 1. Brave Search：执行网络搜索的工具调用。\n> 2. Wolfram Alpha：执行复杂数学计算的工具调用。\n> 3. Code Interpreter：使模型能够输出 Python 代码的功能。\n\n此外，它还警告说：“注意： 我们建议使用 Llama 70B-instruct 或 Llama 405B-instruct 用于结合对话和工具调用的应用。Llama 8B-Instruct 无法可靠地在工具调用定义的同时维持对话。它可以用于零样本工具调用，但在模型和用户之间的常规对话中，应移除工具指令。”（注：引号中内容为原文的中文翻译）\n\n因此，我们使用 [Meta-Llama-3.1-70B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-70B-Instruct) 来展示如何通过 LMDeploy的`api_server`调用模型的工具能力.\n\n在 A100-SXM-80G 节点上，可以按照以下方式启动服务：\n\n```shell\nlmdeploy serve api_server /the/path/of/Meta-Llama-3.1-70B-Instruct/model --tp 4\n```\n\n有关 api_server 的详细介绍，请参考[此处](./api_server.md)的详细文档。\n\n以下代码示例展示了如何使用 \"Wolfram Alpha\" 工具。假设你已经在[Wolfram Alpha](https://www.wolframalpha.com) 网站上注册并获取了 API 密钥。请确保拥有一个有效的 API 密钥，以便访问 Wolfram Alpha 提供的服务。\n\n```python\nfrom openai import OpenAI\nimport requests\n\n\ndef request_llama3_1_service(messages):\n    client = OpenAI(api_key='YOUR_API_KEY',\n                    base_url='http://0.0.0.0:23333/v1')\n    model_name = client.models.list().data[0].id\n    response = client.chat.completions.create(\n        model=model_name,\n        messages=messages,\n        temperature=0.8,\n        top_p=0.8,\n        stream=False)\n    return response.choices[0].message.content\n\n\n# The role of \"system\" MUST be specified, including the required tools\nmessages = [\n    {\n        \"role\": \"system\",\n        \"content\": \"Environment: ipython\\nTools: wolfram_alpha\\n\\n Cutting Knowledge Date: December 2023\\nToday Date: 23 Jul 2024\\n\\nYou are a helpful Assistant.\" # noqa\n    },\n    {\n        \"role\": \"user\",\n        \"content\": \"Can you help me solve this equation: x^3 - 4x^2 + 6x - 24 = 0\"  # noqa\n    }\n]\n\n# send request to the api_server of llama3.1-70b and get the response\n# the \"assistant_response\" is supposed to be:\n# <|python_tag|>wolfram_alpha.call(query=\"solve x^3 - 4x^2 + 6x - 24 = 0\")\nassistant_response = request_llama3_1_service(messages)\nprint(assistant_response)\n\n# Call the API of Wolfram Alpha with the query generated by the model\napp_id = 'YOUR-Wolfram-Alpha-API-KEY'\nparams = {\n    \"input\": assistant_response,\n    \"appid\": app_id,\n    \"format\": \"plaintext\",\n    \"output\": \"json\",\n}\n\nwolframalpha_response = requests.get(\n    \"https://api.wolframalpha.com/v2/query\",\n    params=params\n)\nwolframalpha_response = wolframalpha_response.json()\n\n# Append the contents obtained by the model and the wolframalpha's API\n# to \"messages\", and send it again to the api_server\nmessages += [\n    {\n        \"role\": \"assistant\",\n        \"content\": assistant_response\n    },\n    {\n        \"role\": \"ipython\",\n        \"content\": wolframalpha_response\n    }\n]\n\nassistant_response = request_llama3_1_service(messages)\nprint(assistant_response)\n```\n\n### Qwen2.5\n\nQwen2.5 支持了多工具调用，这意味着可以在一次请求中可能发起多个工具请求\n\n```python\nfrom openai import OpenAI\nimport json\n\ndef get_current_temperature(location: str, unit: str = \"celsius\"):\n    \"\"\"Get current temperature at a location.\n\n    Args:\n        location: The location to get the temperature for, in the format \"City, State, Country\".\n        unit: The unit to return the temperature in. Defaults to \"celsius\". (choices: [\"celsius\", \"fahrenheit\"])\n\n    Returns:\n        the temperature, the location, and the unit in a dict\n    \"\"\"\n    return {\n        \"temperature\": 26.1,\n        \"location\": location,\n        \"unit\": unit,\n    }\n\n\ndef get_temperature_date(location: str, date: str, unit: str = \"celsius\"):\n    \"\"\"Get temperature at a location and date.\n\n    Args:\n        location: The location to get the temperature for, in the format \"City, State, Country\".\n        date: The date to get the temperature for, in the format \"Year-Month-Day\".\n        unit: The unit to return the temperature in. Defaults to \"celsius\". (choices: [\"celsius\", \"fahrenheit\"])\n\n    Returns:\n        the temperature, the location, the date and the unit in a dict\n    \"\"\"\n    return {\n        \"temperature\": 25.9,\n        \"location\": location,\n        \"date\": date,\n        \"unit\": unit,\n    }\n\ndef get_function_by_name(name):\n    if name == \"get_current_temperature\":\n        return get_current_temperature\n    if name == \"get_temperature_date\":\n        return get_temperature_date\n\ntools = [{\n    'type': 'function',\n    'function': {\n        'name': 'get_current_temperature',\n        'description': 'Get current temperature at a location.',\n        'parameters': {\n            'type': 'object',\n            'properties': {\n                'location': {\n                    'type': 'string',\n                    'description': 'The location to get the temperature for, in the format \\'City, State, Country\\'.'\n                },\n                'unit': {\n                    'type': 'string',\n                    'enum': [\n                        'celsius',\n                        'fahrenheit'\n                    ],\n                    'description': 'The unit to return the temperature in. Defaults to \\'celsius\\'.'\n                }\n            },\n            'required': [\n                'location'\n            ]\n        }\n    }\n}, {\n    'type': 'function',\n    'function': {\n        'name': 'get_temperature_date',\n        'description': 'Get temperature at a location and date.',\n        'parameters': {\n            'type': 'object',\n            'properties': {\n                'location': {\n                    'type': 'string',\n                    'description': 'The location to get the temperature for, in the format \\'City, State, Country\\'.'\n                },\n                'date': {\n                    'type': 'string',\n                    'description': 'The date to get the temperature for, in the format \\'Year-Month-Day\\'.'\n                },\n                'unit': {\n                    'type': 'string',\n                    'enum': [\n                        'celsius',\n                        'fahrenheit'\n                    ],\n                    'description': 'The unit to return the temperature in. Defaults to \\'celsius\\'.'\n                }\n            },\n            'required': [\n                'location',\n                'date'\n            ]\n        }\n    }\n}]\nmessages = [{'role': 'user', 'content': 'Today is 2024-11-14, What\\'s the temperature in San Francisco now? How about tomorrow?'}]\n\nclient = OpenAI(api_key='YOUR_API_KEY', base_url='http://0.0.0.0:23333/v1')\nmodel_name = client.models.list().data[0].id\nresponse = client.chat.completions.create(\n    model=model_name,\n    messages=messages,\n    temperature=0.8,\n    top_p=0.8,\n    stream=False,\n    tools=tools)\nprint(response.choices[0].message.tool_calls)\nmessages.append(response.choices[0].message)\n\nfor tool_call in response.choices[0].message.tool_calls:\n    tool_call_args = json.loads(tool_call.function.arguments)\n    tool_call_result =  get_function_by_name(tool_call.function.name)(**tool_call_args)\n    messages.append({\n        'role': 'tool',\n        'name': tool_call.function.name,\n        'content': tool_call_result,\n        'tool_call_id': tool_call.id\n    })\n\nresponse = client.chat.completions.create(\n    model=model_name,\n    messages=messages,\n    temperature=0.8,\n    top_p=0.8,\n    stream=False,\n    tools=tools)\nprint(response.choices[0].message.content)\n\n```\n\n使用Qwen2.5-14B-Instruct，可以得到以下类似结果\n\n```\n[ChatCompletionMessageToolCall(id='0', function=Function(arguments='{\"location\": \"San Francisco, California, USA\"}', name='get_current_temperature'), type='function'),\n ChatCompletionMessageToolCall(id='1', function=Function(arguments='{\"location\": \"San Francisco, California, USA\", \"date\": \"2024-11-15\"}', name='get_temperature_date'), type='function')]\n\nThe current temperature in San Francisco, California, USA is 26.1°C. For tomorrow, 2024-11-15, the temperature is expected to be 25.9°C.\n```\n\n需要注意的是，多工具调用的情况下，工具调用的结果顺序会影响回答的效果，tool_call_id并没有正确给到LLM.\n"
  },
  {
    "path": "docs/zh_cn/llm/codellama.md",
    "content": "# Code Llama\n\n## 模型介绍\n\n[codellama](https://github.com/facebookresearch/codellama) 支持很多种编程语言，包括 Python, C++, Java, PHP, Typescript (Javascript), C#, Bash 等等。具备代码续写、代码填空、对话、python专项等 4 种能力。\n\n它在 [HuggingFace](https://huggingface.co/codellama) 上发布了基座模型，Python模型和指令微调模型：\n\n| 基座模型                                                                        | Python微调模型                                                                                | 指令模型                                                                                          |\n| ------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------- |\n| [codellama/CodeLlama-7b-hf](https://huggingface.co/codellama/CodeLlama-7b-hf)   | [codellama/CodeLlama-7b-Python-hf](https://huggingface.co/codellama/CodeLlama-7b-Python-hf)   | [codellama/CodeLlama-7b-Instruct-hf](https://huggingface.co/codellama/CodeLlama-7b-Instruct-hf)   |\n| [codellama/CodeLlama-13b-hf](https://huggingface.co/codellama/CodeLlama-13b-hf) | [codellama/CodeLlama-13b-Python-hf](https://huggingface.co/codellama/CodeLlama-13b-Python-hf) | [codellama/CodeLlama-13b-Instruct-hf](https://huggingface.co/codellama/CodeLlama-13b-Instruct-hf) |\n| [codellama/CodeLlama-34b-hf](https://huggingface.co/codellama/CodeLlama-34b-hf) | [codellama/CodeLlama-34b-Python-hf](https://huggingface.co/codellama/CodeLlama-34b-Python-hf) | [codellama/CodeLlama-34b-Instruct-hf](https://huggingface.co/codellama/CodeLlama-34b-Instruct-hf) |\n\n模型和能力的对应关系为：\n\n| 模型           | 代码续写 | 代码填空          | 对话 | Python专项 |\n| -------------- | -------- | ----------------- | ---- | ---------- |\n| 基座模型       | Y        | Y(7B,13B), N(34B) | N    | N          |\n| Python微调模型 | Y        | N                 | N    | Y          |\n| 指令微调模型   | Y        | Y(7B,13B), N(34B) | Y    | N          |\n\n## 推理\n\n根据前文模型的能力表，在本小节中，我们讲通过具体的示例展示使用 CodeLlama 各能力的方法\n\n### 代码续写\n\n```python\nfrom lmdeploy import pipeline, GenerationConfig, ChatTemplateConfig\n\npipe = pipeline('meta-llama/CodeLlama-7b-hf',\n                chat_template_config=ChatTemplateConfig(\n                    model_name='codellama',\n                    capability='completion'\n                ))\n\nresponse = pipe(\n    'import socket\\n\\ndef ping_exponential_backoff(host: str):',\n    gen_config=GenerationConfig(\n        top_k=10,\n        temperature=0.1,\n        top_p=0.95\n    )\n)\nprint(response.text)\n```\n\n### 代码填空\n\n```python\nfrom lmdeploy import pipeline, GenerationConfig, ChatTemplateConfig\n\npipe = pipeline('meta-llama/CodeLlama-7b-hf',\n                chat_template_config=ChatTemplateConfig(\n                    model_name='codellama',\n                    capability='infilling'\n                ))\n\nprompt = \"\"\"\ndef remove_non_ascii(s: str) -> str:\n    \\\"\\\"\\\"\n    <FILL>\n    \\\"\\\"\\\"\n    return result\n\"\"\"\nresponse = pipe(\n    prompt,\n    gen_config=GenerationConfig(\n        top_k=10,\n        temperature=0.1,\n        top_p=0.95,\n        max_new_tokens=500\n    )\n)\nprint(response.text)\n```\n\n### 对话\n\n```python\nfrom lmdeploy import pipeline, GenerationConfig, ChatTemplateConfig\n\npipe = pipeline('meta-llama/CodeLlama-7b-Instruct-hf',\n                chat_template_config=ChatTemplateConfig(\n                    model_name='codellama',\n                    capability='chat'\n                ))\n\nresponse = pipe(\n    'implement quick sort in C++',\n    gen_config=GenerationConfig(\n        top_k=10,\n        temperature=0.1,\n        top_p=0.95\n    )\n)\nprint(response.text)\n```\n\n### Python 专项\n\n```python\nfrom lmdeploy import pipeline, GenerationConfig, ChatTemplateConfig\n\npipe = pipeline('meta-llama/CodeLlama-7b-Python-hf',\n                chat_template_config=ChatTemplateConfig(\n                    model_name='codellama',\n                    capability='python'\n                ))\n\nresponse = pipe(\n    'implement quick sort',\n    gen_config=GenerationConfig(\n        top_k=10,\n        temperature=0.1,\n        top_p=0.95\n    )\n)\nprint(response.text)\n```\n\n## 量化\n\nTBD\n\n## 服务\n\n准备好对话模板文件，比如说“codellama.json”，参考如下示例，填写 CodeLlama 的能力：\n\n```json\n{\n    \"model_name\": \"codellama\",\n    \"capability\": \"completion\"\n}\n```\n\n然后，启动推理服务：\n\n```shell\nlmdeploy serve api_server meta-llama/CodeLlama-7b-Instruct-hf --chat-template codellama.json\n```\n\n在服务启动成功后，可以通过`openai`客户端接口，访问服务：\n\n```python\nfrom openai import OpenAI\nclient = OpenAI(\n    api_key='YOUR_API_KEY',\n    base_url=\"http://0.0.0.0:23333/v1\"\n)\nmodel_name = client.models.list().data[0].id\nresponse = client.chat.completions.create(\n  model=model_name,\n  messages=[\n    {\"role\": \"user\", \"content\": \"import socket\\n\\ndef ping_exponential_backoff(host: str):\"},\n  ],\n    temperature=0.1,\n    top_p=0.95,\n    max_tokens=500\n)\nprint(response)\n```\n\n关于 api_server 的详细介绍，请参考[这份](../llm/api_server.md)文档。\n"
  },
  {
    "path": "docs/zh_cn/llm/pipeline.md",
    "content": "# LLM 离线推理 pipeline\n\n本文通过一些例子展示 pipeline 的基本用法。\n\npipeline API 详细的接口说明，请阅读[此处](https://lmdeploy.readthedocs.io/zh-cn/latest/api/pipeline.html)\n\n## 使用方法\n\n### \"Hello, world\" 示例\n\n```python\nfrom lmdeploy import pipeline\n\npipe = pipeline('internlm/internlm2_5-7b-chat')\nresponse = pipe(['Hi, pls intro yourself', 'Shanghai is'])\nprint(response)\n```\n\n在这个例子中，pipeline 默认申请一定比例显存，用来存储推理过程中产生的 k/v。比例由参数 `TurbomindEngineConfig.cache_max_entry_count` 控制。\n\nLMDeploy 在研发过程中，k/v cache 比例的设定策略有变更，以下为变更记录：\n\n1. `v0.2.0 <= lmdeploy <= v0.2.1`\n\n   默认比例为 0.5，表示 **GPU总显存**的 50% 被分配给 k/v cache。 对于 7B 模型来说，如果显存小于 40G，会出现 OOM。当遇到 OOM 时，请按照下面的方法，酌情降低 k/v cache 占比：\n\n   ```python\n   from lmdeploy import pipeline, TurbomindEngineConfig\n\n   # 调低 k/v cache内存占比调整为总显存的 20%\n   backend_config = TurbomindEngineConfig(cache_max_entry_count=0.2)\n\n   pipe = pipeline('internlm/internlm2_5-7b-chat',\n                   backend_config=backend_config)\n   response = pipe(['Hi, pls intro yourself', 'Shanghai is'])\n   print(response)\n   ```\n\n2. `lmdeploy > v0.2.1`\n\n   分配策略改为从**空闲显存**中按比例为 k/v cache 开辟空间。默认比例值调整为 0.8。如果遇到 OOM，类似上面的方法，请酌情减少比例值，降低 k/v cache 的内存占用量\n\n### 设置多卡并行\n\n```python\nfrom lmdeploy import pipeline, TurbomindEngineConfig\n\nbackend_config = TurbomindEngineConfig(tp=2)\npipe = pipeline('internlm/internlm2_5-7b-chat',\n                backend_config=backend_config)\nresponse = pipe(['Hi, pls intro yourself', 'Shanghai is'])\nprint(response)\n```\n\n### 设置随机采样参数\n\n```python\nfrom lmdeploy import pipeline, GenerationConfig, TurbomindEngineConfig\n\nbackend_config = TurbomindEngineConfig(tp=2)\ngen_config = GenerationConfig(top_p=0.8,\n                              top_k=40,\n                              temperature=0.8,\n                              max_new_tokens=1024)\npipe = pipeline('internlm/internlm2_5-7b-chat',\n                backend_config=backend_config)\nresponse = pipe(['Hi, pls intro yourself', 'Shanghai is'],\n                gen_config=gen_config)\nprint(response)\n```\n\n### 使用 OpenAI 格式的 prompt\n\n```python\nfrom lmdeploy import pipeline, GenerationConfig, TurbomindEngineConfig\n\nbackend_config = TurbomindEngineConfig(tp=2)\ngen_config = GenerationConfig(top_p=0.8,\n                              top_k=40,\n                              temperature=0.8,\n                              max_new_tokens=1024)\npipe = pipeline('internlm/internlm2_5-7b-chat',\n                backend_config=backend_config)\nprompts = [[{\n    'role': 'user',\n    'content': 'Hi, pls intro yourself'\n}], [{\n    'role': 'user',\n    'content': 'Shanghai is'\n}]]\nresponse = pipe(prompts,\n                gen_config=gen_config)\nprint(response)\n```\n\n### 流式输出\n\n```python\nfrom lmdeploy import pipeline, GenerationConfig, TurbomindEngineConfig\n\nbackend_config = TurbomindEngineConfig(tp=2)\ngen_config = GenerationConfig(top_p=0.8,\n                              top_k=40,\n                              temperature=0.8,\n                              max_new_tokens=1024)\npipe = pipeline('internlm/internlm2_5-7b-chat',\n                backend_config=backend_config)\nprompts = [[{\n    'role': 'user',\n    'content': 'Hi, pls intro yourself'\n}], [{\n    'role': 'user',\n    'content': 'Shanghai is'\n}]]\nfor item in pipe.stream_infer(prompts, gen_config=gen_config):\n    print(item)\n```\n\n### 获取生成 token 的 logits\n\n```python\nfrom lmdeploy import pipeline, GenerationConfig\n\npipe = pipeline('internlm/internlm2_5-7b-chat')\n\ngen_config=GenerationConfig(output_logits='generation'\n                            max_new_tokens=10)\nresponse = pipe(['Hi, pls intro yourself', 'Shanghai is'],\n                gen_config=gen_config)\nlogits = [x.logits for x in response]\n```\n\n### 获取生成 token 最后一层的 hidden_states\n\n```python\nfrom lmdeploy import pipeline, GenerationConfig\n\npipe = pipeline('internlm/internlm2_5-7b-chat')\n\ngen_config=GenerationConfig(output_last_hidden_state='generation',\n                            max_new_tokens=10)\nresponse = pipe(['Hi, pls intro yourself', 'Shanghai is'],\n                gen_config=gen_config)\nhidden_states = [x.last_hidden_state for x in response]\n```\n\n### 计算 ppl\n\n```python\nfrom transformers import AutoTokenizer\nfrom lmdeploy import pipeline\n\n\nmodel_repoid_or_path = 'internlm/internlm2_5-7b-chat'\npipe = pipeline(model_repoid_or_path)\ntokenizer = AutoTokenizer.from_pretrained(model_repoid_or_path, trust_remote_code=True)\nmessages = [\n   {\"role\": \"user\", \"content\": \"Hello, how are you?\"},\n]\ninput_ids = tokenizer.apply_chat_template(messages)\n\n# logits is a list of tensor\nlogits = pipe.get_logits(input_ids)\nprint(logits)\n\n# ppl is a list of float numbers\nppl = pipe.get_ppl(input_ids)\nprint(ppl)\n```\n\n```{note}\n当 input_ids 过长时，可能会出现 OOM 错误，请小心应用\nget_ppl 返回的是 cross entropy loss，没有在之后加 exp 操作\n```\n\n### 使用 PyTorchEngine\n\n需要先安装 triton\n\n```shell\npip install triton>=2.1.0\n```\n\n```python\nfrom lmdeploy import pipeline, GenerationConfig, PytorchEngineConfig\n\nbackend_config = PytorchEngineConfig(session_len=2048)\ngen_config = GenerationConfig(top_p=0.8,\n                              top_k=40,\n                              temperature=0.8,\n                              max_new_tokens=1024)\npipe = pipeline('internlm/internlm2_5-7b-chat',\n                backend_config=backend_config)\nprompts = [[{\n    'role': 'user',\n    'content': 'Hi, pls intro yourself'\n}], [{\n    'role': 'user',\n    'content': 'Shanghai is'\n}]]\nresponse = pipe(prompts, gen_config=gen_config)\nprint(response)\n```\n\n### LoRA 模型推理\n\n```python\nfrom lmdeploy import pipeline, GenerationConfig, PytorchEngineConfig\n\nbackend_config = PytorchEngineConfig(session_len=2048,\n                                     adapters=dict(lora_name_1='chenchi/lora-chatglm2-6b-guodegang'))\ngen_config = GenerationConfig(top_p=0.8,\n                              top_k=40,\n                              temperature=0.8,\n                              max_new_tokens=1024)\npipe = pipeline('THUDM/chatglm2-6b',\n                backend_config=backend_config)\nprompts = [[{\n    'role': 'user',\n    'content': '您猜怎么着'\n}]]\nresponse = pipe(prompts, gen_config=gen_config, adapter_name='lora_name_1')\nprint(response)\n```\n\n### 释放 pipeline\n\n您可以通过调用其 `close()` 方法来显式释放 pipeline，或者，也可以使用 `with` 语句，如下所示：\n\n```python\nfrom lmdeploy import pipeline\n\nwith pipeline('internlm/internlm2_5-7b-chat') as pipe:\n    response = pipe(['Hi, pls intro yourself', 'Shanghai is'])\n    print(response)\n```\n\n## 常见问题\n\n- **RuntimeError: An attempt has been made to start a new process before the current process has finished its bootstrapping phase**.\n\n  如果你在使用 tp>1 和 pytorch 后端的时候，遇到了这个错误。请确保 python 脚本中有下面内容作为入口\n\n  ```python\n  if __name__ == '__main__':\n  ```\n\n  一般来说，在多线程或多进程上下文中，可能需要确保初始化代码只执行一次。这时候，`if __name__ == '__main__':` 可以帮助确保这些初始化代码只在主程序执行，而不会在每个新创建的进程或线程中重复执行。\n\n- 自定义对话模板，请参考[chat_template.md](../advance/chat_template.md)\n\n- 如果 lora 的权重有对应的对话模板，可以先注册对话模板到 lmdeploy，然后 adapter 名为对话模板名使用即可\n"
  },
  {
    "path": "docs/zh_cn/llm/proxy_server.md",
    "content": "# 请求分发服务器\n\n请求分发服务可以将多个 api_server 服务，进行并联。用户可以只需要访问代理 URL，就可以间接访问不同的 api_server 服务。代理服务内部会自动分发请求，做到负载均衡。\n\n## 启动\n\n启动代理服务：\n\n```shell\nlmdeploy serve proxy --server-name {server_name} --server-port {server_port} --routing-strategy \"min_expected_latency\" --serving-strategy Hybrid\n```\n\n启动成功后，代理服务的 URL 也会被脚本打印。浏览器访问这个 URL，可以打开 Swagger UI。\n随后，用户可以在启动 api_server 服务的时候，通过 `--proxy-url` 命令将其直接添加到代理服务中。例如：`lmdeploy serve api_server InternLM/internlm2-chat-1_8b --proxy-url http://0.0.0.0:8000`。\n这样，用户可以通过代理节点访问 api_server 的服务，代理节点的使用方式和 api_server 一模一样，都是兼容 OpenAI 的形式。\n\n- /v1/models\n- /v1/chat/completions\n- /v1/completions\n\n## 节点管理\n\n通过 Swagger UI，我们可以看到多个 API。其中，和 api_server 节点管理相关的有：\n\n- /nodes/status\n- /nodes/add\n- /nodes/remove\n\n他们分别表示，查看所有的 api_server 服务节点，增加某个节点，删除某个节点。他们的使用方式，最直接的可以在浏览器里面直接操作。也可以通过命令行或者 python 操作。\n\n### 通过 command 增删查\n\n```shell\ncurl -X 'GET' \\\n  'http://localhost:8000/nodes/status' \\\n  -H 'accept: application/json'\n```\n\n```shell\ncurl -X 'POST' \\\n  'http://localhost:8000/nodes/add' \\\n  -H 'accept: application/json' \\\n  -H 'Content-Type: application/json' \\\n  -d '{\n  \"url\": \"http://0.0.0.0:23333\"\n}'\n```\n\n```shell\ncurl -X 'POST' \\\n  'http://localhost:8000/nodes/remove?node_url=http://0.0.0.0:23333' \\\n  -H 'accept: application/json' \\\n  -d ''\n```\n\n### 通过 python 脚本增删查\n\n```python\n# 查询所有节点\nimport requests\nurl = 'http://localhost:8000/nodes/status'\nheaders = {'accept': 'application/json'}\nresponse = requests.get(url, headers=headers)\nprint(response.text)\n```\n\n```python\n# 添加新节点\nimport requests\nurl = 'http://localhost:8000/nodes/add'\nheaders = {\n    'accept': 'application/json',\n    'Content-Type': 'application/json'\n}\ndata = {\"url\": \"http://0.0.0.0:23333\"}\nresponse = requests.post(url, headers=headers, json=data)\nprint(response.text)\n```\n\n```python\n# 删除某个节点\nimport requests\nurl = 'http://localhost:8000/nodes/remove'\nheaders = {'accept': 'application/json',}\nparams = {'node_url': 'http://0.0.0.0:23333',}\nresponse = requests.post(url, headers=headers, data='', params=params)\nprint(response.text)\n```\n\n## 服务策略\n\nLMDeploy 当前支持混合部署服务（Hybrid），以及 PD 分离部署服务（DistServe）\n\n- Hybrid: 不区分 Prefill 和 Decoding 实例，即传统的推理部署模式。\n- DistServe: 将 Prefill 和 Decoding 实例分离，部署在不同的服务节点上以实现更灵活高效的资源调度和扩展。\n\n## 分发策略\n\n代理服务目前的分发策略如下：\n\n- random： 根据用户提供的各个 api_server 节点的处理请求的能力，进行有权重的随机。处理请求的吞吐量越大，就越有可能被分配。部分节点没有提供吞吐量，将按照其他节点的平均吞吐量对待。\n- min_expected_latency： 根据每个节点现有的待处理完的请求，和各个节点吞吐能力，计算预期完成响应所需时间，时间最短的将被分配。未提供吞吐量的节点，同上。\n- min_observed_latency： 根据每个节点过去一定数量的请求，处理完成所需的平均用时，用时最短的将被分配。\n"
  },
  {
    "path": "docs/zh_cn/make.bat",
    "content": "@ECHO OFF\n\npushd %~dp0\n\nREM Command file for Sphinx documentation\n\nif \"%SPHINXBUILD%\" == \"\" (\n\tset SPHINXBUILD=sphinx-build\n)\nset SOURCEDIR=.\nset BUILDDIR=_build\n\nif \"%1\" == \"\" goto help\n\n%SPHINXBUILD% >NUL 2>NUL\nif errorlevel 9009 (\n\techo.\n\techo.The 'sphinx-build' command was not found. Make sure you have Sphinx\n\techo.installed, then set the SPHINXBUILD environment variable to point\n\techo.to the full path of the 'sphinx-build' executable. Alternatively you\n\techo.may add the Sphinx directory to PATH.\n\techo.\n\techo.If you don't have Sphinx installed, grab it from\n\techo.http://sphinx-doc.org/\n\texit /b 1\n)\n\n%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%\ngoto end\n\n:help\n%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%\n\n:end\npopd\n"
  },
  {
    "path": "docs/zh_cn/multi_modal/api_server_vl.md",
    "content": "# 部署 VLM 类 openai 服务\n\n本文主要介绍单个VL模型在单机多卡环境下，部署兼容 openai 接口服务的方式，以及服务接口的用法。为行文方便，我们把该服务名称为 `api_server`。对于多模型的并行服务，请阅读[请求分发服务器](../llm/proxy_server.md)一文。\n\n在这篇文章中， 我们首先介绍服务启动的两种方法，你可以根据应用场景，选择合适的。\n\n其次，我们重点介绍服务的 RESTful API 定义，以及接口使用的方式，并展示如何通过 Swagger UI、LMDeploy CLI 工具体验服务功能\n\n## 启动服务\n\n以 huggingface hub 上的 [llava-v1.6-vicuna-7b](https://huggingface.co/liuhaotian/llava-v1.6-vicuna-7b) 模型为例，你可以任选以下方式之一，启动推理服务。\n\n### 方式一：使用 lmdeploy cli 工具\n\n```shell\nlmdeploy serve api_server liuhaotian/llava-v1.6-vicuna-7b --server-port 23333\n```\n\napi_server 启动时的参数可以通过命令行`lmdeploy serve api_server -h`查看。\n比如，`--tp` 设置张量并行，`--session-len` 设置推理的最大上下文窗口长度，`--cache-max-entry-count` 调整 k/v cache 的内存使用比例等等。\n\n### 方式二：使用 docker\n\n使用 LMDeploy 官方[镜像](https://hub.docker.com/r/openmmlab/lmdeploy/tags)，可以运行兼容 OpenAI 的服务。下面是使用示例：\n\n```shell\ndocker run --runtime nvidia --gpus all \\\n    -v ~/.cache/huggingface:/root/.cache/huggingface \\\n    --env \"HUGGING_FACE_HUB_TOKEN=<secret>\" \\\n    -p 23333:23333 \\\n    --ipc=host \\\n    openmmlab/lmdeploy:latest \\\n    lmdeploy serve api_server liuhaotian/llava-v1.6-vicuna-7b\n```\n\n在这个例子中，`lmdeploy server api_server` 的命令参数与方式一一致。\n\n## RESTful API\n\nLMDeploy 的 RESTful API 兼容了 OpenAI 以下 3 个接口：\n\n- /v1/chat/completions\n- /v1/models\n- /v1/completions\n\n其中使用图片交互的接口是 `/v1/chat/completions`，与 OpenAI 的一致。\n服务启动后，你可以在浏览器中打开网页 http://0.0.0.0:23333，通过 Swagger UI 查看接口的详细说明，并且也可以直接在网页上操作，体验每个接口的用法，如下图所示。\n\n![swagger_ui](https://github.com/InternLM/lmdeploy/assets/4560679/b891dd90-3ffa-4333-92b2-fb29dffa1459)\n\n若需要把服务集成到自己的项目或者产品中，我们推荐以下用法：\n\n### 使用 openai 接口\n\n以下代码是通过 openai 包使用 `v1/chat/completions` 服务的例子。运行之前，请先安装 openai 包: `pip install openai`。\n\n```python\nfrom openai import OpenAI\n\nclient = OpenAI(api_key='YOUR_API_KEY', base_url='http://0.0.0.0:23333/v1')\nmodel_name = client.models.list().data[0].id\nresponse = client.chat.completions.create(\n    model=model_name,\n    messages=[{\n        'role':\n        'user',\n        'content': [{\n            'type': 'text',\n            'text': 'Describe the image please',\n        }, {\n            'type': 'image_url',\n            'image_url': {\n                'url':\n                'https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg',\n            },\n        }],\n    }],\n    temperature=0.8,\n    top_p=0.8)\nprint(response)\n```\n\n### 使用 lmdeploy `APIClient` 接口\n\n如果你想用 `/v1/chat/completions` 接口，你可以尝试下面代码：\n\n```python\nfrom lmdeploy.serve.openai.api_client import APIClient\n\napi_client = APIClient(f'http://0.0.0.0:23333')\nmodel_name = api_client.available_models[0]\nmessages = [{\n    'role':\n    'user',\n    'content': [{\n        'type': 'text',\n        'text': 'Describe the image please',\n    }, {\n        'type': 'image_url',\n        'image_url': {\n            'url':\n            'https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg',\n        },\n    }]\n}]\nfor item in api_client.chat_completions_v1(model=model_name,\n                                           messages=messages):\n    print(item)\n```\n\n### 使用 Java/Golang/Rust\n\n可以使用代码生成工具 [openapi-generator-cli](https://github.com/OpenAPITools/openapi-generator-cli) 将 `http://{server_ip}:{server_port}/openapi.json` 转成 java/rust/golang 客户端。\n下面是一个使用示例：\n\n```shell\n$ docker run -it --rm -v ${PWD}:/local openapitools/openapi-generator-cli generate -i /local/openapi.json -g rust -o /local/rust\n\n$ ls rust/*\nrust/Cargo.toml  rust/git_push.sh  rust/README.md\n\nrust/docs:\nChatCompletionRequest.md  EmbeddingsRequest.md  HttpValidationError.md  LocationInner.md  Prompt.md\nDefaultApi.md             GenerateRequest.md    Input.md                Messages.md       ValidationError.md\n\nrust/src:\napis  lib.rs  models\n```\n"
  },
  {
    "path": "docs/zh_cn/multi_modal/cogvlm.md",
    "content": "# cogvlm\n\n## 简介\n\nCogVLM 是一个强大的开源视觉语言模型（VLM）. LMDeploy 已在PyTorch后端支持 CogVLM-17B 模型 [THUDM/cogvlm-chat-hf](https://huggingface.co/THUDM/cogvlm-chat-hf) 和 CogVLM2-19B 模型如[THUDM/cogvlm2-llama3-chat-19B](https://huggingface.co/THUDM/cogvlm2-llama3-chat-19B)\n\n## 快速开始\n\n请参考[安装文档](../get_started/installation.md)安装 LMDeploy\n\n### 准备\n\n当使用LMDeploy部署 **CogVLM** 模型时，需要下载模型至本地目录。由于 **CogVLM** 模型使用外部Tokenizer，因而需要将相关文件下载至模型目录。然而对于**CogVLM2**模型，则可跳过此步骤。\n\n以 **CogVLM** 模型 `cogvlm-chat-hf` 为例，可执行如下脚本下载模型：\n\n```shell\nhuggingface-cli download THUDM/cogvlm-chat-hf --local-dir ./cogvlm-chat-hf --local-dir-use-symlinks False\nhuggingface-cli download lmsys/vicuna-7b-v1.5 special_tokens_map.json tokenizer.model tokenizer_config.json --local-dir ./cogvlm-chat-hf --local-dir-use-symlinks False\n```\n\n### 离线推理 pipeline\n\n以下是使用pipeline进行离线推理的示例，更多用法参考[VLM离线推理 pipeline](./vl_pipeline.md)\n\n```python\nfrom lmdeploy import pipeline\nfrom lmdeploy.vl import load_image\n\n\nif __name__ == \"__main__\":\n    pipe = pipeline('cogvlm-chat-hf')\n\n    image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg')\n    response = pipe(('describe this image', image))\n    print(response)\n```\n"
  },
  {
    "path": "docs/zh_cn/multi_modal/deepseek_vl2.md",
    "content": "# DeepSeek-VL2\n\n## 简介\n\nDeepSeek-VL2 是一系列先进的 MoE 视觉-语言模型，相较于其前身 DeepSeek-VL 有了显著的改进。\nDeepSeek-VL2 在各种任务中展现出卓越的能力，包括但不限于视觉问答、OCR、文档/表格/图表理解以及视觉定位。\n\nLMDeploy 目前在 Pytorch 引擎中支持 [deepseek-vl2-tiny](https://huggingface.co/deepseek-ai/deepseek-vl2-tiny), [deepseek-vl2-small](https://huggingface.co/deepseek-ai/deepseek-vl2-small) 和 [deepseek-vl2](https://huggingface.co/deepseek-ai/deepseek-vl2) 。\n\n## 快速开始\n\n请参考[安装文档](../get_started/installation.md)安装 LMDeploy。\n\n### 准备\n\n在使用 LMDeploy 部署 **DeepSeek-VL2** 模型时，您必须安装官方的 GitHub 仓库以及一些相关的第三方库。这是因为 LMDeploy 会复用官方仓库中提供的图像处理功能。\n\n```\npip install git+https://github.com/deepseek-ai/DeepSeek-VL2.git --no-deps\npip install attrdict timm 'transformers<4.48.0'\n```\n\n值得注意的是，如果使用 transformers>=4.48.0，可能会出现失败的情况，详情可以参考此 [Issue](https://github.com/deepseek-ai/DeepSeek-VL2/issues/45)。\n\n### 离线推理 pipeline\n\n以下是使用pipeline进行离线推理的示例，更多用法参考[VLM离线推理 pipeline](./vl_pipeline.md)。\n\n为了构建有效的、包含图像输入的 DeepSeek-VL2 提示词，用户应手动插入 `<IMAGE_TOKEN>`\n\n```python\nfrom lmdeploy import pipeline\nfrom lmdeploy.vl import load_image\n\n\nif __name__ == \"__main__\":\n    pipe = pipeline('deepseek-ai/deepseek-vl2-tiny')\n\n    image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg')\n    response = pipe(('<IMAGE_TOKEN>describe this image', image))\n    print(response)\n```\n"
  },
  {
    "path": "docs/zh_cn/multi_modal/gemma3.md",
    "content": "# Gemma3\n\n## 简介\n\nGemma 是 Google 推出的轻量级、最先进的开放模型系列，采用与创建 Gemini 模型相同的研究和技术构建而成。Gemma3 模型是多模态模型，可处理文本和图像输入并生成文本输出，对预训练和指令微调均具有开源的权重。Gemma3 具有 128K 的大型上下文窗口，支持 140 多种语言，并且比以前的版本提供更多尺寸。Gemma3 模型非常适合各种文本生成和图像理解任务，包括问答、总结和推理。它们的尺寸相对较小，因此可以将其部署在资源有限的环境中，例如笔记本电脑、台式机或您自己的云基础设施，从而让每个人都能轻松访问最先进的 AI 模型，并帮助促进创新。\n\n## 快速开始\n\n请参考[安装文档](../get_started/installation.md)安装 LMDeploy。\n\n### 准备\n\n在使用 LMDeploy 部署 **Gemma3** 模型时，请安装最新的 transformers。\n\n### 离线推理 pipeline\n\n以下是使用pipeline进行离线推理的示例，更多用法参考[VLM离线推理 pipeline](./vl_pipeline.md)。\n\n```python\nfrom lmdeploy import pipeline\nfrom lmdeploy.vl import load_image\n\n\nif __name__ == \"__main__\":\n    pipe = pipeline('google/gemma-3-12b-it')\n\n    image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg')\n    response = pipe(('describe this image', image))\n    print(response)\n```\n"
  },
  {
    "path": "docs/zh_cn/multi_modal/index.rst",
    "content": "视觉语言模型\n=================================\n\n.. toctree::\n   :maxdepth: 2\n   :caption: 示例\n\n   deepseek_vl2.md\n   llava.md\n   internvl.md\n   xcomposer2d5.md\n   cogvlm.md\n   minicpmv.md\n   phi3.md\n   qwen2_vl.md\n   qwen2_5_vl.md\n   molmo.md\n   gemma3.md\n"
  },
  {
    "path": "docs/zh_cn/multi_modal/internvl.md",
    "content": "# InternVL\n\nLMDeploy 支持 InternVL 系列模型，具体如下：\n\n|         Model         |     Size      | Supported Inference Engine |\n| :-------------------: | :-----------: | :------------------------: |\n|       InternVL        |    13B-19B    |         TurboMind          |\n|      InternVL1.5      |    2B-26B     |     TurboMind, PyTorch     |\n|       InternVL2       |      4B       |          PyTorch           |\n|       InternVL2       | 1B-2B, 8B-76B |     TurboMind, PyTorch     |\n| InternVL2.5/2.5-MPO/3 |    1B-78B     |     TurboMind, PyTorch     |\n|     Mono-InternVL     |      2B       |          PyTorch           |\n\n本文将以[InternVL2-8B](https://huggingface.co/OpenGVLab/InternVL2-8B)为例，演示使用 LMDeploy 部署 InternVL 系列模型的方法。\n\n## 安装\n\n请参考[安装文档](../get_started/installation.md)安装 LMDeploy，并安装上游 InternVL 模型库需的依赖。\n\n```shell\npip install timm\n# 建议从https://github.com/Dao-AILab/flash-attention/releases寻找和环境匹配的whl包\npip install flash-attn\n```\n\n或者，你可以为 InternVL 的推理构建 docker image。如果，宿主机器上的 CUDA 版本 `>=12.4`，你可以执行如下命令构建镜像：\n\n```\ngit clone https://github.com/InternLM/lmdeploy.git\ncd lmdeploy\ndocker build --build-arg CUDA_VERSION=cu12 -t openmmlab/lmdeploy:internvl . -f ./docker/InternVL_Dockerfile\n```\n\n否则的话，可以基于 LMDeploy cu11 的镜像来构建：\n\n```shell\ndocker build --build-arg CUDA_VERSION=cu11 -t openmmlab/lmdeploy:internvl . -f ./docker/InternVL_Dockerfile\n```\n\n## 离线推理\n\n以下是使用 pipeline 进行离线推理的示例，更多用法参考[VLM离线推理 pipeline](./vl_pipeline.md)\n\n```python\nfrom lmdeploy import pipeline\nfrom lmdeploy.vl import load_image\n\npipe = pipeline('OpenGVLab/InternVL2-8B')\n\nimage = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg')\nresponse = pipe((f'describe this image', image))\nprint(response)\n```\n\n更多例子如下：\n\n<details>\n  <summary>\n    <b>多图多轮对话，拼接图像</b>\n  </summary>\n\n```python\nfrom lmdeploy import pipeline, GenerationConfig\nfrom lmdeploy.vl.constants import IMAGE_TOKEN\n\npipe = pipeline('OpenGVLab/InternVL2-8B', log_level='INFO')\nmessages = [\n    dict(role='user', content=[\n        dict(type='text', text=f'{IMAGE_TOKEN}{IMAGE_TOKEN}\\nDescribe the two images in detail.'),\n        dict(type='image_url', image_url=dict(max_dynamic_patch=12, url='https://raw.githubusercontent.com/OpenGVLab/InternVL/main/internvl_chat/examples/image1.jpg')),\n        dict(type='image_url', image_url=dict(max_dynamic_patch=12, url='https://raw.githubusercontent.com/OpenGVLab/InternVL/main/internvl_chat/examples/image2.jpg'))\n    ])\n]\nout = pipe(messages, gen_config=GenerationConfig(top_k=1))\n\nmessages.append(dict(role='assistant', content=out.text))\nmessages.append(dict(role='user', content='What are the similarities and differences between these two images.'))\nout = pipe(messages, gen_config=GenerationConfig(top_k=1))\n```\n\n</details>\n\n<details>\n  <summary>\n    <b>多图多轮对话，独立图像</b>\n  </summary>\n\n```python\nfrom lmdeploy import pipeline, GenerationConfig\nfrom lmdeploy.vl.constants import IMAGE_TOKEN\n\npipe = pipeline('OpenGVLab/InternVL2-8B', log_level='INFO')\nmessages = [\n    dict(role='user', content=[\n        dict(type='text', text=f'Image-1: {IMAGE_TOKEN}\\nImage-2: {IMAGE_TOKEN}\\nDescribe the two images in detail.'),\n        dict(type='image_url', image_url=dict(max_dynamic_patch=12, url='https://raw.githubusercontent.com/OpenGVLab/InternVL/main/internvl_chat/examples/image1.jpg')),\n        dict(type='image_url', image_url=dict(max_dynamic_patch=12, url='https://raw.githubusercontent.com/OpenGVLab/InternVL/main/internvl_chat/examples/image2.jpg'))\n    ])\n]\nout = pipe(messages, gen_config=GenerationConfig(top_k=1))\n\nmessages.append(dict(role='assistant', content=out.text))\nmessages.append(dict(role='user', content='What are the similarities and differences between these two images.'))\nout = pipe(messages, gen_config=GenerationConfig(top_k=1))\n```\n\n</details>\n\n<details>\n  <summary>\n    <b>视频多轮对话</b>\n  </summary>\n\n```python\nimport numpy as np\nfrom lmdeploy import pipeline, GenerationConfig\nfrom decord import VideoReader, cpu\nfrom lmdeploy.vl.constants import IMAGE_TOKEN\nfrom lmdeploy.vl import encode_image_base64\nfrom PIL import Image\npipe = pipeline('OpenGVLab/InternVL2-8B', log_level='INFO')\n\n\ndef get_index(bound, fps, max_frame, first_idx=0, num_segments=32):\n    if bound:\n        start, end = bound[0], bound[1]\n    else:\n        start, end = -100000, 100000\n    start_idx = max(first_idx, round(start * fps))\n    end_idx = min(round(end * fps), max_frame)\n    seg_size = float(end_idx - start_idx) / num_segments\n    frame_indices = np.array([\n        int(start_idx + (seg_size / 2) + np.round(seg_size * idx))\n        for idx in range(num_segments)\n    ])\n    return frame_indices\n\n\ndef load_video(video_path, bound=None, num_segments=32):\n    vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)\n    max_frame = len(vr) - 1\n    fps = float(vr.get_avg_fps())\n    frame_indices = get_index(bound, fps, max_frame, first_idx=0, num_segments=num_segments)\n    imgs = []\n    for frame_index in frame_indices:\n        img = Image.fromarray(vr[frame_index].asnumpy()).convert('RGB')\n        imgs.append(img)\n    return imgs\n\n\nvideo_path = 'red-panda.mp4'\nimgs = load_video(video_path, num_segments=8)\n\nquestion = ''\nfor i in range(len(imgs)):\n    question = question + f'Frame{i+1}: {IMAGE_TOKEN}\\n'\n\nquestion += 'What is the red panda doing?'\n\ncontent = [{'type': 'text', 'text': question}]\nfor img in imgs:\n    content.append({'type': 'image_url', 'image_url': {'max_dynamic_patch': 1, 'url': f'data:image/jpeg;base64,{encode_image_base64(img)}'}})\n\nmessages = [dict(role='user', content=content)]\nout = pipe(messages, gen_config=GenerationConfig(top_k=1))\n\nmessages.append(dict(role='assistant', content=out.text))\nmessages.append(dict(role='user', content='Describe this video in detail. Don\\'t repeat.'))\nout = pipe(messages, gen_config=GenerationConfig(top_k=1))\n```\n\n</details>\n\n## 在线服务\n\n你可以通过 `lmdeploy serve api_server` CLI 工具启动服务：\n\n```shell\nlmdeploy serve api_server OpenGVLab/InternVL2-8B\n```\n\n也可以基于前文构建的 docker image 启动服务：\n\n```shell\ndocker run --runtime nvidia --gpus all \\\n    -v ~/.cache/huggingface:/root/.cache/huggingface \\\n    --env \"HUGGING_FACE_HUB_TOKEN=<secret>\" \\\n    -p 23333:23333 \\\n    --ipc=host \\\n    openmmlab/lmdeploy:internvl \\\n    lmdeploy serve api_server OpenGVLab/InternVL2-8B\n```\n\nDocker compose 的方式也是一种选择。在 LMDeploy 代码库的根目录下创建`docker-compose.yml`文件，内容参考如下：\n\n```yaml\nversion: '3.5'\n\nservices:\n  lmdeploy:\n    container_name: lmdeploy\n    image: openmmlab/lmdeploy:internvl\n    ports:\n      - \"23333:23333\"\n    environment:\n      HUGGING_FACE_HUB_TOKEN: <secret>\n    volumes:\n      - ~/.cache/huggingface:/root/.cache/huggingface\n    stdin_open: true\n    tty: true\n    ipc: host\n    command: lmdeploy serve api_server OpenGVLab/InternVL2-8B\n    deploy:\n      resources:\n        reservations:\n          devices:\n            - driver: nvidia\n              count: \"all\"\n              capabilities: [gpu]\n```\n\n然后，你就可以执行命令启动服务了：\n\n```shell\ndocker-compose up -d\n```\n\n通过`docker logs -f lmdeploy`可以查看启动的日志信息，如果发现类似下方的日志信息，就表明服务启动成功了。\n\n```text\nHINT:    Please open  http://0.0.0.0:23333   in a browser for detailed api usage!!!\nHINT:    Please open  http://0.0.0.0:23333   in a browser for detailed api usage!!!\nHINT:    Please open  http://0.0.0.0:23333   in a browser for detailed api usage!!!\nINFO:     Started server process [2439]\nINFO:     Waiting for application startup.\nINFO:     Application startup complete.\nINFO:     Uvicorn running on  http://0.0.0.0:23333  (Press CTRL+C to quit)\n```\n\n有关 `lmdeploy serve api_server` 的详细参数可以通过`lmdeploy serve api_server -h`查阅。\n\n关于 `api_server` 更多的介绍，以及访问 `api_server` 的方法，请阅读[此处](api_server_vl.md)\n"
  },
  {
    "path": "docs/zh_cn/multi_modal/llava.md",
    "content": "# LLaVA\n\nLMDeploy 支持以下 LLaVA 系列模型，具体如下表所示：\n\n|                 模型                 | 大小 |   支持的推理引擎   |\n| :----------------------------------: | :--: | :----------------: |\n| llava-hf/Llava-interleave-qwen-7b-hf |  7B  | TurboMind, PyTorch |\n|       llava-hf/llava-1.5-7b-hf       |  7B  | TurboMind, PyTorch |\n|  llava-hf/llava-v1.6-mistral-7b-hf   |  7B  |      PyTorch       |\n|   llava-hf/llava-v1.6-vicuna-7b-hf   |  7B  |      PyTorch       |\n|   liuhaotian/llava-v1.6-vicuna-7b    |  7B  |     TurboMind      |\n|   liuhaotian/llava-v1.6-mistral-7b   |  7B  |     TurboMind      |\n\n接下来的章节将演示如何使用 LMDeploy 部署 LLaVA 模型，并以 [llava-hf/llava-interleave](https://huggingface.co/llava-hf/llava-interleave-qwen-7b-hf) 为例。\n\n```{note}\n自 0.6.4 之后，PyTorch 引擎移除了对 llava 原始模型的支持。我们建议使用它们对应的 transformers 格式的模型。这些模型可以在 https://huggingface.co/llava-hf 中找到\n```\n\n## 安装\n\n请按照[安装指南](../get_started/installation.md)安装 LMDeploy。\n\n或者，您也可以使用官方的 Docker 镜像：\n\n```shell\ndocker pull openmmlab/lmdeploy:latest\n```\n\n## 离线推理\n\n以下示例代码展示了 VLM pipeline 的基本用法。有关详细信息，请参考 [VLM 离线推理流程](./vl_pipeline.md)。\n\n```python\nfrom lmdeploy import GenerationConfig, TurbomindEngineConfig, pipeline\nfrom lmdeploy.vl import load_image\n\npipe = pipeline(\"llava-hf/llava-interleave-qwen-7b-hf\", backend_config=TurbomindEngineConfig(cache_max_entry_count=0.5),\n    gen_config=GenerationConfig(max_new_tokens=512))\n\nimage = load_image('https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg')\nprompt = 'Describe the image.'\nprint(f'prompt:{prompt}')\nresponse = pipe((prompt, image))\nprint(response)\n```\n\n更多示例：\n\n<details>\n  <summary><b>多图片多轮对话，组合图片</b></summary>\n\n```python\nfrom lmdeploy import pipeline, GenerationConfig\n\npipe = pipeline('llava-hf/llava-interleave-qwen-7b-hf', log_level='INFO')\nmessages = [\n    dict(role='user', content=[\n        dict(type='text', text='Describe the two images in detail.'),\n        dict(type='image_url', image_url=dict(url='https://raw.githubusercontent.com/QwenLM/Qwen-VL/master/assets/mm_tutorial/Beijing_Small.jpeg')),\n        dict(type='image_url', image_url=dict(url='https://raw.githubusercontent.com/QwenLM/Qwen-VL/master/assets/mm_tutorial/Chongqing_Small.jpeg'))\n    ])\n]\nout = pipe(messages, gen_config=GenerationConfig(top_k=1))\n\nmessages.append(dict(role='assistant', content=out.text))\nmessages.append(dict(role='user', content='What are the similarities and differences between these two images.'))\nout = pipe(messages, gen_config=GenerationConfig(top_k=1))\n```\n\n</details>\n\n## 在线服务\n\n可以使用 `lmdeploy serve api_server` CLI 启动服务器：\n\n```shell\nlmdeploy serve api_server llava-hf/llava-interleave-qwen-7b-hf\n```\n\n或者，使用前面提到的 Docker 镜像启动服务：\n\n```shell\ndocker run --runtime nvidia --gpus all \\\n    -v ~/.cache/huggingface:/root/.cache/huggingface \\\n    --env \"HUGGING_FACE_HUB_TOKEN=<secret>\" \\\n    -p 23333:23333 \\\n    --ipc=host \\\n    openmmlab/lmdeploy:latest \\\n    lmdeploy serve api_server llava-hf/llava-interleave-qwen-7b-hf\n```\n\n采用 Docker Compose 部署也是一种常见选择。在 lmdeploy 项目的根目录创建 `docker-compose.yml` 文件，如下：\n\n```yaml\nversion: '3.5'\n\nservices:\n  lmdeploy:\n    container_name: lmdeploy\n    image: openmmlab/lmdeploy:latest\n    ports:\n      - \"23333:23333\"\n    environment:\n      HUGGING_FACE_HUB_TOKEN: <secret>\n    volumes:\n      - ~/.cache/huggingface:/root/.cache/huggingface\n    stdin_open: true\n    tty: true\n    ipc: host\n    command: lmdeploy serve api_server llava-hf/llava-interleave-qwen-7b-hf\n    deploy:\n      resources:\n        reservations:\n          devices:\n            - driver: nvidia\n              count: \"all\"\n              capabilities: [gpu]\n```\n\n然后，可以执行以下命令启动服务：\n\n```shell\ndocker-compose up -d\n```\n\n当运行 `docker logs -f lmdeploy` 后看到如下日志，说明服务启动成功：\n\n```text\nHINT:    Please open  http://0.0.0.0:23333   in a browser for detailed api usage!!!\nHINT:    Please open  http://0.0.0.0:23333   in a browser for detailed api usage!!!\nHINT:    Please open  http://0.0.0.0:23333   in a browser for detailed api usage!!!\nINFO:     Started server process [2439]\nINFO:     Waiting for application startup.\nINFO:     Application startup complete.\nINFO:     Uvicorn running on  http://0.0.0.0:23333  (Press CTRL+C to quit)\n```\n\n可以通过 `lmdeploy serve api_server -h` 查看 `lmdeploy serve api_server` 的参数详情。\n\n关于 `api_server` 以及如何访问服务的更多信息可以在[这里](api_server_vl.md)找到。\n"
  },
  {
    "path": "docs/zh_cn/multi_modal/minicpmv.md",
    "content": "# MiniCPM-V\n\nLMDeploy 支持 MiniCPM-V 系列模型，具体如下：\n\n|        Model         | Supported Inference Engine |\n| :------------------: | :------------------------: |\n| MiniCPM-Llama3-V-2_5 |         TurboMind          |\n|    MiniCPM-V-2_6     |         TurboMind          |\n\n本文将以[MiniCPM-V-2_6](https://huggingface.co/openbmb/MiniCPM-V-2_6)为例，演示使用 LMDeploy 部署 MiniCPM-V 系列模型的方法\n\n## 安装\n\n请参考[安装文档](../get_started/installation.md)安装 LMDeploy。\n\n## 离线推理\n\n以下是使用pipeline进行离线推理的示例，更多用法参考[VLM离线推理 pipeline](./vl_pipeline.md)\n\n```python\nfrom lmdeploy import pipeline\nfrom lmdeploy.vl import load_image\n\npipe = pipeline('openbmb/MiniCPM-V-2_6')\n\nimage = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg')\nresponse = pipe(('describe this image', image))\nprint(response)\n```\n\n更多例子如下：\n\n<details>\n  <summary>\n    <b>多张图片，多轮对话</b>\n  </summary>\n\n```python\nfrom lmdeploy import pipeline, GenerationConfig\n\npipe = pipeline('openbmb/MiniCPM-V-2_6', log_level='INFO')\nmessages = [\n    dict(role='user', content=[\n        dict(type='text', text='Describe the two images in detail.'),\n        dict(type='image_url', image_url=dict(max_slice_nums=9, url='https://raw.githubusercontent.com/OpenGVLab/InternVL/main/internvl_chat/examples/image1.jpg')),\n        dict(type='image_url', image_url=dict(max_slice_nums=9, url='https://raw.githubusercontent.com/OpenGVLab/InternVL/main/internvl_chat/examples/image2.jpg'))\n    ])\n]\nout = pipe(messages, gen_config=GenerationConfig(top_k=1))\nprint(out.text)\n\nmessages.append(dict(role='assistant', content=out.text))\nmessages.append(dict(role='user', content='What are the similarities and differences between these two images.'))\nout = pipe(messages, gen_config=GenerationConfig(top_k=1))\nprint(out.text)\n```\n\n</details>\n\n<details>\n  <summary>\n    <b>上下文小样本学习</b>\n  </summary>\n\n```python\nfrom lmdeploy import pipeline, GenerationConfig\n\npipe = pipeline('openbmb/MiniCPM-V-2_6', log_level='INFO')\n\nquestion = \"production date\"\nmessages = [\n    dict(role='user', content=[\n        dict(type='text', text=question),\n        dict(type='image_url', image_url=dict(url='example1.jpg')),\n    ]),\n    dict(role='assistant', content='2023.08.04'),\n    dict(role='user', content=[\n        dict(type='text', text=question),\n        dict(type='image_url', image_url=dict(url='example2.jpg')),\n    ]),\n    dict(role='assistant', content='2007.04.24'),\n    dict(role='user', content=[\n        dict(type='text', text=question),\n        dict(type='image_url', image_url=dict(url='test.jpg')),\n    ])\n]\nout = pipe(messages, gen_config=GenerationConfig(top_k=1))\nprint(out.text)\n```\n\n</details>\n\n<details>\n  <summary>\n    <b>视频对话</b>\n  </summary>\n\n```python\nfrom lmdeploy import pipeline, GenerationConfig\nfrom lmdeploy.vl import encode_image_base64\nimport torch\nfrom PIL import Image\nfrom transformers import AutoModel, AutoTokenizer\nfrom decord import VideoReader, cpu    # pip install decord\n\npipe = pipeline('openbmb/MiniCPM-V-2_6', log_level='INFO')\n\nMAX_NUM_FRAMES=64 # if cuda OOM set a smaller number\ndef encode_video(video_path):\n    def uniform_sample(l, n):\n        gap = len(l) / n\n        idxs = [int(i * gap + gap / 2) for i in range(n)]\n        return [l[i] for i in idxs]\n    vr = VideoReader(video_path, ctx=cpu(0))\n    sample_fps = round(vr.get_avg_fps() / 1)  # FPS\n    frame_idx = [i for i in range(0, len(vr), sample_fps)]\n    if len(frame_idx) > MAX_NUM_FRAMES:\n        frame_idx = uniform_sample(frame_idx, MAX_NUM_FRAMES)\n    frames = vr.get_batch(frame_idx).asnumpy()\n    frames = [Image.fromarray(v.astype('uint8')) for v in frames]\n    print('num frames:', len(frames))\n    return frames\n\nvideo_path=\"video_test.mp4\"\nframes = encode_video(video_path)\nquestion = \"Describe the video\"\n\ncontent=[dict(type='text', text=question)]\nfor frame in frames:\n    content.append(dict(type='image_url', image_url=dict(use_image_id=False, max_slice_nums=2,\n        url=f'data:image/jpeg;base64,{encode_image_base64(frame)}')))\n\nmessages = [dict(role='user', content=content)]\nout = pipe(messages, gen_config=GenerationConfig(top_k=1))\nprint(out.text)\n```\n\n</details>\n\n## 在线服务\n\n你可以通过 `lmdeploy serve api_server` CLI 工具启动服务：\n\n```shell\nlmdeploy serve api_server openbmb/MiniCPM-V-2_6\n```\n\n也可以基于 LMDeploy 的 docker 启动服务：\n\n```shell\ndocker run --runtime nvidia --gpus all \\\n    -v ~/.cache/huggingface:/root/.cache/huggingface \\\n    --env \"HUGGING_FACE_HUB_TOKEN=<secret>\" \\\n    -p 23333:23333 \\\n    --ipc=host \\\n    openmmlab/lmdeploy:latest \\\n    lmdeploy serve api_server openbmb/MiniCPM-V-2_6\n```\n\nDocker compose 的方式也是一种选择。在 LMDeploy 代码库的根目录下创建`docker-compose.yml`文件，内容参考如下：\n\n```yaml\nversion: '3.5'\n\nservices:\n  lmdeploy:\n    container_name: lmdeploy\n    image: openmmlab/lmdeploy:latest\n    ports:\n      - \"23333:23333\"\n    environment:\n      HUGGING_FACE_HUB_TOKEN: <secret>\n    volumes:\n      - ~/.cache/huggingface:/root/.cache/huggingface\n    stdin_open: true\n    tty: true\n    ipc: host\n    command: lmdeploy serve api_server openbmb/MiniCPM-V-2_6\n    deploy:\n      resources:\n        reservations:\n          devices:\n            - driver: nvidia\n              count: \"all\"\n              capabilities: [gpu]\n```\n\n然后，你就可以执行命令启动服务了：\n\n```shell\ndocker-compose up -d\n```\n\n通过`docker logs -f lmdeploy`可以查看启动的日志信息，如果发现类似下方的日志信息，就表明服务启动成功了。\n\n```text\nHINT:    Please open  http://0.0.0.0:23333   in a browser for detailed api usage!!!\nHINT:    Please open  http://0.0.0.0:23333   in a browser for detailed api usage!!!\nHINT:    Please open  http://0.0.0.0:23333   in a browser for detailed api usage!!!\nINFO:     Started server process [2439]\nINFO:     Waiting for application startup.\nINFO:     Application startup complete.\nINFO:     Uvicorn running on  http://0.0.0.0:23333  (Press CTRL+C to quit)\n```\n\n有关 `lmdeploy serve api_server` 的详细参数可以通过`lmdeploy serve api_server -h`查阅。\n\n关于 `api_server` 更多的介绍，以及访问 `api_server` 的方法，请阅读[此处](api_server_vl.md)\n"
  },
  {
    "path": "docs/zh_cn/multi_modal/molmo.md",
    "content": "# Qwen2-VL\n\nLMDeploy 支持 Molmo 系列模型，具体如下：\n\n|      Model      | Size | Supported Inference Engine |\n| :-------------: | :--: | :------------------------: |\n| Molmo-7B-D-0924 |  7B  |         TurboMind          |\n|  Molmo-72-0924  | 72B  |         TurboMind          |\n\n本文将以[Molmo-7B-D-0924](https://huggingface.co/allenai/Molmo-7B-D-0924) 为例，演示使用 LMDeploy 部署 Molmo 系列模型的方法\n\n## 安装\n\n请参考[安装文档](../get_started/installation.md)安装 LMDeploy。\n\n## 离线推理\n\n以下是使用 pipeline 进行离线推理的示例，更多用法参考[VLM离线推理 pipeline](./vl_pipeline.md)\n\n```python\nfrom lmdeploy import pipeline\nfrom lmdeploy.vl import load_image\n\npipe = pipeline('allenai/Molmo-7B-D-0924')\n\nimage = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg')\nresponse = pipe((f'describe this image', image))\nprint(response)\n```\n\n更多例子如下：\n\n<details>\n  <summary>\n    <b>多图多轮对话</b>\n  </summary>\n\n```python\nfrom lmdeploy import pipeline, GenerationConfig\n\npipe = pipeline('Qwen/Qwen2-VL-2B-Instruct', log_level='INFO')\nmessages = [\n    dict(role='user', content=[\n        dict(type='text', text='Describe the two images in detail.'),\n        dict(type='image_url', image_url=dict(url='https://raw.githubusercontent.com/QwenLM/Qwen-VL/master/assets/mm_tutorial/Beijing_Small.jpeg')),\n        dict(type='image_url', image_url=dict(url='https://raw.githubusercontent.com/QwenLM/Qwen-VL/master/assets/mm_tutorial/Chongqing_Small.jpeg'))\n    ])\n]\nout = pipe(messages, gen_config=GenerationConfig(top_k=1))\n\nmessages.append(dict(role='assistant', content=out.text))\nmessages.append(dict(role='user', content='What are the similarities and differences between these two images.'))\nout = pipe(messages, gen_config=GenerationConfig(top_k=1))\n```\n\n</details>\n\n## 在线服务\n\n你可以通过 `lmdeploy serve api_server` CLI 工具启动服务：\n\n```shell\nlmdeploy serve api_server Qwen/Qwen2-VL-2B-Instruct\n```\n\n也可以基于 docker image 启动服务：\n\n```shell\ndocker run --runtime nvidia --gpus all \\\n    -v ~/.cache/huggingface:/root/.cache/huggingface \\\n    --env \"HUGGING_FACE_HUB_TOKEN=<secret>\" \\\n    -p 23333:23333 \\\n    --ipc=host \\\n    openmmlab/lmdeploy:qwen2vl \\\n    lmdeploy serve api_server Qwen/Qwen2-VL-2B-Instruct\n```\n\n如果日志中有如下信息，就表明服务启动成功了。\n\n```text\nHINT:    Please open  http://0.0.0.0:23333   in a browser for detailed api usage!!!\nHINT:    Please open  http://0.0.0.0:23333   in a browser for detailed api usage!!!\nHINT:    Please open  http://0.0.0.0:23333   in a browser for detailed api usage!!!\nINFO:     Started server process [2439]\nINFO:     Waiting for application startup.\nINFO:     Application startup complete.\nINFO:     Uvicorn running on  http://0.0.0.0:23333  (Press CTRL+C to quit)\n```\n\n有关 `lmdeploy serve api_server` 的详细参数可以通过`lmdeploy serve api_server -h`查阅。\n\n关于 `api_server` 更多的介绍，以及访问 `api_server` 的方法，请阅读[此处](api_server_vl.md)\n"
  },
  {
    "path": "docs/zh_cn/multi_modal/phi3.md",
    "content": "# Phi-3 Vision\n\n## 简介\n\n[Phi-3](https://huggingface.co/collections/microsoft/phi-3-6626e15e9585a200d2d761e3) 是微软发布的轻量级系列模型，LMDeploy支持了其中的多模态模型如下：\n\n|                                                Model                                                | Size | Supported Inference Engine |\n| :-------------------------------------------------------------------------------------------------: | :--: | :------------------------: |\n| [microsoft/Phi-3-vision-128k-instruct](https://huggingface.co/microsoft/Phi-3-vision-128k-instruct) | 4.2B |          PyTorch           |\n|    [microsoft/Phi-3.5-vision-instruct](https://huggingface.co/microsoft/Phi-3.5-vision-instruct)    | 4.2B |          PyTorch           |\n\n本文将以[microsoft/Phi-3.5-vision-instruct](https://huggingface.co/microsoft/Phi-3.5-vision-instruct)为例，演示使用 LMDeploy 部署 Phi-3 系列多模态模型的方法\n\n## 安装\n\n请参考[安装文档](../get_started/installation.md)安装 LMDeploy，并安装该模型的依赖。\n\n```shell\n# 建议从https://github.com/Dao-AILab/flash-attention/releases寻找和环境匹配的whl包\npip install flash-attn\n```\n\n## 离线推理 pipeline\n\n以下是使用pipeline进行离线推理的示例，更多用法参考[VLM离线推理 pipeline](./vl_pipeline.md)\n\n```python\nfrom lmdeploy import pipeline\nfrom lmdeploy.vl import load_image\n\npipe = pipeline('microsoft/Phi-3.5-vision-instruct')\n\nimage = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg')\nresponse = pipe(('describe this image', image))\nprint(response)\n```\n\n## 在线服务\n\n### 服务启动\n\n你可以通过 `lmdeploy serve api_server` CLI 工具启动服务：\n\n```shell\nlmdeploy serve api_server microsoft/Phi-3.5-vision-instruct\n```\n\n### 使用 openai 接口\n\n以下代码是通过 openai 包使用 `v1/chat/completions` 服务的例子。运行之前，请先安装 openai 包: `pip install openai`。\n\n```python\nfrom openai import OpenAI\n\nclient = OpenAI(api_key='YOUR_API_KEY', base_url='http://0.0.0.0:23333/v1')\nmodel_name = client.models.list().data[0].id\nresponse = client.chat.completions.create(\n    model=model_name,\n    messages=[{\n        'role':\n        'user',\n        'content': [{\n            'type': 'text',\n            'text': 'Describe the image please',\n        }, {\n            'type': 'image_url',\n            'image_url': {\n                'url':\n                'https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg',\n            },\n        }],\n    }],\n    temperature=0.8,\n    top_p=0.8)\nprint(response)\n```\n"
  },
  {
    "path": "docs/zh_cn/multi_modal/qwen2_5_vl.md",
    "content": "# Qwen2.5-VL\n\nLMDeploy 支持 Qwen-VL 系列模型，具体如下：\n\n|   Model    |       Size       | Supported Inference Engine |\n| :--------: | :--------------: | :------------------------: |\n| Qwen2.5-VL | 3B, 7B, 32B, 72B |          PyTorch           |\n\n本文将以[Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct)为例，演示使用 LMDeploy 部署 Qwen2.5-VL 系列模型的方法\n\n## 安装\n\n请参考[安装文档](../get_started/installation.md)安装 LMDeploy，并安装上游 Qwen2.5-VL 模型库所需的依赖。\n\n```shell\n# Qwen2.5-VL requires the latest transformers (transformers >= 4.49.0)\npip install git+https://github.com/huggingface/transformers\n# It's highly recommended to use `[decord]` feature for faster video loading.\npip install qwen-vl-utils[decord]==0.0.8\n```\n\n## 离线推理\n\n以下是使用 pipeline 进行离线推理的示例，更多用法参考[VLM离线推理 pipeline](./vl_pipeline.md)\n\n```python\nfrom lmdeploy import pipeline\nfrom lmdeploy.vl import load_image\n\npipe = pipeline('Qwen/Qwen2.5-VL-7B-Instruct')\n\nimage = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg')\nresponse = pipe((f'describe this image', image))\nprint(response)\n```\n\n更多例子如下：\n\n<details>\n  <summary>\n    <b>多图多轮对话</b>\n  </summary>\n\n```python\nfrom lmdeploy import pipeline, GenerationConfig\n\npipe = pipeline('Qwen/Qwen2.5-VL-7B-Instruct', log_level='INFO')\nmessages = [\n    dict(role='user', content=[\n        dict(type='text', text='Describe the two images in detail.'),\n        dict(type='image_url', image_url=dict(url='https://raw.githubusercontent.com/QwenLM/Qwen-VL/master/assets/mm_tutorial/Beijing_Small.jpeg')),\n        dict(type='image_url', image_url=dict(url='https://raw.githubusercontent.com/QwenLM/Qwen-VL/master/assets/mm_tutorial/Chongqing_Small.jpeg'))\n    ])\n]\nout = pipe(messages, gen_config=GenerationConfig(top_k=1))\n\nmessages.append(dict(role='assistant', content=out.text))\nmessages.append(dict(role='user', content='What are the similarities and differences between these two images.'))\nout = pipe(messages, gen_config=GenerationConfig(top_k=1))\n```\n\n</details>\n\n<details>\n  <summary>\n    <b>控制图片分辨率，加速推理</b>\n  </summary>\n\n```python\nfrom lmdeploy import pipeline, GenerationConfig\n\npipe = pipeline('Qwen/Qwen2.5-VL-7B-Instruct', log_level='INFO')\n\nmin_pixels = 64 * 28 * 28\nmax_pixels = 64 * 28 * 28\nmessages = [\n    dict(role='user', content=[\n        dict(type='text', text='Describe the two images in detail.'),\n        dict(type='image_url', image_url=dict(min_pixels=min_pixels, max_pixels=max_pixels, url='https://raw.githubusercontent.com/QwenLM/Qwen-VL/master/assets/mm_tutorial/Beijing_Small.jpeg')),\n        dict(type='image_url', image_url=dict(min_pixels=min_pixels, max_pixels=max_pixels, url='https://raw.githubusercontent.com/QwenLM/Qwen-VL/master/assets/mm_tutorial/Chongqing_Small.jpeg'))\n    ])\n]\nout = pipe(messages, gen_config=GenerationConfig(top_k=1))\n\nmessages.append(dict(role='assistant', content=out.text))\nmessages.append(dict(role='user', content='What are the similarities and differences between these two images.'))\nout = pipe(messages, gen_config=GenerationConfig(top_k=1))\n```\n\n</details>\n\n<details>\n  <summary>\n    <b>视频多轮对话</b>\n  </summary>\n\n```python\nimport numpy as np\nfrom lmdeploy import pipeline, GenerationConfig\nfrom decord import VideoReader, cpu\nfrom lmdeploy.vl.constants import IMAGE_TOKEN\nfrom lmdeploy.vl import encode_image_base64\nfrom PIL import Image\npipe = pipeline('Qwen/Qwen2.5-VL-7B-Instruct', log_level='INFO')\n\n\ndef get_index(bound, fps, max_frame, first_idx=0, num_segments=32):\n    if bound:\n        start, end = bound[0], bound[1]\n    else:\n        start, end = -100000, 100000\n    start_idx = max(first_idx, round(start * fps))\n    end_idx = min(round(end * fps), max_frame)\n    seg_size = float(end_idx - start_idx) / num_segments\n    frame_indices = np.array([\n        int(start_idx + (seg_size / 2) + np.round(seg_size * idx))\n        for idx in range(num_segments)\n    ])\n    return frame_indices\n\n\ndef load_video(video_path, bound=None, num_segments=32):\n    vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)\n    max_frame = len(vr) - 1\n    fps = float(vr.get_avg_fps())\n    pixel_values_list, num_patches_list = [], []\n    frame_indices = get_index(bound, fps, max_frame, first_idx=0, num_segments=num_segments)\n    imgs = []\n    for frame_index in frame_indices:\n        img = Image.fromarray(vr[frame_index].asnumpy()).convert('RGB')\n        imgs.append(img)\n    return imgs\n\n\nvideo_path = 'red-panda.mp4'\nimgs = load_video(video_path, num_segments=8)\n\nquestion = ''\nfor i in range(len(imgs)):\n    question = question + f'Frame{i+1}: {IMAGE_TOKEN}\\n'\n\nquestion += 'What is the red panda doing?'\n\ncontent = [{'type': 'text', 'text': question}]\nfor img in imgs:\n    content.append({'type': 'image_url', 'image_url': {'max_dynamic_patch': 1, 'url': f'data:image/jpeg;base64,{encode_image_base64(img)}'}})\n\nmessages = [dict(role='user', content=content)]\nout = pipe(messages, gen_config=GenerationConfig(top_k=1))\n\nmessages.append(dict(role='assistant', content=out.text))\nmessages.append(dict(role='user', content='Describe this video in detail. Don\\'t repeat.'))\nout = pipe(messages, gen_config=GenerationConfig(top_k=1))\n```\n\n</details>\n"
  },
  {
    "path": "docs/zh_cn/multi_modal/qwen2_vl.md",
    "content": "# Qwen2-VL\n\nLMDeploy 支持 Qwen-VL 系列模型，具体如下：\n\n|    Model     |  Size  | Supported Inference Engine |\n| :----------: | :----: | :------------------------: |\n| Qwen-VL-Chat |   -    |         TurboMind          |\n|   Qwen2-VL   | 2B, 7B |          PyTorch           |\n\n本文将以[Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct)为例，演示使用 LMDeploy 部署 Qwen2-VL 系列模型的方法\n\n## 安装\n\n请参考[安装文档](../get_started/installation.md)安装 LMDeploy，并安装上游 Qwen2-VL 模型库需的依赖。\n\n```shell\npip install qwen_vl_utils\n```\n\n或者，你可以为 Qwen2-VL 的推理构建 docker image。如果，宿主机器上的 CUDA 版本 `>=12.4`，你可以执行如下命令构建镜像：\n\n```\ngit clone https://github.com/InternLM/lmdeploy.git\ncd lmdeploy\ndocker build --build-arg CUDA_VERSION=cu12 -t openmmlab/lmdeploy:qwen2vl . -f ./docker/Qwen2VL_Dockerfile\n```\n\n否则的话，可以基于 LMDeploy cu11 的镜像来构建：\n\n```shell\ndocker build --build-arg CUDA_VERSION=cu11 -t openmmlab/lmdeploy:qwen2vl . -f ./docker/Qwen2VL_Dockerfile\n```\n\n## 离线推理\n\n以下是使用 pipeline 进行离线推理的示例，更多用法参考[VLM离线推理 pipeline](./vl_pipeline.md)\n\n```python\nfrom lmdeploy import pipeline\nfrom lmdeploy.vl import load_image\n\npipe = pipeline('Qwen/Qwen2-VL-2B-Instruct')\n\nimage = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg')\nresponse = pipe((f'describe this image', image))\nprint(response)\n```\n\n更多例子如下：\n\n<details>\n  <summary>\n    <b>多图多轮对话</b>\n  </summary>\n\n```python\nfrom lmdeploy import pipeline, GenerationConfig\n\npipe = pipeline('Qwen/Qwen2-VL-2B-Instruct', log_level='INFO')\nmessages = [\n    dict(role='user', content=[\n        dict(type='text', text='Describe the two images in detail.'),\n        dict(type='image_url', image_url=dict(url='https://raw.githubusercontent.com/QwenLM/Qwen-VL/master/assets/mm_tutorial/Beijing_Small.jpeg')),\n        dict(type='image_url', image_url=dict(url='https://raw.githubusercontent.com/QwenLM/Qwen-VL/master/assets/mm_tutorial/Chongqing_Small.jpeg'))\n    ])\n]\nout = pipe(messages, gen_config=GenerationConfig(top_k=1))\n\nmessages.append(dict(role='assistant', content=out.text))\nmessages.append(dict(role='user', content='What are the similarities and differences between these two images.'))\nout = pipe(messages, gen_config=GenerationConfig(top_k=1))\n```\n\n</details>\n\n<details>\n  <summary>\n    <b>控制图片分辨率，加速推理</b>\n  </summary>\n\n```python\nfrom lmdeploy import pipeline, GenerationConfig\n\npipe = pipeline('Qwen/Qwen2-VL-2B-Instruct', log_level='INFO')\n\nmin_pixels = 64 * 28 * 28\nmax_pixels = 64 * 28 * 28\nmessages = [\n    dict(role='user', content=[\n        dict(type='text', text='Describe the two images in detail.'),\n        dict(type='image_url', image_url=dict(min_pixels=min_pixels, max_pixels=max_pixels, url='https://raw.githubusercontent.com/QwenLM/Qwen-VL/master/assets/mm_tutorial/Beijing_Small.jpeg')),\n        dict(type='image_url', image_url=dict(min_pixels=min_pixels, max_pixels=max_pixels, url='https://raw.githubusercontent.com/QwenLM/Qwen-VL/master/assets/mm_tutorial/Chongqing_Small.jpeg'))\n    ])\n]\nout = pipe(messages, gen_config=GenerationConfig(top_k=1))\n\nmessages.append(dict(role='assistant', content=out.text))\nmessages.append(dict(role='user', content='What are the similarities and differences between these two images.'))\nout = pipe(messages, gen_config=GenerationConfig(top_k=1))\n```\n\n</details>\n\n## 在线服务\n\n你可以通过 `lmdeploy serve api_server` CLI 工具启动服务：\n\n```shell\nlmdeploy serve api_server Qwen/Qwen2-VL-2B-Instruct\n```\n\n也可以基于前文构建的 docker image 启动服务：\n\n```shell\ndocker run --runtime nvidia --gpus all \\\n    -v ~/.cache/huggingface:/root/.cache/huggingface \\\n    --env \"HUGGING_FACE_HUB_TOKEN=<secret>\" \\\n    -p 23333:23333 \\\n    --ipc=host \\\n    openmmlab/lmdeploy:qwen2vl \\\n    lmdeploy serve api_server Qwen/Qwen2-VL-2B-Instruct\n```\n\nDocker compose 的方式也是一种选择。在 LMDeploy 代码库的根目录下创建`docker-compose.yml`文件，内容参考如下：\n\n```yaml\nversion: '3.5'\n\nservices:\n  lmdeploy:\n    container_name: lmdeploy\n    image: openmmlab/lmdeploy:qwen2vl\n    ports:\n      - \"23333:23333\"\n    environment:\n      HUGGING_FACE_HUB_TOKEN: <secret>\n    volumes:\n      - ~/.cache/huggingface:/root/.cache/huggingface\n    stdin_open: true\n    tty: true\n    ipc: host\n    command: lmdeploy serve api_server Qwen/Qwen2-VL-2B-Instruct\n    deploy:\n      resources:\n        reservations:\n          devices:\n            - driver: nvidia\n              count: \"all\"\n              capabilities: [gpu]\n```\n\n然后，你就可以执行命令启动服务了：\n\n```shell\ndocker-compose up -d\n```\n\n通过`docker logs -f lmdeploy`可以查看启动的日志信息，如果发现类似下方的日志信息，就表明服务启动成功了。\n\n```text\nHINT:    Please open  http://0.0.0.0:23333   in a browser for detailed api usage!!!\nHINT:    Please open  http://0.0.0.0:23333   in a browser for detailed api usage!!!\nHINT:    Please open  http://0.0.0.0:23333   in a browser for detailed api usage!!!\nINFO:     Started server process [2439]\nINFO:     Waiting for application startup.\nINFO:     Application startup complete.\nINFO:     Uvicorn running on  http://0.0.0.0:23333  (Press CTRL+C to quit)\n```\n\n有关 `lmdeploy serve api_server` 的详细参数可以通过`lmdeploy serve api_server -h`查阅。\n\n关于 `api_server` 更多的介绍，以及访问 `api_server` 的方法，请阅读[此处](api_server_vl.md)\n"
  },
  {
    "path": "docs/zh_cn/multi_modal/vl_pipeline.md",
    "content": "# VLM 离线推理 pipeline\n\nLMDeploy 把视觉-语言模型（VLM）复杂的推理过程，抽象为简单好用的 pipeline。它的用法与大语言模型（LLM）推理 [pipeline](../llm/pipeline.md) 类似。\n\n在[这个列表中](../supported_models/supported_models.md)，你可以查阅每个推理引擎支持的 VLM 模型。我们诚挚邀请社区在 LMDeploy 中添加更多 VLM 模型。\n\n本文将以 [OpenGVLab/InternVL2_5-8B](https://huggingface.co/OpenGVLab/InternVL2_5-8B) 模型为例，展示 VLM pipeline 的用法。你将了解它的最基础用法，以及如何通过调整引擎参数和生成条件来逐步解锁更多高级特性，如张量并行，上下文窗口大小调整，随机采样，以及对话模板的定制。\n\n此外，我们还提供针对多图、批量提示词等场景的实际推理示例。\n\n使用 pipeline 接口推理其他 VLM 模型，大同小异，主要区别在于模型依赖的配置和安装。你可以阅读[此处](https://lmdeploy.readthedocs.io/zh-cn/latest/multi_modal/)，查看不同模型的环境安装和配置方式\n\n## \"Hello, world\" 示例\n\n```python\nfrom lmdeploy import pipeline\nfrom lmdeploy.vl import load_image\n\npipe = pipeline('OpenGVLab/InternVL2_5-8B')\n\nimage = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg')\nresponse = pipe(('describe this image', image))\nprint(response)\n```\n\n如果在执行这个用例时，出现 `ImportError` 的错误，请按照提示安装相关的依赖包。\n\n上面的例子中，推理时的提示词是 (prompt, image) 的 tuple 结构。除了这种结构外，pipeline 支持 openai 格式的提示词：\n\n```python\nfrom lmdeploy import pipeline\n\npipe = pipeline('OpenGVLab/InternVL2_5-8B')\n\nprompts = [\n    {\n        'role': 'user',\n        'content': [\n            {'type': 'text', 'text': 'describe this image'},\n            {'type': 'image_url', 'image_url': {'url': 'https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg'}}\n        ]\n    }\n]\nresponse = pipe(prompts)\nprint(response)\n```\n\n### 设置多卡并行\n\n设置引擎参数 `tp`，可激活多卡并行能力\n\n```python\nfrom lmdeploy import pipeline, TurbomindEngineConfig\nfrom lmdeploy.vl import load_image\n\npipe = pipeline('OpenGVLab/InternVL2_5-8B',\n                backend_config=TurbomindEngineConfig(tp=2))\n\nimage = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg')\nresponse = pipe(('describe this image', image))\nprint(response)\n```\n\n### 设置上下文长度\n\n创建 pipeline 时，通过设置引擎参数 `session_len`，可以定制上下文窗口的最大长度\n\n```python\nfrom lmdeploy import pipeline, TurbomindEngineConfig\nfrom lmdeploy.vl import load_image\n\npipe = pipeline('OpenGVLab/InternVL2_5-8B',\n                backend_config=TurbomindEngineConfig(session_len=8192))\n\nimage = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg')\nresponse = pipe(('describe this image', image))\nprint(response)\n```\n\n### 设置随机采样参数\n\n可通过传入 `GenerationConfig` 修改 pipeline 的生成接口中的默认采样参数。\n\n```python\nfrom lmdeploy import pipeline, GenerationConfig, TurbomindEngineConfig\nfrom lmdeploy.vl import load_image\n\npipe = pipeline('OpenGVLab/InternVL2_5-8B',\n                backend_config=TurbomindEngineConfig(tp=2, session_len=8192))\ngen_config = GenerationConfig(top_k=40, top_p=0.8, temperature=0.6)\nimage = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg')\nresponse = pipe(('describe this image', image), gen_config=gen_config)\nprint(response)\n```\n\n### 自定义图片 token 的位置\n\n默认情况下，LMDeploy 会根据算法 repo 提供的对话模版将表示图片的特殊 token 插入到 user prompt 中，但在一些模型中，图片 token 的位置并没有限制，如 deepseek-vl，或者用户需要自定义图片 token 插入的位置。这种情况下，用户需要手动将表示图片的 token 插入到 prompt 中。LMDeploy 使用 `<IMAGE_TOKEN>` 作为表示图片的特殊 token。\n\n```python\nfrom lmdeploy import pipeline\nfrom lmdeploy.vl import load_image\nfrom lmdeploy.vl.constants import IMAGE_TOKEN\n\npipe = pipeline('deepseek-ai/deepseek-vl-1.3b-chat')\n\nimage = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg')\nresponse = pipe((f'describe this image{IMAGE_TOKEN}', image))\nprint(response)\n```\n\n### 设置对话模板\n\n推理时，LMDeploy 会根据模型路径匹配内置的对话模板，并把对话模板应用到输入的提示词上。如果用户使用的是本地模型，并且模型文件夹名字与官方模型不一致时，需要手动指定对话模版。以 [llava-v1.5-7b](https://huggingface.co/liuhaotian/llava-v1.5-7b) 为例，官方使用 ['llava-v1'](https://github.com/haotian-liu/LLaVA/blob/v1.2.2/llava/conversation.py#L325-L335) 对话模版，如果本地文件夹名字不是 `llava-v1.5-7b`，可以按照如下方式使用。\n\n```python\nfrom lmdeploy import pipeline, ChatTemplateConfig\nfrom lmdeploy.vl import load_image\npipe = pipeline('local_model_folder',\n                chat_template_config=ChatTemplateConfig(model_name='llava-v1'))\nimage = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg')\nresponse = pipe(('describe this image', image))\nprint(response)\n```\n\n关于如何自定义对话模版，请参考[这里](../advance/chat_template.md)\n\n### 设置视觉模型参数\n\n可通过设置 `VisionConfig` 修改视觉模型默认参数\n\n```python\nfrom lmdeploy import pipeline, VisionConfig\nfrom lmdeploy.vl import load_image\nvision_config=VisionConfig(max_batch_size=16)\npipe = pipeline('liuhaotian/llava-v1.5-7b', vision_config=vision_config)\nimage = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg')\nresponse = pipe(('describe this image', image))\nprint(response)\n```\n\n### 获取生成 token 的 logits\n\n```python\nfrom lmdeploy import pipeline, GenerationConfig\nfrom lmdeploy.vl import load_image\npipe = pipeline('OpenGVLab/InternVL2_5-8B')\n\nimage = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg')\n\nresponse = pipe(('describe this image', image),\n                gen_config=GenerationConfig(output_logits='generation'))\nlogits = response.logits\nprint(logits)\n```\n\n## 多图推理\n\n对于多图的场景，在推理时，只要把它们放在一个列表中即可。不过，多图意味着输入 token 数更多，所以通常需要[增大推理的上下文长度](#设置上下文长度)\n\n```python\nfrom lmdeploy import pipeline, TurbomindEngineConfig\nfrom lmdeploy.vl import load_image\n\npipe = pipeline('OpenGVLab/InternVL2_5-8B',\n                backend_config=TurbomindEngineConfig(session_len=8192))\n\nimage_urls=[\n    'https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/demo/resources/human-pose.jpg',\n    'https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/demo/resources/det.jpg'\n]\n\nimages = [load_image(img_url) for img_url in image_urls]\nresponse = pipe(('describe these images', images))\nprint(response)\n```\n\n## 提示词批处理\n\n做批量提示词推理非常简单，只要把它们放在一个 list 结构中：\n\n```python\nfrom lmdeploy import pipeline, TurbomindEngineConfig\nfrom lmdeploy.vl import load_image\n\npipe = pipeline('OpenGVLab/InternVL2_5-8B',\n                backend_config=TurbomindEngineConfig(session_len=8192))\n\nimage_urls=[\n    \"https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/demo/resources/human-pose.jpg\",\n    \"https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/demo/resources/det.jpg\"\n]\nprompts = [('describe this image', load_image(img_url)) for img_url in image_urls]\nresponse = pipe(prompts)\nprint(response)\n```\n\n## 多轮对话\n\npipeline 进行多轮对话有两种方式，一种是按照 openai 的格式来构造 messages，另外一种是使用 `pipeline.chat` 接口。\n\n```python\nfrom lmdeploy import pipeline, TurbomindEngineConfig, GenerationConfig\nfrom lmdeploy.vl import load_image\n\npipe = pipeline('OpenGVLab/InternVL2_5-8B',\n                backend_config=TurbomindEngineConfig(session_len=8192))\n\nimage = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/demo/resources/human-pose.jpg')\ngen_config = GenerationConfig(top_k=40, top_p=0.8, temperature=0.6)\nsess = pipe.chat(('describe this image', image), gen_config=gen_config)\nprint(sess.response.text)\nsess = pipe.chat('What is the woman doing?', session=sess, gen_config=gen_config)\nprint(sess.response.text)\n```\n\n### 释放 pipeline\n\n您可以通过调用其 `close()` 方法来显式释放 pipeline，或者，也可以使用 `with` 语句，如下所示：\n\n```python\nfrom lmdeploy import pipeline\n\nfrom lmdeploy import pipeline\nfrom lmdeploy.vl import load_image\n\nwith pipeline('OpenGVLab/InternVL2_5-8B') as pipe:\n    image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg')\n    response = pipe(('describe this image', image))\n    print(response)\n\n# Clear the torch cache and perform garbage collection if needed\nimport torch\nimport gc\ntorch.cuda.empty_cache()\ngc.collect()\n```\n"
  },
  {
    "path": "docs/zh_cn/multi_modal/xcomposer2d5.md",
    "content": "# InternLM-XComposer-2.5\n\n## 简介\n\n[InternLM-XComposer-2.5](https://github.com/InternLM/InternLM-XComposer) 是基于书生·浦语2大语言模型研发的突破性的图文多模态大模型，仅使用 7B LLM 后端就达到了 GPT-4V 级别的能力。浦语·灵笔2.5使用24K交错的图像-文本上下文进行训练，通过RoPE外推可以无缝扩展到96K长的上下文。这种长上下文能力使浦语·灵笔2.5在需要广泛输入和输出上下文的任务中表现出色。 LMDeploy 支持了 [internlm/internlm-xcomposer2d5-7b](https://huggingface.co/internlm/internlm-xcomposer2d5-7b) 模型，通过 TurboMind 引擎推理。\n\n## 快速开始\n\n### 安装\n\n请参考[安装文档](../get_started/installation.md)安装 LMDeploy，并安装上游模型库 InternLM-XComposer-2.5 所需的依赖。\n\n```shell\npip install decord\n```\n\n### 离线推理 pipeline\n\n以下是使用pipeline进行离线推理的示例，更多用法参考[VLM离线推理 pipeline](./vl_pipeline.md)\n\n```python\nfrom lmdeploy import pipeline\nfrom lmdeploy.vl import load_image\nfrom lmdeploy.vl.constants import IMAGE_TOKEN\n\npipe = pipeline('internlm/internlm-xcomposer2d5-7b')\n\nimage = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg')\nresponse = pipe((f'describe this image', image))\nprint(response)\n```\n\n## Lora 模型\n\nInternLM-XComposer-2.5 针对网页制作和文章创作训练了 LoRA 模型，由于 TurboMind 不支持 slora 特性，所以需要同时只能部署一个 LoRA 模型，需要先对权重进行合并。LMDeploy 提供相关的转换脚本，使用方式为:\n\n```\nexport HF_MODEL=internlm/internlm-xcomposer2d5-7b\nexport WORK_DIR=internlm/internlm-xcomposer2d5-7b-web\nexport TASK=web\npython -m lmdeploy.vl.tools.merge_xcomposer2d5_task $HF_MODEL $WORK_DIR --task $TASK\n```\n\n## 量化\n\n下面以 base 模型为例，展示量化的方式，若要使用 LoRA 模型，请先按照上一章节提取 LoRA 模型。\n\n```shell\n\nexport HF_MODEL=internlm/internlm-xcomposer2d5-7b\nexport WORK_DIR=internlm/internlm-xcomposer2d5-7b-4bit\n\nlmdeploy lite auto_awq \\\n   $HF_MODEL \\\n  --work-dir $WORK_DIR\n```\n\n## 更多使用例子\n\n<details>\n  <summary>\n    <b>Video Understanding</b>\n  </summary>\n\n下面以 `pipeline.chat` 为例展示用法，其它接口同样支持推理，需要手动拼接对话内容。\n\n```python\nfrom lmdeploy import pipeline, GenerationConfig\nfrom transformers.dynamic_module_utils import get_class_from_dynamic_module\n\nHF_MODEL = 'internlm/internlm-xcomposer2d5-7b'\nload_video = get_class_from_dynamic_module('ixc_utils.load_video', HF_MODEL)\nframe2img = get_class_from_dynamic_module('ixc_utils.frame2img', HF_MODEL)\nVideo_transform = get_class_from_dynamic_module('ixc_utils.Video_transform', HF_MODEL)\nget_font = get_class_from_dynamic_module('ixc_utils.get_font', HF_MODEL)\n\nvideo = load_video('liuxiang.mp4') # https://github.com/InternLM/InternLM-XComposer/raw/main/examples/liuxiang.mp4\nimg = frame2img(video, get_font())\nimg = Video_transform(img)\n\npipe = pipeline(HF_MODEL)\ngen_config = GenerationConfig(top_k=50, top_p=0.8, temperature=1.0)\nquery = 'Here are some frames of a video. Describe this video in detail'\nsess = pipe.chat((query, img), gen_config=gen_config)\nprint(sess.response.text)\n\nquery = 'tell me the athlete code of Liu Xiang'\nsess = pipe.chat(query, session=sess, gen_config=gen_config)\nprint(sess.response.text)\n```\n\n</details>\n\n<details>\n  <summary>\n    <b>Multi-Image</b>\n  </summary>\n\n```python\nfrom lmdeploy import pipeline, GenerationConfig\nfrom lmdeploy.vl.constants import IMAGE_TOKEN\nfrom lmdeploy.vl import load_image\n\nquery = f'Image1 {IMAGE_TOKEN}; Image2 {IMAGE_TOKEN}; Image3 {IMAGE_TOKEN}; I want to buy a car from the three given cars, analyze their advantages and weaknesses one by one'\n\nurls = ['https://raw.githubusercontent.com/InternLM/InternLM-XComposer/main/examples/cars1.jpg',\n        'https://raw.githubusercontent.com/InternLM/InternLM-XComposer/main/examples/cars2.jpg',\n        'https://raw.githubusercontent.com/InternLM/InternLM-XComposer/main/examples/cars3.jpg']\nimages = [load_image(url) for url in urls]\n\npipe = pipeline('internlm/internlm-xcomposer2d5-7b', log_level='INFO')\noutput = pipe((query, images), gen_config=GenerationConfig(top_k=0, top_p=0.8, random_seed=89247526689433939))\n```\n\n由于 LMDeploy 不支持 beam search，生成的结果与使用 transformers 的 beam search 相比，会有较大的差异，建议关闭 top_k 或者使用较大的 top_k 采样来增加多样性。\n\n</details>\n\n<details>\n  <summary>\n    <b>Instruction to Webpage</b>\n  </summary>\n\n请先使用使用上述说明，转化 web 模型。\n\n```python\nfrom lmdeploy import pipeline, GenerationConfig\n\npipe = pipeline('/nvme/shared/internlm-xcomposer2d5-7b-web', log_level='INFO')\npipe.chat_template.meta_instruction = None\n\nquery = 'A website for Research institutions. The name is Shanghai AI lab. Top Navigation Bar is blue.Below left, an image shows the logo of the lab. In the right, there is a passage of text below that describes the mission of the laboratory.There are several images to show the research projects of Shanghai AI lab.'\noutput = pipe(query, gen_config=GenerationConfig(max_new_tokens=2048))\n```\n\n使用 transformers 测试时，发现如果设置了 repetition_penalty，beam search 为1时有较大概率停不下来，因为 LMDeploy 不支持 beam search，建议使用 LMDeploy 推理时关闭 repetition_penalty。\n\n</details>\n\n<details>\n  <summary>\n    <b>Write Article</b>\n  </summary>\n\n请先使用使用上述说明，转化 write 模型。\n\n```python\nfrom lmdeploy import pipeline, GenerationConfig\n\npipe = pipeline('/nvme/shared/internlm-xcomposer2d5-7b-write', log_level='INFO')\npipe.chat_template.meta_instruction = None\n\nquery = 'Please write a blog based on the title: French Pastries: A Sweet Indulgence'\noutput = pipe(query, gen_config=GenerationConfig(max_new_tokens=8192))\n```\n\n</details>\n"
  },
  {
    "path": "docs/zh_cn/quantization/kv_quant.md",
    "content": "# Key-Value(KV) Cache 量化\n\n自 v0.4.0 起，LMDeploy 支持**在线** kv cache int4/int8 量化，量化方式为 per-head per-token 的非对称量化。原来的 kv 离线量化方式移除。\n\n从直观上看，量化 kv 有利于增加 kv block 的数量。与 fp16 相比，int4/int8 kv 的 kv block 分别可以增加到 4 倍和 2 倍。这意味着，在相同的内存条件下，kv 量化后，系统能支撑的并发数可以大幅提升，从而最终提高吞吐量。\n\n但是，通常，量化会伴随一定的模型精度损失。我们使用了 opencompass 评测了若干个模型在应用了 int4/int8 量化后的精度，int8 kv 精度几乎无损，int4 kv 略有损失。详细结果放在了[精度评测](#精度评测)章节中。大家可以参考，根据实际需求酌情选择。\n\nLMDeploy kv 4/8 bit 量化和推理支持如下 NVIDIA 显卡型号：\n\n- volta 架构（sm70）： V100\n- 图灵架构（sm75）：20系列、T4\n- 安培架构（sm80,sm86）：30系列、A10、A16、A30、A100\n- Ada Lovelace架构（sm89）：40 系列\n- Hopper 架构（sm90）: H100, H200\n\n总结来说，LMDeploy kv 量化具备以下优势：\n\n1. 量化不需要校准数据集\n2. 支持 volta 架构（sm70）及以上的所有显卡型号\n3. kv int8 量化精度几乎无损，kv int4 量化精度在可接受范围之内\n4. 推理高效，在 llama2-7b 上加入 int8/int4 kv 量化，RPS 相较于 fp16 分别提升近 30% 和 40%\n\n接下来，我们以 internlm2-chat-7b 模型为例，介绍 kv 量化和推理的若干应用。而在此之前，请安装 lmdeploy\n\n```shell\npip install lmdeploy\n```\n\n## 应用示例\n\n通过 LMDeploy 应用 kv 量化非常简单，只需要设定 `quant_policy` 参数。\n\n**LMDeploy 规定 `qant_policy=4` 表示 kv int4 量化，`quant_policy=8` 表示 kv int8 量化。**\n\n### 离线推理\n\n```python\nfrom lmdeploy import pipeline, TurbomindEngineConfig\nengine_config = TurbomindEngineConfig(quant_policy=8)\npipe = pipeline(\"internlm/internlm2_5-7b-chat\", backend_config=engine_config)\nresponse = pipe([\"Hi, pls intro yourself\", \"Shanghai is\"])\nprint(response)\n```\n\n### 推理服务\n\n```shell\nlmdeploy serve api_server internlm/internlm2_5-7b-chat --quant-policy 8\n```\n\n## 精度评测\n\n我们把 lmdeploy 的 kv 量化应用在若干 LLM 模型上，并使用 opencompass 评测推理精度，结果如下表所示：\n\n| -           | -       | -             | llama2-7b-chat | -       | -       | internlm2-chat-7b | -       | -       | internlm2.5-chat-7b | -       | -       | qwen1.5-7b-chat | -       | -       |\n| ----------- | ------- | ------------- | -------------- | ------- | ------- | ----------------- | ------- | ------- | ------------------- | ------- | ------- | --------------- | ------- | ------- |\n| dataset     | version | metric        | kv fp16        | kv int8 | kv int4 | kv fp16           | kv int8 | kv int4 | kv fp16             | kv int8 | kv int4 | fp16            | kv int8 | kv int4 |\n| ceval       | -       | naive_average | 28.42          | 27.96   | 27.58   | 60.45             | 60.88   | 60.28   | 78.06               | 77.87   | 77.05   | 70.56           | 70.49   | 68.62   |\n| mmlu        | -       | naive_average | 35.64          | 35.58   | 34.79   | 63.91             | 64      | 62.36   | 72.30               | 72.27   | 71.17   | 61.48           | 61.56   | 60.65   |\n| triviaqa    | 2121ce  | score         | 56.09          | 56.13   | 53.71   | 58.73             | 58.7    | 58.18   | 65.09               | 64.87   | 63.28   | 44.62           | 44.77   | 44.04   |\n| gsm8k       | 1d7fe4  | accuracy      | 28.2           | 28.05   | 27.37   | 70.13             | 69.75   | 66.87   | 85.67               | 85.44   | 83.78   | 54.97           | 56.41   | 54.74   |\n| race-middle | 9a54b6  | accuracy      | 41.57          | 41.78   | 41.23   | 88.93             | 88.93   | 88.93   | 92.76               | 92.83   | 92.55   | 87.33           | 87.26   | 86.28   |\n| race-high   | 9a54b6  | accuracy      | 39.65          | 39.77   | 40.77   | 85.33             | 85.31   | 84.62   | 90.51               | 90.42   | 90.42   | 82.53           | 82.59   | 82.02   |\n\n具体的评测方式可以参考[这份指南](../benchmark/evaluate_with_opencompass.md)。评测时，请在config文件中，为推理引擎添加 `quant_policy` 参数。\n\n## 推理效率\n\n| model             | kv type | test settings                            | RPS   | v.s. kv fp16 |\n| ----------------- | ------- | ---------------------------------------- | ----- | ------------ |\n| llama2-chat-7b    | fp16    | tp1 / ratio 0.8 / bs 256 / prompts 10000 | 14.98 | 1.0          |\n| -                 | int8    | tp1 / ratio 0.8 / bs 256 / prompts 10000 | 19.01 | 1.27         |\n| -                 | int4    | tp1 / ratio 0.8 / bs 256 / prompts 10000 | 20.81 | 1.39         |\n| llama2-chat-13b   | fp16    | tp1 / ratio 0.9 / bs 128 / prompts 10000 | 8.55  | 1.0          |\n| -                 | int8    | tp1 / ratio 0.9 / bs 256 / prompts 10000 | 10.96 | 1.28         |\n| -                 | int4    | tp1 / ratio 0.9 / bs 256 / prompts 10000 | 11.91 | 1.39         |\n| internlm2-chat-7b | fp16    | tp1 / ratio 0.8 / bs 256 / prompts 10000 | 24.13 | 1.0          |\n| -                 | int8    | tp1 / ratio 0.8 / bs 256 / prompts 10000 | 25.28 | 1.05         |\n| -                 | int4    | tp1 / ratio 0.8 / bs 256 / prompts 10000 | 25.80 | 1.07         |\n\n上述结果使用的测试脚本是 `benchmark/profile_throughput.py`\n"
  },
  {
    "path": "docs/zh_cn/quantization/llm_compressor.md",
    "content": "# llm-compressor 支持\n\n本指南旨在介绍如何使用 LMDeploy 的 TurboMind 推理引擎，运行经由 [vllm-project/llm-compressor](https://github.com/vllm-project/llm-compressor)工具量化后的模型。\n目前支持的 `llm-compressor` 量化模型包括：\n\n- int4 量化（例如 AWQ、GPTQ）\n\n上述量化模型通过 TurboMind 引擎可以在以下 NVIDIA GPU 架构上运行：\n\n| Compute Capability | Micro-architecture | GPUs                            |\n| ------------------ | ------------------ | ------------------------------- |\n| 7.0                | Volta              | V100                            |\n| 7.2                | Volta              | Jetson Xavier                   |\n| 7.5                | Turing             | GeForce RTX 20 series, T4       |\n| 8.0                | Ampere             | A100, A800, A30                 |\n| 8.6                | Ampere             | GeForce RTX 30 series, A40, A10 |\n| 8.7                | Ampere             | Jetson Orin                     |\n| 8.9                | Ada Lovelace       | GeForce RTX 40 series, L40, L20 |\n| 9.0                | Hopper             | H20, H200, H100, GH200          |\n| 12.0               | Blackwell          | GeForce RTX 50 series           |\n\nLMDeploy 将持续跟进并扩展对 `llm-compressor` 项目的支持。\n\n本文的其余部分由以下章节组成：\n\n<!-- toc -->\n\n- [模型量化](#模型量化)\n- [模型部署](#模型部署)\n- [精度评测](#精度评测)\n\n<!-- tocstop -->\n\n## 模型量化\n\n`llm-compressor` 提供了丰富的模型量化[用例](https://github.com/vllm-project/llm-compressor/tree/main/examples)，请参考其教程选择 LMDeploy 支持的量化算法，完成模型量化工作。\nLMDeploy 也内置了通过 `llm-compressor` 对 Qwen3-30B-A3B 进行 AWQ 量化的[脚本](https://github.com/InternLM/lmdeploy/blob/main/examples/lite/qwen3_30b_a3b_awq.py)，供大家进行参考：\n\n```shell\n# 创建 conda 环境\nconda create -n lmdeploy python=3.10 -y\nconda activate lmdeploy\n\n# 安装 llm-compressor\npip install llmcompressor\n\n# 下载 lmdeploy 源码，运行量化用用例\ngit clone https://github.com/InternLM/lmdeploy\ncd lmdeploy\npython examples/lite/qwen3_30b_a3b_awq.py --work-dir ./qwen3_30b_a3b_awq\n\n```\n\n在接下来的章节中，我们以此量化模型为例，介绍模型部署、评测精度等方法\n\n## 模型部署\n\n### 离线推理\n\n量化后的模型，通过以下几行简单的代码，可以实现离线批处理：\n\n```python\nfrom lmdeploy import pipeline, TurbomindEngineConfig\nengine_config = TurbomindEngineConfig()\nwith pipeline(\"./qwen3_30b_a3b_4bit\", backend_config=engine_config) as pipe:\n    response = pipe([\"Hi, pls intro yourself\", \"Shanghai is\"])\n    print(response)\n```\n\n关于 pipeline 的详细介绍，请参考[这里](https://lmdeploy.readthedocs.io/zh-cn/latest/llm/pipeline.html)\n\n### 在线服务\n\nLMDeploy api_server 支持把模型一键封装为服务，对外提供的 RESTful API 兼容 openai 的接口。以下为服务启动的示例：\n\n```shell\nlmdeploy serve api_server ./qwen3_30b_a3b_4bit --backend turbomind\n```\n\n服务默认端口是23333。在 server 启动后，你可以通过 openai SDK 访问服务。关于服务的命令参数，以及访问服务的方式，可以阅读[这份](https://lmdeploy.readthedocs.io/zh-cn/latest/llm/api_server.html)文档\n\n## 精度评测\n\n我们将 Qwen3-8B (Dense) 与 Qwen3-30B-A3B (MoE) 的 AWQ 对称/非对称量化模型通过 LMDeploy 部署为服务，并使用 [opencompass](https://github.com/open-compass/opencompass) 在多个学术数据集上评测。结果显示：Qwen3-8B 的非对称量化整体优于对称量化，而 Qwen3-30B-A3B 在两种量化方式间差异不显著；Qwen3-8B 在对称/非对称量化下与 BF16 模型的精度差异小于 Qwen3-30B-A3B。与 BF16 相比，量化模型在长输出数据集，比如 aime2025 (平均 17,635 tokens)、LCB (平均 14,157 tokens)，精度下降更明显；在中短输出数据集，比如 ifeval (平均 1,885 tokens)，mmlu_pro (平均 2,826)，精度符合预期。\n\n| dataset           | Qwen3-8B |         |          | Qwen3-30B-A3B |         |          |\n| ----------------- | -------- | ------- | -------- | ------------- | ------- | -------- |\n|                   | bf16     | awq sym | awq asym | bf16          | awq sym | awq asym |\n| ifeval            | 85.58    | 83.73   | 85.77    | 86.32         | 84.10   | 84.29    |\n| hle               | 5.05     | 5.05    | 5.24     | 7.00          | 5.47    | 5.65     |\n| gpqa              | 59.97    | 56.57   | 59.47    | 61.74         | 57.95   | 57.07    |\n| aime2025          | 69.48    | 64.38   | 63.96    | 73.44         | 64.79   | 66.67    |\n| mmlu_pro          | 73.69    | 71.73   | 72.34    | 77.85         | 75.77   | 75.69    |\n| LCBCodeGeneration | 50.86    | 44.10   | 46.95    | 56.67         | 50.86   | 49.24    |\n\n复现方式可以参考[这份](https://lmdeploy.readthedocs.io/zh-cn/latest/benchmark/evaluate_with_opencompass.html)文档\n"
  },
  {
    "path": "docs/zh_cn/quantization/w4a16.md",
    "content": "# INT4 模型量化和部署\n\nLMDeploy TurboMind 引擎支持由 [AWQ](https://arxiv.org/abs/2306.00978) 和 [GPTQ](https://github.com/AutoGPTQ/AutoGPTQ) 两种量化方法量化的 4bit 模型的推理。然而，LMDeploy 量化模块目前仅支持 AWQ 量化算法。\n\n可用于 AWQ/GPTQ INT4 推理的 NVIDIA GPU 包括：\n\n- V100(sm70): V100\n- Turing(sm75): 20 系列，T4\n- Ampere(sm80,sm86): 30 系列，A10, A16, A30, A100\n- Ada Lovelace(sm89): 40 系列\n\n在进行量化和推理之前，请确保按照[安装指南](../get_started/installation.md)安装了 lmdeploy。\n\n本文的其余部分由以下章节组成：\n\n<!-- toc -->\n\n- [模型量化](#模型量化)\n- [模型评测](#模型评测)\n- [模型推理](#模型推理)\n- [推理服务](#推理服务)\n- [推理性能](#推理性能)\n\n<!-- tocstop -->\n\n## 模型量化\n\n仅需执行一条命令，就可以完成模型量化工作。量化结束后，权重文件存放在 `$WORK_DIR` 下。\n\n```shell\nexport HF_MODEL=internlm/internlm2_5-7b-chat\nexport WORK_DIR=internlm/internlm2_5-7b-chat-4bit\n\nlmdeploy lite auto_awq \\\n   $HF_MODEL \\\n  --calib-dataset 'wikitext2' \\\n  --calib-samples 128 \\\n  --calib-seqlen 2048 \\\n  --w-bits 4 \\\n  --w-group-size 128 \\\n  --batch-size 1 \\\n  --work-dir $WORK_DIR\n```\n\n绝大多数情况下，在执行上述命令时，可选参数可不用填写，使用默认的即可。比如量化 [internlm/internlm2_5-7b-chat](https://huggingface.co/internlm/internlm2_5-7b-chat) 模型，命令可以简化为：\n\n```shell\nlmdeploy lite auto_awq internlm/internlm2_5-7b-chat --work-dir internlm2_5-7b-chat-4bit\n```\n\n**Note:**\n\n- 我们建议 --work-dir 参数带有模型名字，就像上面的例子展示的那样。这样在推理时，就不用指定对话模板了。因为推理接口会以模糊搜索方式，选出和 --work-dir 近似的对话模板\n- 如果量化模型精度有损，建议开启 --search-scale 重新量化，并调大 --batch-size，比如 8。search_scale 开启后，量化过程会比较耗时。--batch-size 会影响内存占用量，可以根据实际情况，酌情调整。\n\n量化后的模型，可以用一些工具快速验证对话效果。\n\n比如，直接在控制台和模型对话，\n\n```shell\nlmdeploy chat ./internlm2_5-7b-chat-4bit --model-format awq\n```\n\n## 模型评测\n\n我们使用 [OpenCompass](https://opencompass.readthedocs.io/zh-cn/latest/index.html) 评测量化模型在各个维度上的能力。方法请参考[此处](https://opencompass.readthedocs.io/zh-cn/latest/advanced_guides/evaluation_lmdeploy.html)\n\n## 模型推理\n\n量化后的模型，通过以下几行简单的代码，可以实现离线推理：\n\n```python\nfrom lmdeploy import pipeline, TurbomindEngineConfig\nengine_config = TurbomindEngineConfig(model_format='awq')\npipe = pipeline(\"./internlm2_5-7b-chat-4bit\", backend_config=engine_config)\nresponse = pipe([\"Hi, pls intro yourself\", \"Shanghai is\"])\nprint(response)\n```\n\n关于 pipeline 的详细介绍，请参考[这里](../llm/pipeline.md)\n\n除了推理本地量化模型外，LMDeploy 还支持直接推理 huggingface hub 上的通过 AWQ 量化的 4bit 权重模型，比如 [lmdeploy 空间](https://huggingface.co/lmdeploy)和 [TheBloke 空间](https://huggingface.co/TheBloke)下的模型。\n\n```python\n# 推理 lmdeploy 空间下的模型\nfrom lmdeploy import pipeline, TurbomindEngineConfig\npipe = pipeline(\"lmdeploy/llama2-chat-70b-4bit\",\n                backend_config=TurbomindEngineConfig(model_format='awq', tp=4))\nresponse = pipe([\"Hi, pls intro yourself\", \"Shanghai is\"])\nprint(response)\n\n# 推理 TheBloke 空间下的模型（试试codellama行不行）\nfrom lmdeploy import pipeline, TurbomindEngineConfig, ChatTemplateConfig\npipe = pipeline(\"TheBloke/LLaMA2-13B-Tiefighter-AWQ\",\n                backend_config=TurbomindEngineConfig(model_format='awq'),\n                chat_template_config=ChatTemplateConfig(model_name='llama2')\n                )\nresponse = pipe([\"Hi, pls intro yourself\", \"Shanghai is\"])\nprint(response)\n```\n\n## 推理服务\n\nLMDeploy `api_server` 支持把模型一键封装为服务，对外提供的 RESTful API 兼容 openai 的接口。以下为服务启动的示例：\n\n```shell\nlmdeploy serve api_server ./internlm2_5-7b-chat-4bit --backend turbomind --model-format awq\n```\n\n服务默认端口是23333。在 server 启动后，你可以在终端通过`api_client`与server进行对话：\n\n```shell\nlmdeploy serve api_client http://0.0.0.0:23333\n```\n\n还可以通过 Swagger UI `http://0.0.0.0:23333` 在线阅读和试用 `api_server` 的各接口，也可直接查阅[文档](../llm/api_server.md)，了解各接口的定义和使用方法。\n\n## 推理性能\n\n我们在 NVIDIA GeForce RTX 4090 上分别测试了 4-bit Llama-2-7B-chat 和 Llama-2-13B-chat 模型的 token 生成速度。测试配置为 batch size = 1，(prompt_tokens, completion_tokens) = (1, 512)\n\n| model            | llm-awq | mlc-llm | turbomind |\n| ---------------- | ------- | ------- | --------- |\n| Llama-2-7B-chat  | 112.9   | 159.4   | 206.4     |\n| Llama-2-13B-chat | N/A     | 90.7    | 115.8     |\n\n## 快速问答\n\n1. 量化时出现 Out of Memory 显存不够：可以通过减小传参 `--calib-seqlen`，增大传参 `--calib-samples`，并使用 `--batch-size` 为 1。\n2. 量化时，无法链接huggingface并下载数据集。可以尝试使用镜像，`export HF_ENDPOINT=https://hf-mirror.com`。\n"
  },
  {
    "path": "docs/zh_cn/quantization/w8a8.md",
    "content": "# W8A8 LLM 模型部署\n\nLMDeploy 提供了使用 8-bit 整数(INT8)和浮点数(FP8)对神经网络模型进行量化和推理的功能。\n\n可用于 INT8 和 FP8 推理的 NVIDIA GPU 分别为：\n\n- INT8\n  - V100(sm70): V100\n  - Turing(sm75): 20 series, T4\n  - Ampere(sm80,sm86): 30 series, A10, A16, A30, A100\n  - Ada Lovelace(sm89): 40 series\n  - Hopper(sm90): H100\n- FP8\n  - Ada Lovelace(sm89): 40 series\n  - Hopper(sm90): H100\n\n首先，执行如下命令安装lmdeploy：\n\n```shell\npip install lmdeploy[all]\n```\n\n## 8-bit 权重量化\n\n进行 8-bit 权重量化需要经历以下三步：\n\n1. **权重平滑**：首先对语言模型的权重进行平滑处理，以便更好地进行量化。\n2. **模块替换**：使用 `QRMSNorm` 和 `QLinear` 模块替换原模型 `DecoderLayer` 中的 `RMSNorm` 模块和 `nn.Linear` 模块。`lmdeploy/pytorch/models/q_modules.py` 文件中定义了这些量化模块。\n3. **保存量化模型**：完成上述必要的替换后，我们即可保存新的量化模型。\n\nlmdeploy 提供了命令行工具 `lmdeploy lite smooth_quant` 实现了以上三个步骤。并且其中命令行参数 `--quant-dtype` 可以用来控制是进行8-bit整数还是浮点数类型的量化。更多命令行工具使用方式，请执行 `lmdeploy lite smooth_quant --help` 查看。\n\n以下示例演示了进行 int8 或 fp8 的量化命令。\n\n- int8\n\n  ```shell\n  lmdeploy lite smooth_quant internlm/internlm2_5-7b-chat --work-dir ./internlm2_5-7b-chat-int8 --quant-dtype int8\n  ```\n\n- fp8\n\n  ```shell\n  lmdeploy lite smooth_quant internlm/internlm2_5-7b-chat --work-dir ./internlm2_5-7b-chat-fp8 --quant-dtype fp8\n  ```\n\n## 模型推理\n\n量化后的模型，通过以下几行简单的代码，可以实现离线推理：\n\n```python\nfrom lmdeploy import pipeline, PytorchEngineConfig\n\nengine_config = PytorchEngineConfig(tp=1)\npipe = pipeline(\"internlm2_5-7b-chat-int8\", backend_config=engine_config)\nresponse = pipe([\"Hi, pls intro yourself\", \"Shanghai is\"])\nprint(response)\n```\n\n关于 pipeline 的详细介绍，请参考[这里](../llm/pipeline.md)\n\n## 推理服务\n\nLMDeploy `api_server` 支持把模型一键封装为服务，对外提供的 RESTful API 兼容 openai 的接口。以下为服务启动的示例：\n\n```shell\nlmdeploy serve api_server ./internlm2_5-7b-chat-int8 --backend pytorch\n```\n\n服务默认端口是23333。在 server 启动后，你可以在终端通过`api_client`与server进行对话：\n\n```shell\nlmdeploy serve api_client http://0.0.0.0:23333\n```\n\n还可以通过 Swagger UI `http://0.0.0.0:23333` 在线阅读和试用 `api_server` 的各接口，也可直接查阅[文档](../llm/api_server.md)，了解各接口的定义和使用方法。\n"
  },
  {
    "path": "docs/zh_cn/supported_models/reward_models.md",
    "content": "# 奖励模型\n\nLMDeploy 支持以下奖励模型：\n\n|      Model       |     Size      | Supported Inference Engine |\n| :--------------: | :-----------: | :------------------------: |\n| Qwen2.5-Math-RM  |      72B      |          PyTorch           |\n| InternLM2-Reward | 1.8B, 7B, 20B |          PyTorch           |\n|      POLAR       |   1.8B, 7B    |          PyTorch           |\n\n## 离线推理\n\n我们以 `internlm/internlm2-1_8b-reward` 为例：\n\n```python\nfrom transformers import AutoTokenizer\nfrom lmdeploy import pipeline, PytorchEngineConfig\n\nmodel_path = \"internlm/internlm2-1_8b-reward\"\nchat = [\n    {\"role\": \"system\", \"content\": \"Please reason step by step, and put your final answer within \\\\boxed{}.\"},\n    {\"role\": \"user\", \"content\": \"Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?\"},\n    {\"role\": \"assistant\", \"content\": \"To determine how much Janet makes from selling the duck eggs at the farmers' market, we need to follow these steps:\\n\\n1. Calculate the total number of eggs laid by the ducks each day.\\n2. Determine how many eggs Janet eats and bakes for herself each day.\\n3. Find out how many eggs are left to be sold.\\n4. Calculate the revenue from selling the remaining eggs at $2 per egg.\\n\\nLet's start with the first step:\\n\\n1. Janet's ducks lay 16 eggs per day.\\n\\nNext, we calculate how many eggs Janet eats and bakes for herself each day:\\n\\n2. Janet eats 3 eggs for breakfast every morning.\\n3. Janet bakes 4 eggs for her friends every day.\\n\\nSo, the total number of eggs Janet eats and bakes for herself each day is:\\n\\\\[ 3 + 4 = 7 \\\\text{ eggs} \\\\]\\n\\nNow, we find out how many eggs are left to be sold:\\n\\\\[ 16 - 7 = 9 \\\\text{ eggs} \\\\]\\n\\nFinally, we calculate the revenue from selling the remaining eggs at $2 per egg:\\n\\\\[ 9 \\\\times 2 = 18 \\\\text{ dollars} \\\\]\\n\\nTherefore, Janet makes 18 dollars every day at the farmers' market.\"}\n]\n\ntokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)\n\nconversation_str = tokenizer.apply_chat_template(\n    chat,\n    tokenize=False,\n    add_generation_prompt=False\n)\n\ninput_ids = tokenizer.encode(\n    conversation_str,\n    add_special_tokens=False\n)\n\n\nif __name__ == '__main__':\n    engine_config = PytorchEngineConfig(tp=tp)\n    with pipeline(model_path, backend_config=engine_config) as pipe:\n        score = pipe.get_reward_score(input_ids)\n        print(f'score: {score}')\n```\n\n## 在线推理\n\n启动 API 服务：\n\n```bash\nlmdeploy serve api_server internlm/internlm2-1_8b-reward --backend pytorch\n```\n\n通过 `/pooling` 接口获取奖励分数：\n\n```\ncurl http://0.0.0.0:23333/pooling \\\n  -H \"Content-Type: application/json\" \\\n  -d '{\n    \"model\": \"internlm/internlm2-1_8b-reward\",\n    \"input\": \"Who are you?\"\n  }'\n```\n"
  },
  {
    "path": "docs/zh_cn/supported_models/supported_models.md",
    "content": "# 支持的模型\n\n以下列表分别为 LMDeploy TurboMind 引擎和 PyTorch 引擎在不同软硬件平台下支持的模型\n\n## TurboMind CUDA 平台\n\n|              Model               |      Size      | Type | FP16/BF16 | KV INT8 | KV INT4 | W4A16 |\n| :------------------------------: | :------------: | :--: | :-------: | :-----: | :-----: | :---: |\n|              Llama               |    7B - 65B    | LLM  |    Yes    |   Yes   |   Yes   |  Yes  |\n|              Llama2              |    7B - 70B    | LLM  |    Yes    |   Yes   |   Yes   |  Yes  |\n|              Llama3              |    8B, 70B     | LLM  |    Yes    |   Yes   |   Yes   |  Yes  |\n|             Llama3.1             |    8B, 70B     | LLM  |    Yes    |   Yes   |   Yes   |  Yes  |\n|     Llama3.2<sup>\\[2\\]</sup>     |     1B, 3B     | LLM  |    Yes    |  Yes\\*  |  Yes\\*  |  Yes  |\n|             InternLM             |    7B - 20B    | LLM  |    Yes    |   Yes   |   Yes   |  Yes  |\n|            InternLM2             |    7B - 20B    | LLM  |    Yes    |   Yes   |   Yes   |  Yes  |\n|           InternLM2.5            |       7B       | LLM  |    Yes    |   Yes   |   Yes   |  Yes  |\n|            InternLM3             |       8B       | LLM  |    Yes    |   Yes   |   Yes   |  Yes  |\n|       InternLM-XComposer2        |  7B, 4khd-7B   | MLLM |    Yes    |   Yes   |   Yes   |  Yes  |\n|      InternLM-XComposer2.5       |       7B       | MLLM |    Yes    |   Yes   |   Yes   |  Yes  |\n|            Intern-S1             |      241B      | MLLM |    Yes    |   Yes   |   Yes   |  No   |\n|          Intern-S1-mini          |      8.3B      | MLLM |    Yes    |   Yes   |   Yes   |  No   |\n|          Intern-S1-Pro           |      1TB       | MLLM |    Yes    |    -    |    -    |  No   |\n|               Qwen               |   1.8B - 72B   | LLM  |    Yes    |   Yes   |   Yes   |  Yes  |\n|     Qwen1.5<sup>\\[1\\]</sup>      |  1.8B - 110B   | LLM  |    Yes    |   Yes   |   Yes   |  Yes  |\n|      Qwen2<sup>\\[2\\]</sup>       |   0.5B - 72B   | LLM  |    Yes    |  Yes\\*  |  Yes\\*  |  Yes  |\n|            Qwen2-MoE             |    57BA14B     | LLM  |    Yes    |   Yes   |   Yes   |  Yes  |\n|     Qwen2.5<sup>\\[2\\]</sup>      |   0.5B - 72B   | LLM  |    Yes    |  Yes\\*  |  Yes\\*  |  Yes  |\n|              Qwen3               |   0.6B-235B    | LLM  |    Yes    |   Yes   |  Yes\\*  |  Yes  |\n|     Qwen3.5<sup>\\[3\\]</sup>      |   0.8B-397B    | LLM  |    Yes    |   Yes   |   No    |  Yes  |\n|     Mistral<sup>\\[1\\]</sup>      |       7B       | LLM  |    Yes    |   Yes   |   Yes   |  No   |\n|             Mixtral              |  8x7B, 8x22B   | LLM  |    Yes    |   Yes   |   Yes   |  Yes  |\n|           DeepSeek-V2            |   16B, 236B    | LLM  |    Yes    |   Yes   |   Yes   |  No   |\n|          DeepSeek-V2.5           |      236B      | LLM  |    Yes    |   Yes   |   Yes   |  No   |\n|             Qwen-VL              |       7B       | MLLM |    Yes    |   Yes   |   Yes   |  Yes  |\n|           DeepSeek-VL            |       7B       | MLLM |    Yes    |   Yes   |   Yes   |  Yes  |\n|             Baichuan             |       7B       | LLM  |    Yes    |   Yes   |   Yes   |  Yes  |\n|            Baichuan2             |       7B       | LLM  |    Yes    |   Yes   |   Yes   |  Yes  |\n|            Code Llama            |    7B - 34B    | LLM  |    Yes    |   Yes   |   Yes   |  No   |\n|                YI                |    6B - 34B    | LLM  |    Yes    |   Yes   |   Yes   |  Yes  |\n|          LLaVA(1.5,1.6)          |    7B - 34B    | MLLM |    Yes    |   Yes   |   Yes   |  Yes  |\n|             InternVL             |  v1.1 - v1.5   | MLLM |    Yes    |   Yes   |   Yes   |  Yes  |\n|            InternVL2             | 1-2B, 8B - 76B | MLLM |    Yes    |  Yes\\*  |  Yes\\*  |  Yes  |\n| InternVL2.5(MPO)<sup>\\[2\\]</sup> |    1 - 78B     | MLLM |    Yes    |  Yes\\*  |  Yes\\*  |  Yes  |\n|    InternVL3<sup>\\[2\\]</sup>     |    1 - 78B     | MLLM |    Yes    |  Yes\\*  |  Yes\\*  |  Yes  |\n|   InternVL3.5<sup>\\[3\\]</sup>    |  1 - 241BA28B  | MLLM |    Yes    |  Yes\\*  |  Yes\\*  |  No   |\n|             ChemVLM              |    8B - 26B    | MLLM |    Yes    |   Yes   |   Yes   |  Yes  |\n|       MiniCPM-Llama3-V-2_5       |       -        | MLLM |    Yes    |   Yes   |   Yes   |  Yes  |\n|          MiniCPM-V-2_6           |       -        | MLLM |    Yes    |   Yes   |   Yes   |  Yes  |\n|               GLM4               |       9B       | LLM  |    Yes    |   Yes   |   Yes   |  Yes  |\n|            CodeGeeX4             |       9B       | LLM  |    Yes    |   Yes   |   Yes   |   -   |\n|              Molmo               |    7B-D,72B    | MLLM |    Yes    |   Yes   |   Yes   |  No   |\n|             gpt-oss              |    20B,120B    | LLM  |    Yes    |   Yes   |   Yes   |  Yes  |\n\n“-” 表示还没有验证。\n\n```{note}\n* [1] turbomind 引擎不支持 window attention。所以，对于应用了 window attention，并开启了对应的开关\"use_sliding_window\"的模型，比如 Mistral、Qwen1.5 等，在推理时，请选择 pytorch engine\n* [2] 当模型的 head_dim 非 128 时，turbomind 不支持它的 kv cache 4/8 bit 量化和推理。比如，llama3.2-1B，qwen2-0.5B，internvl2-1B 等等\n* [3] turbomind 目前暂不支持 Qwen3.5 系列的视觉编码器。\n```\n\n## PyTorchEngine CUDA 平台\n\n|             Model              |      Size       | Type | FP16/BF16 | KV INT8 | KV INT4 | W8A8 | W4A16 |\n| :----------------------------: | :-------------: | :--: | :-------: | :-----: | :-----: | :--: | :---: |\n|             Llama              |    7B - 65B     | LLM  |    Yes    |   Yes   |   Yes   | Yes  |  Yes  |\n|             Llama2             |    7B - 70B     | LLM  |    Yes    |   Yes   |   Yes   | Yes  |  Yes  |\n|             Llama3             |     8B, 70B     | LLM  |    Yes    |   Yes   |   Yes   | Yes  |  Yes  |\n|            Llama3.1            |     8B, 70B     | LLM  |    Yes    |   Yes   |   Yes   | Yes  |  Yes  |\n|            Llama3.2            |     1B, 3B      | LLM  |    Yes    |   Yes   |   Yes   | Yes  |  Yes  |\n|             Llama4             | Scout, Maverick | MLLM |    Yes    |   Yes   |   Yes   |  -   |   -   |\n|            InternLM            |    7B - 20B     | LLM  |    Yes    |   Yes   |   Yes   | Yes  |  Yes  |\n|           InternLM2            |    7B - 20B     | LLM  |    Yes    |   Yes   |   Yes   | Yes  |  Yes  |\n|          InternLM2.5           |       7B        | LLM  |    Yes    |   Yes   |   Yes   | Yes  |  Yes  |\n|           InternLM3            |       8B        | LLM  |    Yes    |   Yes   |   Yes   | Yes  |  Yes  |\n|           Intern-S1            |      241B       | MLLM |    Yes    |   Yes   |   Yes   | Yes  |   -   |\n|         Intern-S1-mini         |      8.3B       | MLLM |    Yes    |   Yes   |   Yes   | Yes  |   -   |\n|           Baichuan2            |       7B        | LLM  |    Yes    |   Yes   |   Yes   | Yes  |  No   |\n|           Baichuan2            |       13B       | LLM  |    Yes    |   Yes   |   Yes   |  No  |  No   |\n|            ChatGLM2            |       6B        | LLM  |    Yes    |   Yes   |   Yes   |  No  |  No   |\n|               YI               |    6B - 34B     | LLM  |    Yes    |   Yes   |   Yes   | Yes  |  Yes  |\n|            Mistral             |       7B        | LLM  |    Yes    |   Yes   |   Yes   | Yes  |  Yes  |\n|            Mixtral             |   8x7B, 8x22B   | LLM  |    Yes    |   Yes   |   Yes   |  No  |  No   |\n|              QWen              |   1.8B - 72B    | LLM  |    Yes    |   Yes   |   Yes   | Yes  |  Yes  |\n|            QWen1.5             |   0.5B - 110B   | LLM  |    Yes    |   Yes   |   Yes   | Yes  |  Yes  |\n|          QWen1.5-MoE           |      A2.7B      | LLM  |    Yes    |   Yes   |   Yes   |  No  |  No   |\n|             QWen2              |   0.5B - 72B    | LLM  |    Yes    |   Yes   |   No    | Yes  |  Yes  |\n|            Qwen2.5             |   0.5B - 72B    | LLM  |    Yes    |   Yes   |   No    | Yes  |  Yes  |\n|             Qwen3              |   0.6B - 235B   | LLM  |    Yes    |   Yes   |  Yes\\*  |  -   |  Yes  |\n|            QWen3.5             |    0.8B-397B    | MLLM |    Yes    |   No    |   No    |  No  |  No   |\n|           QWen3-Next           |       80B       | LLM  |    Yes    |   No    |   No    |  No  |  No   |\n|            QWen2-VL            |     2B, 7B      | MLLM |    Yes    |   Yes   |   No    |  No  |  Yes  |\n|           QWen2.5-VL           |    3B - 72B     | MLLM |    Yes    |   No    |   No    |  No  |  No   |\n|            QWen3-VL            |    2B - 235B    | MLLM |    Yes    |   No    |   No    |  No  |  No   |\n|          DeepSeek-MoE          |       16B       | LLM  |    Yes    |   No    |   No    |  No  |  No   |\n|          DeepSeek-V2           |    16B, 236B    | LLM  |    Yes    |   No    |   No    |  No  |  No   |\n|         DeepSeek-V2.5          |      236B       | LLM  |    Yes    |   No    |   No    |  No  |  No   |\n|          DeepSeek-V3           |      685B       | LLM  |    Yes    |   No    |   No    |  No  |  No   |\n|         DeepSeek-V3.2          |      685B       | LLM  |    Yes    |   No    |   No    |  No  |  No   |\n|          DeepSeek-VL2          |    3B - 27B     | MLLM |    Yes    |   No    |   No    |  No  |  No   |\n|            MiniCPM3            |       4B        | LLM  |    Yes    |   Yes   |   Yes   |  No  |  No   |\n|         MiniCPM-V-2_6          |       8B        | LLM  |    Yes    |   No    |   No    |  No  |  Yes  |\n|             Gemma              |      2B-7B      | LLM  |    Yes    |   Yes   |   Yes   |  No  |  No   |\n|           StarCoder2           |     3B-15B      | LLM  |    Yes    |   Yes   |   Yes   |  No  |  No   |\n|           Phi-3-mini           |      3.8B       | LLM  |    Yes    |   Yes   |   Yes   | Yes  |  Yes  |\n|          Phi-3-vision          |      4.2B       | MLLM |    Yes    |   Yes   |   Yes   |  -   |   -   |\n|           Phi-4-mini           |      3.8B       | LLM  |    Yes    |   Yes   |   Yes   | Yes  |  Yes  |\n|          CogVLM-Chat           |       17B       | MLLM |    Yes    |   Yes   |   Yes   |  -   |   -   |\n|          CogVLM2-Chat          |       19B       | MLLM |    Yes    |   Yes   |   Yes   |  -   |   -   |\n| LLaVA(1.5,1.6)<sup>\\[2\\]</sup> |     7B-34B      | MLLM |    No     |   No    |   No    |  No  |  No   |\n|         InternVL(v1.5)         |     2B-26B      | MLLM |    Yes    |   Yes   |   Yes   |  No  |  Yes  |\n|           InternVL2            |     1B-76B      | MLLM |    Yes    |   Yes   |   Yes   |  -   |   -   |\n|        InternVL2.5(MPO)        |     1B-78B      | MLLM |    Yes    |   Yes   |   Yes   |  -   |   -   |\n|           InternVL3            |     1B-78B      | MLLM |    Yes    |   Yes   |   Yes   |  -   |   -   |\n|          InternVL3.5           |   1B-241BA28B   | MLLM |    Yes    |   Yes   |   Yes   |  No  |  No   |\n| Mono-InternVL<sup>\\[1\\]</sup>  |       2B        | MLLM |   Yes\\*   |   Yes   |   Yes   |  -   |   -   |\n|            ChemVLM             |     8B-26B      | MLLM |    Yes    |   Yes   |   No    |  -   |   -   |\n|             Gemma2             |     9B-27B      | LLM  |    Yes    |   Yes   |   Yes   |  -   |   -   |\n|             Gemma3             |     1B-27B      | MLLM |    Yes    |   Yes   |   Yes   |  -   |   -   |\n|             GLM-4              |       9B        | LLM  |    Yes    |   Yes   |   Yes   |  No  |  No   |\n|           GLM-4-0414           |       9B        | LLM  |    Yes    |   Yes   |   Yes   |  -   |   -   |\n|             GLM-4V             |       9B        | MLLM |    Yes    |   Yes   |   Yes   |  No  |  Yes  |\n|       GLM-4.1V-Thinking        |       9B        | MLLM |    Yes    |   Yes   |   Yes   |  -   |   -   |\n|            GLM-4.5             |      355B       | LLM  |    Yes    |   Yes   |   Yes   |  -   |   -   |\n|          GLM-4.5-Air           |      106B       | LLM  |    Yes    |   Yes   |   Yes   |  -   |   -   |\n|         GLM-4.7-Flash          |       30B       | LLM  |    Yes    |   No    |   No    |  No  |  No   |\n|             GLM-5              |      754B       | LLM  |    Yes    |   No    |   No    |  No  |  No   |\n|           CodeGeeX4            |       9B        | LLM  |    Yes    |   Yes   |   Yes   |  -   |   -   |\n|          Phi-3.5-mini          |      3.8B       | LLM  |    Yes    |   Yes   |   No    |  -   |   -   |\n|          Phi-3.5-MoE           |     16x3.8B     | LLM  |    Yes    |   Yes   |   No    |  -   |   -   |\n|         Phi-3.5-vision         |      4.2B       | MLLM |    Yes    |   Yes   |   No    |  -   |   -   |\n|              SDAR              |    1.7B-30B     | LLM  |    Yes    |   Yes   |   No    |  -   |   -   |\n\n```{note}\n* [1] 目前，Mono-InternVL不支持FP16，因为数值不稳定。请改用BF16\n* [2] 自 0.6.4 之后，PyTorch 引擎移除了对 llava 模型原始格式的支持。我们建议使用它们对应的 transformers 格式的模型。这些模型可以在 https://huggingface.co/llava-hf 中找到\n自 0.11.1 起，PytorchEngine 移除了 mllama 的支持\n```\n\n## PyTorchEngine 其他平台\n\n|                |           |      |  Atlas 800T A2   |  Atlas 800T A2   | Atlas 800T A2 | Atlas 800T A2 | Atlas 300I Duo |  Atlas 800T A3   | Maca C500 | Cambricon |\n| :------------: | :-------: | :--: | :--------------: | :--------------: | :-----------: | :-----------: | :------------: | :--------------: | :-------: | :-------: |\n|     Model      |   Size    | Type | FP16/BF16(eager) | FP16/BF16(graph) |  W8A8(graph)  | W4A16(eager)  |  FP16(graph)   | FP16/BF16(eager) |  BF/FP16  |  BF/FP16  |\n|     Llama2     | 7B - 70B  | LLM  |       Yes        |       Yes        |      Yes      |      Yes      |       -        |       Yes        |    Yes    |    Yes    |\n|     Llama3     |    8B     | LLM  |       Yes        |       Yes        |      Yes      |      Yes      |      Yes       |       Yes        |    Yes    |    Yes    |\n|    Llama3.1    |    8B     | LLM  |       Yes        |       Yes        |      Yes      |      Yes      |      Yes       |       Yes        |    Yes    |    Yes    |\n|   InternLM2    | 7B - 20B  | LLM  |       Yes        |       Yes        |      Yes      |      Yes      |      Yes       |       Yes        |    Yes    |    Yes    |\n|  InternLM2.5   | 7B - 20B  | LLM  |       Yes        |       Yes        |      Yes      |      Yes      |      Yes       |       Yes        |    Yes    |    Yes    |\n|   InternLM3    |    8B     | LLM  |       Yes        |       Yes        |      Yes      |      Yes      |      Yes       |       Yes        |    Yes    |    Yes    |\n|    Mixtral     |   8x7B    | LLM  |       Yes        |       Yes        |      No       |      No       |      Yes       |        -         |    Yes    |    Yes    |\n|  QWen1.5-MoE   |   A2.7B   | LLM  |       Yes        |        -         |      No       |      No       |       -        |        -         |    Yes    |     -     |\n|   QWen2(.5)    |    7B     | LLM  |       Yes        |       Yes        |      Yes      |      Yes      |      Yes       |        -         |    Yes    |    Yes    |\n|    QWen2-VL    |  2B, 7B   | MLLM |       Yes        |       Yes        |       -       |       -       |       -        |        -         |    Yes    |    No     |\n|   QWen2.5-VL   | 3B - 72B  | MLLM |       Yes        |       Yes        |       -       |       -       |      Yes       |        -         |    Yes    |    No     |\n|   QWen2-MoE    |  A14.57B  | LLM  |       Yes        |        -         |      No       |      No       |       -        |        -         |    Yes    |     -     |\n|     QWen3      | 0.6B-235B | LLM  |       Yes        |       Yes        |      No       |      No       |      Yes       |       Yes        |    Yes    |    Yes    |\n|  DeepSeek-V2   |    16B    | LLM  |        No        |       Yes        |      No       |      No       |       -        |        -         |     -     |     -     |\n| InternVL(v1.5) |  2B-26B   | MLLM |       Yes        |        -         |      Yes      |      Yes      |       -        |        -         |    Yes    |     -     |\n|   InternVL2    |  1B-40B   | MLLM |       Yes        |       Yes        |      Yes      |      Yes      |      Yes       |        -         |    Yes    |    Yes    |\n|  InternVL2.5   |  1B-78B   | MLLM |       Yes        |       Yes        |      Yes      |      Yes      |      Yes       |        -         |    Yes    |    Yes    |\n|   InternVL3    |  1B-78B   | MLLM |       Yes        |       Yes        |      Yes      |      Yes      |      Yes       |        -         |    Yes    |    Yes    |\n|  CogVLM2-chat  |    19B    | MLLM |       Yes        |        No        |       -       |       -       |       -        |        -         |    Yes    |     -     |\n|     GLM4V      |    9B     | MLLM |       Yes        |        No        |       -       |       -       |       -        |        -         |     -     |     -     |\n"
  },
  {
    "path": "eval/config.py",
    "content": "# flake8: noqa\n\nfrom mmengine.config import read_base\nfrom opencompass.models import OpenAISDK\nfrom opencompass.partitioners import NaivePartitioner, NumWorkerPartitioner\nfrom opencompass.runners import LocalRunner\nfrom opencompass.tasks import OpenICLEvalTask, OpenICLInferTask\nfrom opencompass.utils.text_postprocessors import extract_non_reasoning_content\n\n# Dataset Configurations\nwith read_base():\n    # Datasets\n    from opencompass.configs.datasets.aime2025.aime2025_llmjudge_academic import aime2025_datasets\n    from opencompass.configs.datasets.gpqa.gpqa_cascade_eval_academic import gpqa_datasets\n    from opencompass.configs.datasets.HLE.hle_llmverify_academic import hle_datasets\n    from opencompass.configs.datasets.IFEval.IFEval_gen_353ae7 import ifeval_datasets\n    from opencompass.configs.datasets.livecodebench.livecodebench_v6_academic import LCBCodeGeneration_dataset\n    from opencompass.configs.datasets.mmlu_pro.mmlu_pro_0shot_nocot_genericllmeval_gen_08c1de import mmlu_pro_datasets\n    # Summary Groups\n    from opencompass.configs.summarizers.groups.mmlu_pro import mmlu_pro_summary_groups\n\n# <dataset_replace_tag>\ndatasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), [])\ndatasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), []) + [LCBCodeGeneration_dataset]\n# </dataset_replace_tag>\n\nTASK_TAG = ''\nAPI_SERVER_ADDR = 'http://<API_SERVER>'\nSERVED_MODEL_PATH = ''\n\nmodels = [\n    dict(abbr=TASK_TAG,\n         key='dummy',\n         openai_api_base=f'{API_SERVER_ADDR}/v1',\n         type=OpenAISDK,\n         path=SERVED_MODEL_PATH,\n         temperature=0.6,\n         meta_template=dict(round=[\n             dict(role='HUMAN', api_role='HUMAN'),\n             dict(role='BOT', api_role='BOT', generate=True),\n         ], ),\n         query_per_second=10,\n         max_out_len=64000,\n         max_seq_len=65536,\n         batch_size=32,\n         retry=10,\n         pred_postprocessor=dict(type=extract_non_reasoning_content),\n         verbose=False)\n]\n\nJUDGER_ADDR = 'http://<JUDGER_SERVER>'\nJUDGER_MODEL_PATH = ''\njudge_cfg = dict(\n    abbr='CompassVerifier',\n    type=OpenAISDK,\n    path=JUDGER_MODEL_PATH,\n    key='YOUR_API_KEY',\n    openai_api_base=f'{JUDGER_ADDR}/v1',\n    meta_template=dict(round=[\n        dict(role='HUMAN', api_role='HUMAN'),\n        dict(role='BOT', api_role='BOT', generate=True),\n    ]),\n    query_per_second=8,\n    batch_size=32,\n    temperature=0.001,\n    max_out_len=8192,\n    max_seq_len=65536,\n    mode='mid',\n)\n\nfor item in datasets:\n    if 'judge_cfg' in item['eval_cfg']['evaluator']:\n        item['eval_cfg']['evaluator']['judge_cfg'] = judge_cfg\n    if 'llm_evaluator' in item['eval_cfg']['evaluator'].keys(\n    ) and 'judge_cfg' in item['eval_cfg']['evaluator']['llm_evaluator']:\n        item['eval_cfg']['evaluator']['llm_evaluator']['judge_cfg'] = judge_cfg\n\n#######################################################################\n#                         Dataset Summarizer                          #\n#######################################################################\n\ncore_summary_groups = [\n    {\n        'name':\n        'core_average',\n        'subsets': [\n            ['IFEval', 'Prompt-level-strict-accuracy'],\n            ['hle_llmjudge', 'accuracy'],\n            ['aime2025_repeat_32', 'accuracy (32 runs average)'],\n            ['GPQA_diamond_repeat_4', 'accuracy (4 runs average)'],\n            ['mmlu_pro', 'naive_average'],\n            ['lcb_code_generation_repeat_6', 'pass@1 (6 runs average)'],\n        ],\n    },\n]\n\nsummarizer = dict(\n    dataset_abbrs=[\n        ['core_average', 'naive_average'],\n        '',\n        'Instruction Following',\n        ['IFEval', 'Prompt-level-strict-accuracy'],\n        '',\n        'General Reasoning',\n        ['hle_llmjudge', 'accuracy'],\n        ['GPQA_diamond_repeat_4', 'accuracy (4 runs average)'],\n        '',\n        'Math Calculation',\n        ['aime2025_repeat_32', 'accuracy (32 runs average)'],\n        '',\n        'Knowledge',\n        ['mmlu_pro', 'naive_average'],\n        '',\n        'Code',\n        ['lcb_code_generation_repeat_6', 'pass@1 (6 runs average)'],\n    ],\n    summary_groups=sum([v for k, v in locals().items() if k.endswith('_summary_groups')], []),\n)\n\n#######################################################################\n#                   Inference/Evaluation Configuration                #\n#######################################################################\n\n# infer with local runner\ninfer = dict(\n    partitioner=dict(type=NumWorkerPartitioner, num_worker=8),\n    runner=dict(\n        type=LocalRunner,\n        max_num_workers=16,\n        retry=0,  # Modify if needed\n        task=dict(type=OpenICLInferTask),\n    ),\n)\n\n# eval with local runner\neval = dict(\n    partitioner=dict(type=NaivePartitioner, n=10),\n    runner=dict(type=LocalRunner, max_num_workers=16, task=dict(type=OpenICLEvalTask)),\n)\n"
  },
  {
    "path": "eval/eval.py",
    "content": "import argparse\nimport os\nimport signal\nimport subprocess\nimport sys\nfrom datetime import datetime\n\n\nclass ProcessManager:\n    \"\"\"Manager for subprocess execution with proper signal handling.\"\"\"\n\n    def __init__(self):\n        self.process = None\n        self.original_handlers = {}\n\n    def __enter__(self):\n        \"\"\"Context manager entry - setup signal handlers\"\"\"\n        # Save original signal handlers\n        self.original_handlers[signal.SIGINT] = signal.getsignal(signal.SIGINT)\n        self.original_handlers[signal.SIGTERM] = signal.getsignal(signal.SIGTERM)\n\n        # Register new signal handlers\n        signal.signal(signal.SIGINT, self._signal_handler)\n        signal.signal(signal.SIGTERM, self._signal_handler)\n        return self\n\n    def __exit__(self, exc_type, exc_val, exc_tb):\n        \"\"\"Context manager exit - restore original signal handlers\"\"\"\n        # Restore original signal handlers\n        for sig, handler in self.original_handlers.items():\n            signal.signal(sig, handler)\n\n    def _signal_handler(self, sig, frame):\n        \"\"\"Handle termination signals.\"\"\"\n        signal_name = 'SIGINT' if sig == signal.SIGINT else 'SIGTERM'\n        print(f'\\nReceived {signal_name}, cleaning up subprocess...')\n        self.cleanup()\n        sys.exit(0)\n\n    def start_process(self, cmd):\n        self.process = subprocess.Popen(cmd)\n        return self.process\n\n    def cleanup(self):\n        if self.process and self.process.poll() is None:\n            print('Terminating subprocess...')\n            self.process.terminate()\n            try:\n                self.process.wait(timeout=5)\n                print('Subprocess terminated successfully')\n            except subprocess.TimeoutExpired:\n                print('Subprocess did not terminate normally, forcing kill...')\n                self.process.kill()\n                self.process.wait()\n                print('Subprocess killed')\n\n\ndef read_config():\n    \"\"\"Get configuration content from config file in script directory.\n\n    Returns:\n        str: Configuration file content, returns None if reading fails\n    \"\"\"\n    script_dir = os.path.dirname(os.path.abspath(__file__))\n    config_path = os.path.join(script_dir, 'config.py')\n\n    # Read config file content\n    try:\n        with open(config_path, 'r', encoding='utf-8') as f:\n            config_content = f.read()\n        return config_content\n    except FileNotFoundError:\n        print(f'Error: Config file not found at {config_path}')\n        return None\n    except Exception as e:\n        print(f'Error reading config file: {e}')\n        return None\n\n\ndef update_datasets(config, datasets):\n    \"\"\"Update datasets part in config according to datasets list.\n\n    Args:\n        config (str): Original configuration content\n        datasets (list[str]): List of dataset names to include\n    Returns:\n        str: Updated configuration content\n    \"\"\"\n    if 'all' in datasets:\n        # datasets part of the config file specifies all datasets, no need to update\n        return config\n\n    selected_datasets = []\n    if 'code' in datasets:\n        selected_datasets.append('[LCBCodeGeneration_dataset]')\n        datasets.remove('code')\n    for d in datasets:\n        selected_datasets.append(f'{d}_datasets')\n    selected_datasets = ' + '.join(selected_datasets)\n    selected_datasets = f'datasets = {selected_datasets}'\n\n    # replace datasets part in config\n    start_tag = '# <dataset_replace_tag>'\n    end_tag = '# </dataset_replace_tag>'\n\n    start_index = config.find(start_tag)\n    end_index = config.find(end_tag)\n\n    if start_index == -1 or end_index == -1:\n        raise ValueError('replace tag not found in config file')\n\n    end_index += len(end_tag)\n    replacement = f'{start_tag}\\n{selected_datasets}\\n{end_tag}'\n    result = config[:start_index] + replacement + config[end_index:]\n    return result\n\n\ndef get_model_name_from_server(server: str, tag: str) -> str:\n    from openai import OpenAI\n    try:\n        client = OpenAI(api_key='YOUR_API_KEY', base_url=f'{server}/v1')\n        model_name = client.models.list().data[0].id\n        return model_name\n    except Exception as e:\n        raise RuntimeError(f'Failed to get model name from {tag}_server {server}: {e}')\n\n\ndef save_config(work_dir: str, config: str):\n    \"\"\"Save configuration content to a file in the specified directory.\n\n    Args:\n        work_dir (str): Directory to save the configuration file\n        config (str): Configuration content to save\n    \"\"\"\n    if not work_dir:\n        return\n    os.makedirs(work_dir, exist_ok=True)\n    output_file = os.path.join(work_dir, 'config.py')\n    with open(output_file, 'w', encoding='utf-8') as f:\n        f.write(config)\n    print(f'Config written to {output_file}')\n\n\ndef perform_evaluation(config, api_server, judger_server, mode, work_dir, reuse):\n    \"\"\"Perform model evaluation by opencompass.\n\n    Args:\n        config (str): Configuration content\n        api_server (str): API server address for inference\n        judger_server (str): Judger server address for evaluation\n        mode (str): Running mode selection, options: infer, eval, all, config\n        work_dir (str): Output directory for evaluation results. If not specified,\n            config will not be saved and execution will not be performed.\n        reuse (str): Whether to reuse existing results\n    \"\"\"\n    if mode in ['infer', 'all', 'config']:\n        served_model_name = get_model_name_from_server(api_server, 'api')\n        config = config.replace(\"SERVED_MODEL_PATH = ''\", f\"SERVED_MODEL_PATH = '{served_model_name}'\")\n    if mode in ['eval', 'all', 'config']:\n        judger_model_name = get_model_name_from_server(judger_server, 'judger')\n        config = config.replace(\"JUDGER_MODEL_PATH = ''\", f\"JUDGER_MODEL_PATH = '{judger_model_name}'\")\n\n    # write updated config to work_dir\n    if work_dir:\n        save_config(work_dir, config)\n        if mode == 'config':\n            return\n    else:\n        print(config)\n        return\n\n    # execute opencompass command\n    cmd = ['opencompass', f'{work_dir}/config.py', '-m', mode, '-w', work_dir]\n    if reuse:\n        # reuse previous outputs & results. If reuse is a string, it indicates a specific timestamp.\n        try:\n            datetime.strptime(reuse, '%Y%m%d_%H%M%S')\n            cmd.extend(['-r', str(reuse)])\n        except ValueError as e:\n            print(e)\n            raise ValueError(f'Invalid reuse timestamp format: {reuse}. Expected format: YYYYMMDD_HHMMSS') from e\n    try:\n        print(f'Executing command: {\" \".join(cmd)}')\n        # result = subprocess.run(cmd, text=True, check=True)\n        # return result\n        with ProcessManager() as manager:\n            process = manager.start_process(cmd)\n            result = process.wait()\n            return subprocess.CompletedProcess(cmd, result)\n    except Exception as e:\n        print(f'Executing commanded failed with {e}')\n        return\n\n\ndef main():\n    parser = argparse.ArgumentParser(description='Perform model evaluation')\n    parser.add_argument('task_name', type=str, help='The name of an evaluation task')\n    parser.add_argument('-a', '--api-server', type=str, default='', help='API server address for inference')\n    parser.add_argument('-j', '--judger-server', type=str, default='', help='Judger server address for evaluation')\n    dataset_choices = ['aime2025', 'gpqa', 'ifeval', 'code', 'mmlu_pro', 'hle', 'all']\n    parser.add_argument('-d',\n                        '--datasets',\n                        nargs='+',\n                        choices=dataset_choices,\n                        default=['all'],\n                        help=f\"List of datasets. Available options: {', '.join(dataset_choices)}. \"\n                        'Use \"all\" to include all datasets.')\n    parser.add_argument('-w',\n                        '--work-dir',\n                        type=str,\n                        default='',\n                        help='Output directory of evaluation. If not specified, outputs will not be saved.')\n    parser.add_argument('-r',\n                        '--reuse',\n                        nargs='?',\n                        type=str,\n                        const='latest',\n                        help='Reuse previous outputs & results, and run any missing jobs presented in the config. '\n                        'If its argument is not specified, the latest results in the work_dir will be reused. '\n                        'The argument should also be a specific timestamp, e.g. 20230516_144254')\n    parser.add_argument('-m',\n                        '--mode',\n                        type=str,\n                        help='Running mode selection. '\n                        'all: complete pipeline including both inference and evaluation (default). '\n                        'infer: only perform model inference to generate results. '\n                        'eval: only evaluate previously generated results. '\n                        'config: generate configuration files without execution.',\n                        choices=['all', 'infer', 'eval', 'config'],\n                        default='all')\n    args = parser.parse_args()\n    task_name = args.task_name\n    api_server = args.api_server\n    judger_server = args.judger_server\n    datasets = args.datasets\n    mode = args.mode\n    work_dir = args.work_dir\n\n    # Process server addresses\n    if api_server and not api_server.startswith('http'):\n        api_server = f'http://{api_server}'\n    if judger_server and not judger_server.startswith('http'):\n        judger_server = f'http://{judger_server}'\n\n    # read config file\n    config = read_config()\n\n    # update task name in config\n    config = config.replace(\"TASK_TAG = ''\", f\"TASK_TAG = '{task_name}'\")\n\n    # update datasets part of config according to args.datasets\n    config = update_datasets(config, datasets)\n\n    # update api_server part of config according to args.api_server\n    if api_server:\n        config = config.replace(\"API_SERVER_ADDR = 'http://<API_SERVER>'\", f\"API_SERVER_ADDR = '{api_server}'\")\n    if judger_server:\n        # update judger_server part of config according to args.judger_server\n        config = config.replace(\"JUDGER_ADDR = 'http://<JUDGER_SERVER>'\", f\"JUDGER_ADDR = '{judger_server}'\")\n\n    # perform evaluation\n    perform_evaluation(config, api_server, judger_server, mode, work_dir, args.reuse)\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "examples/lite/qwen3_30b_a3b_awq.py",
    "content": "import argparse\n\nfrom datasets import load_dataset\nfrom llmcompressor import oneshot\nfrom llmcompressor.modifiers.awq import AWQModifier\nfrom transformers import AutoModelForCausalLM, AutoTokenizer\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description='Run AWQ quantization for Qwen3 model')\n\n    parser.add_argument('--work-dir',\n                        type=str,\n                        default='./qwen3_30b_a3b_awq',\n                        required=True,\n                        help='The directory to save the quantized model')\n\n    parser.add_argument('--model-id',\n                        type=str,\n                        default='Qwen/Qwen3-30B-A3B',\n                        help='The Hugging Face model ID to quantize')\n    return parser.parse_args()\n\n\ndef main():\n    # 1. Achieve command args\n    args = parse_args()\n    MODEL_ID = args.model_id\n    SAVE_DIR = args.work_dir\n\n    print(f'Loading model: {MODEL_ID}')\n    print(f'Saving to: {SAVE_DIR}')\n\n    # 2. Load_dataset and tokenizer\n    model = AutoModelForCausalLM.from_pretrained(MODEL_ID, dtype='auto', device_map='auto', trust_remote_code=True)\n    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)\n\n    # 3. Prepare calibration dataset\n    DATASET_ID = 'neuralmagic/calibration'\n    DATASET_SPLIT = 'train'\n    NUM_CALIBRATION_SAMPLES = 256\n    MAX_SEQUENCE_LENGTH = 512\n\n    def get_calib_dataset(tokenizer):\n        ds = load_dataset(\n            DATASET_ID,\n            'LLM',\n            split=f'{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]',\n        )\n\n        def preprocess(example):\n            messages = []\n            for message in example['messages']:\n                if message['role'] == 'user':\n                    messages.append({'role': 'user', 'content': message['content']})\n                elif message['role'] == 'assistant':\n                    messages.append({'role': 'assistant', 'content': message['content']})\n\n            return tokenizer(\n                tokenizer.apply_chat_template(\n                    messages,\n                    tokenize=False,\n                ),\n                padding=False,\n                max_length=MAX_SEQUENCE_LENGTH,\n                truncation=True,\n                add_special_tokens=False,\n            )\n\n        ds = (ds.shuffle(seed=42).map(preprocess,\n                                      remove_columns=ds.column_names).select(range(NUM_CALIBRATION_SAMPLES)))\n        return ds\n\n    # 4. Configure quant args (W4A16_ASYM AWQ)\n    recipe = [\n        AWQModifier(\n            ignore=['lm_head', 're:.*mlp.gate$'],\n            scheme='W4A16_ASYM',\n            targets=['Linear'],\n            duo_scaling='both',\n        ),\n    ]\n\n    # 5. Run quantization\n    print('Starting quantization...')\n    oneshot(\n        model=model,\n        dataset=get_calib_dataset(tokenizer),\n        recipe=recipe,\n        max_seq_length=MAX_SEQUENCE_LENGTH,\n        num_calibration_samples=NUM_CALIBRATION_SAMPLES,\n        log_dir=None,\n    )\n\n    # 6. Save quantized model\n    print('Saving model...')\n    model.save_pretrained(SAVE_DIR)\n    tokenizer.save_pretrained(SAVE_DIR)\n    print(f'Successfully saved to {SAVE_DIR}')\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "examples/lite/qwen3_30b_a3b_gptq.py",
    "content": "import argparse\n\nfrom datasets import load_dataset\nfrom llmcompressor import oneshot\nfrom llmcompressor.modifiers.quantization import GPTQModifier\nfrom transformers import AutoModelForCausalLM, AutoTokenizer\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description='Run GPTQ quantization for Qwen3 model')\n\n    parser.add_argument('--work-dir',\n                        type=str,\n                        default='./qwen3_30b_a3b_gptq',\n                        required=True,\n                        help='The directory to save the quantized model')\n\n    parser.add_argument('--model-id',\n                        type=str,\n                        default='Qwen/Qwen3-30B-A3B',\n                        help='The Hugging Face model ID to quantize')\n    return parser.parse_args()\n\n\ndef main():\n    # 1. Achieve command args\n    args = parse_args()\n    MODEL_ID = args.model_id\n    SAVE_DIR = args.work_dir\n\n    print(f'Loading model: {MODEL_ID}')\n    print(f'Saving to: {SAVE_DIR}')\n\n    # 2. Load_dataset and tokenizer\n    model = AutoModelForCausalLM.from_pretrained(MODEL_ID, dtype='auto', device_map='auto', trust_remote_code=True)\n    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)\n\n    # 3. Prepare calibration dataset\n    DATASET_ID = 'neuralmagic/calibration'\n    DATASET_SPLIT = 'train'\n    NUM_CALIBRATION_SAMPLES = 256\n    MAX_SEQUENCE_LENGTH = 512\n\n    def get_calib_dataset(tokenizer):\n        ds = load_dataset(\n            DATASET_ID,\n            'LLM',\n            split=f'{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]',\n        )\n\n        def preprocess(example):\n            messages = []\n            for message in example['messages']:\n                if message['role'] == 'user':\n                    messages.append({'role': 'user', 'content': message['content']})\n                elif message['role'] == 'assistant':\n                    messages.append({'role': 'assistant', 'content': message['content']})\n\n            return tokenizer(\n                tokenizer.apply_chat_template(\n                    messages,\n                    tokenize=False,\n                ),\n                padding=False,\n                max_length=MAX_SEQUENCE_LENGTH,\n                truncation=True,\n                add_special_tokens=False,\n            )\n\n        ds = (ds.shuffle(seed=42).map(preprocess,\n                                      remove_columns=ds.column_names).select(range(NUM_CALIBRATION_SAMPLES)))\n        return ds\n\n    # 4. Configure quant args (W4A16_ASYM AWQ)\n    recipe = [\n        GPTQModifier(targets='Linear', scheme='W4A16_ASYM', ignore=['lm_head', 're:.*mlp.gate$']),\n    ]\n\n    # 5. Run quantization\n    print('Starting quantization...')\n    oneshot(\n        model=model,\n        dataset=get_calib_dataset(tokenizer),\n        recipe=recipe,\n        max_seq_length=MAX_SEQUENCE_LENGTH,\n        num_calibration_samples=NUM_CALIBRATION_SAMPLES,\n        log_dir=None,\n    )\n\n    # 6. Save quantized model\n    print('Saving model...')\n    model.save_pretrained(SAVE_DIR)\n    tokenizer.save_pretrained(SAVE_DIR)\n    print(f'Successfully saved to {SAVE_DIR}')\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "generate.sh",
    "content": "#!/bin/bash\nWORKSPACE_PATH=$(dirname \"$(readlink -f \"$0\")\")\n\nbuilder=\"-G Ninja\"\n\nif [ \"$1\" == \"make\" ]; then\n    builder=\"\"\nfi\n\ncmake ${builder} .. \\\n    -DCMAKE_BUILD_TYPE=RelWithDebInfo \\\n    -DCMAKE_EXPORT_COMPILE_COMMANDS=1 \\\n    -DCMAKE_INSTALL_PREFIX=${WORKSPACE_PATH}/install \\\n    -DBUILD_PY_FFI=ON \\\n    -DBUILD_MULTI_GPU=ON \\\n    -DCMAKE_CUDA_FLAGS=\"-lineinfo\" \\\n    -DUSE_NVTX=ON \\\n    -DFETCHCONTENT_QUIET=OFF\n"
  },
  {
    "path": "k8s/deployment.yaml",
    "content": "apiVersion: apps/v1\nkind: Deployment\nmetadata:\n  labels:\n    app: internlm2-chat-7b\n  name: internlm2-chat-7b\nspec:\n  replicas: 1\n  selector:\n    matchLabels:\n      app: internlm2-chat-7b\n  strategy: {}\n  template:\n    metadata:\n      labels:\n        app: internlm2-chat-7b\n    spec:\n      containers:\n      - name: internlm2-chat-7b\n        image: openmmlab/lmdeploy:latest\n        command:\n        - /bin/sh\n        - -c\n        args:\n        - \"lmdeploy serve api_server internlm/internlm2-chat-7b --server-port 23333\"\n        env:\n        - name: HUGGING_FACE_HUB_TOKEN\n          value: \"{{HUGGING_FACE_HUB_TOKEN}}\"\n        ports:\n        - containerPort: 23333\n          protocol: TCP\n          name: main\n        resources:\n          limits:\n            cpu: \"16\"\n            memory: 64Gi\n            nvidia.com/gpu: \"1\"\n          requests:\n            cpu: \"16\"\n            memory: 64Gi\n            nvidia.com/gpu: \"1\"\n        readinessProbe:\n          failureThreshold: 3\n          initialDelaySeconds: 400\n          periodSeconds: 10\n          successThreshold: 1\n          tcpSocket:\n            port: main\n          timeoutSeconds: 1\n        livenessProbe:\n          failureThreshold: 3\n          initialDelaySeconds: 900\n          periodSeconds: 20\n          successThreshold: 1\n          tcpSocket:\n            port: main\n          timeoutSeconds: 1\n        volumeMounts:\n        - mountPath: /root/.cache/huggingface\n          name: model-data\n        - mountPath: /dev/shm\n          name: dshm\n      volumes:\n      - name: model-data\n        hostPath:\n          path: /root/.cache/huggingface\n          type: DirectoryOrCreate\n      - emptyDir:\n          medium: Memory\n        name: dshm\n"
  },
  {
    "path": "k8s/service.yaml",
    "content": "apiVersion: v1\nkind: Service\nmetadata:\n  labels:\n    app: internlm2-chat-7b\n  name: internlm2-chat-7b-svc\nspec:\n  ports:\n  - name: main\n    port: 23333\n    protocol: TCP\n    targetPort: main\n  selector:\n    app: internlm2-chat-7b\n  type: ClusterIP\n"
  },
  {
    "path": "lmdeploy/__init__.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nfrom .api import client, pipeline, serve\nfrom .messages import GenerationConfig, PytorchEngineConfig, TurbomindEngineConfig, VisionConfig\nfrom .model import ChatTemplateConfig\nfrom .pipeline import Pipeline\nfrom .tokenizer import Tokenizer\nfrom .version import __version__, version_info\n\n__all__ = [\n    'pipeline', 'serve', 'client', 'Tokenizer', 'GenerationConfig', '__version__', 'version_info', 'ChatTemplateConfig',\n    'PytorchEngineConfig', 'TurbomindEngineConfig', 'VisionConfig', 'Pipeline'\n]\n"
  },
  {
    "path": "lmdeploy/__main__.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom .cli import run\n\nif __name__ == '__main__':\n    run()\n"
  },
  {
    "path": "lmdeploy/api.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom __future__ import annotations\n\nfrom typing import TYPE_CHECKING, List, Literal\n\nfrom typing_extensions import deprecated\n\nfrom .pipeline import Pipeline\n\nif TYPE_CHECKING:\n    from .messages import PytorchEngineConfig, SpeculativeConfig, TurbomindEngineConfig\n    from .model import ChatTemplateConfig\n\n\ndef pipeline(model_path: str,\n             backend_config: 'TurbomindEngineConfig' | 'PytorchEngineConfig' | None = None,\n             chat_template_config: 'ChatTemplateConfig' | None = None,\n             log_level: str = 'WARNING',\n             max_log_len: int | None = None,\n             speculative_config: 'SpeculativeConfig' | None = None,\n             **kwargs):\n    \"\"\"\n    Args:\n        model_path: the path of a model. It could be one of the following options:\n\n            - i) A local directory path of a turbomind model which is\n              converted by ``lmdeploy convert`` command or download from\n              ii) and iii).\n            - ii) The model_id of a lmdeploy-quantized model hosted\n              inside a model repo on huggingface.co, such as\n              ``InternLM/internlm-chat-20b-4bit``,\n              ``lmdeploy/llama2-chat-70b-4bit``, etc.\n            - iii) The model_id of a model hosted inside a model repo\n              on huggingface.co, such as ``internlm/internlm-chat-7b``,\n              ``Qwen/Qwen-7B-Chat``, ``baichuan-inc/Baichuan2-7B-Chat``\n              and so on.\n        backend_config: backend\n            config instance. Default to None.\n        chat_template_config: chat template configuration.\n            Default to None.\n        log_level: set log level whose value among [``CRITICAL``, ``ERROR``,\n            ``WARNING``, ``INFO``, ``DEBUG``]\n        max_log_len: Max number of prompt characters or prompt tokens\n            being printed in log\n\n    Examples:\n\n        .. code-block:: python\n\n            # LLM\n            import lmdeploy\n            pipe = lmdeploy.pipeline('internlm/internlm-chat-7b')\n            response = pipe(['hi','say this is a test'])\n            print(response)\n\n            # VLM\n            from lmdeploy.vl import load_image\n            from lmdeploy import pipeline, TurbomindEngineConfig, ChatTemplateConfig\n            pipe = pipeline('liuhaotian/llava-v1.5-7b',\n                            backend_config=TurbomindEngineConfig(session_len=8192),\n                            chat_template_config=ChatTemplateConfig(model_name='vicuna'))\n            im = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/demo/resources/human-pose.jpg')\n            response = pipe([('describe this image', [im])])\n            print(response)\n\n    \"\"\" # noqa E501\n\n    return Pipeline(model_path,\n                    backend_config=backend_config,\n                    chat_template_config=chat_template_config,\n                    log_level=log_level,\n                    max_log_len=max_log_len,\n                    speculative_config=speculative_config,\n                    **kwargs)\n\n\n@deprecated('This function is no longer available. Please use CLI command \"lmdeploy serve api_server\" instead.')\ndef serve(model_path: str,\n          model_name: str | None = None,\n          backend: Literal['turbomind', 'pytorch'] = 'turbomind',\n          backend_config: 'TurbomindEngineConfig' | 'PytorchEngineConfig' | None = None,\n          chat_template_config: 'ChatTemplateConfig' | None = None,\n          server_name: str = '0.0.0.0',\n          server_port: int = 23333,\n          log_level: str = 'ERROR',\n          api_keys: List[str] | str | None = None,\n          ssl: bool = False,\n          **kwargs):\n    \"\"\"This function is deprecated and no longer available.\n\n    .. deprecated::\n        This function has been removed. Please use alternative methods.\n\n    This will run the api_server in a subprocess.\n    \"\"\" # noqa E501\n    raise NotImplementedError(\"The 'serve' function is no longer available. \"\n                              'This function has been deprecated and removed.')\n\n\n@deprecated('This function is no longer available. Please use \"from lmdeploy.serve import APIClient\" instead.')\ndef client(api_server_url: str = 'http://0.0.0.0:23333', api_key: str | None = None, **kwargs):\n    \"\"\"This function is deprecated and no longer available.\n\n    .. deprecated::\n        This function has been removed. Please use ``from lmdeploy.serve import APIClient`` instead.\n\n    Args:\n        api_server_url: communicating address ``http://<ip>:<port>`` of\n            api_server\n        api_key: api key. Default to None, which means no\n            api key will be used.\n    Return:\n        Chatbot for LLaMA series models with turbomind as inference engine.\n    \"\"\"\n    raise NotImplementedError(\"The 'client' function is no longer available. This function has been deprecated. \"\n                              ' Please use \"from lmdeploy.serve import APIClient\" instead.')\n"
  },
  {
    "path": "lmdeploy/archs.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport os\nfrom typing import Dict, List, Literal, Tuple\n\nfrom transformers import AutoConfig\n\nfrom .messages import PytorchEngineConfig, TurbomindEngineConfig\nfrom .utils import get_logger\n\nlogger = get_logger('lmdeploy')\n\n\ndef autoget_backend(model_path: str) -> Literal['turbomind', 'pytorch']:\n    \"\"\"Get backend type in auto backend mode.\n\n    Args:\n         model_path (str): the path of a model.\n            It could be one of the following options:\n                - i) A local directory path of a turbomind model which is\n                    converted by `lmdeploy convert` command or download from\n                    ii) and iii).\n                - ii) The model_id of a lmdeploy-quantized model hosted\n                    inside a model repo on huggingface.co, such as\n                    \"InternLM/internlm-chat-20b-4bit\",\n                    \"lmdeploy/llama2-chat-70b-4bit\", etc.\n                - iii) The model_id of a model hosted inside a model repo\n                    on huggingface.co, such as \"internlm/internlm-chat-7b\",\n                    \"Qwen/Qwen-7B-Chat \", \"baichuan-inc/Baichuan2-7B-Chat\"\n                    and so on.\n\n    Returns:\n        str: the backend type.\n    \"\"\"\n\n    turbomind_has = False\n    is_turbomind_installed = True\n    try:\n        from lmdeploy.turbomind.supported_models import is_supported as is_supported_turbomind\n        turbomind_has = is_supported_turbomind(model_path)\n    except ImportError:\n        is_turbomind_installed = False\n\n    if is_turbomind_installed:\n        if not turbomind_has:\n            logger.warning('Fallback to pytorch engine because '\n                           f'{model_path!r} not supported by turbomind'\n                           ' engine.')\n    else:\n        logger.warning('Fallback to pytorch engine because turbomind engine is not '\n                       'installed correctly. If you insist to use turbomind engine, '\n                       'you may need to reinstall lmdeploy from pypi or build from '\n                       'source and try again.')\n\n    backend = 'turbomind' if turbomind_has else 'pytorch'\n    return backend\n\n\ndef autoget_backend_config(\n    model_path: str,\n    backend_config: PytorchEngineConfig | TurbomindEngineConfig | None = None\n) -> Tuple[Literal['turbomind', 'pytorch'], PytorchEngineConfig | TurbomindEngineConfig]:\n    \"\"\"Get backend config automatically.\n\n    Args:\n        model_path (str): The input model path.\n        backend_config (TurbomindEngineConfig | PytorchEngineConfig): The\n            input backend config. Default to None.\n\n    Returns:\n        (PytorchEngineConfig | TurbomindEngineConfig): The auto-determined\n            backend engine config.\n    \"\"\"\n    from dataclasses import asdict\n\n    if isinstance(backend_config, PytorchEngineConfig):\n        return 'pytorch', backend_config\n\n    backend = autoget_backend(model_path)\n    config = PytorchEngineConfig() if backend == 'pytorch' else TurbomindEngineConfig()\n    if backend_config is not None:\n        if type(backend_config) == type(config):\n            config = backend_config\n        else:\n            data = asdict(backend_config)\n            for k, v in data.items():\n                if v and hasattr(config, k):\n                    setattr(config, k, v)\n            # map attributes with different names\n            if type(backend_config) is TurbomindEngineConfig:\n                config.block_size = backend_config.cache_block_seq_len\n            else:\n                config.cache_block_seq_len = backend_config.block_size\n    return backend, config\n\n\ndef check_vl_llm(backend: str, config: dict) -> bool:\n    \"\"\"Check if the model is a vl model from model config.\"\"\"\n    if 'auto_map' in config:\n        for _, v in config['auto_map'].items():\n            if 'InternLMXComposer2ForCausalLM' in v:\n                return True\n\n    if 'language_config' in config and 'vision_config' in config and config['language_config'].get(\n            'architectures', [None])[0] == 'DeepseekV2ForCausalLM':\n        return True\n\n    arch = config['architectures'][0]\n    supported_archs = set([\n        'LlavaLlamaForCausalLM', 'LlavaMistralForCausalLM', 'CogVLMForCausalLM', 'InternLMXComposer2ForCausalLM',\n        'InternVLChatModel', 'MiniCPMV', 'LlavaForConditionalGeneration', 'LlavaNextForConditionalGeneration',\n        'Phi3VForCausalLM', 'Qwen2VLForConditionalGeneration', 'Qwen2_5_VLForConditionalGeneration',\n        'Qwen3VLForConditionalGeneration', 'Qwen3VLMoeForConditionalGeneration', 'Qwen3_5ForConditionalGeneration',\n        'Qwen3_5MoeForConditionalGeneration', 'MllamaForConditionalGeneration', 'MolmoForCausalLM',\n        'Gemma3ForConditionalGeneration', 'Llama4ForConditionalGeneration', 'InternVLForConditionalGeneration',\n        'InternS1ForConditionalGeneration', 'InternS1ProForConditionalGeneration',\n        'InternS1_1_ForConditionalGeneration', 'Glm4vForConditionalGeneration'\n    ])\n    if arch == 'QWenLMHeadModel' and 'visual' in config:\n        return True\n    elif arch == 'MultiModalityCausalLM' and 'language_config' in config:\n        return True\n    elif arch in ['ChatGLMModel', 'ChatGLMForConditionalGeneration'] and 'vision_config' in config:\n        return True\n    elif arch in ['Qwen3_5ForConditionalGeneration', 'Qwen3_5MoeForConditionalGeneration'] and backend == 'turbomind':\n        return False\n    elif arch in supported_archs:\n        return True\n    return False\n\n\ndef get_task(backend: str, model_path: str):\n    \"\"\"Get pipeline type and pipeline class from model config.\"\"\"\n    from lmdeploy.serve.core import AsyncEngine\n\n    if os.path.exists(os.path.join(model_path, 'triton_models', 'weights')):\n        # workspace model\n        return 'llm', AsyncEngine\n    _, config = get_model_arch(model_path)\n    if check_vl_llm(backend, config.to_dict()):\n        from lmdeploy.serve.core import VLAsyncEngine\n        return 'vlm', VLAsyncEngine\n\n    # default task, pipeline_class\n    return 'llm', AsyncEngine\n\n\ndef get_model_arch(model_path: str):\n    \"\"\"Get a model's architecture and configuration.\n\n    Args:\n        model_path(str): the model path\n    \"\"\"\n    try:\n        cfg = AutoConfig.from_pretrained(model_path, trust_remote_code=True)\n    except Exception as e:  # noqa\n        from transformers import PretrainedConfig\n        cfg = PretrainedConfig.from_pretrained(model_path, trust_remote_code=True)\n\n    _cfg = cfg.to_dict()\n    if _cfg.get('architectures', None):\n        arch = _cfg['architectures'][0]\n        if _cfg.get('auto_map'):\n            for _, v in _cfg['auto_map'].items():\n                if 'InternLMXComposer2ForCausalLM' in v:\n                    arch = 'InternLMXComposer2ForCausalLM'\n    elif _cfg.get('auto_map', None) and 'AutoModelForCausalLM' in _cfg['auto_map']:\n        arch = _cfg['auto_map']['AutoModelForCausalLM'].split('.')[-1]\n    elif _cfg.get('language_config', None) and _cfg['language_config'].get(\n            'auto_map', None) and 'AutoModelForCausalLM' in _cfg['language_config']['auto_map']:\n        arch = _cfg['language_config']['auto_map']['AutoModelForCausalLM'].split('.')[-1]\n    else:\n        raise RuntimeError(f'Could not find model architecture from config: {_cfg}')\n    return arch, cfg\n\n\ndef search_nested_config(config, key):\n    \"\"\"Recursively searches for the value associated with the given key in a\n    nested configuration of a model.\"\"\"\n    if isinstance(config, Dict):\n        for k, v in config.items():\n            if k == key:\n                return v\n            if isinstance(v, (Dict, List)):\n                result = search_nested_config(v, key)\n                if result is not None:\n                    return result\n    elif isinstance(config, List):\n        for item in config:\n            result = search_nested_config(item, key)\n            if result is not None:\n                return result\n    return None\n"
  },
  {
    "path": "lmdeploy/cli/__init__.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom .entrypoint import run\n\n__all__ = ['run']\n"
  },
  {
    "path": "lmdeploy/cli/chat.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom contextlib import closing\n\nimport fire\n\nfrom lmdeploy import GenerationConfig, PytorchEngineConfig, TurbomindEngineConfig, pipeline\nfrom lmdeploy.archs import autoget_backend\n\n\ndef input_prompt():\n    \"\"\"Input a prompt in the consolo interface.\"\"\"\n    print('\\ndouble enter to end input >>> ', end='')\n    sentinel = ''  # ends when this string is seen\n    return '\\n'.join(iter(input, sentinel))\n\n\ndef build_pipe(model_path, backend, **kwargs):\n    engine_config = None\n    if kwargs.get('enable_prefix_caching', False):\n        print('interactive chat cannot be used when prefix caching is enabled')\n        exit(-1)\n    if backend == 'turbomind':\n        engine_config = TurbomindEngineConfig()\n        for key, value in kwargs.items():\n            if hasattr(TurbomindEngineConfig, key):\n                setattr(engine_config, key, value)\n    else:\n        engine_config = PytorchEngineConfig()\n        for key, value in kwargs.items():\n            key = 'device_type' if key == 'device' else key\n            if hasattr(PytorchEngineConfig, key):\n                setattr(engine_config, key, value)\n        if kwargs.get('adapters', None):\n            from .utils import get_lora_adapters\n            adapters = get_lora_adapters(kwargs['adapters'])\n            engine_config.adapters = adapters\n    # disable metrics to avoid installing prometheus_client, which is not needed\n    # in interactive chat\n    engine_config.enable_metrics = False\n\n    # set chat template config\n    chat_template = kwargs.get('chat_template', None)\n    chat_template_config = None\n    if chat_template:\n        from .utils import get_chat_template\n        chat_template_config = get_chat_template(chat_template, model_path)\n    pipe = pipeline(model_path,\n                    backend_config=engine_config,\n                    chat_template_config=chat_template_config,\n                    log_level='ERROR',\n                    **kwargs)\n    return pipe\n\n\ndef build_gen_config(**kwargs):\n    gen_config = GenerationConfig(do_sample=True, max_new_tokens=4096)\n    for key, value in kwargs.items():\n        if hasattr(GenerationConfig, key):\n            setattr(gen_config, key, value)\n    return gen_config\n\n\ndef get_adapter_name(adapters=None, **kwargs):\n    if adapters is None:\n        return None\n    from .utils import get_lora_adapters\n    adapters = get_lora_adapters(adapters)\n    return list(adapters.keys())[0]\n\n\ndef main(model_path, backend, **kwargs):\n    if backend != 'pytorch':\n        # set auto backend mode\n        backend = autoget_backend(model_path)\n    quit = False\n    with build_pipe(model_path, backend, **kwargs) as pipe:\n        gen_config = build_gen_config(**kwargs)\n        adapter_name = get_adapter_name(**kwargs)\n        while not quit:\n            with closing(pipe.session()) as sess:\n                while True:\n                    try:\n                        prompt = input_prompt()\n                    except KeyboardInterrupt:\n                        quit = True\n                        break\n                    if prompt == 'end':\n                        sess.close()\n                        break\n                    if prompt == 'exit':\n                        quit = True\n                        break\n                    if prompt.strip() == '':\n                        continue\n                    resps = pipe.chat(prompt,\n                                      session=sess,\n                                      gen_config=gen_config,\n                                      adapter_name=adapter_name,\n                                      stream_response=True)\n                    try:\n                        for resp in resps:\n                            print(resp.text, end='', flush=True)\n                    except KeyboardInterrupt:\n                        sess.abort()\n        else:\n            print('exiting...')\n\n\nif __name__ == '__main__':\n    fire.Fire(main)\n"
  },
  {
    "path": "lmdeploy/cli/cli.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nimport os\n\nfrom ..version import __version__\nfrom .utils import (ArgumentHelper, DefaultsAndTypesHelpFormatter, FlexibleArgumentParser, convert_args,\n                    get_speculative_config)\n\n\nclass CLI(object):\n    _desc = 'The CLI provides a unified API for converting, ' \\\n            'compressing and deploying large language models.'\n    parser = FlexibleArgumentParser(prog='lmdeploy', description=_desc, add_help=True)\n    parser.add_argument('-v', '--version', action='version', version=__version__)\n    subparsers = parser.add_subparsers(title='Commands', description='lmdeploy has following commands:', dest='command')\n\n    @staticmethod\n    def add_parser_chat():\n        \"\"\"Add parser for list command.\"\"\"\n        parser = CLI.subparsers.add_parser('chat',\n                                           formatter_class=DefaultsAndTypesHelpFormatter,\n                                           description=CLI.chat.__doc__,\n                                           help=CLI.chat.__doc__)\n        parser.set_defaults(run=CLI.chat)\n        parser.add_argument('model_path',\n                            type=str,\n                            help='The path of a model. it could be one of the following '\n                            'options: - i) a local directory path of a turbomind model'\n                            ' which is converted by `lmdeploy convert` command or '\n                            'download from ii) and iii). - ii) the model_id of a '\n                            'lmdeploy-quantized model hosted inside a model repo on '\n                            'huggingface.co, such as \"internlm/internlm-chat-20b-4bit\",'\n                            ' \"lmdeploy/llama2-chat-70b-4bit\", etc. - iii) the model_id'\n                            ' of a model hosted inside a model repo on huggingface.co,'\n                            ' such as \"internlm/internlm-chat-7b\", \"qwen/qwen-7b-chat \"'\n                            ', \"baichuan-inc/baichuan2-7b-chat\" and so on')\n        # common args\n        ArgumentHelper.backend(parser)\n        # chat template args\n        ArgumentHelper.chat_template(parser)\n        # model args\n        ArgumentHelper.revision(parser)\n        ArgumentHelper.download_dir(parser)\n\n        # pytorch engine args\n        pt_group = parser.add_argument_group('PyTorch engine arguments')\n        ArgumentHelper.adapters(pt_group)\n        ArgumentHelper.device(pt_group)\n        ArgumentHelper.eager_mode(pt_group)\n        ArgumentHelper.dllm_block_length(pt_group)\n        # common engine args\n        dtype_act = ArgumentHelper.dtype(pt_group)\n        tp_act = ArgumentHelper.tp(pt_group)\n        session_len_act = ArgumentHelper.session_len(pt_group)\n        cache_max_entry_act = ArgumentHelper.cache_max_entry_count(pt_group)\n        prefix_caching_act = ArgumentHelper.enable_prefix_caching(pt_group)\n        quant_policy = ArgumentHelper.quant_policy(pt_group)\n\n        # turbomind args\n        tb_group = parser.add_argument_group('TurboMind engine arguments')\n        # common engine args\n        tb_group._group_actions.append(dtype_act)\n        tb_group._group_actions.append(tp_act)\n        tb_group._group_actions.append(session_len_act)\n        tb_group._group_actions.append(cache_max_entry_act)\n        tb_group._group_actions.append(prefix_caching_act)\n        tb_group._group_actions.append(quant_policy)\n        ArgumentHelper.model_format(tb_group)\n        ArgumentHelper.rope_scaling_factor(tb_group)\n        ArgumentHelper.communicator(tb_group)\n        ArgumentHelper.cp(tb_group)\n        ArgumentHelper.async_(tb_group)\n\n        # speculative decoding\n        ArgumentHelper.add_spec_group(parser)\n\n    @staticmethod\n    def add_parser_checkenv():\n        \"\"\"Add parser for check_env command.\"\"\"\n        parser = CLI.subparsers.add_parser('check_env',\n                                           formatter_class=DefaultsAndTypesHelpFormatter,\n                                           description=CLI.check_env.__doc__,\n                                           help=CLI.check_env.__doc__)\n        parser.set_defaults(run=CLI.check_env)\n        parser.add_argument('--dump-file',\n                            type=str,\n                            default=None,\n                            help='The file path to save env info. Only '\n                            'support file format in `json`, `yml`,'\n                            ' `pkl`')\n\n    @staticmethod\n    def check_env(args):\n        \"\"\"Check the environmental information.\"\"\"\n        import importlib\n\n        import mmengine\n        from mmengine.utils import get_git_hash\n        from mmengine.utils.dl_utils import collect_env\n\n        from lmdeploy.version import __version__\n\n        env_info = collect_env()\n        env_info['LMDeploy'] = __version__ + '+' + get_git_hash()[:7]\n\n        # remove some unnecessary info\n        remove_reqs = ['MMEngine', 'OpenCV']\n        for req in remove_reqs:\n            if req in env_info:\n                env_info.pop(req)\n\n        # extra important dependencies\n        extra_reqs = ['transformers', 'fastapi', 'pydantic', 'triton']\n\n        for req in extra_reqs:\n            try:\n                env_info[req] = importlib.import_module(req).__version__\n            except Exception:\n                env_info[req] = 'Not Found'\n\n        def get_gpu_topo():\n            import subprocess\n            import sys\n            if sys.platform.startswith('linux'):\n                try:\n                    res = subprocess.run(['nvidia-smi', 'topo', '-m'],\n                                         stdout=subprocess.PIPE,\n                                         stderr=subprocess.PIPE,\n                                         text=True,\n                                         check=True)\n                    if res.returncode == 0:\n                        return '\\n' + res.stdout\n                    else:\n                        return None\n                except FileNotFoundError:\n                    return None\n            else:\n                return None\n\n        gpu_topo = get_gpu_topo()\n        if gpu_topo is not None:\n            env_info['NVIDIA Topology'] = gpu_topo\n\n        # print env info\n        for k, v in env_info.items():\n            print(f'{k}: {v}')\n\n        # dump to local file\n        dump_file = args.dump_file\n        if dump_file is not None:\n            work_dir, _ = os.path.split(dump_file)\n            if work_dir:\n                os.makedirs(work_dir, exist_ok=True)\n            mmengine.dump(env_info, dump_file)\n\n    @staticmethod\n    def chat(args):\n        from .chat import main\n\n        kwargs = convert_args(args)\n        speculative_config = get_speculative_config(args)\n        to_remove = ['speculative_algorithm', 'speculative_draft_model', 'speculative_num_draft_tokens']\n        for key in to_remove:\n            kwargs.pop(key)\n        kwargs['speculative_config'] = speculative_config\n        main(**kwargs)\n\n    @staticmethod\n    def add_parsers():\n        \"\"\"Add all parsers.\"\"\"\n        CLI.add_parser_checkenv()\n        CLI.add_parser_chat()\n"
  },
  {
    "path": "lmdeploy/cli/entrypoint.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport os\nimport sys\n\nfrom .cli import CLI\nfrom .lite import SubCliLite\nfrom .serve import SubCliServe\n\n\ndef run():\n    \"\"\"The entry point of running LMDeploy CLI.\"\"\"\n    args = sys.argv[1:]\n    CLI.add_parsers()\n    SubCliServe.add_parsers()\n    SubCliLite.add_parsers()\n    parser = CLI.parser\n    args = parser.parse_args()\n\n    if hasattr(args, 'model_name'):\n        # if `model_name` is not specified, use the model_path instead. The\n        # 'model_path' could be a a local path, or a repo id from hub\n        args.model_name = args.model_name if args.model_name else \\\n            args.model_path\n\n    if 'run' in dir(args):\n        from lmdeploy.utils import get_model\n        model_path = getattr(args, 'model_path', None)\n        revision = getattr(args, 'revision', None)\n        download_dir = getattr(args, 'download_dir', None)\n        if model_path is not None and not os.path.exists(args.model_path):\n            args.model_path = get_model(args.model_path, download_dir=download_dir, revision=revision)\n        model_path_or_server = getattr(args, 'model_path_or_server', None)\n        if model_path_or_server is not None and (':' not in model_path_or_server\n                                                 and not os.path.exists(model_path_or_server)):\n            args.model_path_or_server = get_model(args.model_path_or_server,\n                                                  download_dir=download_dir,\n                                                  revision=revision)\n\n        args.run(args)\n    else:\n        try:\n            args.print_help()\n        except AttributeError:\n            command = args.command\n            if command == 'serve':\n                SubCliServe.parser.print_help()\n            elif command == 'lite':\n                SubCliLite.parser.print_help()\n            else:\n                parser.print_help()\n"
  },
  {
    "path": "lmdeploy/cli/lite.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom .cli import CLI\nfrom .utils import ArgumentHelper, DefaultsAndTypesHelpFormatter, convert_args\n\n\nclass SubCliLite(object):\n    \"\"\"CLI for compressing LLMs.\"\"\"\n    _help = 'Compressing and accelerating LLMs with lmdeploy.lite module'\n    _desc = _help\n    parser = CLI.subparsers.add_parser(\n        'lite',\n        help=_help,\n        description=_desc,\n    )\n    subparsers = parser.add_subparsers(title='Commands', description='This group has the following commands:')\n\n    @staticmethod\n    def add_parser_auto_awq():\n        \"\"\"Add parser for auto_awq command.\"\"\"\n        parser = SubCliLite.subparsers.add_parser('auto_awq',\n                                                  formatter_class=DefaultsAndTypesHelpFormatter,\n                                                  description=SubCliLite.auto_awq.__doc__,\n                                                  help=SubCliLite.auto_awq.__doc__)\n        parser.set_defaults(run=SubCliLite.auto_awq)\n        parser.add_argument('model', type=str, help='The path of model in hf format')\n        ArgumentHelper.revision(parser)\n        ArgumentHelper.download_dir(parser)\n        ArgumentHelper.work_dir(parser)\n        ArgumentHelper.calib_dataset(parser)\n        ArgumentHelper.calib_samples(parser)\n        ArgumentHelper.calib_seqlen(parser)\n        ArgumentHelper.calib_batchsize(parser)\n        ArgumentHelper.calib_search_scale(parser)\n        ArgumentHelper.dtype(parser)\n        parser.add_argument('--device', type=str, default='cuda', help='Device for weight quantization (cuda or npu)')\n        parser.add_argument('--w-bits', type=int, default=4, help='Bit number for weight quantization')\n        parser.add_argument('--w-sym', action='store_true', help='Whether to do symmetric quantization')\n        parser.add_argument('--w-group-size',\n                            type=int,\n                            default=128,\n                            help='Group size for weight quantization statistics')\n\n    @staticmethod\n    def add_parser_auto_gptq():\n        \"\"\"Add parser for auto_gptq command.\"\"\"\n        parser = SubCliLite.subparsers.add_parser('auto_gptq',\n                                                  formatter_class=DefaultsAndTypesHelpFormatter,\n                                                  description=SubCliLite.auto_gptq.__doc__,\n                                                  help=SubCliLite.auto_gptq.__doc__)\n        parser.set_defaults(run=SubCliLite.auto_gptq)\n        parser.add_argument('model', type=str, help='The path of model in hf format')\n        ArgumentHelper.revision(parser)\n        ArgumentHelper.work_dir(parser)\n        ArgumentHelper.calib_dataset(parser)\n        ArgumentHelper.calib_samples(parser)\n        ArgumentHelper.calib_seqlen(parser)\n        ArgumentHelper.calib_batchsize(parser)\n        ArgumentHelper.dtype(parser)\n        parser.add_argument('--w-bits', type=int, default=4, help='Bit number for weight quantization')\n        parser.add_argument('--w-group-size',\n                            type=int,\n                            default=128,\n                            help='Group size for weight quantization statistics')\n\n    @staticmethod\n    def add_parser_calibrate():\n        \"\"\"Add parser for calibrate command.\"\"\"\n        parser = SubCliLite.subparsers.add_parser('calibrate',\n                                                  formatter_class=DefaultsAndTypesHelpFormatter,\n                                                  description=SubCliLite.calibrate.__doc__,\n                                                  help=SubCliLite.calibrate.__doc__)\n        parser.set_defaults(run=SubCliLite.calibrate)\n        parser.add_argument('model', type=str, help='The name or path of the model to be loaded')\n        ArgumentHelper.work_dir(parser)\n        ArgumentHelper.calib_dataset(parser)\n        ArgumentHelper.calib_samples(parser)\n        ArgumentHelper.calib_seqlen(parser)\n        ArgumentHelper.calib_batchsize(parser)\n        ArgumentHelper.calib_search_scale(parser)\n        ArgumentHelper.dtype(parser)\n\n    @staticmethod\n    def add_parser_smooth_quant():\n        \"\"\"Add parser for smooth_quant command.\"\"\"\n        parser = SubCliLite.subparsers.add_parser('smooth_quant',\n                                                  formatter_class=DefaultsAndTypesHelpFormatter,\n                                                  description=SubCliLite.smooth_quant.__doc__,\n                                                  help=SubCliLite.smooth_quant.__doc__)\n        parser.set_defaults(run=SubCliLite.smooth_quant)\n        parser.add_argument('model', type=str, help='The name or path of the model to be loaded')\n        parser.add_argument('--work-dir',\n                            type=str,\n                            default='./work_dir',\n                            help='The working directory for outputs. defaults to \"./work_dir\"')\n        parser.add_argument('--device', type=str, default='cuda', help='Device for weight quantization (cuda or npu)')\n        ArgumentHelper.calib_dataset(parser)\n        ArgumentHelper.calib_samples(parser)\n        ArgumentHelper.calib_seqlen(parser)\n        ArgumentHelper.calib_batchsize(parser)\n        ArgumentHelper.calib_search_scale(parser)\n        ArgumentHelper.dtype(parser)\n        ArgumentHelper.quant_dtype(parser)\n        ArgumentHelper.revision(parser)\n        ArgumentHelper.download_dir(parser)\n\n    @staticmethod\n    def auto_awq(args):\n        \"\"\"Perform weight quantization using AWQ algorithm.\"\"\"\n        from lmdeploy.lite.apis.auto_awq import auto_awq\n        kwargs = convert_args(args)\n        auto_awq(**kwargs)\n\n    @staticmethod\n    def auto_gptq(args):\n        \"\"\"Perform weight quantization using GPTQ algorithm.\"\"\"\n        from lmdeploy.lite.apis.gptq import auto_gptq\n        kwargs = convert_args(args)\n        auto_gptq(**kwargs)\n\n    @staticmethod\n    def calibrate(args):\n        \"\"\"Perform calibration on a given dataset.\"\"\"\n        from lmdeploy.lite.apis.calibrate import calibrate\n        kwargs = convert_args(args)\n        calibrate(**kwargs)\n\n    @staticmethod\n    def smooth_quant(args):\n        \"\"\"Perform w8a8 quantization using SmoothQuant.\"\"\"\n        from lmdeploy.lite.apis.smooth_quant import smooth_quant\n        kwargs = convert_args(args)\n        smooth_quant(**kwargs)\n\n    @staticmethod\n    def add_parsers():\n        \"\"\"Add all parsers.\"\"\"\n        SubCliLite.add_parser_auto_awq()\n        SubCliLite.add_parser_auto_gptq()\n        SubCliLite.add_parser_calibrate()\n        SubCliLite.add_parser_smooth_quant()\n"
  },
  {
    "path": "lmdeploy/cli/serve.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom lmdeploy.pytorch.disagg.config import EngineRole, MigrationBackend\nfrom lmdeploy.utils import get_max_batch_size\n\nfrom .cli import CLI\nfrom .utils import (ArgumentHelper, DefaultsAndTypesHelpFormatter, convert_args, get_chat_template, get_lora_adapters,\n                    get_speculative_config)\n\n\nclass SubCliServe:\n    \"\"\"Serve LLMs and interact on terminal.\"\"\"\n    _help = 'Serve LLMs with openai API'\n    _desc = _help\n    parser = CLI.subparsers.add_parser(\n        'serve',\n        help=_help,\n        description=_desc,\n    )\n    subparsers = parser.add_subparsers(title='Commands', description='This group has the following commands:')\n\n    @staticmethod\n    def add_parser_api_server():\n        \"\"\"Add parser for api_server command.\"\"\"\n        parser = SubCliServe.subparsers.add_parser('api_server',\n                                                   formatter_class=DefaultsAndTypesHelpFormatter,\n                                                   description=SubCliServe.api_server.__doc__,\n                                                   help=SubCliServe.api_server.__doc__)\n        parser.set_defaults(run=SubCliServe.api_server)\n        parser.add_argument('model_path',\n                            type=str,\n                            help='The path of a model. it could be one of the following '\n                            'options: - i) a local directory path of a turbomind model'\n                            ' which is converted by `lmdeploy convert` command or '\n                            'download from ii) and iii). - ii) the model_id of a '\n                            'lmdeploy-quantized model hosted inside a model repo on '\n                            'huggingface.co, such as \"internlm/internlm-chat-20b-4bit\",'\n                            ' \"lmdeploy/llama2-chat-70b-4bit\", etc. - iii) the model_id'\n                            ' of a model hosted inside a model repo on huggingface.co,'\n                            ' such as \"internlm/internlm-chat-7b\", \"qwen/qwen-7b-chat \"'\n                            ', \"baichuan-inc/baichuan2-7b-chat\" and so on')\n        parser.add_argument('--server-name', type=str, default='0.0.0.0', help='Host ip for serving')\n        parser.add_argument('--server-port', type=int, default=23333, help='Server port')\n        parser.add_argument('--allow-origins',\n                            nargs='+',\n                            type=str,\n                            default=['*'],\n                            help='A list of allowed origins for cors')\n        parser.add_argument('--allow-credentials', action='store_true', help='Whether to allow credentials for cors')\n        parser.add_argument('--allow-methods',\n                            nargs='+',\n                            type=str,\n                            default=['*'],\n                            help='A list of allowed http methods for cors')\n        parser.add_argument('--allow-headers',\n                            nargs='+',\n                            type=str,\n                            default=['*'],\n                            help='A list of allowed http headers for cors')\n        parser.add_argument('--proxy-url', type=str, default=None, help='The proxy url for api server.')\n        parser.add_argument('--max-concurrent-requests',\n                            type=int,\n                            default=None,\n                            help='This refers to the number of concurrent requests that '\n                            'the server can handle. The server is designed to process the '\n                            'engine’s tasks once the maximum number of concurrent requests is '\n                            'reached, regardless of any additional requests sent by clients '\n                            'concurrently during that time. Default to None.')\n        # common args\n        ArgumentHelper.backend(parser)\n        ArgumentHelper.log_level(parser)\n        ArgumentHelper.api_keys(parser)\n        ArgumentHelper.ssl(parser)\n        ArgumentHelper.model_name(parser)\n        ArgumentHelper.max_log_len(parser)\n        ArgumentHelper.disable_fastapi_docs(parser)\n        ArgumentHelper.allow_terminate_by_client(parser)\n        ArgumentHelper.enable_abort_handling(parser)\n        # chat template args\n        ArgumentHelper.chat_template(parser)\n\n        # parsers\n        ArgumentHelper.tool_call_parser(parser)\n        ArgumentHelper.reasoning_parser(parser)\n\n        # model args\n        ArgumentHelper.revision(parser)\n        ArgumentHelper.download_dir(parser)\n\n        # pytorch engine args\n        pt_group = parser.add_argument_group('PyTorch engine arguments')\n\n        ArgumentHelper.adapters(pt_group)\n        ArgumentHelper.device(pt_group)\n        ArgumentHelper.eager_mode(pt_group)\n        ArgumentHelper.disable_vision_encoder(pt_group)\n        ArgumentHelper.logprobs_mode(pt_group)\n        ArgumentHelper.dllm_block_length(pt_group)\n        ArgumentHelper.dllm_unmasking_strategy(pt_group)\n        ArgumentHelper.dllm_denoising_steps(pt_group)\n        ArgumentHelper.dllm_confidence_threshold(pt_group)\n        ArgumentHelper.enable_return_routed_experts(pt_group)\n        ArgumentHelper.distributed_executor_backend(pt_group)\n\n        # common engine args\n        dtype_act = ArgumentHelper.dtype(pt_group)\n        tp_act = ArgumentHelper.tp(pt_group)\n        session_len_act = ArgumentHelper.session_len(pt_group)\n        max_batch_size_act = ArgumentHelper.max_batch_size(pt_group)\n        cache_max_entry_act = ArgumentHelper.cache_max_entry_count(pt_group)\n        cache_block_seq_len_act = ArgumentHelper.cache_block_seq_len(pt_group)\n        prefix_caching_act = ArgumentHelper.enable_prefix_caching(pt_group)\n        max_prefill_token_num_act = ArgumentHelper.max_prefill_token_num(pt_group)\n        quant_policy = ArgumentHelper.quant_policy(pt_group)\n        model_format = ArgumentHelper.model_format(pt_group)\n        hf_overrides = ArgumentHelper.hf_overrides(pt_group)\n        disable_metrics = ArgumentHelper.disable_metrics(pt_group)\n        dp = ArgumentHelper.dp(pt_group)\n        ArgumentHelper.ep(pt_group)\n        ArgumentHelper.enable_microbatch(pt_group)\n        ArgumentHelper.enable_eplb(pt_group)\n        ArgumentHelper.role(pt_group)\n        ArgumentHelper.migration_backend(pt_group)\n        # multi-node serving args\n        node_rank_act = ArgumentHelper.node_rank(pt_group)\n        num_nodes_act = ArgumentHelper.num_nodes(pt_group)\n\n        # turbomind args\n        tb_group = parser.add_argument_group('TurboMind engine arguments')\n        # common engine args\n        tb_group._group_actions.append(dtype_act)\n        tb_group._group_actions.append(tp_act)\n        tb_group._group_actions.append(session_len_act)\n        tb_group._group_actions.append(max_batch_size_act)\n        tb_group._group_actions.append(cache_max_entry_act)\n        tb_group._group_actions.append(cache_block_seq_len_act)\n        tb_group._group_actions.append(prefix_caching_act)\n        tb_group._group_actions.append(max_prefill_token_num_act)\n        tb_group._group_actions.append(quant_policy)\n        tb_group._group_actions.append(model_format)\n        tb_group._group_actions.append(num_nodes_act)\n        tb_group._group_actions.append(node_rank_act)\n        tb_group._group_actions.append(hf_overrides)\n        tb_group._group_actions.append(disable_metrics)\n        tb_group._group_actions.append(dp)\n        ArgumentHelper.cp(tb_group)\n        ArgumentHelper.rope_scaling_factor(tb_group)\n        ArgumentHelper.num_tokens_per_iter(tb_group)\n        ArgumentHelper.max_prefill_iters(tb_group)\n        ArgumentHelper.async_(tb_group)\n        ArgumentHelper.communicator(tb_group)\n        ArgumentHelper.dist_init_addr(tb_group)\n\n        # vlm args\n        vision_group = parser.add_argument_group('Vision model arguments')\n        ArgumentHelper.vision_max_batch_size(vision_group)\n\n        # spec decode\n        ArgumentHelper.add_spec_group(parser)\n\n    @staticmethod\n    def add_parser_proxy():\n        \"\"\"Add parser for proxy server command.\"\"\"\n        parser = SubCliServe.subparsers.add_parser('proxy',\n                                                   formatter_class=DefaultsAndTypesHelpFormatter,\n                                                   description=SubCliServe.proxy.__doc__,\n                                                   help=SubCliServe.proxy.__doc__)\n        parser.set_defaults(run=SubCliServe.proxy)\n        parser.add_argument('--server-name', type=str, default='0.0.0.0', help='Host ip for proxy serving')\n        parser.add_argument('--server-port', type=int, default=8000, help='Server port of the proxy')\n        parser.add_argument('--serving-strategy',\n                            type=str,\n                            choices=['Hybrid', 'DistServe'],\n                            default='Hybrid',\n                            help='the strategy to serve, Hybrid for colocating Prefill and Decode'\n                            'workloads into same engine, DistServe for Prefill-Decode Disaggregation')\n        parser.add_argument('--dummy-prefill', action='store_true', help='dummy prefill for performance profiler')\n        parser.add_argument('--routing-strategy',\n                            type=str,\n                            choices=['random', 'min_expected_latency', 'min_observed_latency'],\n                            default='min_expected_latency',\n                            help='the strategy to dispatch requests to nodes')\n        parser.add_argument('--disable-cache-status',\n                            action='store_true',\n                            help='Whether to disable cache status of the '\n                            'proxy. If set, the proxy will forget the status '\n                            'of the previous time')\n\n        # For Disaggregation\n        parser.add_argument('--migration-protocol',\n                            type=str,\n                            choices=['RDMA', 'NVLINK'],\n                            default='RDMA',\n                            help='transport protocol of KV migration')\n        parser.add_argument('--link-type', type=str, choices=['RoCE', 'IB'], default='RoCE', help='RDMA Link Type')\n        parser.add_argument('--disable-gdr', action='store_true', help='with GPU Direct Memory Access')\n        ArgumentHelper.api_keys(parser)\n        ArgumentHelper.ssl(parser)\n        ArgumentHelper.log_level(parser)\n\n    @staticmethod\n    def api_server(args):\n        \"\"\"Serve LLMs with restful api using fastapi.\"\"\"\n        from lmdeploy.archs import autoget_backend\n\n        max_batch_size = args.max_batch_size if args.max_batch_size \\\n            else get_max_batch_size(args.device)\n        backend = args.backend\n        if backend != 'pytorch':\n            # set auto backend mode\n            backend = autoget_backend(args.model_path)\n\n        if backend == 'pytorch':\n            from lmdeploy.messages import PytorchEngineConfig\n            adapters = get_lora_adapters(args.adapters)\n            backend_config = PytorchEngineConfig(\n                dtype=args.dtype,\n                tp=args.tp,\n                dp=args.dp,\n                ep=args.ep,\n                max_batch_size=max_batch_size,\n                cache_max_entry_count=args.cache_max_entry_count,\n                block_size=args.cache_block_seq_len,\n                session_len=args.session_len,\n                adapters=adapters,\n                enable_prefix_caching=args.enable_prefix_caching,\n                device_type=args.device,\n                quant_policy=args.quant_policy,\n                eager_mode=args.eager_mode,\n                max_prefill_token_num=args.max_prefill_token_num,\n                enable_microbatch=args.enable_microbatch,\n                enable_eplb=args.enable_eplb,\n                enable_metrics=not args.disable_metrics,\n                role=EngineRole[args.role],\n                migration_backend=MigrationBackend[args.migration_backend],\n                model_format=args.model_format,\n                hf_overrides=args.hf_overrides,\n                disable_vision_encoder=args.disable_vision_encoder,\n                logprobs_mode=args.logprobs_mode,\n                dllm_block_length=args.dllm_block_length,\n                dllm_unmasking_strategy=args.dllm_unmasking_strategy,\n                dllm_denoising_steps=args.dllm_denoising_steps,\n                dllm_confidence_threshold=args.dllm_confidence_threshold,\n                enable_return_routed_experts=args.enable_return_routed_experts,\n                distributed_executor_backend=args.distributed_executor_backend,\n            )\n        else:\n            from lmdeploy.messages import TurbomindEngineConfig\n            backend_config = TurbomindEngineConfig(dtype=args.dtype,\n                                                   tp=args.tp,\n                                                   dp=args.dp,\n                                                   cp=args.cp,\n                                                   nnodes=args.nnodes,\n                                                   node_rank=args.node_rank,\n                                                   dist_init_addr=args.dist_init_addr,\n                                                   max_batch_size=max_batch_size,\n                                                   session_len=args.session_len,\n                                                   model_format=args.model_format,\n                                                   quant_policy=args.quant_policy,\n                                                   rope_scaling_factor=args.rope_scaling_factor,\n                                                   cache_max_entry_count=args.cache_max_entry_count,\n                                                   cache_block_seq_len=args.cache_block_seq_len,\n                                                   enable_prefix_caching=args.enable_prefix_caching,\n                                                   max_prefill_token_num=args.max_prefill_token_num,\n                                                   num_tokens_per_iter=args.num_tokens_per_iter,\n                                                   max_prefill_iters=args.max_prefill_iters,\n                                                   async_=args.async_,\n                                                   communicator=args.communicator,\n                                                   enable_metrics=not args.disable_metrics,\n                                                   hf_overrides=args.hf_overrides)\n        chat_template_config = get_chat_template(args.chat_template, args.model_path)\n        speculative_config = get_speculative_config(args)\n\n        from lmdeploy.messages import VisionConfig\n        vision_config = VisionConfig(args.vision_max_batch_size)\n        if args.dp == 1 or backend == 'turbomind':\n            from lmdeploy.serve.openai.api_server import serve as run_api_server\n\n            run_api_server(\n                args.model_path,\n                model_name=args.model_name,\n                backend=backend,\n                backend_config=backend_config,\n                chat_template_config=chat_template_config,\n                vision_config=vision_config,\n                server_name=args.server_name,\n                server_port=args.server_port,\n                allow_origins=args.allow_origins,\n                allow_credentials=args.allow_credentials,\n                allow_methods=args.allow_methods,\n                allow_headers=args.allow_headers,\n                allow_terminate_by_client=args.allow_terminate_by_client,\n                enable_abort_handling=args.enable_abort_handling,\n                log_level=args.log_level.upper(),\n                api_keys=args.api_keys,\n                ssl=args.ssl,\n                proxy_url=args.proxy_url,\n                max_log_len=args.max_log_len,\n                disable_fastapi_docs=args.disable_fastapi_docs,\n                max_concurrent_requests=args.max_concurrent_requests,\n                reasoning_parser=args.reasoning_parser,\n                tool_call_parser=args.tool_call_parser,\n                speculative_config=speculative_config,\n            )\n        else:\n            from lmdeploy.serve.openai.launch_server import launch_server\n\n            launch_server(\n                args.nnodes,\n                args.node_rank,\n                args.model_path,\n                model_name=args.model_name,\n                backend=backend,\n                backend_config=backend_config,\n                chat_template_config=chat_template_config,\n                vision_config=vision_config,\n                server_name=args.server_name,\n                server_port=args.server_port,\n                allow_origins=args.allow_origins,\n                allow_credentials=args.allow_credentials,\n                allow_methods=args.allow_methods,\n                allow_headers=args.allow_headers,\n                allow_terminate_by_client=args.allow_terminate_by_client,\n                enable_abort_handling=args.enable_abort_handling,\n                log_level=args.log_level.upper(),\n                api_keys=args.api_keys,\n                ssl=args.ssl,\n                proxy_url=args.proxy_url,\n                max_log_len=args.max_log_len,\n                disable_fastapi_docs=args.disable_fastapi_docs,\n                max_concurrent_requests=args.max_concurrent_requests,\n                reasoning_parser=args.reasoning_parser,\n                tool_call_parser=args.tool_call_parser,\n                speculative_config=speculative_config,\n            )\n\n    @staticmethod\n    def proxy(args):\n        \"\"\"Proxy server that manages distributed api_server nodes.\"\"\"\n        from lmdeploy.serve.proxy.proxy import proxy\n        kwargs = convert_args(args)\n        proxy(**kwargs)\n\n    @staticmethod\n    def add_parsers():\n        SubCliServe.add_parser_api_server()\n        SubCliServe.add_parser_proxy()\n"
  },
  {
    "path": "lmdeploy/cli/utils.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nimport argparse\nimport json\nimport re\nimport sys\nfrom collections import defaultdict\nfrom typing import Any, List\n\nfrom lmdeploy.utils import get_logger\n\nlogger = get_logger('lmdeploy')\n\n\nclass DefaultsAndTypesHelpFormatter(argparse.HelpFormatter):\n    \"\"\"Formatter to output default value and type in help information.\"\"\"\n\n    def _get_help_string(self, action):\n        \"\"\"Add default and type info into help.\"\"\"\n        help = action.help\n        if '%(default)' not in action.help:\n            if action.default is not argparse.SUPPRESS:\n                defaulting_nargs = [argparse.OPTIONAL, argparse.ZERO_OR_MORE]\n                if (action.option_strings or action.nargs in defaulting_nargs) and 'default' not in help.lower():\n                    if not help.endswith('.'):\n                        help += '.'\n                    help += ' Default: %(default)s'\n                if action.type:\n                    if not help.endswith('.'):\n                        help += '.'\n                    help += ' Type: %(type)s'\n        return help\n\n\ndef convert_args(args):\n    \"\"\"Convert args to dict format.\"\"\"\n    special_names = ['run', 'command']\n    kwargs = {k[0]: k[1] for k in args._get_kwargs() if k[0] not in special_names}\n    return kwargs\n\n\ndef get_lora_adapters(adapters: List[str]):\n    \"\"\"Parse lora adapers from cli input.\n\n    Args:\n        adapters (List[str]): CLI input string of lora adapter path(s).\n\n    Returns:\n        Dict[str,str] or None: Parsed lora adapter path(s).\n    \"\"\"\n    if not adapters:\n        return None\n    n = len(adapters)\n    output = {}\n    if n == 1:\n        name = 'default'\n        path = adapters[0].strip()\n        if '=' in path:\n            name, path = path.split('=', 1)\n        output[name] = path\n    else:\n        for pair in adapters:\n            assert '=' in pair, f'Multiple lora paths must in format of ' \\\n                                 f'xxx=yyy. But given: {pair}'\n            name, path = pair.strip().split('=', 1)\n            assert name not in output, f'Multiple lora paths with repeated lora name: {name}'\n            output[name] = path\n    return output\n\n\ndef get_chat_template(chat_template: str, model_path: str = None):\n    \"\"\"Get chat template config.\n\n    Args:\n        chat_template(str): it could be a builtin chat template name, or a chat template json file\n        model_path(str): the model path, used to check deprecated chat template names\n    \"\"\"\n    import os\n\n    from lmdeploy.model import ChatTemplateConfig\n    if chat_template:\n        if os.path.isfile(chat_template):\n            return ChatTemplateConfig.from_json(chat_template)\n        else:\n            from lmdeploy.model import DEPRECATED_CHAT_TEMPLATE_NAMES, MODELS, REMOVED_CHAT_TEMPLATE_NAMES\n            if chat_template in REMOVED_CHAT_TEMPLATE_NAMES:\n                raise ValueError(f\"The chat template '{chat_template}' has been removed. \"\n                                 f'Please refer to the latest chat templates in '\n                                 f'https://lmdeploy.readthedocs.io/en/latest/advance/chat_template.html')\n            if chat_template in DEPRECATED_CHAT_TEMPLATE_NAMES:\n                logger.warning(f\"The chat template '{chat_template}' is deprecated and fallback to hf chat template.\")\n                chat_template = 'hf'\n            assert chat_template in MODELS.module_dict.keys(), \\\n                f\"chat template '{chat_template}' is not \" \\\n                f'registered. The builtin chat templates are: ' \\\n                f'{MODELS.module_dict.keys()}'\n            return ChatTemplateConfig(model_name=chat_template, model_path=model_path)\n    else:\n        return None\n\n\ndef get_speculative_config(args):\n    \"\"\"Get speculative config from args.\"\"\"\n    from lmdeploy.messages import SpeculativeConfig\n    speculative_config = None\n    if args.speculative_algorithm is not None:\n        speculative_config = SpeculativeConfig(\n            method=args.speculative_algorithm,\n            model=args.speculative_draft_model,\n            num_speculative_tokens=args.speculative_num_draft_tokens,\n        )\n    return speculative_config\n\n\nclass ArgumentHelper:\n    \"\"\"Helper class to add unified argument.\"\"\"\n\n    @staticmethod\n    def model_name(parser):\n        \"\"\"Add argument model_name to parser.\"\"\"\n\n        return parser.add_argument('--model-name',\n                                   type=str,\n                                   default=None,\n                                   help='The name of the served model. It can be accessed '\n                                   'by the RESTful API `/v1/models`. If it is not specified, '\n                                   '`model_path` will be adopted')\n\n    @staticmethod\n    def dtype(parser, default: str = 'auto'):\n        return parser.add_argument('--dtype',\n                                   type=str,\n                                   default=default,\n                                   choices=['auto', 'float16', 'bfloat16'],\n                                   help='data type for model weights and activations. '\n                                   'The \"auto\" option will use FP16 precision '\n                                   'for FP32 and FP16 models, and BF16 precision '\n                                   'for BF16 models. This option will be ignored if '\n                                   'the model is a quantized model')\n\n    @staticmethod\n    def quant_dtype(parser, default: str = 'int8'):\n        return parser.add_argument('--quant-dtype',\n                                   type=str,\n                                   default=default,\n                                   choices=['int8', 'float8_e4m3fn', 'float8_e5m2', 'fp8'],\n                                   help='data type for the quantized model weights and activations.'\n                                   'Note \"fp8\" is the short version of \"float8_e4m3fn\"')\n\n    @staticmethod\n    def model_format(parser, default: str = None):\n        return parser.add_argument('--model-format',\n                                   type=str,\n                                   default=default,\n                                   choices=['hf', 'awq', 'gptq', 'fp8', 'mxfp4'],\n                                   help='The format of input model. `hf` means `hf_llama`, '\n                                   '`awq` represents the quantized model by AWQ,'\n                                   ' and `gptq` refers to the quantized model by GPTQ')\n\n    @staticmethod\n    def revision(parser, default: str = None):\n        return parser.add_argument('--revision',\n                                   type=str,\n                                   default=default,\n                                   help='The specific model version to use. '\n                                   'It can be a branch name, a tag name, or a commit id. '\n                                   'If unspecified, will use the default version.')\n\n    @staticmethod\n    def download_dir(parser, default: str = None):\n        return parser.add_argument('--download-dir',\n                                   type=str,\n                                   default=default,\n                                   help='Directory to download and load the weights, '\n                                   'default to the default cache directory of huggingface.')\n\n    @staticmethod\n    def tp(parser):\n        \"\"\"Add argument tp to parser.\"\"\"\n\n        return parser.add_argument('--tp',\n                                   type=int,\n                                   default=1,\n                                   help='GPU number used in tensor parallelism. Should be 2^n')\n\n    @staticmethod\n    def dp(parser):\n        \"\"\"Add argument dp to parser.\"\"\"\n\n        return parser.add_argument('--dp',\n                                   type=int,\n                                   default=1,\n                                   help='data parallelism. dp_rank is required when pytorch engine is used.')\n\n    @staticmethod\n    def ep(parser):\n        \"\"\"Add argument ep to parser.\"\"\"\n\n        return parser.add_argument('--ep',\n                                   type=int,\n                                   default=1,\n                                   help='expert parallelism. dp is required when pytorch engine is used.')\n\n    @staticmethod\n    def cp(parser):\n        \"\"\"Add argument cp to parser.\"\"\"\n\n        return parser.add_argument(\n            '--cp',\n            type=int,\n            default=1,\n            help='context parallelism size in attention for turbomind backend, tp must be a multiple of cp.')\n\n    @staticmethod\n    def dp_rank(parser):\n        \"\"\"Add argument dp_rank to parser.\"\"\"\n\n        return parser.add_argument('--dp-rank',\n                                   type=int,\n                                   default=0,\n                                   help='data parallelism rank, all ranks between 0 ~ dp should be created.')\n\n    @staticmethod\n    def node_rank(parser):\n        \"\"\"Add argument node_rank to parser.\"\"\"\n\n        return parser.add_argument('--node-rank', type=int, default=0, help='The current node rank.')\n\n    @staticmethod\n    def num_nodes(parser):\n        \"\"\"Add argument num_nodes to parser.\"\"\"\n\n        return parser.add_argument('--nnodes', type=int, default=1, help='The total node nums')\n\n    @staticmethod\n    def dist_init_addr(parser):\n        \"\"\"Add argument dist_init_addr to parser.\"\"\"\n\n        return parser.add_argument('--dist-init-addr', type=str, default=None)\n\n    @staticmethod\n    def session_id(parser):\n        \"\"\"Add argument session_id to parser.\"\"\"\n\n        return parser.add_argument('--session-id', type=int, default=1, help='The identical id of a session')\n\n    @staticmethod\n    def session_len(parser, default: int = None):\n        return parser.add_argument('--session-len',\n                                   type=int,\n                                   default=default,\n                                   help='The max session length of a sequence')\n\n    @staticmethod\n    def max_batch_size(parser):\n        \"\"\"Add argument max_batch_size to parser.\"\"\"\n\n        return parser.add_argument('--max-batch-size',\n                                   type=int,\n                                   default=None,\n                                   help='Maximum batch size. If not specified, the engine will '\n                                   'automatically set it according to the device')\n\n    @staticmethod\n    def quant_policy(parser, default: int = 0):\n        \"\"\"Add argument quant_policy to parser.\"\"\"\n\n        return parser.add_argument('--quant-policy',\n                                   type=int,\n                                   default=0,\n                                   choices=[0, 4, 8],\n                                   help='Quantize kv or not. 0: no quant; 4: 4bit kv; 8: 8bit kv')\n\n    @staticmethod\n    def rope_scaling_factor(parser):\n        \"\"\"Add argument rope_scaling_factor to parser.\"\"\"\n\n        return parser.add_argument('--rope-scaling-factor', type=float, default=0.0, help='Rope scaling factor')\n\n    @staticmethod\n    def hf_overrides(parser):\n        \"\"\"Add argument hf_overrides to parser.\"\"\"\n        return parser.add_argument('--hf-overrides',\n                                   type=json.loads,\n                                   default=None,\n                                   help='Extra arguments to be forwarded to the HuggingFace config.')\n\n    @staticmethod\n    def use_logn_attn(parser):\n        \"\"\"Add argument use_logn_attn to parser.\"\"\"\n\n        return parser.add_argument('--use-logn-attn',\n                                   action='store_true',\n                                   default=False,\n                                   help='Whether to use logn attention scaling')\n\n    @staticmethod\n    def block_size(parser):\n        \"\"\"Add argument block_size to parser.\"\"\"\n\n        return parser.add_argument('--block-size', type=int, default=64, help='The block size for paging cache')\n\n    @staticmethod\n    def top_p(parser):\n        \"\"\"Add argument top_p to parser.\"\"\"\n\n        return parser.add_argument('--top-p',\n                                   type=float,\n                                   default=0.8,\n                                   help='An alternative to sampling with temperature,'\n                                   ' called nucleus sampling, where the model '\n                                   'considers the results of the tokens with '\n                                   'top_p probability mass')\n\n    @staticmethod\n    def top_k(parser):\n        \"\"\"Add argument top_k to parser.\"\"\"\n\n        return parser.add_argument('--top-k',\n                                   type=int,\n                                   default=1,\n                                   help='An alternative to sampling with temperature, '\n                                   'where the model considers the top_k tokens '\n                                   'with the highest probability')\n\n    @staticmethod\n    def temperature(parser, default: float = 0.8):\n        return parser.add_argument('-temp', '--temperature', type=float, default=default, help='Sampling temperature')\n\n    @staticmethod\n    def repetition_penalty(parser):\n        \"\"\"Add argument repetition_penalty to parser.\"\"\"\n\n        return parser.add_argument('--repetition-penalty',\n                                   type=float,\n                                   default=1.0,\n                                   help='Parameter to penalize repetition')\n\n    @staticmethod\n    def log_level(parser):\n        \"\"\"Add argument log_level to parser.\"\"\"\n\n        import logging\n        return parser.add_argument('--log-level',\n                                   type=str,\n                                   default='WARNING',\n                                   choices=list(logging._nameToLevel.keys()),\n                                   help='Set the log level')\n\n    @staticmethod\n    def api_keys(parser):\n        return parser.add_argument(\n            '--api-keys',\n            type=str,\n            nargs='*',\n            default=None,\n            help='Optional list of space separated API keys',\n        )\n\n    @staticmethod\n    def ssl(parser):\n        return parser.add_argument(\n            '--ssl',\n            action='store_true',\n            required=False,\n            default=False,\n            help='Enable SSL. Requires OS Environment variables'\n            \" 'SSL_KEYFILE' and 'SSL_CERTFILE'\",\n        )\n\n    @staticmethod\n    def backend(parser):\n        \"\"\"Add argument backend to parser.\"\"\"\n\n        return parser.add_argument('--backend',\n                                   type=str,\n                                   default='turbomind',\n                                   choices=['pytorch', 'turbomind'],\n                                   help='Set the inference backend')\n\n    @staticmethod\n    def stream_output(parser):\n        \"\"\"Add argument stream_output to parser.\"\"\"\n\n        return parser.add_argument('--stream-output', action='store_true', help='Indicator for streaming output or not')\n\n    @staticmethod\n    def calib_dataset(parser):\n        \"\"\"Add argument calib_dataset to parser.\"\"\"\n\n        return parser.add_argument(\n            '--calib-dataset',\n            type=str,\n            default='wikitext2',\n            choices=['wikitext2', 'c4', 'pileval', 'gsm8k', 'neuralmagic_calibration', 'open-platypus', 'openwebtext'],\n            help='The calibration dataset name.')\n\n    @staticmethod\n    def calib_samples(parser):\n        \"\"\"Add argument calib_samples to parser.\"\"\"\n\n        return parser.add_argument('--calib-samples',\n                                   type=int,\n                                   default=128,\n                                   help='The number of samples for calibration')\n\n    @staticmethod\n    def calib_seqlen(parser):\n        \"\"\"Add argument calib_seqlen to parser.\"\"\"\n\n        return parser.add_argument('--calib-seqlen', type=int, default=2048, help='The sequence length for calibration')\n\n    @staticmethod\n    def calib_batchsize(parser):\n        \"\"\"Add argument batch_size to parser.\"\"\"\n\n        return parser.add_argument(\n            '--batch-size',\n            type=int,\n            default=1,\n            help=\\\n            'The batch size for running the calib samples. Low GPU mem requires small batch_size. Large batch_size reduces the calibration time while costs more VRAM'  # noqa\n        )\n\n    @staticmethod\n    def calib_search_scale(parser):\n        \"\"\"Add argument search_scale to parser.\"\"\"\n\n        return parser.add_argument(\n            '--search-scale',\n            action='store_true',\n            default=False,\n            help=\\\n            'Whether search scale ratio. Default to be disabled, which means only smooth quant with 0.5 ratio will be applied'  # noqa\n        )\n\n    @staticmethod\n    def device(parser, default: str = 'cuda', choices: List[str] = ['cuda', 'ascend', 'maca', 'camb']):\n        \"\"\"Add argument device to parser.\"\"\"\n\n        return parser.add_argument('--device',\n                                   type=str,\n                                   default=default,\n                                   choices=choices,\n                                   help='The device type of running')\n\n    @staticmethod\n    def chat_template(parser):\n        \"\"\"Add chat template config to parser.\"\"\"\n\n        return parser.add_argument(\n            '--chat-template',\n            type=str,\n            default=None,\n            help=\\\n            'A JSON file or string that specifies the chat template configuration. '  # noqa\n            'Please refer to https://lmdeploy.readthedocs.io/en/latest/advance/chat_template.html for the specification'  # noqa\n        )\n\n    @staticmethod\n    def reasoning_parser(parser):\n        \"\"\"Add reasoning parser to parser.\"\"\"\n        from lmdeploy.serve.openai.reasoning_parser import ReasoningParserManager\n        return parser.add_argument(\n            '--reasoning-parser',\n            type=str,\n            default=None,\n            help=f'The registered reasoning parser name from {ReasoningParserManager.module_dict.keys()}. '\n            'Default to None.')\n\n    @staticmethod\n    def tool_call_parser(parser):\n        \"\"\"Add tool call parser to parser.\"\"\"\n        from lmdeploy.serve.openai.tool_parser import ToolParserManager\n\n        return parser.add_argument(\n            '--tool-call-parser',\n            type=str,\n            default=None,\n            help=f'The registered tool parser name {ToolParserManager.module_dict.keys()}. Default to None.')\n\n    @staticmethod\n    def allow_terminate_by_client(parser):\n        \"\"\"Add argument allow_terminate_by_client to parser.\"\"\"\n\n        return parser.add_argument('--allow-terminate-by-client',\n                                   action='store_true',\n                                   default=False,\n                                   help='Enable server to be terminated by request from client')\n\n    @staticmethod\n    def enable_abort_handling(parser):\n        \"\"\"Add --enable-abort-handling argument to configure server abort\n        request processing.\"\"\"\n\n        return parser.add_argument('--enable-abort-handling',\n                                   action='store_true',\n                                   default=False,\n                                   help='Enable server to handle client abort requests')\n\n    @staticmethod\n    def cache_max_entry_count(parser):\n        \"\"\"Add argument cache_max_entry_count to parser.\"\"\"\n\n        return parser.add_argument('--cache-max-entry-count',\n                                   type=float,\n                                   default=0.8,\n                                   help='The percentage of free gpu memory occupied by the k/v '\n                                   'cache, excluding weights ')\n\n    @staticmethod\n    def adapters(parser):\n        \"\"\"Add argument adapters to parser.\"\"\"\n\n        return parser.add_argument('--adapters',\n                                   nargs='*',\n                                   type=str,\n                                   default=None,\n                                   help='Used to set path(s) of lora adapter(s). One can input '\n                                   'key-value pairs in xxx=yyy format for multiple lora '\n                                   'adapters. If only have one adapter, one can only input '\n                                   'the path of the adapter.')\n\n    @staticmethod\n    def work_dir(parser):\n        \"\"\"Add argument work_dir to parser.\"\"\"\n\n        return parser.add_argument('--work-dir',\n                                   type=str,\n                                   default='./work_dir',\n                                   help='The working directory to save results')\n\n    @staticmethod\n    def cache_block_seq_len(parser):\n        \"\"\"Add argument cache_block_seq_len to parser.\"\"\"\n\n        return parser.add_argument('--cache-block-seq-len',\n                                   type=int,\n                                   default=64,\n                                   help='The length of the token sequence in a k/v block. '\n                                   'For Turbomind Engine, if the GPU compute capability '\n                                   'is >= 8.0, it should be a multiple of 32, otherwise '\n                                   'it should be a multiple of 64. For Pytorch Engine, '\n                                   'if Lora Adapter is specified, this parameter will '\n                                   'be ignored')\n\n    @staticmethod\n    def enable_prefix_caching(parser):\n        \"\"\"Add argument enable_prefix_caching to parser.\"\"\"\n\n        return parser.add_argument('--enable-prefix-caching',\n                                   action='store_true',\n                                   default=False,\n                                   help='Enable cache and match prefix')\n\n    @staticmethod\n    def num_tokens_per_iter(parser):\n        return parser.add_argument('--num-tokens-per-iter',\n                                   type=int,\n                                   default=0,\n                                   help='the number of tokens processed in a forward pass')\n\n    @staticmethod\n    def max_prefill_iters(parser):\n        return parser.add_argument('--max-prefill-iters',\n                                   type=int,\n                                   default=1,\n                                   help='the max number of forward passes in prefill stage')\n\n    @staticmethod\n    def async_(parser):\n        return parser.add_argument('--async',\n                                   type=int,\n                                   default=1,\n                                   choices=[0, 1],\n                                   dest='async_',\n                                   help='Enable async execution (default: 1, enabled). '\n                                   'Set to 0 to disable async mode, 1 to enable it.')\n\n    @staticmethod\n    def max_prefill_token_num(parser):\n        return parser.add_argument('--max-prefill-token-num',\n                                   type=int,\n                                   default=8192,\n                                   help='the max number of tokens per iteration during prefill')\n\n    @staticmethod\n    def vision_max_batch_size(parser):\n        return parser.add_argument('--vision-max-batch-size', type=int, default=1, help='the vision model batch size')\n\n    @staticmethod\n    def max_log_len(parser):\n        return parser.add_argument('--max-log-len',\n                                   type=int,\n                                   default=None,\n                                   help='Max number of prompt characters or prompt tokens being '\n                                   'printed in log. Default: Unlimited')\n\n    @staticmethod\n    def disable_fastapi_docs(parser):\n        return parser.add_argument('--disable-fastapi-docs',\n                                   action='store_true',\n                                   default=False,\n                                   help=\"Disable FastAPI's OpenAPI schema,\"\n                                   ' Swagger UI, and ReDoc endpoint')\n\n    @staticmethod\n    def eager_mode(parser):\n        \"\"\"Add argument eager_mode to parser.\"\"\"\n\n        return parser.add_argument('--eager-mode',\n                                   action='store_true',\n                                   default=False,\n                                   help='Whether to enable eager mode. '\n                                   'If True, cuda graph would be disabled')\n\n    @staticmethod\n    def communicator(parser):\n        return parser.add_argument('--communicator',\n                                   type=str,\n                                   default='nccl',\n                                   choices=['nccl', 'native', 'cuda-ipc'],\n                                   help='Communication backend for multi-GPU inference. The \"native\" option is '\n                                   'deprecated and serves as an alias for \"cuda-ipc\"')\n\n    @staticmethod\n    def enable_microbatch(parser):\n        \"\"\"Add argument enable_microbatch to parser.\"\"\"\n\n        return parser.add_argument('--enable-microbatch',\n                                   action='store_true',\n                                   help='enable microbatch for specified model')\n\n    @staticmethod\n    def enable_eplb(parser):\n        \"\"\"Add argument enable_eplb to parser.\"\"\"\n\n        return parser.add_argument('--enable-eplb', action='store_true', help='enable eplb for specified model')\n\n    @staticmethod\n    def disable_metrics(parser):\n        \"\"\"Add argument disable_metrics to parser.\"\"\"\n        return parser.add_argument('--disable-metrics',\n                                   action='store_true',\n                                   default=False,\n                                   help='disable metrics system')\n\n    # For Disaggregation\n    @staticmethod\n    def role(parser):\n        return parser.add_argument('--role',\n                                   type=str,\n                                   default='Hybrid',\n                                   choices=['Hybrid', 'Prefill', 'Decode'],\n                                   help='Hybrid for Non-Disaggregated Engine; '\n                                   'Prefill for Disaggregated Prefill Engine; '\n                                   'Decode for Disaggregated Decode Engine')\n\n    @staticmethod\n    def migration_backend(parser):\n        return parser.add_argument('--migration-backend',\n                                   type=str,\n                                   default='DLSlime',\n                                   choices=['DLSlime', 'Mooncake'],\n                                   help='kvcache migration management backend when PD disaggregation')\n\n    @staticmethod\n    def disable_vision_encoder(parser):\n        \"\"\"Disable loading vision encoder.\"\"\"\n        return parser.add_argument('--disable-vision-encoder',\n                                   action='store_true',\n                                   default=False,\n                                   help='disable multimodal encoder')\n\n    @staticmethod\n    def logprobs_mode(parser):\n        \"\"\"The mode of logprobs.\"\"\"\n        return parser.add_argument('--logprobs-mode',\n                                   type=str,\n                                   default=None,\n                                   choices=[None, 'raw_logits', 'raw_logprobs'],\n                                   help='The mode of logprobs.')\n\n    @staticmethod\n    def dllm_block_length(parser):\n        \"\"\"dllm_block_length for dllm.\"\"\"\n        return parser.add_argument('--dllm-block-length', type=int, default=None, help='Block length for dllm')\n\n    @staticmethod\n    def dllm_unmasking_strategy(parser):\n        \"\"\"Dllm unmasking strategy.\"\"\"\n        return parser.add_argument('--dllm-unmasking-strategy',\n                                   type=str,\n                                   default='low_confidence_dynamic',\n                                   choices=['low_confidence_dynamic', 'low_confidence_static', 'sequential'],\n                                   help='The unmasking strategy for dllm.')\n\n    @staticmethod\n    def dllm_denoising_steps(parser):\n        \"\"\"Dllm denoising steps.\"\"\"\n        return parser.add_argument('--dllm-denoising-steps',\n                                   type=int,\n                                   default=None,\n                                   help='The number of denoising steps for dllm.')\n\n    @staticmethod\n    def dllm_confidence_threshold(parser):\n        \"\"\"Dllm confidence threshold.\"\"\"\n        return parser.add_argument('--dllm-confidence-threshold',\n                                   type=float,\n                                   default=0.85,\n                                   help='The confidence threshold for dllm.')\n\n    @staticmethod\n    def enable_return_routed_experts(parser):\n        \"\"\"Add argument return routed experts to parser.\"\"\"\n\n        return parser.add_argument('--enable-return-routed-experts',\n                                   action='store_true',\n                                   default=False,\n                                   help='Whether to output routed expert ids for replay')\n\n    @staticmethod\n    def add_spec_group(parser):\n        spec_group = parser.add_argument_group('Speculative decoding arguments')\n        spec_group.add_argument('--speculative-algorithm',\n                                type=str,\n                                default=None,\n                                choices=['eagle', 'eagle3', 'deepseek_mtp'],\n                                help='The speculative algorithm to use. `None` means speculative decoding is disabled')\n\n        spec_group.add_argument('--speculative-draft-model',\n                                type=str,\n                                default=None,\n                                help='The path to speculative draft model')\n\n        spec_group.add_argument('--speculative-num-draft-tokens',\n                                type=int,\n                                default=1,\n                                help='The number of speculative tokens to generate per step')\n\n        return spec_group\n\n    @staticmethod\n    def distributed_executor_backend(parser):\n        \"\"\"Distributed_executor_backend.\"\"\"\n        return parser.add_argument('--distributed-executor-backend',\n                                   type=str,\n                                   default=None,\n                                   choices=['uni', 'mp', 'ray'],\n                                   help='The distributed executor backend for pytorch engine.')\n\n\n# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/utils/__init__.py\nclass FlexibleArgumentParser(argparse.ArgumentParser):\n    \"\"\"\"More flexible argument parser.\"\"\"\n\n    def parse_args(self, args=None, namespace=None):\n        # If args is not provided, use arguments from the command line\n        if args is None:\n            args = sys.argv[1:]\n\n        def repl(match: re.Match) -> str:\n            \"\"\"Replaces underscores with dashes in the matched string.\"\"\"\n            return match.group(0).replace('_', '-')\n\n        # Everything between the first -- and the first .\n        pattern = re.compile(r'(?<=--)[^\\.]*')\n\n        # Convert underscores to dashes and vice versa in argument names\n        processed_args = []\n        for arg in args:\n            if arg.startswith('--'):\n                if '=' in arg:\n                    key, value = arg.split('=', 1)\n                    key = pattern.sub(repl, key, count=1)\n                    processed_args.append(f'{key}={value}')\n                else:\n                    key = pattern.sub(repl, arg, count=1)\n                    processed_args.append(key)\n            elif arg.startswith('-O') and arg != '-O' and len(arg) == 2:\n                # allow -O flag to be used without space, e.g. -O3\n                processed_args.append('-O')\n                processed_args.append(arg[2:])\n            else:\n                processed_args.append(arg)\n\n        def _try_convert(value: str):\n            \"\"\"Try to convert string to float or int.\"\"\"\n            if not isinstance(value, str):\n                return value\n            # try loads from json\n            try:\n                return json.loads(value)\n            except json.JSONDecodeError:\n                pass\n            return value\n\n        def create_nested_dict(keys: list[str], value: str):\n            \"\"\"Creates a nested dictionary from a list of keys and a value.\n\n            For example, `keys = [\"a\", \"b\", \"c\"]` and `value = 1` will create: `{\"a\": {\"b\": {\"c\": 1}}}`\n            \"\"\"\n            nested_dict: Any = _try_convert(value)\n            for key in reversed(keys):\n                nested_dict = {key: nested_dict}\n            return nested_dict\n\n        def recursive_dict_update(original: dict, update: dict):\n            \"\"\"Recursively updates a dictionary with another dictionary.\"\"\"\n            for k, v in update.items():\n                if isinstance(v, dict) and isinstance(original.get(k), dict):\n                    recursive_dict_update(original[k], v)\n                else:\n                    original[k] = v\n\n        delete = set()\n        dict_args: dict[str, dict] = defaultdict(dict)\n        for i, processed_arg in enumerate(processed_args):\n            if processed_arg.startswith('--') and '.' in processed_arg:\n                if '=' in processed_arg:\n                    processed_arg, value = processed_arg.split('=', 1)\n                    if '.' not in processed_arg:\n                        # False positive, . was only in the value\n                        continue\n                else:\n                    value = processed_args[i + 1]\n                    delete.add(i + 1)\n                key, *keys = processed_arg.split('.')\n                # Merge all values with the same key into a single dict\n                arg_dict = create_nested_dict(keys, value)\n                recursive_dict_update(dict_args[key], arg_dict)\n                delete.add(i)\n        # Filter out the dict args we set to None\n        processed_args = [a for i, a in enumerate(processed_args) if i not in delete]\n        # Add the dict args back as if they were originally passed as JSON\n        for dict_arg, dict_value in dict_args.items():\n            processed_args.append(dict_arg)\n            processed_args.append(json.dumps(dict_value))\n\n        return super().parse_args(processed_args, namespace)\n"
  },
  {
    "path": "lmdeploy/lite/__init__.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom .apis import *  # noqa: F401,F403\nfrom .quantization import *  # noqa: F401,F403\nfrom .utils import *  # noqa: F401,F403\n"
  },
  {
    "path": "lmdeploy/lite/apis/__init__.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n"
  },
  {
    "path": "lmdeploy/lite/apis/auto_awq.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nimport os\nimport os.path as osp\nimport shutil\nfrom typing import Literal\n\nimport torch\nfrom torch import nn\n\nfrom lmdeploy.lite.quantization.awq import FC_FCS_MAP, NORM_FCS_MAP, awq_layers, quant_weights, smooth_layers\nfrom lmdeploy.lite.utils import collect_target_modules\nfrom lmdeploy.utils import try_import_deeplink\n\nfrom .calibrate import LAYER_TYPE_MAP, calibrate\n\n\ndef save_vl_model(vl_model, model_path, dst_path):\n    vl_model.save_pretrained(dst_path, safe_serialization=True)\n    candidate = [\n        'preprocessor_config.json', 'processor_config.json', 'vit', 'generation_config.json', 'added_tokens.json'\n    ]\n    for name in candidate:\n        tmp_path = osp.join(model_path, name)\n        if osp.exists(tmp_path):\n            if osp.isfile(tmp_path):\n                shutil.copy(tmp_path, osp.join(dst_path, name))\n            elif osp.isdir(tmp_path):\n                shutil.copytree(tmp_path, osp.join(dst_path, name))\n    # AutoProcessor files\n    allfiles = os.listdir(model_path)\n    for file in allfiles:\n        if not file.endswith('.py'):\n            continue\n        copy_src = osp.join(model_path, file)\n        copy_dst = osp.join(dst_path, file)\n        if not osp.exists(copy_dst):\n            shutil.copyfile(copy_src, copy_dst)\n\n\ndef auto_awq(model: str,\n             work_dir: str = './work_dir',\n             calib_dataset: str = 'wikitext2',\n             calib_samples: int = 128,\n             batch_size: int = 1,\n             calib_seqlen: int = 2048,\n             w_bits: int = 4,\n             w_sym: bool = False,\n             w_group_size: int = 128,\n             search_scale: bool = False,\n             device: str = 'cuda',\n             revision: str = None,\n             dtype: Literal['float16', 'bfloat16', 'auto'] = 'auto',\n             download_dir: str = None):\n    \"\"\"Perform weight quantization using AWQ algorithm.\n\n    Args:\n        model (str): The path of model in hf format.\n        work_dir (str): The working directory to save results.\n        calib_dataset (str): The calibration dataset name.\n            Defaults to 'wikitext2'.\n        calib_samples (int): The number of samples for calibration.\n        batch_size (int): The batch size for running the calib samples.\n            Low GPU mem requires small batch_size. Large batch_size\n            reduces the calibration time while costs more VRAM.\n        calib_seqlen (int): The sequence length for calibration.\n        w_bits (int): Bit number for weight quantization.\n        w_sym (bool): Whether to do symmetric quantization.\n        w_group_size (int): Group size for weight quantization statistics.\n        search_scale (bool): Whether search scale ratio. Default to False,\n            which means only smooth quant with 0.5 ratio will be applied.\n        device (str): Device type of running.\n        revision (str): The specific model version to use. It can be a\n            branch name, a tag name, or a commit id. If unspecified,\n            will use the default version.\n        dtype (str): Data type for loading model weights and calib infer.\n        download_dir (str): Directory to download and load the weights,\n            default to the default cache directory of huggingface.\n    \"\"\"\n    try_import_deeplink(device)\n    if not osp.exists(model):\n        print(f'can\\'t find model from local_path {model}, '\n              'try to download from remote')\n        from lmdeploy.utils import get_model\n        model = get_model(model, revision=revision, download_dir=download_dir)\n    model_path = model\n    vl_model, model, tokenizer, work_dir = calibrate(model,\n                                                     calib_dataset,\n                                                     calib_samples,\n                                                     calib_seqlen,\n                                                     work_dir,\n                                                     device,\n                                                     w_bits=w_bits,\n                                                     w_group_size=w_group_size,\n                                                     search_scale=search_scale,\n                                                     dtype=dtype,\n                                                     batch_size=batch_size)\n\n    layer_type = LAYER_TYPE_MAP[type(model).__name__]\n    fc2fcs = FC_FCS_MAP[layer_type]\n    norm2fcs = NORM_FCS_MAP[layer_type]\n    input_stats = torch.load(osp.join(work_dir, 'inputs_stats.pth'), weights_only=True)\n    layers = collect_target_modules(model, layer_type)\n    fcs = {}\n    for l_name, layer in layers.items():\n        name2fc = collect_target_modules(layer, nn.Linear, prefix=l_name)\n        fcs.update(name2fc)\n\n    if search_scale:\n        awq_ratios = input_stats['ratios']\n        act_scales = input_stats['absmean']\n        awq_layers(layers, fc2fcs, norm2fcs, act_scales, awq_ratios, w_group_size, device)\n    else:\n        act_scales = input_stats['absmax']\n        smooth_layers(layers, fc2fcs, norm2fcs, act_scales, w_group_size, device)\n    quant_weights(model, fcs, w_bits, w_sym, w_group_size, device)\n    quantization_config = dict(quant_method='awq',\n                               version='gemm',\n                               bits=w_bits,\n                               group_size=w_group_size,\n                               zero_point=not w_sym)\n    model.config.update(dict(quantization_config=quantization_config))\n\n    if vl_model:\n        save_vl_model(vl_model, model_path, work_dir)\n    else:\n        model.save_pretrained(work_dir, safe_serialization=True)\n    tokenizer.save_pretrained(work_dir)\n\n\nif __name__ == '__main__':\n    import fire\n\n    fire.Fire(auto_awq)\n"
  },
  {
    "path": "lmdeploy/lite/apis/calibrate.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom pathlib import Path\nfrom typing import Literal, Union\n\nimport torch\nfrom torch import nn\nfrom transformers import AutoTokenizer\n\nfrom lmdeploy.archs import get_task\nfrom lmdeploy.lite.quantization import CalibrationContext, CalibrationContextV2\nfrom lmdeploy.lite.utils import collect_target_modules, get_calib_loaders, load_hf_from_pretrained\nfrom lmdeploy.vl.model.builder import load_vl_model\n\nLAYER_TYPE_MAP = {\n    'InternLMForCausalLM': 'InternLMDecoderLayer',\n    'InternLM2ForCausalLM': 'InternLM2DecoderLayer',\n    'InternLM3ForCausalLM': 'InternLM3DecoderLayer',\n    'QWenLMHeadModel': 'QWenBlock',\n    'Qwen2ForCausalLM': 'Qwen2DecoderLayer',\n    'Qwen3ForCausalLM': 'Qwen3DecoderLayer',\n    'BaiChuanForCausalLM': 'DecoderLayer',  # Baichuan 7B\n    'BaichuanForCausalLM': 'DecoderLayer',  # Baichuan2 7B\n    'LlamaForCausalLM': 'LlamaDecoderLayer',\n    'LlavaLlamaForCausalLM': 'LlamaDecoderLayer',\n    'MGMLlamaForCausalLM': 'LlamaDecoderLayer',  # mini gemini\n    'InternLMXComposer2ForCausalLM': 'InternLM2DecoderLayer',\n    'Phi3ForCausalLM': 'Phi3DecoderLayer',\n    'ChatGLMForConditionalGeneration': 'GLMBlock',\n    'MixtralForCausalLM': 'MixtralDecoderLayer',\n    'Qwen2VLForConditionalGeneration': 'Qwen2VLDecoderLayer',\n    'Qwen2_5_VLForConditionalGeneration': 'Qwen2_5_VLDecoderLayer',\n    'MistralForCausalLM': 'MistralDecoderLayer',\n}\n\nNORM_TYPE_MAP = {\n    'InternLMForCausalLM': 'InternLMRMSNorm',\n    'InternLM2ForCausalLM': 'InternLM2RMSNorm',\n    'InternLM3ForCausalLM': 'InternLM3RMSNorm',\n    'QWenLMHeadModel': 'RMSNorm',\n    'Qwen2ForCausalLM': 'Qwen2RMSNorm',\n    'Qwen3ForCausalLM': 'Qwen3RMSNorm',\n    'BaiChuanForCausalLM': 'RMSNorm',  # Baichuan 7B\n    'BaichuanForCausalLM': 'RMSNorm',  # Baichuan2 7B\n    'LlamaForCausalLM': 'LlamaRMSNorm',\n    'LlavaLlamaForCausalLM': 'LlamaRMSNorm',\n    'MGMLlamaForCausalLM': 'LlamaRMSNorm',  # mini gemini\n    'InternLMXComposer2ForCausalLM': 'InternLM2RMSNorm',\n    'Phi3ForCausalLM': 'Phi3RMSNorm',\n    'ChatGLMForConditionalGeneration': 'RMSNorm',\n    'MixtralForCausalLM': 'MixtralRMSNorm',\n    'Qwen2VLForConditionalGeneration': 'Qwen2RMSNorm',\n    'Qwen2_5_VLForConditionalGeneration': 'Qwen2RMSNorm',\n    'MistralForCausalLM': 'MistralRMSNorm',\n}\n\nHEAD_NAME_MAP = {\n    'InternLMForCausalLM': 'lm_head',\n    'InternLM2ForCausalLM': 'output',\n    'InternLM3ForCausalLM': 'output',\n    'QWenLMHeadModel': 'lm_head',\n    'Qwen2ForCausalLM': 'lm_head',\n    'Qwen3ForCausalLM': 'lm_head',\n    'BaiChuanForCausalLM': 'lm_head',  # Baichuan 7B\n    'BaichuanForCausalLM': 'lm_head',  # Baichuan2 7B\n    'LlamaForCausalLM': 'lm_head',\n    'LlavaLlamaForCausalLM': 'lm_head',\n    'MGMLlamaForCausalLM': 'lm_head',  # mini gemini\n    'InternLMXComposer2ForCausalLM': 'output',\n    'Phi3ForCausalLM': 'lm_head',\n    'ChatGLMForConditionalGeneration': 'output_layer',\n    'MixtralForCausalLM': 'lm_head',\n    'Qwen2VLForConditionalGeneration': 'lm_head',\n    'Qwen2_5_VLForConditionalGeneration': 'lm_head',\n    'MistralForCausalLM': 'lm_head',\n}\n\n\ndef _prepare_for_calibrate(model: nn.Module,\n                           layer_type: Union[str, type],\n                           head_name: str = 'lm_head',\n                           device: str = 'cuda',\n                           prefix: str = '') -> None:\n    \"\"\"Prepare the model for calibration by moving specific modules to CPU.\n\n    This function goes through each child of a given model and checks whether\n    it is an instance of a certain layer type or has the name equal to\n    `head_name`.\n    If yes, it moves the module to CPU, otherwise to the specified device\n    (default is CUDA).\n\n    If the child contains the target layer type in its sub-modules, the\n    function performs the same operation recursively.\n\n    Parameters\n    ----------\n    model : nn.Module\n        The PyTorch model to prepare for calibration.\n    layer_type : Union[str, Type]\n        The type of the layer to be moved to CPU. Can be either a string of\n        class name or the class type itself.\n    head_name : str, optional\n        The name of the module to be moved to CPU. Default is 'lm_head'.\n    device : str, optional\n        The device to which modules not matching the `layer_type` or\n        `head_name` will be moved. Default is 'cuda'.\n    prefix : str, optional\n        The prefix used when printing the names of the moved modules.\n        Default is ''.\n\n    Raises\n    ------\n    TypeError\n        If `layer_type` is neither a string nor a type.\n    \"\"\"\n\n    for name, child in model.named_children():\n\n        # Check if the child is an instance of the given layer type\n        if isinstance(layer_type, str):\n            is_layer = type(child).__name__ == layer_type\n        elif isinstance(layer_type, type):\n            is_layer = isinstance(child, layer_type)\n        else:\n            raise TypeError('layer_type should be a string (class name) or a type')\n\n        # Check if the child contains the target module type\n        contain_layer = len(collect_target_modules(child, layer_type, [head_name]).keys()) > 0\n\n        # Check if the child matches the head name\n        is_head = name == head_name\n        # skip moving head layer to CPU when tie_word_embeddings is True\n        is_head = is_head and not getattr(model.config, 'tie_word_embeddings', False)\n\n        mod_name = f'{prefix}.{name}' if prefix else name\n\n        # If the child is either an instance of the layer type or has the\n        # head name, move it to CPU, otherwise move it to the specified device\n        if is_layer or is_head:\n            child.to('cpu')\n            print(f'Move {mod_name} to CPU.')\n        elif contain_layer:\n            _prepare_for_calibrate(child, layer_type, head_name, device, mod_name)\n        else:\n            child.to(device)\n            print(f'Move {mod_name} to GPU.')\n\n\n# TODO to be removed\ndef make_compatible_internvl_config(model_path):\n    \"\"\"Patch model.config since after transformers v4.45.0, InternVL models\n    can't use `save_pretrained`\"\"\"\n    from lmdeploy.archs import get_model_arch\n    arch, _ = get_model_arch(model_path)\n    if arch == 'InternVLChatModel':\n        import transformers\n        from packaging import version\n        if version.parse(transformers.__version__) >= version.parse('4.45.0'):\n\n            def _get_non_default_generation_parameters(self):\n                return {}\n\n            from transformers import PretrainedConfig\n            PretrainedConfig._get_non_default_generation_parameters = _get_non_default_generation_parameters  # noqa\n\n\ndef update_moe_mapping(model, model_type):\n    \"\"\"Update moe mapping.\"\"\"\n    from lmdeploy.lite.quantization.awq import FC_FCS_MAP, NORM_FCS_MAP\n\n    # get experts num\n    num_experts = 0\n    for n, m in model.named_modules():\n        if type(m).__name__ == LAYER_TYPE_MAP[model_type]:\n            fc2fcs = FC_FCS_MAP[LAYER_TYPE_MAP[model_type]]\n            for k, v in fc2fcs.items():\n                if '{i}' in k:\n                    break\n            num_experts = len(m.get_submodule(k.split('.{i}')[0]))\n            break\n\n    # update FC_FCS_MAP\n    updated_fc2fcs = dict()\n    for prev_fc, post_fc in fc2fcs.items():\n        if '{i}' in prev_fc:\n            for i in range(num_experts):\n                updated_fc2fcs.update({prev_fc.format(i=i): [v.format(i=i) for v in post_fc]})\n        else:\n            updated_fc2fcs.update({prev_fc: post_fc})\n    FC_FCS_MAP[LAYER_TYPE_MAP[model_type]] = updated_fc2fcs\n    # update NORM_FCS_MAP\n    norm2fcs = NORM_FCS_MAP[LAYER_TYPE_MAP[model_type]]\n    updated_norm2fcs = dict()\n    for norm, fc in norm2fcs.items():\n        updated_norm2fcs.update({norm: list(set([v.format(i=i) for v in fc for i in range(num_experts)]))})\n    NORM_FCS_MAP[LAYER_TYPE_MAP[model_type]] = updated_norm2fcs\n\n\ndef calibrate(model: str,\n              calib_dataset: str = 'wikitext2',\n              calib_samples: int = 128,\n              calib_seqlen: int = 2048,\n              work_dir: str = './work_dir',\n              device: str = 'cuda',\n              w_bits: int = 4,\n              w_group_size: int = 128,\n              search_scale: bool = False,\n              dtype: Literal['float16', 'bfloat16', 'auto'] = 'auto',\n              batch_size: int = 1) -> None:\n    \"\"\"The main function for loading the model and performing calibration on a\n    given dataset.\n\n    Args:\n        model (str): The name or path of the model to be loaded.\n        calib_dataset (str, optional): The calibration dataset name.\n            Defaults to 'wikitext2'.\n        calib_samples (int, optional): The number of samples for calibration.\n            Defaults to 128.\n        calib_seqlen (int, optional): The sequence length for calibration.\n            Defaults to 2048.\n        work_dir (str): The working directory for outputs.\n            Defaults to './work_dir'.\n        device (str, optional): The device to be used for calculation.\n            Defaults to 'cuda'.\n        w_bits (int): Bit number for weight quantization.\n        w_group_size (int): Group size for weight quantization statistics.\n        search_scale (bool): Whether search scale ratio. Default to False,\n            which means only smooth quant with 0.5 ratio will be applied.\n        dtype (str): Data type for loading model weights and calib infer.\n        batch_size (int): The batch size for running the calib samples.\n            Low GPU mem requires small batch_size. Large batch_size\n            reduces the calibration time while costs more VRAM.\n\n    Returns:\n        model (nn.Module): The loaded huggingface model.\n        tokenizer : The loaded hugginface tokenizer.\n        work_dir (str): The working directory for outputs.\n    \"\"\"\n\n    assert calib_dataset in ['wikitext2', 'c4', 'pileval',\n                             'gsm8k', 'neuralmagic_calibration', 'open-platypus', 'openwebtext'], \\\n        'Support only `wikitext2`, `c4`, `pileval`, `gsm8k`, ' \\\n        '`neuralmagic_calibration`, `open-platypus`, `openwebtext`.'\n\n    model_type, _ = get_task(backend='turbomind', model_path=model)\n    make_compatible_internvl_config(model)\n\n    # Load tokenizer and configuration\n    tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)\n\n    if model_type == 'llm':\n        model = load_hf_from_pretrained(model, dtype=dtype, trust_remote_code=True)\n        vl_model = None\n    elif model_type == 'vlm':\n        vl_model = load_vl_model(model, backend=None, with_llm=True).vl_model\n        model = vl_model\n        if hasattr(vl_model, 'language_model'):  # deepseek-vl, ...\n            model = vl_model.language_model\n        if hasattr(vl_model, 'llm'):  # MiniCPMV, ...\n            model = vl_model.llm\n        model.config.use_cache = False\n        if dtype == 'float16':\n            model.half()\n        elif dtype == 'bfloat16':\n            assert torch.cuda.is_bf16_supported(\n            ), 'your device does not support bfloat16 please set --dtype float16'  # noqa\n            model.to(torch.bfloat16)\n        elif dtype == 'auto' and model.config.torch_dtype == torch.bfloat16:\n            print('Warning: we cast model to float16 to prevent OOM. You'\n                  ' may enforce it bfloat16 by `--dtype bfloat16`')\n            model.half()\n        model.eval()\n\n    model_type = type(model).__name__\n    if model_type not in LAYER_TYPE_MAP or model_type not in NORM_TYPE_MAP:\n        raise RuntimeError(f'Currently, quantification and calibration of {model_type} are '\n                           f'not supported. The supported model types are '\n                           f\"{', '.join(LAYER_TYPE_MAP.keys())}.\")\n\n    if model_type in ['MixtralForCausalLM']:\n        update_moe_mapping(model, model_type)\n\n    if model_type == 'QWenLMHeadModel':\n        try:\n            import flash_attn  # noqa: F401\n        except ImportError:\n            raise RuntimeError('When using Qwen, you need to `pip install flash-attn` first, '\n                               'otherwise calibration and quantification will not work '\n                               'properly.')\n\n    layer_type = LAYER_TYPE_MAP[type(model).__name__]\n    norm_type = NORM_TYPE_MAP[type(model).__name__]\n\n    _prepare_for_calibrate(model, layer_type, HEAD_NAME_MAP[type(model).__name__], device)\n\n    print('Loading calibrate dataset ...')\n    calib_loader = get_calib_loaders(calib_dataset, tokenizer, nsamples=calib_samples, seqlen=calib_seqlen)\n\n    # Initialize calibration context\n    if search_scale:\n        calib_ctx = CalibrationContextV2(model,\n                                         tokenizer,\n                                         layer_type=layer_type,\n                                         norm_type=norm_type,\n                                         device=device,\n                                         w_bits=w_bits,\n                                         w_group_size=w_group_size,\n                                         batch_size=batch_size,\n                                         search_scale=search_scale)\n    else:\n        calib_ctx = CalibrationContext(model,\n                                       tokenizer,\n                                       layer_type=layer_type,\n                                       norm_type=norm_type,\n                                       batch_size=batch_size,\n                                       device=device)\n\n    with calib_ctx:\n        all_data = torch.cat(calib_loader).to(device)\n        calib_ctx.calibrate(all_data)\n\n    # Create work directory if not exists\n    work_dir = Path(work_dir)\n    work_dir.mkdir(parents=True, exist_ok=True)\n    calib_ctx.export(work_dir)\n\n    return vl_model, model, tokenizer, work_dir\n\n\nif __name__ == '__main__':\n    import fire\n\n    fire.Fire(calibrate)\n"
  },
  {
    "path": "lmdeploy/lite/apis/get_small_sharded_hf.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport argparse\nimport copy\nimport json\nimport os\nimport shutil\n\nimport torch\nfrom mmengine.utils import mkdir_or_exist\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description='Convert a hugging face model to the smallest sharded one')\n    parser.add_argument('src_dir', help='the directory of the model')\n    parser.add_argument('dst_dir', help='the directory to save the new model')\n    args = parser.parse_args()\n    return args\n\n\ndef main():\n    args = parse_args()\n    mkdir_or_exist(args.dst_dir)\n\n    all_files = os.listdir(args.src_dir)\n    for name in all_files:\n        if not name.startswith(('pytorch_model', '.')):\n            src_path = os.path.join(args.src_dir, name)\n            dst_path = os.path.join(args.dst_dir, name)\n            shutil.copy(src_path, dst_path)\n\n    with open(os.path.join(args.src_dir, 'pytorch_model.bin.index.json')) as f:\n        index = json.load(f)\n\n    n_shard = len(index['weight_map'])\n    new_index = copy.deepcopy(index)\n    new_index['weight_map'] = {}\n    cnt = 1\n\n    checkpoints = set(index['weight_map'].values())\n    for ckpt in checkpoints:\n        state_dict = torch.load(os.path.join(args.src_dir, ckpt), map_location='cuda', weights_only=True)\n        keys = sorted(list(state_dict.keys()))\n        for k in keys:\n            new_state_dict_name = 'pytorch_model-{:05d}-of-{:05d}.bin'.format(cnt, n_shard)\n            new_index['weight_map'][k] = new_state_dict_name\n            new_state_dict = {k: state_dict[k]}\n            torch.save(new_state_dict, os.path.join(args.dst_dir, new_state_dict_name))\n            cnt += 1\n        del state_dict\n        torch.cuda.empty_cache()\n    with open(os.path.join(args.dst_dir, 'pytorch_model.bin.index.json'), 'w') as f:\n        json.dump(new_index, f)\n    assert new_index['weight_map'].keys() == index['weight_map'].keys(), 'Mismatch on `weight_map`!'\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "lmdeploy/lite/apis/gptq.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport logging\nfrom typing import Literal\n\nimport torch\nfrom transformers import AutoConfig, AutoTokenizer\n\nfrom lmdeploy.lite.utils.calib_dataloader import get_calib_loaders\n\n\ndef auto_gptq(model: str,\n              work_dir: str = './work_dir',\n              w_bits: int = 4,\n              w_group_size: int = 128,\n              calib_dataset: str = 'wikitext2',\n              calib_samples: int = 128,\n              calib_seqlen: int = 2048,\n              batch_size: int = 1,\n              dtype: Literal['float16', 'bfloat16', 'auto'] = 'auto',\n              revision: str = None):\n    \"\"\"Perform weight quantization using AWQ algorithm.\n\n    Args:\n        model (str): The path of model in hf format.\n        work_dir (str): The working directory to save results.\n        calib_dataset (str): The calibration dataset name.\n            Defaults to 'wikitext2'.\n        calib_samples (int): The number of samples for calibration.\n        batch_size (int): The batch size for running the calib samples.\n            Low GPU mem requires small batch_size. Large batch_size\n            reduces the calibration time while costs more VRAM.\n        calib_seqlen (int): The sequence length for calibration.\n        w_bits (int): Bit number for weight quantization.\n        w_group_size (int): Group size for weight quantization statistics.\n        dtype (str): Data type for loading model weights and calib infer.\n        revision (str): The specific model version to use. It can be a\n            branch name, a tag name, or a commit id. If unspecified,\n            will use the default version.\n    \"\"\"\n    try:\n        from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig\n    except Exception:\n        raise ImportError('To use auto_gptq, please install auto-gptq by '\n                          'pip install auto-gptq')\n    logging.basicConfig(\n        format='%(asctime)s %(levelname)s [%(name)s] %(message)s',\n        level=logging.INFO,\n        datefmt='%Y-%m-%d %H:%M:%S',\n    )\n    # support internlm2\n    from auto_gptq.modeling import GPTQ_CAUSAL_LM_MODEL_MAP\n    from auto_gptq.modeling._const import SUPPORTED_MODELS\n\n    from ..modeling.internlm2_gptq import InternLM2GPTQForCausalLM\n    from ..modeling.internlm3_gptq import InternLM3GPTQForCausalLM\n    SUPPORTED_MODELS.append('internlm2')\n    GPTQ_CAUSAL_LM_MODEL_MAP.update(dict(internlm2=InternLM2GPTQForCausalLM))\n    SUPPORTED_MODELS.append('internlm3')\n    GPTQ_CAUSAL_LM_MODEL_MAP.update(dict(internlm3=InternLM3GPTQForCausalLM))\n\n    pretrained_model_dir = model\n    quantized_model_dir = work_dir\n\n    tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir, trust_remote_code=True)\n    print('Loading calibrate dataset ...')\n    calib_loader = get_calib_loaders(calib_dataset, tokenizer, nsamples=calib_samples, seqlen=calib_seqlen)\n    attention_mask = [1] * calib_seqlen\n    examples = [dict(input_ids=data.flatten().tolist(), attention_mask=attention_mask) for data in calib_loader]\n\n    quantize_config = BaseQuantizeConfig(\n        bits=w_bits,  # quantize model to 4-bit\n        group_size=w_group_size,  # it is recommended to set the value to 128\n        desc_act=False,  # lmdeploy only supports False\n        sym=True,  # lmdeploy only supports True\n    )\n\n    # load un-quantized model, by default,\n    # the model will always be loaded into CPU memory\n    hf_config = AutoConfig.from_pretrained(pretrained_model_dir, revision=revision, trust_remote_code=True)\n    torch_dtype = getattr(hf_config, 'torch_dtype', torch.float16)\n    if dtype == 'float16':\n        torch_dtype = torch.float16\n    elif dtype == 'bfloat16':\n        torch_dtype = torch.bfloat16\n    model = AutoGPTQForCausalLM.from_pretrained(pretrained_model_dir,\n                                                quantize_config,\n                                                revision=revision,\n                                                torch_dtype=torch_dtype,\n                                                trust_remote_code=True).cuda()\n\n    # quantize model, the examples should be list of dict whose keys\n    # can only be \"input_ids\" and \"attention_mask\"\n    model.quantize(examples, batch_size=batch_size)\n\n    # save quantized model\n    model.save_quantized(quantized_model_dir)\n\n    tokenizer.save_pretrained(quantized_model_dir)\n\n\nif __name__ == '__main__':\n    import fire\n\n    fire.Fire(auto_gptq)\n"
  },
  {
    "path": "lmdeploy/lite/apis/smooth_quant.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nimport os.path as osp\nfrom typing import Literal\n\nimport fire\nimport torch\nfrom torch import nn\n\nfrom lmdeploy.lite.apis.calibrate import LAYER_TYPE_MAP, NORM_TYPE_MAP, calibrate\nfrom lmdeploy.lite.quantization.awq import FC_FCS_MAP, NORM_FCS_MAP, awq_layers, skipped_module, smooth_layers\nfrom lmdeploy.lite.utils import collect_target_modules\nfrom lmdeploy.pytorch.models import QLinear, QRMSNorm\nfrom lmdeploy.utils import try_import_deeplink\n\n\ndef smooth_quant(model: str,\n                 work_dir: str = './work_dir',\n                 calib_dataset: str = 'wikitext2',\n                 calib_samples: int = 128,\n                 calib_seqlen: int = 2048,\n                 search_scale: bool = False,\n                 batch_size: int = 1,\n                 w_bits: int = 8,\n                 dtype: Literal['float16', 'bfloat16', 'auto'] = 'auto',\n                 device: str = 'cuda',\n                 quant_dtype: Literal['int8', 'fp8', 'float8_e4m3fn', 'float8_e5m2'] = 'int8',\n                 revision: str = None,\n                 download_dir: str = None):\n    try_import_deeplink(device)\n    if quant_dtype == 'fp8':\n        quant_dtype = 'float8_e4m3fn'\n\n    quant_dtype = getattr(torch, quant_dtype, torch.int8)\n    if quant_dtype.is_floating_point:\n        q_dtype_info = torch.finfo(quant_dtype)\n    else:\n        q_dtype_info = torch.iinfo(quant_dtype)\n\n    assert q_dtype_info.bits == w_bits\n    if not osp.exists(model):\n        print(f'can\\'t find model from local_path {model}, '\n              'try to download from remote')\n        from lmdeploy.utils import get_model\n        model = get_model(model, revision=revision, download_dir=download_dir)\n    model_path = model\n    vl_model, model, tokenizer, work_dir = calibrate(model,\n                                                     calib_dataset,\n                                                     calib_samples,\n                                                     calib_seqlen,\n                                                     work_dir,\n                                                     device,\n                                                     w_bits=w_bits,\n                                                     w_group_size=-1,\n                                                     search_scale=search_scale,\n                                                     dtype=dtype,\n                                                     batch_size=batch_size)\n\n    # calibrate function exports the calibration statistics\n    # (inputs, outputs, keys and values) to `work_dir`.\n    inp_stats = torch.load(work_dir / 'inputs_stats.pth', weights_only=True)\n    act_scales = inp_stats['absmax']\n\n    model_type = type(model).__name__\n    if model_type not in LAYER_TYPE_MAP or model_type not in NORM_TYPE_MAP:\n        raise RuntimeError(f'Currently, quantification and calibration of {model_type} are '\n                           f'not supported. The supported model types are '\n                           f\"{', '.join(LAYER_TYPE_MAP.keys())}.\")\n\n    if model_type == 'QWenLMHeadModel':\n        try:\n            import flash_attn  # noqa: F401\n        except ImportError:\n            raise RuntimeError('When using Qwen, you need to `pip install flash-attn` first, '\n                               'otherwise calibration and quantification will not work '\n                               'properly.')\n\n    layer_type = LAYER_TYPE_MAP[type(model).__name__]\n    norm_type = NORM_TYPE_MAP[type(model).__name__]\n    fc2fcs = FC_FCS_MAP[layer_type]\n    norm2fcs = NORM_FCS_MAP[layer_type]\n\n    layers = collect_target_modules(model, layer_type)\n    fcs = {}\n    for l_name, layer in layers.items():\n        name2fc = collect_target_modules(layer, nn.Linear, prefix=l_name)\n        fcs.update(name2fc)\n\n    if search_scale:\n        awq_ratios = inp_stats['ratios']\n        act_scales = inp_stats['absmean']\n        awq_layers(layers, fc2fcs, norm2fcs, act_scales, awq_ratios, -1, device)\n    else:\n        smooth_layers(layers, fc2fcs, norm2fcs, act_scales, -1, device)\n\n    rmsnorms = collect_target_modules(model, norm_type)\n\n    for name, linear in fcs.items():\n        if skipped_module(name):\n            continue\n        linear.to(device)\n        q_linear = QLinear.from_float(linear, quant_dtype=quant_dtype)\n        parent_name, _, child_name = name.rpartition('.')\n        parent = model.get_submodule(parent_name)\n        setattr(parent, child_name, q_linear)\n        linear.to('cpu')\n        q_linear.to('cpu')\n        torch.cuda.empty_cache()\n\n    for name, norm in rmsnorms.items():\n        if skipped_module(name):\n            continue\n        norm.to(device)\n        q_norm = QRMSNorm.from_float(norm, quant_dtype=quant_dtype)\n        parent_name, _, child_name = name.rpartition('.')\n        parent = model.get_submodule(parent_name)\n        setattr(parent, child_name, q_norm)\n        norm.to('cpu')\n        q_norm.to('cpu')\n        torch.cuda.empty_cache()\n\n    quant_dtype_s = str(quant_dtype).split('.')[1]\n    model.config.update(dict(quantization_config=dict(quant_method='smooth_quant', quant_dtype=f'{quant_dtype_s}')))\n\n    if vl_model:\n        from .auto_awq import save_vl_model\n        save_vl_model(vl_model, model_path, work_dir)\n    else:\n        model.save_pretrained(work_dir, safe_serialization=True)\n    tokenizer.save_pretrained(work_dir)\n\n\nif __name__ == '__main__':\n    fire.Fire(smooth_quant)\n"
  },
  {
    "path": "lmdeploy/lite/defaults.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom torch import nn\n\nOFFLOAD_MOD = (nn.Linear, )\nKV_CACHE_SIGNATURE = 'past_key_value'\n"
  },
  {
    "path": "lmdeploy/lite/modeling/__init__.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n"
  },
  {
    "path": "lmdeploy/lite/modeling/internlm2_gptq.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom auto_gptq.modeling import BaseGPTQForCausalLM\n\n\nclass InternLM2GPTQForCausalLM(BaseGPTQForCausalLM):\n    layer_type = 'InternLM2DecoderLayer'\n    layers_block_name = 'model.layers'\n    outside_layer_modules = ['model.tok_embeddings', 'model.norm']\n    inside_layer_modules = [\n        ['attention.wqkv'],\n        ['attention.wo'],\n        ['feed_forward.w3', 'feed_forward.w1'],\n        ['feed_forward.w2'],\n    ]\n"
  },
  {
    "path": "lmdeploy/lite/modeling/internlm3_gptq.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom auto_gptq.modeling import BaseGPTQForCausalLM\n\n\nclass InternLM3GPTQForCausalLM(BaseGPTQForCausalLM):\n    layer_type = 'InternLM3DecoderLayer'\n    layers_block_name = 'model.layers'\n    outside_layer_modules = ['model.embed_tokens', 'model.norm']\n    inside_layer_modules = [\n        ['self_attn.k_proj', 'self_attn.v_proj', 'self_attn.q_proj'],\n        ['self_attn.o_proj'],\n        ['mlp.up_proj', 'mlp.gate_proj'],\n        ['mlp.down_proj'],\n    ]\n"
  },
  {
    "path": "lmdeploy/lite/quantization/__init__.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom .activation import ActivationObserver, KVCacheObserver\nfrom .calibration import CalibrationContext, CalibrationContextV2\nfrom .weight import WeightQuantizer\n\n__all__ = ['WeightQuantizer', 'ActivationObserver', 'KVCacheObserver', 'CalibrationContext', 'CalibrationContextV2']\n"
  },
  {
    "path": "lmdeploy/lite/quantization/activation/__init__.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom .observer import ActivationObserver, KVCacheObserver\n\n__all__ = ['ActivationObserver', 'KVCacheObserver']\n"
  },
  {
    "path": "lmdeploy/lite/quantization/activation/observer.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nimport torch\n\nfrom lmdeploy.lite.utils.global_avail import GlobalAvailMixin\n\n\nclass KVCacheObserver(GlobalAvailMixin):\n    \"\"\"A class to observe and record the max, min, and absolute max value of\n    given tensor.\"\"\"\n\n    def __init__(self, num_head: int, head_dim: int) -> None:\n        \"\"\"Constructor for KVCacheObserver.\n\n        Args:\n            num_head : Number of heads\n            head_dim : Dimension of each head\n        \"\"\"\n        self.num_head = num_head\n        self.head_dim = head_dim\n        self.max_val = torch.full((num_head, head_dim), -torch.inf, dtype=torch.float16)\n        self.min_val = torch.full((num_head, head_dim), torch.inf, dtype=torch.float16)\n        self.absmax_val = torch.full((num_head, head_dim), 0, dtype=torch.float16)\n\n    @torch.no_grad()\n    def observe(self, x: torch.Tensor) -> None:\n        \"\"\"Function to observe the input tensor and update the max, min, and\n        absolute max values.\n\n        Args:\n            x : Input tensor\n        \"\"\"\n        assert len(x.shape) == 4\n\n        if x.size(2) == self.num_head and x.size(3) == self.head_dim:\n            # layout: (bs, seqlen, heads, dims)\n            x = x\n        elif x.size(1) == self.num_head and x.size(3) == self.head_dim:\n            # layout: (bs, heads, seqlen, dims)\n            x = x.transpose(1, 2)\n        else:\n            raise RuntimeError\n\n        cur_max = x.flatten(0, 1).max(0)[0].cpu()\n        cur_min = x.flatten(0, 1).min(0)[0].cpu()\n        cur_absmax = x.flatten(0, 1).abs().max(0)[0].cpu()\n\n        self.max_val = torch.maximum(self.max_val, cur_max)\n        self.min_val = torch.minimum(self.min_val, cur_min)\n        self.absmax_val = torch.maximum(self.absmax_val, cur_absmax)\n\n\nclass ActivationObserver(GlobalAvailMixin):\n    \"\"\"A class to observe and record the max, min, mean, absolute max, and\n    absolute mean value of a given tensor.\n\n    Also keeps track of the number of batches observed.\n    \"\"\"\n    observed = False\n\n    def __init__(self, dim: int) -> None:\n        \"\"\"Constructor for ActivationObserver.\n\n        Args:\n            dim : Dimension of the tensor\n        \"\"\"\n        self.dim = dim\n        self.max_val = torch.full((dim, ), -torch.inf, dtype=torch.float16)\n        self.min_val = torch.full((dim, ), torch.inf, dtype=torch.float16)\n        self.absmax_val = torch.full((dim, ), 0, dtype=torch.float16)\n        self.absmean_val = torch.full((dim, ), 0, dtype=torch.float16)\n        self.mean_val = torch.full((dim, ), 0, dtype=torch.float16)\n        self.num_batches_tracked = 0\n        self.value = None\n        self.ratio = None\n        self.num_ratio_tracked = 0\n\n    @classmethod\n    def disable(cls):\n        \"\"\"To avoid recomputation in search scale process.\"\"\"\n        cls.observed = True\n\n    @classmethod\n    def enable(cls):\n        \"\"\"To avoid recomputation in search scale process.\"\"\"\n        cls.observed = False\n\n    @torch.no_grad()\n    def observe(self, x: torch.Tensor, save_input: bool = False) -> None:\n        \"\"\"Function to observe the input tensor and update the max, min, mean,\n        absolute max, absolute mean values and number of batches tracked.\n\n        Args:\n            x : Input tensor\n        \"\"\"\n        assert torch.isnan(x).sum() == 0\n        if self.observed:\n            return\n        assert x.size(-1) == self.dim\n        cur_val = x.flatten(0, 1)\n        if any([s == 0 for s in cur_val.shape]):\n            return\n        cur_max = cur_val.max(0)[0].cpu()\n        cur_min = cur_val.min(0)[0].cpu()\n        cur_mean = cur_val.mean(0).cpu()\n\n        cur_abs = cur_val.abs()\n        cur_absmax = cur_abs.max(0)[0].cpu()\n        cur_absmean = cur_abs.mean(0).cpu()\n\n        self.max_val = torch.maximum(self.max_val, cur_max)\n        self.min_val = torch.minimum(self.min_val, cur_min)\n        self.absmax_val = torch.maximum(self.absmax_val, cur_absmax)\n        if save_input:\n            self.value = x\n\n        # Update mean and absmean value with accumulated sum divided\n        # by total number of batches\n        self.mean_val = ((self.mean_val * self.num_batches_tracked + cur_mean) / (self.num_batches_tracked + 1))\n        self.absmean_val = ((self.absmean_val * self.num_batches_tracked + cur_absmean) /\n                            (self.num_batches_tracked + 1))\n\n        # Increment the count of batches tracked\n        self.num_batches_tracked += 1\n\n    @torch.no_grad()\n    def save_ratio(self, ratio: float) -> None:\n        if self.ratio is None:\n            self.ratio = 0\n        self.ratio = (self.ratio * self.num_ratio_tracked + ratio) / (self.num_ratio_tracked + 1)\n        self.num_ratio_tracked += 1\n"
  },
  {
    "path": "lmdeploy/lite/quantization/awq.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom typing import List\n\nimport torch\n\n# Maps that describe the structure of your model.\nNORM_FCS_MAP = {\n    'LlamaDecoderLayer': {\n        'input_layernorm': ['self_attn.k_proj', 'self_attn.q_proj', 'self_attn.v_proj'],\n        'post_attention_layernorm': ['mlp.gate_proj', 'mlp.up_proj']\n    },\n    'InternLMDecoderLayer': {\n        'input_layernorm': ['self_attn.k_proj', 'self_attn.q_proj', 'self_attn.v_proj'],\n        'post_attention_layernorm': ['mlp.gate_proj', 'mlp.up_proj']\n    },\n    'InternLM2DecoderLayer': {\n        'attention_norm': ['attention.wqkv'],\n        'ffn_norm': ['feed_forward.w1', 'feed_forward.w3']\n    },\n    'InternLM3DecoderLayer': {\n        'input_layernorm': ['self_attn.k_proj', 'self_attn.q_proj', 'self_attn.v_proj'],\n        'post_attention_layernorm': ['mlp.gate_proj', 'mlp.up_proj']\n    },\n    'QWenBlock': {\n        'ln_1': ['attn.c_attn'],\n        'ln_2': ['mlp.w1', 'mlp.w2']\n    },\n    'Qwen2DecoderLayer': {\n        'input_layernorm': ['self_attn.k_proj', 'self_attn.q_proj', 'self_attn.v_proj'],\n        'post_attention_layernorm': ['mlp.gate_proj', 'mlp.up_proj']\n    },\n    'Qwen3DecoderLayer': {\n        'input_layernorm': ['self_attn.k_proj', 'self_attn.q_proj', 'self_attn.v_proj'],\n        'post_attention_layernorm': ['mlp.gate_proj', 'mlp.up_proj']\n    },\n    'DecoderLayer': {\n        'input_layernorm': ['self_attn.W_pack'],\n        'post_attention_layernorm': ['mlp.gate_proj', 'mlp.up_proj']\n    },\n    'Phi3DecoderLayer': {\n        'input_layernorm': ['self_attn.qkv_proj'],\n        'post_attention_layernorm': ['mlp.gate_up_proj']\n    },\n    'GLMBlock': {\n        'input_layernorm': ['self_attention.query_key_value'],\n        'post_attention_layernorm': ['mlp.dense_h_to_4h']\n    },\n    'MixtralDecoderLayer': {\n        'input_layernorm': ['self_attn.k_proj', 'self_attn.q_proj', 'self_attn.v_proj'],\n        'post_attention_layernorm':\n        ['block_sparse_moe.gate', 'block_sparse_moe.experts.{i}.w1', 'block_sparse_moe.experts.{i}.w3']\n    },\n    'Qwen2VLDecoderLayer': {\n        'input_layernorm': ['self_attn.k_proj', 'self_attn.q_proj', 'self_attn.v_proj'],\n        'post_attention_layernorm': ['mlp.gate_proj', 'mlp.up_proj']\n    },\n    'Qwen2_5_VLDecoderLayer': {\n        'input_layernorm': ['self_attn.k_proj', 'self_attn.q_proj', 'self_attn.v_proj'],\n        'post_attention_layernorm': ['mlp.gate_proj', 'mlp.up_proj']\n    },\n    'MistralDecoderLayer': {\n        'input_layernorm': ['self_attn.k_proj', 'self_attn.q_proj', 'self_attn.v_proj'],\n        'post_attention_layernorm': ['mlp.gate_proj', 'mlp.up_proj']\n    },\n}\n\nFC_FCS_MAP = {\n    'LlamaDecoderLayer': {\n        'self_attn.v_proj': ['self_attn.o_proj'],\n        'mlp.up_proj': ['mlp.down_proj']\n    },\n    'InternLMDecoderLayer': {\n        'self_attn.v_proj': ['self_attn.o_proj'],\n        'mlp.up_proj': ['mlp.down_proj']\n    },\n    'InternLM2DecoderLayer': {\n        'feed_forward.w3': ['feed_forward.w2']\n    },\n    'InternLM3DecoderLayer': {\n        'self_attn.v_proj': ['self_attn.o_proj'],\n        'mlp.up_proj': ['mlp.down_proj']\n    },\n    'QWenBlock': {\n        'attn.c_attn': ['attn.c_proj'],\n        'mlp.w1': ['mlp.c_proj']\n    },\n    'Qwen2DecoderLayer': {\n        'self_attn.v_proj': ['self_attn.o_proj'],\n        'mlp.up_proj': ['mlp.down_proj']\n    },\n    'Qwen3DecoderLayer': {\n        'self_attn.v_proj': ['self_attn.o_proj'],\n        'mlp.up_proj': ['mlp.down_proj']\n    },\n    'DecoderLayer': {\n        'self_attn.W_pack': ['self_attn.o_proj'],\n        'mlp.up_proj': ['mlp.down_proj']\n    },\n    'Phi3DecoderLayer': {\n        'self_attn.qkv_proj': ['self_attn.o_proj'],\n        'mlp.gate_up_proj': ['mlp.down_proj']\n    },\n    'GLMBlock': {\n        # 'self_attention.query_key_value': ['self_attention.dense']\n        # 'mlp.dense_h_to_4h': ['mlp.dense_4h_to_h']\n    },\n    'MixtralDecoderLayer': {\n        'self_attn.v_proj': ['self_attn.o_proj'],\n        'block_sparse_moe.experts.{i}.w3': ['block_sparse_moe.experts.{i}.w2']\n    },\n    'Qwen2VLDecoderLayer': {\n        'self_attn.v_proj': ['self_attn.o_proj'],\n        'mlp.up_proj': ['mlp.down_proj']\n    },\n    'Qwen2_5_VLDecoderLayer': {\n        'self_attn.v_proj': ['self_attn.o_proj'],\n        'mlp.up_proj': ['mlp.down_proj']\n    },\n    'MistralDecoderLayer': {\n        'self_attn.v_proj': ['self_attn.o_proj'],\n        'mlp.up_proj': ['mlp.down_proj']\n    }\n}\n\nSKIPPED_MODULE = ['lora', 'block_sparse_moe.gate']\n\n\ndef skipped_module(name: str):\n    \"\"\"Whether the module should be skipped from quantization.\"\"\"\n    for m in SKIPPED_MODULE:\n        if m in name:\n            return True\n    return False\n\n\n@torch.no_grad()\ndef get_weight_scale(weight, q_group_size=-1):\n    org_shape = weight.shape\n    if q_group_size > 0:\n        weight = weight.view(-1, q_group_size)\n    abs_weight = weight.abs()\n    abs_weight_amax = abs_weight.amax(dim=1, keepdim=True)\n    if abs_weight_amax.min().item() == 0:\n        print('weight.amax.min is zero, clamping weight.amax to 1e-4')\n        abs_weight_amax = abs_weight_amax.clamp(min=1e-4)\n    scale = abs_weight / abs_weight_amax\n    scale = scale.view(org_shape)\n    scale = scale.mean(0)\n    return scale\n\n\n@torch.no_grad()\ndef smooth_ln_fcs(ln: torch.nn.Module,\n                  fcs: List[torch.nn.Module],\n                  act_scales: torch.Tensor,\n                  group_size: int = -1,\n                  alpha: float = 0.5) -> torch.Tensor:\n    \"\"\"Smooth weights of a layer normalization and its fully connected layers.\n\n    :param ln: Layer Normalization module\n    :param fcs: List of Fully Connected modules\n    :param act_scales: Activation scales\n    :param alpha: Scaling factor (default is 0.5)\n    :return: Scales\n    \"\"\"\n    device, dtype = fcs[0].weight.device, fcs[0].weight.dtype\n\n    # If zeros exist within the weight of the layer norm, it becomes\n    # unnecessary to perform smooth quantization at the positions where\n    # these zeros occur.\n    zero_positions = (ln.weight == 0).nonzero(as_tuple=True)[0]\n    nonzero_positions = (ln.weight != 0).nonzero(as_tuple=True)[0]\n\n    act_scales = act_scales.to(device=device, dtype=dtype)\n\n    concat_w = torch.cat([fc.weight for fc in fcs], dim=0)\n    w_scales = get_weight_scale(concat_w, group_size)\n\n    w_scales_pow = w_scales.pow(1 - alpha)\n    if w_scales_pow.min().item() == 0:\n        print('w_scales.pow(1 - alpha).min is zero, '\n              'clamping w_scales.pow(1 - alpha) to 1e-4')\n        w_scales_pow = w_scales_pow.clamp(min=1e-4)\n    scales = (act_scales.pow(alpha) / w_scales_pow).clamp(min=1e-4).to(device).to(dtype)\n\n    scales = scales / (scales[nonzero_positions].max() * scales[nonzero_positions].min()).sqrt()\n\n    scales[zero_positions] = 1\n\n    ln.weight.div_(scales)\n    if hasattr(ln, 'bias'):\n        ln.bias.div_(scales)\n\n    for fc in fcs:\n        fc.weight.mul_(scales.view(1, -1))\n\n    for p in ln.parameters():\n        assert torch.isnan(p).sum() == 0\n    for fc in fcs:\n        for p in fc.parameters():\n            assert torch.isnan(p).sum() == 0\n    return scales\n\n\n@torch.no_grad()\ndef smooth_fc_fcs(pre_fc: torch.nn.Module,\n                  fcs: List[torch.nn.Module],\n                  act_scales: torch.Tensor,\n                  group_size: int = -1,\n                  alpha: float = 0.5) -> torch.Tensor:\n    \"\"\"Smooth weights of a fully connected layer and its downstream layers.\n\n    :param pre_fc: Previous Fully Connected layer\n    :param fcs: List of Fully Connected modules\n    :param act_scales: Activation scales\n    :param alpha: Scaling factor (default is 0.5)\n    :return: Scales\n    \"\"\"\n    device, dtype = pre_fc.weight.device, pre_fc.weight.dtype\n\n    size_a = act_scales.size(0)\n    size_pre_fc = pre_fc.weight.size(0)\n\n    # (for llama2) use group query attention, pre_fc is v_proj, fc is o_proj\n    if size_pre_fc < size_a and size_a % size_pre_fc == 0:\n        return\n\n    act_scales = act_scales.to(device=device, dtype=dtype)\n\n    concat_w = torch.cat([fc.weight for fc in fcs], dim=0)\n    w_scales = get_weight_scale(concat_w, group_size)\n\n    w_scales_pow = w_scales.pow(1 - alpha)\n    if w_scales_pow.min().item() == 0:\n        print('w_scales.pow(1 - alpha).min is zero, '\n              'clamping w_scales.pow(1 - alpha) to 1e-4')\n        w_scales_pow = w_scales_pow.clamp(min=1e-4)\n    scales = (act_scales.pow(alpha) / w_scales_pow).clamp(min=1e-4).to(device).to(dtype)\n    scales = scales / (scales.max() * scales.min()).sqrt()\n\n    # (for qwen&baichuan) pre_fc is packed QKV, only V needs to scale\n    # phi3 fused qkv and gate_up\n    if size_pre_fc > size_a and size_pre_fc % size_a == 0 \\\n            and size_pre_fc // size_a in [2, 3]:\n\n        pre_fc.weight[-size_a:].div_(scales.view(-1, 1))\n\n        if getattr(pre_fc, 'bias', None) is not None:\n            pre_fc.bias[-size_a:].div_(scales)\n    else:\n\n        pre_fc.weight.div_(scales.view(-1, 1))\n\n        if getattr(pre_fc, 'bias', None) is not None:\n            pre_fc.bias.div_(scales)\n\n    for fc in fcs:\n        fc.weight.mul_(scales.view(1, -1))\n\n    for p in pre_fc.parameters():\n        assert torch.isnan(p).sum() == 0\n    for fc in fcs:\n        for p in fc.parameters():\n            assert torch.isnan(p).sum() == 0\n\n    return scales\n\n\ndef check_awq_supported(layer_type):\n    \"\"\"Check if the smooth function is supported by inspecting layer type.\"\"\"\n    norm_fcs_found = False\n    fc_fcs_found = False\n\n    if isinstance(layer_type, str):\n        if layer_type in NORM_FCS_MAP:\n            norm_fcs_found = True\n        if layer_type in FC_FCS_MAP:\n            fc_fcs_found = True\n\n    elif isinstance(layer_type, type):\n        if layer_type.__name__ in NORM_FCS_MAP:\n            norm_fcs_found = True\n        if layer_type.__name__ in FC_FCS_MAP:\n            fc_fcs_found = True\n\n    else:\n        raise NotImplementedError\n\n    if not norm_fcs_found:\n        raise NotImplementedError\n\n    if not fc_fcs_found:\n        raise NotImplementedError\n\n\ndef quant_weights(model, fcs, bits, symmetry, group_size=-1, device='cuda'):\n    \"\"\"Quantize the weights of the target model's linear layers.\"\"\"\n    from lmdeploy.lite.quantization import WeightQuantizer\n    from lmdeploy.lite.quantization.modules import WeightOnlyQLinear\n    from lmdeploy.lite.utils import QParams\n    for name, fc in fcs.items():\n        fc.to(device)\n        parent_name, _, child_name = name.rpartition('.')\n        parent = model.get_submodule(parent_name)\n        pack_or_skip = 'packed'\n        if skipped_module(name):\n            q_linear = fc\n            pack_or_skip = 'skipped'\n        else:\n            quantizer = WeightQuantizer(bits, symmetry, 'per_group', group_size)\n            fc.weight.data, scales, zeros = pseudo_quantize_tensor(fc.weight.data,\n                                                                   bits,\n                                                                   group_size,\n                                                                   return_scale_zeros=True)\n            q_linear = WeightOnlyQLinear.from_linear(fc, quantizer, qparams=QParams(scales, zeros))\n        setattr(parent, child_name, q_linear)\n        fc.to('cpu')\n        torch.cuda.empty_cache()\n\n        print(f'{name} weight {pack_or_skip}.')\n\n\ndef smooth_layers(layers, fc2fcs, norm2fcs, a_scales, group_size=-1, device='cuda'):\n    \"\"\"Apply weight smoothing based on input scales.\"\"\"\n\n    for l_name, layer in layers.items():\n        layer.to(device)\n        submodule_names = [name for name, _ in layer.named_modules()]\n        for ln_name, fc_names in norm2fcs.items():\n            a_name = [f'{l_name}.{n}' for n in fc_names if n in submodule_names][0]\n\n            ln = layer.get_submodule(ln_name)\n            fcs = [layer.get_submodule(n) for n in fc_names if n in submodule_names]\n            smooth_ln_fcs(ln, fcs, a_scales[a_name], group_size)\n\n        for f_name, fc_names in fc2fcs.items():\n            a_name = [f'{l_name}.{n}' for n in fc_names if n in submodule_names][0]\n\n            fc = layer.get_submodule(f_name)\n            fcs = [layer.get_submodule(n) for n in fc_names if n in submodule_names]\n\n            smooth_fc_fcs(fc, fcs, a_scales[a_name], group_size)\n\n        layer.to('cpu')\n        torch.cuda.empty_cache()\n        max_memory = torch.cuda.max_memory_allocated(device=device) / 1024 / 1024 / 1024\n        print(f'{l_name} smooth weight done.'\n              f' max gpu memory: {max_memory:.2f} GB')\n\n\ndef pseudo_quantize_tensor(w, w_bit=8, w_group_size=-1, return_scale_zeros=False):\n    \"\"\"Pseudo quantize tensor.\"\"\"\n    org_w_shape = w.shape\n    if w_group_size > 0:\n        assert org_w_shape[-1] % w_group_size == 0\n        w = w.reshape(-1, w_group_size)\n    assert w.dim() == 2\n    max_val = w.amax(dim=1, keepdim=True)\n    min_val = w.amin(dim=1, keepdim=True)\n    max_int = 2**w_bit - 1\n    min_int = 0\n    scales = (max_val - min_val).clamp(min=1e-5) / max_int\n    zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int)\n    assert torch.isnan(scales).sum() == 0\n    assert torch.isnan(w).sum() == 0\n\n    q_w = torch.clamp(torch.round(w / scales) + zeros, min_int, max_int)\n    w = (q_w - zeros) * scales\n    assert torch.isnan(w).sum() == 0\n\n    if return_scale_zeros:\n        zeros = zeros.view(org_w_shape[0], org_w_shape[-1] // w_group_size, -1)\n        scales = scales.view(org_w_shape[0], org_w_shape[-1] // w_group_size, -1)\n        q_w = q_w.reshape(org_w_shape)\n        return q_w, scales, zeros\n    w = w.reshape(org_w_shape)\n    return w\n\n\ndef awq_layers(layers, fc2fcs, norm2fcs, a_scales, a_ratios=None, group_size=-1, device='cuda'):\n    \"\"\"Apply awq based on input scales.\"\"\"\n\n    for l_name, layer in layers.items():\n        layer.to(device)\n        for ln_name, fc_names in norm2fcs.items():\n            a_name = [f'{l_name}.{n}' for n in fc_names][0]\n            ratios = [a_ratios[f'{l_name}.{n}'] for n in fc_names]\n            ratio = [s for s in ratios if s is not None][0]\n\n            ln = layer.get_submodule(ln_name)\n            fcs = [layer.get_submodule(n) for n in fc_names]\n            smooth_ln_fcs(ln, fcs, a_scales[a_name], group_size, ratio)\n\n        for f_name, fc_names in fc2fcs.items():\n            a_name = [f'{l_name}.{n}' for n in fc_names][0]\n            ratios = [a_ratios[f'{l_name}.{n}'] for n in fc_names]\n            ratios = [s for s in ratios if s is not None]\n            ratio = 0.5 if not len(ratios) else ratios[0]\n\n            fc = layer.get_submodule(f_name)\n            fcs = [layer.get_submodule(n) for n in fc_names]\n\n            smooth_fc_fcs(fc, fcs, a_scales[a_name], group_size, ratio)\n\n        layer.to('cpu')\n        torch.cuda.empty_cache()\n        max_memory = torch.cuda.max_memory_allocated(device=device) / 1024 / 1024 / 1024\n        print(f'{l_name} smooth weight done.'\n              f' max gpu memory: {max_memory:.2f} GB')\n"
  },
  {
    "path": "lmdeploy/lite/quantization/calibration.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nfrom functools import partial\nfrom typing import Union\n\nimport torch\nfrom torch import nn\nfrom transformers import PreTrainedTokenizer\n\nfrom lmdeploy.lite.quantization.activation import ActivationObserver\nfrom lmdeploy.lite.quantization.awq import FC_FCS_MAP, NORM_FCS_MAP\nfrom lmdeploy.lite.utils import (bimap_name_mod, collect_target_modules, concat_decoder_layer_outputs,\n                                 split_decoder_layer_inputs)\n\n\nclass CalibrationContext():\n    \"\"\"Calibration context manager for model quantization.\n\n    Parameters:\n      - model: The target model to be calibrated and quantized\n      - tokenizer: The tokenizer used in the model training\n      - layer_type: Layer type to be targeted for calibration\n      - norm_type: Normalization type used for calibration\n      - device: Device on which model is to be calibrated ('cpu' or 'cuda')\n    \"\"\"\n\n    inp_obs_group = 'inputs'\n    out_obs_group = 'outputs'\n\n    def __init__(self,\n                 model: nn.Module,\n                 tokenizer: PreTrainedTokenizer,\n                 layer_type: Union[str, type],\n                 norm_type: Union[str, type],\n                 batch_size: int = 1,\n                 device: str = 'cuda',\n                 **kwargs) -> None:\n        \"\"\"Initiate calibration context.\n\n        Args:\n            model (nn.Module): Model to be calibrated.\n            tokenizer (PreTrainedTokenizer): Tokenizer of the given model.\n            layer_type (Union[str, type]): Type of the layers to be observed.\n            norm_type (Union[str, type]): Norm type used in the model.\n            batch_size (int): The batch size for running the calib samples.\n                Low GPU mem requires small batch_size. Large batch_size\n                reduces the calibration time while costs more VRAM.\n            device (str, optional): Device where the model should run.\n                Defaults to 'cuda'.\n        \"\"\"\n\n        self.layer_type = layer_type\n        self.norm_type = norm_type\n        self.batch_size = batch_size\n\n        num_kv_heads, num_attn_heads = self._guess_num_heads(model)\n        self.num_kv_heads = num_kv_heads\n        self.head_dim = model.config.hidden_size // num_attn_heads\n        self.model = model\n\n        self.tokenizer = tokenizer\n\n        # Collect modules to observe\n        self.name2layer = collect_target_modules(self.model, layer_type)\n        self.name2fc = {}\n        for l_name, layer in self.name2layer.items():\n            name2fc = collect_target_modules(layer, nn.Linear, prefix=l_name)\n            self.name2fc.update(name2fc)\n        self.name2norm = collect_target_modules(self.model, norm_type)\n\n        maps = bimap_name_mod([self.name2layer, self.name2fc, self.name2norm])\n        self.name2mod, self.mod2name = maps\n\n        # Initialize observers\n        self._init_input_observers(self.name2fc)\n        self._init_output_observers(self.name2norm)\n        self._init_output_observers(self.name2fc)\n\n        self.device = device\n\n    def _guess_num_heads(self, model):\n\n        if hasattr(model.config, 'num_key_value_heads'):\n            num_kv_heads = model.config.num_key_value_heads\n        else:\n            num_kv_heads = model.config.num_attention_heads\n\n        num_attn_heads = model.config.num_attention_heads\n\n        return num_kv_heads, num_attn_heads\n\n    def _init_input_observers(self, name2mod):\n        \"\"\"Initialize input observers for given modules.\"\"\"\n        for name, mod in name2mod.items():\n            obs = ActivationObserver(mod.weight.size(-1))\n            obs.global_available(name, group=self.inp_obs_group)\n\n    def _init_output_observers(self, name2mod):\n        \"\"\"Initialize output observers for given modules.\"\"\"\n        for name, mod in name2mod.items():\n            obs = ActivationObserver(mod.weight.size(0))\n            obs.global_available(name, group=self.out_obs_group)\n\n    def _insert_input_observers(self):\n        \"\"\"Insert input observers into the target modules.\n\n        This function registers a forward pre-hook on each target module to observe the inputs.\n        \"\"\"\n\n        def _input_hook(mod: nn.Module, inp: torch.Tensor):\n            m_name = self.mod2name[mod]\n            obs = ActivationObserver.find(m_name, group=self.inp_obs_group)\n            obs.observe(inp[0])\n\n        group = ActivationObserver.find_group(self.inp_obs_group)\n        for name in group.keys():\n            mod = self.name2mod[name]\n            hook_fn = mod.register_forward_pre_hook(_input_hook)\n            self._hooks.append(hook_fn)\n\n    def _insert_output_observers(self):\n        \"\"\"Insert output observers into the target modules.\n\n        This function registers a forward hook on each target module to observe the outputs.\n        \"\"\"\n\n        def _output_hook(mod: nn.Module, inp: torch.Tensor, out: torch.Tensor):\n            m_name = self.mod2name[mod]\n            obs = ActivationObserver.find(m_name, group=self.out_obs_group)\n            obs.observe(out)\n\n        group = ActivationObserver.find_group(self.out_obs_group)\n        for name in group.keys():\n            mod = self.name2mod[name]\n            hook_fn = mod.register_forward_hook(_output_hook)\n            self._hooks.append(hook_fn)\n\n    def _wrap_decoder_layers(self):\n        \"\"\"Method to wrap the decoder layers' forward functions for observing\n        their key/value cache during batched forward passes.\"\"\"\n\n        def _forward(mod, *args, **kwargs):\n\n            mod.to(self.device)\n            batch_args, batch_kwargs = split_decoder_layer_inputs(self.batch_size, *args, **kwargs)\n            batch_outputs = []\n            samples = len(batch_args)\n\n            m_name = self.mod2name[mod]\n\n            for i in range(len(batch_args)):\n                batch_outputs.append(self._ori_forwards[mod](*batch_args[i], **batch_kwargs[i]))\n\n            outputs = concat_decoder_layer_outputs(batch_outputs)\n\n            del batch_outputs, batch_args, batch_kwargs, args\n            mod.to('cpu')\n            torch.cuda.empty_cache()\n            max_memory = torch.cuda.max_memory_allocated(device=self.device) / 1024 / 1024 / 1024\n            print(f'{m_name}, samples: {samples}, '\n                  f'max gpu memory: {max_memory:.2f} GB')\n            return outputs\n\n        for layer in self.name2layer.values():\n            self._ori_forwards[layer] = layer.forward\n            layer.forward = partial(_forward, layer)\n\n    def collect_inputs_stats(self):\n        \"\"\"Collect statistics (min, max, absmax values) of the observed inputs.\n\n        Returns a dictionary with these collected stats.\n        \"\"\"\n        inputs_stats = {'max': {}, 'min': {}, 'mean': {}, 'absmax': {}, 'absmean': {}}\n        obs_group = ActivationObserver.find_group(self.inp_obs_group)\n        for name, obs in obs_group.items():\n            inputs_stats['max'][name] = obs.max_val\n            inputs_stats['min'][name] = obs.min_val\n            inputs_stats['mean'][name] = obs.mean_val\n            inputs_stats['absmax'][name] = obs.absmax_val\n            inputs_stats['absmean'][name] = obs.absmean_val\n        return inputs_stats\n\n    def collect_outputs_stats(self):\n        \"\"\"Collect statistics (min, max, absmax values) of the observed\n        outputs.\n\n        Returns a dictionary with these collected stats.\n        \"\"\"\n        outputs_stats = {'max': {}, 'min': {}, 'mean': {}, 'absmax': {}, 'absmean': {}}\n        obs_group = ActivationObserver.find_group(self.out_obs_group)\n        for name, obs in obs_group.items():\n            outputs_stats['max'][name] = obs.max_val\n            outputs_stats['min'][name] = obs.min_val\n            outputs_stats['mean'][name] = obs.mean_val\n            outputs_stats['absmax'][name] = obs.absmax_val\n            outputs_stats['absmean'][name] = obs.absmean_val\n        return outputs_stats\n\n    def export(self, out_dir):\n        \"\"\"Export the calibration statistics (inputs, outputs, keys and values)\n        to specified directory.\n\n        Args:\n            out_dir (Union[str, Path]): The directory path where the stats\n                will be saved.\n        \"\"\"\n\n        inp_stats = self.collect_inputs_stats()\n        torch.save(inp_stats, out_dir / 'inputs_stats.pth')\n        torch.cuda.empty_cache()\n\n        out_stats = self.collect_outputs_stats()\n        torch.save(out_stats, out_dir / 'outputs_stats.pth')\n        torch.cuda.empty_cache()\n\n    def calibrate(self, data):\n        \"\"\"Forward pass through the model in inference mode with given data.\"\"\"\n\n        if type(self.model).__name__ in ('QWenLMHeadModel', 'ChatGLMForConditionalGeneration'):\n            model = self.model.transformer\n        else:\n            model = self.model.model\n        with torch.inference_mode():\n            _ = model(data.to(self.device))\n        torch.cuda.empty_cache()\n\n    def __enter__(self):\n        \"\"\"Prepares the Calibration object for a 'with' statement by\n        registering hooks and wrapping layer forward methods.\"\"\"\n\n        self._hooks = list()\n\n        self._ori_forwards = {}\n        for layer in self.name2layer.values():\n            self._ori_forwards[layer] = layer.forward\n\n        self._insert_input_observers()\n        self._insert_output_observers()\n        self._wrap_decoder_layers()\n\n    def __exit__(self, exc_type, exc_value, traceback):\n        \"\"\"Clean up after a 'with' statement by removing registered hooks,\n        restoring original forward methods, and if no exception occurred,\n        collecting all gathered statistics and saving them.\"\"\"\n        for h in self._hooks:\n            h.remove()\n\n        for layer in self.name2layer.values():\n            layer.forward = self._ori_forwards[layer]\n\n\n@torch.no_grad()\ndef auto_scale_block(module, module_kwargs, w_bit, w_group_size, input_feat, mod_name):\n    if 'use_cache' in module_kwargs:\n        module_kwargs.pop('use_cache')\n\n    # find the best scale ratio\n    def _search_module_scale(block, linears2scale: list, x, kwargs={}):\n        x = x.to(next(block.parameters()).device)\n        with torch.no_grad():\n            org_out = block(x, **kwargs)\n            if isinstance(org_out, tuple):\n                org_out = org_out[0]\n\n        x_max = x.abs().view(-1, x.shape[-1]).mean(0)\n\n        best_error = float('inf')\n        best_ratio = -1\n        n_grid = 20\n        history = []\n\n        concat_w = torch.cat([_m.weight for _m in linears2scale], dim=0)\n        from .awq import get_weight_scale, pseudo_quantize_tensor\n        w_mean = get_weight_scale(concat_w, w_group_size)\n\n        org_sd = {k: v.cpu() for k, v in block.state_dict().items()}\n        for ratio in range(0, n_grid):\n            ratio = ratio / n_grid\n            w_mean_pow = w_mean.pow(1 - ratio)\n            if w_mean_pow.min().item() == 0:\n                print('w_mean.pow(1 - ratio).min is zero, '\n                      'clamping w_mean.pow(1 - ratio) to 1e-4')\n                w_mean_pow = w_mean_pow.clamp(min=1e-4)\n            scales = (x_max.pow(ratio) / w_mean_pow).clamp(min=1e-4).view(-1)\n\n            scales = scales / (scales.max() * scales.min()).sqrt()\n            for fc in linears2scale:\n                fc.weight.mul_(scales.view(1, -1).to(fc.weight.device))\n                fc.weight.data = pseudo_quantize_tensor(fc.weight.data, w_bit, w_group_size) / (scales.view(1, -1))\n            out = block(x, **kwargs)\n            if isinstance(out, tuple):\n                out = out[0]\n\n            # float prevents overflow\n            loss = (org_out - out).float().pow(2).mean().item()\n            history.append(loss)\n            if loss < best_error:\n                best_error = loss\n                best_ratio = ratio\n            block.load_state_dict(org_sd)\n        if best_ratio == -1:\n            print(history)\n            raise Exception\n        return best_ratio\n\n    def _auto_get_scale(layers, inp, module2inspect=None, kwargs={}):\n        # module2inspect: if given, we will check the output diff of\n        #  this module instead of layers\n        if module2inspect is None:\n            assert len(layers) == 1\n            module2inspect = layers[0]\n        # internlm-xcomposer2-vl applies plora, which requires im_mask arg\n        if module2inspect._get_name() == 'InternLM2MLP':\n            from inspect import signature\n            if 'im_mask' in signature(module2inspect.forward).parameters:\n                kwargs['im_mask'] = None\n\n        best_ratio = _search_module_scale(module2inspect, layers, inp.value, kwargs)\n        inp.save_ratio(best_ratio)\n\n    for i, (prev_name, layer_names) in enumerate(NORM_FCS_MAP[module._get_name()].items()):\n        # attention input\n        _auto_get_scale(\n            layers=[module.get_submodule(name) for name in layer_names],\n            inp=input_feat[f'{mod_name}.{layer_names[0]}'],\n            module2inspect=module.get_submodule(layer_names[0].split('.')[0]),\n            kwargs=module_kwargs if i == 0 else {},  # only attention input need\n        )\n    for prev_name, layer_names in FC_FCS_MAP[module._get_name()].items():\n        # attention input\n        _auto_get_scale(\n            layers=[module.get_submodule(name) for name in layer_names],\n            inp=input_feat[f'{mod_name}.{layer_names[0]}'],\n        )\n\n\nclass CalibrationContextV2(CalibrationContext):\n\n    def __init__(self,\n                 model: nn.Module,\n                 tokenizer: PreTrainedTokenizer,\n                 layer_type: Union[str, type],\n                 norm_type: Union[str, type],\n                 batch_size: int = 1,\n                 device: str = 'cuda',\n                 search_scale: bool = True,\n                 w_bits: int = 4,\n                 w_group_size: int = 128,\n                 **kwargs) -> None:\n        super().__init__(model, tokenizer, layer_type, norm_type, batch_size, device)\n        self.w_bits = w_bits\n        self.w_group_size = w_group_size\n        self.search_scale = search_scale\n\n    def _insert_input_observers(self):\n        \"\"\"Insert input observers into the target modules.\n\n        This function registers a forward pre-hook on each target module to observe the inputs.\n        \"\"\"\n\n        def _input_hook(mod: nn.Module, inp: torch.Tensor):\n            m_name = self.mod2name[mod]\n            obs = ActivationObserver.find(m_name, group=self.inp_obs_group)\n            obs.observe(inp[0], self.search_scale)\n\n        group = ActivationObserver.find_group(self.inp_obs_group)\n        for name in group.keys():\n            mod = self.name2mod[name]\n            hook_fn = mod.register_forward_pre_hook(_input_hook)\n            self._hooks.append(hook_fn)\n\n    def export(self, out_dir):\n        \"\"\"Export the calibration statistics (inputs, outputs, keys and values)\n        to specified directory.\n\n        Args:\n            out_dir (Union[str, Path]): The directory path where the stats\n                will be saved.\n        \"\"\"\n        inputs_stats = {\n            'max': {},\n            'min': {},\n            'mean': {},\n            'absmax': {},\n            'absmean': {},\n            'ratios': {},\n        }\n        obs_group = ActivationObserver.find_group(self.inp_obs_group)\n        for name, obs in obs_group.items():\n            inputs_stats['max'][name] = obs.max_val\n            inputs_stats['min'][name] = obs.min_val\n            inputs_stats['mean'][name] = obs.mean_val\n            inputs_stats['absmax'][name] = obs.absmax_val\n            inputs_stats['absmean'][name] = obs.absmean_val\n            inputs_stats['ratios'][name] = obs.ratio\n        torch.save(inputs_stats, out_dir / 'inputs_stats.pth')\n        torch.cuda.empty_cache()\n\n    def _wrap_decoder_layers_for_search(self):\n        \"\"\"Method to wrap the decoder layers' forward functions for observing\n        their key/value cache during batched forward passes.\"\"\"\n\n        @torch.no_grad()\n        def _forward(mod, *args, **kwargs):\n\n            mod.to(self.device)\n            batch_args, batch_kwargs = split_decoder_layer_inputs(self.batch_size, *args, **kwargs)\n            batch_outputs = []\n            samples = len(batch_args)\n\n            m_name = self.mod2name[mod]\n            for i in range(len(batch_args)):\n                batch_outputs.append(self._ori_forwards[mod](*batch_args[i], **batch_kwargs[i]))\n                obs_group = ActivationObserver.find_group(self.inp_obs_group)\n                mod_name = self.mod2name[mod]\n                ActivationObserver.disable()\n                auto_scale_block(mod, batch_kwargs[i], self.w_bits, self.w_group_size, obs_group, mod_name)\n                ActivationObserver.enable()\n            for key, item in obs_group.items():\n                if key.startswith(f'{mod_name}.') and item.value is not None:\n                    item.value.cpu()\n                    del item.value\n\n            outputs = concat_decoder_layer_outputs(batch_outputs)\n\n            del batch_outputs, batch_args, batch_kwargs, args\n            mod.cpu()\n            import gc\n            gc.collect()\n            torch.cuda.empty_cache()\n            max_memory = torch.cuda.max_memory_allocated(device=self.device) / (1 << 30)\n            print(f'{m_name}, samples: {samples}, '\n                  f'max gpu memory: {max_memory:.2f} GB')\n            return outputs\n\n        for layer in self.name2layer.values():\n            self._ori_forwards[layer] = layer.forward\n            layer.forward = partial(_forward, layer)\n            layer.cpu()\n\n    def __enter__(self):\n        \"\"\"Prepares the Calibration object for a 'with' statement by\n        registering hooks and wrapping layer forward methods.\"\"\"\n\n        self._hooks = list()\n\n        self._insert_input_observers()\n        self._ori_forwards = {}\n        for layer in self.name2layer.values():\n            self._ori_forwards[layer] = layer.forward\n\n        if self.search_scale:\n            self._wrap_decoder_layers_for_search()\n"
  },
  {
    "path": "lmdeploy/lite/quantization/modules/__init__.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom .linear import WeightOnlyQLinear\n\n__all__ = ['WeightOnlyQLinear']\n"
  },
  {
    "path": "lmdeploy/lite/quantization/modules/linear.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom typing import Optional, Type, TypeVar\n\nimport torch\nfrom torch import nn\n\nfrom lmdeploy.lite.utils.cal_qparams import QParams\n\ntry:\n    import awq_inference_engine\nexcept ModuleNotFoundError:\n    awq_inference_engine = None\n\n\nclass WeightOnlyQLinear(nn.Module):\n    \"\"\"This class implements weight only quantization linear.\n\n    Args:\n        w_bit (int): number of bits for quantization.\n        symmetry (bool): If true, use symmetric quantization,\n            otherwise use asymmetric quantization.\n        group_size (int): size of the quantization group.\n        in_features (int): size of each input sample.\n        out_features (int): size of each output sample.\n        bias (Tensor, optional): Defaults to None.\n    \"\"\"\n\n    def __init__(\n        self,\n        in_features: int,\n        out_features: int,\n        bias: Optional[torch.Tensor] = True,\n        w_bit: int = 4,\n        symmetry: bool = False,\n        group_size: int = 128,\n    ) -> None:\n        super().__init__()\n\n        if w_bit not in [2, 4, 8]:\n            raise NotImplementedError('Only 2,4,8 bit are supported for now.')\n\n        self.in_features = in_features\n        self.out_features = out_features\n        self.w_bit = w_bit\n        self.group_size = group_size if group_size != -1 else in_features\n\n        assert self.in_features % self.group_size == 0\n        assert out_features % (32 // self.w_bit) == 0\n\n        w_pack_oc = out_features // (32 // self.w_bit)\n        w_inc = in_features\n        weight = torch.zeros((w_inc, w_pack_oc), dtype=torch.int32)\n        self.register_buffer('qweight', weight)\n\n        if bias:\n            self.register_buffer('bias', torch.zeros(out_features))\n        else:\n            self.bias = None\n\n        s_inc = in_features // self.group_size\n        s_oc = out_features\n        scales = torch.zeros((s_inc, s_oc), dtype=torch.float16)\n        self.register_buffer('scales', scales)\n\n        if not symmetry:\n            z_inc = in_features // self.group_size\n            z_oc = out_features // (32 // self.w_bit)\n            zeros = torch.zeros((z_inc, z_oc), dtype=torch.int32)\n            self.register_buffer('qzeros', zeros)\n        else:\n            self.qzeros = None\n\n    @classmethod\n    def from_linear(cls: Type['WeightOnlyQLinear'],\n                    linear: nn.Linear,\n                    quantizer: TypeVar('Quantizer'),\n                    awq_layout: bool = True,\n                    qparams: Optional[QParams] = None) -> 'WeightOnlyQLinear':\n        \"\"\"Create a WeightOnlyQLinear object from a PyTorch Linear object.\n\n        Args:\n            linear (nn.Linear): PyTorch Linear object.\n            quantizer (Quantizer): Object that handles quantization.\n            awq_layout (bool): AWQ layout. Defaults to True.\n\n        Returns:\n            WeightOnlyQLinear: A WeightOnlyQLinear object.\n        \"\"\"\n        device = linear.weight.device\n\n        w_bit = quantizer.bits\n        pack_num = 32 // w_bit\n        if awq_layout:\n            assert w_bit == 4\n            pack_order = [0, 2, 4, 6, 1, 3, 5, 7]\n        else:\n            pack_order = torch.arange(pack_num)\n        group_size = quantizer.group_size\n        symmetry = quantizer.symmetry\n\n        in_features = linear.in_features\n        out_features = linear.out_features\n        bias = False if linear.bias is None else True\n\n        qlinear = cls(in_features, out_features, bias, w_bit, symmetry, group_size)\n        qlinear.bias = linear.bias\n\n        if qparams is None:\n            qparams = quantizer.calculate_qparams(linear.weight)\n            i32_w = quantizer.quant(linear.weight, qparams, real=True)\n        else:\n            i32_w = linear.weight.to(torch.int32)\n        i32_w = i32_w.t().contiguous()\n\n        pack_int_w = torch.zeros_like(qlinear.qweight).to(device)\n\n        for col in range(pack_int_w.shape[1]):\n            for i in range(pack_num):\n                pack_int_w_col = i32_w[:, col * pack_num + pack_order[i]]\n                pack_int_w[:, col] |= pack_int_w_col << (i * w_bit)\n\n        qlinear.qweight = pack_int_w\n        qlinear.scales = qparams.scales.squeeze(-1).t().contiguous()\n\n        if qparams.zero_points is not None:\n            zeros = qparams.zero_points.to(torch.int32).to(device)\n            zeros = zeros.squeeze(-1).t().contiguous()\n            pack_int_zeros = torch.zeros_like(qlinear.qzeros).to(device)\n\n            for col in range(pack_int_zeros.shape[1]):\n                for i in range(pack_num):\n                    qzero_col = zeros[:, col * pack_num + pack_order[i]]\n                    pack_int_zeros[:, col] |= qzero_col << (i * w_bit)\n            qlinear.qzeros = pack_int_zeros\n\n        qlinear.to('cpu')\n\n        return qlinear\n\n    @torch.no_grad()\n    def forward(self, x):\n        if awq_inference_engine is None:\n            raise RuntimeError('Run the following command to install '\n                               'the kernel for 4bit inference\\n\\n'\n                               'git clone https://github.com/mit-han-lab/llm-awq.git\\n'\n                               'cd awq/kernels\\n'\n                               'python setup.py install\\n')\n        out_shape = x.shape[:-1] + (self.out_features, )\n        inputs = x.reshape(-1, x.shape[-1])\n\n        out = awq_inference_engine.gemm_forward_cuda(inputs.half(), self.qweight, self.scales.half(), self.qzeros,\n                                                     self.group_size)\n        out = out + self.bias if self.bias is not None else out\n\n        return out.reshape(out_shape)\n"
  },
  {
    "path": "lmdeploy/lite/quantization/weight/__init__.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom .quantizer import WeightQuantizer\n\n__all__ = ['WeightQuantizer']\n"
  },
  {
    "path": "lmdeploy/lite/quantization/weight/quant_utils.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom typing import Optional, Sequence, Union\n\nimport torch\n\n\ndef _aligned_size(a, b):\n    return (a + b - 1) // b * b\n\n\ndef fast_log2_ceil_torch(x: torch.Tensor) -> torch.Tensor:\n    bits_x = x.view(torch.int32)\n    exp_x = (bits_x >> 23) & 0xFF\n    man_bits = bits_x & ((1 << 23) - 1)\n    result = (exp_x - 127).to(torch.int32)\n    result = result + torch.where(man_bits != 0, 1, 0)\n\n    return result.to(torch.int32)\n\n\ndef fast_pow2_torch(x: torch.Tensor) -> torch.Tensor:\n    bits_x = (x + 127) << 23\n    return bits_x.view(torch.float32)\n\n\ndef fast_round_scale_torch(amax: torch.Tensor, fp8_max: torch.Tensor) -> torch.Tensor:\n    return fast_pow2_torch(fast_log2_ceil_torch(amax / fp8_max))\n\n\ndef _get_quant_scaling(weight: torch.Tensor,\n                       fp8_dtype: torch.dtype,\n                       dim: Union[int, Sequence[int]],\n                       scale_fmt: Optional[str] = None):\n    \"\"\"Get the scaling factor for FP8 quantization.\"\"\"\n    finfo = torch.finfo(fp8_dtype)\n    fmax = finfo.max\n    amax = weight.abs().amax(dim, keepdim=True).clamp_min(1e-6).float()\n\n    if scale_fmt == 'ue8m0':\n        return fast_round_scale_torch(amax, fmax)\n    else:\n        # default\n        scaling = amax / fmax\n    return scaling\n\n\ndef quant_blocked_fp8(weight: torch.Tensor,\n                      fp8_dtype: torch.dtype,\n                      block_size: int = 128,\n                      scale_fmt: Optional[str] = None):\n    \"\"\"Quantize the weight tensor to blocked FP8 format.\"\"\"\n    assert scale_fmt in (None, 'ue8m0'), f'Unsupported scale_fmt: {scale_fmt}'\n\n    weight_shape = weight.shape\n    K, N = weight_shape[-2:]\n    aligned_k = _aligned_size(K, block_size)\n    aligned_n = _aligned_size(N, block_size)\n\n    # fill the weight tensor with zeros if it is not aligned\n    if aligned_k != K or aligned_n != N:\n        new_weight = weight.new_zeros(weight_shape[:-2] + (aligned_k, aligned_n))\n        new_weight[..., :K, :N] = weight\n        weight = new_weight\n    aligned_shape = weight.shape\n\n    # reverse pixel shuffle\n    weight = weight.unflatten(-2, (-1, block_size)).unflatten(-1, (-1, block_size))\n    weight = weight.to(torch.float32)\n\n    # get scaling\n    scaling = _get_quant_scaling(weight, fp8_dtype, dim=(-3, -1), scale_fmt=scale_fmt)\n\n    # get quantized weight\n    quantized_weight = weight / scaling\n    quantized_weight = quantized_weight.to(fp8_dtype)\n    quantized_weight = quantized_weight.view(aligned_shape)\n    quantized_weight = quantized_weight[..., :K, :N]\n\n    # reshape scaling\n    scaling = scaling.squeeze(-3, -1)\n\n    return quantized_weight, scaling\n"
  },
  {
    "path": "lmdeploy/lite/quantization/weight/quantizer.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nfrom typing import Callable, Dict, Optional\n\nimport torch\n\nfrom lmdeploy.lite.utils import (QParams, cal_qparams_per_channel_absmax, cal_qparams_per_channel_minmax,\n                                 cal_qparams_per_group_absmax, cal_qparams_per_group_minmax,\n                                 cal_qparams_per_tensor_absmax, cal_qparams_per_tensor_minmax, precise_round)\nfrom lmdeploy.lite.utils.global_avail import GlobalAvailMixin\n\n\nclass WeightQuantizer(GlobalAvailMixin):\n    \"\"\"A class for performing weight quantization of neural networks.\n\n    The WeightQuantizer class provides various methods to quantize the weights\n    of a neural network. This helps in reducing the memory requirements and\n    computational complexity of the model, potentially offering faster\n    inference and lower power consumption.\n\n    Attributes:\n        bits (int): The bit width for quantization.\n        symmetry (bool): If True, use absmax scaling; if False,\n            use min-max scaling.\n        granularity (str): The granularity of quantization. Available options\n            are 'per_channel', 'per_tensor', and 'per_group'.\n        group_size (Optional[int]): If using 'per_group' quantization, this is\n            the number of channels in each group.\n\n    Example:\n\n        # Instantiate the weight quantizer with specific quantization settings\n        quantizer = WeightQuantizer(bits=8,\n                                     symmetry=True,\n                                     granularity='per_tensor')\n\n        # Calculate the quantization parameters for given weights\n        qparams = quantizer.calculate_qparams(weights)\n\n        # Perform fake quantization on the weights\n        quantized_weights = quantizer.fake_quant(weights, qparams)\n    \"\"\"\n\n    CAL_FUNC_MAP: Dict[str, Dict[str, Callable]] = {\n        'per_group': {\n            'absmax': cal_qparams_per_group_absmax,\n            'minmax': cal_qparams_per_group_minmax,\n        },\n        'per_channel': {\n            'absmax': cal_qparams_per_channel_absmax,\n            'minmax': cal_qparams_per_channel_minmax,\n        },\n        'per_tensor': {\n            'absmax': cal_qparams_per_tensor_absmax,\n            'minmax': cal_qparams_per_tensor_minmax,\n        },\n    }\n\n    def __init__(self, bits: int, symmetry: bool, granularity: str, group_size: Optional[int] = -1):\n\n        assert bits in [4, 8], \"The 'bits' argument must be either 4 or 8.\"\n        self.bits = bits\n\n        if granularity not in ['per_channel', 'per_tensor', 'per_group']:\n            raise NotImplementedError(\"The 'granularity' argument must be one of 'per_channel', \"\n                                      \"'per_tensor', or 'per_group'.\")\n\n        self.granularity = granularity\n\n        if self.granularity == 'per_group':\n            assert group_size > 0, \\\n                \"The 'group_size' argument must be greater than 0.\"\n\n        self.group_size = group_size\n\n        # If symmetry is True, use absmax to compute scales\n        # If symmetry is False, use minmax to compute scales and zeor-points\n        self.symmetry = symmetry\n        self.observer = 'absmax' if symmetry else 'minmax'\n\n    def calculate_qparams(self, weight: torch.Tensor) -> QParams:\n        \"\"\"Calculate the quantization parameters for the given weight tensor.\n\n        Args:\n            weight (torch.Tensor): The weight tensor with shape\n                (out_features, in_features).\n\n        Returns:\n            QParams: A namedtuple containing 'scales' and 'zero_points'.\n        \"\"\"\n\n        cal_func = self.CAL_FUNC_MAP[self.granularity][self.observer]\n        if self.granularity == 'per_group':\n            return cal_func(weight, self.bits, self.group_size)\n        else:\n            return cal_func(weight, self.bits)\n\n    def quant(self, weight: torch.Tensor, qparams: Optional[QParams] = None, real: bool = False) -> torch.Tensor:\n        \"\"\"Perform fake quantization on the given weight tensor.\n\n        Args:\n            weight (torch.Tensor): The weight tensor with shape\n                (out_features, in_features).\n            qparams (Optional[QParams]): A namedtuple containing 'scales'\n                and 'zero_points'.\n            real (bool): If True, return the tensor with quantized type.\n\n        Returns:\n            torch.Tensor: The fake quantized weight tensor.\n        \"\"\"\n\n        float_w = weight.float()\n\n        if qparams is None:\n            qparams = self.calculate_qparams(float_w)\n\n        scales = qparams.scales\n        zero_points = qparams.zero_points\n\n        out_c, in_c = weight.shape\n\n        # Reshape the weights if using per_group quantization\n        # per tensor scales shape: [1]\n        # per channel scales shape: [out_c, 1]\n        # per group scales shape: [out_c, in_c//group_size, 1]\n        if len(scales.shape) > 2:\n            # scales shape: [out_c, in_c//group_size, 1]\n            float_w = float_w.reshape(out_c, scales.shape[1], -1)\n\n        if zero_points is None:\n            assert self.symmetry\n            real_qweight = (float_w / scales).round()\n            fake_qweight = real_qweight * scales\n\n        else:\n            assert not self.symmetry\n\n            real_qweight = precise_round((float_w - float_w.min(-1, keepdim=True)[0]) / scales)\n            fake_qweight = (real_qweight - zero_points) * scales\n\n        if len(scales.shape) > 2:\n            real_qweight = real_qweight.reshape(out_c, in_c)\n            fake_qweight = fake_qweight.reshape(out_c, in_c)\n\n        if real:\n            return real_qweight.to(torch.int32)\n        else:\n            return fake_qweight.to(weight.dtype)\n"
  },
  {
    "path": "lmdeploy/lite/utils/__init__.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nfrom .batch_split import concat_decoder_layer_outputs, split_decoder_layer_inputs\nfrom .cal_qparams import (QParams, cal_qparams_per_channel_absmax, cal_qparams_per_channel_minmax,\n                          cal_qparams_per_group_absmax, cal_qparams_per_group_minmax, cal_qparams_per_tensor_absmax,\n                          cal_qparams_per_tensor_minmax, precise_round)\nfrom .calib_dataloader import get_calib_loaders\nfrom .collect import bimap_name_mod, collect_target_modules, collect_target_weights\nfrom .global_avail import GlobalAvailMixin\nfrom .load import load_hf_from_pretrained\n\n__all__ = [\n    'cal_qparams_per_channel_absmax', 'cal_qparams_per_channel_minmax', 'cal_qparams_per_group_absmax',\n    'cal_qparams_per_group_minmax', 'cal_qparams_per_tensor_absmax', 'cal_qparams_per_tensor_minmax', 'QParams',\n    'get_calib_loaders', 'collect_target_modules', 'precise_round', 'collect_target_weights', 'GlobalAvailMixin',\n    'split_decoder_layer_inputs', 'bimap_name_mod', 'concat_decoder_layer_outputs', 'load_hf_from_pretrained'\n]\n"
  },
  {
    "path": "lmdeploy/lite/utils/batch_split.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom typing import Any, Dict, List, Tuple, Union\n\nimport torch\n\n\ndef split_decoder_layer_inputs(batch_size, *args: Union[torch.Tensor, Any],\n                               **kwargs: Union[torch.Tensor, Any]) -> Tuple[List[List[Any]], List[Dict[str, Any]]]:\n    \"\"\"This function splits batched decoder layer inputs into individual\n    elements.\n\n    Args:\n        *args (Union[torch.Tensor, Any]): Positional arguments which could\n            be a mix of tensors and other types.\n        **kwargs (Union[torch.Tensor, Any]): Keyword arguments which could\n            be a mix of tensors and other types.\n\n    Returns:\n        Tuple[List[List[Any]], List[Dict[str, Any]]]: A tuple containing two\n            lists, one for positional arguments, one for keyword arguments.\n            Each list contains individual elements from the batch.\n    \"\"\"\n\n    if not isinstance(args[0], torch.Tensor):\n        raise ValueError('The first argument must be a Tensor')\n\n    bs = args[0].size(0)\n\n    batch_args = []\n    batch_kwargs = []\n    for i in range(0, bs, batch_size):\n        new_args = []\n        # Iterate over each argument. If it's a torch.Tensor and its first\n        # dimension equals the batch size, then get the value corresponding\n        # to the current index, else directly add the whole value.\n        for val in args:\n            if isinstance(val, torch.Tensor) and val.size(0) == bs:\n                new_args.append(val[i:i + batch_size])\n            else:\n                new_args.append(val)\n\n        new_kwargs = {}\n        # Execute the same operation for the keyword arguments.\n        for name, val in kwargs.items():\n            if isinstance(val, torch.Tensor) and val.size(0) == bs:\n                new_kwargs[name] = val[i:i + batch_size]\n            elif isinstance(val, torch.Tensor) and len(val.shape) > 1 and val.size(1) == bs:  # qwen2-vl\n                new_kwargs[name] = val[:, i:i + batch_size]\n            elif name == 'position_embeddings' and isinstance(val, Tuple) and len(\n                    val[0].shape) > 1 and val[0].size(1) == bs:  # qwen2-vl\n                new_kwargs[name] = (val[0][:, i:i + batch_size], val[1][:, i:i + batch_size])\n            else:\n                new_kwargs[name] = val\n\n        batch_args.append(new_args)\n        batch_kwargs.append(new_kwargs)\n\n    return batch_args, batch_kwargs\n\n\ndef concat_decoder_layer_outputs(batch_outputs: List[Any]) -> Any:\n    \"\"\"This function concatenates individual decoder layer outputs into a\n    batched output.\n\n    Args:\n        batch_outputs (List[Any]): A list, where each tuple\n            represents the output from an individual element in the batch.\n\n    Returns:\n        Any: Batched output.\n    \"\"\"\n\n    output_is_tuple = True\n    if not isinstance(batch_outputs[0], tuple):\n        output_is_tuple = False\n        batch_outputs = [(output, ) for output in batch_outputs]\n\n    num_returns = len(batch_outputs[0])\n\n    def is_past_key_value(data: Any) -> bool:\n        \"\"\"Check whether data is a past key-value pair.\n\n        Args:\n            data (Any): The data to check.\n\n        Returns:\n            bool: True if data is a past key-value pair, False otherwise.\n        \"\"\"\n        flag = isinstance(data, tuple)\n        flag = flag and len(data) == 2\n        flag = flag and isinstance(data[0], torch.Tensor)\n        flag = flag and isinstance(data[1], torch.Tensor)\n        return flag\n\n    new_outputs = []\n\n    # Iterate over all types of return values.\n    for i in range(num_returns):\n        # Check if the current element is a past key-value pair.\n        flag = is_past_key_value(batch_outputs[0][i])\n        if flag:\n            # Concatenate the keys and values separately.\n            key = torch.cat([out[i][0] for out in batch_outputs])\n            value = torch.cat([out[i][1] for out in batch_outputs])\n            out_i = (key, value)\n        elif batch_outputs[0][i] is None:  # glm4\n            out_i = None\n        else:\n            # If it's not a past key-value pair, concatenate directly.\n            out_i = torch.cat([out[i] for out in batch_outputs])\n        new_outputs.append(out_i)\n\n    if output_is_tuple:\n        return tuple(new_outputs)\n    else:\n        return new_outputs[0]\n"
  },
  {
    "path": "lmdeploy/lite/utils/cal_qparams.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom typing import NamedTuple, Optional\n\nimport torch\n\n\nclass QParams(NamedTuple):\n    \"\"\"A class to hold the quantization parameters.\"\"\"\n\n    scales: torch.Tensor\n    zero_points: Optional[torch.Tensor]\n\n\n@torch.no_grad()\ndef precise_round(x):\n    return x.sign() * (x.abs() + 0.5).floor()\n\n\n@torch.no_grad()\ndef cal_qparams_per_channel_absmax(w: torch.Tensor, n_bits: int, return_stats: bool = False) -> QParams:\n    \"\"\"Calculate quantization parameters for each channel using absolute max\n    value.\"\"\"\n    float_w = w.float()\n\n    absmax = float_w.abs().max(dim=-1, keepdim=True)[0]\n    q_max = 2**(n_bits - 1) - 1\n    scales = absmax.div(q_max)\n\n    if return_stats:\n        return QParams(scales=scales, zero_points=None), absmax\n    else:\n        return QParams(scales=scales, zero_points=None)\n\n\n@torch.no_grad()\ndef cal_qparams_per_channel_minmax(w: torch.Tensor, n_bits: int, return_stats: bool = False) -> QParams:\n    \"\"\"Calculate quantization parameters for each channel using min and max\n    values.\"\"\"\n\n    float_w = w.float()\n\n    w_min = float_w.min(dim=-1, keepdim=True)[0]\n    w_max = float_w.max(dim=-1, keepdim=True)[0]\n\n    q_max = 2**n_bits - 1\n    scales = (w_max - w_min)\n    scales = scales.div_(q_max)\n\n    zero_points = precise_round(-w_min / scales)\n\n    if return_stats:\n        return QParams(scales=scales, zero_points=zero_points), (w_min, w_max)\n    else:\n        return QParams(scales=scales, zero_points=zero_points)\n\n\n@torch.no_grad()\ndef cal_qparams_per_group_absmax(w: torch.Tensor, n_bits: int, group_size: int, return_stats: bool = False) -> QParams:\n    \"\"\"Calculate quantization parameters for each group using absolute max\n    value.\"\"\"\n\n    outc, inc = w.shape\n    assert inc >= group_size, \\\n        'Input channels should be greater than or equal to group_size.'\n    assert inc % group_size == 0, \\\n        'Input channels should be divisible by group_size.'\n\n    float_w = w.float()\n    absmax = float_w.abs().reshape(outc, -1, group_size).max(dim=-1, keepdim=True)[0]\n    q_max = 2**(n_bits - 1) - 1\n    scales = absmax.div(q_max)\n    if return_stats:\n        return QParams(scales=scales, zero_points=None), absmax\n    else:\n        return QParams(scales=scales, zero_points=None)\n\n\n@torch.no_grad()\ndef cal_qparams_per_group_minmax(w: torch.Tensor, n_bits: int, group_size: int, return_stats: bool = False) -> QParams:\n    \"\"\"Calculate quantization parameters for each group using min and max\n    values.\"\"\"\n\n    outc, inc = w.shape\n    assert inc >= group_size, \\\n        'Input channels should be greater than or equal to group_size.'\n    assert inc % group_size == 0, \\\n        'Input channels should be divisible by group_size.'\n\n    float_w = w.float()\n    w_group_wise = float_w.reshape(outc, -1, group_size)\n    w_min = w_group_wise.min(dim=-1, keepdim=True)[0]\n    w_max = w_group_wise.max(dim=-1, keepdim=True)[0]\n\n    q_max = 2**n_bits - 1\n    scales = (w_max - w_min)\n    scales = scales.div_(q_max)\n    zero_points = precise_round(-w_min / scales)\n    if return_stats:\n        return QParams(scales=scales, zero_points=zero_points), (w_min, w_max)\n    else:\n        return QParams(scales=scales, zero_points=zero_points)\n\n\n@torch.no_grad()\ndef cal_qparams_per_tensor_minmax(w: torch.Tensor, n_bits: int, return_stats: bool = False) -> QParams:\n    \"\"\"Calculate quantization parameters for the entire tensor using min and\n    max values.\"\"\"\n\n    float_w = w.float()\n\n    w_min = float_w.min()\n    w_max = float_w.max()\n\n    q_max = 2**n_bits - 1\n    scales = (w_max - w_min)\n    scales = scales.clamp_(min=1e-5).div_(q_max)\n    zero_points = precise_round(-w_min / scales)\n    if return_stats:\n        return QParams(scales=scales, zero_points=zero_points), (w_min, w_max)\n    else:\n        return QParams(scales=scales, zero_points=zero_points)\n\n\n@torch.no_grad()\ndef cal_qparams_per_tensor_absmax(w: torch.Tensor, n_bits: int, return_stats: bool = False) -> QParams:\n    \"\"\"Calculate quantization parameters for the entire tensor using absolute\n    max value.\"\"\"\n    float_w = w.float()\n    absmax = float_w.abs().max()\n    q_max = 2**(n_bits - 1) - 1\n    scales = absmax.div(q_max)\n\n    if return_stats:\n        return QParams(scales=scales, zero_points=None), absmax\n    else:\n        return QParams(scales=scales, zero_points=None)\n"
  },
  {
    "path": "lmdeploy/lite/utils/calib_dataloader.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport numpy as np\nimport torch\n\nNUM_LOADED_SAMPLES = 30000\n\n\ndef set_seed(seed):\n    np.random.seed(seed)\n    torch.random.manual_seed(seed)\n\n\n# adapted from https://github.com/vllm-project/llm-compressor/blob/main/tests/testing_utils.py\ndef process_dataset(ds, tokenizer, max_seq_length):\n    \"\"\"Helper function to preprocess and tokenize a dataset according to\n    presets.\n\n    Args:\n        ds: Language dataset to preprocess and tokenize.\n        tokenizer: Tokenizer to encode text.\n        max_seq_length: Maximum sequence length of samples.\n\n    Returns:\n        ds: Tokenized dataset.\n    \"\"\"\n    ds_name = ds.info.dataset_name.lower()\n    if ds_name == 'gsm8k':\n\n        def tokenize(sample):\n            return tokenizer(\n                sample['question'],\n                padding=False,\n                max_length=max_seq_length,\n                truncation=True,\n                add_special_tokens=False,\n            )\n\n    elif ds_name == 'open-platypus':\n        # use the output rather than the instruction\n        def tokenize(sample):\n            messages = [{\n                'role': 'user',\n                'content': sample['instruction'] + ' ' + sample['input']\n            }, {\n                'role': 'assistant',\n                'content': sample['output']\n            }]\n            return tokenizer(\n                tokenizer.apply_chat_template(\n                    messages,\n                    tokenize=False,\n                ),\n                padding=False,\n                max_length=max_seq_length,\n                truncation=True,\n                add_special_tokens=False,\n            )\n\n    # \"neuralmagic/calibration\"\n    elif ds_name == 'calibration':\n\n        def tokenize(sample):\n            messages = []\n            for message in sample['messages']:\n                if message['role'] == 'user':\n                    messages.append({'role': 'user', 'content': message['content']})\n                elif message['role'] == 'assistant':\n                    messages.append({'role': 'assistant', 'content': message['content']})\n\n            return tokenizer(\n                tokenizer.apply_chat_template(\n                    messages,\n                    tokenize=False,\n                ),\n                padding=False,\n                max_length=max_seq_length,\n                truncation=True,\n                add_special_tokens=False,\n            )\n\n    elif ds_name == 'openwebtext':\n\n        def tokenize(sample):\n            return tokenizer(\n                sample['text'],\n                padding=False,\n                max_length=max_seq_length,\n                truncation=True,\n                add_special_tokens=False,\n            )\n\n    else:\n        raise NotImplementedError(f'Cannot preprocess dataset {ds.info.dataset_name} '\n                                  f'Only `gsm8k`, `open-platypus`, `calibration`, `openwebtext` '\n                                  f'are supported by preprocess. ')\n\n    ds = ds.map(tokenize, remove_columns=ds.column_names)\n\n    return ds\n\n\ndef get_wikitext2(dataset, tokenizer, nsamples, seed, seqlen):\n    \"\"\"Load Wikitext-2 train and test datasets and tokenize.\n\n    Args:\n        dataset: calib dataset\n        tokenizer: Tokenizer to encode text.\n        nsamples: Number of samples to take from train set.\n        seed: Random seed for sampling.\n        seqlen: Maximum sequence length.\n\n    Returns:\n        List of sampled and tokenized training examples.\n    \"\"\"\n    trainenc = tokenizer('\\n\\n'.join(dataset['text']), return_tensors='pt')\n\n    import random\n    random.seed(seed)\n    trainloader = []\n    for _ in range(nsamples):\n        i = random.randint(0, trainenc.input_ids.shape[1] - seqlen)\n        j = i + seqlen\n        inp = trainenc.input_ids[:, i:j]\n        trainloader.append(inp)\n    return trainloader\n\n\ndef get_c4(dataset, tokenizer, nsamples, seed, seqlen):\n    \"\"\"Load C4 train and validation datasets and tokenize.\n\n    Args:\n        dataset: calib dataset\n        tokenizer: Tokenizer to encode text.\n        nsamples: Number of samples to take from train set.\n        seed: Random seed for sampling.\n        seqlen: Maximum sequence length.\n\n    Returns:\n        List of sampled and tokenized training examples.\n    \"\"\"\n    import random\n    random.seed(seed)\n    trainloader = []\n    for _ in range(nsamples):\n        while True:\n            i = random.randint(0, len(dataset) - 1)\n            trainenc = tokenizer(dataset[i]['text'], return_tensors='pt')\n            if trainenc.input_ids.shape[1] >= seqlen:\n                break\n        i = random.randint(0, trainenc.input_ids.shape[1] - seqlen)\n        j = i + seqlen\n        inp = trainenc.input_ids[:, i:j]\n        trainloader.append(inp)\n\n    return trainloader\n\n\ndef get_pileval(dataset, tokenizer, nsamples, seed, seqlen=512):\n    \"\"\"Load pileval train dataset and tokenize.\n\n    Args:\n        dataset: calib dataset\n        tokenizer: Tokenizer to encode text.\n        nsamples: Number of samples to take from train set.\n        seed: Random seed for sampling.\n        seqlen: Maximum sequence length.\n\n    Returns:\n        List of sampled and tokenized training examples.\n    \"\"\"\n\n    # pileval samples have far fewer tokens than seqlen; recompute how many\n    # train items to select so it can still yield enough samples after concatenation.\n    samples_encode = []\n    lengths = []\n    for data in dataset:\n        ids = tokenizer.encode(data['text'].strip())\n        if not ids or len(ids) > 512:\n            continue\n        samples_encode.append(torch.tensor([ids]))\n        lengths.append(len(ids))\n        if len(samples_encode) >= len(dataset):\n            break\n\n    avg_tokens = sum(lengths) / len(lengths)\n    needed_samples = max(1, int((seqlen * nsamples) // avg_tokens))\n\n    dataset = dataset.shuffle(seed=seed)\n    samples = []\n    n_run = 0\n    for data in dataset:\n        line = data['text']\n        line = line.strip()\n        line_encoded = tokenizer.encode(line)\n        if len(line_encoded) > 512:\n            continue\n        sample = torch.tensor([line_encoded])\n        if sample.numel() == 0:\n            continue\n        samples.append(sample)\n        n_run += 1\n        if n_run == needed_samples:\n            break\n    # now concatenate all samples and split according to block size\n    cat_samples = torch.cat(samples, dim=1)\n    n_split = cat_samples.shape[1] // seqlen\n    print(f' * Split into {n_split} blocks')\n    return [cat_samples[:, i * seqlen:(i + 1) * seqlen] for i in range(n_split)]\n\n\ndef get_gsm8k(dataset, tokenizer, nsamples, seed, seqlen):\n    \"\"\"Load GSM8K train and test datasets and tokenize.\n\n    Args:\n        dataset: calib dataset\n        tokenizer: Tokenizer to encode text.\n        nsamples: Number of samples to take from train set.\n        seed: Random seed for sampling.\n        seqlen: Maximum sequence length.\n\n    Returns:\n        List of sampled and tokenized training examples.\n    \"\"\"\n    dataset = dataset.shuffle(seed=seed)\n    dataset = process_dataset(dataset, tokenizer, seqlen)\n\n    # GSM8K samples have far fewer tokens than seqlen; recompute how many\n    # train items to select so it can still yield enough samples after concatenation.\n    lengths = torch.tensor([len(sample['input_ids']) for sample in dataset], dtype=torch.long)\n    avg_tokens = lengths.sum().item() // len(dataset)\n    needed_samples = max(1, int((seqlen * nsamples) // avg_tokens))\n\n    samples = []\n    n_run = 0\n    for i in range(len(dataset)):\n        line = dataset[i]['input_ids']\n        sample = torch.tensor([line])\n        if sample.numel() == 0:\n            continue\n        samples.append(sample)\n        n_run += 1\n        if n_run == needed_samples:\n            break\n    cat_samples = torch.cat(samples, dim=1)\n    n_split = cat_samples.shape[1] // seqlen\n    print(f' * Split into {n_split} blocks')\n    return [cat_samples[:, i * seqlen:(i + 1) * seqlen] for i in range(n_split)]\n\n\ndef get_neuralmagic_calibration(dataset, tokenizer, nsamples, seed, seqlen):\n    \"\"\"Load neuralmagic_calibration train and test datasets and tokenize.\n\n    Args:\n        dataset: calib dataset\n        tokenizer: Tokenizer to encode text.\n        nsamples: Number of samples to take from train set.\n        seed: Random seed for sampling.\n        seqlen: Maximum sequence length.\n\n    Returns:\n        List of sampled and tokenized training examples.\n    \"\"\"\n    dataset = dataset.shuffle(seed=seed)\n    dataset = process_dataset(dataset, tokenizer, seqlen)\n\n    # neuralmagic_calibration samples have far fewer tokens than seqlen; recompute how many\n    # train items to select so it can still yield enough samples after concatenation.\n    lengths = torch.tensor([len(sample['input_ids']) for sample in dataset], dtype=torch.long)\n    avg_tokens = lengths.sum().item() / len(dataset)\n    needed_samples = max(1, int((seqlen * nsamples) // avg_tokens))\n\n    samples = []\n    n_run = 0\n    for i in range(len(dataset)):\n        line = dataset[i]['input_ids']\n        sample = torch.tensor([line])\n        if sample.numel() == 0:\n            continue\n        samples.append(sample)\n        n_run += 1\n        if n_run == needed_samples:\n            break\n    cat_samples = torch.cat(samples, dim=1)\n    n_split = cat_samples.shape[1] // seqlen\n    print(f' * Split into {n_split} blocks')\n    return [cat_samples[:, i * seqlen:(i + 1) * seqlen] for i in range(n_split)]\n\n\ndef get_open_platypus(dataset, tokenizer, nsamples, seed, seqlen):\n    \"\"\"Load open-platypus train and test datasets and tokenize.\n\n    Args:\n        dataset: calib dataset\n        tokenizer: Tokenizer to encode text.\n        nsamples: Number of samples to take from train set.\n        seed: Random seed for sampling.\n        seqlen: Maximum sequence length.\n\n    Returns:\n        List of sampled and tokenized training examples.\n    \"\"\"\n    dataset = dataset.shuffle(seed=seed)\n    dataset = process_dataset(dataset, tokenizer, seqlen)\n\n    # open-platypus samples have far fewer tokens than seqlen; recompute how many\n    # train items to select so it can still yield enough samples after concatenation.\n    lengths = torch.tensor([len(sample['input_ids']) for sample in dataset], dtype=torch.long)\n    avg_tokens = lengths.sum().item() / len(dataset)\n    needed_samples = max(1, int((seqlen * nsamples) // avg_tokens))\n\n    samples = []\n    n_run = 0\n    for i in range(len(dataset)):\n        line = dataset[i]['input_ids']\n        sample = torch.tensor([line])\n        if sample.numel() == 0:\n            continue\n        samples.append(sample)\n        n_run += 1\n        if n_run == needed_samples:\n            break\n    cat_samples = torch.cat(samples, dim=1)\n    n_split = cat_samples.shape[1] // seqlen\n    print(f' * Split into {n_split} blocks')\n    return [cat_samples[:, i * seqlen:(i + 1) * seqlen] for i in range(n_split)]\n\n\ndef get_openwebtext(dataset, tokenizer, nsamples, seed, seqlen):\n    \"\"\"Load openwebtext train and test datasets and tokenize.\n\n    Args:\n        dataset: calib dataset\n        tokenizer: Tokenizer to encode text.\n        nsamples: Number of samples to take from train set.\n        seed: Random seed for sampling.\n        seqlen: Maximum sequence length.\n\n    Returns:\n        List of sampled and tokenized training examples.\n    \"\"\"\n    dataset = dataset.shuffle(seed=seed)\n    dataset = process_dataset(dataset, tokenizer, seqlen)\n\n    import random\n    random.seed(seed)\n    trainloader = []\n    for _ in range(nsamples):\n        while True:\n            i = random.randint(0, len(dataset) - 1)\n            trainenc = dataset[i]\n            if len(trainenc['input_ids']) >= seqlen:\n                break\n        i = random.randint(0, len(trainenc['input_ids']) - seqlen)\n        j = i + seqlen\n        inp = trainenc['input_ids'][i:j]\n        inp = torch.tensor([inp])\n        trainloader.append(inp)\n\n    return trainloader\n\n\ndef get_calib_loaders(name, tokenizer, nsamples=128, seed=0, seqlen=2048):\n    \"\"\"Get calibration data loaders for a dataset.\n\n    Args:\n        name: Dataset name ('wikitext2', 'c4', 'pileval', 'gsm8k',\n                'neuralmagic_calibration', 'open-platypus', 'openwebtext').\n        tokenizer: Tokenizer to encode text.\n        nsamples: Number of samples to take from train set.\n        seed: Random seed for sampling.\n        seqlen: Maximum sequence length.\n\n    Returns:\n        List of sampled and tokenized training examples.\n    \"\"\"\n    from datasets import VerificationMode, load_dataset\n    if 'wikitext2' in name:\n        dataset = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')\n        return get_wikitext2(dataset, tokenizer, nsamples, seed, seqlen)\n\n    if 'c4' in name:\n        dataset = load_dataset('allenai/c4',\n                               'en',\n                               data_files={'train': 'en/c4-train.00000-of-01024.json.gz'},\n                               split='train',\n                               verification_mode=VerificationMode.NO_CHECKS)\n        return get_c4(dataset, tokenizer, nsamples, seed, seqlen)\n\n    if 'pileval' in name:\n        from datasets.builder import DatasetGenerationError\n        try:\n            dataset = load_dataset('mit-han-lab/pile-val-backup', split=f'validation[:{NUM_LOADED_SAMPLES}]')\n        except DatasetGenerationError:\n            raise InterruptedError('There have been some issues when generating '\n                                   'the dataset, you could try to download it '\n                                   'locally first, and replace the `data_files`'\n                                   'with local addresses or use other datasets '\n                                   '(c4, wiki, ptb).')\n        return get_pileval(dataset, tokenizer, nsamples, seed, seqlen)\n\n    if 'gsm8k' in name:\n        dataset = load_dataset('openai/gsm8k', 'main', split='train')\n        return get_gsm8k(dataset, tokenizer, nsamples, seed, seqlen)\n\n    if 'neuralmagic_calibration' in name:\n        dataset = load_dataset('neuralmagic/calibration', 'LLM', split='train')\n        return get_neuralmagic_calibration(dataset, tokenizer, nsamples, seed, seqlen)\n\n    if 'open-platypus' in name:\n        dataset = load_dataset('garage-bAInd/Open-Platypus', split='train')\n        return get_open_platypus(dataset, tokenizer, nsamples, seed, seqlen)\n\n    if 'openwebtext' in name:\n        dataset = load_dataset('Skylion007/openwebtext',\n                               data_files={'train': 'plain_text/train-00000-of-00080.parquet'},\n                               split=f'train[:{NUM_LOADED_SAMPLES}]',\n                               verification_mode=VerificationMode.NO_CHECKS)\n        return get_openwebtext(dataset, tokenizer, nsamples, seed, seqlen)\n"
  },
  {
    "path": "lmdeploy/lite/utils/collect.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom typing import Dict, List, Tuple, Union\n\nfrom torch import nn\n\n\ndef collect_target_modules(model: nn.Module,\n                           target: Union[str, type],\n                           skip_names: List[str] = [],\n                           prefix: str = '') -> Dict[str, nn.Module]:\n    \"\"\"Collects the specific target modules from the model.\n\n    Args:\n        model : The PyTorch module from which to collect the target modules.\n        target : The specific target to be collected. It can be a class of a\n            module or the name of a module.\n        skip_names : List of names of modules to be skipped during collection.\n        prefix : A string to be added as a prefix to the module names.\n\n    Returns:\n        A dictionary mapping from module names to module instances.\n    \"\"\"\n\n    if not isinstance(target, (type, str)):\n        raise TypeError('Target must be a string (name of the module) '\n                        'or a type (class of the module)')\n\n    def _is_target(n, m):\n        if isinstance(target, str):\n            return target == type(m).__name__ and n not in skip_names\n        return isinstance(m, target) and n not in skip_names\n\n    name2mod = {}\n    for name, mod in model.named_modules():\n        m_name = f'{prefix}.{name}' if prefix else name\n        if _is_target(name, mod):\n            name2mod[m_name] = mod\n    return name2mod\n\n\ndef collect_target_weights(model: nn.Module, target: Union[str, type], skip_names: List[str]) -> Dict[str, nn.Module]:\n    \"\"\"Collects weights of the specific target modules from the model.\n\n    Args:\n        model : The PyTorch module from which to collect the weights of\n            target modules.\n        target : The specific target whose weights to be collected. It can be\n            a class of a module or the name of a module.\n        skip_names : Names of modules to be skipped during weight collection.\n\n    Returns:\n        A dictionary mapping from module instances to their\n            corresponding weights.\n    \"\"\"\n\n    named_modules = collect_target_modules(model, target, skip_names)\n    mod2weight = {}\n    for _, mod in named_modules.items():\n        assert hasattr(mod, 'weight'), \"The module does not have a 'weight' attribute\"\n        mod2weight[mod] = mod.weight\n    return mod2weight\n\n\ndef bimap_name_mod(name2mod_mappings: List[Dict[str, nn.Module]]) -> Tuple[Dict[str, nn.Module], Dict[nn.Module, str]]:\n    \"\"\"Generates bidirectional maps from module names to module instances and\n    vice versa.\n\n    Args:\n        name2mod_mappings : List of dictionaries each mapping from module\n            names to module instances.\n\n    Returns:\n        Two dictionaries providing bidirectional mappings between module\n            names and module instances.\n    \"\"\"\n\n    name2mod = {}\n    mod2name = {}\n    for mapping in name2mod_mappings:\n        mod2name.update({v: k for k, v in mapping.items()})\n        name2mod.update(mapping)\n    return name2mod, mod2name\n"
  },
  {
    "path": "lmdeploy/lite/utils/global_avail.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom typing import Dict, Union\n\nfrom torch import nn\n\n\nclass GlobalAvailMixin:\n    \"\"\"Mixin class to make instances globally available.\"\"\"\n\n    _instances: Dict[str, Dict[Union[str, nn.Module], 'GlobalAvailMixin']] = {'default': {}}\n\n    def global_available(self, key: Union[str, nn.Module] = 'default', group: str = 'default') -> None:\n        \"\"\"Make the instance globally available.\n\n        Args:\n            key (Union[str, nn.Module], optional): Key to save the instance.\n                Defaults to 'default'.\n            group (str, optional): Group to save the instance.\n                Defaults to 'default'.\n        \"\"\"\n        self._save_instance(self, key, group)\n\n    @classmethod\n    def _save_instance(cls,\n                       instance: 'GlobalAvailMixin',\n                       key: Union[str, nn.Module] = 'default',\n                       group: str = 'default') -> None:\n        \"\"\"Save the instance.\n\n        Args:\n            instance (GlobalAvailMixin): Instance to save.\n            key (Union[str, nn.Module], optional): Key to save the instance.\n                Defaults to 'default'.\n            group (str, optional): Group to save the instance.\n                Defaults to 'default'.\n        \"\"\"\n        if group not in cls._instances:\n            assert isinstance(group, str)\n            cls._instances[group] = {}\n\n        cls._instances[group][key] = instance\n\n    @classmethod\n    def find(cls, key: Union[str, nn.Module] = 'default', group: str = 'default') -> Union[None, 'GlobalAvailMixin']:\n        \"\"\"Find an instance by its key and group.\n\n        Args:\n            key (Union[str, nn.Module], optional): Key of the instance.\n                Defaults to 'default'.\n            group (str, optional): Group of the instance.\n                Defaults to 'default'.\n\n        Returns:\n            Union[None, GlobalAvailMixin]: The found instance, or None if\n                it does not exist.\n        \"\"\"\n        return cls._instances.get(group, {}).get(key)\n\n    @classmethod\n    def find_group(cls, group: str) -> Dict[Union[str, nn.Module], 'GlobalAvailMixin']:\n        \"\"\"Find all instances in a group.\n\n        Args:\n            group (str): Group of the instances.\n\n        Returns:\n            Dict[Union[str, nn.Module], GlobalAvailMixin]: All instances in\n                the group.\n        \"\"\"\n        return cls._instances.get(group, {})\n\n    @classmethod\n    def instances(cls) -> Dict[str, Dict[Union[str, nn.Module], 'GlobalAvailMixin']]:\n        \"\"\"Get all instances.\"\"\"\n        return cls._instances\n"
  },
  {
    "path": "lmdeploy/lite/utils/load.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nfrom typing import Literal\n\nimport torch\nfrom transformers import AutoConfig, AutoModelForCausalLM\n\n\nclass LoadNoInit:\n    \"\"\"Initialize model without parameter initialization.\"\"\"\n\n    def __init__(self):\n        self.constant_ = torch.nn.init.constant_\n        self.zeros_ = torch.nn.init.zeros_\n        self.ones_ = torch.nn.init.ones_\n        self.uniform_ = torch.nn.init.uniform_\n        self.normal_ = torch.nn.init.normal_\n        self.kaiming_uniform_ = torch.nn.init.kaiming_uniform_\n        self.kaiming_normal_ = torch.nn.init.kaiming_normal_\n        self.tensor_normal_ = torch.Tensor.normal_\n\n    def __enter__(self, *args, **kwargs):\n        \"\"\"Replace initializers with no-op.\"\"\"\n\n        torch.nn.init.constant_ = lambda *args, **kwargs: None\n        torch.nn.init.zeros_ = lambda *args, **kwargs: None\n        torch.nn.init.ones_ = lambda *args, **kwargs: None\n        torch.nn.init.uniform_ = lambda *args, **kwargs: None\n        torch.nn.init.normal_ = lambda *args, **kwargs: None\n        torch.nn.init.kaiming_uniform_ = lambda *args, **kwargs: None\n        torch.nn.init.kaiming_normal_ = lambda *args, **kwargs: None\n        torch.Tensor.normal_ = lambda *args, **kwargs: None\n\n    def __exit__(self, *args, **kwargs):\n        \"\"\"Recover.\"\"\"\n\n        torch.nn.init.constant_ = self.constant_\n        torch.nn.init.zeros_ = self.zeros_\n        torch.nn.init.ones_ = self.ones_\n        torch.nn.init.uniform_ = self.uniform_\n        torch.nn.init.normal_ = self.normal_\n        torch.nn.init.kaiming_uniform_ = self.kaiming_uniform_\n        torch.nn.init.kaiming_normal_ = self.kaiming_normal_\n        torch.Tensor.normal_ = self.tensor_normal_\n\n\ndef load_hf_from_pretrained(pretrained_model_name_or_path, dtype: Literal['float16', 'bfloat16', 'auto'], **kwargs):\n\n    if dtype == 'bfloat16' and not torch.cuda.is_bf16_supported():\n        raise RuntimeError('Your device does not supports bf16(bfloat16), '\n                           'please change to fp16(float16)')\n\n    kwargs.pop('config', None)\n\n    hf_config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True)\n\n    # HACK hard code for qwen, other configs do not have the `fp16` attribute.\n    if hasattr(hf_config, 'fp16') or hasattr(hf_config, 'bf16'):\n        if dtype == 'bfloat16':\n            hf_config.bf16 = True\n        else:\n            hf_config.fp16 = True\n\n    torch_dtype = getattr(hf_config, 'torch_dtype', torch.float16)\n    if dtype == 'bfloat16':\n        torch_dtype = torch.bfloat16\n    elif dtype == 'float16':\n        torch_dtype = torch.float16\n    elif dtype == 'auto' and torch_dtype == torch.bfloat16:\n        print('Warning: we cast model to float16 to prevent OOM. '\n              'You may enforce it bfloat16 by `--dtype bfloat16`')\n        torch_dtype = torch.float16\n\n    with LoadNoInit():\n        # Load model\n        model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path,\n                                                     config=hf_config,\n                                                     torch_dtype=torch_dtype,\n                                                     **kwargs)\n        model.config.use_cache = False\n\n    return model\n"
  },
  {
    "path": "lmdeploy/lite/utils/memory_efficient.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport inspect\nimport re\nimport warnings\nfrom contextlib import contextmanager\nfrom functools import partial\nfrom typing import List\n\nimport torch\nfrom torch import nn\n\nfrom lmdeploy.lite.defaults import KV_CACHE_SIGNATURE, OFFLOAD_MOD\n\n\ndef extract_return_values(module: nn.Module) -> List[str]:\n    \"\"\"Extracts return values from given module's forward method.\n\n    Args:\n        module (nn.Module): Module to inspect\n\n    Returns:\n        list[str]: List of return values\n    \"\"\"\n\n    last_line = inspect.getsource(module.forward).rstrip('\\n').split('\\n')[-1]\n    pattern = r'return ([\\w\\s,]+)'\n    match = re.search(pattern, last_line)\n\n    if match:\n        return_values = match.group(1).split(',')\n        return [value.strip() for value in return_values]\n    else:\n        return []\n\n\ndef find_kv_cache_idx(module: nn.Module) -> int:\n    \"\"\"Finds index of kv cache signature in module's forward parameters.\"\"\"\n\n    signatures = list(inspect.signature(module.forward).parameters.keys())\n    if KV_CACHE_SIGNATURE not in signatures:\n        raise ValueError(f'{KV_CACHE_SIGNATURE} not in signatures of '\n                         f'{type(module)} forward.')\n    return signatures.index(KV_CACHE_SIGNATURE)\n\n\ndef find_modules_by_return_value(model: nn.Module, value: str) -> List[nn.Module]:\n    \"\"\"Finds modules in model that return given value.\n\n    Args:\n        model (nn.Module): Model to inspect\n        value (str): Return value to search for\n\n    Returns:\n        list[nn.Module]: List of matching modules\n\n    Raises:\n        ValueError: If no matching modules found\n    \"\"\"\n\n    modules = []\n    for name, module in model.named_modules():\n        returns = extract_return_values(module)\n        if value in returns:\n            print(f'Found {name} returning {value}')\n            modules.append(module)\n\n    if not modules:\n        error_msg = f'No modules found returning {value}. '\n        error_msg += 'Please check if the default KV_CACHE_SIGNATURE  '\n        error_msg += f\"'{KV_CACHE_SIGNATURE}' matches what is used in your \"\n        error_msg += 'model code. If not, you can modify KV_CACHE_SIGNATURE '\n        error_msg += 'in `lmdeploy.lite.defaults`.'\n        raise ValueError(error_msg)\n\n    return modules\n\n\n@contextmanager\ndef offload_kv_cache(model: nn.Module, device: str = 'cuda') -> None:\n    \"\"\"Offloads kv cache to given device during forward pass.\n\n    Args:\n        model (nn.Module): Model for inference\n        device (str): Device to offload to\n\n    Yields:\n        None\n    \"\"\"\n\n    modules = find_modules_by_return_value(model, KV_CACHE_SIGNATURE)\n\n    original_forwards = {mod: mod.forward for mod in modules}\n    input_idxs = {mod: find_kv_cache_idx(mod) for mod in modules}\n    output_idxs = {mod: extract_return_values(mod).index(KV_CACHE_SIGNATURE) for mod in modules}\n\n    def wrap_forward(module, *args, **kwargs):\n\n        idx = input_idxs[module]\n        if idx >= len(args):\n            # kv cache in kwargs\n            if KV_CACHE_SIGNATURE in kwargs:\n                if kwargs[KV_CACHE_SIGNATURE]:\n                    kwargs[KV_CACHE_SIGNATURE] = kwargs[KV_CACHE_SIGNATURE].to(device)\n            else:\n                raise ValueError(f'No kv cache input found at index {idx}')\n        else:\n            # kv cache in args\n            args = list(args)\n            args[idx] = args[idx].to(device)\n            args = tuple(args)\n\n        result = original_forwards[module](*args, **kwargs)\n\n        result = list(result)\n        idx = output_idxs[module]\n\n        # Move kv cache outputs back to CPU\n        key = result[idx][0].to('cpu')\n        value = result[idx][1].to('cpu')\n        torch.cuda.empty_cache()\n\n        result[idx] = (key, value)\n        result = tuple(result)\n\n        return result\n\n    try:\n        for module in modules:\n            original_forwards[module] = module.forward\n            module.forward = partial(wrap_forward, module)\n\n        yield\n\n    finally:\n        for module in modules:\n            module.forward = original_forwards[module]\n            del original_forwards[module]\n\n\n@contextmanager\ndef offload_weights(model: nn.Module, device: str = 'cuda') -> None:\n    \"\"\"Offloads specified modules to given device during forward pass.\n\n    Args:\n        model (nn.Module): Model for inference\n        device (str): Device to offload to\n\n    Yields:\n        None\n    \"\"\"\n\n    target_modules = OFFLOAD_MOD\n\n    def before_forward(module: nn.Module, inp: torch.Tensor):\n        module.to(device)\n\n    def after_forward(module: nn.Module, inp: torch.Tensor, out: torch.Tensor):\n        module.to('cpu')\n        torch.cuda.empty_cache()\n\n    def _to_device(m, spec_modules, dev):\n        if len(spec_modules) == 0 or len(list(m.children())) == 0:\n            m.to(dev)\n            return\n\n        for child in m.children():\n            if isinstance(child, spec_modules):\n                child.to('cpu')\n            else:\n                _to_device(child, spec_modules, dev)\n                # m.to(dev)\n\n    warnings.warn('By default, offloading will be done on '\n                  '`nn.Linear`. You can add modules which want offload to '\n                  'the `lmdeploy.lite.defaults.OFFLOAD_MOD`.')\n    target = OFFLOAD_MOD\n\n    _to_device(model, target, device)\n\n    handles = []\n    for module in model.modules():\n        if isinstance(module, target_modules):\n            handle1 = module.register_forward_pre_hook(before_forward)\n            handle2 = module.register_forward_hook(after_forward)\n            handles.extend([handle1, handle2])\n\n    try:\n        yield\n    finally:\n        for handle in handles:\n            handle.remove()\n\n        model.to('cpu')\n        torch.cuda.empty_cache()\n\n\n@contextmanager\ndef memory_efficient_inference(model: nn.Module, offload: bool = True, device: str = 'cuda') -> None:\n    \"\"\"Memory efficient inference context manager.\n\n    Moves model to device for inference, with option to offload\n    specific modules.\n\n    Args:\n        model (nn.Module): Model for inference\n        offload (bool): Whether to offload modules\n        device (str): Device for inference\n\n    Yields:\n        None\n    \"\"\"\n\n    if offload:\n        warnings.warn('Using offload mode - modules defined in OFFLOAD_MOD '\n                      'will be moved to GPU during forward pass only.')\n        warnings.warn('Using offload mode will incur performance penalty due to '\n                      'frequent CPU-GPU data transfers.')\n        with torch.inference_mode():\n            with offload_kv_cache(model, device):\n                with offload_weights(model, device):\n                    yield\n    else:\n        model.to(device)\n        with torch.inference_mode():\n            yield\n"
  },
  {
    "path": "lmdeploy/logger.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n# modify from https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/logger.py  # noqa\nfrom typing import List, Optional\n\nfrom .messages import GenerationConfig\nfrom .utils import get_logger\n\nlogger = get_logger('lmdeploy')\n\n\nclass RequestLogger:\n    \"\"\"A class responsible for logging requests, ensuring that logs do not\n    exceed a specified maximum length.\n\n    Args:\n        max_log_len (Optional[int]): The maximum length of the log entries.\n            If None, no maximum length is enforced.\n    \"\"\"\n\n    def __init__(self, max_log_len: Optional[int]) -> None:\n        self.max_log_len = max_log_len\n\n    def log_prompt(self, session_id: int, prompt: str) -> None:\n        if not isinstance(prompt, str):\n            # Prompt may be a GPT4V message with base64 images;\n            # logging might be impractical due to length\n            return\n        if self.max_log_len is not None:\n            if prompt is not None:\n                prompt = prompt[:self.max_log_len]\n        logger.info(f'session={session_id}, '\n                    f'prompt={prompt!r}')\n\n    def log_inputs(self, session_id: int, prompt: Optional[str], prompt_token_ids: Optional[List[int]],\n                   gen_config: GenerationConfig, adapter_name: str) -> None:\n        max_log_len = self.max_log_len\n        input_tokens = len(prompt_token_ids)\n        if max_log_len is not None:\n            if prompt is not None:\n                prompt = prompt[:max_log_len]\n\n            if prompt_token_ids is not None:\n                prompt_token_ids = prompt_token_ids[:max_log_len]\n\n        logger.info(f'session={session_id}, '\n                    f'adapter_name={adapter_name}, '\n                    f'input_tokens={input_tokens}, '\n                    f'gen_config={gen_config}, '\n                    f'prompt={prompt!r}, '\n                    f'prompt_token_id={prompt_token_ids}')\n"
  },
  {
    "path": "lmdeploy/messages.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport enum\nimport time\nfrom dataclasses import dataclass, field\nfrom typing import Any, Callable, Dict, List, Literal, Optional\n\nimport torch\nfrom pydantic.dataclasses import dataclass as pydantic_dataclass\n\nfrom lmdeploy.pytorch.disagg.config import EngineRole, MigrationBackend\nfrom lmdeploy.pytorch.disagg.conn.protocol import MigrationRequest\n\nfrom .tokenizer import Tokenizer\nfrom .utils import get_logger\n\nlogger = get_logger('lmdeploy')\n\nLogitsProcessor = Callable[[torch.Tensor, torch.Tensor], torch.Tensor]\n\"\"\"LogitsProcessor is a function that takes a tensor of input_ids, the logits\ntensor for the next token, and returns a modified tensor of logits to sample\nfrom.\"\"\"\n\n\n@dataclass\nclass GenerationConfig:\n    \"\"\"Generation parameters used by inference engines.\n\n    Args:\n        n: Define how many chat completion choices to generate for each\n            input message. **Only 1** is supported now.\n        max_new_tokens: The maximum number of tokens that can be\n            generated in the chat completion\n        do_sample:  Whether or not to use sampling, use greedy\n            decoding otherwise. Default to be False.\n        top_p: An alternative to sampling with temperature, called\n            nucleus sampling, where the model considers the results of the\n            tokens with top_p probability mass\n        top_k: An alternative to sampling with temperature, where\n            the model considers the top_k tokens with the highest probability\n        min_p: Minimum token probability, which will be scaled by the\n            probability of the most likely token. It must be a value between\n            0 and 1. Typical values are in the 0.01-0.2 range, comparably\n            selective as setting `top_p` in the 0.99-0.8 range (use the\n            opposite of normal `top_p` values)\n        temperature: Sampling temperature\n        repetition_penalty: Penalty to prevent the model from\n            generating repeated words or phrases. A value larger than\n            1 discourages repetition\n        ignore_eos: Indicator to ignore the eos_token_id or not\n        random_seed: Seed used when sampling a token\n        stop_words: Words that stop generating further tokens\n        bad_words: Words that the engine will never generate\n        stop_token_ids: List of tokens that stop the generation\n            when they are generated. The returned output will not contain\n            the stop tokens.\n        bad_token_ids: List of tokens that the engine will never\n            generate.\n        min_new_tokens: The minimum numbers of tokens to generate,\n            ignoring the number of tokens in the prompt.\n        skip_special_tokens: Whether or not to remove special tokens\n            in the decoding. Default to be True.\n        spaces_between_special_tokens: Whether or not to add spaces\n            around special tokens. The behavior of Fast tokenizers is to have\n            this to False. This is setup to True in slow tokenizers.\n        logprobs: Number of log probabilities to return per output token.\n        response_format: Generate responses according to given formatting.\n            Examples:\n\n            .. code-block:: json\n\n                {\n                    \"type\": \"json_schema\",\n                    \"json_schema\": {\n                        \"name\": \"test\",\n                        \"schema\": {\n                        \"properties\": {\n                            \"name\": {\n                            \"type\": \"string\"\n                            }\n                        },\n                        \"required\": [\"name\"],\n                        \"type\": \"object\"\n                        }\n                    }\n                }\n\n\n            or,\n\n            .. code-block:: json\n\n                {\n                    \"type\": \"regex_schema\",\n                    \"regex_schema\": \"call me [A-Za-z]{1,10}\"\n                }\n\n        logits_processors: Custom logit processors.\n        repetition_ngram_size: The size of n-grams to consider for repetition early stop.\n        repetition_ngram_threshold: The number of times an n-gram must be repeated to trigger early stop.\n    \"\"\"\n\n    n: int = 1\n    max_new_tokens: int = 512\n    do_sample: bool = False\n    top_p: float = 1.0\n    top_k: int = 50\n    min_p: float = 0.0\n    temperature: float = 0.8\n    repetition_penalty: float = 1.0\n    ignore_eos: bool = False\n    random_seed: int = None\n    stop_words: List[str] = None\n    bad_words: List[str] = None\n    stop_token_ids: List[int] = None\n    bad_token_ids: List[int] = None\n    min_new_tokens: int = None\n    skip_special_tokens: bool = True\n    spaces_between_special_tokens: bool = True\n    logprobs: int = None\n    response_format: Optional[Dict] = None\n    logits_processors: Optional[List[LogitsProcessor]] = None\n    output_logits: Literal['all', 'generation'] = None\n    output_last_hidden_state: Literal['all', 'generation'] = None\n    include_stop_str_in_output: bool = False\n\n    # for disaggregation\n    with_cache: bool = False\n    preserve_cache: bool = False\n    migration_request: Optional[MigrationRequest] = None\n\n    # router replay\n    return_routed_experts: bool = False\n\n    # ngram, generation would stop if latest [size] tokens are repeated for [threshold] times\n    repetition_ngram_size: int = 0\n    repetition_ngram_threshold: int = 0\n\n    def convert_stop_bad_words_to_ids(self, tokenizer: Tokenizer):\n        \"\"\"Convert stop_words/bad_sords to ids and append the ids to\n        stop_token_ids/bad_token_ids.\"\"\"\n\n        def special_word_token_ids(words):\n            if words is not None:\n                assert isinstance(words, List) and \\\n                    all(isinstance(elem, str) for elem in words), \\\n                    f'stop_words must be a list of str but got {type(words)}'\n                indexes = []\n                for word in words:\n                    indexes += tokenizer.indexes_containing_token(word)\n                return indexes\n            return None\n\n        stop_token_ids = special_word_token_ids(self.stop_words) or []\n        bad_token_ids = special_word_token_ids(self.bad_words) or []\n        stop_token_ids.extend(self.stop_token_ids or [])\n        bad_token_ids.extend(self.bad_token_ids or [])\n        self.stop_token_ids = list(set(stop_token_ids)) or None\n        self.bad_token_ids = list(set(bad_token_ids)) or None\n\n    def update_from_hf_gen_cfg(self, generation_config, tokenizer_eos_token_id):\n        \"\"\"Update the stop_token_ids.\"\"\"\n        stop_token_ids = set(self.stop_token_ids or [])\n\n        # add tokenizer's eos_token_id\n        if tokenizer_eos_token_id is not None:\n            stop_token_ids.add(tokenizer_eos_token_id)\n\n        # add eos_token_id from model's generation_config.json file if there\n        # is any.\n        eos_token_id = generation_config.get('eos_token_id')\n        if eos_token_id is not None:\n            if isinstance(eos_token_id, int):\n                stop_token_ids.add(eos_token_id)\n            else:\n                stop_token_ids.update(eos_token_id)\n\n        self.stop_token_ids = list(stop_token_ids)\n\n    def __post_init__(self):\n        \"\"\"Check input validation.\"\"\"\n        assert type(self.n) == int and self.n > 0, 'n is not a positive integer'\n        assert self.top_p >= 0 and self.top_p <= 1  # [0, 1]\n        assert self.top_k >= 0, 'top_k can not be a negative integer'\n        assert self.temperature >= 0 and self.temperature <= 2  # [0,2]\n        assert 0 <= self.min_p <= 1, \\\n            f'min_p should be in range [0, 1], but found {self.min_p}'\n\n\n@pydantic_dataclass\nclass TurbomindEngineConfig:\n    \"\"\"TurboMind Engine config.\n\n    Args:\n        dtype: data type for model weights and activations. It can be\n            one of the following values, ['auto', 'float16', 'bfloat16']\n            The `auto` option will use FP16 precision for FP32 and FP16\n            models, and BF16 precision for BF16 models.\n        model_format: the layout of the deployed model. It can be one\n            of the following values [hf, awq, gptq],`hf` meaning\n            huggingface model(.bin, .safetensors), `awq` and `gptq` meaning\n            the quantized model by AWQ and GPTQ, respectively. If it is not\n            specified, i.e. None, it will be extracted from the input model\n        tp: the number of GPU cards used in tensor parallelism,\n            default to 1\n        session_len: the max session length of a sequence, default to\n            None\n        max_batch_size: the max batch size during inference. If it is\n            not specified, the engine will automatically set it according to\n            the device\n        cache_max_entry_count: the percentage of gpu memory occupied\n            by the k/v cache.\n            For versions of lmdeploy between `v0.2.0` and `v0.2.1`, it\n            defaults to 0.5, depicting the percentage of TOTAL GPU memory to\n            be allocated to the k/v cache.\n            For lmdeploy versions greater than `v0.2.1`, it defaults to 0.8,\n            signifying the percentage of FREE GPU memory to be reserved for\n            the k/v cache.\n            When it's an integer > 0, it represents the total number of k/v\n            blocks.\n        cache_chunk_size: The policy to apply for KV block from\n            the block manager, default to -1.\n        cache_block_seq_len: the length of the token sequence in\n            a k/v block, default to 64\n        enable_prefix_caching: enable cache prompts for block reuse,\n            default to False\n        quant_policy: default to 0. When k/v is quantized into 4 or 8\n            bit, set it to 4 or 8, respectively\n        rope_scaling_factor: scaling factor used for dynamic ntk,\n            default to 0. TurboMind follows the implementation of transformer\n            LlamaAttention\n        use_logn_attn: whether or not to use log attn: default to False\n        download_dir: Directory to download and load the weights,\n            default to the default cache directory of huggingface.\n        revision: The specific model version to use. It can be a branch\n            name, a tag name, or a commit id. If unspecified, will use the\n            default version.\n        max_prefill_token_num: the number of tokens each iteration during\n            prefill, default to 8192\n        num_tokens_per_iter: the number of tokens processed in each\n            forward pass. Working with `max_prefill_iters` enables the\n            \"Dynamic SplitFuse\"-like scheduling\n        max_prefill_iters: the max number of forward pass during prefill\n            stage\n        async_: enable async execution, default to 1 (enabled)\n        devices: the used devices\n        empty_init: Whether to load the model weights, you should set\n            it to True if you want to update weights after create the pipeline\n        hf_overrides: Huggingface overrides for the model.\n            It can be used to override the default config of the model\n        enable_metrics: enable metrics system\n    \"\"\"\n\n    dtype: str = 'auto'\n    model_format: Optional[str] = None\n    tp: int = 1\n    dp: int = 1\n    cp: int = 1\n    device_num: int = None\n    attn_tp_size: int = None\n    attn_cp_size: int = None\n    attn_dp_size: int = None\n    mlp_tp_size: int = None\n    mlp_dp_size: int = None\n    outer_dp_size: int = None\n    nnodes: int = 1\n    node_rank: int = 0\n    dist_init_addr: Optional[str] = None\n    devices: List[int] = None\n    session_len: Optional[int] = None\n    max_batch_size: int = None\n    cache_max_entry_count: float = 0.8\n    cache_chunk_size: int = -1\n    cache_block_seq_len: int = 64\n    enable_prefix_caching: bool = False\n    quant_policy: int = 0\n    rope_scaling_factor: float = 0.0\n    use_logn_attn: bool = False\n    download_dir: Optional[str] = None\n    revision: Optional[str] = None\n    max_prefill_token_num: int = 8192\n    num_tokens_per_iter: int = 0\n    max_prefill_iters: int = 1\n    async_: int = 1\n    devices: Optional[List[int]] = None\n    empty_init: bool = False\n    communicator: str = 'nccl'\n    hf_overrides: Optional[Dict[str, Any]] = None\n    enable_metrics: bool = True\n\n    def __post_init__(self):\n        \"\"\"Check input validation.\"\"\"\n        assert self.dtype in ['auto', 'float16', 'bfloat16']\n        assert self.tp >= 1, 'tp must be a positive integer'\n        assert self.cache_max_entry_count > 0, 'invalid cache_max_entry_count'\n        assert self.quant_policy in (0, 4, 8), 'invalid quant_policy'\n        assert self.rope_scaling_factor >= 0, 'invalid rope_scaling_factor'\n        assert self.max_prefill_token_num >= 0, \\\n            'invalid max_prefill_token_num'\n        assert self.num_tokens_per_iter >= 0, 'invalid num_tokens_per_iter'\n        assert self.async_ in (0, 1), 'async_ must be 0 (disabled) or 1 (enabled)'\n\n\n@dataclass\nclass PytorchEngineConfig:\n    \"\"\"PyTorch Engine Config.\n\n    Args:\n        dtype: data type for model weights and activations. It can be\n            one of the following values, ['auto', 'float16', 'bfloat16']\n            The `auto` option will use FP16 precision for FP32 and FP16\n            models, and BF16 precision for BF16 models.\n        tp: Tensor Parallelism. default 1.\n        dp: Data Parallelism. default 1.\n        dp_rank: rank of dp.\n        ep: Expert Parallelism. default 1.\n        session_len: Max session length. Default None.\n        max_batch_size: Max batch size. If it is not specified,\n            the engine will automatically set it according to the device\n        attn_tp_size: tp size for attention, only works for dp>1\n        mlp_tp_size: tp size for mlp, only works for dp>1\n        moe_tp_size: tp size for moe, only works for dp>1\n        cache_max_entry_count: the percentage of gpu memory occupied\n            by the k/v cache. For lmdeploy versions greater than `v0.2.1`,\n            it defaults to 0.8, signifying the percentage of FREE GPU memory\n            to be reserved for the k/v cache\n        prefill_interval: Interval to perform prefill,\n            Default 16.\n        block_size: paging cache block size, default 64.\n        num_cpu_blocks: Num cpu blocks. If num is 0, cache\n            would be allocate according to current environment.\n        num_gpu_blocks: Num gpu blocks. If num is 0, cache\n            would be allocate according to current environment.\n        adapters: The path configs to lora adapters.\n        max_prefill_token_num: tokens per iteration.\n        thread_safe: thread safe engine instance.\n        enable_prefix_caching: Enable token match and sharing caches.\n        device_type: The inference device type, options ['cuda']\n        eager_mode: Enable \"eager\" mode or not\n        custom_module_map: nn module map customized by users. Once\n            provided, the original nn modules of the model will be\n            substituted by the mapping ones\n        download_dir: Directory to download and load the weights,\n            default to the default cache directory of huggingface.\n        revision: The specific model version to use.\n            It can be a branch name, a tag name, or a commit id.\n            If unspecified, will use the default version.\n        quant_policy: default to 0. When k/v is quantized into 4 or 8\n            bit, set it to 4 or 8, respectively\n        distributed_executor_backend: backend of distributed backend,\n            options: ['uni', 'mp', 'ray']\n        empty_init: Whether to load the model weights, you should set\n            it to True if you want to update weights after create the pipeline\n        enable_microbatch: enable microbatch for specified model\n        enable_eplb: enable eplb for specified model\n        enable_metrics: enable metrics system\n        role: role of engin, options: ['Hybrid', 'Prefill',\n            'Decode']. Default to `EngineRole.Hybrid`.\n        migration_backend: migration backend. options: ['DLSlime'].\n            Default to `MigrationBackend.DLSlime`.\n        enable_mp_engine: run engine in multi-process mode.\n        mp_engine_backend: backend of mp engine, options:\n            ['mp', 'ray']. Default to `mp`.\n        model_format: weight quantization policy, options: ['fp8'].\n        hf_overrides: Huggingface overrides for the model.\n            It can be used to override the default config of the model,\n        disable_vision_encoder: Whether to disable loading vision\n            encoder. Default to False.\n        logprobs_mode: The mode of logprob, options: ['raw_logits', 'raw_logprobs']\n        dllm_block_length: Block size of block diffusion model.\n        dllm_unmasking_strategy: Dllm unmasking strategy, options:\n            ['low_confidence_dynamic', 'low_confidence_static', 'sequential'].\n        dllm_denoising_steps: Dllm denoising steps.\n        dllm_confidence_threshold: dllm unmasking threshold for\n            dynamic unmasking.\n    \"\"\"\n    dtype: str = 'auto'\n    tp: int = 1\n    dp: int = 1\n    dp_rank: int = 0\n    ep: int = 1\n    session_len: int = None\n    max_batch_size: int = None\n    attn_tp_size: int = None\n    mlp_tp_size: int = None\n    moe_tp_size: int = None\n    cache_max_entry_count: float = 0.8\n    prefill_interval: int = 16\n    block_size: int = 64\n    num_cpu_blocks: int = 0\n    num_gpu_blocks: int = 0\n    adapters: Dict[str, str] = None\n    max_prefill_token_num: int = 4096\n    thread_safe: bool = False\n    enable_prefix_caching: bool = False\n    device_type: str = 'cuda'\n    eager_mode: bool = False\n    custom_module_map: Dict[str, str] = None\n    download_dir: str = None\n    revision: str = None\n    quant_policy: Literal[0, 4, 8] = 0\n    distributed_executor_backend: str = None\n    empty_init: bool = False\n    enable_microbatch: bool = False\n    enable_eplb: bool = False\n    enable_mp_engine: bool = False\n    mp_engine_backend: str = 'mp'\n    model_format: str = None\n    enable_metrics: bool = True\n    hf_overrides: Optional[Dict[str, Any]] = None\n    disable_vision_encoder: bool = False\n    logprobs_mode: str = None\n    # router replay\n    enable_return_routed_experts: bool = False\n    enable_transfer_obj_ref: bool = False\n\n    # dllm\n    dllm_block_length: int = None\n    dllm_unmasking_strategy: str = 'low_confidence_dynamic'\n    dllm_denoising_steps: int = None\n    dllm_confidence_threshold: float = 0.85\n\n    role: EngineRole = EngineRole.Hybrid\n    migration_backend: MigrationBackend = MigrationBackend.DLSlime\n\n    def __post_init__(self):\n        \"\"\"Check input validation.\"\"\"\n        assert self.dtype in ['auto', 'float16', 'bfloat16']\n        assert self.tp >= 1, 'invalid tp'\n        assert self.dp >= 1, 'invalid dp'\n        assert self.ep >= 1, 'invalid ep'\n        assert 0 < self.cache_max_entry_count < 1, \\\n            'invalid cache_max_entry_count'\n        assert self.num_cpu_blocks >= 0, 'invalid num_cpu_blocks'\n        assert self.max_prefill_token_num >= 0, \\\n            'invalid max_prefill_token_num'\n        assert self.num_gpu_blocks >= 0, 'invalid num_gpu_blocks'\n        assert self.quant_policy in (0, 4, 8), 'invalid quant_policy'\n        assert self.device_type in ['cuda', 'ascend', 'maca', 'camb'], (f'invalid device_type: {self.device_type}')\n        assert self.block_size >= 16 and (self.block_size & (self.block_size - 1)) == 0, \\\n            f'block_size must be >= 16 and a power of 2, but got {self.block_size}'\n        if self.quant_policy > 0 and self.device_type not in ['cuda', 'ascend']:\n            assert False, \\\n                   'kv cache quantization only works for CUDA and ASCEND.'\n        if self.device_type == 'camb' and self.block_size != 16:\n            self.block_size = 16\n            logger.warning('Currently, camb device requires block size to be 16, \\\n                    setting block size to 16')\n\n\nclass ResponseType(enum.Enum):\n    \"\"\"Response type.\"\"\"\n\n    SUCCESS = enum.auto()\n    FINISH = enum.auto()\n    ENGINE_STOP_ERROR = enum.auto()\n    SESSION_REPEAT = enum.auto()\n    SESSION_NOT_EXIST = enum.auto()\n    HANDLER_NOT_EXIST = enum.auto()\n    INPUT_LENGTH_ERROR = enum.auto()\n    INTERNAL_ENGINE_ERROR = enum.auto()\n    CANCEL = enum.auto()\n    PREFIX_CACHE_CONFLICT_INTERACTIVE_MODE = enum.auto()\n    NO_QUEUE = enum.auto()\n\n\n@dataclass\nclass Response:\n    \"\"\"Pack all response information together.\n\n    Args:\n        text: the response text from the server. If the output text is\n            an empty str and the finish_reason is length, it means the session\n            length is reached.\n        generate_token_len: the response token length.\n        input_token_len: the input prompt token length. Note that it may\n            contains chat template part.\n        session_id: the id for running the session.\n        finish_reason: the reason the model stopped\n            generating tokens. This will be 'stop' if the model hit a natural\n            stop point or a provided stop sequence, 'length' if the maximum\n            number of tokens specified in the request was reached.\n        token_ids:: the output token ids.\n        logprobs:: the top logprobs for each output\n            position.\n        index: it refers to the position index of the input request\n            batch\n    \"\"\"\n    text: str\n    generate_token_len: int\n    input_token_len: int\n    finish_reason: Optional[Literal['stop', 'length']] = None\n    token_ids: List[int] = field(default_factory=list)\n    logprobs: List[Dict[int, float]] = None\n    logits: torch.Tensor = None\n    last_hidden_state: torch.Tensor = None\n    index: int = 0\n    routed_experts: Any = None\n\n    def __str__(self):\n        return f'text={self.text}\\n{self._format_none_text_fields()}'\n\n    def __repr__(self):\n        return f'text={self.text!r}\\n{self._format_none_text_fields()}'\n\n    def _format_none_text_fields(self):\n        fields = []\n        fields.append(f'input_token_len={self.input_token_len}')\n        fields.append(f'generate_token_len={self.generate_token_len}')\n        fields.append(f'finish_reason=\"{self.finish_reason}\"')\n        fields.append(f'token_ids={self.token_ids}')\n        fields.append(f'logprobs={self.logprobs}')\n\n        # Helper function to format tensor information\n        def _format_tensor(name: str, tensor: Optional[torch.Tensor]) -> List[str]:\n            if tensor is None:\n                return [f'{name}=None']\n            try:\n                return [f'{name}.shape={tensor.shape}', f'{name}={tensor}']\n            except:  # noqa\n                # in case tensor is not torch.Tensor or has no shape\n                return [f'{name}={tensor}']\n\n        # Format tensor fields\n        fields.extend(_format_tensor('logits', self.logits))\n        fields.extend(_format_tensor('last_hidden_state', self.last_hidden_state))\n        fields.extend(_format_tensor('routed_experts', self.routed_experts))\n        return '\\n'.join(fields)\n\n    def extend(self, other: 'Response') -> 'Response':\n        \"\"\"Extend this response with another response.\n\n        This method merges the content of another Response into this one,\n        similar to list.extend(). The text, token_ids, and logprobs are\n        concatenated, while other fields are updated from the other response.\n\n        Args:\n            other: Another Response to append to this one.\n\n        Returns:\n            Self (for method chaining).\n        \"\"\"\n        self.text += other.text\n        self.generate_token_len = other.generate_token_len\n        self.input_token_len = other.input_token_len\n        self.finish_reason = other.finish_reason\n        self.index = other.index\n        if other.token_ids:\n            self.token_ids += other.token_ids\n        if other.logprobs:\n            self.logprobs = self.logprobs or []\n            self.logprobs += other.logprobs\n        self.routed_experts = other.routed_experts\n        return self\n\n\n# modified from https://github.com/vllm-project/vllm/blob/main/vllm/v1/engine/__init__.py\nclass EventType(enum.IntEnum):\n    \"\"\"The type of request event.\n\n    QUEUED - when the request was enqued by the engine\n    SCHEDULED - when the request was first scheduled for execution\n    PREEMPTED - the request has been put back in the waiting queue in order to make room for other requests to complete.\n                It will be re-scheduled in future and re-start its prefill phase\n    \"\"\"\n    QUEUED = 1\n    SCHEDULED = 2\n    PREEMPTED = 3  # FIXME, currently ignored for simplicity\n\n\n# modified from https://github.com/vllm-project/vllm/blob/main/vllm/v1/engine/__init__.py\n@dataclass\nclass EngineEvent:\n    \"\"\"A timestamped engine event associated with a request.\n\n    Attributes:\n        type: the type of an event associated with a request during its life cycle\n        timestamp: the WALL-CLOCK time when the event happens.\n    \"\"\"\n    type: EventType\n    timestamp: float\n\n    @classmethod\n    def new_event(cls, event_type: EventType, timestamp: Optional[float] = None) -> 'EngineEvent':\n        # Timestamps MUST use wall-clock time (time.time()) to maintain consistency\n        # between csrc(std::chrono::system_clock) and python\n        timestamp = time.time() if timestamp is None else timestamp\n        return cls(event_type, timestamp)\n\n\n@dataclass\nclass ScheduleMetrics:\n    active_seqs: int = 0\n    waiting_seqs: int = 0\n    total_blocks: int = 0\n    active_blocks: int = 0\n    cached_blocks: int = 0\n    free_blocks: int = 0\n    prefix_cache_hit_rate: float = 0\n\n\n@dataclass\nclass RequestMetrics:\n    \"\"\"Basic metrics for a request.\n\n    Attributes:\n        token_timestamp: A wall-clock time when a token is generated.\n        engine_events: List of engine events during inference.\n    \"\"\"\n    token_timestamp: float = 0.0\n    engine_events: List[EngineEvent] = field(default_factory=list)\n    spec_info: Optional[Dict[str, Any]] = None\n\n\n@dataclass\nclass EngineOutput:\n    \"\"\"Engine output from turbomind/pytorch engine.\n\n    Args:\n        status: the response type.\n        token_ids: the newly generated token ids in each iteration.\n        logprobs: the top logprobs for each output\n            position.\n        cache_block_ids: send cache blocks back for migration in\n            Disaggregated LLM Serving when Prefill Engine is Done.\n        req_metrics: request metrics information\n    \"\"\"\n    status: ResponseType\n    token_ids: List[int]\n    logprobs: List[Dict[int, float]] = None\n    logits: torch.Tensor = None\n    last_hidden_state: torch.Tensor = None\n    cache_block_ids: Optional[List[int]] = None\n    req_metrics: Optional[RequestMetrics] = None\n    routed_experts: torch.Tensor = None\n\n\n@dataclass\nclass VisionConfig:\n    \"\"\"Vision model configs.\n\n    Args:\n        max_batch_size: the max image size passed to the model, since\n            some models will use image patch, the actual running batch could\n            be larger than this value.\n        thread_safe: Specifies whether the engine instance is\n            thread-safe. Please set it to True when using the pipeline\n            in a multi-threaded environment.\n    \"\"\"\n    max_batch_size: int = 1\n    thread_safe: bool = False\n\n\n@dataclass\nclass SpeculativeConfig:\n    \"\"\"Speculative decoding config.\n\n    Args:\n        method: the speculative decoding method.\n        model: the path of speculative model.\n        num_speculative_tokens: number of generated token of draft model per step\n    \"\"\"\n    method: str\n    model: str = ''\n    num_speculative_tokens: int = 1\n"
  },
  {
    "path": "lmdeploy/metrics/__init__.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n"
  },
  {
    "path": "lmdeploy/metrics/loggers.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n# adapted from: https://github.com/vllm-project/vllm/blob/main/vllm/v1/metrics/loggers.py\n\nimport time\nfrom abc import ABC, abstractmethod\nfrom datetime import datetime\nfrom typing import List\n\nimport numpy as np\n\nfrom lmdeploy.metrics.stats import IterationStats, RequestStats, SchedulerStats, SpeculativeDecodingStats\nfrom lmdeploy.utils import get_logger\n\nlogger = get_logger('lmdeploy')\n\n\nclass StatLoggerBase(ABC):\n\n    @abstractmethod\n    def record_schedule(self, stats: SchedulerStats) -> None:\n        ...\n\n    @abstractmethod\n    def record_iteration(self, stats: IterationStats) -> None:\n        ...\n\n    @abstractmethod\n    def record_specdecode(self, stats: SpeculativeDecodingStats) -> None:\n        ...\n\n    def log(self):  # noqa\n        pass\n\n\nclass LoggingStatLogger(StatLoggerBase):\n\n    def __init__(self, dp_rank: int = 0):\n        self.dp_rank = dp_rank\n        self._reset(time.perf_counter())\n        self.last_scheduler_stats = SchedulerStats()\n\n    def _reset(self, now):\n        self.last_log_time = now\n        self.total_prompt_tokens = 0\n        self.total_generation_tokens = 0\n        # spec decode\n        self.num_drafts: int = 0\n        self.num_draft_tokens: int = 0\n        self.num_accepted_tokens: int = 0\n        self.num_accepted_tokens_per_pos: np.ndarray = None\n\n    def record_schedule(self, stats: SchedulerStats):\n        self.last_scheduler_stats = stats\n\n    def record_iteration(self, stats: IterationStats):\n        # In the first iteration of a sequence, stats.prompt_tokens is the\n        # prompt token number of a sequence. In subsequent iterations,\n        # the value is 0. This enables cumulative counting in `total_prompt_tokens`\n        self.total_prompt_tokens += stats.prompt_tokens\n        self.total_generation_tokens += stats.new_generation_tokens\n\n    def record_specdecode(self, stats: SpeculativeDecodingStats):\n        \"\"\"Record spec decoding stats.\"\"\"\n        if stats.num_drafts <= 0:\n            return\n        if self.num_accepted_tokens_per_pos is None:\n            self.num_accepted_tokens_per_pos = np.zeros(stats.num_spec_tokens)\n        self.num_drafts += stats.num_drafts\n        self.num_draft_tokens += stats.num_draft_tokens\n        self.num_accepted_tokens += stats.num_accepted_tokens\n        self.num_accepted_tokens_per_pos += stats.num_accepted_tokens_per_pos\n\n    def record_finish(self, stats: RequestStats):\n        pass\n\n    def get_spec_msg(self):\n        \"\"\"Get spec decoding logging msg.\"\"\"\n        if self.num_drafts == 0:\n            return None\n\n        draft_acceptance_rate = (self.num_accepted_tokens / self.num_draft_tokens *\n                                 100 if self.num_draft_tokens > 0 else float('nan'))\n\n        # conventionally, mean acceptance length includes the bonus token\n        mean_acceptance_length = 1 + (self.num_accepted_tokens / self.num_drafts)\n\n        acceptance_rates = self.num_accepted_tokens_per_pos / self.num_drafts\n        rates_str = ', '.join(f'{p:.3f}' for p in acceptance_rates)\n\n        log_msg = ('SpecDecoding metrics: '\n                   f'Draft acceptance rate: {draft_acceptance_rate:.2f}%, '\n                   f'Mean acceptance length: {mean_acceptance_length:.2f}, '\n                   f'Accepted: {self.num_accepted_tokens} tokens, '\n                   f'Drafted: {self.num_draft_tokens} tokens, '\n                   f'Per-position acceptance rate: {rates_str}')\n        return log_msg\n\n    def log(self):\n        now = time.perf_counter()\n\n        # skip logging if no tokens were processed\n        if self.total_prompt_tokens == 0 and self.total_generation_tokens == 0:\n            self._reset(now)\n            return\n\n        # derive log information\n        prompt_throughput = self.total_prompt_tokens / (now - self.last_log_time)\n        generation_throughput = self.total_generation_tokens / (now - self.last_log_time)\n        scheduler_stats = self.last_scheduler_stats\n        scheduler_stats.num_api_waiting_reqs = scheduler_stats.num_total_reqs - \\\n            scheduler_stats.num_completed_reqs - scheduler_stats.num_api_routed_reqs\n        spec_msg = self.get_spec_msg()\n\n        # format and print\n        log_msg = (\n            f\"[{datetime.fromtimestamp(time.time()).strftime('%Y-%m-%d %H:%M:%S')} DP{self.dp_rank}] \"\n            f'Avg thr (in/out): {prompt_throughput:.1f} / {generation_throughput:.1f} tokens/s, '\n            f'API server (completed/routed/waiting): {scheduler_stats.num_completed_reqs} / '\n            f'{scheduler_stats.num_api_routed_reqs} / {scheduler_stats.num_api_waiting_reqs}, '\n            f'Engine (running/waiting): {scheduler_stats.num_running_reqs} / {scheduler_stats.num_waiting_reqs}, '\n            f'KV cache: {scheduler_stats.gpu_cache_usage * 100 :.1f}%, ')\n\n        if scheduler_stats.prefix_cache_hit_rate != 0:\n            log_msg += f'Prefix cache hit rate: {scheduler_stats.prefix_cache_hit_rate * 100 :.1f}%, '\n\n        if spec_msg is not None:\n            log_msg += spec_msg\n\n        print(log_msg, flush=True)\n        self._reset(now)\n\n\nclass PrometheusStatLogger(StatLoggerBase):\n\n    def __init__(self, model_name: str, max_model_len: int, dp_rank: int = 0):\n        try:\n            import prometheus_client\n            prometheus_client.disable_created_metrics()  # disable noisy creation timestamp gauge in prometheus\n        except ImportError:\n            raise ImportError(\n                'To use metrics system , please install prometheus_client by `pip install prometheus_client`')\n\n        self.dp_rank = dp_rank\n\n        # unregister any existing lmdeploy collectors\n        for collector in list(prometheus_client.REGISTRY._collector_to_names):\n            if hasattr(collector, '_name') and 'lmdeploy' in collector._name:\n                prometheus_client.REGISTRY.unregister(collector)\n\n        # config information\n        self.info_backend_config = prometheus_client.Info(name='lmdeploy:backend_config',\n                                                          documentation='information of backend_config')\n\n        labelnames = ['model_name', 'engine']\n        labelvalues = [model_name, str(dp_rank)]\n\n        #\n        # Scheduler stats\n        #\n        self.gauge_scheduler_completed = prometheus_client.Gauge(name='lmdeploy:num_requests_completed',\n                                                                 documentation='Number of current completed requests.',\n                                                                 labelnames=labelnames).labels(*labelvalues)\n\n        self.gauge_scheduler_api_routed = prometheus_client.Gauge(\n            name='lmdeploy:num_api_requests_routed',\n            documentation='Number of requests routed to request handles.',\n            labelnames=labelnames).labels(*labelvalues)\n\n        self.gauge_scheduler_api_waiting = prometheus_client.Gauge(\n            name='lmdeploy:num_api_requests_waiting',\n            documentation='Number of requests waiting for free request handles.',\n            labelnames=labelnames).labels(*labelvalues)\n\n        self.gauge_scheduler_running = prometheus_client.Gauge(\n            name='lmdeploy:num_requests_running',\n            documentation='Number of requests in model execution batches.',\n            labelnames=labelnames).labels(*labelvalues)\n\n        self.gauge_scheduler_waiting = prometheus_client.Gauge(\n            name='lmdeploy:num_requests_waiting',\n            documentation='Number of requests waiting to be processed.',\n            labelnames=labelnames).labels(*labelvalues)\n\n        #\n        # GPU cache\n        #\n        self.gauge_gpu_cache_usage = prometheus_client.Gauge(\n            name='lmdeploy:gpu_cache_usage_perc',\n            documentation='GPU KV-cache usage. 1 means 100 percent usage.',\n            labelnames=labelnames).labels(*labelvalues)\n\n        #\n        # Counters\n        #\n        self.counter_prompt_tokens = prometheus_client.Counter(name='lmdeploy:prompt_tokens_total',\n                                                               documentation='Number of prefill tokens processed.',\n                                                               labelnames=labelnames).labels(*labelvalues)\n\n        self.counter_generation_tokens = prometheus_client.Counter(\n            name='lmdeploy:generation_tokens_total',\n            documentation='Number of generation tokens processed.',\n            labelnames=labelnames).labels(*labelvalues)\n\n        from lmdeploy.messages import ResponseType\n        self.counter_request_success: dict[ResponseType, prometheus_client.Counter] = {}\n        counter_request_success_base = prometheus_client.Counter(\n            name='lmdeploy:request_success_total',\n            documentation='Count of successfully processed requests.',\n            labelnames=labelnames + ['finished_reason'])\n        for reason in ResponseType:\n            self.counter_request_success[reason] = counter_request_success_base.labels(*(labelvalues + [str(reason)]))\n\n        #\n        # Histograms of counts\n        #\n        self.histogram_num_prompt_tokens_request = \\\n            prometheus_client.Histogram(\n                name='lmdeploy:request_prompt_tokens',\n                documentation='Number of prefill tokens processed.',\n                buckets=build_1_2_5_buckets(max_model_len),\n                labelnames=labelnames).labels(*labelvalues)\n\n        self.histogram_num_generation_tokens_request = \\\n            prometheus_client.Histogram(\n                name='lmdeploy:request_generation_tokens',\n                documentation='Number of generation tokens processed.',\n                buckets=build_1_2_5_buckets(max_model_len),\n                labelnames=labelnames).labels(*labelvalues)\n\n        self.histogram_iteration_tokens = \\\n            prometheus_client.Histogram(\n                name='lmdeploy:iteration_tokens_total',\n                documentation='Histogram of number of tokens per engine_step.',\n                buckets=[\n                    1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192,\n                    16384\n                ],\n                labelnames=labelnames).labels(*labelvalues)\n\n        #\n        # Histogram of timing intervals\n        #\n        self.histogram_time_to_first_token = \\\n            prometheus_client.Histogram(\n                name='lmdeploy:time_to_first_token_seconds',\n                documentation='Histogram of time to first token in seconds.',\n                buckets=[\n                    0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.25, 0.5,\n                    0.75, 1.0, 2.5, 5.0, 7.5, 10.0, 20.0, 40.0, 80.0, 160.0,\n                    640.0, 2560.0\n                ],\n                labelnames=labelnames).labels(*labelvalues)\n\n        self.histogram_time_per_output_token = \\\n            prometheus_client.Histogram(\n                name='lmdeploy:time_per_output_token_seconds',\n                documentation='Histogram of time per output token in seconds.',\n                buckets=[\n                    0.005, 0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5,\n                    0.75, 1.0, 2.5, 5.0, 7.5, 10.0, 20.0, 40.0, 80.0\n                ],\n                labelnames=labelnames).labels(*labelvalues)\n\n        self.histogram_iter_token_latency = \\\n            prometheus_client.Histogram(\n                name='lmdeploy:iter_token_latency',\n                documentation='Histogram of inter-token latency',\n                buckets=[\n                    0.005, 0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5,\n                    0.75, 1.0, 2.5, 5.0, 7.5, 10.0, 20.0, 40.0, 80.0\n                ],\n                labelnames=labelnames).labels(*labelvalues)\n\n        request_latency_buckets = [\n            0.3, 0.5, 0.8, 1.0, 1.5, 2.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 50.0, 60.0, 120.0, 240.0, 480.0,\n            960.0, 1920.0, 7680.0\n        ]\n        self.histogram_e2e_time_request = \\\n            prometheus_client.Histogram(\n                name='lmdeploy:e2e_request_latency_seconds',\n                documentation='Histogram of e2e request latency in seconds.',\n                buckets=request_latency_buckets,\n                labelnames=labelnames).labels(*labelvalues)\n        self.histogram_queue_time_request = \\\n            prometheus_client.Histogram(\n                name='lmdeploy:request_queue_time_seconds',\n                documentation='Histogram of time spent in WAITING phase for request.',\n                buckets=request_latency_buckets,\n                labelnames=labelnames).labels(*labelvalues)\n        self.histogram_inference_time_request = \\\n            prometheus_client.Histogram(\n                name='lmdeploy:request_inference_time_seconds',\n                documentation='Histogram of time spent in RUNNING phase for request.',\n                buckets=request_latency_buckets,\n                labelnames=labelnames).labels(*labelvalues)\n        self.histogram_prefill_time_request = \\\n            prometheus_client.Histogram(\n                name='lmdeploy:request_prefill_time_seconds',\n                documentation='Histogram of time spent in PREFILL phase for request.',\n                buckets=request_latency_buckets,\n                labelnames=labelnames).labels(*labelvalues)\n        self.histogram_decode_time_request = \\\n            prometheus_client.Histogram(\n                name='lmdeploy:request_decode_time_seconds',\n                documentation='Histogram of time spent in DECODE phase for request.',\n                buckets=request_latency_buckets,\n                labelnames=labelnames).labels(*labelvalues)\n\n    def record_schedule(self, stats: SchedulerStats) -> None:\n        \"\"\"Report schedule metrics to prometheus.\"\"\"\n        self.gauge_scheduler_completed.set(stats.num_completed_reqs)\n        self.gauge_scheduler_api_routed.set(stats.num_api_routed_reqs)\n        self.gauge_scheduler_api_waiting.set(stats.num_total_reqs - stats.num_completed_reqs -\n                                             stats.num_api_routed_reqs)\n        self.gauge_scheduler_running.set(stats.num_running_reqs)\n        self.gauge_scheduler_waiting.set(stats.num_waiting_reqs)\n        self.gauge_gpu_cache_usage.set(stats.gpu_cache_usage)\n\n    def record_iteration(self, stats: IterationStats) -> None:\n        \"\"\"Report token-related metrics to prometheus.\"\"\"\n\n        self.counter_prompt_tokens.inc(stats.prompt_tokens)\n        self.counter_generation_tokens.inc(stats.new_generation_tokens)\n        self.histogram_iteration_tokens.observe(stats.prompt_tokens + stats.new_generation_tokens)\n\n        if stats.ttft:\n            self.histogram_time_to_first_token.observe(stats.ttft)\n\n        if stats.tpot:\n            self.histogram_time_per_output_token.observe(stats.tpot)\n\n        if stats.itl:\n            self.histogram_iter_token_latency.observe(stats.itl)\n\n    def record_finish(self, stats: RequestStats) -> None:\n        self.counter_request_success[stats.finish_reason].inc()\n        self.histogram_e2e_time_request.observe(stats.e2e_latency)\n        self.histogram_queue_time_request.observe(stats.queued_time_interval)\n        self.histogram_prefill_time_request.observe(stats.prefill_time_interval)\n        self.histogram_inference_time_request.observe(stats.inference_time_interval)\n        self.histogram_decode_time_request.observe(stats.decode_time_interval)\n        self.histogram_num_prompt_tokens_request.observe(stats.prompt_tokens)\n        self.histogram_num_generation_tokens_request.observe(stats.generation_tokens)\n\n    def record_specdecode(self, stats: SpeculativeDecodingStats) -> None:\n        pass\n\n\ndef build_buckets(mantissa_lst: List[int], max_value: int) -> List[int]:\n    \"\"\"Builds a list of buckets with increasing powers of 10 multiplied by\n    mantissa values until the value exceeds the specified maximum.\"\"\"\n    exponent = 0\n    buckets: List[int] = []\n    while True:\n        for m in mantissa_lst:\n            value = m * 10**exponent\n            if value <= max_value:\n                buckets.append(value)\n            else:\n                return buckets\n        exponent += 1\n\n\ndef build_1_2_5_buckets(max_value: int) -> List[int]:\n    \"\"\"\n    Example:\n    >>> build_1_2_5_buckets(100)\n    [1, 2, 5, 10, 20, 50, 100]\n    \"\"\"\n    return build_buckets([1, 2, 5], max_value)\n"
  },
  {
    "path": "lmdeploy/metrics/metrics_processor.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport asyncio\n\nfrom lmdeploy.messages import ResponseType, ScheduleMetrics\nfrom lmdeploy.pytorch.utils import singleton\nfrom lmdeploy.utils import get_logger\n\nfrom .stats import SchedulerStats\n\nlogger = get_logger('lmdeploy')\n\n\n@singleton\nclass MetricsProcessor():\n    \"\"\"Metrics processor.\"\"\"\n\n    def __init__(self):\n        \"\"\"Init metrics processor.\"\"\"\n        self.enable_metrics: bool = False\n        self.scheduler_stats = SchedulerStats()\n        self.stat_loggers = []\n        self.metrics_queue: asyncio.Queue = None\n        self.metrics_handler: asyncio.Task = None\n\n    def start_metrics_handler(self, enable_metrics: bool):\n        \"\"\"Start metrics handler.\"\"\"\n        self.enable_metrics = enable_metrics\n        if enable_metrics and self.metrics_handler is None:\n            self.metrics_queue = asyncio.Queue()\n            self.metrics_handler = asyncio.create_task(self._run_metrics_handler())\n            logger.info('Metrics handler task started.')\n\n    async def stop_metrics_handler(self):\n        \"\"\"Stop metrics handler.\"\"\"\n        if self.metrics_handler is not None:\n            self.metrics_handler.cancel()\n            try:\n                await self.metrics_handler\n            except asyncio.CancelledError:\n                pass  # Expected cancellation\n            finally:\n                self.metrics_handler = None\n                logger.info('Metrics handler task stopped.')\n\n    async def _run_metrics_handler(self):\n        \"\"\"A background task that consumes and processes metrics data.\"\"\"\n        while True:\n            try:\n                # fetch data from the queue\n                update_data = await self.metrics_queue.get()\n                outputs, req_stats, iteration_stats, specdecode_stats = update_data\n\n                # update request stats\n                if outputs and outputs.req_metrics:\n                    # when users visit \"/abort_request\" endpoint, `req_metrics` might be None\n                    req_stats.update_from_events(outputs.req_metrics.engine_events)\n\n                # update iteration stats\n                # some attributes of req_stats will also be updated, e.g., lastest_token_time\n                iteration_stats.update_from_output(outputs, req_stats)\n\n                # update spec decode stats\n                if specdecode_stats is not None:\n                    specdecode_stats.update_from_output(outputs)\n\n                # record iteration stats\n                for stat_logger in self.stat_loggers:\n                    stat_logger.record_iteration(iteration_stats)\n                    if specdecode_stats is not None:\n                        stat_logger.record_specdecode(specdecode_stats)\n\n                # record finished request stats\n                if outputs.status == ResponseType.FINISH:\n                    for stat_logger in self.stat_loggers:\n                        stat_logger.record_finish(req_stats)\n\n                self.metrics_queue.task_done()\n            except asyncio.CancelledError:\n                break\n            except Exception as e:\n                logger.exception(f'Metrics handler background task failed: {e}')\n\n    async def update_schedule_stats(self, schedule_metrics: ScheduleMetrics):\n        \"\"\"Update schedule stats.\"\"\"\n        self.scheduler_stats.update_from_schedule_metrics(schedule_metrics)\n        # record schedule stats\n        for stat_logger in self.stat_loggers:\n            stat_logger.record_schedule(self.scheduler_stats)\n\n    def queue_update(self, update_data: tuple):\n        \"\"\"Queue update.\"\"\"\n        if not self.enable_metrics or self.metrics_queue is None:\n            return\n        self.metrics_queue.put_nowait(update_data)\n\n    def increase_total_requests(self):\n        \"\"\"Increase total requests.\"\"\"\n        self.scheduler_stats.num_total_reqs += 1\n\n    def increase_completed_requests(self):\n        \"\"\"Increase completed requests.\"\"\"\n        self.scheduler_stats.num_completed_reqs += 1\n\n    def increase_api_routed_requests(self):\n        \"\"\"Increase API routed requests.\"\"\"\n        self.scheduler_stats.num_api_routed_reqs += 1\n\n    def decrease_api_routed_requests(self):\n        \"\"\"Decrease API routed requests.\"\"\"\n        self.scheduler_stats.num_api_routed_reqs -= 1\n\n\nmetrics_processor = MetricsProcessor()\n"
  },
  {
    "path": "lmdeploy/metrics/stats.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/v1/metrics/stats.py\n\nimport time\nfrom dataclasses import dataclass\nfrom typing import List, Optional\n\nimport numpy as np\n\nfrom lmdeploy.messages import EngineEvent, EngineOutput, ResponseType, ScheduleMetrics\n\n\n@dataclass\nclass SchedulerStats:\n    \"\"\"Stats associated with the scheduler.\n    Desc:\n        Dataflow: client --> API server --> Engine core\n        API server total  = completed + uncompleted = completed + (api_routed + api_waiting)\n        Engine core total = running + waiting = api_routed\n\n    Attributes:\n        num_total_reqs: API server, the number of all requests received since server start.\n        num_completed_reqs: API server, the number of successfully completed requests since server start.\n        num_api_routed_reqs: API server, the number of requests routed to request handles.\n        num_api_waiting_reqs: API server, the number of requests waiting for free request handles.\n        num_running_reqs: Engine core, currently executing requests.\n        num_waiting_reqs: Engine core, requests queued waiting for execution.\n        gpu_cache_usage: Fraction of GPU KV blocks utilized (0.0 to 1.0).\n        prefix_cache_hit_rate: Prefix caching hit rate.\n    \"\"\"\n\n    # api server\n    num_total_reqs: int = 0\n    num_completed_reqs: int = 0\n    num_api_routed_reqs: int = 0\n    num_api_waiting_reqs: int = 0\n\n    # engine core\n    num_running_reqs: int = 0\n    num_waiting_reqs: int = 0\n    gpu_cache_usage: float = 0.0\n    prefix_cache_hit_rate: float = 0.0\n\n    def __repr__(self):\n        return ('SchedulerStats(\\n'\n                f'  num_total_reqs={self.num_total_reqs},\\n'\n                f'  num_completed_reqs={self.num_completed_reqs},\\n'\n                f'  num_api_routed_reqs={self.num_api_routed_reqs},\\n'\n                f'  num_api_waiting_reqs={self.num_api_waiting_reqs},\\n'\n                f'  num_running_reqs={self.num_running_reqs},\\n'\n                f'  num_waiting_reqs={self.num_waiting_reqs},\\n'\n                f'  gpu_cache_usage={self.gpu_cache_usage:.6f},\\n'\n                f'  prefix_cache_hit_rate={self.prefix_cache_hit_rate:.6f},\\n'\n                ')')\n\n    def update_from_schedule_metrics(self, scheduled_metrics: ScheduleMetrics):\n        self.num_running_reqs = scheduled_metrics.active_seqs\n        self.num_waiting_reqs = scheduled_metrics.waiting_seqs\n        self.gpu_cache_usage = 1.0 - (scheduled_metrics.free_blocks / scheduled_metrics.total_blocks)\n        self.prefix_cache_hit_rate = scheduled_metrics.prefix_cache_hit_rate\n\n\nclass RequestStats:\n    \"\"\"Stats associated with a request.\"\"\"\n\n    def __init__(self, arrival_time: float = None, prompt_tokens: int = 0):\n        \"\"\"Initialize the stats of a request.\n\n        Args:\n            arrival_time (float, optional): The timestamp when the request arrives.\n                If not provided, the current time will be used. Defaults to None.\n            prompt_tokens (int, optional): The number of tokens in the prompt. Defaults to 0.\n\n        Attributes:\n            generation_tokens (int): The number of tokens generated during the request inference.\n                It will be updated by IterationStats.update_from_output.\n            queued_time (float): Time when the request is put to the inference engine's queue.\n                It will be updated according the EngineEvent.\n            scheduled_time (float): Time when the request is scheduled to run.\n                It will be updated according the EngineEvent.\n            first_token_time (float): Time when the first token is generated.\n                It will be updated by IterationStats.update_from_output.\n            lastest_token_time (float): Time when the latest token is generated.\n                It will be updated by IterationStats.update_from_output.\n            finish_time (float): Time when a request finishes generation.\n                It will be updated by IterationStats.update_from_output.\n            finish_reason (ResponseType): The reason why the request finished.\n        \"\"\"\n        self.arrival_time = time.time() if arrival_time is None else arrival_time\n        self.prompt_tokens = prompt_tokens\n\n        self.generation_tokens: int = 0\n        self.queued_time: float = 0.0\n        self.scheduled_time: float = 0.0\n        self.first_token_time: float = 0.0\n        self.lastest_token_time: float = 0.0\n        self.finish_time: float = 0.0\n        self.finish_reason: ResponseType = None\n\n    def __repr__(self):\n        return ('RequestStats(\\n'\n                f'  arrival_time={self.arrival_time:.6f},\\n'\n                f'  prompt_tokens={self.prompt_tokens},\\n'\n                f'  generation_tokens={self.generation_tokens},\\n'\n                f'  queued_time={self.queued_time:.6f},\\n'\n                f'  scheduled_time={self.scheduled_time:.6f},\\n'\n                f'  first_token_time={self.first_token_time:.6f},\\n'\n                f'  latest_token_time={self.lastest_token_time:.6f},\\n'\n                ')')\n\n    def update_from_events(self, engine_events: List[EngineEvent]):\n        # avoid circular dependency\n        from lmdeploy.messages import EventType\n\n        for event in engine_events:\n            if event.type == EventType.QUEUED:\n                self.queued_time = event.timestamp\n            elif event.type == EventType.SCHEDULED:\n                if self.scheduled_time == 0.0:  # ignore preemptions\n                    self.scheduled_time = event.timestamp\n            # FIXME: deal with preempted case\n            # elif event.type == EventType.PREEMPTED:\n            #     self.num_preempted_reqs += 1\n\n    @property\n    def e2e_latency(self) -> float:\n        \"\"\"End-to-end latency.\"\"\"\n        return self.finish_time - self.arrival_time\n\n    @property\n    def queued_time_interval(self) -> float:\n        \"\"\"Queued interval is from first QUEUED event to first SCHEDULED.\"\"\"\n        return self.scheduled_time - self.queued_time\n\n    @property\n    def prefill_time_interval(self) -> float:\n        \"\"\"Prefill interval is from first SCHEDULED to first NEW_TOKEN.\n\n        Any preemptions during prefill is included in the interval.\n        \"\"\"\n        return self.first_token_time - self.scheduled_time\n\n    @property\n    def decode_time_interval(self) -> float:\n        \"\"\"Decode interval is from first NEW_TOKEN to last NEW_TOKEN.\n\n        Any preemptions during decode are included.\n        \"\"\"\n        return self.finish_time - self.first_token_time\n\n    @property\n    def inference_time_interval(self) -> float:\n        \"\"\"Inference interval is from first SCHEDULED to last NEW_TOKEN.\n\n        Any preemptions during prefill or decode are included.\n        \"\"\"\n        return self.finish_time - self.scheduled_time\n\n\nclass IterationStats:\n    \"\"\"Stats associated with one token generation iteration of a request.\"\"\"\n\n    def __init__(self):\n        \"\"\"Initialize the stats of one iteration.\n\n        Attributes:\n            iteration_timestamp (float): The timestamp when this iteration finishes.\n            new_generation_tokens (int): The number of newly generated tokens in this iteration.\n            prompt_tokens (int): The number of prompt tokens processed in this iteration.\n            ttft (float | None): Time to First Token (TTFT).\n            tpot (float | None): Time per Output Token (TPOT).\n            itl (float | None): Iter-Token Latency (ITL).\n        \"\"\"\n        self.iteration_timestamp = time.time()\n        self.new_generation_tokens = 0\n        self.prompt_tokens = 0\n        self.ttft: Optional[float] = None\n        self.tpot: Optional[float] = None\n        self.itl: Optional[float] = None\n\n    def __repr__(self):\n        return ('IterationStats(\\n'\n                f'  iteration_timestamp={self.iteration_timestamp:.6f},\\n'\n                f'  new_generation_tokens={self.new_generation_tokens},\\n'\n                f'  prompt_tokens={self.prompt_tokens},\\n'\n                f'  ttft={self.ttft},\\n'\n                f'  tpot={self.tpot},\\n'\n                f'  itl={self.itl},\\n'\n                ')')\n\n    def _time_since(self, start: float) -> float:\n        \"\"\"Calculate an interval relative to this iteration's timestamp.\"\"\"\n        return self.iteration_timestamp - start\n\n    def update_from_output(self, outputs: EngineOutput, req_stats: RequestStats):\n        \"\"\"Update the iteration statistics.\n\n        Args:\n            outputs (EngineOutput): The output from the engine containing information about the current iteration.\n            req_stats (RequestStats): The stats of the request, including timestamps and token counts.\n        \"\"\"\n        if outputs.req_metrics is None:\n            # when users visit \"/abort_request\" endpoint, `req_metrics` might be None\n            return\n\n        new_generation_tokens = len(outputs.token_ids)\n        if new_generation_tokens == 0:\n            return\n\n        self.new_generation_tokens = new_generation_tokens\n\n        if req_stats.first_token_time == 0:\n            # the first token is generated in this iteration\n            req_stats.first_token_time = outputs.req_metrics.token_timestamp\n            self.prompt_tokens = req_stats.prompt_tokens\n            self.ttft = self._time_since(req_stats.arrival_time)\n        else:\n            self.itl = self._time_since(req_stats.lastest_token_time)\n            self.tpot = self._time_since(req_stats.lastest_token_time) / self.new_generation_tokens\n\n        req_stats.lastest_token_time = outputs.req_metrics.token_timestamp\n        req_stats.generation_tokens += new_generation_tokens\n\n        if outputs.status != ResponseType.SUCCESS:\n            req_stats.finish_reason = outputs.status\n            req_stats.finish_time = self.iteration_timestamp\n\n\n# modify from vllm\n@dataclass\nclass SpeculativeDecodingStats:\n    \"\"\"Speculative decoding stats.\"\"\"\n\n    num_spec_tokens: int\n    num_drafts: int = 0\n    num_draft_tokens: int = 0\n    num_accepted_tokens: int = 0\n    num_accepted_tokens_per_pos: np.ndarray = None\n\n    def __post_init__(self):\n        assert self.num_spec_tokens > 0\n        self.num_accepted_tokens_per_pos = np.zeros(self.num_spec_tokens)\n\n    def update_from_output(self, outputs: EngineOutput):\n        \"\"\"Update from engine output.\"\"\"\n        spec_info = getattr(outputs.req_metrics, 'spec_info', None)\n        if spec_info:\n            self.num_drafts += 1\n            self.num_draft_tokens += spec_info['num_draft_tokens']\n            self.num_accepted_tokens += spec_info['num_accepted_tokens']\n            self.num_accepted_tokens_per_pos[:spec_info['num_accepted_tokens']] += 1\n\n    def update_per_draft(self, num_draft_tokens: int, num_accepted_tokens: int):\n        \"\"\"Update with per draft stats.\"\"\"\n        if num_draft_tokens > 0:\n            self.num_drafts += 1\n            self.num_draft_tokens += num_draft_tokens\n            self.num_accepted_tokens += num_accepted_tokens\n            self.num_accepted_tokens_per_pos[:num_accepted_tokens] += 1\n\n    def __repr__(self):\n        draft_acceptance_rate = (self.num_accepted_tokens / self.num_draft_tokens *\n                                 100 if self.num_draft_tokens > 0 else float('nan'))\n\n        # conventionally, mean acceptance length includes the bonus token\n        mean_acceptance_length = 1 + (self.num_accepted_tokens /\n                                      self.num_drafts) if self.num_drafts > 0 else float('nan')\n\n        acceptance_rates = self.num_accepted_tokens_per_pos / self.num_drafts if self.num_drafts > 0 else [\n            float('nan')\n        ] * self.num_accepted_tokens\n        rates_str = ', '.join(f'{p:.3f}' for p in acceptance_rates)\n\n        return ('SpeculativeDecodingStats('\n                f'num_spec_tokens={self.num_spec_tokens}, '\n                f'num_drafts={self.num_drafts}, '\n                f'num_draft_tokens={self.num_draft_tokens}, '\n                f'num_accepted_tokens={self.num_accepted_tokens}, '\n                f'draft_acceptance_rate={draft_acceptance_rate:.2f}%, '\n                f'mean_acceptance_length={mean_acceptance_length:.2f}, '\n                f'per_position_acceptance_rate={rates_str})')\n"
  },
  {
    "path": "lmdeploy/model.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport dataclasses\nimport json\nimport uuid\nfrom typing import List, Literal, Optional, Union\n\nfrom mmengine import Registry\n\nfrom lmdeploy.archs import get_model_arch\nfrom lmdeploy.utils import get_logger\n\nlogger = get_logger('lmdeploy')\nMODELS = Registry('model', locations=['lmdeploy.model'])\n\n\ndef random_uuid() -> str:\n    \"\"\"Return a random uuid.\"\"\"\n    return str(uuid.uuid4().hex)\n\n\ndef get_text(content: Union[str, List[dict]]):\n    \"\"\"Within the OpenAI API, the content field may be specified as either a\n    string or a list of ChatCompletionContentPartTextParam (defined in openai).\n\n    When a list is provided, lmdeploy selects the first element to incorporate into the chat template, as the manner in\n    which OpenAI processes lists is not explicitly defined.\n    \"\"\"\n\n    if isinstance(content, str):\n        return content\n    return content[0]['text']\n\n\n@dataclasses.dataclass\nclass ChatTemplateConfig:\n    \"\"\"Parameters for chat template.\n\n    Args:\n        model_name (str): the name of the deployed model. Determine which chat template will be applied.\n            All the chat template names: `lmdeploy list`\n        system (str | None): begin of the system prompt\n        meta_instruction (str | None): system prompt\n        eosys (str | None): end of the system prompt\n        user (str | None): begin of the user prompt\n        eoh (str | None): end of the user prompt\n        assistant (str | None): begin of the assistant prompt\n        eoa (str | None): end of the assistant prompt\n        tool (str | None): begin of the tool prompt\n        eotool (str | None): end of the tool prompt\n        capability: ('completion' | 'infilling' | 'chat' | 'python') = None\n    \"\"\"  # noqa: E501\n\n    model_name: str\n    model_path: Optional[str] = None\n    system: Optional[str] = None\n    meta_instruction: Optional[str] = None\n    eosys: Optional[str] = None\n    user: Optional[str] = None\n    eoh: Optional[str] = None\n    assistant: Optional[str] = None\n    eoa: Optional[str] = None\n    tool: Optional[str] = None\n    eotool: Optional[str] = None\n    separator: Optional[str] = None\n    capability: Optional[Literal['completion', 'infilling', 'chat', 'python']] = None\n    stop_words: Optional[List[str]] = None\n\n    @property\n    def chat_template(self):\n        attrs = {key: value for key, value in dataclasses.asdict(self).items() if value is not None}\n        attrs.pop('model_name', None)\n        if self.model_name in MODELS.module_dict.keys():\n            model = MODELS.get(self.model_name)(**attrs)\n        else:\n            logger.warning(f'Could not find {self.model_name} in registered models. '\n                           f'Register {self.model_name} using the BaseChatTemplate.')\n            model = BaseChatTemplate(**attrs)\n        return model\n\n    def to_json(self, file_path=None):\n        \"\"\"Convert the dataclass instance to a JSON formatted string and\n        optionally save to a file.\"\"\"\n        json_str = json.dumps(dataclasses.asdict(self), ensure_ascii=False, indent=4)\n        if file_path:\n            with open(file_path, 'w', encoding='utf-8') as file:\n                file.write(json_str)\n        return json_str\n\n    @classmethod\n    def from_json(cls, file_or_string):\n        \"\"\"Construct a dataclass instance from a JSON file or JSON string.\"\"\"\n        try:\n            # Try to open the input_data as a file path\n            with open(file_or_string, 'r', encoding='utf-8') as file:\n                json_data = file.read()\n        except FileNotFoundError:\n            # If it's not a file path, assume it's a JSON string\n            json_data = file_or_string\n        except IOError:\n            # If it's not a file path and not a valid JSON string, raise error\n            raise ValueError('Invalid input. Must be a file path or a valid JSON string.')\n        json_data = json.loads(json_data)\n        if json_data.get('model_name', None) is None:\n            json_data['model_name'] = random_uuid()\n        if json_data['model_name'] not in MODELS.module_dict.keys():\n            MODELS.register_module(json_data['model_name'], module=BaseChatTemplate)\n        return cls(**json_data)\n\n\n@MODELS.register_module(name='base')\nclass BaseChatTemplate:\n    \"\"\"Base Chat template.\"\"\"\n\n    def __init__(self,\n                 system='',\n                 meta_instruction='',\n                 eosys='',\n                 user='',\n                 eoh='',\n                 assistant='',\n                 eoa='',\n                 separator='',\n                 tool='',\n                 eotool='',\n                 capability='chat',\n                 stop_words=None,\n                 **kwargs):\n        self.system = system\n        self.meta_instruction = meta_instruction\n        self.user = user\n        self.eoh = eoh\n        self.eoa = eoa\n        self.separator = separator\n        self.eosys = eosys\n        self.assistant = assistant\n        self.tool = tool\n        self.eotool = eotool\n        self.stop_words = stop_words\n        self.capability = capability\n\n    def get_prompt(self, prompt, sequence_start=True):\n        \"\"\"Return the prompt that is concatenated with other elements in the\n        chat template.\n\n        Args:\n            prompt (str): user's input prompt\n            sequence_start (bool): indicator for the first round chat of a\n               session sequence\n        Returns:\n            str: the concatenated prompt\n        \"\"\"\n        if self.capability == 'completion':\n            return prompt\n        if sequence_start:\n            # None is different from ''\n            if self.meta_instruction is not None:\n                return f'{self.system}{self.meta_instruction}{self.eosys}' \\\n                    f'{self.user}{prompt}{self.eoh}' \\\n                    f'{self.assistant}'\n            else:\n                return f'{self.user}{prompt}{self.eoh}' \\\n                       f'{self.assistant}'\n        else:\n            return f'{self.separator}{self.user}{prompt}{self.eoh}' \\\n                   f'{self.assistant}'\n\n    def messages2prompt(self, messages, sequence_start=True, **kwargs):\n        \"\"\"Return the prompt that is concatenated with other elements in the\n        chat template.\n\n        Args:\n            messages (str | List): user's input prompt\n        Returns:\n            str: the concatenated prompt\n        \"\"\"\n        if isinstance(messages, str):\n            return self.get_prompt(messages, sequence_start)\n        box_map = dict(user=self.user, assistant=self.assistant, system=self.system, tool=self.tool)\n        eox_map = dict(user=self.eoh, assistant=self.eoa + self.separator, system=self.eosys, tool=self.eotool)\n        ret = ''\n        if self.meta_instruction is not None and sequence_start:\n            if len(messages) and messages[0]['role'] != 'system':\n                ret += f'{self.system}{self.meta_instruction}{self.eosys}'\n        for message in messages:\n            role = message['role']\n            content = get_text(message['content'])\n            ret += f'{box_map[role]}{content}{eox_map[role]}'\n        if len(messages) and messages[-1]['role'] == 'assistant' and len(eox_map['assistant']) > 0:\n            return ret[:-len(eox_map['assistant'])]  # prefix of response\n        ret += f'{self.assistant}'\n        return ret\n\n    @classmethod\n    def match(cls, model_path: str) -> Optional[str]:\n        \"\"\"Return the model_name that was registered to MODELS.\n\n        Args:\n            model_path (str): the model path used for matching.\n        \"\"\"\n        return None\n\n\n@MODELS.register_module(name='cogvlm')\nclass CogVLM(BaseChatTemplate):\n    \"\"\"Chat template of CogVLM model.\"\"\"\n\n    def __init__(self,\n                 meta_instruction='',\n                 eosys='',\n                 user='Question: ',\n                 separator='\\n',\n                 eoh=' ',\n                 assistant='Answer:',\n                 eoa='</s>',\n                 stop_words=['</s>'],\n                 **kwargs):\n        super().__init__(meta_instruction=meta_instruction,\n                         eosys=eosys,\n                         user=user,\n                         eoh=eoh,\n                         separator=separator,\n                         assistant=assistant,\n                         eoa=eoa,\n                         stop_words=stop_words,\n                         **kwargs)\n\n    @classmethod\n    def match(cls, model_path: str) -> Optional[str]:\n        \"\"\"Return the model_name that was registered to MODELS.\n\n        Args:\n            model_path (str): the model path used for matching.\n        \"\"\"\n        path = model_path.lower()\n        if 'cogvlm' in path and 'cogvlm2' not in path:\n            return 'cogvlm'\n\n\n@MODELS.register_module(name='vicuna')\nclass Vicuna(BaseChatTemplate):\n    \"\"\"Chat template of vicuna model.\"\"\"\n\n    def __init__(\n            self,\n            meta_instruction=\"\"\"A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\"\"\",  # noqa: E501\n            eosys=' ',\n            user='USER: ',\n            eoh=' ',\n            assistant='ASSISTANT: ',\n            eoa='</s>',\n            stop_words=['</s>'],\n            **kwargs):\n        super().__init__(meta_instruction=meta_instruction,\n                         eosys=eosys,\n                         user=user,\n                         eoh=eoh,\n                         assistant=assistant,\n                         eoa=eoa,\n                         stop_words=stop_words,\n                         **kwargs)\n\n    def get_prompt(self, prompt, sequence_start=True):\n        if self.capability == 'chat':\n            return super().get_prompt(prompt, sequence_start)[:-1]\n        return super().get_prompt(prompt, sequence_start)\n\n    def messages2prompt(self, messages, sequence_start=True, **kwargs):\n        if isinstance(messages, str):\n            return self.get_prompt(messages, sequence_start)\n        return super().messages2prompt(messages, sequence_start, **kwargs)[:-1]\n\n    @classmethod\n    def match(cls, model_path: str) -> Optional[str]:\n        \"\"\"Return the model_name that was registered to MODELS.\n\n        Args:\n            model_path (str): the model path used for matching.\n        \"\"\"\n        path = model_path.lower()\n        if 'vicuna' in path and 'llava' not in path:\n            return 'vicuna'\n        if 'wizardlm' in path:\n            return 'wizardlm'\n\n\n@MODELS.register_module(name='llava-v1')\nclass Llavav1(Vicuna):\n    \"\"\"Chat template of llava-v1 model.\"\"\"\n\n    def __init__(\n            self,\n            meta_instruction=\"\"\"A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\"\"\",  # noqa: E501\n            **kwargs):\n        super().__init__(meta_instruction=meta_instruction, **kwargs)\n\n    @classmethod\n    def match(cls, model_path: str) -> Optional[str]:\n        \"\"\"Return the model_name that was registered to MODELS.\n\n        Args:\n            model_path (str): the model path used for matching.\n        \"\"\"\n        path = model_path.lower()\n        if 'llava' in path and 'v1' in path and 'v1.6-34b' not in path \\\n                and 'mistral' not in path:\n            return 'llava-v1'\n        elif 'llava-1.5' in path:\n            return 'llava-v1'\n\n\n@MODELS.register_module(name='internlm')\nclass InternLMChat7B(BaseChatTemplate):\n    \"\"\"Chat template of InternLM model.\"\"\"\n\n    def __init__(\n            self,\n            system='<|System|>:',\n            meta_instruction=\"\"\"You are an AI assistant whose name is InternLM (书生·浦语).\n- InternLM (书生·浦语) is a conversational language model that is developed by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n- InternLM (书生·浦语) can understand and communicate fluently in the language chosen by the user such as English and 中文.\n\"\"\",  # noqa: E501\n            eosys='\\n',\n            user='<|User|>:',\n            eoh='\\n',\n            assistant='<|Bot|>:',\n            eoa='<eoa>',\n            separator='\\n',\n            stop_words=['<eoa>'],\n            **kwargs):\n        super().__init__(system=system,\n                         meta_instruction=meta_instruction,\n                         eosys=eosys,\n                         user=user,\n                         eoh=eoh,\n                         assistant=assistant,\n                         eoa=eoa,\n                         separator=separator,\n                         stop_words=stop_words,\n                         **kwargs)\n\n    @classmethod\n    def match(cls, model_path: str) -> Optional[str]:\n        \"\"\"Return the model_name that was registered to MODELS.\n\n        Args:\n            model_path (str): the model path used for matching.\n        \"\"\"\n        path = model_path.lower()\n        if all([c not in path for c in ['internlm3', 'internlm2', '8k']]) and \\\n                all([c in path for c in ['internlm', 'chat']]):\n            return 'internlm'\n\n\n@MODELS.register_module(name='baichuan2')\nclass Baichuan2(BaseChatTemplate):\n    \"\"\"Chat template and generation parameters of Baichuan2-7B-Base and\n    Baichuan2-7B-Chat models.\"\"\"\n\n    def __init__(self, user='<reserved_106>', assistant='<reserved_107>', **kwargs):\n        super().__init__(user=user, assistant=assistant, **kwargs)\n\n    @classmethod\n    def match(cls, model_path: str) -> Optional[str]:\n        \"\"\"Return the model_name that was registered to MODELS.\n\n        Args:\n            model_path (str): the model path used for matching.\n        \"\"\"\n        path = model_path.lower()\n        if 'baichuan2' in path and 'chat' in path:\n            return 'baichuan2'\n\n\n@MODELS.register_module(name='llama2')\nclass Llama2(BaseChatTemplate):\n    \"\"\"Chat template of LLaMA2 model.\"\"\"\n\n    def __init__(\n            self,\n            system='[INST] <<SYS>>\\n',\n            meta_instruction=\"\"\"\\\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\"\"\",  # noqa: E501\n            eosys='\\n<</SYS>>\\n\\n',\n            assistant=' [/INST] ',\n            eoa='</s>',\n            separator='<s>[INST] ',\n            session_len=4096,\n            **kwargs):\n        super().__init__(system=system,\n                         meta_instruction=meta_instruction,\n                         eosys=eosys,\n                         assistant=assistant,\n                         eoa=eoa,\n                         separator=separator,\n                         session_len=session_len,\n                         **kwargs)\n\n    @classmethod\n    def match(cls, model_path: str) -> Optional[str]:\n        \"\"\"Return the model_name that was registered to MODELS.\n\n        Args:\n            model_path (str): the model path used for matching.\n        \"\"\"\n        if 'llama-2' in model_path.lower() or 'llama2' in model_path.lower():\n            return 'llama2'\n\n\n@MODELS.register_module(name='codellama')\nclass CodeLlama(Llama2):\n\n    def __init__(self, meta_instruction='', suffix_first=False, stop_words=None, **kwargs):\n        super().__init__(meta_instruction=meta_instruction, stop_words=stop_words, **kwargs)\n        caps = ['completion', 'infilling', 'chat', 'python']\n        assert self.capability in caps, \\\n            f'{self.capability} is not supported. ' \\\n            f'The supported capabilities are: {caps}'\n        self.meta_instruction = meta_instruction\n        self.suffix_first = suffix_first\n        self.stop_words = stop_words\n        if self.capability == 'infilling':\n            if self.stop_words is None:\n                self.stop_words = ['<EOT>']\n\n    def get_prompt(self, prompt, sequence_start=True):\n        if self.capability == 'infilling':\n            return self._infill_prompt(prompt)\n        elif self.capability == 'chat':\n            return super().get_prompt(prompt, sequence_start)\n        else:  # python speicalist\n            return prompt\n\n    def _infill_prompt(self, prompt):\n        prefix, suffix = prompt.split('<FILL>')\n        if self.suffix_first:\n            # format as \"<PRE> <SUF>{suf} <MID> {pre}\"\n            prompt = f'<PRE> <SUF>{suffix} <MID> {prefix}'\n        else:\n            # format as \"<PRE> {pre} <SUF>{suf} <MID>\"\n            prompt = f'<PRE> {prefix} <SUF>{suffix} <MID>'\n        return prompt\n\n    @classmethod\n    def match(cls, model_path: str) -> Optional[str]:\n        \"\"\"Return the model_name that was registered to MODELS.\n\n        Args:\n            model_path (str): the model path used for matching.\n        \"\"\"\n        if 'codellama' in model_path.lower():\n            return 'codellama'\n\n\n@MODELS.register_module(name='chatglm')\nclass ChatGLM2(BaseChatTemplate):\n\n    def __init__(self, user='问：', eoh='\\n\\n', assistant='答：', eoa='\\n\\n', **kwargs):\n        super().__init__(**kwargs)\n        self._user = user\n        self._assistant = assistant\n        self._eoh = eoh\n        self._eoa = eoa\n        self.count = 0\n\n    def get_prompt(self, prompt, sequence_start=True):\n        \"\"\"Get prompt.\"\"\"\n        # need more check\n        # https://github.com/THUDM/ChatGLM2-6B/issues/48\n        # [64790, 64792] to be prepended\n        self.count += 1\n        ret = f'[Round {self.count}]\\n\\n'\n        ret += f'{self._user}{prompt}{self._eoh}'\n        ret += f'{self._assistant}'\n        return ret\n\n    def messages2prompt(self, messages, sequence_start=True, **kwargs):\n        \"\"\"Message to prompt.\"\"\"\n        if isinstance(messages, str):\n            return self.get_prompt(messages, sequence_start)\n        ret = ''\n        count = 0\n        for message in messages:\n            role = message['role']\n            content = get_text(message['content'])\n            if role == 'user':\n                count += 1\n                ret += f'[Round {count}]\\n\\n'\n                ret += f'{self._user}{content}{self._eoh}'\n                ret += f'{self._assistant}'\n            if role == 'assistant':\n                ret += f'{content}'\n        return ret\n\n    @classmethod\n    def match(cls, model_path: str) -> Optional[str]:\n        \"\"\"Return the model_name that was registered to MODELS.\n\n        Args:\n            model_path (str): the model path used for matching.\n        \"\"\"\n        path = model_path.lower()\n        if 'chatglm2' in path:\n            return 'chatglm'\n\n\n@MODELS.register_module(name=['mistral', 'mixtral'])\nclass MistralChat(BaseChatTemplate):\n    \"\"\"Template of Mistral and Mixtral Instruct models.\n\n    `https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1`\n    `https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1`\n    \"\"\"\n\n    def __init__(self, user='[INST] ', eoh=' [/INST]', eoa='</s>', **kwargs):\n        super().__init__(user=user, eoh=eoh, eoa=eoa, **kwargs)\n\n    @classmethod\n    def match(cls, model_path: str) -> Optional[str]:\n        \"\"\"Return the model_name that was registered to MODELS.\n\n        Args:\n            model_path (str): the model path used for matching.\n        \"\"\"\n        model_path = model_path.lower()\n        if 'instruct' in model_path or 'llava' in model_path:\n            if 'mistral' in model_path:\n                return 'mistral'\n            if 'mixtral' in model_path:\n                return 'mixtral'\n\n\n@MODELS.register_module(name=['internvl-zh'])\nclass InternVLZH(BaseChatTemplate):\n\n    def __init__(self, user='<human>: ', eoh=' ', assistant='<bot>: ', eoa='</s>', **kwargs):\n        super().__init__(user=user, eoh=eoh, assistant=assistant, eoa=eoa, **kwargs)\n\n    def get_prompt(self, prompt, sequence_start=True):\n        if self.capability == 'chat':\n            return super().get_prompt(prompt, sequence_start)[:-1]\n        return super().get_prompt(prompt, sequence_start)\n\n    def messages2prompt(self, messages, sequence_start=True, **kwargs):\n        if isinstance(messages, str):\n            return self.get_prompt(messages, sequence_start)\n        return super().messages2prompt(messages, sequence_start, **kwargs)[:-1]\n\n    @classmethod\n    def match(cls, model_path: str) -> Optional[str]:\n        \"\"\"Return the model_name that was registered to MODELS.\n\n        Args:\n            model_path (str): the model path used for matching.\n        \"\"\"\n        path = model_path.lower()\n        if 'internvl-chat' in path and 'v1-1' in path:\n            return 'internvl-zh'\n\n\n@MODELS.register_module(name=['deepseek-vl'])\nclass DeepseekVL(BaseChatTemplate):\n\n    def __init__(\n            self,\n            meta_instruction=\"\"\"You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.\"\"\",  # noqa: E501\n            eosys='\\n\\n',\n            user='User: ',\n            eoh='\\n\\n',\n            assistant='Assistant: ',\n            eoa='<｜end▁of▁sentence｜>',\n            **kwargs):\n        super().__init__(meta_instruction=meta_instruction,\n                         eosys=eosys,\n                         user=user,\n                         eoh=eoh,\n                         assistant=assistant,\n                         eoa=eoa,\n                         **kwargs)\n\n    def get_prompt(self, prompt, sequence_start=True):\n        if self.capability == 'chat':\n            return super().get_prompt(prompt, sequence_start)[:-1]\n        return super().get_prompt(prompt, sequence_start)\n\n    def messages2prompt(self, messages, sequence_start=True, **kwargs):\n        if isinstance(messages, str):\n            return self.get_prompt(messages, sequence_start)\n        return super().messages2prompt(messages, sequence_start, **kwargs)[:-1]\n\n    @classmethod\n    def match(cls, model_path: str) -> Optional[str]:\n        \"\"\"Return the model_name that was registered to MODELS.\n\n        Args:\n            model_path (str): the model path used for matching.\n        \"\"\"\n        path = model_path.lower()\n        if 'deepseek-vl' in path and 'chat' in path:\n            return 'deepseek-vl'\n\n\n@MODELS.register_module(name=['deepseek-vl2'])\nclass DeepseekVL2(BaseChatTemplate):\n\n    def __init__(self,\n                 meta_instruction='',\n                 eosys='',\n                 user='<|User|>: ',\n                 eoh='\\n\\n',\n                 assistant='<|Assistant|>: ',\n                 eoa='<｜end▁of▁sentence｜>',\n                 **kwargs):\n        super().__init__(meta_instruction=meta_instruction,\n                         eosys=eosys,\n                         user=user,\n                         eoh=eoh,\n                         assistant=assistant,\n                         eoa=eoa,\n                         **kwargs)\n\n    def get_prompt(self, prompt, sequence_start=True):\n        return super().get_prompt(prompt, sequence_start)[:-1]\n\n    def messages2prompt(self, messages, sequence_start=True, **kwargs):\n        if isinstance(messages, str):\n            return self.get_prompt(messages, sequence_start)\n        return super().messages2prompt(messages, sequence_start, **kwargs)[:-1]\n\n    @classmethod\n    def match(cls, model_path: str) -> Optional[str]:\n        \"\"\"Return the model_name that was registered to MODELS.\n\n        Args:\n            model_path (str): the model path used for matching.\n        \"\"\"\n        path = model_path.lower()\n        if 'deepseek-vl2' in path:\n            return 'deepseek-vl2'\n\n\n@MODELS.register_module(name=['llava-chatml'])\nclass ChatmlDirect(BaseChatTemplate):\n\n    def __init__(self,\n                 system='<|im_start|>system\\n',\n                 meta_instruction='Answer the questions.',\n                 eosys='<|im_end|>',\n                 user='<|im_start|>user\\n',\n                 eoh='<|im_end|>',\n                 assistant='<|im_start|>assistant\\n',\n                 eoa='<|im_end|>',\n                 separator='',\n                 **kwargs):\n        super().__init__(system,\n                         meta_instruction=meta_instruction,\n                         eosys=eosys,\n                         user=user,\n                         eoh=eoh,\n                         assistant=assistant,\n                         eoa=eoa,\n                         separator=separator,\n                         **kwargs)\n\n    @classmethod\n    def match(cls, model_path: str) -> Optional[str]:\n        \"\"\"Return the model_name that was registered to MODELS.\n\n        Args:\n            model_path (str): the model path used for matching.\n        \"\"\"\n        path = model_path.lower()\n        if 'llava' in path and 'v1.6-34b' in path:\n            return 'llava-chatml'\n\n\n@MODELS.register_module(name=['hf'])\nclass HFChatTemplate(BaseChatTemplate):\n    \"\"\"Chat template for HuggingFace models with `apply_chat_template` method.\n\n    It MUST be at the end of @MODLES registry\n    \"\"\"\n\n    def __init__(self, model_path: str = '', **kwargs):\n        self.model_path = model_path\n        try:\n            from transformers import AutoTokenizer\n            self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)\n            # Verify if the model can perform apply_chat_template with different roles.\n            self.user_start, self.user_end, _, _ = self._user_instruction()\n            self.assistant_start, self.assistant_end, _, _ = self._assistant_instruction()\n            _, _, self.sentinel_system_messages, self.sentinel_system_prompt = self._system_instruction()\n            self.stop_words = []\n            if hasattr(self.tokenizer, 'eos_token') and self.tokenizer.eos_token is not None:\n                self.stop_words.append(self.tokenizer.eos_token)\n            if hasattr(self.tokenizer, 'eot_token') and self.tokenizer.eot_token is not None:\n                self.stop_words.append(self.tokenizer.eot_token)\n            arch, _ = get_model_arch(model_path)\n            self.is_gpt_oss = arch == 'GptOssForCausalLM'\n            if self.is_gpt_oss:\n                self.stop_words.append('<|call|>')\n        except Exception as e:\n            raise ValueError(f'Try apply_chat_template failed: {e}')\n\n    def get_prompt(self, prompt, sequence_start=True, **kwargs):\n        messages = [{'role': 'user', 'content': prompt}]\n        return self.messages2prompt(messages, sequence_start, **kwargs)\n\n    def messages2prompt(self, messages, sequence_start=True, **kwargs):\n        if isinstance(messages, str):\n            messages = [{'role': 'user', 'content': messages}]\n        assert all(isinstance(m, dict) and 'role' in m and 'content' in m for m in messages), \\\n            'Each message should be a dict with \"role\" and \"content\" keys.'\n\n        if 'enable_thinking' in kwargs and kwargs['enable_thinking'] is None:\n            # Workaround for internlm/Intern-S1: when enable_thinking=None passed apply_chat_template,\n            # the <think> tag is not generated.\n            kwargs.pop('enable_thinking')\n        if 'reasoning_effort' in kwargs and kwargs['reasoning_effort'] is None:\n            kwargs.pop('reasoning_effort')\n        add_generation_prompt = messages[-1]['role'] != 'assistant'\n        if sequence_start:\n            prompt = self.tokenizer.apply_chat_template(messages,\n                                                        tokenize=False,\n                                                        add_generation_prompt=add_generation_prompt,\n                                                        **kwargs)\n        else:\n            # Use a sentinel position to avoid the influence of default system role in the tokenizer's chat template\n            # in interactive chat mode\n            messages = self.sentinel_system_messages + messages if self.sentinel_system_messages else messages\n            prompt = self.tokenizer.apply_chat_template(messages,\n                                                        tokenize=False,\n                                                        add_generation_prompt=add_generation_prompt,\n                                                        **kwargs)\n            # Remove the sentinel part.\n            prompt = prompt[len(self.sentinel_system_prompt):] if len(self.sentinel_system_prompt) > 0 else prompt\n        if messages[-1]['role'] == 'assistant' and len(self.assistant_end) > 0:\n            prompt = prompt[:-len(self.assistant_end)]  # prefix of response to let the model complete the response\n        if self.is_gpt_oss and not kwargs.get('tools'):\n            # for gpt-oss model, remove this seems more conducive to instruction following.\n            prompt = prompt.replace('commentary, ', '', 1)\n        return prompt\n\n    def _user_instruction(self):\n        \"\"\"Extract user message template markers from the tokenizer's chat\n        template.\"\"\"\n\n        messages = [{'role': 'user', 'content': 'sentinel'}]\n        prompt = self.tokenizer.apply_chat_template(messages, tokenize=False)\n        user_pos = prompt.find('sentinel')\n        user_start = prompt[:user_pos]\n        user_end = prompt[user_pos + len('sentinel'):]\n        return user_start, user_end, messages, prompt\n\n    def _assistant_instruction(self):\n        \"\"\"Extract assistant message template markers from the tokenizer's chat\n        template.\"\"\"\n\n        # Some models, such as google/gemma-2-2b-it, require conversation roles to strictly\n        # alternate between 'user' and 'assistant' (e.g., user/assistant/user/assistant...).\n        # Consequently, we construct test messages containing both user and assistant roles\n        # with special tokens, and parse the assistant tag according to user markers and\n        # special tokens.\n        messages = [{'role': 'user', 'content': 'placeholder'}, {'role': 'assistant', 'content': 'sentinel'}]\n        prompt = self.tokenizer.apply_chat_template(messages, tokenize=False)\n        user_end_pos = prompt.find(self.user_end)\n        assistant_pos = prompt.find('sentinel')\n        assistant_start = prompt[user_end_pos + len(self.user_end):assistant_pos]\n        assistant_end = prompt[assistant_pos + len('sentinel'):]\n        return assistant_start, assistant_end, messages, prompt\n\n    def _system_instruction(self):\n        \"\"\"Extract system message template markers from the tokenizer's chat\n        template.\"\"\"\n        messages = [{'role': 'system', 'content': 'sentinel'}]\n        try:\n            prompt = self.tokenizer.apply_chat_template(messages, tokenize=False)\n            system_pos = prompt.find('sentinel')\n            if system_pos == -1:\n                return None, None, [], self.tokenizer.bos_token or ''\n            system_start = prompt[:system_pos]\n            system_end = prompt[system_pos + len('sentinel'):]\n            return system_start, system_end, messages, prompt\n        except Exception:\n            # Some models, such as google/gemma-2-2b-it, do not support a system role in the message structure.\n            return None, None, [], self.tokenizer.bos_token or ''\n\n    @classmethod\n    def match(cls, model_path: str) -> Optional[str]:\n        try:\n            cls(model_path)\n        except Exception:\n            return False\n        return True\n\n\ndef get_chat_template(model_path: str, config: Optional[ChatTemplateConfig] = None) -> BaseChatTemplate:\n    \"\"\"Get the chat template for the model.\n\n    Args:\n        model_path (str): the model path.\n        config (Optional[ChatTemplateConfig]): the chat template config.\n    Returns:\n        BaseChatTemplate: the chat template.\n    \"\"\"\n    if config is not None:\n        return config.chat_template\n    chat_template_name = 'base'\n    for name, model in MODELS.module_dict.items():\n        if model.match(model_path):\n            chat_template_name = name\n            break\n    config = ChatTemplateConfig(chat_template_name, model_path=model_path)\n    return config.chat_template\n"
  },
  {
    "path": "lmdeploy/monitoring/docker-compose.yaml",
    "content": "# copy from https://github.com/sgl-project/sglang/blob/main/examples/monitoring/docker-compose.yaml\nversion: '3'\nservices:\n  prometheus:\n    image: prom/prometheus:latest\n    container_name: prometheus\n    network_mode: host\n    volumes:\n      - ./prometheus.yaml:/etc/prometheus/prometheus.yml\n    command:\n      - '--config.file=/etc/prometheus/prometheus.yml'\n      - '--storage.tsdb.path=/prometheus'\n\n  grafana:\n    image: grafana/grafana:latest\n    container_name: grafana\n    network_mode: host\n    volumes:\n      - ./grafana/datasources:/etc/grafana/provisioning/datasources\n      - ./grafana/dashboards/config:/etc/grafana/provisioning/dashboards\n      - ./grafana/dashboards/json:/var/lib/grafana/dashboards\n    environment:\n      - GF_AUTH_ANONYMOUS_ENABLED=true\n      - GF_AUTH_ANONYMOUS_ORG_ROLE=Viewer\n      - GF_AUTH_BASIC_ENABLED=false\n      - GF_USERS_ALLOW_SIGN_UP=false\n      - GF_DASHBOARDS_DEFAULT_HOME_DASHBOARD_PATH=/var/lib/grafana/dashboards/lmdeploy-dashboard.json\n    depends_on:\n      - prometheus\n"
  },
  {
    "path": "lmdeploy/monitoring/grafana/dashboards/config/dashboard.yaml",
    "content": "apiVersion: 1\nproviders:\n  - name: 'LMDeploy'\n    orgId: 1\n    folder: 'LMDeploy Monitoring'\n    type: file\n    disableDeletion: false\n    updateIntervalSeconds: 10\n    allowUiUpdates: false\n    options:\n      path: /var/lib/grafana/dashboards\n"
  },
  {
    "path": "lmdeploy/monitoring/grafana/dashboards/json/lmdeploy-dashboard.json",
    "content": "{\n  \"_comment\": \"json file adapted from https://github.com/vllm-project/vllm/blob/main/examples/online_serving/prometheus_grafana/grafana.json\",\n  \"annotations\": {\n    \"list\": [\n      {\n        \"builtIn\": 1,\n        \"datasource\": {\n          \"type\": \"grafana\",\n          \"uid\": \"-- Grafana --\"\n        },\n        \"enable\": true,\n        \"hide\": true,\n        \"iconColor\": \"rgba(0, 211, 255, 1)\",\n        \"name\": \"Annotations & Alerts\",\n        \"target\": {\n          \"limit\": 100,\n          \"matchAny\": false,\n          \"tags\": [],\n          \"type\": \"dashboard\"\n        },\n        \"type\": \"dashboard\"\n      }\n    ]\n  },\n  \"description\": \"Monitoring LMDeploy Inference Server\",\n  \"editable\": true,\n  \"fiscalYearStartMonth\": 0,\n  \"graphTooltip\": 0,\n  \"id\": 1,\n  \"links\": [],\n  \"liveNow\": false,\n  \"panels\": [\n    {\n      \"datasource\": {\n        \"type\": \"prometheus\",\n        \"uid\": \"${DS_PROMETHEUS}\"\n      },\n      \"description\": \"End to end request latency measured in seconds.\",\n      \"fieldConfig\": {\n        \"defaults\": {\n          \"color\": {\n            \"mode\": \"palette-classic\"\n          },\n          \"custom\": {\n            \"axisBorderShow\": false,\n            \"axisCenteredZero\": false,\n            \"axisColorMode\": \"text\",\n            \"axisLabel\": \"\",\n            \"axisPlacement\": \"auto\",\n            \"barAlignment\": 0,\n            \"barWidthFactor\": 0.6,\n            \"drawStyle\": \"line\",\n            \"fillOpacity\": 0,\n            \"gradientMode\": \"none\",\n            \"hideFrom\": {\n              \"legend\": false,\n              \"tooltip\": false,\n              \"viz\": false\n            },\n            \"insertNulls\": false,\n            \"lineInterpolation\": \"linear\",\n            \"lineWidth\": 1,\n            \"pointSize\": 5,\n            \"scaleDistribution\": {\n              \"type\": \"linear\"\n            },\n            \"showPoints\": \"auto\",\n            \"spanNulls\": false,\n            \"stacking\": {\n              \"group\": \"A\",\n              \"mode\": \"none\"\n            },\n            \"thresholdsStyle\": {\n              \"mode\": \"off\"\n            }\n          },\n          \"mappings\": [],\n          \"thresholds\": {\n            \"mode\": \"absolute\",\n            \"steps\": [\n              {\n                \"color\": \"green\",\n                \"value\": null\n              },\n              {\n                \"color\": \"red\",\n                \"value\": 80\n              }\n            ]\n          },\n          \"unit\": \"s\"\n        },\n        \"overrides\": []\n      },\n      \"gridPos\": {\n        \"h\": 8,\n        \"w\": 12,\n        \"x\": 0,\n        \"y\": 0\n      },\n      \"id\": 9,\n      \"options\": {\n        \"legend\": {\n          \"calcs\": [],\n          \"displayMode\": \"list\",\n          \"placement\": \"bottom\",\n          \"showLegend\": true\n        },\n        \"tooltip\": {\n          \"mode\": \"single\",\n          \"sort\": \"none\"\n        }\n      },\n      \"targets\": [\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"${DS_PROMETHEUS}\"\n          },\n          \"disableTextWrap\": false,\n          \"editorMode\": \"builder\",\n          \"expr\": \"histogram_quantile(0.99, sum by(le) (rate(lmdeploy:e2e_request_latency_seconds_bucket{model_name=\\\"$model_name\\\"}[$__rate_interval])))\",\n          \"fullMetaSearch\": false,\n          \"includeNullMetadata\": false,\n          \"instant\": false,\n          \"legendFormat\": \"P99\",\n          \"range\": true,\n          \"refId\": \"A\",\n          \"useBackend\": false\n        },\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"${DS_PROMETHEUS}\"\n          },\n          \"disableTextWrap\": false,\n          \"editorMode\": \"builder\",\n          \"expr\": \"histogram_quantile(0.95, sum by(le) (rate(lmdeploy:e2e_request_latency_seconds_bucket{model_name=\\\"$model_name\\\"}[$__rate_interval])))\",\n          \"fullMetaSearch\": false,\n          \"hide\": false,\n          \"includeNullMetadata\": false,\n          \"instant\": false,\n          \"legendFormat\": \"P95\",\n          \"range\": true,\n          \"refId\": \"B\",\n          \"useBackend\": false\n        },\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"${DS_PROMETHEUS}\"\n          },\n          \"disableTextWrap\": false,\n          \"editorMode\": \"builder\",\n          \"expr\": \"histogram_quantile(0.9, sum by(le) (rate(lmdeploy:e2e_request_latency_seconds_bucket{model_name=\\\"$model_name\\\"}[$__rate_interval])))\",\n          \"fullMetaSearch\": false,\n          \"hide\": false,\n          \"includeNullMetadata\": false,\n          \"instant\": false,\n          \"legendFormat\": \"P90\",\n          \"range\": true,\n          \"refId\": \"C\",\n          \"useBackend\": false\n        },\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"${DS_PROMETHEUS}\"\n          },\n          \"disableTextWrap\": false,\n          \"editorMode\": \"builder\",\n          \"expr\": \"histogram_quantile(0.5, sum by(le) (rate(lmdeploy:e2e_request_latency_seconds_bucket{model_name=\\\"$model_name\\\"}[$__rate_interval])))\",\n          \"fullMetaSearch\": false,\n          \"hide\": false,\n          \"includeNullMetadata\": false,\n          \"instant\": false,\n          \"legendFormat\": \"P50\",\n          \"range\": true,\n          \"refId\": \"D\",\n          \"useBackend\": false\n        },\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"${DS_PROMETHEUS}\"\n          },\n          \"editorMode\": \"code\",\n          \"expr\": \"rate(lmdeploy:e2e_request_latency_seconds_sum{model_name=\\\"$model_name\\\"}[$__rate_interval])\\n/\\nrate(lmdeploy:e2e_request_latency_seconds_count{model_name=\\\"$model_name\\\"}[$__rate_interval])\",\n          \"hide\": false,\n          \"instant\": false,\n          \"legendFormat\": \"Average\",\n          \"range\": true,\n          \"refId\": \"E\"\n        }\n      ],\n      \"title\": \"E2E Request Latency\",\n      \"type\": \"timeseries\"\n    },\n    {\n      \"datasource\": {\n        \"type\": \"prometheus\",\n        \"uid\": \"${DS_PROMETHEUS}\"\n      },\n      \"description\": \"Number of tokens processed per second\",\n      \"fieldConfig\": {\n        \"defaults\": {\n          \"color\": {\n            \"mode\": \"palette-classic\"\n          },\n          \"custom\": {\n            \"axisBorderShow\": false,\n            \"axisCenteredZero\": false,\n            \"axisColorMode\": \"text\",\n            \"axisLabel\": \"\",\n            \"axisPlacement\": \"auto\",\n            \"barAlignment\": 0,\n            \"barWidthFactor\": 0.6,\n            \"drawStyle\": \"line\",\n            \"fillOpacity\": 0,\n            \"gradientMode\": \"none\",\n            \"hideFrom\": {\n              \"legend\": false,\n              \"tooltip\": false,\n              \"viz\": false\n            },\n            \"insertNulls\": false,\n            \"lineInterpolation\": \"linear\",\n            \"lineWidth\": 1,\n            \"pointSize\": 5,\n            \"scaleDistribution\": {\n              \"type\": \"linear\"\n            },\n            \"showPoints\": \"auto\",\n            \"spanNulls\": false,\n            \"stacking\": {\n              \"group\": \"A\",\n              \"mode\": \"none\"\n            },\n            \"thresholdsStyle\": {\n              \"mode\": \"off\"\n            }\n          },\n          \"mappings\": [],\n          \"thresholds\": {\n            \"mode\": \"absolute\",\n            \"steps\": [\n              {\n                \"color\": \"green\",\n                \"value\": null\n              },\n              {\n                \"color\": \"red\",\n                \"value\": 80\n              }\n            ]\n          }\n        },\n        \"overrides\": []\n      },\n      \"gridPos\": {\n        \"h\": 8,\n        \"w\": 12,\n        \"x\": 12,\n        \"y\": 0\n      },\n      \"id\": 8,\n      \"options\": {\n        \"legend\": {\n          \"calcs\": [],\n          \"displayMode\": \"list\",\n          \"placement\": \"bottom\",\n          \"showLegend\": true\n        },\n        \"tooltip\": {\n          \"mode\": \"single\",\n          \"sort\": \"none\"\n        }\n      },\n      \"targets\": [\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"${DS_PROMETHEUS}\"\n          },\n          \"disableTextWrap\": false,\n          \"editorMode\": \"builder\",\n          \"expr\": \"rate(lmdeploy:prompt_tokens_total{model_name=\\\"$model_name\\\"}[$__rate_interval])\",\n          \"fullMetaSearch\": false,\n          \"includeNullMetadata\": false,\n          \"instant\": false,\n          \"legendFormat\": \"Prompt Tokens/Sec\",\n          \"range\": true,\n          \"refId\": \"A\",\n          \"useBackend\": false\n        },\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"${DS_PROMETHEUS}\"\n          },\n          \"disableTextWrap\": false,\n          \"editorMode\": \"builder\",\n          \"expr\": \"rate(lmdeploy:generation_tokens_total{model_name=\\\"$model_name\\\"}[$__rate_interval])\",\n          \"fullMetaSearch\": false,\n          \"hide\": false,\n          \"includeNullMetadata\": false,\n          \"instant\": false,\n          \"legendFormat\": \"Generation Tokens/Sec\",\n          \"range\": true,\n          \"refId\": \"B\",\n          \"useBackend\": false\n        }\n      ],\n      \"title\": \"Token Throughput\",\n      \"type\": \"timeseries\"\n    },\n    {\n      \"datasource\": {\n        \"type\": \"prometheus\",\n        \"uid\": \"${DS_PROMETHEUS}\"\n      },\n      \"description\": \"TOPT latency in seconds.\",\n      \"fieldConfig\": {\n        \"defaults\": {\n          \"color\": {\n            \"mode\": \"palette-classic\"\n          },\n          \"custom\": {\n            \"axisBorderShow\": false,\n            \"axisCenteredZero\": false,\n            \"axisColorMode\": \"text\",\n            \"axisLabel\": \"\",\n            \"axisPlacement\": \"auto\",\n            \"barAlignment\": 0,\n            \"barWidthFactor\": 0.6,\n            \"drawStyle\": \"line\",\n            \"fillOpacity\": 0,\n            \"gradientMode\": \"none\",\n            \"hideFrom\": {\n              \"legend\": false,\n              \"tooltip\": false,\n              \"viz\": false\n            },\n            \"insertNulls\": false,\n            \"lineInterpolation\": \"linear\",\n            \"lineWidth\": 1,\n            \"pointSize\": 5,\n            \"scaleDistribution\": {\n              \"type\": \"linear\"\n            },\n            \"showPoints\": \"auto\",\n            \"spanNulls\": false,\n            \"stacking\": {\n              \"group\": \"A\",\n              \"mode\": \"none\"\n            },\n            \"thresholdsStyle\": {\n              \"mode\": \"off\"\n            }\n          },\n          \"mappings\": [],\n          \"thresholds\": {\n            \"mode\": \"absolute\",\n            \"steps\": [\n              {\n                \"color\": \"green\",\n                \"value\": null\n              },\n              {\n                \"color\": \"red\",\n                \"value\": 80\n              }\n            ]\n          },\n          \"unit\": \"s\"\n        },\n        \"overrides\": []\n      },\n      \"gridPos\": {\n        \"h\": 8,\n        \"w\": 12,\n        \"x\": 0,\n        \"y\": 8\n      },\n      \"id\": 10,\n      \"options\": {\n        \"legend\": {\n          \"calcs\": [],\n          \"displayMode\": \"list\",\n          \"placement\": \"bottom\",\n          \"showLegend\": true\n        },\n        \"tooltip\": {\n          \"mode\": \"single\",\n          \"sort\": \"none\"\n        }\n      },\n      \"targets\": [\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"${DS_PROMETHEUS}\"\n          },\n          \"disableTextWrap\": false,\n          \"editorMode\": \"builder\",\n          \"expr\": \"histogram_quantile(0.99, sum by(le) (rate(lmdeploy:time_per_output_token_seconds_bucket{model_name=\\\"$model_name\\\"}[$__rate_interval])))\",\n          \"fullMetaSearch\": false,\n          \"includeNullMetadata\": false,\n          \"instant\": false,\n          \"legendFormat\": \"P99\",\n          \"range\": true,\n          \"refId\": \"A\",\n          \"useBackend\": false\n        },\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"${DS_PROMETHEUS}\"\n          },\n          \"disableTextWrap\": false,\n          \"editorMode\": \"builder\",\n          \"expr\": \"histogram_quantile(0.95, sum by(le) (rate(lmdeploy:time_per_output_token_seconds_bucket{model_name=\\\"$model_name\\\"}[$__rate_interval])))\",\n          \"fullMetaSearch\": false,\n          \"hide\": false,\n          \"includeNullMetadata\": false,\n          \"instant\": false,\n          \"legendFormat\": \"P95\",\n          \"range\": true,\n          \"refId\": \"B\",\n          \"useBackend\": false\n        },\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"${DS_PROMETHEUS}\"\n          },\n          \"disableTextWrap\": false,\n          \"editorMode\": \"builder\",\n          \"expr\": \"histogram_quantile(0.9, sum by(le) (rate(lmdeploy:time_per_output_token_seconds_bucket{model_name=\\\"$model_name\\\"}[$__rate_interval])))\",\n          \"fullMetaSearch\": false,\n          \"hide\": false,\n          \"includeNullMetadata\": false,\n          \"instant\": false,\n          \"legendFormat\": \"P90\",\n          \"range\": true,\n          \"refId\": \"C\",\n          \"useBackend\": false\n        },\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"${DS_PROMETHEUS}\"\n          },\n          \"disableTextWrap\": false,\n          \"editorMode\": \"builder\",\n          \"expr\": \"histogram_quantile(0.5, sum by(le) (rate(lmdeploy:time_per_output_token_seconds_bucket{model_name=\\\"$model_name\\\"}[$__rate_interval])))\",\n          \"fullMetaSearch\": false,\n          \"hide\": false,\n          \"includeNullMetadata\": false,\n          \"instant\": false,\n          \"legendFormat\": \"P50\",\n          \"range\": true,\n          \"refId\": \"D\",\n          \"useBackend\": false\n        },\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"${DS_PROMETHEUS}\"\n          },\n          \"editorMode\": \"code\",\n          \"expr\": \"rate(lmdeploy:time_per_output_token_seconds_sum{model_name=\\\"$model_name\\\"}[$__rate_interval])\\n/\\nrate(lmdeploy:time_per_output_token_seconds_count{model_name=\\\"$model_name\\\"}[$__rate_interval])\",\n          \"hide\": false,\n          \"instant\": false,\n          \"legendFormat\": \"Mean\",\n          \"range\": true,\n          \"refId\": \"E\"\n        }\n      ],\n      \"title\": \"Time Per Output Token Latency\",\n      \"type\": \"timeseries\"\n    },\n    {\n      \"datasource\": {\n        \"type\": \"prometheus\",\n        \"uid\": \"${DS_PROMETHEUS}\"\n      },\n      \"description\": \"Inter-token latency in seconds.\",\n      \"fieldConfig\": {\n        \"defaults\": {\n          \"color\": {\n            \"mode\": \"palette-classic\"\n          },\n          \"custom\": {\n            \"axisBorderShow\": false,\n            \"axisCenteredZero\": false,\n            \"axisColorMode\": \"text\",\n            \"axisLabel\": \"\",\n            \"axisPlacement\": \"auto\",\n            \"barAlignment\": 0,\n            \"barWidthFactor\": 0.6,\n            \"drawStyle\": \"line\",\n            \"fillOpacity\": 0,\n            \"gradientMode\": \"none\",\n            \"hideFrom\": {\n              \"legend\": false,\n              \"tooltip\": false,\n              \"viz\": false\n            },\n            \"insertNulls\": false,\n            \"lineInterpolation\": \"linear\",\n            \"lineWidth\": 1,\n            \"pointSize\": 5,\n            \"scaleDistribution\": {\n              \"type\": \"linear\"\n            },\n            \"showPoints\": \"auto\",\n            \"spanNulls\": false,\n            \"stacking\": {\n              \"group\": \"A\",\n              \"mode\": \"none\"\n            },\n            \"thresholdsStyle\": {\n              \"mode\": \"off\"\n            }\n          },\n          \"mappings\": [],\n          \"thresholds\": {\n            \"mode\": \"absolute\",\n            \"steps\": [\n              {\n                \"color\": \"green\",\n                \"value\": null\n              },\n              {\n                \"color\": \"red\",\n                \"value\": 80\n              }\n            ]\n          },\n          \"unit\": \"s\"\n        },\n        \"overrides\": []\n      },\n      \"gridPos\": {\n        \"h\": 8,\n        \"w\": 12,\n        \"x\": 0,\n        \"y\": 8\n      },\n      \"id\": 10,\n      \"options\": {\n        \"legend\": {\n          \"calcs\": [],\n          \"displayMode\": \"list\",\n          \"placement\": \"bottom\",\n          \"showLegend\": true\n        },\n        \"tooltip\": {\n          \"mode\": \"single\",\n          \"sort\": \"none\"\n        }\n      },\n      \"targets\": [\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"${DS_PROMETHEUS}\"\n          },\n          \"disableTextWrap\": false,\n          \"editorMode\": \"builder\",\n          \"expr\": \"histogram_quantile(0.99, sum by(le) (rate(lmdeploy:iter_token_latency_bucket{model_name=\\\"$model_name\\\"}[$__rate_interval])))\",\n          \"fullMetaSearch\": false,\n          \"includeNullMetadata\": false,\n          \"instant\": false,\n          \"legendFormat\": \"P99\",\n          \"range\": true,\n          \"refId\": \"A\",\n          \"useBackend\": false\n        },\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"${DS_PROMETHEUS}\"\n          },\n          \"disableTextWrap\": false,\n          \"editorMode\": \"builder\",\n          \"expr\": \"histogram_quantile(0.95, sum by(le) (rate(lmdeploy:iter_token_latency_bucket{model_name=\\\"$model_name\\\"}[$__rate_interval])))\",\n          \"fullMetaSearch\": false,\n          \"hide\": false,\n          \"includeNullMetadata\": false,\n          \"instant\": false,\n          \"legendFormat\": \"P95\",\n          \"range\": true,\n          \"refId\": \"B\",\n          \"useBackend\": false\n        },\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"${DS_PROMETHEUS}\"\n          },\n          \"disableTextWrap\": false,\n          \"editorMode\": \"builder\",\n          \"expr\": \"histogram_quantile(0.9, sum by(le) (rate(lmdeploy:iter_token_latency_bucket{model_name=\\\"$model_name\\\"}[$__rate_interval])))\",\n          \"fullMetaSearch\": false,\n          \"hide\": false,\n          \"includeNullMetadata\": false,\n          \"instant\": false,\n          \"legendFormat\": \"P90\",\n          \"range\": true,\n          \"refId\": \"C\",\n          \"useBackend\": false\n        },\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"${DS_PROMETHEUS}\"\n          },\n          \"disableTextWrap\": false,\n          \"editorMode\": \"builder\",\n          \"expr\": \"histogram_quantile(0.5, sum by(le) (rate(lmdeploy:iter_token_latency_bucket{model_name=\\\"$model_name\\\"}[$__rate_interval])))\",\n          \"fullMetaSearch\": false,\n          \"hide\": false,\n          \"includeNullMetadata\": false,\n          \"instant\": false,\n          \"legendFormat\": \"P50\",\n          \"range\": true,\n          \"refId\": \"D\",\n          \"useBackend\": false\n        },\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"${DS_PROMETHEUS}\"\n          },\n          \"editorMode\": \"code\",\n          \"expr\": \"rate(lmdeploy:iter_token_latency_sum{model_name=\\\"$model_name\\\"}[$__rate_interval])\\n/\\nrate(lmdeploy:iter_token_latency_count{model_name=\\\"$model_name\\\"}[$__rate_interval])\",\n          \"hide\": false,\n          \"instant\": false,\n          \"legendFormat\": \"Mean\",\n          \"range\": true,\n          \"refId\": \"E\"\n        }\n      ],\n      \"title\": \"Inter-Token Latency\",\n      \"type\": \"timeseries\"\n    },\n    {\n      \"datasource\": {\n        \"type\": \"prometheus\",\n        \"uid\": \"${DS_PROMETHEUS}\"\n      },\n      \"description\": \"Number of requests in RUNNING, WAITING, and SWAPPED state\",\n      \"fieldConfig\": {\n        \"defaults\": {\n          \"color\": {\n            \"mode\": \"palette-classic\"\n          },\n          \"custom\": {\n            \"axisBorderShow\": false,\n            \"axisCenteredZero\": false,\n            \"axisColorMode\": \"text\",\n            \"axisLabel\": \"\",\n            \"axisPlacement\": \"auto\",\n            \"barAlignment\": 0,\n            \"barWidthFactor\": 0.6,\n            \"drawStyle\": \"line\",\n            \"fillOpacity\": 0,\n            \"gradientMode\": \"none\",\n            \"hideFrom\": {\n              \"legend\": false,\n              \"tooltip\": false,\n              \"viz\": false\n            },\n            \"insertNulls\": false,\n            \"lineInterpolation\": \"linear\",\n            \"lineWidth\": 1,\n            \"pointSize\": 5,\n            \"scaleDistribution\": {\n              \"type\": \"linear\"\n            },\n            \"showPoints\": \"auto\",\n            \"spanNulls\": false,\n            \"stacking\": {\n              \"group\": \"A\",\n              \"mode\": \"none\"\n            },\n            \"thresholdsStyle\": {\n              \"mode\": \"off\"\n            }\n          },\n          \"mappings\": [],\n          \"thresholds\": {\n            \"mode\": \"absolute\",\n            \"steps\": [\n              {\n                \"color\": \"green\",\n                \"value\": null\n              },\n              {\n                \"color\": \"red\",\n                \"value\": 80\n              }\n            ]\n          },\n          \"unit\": \"none\"\n        },\n        \"overrides\": []\n      },\n      \"gridPos\": {\n        \"h\": 8,\n        \"w\": 12,\n        \"x\": 12,\n        \"y\": 8\n      },\n      \"id\": 3,\n      \"options\": {\n        \"legend\": {\n          \"calcs\": [],\n          \"displayMode\": \"list\",\n          \"placement\": \"bottom\",\n          \"showLegend\": true\n        },\n        \"tooltip\": {\n          \"mode\": \"single\",\n          \"sort\": \"none\"\n        }\n      },\n      \"targets\": [\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"${DS_PROMETHEUS}\"\n          },\n          \"disableTextWrap\": false,\n          \"editorMode\": \"builder\",\n          \"expr\": \"lmdeploy:num_requests_running{model_name=\\\"$model_name\\\"}\",\n          \"fullMetaSearch\": false,\n          \"includeNullMetadata\": true,\n          \"instant\": false,\n          \"legendFormat\": \"Num Running\",\n          \"range\": true,\n          \"refId\": \"C\",\n          \"useBackend\": false\n        },\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"${DS_PROMETHEUS}\"\n          },\n          \"disableTextWrap\": false,\n          \"editorMode\": \"builder\",\n          \"expr\": \"lmdeploy:num_requests_waiting{model_name=\\\"$model_name\\\"}\",\n          \"fullMetaSearch\": false,\n          \"hide\": false,\n          \"includeNullMetadata\": true,\n          \"instant\": false,\n          \"legendFormat\": \"Num Waiting\",\n          \"range\": true,\n          \"refId\": \"D\",\n          \"useBackend\": false\n        }\n      ],\n      \"title\": \"Scheduler Stats\",\n      \"type\": \"timeseries\"\n    },\n    {\n      \"datasource\": {\n        \"type\": \"prometheus\",\n        \"uid\": \"${DS_PROMETHEUS}\"\n      },\n      \"description\": \"P50, P90, P95, and P99 TTFT latency in seconds.\",\n      \"fieldConfig\": {\n        \"defaults\": {\n          \"color\": {\n            \"mode\": \"palette-classic\"\n          },\n          \"custom\": {\n            \"axisBorderShow\": false,\n            \"axisCenteredZero\": false,\n            \"axisColorMode\": \"text\",\n            \"axisLabel\": \"\",\n            \"axisPlacement\": \"auto\",\n            \"barAlignment\": 0,\n            \"barWidthFactor\": 0.6,\n            \"drawStyle\": \"line\",\n            \"fillOpacity\": 0,\n            \"gradientMode\": \"none\",\n            \"hideFrom\": {\n              \"legend\": false,\n              \"tooltip\": false,\n              \"viz\": false\n            },\n            \"insertNulls\": false,\n            \"lineInterpolation\": \"linear\",\n            \"lineWidth\": 1,\n            \"pointSize\": 5,\n            \"scaleDistribution\": {\n              \"type\": \"linear\"\n            },\n            \"showPoints\": \"auto\",\n            \"spanNulls\": false,\n            \"stacking\": {\n              \"group\": \"A\",\n              \"mode\": \"none\"\n            },\n            \"thresholdsStyle\": {\n              \"mode\": \"off\"\n            }\n          },\n          \"mappings\": [],\n          \"thresholds\": {\n            \"mode\": \"absolute\",\n            \"steps\": [\n              {\n                \"color\": \"green\",\n                \"value\": null\n              },\n              {\n                \"color\": \"red\",\n                \"value\": 80\n              }\n            ]\n          },\n          \"unit\": \"s\"\n        },\n        \"overrides\": []\n      },\n      \"gridPos\": {\n        \"h\": 8,\n        \"w\": 12,\n        \"x\": 0,\n        \"y\": 16\n      },\n      \"id\": 5,\n      \"options\": {\n        \"legend\": {\n          \"calcs\": [],\n          \"displayMode\": \"list\",\n          \"placement\": \"bottom\",\n          \"showLegend\": true\n        },\n        \"tooltip\": {\n          \"mode\": \"single\",\n          \"sort\": \"none\"\n        }\n      },\n      \"targets\": [\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"${DS_PROMETHEUS}\"\n          },\n          \"disableTextWrap\": false,\n          \"editorMode\": \"builder\",\n          \"expr\": \"histogram_quantile(0.99, sum by(le) (rate(lmdeploy:time_to_first_token_seconds_bucket{model_name=\\\"$model_name\\\"}[$__rate_interval])))\",\n          \"fullMetaSearch\": false,\n          \"hide\": false,\n          \"includeNullMetadata\": false,\n          \"instant\": false,\n          \"legendFormat\": \"P99\",\n          \"range\": true,\n          \"refId\": \"A\",\n          \"useBackend\": false\n        },\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"${DS_PROMETHEUS}\"\n          },\n          \"disableTextWrap\": false,\n          \"editorMode\": \"builder\",\n          \"expr\": \"histogram_quantile(0.95, sum by(le) (rate(lmdeploy:time_to_first_token_seconds_bucket{model_name=\\\"$model_name\\\"}[$__rate_interval])))\",\n          \"fullMetaSearch\": false,\n          \"includeNullMetadata\": false,\n          \"instant\": false,\n          \"legendFormat\": \"P95\",\n          \"range\": true,\n          \"refId\": \"B\",\n          \"useBackend\": false\n        },\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"${DS_PROMETHEUS}\"\n          },\n          \"disableTextWrap\": false,\n          \"editorMode\": \"builder\",\n          \"expr\": \"histogram_quantile(0.9, sum by(le) (rate(lmdeploy:time_to_first_token_seconds_bucket{model_name=\\\"$model_name\\\"}[$__rate_interval])))\",\n          \"fullMetaSearch\": false,\n          \"hide\": false,\n          \"includeNullMetadata\": false,\n          \"instant\": false,\n          \"legendFormat\": \"P90\",\n          \"range\": true,\n          \"refId\": \"C\",\n          \"useBackend\": false\n        },\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"${DS_PROMETHEUS}\"\n          },\n          \"disableTextWrap\": false,\n          \"editorMode\": \"builder\",\n          \"expr\": \"histogram_quantile(0.5, sum by(le) (rate(lmdeploy:time_to_first_token_seconds_bucket{model_name=\\\"$model_name\\\"}[$__rate_interval])))\",\n          \"fullMetaSearch\": false,\n          \"hide\": false,\n          \"includeNullMetadata\": false,\n          \"instant\": false,\n          \"legendFormat\": \"P50\",\n          \"range\": true,\n          \"refId\": \"D\",\n          \"useBackend\": false\n        },\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"${DS_PROMETHEUS}\"\n          },\n          \"editorMode\": \"code\",\n          \"expr\": \"rate(lmdeploy:time_to_first_token_seconds_sum{model_name=\\\"$model_name\\\"}[$__rate_interval])\\n/\\nrate(lmdeploy:time_to_first_token_seconds_count{model_name=\\\"$model_name\\\"}[$__rate_interval])\",\n          \"hide\": false,\n          \"instant\": false,\n          \"legendFormat\": \"Average\",\n          \"range\": true,\n          \"refId\": \"E\"\n        }\n      ],\n      \"title\": \"Time To First Token Latency\",\n      \"type\": \"timeseries\"\n    },\n    {\n      \"datasource\": {\n        \"type\": \"prometheus\",\n        \"uid\": \"${DS_PROMETHEUS}\"\n      },\n      \"description\": \"Percentage of used cache blocks by LMDeploy.\",\n      \"fieldConfig\": {\n        \"defaults\": {\n          \"color\": {\n            \"mode\": \"palette-classic\"\n          },\n          \"custom\": {\n            \"axisBorderShow\": false,\n            \"axisCenteredZero\": false,\n            \"axisColorMode\": \"text\",\n            \"axisLabel\": \"\",\n            \"axisPlacement\": \"auto\",\n            \"barAlignment\": 0,\n            \"barWidthFactor\": 0.6,\n            \"drawStyle\": \"line\",\n            \"fillOpacity\": 0,\n            \"gradientMode\": \"none\",\n            \"hideFrom\": {\n              \"legend\": false,\n              \"tooltip\": false,\n              \"viz\": false\n            },\n            \"insertNulls\": false,\n            \"lineInterpolation\": \"linear\",\n            \"lineWidth\": 1,\n            \"pointSize\": 5,\n            \"scaleDistribution\": {\n              \"type\": \"linear\"\n            },\n            \"showPoints\": \"auto\",\n            \"spanNulls\": false,\n            \"stacking\": {\n              \"group\": \"A\",\n              \"mode\": \"none\"\n            },\n            \"thresholdsStyle\": {\n              \"mode\": \"off\"\n            }\n          },\n          \"mappings\": [],\n          \"thresholds\": {\n            \"mode\": \"absolute\",\n            \"steps\": [\n              {\n                \"color\": \"green\",\n                \"value\": null\n              },\n              {\n                \"color\": \"red\",\n                \"value\": 80\n              }\n            ]\n          },\n          \"unit\": \"percentunit\"\n        },\n        \"overrides\": []\n      },\n      \"gridPos\": {\n        \"h\": 8,\n        \"w\": 12,\n        \"x\": 12,\n        \"y\": 16\n      },\n      \"id\": 4,\n      \"options\": {\n        \"legend\": {\n          \"calcs\": [],\n          \"displayMode\": \"list\",\n          \"placement\": \"bottom\",\n          \"showLegend\": true\n        },\n        \"tooltip\": {\n          \"mode\": \"single\",\n          \"sort\": \"none\"\n        }\n      },\n      \"targets\": [\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"${DS_PROMETHEUS}\"\n          },\n          \"editorMode\": \"code\",\n          \"expr\": \"lmdeploy:gpu_cache_usage_perc{model_name=\\\"$model_name\\\"}\",\n          \"instant\": false,\n          \"legendFormat\": \"GPU Cache Usage\",\n          \"range\": true,\n          \"refId\": \"A\"\n        }\n      ],\n      \"title\": \"Cache Utilization\",\n      \"type\": \"timeseries\"\n    },\n    {\n      \"datasource\": {\n        \"type\": \"prometheus\",\n        \"uid\": \"${DS_PROMETHEUS}\"\n      },\n      \"description\": \"Number of finished requests by their finish reason: either an EOS token was generated or the max sequence length was reached.\",\n      \"fieldConfig\": {\n        \"defaults\": {\n          \"color\": {\n            \"mode\": \"palette-classic\"\n          },\n          \"custom\": {\n            \"axisBorderShow\": false,\n            \"axisCenteredZero\": false,\n            \"axisColorMode\": \"text\",\n            \"axisLabel\": \"\",\n            \"axisPlacement\": \"auto\",\n            \"barAlignment\": 0,\n            \"barWidthFactor\": 0.6,\n            \"drawStyle\": \"line\",\n            \"fillOpacity\": 0,\n            \"gradientMode\": \"none\",\n            \"hideFrom\": {\n              \"legend\": false,\n              \"tooltip\": false,\n              \"viz\": false\n            },\n            \"insertNulls\": false,\n            \"lineInterpolation\": \"linear\",\n            \"lineWidth\": 1,\n            \"pointSize\": 5,\n            \"scaleDistribution\": {\n              \"type\": \"linear\"\n            },\n            \"showPoints\": \"auto\",\n            \"spanNulls\": false,\n            \"stacking\": {\n              \"group\": \"A\",\n              \"mode\": \"none\"\n            },\n            \"thresholdsStyle\": {\n              \"mode\": \"off\"\n            }\n          },\n          \"mappings\": [],\n          \"thresholds\": {\n            \"mode\": \"absolute\",\n            \"steps\": [\n              {\n                \"color\": \"green\"\n              },\n              {\n                \"color\": \"red\",\n                \"value\": 80\n              }\n            ]\n          }\n        },\n        \"overrides\": []\n      },\n      \"gridPos\": {\n        \"h\": 8,\n        \"w\": 12,\n        \"x\": 0,\n        \"y\": 32\n      },\n      \"id\": 11,\n      \"options\": {\n        \"legend\": {\n          \"calcs\": [],\n          \"displayMode\": \"list\",\n          \"placement\": \"bottom\",\n          \"showLegend\": true\n        },\n        \"tooltip\": {\n          \"mode\": \"single\",\n          \"sort\": \"none\"\n        }\n      },\n      \"targets\": [\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"${DS_PROMETHEUS}\"\n          },\n          \"disableTextWrap\": false,\n          \"editorMode\": \"builder\",\n          \"expr\": \"sum by(finished_reason) (increase(lmdeploy:request_success_total{model_name=\\\"$model_name\\\"}[$__rate_interval]))\",\n          \"fullMetaSearch\": false,\n          \"includeNullMetadata\": true,\n          \"instant\": false,\n          \"interval\": \"\",\n          \"legendFormat\": \"__auto\",\n          \"range\": true,\n          \"refId\": \"A\",\n          \"useBackend\": false\n        }\n      ],\n      \"title\": \"Finish Reason\",\n      \"type\": \"timeseries\"\n    },\n    {\n      \"datasource\": {\n        \"default\": false,\n        \"type\": \"prometheus\",\n        \"uid\": \"${DS_PROMETHEUS}\"\n      },\n      \"fieldConfig\": {\n        \"defaults\": {\n          \"color\": {\n            \"mode\": \"palette-classic\"\n          },\n          \"custom\": {\n            \"axisBorderShow\": false,\n            \"axisCenteredZero\": false,\n            \"axisColorMode\": \"text\",\n            \"axisLabel\": \"seconds\",\n            \"axisPlacement\": \"auto\",\n            \"barAlignment\": 0,\n            \"barWidthFactor\": 0.6,\n            \"drawStyle\": \"line\",\n            \"fillOpacity\": 0,\n            \"gradientMode\": \"none\",\n            \"hideFrom\": {\n              \"legend\": false,\n              \"tooltip\": false,\n              \"viz\": false\n            },\n            \"insertNulls\": false,\n            \"lineInterpolation\": \"linear\",\n            \"lineWidth\": 1,\n            \"pointSize\": 5,\n            \"scaleDistribution\": {\n              \"type\": \"linear\"\n            },\n            \"showPoints\": \"auto\",\n            \"spanNulls\": false,\n            \"stacking\": {\n              \"group\": \"A\",\n              \"mode\": \"none\"\n            },\n            \"thresholdsStyle\": {\n              \"mode\": \"off\"\n            }\n          },\n          \"mappings\": [],\n          \"thresholds\": {\n            \"mode\": \"absolute\",\n            \"steps\": [\n              {\n                \"color\": \"green\"\n              },\n              {\n                \"color\": \"red\",\n                \"value\": 80\n              }\n            ]\n          }\n        },\n        \"overrides\": []\n      },\n      \"gridPos\": {\n        \"h\": 8,\n        \"w\": 12,\n        \"x\": 12,\n        \"y\": 32\n      },\n      \"id\": 14,\n      \"options\": {\n        \"legend\": {\n          \"calcs\": [],\n          \"displayMode\": \"list\",\n          \"placement\": \"bottom\",\n          \"showLegend\": true\n        },\n        \"tooltip\": {\n          \"mode\": \"single\",\n          \"sort\": \"none\"\n        }\n      },\n      \"targets\": [\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"${DS_PROMETHEUS}\"\n          },\n          \"disableTextWrap\": false,\n          \"editorMode\": \"code\",\n          \"expr\": \"rate(lmdeploy:request_queue_time_seconds_sum{model_name=\\\"$model_name\\\"}[$__rate_interval])\",\n          \"fullMetaSearch\": false,\n          \"includeNullMetadata\": true,\n          \"instant\": false,\n          \"legendFormat\": \"__auto\",\n          \"range\": true,\n          \"refId\": \"A\",\n          \"useBackend\": false\n        }\n      ],\n      \"title\": \"Queue Time\",\n      \"type\": \"timeseries\"\n    }\n  ],\n  \"refresh\": \"\",\n  \"schemaVersion\": 39,\n  \"tags\": [],\n  \"templating\": {\n    \"list\": [\n      {\n        \"current\": {\n          \"selected\": false,\n          \"text\": \"prometheus\",\n          \"value\": \"edx8memhpd9tsa\"\n        },\n        \"hide\": 0,\n        \"includeAll\": false,\n        \"label\": \"datasource\",\n        \"multi\": false,\n        \"name\": \"DS_PROMETHEUS\",\n        \"options\": [],\n        \"query\": \"prometheus\",\n        \"queryValue\": \"\",\n        \"refresh\": 1,\n        \"regex\": \"\",\n        \"skipUrlSync\": false,\n        \"type\": \"datasource\"\n      },\n      {\n        \"current\": {\n          \"selected\": false,\n          \"text\": \"/share/datasets/public_models/Meta-Llama-3-8B-Instruct\",\n          \"value\": \"/share/datasets/public_models/Meta-Llama-3-8B-Instruct\"\n        },\n        \"datasource\": {\n          \"type\": \"prometheus\",\n          \"uid\": \"${DS_PROMETHEUS}\"\n        },\n        \"definition\": \"label_values(model_name)\",\n        \"hide\": 0,\n        \"includeAll\": false,\n        \"label\": \"model_name\",\n        \"multi\": false,\n        \"name\": \"model_name\",\n        \"options\": [],\n        \"query\": {\n          \"query\": \"label_values(model_name)\",\n          \"refId\": \"StandardVariableQuery\"\n        },\n        \"refresh\": 1,\n        \"regex\": \"\",\n        \"skipUrlSync\": false,\n        \"sort\": 0,\n        \"type\": \"query\"\n      }\n    ]\n  },\n  \"time\": {\n    \"from\": \"now-5m\",\n    \"to\": \"now\"\n  },\n  \"timepicker\": {},\n  \"timezone\": \"\",\n  \"title\": \"LMDeploy\",\n  \"uid\": \"b281712d-8bff-41ef-9f3f-71ad43c05e9b\",\n  \"version\": 8,\n  \"weekStart\": \"\"\n}\n"
  },
  {
    "path": "lmdeploy/monitoring/grafana/datasources/datasource.yaml",
    "content": "apiVersion: 1\ndatasources:\n  - name: Prometheus\n    type: prometheus\n    access: proxy\n    url: http://localhost:9090\n    isDefault: true\n    editable: false\n"
  },
  {
    "path": "lmdeploy/monitoring/prometheus.yaml",
    "content": "# prometheus.yaml\nglobal:\n  scrape_interval: 5s\n  evaluation_interval: 30s\n\nscrape_configs:\n  - job_name: lmdeploy\n    static_configs:\n      - targets:\n          - '127.0.0.1:23333'\n"
  },
  {
    "path": "lmdeploy/pipeline.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport asyncio\nimport atexit\nimport concurrent.futures\nimport os\nfrom contextlib import closing\nfrom functools import partial\nfrom queue import Queue\nfrom threading import Thread\nfrom typing import TYPE_CHECKING, Dict, Iterator, List, Tuple\n\nimport torch\nimport tqdm\nfrom typing_extensions import deprecated\n\nfrom .archs import autoget_backend_config, get_task\nfrom .messages import GenerationConfig, PytorchEngineConfig, Response, SpeculativeConfig, TurbomindEngineConfig\nfrom .model import ChatTemplateConfig\nfrom .serve.processors import MultimodalProcessor\nfrom .utils import get_logger, get_model\n\nif TYPE_CHECKING:\n    from PIL.Image import Image\n\n    from .serve.managers import Session\n\nlogger = get_logger('lmdeploy')\n\n\nclass Pipeline:\n    \"\"\"Pipeline - User-facing API layer for inference.\"\"\"\n\n    def __init__(self,\n                 model_path: str,\n                 backend_config: TurbomindEngineConfig | PytorchEngineConfig | None = None,\n                 chat_template_config: ChatTemplateConfig | None = None,\n                 log_level: str = 'WARNING',\n                 max_log_len: int | None = None,\n                 speculative_config: SpeculativeConfig | None = None,\n                 **kwargs):\n        \"\"\"Initialize Pipeline.\n\n        Args:\n            model_path: Path to the model.\n            backend_config: Backend configuration.\n            chat_template_config: Chat template configuration.\n            log_level: Log level.\n            max_log_len: Max number of prompt characters or prompt tokens being printed in log.\n            speculative_config: Speculative decoding configuration.\n            **kwargs: Additional keyword arguments.\n        \"\"\"\n\n        os.environ.setdefault('TM_LOG_LEVEL', log_level)\n        logger.setLevel(log_level)\n\n        # Download model if the path does not exist locally\n        if not os.path.exists(model_path):\n            download_dir = backend_config.download_dir if backend_config else None\n            revision = backend_config.revision if backend_config else None\n            model_path = get_model(model_path, download_dir, revision)\n\n        # Download speculative model if the path does not exist locally\n        if speculative_config and speculative_config.model and not os.path.exists(speculative_config.model):\n            download_dir = backend_config.download_dir if backend_config else None\n            speculative_config.model = get_model(speculative_config.model, download_dir)\n\n        # Create inference engine\n        backend, backend_config = autoget_backend_config(model_path, backend_config)\n        _, pipeline_class = get_task(backend, model_path)\n        self.async_engine = pipeline_class(model_path,\n                                           backend=backend,\n                                           backend_config=backend_config,\n                                           chat_template_config=chat_template_config,\n                                           max_log_len=max_log_len,\n                                           speculative_config=speculative_config,\n                                           **kwargs)\n        self.internal_thread = _EventLoopThread(daemon=True)\n        self.limiter: asyncio.Semaphore = None\n        self.session_mgr = self.async_engine.session_mgr\n        self.backend_config = self.async_engine.backend_config\n        self.async_engine.start_loop(self.internal_thread.loop, use_async_api=False)\n\n    def infer(self,\n              prompts: List[str] | str | List[Dict] | List[List[Dict]] | Tuple | List[Tuple],\n              gen_config: GenerationConfig | List[GenerationConfig] | None = None,\n              do_preprocess: bool = True,\n              adapter_name: str | None = None,\n              use_tqdm: bool = False,\n              **kwargs):\n        \"\"\"Inference prompts.\n\n        Args:\n            prompts: Prompts to inference. It can be a single prompt, a list of prompts, a list of tuples, or a tuple.\n                Tuple can be (prompt, image or [images]) or (image or [images], prompt).\n            gen_config(GenerationConfig | List[GenerationConfig] | None): Generation configuration(s).\n            do_preprocess(bool): Whether to pre-process messages.\n            adapter_name(str | None): Adapter name.\n            use_tqdm(bool): Whether to use progress bar.\n            **kwargs(dict): Additional keyword arguments.\n        \"\"\"\n        is_single = self._is_single(prompts)\n        # format prompts to openai message format, which is a list of dicts\n        prompts = MultimodalProcessor.format_prompts(prompts)\n        pbar = tqdm.tqdm(total=len(prompts)) if use_tqdm else None\n        outputs = []\n        try:\n            requests = self._request_generator(prompts,\n                                               gen_config=gen_config,\n                                               do_preprocess=do_preprocess,\n                                               adapter_name=adapter_name,\n                                               stream_response=False,\n                                               **kwargs)\n            for g in self._infer(requests, multiplex=False, pbar=pbar):\n                res = None\n                for out in g:\n                    res = res.extend(out) if res else out\n                outputs.append(res)\n        finally:\n            if pbar: pbar.close()  # noqa\n        if is_single:\n            return outputs[0]\n        return outputs\n\n    @deprecated('This method is deprecated. Please use \"Pipeline.infer\" instead.')\n    def batch_infer(self, *args, **kwargs):\n        return self.infer(*args, **kwargs)\n\n    def stream_infer(self,\n                     prompts: List[str] | str | List[Dict] | List[List[Dict]] | Tuple | List[Tuple],\n                     sessions: 'Session' | List['Session'] | None = None,\n                     gen_config: GenerationConfig | List[GenerationConfig] | None = None,\n                     do_preprocess: bool = True,\n                     adapter_name: str | None = None,\n                     stream_response: bool = True,\n                     **kwargs):\n        \"\"\"Stream inference.\n\n        Args:\n            prompts(List[str] | str | List[Dict] | List[List[Dict]] | Tuple | List[Tuple]): Prompts to inference.\n                It can be a single prompt, a list of prompts, a list of tuples, or a tuple.\n                Tuple can be (prompt, image or [images]) or (image or [images], prompt).\n            sessions(Session | List[Session] | None): Sessions. Each of which corresponds to a prompt.\n            gen_config(GenerationConfig | List[GenerationConfig] | None): Generation configuration(s).\n            do_preprocess(bool): Whether to pre-process messages.\n            adapter_name(str | None): Adapter name.\n            stream_response(bool): Whether to stream the response. If True, the generator will stream the response.\n                Otherwise, the generator will run until finish and return the final response. This argument\n                is introduced to support the streaming and non-streaming modes of Pipeline.chat.\n            **kwargs(dict): Additional keyword arguments.\n\n        Returns:\n            Generator: A generator that yields the output (i.e. instance of class `Response`) of the inference.\n        \"\"\"\n        prompts = MultimodalProcessor.format_prompts(prompts)\n        requests = self._request_generator(prompts,\n                                           sessions=sessions,\n                                           gen_config=gen_config,\n                                           do_preprocess=do_preprocess,\n                                           adapter_name=adapter_name,\n                                           stream_response=stream_response,\n                                           **kwargs)\n        return self._infer(requests, multiplex=True)\n\n    def close(self):\n        \"\"\"Close the pipeline.\"\"\"\n        self.internal_thread.close()\n        self.async_engine.close()\n\n    def chat(self,\n             prompt: str | Tuple[str, 'Image' | List['Image']],\n             session=None,\n             gen_config: GenerationConfig | None = None,\n             stream_response=False,\n             adapter_name=None,\n             **kwargs) -> 'Session' | Iterator:\n        \"\"\"Chat.\n\n        Args:\n            prompt (str): prompt\n            session (Session): the chat session\n            gen_config (GenerationConfig | None): a instance of\n                GenerationConfig. Default to None.\n            stream_response (bool): whether to stream the response.\n            adapter_name (str): adapter name.\n            **kwargs (dict): additional keyword arguments.\n        \"\"\"\n        if session is None:\n            session = self.session_mgr.get()\n        session.update(prompt=prompt, response=None)\n\n        prompt = MultimodalProcessor.format_prompts(prompt)\n\n        sequence_start = session.step == 0\n        generator = self.stream_infer(prompts=prompt,\n                                      sessions=session,\n                                      gen_config=gen_config,\n                                      stream_response=stream_response,\n                                      adapter_name=adapter_name,\n                                      multiplex=True,\n                                      sequence_start=sequence_start,\n                                      sequence_end=False,\n                                      step=session.step,\n                                      **kwargs)\n\n        def _gen():\n            resp = None\n            try:\n                for out in generator:\n                    resp = resp.extend(out) if resp else out\n                    yield out\n            except:  # noqa\n                self._run(coro=session.async_abort())\n                raise\n            else:\n                session.response = resp\n                session.step += resp.generate_token_len + resp.input_token_len\n                session.history.append((session.prompt, resp.text))\n\n        if stream_response:\n            return _gen()\n        else:\n            # run the generator until finish\n            with closing(_gen()) as gen:\n                for _ in gen:\n                    pass\n            session.generator = None\n\n        return session\n\n    def session(self) -> 'Session':\n        \"\"\"Create a new session.\"\"\"\n        return self.session_mgr.get()\n\n    def get_reward_score(self, input_ids: List) -> List[float]:\n        \"\"\"Get reward score.\n\n        Args:\n            input_ids(List): a list of token_id or a list of token_id list or token_id tensor\n        Return:\n            reward score in a list. If the input_ids is a list of token_id, the return value\n            is still a list with length 1.\n        \"\"\"\n        supported_reward_models = ['InternLM2ForRewardModel', 'Qwen2ForRewardModel']\n        arch = self.async_engine.arch\n        if arch not in supported_reward_models:\n            raise ValueError(f'{arch} is not in reward model list: {supported_reward_models}')\n        assert isinstance(input_ids, List)\n        assert all(isinstance(x, int) for x in input_ids) or all(isinstance(x, List) for x in input_ids)\n        # Make input_ids a list of token_id list\n        input_ids = [input_ids] if isinstance(input_ids[0], int) else input_ids\n        logits = self._run(coro=self.async_engine.async_get_logits(input_ids=input_ids)).result()\n        logits = [x.squeeze() for x in logits]\n        scores = [x[-1].cpu().item() for x in logits]\n        return scores\n\n    def get_ppl(self, input_ids: List[int] | List[List[int]]) -> List[float]:\n        \"\"\"Get perplexity scores given a list of input tokens that have to be\n        of the same length.\n\n        Args:\n            input_ids (List[int] | List[List[int]]): the batch of input token ids\n\n        Returns:\n            List[float]: A list of perplexity scores.\n        \"\"\"\n        assert isinstance(input_ids, List)\n        if isinstance(input_ids[0], int):\n            input_ids = [input_ids]\n        assert all(len(_) > 1 for _ in input_ids)\n\n        # TODO: a better way to determine `max_input_len`, at most allocate\n        # 2G mem for logits with shape [bs, max_input_len, vocab_size]\n        vocab_size = self.async_engine.hf_cfg.vocab_size\n        max_input_len = 2 * 1024**3 // (vocab_size * 4)\n        sizes = [len(_) for _ in input_ids]\n        result = []\n        sorted_index_values = sorted(list(enumerate(sizes)), key=lambda x: x[1], reverse=True)\n        sizes = [value for index, value in sorted_index_values]\n        indices = [index for index, value in sorted_index_values]\n        logger.info(f'sorted sizes: {sizes}')\n        logger.info(f'sorted indices: {indices}')\n        for (start, end) in self._batch_iterator(sizes, max_input_len):\n            logger.info(f'start: {start}, end: {end}')\n            if start == end:\n                _input_ids = input_ids[indices[start]]\n                session = self.session_mgr.get()\n                res = self._get_long_text_ppl(session, input_ids=_input_ids, max_input_len=max_input_len)\n                result.append(res)\n                self.session_mgr.remove(session)\n            else:\n                _input_ids = [input_ids[indices[i]] for i in range(start, end)]\n                sessions = [self.session_mgr.get() for _ in range(start, end)]\n                res = self._get_ppl(\n                    sessions=sessions,\n                    input_ids=_input_ids,\n                    max_input_len=max_input_len,\n                )\n                result.extend(res)\n                for session in sessions:\n                    self.session_mgr.remove(session)\n        output = list(range(len(result)))\n        for index, sorted_index in enumerate(indices):\n            output[sorted_index] = result[index]\n        return output\n\n    def __call__(self,\n                 prompts: List[str] | str | List[Dict] | List[List[Dict]],\n                 gen_config: GenerationConfig | List[GenerationConfig] | None = None,\n                 **kwargs):\n        return self.infer(prompts, gen_config=gen_config, **kwargs)\n\n    def __enter__(self):\n        return self\n\n    def __exit__(self, exc_type, exc_value, traceback):\n        self.close()\n\n    @deprecated('This method is deprecated. Please use \"AsyncEngine.generate\" instead.')\n    async def generate(self, *args, **kwargs):\n        \"\"\"Generate responses as an async generator.\n\n        This method delegates to async_engine.generate and forwards all yielded values.\n        \"\"\"\n        async for item in self.async_engine.generate(*args, **kwargs):\n            yield item\n\n    @staticmethod\n    def _is_single(prompts):\n        \"\"\"Check if prompts is a single prompt.\"\"\"\n        return (isinstance(prompts, str) or (isinstance(prompts, tuple) and len(prompts) == 2)\n                or (isinstance(prompts, list) and len(prompts) > 0 and isinstance(prompts[0], Dict)))\n\n    def _request_generator(self,\n                           prompts: List[str] | str | List[Dict] | List[List[Dict]],\n                           sessions: List['Session'] | 'Session' | None = None,\n                           gen_config: GenerationConfig | List[GenerationConfig] | None = None,\n                           **kwargs):\n        \"\"\"Generate requests.\"\"\"\n        is_single = self._is_single(prompts)\n        prompts = [prompts] if is_single else prompts\n\n        if sessions is None:\n            sessions = [self.session_mgr.get() for _ in prompts]\n        elif isinstance(sessions, list):\n            sessions = sessions\n        else:\n            sessions = [sessions]\n\n        if len(prompts) != len(sessions):\n            raise ValueError(f'prompts and sessions should have the same length. '\n                             f'Got {len(prompts)} prompts and {len(sessions)} sessions')\n\n        if gen_config is None:\n            gen_configs = [GenerationConfig()] * len(prompts)\n        elif isinstance(gen_config, list):\n            gen_configs = gen_config\n        else:\n            gen_configs = [gen_config] * len(prompts)\n\n        if len(prompts) != len(gen_configs):\n            raise ValueError(f'input gen_config length differs from the length of prompts. '\n                             f'Got {len(prompts)} prompts and {len(gen_configs)} gen_configs')\n\n        for prompt, gen_cfg, session in zip(prompts, gen_configs, sessions):\n            # Use session_id is for backward compatibility. We will remove it in the future.\n            # Since AsyncEngine.generate defines session_id in the argument lists, here we\n            # use session_id to pass the session to the AsyncEngine.generate. It's\n            yield dict(session_id=session, messages=prompt, gen_config=gen_cfg, **kwargs)\n\n    def _get_limiter(self):\n        if not self.limiter:\n            self.limiter = asyncio.Semaphore(self.backend_config.max_batch_size)\n        return self.limiter\n\n    def _infer(self, requests: Iterator[Dict], multiplex: bool, pbar=None, loop=None) -> Iterator[Iterator[Response]]:\n\n        async def _sync_resp(g, que: Queue, idx: int, sem: asyncio.Semaphore):\n            async for out in g:\n                que.put(out.to_response(idx))\n            sem.release()\n            if not multiplex:\n                que.put(None)  # sentinel of inner generator\n            if pbar:\n                pbar.update(1)\n\n        que = Queue()\n\n        async def _infer():\n            sem = self._get_limiter()\n            tasks = []\n            for idx, req in enumerate(requests):\n                await sem.acquire()\n                gen = self.async_engine.generate(**req)\n                dst = que if multiplex else Queue()\n                if not multiplex:\n                    que.put(iter(dst.get, None))\n                # create a task to send the responses\n                task = asyncio.create_task(_sync_resp(gen, dst, idx, sem))\n                tasks.append(task)\n            if not multiplex:  # sentinel of outer generator\n                que.put(None)\n            await asyncio.gather(*tasks)\n            if multiplex:\n                que.put(None)  # sentinel of inner generator\n\n        loop = loop or self.internal_thread.loop\n        # submit the coroutine to async world\n        asyncio.run_coroutine_threadsafe(_infer(),\n                                         loop).add_done_callback(lambda f: None if f.cancelled() else f.result())\n\n        return iter(que.get, None)\n\n    def _run(self, fn=None, coro=None):\n        assert (fn or coro) and not (fn and coro)\n        loop = self.internal_thread.loop\n        if fn:\n\n            async def _coro():\n                return fn()\n\n            coro = _coro()\n        return asyncio.run_coroutine_threadsafe(coro, loop)\n\n    def _batch_iterator(self, sizes, max_value):\n        \"\"\"Return an iterator that calculates intervals (start, end) of a\n        descend-order list, in which the sum of values in the range is the\n        maximum number not less than max_value. By \"the sum of values\",\n\n        here it means $$len(sizes[start:end]) * sizes[start]$$\n        \"\"\"\n        i = 0\n        while i < len(sizes):\n            current_sum = 0\n            start_index = i\n\n            while i < len(sizes) and current_sum + sizes[start_index] <= max_value:\n                current_sum += sizes[start_index]\n                i += 1\n\n            yield (start_index, i)\n            if i > start_index:\n                continue\n            else:\n                i += 1\n\n    def _get_long_text_ppl(self, session, input_ids, max_input_len):\n        assert all(isinstance(_, int) for _ in input_ids)\n        seq_len = len(input_ids)\n        assert seq_len > max_input_len\n        logger.info(f'get long text ppl: seq_len {seq_len}')\n\n        losses = []\n        target_counts = []\n        for i in range(0, seq_len, max_input_len):\n            token_ids = input_ids[i:i + max_input_len]\n            session.update(step=i)\n            # shift token_ids by 1 to the left\n            target_ids = input_ids[i + 1:i + 1 + max_input_len]\n            loss = self._get_ppl(sessions=[session],\n                                 input_ids=[token_ids],\n                                 max_input_len=len(token_ids),\n                                 target_ids=[target_ids],\n                                 sequence_start=(i == 0),\n                                 sequence_end=False)\n            losses.extend(loss)\n            target_counts.append(len(target_ids))\n        losses = [loss * target_count for loss, target_count in zip(losses, target_counts)]\n        loss_sum = sum(losses)\n        target_count = sum(target_counts)\n        return loss_sum / target_count\n\n    def _get_ppl(self,\n                 sessions: List['Session'],\n                 input_ids: List[List[int]],\n                 max_input_len: int,\n                 target_ids=None,\n                 sequence_start: bool = True,\n                 sequence_end: bool = True):\n        assert (isinstance(input_ids, List) and all(isinstance(_, List) for _ in input_ids))\n        assert target_ids is None or len(target_ids) == len(input_ids)\n        assert len(sessions) == len(input_ids)\n\n        lens = [len(_) for _ in input_ids]\n        total_len = sum(lens)\n        assert sum(lens) <= max_input_len\n\n        logger.info(f'get_ppl: bs: {len(input_ids)}, lens: {lens}, '\n                    f'total_len: {total_len}')\n        torch.cuda.empty_cache()\n\n        logits = self._run(coro=self.async_engine.async_get_logits(\n            input_ids=input_ids, sessions=sessions, sequence_start=sequence_start, sequence_end=sequence_end)).result()\n        padding_token_id = -100\n        if target_ids is None:\n            target_ids = [x[1:] + [padding_token_id] for x in input_ids]\n        else:\n            target_ids = [\n                target_ids[i] + [padding_token_id] if len(target_ids[i]) < len(input_ids[i]) else target_ids[i]\n                for i in range(len(input_ids))\n            ]\n        target_ids = [torch.Tensor(torch.LongTensor(_target_ids)) for _target_ids in target_ids]\n\n        result = []\n        for _logits, _target_ids in zip(logits, target_ids):\n            _logits = _logits.float()\n            vocab_size = _logits.shape[-1]\n            _target_ids = _target_ids.to(_logits.device)\n            target_mask = _target_ids != padding_token_id\n            # compute cross entropy loss\n            flat_logits = _logits.contiguous().view(-1, vocab_size)\n            flat_target_ids = _target_ids.contiguous().view(-1)\n            flat_loss_matrix = torch.nn.functional.cross_entropy(flat_logits,\n                                                                 flat_target_ids,\n                                                                 reduction='none',\n                                                                 ignore_index=padding_token_id)\n            loss = flat_loss_matrix.sum()\n            target_count = target_mask.sum()\n            result.append(loss.item() / target_count.item())\n        logger.info(f'ppl result: {result}')\n        return result\n\n\nclass _EventLoopThread:\n\n    def __init__(self, daemon=False):\n        fut = concurrent.futures.Future()\n        self.thread = Thread(target=partial(self._thread_entry, fut), daemon=daemon)\n        self.thread.start()\n        self.loop: asyncio.AbstractEventLoop = fut.result()\n        self.closed = False\n        if daemon:\n            atexit.register(self.close)\n\n    def _thread_entry(self, fut):\n        loop = asyncio.new_event_loop()\n        asyncio.set_event_loop(loop)\n        fut.set_result(loop)\n        try:\n            loop.run_forever()\n        except BaseException as e:\n            logger.error(f'[internal_thread] {type(e).__name__} {e}')\n        finally:\n            try:\n                self._cancel_all_tasks()\n                loop.run_until_complete(loop.shutdown_asyncgens())\n            finally:\n                asyncio.set_event_loop(None)\n                loop.close()\n\n    def _cancel_all_tasks(self):\n        \"\"\"Modified from asyncio/runners.py.\"\"\"\n        to_cancel = asyncio.all_tasks(self.loop)\n        if not to_cancel:\n            return\n\n        for task in to_cancel:\n            task.cancel()\n\n        async def _gather():\n            await asyncio.gather(*to_cancel, return_exceptions=True)\n\n        self.loop.run_until_complete(_gather())\n\n        for task in to_cancel:\n            if task.cancelled():\n                continue\n            if task.exception() is not None:\n                self.loop.call_exception_handler({\n                    'message': 'unhandled exception during worker thread shutdown',\n                    'exception': task.exception(),\n                    'task': task,\n                })\n\n    def close(self):\n        if self.closed:\n            return\n        self.closed = True\n        self.loop.call_soon_threadsafe(self.loop.stop)\n        self.thread.join()\n"
  },
  {
    "path": "lmdeploy/profiler.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport csv\nimport os\nimport time\nfrom typing import List\n\nimport numpy as np\n\n\nclass Session:\n\n    UNKNOWN = 0\n    SUCCESS = 1\n    FAIL = 2\n\n    def __init__(self, input_len, req_output_len):\n        self.ts = []\n        self.ns = []\n        self.input_len = input_len\n        self.req_output_len = req_output_len\n        self.status = Session.UNKNOWN\n\n    def tick(self, n_token):\n        self.ts.append(time.perf_counter())\n        self.ns.append(n_token)\n\n    def finish(self, status):\n        self.status = status\n\n\nclass Profiler:\n\n    def __init__(self, stream_output: bool, percentages: List[int]):\n        self.sessions: List[Session] = []\n        self.stream_output = stream_output\n        self.percentages = percentages\n\n    def new_session(self, *args, **kwargs):\n        sess = Session(*args, **kwargs)\n        self.sessions.append(sess)\n        return sess\n\n    def start(self):\n        self.t_start = time.perf_counter()\n\n    def finish(self):\n        self.elapsed_time = time.perf_counter() - self.t_start\n\n    def compute_metrics(self):\n        self.ttfts: List[float] = []\n        self.tpots: List[float] = []\n        self.e2es: List[float] = []\n        self.itls: List[float] = []\n        self.tpts: List[int] = []\n        self.total_output = 0\n        self.total_input = 0\n        self.success = 0\n\n        for sess in self.sessions:\n            if sess.status != Session.SUCCESS:\n                continue\n            ns = sess.ns\n            ts = sess.ts\n            if ns[-1] < sess.req_output_len:\n                continue\n            self.success += 1\n            self.total_output += ns[-1]\n            self.total_input += sess.input_len\n            self.e2es.append(ts[-1] - ts[0])\n            self.ttfts.append(ts[1] - ts[0])\n            if ns[-1] > ns[1]:\n                self.tpots.append((ts[-1] - ts[1]) / (ns[-1] - ns[1]))\n            else:  # no-stream-output\n                self.tpots.append((ts[-1] - ts[0]) / (ns[-1] - ns[0]))\n            t_dif = np.subtract(ts[1:], ts[:-1])\n            n_dif = np.subtract(ns[1:], ns[:-1])\n            self.itls.extend(t_dif[1:])\n            self.tpts.extend(n_dif)\n\n        self.output_throughput = self.total_output / self.elapsed_time\n        self.input_throughput = self.total_input / self.elapsed_time\n\n        qs = self.percentages\n\n        self.e2es = self.e2es or [float('inf')]\n        self.tpots = self.tpots or [float('inf')]\n        self.ttfts = self.ttfts or [float('inf')]\n        self.itls = self.itls or [float('inf')]\n        self.tpts = self.tpts or [0]\n\n        self.tpot_mean = np.mean(self.tpots)\n        self.tpot_stat = tuple(np.percentile(self.tpots, qs))\n        self.e2e_mean = np.mean(self.e2es)\n        self.e2e_stat = tuple(np.percentile(self.e2es, qs))\n\n        if self.stream_output:\n            self.ttft_mean = np.mean(self.ttfts)\n            self.ttft_stat = tuple(np.percentile(self.ttfts, qs))\n            self.itls_mean = np.mean(self.itls)\n            self.itls_stat = tuple(np.percentile(self.itls, qs))\n            self.tpts_mean = np.mean(self.tpts)\n            self.tpts_stat = tuple(np.percentile(self.tpts, qs).astype(int))\n\n        self.rps = self.success / self.elapsed_time\n\n    def summarize(self, title: str, hyperparams: List = None, header=40, digits=10):\n\n        width = header + digits * (1 + len(self.percentages))\n\n        def tab_row(name, *items):\n\n            def fmt(x):\n                return '{:>{d}.3f}'.format(x, d=digits) if isinstance(x, float) else '{:>{d}}'.format(x, d=digits)\n\n            print('{:<{p}}{}'.format(name, ''.join([fmt(x) for x in items]), p=header))\n\n        print('\\n{s:{c}^{n}}'.format(s=f' {title} ', n=width, c='='))\n        tab_row('Benchmark duration', self.elapsed_time)\n        tab_row('Total requests', len(self.sessions))\n        tab_row('Successful requests', self.success)\n        if hyperparams:\n            for k, v in hyperparams:\n                tab_row(k, v)\n        tab_row('Total input tokens', self.total_input)\n        tab_row('Total generated tokens', self.total_output)\n        tab_row('Input throughput (tok/s)', self.input_throughput)\n        tab_row('Output throughput (tok/s)', self.output_throughput)\n        tab_row('Request throughput (req/s)', self.rps)\n        print('-' * width)\n        tab_row('', 'mean', *(f'P{q}' for q in self.percentages))\n        tab_row('End-to-end Latency', self.e2e_mean, *self.e2e_stat)\n        if self.stream_output:\n            tab_row('Time to First Token (TTFT)', self.ttft_mean, *self.ttft_stat)\n        tab_row('Time per Output Token (TPOT)', self.tpot_mean, *self.tpot_stat)\n        if self.stream_output:\n            tab_row('Inter-token Latency (ITL)', self.itls_mean, *self.itls_stat)\n            tab_row('Tokens per Tick', self.tpts_mean, *self.tpts_stat)\n        print('=' * width)\n\n    def save_csv(self, csv_file: str, hyperparams):\n        \"\"\"Export legacy metrics to CSV.\"\"\"\n        file_exists = os.path.isfile(csv_file)\n        with open(csv_file, mode='a', newline='') as csvfile:\n            writer = csv.writer(csvfile)\n            keys, vals = zip(*hyperparams)\n            if not file_exists:\n                writer.writerow([\n                    *keys,\n                    'completed',\n                    'total_input_tokens',\n                    'total_output_tokens',\n                    'duration',\n                    'request_throughput',\n                    'input_throughput',\n                    'output_throughput',\n                    'mean_e2e_latency_ms',\n                    'mean_ttft_ms',\n                    'mean_tpot_ms',\n                    'mean_itl_ms',\n                ])\n            writer.writerow([\n                *vals,\n                self.success,\n                self.total_input,\n                self.total_output,\n                self.elapsed_time,\n                f'{self.rps:.3f}',\n                f'{(self.input_throughput):.3f}',\n                f'{self.output_throughput:.3f}',\n                f'{self.e2e_mean*1000:.3f}',\n                f'{self.ttft_mean*1000:.3f}' if self.stream_output else '-',\n                f'{self.tpot_mean*1000:.3f}',\n                f'{self.itls_mean*1000:.3f}' if self.stream_output else '-',\n            ])\n"
  },
  {
    "path": "lmdeploy/pytorch/__init__.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n"
  },
  {
    "path": "lmdeploy/pytorch/adapter/__init__.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n"
  },
  {
    "path": "lmdeploy/pytorch/adapter/adapter.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nimport re\nfrom typing import Dict, Iterable, List, Tuple\n\nimport torch\nfrom torch import nn\n\n\ndef get_ranks_and_scalings(target_name: str, cfgs: Iterable, device: torch.device = None):\n    \"\"\"Get ranks and scalings.\"\"\"\n    ranks = []\n    scalings = []\n    for cfg in cfgs:\n        if target_name not in cfg.target_modules:\n            ranks.append(0)\n            scalings.append(1)\n            continue\n        ranks.append(cfg.r)\n        scalings.append(float(cfg.lora_alpha / cfg.r))\n    ranks = torch.tensor(ranks, device=device)\n    scalings = torch.tensor(scalings, device=device)\n    return ranks, scalings\n\n\ndef find_all_target(model: torch.nn.Module, target_name: str):\n    \"\"\"Find all targets.\"\"\"\n    # find packed name\n    packed_name = target_name\n    pack_idx = None\n    packed_modules_mapping = getattr(model, 'packed_modules_mapping', dict())\n    for name, sub_names in packed_modules_mapping.items():\n        if target_name in sub_names:\n            pack_idx = sub_names.index(target_name)\n            packed_name = name\n            break\n\n    found_mods = []\n    name_postfix = f'.{packed_name}'\n    for name, mod in model.named_modules():\n        if not name.endswith(name_postfix):\n            continue\n        found_mods.append((name, mod))\n\n    return found_mods, pack_idx\n\n\ndef get_layer_index(key: str, layers_pattern: str = None):\n    \"\"\"Get layer index of the lora linear.\"\"\"\n    if isinstance(layers_pattern, str):\n        layers_pattern = [layers_pattern]\n    if layers_pattern is None or len(layers_pattern) == 0:\n        layer_index = re.match(r'.*\\.[^.]*\\.(\\d+)\\.', key)\n        return int(layer_index[1])\n    else:\n        for pattern in layers_pattern:\n            layer_index = re.match(f'.*.{pattern}\\\\.(\\\\d+)\\\\.*', key)\n\n            if layer_index is not None:\n                return int(layer_index[1])\n\n\ndef _get_reverse_pack_map(model: nn.Module):\n    \"\"\"Get reverse pack map.\"\"\"\n    packed_modules_mapping = getattr(model, 'packed_modules_mapping', dict())\n    reverse_map = dict()\n    for pack_name, names in packed_modules_mapping.items():\n        for name in names:\n            reverse_map[name] = pack_name\n    return reverse_map\n\n\ndef _get_key_map(reverse_map: Dict[str, str]):\n    \"\"\"Get key map.\"\"\"\n    key_map = dict()\n    for name, pack_name in reverse_map.items():\n        key = f'.{name}'\n        val = f'.{pack_name}.lora_adapters.{name}'\n        key_map[key] = val\n\n    return key_map\n\n\ndef load_lora_weights(model: nn.Module, weights: Iterable[Tuple[str, torch.Tensor]], adapter_id: int):\n    \"\"\"Load lora weights.\"\"\"\n    from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight\n    prefix_len = len('base_model.model.')\n    w_len = len('.weight')\n    reverse_map = _get_reverse_pack_map(model)\n    key_map = _get_key_map(reverse_map)\n\n    params_dict = dict(model.named_parameters())\n    for name, loaded_weight in weights:\n        name = name[prefix_len:]\n        splited_name = name.split('.')\n        assert splited_name[-1] == 'weight'\n        assert splited_name[-2] in ['lora_A', 'lora_B']\n        mod_name = splited_name[-3]\n        dot_mod_name = f'.{mod_name}'\n        if dot_mod_name in key_map:\n            replace_name = key_map[dot_mod_name]\n        else:\n            replace_name = f'.{mod_name}.lora_adapters.{mod_name}'\n        name = name[:-w_len]\n        param_name = name.replace(dot_mod_name, replace_name)\n\n        param = params_dict[param_name]\n        load_weight(param, loaded_weight, adapter_id=adapter_id)\n\n\nclass AdapterManager:\n    \"\"\"Adapter manager.\"\"\"\n\n    def __init__(self, adapters: Dict[str, str]):\n        if adapters is None:\n            adapters = dict()\n\n        adapter_names = list(adapters.keys())\n        adapter_names = sorted(adapter_names)\n        adapter_names = [None] + adapter_names\n\n        adapter_id_map = dict(zip(adapter_names, range(len(adapter_names))))\n        self.adapter_id_map = adapter_id_map\n\n    def get_adapter_ids(self, names: List[str]):\n        return [self.adapter_id_map[name] for name in names]\n\n    def num_adapters(self):\n        return len(self.adapter_id_map)\n"
  },
  {
    "path": "lmdeploy/pytorch/backends/__init__.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom .base import OpType  # noqa: F401\nfrom .selector import get_backend  # noqa: F401\n"
  },
  {
    "path": "lmdeploy/pytorch/backends/activation.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom abc import ABC, abstractmethod\n\n\nclass SiluAndMulImpl(ABC):\n    \"\"\"Silu + multiple residual fused implementation.\"\"\"\n\n    @abstractmethod\n    def forward(self, x):\n        \"\"\"forward.\"\"\"\n        raise NotImplementedError\n\n\nclass SiluAndMulBuilder(ABC):\n    \"\"\"Silu and mul implementation builder.\"\"\"\n\n    @staticmethod\n    @abstractmethod\n    def build(inplace: bool = False):\n        \"\"\"build.\"\"\"\n        raise NotImplementedError\n\n\nclass GeluAndMulImpl(ABC):\n    \"\"\"Gelu + multiple residual fused implementation.\"\"\"\n\n    @abstractmethod\n    def forward(self, x):\n        \"\"\"forward.\"\"\"\n        raise NotImplementedError\n\n\nclass GeluAndMulBuilder(ABC):\n    \"\"\"Gelu and mul implementation builder.\"\"\"\n\n    @staticmethod\n    @abstractmethod\n    def build(approximate: str = 'none'):\n        \"\"\"build.\"\"\"\n        raise NotImplementedError\n"
  },
  {
    "path": "lmdeploy/pytorch/backends/apply_rotary_emb.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom abc import ABC, abstractmethod\n\nfrom torch import Tensor\n\n\nclass ApplyRotaryEmbImpl(ABC):\n    \"\"\"Apply rotary embedding implementation.\"\"\"\n\n    @abstractmethod\n    def forward(self, query: Tensor, key: Tensor, cos: Tensor, sin: Tensor, inplace: bool = True):\n        \"\"\"forward.\"\"\"\n        raise NotImplementedError\n\n\nclass ApplyRotaryEmbBuilder(ABC):\n    \"\"\"Apply rotary embedding implementation builder.\"\"\"\n\n    @staticmethod\n    @abstractmethod\n    def build():\n        \"\"\"Build implementation.\"\"\"\n        raise NotImplementedError\n"
  },
  {
    "path": "lmdeploy/pytorch/backends/attention.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom abc import ABC, abstractmethod\nfrom dataclasses import dataclass\nfrom functools import lru_cache\nfrom typing import Generic, Literal, TypeVar\n\nimport torch\n\n\n@dataclass\nclass AttentionMetadata:\n    \"\"\"Base Attention metadata.\"\"\"\n    is_decoding: bool\n    block_offsets: torch.Tensor\n    q_start_loc: torch.Tensor = None\n    q_seqlens: torch.Tensor = None\n    kv_seqlens: torch.Tensor = None\n    fill_seqlens: torch.Tensor = None\n    cu_seqlens_q: torch.Tensor = None\n    cu_seqlens_k: torch.Tensor = None\n    quant_policy: Literal[0, 4, 8] = 0\n\n\nT = TypeVar('T', bound=AttentionMetadata)\n\n\nclass AttentionImpl(ABC, Generic[T]):\n    \"\"\"Attention implementation.\"\"\"\n\n    def __init__(\n        self,\n        num_heads: int,\n        head_size: int,\n        scale: float = None,\n        num_kv_heads: int = None,\n        v_head_size: int = None,\n        alibi: bool = None,\n        sliding_window: int = None,\n        logit_softcapping: float = 0.0,\n        causal: bool = True,\n        use_flash_mla: bool = False,\n        **kwargs,\n    ) -> None:\n        if scale is None:\n            scale = 1.0 / (head_size**0.5)\n\n        if num_kv_heads is None:\n            num_kv_heads = num_heads\n\n        if v_head_size is None:\n            v_head_size = head_size\n\n        self.num_heads = num_heads\n        self.head_size = head_size\n        self.scale = scale\n        self.num_kv_heads = num_kv_heads\n        self.v_head_size = v_head_size\n        self.alibi = alibi\n        self.sliding_window = sliding_window\n        self.logit_softcapping = logit_softcapping\n        self.causal = causal\n        self.use_flash_mla = use_flash_mla\n        self.alibi_slopes = None\n\n    @staticmethod\n    @lru_cache(maxsize=4)\n    def make_alibi_slopes(head_start: int, head_end: int, num_heads: int, alibi_scale: float, dtype: torch.dtype,\n                          device: torch.device):\n        \"\"\"Make alibi slopes.\"\"\"\n        head_ids = torch.arange(head_start, head_end, dtype=dtype, device=device)\n        num_heads_tensor = head_ids.new_full([1], num_heads)\n        num_heads_p2 = num_heads_tensor.log2().to(torch.int64).exp2()\n\n        # update head_ids and closest_power_of_2\n        mask = head_ids < num_heads_p2\n        head_ids = torch.where(mask, head_ids, (head_ids - num_heads_p2) * 2)\n        closest_power_of_2 = torch.where(mask, num_heads_p2, num_heads_p2 * 2)\n\n        # get slope\n        start = torch.sub(3, closest_power_of_2.log2()).exp2().neg()\n        start = start.exp2()\n        ratio = start\n        return start * torch.pow(ratio, head_ids) * alibi_scale\n\n    def set_alibi_slopes(self, slopes: torch.Tensor):\n        self.alibi_slopes = slopes\n\n    @abstractmethod\n    def forward(\n        self,\n        query: torch.Tensor,\n        key: torch.Tensor,\n        value: torch.Tensor,\n        k_cache: torch.Tensor,\n        v_cache: torch.Tensor,\n        attn_metadata: T,\n        k_scales_zeros: torch.Tensor = None,\n        v_scales_zeros: torch.Tensor = None,\n        learnable_sink: torch.Tensor = None,\n        nsa_indices: torch.Tensor = None,\n        inplace: bool = False,\n    ) -> torch.Tensor:\n        \"\"\"forward.\"\"\"\n        raise NotImplementedError\n\n\nclass AttentionBuilder(ABC, Generic[T]):\n    \"\"\"Attention implementation builder.\"\"\"\n\n    @staticmethod\n    @abstractmethod\n    def build(\n        num_heads: int,\n        head_size: int,\n        scale: float = None,\n        num_kv_heads: int = None,\n        v_head_size: int = None,\n        alibi: bool = False,\n        sliding_window: int = None,\n        logit_softcapping: float = 0.0,\n        causal: bool = True,\n        use_flash_mla: bool = False,\n        learnable_sink: bool = False,\n        block_sparse_size: int = 1,\n        **kwargs,\n    ) -> AttentionImpl[T]:\n        \"\"\"build.\"\"\"\n        raise NotImplementedError\n"
  },
  {
    "path": "lmdeploy/pytorch/backends/awq_modules.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom abc import ABC, abstractmethod\nfrom typing import Optional\n\nimport torch\n\n\nclass LinearW4A16Impl(ABC):\n    \"\"\"W4a16 linear implementation.\"\"\"\n\n    def update_weights(self,\n                       qweight: torch.Tensor,\n                       scales: torch.Tensor,\n                       qzeros: torch.Tensor,\n                       bias: Optional[torch.Tensor] = None):\n        \"\"\"Update weights.\"\"\"\n        return qweight, scales, qzeros, bias\n\n    @abstractmethod\n    def forward(self,\n                x,\n                weight: torch.Tensor,\n                bias: Optional[torch.Tensor] = None,\n                all_reduce: bool = False,\n                group: Optional[torch.distributed.ProcessGroup] = None):\n        \"\"\"forward.\"\"\"\n        raise NotImplementedError\n\n\nclass LinearW4A16Builder(ABC):\n    \"\"\"W4a16 linear implementation builder.\"\"\"\n\n    @staticmethod\n    @abstractmethod\n    def build(in_features: int,\n              out_features: int,\n              w_bit: int,\n              group_size: int,\n              bias: bool = False,\n              dtype: torch.dtype = None):\n        \"\"\"build.\"\"\"\n        raise NotImplementedError\n"
  },
  {
    "path": "lmdeploy/pytorch/backends/base.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n# modify from:\n# https://github.com/vllm-project/vllm/blob/main/vllm/attention/backends/abstract.py\nfrom abc import ABC, abstractmethod\nfrom enum import Enum, auto\nfrom typing import Tuple\n\nimport torch\n\nfrom lmdeploy.pytorch.config import BackendConfig, CacheConfig, ModelConfig\n\n\nclass OpType(Enum):\n    \"\"\"Layer type enumerate.\"\"\"\n    PagedAttention = auto()\n    FlashAttention = auto()\n    Linear = auto()\n    RotaryEmbedding = auto()\n    ApplyRotaryEmb = auto()\n    SiluAndMul = auto()\n    GeluAndMul = auto()\n    RMSNorm = auto()\n    LayerNorm = auto()\n    LoRA = auto()\n    LinearW8A8 = auto()\n    RMSNormW8A8 = auto()\n    MultinomialSampling = auto()\n    LinearW4A16 = auto()\n    SoftmaxTopK = auto()\n    FusedMoE = auto()\n    FusedMoEW8A8 = auto()\n    LinearBlockedF8 = auto()\n    FusedMoEBlockedF8 = auto()\n    NSAIndexFP8 = auto()\n    Embedding = auto()\n\n    # MoE router\n    RouterNoauxTC = auto()\n\n    # Gated Delta\n    CausalConv1d = auto()\n    GatedDeltaRule = auto()\n\n\nclass OpsBackend(ABC):\n    \"\"\"Layer backend abstract.\"\"\"\n\n    @staticmethod\n    @abstractmethod\n    def get_name() -> str:\n        \"\"\"Get backend name.\"\"\"\n        raise NotImplementedError\n\n    @classmethod\n    @abstractmethod\n    def get_layer_impl_builder(cls, layer_type: OpType):\n        \"\"\"Get builder of given layer type.\"\"\"\n        raise NotImplementedError\n\n    @staticmethod\n    @abstractmethod\n    def get_attention_metadata_cls():\n        \"\"\"Get attention metadata class.\"\"\"\n        raise NotImplementedError\n\n    @staticmethod\n    @abstractmethod\n    def get_k_block_shape(\n        block_size: int,\n        num_heads: int,\n        head_size: int,\n        dtype: torch.dtype,\n    ) -> Tuple[int, ...]:\n        \"\"\"Get block shape of k.\"\"\"\n        raise NotImplementedError\n\n    @staticmethod\n    @abstractmethod\n    def get_v_block_shape(\n        block_size: int,\n        num_heads: int,\n        head_size: int,\n        dtype: torch.dtype,\n    ) -> Tuple[int, ...]:\n        \"\"\"Get block shape of v.\"\"\"\n        raise NotImplementedError\n\n    @classmethod\n    def update_step_context(cls, step_context):\n        \"\"\"Update StepContext for inference.\n\n        attention meta should be built here.\n        \"\"\"\n        return step_context\n\n    @staticmethod\n    def build_graph_runner(model: torch.nn.Module, model_config: ModelConfig, cache_config: CacheConfig,\n                           backend_config: BackendConfig, device: torch.device):\n        \"\"\"Build graph runner.\"\"\"\n        from .graph_runner import GraphRunner\n        return GraphRunner(model, model_config, cache_config, backend_config, device)\n\n    @staticmethod\n    def device_count():\n        \"\"\"Get num available devices.\"\"\"\n        return None\n\n    @staticmethod\n    def support_ray():\n        \"\"\"Support ray.\"\"\"\n        return False\n"
  },
  {
    "path": "lmdeploy/pytorch/backends/blockedf8_modules.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom abc import ABC, abstractmethod\nfrom typing import List, Optional\n\nimport torch\nimport torch.distributed as dist\n\n\nclass LinearBlockedF8Impl(ABC):\n    \"\"\"Linear BlockedF8 implementation api.\"\"\"\n\n    def __init__(self):\n        self.scale_fmt: Optional[str] = None\n\n    def update_weights(self, weight: torch.Tensor, scale: torch.Tensor, bias: Optional[torch.Tensor] = None):\n        \"\"\"Update weights.\"\"\"\n        return weight, scale, bias\n\n    def set_scale_fmt(self, scale_fmt: Optional[str]):\n        \"\"\"Set scale fmt.\"\"\"\n        self.scale_fmt = scale_fmt\n\n    @abstractmethod\n    def forward(self,\n                x,\n                weight: torch.Tensor,\n                scale: torch.Tensor,\n                bias: Optional[torch.Tensor] = None,\n                all_reduce: bool = False,\n                group: Optional[dist.ProcessGroup] = None,\n                rank: int = 0,\n                scatter_size: List[int] = None):\n        \"\"\"forward.\"\"\"\n        raise NotImplementedError\n\n\nclass LinearBlockedF8Builder(ABC):\n    \"\"\"Linear BlockedF8 implementation builder.\"\"\"\n\n    @staticmethod\n    @abstractmethod\n    def build(in_features: int, out_features: int, bias: bool = True, dtype: torch.dtype = None):\n        \"\"\"build.\"\"\"\n        raise NotImplementedError\n"
  },
  {
    "path": "lmdeploy/pytorch/backends/causal_conv1d.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom abc import ABC, abstractmethod\n\nimport torch\n\n\nclass CausalConv1dImpl(ABC):\n    \"\"\"CausalConv1d implementation api.\"\"\"\n\n    @abstractmethod\n    def conv1d_fn(self,\n                  x: torch.Tensor,\n                  weight: torch.Tensor,\n                  bias: torch.Tensor | None = None,\n                  seq_idx: torch.Tensor | None = None,\n                  return_final_states: bool = False,\n                  activation: str | None = None):\n        \"\"\"forward.\"\"\"\n        raise NotImplementedError\n\n    @abstractmethod\n    def update_fn(self,\n                  x: torch.Tensor,\n                  conv_state: torch.Tensor,\n                  weight: torch.Tensor,\n                  bias: torch.Tensor | None = None,\n                  activation: str | None = None,\n                  conv_state_indices: torch.Tensor | None = None):\n        \"\"\"Update conv state.\"\"\"\n        raise NotImplementedError\n\n\nclass CausalConv1dBuilder(ABC):\n    \"\"\"CausalConv1d implementation builder.\"\"\"\n\n    @staticmethod\n    @abstractmethod\n    def build():\n        \"\"\"build.\"\"\"\n        raise NotImplementedError\n"
  },
  {
    "path": "lmdeploy/pytorch/backends/cuda/__init__.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom .op_backend import CudaOpsBackend  # noqa: F401\n"
  },
  {
    "path": "lmdeploy/pytorch/backends/cuda/activation.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom lmdeploy.pytorch.kernels.cuda.activation import silu_and_mul\n\nfrom ..activation import SiluAndMulBuilder, SiluAndMulImpl\n\n\nclass TritonSiluAndMulImpl(SiluAndMulImpl):\n    \"\"\"Silu + multiple residual fused implementation.\"\"\"\n\n    def __init__(self, inplace: bool):\n        self.inplace = inplace\n\n    def forward(self, x):\n        \"\"\"forward.\"\"\"\n        out = None\n        x_shape = None\n        if x.dim() != 2:\n            x_shape = x.shape\n            x = x.flatten(0, -2)\n        if self.inplace:\n            out = x.chunk(2, -1)[0]\n\n        out = silu_and_mul(x, out)\n\n        if x_shape is not None:\n            out = out.unflatten(0, x_shape[:-1])\n        return out\n\n\nclass TritonSiluAndMulBuilder(SiluAndMulBuilder):\n    \"\"\"Silu and mul implementation builder.\"\"\"\n\n    @staticmethod\n    def build(inplace: bool = False):\n        \"\"\"build.\"\"\"\n        return TritonSiluAndMulImpl(inplace)\n"
  },
  {
    "path": "lmdeploy/pytorch/backends/cuda/apply_rotary_emb.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport torch\nfrom torch import Tensor\n\nfrom lmdeploy.pytorch.kernels.cuda import apply_rotary_pos_emb\n\nfrom ..apply_rotary_emb import ApplyRotaryEmbBuilder, ApplyRotaryEmbImpl\n\n\nclass TritonApplyRotaryEmbImpl(ApplyRotaryEmbImpl):\n    \"\"\"Apply rotary embedding implementation.\"\"\"\n\n    def forward(self, query: Tensor, key: Tensor, cos: Tensor, sin: Tensor, inplace: bool = True):\n        \"\"\"forward.\"\"\"\n        if inplace:\n            q_embed = query\n            k_embed = key\n        else:\n            q_embed = torch.empty_like(query)\n            k_embed = torch.empty_like(key)\n        return apply_rotary_pos_emb(query, key, cos, sin, q_embed, k_embed)\n\n\nclass TritonApplyRotaryEmbBuilder(ApplyRotaryEmbBuilder):\n    \"\"\"Apply rotary embedding implementation builder.\"\"\"\n\n    @staticmethod\n    def build():\n        \"\"\"Build implementation.\"\"\"\n        return TritonApplyRotaryEmbImpl()\n"
  },
  {
    "path": "lmdeploy/pytorch/backends/cuda/attention/__init__.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport functools\n\nimport torch\n\nfrom lmdeploy.pytorch.backends.attention import AttentionBuilder\nfrom lmdeploy.utils import get_logger\n\nfrom .default import TritonAttentionImpl, TritonAttentionMetadata\n\nlogger = get_logger('lmdeploy')\n\nuse_fa3 = False\ntry:\n    # Now flash-attention only support FA3 for sm90a && cuda >= 12.3\n    if (torch.cuda.get_device_capability()[0] == 9) and (torch.version.cuda >= '12.3'):\n        import lmdeploy.pytorch.third_party.flash_attn_interface  # noqa: F401\n        assert torch.ops.flash_attn_3 is not None\n        use_fa3 = True\nexcept Exception:\n    logger.debug('For higher performance, please install FlashAttention-3 '\n                 'https://github.com/Dao-AILab/flash-attention')\n\n\n@functools.lru_cache\ndef use_fa3_warning():\n    if use_fa3:\n        return True\n    logger.warning('For higher performance, please install FlashAttention-3 '\n                   'https://github.com/Dao-AILab/flash-attention')\n    return False\n\n\n@functools.lru_cache\ndef _enable_fa3(alibi: bool, learnable_sink: bool, block_sparse_size: int, head_size: int) -> bool:\n    \"\"\"Check if FA3 should be enabled.\n\n    FA3 is enabled when:\n    - No alibi\n    - No learnable sink\n    - block_sparse_size == 1\n    - FA3 is available (checked by use_fa3_warning)\n\n    Returns:\n        True if FA3 should be enabled, False otherwise.\n    \"\"\"\n    enable = not alibi and not learnable_sink and block_sparse_size == 1 and head_size <= 256\n    if enable and not use_fa3_warning():\n        enable = False\n    return enable\n\n\ndef _normalize_sliding_window(sliding_window):\n    \"\"\"Normalize sliding window to tuple format.\n\n    Args:\n        sliding_window: None, int, or tuple of (left, right).\n\n    Returns:\n        Tuple of (left, right) or (-1, -1) if None.\n    \"\"\"\n    if sliding_window is None:\n        return (-1, -1)\n    if isinstance(sliding_window, int):\n        return (sliding_window, sliding_window)\n    return sliding_window\n\n\nclass TritonAttentionBuilder(AttentionBuilder[TritonAttentionMetadata]):\n    \"\"\"Triton attention builder.\n\n    This builder selects the appropriate attention implementation based on:\n    1. use_flash_mla: Use FlashMLAImpl for MLA models\n    2. enable_fa3: Use FA3Impl if FA3 is available and supported\n    3. Default: Use TritonAttentionImpl as fallback\n    \"\"\"\n\n    @staticmethod\n    def build(\n        num_heads: int,\n        head_size: int,\n        scale: float = None,\n        num_kv_heads: int = None,\n        v_head_size: int = None,\n        alibi: bool = False,\n        sliding_window: int = None,\n        logit_softcapping: float = 0.0,\n        causal: bool = True,\n        use_flash_mla: bool = False,\n        learnable_sink: bool = False,\n        block_sparse_size: int = 1,\n        **kwargs,\n    ) -> TritonAttentionImpl:\n        \"\"\"Build appropriate attention implementation.\n\n        Args:\n            num_heads: Number of attention heads.\n            head_size: Size of each attention head.\n            scale: Scaling factor for attention scores.\n            num_kv_heads: Number of key-value heads (for GQA).\n            v_head_size: Size of value head (for MLA).\n            alibi: Whether to use ALiBi positional encoding.\n            sliding_window: Sliding window size for local attention.\n            logit_softcapping: Logit softcapping value (for Gemma 2).\n            causal: Whether to use causal attention.\n            use_flash_mla: Whether to use Flash MLA implementation.\n            learnable_sink: Whether to use learnable sink tokens.\n            block_sparse_size: Block sparse attention size.\n            **kwargs: Additional arguments.\n\n        Returns:\n            Appropriate AttentionImpl instance.\n        \"\"\"\n        # Normalize sliding window format\n        sliding_window = _normalize_sliding_window(sliding_window)\n\n        # Common arguments for all implementations\n        common_args = dict(\n            num_heads=num_heads,\n            head_size=head_size,\n            scale=scale,\n            num_kv_heads=num_kv_heads,\n            v_head_size=v_head_size,\n            alibi=alibi,\n            sliding_window=sliding_window,\n            logit_softcapping=logit_softcapping,\n            causal=causal,\n            **kwargs,\n        )\n        enable_fa3 = _enable_fa3(alibi, learnable_sink, block_sparse_size, head_size)\n\n        if use_flash_mla is True:\n            logger.debug('Build FlashMLAImpl Attention')\n            from .mla import FlashMLAImpl\n            return FlashMLAImpl(use_fa3=use_fa3, **common_args)\n        elif enable_fa3:\n            logger.debug('Build FA3Impl Attention')\n            from .fa3 import FA3Impl\n            return FA3Impl(**common_args)\n        else:\n            logger.debug('Build TritonAttentionImpl Attention')\n            return TritonAttentionImpl(block_sparse_size=block_sparse_size, **common_args)\n"
  },
  {
    "path": "lmdeploy/pytorch/backends/cuda/attention/default.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom dataclasses import dataclass\nfrom typing import Literal\n\nimport torch\n\nfrom lmdeploy.pytorch.backends.attention import AttentionImpl, AttentionMetadata\nfrom lmdeploy.utils import get_logger\n\nlogger = get_logger('lmdeploy')\n\n\n@dataclass\nclass TritonAttentionMetadata(AttentionMetadata):\n    \"\"\"Triton attention metadata.\n\n    This dataclass contains all metadata needed for attention computation\n    across different stages (prefill/decoding) and implementations.\n\n    Attributes:\n        is_decoding: True for decoding stage, False for prefill.\n        block_offsets: Block indices for paged KV cache [batch_size, max_blocks].\n        q_start_loc: Start location of each query sequence [batch_size].\n        q_seqlens: Length of each query sequence [batch_size].\n        kv_start_loc: Start location of each KV sequence [batch_size].\n        kv_seqlens: Length of each KV sequence [batch_size].\n        quant_policy: Quantization policy (0=none, 4=int4, 8=int8/fp8).\n        kv_flatten_size: Total size of flattened KV cache.\n        tile_scheduler_metadata: Scheduler metadata for Flash MLA.\n        num_splits: Number of splits for Flash MLA.\n        cu_seqlens_q: Cumulative query sequence lengths [batch_size + 1].\n        cu_seqlens_k: Cumulative KV sequence lengths [batch_size + 1].\n        scheduler_metadata: Scheduler metadata for FA3.\n        max_kv_seqlen: Maximum KV sequence length in the batch.\n        max_q_seqlen: Maximum query sequence length in the batch.\n    \"\"\"\n    is_decoding: bool\n    block_offsets: torch.Tensor\n    q_start_loc: torch.Tensor = None\n    q_seqlens: torch.Tensor = None\n    kv_start_loc: torch.Tensor = None\n    kv_seqlens: torch.Tensor = None\n    quant_policy: Literal[0, 4, 8] = 0\n    kv_flatten_size: int = None\n    # flash mla\n    tile_scheduler_metadata: torch.Tensor = None\n    num_splits: torch.Tensor = None\n    cu_seqlens_q: torch.Tensor = None\n    cu_seqlens_k: torch.Tensor = None\n    # flash attn\n    scheduler_metadata: torch.Tensor = None\n    max_kv_seqlen: int = None\n    max_q_seqlen: int = None\n\n\ndef _cdiv(a, b):\n    \"\"\"Perform ceiling division (division rounded up).\n\n    Args:\n        a: Dividend.\n        b: Divisor.\n\n    Returns:\n        Ceiling of a / b.\n    \"\"\"\n    return (a + b - 1) // b\n\n\nclass TritonAttentionImpl(AttentionImpl[TritonAttentionMetadata]):\n    \"\"\"Triton attention implementation.\"\"\"\n\n    def __init__(\n        self,\n        num_heads: int,\n        head_size: int,\n        scale: float = None,\n        num_kv_heads: int = None,\n        v_head_size: int = None,\n        alibi: bool = False,\n        sliding_window: int = None,\n        logit_softcapping: float = 0.0,\n        causal: bool = True,\n        block_sparse_size: int = 1,\n        **kwargs,\n    ):\n        super().__init__(\n            num_heads=num_heads,\n            head_size=head_size,\n            scale=scale,\n            num_kv_heads=num_kv_heads,\n            v_head_size=v_head_size,\n            alibi=alibi,\n            sliding_window=sliding_window,\n            logit_softcapping=logit_softcapping,\n            causal=causal,\n            **kwargs,\n        )\n        self.logit_softcapping = -1 if self.logit_softcapping <= 0.0 else self.logit_softcapping\n        assert not (alibi and not causal)\n\n        from lmdeploy.pytorch.kernels.cuda import (fill_kv_cache, flash_attn_varlen_func, flash_attn_with_kvcache,\n                                                   flatten_kv_cache)\n\n        self.fill_kv_cache = fill_kv_cache\n        self.paged_attention_fwd = flash_attn_with_kvcache\n        self.flatten_kv_cache = flatten_kv_cache\n        self.flash_attention_fwd = flash_attn_varlen_func\n\n        self.block_sparse_size = block_sparse_size\n\n    def _get_max_q_seqlen(\n        self,\n        query: torch.Tensor,\n        attn_metadata: TritonAttentionMetadata,\n    ) -> int:\n        \"\"\"Get max q seqlen.\"\"\"\n        if attn_metadata.is_decoding:\n            max_q_seqlen = self.block_sparse_size\n        else:\n            if attn_metadata.max_q_seqlen is not None:\n                max_q_seqlen = attn_metadata.max_q_seqlen\n            else:\n                max_q_seqlen = query.numel() // (query.size(-1) * query.size(-2))\n        return max_q_seqlen\n\n    def _get_fill_meta(\n        self,\n        key: torch.Tensor,\n        attn_metadata: TritonAttentionMetadata,\n        max_q_seqlen: int,\n    ):\n        \"\"\"Get fill meta.\"\"\"\n        fill_seqlens = attn_metadata.q_seqlens\n        fill_max_q_seqlen = max_q_seqlen\n        fill_q_start_loc = attn_metadata.q_start_loc\n        return fill_seqlens, fill_max_q_seqlen, fill_q_start_loc\n\n    def _fill_kv_cache_impl(\n        self,\n        key: torch.Tensor,\n        value: torch.Tensor,\n        k_cache: torch.Tensor,\n        v_cache: torch.Tensor,\n        attn_metadata: TritonAttentionMetadata,\n        max_q_seqlen: int,\n        k_scales_zeros: torch.Tensor = None,\n        v_scales_zeros: torch.Tensor = None,\n    ):\n        \"\"\"Fill kv cache.\"\"\"\n        kv_seqlens = attn_metadata.kv_seqlens\n        block_offsets = attn_metadata.block_offsets\n        quant_policy = attn_metadata.quant_policy\n\n        # fill seqlen args\n        fill_seqlens, fill_max_q_seqlen, fill_q_start_loc = self._get_fill_meta(\n            key,\n            attn_metadata,\n            max_q_seqlen,\n        )\n\n        # fill kv cache\n        self.fill_kv_cache(\n            key,\n            value,\n            k_cache,\n            v_cache,\n            fill_q_start_loc,\n            fill_seqlens,\n            kv_seq_length=kv_seqlens,\n            max_q_seq_length=fill_max_q_seqlen,\n            block_offsets=block_offsets,\n            k_scales_zeros=k_scales_zeros,\n            v_scales_zeros=v_scales_zeros,\n            quant_policy=quant_policy,\n        )\n\n    def _forward_decoding(\n        self,\n        query: torch.Tensor,\n        k_cache: torch.Tensor,\n        v_cache: torch.Tensor,\n        attn_metadata: TritonAttentionMetadata,\n        max_q_seqlen: int,\n        k_scales_zeros: torch.Tensor = None,\n        v_scales_zeros: torch.Tensor = None,\n        learnable_sink: torch.Tensor = None,\n    ) -> torch.Tensor:\n        \"\"\"Forward pass for decoding stage.\n\n        Args:\n            query: Query tensor.\n            k_cache: Key cache tensor.\n            v_cache: Value cache tensor.\n            attn_metadata: Attention metadata.\n            max_q_seqlen: Maximum query sequence length.\n            k_scales_zeros: Key quantization scales/zeros.\n            v_scales_zeros: Value quantization scales/zeros.\n            learnable_sink: Learnable sink tokens.\n\n        Returns:\n            Attention output tensor.\n        \"\"\"\n        block_offsets = attn_metadata.block_offsets\n        quant_policy = attn_metadata.quant_policy\n\n        attn_output = self.paged_attention_fwd(\n            query,\n            k_cache,\n            v_cache,\n            cache_seqlens=attn_metadata.kv_seqlens,\n            page_table=block_offsets,\n            cu_seqlens_q=attn_metadata.cu_seqlens_q,\n            max_seqlen_q=max_q_seqlen,\n            softmax_scale=self.scale,\n            softcap=self.logit_softcapping,\n            window_size=self.sliding_window,\n            # custom args\n            sinks=learnable_sink,\n            alibi_slopes=self.alibi_slopes,\n            quant_policy=quant_policy,\n            k_scales_zeros=k_scales_zeros,\n            v_scales_zeros=v_scales_zeros,\n        )\n        return attn_output\n\n    def _forward_prefill(\n        self,\n        query: torch.Tensor,\n        k_cache: torch.Tensor,\n        v_cache: torch.Tensor,\n        attn_metadata: TritonAttentionMetadata,\n        max_q_seqlen: int,\n        k_scales_zeros: torch.Tensor = None,\n        v_scales_zeros: torch.Tensor = None,\n        learnable_sink: torch.Tensor = None,\n    ) -> torch.Tensor:\n        \"\"\"Forward pass for prefill stage.\n\n        Args:\n            query: Query tensor.\n            k_cache: Key cache tensor.\n            v_cache: Value cache tensor.\n            attn_metadata: Attention metadata.\n            max_q_seqlen: Maximum query sequence length.\n            k_scales_zeros: Key quantization scales/zeros.\n            v_scales_zeros: Value quantization scales/zeros.\n            learnable_sink: Learnable sink tokens.\n\n        Returns:\n            Attention output tensor.\n        \"\"\"\n        block_offsets = attn_metadata.block_offsets\n        kv_start_loc = attn_metadata.kv_start_loc\n        kv_seqlens = attn_metadata.kv_seqlens\n        kv_flatten_size = attn_metadata.kv_flatten_size\n        quant_policy = attn_metadata.quant_policy\n\n        # Prepare flattened KV cache\n        BLOCK_BS = k_cache.size(1)\n        # pad one more block to avoid invalid kv visit\n        out_size = (_cdiv(kv_flatten_size, BLOCK_BS) * BLOCK_BS + BLOCK_BS)\n        kv_layout = 'hsd'  # custom triton kernel requires 'hsd' while fa3 requires 'shd'\n\n        flatten_k, flatten_v = self.flatten_kv_cache(\n            k_cache,\n            v_cache,\n            kv_seqlens,\n            block_offsets,\n            start_loc=kv_start_loc,\n            out_size=out_size,\n            out_dtype=query.dtype,\n            k_scales_zeros=k_scales_zeros,\n            v_scales_zeros=v_scales_zeros,\n            quant_policy=quant_policy,\n            flatten_kv_layout=kv_layout,\n        )\n\n        attn_output = self.flash_attention_fwd(\n            query,\n            flatten_k,\n            flatten_v,\n            cu_seqlens_q=attn_metadata.cu_seqlens_q,\n            cu_seqlens_k=attn_metadata.cu_seqlens_k,\n            max_seqlen_q=max_q_seqlen,\n            max_seqlen_k=attn_metadata.max_kv_seqlen,\n            window_size=self.sliding_window,\n            softmax_scale=self.scale,\n            softcap=self.logit_softcapping,\n            causal=self.causal,\n            # custom args\n            sinks=learnable_sink,\n            alibi_slopes=self.alibi_slopes,\n            block_sparse_size=self.block_sparse_size,\n            kv_layout=kv_layout,\n        )\n        return attn_output\n\n    def forward(\n        self,\n        query: torch.Tensor,\n        key: torch.Tensor,\n        value: torch.Tensor,\n        k_cache: torch.Tensor,\n        v_cache: torch.Tensor,\n        attn_metadata: TritonAttentionMetadata,\n        k_scales_zeros: torch.Tensor = None,\n        v_scales_zeros: torch.Tensor = None,\n        learnable_sink: torch.Tensor = None,\n        inplace: bool = True,\n        **kwargs,\n    ) -> torch.Tensor:\n        \"\"\"Forward pass for attention computation.\n\n        This method handles both prefill and decoding stages by:\n        1. Computing max query sequence length\n        2. Filling KV cache if new key/value are provided\n        3. Dispatching to appropriate stage-specific method\n\n        Args:\n            query: Query tensor.\n            key: Key tensor (None for decoding-only).\n            value: Value tensor (None for decoding-only).\n            k_cache: Key cache tensor.\n            v_cache: Value cache tensor.\n            attn_metadata: Attention metadata containing stage info and indices.\n            k_scales_zeros: Key quantization scales/zeros.\n            v_scales_zeros: Value quantization scales/zeros.\n            learnable_sink: Learnable sink tokens.\n            inplace: Whether to modify query inplace (unused, kept for compatibility).\n\n        Returns:\n            Attention output tensor.\n        \"\"\"\n        # Shared preparation\n        max_q_seqlen = self._get_max_q_seqlen(query, attn_metadata)\n\n        # Fill KV cache with new key/value if provided\n        if key is not None and value is not None:\n            self._fill_kv_cache_impl(\n                key,\n                value,\n                k_cache=k_cache,\n                v_cache=v_cache,\n                attn_metadata=attn_metadata,\n                max_q_seqlen=max_q_seqlen,\n                k_scales_zeros=k_scales_zeros,\n                v_scales_zeros=v_scales_zeros,\n            )\n\n        # Validate alibi configuration\n        if self.alibi:\n            assert self.alibi_slopes is not None, 'alibi_slopes is not set.'\n\n        # Dispatch to stage-specific forward method\n        if attn_metadata.is_decoding:\n            return self._forward_decoding(\n                query,\n                k_cache,\n                v_cache,\n                attn_metadata,\n                max_q_seqlen,\n                k_scales_zeros=k_scales_zeros,\n                v_scales_zeros=v_scales_zeros,\n                learnable_sink=learnable_sink,\n            )\n        else:\n            return self._forward_prefill(\n                query,\n                k_cache,\n                v_cache,\n                attn_metadata,\n                max_q_seqlen,\n                k_scales_zeros=k_scales_zeros,\n                v_scales_zeros=v_scales_zeros,\n                learnable_sink=learnable_sink,\n            )\n"
  },
  {
    "path": "lmdeploy/pytorch/backends/cuda/attention/fa3.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport torch\n\nfrom lmdeploy.utils import get_logger\n\nfrom .default import TritonAttentionImpl, TritonAttentionMetadata\n\nlogger = get_logger('lmdeploy')\n\n\nclass FA3Impl(TritonAttentionImpl):\n    \"\"\"Flash Attention 3 implementation.\n\n    This implementation leverages Flash Attention 3's optimized kernels for both\n    prefill and decoding stages. FA3 provides significant performance improvements\n    on Hopper architecture (SM90) with CUDA >= 12.3.\n\n    Key features:\n    - Optimized prefill using flash_attn_varlen_func\n    - Speculative decoding support with multi-token queries\n    - Standard single-token decoding with paged attention\n    \"\"\"\n\n    def __init__(\n        self,\n        num_heads: int,\n        head_size: int,\n        scale: float = None,\n        num_kv_heads: int = None,\n        v_head_size: int = None,\n        alibi: bool = False,\n        sliding_window: tuple = None,\n        logit_softcapping: float = 0.0,\n        causal: bool = True,\n        **kwargs,\n    ):\n        assert alibi is False, 'alibi not supported for FA3'\n        super().__init__(\n            num_heads=num_heads,\n            head_size=head_size,\n            scale=scale,\n            num_kv_heads=num_kv_heads,\n            v_head_size=v_head_size,\n            alibi=alibi,\n            sliding_window=sliding_window,\n            logit_softcapping=logit_softcapping,\n            causal=causal,\n            **kwargs,\n        )\n        from lmdeploy.pytorch.third_party.flash_attn_interface import flash_attn_varlen_func, flash_attn_with_kvcache\n        self.flash_attn_varlen_func_v3 = flash_attn_varlen_func\n        self.flash_attn_with_kvcache_v3 = flash_attn_with_kvcache\n\n    def _get_max_q_seqlen(\n        self,\n        query: torch.Tensor,\n        attn_metadata: TritonAttentionMetadata,\n    ) -> int:\n        \"\"\"Get max q seqlen.\"\"\"\n        max_q_seqlen = query.numel() // (query.size(-1) * query.size(-2))\n        if attn_metadata.is_decoding:\n            batch_size = attn_metadata.q_seqlens.size(0)\n            max_q_seqlen = max_q_seqlen // batch_size\n        return max_q_seqlen\n\n    def _normalize_sliding_window(self, sliding_window):\n        \"\"\"Normalize sliding window to tuple format.\n\n        Args:\n            sliding_window: Sliding window size (None, int, or tuple).\n\n        Returns:\n            Tuple of (left_window, right_window) or (-1, -1) if None.\n        \"\"\"\n        if sliding_window is None:\n            return (-1, -1)\n        if isinstance(sliding_window, int):\n            return (sliding_window, sliding_window)\n        return sliding_window\n\n    def _decoding_speculative(\n        self,\n        query: torch.Tensor,\n        k_cache: torch.Tensor,\n        v_cache: torch.Tensor,\n        attn_metadata: TritonAttentionMetadata,\n        max_q_seqlen: int,\n    ) -> torch.Tensor:\n        \"\"\"Speculative decoding with multi-token queries.\n\n        This path handles speculative decoding where multiple tokens are generated\n        in parallel (max_q_seqlen > 1). Uses FA3's flash_attn_with_kvcache for\n        efficient batched computation.\n\n        Args:\n            query: Query tensor to unflatten.\n            k_cache: Key cache tensor.\n            v_cache: Value cache tensor.\n            attn_metadata: Attention metadata.\n            max_q_seqlen: Maximum query sequence length (> 1).\n\n        Returns:\n            Attention output tensor.\n        \"\"\"\n        block_offsets = attn_metadata.block_offsets\n        sliding_window = self._normalize_sliding_window(self.sliding_window)\n\n        # Reshape query for batched processing\n        query = query.unflatten(0, (-1, max_q_seqlen))\n\n        attn_output = self.flash_attn_with_kvcache_v3(\n            query,\n            k_cache,\n            v_cache,\n            cache_seqlens=attn_metadata.kv_seqlens.to(torch.int32),\n            max_seqlen_q=max_q_seqlen,\n            scheduler_metadata=attn_metadata.scheduler_metadata,\n            page_table=block_offsets,\n            softmax_scale=self.scale,\n            causal=self.causal,\n            window_size=sliding_window,\n            softcap=self.logit_softcapping,\n        )\n        return attn_output\n\n    def _decoding_standard(\n        self,\n        query: torch.Tensor,\n        k_cache: torch.Tensor,\n        v_cache: torch.Tensor,\n        attn_metadata: TritonAttentionMetadata,\n        max_q_seqlen: int,\n        k_scales_zeros: torch.Tensor = None,\n        v_scales_zeros: torch.Tensor = None,\n    ) -> torch.Tensor:\n        \"\"\"Standard single-token decoding.\n\n        This path handles standard decoding where only one token is generated\n        per request (max_q_seqlen = 1). Uses paged attention for memory efficiency.\n\n        Args:\n            query: Query tensor (single token per request).\n            k_cache: Key cache tensor.\n            v_cache: Value cache tensor.\n            attn_metadata: Attention metadata.\n            max_q_seqlen: Maximum query sequence length (= 1).\n            k_scales_zeros: Key quantization scales/zeros.\n            v_scales_zeros: Value quantization scales/zeros.\n\n        Returns:\n            Attention output tensor.\n        \"\"\"\n        block_offsets = attn_metadata.block_offsets\n        quant_policy = attn_metadata.quant_policy\n\n        attn_output = self.paged_attention_fwd(\n            query,\n            k_cache,\n            v_cache,\n            cache_seqlens=attn_metadata.kv_seqlens,\n            page_table=block_offsets,\n            cu_seqlens_q=attn_metadata.cu_seqlens_q,\n            max_seqlen_q=max_q_seqlen,\n            scheduler_metadata=attn_metadata.scheduler_metadata,\n            softmax_scale=self.scale,\n            causal=self.causal,\n            softcap=self.logit_softcapping,\n            window_size=self.sliding_window,\n            # custom args\n            k_scales_zeros=k_scales_zeros,\n            v_scales_zeros=v_scales_zeros,\n            quant_policy=quant_policy,\n        )\n        return attn_output\n\n    def _forward_decoding(\n        self,\n        query: torch.Tensor,\n        k_cache: torch.Tensor,\n        v_cache: torch.Tensor,\n        attn_metadata: TritonAttentionMetadata,\n        max_q_seqlen: int,\n        k_scales_zeros: torch.Tensor = None,\n        v_scales_zeros: torch.Tensor = None,\n    ) -> torch.Tensor:\n        \"\"\"Forward pass for decoding stage.\n\n        Supports two decoding modes:\n        1. Speculative decoding: Multiple tokens (max_q_seqlen > 1)\n        2. Standard decoding: Single token (max_q_seqlen = 1)\n\n        Args:\n            query: Query tensor.\n            k_cache: Key cache tensor.\n            v_cache: Value cache tensor.\n            attn_metadata: Attention metadata.\n            max_q_seqlen: Maximum query sequence length.\n            k_scales_zeros: Key quantization scales/zeros.\n            v_scales_zeros: Value quantization scales/zeros.\n\n        Returns:\n            Attention output tensor.\n        \"\"\"\n        if max_q_seqlen > 1:\n            return self._decoding_speculative(query, k_cache, v_cache, attn_metadata, max_q_seqlen)\n        else:\n            return self._decoding_standard(query, k_cache, v_cache, attn_metadata, max_q_seqlen, k_scales_zeros,\n                                           v_scales_zeros)\n\n    def _forward_prefill(\n        self,\n        query: torch.Tensor,\n        k_cache: torch.Tensor,\n        v_cache: torch.Tensor,\n        attn_metadata: TritonAttentionMetadata,\n        max_q_seqlen: int,\n        k_scales_zeros: torch.Tensor = None,\n        v_scales_zeros: torch.Tensor = None,\n    ) -> torch.Tensor:\n        \"\"\"Forward pass for prefill stage.\n\n        Uses FA3's flash_attn_varlen_func for efficient variable-length attention\n        computation during the prefill phase.\n\n        Args:\n            query: Query tensor.\n            k_cache: Key cache tensor.\n            v_cache: Value cache tensor.\n            attn_metadata: Attention metadata.\n            max_q_seqlen: Maximum query sequence length.\n            k_scales_zeros: Key quantization scales/zeros.\n            v_scales_zeros: Value quantization scales/zeros.\n\n        Returns:\n            Attention output tensor.\n        \"\"\"\n        block_offsets = attn_metadata.block_offsets\n        kv_start_loc = attn_metadata.kv_start_loc\n        kv_seqlens = attn_metadata.kv_seqlens\n        kv_flatten_size = attn_metadata.kv_flatten_size\n        quant_policy = attn_metadata.quant_policy\n\n        # Flatten KV cache for varlen attention\n        flatten_k, flatten_v = self.flatten_kv_cache(\n            k_cache,\n            v_cache,\n            kv_seqlens,\n            block_offsets,\n            start_loc=kv_start_loc,\n            out_size=kv_flatten_size,\n            out_dtype=query.dtype,\n            k_scales_zeros=k_scales_zeros,\n            v_scales_zeros=v_scales_zeros,\n            quant_policy=quant_policy,\n            flatten_kv_layout='shd',\n        )\n\n        sliding_window = self._normalize_sliding_window(self.sliding_window)\n\n        attn_output = self.flash_attn_varlen_func_v3(\n            q=query,\n            k=flatten_k,\n            v=flatten_v,\n            cu_seqlens_q=attn_metadata.cu_seqlens_q,\n            cu_seqlens_k=attn_metadata.cu_seqlens_k,\n            max_seqlen_q=max_q_seqlen,\n            max_seqlen_k=kv_flatten_size,\n            softmax_scale=self.scale,\n            causal=self.causal,\n            window_size=sliding_window,\n            softcap=self.logit_softcapping,\n        )\n        return attn_output\n\n    def forward(\n        self,\n        query: torch.Tensor,\n        key: torch.Tensor,\n        value: torch.Tensor,\n        k_cache: torch.Tensor,\n        v_cache: torch.Tensor,\n        attn_metadata: TritonAttentionMetadata,\n        k_scales_zeros: torch.Tensor = None,\n        v_scales_zeros: torch.Tensor = None,\n        learnable_sink: torch.Tensor = None,\n        inplace: bool = True,\n    ) -> torch.Tensor:\n        \"\"\"Forward pass for FA3 attention computation.\n\n        This method handles both prefill and decoding stages by:\n        1. Computing max query sequence length\n        2. Filling KV cache if new key/value are provided\n        3. Dispatching to appropriate stage-specific method\n\n        Architecture:\n        - Decoding: Supports both speculative (multi-token) and standard (single-token)\n        - Prefill: Uses flash_attn_varlen_func for efficient varlen attention\n\n        Args:\n            query: Query tensor.\n            key: Key tensor (None for decoding-only).\n            value: Value tensor (None for decoding-only).\n            k_cache: Key cache tensor.\n            v_cache: Value cache tensor.\n            attn_metadata: Attention metadata containing stage info and indices.\n            k_scales_zeros: Key quantization scales/zeros.\n            v_scales_zeros: Value quantization scales/zeros.\n            learnable_sink: Learnable sink tokens (unused in FA3).\n            inplace: Whether to modify query inplace (unused, kept for compatibility).\n\n        Returns:\n            Attention output tensor.\n        \"\"\"\n        # Shared preparation\n        max_q_seqlen = self._get_max_q_seqlen(query, attn_metadata)\n\n        # Fill KV cache with new key/value if provided\n        if key is not None and value is not None:\n            self._fill_kv_cache_impl(\n                key,\n                value,\n                k_cache=k_cache,\n                v_cache=v_cache,\n                attn_metadata=attn_metadata,\n                max_q_seqlen=max_q_seqlen,\n                k_scales_zeros=k_scales_zeros,\n                v_scales_zeros=v_scales_zeros,\n            )\n\n        # Dispatch to stage-specific forward method\n        if attn_metadata.is_decoding:\n            return self._forward_decoding(\n                query,\n                k_cache,\n                v_cache,\n                attn_metadata,\n                max_q_seqlen,\n                k_scales_zeros,\n                v_scales_zeros,\n            )\n        else:\n            return self._forward_prefill(\n                query,\n                k_cache,\n                v_cache,\n                attn_metadata,\n                max_q_seqlen,\n                k_scales_zeros,\n                v_scales_zeros,\n            )\n"
  },
  {
    "path": "lmdeploy/pytorch/backends/cuda/attention/mla.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nimport functools\n\nimport torch\n\nfrom lmdeploy.utils import get_logger\n\nfrom .default import TritonAttentionImpl, TritonAttentionMetadata\n\nlogger = get_logger('lmdeploy')\n\n\ndef _cdiv(a, b):\n    \"\"\"Perform div up.\"\"\"\n    return (a + b - 1) // b\n\n\ndef _try_dynamic_compile(func, *args, **kwargs):\n    \"\"\"Try compile.\"\"\"\n    try:\n        compiled_func = torch.compile(func, dynamic=True)\n        compiled_func(*args, **kwargs)\n        return compiled_func\n    except Exception:\n        return func\n\n\nclass NSAIndicesUpdater:\n    \"\"\"NSA indices updater.\n\n    Flash MLA sparse attention requires different indice format for prefill and decoding. This module is used to update\n    the indices to meet the requirements.\n    \"\"\"\n\n    def __init__(self):\n        self._update_decode_func = None\n        self._update_prefill_func = None\n\n    def _update_decode_impl(self, nsa_indices: torch.Tensor, block_offsets: torch.Tensor,\n                            block_size: int) -> torch.Tensor:\n        \"\"\"Update for decode impl.\"\"\"\n        block_ids = nsa_indices // block_size\n        block_ids = block_ids.clamp_min(0)\n        block_ids = block_offsets.gather(1, block_ids)\n        block_remain = nsa_indices % block_size\n        ret = block_ids * block_size + block_remain\n        ret[nsa_indices < 0] = -1\n        return ret[:, None]\n\n    def update_decode(self, nsa_indices: torch.Tensor, block_offsets: torch.Tensor, block_size: int) -> torch.Tensor:\n        \"\"\"Update for decode.\"\"\"\n        if self._update_decode_func is None:\n            self._update_decode_func = _try_dynamic_compile(self._update_decode_impl, nsa_indices, block_offsets,\n                                                            block_size)\n\n        return self._update_decode_func(nsa_indices, block_offsets, block_size)\n\n    def _update_prefill_impl(self, nsa_indices: torch.Tensor, q_seqlens: torch.Tensor, cu_seqlens_k: torch.Tensor):\n        \"\"\"Update for prefill impl.\"\"\"\n        num_tokens = nsa_indices.size(0)\n        repeat_cu_seqlens_k = torch.repeat_interleave(cu_seqlens_k[:-1], q_seqlens, output_size=num_tokens)\n        neg_mask = nsa_indices < 0\n        nsa_indices = nsa_indices + repeat_cu_seqlens_k[:, None]\n        nsa_indices[neg_mask] = -1\n        return nsa_indices[:, None]\n\n    def update_prefill(self, nsa_indices: torch.Tensor, q_seqlens: torch.Tensor, cu_seqlens_k: torch.Tensor):\n        \"\"\"Update for prefill.\"\"\"\n        if self._update_prefill_func is None:\n            self._update_prefill_func = _try_dynamic_compile(self._update_prefill_impl, nsa_indices, q_seqlens,\n                                                             cu_seqlens_k)\n\n        return self._update_prefill_func(nsa_indices, q_seqlens, cu_seqlens_k)\n\n    @staticmethod\n    @functools.lru_cache(maxsize=None)\n    def build():\n        return NSAIndicesUpdater()\n\n\nclass FlashMLAImpl(TritonAttentionImpl):\n    \"\"\"Flash MLA (Multi-head Latent Attention) implementation.\n\n    This implementation supports multiple execution paths:\n    - Decoding: Uses flash_mla_with_kvcache with paged KV cache\n    - Prefill with NSA: Uses flash_mla_sparse_fwd for sparse attention\n    - Prefill with FA3: Uses flash_attn_varlen_func with split q_rope/q_nope\n    - Prefill fallback: Uses custom Triton kernel\n    \"\"\"\n\n    # MLA-specific constants\n    _MLA_HEAD_ALIGNMENT = 64  # Query heads must be multiple of 64 for flash_mla\n    _MLA_NOPE_SIZE = 512  # Size of non-positional embeddings\n    _MLA_SCALE_SIZE = 16  # Size of FP8 quantization scales\n\n    def __init__(\n        self,\n        num_heads: int,\n        head_size: int,\n        scale: float = None,\n        num_kv_heads: int = None,\n        v_head_size: int = None,\n        alibi: bool = False,\n        sliding_window: tuple = None,\n        logit_softcapping: float = 0.0,\n        causal: bool = True,\n        use_fa3: bool = False,\n        **kwargs,\n    ):\n        assert (sliding_window is None\n                or all(win == -1 for win in sliding_window)), ('sliding window not supported for FlashMLA')\n        assert alibi is False, 'alibi not supported for FlashMLA'\n        if logit_softcapping > 0.0:\n            logger.warning('logit_softcapping not properly supported for FlashMLA, using -1.0')\n            logit_softcapping = -1.0\n        super().__init__(\n            num_heads=num_heads,\n            head_size=head_size,\n            scale=scale,\n            num_kv_heads=num_kv_heads,\n            v_head_size=v_head_size,\n            alibi=alibi,\n            sliding_window=sliding_window,\n            logit_softcapping=logit_softcapping,\n            causal=causal,\n            **kwargs,\n        )\n\n        import flash_mla\n\n        from lmdeploy.pytorch.kernels.cuda.fill_kv_cache import fill_kv_cache_blocked_fp8\n        from lmdeploy.pytorch.kernels.cuda.flatten_kv_cache import flatten_kv_cache_mla_fp8\n        self.flash_mla_with_kvcache = flash_mla.flash_mla_with_kvcache\n        self.flash_mla_sparse_fwd = None\n        self.fill_kv_cache_blocked_fp8 = fill_kv_cache_blocked_fp8\n        self.flatten_kv_cache_mla_fp8 = flatten_kv_cache_mla_fp8\n        assert num_kv_heads == 1, 'MLA requires num kv heads equal to 1'\n        self.use_fa3 = use_fa3\n\n        self.nsa_updater = NSAIndicesUpdater.build()\n\n    def _get_flash_mla_sparse_fwd(self):\n        if self.flash_mla_sparse_fwd is not None:\n            return self.flash_mla_sparse_fwd\n\n        try:\n            import flash_mla\n            self.flash_mla_sparse_fwd = flash_mla.flash_mla_sparse_fwd\n            return self.flash_mla_sparse_fwd\n        except Exception:\n            logger.exception('Can not import flash_mla_sparse_fwd from flash_mla.')\n\n    def flash_mla_decoding(\n        self,\n        query: torch.Tensor,\n        k_cache: torch.Tensor,\n        nsa_indices: torch.Tensor,\n        attn_metadata: TritonAttentionMetadata,\n    ):\n        \"\"\"Flash mla decoding.\"\"\"\n        causal = self.causal\n        kv_seqlens = attn_metadata.kv_seqlens\n        block_offsets = attn_metadata.block_offsets\n        is_fp8_kvcache = k_cache.dtype == torch.float8_e4m3fn\n\n        q_seqlens = attn_metadata.q_seqlens\n        batch_size = q_seqlens.size(0)\n        max_q_seqlen = query.numel() // (query.size(-1) * query.size(-2))\n        max_q_seqlen = max_q_seqlen // batch_size\n        query = query.unflatten(0, (batch_size, max_q_seqlen))\n        if kv_seqlens.dtype == torch.int64:\n            kv_seqlens = kv_seqlens.to(torch.int32)\n\n        # update nsa indice according to flash-mla requirement\n        if nsa_indices is not None:\n            block_size = k_cache.size(1)\n            nsa_indices = self.nsa_updater.update_decode(nsa_indices, block_offsets, block_size)\n            causal = False\n\n        attn_output, _ = self.flash_mla_with_kvcache(query,\n                                                     k_cache=k_cache,\n                                                     block_table=block_offsets,\n                                                     cache_seqlens=kv_seqlens,\n                                                     head_dim_v=self.v_head_size,\n                                                     softmax_scale=self.scale,\n                                                     tile_scheduler_metadata=attn_metadata.tile_scheduler_metadata,\n                                                     num_splits=attn_metadata.num_splits,\n                                                     causal=causal,\n                                                     is_fp8_kvcache=is_fp8_kvcache,\n                                                     indices=nsa_indices)\n\n        attn_output = attn_output.flatten(0, 1)\n        return attn_output\n\n    def _prefill_sparse(self, query: torch.Tensor, flatten_k: torch.Tensor, nsa_indices: torch.Tensor,\n                        attn_metadata: TritonAttentionMetadata) -> torch.Tensor:\n        \"\"\"Sparse prefill using flash_mla_sparse_fwd.\n\n        This path is used when NSA (Non-contiguous Sparse Attention) indices are provided.\n        Requires FP8 KV cache and flash_mla library.\n\n        Args:\n            query: Query tensor.\n            flatten_k: Flattened key cache.\n            nsa_indices: Sparse attention indices.\n            attn_metadata: Attention metadata.\n\n        Returns:\n            Attention output tensor.\n        \"\"\"\n        q_seqlens = attn_metadata.q_seqlens\n        flash_mla_sparse_fwd = self._get_flash_mla_sparse_fwd()\n\n        num_q_heads = query.size(1)\n        # flash_mla_sparse_fwd requires query heads to be multiple of alignment\n        if num_q_heads % self._MLA_HEAD_ALIGNMENT != 0:\n            padding = self._MLA_HEAD_ALIGNMENT - num_q_heads % self._MLA_HEAD_ALIGNMENT\n            query = torch.nn.functional.pad(query, (0, 0, 0, padding))\n\n        nsa_indices = self.nsa_updater.update_prefill(nsa_indices, q_seqlens, attn_metadata.cu_seqlens_k)\n        output = flash_mla_sparse_fwd(\n            query,\n            flatten_k,\n            nsa_indices,\n            sm_scale=self.scale,\n        )\n        attn_output = output[0]\n        attn_output = attn_output[:, :num_q_heads]\n        return attn_output\n\n    def _prefill_triton(\n        self,\n        query: torch.Tensor,\n        flatten_k: torch.Tensor,\n        flatten_v: torch.Tensor,\n        attn_metadata: TritonAttentionMetadata,\n    ) -> torch.Tensor:\n        \"\"\"Triton-based prefill fallback.\n\n        This is the fallback path when Flash Attention 3 is not available.\n        Uses custom Triton kernel for attention computation.\n\n        Args:\n            query: Query tensor.\n            flatten_k: Flattened key cache.\n            flatten_v: Flattened value cache.\n            attn_metadata: Attention metadata.\n\n        Returns:\n            Attention output tensor.\n        \"\"\"\n        max_q_seqlen = query.numel() // (query.size(-1) * query.size(-2))\n\n        attn_output = self.flash_attention_fwd(\n            query,\n            flatten_k,\n            flatten_v,\n            cu_seqlens_q=attn_metadata.cu_seqlens_q,\n            cu_seqlens_k=attn_metadata.cu_seqlens_k,\n            max_seqlen_q=max_q_seqlen,\n            max_seqlen_k=attn_metadata.max_kv_seqlen,\n            window_size=self.sliding_window,\n            softmax_scale=self.scale,\n            softcap=self.logit_softcapping,\n            causal=self.causal,\n        )\n\n        return attn_output\n\n    def _prefill_fa3(\n        self,\n        query: torch.Tensor,\n        flatten_k: torch.Tensor,\n        attn_metadata: TritonAttentionMetadata,\n    ) -> torch.Tensor:\n        \"\"\"Flash Attention 3 optimized prefill.\n\n        This path uses Flash Attention 3's optimized kernels with split\n        rope (positional) and nope (non-positional) components.\n\n        Args:\n            query: Query tensor.\n            flatten_k: Flattened key cache.\n            attn_metadata: Attention metadata.\n\n        Returns:\n            Attention output tensor.\n        \"\"\"\n        max_q_seqlen = query.numel() // (query.size(-1) * query.size(-2))\n        kv_flatten_size = attn_metadata.kv_flatten_size\n        causal = self.causal\n\n        # Split query and key into rope (positional) and nope (non-positional) parts\n        q_rope = query[:, :, self.v_head_size:]\n        q_nope = query[:, :, :self.v_head_size]\n        k_rope = flatten_k.view(kv_flatten_size, self.num_kv_heads, -1)[:, :, self.v_head_size:]\n        c_kv = flatten_k.view(kv_flatten_size, self.num_kv_heads, -1)[:, :, :self.v_head_size]\n        from lmdeploy.pytorch.third_party.flash_attn_interface import flash_attn_varlen_func\n        attn_output = flash_attn_varlen_func(\n            q=q_rope,\n            k=k_rope,\n            v=c_kv,\n            qv=q_nope,\n            cu_seqlens_q=attn_metadata.cu_seqlens_q,\n            cu_seqlens_k=attn_metadata.cu_seqlens_k,\n            max_seqlen_q=max_q_seqlen,\n            max_seqlen_k=kv_flatten_size,\n            softmax_scale=self.scale,\n            causal=causal,\n            window_size=(-1, -1) if self.sliding_window is None else self.sliding_window,\n        )\n        return attn_output\n\n    def run_flatten_kv_cache(self,\n                             k_cache: torch.Tensor,\n                             v_cache: torch.Tensor,\n                             attn_metadata: TritonAttentionMetadata,\n                             out_dtype: torch.dtype,\n                             is_nsa: bool,\n                             k_scales_zeros: torch.Tensor = None,\n                             v_scales_zeros: torch.Tensor = None):\n        \"\"\"Flatten kv cache for prefill.\"\"\"\n\n        kv_start_loc = attn_metadata.kv_start_loc\n        kv_seqlens = attn_metadata.kv_seqlens\n        block_offsets = attn_metadata.block_offsets\n        kv_flatten_size = attn_metadata.kv_flatten_size\n        quant_policy = attn_metadata.quant_policy\n        is_fp8_kvcache = k_cache.dtype == torch.float8_e4m3fn\n        BLOCK_BS = k_cache.size(1)\n\n        # pad one more block to avoid invalid kv visit\n        if self.use_fa3 or is_nsa:\n            out_size = kv_flatten_size\n            flatten_kv_layout = 'shd'\n        else:\n            out_size = (_cdiv(kv_flatten_size, BLOCK_BS) * BLOCK_BS + BLOCK_BS)\n            flatten_kv_layout = 'hsd'\n\n        if is_fp8_kvcache:\n            flatten_k = self.flatten_kv_cache_mla_fp8(\n                k_cache,\n                kv_seqlens,\n                block_offsets,\n                start_loc=kv_start_loc,\n                out_size=out_size,\n                out_dtype=out_dtype,\n                flatten_kv_layout=flatten_kv_layout,\n            )\n            flatten_v = flatten_k[..., :self._MLA_NOPE_SIZE]\n        else:\n            flatten_k, flatten_v = self.flatten_kv_cache(\n                k_cache,\n                v_cache,\n                kv_seqlens,\n                block_offsets,\n                start_loc=kv_start_loc,\n                out_size=out_size,\n                out_dtype=out_dtype,\n                k_scales_zeros=k_scales_zeros,\n                v_scales_zeros=v_scales_zeros,\n                quant_policy=quant_policy,\n                flatten_kv_layout=flatten_kv_layout,\n            )\n\n        return flatten_k, flatten_v\n\n    def _get_max_q_seqlen(\n        self,\n        query: torch.Tensor,\n        attn_metadata: TritonAttentionMetadata,\n    ) -> int:\n        \"\"\"Get max q seqlen.\"\"\"\n        q_seqlens = attn_metadata.q_seqlens\n        max_q_seqlen = query.numel() // (query.size(-1) * query.size(-2))\n        batch_size = q_seqlens.size(0)\n        if attn_metadata.is_decoding:\n            max_q_seqlen = max_q_seqlen // batch_size\n        return max_q_seqlen\n\n    def _fill_kv_cache_impl(self,\n                            key: torch.Tensor,\n                            value: torch.Tensor,\n                            k_cache: torch.Tensor,\n                            v_cache: torch.Tensor,\n                            attn_metadata: TritonAttentionMetadata,\n                            max_q_seqlen: int,\n                            k_scales_zeros: torch.Tensor = None,\n                            v_scales_zeros: torch.Tensor = None):\n        \"\"\"Fill kv cache.\"\"\"\n        is_fp8_kvcache = k_cache.dtype == torch.float8_e4m3fn\n        if not is_fp8_kvcache:\n            return super()._fill_kv_cache_impl(\n                key,\n                value,\n                k_cache,\n                v_cache,\n                attn_metadata,\n                max_q_seqlen,\n                k_scales_zeros=k_scales_zeros,\n                v_scales_zeros=v_scales_zeros,\n            )\n\n        block_offsets = attn_metadata.block_offsets\n        kv_seqlens = attn_metadata.kv_seqlens\n        quant_policy = attn_metadata.quant_policy\n        assert quant_policy == 0\n\n        # fill seqlen args\n        fill_seqlens, fill_max_q_seqlen, fill_q_start_loc = self._get_fill_meta(\n            key,\n            attn_metadata,\n            max_q_seqlen,\n        )\n\n        # Split k_cache into nope, scale, and pe components\n        scale_offset = self._MLA_NOPE_SIZE\n        scale_end = scale_offset + self._MLA_SCALE_SIZE\n        k_cache_scale = k_cache[..., scale_offset:scale_end].view(torch.float32)\n        k_cache_nope = k_cache[..., :self._MLA_NOPE_SIZE]\n        k_cache_pe = k_cache[..., scale_end:].view(key.dtype)\n        self.fill_kv_cache_blocked_fp8(\n            key[..., :self._MLA_NOPE_SIZE],\n            None,\n            k_cache_nope,\n            None,\n            k_cache_scale,\n            None,\n            cu_seqlen_q=attn_metadata.cu_seqlens_q,\n            kv_seqlens=attn_metadata.kv_seqlens,\n            max_q_seqlen=max_q_seqlen,\n            block_offsets=block_offsets,\n            group_size=128,\n            scale_fmt='ue8m0',\n        )\n        self.fill_kv_cache(\n            key[..., self._MLA_NOPE_SIZE:],\n            None,\n            k_cache_pe,\n            None,\n            fill_q_start_loc,\n            fill_seqlens,\n            kv_seq_length=kv_seqlens,\n            max_q_seq_length=fill_max_q_seqlen,\n            block_offsets=block_offsets,\n        )\n\n    def _forward_decoding(\n        self,\n        query: torch.Tensor,\n        k_cache: torch.Tensor,\n        attn_metadata: TritonAttentionMetadata,\n        nsa_indices: torch.Tensor = None,\n    ) -> torch.Tensor:\n        \"\"\"Forward pass for decoding stage.\n\n        Uses flash_mla_with_kvcache for efficient decoding with paged KV cache.\n        Supports both regular and sparse (NSA) attention patterns.\n\n        Args:\n            query: Query tensor.\n            k_cache: Key cache tensor.\n            attn_metadata: Attention metadata.\n            nsa_indices: Optional sparse attention indices.\n\n        Returns:\n            Attention output tensor.\n        \"\"\"\n        return self.flash_mla_decoding(query, k_cache, nsa_indices, attn_metadata)\n\n    def _forward_prefill(\n        self,\n        query: torch.Tensor,\n        k_cache: torch.Tensor,\n        v_cache: torch.Tensor,\n        attn_metadata: TritonAttentionMetadata,\n        nsa_indices: torch.Tensor = None,\n        k_scales_zeros: torch.Tensor = None,\n        v_scales_zeros: torch.Tensor = None,\n    ) -> torch.Tensor:\n        \"\"\"Forward pass for prefill stage.\n\n        Supports three execution paths:\n        1. Sparse (NSA + FP8): flash_mla_sparse_fwd for sparse attention\n        2. FA3 optimized: flash_attn_varlen_func with split q_rope/q_nope\n        3. Triton fallback: Custom Triton kernel implementation\n\n        Args:\n            query: Query tensor.\n            k_cache: Key cache tensor.\n            v_cache: Value cache tensor.\n            attn_metadata: Attention metadata.\n            nsa_indices: Optional sparse attention indices.\n            k_scales_zeros: Key quantization scales/zeros.\n            v_scales_zeros: Value quantization scales/zeros.\n\n        Returns:\n            Attention output tensor.\n        \"\"\"\n        # Flatten KV cache once for all prefill paths\n        flatten_k, flatten_v = self.run_flatten_kv_cache(\n            k_cache,\n            v_cache,\n            attn_metadata,\n            out_dtype=query.dtype,\n            is_nsa=nsa_indices is not None,\n            k_scales_zeros=k_scales_zeros,\n            v_scales_zeros=v_scales_zeros,\n        )\n\n        # Dispatch to appropriate prefill implementation\n        if nsa_indices is not None:\n            return self._prefill_sparse(query, flatten_k, nsa_indices, attn_metadata)\n        elif self.use_fa3:\n            return self._prefill_fa3(query, flatten_k, attn_metadata)\n        else:\n            return self._prefill_triton(query, flatten_k, flatten_v, attn_metadata)\n\n    def forward(\n        self,\n        query: torch.Tensor,\n        key: torch.Tensor,\n        value: torch.Tensor,\n        k_cache: torch.Tensor,\n        v_cache: torch.Tensor,\n        attn_metadata: TritonAttentionMetadata,\n        k_scales_zeros: torch.Tensor = None,\n        v_scales_zeros: torch.Tensor = None,\n        nsa_indices: torch.Tensor = None,\n        **kwargs,\n    ) -> torch.Tensor:\n        \"\"\"Forward pass for MLA attention computation.\n\n        This method handles both prefill and decoding stages by:\n        1. Validating NSA requirements (FP8 KV cache)\n        2. Computing max query sequence length\n        3. Filling KV cache if new key/value are provided\n        4. Dispatching to appropriate stage-specific method\n\n        Architecture:\n        - Decoding: Uses flash_mla_with_kvcache with paged KV cache\n        - Prefill: Three paths based on availability and requirements\n          * Sparse (NSA + FP8): flash_mla_sparse_fwd\n          * FA3 optimized: flash_attn_varlen_func with split q_rope/q_nope\n          * Triton fallback: Custom triton kernel\n\n        Args:\n            query: Query tensor.\n            key: Key tensor (None for decoding-only).\n            value: Value tensor (None for decoding-only).\n            k_cache: Key cache tensor.\n            v_cache: Value cache tensor.\n            attn_metadata: Attention metadata containing stage info and indices.\n            k_scales_zeros: Key quantization scales/zeros.\n            v_scales_zeros: Value quantization scales/zeros.\n            nsa_indices: Optional sparse attention indices.\n\n        Returns:\n            Attention output tensor.\n        \"\"\"\n        # Validate NSA requirements\n        is_nsa = nsa_indices is not None\n        if is_nsa:\n            is_fp8_kvcache = k_cache.dtype == torch.float8_e4m3fn\n            assert is_fp8_kvcache, 'NSA sparse attention requires FP8 KV cache'\n\n        # Shared preparation\n        max_q_seqlen = self._get_max_q_seqlen(query, attn_metadata)\n\n        # Fill KV cache with new key/value if provided\n        self._fill_kv_cache_impl(\n            key,\n            value,\n            k_cache,\n            v_cache,\n            attn_metadata,\n            max_q_seqlen,\n            k_scales_zeros=k_scales_zeros,\n            v_scales_zeros=v_scales_zeros,\n        )\n\n        # Dispatch to stage-specific forward method\n        if attn_metadata.is_decoding:\n            return self._forward_decoding(query, k_cache, attn_metadata, nsa_indices)\n        else:\n            return self._forward_prefill(\n                query,\n                k_cache,\n                v_cache,\n                attn_metadata,\n                nsa_indices,\n                k_scales_zeros,\n                v_scales_zeros,\n            )\n"
  },
  {
    "path": "lmdeploy/pytorch/backends/cuda/awq_modules.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom typing import Optional\n\nimport torch\n\nimport lmdeploy.pytorch.distributed as dist\n\nfrom ..awq_modules import LinearW4A16Builder, LinearW4A16Impl\n\n\ndef wq_gemm_forward(\n    x,\n    qweight,\n    qzeros,\n    scales,\n    w_bit=4,\n    group_size=128,\n    bias=None,\n    out_features=0,\n):\n    \"\"\"Wq gemm forward.\"\"\"\n    from lmdeploy.pytorch.kernels.cuda.awq_kernels import awq_linear\n    out_shape = x.shape[:-1] + (out_features, )\n    input_dtype = x.dtype\n    if input_dtype != torch.float16:\n        x = x.half()\n\n    x = x.flatten(0, -2)\n    out = awq_linear(x, qweight, scales, qzeros)\n\n    out = out + bias if bias is not None else out\n    out = out.reshape(out_shape)\n\n    # always want 3D tensor if tensor is 2D\n    if len(out.shape) == 2:\n        out = out.unsqueeze(0)\n\n    if input_dtype != torch.float16:\n        out = out.to(dtype=input_dtype)\n    return out\n\n\nclass AwqLinearW4A16Impl(LinearW4A16Impl):\n    \"\"\"Awq kernel linear.\"\"\"\n\n    def __init__(self, in_features: int, out_features: int, w_bit: int, group_size: int):\n        self.in_features = in_features\n        self.out_features = out_features\n        self.w_bit = w_bit\n        self.group_size = group_size\n\n    def forward(self,\n                x,\n                qweight: torch.Tensor,\n                scales: torch.Tensor,\n                qzeros: torch.Tensor,\n                bias: Optional[torch.Tensor] = None,\n                all_reduce: bool = False,\n                group: Optional[torch.distributed.ProcessGroup] = None):\n        \"\"\"forward.\"\"\"\n        out_features = scales.size(1)\n        out = wq_gemm_forward(x, qweight, qzeros, scales, self.w_bit, self.group_size, bias, out_features)\n        if all_reduce:\n            dist.all_reduce(out, group=group)\n        return out\n\n\nclass AwqLinearW4A16Builder(LinearW4A16Builder):\n    \"\"\"Awq linear builder.\"\"\"\n\n    @staticmethod\n    def build(in_features: int,\n              out_features: int,\n              w_bit: int,\n              group_size: int,\n              bias: bool = False,\n              dtype: torch.dtype = None):\n        \"\"\"build.\"\"\"\n        return AwqLinearW4A16Impl(in_features, out_features, w_bit, group_size)\n"
  },
  {
    "path": "lmdeploy/pytorch/backends/cuda/blockedf8_modules.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom typing import List, Optional\n\nimport torch\n\nimport lmdeploy.pytorch.distributed as dist\nfrom lmdeploy.pytorch.kernels.cuda.blocked_gemm_fp8 import blocked_gemm_fp8, deep_gemm_fp8, quant_fp8, quant_fp8_tma\nfrom lmdeploy.utils import get_logger\n\nfrom ..blockedf8_modules import LinearBlockedF8Builder, LinearBlockedF8Impl\nfrom .warmup_manager import WarmupMeta, get_warmup_manager\n\nlogger = get_logger('lmdeploy')\n\n\nclass TritonLinearBlockedF8Impl(LinearBlockedF8Impl):\n    \"\"\"Triton linear blocked f8 implementation.\"\"\"\n\n    def __init__(self, in_features: int, out_features: int, block_size: int, out_dtype: torch.dtype = torch.float16):\n        super().__init__()\n        self.in_features = in_features\n        self.out_features = out_features\n        self.out_dtype = out_dtype\n        self.block_size = block_size\n\n    def forward(self,\n                x,\n                weight: torch.Tensor,\n                scale: torch.Tensor,\n                bias: Optional[torch.Tensor] = None,\n                all_reduce: bool = False,\n                group: Optional[dist.ProcessGroup] = None,\n                rank: int = 0,\n                scatter_size: List[int] = None):\n        \"\"\"forward.\"\"\"\n        x_shape = x.shape\n        x = x.flatten(0, -2)\n        input_quant, input_scale = quant_fp8(x,\n                                             self.block_size,\n                                             dtype=weight.dtype,\n                                             trans_scale=True,\n                                             scale_fmt=self.scale_fmt)\n\n        out = blocked_gemm_fp8(input_quant, input_scale, weight.t(), scale.t(), out_dtype=x.dtype)\n        if bias is not None:\n            out += bias\n\n        out = out.unflatten(0, x_shape[:-1])\n\n        if all_reduce:\n            if scatter_size is not None:\n                out = dist.reduce_scatter_by_tp_sizes(out, rank, scatter_size, group=group)\n            else:\n                dist.all_reduce(out)\n        return out\n\n\nclass TritonLinearBlockedF8Builder(LinearBlockedF8Builder):\n    \"\"\"Triton linear blocked f8 implementation builder.\"\"\"\n\n    @staticmethod\n    def build(in_features: int, out_features: int, block_size: int = 128, bias: bool = True, dtype: torch.dtype = None):\n        \"\"\"build.\"\"\"\n        try:\n            import deep_gemm  # noqa\n            logger.debug('build with DeepGemmLinearBlockedF8Impl')\n            return DeepGemmLinearBlockedF8Impl(in_features, out_features, block_size, dtype)\n        except:  # noqa\n            logger.warning('Failed to import deep_gemm, LinearBlockedF8 fallback to triton implementation.')\n            return TritonLinearBlockedF8Impl(in_features, out_features, block_size, dtype)\n\n\nclass DeepGemmLinearBlockedF8Impl(LinearBlockedF8Impl):\n    \"\"\"Deep gemm blocked f8 implementation.\"\"\"\n\n    def __init__(self, in_features: int, out_features: int, block_size: int, out_dtype: torch.dtype = torch.float16):\n        super().__init__()\n        self.in_features = in_features\n        self.out_features = out_features\n        self.out_dtype = out_dtype\n        self.block_size = block_size\n\n        warmup_mgr = get_warmup_manager()\n        key = ('deepgemm_blockedfp8_gemm_'\n               f'{in_features}_{out_features}_{block_size}_{out_dtype}')\n        if key not in warmup_mgr:\n            warmup_mgr[key] = self.warmup\n\n    def warmup(self, warmup_meta: WarmupMeta):\n        \"\"\"warmup.\"\"\"\n        import random\n\n        from lmdeploy.pytorch.third_party.deep_gemm import get_m_alignment_for_contiguous_layout\n        device = 'cuda'\n        max_num_tokens = warmup_meta.max_num_tokens\n        alignment = get_m_alignment_for_contiguous_layout()\n        range_end = max_num_tokens + alignment - 1\n        k, n = self.in_features, self.out_features\n        block_size = self.block_size\n        weight = torch.empty(n, k, dtype=torch.float8_e4m3fn, device=device)\n        scale = torch.empty(((n + block_size - 1) // block_size, (k + block_size - 1) // block_size),\n                            dtype=torch.float32,\n                            device=device)\n        # shuffle ranges so ranks might compile different kernels concurrently.\n        ranges = list(range(alignment, range_end, alignment))\n        random.shuffle(ranges)\n        for m in ranges:\n            inputs = torch.empty(m, k, dtype=self.out_dtype, device=device)\n            input_quant, input_scale = quant_fp8_tma(inputs, self.block_size, dtype=weight.dtype)\n            deep_gemm_fp8(input_quant, input_scale, weight, scale, out_dtype=inputs.dtype)\n\n    def forward(self,\n                x,\n                weight: torch.Tensor,\n                scale: torch.Tensor,\n                bias: Optional[torch.Tensor] = None,\n                all_reduce: bool = False,\n                group: Optional[dist.ProcessGroup] = None,\n                rank: int = 0,\n                scatter_size: List[int] = None):\n        \"\"\"forward.\"\"\"\n        x_shape = x.shape\n        x = x.flatten(0, -2)\n        input_quant, input_scale = quant_fp8_tma(x, self.block_size, dtype=weight.dtype, scale_fmt=self.scale_fmt)\n\n        out = deep_gemm_fp8(input_quant, input_scale, weight, scale, out_dtype=x.dtype)\n        out = out[:x.size(0)]\n        if bias is not None:\n            out += bias\n        out = out.unflatten(0, x_shape[:-1])\n\n        if all_reduce:\n            if scatter_size is not None:\n                out = dist.reduce_scatter_by_tp_sizes(out, rank, scatter_size, group=group)\n            else:\n                dist.all_reduce(out, group=group)\n        return out\n"
  },
  {
    "path": "lmdeploy/pytorch/backends/cuda/causal_conv1d.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom functools import lru_cache\n\nimport torch\n\nfrom ..causal_conv1d import CausalConv1dBuilder, CausalConv1dImpl\nfrom .utils import has_tilelang\n\n\nclass CausalConv1dTilelangImpl(CausalConv1dImpl):\n    \"\"\"CausalConv1d update implementation.\"\"\"\n\n    def __init__(self):\n        from lmdeploy.pytorch.kernels.cuda.causal_conv1d import causal_conv1d_fn, causal_conv1d_update\n        self.causal_conv1d_fn = causal_conv1d_fn\n        self.causal_conv1d_update = causal_conv1d_update\n\n    def conv1d_fn(self,\n                  x: torch.Tensor,\n                  weight: torch.Tensor,\n                  bias: torch.Tensor | None = None,\n                  seq_idx: torch.Tensor | None = None,\n                  return_final_states: bool = False,\n                  activation: str | None = None):\n        return self.causal_conv1d_fn(x,\n                                     weight,\n                                     bias=bias,\n                                     seq_idx=seq_idx,\n                                     return_final_states=return_final_states,\n                                     activation=activation)\n\n    def update_fn(self,\n                  x: torch.Tensor,\n                  conv_state: torch.Tensor,\n                  weight: torch.Tensor,\n                  bias: torch.Tensor | None = None,\n                  activation: str | None = None,\n                  conv_state_indices: torch.Tensor | None = None):\n        \"\"\"Update conv state.\"\"\"\n        return self.causal_conv1d_update(x,\n                                         conv_state,\n                                         weight,\n                                         bias=bias,\n                                         activation=activation,\n                                         conv_state_indices=conv_state_indices)\n\n\nclass CausalConv1dDaoImpl(CausalConv1dTilelangImpl):\n\n    def __init__(self):\n        try:\n            import causal_conv1d\n            self.causal_conv1d_fn = causal_conv1d.causal_conv1d_fn\n            self.causal_conv1d_update = causal_conv1d.causal_conv1d_update\n        except Exception:\n            raise RuntimeError(\n                'causal_conv1d is not installed, please refer to https://github.com/Dao-AILab/causal-conv1d')\n\n\n@lru_cache\ndef has_dao():\n    try:\n        import causal_conv1d  # noqa: F401\n        causal_conv1d_fn = causal_conv1d.causal_conv1d_fn  # noqa: F841\n        causal_conv1d_update = causal_conv1d.causal_conv1d_update  # noqa: F841\n        return True\n    except Exception:\n        return False\n\n\nclass CausalConv1dCudaBuilder(CausalConv1dBuilder):\n    \"\"\"CausalConv1d update implementation builder.\"\"\"\n\n    @staticmethod\n    def build() -> CausalConv1dImpl:\n        \"\"\"build.\"\"\"\n        if has_tilelang():\n            return CausalConv1dTilelangImpl()\n        elif has_dao():\n            return CausalConv1dDaoImpl()\n        else:\n            raise RuntimeError('No available implementation for CausalConv1d, '\n                               'please install https://tilelang.com/ or https://github.com/Dao-AILab/causal-conv1d')\n"
  },
  {
    "path": "lmdeploy/pytorch/backends/cuda/flash_attention.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom torch import Tensor\n\nfrom ..flash_attention import FlashAttentionBuilder, FlashAttentionImpl\n\n\nclass TritonFlashAttentionImpl(FlashAttentionImpl):\n    \"\"\"Triton flash attention implementation.\"\"\"\n\n    def __init__(\n        self,\n        num_heads: int,\n        head_dim: int,\n        scale: float = None,\n        num_kv_heads: int = None,\n        v_head_dim: int = None,\n        causal: bool = True,\n        sliding_window: int = None,\n        logit_softcapping: float = None,\n    ):\n        if scale is None:\n            scale = 1.0 / (head_dim**0.5)\n\n        if num_kv_heads is None:\n            num_kv_heads = num_heads\n\n        if v_head_dim is None:\n            v_head_dim = head_dim\n\n        self.num_heads = num_heads\n        self.head_dim = head_dim\n        self.scale = scale\n        self.num_kv_heads = num_kv_heads\n        self.v_head_dim = v_head_dim\n        self.causal = causal\n        self.sliding_window = sliding_window\n        self.logit_softcapping = logit_softcapping\n\n        from lmdeploy.pytorch.kernels.cuda import flash_attn_varlen_func\n        self.flash_attention_fwd = flash_attn_varlen_func\n\n    def forward(self,\n                query: Tensor,\n                key: Tensor,\n                value: Tensor,\n                q_start_loc: Tensor,\n                q_seqlens: Tensor,\n                kv_start_loc: Tensor,\n                kv_seqlens: Tensor,\n                max_q_seqlen: int = None):\n        \"\"\"forward.\"\"\"\n        out = self.flash_attention_fwd(\n            query,\n            key,\n            value,\n            q_start_loc=q_start_loc,\n            q_seqlens=q_seqlens,\n            kv_start_loc=kv_start_loc,\n            kv_seqlens=kv_seqlens,\n            max_seqlen_q=max_q_seqlen,\n            window_size=self.sliding_window,\n            softmax_scale=self.scale,\n            softcap=self.logit_softcapping,\n            causal=self.causal,\n            kv_layout='shd',\n        )\n\n        return out\n\n\nclass TritonFlashAttentionBuilder(FlashAttentionBuilder):\n    \"\"\"Triton attention builder.\"\"\"\n\n    @staticmethod\n    def build(\n        num_heads: int,\n        head_dim: int,\n        scale: float = None,\n        num_kv_heads: int = None,\n        v_head_dim: int = None,\n        causal: bool = True,\n        sliding_window: int = None,\n        logit_softcapping: float = None,\n        **kwargs,\n    ) -> FlashAttentionImpl:\n        \"\"\"build.\"\"\"\n        return TritonFlashAttentionImpl(\n            num_heads=num_heads,\n            head_dim=head_dim,\n            scale=scale,\n            num_kv_heads=num_kv_heads,\n            v_head_dim=v_head_dim,\n            causal=causal,\n            sliding_window=sliding_window,\n            logit_softcapping=logit_softcapping,\n        )\n"
  },
  {
    "path": "lmdeploy/pytorch/backends/cuda/gated_delta_rule.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom functools import lru_cache\n\nimport torch\n\nfrom ..gated_delta_rule import GatedDeltaRuleBuilder, GatedDeltaRuleImpl\nfrom .utils import has_tilelang\n\n\n@lru_cache\ndef has_fla():\n    try:\n        from fla.ops.gated_delta_rule import chunk_gated_delta_rule  # noqa: F401\n        return True\n    except Exception:\n        return False\n\n\nclass CudaGatedDeltaRuleImpl(GatedDeltaRuleImpl):\n\n    def __init__(self):\n        if not has_fla() or not has_tilelang():\n            raise ImportError('fla and tilelang is required for CudaGatedDeltaRuleImpl')\n        from fla.ops.gated_delta_rule import chunk_gated_delta_rule\n\n        from lmdeploy.pytorch.kernels.cuda.gated_delta_rule import fused_recurrent_gated_delta_rule\n        self.chunk_func = chunk_gated_delta_rule\n        self.recurrent_func = fused_recurrent_gated_delta_rule\n\n    def chunk_gated_delta_rule(self,\n                               q: torch.Tensor,\n                               k: torch.Tensor,\n                               v: torch.Tensor,\n                               g: torch.Tensor | None = None,\n                               beta: torch.Tensor | None = None,\n                               initial_state: torch.Tensor | None = None,\n                               state_indices: torch.Tensor | None = None,\n                               scale: float | None = None,\n                               use_qk_l2norm_in_kernel: bool = False,\n                               cu_seqlens: torch.Tensor | None = None,\n                               output_final_state: bool = False):\n\n        assert initial_state is not None\n        recurrent_state = initial_state\n        init_state = recurrent_state.index_select(0, state_indices)\n        if use_qk_l2norm_in_kernel:\n            # l2norm in fla would recompile when seqlen changed.\n            q = torch.nn.functional.normalize(q, p=2, dim=-1)\n            k = torch.nn.functional.normalize(k, p=2, dim=-1)\n        core_attn_out, last_state = self.chunk_func(\n            q,\n            k,\n            v,\n            g=g,\n            beta=beta,\n            scale=scale,\n            initial_state=init_state,\n            output_final_state=output_final_state,\n            use_qk_l2norm_in_kernel=False,\n            cu_seqlens=cu_seqlens,\n        )\n\n        last_state = recurrent_state.index_copy_(0, state_indices, last_state.to(recurrent_state.dtype))\n        if not output_final_state:\n            last_state = None\n        return core_attn_out, last_state\n\n    def fused_recurrent_gated_delta_rule(self,\n                                         q: torch.Tensor,\n                                         k: torch.Tensor,\n                                         v: torch.Tensor,\n                                         g: torch.Tensor | None = None,\n                                         beta: torch.Tensor | None = None,\n                                         initial_state: torch.Tensor | None = None,\n                                         state_indices: torch.Tensor | None = None,\n                                         scale: float | None = None,\n                                         use_qk_l2norm_in_kernel: bool = False,\n                                         output_final_state: bool = False):\n        return self.recurrent_func(\n            q,\n            k,\n            v,\n            g=g,\n            beta=beta,\n            scale=scale,\n            initial_state=initial_state,\n            state_indices=state_indices,\n            use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,\n            output_final_state=output_final_state,\n        )\n\n\nclass CudaGatedDeltaRuleBuilder(GatedDeltaRuleBuilder):\n\n    @staticmethod\n    def build() -> GatedDeltaRuleImpl:\n        return CudaGatedDeltaRuleImpl()\n"
  },
  {
    "path": "lmdeploy/pytorch/backends/cuda/graph_runner.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport functools\nfrom typing import Any, Dict, List, Tuple\n\nimport torch\nfrom torch.profiler import record_function\n\nfrom lmdeploy.pytorch.backends.deepep_moe_checker import get_moe_backend\nfrom lmdeploy.pytorch.backends.selector import get_backend\nfrom lmdeploy.pytorch.config import BackendConfig, CacheConfig, ModelConfig\nfrom lmdeploy.pytorch.model_inputs import StepContext, get_step_ctx_manager\nfrom lmdeploy.pytorch.models.utils.cudagraph import CudaGraphMeta\nfrom lmdeploy.pytorch.strategies.base import StrategyFactoryBase\nfrom lmdeploy.utils import get_logger\n\nfrom ..graph_runner import GraphRunner\nfrom .attention import TritonAttentionMetadata\n\nlogger = get_logger('lmdeploy')\n\n\ndef next_power_of_2(n: int):\n    \"\"\"Return the smallest power of 2 greater than or equal to n.\"\"\"\n    n -= 1\n    n |= n >> 1\n    n |= n >> 2\n    n |= n >> 4\n    n |= n >> 8\n    n |= n >> 16\n    n |= n >> 32\n    n += 1\n    return n\n\n\n@functools.lru_cache\ndef _get_capture_batch_size_impl(max_batches: int):\n    \"\"\"Capture batch size.\"\"\"\n    ret = []\n    batch_size = 1\n    batch_step = 256\n    # power of 2\n    while batch_size <= min(batch_step, max_batches):\n        ret.append(batch_size)\n        batch_size *= 2\n\n    # step\n    ret += list(range(batch_size, max_batches + 1, batch_step))\n\n    if max_batches != ret[-1]:\n        ret.append(max_batches)\n    return ret\n\n\ndef _false(*args, **kwargs):\n    \"\"\"Default value of not support cuda graph.\"\"\"\n    return False\n\n\nclass CUDASingleGraphRunner:\n    \"\"\"Cuda single graph runner.\"\"\"\n\n    def __init__(\n        self,\n        model: torch.nn.Module,\n        max_batches: int,\n        max_tokens: int,\n        num_blocks: int,\n        is_decoding: bool,\n        pool: Tuple[int, int],\n        model_config: ModelConfig,\n        device: torch.device,\n        decode_query_len: int = 1,\n    ):\n        self.model = model\n        self.ctx_mgr = model.ctx_mgr\n        self.model_config = model_config\n\n        self.meta = CudaGraphMeta(\n            max_batchs=max_batches,\n            max_tokens=max_tokens,\n            num_blocks=num_blocks,\n            is_decoding=is_decoding,\n            device=device,\n            input_buffers=dict(),\n            output_buffers=dict(),\n            vocab_size=self.model_config.vocab_size,\n            use_mla_fp8_cache=getattr(self.model_config, 'use_mla_fp8_cache', False),\n            use_flash_mla=getattr(self.model_config, 'use_flash_mla', False),\n            mla_index_topk=getattr(self.model_config, 'mla_index_topk', None),\n            decode_query_len=decode_query_len,\n            use_fa3_decoding=model_config.model_paradigm == 'ar_spec',\n        )\n        self.device = device\n        self.max_batches = max_batches\n        self.max_tokens = max_tokens\n        self.num_blocks = num_blocks\n        self.is_decoding = is_decoding\n        self.pool = pool\n        self._graph: torch.cuda.CUDAGraph = None\n\n    @record_function('capture_cudagraph')\n    def capture(self, **kwargs):\n        \"\"\"Capture graph.\"\"\"\n        logger.debug(f'Capturing graph with meta: {self.meta}')\n        self.meta.input_buffers = self.model.make_buffers_cudagraph(self.meta, **kwargs)\n        padded_kwargs = self.model.fill_buffers_cudagraph(self.meta, **kwargs)\n        context = self.ctx_mgr.current_context()\n        self.model.update_context_cudagraph(self.meta, context)\n        current_stream = torch.cuda.current_stream()\n\n        # warmup\n        warmup_output = self.model(**padded_kwargs)\n        warmup_buffers = self.model.make_output_buffers(warmup_output)\n\n        self._graph = torch.cuda.CUDAGraph()\n        # unsafe kernel call in other thread might invalid the capture\n        # so we set thread_safe capture mode here.\n        with torch.cuda.graph(self._graph, pool=self.pool, stream=current_stream, capture_error_mode='thread_local'):\n            output = self.model(**padded_kwargs)\n\n        output_buffers = self.model.make_output_buffers(output)\n        self.meta.output_buffers = output_buffers\n        output = self.model.get_outputs_cudagraph(warmup_buffers, **kwargs)\n        return output\n\n    @record_function('forward_cudagraph')\n    def forward(self, **kwargs):\n        \"\"\"forward.\"\"\"\n        assert self._graph is not None\n        self.model.fill_buffers_cudagraph(self.meta, **kwargs)\n        context = self.ctx_mgr.current_context()\n        self.model.update_context_cudagraph(self.meta, context)\n        self._graph.replay()\n        output_buffers = self.meta.output_buffers\n        output = self.model.get_outputs_cudagraph(output_buffers, **kwargs)\n        return output\n\n    def __del__(self):\n        \"\"\"del.\"\"\"\n        del self._graph\n\n\nclass CUDAGraphRunner(GraphRunner):\n    \"\"\"Cuda graph runner.\"\"\"\n\n    def __init__(self, model: torch.nn.Module, model_config: ModelConfig, cache_config: CacheConfig,\n                 backend_config: BackendConfig, device: torch.device):\n        super().__init__(model, model_config, cache_config, backend_config, device)\n        self.max_batches = cache_config.max_batches\n        self.max_tokens = cache_config.max_prefill_token_num\n        self.num_blocks = cache_config.num_gpu_blocks\n\n        self.enable_graph = self.check_enable_graph()\n\n        self.graph_pool_handle = torch.cuda.graph_pool_handle()\n        self._runner_map: Dict[Any, CUDASingleGraphRunner] = dict()\n        self.has_try_compile_model: bool = False\n\n        # strategy factory\n        build_ctx = model.ctx_mgr.build_ctx\n        strategy_factory: StrategyFactoryBase = build_ctx.strategy_factory\n        self.cudagraph_strategy = strategy_factory.build_cudagraph_strategy()\n\n    def check_enable_graph(self):\n        \"\"\"Check enable graph.\"\"\"\n        if self.backend_config.eager_mode:\n            return _false\n\n        return getattr(self.model, 'support_cuda_graph', _false)\n\n    def _try_compile_model_once(self):\n        if self.has_try_compile_model:\n            return\n\n        # TODO: recovery it when torch.compile is stable (should be add a flag to enable it?)\n        # if hasattr(self.model, 'compile_model'):\n        #     method = getattr(self.model, 'compile_model')\n        #     method()\n\n        self.has_try_compile_model = True\n\n    def _get_capture_tokens(self, batch_size: int):\n        \"\"\"Get capture tokens.\"\"\"\n        cap_sizes = self.get_capture_batch_sizes()\n        for size in cap_sizes:\n            if size >= batch_size:\n                return size\n        assert False, f'Unsupported batch_size={batch_size}'\n\n    def get_graph_key(self, input_ids: torch.Tensor, position_ids: torch.Tensor, past_key_values: List,\n                      attn_metadata: TritonAttentionMetadata, inputs_embeds: torch.Tensor, **kwargs):\n        \"\"\"Get graph key.\"\"\"\n        context = self.ctx_mgr.current_context()\n        is_decoding = context.is_decoding\n        batch_size = attn_metadata.q_seqlens.size(0)\n        meta = self.get_meta()\n        enable_microbatch = get_step_ctx_manager().current_context().enable_microbatch\n        # for draft model to distinguish inputs from target model and itself\n        query_len = input_ids.size(1) // batch_size\n        if meta.padding_batch_size is None:\n            batch_size = self._get_capture_tokens(batch_size)\n        else:\n            batch_size = self._get_capture_tokens(meta.padding_batch_size)\n        return (batch_size, is_decoding, enable_microbatch, query_len)\n\n    def _prepare_inputs(self, **kwargs):\n        \"\"\"Prepare inputs.\"\"\"\n        assert 'attn_metadata' in kwargs, 'attn_metadata is required for cudagraph.'\n        attn_metadata: TritonAttentionMetadata = kwargs['attn_metadata']\n        if not attn_metadata.block_offsets.dtype == torch.int32:\n            attn_metadata.block_offsets = attn_metadata.block_offsets.to(torch.int32)\n        return kwargs\n\n    def _get_max_tokens(self, graph_key: tuple, input_ids: torch.Tensor, q_seqlens: torch.Tensor):\n        max_batches = graph_key[0]\n        is_decoding = graph_key[1]\n        assert is_decoding\n        origin_batch_size = q_seqlens.size(0)\n        num_tokens = input_ids.size(1)\n        return self.cudagraph_strategy.get_max_tokens(max_batches, origin_batch_size, num_tokens)\n\n    def __call__(self, **kwargs):\n        \"\"\"call.\"\"\"\n        if not self.backend_config.eager_mode and get_backend().get_name() == 'cuda':\n            self._try_compile_model_once()\n\n        kwargs = self._prepare_inputs(**kwargs)\n        enable_graph = self.enable_graph(**kwargs)\n\n        if not enable_graph:\n            with record_function('forward_eager'):\n                output = self.model(**kwargs)\n                return self.model.make_output_buffers(output)\n\n        graph_key = self.get_graph_key(**kwargs)\n        max_batches = graph_key[0]\n        is_decoding = graph_key[1]\n        decode_query_len = graph_key[3]\n        if graph_key not in self._runner_map:\n            max_tokens = self._get_max_tokens(graph_key, kwargs['input_ids'], kwargs['attn_metadata'].q_seqlens)\n            runner = CUDASingleGraphRunner(\n                self.model,\n                max_batches=max_batches,\n                max_tokens=max_tokens,\n                num_blocks=self.num_blocks,\n                is_decoding=is_decoding,\n                pool=self.graph_pool_handle,\n                model_config=self.model_config,\n                device=self.device,\n                decode_query_len=decode_query_len,\n            )\n            output = runner.capture(**kwargs)\n            self._runner_map[graph_key] = runner\n            # SSM would update the state in capture(warmup), replay the graph will leads unexpected state update.\n            return output\n        else:\n            runner = self._runner_map[graph_key]\n            output = runner.forward(**kwargs)\n            return output\n\n    @record_function('prepare_inputs_for_generation')\n    def prepare_inputs_for_generation(\n        self,\n        past_key_values: List[List[torch.Tensor]],\n        inputs_embeds: torch.Tensor = None,\n        context: StepContext = None,\n    ):\n        \"\"\"Prepare inputs.\"\"\"\n\n        if get_moe_backend().use_deepep_moe_backend():\n            from dlblas.layers.moe.token_dispatcher import DeepEPBuffer, DeepEPMode\n            deepep_mode = DeepEPMode.LOW_LATENCY if context.is_decoding else DeepEPMode.NORMAL\n            DeepEPBuffer.set_deepep_mode(deepep_mode)\n\n        return self.model.prepare_inputs_for_generation(\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            context=context,\n        )\n\n    def reset(self):\n        \"\"\"Remove all graphs to prevent hanging on exit.\"\"\"\n        self._runner_map.clear()\n        if get_moe_backend().use_deepep_moe_backend():\n            from dlblas.layers.moe.token_dispatcher import DeepEPBuffer\n\n            if hasattr(DeepEPBuffer, 'destroy'):\n                from torch import distributed as dist\n\n                DeepEPBuffer.destroy()\n                dist.barrier()\n\n    def update_inputs(self, inputs):\n        \"\"\"Update inputs.\"\"\"\n        if self.backend_config.eager_mode:\n            return inputs\n        is_decoding = inputs.is_decoding\n        dp_meta = inputs.dp_meta\n        if is_decoding and dp_meta is not None:\n            meta = self.get_meta()\n            padding_batch_size = meta.padding_batch_size\n            tp_size = self._get_capture_tokens(padding_batch_size)\n            dp_meta.sync_tp_size(tp_size)\n        return inputs\n\n    def get_capture_batch_sizes(self) -> List[int]:\n        \"\"\"Capture batch sizes.\"\"\"\n        return _get_capture_batch_size_impl(self.cache_config.max_batches)\n"
  },
  {
    "path": "lmdeploy/pytorch/backends/cuda/lora.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom dataclasses import dataclass\n\nimport torch\n\nfrom lmdeploy.pytorch.kernels.cuda.fused_lora import fused_lora\nfrom lmdeploy.pytorch.model_inputs import StepContextManager\n\nfrom ..lora import AdapterInfo, LoRABuilder, LoRAImpl\n\n\n@dataclass\nclass PackedLoRAInput:\n    \"\"\"Packed lora input.\"\"\"\n    x: torch.Tensor\n    q_start_loc: torch.Tensor\n    q_seqlens: torch.Tensor\n    adapter_ids: torch.Tensor\n    max_seq_len: int\n    is_decoding: bool\n\n\nclass TritonLoRAImpl(LoRAImpl):\n    \"\"\"Triton lora implementation.\"\"\"\n\n    @staticmethod\n    def _make_packed_lora_input(x, ctx_mgr):\n        \"\"\"Make PackedLoRAInput.\"\"\"\n        context = ctx_mgr.current_context()\n\n        # adapter cache\n        max_q_seq_length = x.numel() // x.size(-1)\n\n        return PackedLoRAInput(x=x.flatten(0, -2).contiguous(),\n                               q_start_loc=context.q_start_loc,\n                               q_seqlens=context.q_seqlens,\n                               adapter_ids=context.local_adapter_ids,\n                               max_seq_len=max_q_seq_length,\n                               is_decoding=context.is_decoding)\n\n    def forward(self,\n                x: torch.Tensor,\n                lora_A: torch.Tensor,\n                lora_B: torch.Tensor,\n                base_output: torch.Tensor,\n                adapter_info: AdapterInfo,\n                ctx_mgr: StepContextManager,\n                colwise: bool,\n                is_tp: bool = True):\n        \"\"\"forward.\"\"\"\n        lora_input = self._make_packed_lora_input(x, ctx_mgr)\n\n        base_slice = adapter_info.base_slice\n        sliced_base = base_output[..., base_slice]\n\n        if base_output.is_contiguous():\n            kernel_output = sliced_base.flatten(0, -2)\n            cum = True\n        else:\n            kernel_output = None\n            cum = False\n        lora_out = fused_lora(\n            lora_input.x,\n            lora_A,\n            lora_B,\n            scaling=adapter_info.scalings,\n            rank_start=adapter_info.rank_offsets,\n            ranks=adapter_info.ranks,\n            seq_start=lora_input.q_start_loc,\n            seq_lens=lora_input.q_seqlens,\n            adapter_ids=lora_input.adapter_ids,\n            max_rank=adapter_info.max_rank,\n            max_seqlen=lora_input.max_seq_len,\n            output=kernel_output,\n            cum=cum,\n        )\n\n        if not base_output.is_contiguous():\n            lora_out = lora_out.reshape(sliced_base.shape)\n            sliced_base.add_(lora_out)\n        return base_output\n\n\nclass TritonLoRABuilder(LoRABuilder):\n    \"\"\"Triton lora layer builder.\"\"\"\n\n    @staticmethod\n    def build():\n        \"\"\"build.\"\"\"\n        return TritonLoRAImpl()\n"
  },
  {
    "path": "lmdeploy/pytorch/backends/cuda/moe/__init__.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom .blocked_fp8 import TritonFusedMoEBlockedF8Builder  # noqa: F401\nfrom .default import TritonFusedMoEBuilder  # noqa: F401\nfrom .w8a8 import TritonFusedMoEW8A8Builder  # noqa: F401\n"
  },
  {
    "path": "lmdeploy/pytorch/backends/cuda/moe/blocked_fp8.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nfrom typing import Callable, List\n\nimport torch\nimport torch.distributed as dist\n\nfrom lmdeploy.pytorch.backends.deepep_moe_checker import get_moe_backend\nfrom lmdeploy.pytorch.backends.moe import FusedMoEBlockedF8Builder, FusedMoEBlockedF8Impl\nfrom lmdeploy.pytorch.distributed import get_dist_manager\nfrom lmdeploy.pytorch.kernels.cuda.blocked_fp8_fused_moe import fused_moe_blocked_fp8\nfrom lmdeploy.pytorch.kernels.cuda.blocked_gemm_fp8 import quant_fp8\nfrom lmdeploy.pytorch.kernels.cuda.fused_moe import _renormalize\nfrom lmdeploy.pytorch.model_inputs import get_step_ctx_manager\nfrom lmdeploy.utils import get_logger\n\nfrom .ep_utils import gather_outputs_by_attn_tp, split_inputs_by_attn_tp\n\nlogger = get_logger('lmdeploy')\n\n\nclass TritonFusedMoEBlockedF8Impl(FusedMoEBlockedF8Impl):\n    \"\"\"Triton fused moe blocked f8 implementation.\"\"\"\n\n    def __init__(self,\n                 top_k: int,\n                 num_experts: int,\n                 renormalize: bool = False,\n                 block_size: int = 128,\n                 out_dtype: torch.dtype = torch.float16):\n        super().__init__()\n        self.num_experts = num_experts\n        self.top_k = top_k\n        self.renormalize = renormalize\n        self.block_size = block_size\n        self.out_dtype = out_dtype\n\n    def ep_expert_list(self, world_size: int, rank: int):\n        \"\"\"Experts list of current rank.\"\"\"\n        num_experts = self.num_experts\n        expert_per_rank = (num_experts + world_size - 1) // world_size\n        first_expert = rank * expert_per_rank\n        last_expert = min(first_expert + expert_per_rank, num_experts)\n        return list(range(first_expert, last_expert))\n\n    def forward(self,\n                hidden_states: torch.Tensor,\n                topk_weights: torch.Tensor,\n                topk_ids: torch.LongTensor,\n                gate_up_weights: torch.Tensor,\n                gate_up_scale: torch.Tensor,\n                down_weights: torch.Tensor,\n                down_scale: torch.Tensor,\n                gate_up_bias: torch.Tensor = None,\n                down_bias: torch.Tensor = None,\n                expert_list: List[int] = None,\n                act_func: Callable = None):\n        \"\"\"forward.\"\"\"\n        input_size = hidden_states.shape\n        hidden_states = hidden_states.flatten(0, -2)\n        input_quant, input_scale = quant_fp8(hidden_states,\n                                             self.block_size,\n                                             dtype=gate_up_weights.dtype,\n                                             scale_fmt=self.scale_fmt)\n        expert_offset = 0\n        num_experts = None\n        if expert_list is not None and len(expert_list) != self.num_experts:\n            expert_offset = expert_list[0]\n            num_experts = self.num_experts\n        output = fused_moe_blocked_fp8(input_quant,\n                                       input_scale,\n                                       gate_up_weights,\n                                       gate_up_scale,\n                                       down_weights,\n                                       down_scale,\n                                       topk_weights=topk_weights,\n                                       topk_ids=topk_ids,\n                                       topk=self.top_k,\n                                       w1_bias=gate_up_bias,\n                                       w2_bias=down_bias,\n                                       out_dtype=hidden_states.dtype,\n                                       expert_offset=expert_offset,\n                                       num_experts=num_experts,\n                                       renormalize=self.renormalize,\n                                       act_func=act_func)\n        output = output.unflatten(0, input_size[:-1])\n        return output\n\n\nclass FusedDeepEpMoEBlockedF8Impl(TritonFusedMoEBlockedF8Impl):\n\n    def __init__(self,\n                 ep_size: int,\n                 ep_group: dist.ProcessGroup,\n                 top_k: int,\n                 num_experts: int,\n                 hidden_dim: int,\n                 renormalize: bool = False,\n                 block_size: int = 128,\n                 out_dtype: torch.dtype = torch.bfloat16,\n                 layer_idx: int = 0):\n        super().__init__(top_k, num_experts, renormalize, block_size, out_dtype)\n        self.num_experts = num_experts\n        self.ep_size = ep_size\n        self.ep_group = ep_group\n        self.hidden_dim = hidden_dim\n        self.block_size = block_size\n        self.out_dtype = out_dtype\n        self.layer_idx = layer_idx\n        try:\n            import deep_gemm  # noqa: F401\n            self.use_deep_gemm = True\n        except ImportError:\n            self.use_deep_gemm = False\n            logger.warning('For higher performance, please install DeepGEMM https://github.com/deepseek-ai/DeepGEMM')\n\n        try:\n            from dlblas.layers.moe.token_dispatcher import DeepEPBuffer, DeepEPMode, use_deepep  # noqa: F401\n            get_moe_backend().set_deepep_moe_backend()\n            if hasattr(DeepEPBuffer, 'set_explicitly_destroy'):\n                DeepEPBuffer.set_explicitly_destroy()\n        except ImportError:\n            logger.warning('For higher performance, please install DeepEP https://github.com/deepseek-ai/DeepEP')\n\n        # pre-allocate buffer\n        self.fusedmoe_build(True)\n\n    def ep_expert_list(self, world_size: int, rank: int):\n        \"\"\"Experts list of current rank.\"\"\"\n        if get_dist_manager().current_context().dist_config.enable_eplb:\n            from dlblas.layers.moe.eplb import get_eplb_phy2log_metadata_by_layer\n            phy2log = get_eplb_phy2log_metadata_by_layer(self.layer_idx)\n            expert_per_rank = (self.num_experts + world_size - 1) // world_size\n            first_expert = rank * expert_per_rank\n            last_expert = min(first_expert + expert_per_rank, self.num_experts)\n            sliced_phy2log = phy2log[first_expert:last_expert].tolist()\n            return sliced_phy2log\n        else:\n            return super().ep_expert_list(world_size=world_size, rank=rank)\n\n    def forward(self,\n                hidden_states: torch.Tensor,\n                topk_weights: torch.Tensor,\n                topk_ids: torch.LongTensor,\n                gate_up_weights: torch.Tensor,\n                gate_up_scale: torch.Tensor,\n                down_weights: torch.Tensor,\n                down_scale: torch.Tensor,\n                gate_up_bias: torch.Tensor = None,\n                down_bias: torch.Tensor = None,\n                expert_list: List[int] = None,\n                act_func: Callable = None,\n                **kwargs):\n        \"\"\"forward.\"\"\"\n        hidden_states, topk_weights, topk_ids, split_size = split_inputs_by_attn_tp(hidden_states, topk_weights,\n                                                                                    topk_ids)\n\n        topk_weights = self.do_renormalize(topk_weights)\n        step_ctx = get_step_ctx_manager().current_context()\n        low_latency_mode = step_ctx.is_decoding and self.use_deep_gemm\n        moe = self.fusedmoe_build(low_latency_mode)\n        out_states = moe.forward(hidden_states, topk_weights, topk_ids, gate_up_weights, gate_up_scale, down_weights,\n                                 down_scale, expert_list)\n\n        out_states = gather_outputs_by_attn_tp(out_states, split_size)\n        return out_states\n\n    def do_renormalize(self, topk_weights):\n        return _renormalize(topk_weights, self.renormalize)\n\n    def fusedmoe_build(self, low_latency_mode: bool = False):\n        from dlblas.layers.moe.ep_moe import build_deepep_moe\n        deepep_moe = build_deepep_moe(low_latency_mode,\n                                      self.ep_size,\n                                      self.ep_group,\n                                      self.num_experts,\n                                      self.hidden_dim,\n                                      self.block_size,\n                                      self.top_k,\n                                      self.out_dtype,\n                                      layer_idx=self.layer_idx,\n                                      chunk_size=16 * 1024)\n        return deepep_moe\n\n\nclass TritonFusedMoEBlockedF8Builder(FusedMoEBlockedF8Builder):\n    \"\"\"Triton fused moe blocked f8 builder.\"\"\"\n\n    @staticmethod\n    def build(top_k: int,\n              num_experts: int,\n              hidden_dim: int = 1,\n              renormalize: bool = False,\n              block_size: int = 128,\n              ep_size: int = 1,\n              ep_group: dist.ProcessGroup = None,\n              out_dtype: torch.dtype = torch.float16,\n              layer_idx: int = 0,\n              custom_gateup_act: bool = False):\n        \"\"\"Build from mlp.\"\"\"\n        if ep_size > 1:\n            assert custom_gateup_act is False, 'Custom gate up activation is not supported in EP MoE.'\n            return FusedDeepEpMoEBlockedF8Impl(ep_size=ep_size,\n                                               ep_group=ep_group,\n                                               top_k=top_k,\n                                               num_experts=num_experts,\n                                               hidden_dim=hidden_dim,\n                                               renormalize=renormalize,\n                                               block_size=block_size,\n                                               out_dtype=out_dtype,\n                                               layer_idx=layer_idx)\n        else:\n            return TritonFusedMoEBlockedF8Impl(top_k=top_k,\n                                               num_experts=num_experts,\n                                               renormalize=renormalize,\n                                               block_size=block_size,\n                                               out_dtype=out_dtype)\n"
  },
  {
    "path": "lmdeploy/pytorch/backends/cuda/moe/default.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nfrom typing import Callable, List, Optional\n\nimport torch\n\nimport lmdeploy.pytorch.distributed as dist\nfrom lmdeploy.pytorch.backends.deepep_moe_checker import get_moe_backend\nfrom lmdeploy.pytorch.backends.moe import FusedMoEBuilder, FusedMoEImpl\nfrom lmdeploy.pytorch.distributed import get_dist_manager\nfrom lmdeploy.pytorch.kernels.cuda import fused_moe\nfrom lmdeploy.pytorch.kernels.cuda.fused_moe import _renormalize\nfrom lmdeploy.pytorch.model_inputs import get_step_ctx_manager\nfrom lmdeploy.utils import get_logger\n\nfrom .ep_utils import gather_outputs_by_attn_tp, split_inputs_by_attn_tp\n\nlogger = get_logger('lmdeploy')\n\n\nclass TritonFusedMoEImpl(FusedMoEImpl):\n    \"\"\"Triton fused moe implementation.\"\"\"\n\n    def __init__(self, top_k: int, num_experts: int, renormalize: bool = False):\n        self.num_experts = num_experts\n        self.top_k = top_k\n        self.renormalize = renormalize\n\n    def update_weights(self, gate_up_weights: torch.Tensor, down_weights: torch.Tensor):\n        gate_up_weights = gate_up_weights.transpose(1, 2).contiguous().transpose(1, 2)\n        down_weights = down_weights.transpose(1, 2).contiguous().transpose(1, 2)\n        return gate_up_weights, down_weights\n\n    def ep_expert_list(self, world_size: int, rank: int):\n        \"\"\"Experts list of current rank.\"\"\"\n        num_experts = self.num_experts\n        expert_per_rank = (num_experts + world_size - 1) // world_size\n        first_expert = rank * expert_per_rank\n        last_expert = min(first_expert + expert_per_rank, num_experts)\n        return list(range(first_expert, last_expert))\n\n    def forward(self,\n                hidden_states: torch.Tensor,\n                topk_weights: torch.Tensor,\n                topk_ids: torch.LongTensor,\n                gate_up_weights: torch.Tensor,\n                down_weights: torch.Tensor,\n                gate_up_bias: torch.Tensor = None,\n                down_bias: torch.Tensor = None,\n                expert_list: List[int] = None,\n                act_func: Callable = None):\n        \"\"\"forward.\"\"\"\n        expert_offset = 0\n        num_experts = None\n        if expert_list is not None and len(expert_list) != self.num_experts:\n            expert_offset = expert_list[0]\n            num_experts = self.num_experts\n        return fused_moe(hidden_states,\n                         gate_up_weights,\n                         down_weights,\n                         topk_weights=topk_weights,\n                         topk_ids=topk_ids,\n                         topk=self.top_k,\n                         w1_bias=gate_up_bias,\n                         w2_bias=down_bias,\n                         expert_offset=expert_offset,\n                         num_experts=num_experts,\n                         renormalize=self.renormalize,\n                         act_func=act_func)\n\n\n# modify from dlblas: https://github.com/DeepLink-org/DLBlas\nclass FusedMoENormal:\n\n    def __init__(\n        self,\n        ep_size: int,\n        ep_group: dist.ProcessGroup,\n        num_experts: int,\n        hidden_dim: int,\n        layer_index: int = 0,\n        top_k: int = 8,\n        out_dtype: torch.dtype = torch.bfloat16,\n    ):\n        from dlblas.layers.moe.token_dispatcher import DeepEPTokenDispatcherNormal\n        self.layer_index = layer_index\n        self.top_k = top_k\n        self.num_experts = num_experts\n        self.num_local_experts = num_experts // ep_size\n        self.out_dtype = out_dtype\n        self.token_dispatcher = DeepEPTokenDispatcherNormal(\n            group=ep_group,\n            num_experts=num_experts,\n            num_local_experts=self.num_local_experts,\n            hidden_size=hidden_dim,\n            params_dtype=out_dtype,\n        )\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        topk_weights: torch.Tensor,\n        topk_ids: torch.LongTensor,\n        up_weights: torch.Tensor,\n        down_weights: torch.Tensor,\n        expert_list: List[int] = None,\n    ):\n        \"\"\"forward.\"\"\"\n        from lmdeploy.pytorch.kernels.cuda.fused_moe_ep import fused_moe_v3\n        x, recv_topk_ids, recv_topk_weights, recv_tokens_per_expert = self.token_dispatcher.dispatch(\n            hidden_states,\n            topk_ids,\n            topk_weights,\n            expert_list,\n        )\n        topk_ids, topk_weights = None, None\n        out_states = fused_moe_v3(x, recv_topk_ids, recv_topk_weights, up_weights, down_weights, recv_tokens_per_expert)\n        out_states = self.token_dispatcher.combine(out_states)\n        return out_states\n\n    def capture(self):\n        return self.token_dispatcher.buffer_normal.capture()\n\n    def wait(self, event):\n        self.token_dispatcher.release()\n        event.current_stream_wait()\n\n    def dispatch_async(self,\n                       x: torch.Tensor,\n                       topk_idx: torch.Tensor,\n                       topk_weights: torch.Tensor,\n                       num_experts: Optional[int] = None,\n                       previous_event=None,\n                       async_finish=True):\n        return self.token_dispatcher.dispatch_normal_async(x, topk_idx, topk_weights, num_experts, previous_event,\n                                                           async_finish)\n\n    def combine_async(self, x: torch.Tensor, handle: tuple, previous_event=None, async_finish=True):\n        return self.token_dispatcher.combine_normal_async(x, handle, previous_event, async_finish)\n\n    def release(self):\n        return self.token_dispatcher.release()\n\n    def fusedmoe_forward(self, state, up_weight, down_weight):\n        from lmdeploy.pytorch.kernels.cuda.fused_moe_ep import fused_moe_v3\n        return fused_moe_v3(state['recv_hidden_states'], state['recv_topk_idx'], state['recv_topk_weights'], up_weight,\n                            down_weight, state['recv_tokens_per_expert'])\n\n\ndef _disposible_tensor(tensor):\n    from dlblas.utils.utils import DisposibleTensor\n    if isinstance(tensor, torch.Tensor):\n        tensor = DisposibleTensor(tensor)\n    else:\n        tensor = [DisposibleTensor(x) for x in tensor]\n    return tensor\n\n\ndef dispatch_ll(\n    self,\n    hidden_states: torch.Tensor,\n    topk_idx: torch.Tensor,\n    topk_weights: torch.Tensor,\n    num_experts: int,\n    use_fp8: bool = True,\n):\n    \"\"\"Dispatch low latency.\"\"\"\n    if num_experts is not None and self.num_experts is not None:\n        assert self.num_experts == num_experts\n    topk_idx = topk_idx.to(torch.int64)\n    expected_m = (hidden_states.shape[0] * self.get_buffer().group_size * topk_idx.shape[1] +\n                  num_experts) // num_experts\n\n    (\n        packed_recv_hidden,\n        masked_m,\n        self.handle,\n        event,\n        hook,\n    ) = self.get_buffer().low_latency_dispatch(\n        hidden_states,\n        topk_idx,\n        self.num_max_dispatch_tokens_per_rank,\n        num_experts,\n        use_fp8=use_fp8,\n        async_finish=not self.return_recv_hook,\n        return_recv_hook=self.return_recv_hook,\n    )\n    hook() if self.return_recv_hook else event.current_stream_wait()\n    packed_recv_hidden = _disposible_tensor(packed_recv_hidden)\n    return (\n        packed_recv_hidden,\n        topk_idx,\n        topk_weights,\n        masked_m,\n        expected_m,\n    )\n\n\ndef dispatch_async_ll(\n    self,\n    hidden_states: torch.Tensor,\n    topk_idx: torch.Tensor,\n    num_experts: Optional[int] = None,\n    use_fp8: bool = True,\n    async_finish: bool = True,\n):\n    assert topk_idx.dtype == torch.int64\n    if num_experts is not None and self.num_experts is not None:\n        assert self.num_experts == num_experts\n    (\n        recv_hidden_states,\n        recv_expert_count,\n        handle,\n        event,\n        hook,\n    ) = self.get_buffer().low_latency_dispatch(\n        hidden_states,\n        topk_idx,\n        self.num_max_dispatch_tokens_per_rank,\n        num_experts=self.num_experts,\n        use_fp8=use_fp8,\n        async_finish=async_finish,\n        return_recv_hook=not async_finish,\n    )\n    recv_hidden_states = _disposible_tensor(recv_hidden_states)\n    return recv_hidden_states, recv_expert_count, handle, event, hook\n\n\nclass FusedMoELowLatency:\n\n    def __init__(\n        self,\n        ep_size: int,\n        ep_group: dist.ProcessGroup,\n        num_experts: int,\n        hidden_dim: int,\n        layer_index: int,\n        out_dtype: torch.dtype = torch.bfloat16,\n    ):\n        from dlblas.layers.moe.token_dispatcher import DeepEPTokenDispatcherLowLatency\n        self.num_experts = num_experts\n        self.layer_index = layer_index\n        self.out_dtype = out_dtype\n        self.token_dispatcher = DeepEPTokenDispatcherLowLatency(\n            group=ep_group,\n            num_experts=num_experts,\n            num_local_experts=num_experts // ep_size,\n            hidden_size=hidden_dim,\n            params_dtype=out_dtype,\n        )\n\n    def experts(\n        self,\n        hidden_states: torch.Tensor,\n        gate_up_weight: torch.Tensor,\n        gate_down_weight: torch.Tensor,\n        masked_m: torch.Tensor,\n        expected_m: int,\n    ):\n        from dlblas.utils.utils import DisposibleTensor\n\n        from lmdeploy.pytorch.kernels.cuda.activation import silu_and_mul_moe_ep\n        from lmdeploy.pytorch.third_party.deep_gemm import m_grouped_bf16_gemm_nt_masked\n        num_groups, m, _ = hidden_states.shape\n        n = gate_up_weight.size(1)\n        expected_m = min(expected_m, m)\n        gateup_output = gate_up_weight.new_empty((num_groups, m, n))\n        m_grouped_bf16_gemm_nt_masked(DisposibleTensor.maybe_unwrap(hidden_states), gate_up_weight, gateup_output,\n                                      masked_m, expected_m)\n        DisposibleTensor.maybe_dispose(hidden_states)\n        down_input = silu_and_mul_moe_ep(gateup_output, masked_m)\n        del gateup_output\n        n = gate_down_weight.size(1)\n        down_output = down_input.new_empty((num_groups, m, n))\n        m_grouped_bf16_gemm_nt_masked(down_input, gate_down_weight, down_output, masked_m, expected_m)\n        return down_output\n\n    def forward(self,\n                hidden_states: torch.Tensor,\n                topk_weights: torch.Tensor,\n                topk_ids: torch.LongTensor,\n                up_weights: torch.Tensor,\n                down_weights: torch.Tensor,\n                expert_list: List[int] = None):\n        \"\"\"forward.\"\"\"\n        recv_hidden_states, topk_idx, topk_weights, masked_m, expected_m = dispatch_ll(\n            self.token_dispatcher,\n            hidden_states,\n            topk_ids,\n            topk_weights,\n            self.num_experts,\n            use_fp8=False,\n        )\n        hidden_states = None\n        out_states = self.experts(recv_hidden_states, up_weights, down_weights, masked_m, expected_m)\n        out_states = self.token_dispatcher.combine(out_states, topk_idx, topk_weights)\n        return out_states\n\n    def wait(self, event):\n        event.current_stream_wait()\n\n    def dispatch_async(\n        self,\n        hidden_states: torch.Tensor,\n        topk_idx: torch.Tensor,\n        num_experts: Optional[int] = None,\n        use_fp8: bool = False,\n        async_finish: bool = True,\n    ):\n        return dispatch_async_ll(self.token_dispatcher, hidden_states, topk_idx, num_experts, use_fp8, async_finish)\n\n    def combine_async(\n        self,\n        hidden_states: torch.Tensor,\n        topk_idx: torch.Tensor,\n        topk_weights: torch.Tensor,\n        handle: tuple,\n        async_finish: bool,\n    ):\n        return self.token_dispatcher.combine_async(hidden_states, topk_idx, topk_weights, handle, async_finish)\n\n    def fusedmoe_forward(self, state, up_weight, down_weight):\n        recv_hidden_states = state['recv_hidden_states']\n        masked_m = state['recv_expert_count']\n        hidden_shape = state['raw_hidden_shape']\n        topk_idx = state['topk_idx']\n        expected_m = (hidden_shape[0] * self.token_dispatcher.buffer_low_latency.group_size * topk_idx.shape[1] +\n                      self.token_dispatcher.num_experts) // self.token_dispatcher.num_experts\n        return self.experts(recv_hidden_states, up_weight, down_weight, masked_m, expected_m)\n\n\ndef build_deepep_moe(\n    low_latency_mode: bool,\n    ep_size: int,\n    ep_group: dist.ProcessGroup,\n    num_experts: int,\n    hidden_dim: int,\n    top_k: int,\n    layer_idx: int = 0,\n    out_dtype: torch.dtype = torch.bfloat16,\n):\n    if low_latency_mode:\n        return FusedMoELowLatency(ep_size=ep_size,\n                                  ep_group=ep_group,\n                                  num_experts=num_experts,\n                                  hidden_dim=hidden_dim,\n                                  layer_index=layer_idx,\n                                  out_dtype=out_dtype)\n    else:\n        return FusedMoENormal(ep_size=ep_size,\n                              ep_group=ep_group,\n                              num_experts=num_experts,\n                              hidden_dim=hidden_dim,\n                              layer_index=layer_idx,\n                              top_k=top_k,\n                              out_dtype=out_dtype)\n\n\nclass FusedMoEEPImpl(TritonFusedMoEImpl):\n    \"\"\"Fused moe implementation.\"\"\"\n\n    def __init__(\n        self,\n        ep_size: int,\n        ep_group: dist.ProcessGroup,\n        top_k: int,\n        num_experts: int,\n        hidden_dim: int,\n        renormalize: bool = False,\n        layer_idx: int = 0,\n        out_dtype: torch.dtype = torch.bfloat16,\n    ):\n        super().__init__(top_k, num_experts, renormalize)\n        self.num_experts = num_experts\n        self.ep_size = ep_size\n        self.ep_group = ep_group\n        self.hidden_dim = hidden_dim\n        self.layer_idx = layer_idx\n        self.out_dtype = out_dtype\n\n        try:\n            import deep_gemm  # noqa: F401\n        except ImportError:\n            logger.exception('DeepGEMM is required for DeepEP MoE implementation.')\n\n        try:\n            from dlblas.layers.moe.token_dispatcher import DeepEPBuffer, DeepEPMode, use_deepep  # noqa: F401\n            get_moe_backend().set_deepep_moe_backend()\n            if hasattr(DeepEPBuffer, 'set_explicitly_destroy'):\n                DeepEPBuffer.set_explicitly_destroy()\n        except ImportError:\n            logger.warning('For higher performance, please install DeepEP https://github.com/deepseek-ai/DeepEP')\n\n        # pre-allocate buffer\n        self.fusedmoe_build(True)\n\n    def update_weights(self, gate_up_weights: torch.Tensor, down_weights: torch.Tensor):\n        return gate_up_weights, down_weights\n\n    def forward(self,\n                hidden_states: torch.Tensor,\n                topk_weights: torch.Tensor,\n                topk_ids: torch.LongTensor,\n                gate_up_weights: torch.Tensor,\n                down_weights: torch.Tensor,\n                gate_up_bias: torch.Tensor = None,\n                down_bias: torch.Tensor = None,\n                expert_list: List[int] = None,\n                act_func: Callable = None):\n        \"\"\"forward.\"\"\"\n        assert act_func is None, 'Activation function is not supported in DeepEP MoE.'\n        hidden_states, topk_weights, topk_ids, split_size = split_inputs_by_attn_tp(hidden_states, topk_weights,\n                                                                                    topk_ids)\n\n        topk_weights = self.do_renormalize(topk_weights)\n        step_ctx = get_step_ctx_manager().current_context()\n        low_latency_mode = step_ctx.is_decoding\n        moe = self.fusedmoe_build(low_latency_mode)\n        out_states = moe.forward(hidden_states, topk_weights, topk_ids, gate_up_weights, down_weights, expert_list)\n\n        out_states = gather_outputs_by_attn_tp(out_states, split_size)\n        return out_states\n\n    def ep_expert_list(self, world_size: int, rank: int):\n        \"\"\"Experts list of current rank.\"\"\"\n        if get_dist_manager().current_context().dist_config.enable_eplb:\n            raise NotImplementedError('float16/bfloat16 enable_eplb is not Implemented.')\n        else:\n            return super().ep_expert_list(world_size=world_size, rank=rank)\n\n    def do_renormalize(self, topk_weights):\n        return _renormalize(topk_weights, self.renormalize)\n\n    def fusedmoe_build(self, low_latency_mode: bool = False):\n        deepep_moe = build_deepep_moe(low_latency_mode,\n                                      self.ep_size,\n                                      self.ep_group,\n                                      self.num_experts,\n                                      self.hidden_dim,\n                                      self.top_k,\n                                      layer_idx=self.layer_idx,\n                                      out_dtype=self.out_dtype)\n        return deepep_moe\n\n\nclass TritonFusedMoEBuilder(FusedMoEBuilder):\n    \"\"\"Triton fused moe builder.\"\"\"\n\n    @staticmethod\n    def build(\n        top_k: int,\n        num_experts: int,\n        renormalize: bool = False,\n        hidden_dim: int = 1,\n        ep_size: int = 1,\n        ep_group: dist.ProcessGroup = None,\n        layer_idx: int = 0,\n        out_dtype: torch.dtype = torch.bfloat16,\n    ):\n        \"\"\"Build from mlp.\"\"\"\n        if ep_size > 1:\n            return FusedMoEEPImpl(ep_size=ep_size,\n                                  ep_group=ep_group,\n                                  top_k=top_k,\n                                  num_experts=num_experts,\n                                  hidden_dim=hidden_dim,\n                                  renormalize=renormalize,\n                                  layer_idx=layer_idx,\n                                  out_dtype=out_dtype)\n        return TritonFusedMoEImpl(top_k=top_k, num_experts=num_experts, renormalize=renormalize)\n"
  },
  {
    "path": "lmdeploy/pytorch/backends/cuda/moe/ep_utils.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom typing import List\n\nimport torch\nfrom torch import distributed as dist\n\nfrom lmdeploy.pytorch.distributed import get_dist_manager\n\n\ndef split_inputs_by_attn_tp(\n    hidden_states: torch.Tensor,\n    topk_weights: torch.Tensor,\n    topk_ids: torch.Tensor,\n):\n    \"\"\"Split input by attn tp.\"\"\"\n    dist_ctx = get_dist_manager().current_context()\n    attn_tp = dist_ctx.dist_config.attn_tp\n    attn_rank = dist_ctx.attn_tp_group.rank\n    num_states = hidden_states.size(0)\n\n    if attn_tp == 1 or attn_tp > num_states:\n        return hidden_states, topk_weights, topk_ids, None\n\n    # split size\n    base = num_states // attn_tp\n    remain = num_states % attn_tp\n    split_size = [base + 1] * remain + [base] * (attn_tp - remain)\n\n    # split inputs\n    hidden_states = torch.split(hidden_states, split_size, dim=0)[attn_rank]\n    topk_weights = torch.split(topk_weights, split_size, dim=0)[attn_rank]\n    topk_ids = torch.split(topk_ids, split_size, dim=0)[attn_rank]\n\n    return hidden_states, topk_weights, topk_ids, split_size\n\n\ndef gather_outputs_by_attn_tp(out_states: torch.Tensor, split_size: List[int]):\n    \"\"\"Gather output by attn tp.\"\"\"\n    if split_size is None:\n        return out_states\n\n    dist_ctx = get_dist_manager().current_context()\n    gpu_group = dist_ctx.attn_tp_group.gpu_group\n    new_out_states = out_states.new_empty((sum(split_size), out_states.shape[1]))\n    new_out_states_list = list(new_out_states.split(split_size, dim=0))\n    dist.all_gather(new_out_states_list, out_states, group=gpu_group)\n    return new_out_states\n"
  },
  {
    "path": "lmdeploy/pytorch/backends/cuda/moe/w8a8.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nfrom typing import List\n\nimport torch\n\nfrom lmdeploy.pytorch.backends.moe import FusedMoEW8A8Builder, FusedMoEW8A8Impl\nfrom lmdeploy.pytorch.kernels.cuda import fused_moe_w8a8\nfrom lmdeploy.pytorch.kernels.cuda.w8a8_triton_kernels import per_token_quant_int8\nfrom lmdeploy.pytorch.models.q_modules import QTensor\nfrom lmdeploy.utils import get_logger\n\nlogger = get_logger('lmdeploy')\n\n\nclass TritonFusedMoEW8A8Impl(FusedMoEW8A8Impl):\n    \"\"\"Triton fused moe w8a8 implementation.\"\"\"\n\n    def __init__(\n        self,\n        top_k: int,\n        num_experts: int,\n        renormalize: bool = False,\n        out_dtype: torch.dtype = torch.float16,\n        quant_dtype: torch.dtype = torch.int8,\n    ):\n        self.num_experts = num_experts\n        self.top_k = top_k\n        self.renormalize = renormalize\n        self.out_dtype = out_dtype\n        self.quant_dtype = quant_dtype\n\n    def update_weights(self, gate_up_weights: torch.Tensor, down_weights: torch.Tensor, gate_up_scale: torch.Tensor,\n                       down_scale: torch.Tensor):\n        # do not transpose weight for int8/fp8\n        return gate_up_weights, down_weights, gate_up_scale, down_scale\n\n    def forward(self,\n                hidden_states: torch.Tensor,\n                topk_weights: torch.Tensor,\n                topk_ids: torch.LongTensor,\n                gate_up_weights: torch.Tensor,\n                gate_up_scale: torch.Tensor,\n                down_weights: torch.Tensor,\n                down_scale: torch.Tensor,\n                expert_list: List[int] = None):\n        \"\"\"forward.\"\"\"\n\n        if isinstance(hidden_states, torch.Tensor):\n            hidden_states = hidden_states.contiguous()\n            input_quant, input_scale = per_token_quant_int8(hidden_states, 1e-7, quant_dtype=self.quant_dtype)\n        else:\n            assert isinstance(hidden_states, QTensor)\n            input_quant, input_scale = (hidden_states.tensor, hidden_states.scale)\n\n        expert_offset = 0\n        num_experts = None\n        if expert_list is not None and len(expert_list) != self.num_experts:\n            expert_offset = expert_list[0]\n            num_experts = self.num_experts\n        return fused_moe_w8a8(input_quant,\n                              input_scale,\n                              gate_up_weights,\n                              gate_up_scale,\n                              down_weights,\n                              down_scale,\n                              topk_weights=topk_weights,\n                              topk_ids=topk_ids,\n                              topk=self.top_k,\n                              out_dtype=self.out_dtype,\n                              quant_dtype=self.quant_dtype,\n                              expert_offset=expert_offset,\n                              num_experts=num_experts,\n                              renormalize=self.renormalize)\n\n\nclass TritonFusedMoEW8A8Builder(FusedMoEW8A8Builder):\n    \"\"\"Triton fused moe w8a8 builder.\"\"\"\n\n    @staticmethod\n    def build(\n        top_k: int,\n        num_experts: int,\n        renormalize: bool = False,\n        out_dtype: torch.dtype = torch.float16,\n        quant_dtype: torch.dtype = torch.int8,\n    ):\n        \"\"\"Build from mlp.\"\"\"\n        return TritonFusedMoEW8A8Impl(top_k=top_k,\n                                      num_experts=num_experts,\n                                      renormalize=renormalize,\n                                      out_dtype=out_dtype,\n                                      quant_dtype=quant_dtype)\n"
  },
  {
    "path": "lmdeploy/pytorch/backends/cuda/moe_router.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom typing import Tuple\n\nimport torch\n\nfrom lmdeploy.pytorch.kernels.cuda.fused_noaux_tc import fused_noaux_tc_routing\n\nfrom ..default.moe_router import DefaultRouterNoauxTCImpl\nfrom ..moe_router import RouterNoauxTCBuilder, RouterNoauxTCImpl\n\n\ndef is_power_of_two(n):\n    return n > 0 and (n & (n - 1)) == 0\n\n\nclass TritonRouterNoauxTCImpl(DefaultRouterNoauxTCImpl):\n\n    def __init__(\n        self,\n        scoring_func: str,\n        top_k: int,\n        n_group: int,\n        topk_group: int,\n        n_routed_experts: int,\n        routed_scaling_factor: float,\n        renormalize: bool = True,\n        router_n_groups: int = -1,\n    ):\n        super().__init__(\n            scoring_func=scoring_func,\n            top_k=top_k,\n            n_group=n_group,\n            topk_group=topk_group,\n            n_routed_experts=n_routed_experts,\n            routed_scaling_factor=routed_scaling_factor,\n            renormalize=renormalize,\n            router_n_groups=router_n_groups,\n        )\n\n        self.enable_custom_kernel = self.should_enable_custom_kernel()\n\n    def should_enable_custom_kernel(self) -> bool:\n        if self.router_n_groups > 0:\n            return False\n\n        if self.scoring_func != 'sigmoid':\n            return False\n\n        if self.n_routed_experts % 32 != 0:\n            return False\n\n        if not is_power_of_two(self.n_routed_experts):\n            return False\n\n        if not is_power_of_two(self.n_group):\n            return False\n\n        return True\n\n    def forward(self, logits: torch.Tensor, bias: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"Router forward.\"\"\"\n        if self.enable_custom_kernel:\n            return fused_noaux_tc_routing(\n                logits,\n                bias,\n                num_experts=self.n_routed_experts,\n                n_group=self.n_group,\n                topk_group=self.topk_group,\n                top_k=self.top_k,\n                renormalize=self.renormalize,\n                routed_scaling_factor=self.routed_scaling_factor,\n            )\n        else:\n            return super().forward(logits, bias)\n\n\nclass TritonRouterNoauxTCBuilder(RouterNoauxTCBuilder):\n\n    @staticmethod\n    def build(\n        scoring_func: str,\n        top_k: int,\n        n_group: int,\n        topk_group: int,\n        n_routed_experts: int,\n        routed_scaling_factor: float,\n        renormalize: bool = True,\n        router_n_groups: int = -1,\n    ) -> RouterNoauxTCImpl:\n        return TritonRouterNoauxTCImpl(\n            scoring_func=scoring_func,\n            top_k=top_k,\n            n_group=n_group,\n            topk_group=topk_group,\n            n_routed_experts=n_routed_experts,\n            routed_scaling_factor=routed_scaling_factor,\n            renormalize=renormalize,\n            router_n_groups=router_n_groups,\n        )\n"
  },
  {
    "path": "lmdeploy/pytorch/backends/cuda/multinomial_sampling.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nimport torch\n\nfrom lmdeploy.pytorch.kernels.cuda import multinomial_sampling\n\nfrom ..multinomial_sampling import MultinomialSamplingBuilder, MultinomialSamplingImpl\n\n\nclass TritonMultinomialSamplingImpl(MultinomialSamplingImpl):\n\n    def forward(self,\n                scores: torch.Tensor,\n                seeds: torch.LongTensor,\n                offsets: torch.LongTensor,\n                indices: torch.Tensor = None):\n        \"\"\"forward.\"\"\"\n        return multinomial_sampling(scores, seeds, offsets, indices)\n\n\nclass TritonMultinomialSamplingBuilder(MultinomialSamplingBuilder):\n    \"\"\"Triton multinomial sampling builder.\"\"\"\n\n    def build():\n        \"\"\"build.\"\"\"\n        return TritonMultinomialSamplingImpl()\n"
  },
  {
    "path": "lmdeploy/pytorch/backends/cuda/norm.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport torch\n\nfrom lmdeploy.pytorch.kernels.cuda import rms_norm\n\nfrom ..norm import RMSNormBuilder, RMSNormImpl\n\n\nclass TritonRMSNormImpl(RMSNormImpl):\n    \"\"\"Triton RMS norm implementation.\"\"\"\n\n    def __init__(self, hidden_size: int, eps: float = 1e-6):\n        self.hidden_size = hidden_size\n        self.eps = eps\n\n    def forward(self, x: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor = None):\n        \"\"\"forward.\"\"\"\n        if residual is None:\n            x = rms_norm(x, weight, self.eps)\n            return x\n        else:\n            x, residual = rms_norm(x, weight, self.eps, residual=residual)\n            return x, residual\n\n\nclass TritonRMSNormBuilder(RMSNormBuilder):\n    \"\"\"Triton RMS norm implementation builder.\"\"\"\n\n    @staticmethod\n    def build(weight: torch.Tensor, eps: float = 1e-6):\n        \"\"\"build.\"\"\"\n        return TritonRMSNormImpl(weight, eps)\n"
  },
  {
    "path": "lmdeploy/pytorch/backends/cuda/nsa.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom torch import Tensor\n\nfrom lmdeploy.pytorch.kernels.cuda.bitonic_topk import bitonic_topk\nfrom lmdeploy.pytorch.kernels.cuda.blocked_gemm_fp8 import quant_fp8\nfrom lmdeploy.pytorch.kernels.cuda.ds_index import fp8_index\nfrom lmdeploy.pytorch.kernels.cuda.fill_kv_cache import fill_kv_cache_blocked_fp8\n\nfrom ..nsa import BaseNSAIndexFP8, BaseNSAIndexFP8Builder, NSAIndexMeta\n\n\nclass TritonNSAIndexFP8(BaseNSAIndexFP8):\n\n    def __init__(self, topk: int, softmax_scale: float, block_size: int, fill: int) -> None:\n        super().__init__()\n        self.topk = topk\n        self.softmax_scale = softmax_scale\n        self.block_size = block_size\n        self.fill = fill\n        # TODO: configable scale fmt\n        self.scale_fmt = 'ue8m0'\n\n    def forward(self, q: Tensor, k: Tensor, weights: Tensor, k_cache: Tensor, k_s_cache: Tensor,\n                meta: NSAIndexMeta) -> Tensor:\n\n        assert q.dim() == 3\n        assert k.dim() == 2\n        cu_seqlen_q = meta.cu_seqlen_q\n        q_seqlens = meta.q_seqlens\n        k_seqlens = meta.k_seqlens\n        block_offset = meta.block_offset\n        max_q_seqlen = meta.max_q_seqlen\n        max_kv_seqlen = meta.max_kv_seqlen\n\n        q_shape = q.shape\n        q = q.reshape(-1, q_shape[-1])\n        q, q_s = quant_fp8(q, self.block_size, dtype=k_cache.dtype, trans_scale=True, scale_fmt=self.scale_fmt)\n        q = q.reshape(*q_shape)\n        q_s = q_s.reshape(weights.shape)\n        q_s = q_s * self.softmax_scale * weights\n\n        fill_kv_cache_blocked_fp8(k[:, None],\n                                  None,\n                                  k_cache[..., None, :],\n                                  None,\n                                  k_s_cache[..., None, :],\n                                  None,\n                                  cu_seqlen_q=cu_seqlen_q,\n                                  kv_seqlens=k_seqlens,\n                                  max_q_seqlen=max_q_seqlen,\n                                  block_offsets=block_offset,\n                                  group_size=self.block_size,\n                                  scale_fmt=self.scale_fmt)\n\n        scores = fp8_index(q,\n                           q_s,\n                           k_cache,\n                           k_s_cache[..., 0],\n                           cu_seqlen_q,\n                           k_seqlens,\n                           block_offset,\n                           max_q_seqlen=max_q_seqlen,\n                           max_k_seqlen=max_kv_seqlen,\n                           causal=True)\n        return bitonic_topk(scores, q_seqlens, k_seqlens, self.topk, fill=self.fill, descending=True)\n\n\nclass TritonNSAIndexFP8Builder(BaseNSAIndexFP8Builder):\n\n    @staticmethod\n    def build(topk: int, softmax_scale: float, block_size: int = 128, fill: int = -1) -> BaseNSAIndexFP8:\n        return TritonNSAIndexFP8(topk, softmax_scale=softmax_scale, block_size=block_size, fill=fill)\n"
  },
  {
    "path": "lmdeploy/pytorch/backends/cuda/op_backend.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom typing import Tuple\n\nimport torch\n\nfrom lmdeploy.pytorch.config import BackendConfig, CacheConfig, ModelConfig\nfrom lmdeploy.utils import get_logger\n\nfrom ..base import OpType\nfrom ..default import DefaultOpsBackend\n\nlogger = get_logger('lmdeploy')\n\n\nclass CudaOpsBackend(DefaultOpsBackend):\n    \"\"\"Cuda layer backend.\"\"\"\n\n    @staticmethod\n    def get_name() -> str:\n        \"\"\"Backend name.\"\"\"\n        return 'cuda'\n\n    @classmethod\n    def get_layer_impl_builder(cls, layer_type: OpType):\n        \"\"\"Get cuda layer builder.\"\"\"\n        if layer_type == OpType.PagedAttention:\n            from .attention import TritonAttentionBuilder\n            return TritonAttentionBuilder\n        elif layer_type == OpType.FlashAttention:\n            from .flash_attention import TritonFlashAttentionBuilder\n            return TritonFlashAttentionBuilder\n        elif layer_type == OpType.ApplyRotaryEmb:\n            from .apply_rotary_emb import TritonApplyRotaryEmbBuilder\n            return TritonApplyRotaryEmbBuilder\n        elif layer_type == OpType.RMSNorm:\n            from .norm import TritonRMSNormBuilder\n            return TritonRMSNormBuilder\n        elif layer_type == OpType.LoRA:\n            from .lora import TritonLoRABuilder\n            return TritonLoRABuilder\n        elif layer_type == OpType.LinearW8A8:\n            from .qmodules import TritonLinearW8A8Builder\n            return TritonLinearW8A8Builder\n        elif layer_type == OpType.RMSNormW8A8:\n            from .qmodules import TritonRMSNormBuilder\n            return TritonRMSNormBuilder\n        elif layer_type == OpType.MultinomialSampling:\n            from .multinomial_sampling import TritonMultinomialSamplingBuilder\n            return TritonMultinomialSamplingBuilder\n        elif layer_type == OpType.SiluAndMul:\n            from .activation import TritonSiluAndMulBuilder\n            return TritonSiluAndMulBuilder\n        elif layer_type == OpType.LinearW4A16:\n            from .awq_modules import AwqLinearW4A16Builder\n            return AwqLinearW4A16Builder\n        elif layer_type == OpType.FusedMoE:\n            from .moe import TritonFusedMoEBuilder\n            return TritonFusedMoEBuilder\n        elif layer_type == OpType.FusedMoEW8A8:\n            from .moe import TritonFusedMoEW8A8Builder\n            return TritonFusedMoEW8A8Builder\n        elif layer_type == OpType.FusedMoEBlockedF8:\n            from .moe import TritonFusedMoEBlockedF8Builder\n            return TritonFusedMoEBlockedF8Builder\n        elif layer_type == OpType.LinearBlockedF8:\n            from .blockedf8_modules import TritonLinearBlockedF8Builder\n            return TritonLinearBlockedF8Builder\n        elif layer_type == OpType.NSAIndexFP8:\n            from .nsa import TritonNSAIndexFP8Builder\n            return TritonNSAIndexFP8Builder\n        elif layer_type == OpType.RouterNoauxTC:\n            from .moe_router import TritonRouterNoauxTCBuilder\n            return TritonRouterNoauxTCBuilder\n        elif layer_type == OpType.CausalConv1d:\n            from .causal_conv1d import CausalConv1dCudaBuilder\n            return CausalConv1dCudaBuilder\n        elif layer_type == OpType.GatedDeltaRule:\n            from .gated_delta_rule import CudaGatedDeltaRuleBuilder\n            return CudaGatedDeltaRuleBuilder\n        else:\n            logger.debug(f'Op {layer_type} fallback to default implementation.')\n            return super().get_layer_impl_builder(layer_type)\n\n    @staticmethod\n    def get_attention_metadata_cls():\n        \"\"\"Get attention metadata class.\"\"\"\n        from .attention import TritonAttentionMetadata\n        return TritonAttentionMetadata\n\n    @staticmethod\n    def get_k_block_shape(\n        block_size: int,\n        num_heads: int,\n        head_size: int,\n        dtype: torch.dtype,\n    ) -> Tuple[int, ...]:\n        \"\"\"Get k block shape.\"\"\"\n        return (\n            block_size,\n            num_heads,\n            head_size,\n        )\n\n    @staticmethod\n    def get_v_block_shape(\n        block_size: int,\n        num_heads: int,\n        head_size: int,\n        dtype: torch.dtype,\n    ) -> Tuple[int, ...]:\n        \"\"\"Get v block shape.\"\"\"\n        return (\n            block_size,\n            num_heads,\n            head_size,\n        )\n\n    @classmethod\n    def update_meta_flashmla(cls, attn_metadata, model_config: ModelConfig, decoding_query_len: int):\n        \"\"\"Update meta for flashmla.\"\"\"\n        import flash_mla\n        num_attention_heads = model_config.num_attention_heads * decoding_query_len\n        is_fp8_kvcache = model_config.use_mla_fp8_cache\n        index_topk = model_config.mla_index_topk\n        num_heads_q = None if index_topk is None else num_attention_heads\n        tile_scheduler_metadata, num_splits = flash_mla.get_mla_metadata(attn_metadata.kv_seqlens.to(torch.int32),\n                                                                         num_attention_heads,\n                                                                         num_heads_k=1,\n                                                                         num_heads_q=num_heads_q,\n                                                                         is_fp8_kvcache=is_fp8_kvcache,\n                                                                         topk=index_topk)\n        attn_metadata.tile_scheduler_metadata = tile_scheduler_metadata\n        attn_metadata.num_splits = num_splits\n\n        if attn_metadata.block_offsets.dtype != torch.int32:\n            attn_metadata.block_offsets = attn_metadata.block_offsets.to(torch.int32)\n\n    @classmethod\n    def update_meta_flashattn(cls, attn_metadata, step_context):\n        from lmdeploy.pytorch.models.utils.cudagraph import _get_meta_flashattn\n        batch_size = attn_metadata.q_seqlens.size(0)\n        max_seqlen_q = step_context.input_ids.size(1) // batch_size\n        block_size = step_context.kv_caches[0][0].size(1)\n        window_size = (step_context.model_config.sliding_window, ) * 2\n        scheduler_metadata = _get_meta_flashattn(\n            batch_size=batch_size,\n            max_seqlen_q=max_seqlen_q,\n            max_seqlen_k=step_context.max_kv_seqlen,\n            num_heads_q=step_context.model_config.num_attention_heads,\n            num_heads_kv=step_context.model_config.num_key_value_heads,\n            headdim=step_context.model_config.head_dim,\n            cache_seqlens=attn_metadata.kv_seqlens.to(torch.int32),\n            qkv_dtype=step_context.model_config.dtype,\n            page_size=block_size,\n            window_size=window_size,\n        )\n        attn_metadata.scheduler_metadata = scheduler_metadata\n        attn_metadata.max_kv_seqlen = step_context.max_kv_seqlen\n        return attn_metadata\n\n    @classmethod\n    def update_step_context(cls, step_context):\n        \"\"\"Update step context.\"\"\"\n        attn_meta_cls = cls.get_attention_metadata_cls()\n        q_seqlens = step_context.q_seqlens\n        kv_seqlens = step_context.kv_seqlens\n        kv_start_loc = None\n        kv_flatten_size = None\n        use_flash_mla = step_context.model_config.use_flash_mla\n        use_flash_attn3_decoding = step_context.model_config.model_paradigm == 'ar_spec'\n\n        # pad and cumsum requires 4 kernels, so we fuse seqlens cumsum into one kernel\n        seqlens = torch.stack([q_seqlens, kv_seqlens], dim=0)\n        cu_seqlens = torch.nn.functional.pad(torch.cumsum(seqlens, dim=1, dtype=torch.int32), (1, 0))\n        cu_seqlens_q = cu_seqlens[0]\n        cu_seqlens_k = cu_seqlens[1]\n        q_start_loc = step_context.q_start_loc\n        if not step_context.is_decoding:\n            kv_start_loc = cu_seqlens_k[:-1].to(kv_seqlens.dtype)\n            kv_flatten_size = step_context.sum_kv_seqlen\n\n        attn_metadata = attn_meta_cls(\n            step_context.is_decoding,\n            step_context.block_offsets,\n            q_start_loc=q_start_loc,\n            q_seqlens=q_seqlens,\n            kv_start_loc=kv_start_loc,\n            kv_seqlens=kv_seqlens,\n            kv_flatten_size=kv_flatten_size,\n            quant_policy=step_context.kv_quant_policy,\n            cu_seqlens_q=cu_seqlens_q,\n            cu_seqlens_k=cu_seqlens_k,\n            max_kv_seqlen=step_context.max_kv_seqlen,\n        )\n        if step_context.is_decoding:\n            if use_flash_mla:\n                model_config = step_context.model_config\n                decode_query_len = step_context.input_ids.size(1) // q_seqlens.size(0)\n                cls.update_meta_flashmla(attn_metadata, model_config, decode_query_len)\n            elif use_flash_attn3_decoding:\n                attn_metadata = cls.update_meta_flashattn(attn_metadata, step_context)\n\n        step_context.attn_metadata = attn_metadata\n        return step_context\n\n    @staticmethod\n    def build_graph_runner(model: torch.nn.Module, model_config: ModelConfig, cache_config: CacheConfig,\n                           backend_config: BackendConfig, device: torch.device):\n        \"\"\"Build graph runner.\"\"\"\n        from .graph_runner import CUDAGraphRunner\n        from .warmup_manager import WarmupMeta, get_warmup_manager\n\n        # warmup ops.\n        warmup_meta = WarmupMeta(\n            max_num_tokens=cache_config.max_prefill_token_num,\n            max_batch_size=cache_config.max_batches,\n            dtype=model_config.dtype,\n        )\n        get_warmup_manager().warmup(warmup_meta)\n\n        # make graph runner.\n        return CUDAGraphRunner(model, model_config, cache_config, backend_config, device)\n\n    @staticmethod\n    def device_count():\n        \"\"\"Get num available devices.\"\"\"\n        return torch.cuda.device_count()\n\n    @staticmethod\n    def support_ray():\n        \"\"\"Support ray.\"\"\"\n        return True\n"
  },
  {
    "path": "lmdeploy/pytorch/backends/cuda/qmodules.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom typing import Optional\n\nimport torch\n\nimport lmdeploy.pytorch.distributed as dist\nfrom lmdeploy.pytorch.kernels.cuda.w8a8_triton_kernels import (matmul_kernel_dynamic_quant, per_token_quant_int8,\n                                                               rms_norm_dynamic_quant)\nfrom lmdeploy.pytorch.models.q_modules import QTensor\n\nfrom ..qmodules import LinearW8A8Builder, LinearW8A8Impl, RMSNormW8A8Builder, RMSNormW8A8Impl\n\n\nclass TritonRMSNormW8A8Impl(RMSNormW8A8Impl):\n    \"\"\"Triton RMS norm w8a8 implementation api.\"\"\"\n\n    def __init__(self, hidden_size: int, eps: float = 1e-6, quant_dtype: torch.dtype = torch.int8):\n        super().__init__()\n        self.hidden_size = hidden_size\n        self.eps = eps\n        self.quant_dtype = quant_dtype\n\n    def forward(self, x: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor = None):\n        \"\"\"forward.\"\"\"\n        if residual is None:\n            (x, rms_scale) = rms_norm_dynamic_quant(x, weight, self.eps, quant_dtype=self.quant_dtype)\n            x = QTensor(x, rms_scale)\n            return x\n        else:\n            (x, rms_scale, residual) = rms_norm_dynamic_quant(x,\n                                                              weight,\n                                                              self.eps,\n                                                              residual=residual,\n                                                              quant_dtype=self.quant_dtype)\n            x = QTensor(x, rms_scale)\n            return x, residual\n\n\nclass TritonRMSNormBuilder(RMSNormW8A8Builder):\n    \"\"\"Triton RMS norm w8a8 implementation builder.\"\"\"\n\n    @staticmethod\n    def build(hidden_size: int, eps: float = 1e-6, quant_dtype: torch.dtype = torch.int8):\n        \"\"\"build.\"\"\"\n        return TritonRMSNormW8A8Impl(hidden_size, eps, quant_dtype)\n\n\nclass TritonLinearW8A8Impl(LinearW8A8Impl):\n    \"\"\"Triton linear w8a8 implementation.\"\"\"\n\n    def __init__(self,\n                 in_features: int,\n                 out_features: int,\n                 out_dtype: torch.dtype = torch.float16,\n                 quant_dtype: torch.dtype = torch.int8):\n        self.in_features = in_features\n        self.out_features = out_features\n        self.out_dtype = out_dtype\n        self.quant_dtype = quant_dtype\n\n    def forward(self,\n                x,\n                weight: torch.Tensor,\n                scale: torch.Tensor,\n                bias: Optional[torch.Tensor] = None,\n                all_reduce: bool = False,\n                group: Optional[torch.distributed.ProcessGroup] = None):\n        \"\"\"forward.\"\"\"\n        if isinstance(x, torch.Tensor):\n            input_quant, input_scale = per_token_quant_int8(x, 1e-7, quant_dtype=self.quant_dtype)\n        else:\n            assert isinstance(x, QTensor)\n            input_quant, input_scale = x.tensor, x.scale\n\n        out = matmul_kernel_dynamic_quant(input_quant,\n                                          weight,\n                                          input_scale,\n                                          scale,\n                                          output_dtype=self.out_dtype,\n                                          bias=bias)\n\n        if all_reduce:\n            dist.all_reduce(out, group=group)\n        return out\n\n\nclass TritonLinearW8A8Builder(LinearW8A8Builder):\n    \"\"\"Triton linear w8a8 implementation builder.\"\"\"\n\n    @staticmethod\n    def build(in_features: int,\n              out_features: int,\n              bias: bool = True,\n              dtype: torch.dtype = None,\n              quant_dtype: torch.dtype = torch.int8):\n        \"\"\"build.\"\"\"\n        return TritonLinearW8A8Impl(in_features, out_features, dtype, quant_dtype=quant_dtype)\n"
  },
  {
    "path": "lmdeploy/pytorch/backends/cuda/token_dispatcher.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\ntry:\n    from deep_ep import Buffer\n\n    from lmdeploy.pytorch.envs import deep_ep_buffer_num_sms\n\n    Buffer.set_num_sms(deep_ep_buffer_num_sms)\n    use_deepep = True\nexcept ImportError:\n    use_deepep = False\n\nfrom typing import List, Optional, Tuple\n\nimport torch\nimport torch.distributed as dist\n\nfrom ..default.token_dispatcher import AlltoAllTokenDispatcher\nfrom ..token_dispatcher import TokenDispatcherImpl\n\n_buffer_normal = None\n_buffer_low_latency = None\n_buffer_common = None\n\n\ndef get_buffer_common(\n    group: dist.ProcessGroup,\n    num_max_dispatch_tokens_per_rank: int,\n    hidden: int,\n    num_experts: int,\n    hidden_bytes: int,\n):\n    global _buffer_common\n    num_nvl_bytes, num_rdma_bytes = 0, 0\n    for config in (\n            Buffer.get_dispatch_config(group.size()),\n            Buffer.get_combine_config(group.size()),\n    ):\n        num_nvl_bytes = max(config.get_nvl_buffer_size_hint(hidden_bytes, group.size()), num_nvl_bytes)\n        num_rdma_bytes = max(config.get_rdma_buffer_size_hint(hidden_bytes, group.size()), num_rdma_bytes)\n\n    num_rdma_bytes = max(\n        Buffer.get_low_latency_rdma_size_hint(num_max_dispatch_tokens_per_rank, hidden, group.size(), num_experts),\n        num_rdma_bytes)\n\n    if (_buffer_common is None or _buffer_common.group != group or _buffer_common.num_nvl_bytes < num_nvl_bytes\n            or _buffer_common.num_rdma_bytes < num_rdma_bytes):\n        _buffer_common = Buffer(\n            group,\n            num_nvl_bytes=num_nvl_bytes,\n            num_rdma_bytes=num_rdma_bytes,\n            low_latency_mode=True,\n            num_qps_per_rank=max(num_experts // group.size(), Buffer.num_sms // 2),\n        )\n    return _buffer_common\n\n\ndef get_buffer_normal(group: dist.ProcessGroup, hidden_bytes: int):\n    \"\"\"Copy from DeepEP example usage in model inference prefilling.\n\n    https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-model-training-or-inference-prefilling\n    \"\"\"\n    global _buffer_normal\n    num_nvl_bytes, num_rdma_bytes = 0, 0\n    for config in (\n            Buffer.get_dispatch_config(group.size()),\n            Buffer.get_combine_config(group.size()),\n    ):\n        num_nvl_bytes = max(config.get_nvl_buffer_size_hint(hidden_bytes, group.size()), num_nvl_bytes)\n        num_rdma_bytes = max(config.get_rdma_buffer_size_hint(hidden_bytes, group.size()), num_rdma_bytes)\n\n    if (_buffer_normal is None or _buffer_normal.group != group or _buffer_normal.num_nvl_bytes < num_nvl_bytes\n            or _buffer_normal.num_rdma_bytes < num_rdma_bytes):\n        _buffer_normal = Buffer(group, num_nvl_bytes, num_rdma_bytes)\n    return _buffer_normal\n\n\ndef get_buffer_low_latency(\n    group: dist.ProcessGroup,\n    num_max_dispatch_tokens_per_rank: int,\n    hidden: int,\n    num_experts: int,\n):\n    \"\"\"Copy from DeepEP example usage in model inference decoding.\n\n    https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding\n    \"\"\"\n\n    global _buffer_low_latency\n    num_rdma_bytes = Buffer.get_low_latency_rdma_size_hint(num_max_dispatch_tokens_per_rank, hidden, group.size(),\n                                                           num_experts)\n\n    if (_buffer_low_latency is None or _buffer_low_latency.group != group or not _buffer_low_latency.low_latency_mode\n            or _buffer_low_latency.num_rdma_bytes < num_rdma_bytes):\n        assert num_experts % group.size(\n        ) == 0, f'num_experts: {num_experts} must be divisible by ep_size: {group.size()}'\n        _buffer_low_latency = Buffer(\n            group,\n            num_rdma_bytes=num_rdma_bytes,\n            low_latency_mode=True,\n            num_qps_per_rank=max(num_experts // group.size(), Buffer.num_sms // 2),\n        )\n    return _buffer_low_latency\n\n\nclass DeepEPTokenDispatcher(TokenDispatcherImpl):\n    \"\"\"Copy from Megatron-Core token_dispatcher MoEFlexTokenDispatcher\n    https://github.com/NVIDIA/Megatron-\n    LM/blob/main/megatron/core/transformer/moe/token_dispatcher.py.\"\"\"\n\n    def __init__(\n        self,\n        group: torch.distributed.ProcessGroup,\n        num_experts: int = None,\n        num_local_experts: int = None,\n        hidden_size: int = None,\n        params_dtype: torch.dtype = None,\n        num_max_dispatch_tokens_per_rank=128,\n    ):\n        self.group = group\n        self.num_experts = num_experts\n        self.num_local_experts = num_local_experts\n        self.hidden_size = hidden_size\n        self.params_bytes = params_dtype.itemsize\n        # Handle used for combine operation\n        self.handle = None\n        if not use_deepep:\n            raise ImportError('DeepEP is not installed. Please install DeepEP package from '\n                              'https://github.com/deepseek-ai/deepep.')\n        self.buffer_normal = get_buffer_common(self.group,\n                                               num_max_dispatch_tokens_per_rank,\n                                               self.hidden_size,\n                                               self.num_experts,\n                                               hidden_bytes=self.hidden_size * self.params_bytes)\n\n    def dispatch(\n        self,\n        hidden_states: torch.Tensor,\n        topk_idx: torch.Tensor,\n        topk_weights: torch.Tensor,\n        expert_list: List[int] = None,\n        previous_event=None,\n    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:\n        self.hidden_shape = hidden_states.shape\n        topk_idx = topk_idx.to(torch.int64)\n        (\n            hidden_states,\n            topk_idx,\n            topk_weights,\n            recv_tokens_per_expert,\n            handle,\n            event,\n        ) = self.dispatch_normal(hidden_states, topk_idx, topk_weights, self.num_experts, previous_event)\n        self.tokens_per_expert = torch.tensor(\n            recv_tokens_per_expert,\n            device=hidden_states.device,\n            dtype=torch.int64,\n        )\n        tokens_per_expert = self.get_number_of_tokens_per_expert()\n        self.handle = handle\n        self.topk_idx = topk_idx\n        self.topk_weights = topk_weights\n        if hidden_states.shape[0] > 0:\n            hidden_states, _, _, _, _ = self.get_permuted_hidden_states_by_experts(hidden_states)\n        return hidden_states, topk_idx, topk_weights, tokens_per_expert\n\n    def dispatch_normal(\n        self,\n        x: torch.Tensor,\n        topk_idx: torch.Tensor,\n        topk_weights: torch.Tensor,\n        num_experts: int,\n        previous_event=None,\n    ):\n        (\n            num_tokens_per_rank,\n            num_tokens_per_rdma_rank,\n            num_tokens_per_expert,\n            is_token_in_rank,\n            previous_event,\n        ) = self.buffer_normal.get_dispatch_layout(\n            topk_idx,\n            num_experts,\n            previous_event=previous_event,\n            async_finish=False,\n            allocate_on_comm_stream=False,\n        )\n\n        (\n            recv_x,\n            recv_topk_idx,\n            recv_topk_weights,\n            recv_tokens_per_expert,\n            handle,\n            event,\n        ) = self.buffer_normal.dispatch(\n            x,\n            topk_idx=topk_idx,\n            topk_weights=topk_weights.to(torch.float32),\n            num_tokens_per_rank=num_tokens_per_rank,\n            num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,\n            is_token_in_rank=is_token_in_rank,\n            num_tokens_per_expert=num_tokens_per_expert,\n            previous_event=previous_event,\n            async_finish=False,\n            allocate_on_comm_stream=False,\n        )\n\n        return (\n            recv_x,\n            recv_topk_idx,\n            recv_topk_weights,\n            recv_tokens_per_expert,\n            handle,\n            event,\n        )\n\n    def dispatch_normal_async(self,\n                              x: torch.Tensor,\n                              topk_idx: torch.Tensor,\n                              topk_weights: torch.Tensor,\n                              num_experts: Optional[int] = None,\n                              previous_event=None,\n                              async_finish=True):\n        (\n            num_tokens_per_rank,\n            num_tokens_per_rdma_rank,\n            num_tokens_per_expert,\n            is_token_in_rank,\n            previous_event,\n        ) = self.buffer_normal.get_dispatch_layout(\n            topk_idx,\n            num_experts=self.num_experts if num_experts is None else num_experts,\n            previous_event=previous_event,\n            async_finish=async_finish,\n            allocate_on_comm_stream=previous_event is not None and async_finish,\n        )\n\n        (\n            recv_x,\n            recv_topk_idx,\n            recv_topk_weights,\n            recv_tokens_per_expert,\n            handle,\n            event,\n        ) = self.buffer_normal.dispatch(\n            x,\n            topk_idx=topk_idx,\n            topk_weights=topk_weights,\n            num_tokens_per_rank=num_tokens_per_rank,\n            num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,\n            is_token_in_rank=is_token_in_rank,\n            num_tokens_per_expert=num_tokens_per_expert,\n            previous_event=previous_event,\n            async_finish=async_finish,\n            allocate_on_comm_stream=previous_event is not None and async_finish,\n        )\n\n        return (\n            recv_x,\n            recv_topk_idx,\n            recv_topk_weights,\n            recv_tokens_per_expert,\n            handle,\n            event,\n        )\n\n    def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        if hidden_states.shape[0] > 0:\n            hidden_states = self.get_restored_hidden_states_by_experts(hidden_states)\n        hidden_states, event = self.combine_normal(hidden_states, self.handle)\n        self.handle = None\n        return hidden_states.view(self.hidden_shape)\n\n    def combine_normal(self, x: torch.Tensor, handle: Tuple, previous_event=None):\n        combined_x, _, event = self.buffer_normal.combine(\n            x,\n            handle,\n            async_finish=False,\n            previous_event=previous_event,\n            allocate_on_comm_stream=False,\n        )\n        return combined_x, event\n\n    def combine_normal_async(self, x: torch.Tensor, handle: Tuple, previous_event=None, async_finish=True):\n        combined_x, _, event = self.buffer_normal.combine(\n            x,\n            handle,\n            async_finish=async_finish,\n            previous_event=previous_event,\n            allocate_on_comm_stream=previous_event is not None and async_finish,\n        )\n        return combined_x, event\n\n    def release(self):\n        self.tokens_per_expert = None\n        self.handle = None\n        self.topk_idx = None\n        self.topk_weights = None\n        self.hidden_shape_before_permute = None\n        self.dispatched_routing_map = None\n        self.reversed_mapping_for_combine = None\n        return True\n\n    def get_number_of_tokens_per_expert(self) -> torch.Tensor:\n        \"\"\"Get the number of tokens per expert.\"\"\"\n        return self.tokens_per_expert\n\n    def get_permuted_hidden_states_by_experts(self,\n                                              hidden_states: torch.Tensor,\n                                              topk_idx: Optional[torch.Tensor] = None,\n                                              topk_weights: Optional[torch.Tensor] = None,\n                                              num_experts: Optional[int] = None) -> torch.Tensor:\n        (dispatched_routing_map,\n         topk_weights) = super().indices_to_multihot(self.topk_idx if topk_idx is None else topk_idx,\n                                                     self.topk_weights if topk_weights is None else topk_weights,\n                                                     self.num_experts if num_experts is None else num_experts)\n        hidden_states_shape = hidden_states.shape\n        (hidden_states, reversed_mapping_for_combine) = super().permute(\n            hidden_states,\n            dispatched_routing_map,\n        )\n        self.hidden_shape_before_permute = hidden_states_shape\n        self.dispatched_routing_map = dispatched_routing_map\n        self.topk_weights = topk_weights\n        self.reversed_mapping_for_combine = reversed_mapping_for_combine\n        return hidden_states, hidden_states_shape, dispatched_routing_map, topk_weights, reversed_mapping_for_combine\n\n    def get_restored_hidden_states_by_experts(\n        self,\n        hidden_states: torch.Tensor,\n        reversed_mapping_for_combine: Optional[torch.Tensor] = None,\n        hidden_shape_before_permute: Optional[torch.Size] = None,\n        dispatched_routing_map: Optional[torch.Tensor] = None,\n        topk_weights: Optional[torch.Tensor] = None,\n    ) -> torch.Tensor:\n        input_dtype = hidden_states.dtype\n        assert (self.topk_weights.dtype == torch.float32), 'DeepEP only supports float32 probs'\n        hidden_states = super().unpermute(\n            hidden_states,\n            sorted_indices=self.reversed_mapping_for_combine\n            if reversed_mapping_for_combine is None else reversed_mapping_for_combine,\n            restore_shape=self.hidden_shape_before_permute\n            if hidden_shape_before_permute is None else hidden_shape_before_permute,\n            routing_map=self.dispatched_routing_map if dispatched_routing_map is None else dispatched_routing_map,\n            probs=self.topk_weights if topk_weights is None else topk_weights,\n        )\n        return hidden_states.to(input_dtype)\n\n\nclass DeepEPTokenDispatcherLowLatency(TokenDispatcherImpl):\n\n    def __init__(\n        self,\n        group: torch.distributed.ProcessGroup,\n        num_experts: int = None,\n        num_local_experts: int = None,\n        hidden_size: int = None,\n        params_dtype: torch.dtype = None,\n        return_recv_hook: bool = False,\n    ):\n        if not use_deepep:\n            raise ImportError('DeepEP is not installed. Please install DeepEP package from '\n                              'https://github.com/deepseek-ai/deepep.')\n        self.group = group\n        self.num_experts = num_experts\n        self.num_local_experts = num_local_experts\n        self.hidden_size = hidden_size\n        self.params_bytes = params_dtype.itemsize\n        self.handle = None\n        self.num_max_dispatch_tokens_per_rank = 128\n        self.buffer_low_latency = get_buffer_common(self.group,\n                                                    self.num_max_dispatch_tokens_per_rank,\n                                                    self.hidden_size,\n                                                    self.num_experts,\n                                                    hidden_bytes=self.hidden_size * self.params_bytes)\n        self.return_recv_hook = return_recv_hook\n\n    def dispatch(\n        self,\n        hidden_states: torch.Tensor,\n        topk_idx: torch.Tensor,\n        topk_weights: torch.Tensor,\n        num_experts: int,\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n        topk_idx = topk_idx.to(torch.int64)\n        expected_m = (hidden_states.shape[0] * self.buffer_low_latency.group_size * topk_idx.shape[1] +\n                      num_experts) // num_experts\n\n        packed_recv_hidden, masked_m, self.handle, event, hook = (self.buffer_low_latency.low_latency_dispatch(\n            hidden_states,\n            topk_idx,\n            self.num_max_dispatch_tokens_per_rank,\n            num_experts,\n            use_fp8=True,\n            async_finish=not self.return_recv_hook,\n            return_recv_hook=self.return_recv_hook,\n        ))\n        hook() if self.return_recv_hook else event.current_stream_wait()\n        return (\n            packed_recv_hidden,\n            topk_idx,\n            topk_weights,\n            masked_m,\n            expected_m,\n        )\n\n    def dispatch_async(\n        self,\n        hidden_states: torch.Tensor,\n        topk_idx: torch.Tensor,\n        num_experts: Optional[int] = None,\n        use_fp8: bool = True,\n        async_finish: bool = True,\n    ):\n        assert topk_idx.dtype == torch.int64\n        recv_hidden_states, recv_expert_count, handle, event, hook = (self.buffer_low_latency.low_latency_dispatch(\n            hidden_states,\n            topk_idx,\n            self.num_max_dispatch_tokens_per_rank,\n            num_experts=self.num_experts if num_experts is None else num_experts,\n            use_fp8=use_fp8,\n            async_finish=async_finish,\n            return_recv_hook=not async_finish,\n        ))\n        return recv_hidden_states, recv_expert_count, handle, event, hook\n\n    def combine(\n        self,\n        hidden_states: torch.Tensor,\n        topk_idx: torch.Tensor,\n        topk_weights: torch.Tensor,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:\n        combined_hidden_states, event, hook = (self.buffer_low_latency.low_latency_combine(\n            hidden_states,\n            topk_idx,\n            topk_weights.to(torch.float32),\n            self.handle,\n            async_finish=not self.return_recv_hook,\n            return_recv_hook=self.return_recv_hook,\n        ))\n        hook() if self.return_recv_hook else event.current_stream_wait()\n        return combined_hidden_states\n\n    def combine_async(\n        self,\n        hidden_states: torch.Tensor,\n        topk_idx: torch.Tensor,\n        topk_weights: torch.Tensor,\n        handle: Tuple,\n        async_finish: bool,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:\n        assert topk_idx.dtype == torch.int64\n        assert topk_weights.dtype == torch.float32\n        combined_hidden_states, event, hook = self.buffer_low_latency.low_latency_combine(\n            hidden_states,\n            topk_idx,\n            topk_weights,\n            handle,\n            async_finish=async_finish,\n            return_recv_hook=not async_finish,\n        )\n        return combined_hidden_states, event, hook\n\n\nclass TokenDispatcherBuilder:\n    \"\"\"Token dispatcher builder.\"\"\"\n\n    @staticmethod\n    def build(\n        group,\n        num_experts,\n        num_local_experts,\n        hidden_size,\n        params_dtype,\n    ) -> TokenDispatcherImpl:\n        \"\"\"build.\"\"\"\n        if use_deepep is True:\n            return DeepEPTokenDispatcher(\n                group,\n                num_experts,\n                num_local_experts,\n                hidden_size,\n                params_dtype,\n            )\n        else:\n            return AlltoAllTokenDispatcher(\n                group,\n                num_experts,\n                num_local_experts,\n            )\n"
  },
  {
    "path": "lmdeploy/pytorch/backends/cuda/utils.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom functools import lru_cache\n\n\n@lru_cache\ndef has_tilelang():\n    try:\n        import tilelang  # noqa: F401\n        return True\n    except Exception:\n        return False\n"
  },
  {
    "path": "lmdeploy/pytorch/backends/cuda/warmup_manager.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom dataclasses import dataclass\n\nimport torch\n\nfrom lmdeploy.pytorch.utils import singleton\nfrom lmdeploy.utils import get_logger\n\nlogger = get_logger('lmdeploy')\n\n\n@dataclass\nclass WarmupMeta:\n    \"\"\"Warmup meta.\"\"\"\n    max_num_tokens: int\n    max_batch_size: int\n    dtype: torch.dtype\n\n\n@singleton\nclass WarmupManager:\n\n    def __init__(self):\n        self._warmup_calls = dict()\n\n    def __contains__(self, key: str):\n        \"\"\"Contain key.\"\"\"\n        return key in self._warmup_calls\n\n    def __getitem__(self, key: str):\n        \"\"\"Get item.\"\"\"\n        return self._warmup_calls.get(key, None)\n\n    def __setitem__(self, key: str, val):\n        \"\"\"Set item.\"\"\"\n        self._warmup_calls[key] = val\n\n    def warmup(self, warmup_meta: WarmupMeta):\n        \"\"\"Warmup meta.\"\"\"\n        if len(self._warmup_calls) == 0:\n            return\n        import random\n        logger.info('Warming up ops.')\n        funcs = list(self._warmup_calls.values())\n        random.shuffle(funcs)\n        for func in funcs:\n            func(warmup_meta)\n\n\ndef get_warmup_manager():\n    \"\"\"Get warmup manager.\"\"\"\n    return WarmupManager()\n"
  },
  {
    "path": "lmdeploy/pytorch/backends/deepep_moe_checker.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom lmdeploy.pytorch.utils import singleton\n\n\n@singleton\nclass MoEBackend:\n\n    def __init__(self):\n        \"\"\"Initialize moe backend.\"\"\"\n        self._use_deepep_moe_backend = False\n\n    def set_deepep_moe_backend(self):\n        \"\"\"Set deepep moe backend.\"\"\"\n        self._use_deepep_moe_backend = True\n\n    def use_deepep_moe_backend(self):\n        \"\"\"Get deepep moe backend.\"\"\"\n        return self._use_deepep_moe_backend\n\n\ndef get_moe_backend():\n    return MoEBackend()\n"
  },
  {
    "path": "lmdeploy/pytorch/backends/default/__init__.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom .op_backend import DefaultOpsBackend  # noqa: F401\n"
  },
  {
    "path": "lmdeploy/pytorch/backends/default/activation.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nfrom torch import nn\n\nfrom ..activation import GeluAndMulBuilder, GeluAndMulImpl, SiluAndMulBuilder, SiluAndMulImpl\n\n\nclass DefaultSiluAndMulImpl(SiluAndMulImpl):\n    \"\"\"Silu + multiple residual fused implementation.\"\"\"\n\n    def __init__(self, inplace: bool):\n        self.inplace = inplace\n        self.silu = nn.SiLU(inplace)\n\n    def forward(self, x):\n        \"\"\"forward.\"\"\"\n        gate, up = x.chunk(2, -1)\n        return self.silu(gate) * up\n\n\nclass DefaultSiluAndMulBuilder(SiluAndMulBuilder):\n    \"\"\"Silu and mul implementation builder.\"\"\"\n\n    @staticmethod\n    def build(inplace: bool = False):\n        \"\"\"build.\"\"\"\n        return DefaultSiluAndMulImpl(inplace)\n\n\nclass DefaultGeluAndMulImpl(GeluAndMulImpl):\n    \"\"\"Gelu + multiple residual fused implementation.\"\"\"\n\n    def __init__(self, approximate: str = 'none'):\n        self.act = nn.GELU(approximate=approximate)\n\n    def forward(self, x):\n        \"\"\"forward.\"\"\"\n        gate, up = x.chunk(2, -1)\n        return self.act(gate) * up\n\n\nclass DefaultGeluAndMulBuilder(GeluAndMulBuilder):\n    \"\"\"Gelu and mul implementation builder.\"\"\"\n\n    @staticmethod\n    def build(approximate: str = 'none'):\n        \"\"\"build.\"\"\"\n        return DefaultGeluAndMulImpl(approximate)\n"
  },
  {
    "path": "lmdeploy/pytorch/backends/default/apply_rotary_emb.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport torch\nfrom torch import Tensor\n\nfrom ..apply_rotary_emb import ApplyRotaryEmbBuilder, ApplyRotaryEmbImpl\n\n\ndef rotate_half(x):\n    \"\"\"Rotates half the hidden dims of the input.\"\"\"\n    half_size = x.shape[-1] // 2\n    x1 = x[..., :half_size]\n    x2 = x[..., half_size:]\n    out = torch.empty_like(x)\n    out[..., :half_size] = -x2\n    out[..., half_size:] = x1\n    return out\n\n\nclass DefaultApplyRotaryEmbImpl(ApplyRotaryEmbImpl):\n    \"\"\"Apply rotary embedding implementation.\"\"\"\n\n    def forward(self, query: Tensor, key: Tensor, cos: Tensor, sin: Tensor, inplace: bool = True):\n        \"\"\"forward.\"\"\"\n        unsqueeze_dim = -2\n        cos = cos.unsqueeze(unsqueeze_dim)\n        sin = sin.unsqueeze(unsqueeze_dim)\n        if inplace:\n            q_embed = query\n            k_embed = key\n            q_sin = rotate_half(query) * sin\n            q_embed.mul_(cos)\n            q_embed.add_(q_sin)\n            k_sin = rotate_half(key) * sin\n            k_embed.mul_(cos)\n            k_embed.add_(k_sin)\n        else:\n            q_embed = (query * cos) + (rotate_half(query) * sin)\n            k_embed = (key * cos) + (rotate_half(key) * sin)\n        return q_embed, k_embed\n\n\nclass DefaultApplyRotaryEmbBuilder(ApplyRotaryEmbBuilder):\n    \"\"\"Apply rotary embedding implementation builder.\"\"\"\n\n    @staticmethod\n    def build():\n        \"\"\"Build implementation.\"\"\"\n        return DefaultApplyRotaryEmbImpl()\n"
  },
  {
    "path": "lmdeploy/pytorch/backends/default/awq_modules.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom functools import lru_cache\nfrom typing import Optional\n\nimport torch\n\nimport lmdeploy.pytorch.distributed as dist\n\nfrom ..awq_modules import LinearW4A16Builder, LinearW4A16Impl\n\n\n@lru_cache\ndef get_shifts(bits: int, device: torch.device):\n    \"\"\"Get awq shifts.\"\"\"\n    shifts = torch.arange(0, 32, bits, device=device)\n    shifts = shifts.view(2, 4).t().flatten()\n    return shifts\n\n\ndef unpack_awq(qweight: torch.Tensor, qzeros: torch.Tensor, bits: int):\n    shifts = get_shifts(bits, qzeros.device)\n\n    # unpacking columnwise\n    iweights = torch.bitwise_right_shift(qweight[:, :, None], shifts[None, None, :]).to(torch.int8)\n    iweights = iweights.view(iweights.shape[0], -1)\n\n    # unpacking columnwise\n    izeros = torch.bitwise_right_shift(qzeros[:, :, None], shifts[None, None, :]).to(torch.int8)\n    izeros = izeros.view(izeros.shape[0], -1)\n\n    # overflow checks\n    iweights = torch.bitwise_and(iweights, (2**bits) - 1)\n    izeros = torch.bitwise_and(izeros, (2**bits) - 1)\n\n    return iweights, izeros\n\n\ndef dequantize_gemm(qweight, qzeros, scales, bits, group_size):\n    # Unpack the qweight and qzeros tensors\n    iweight, izeros = unpack_awq(qweight, qzeros, bits)\n\n    # fp16 weights\n    iweight = iweight.unflatten(0, (-1, group_size))\n    iweight = (iweight - izeros[:, None]) * scales[:, None]\n    iweight = iweight.flatten(0, 1)\n\n    return iweight\n\n\nclass DefaultLinearW4A16Impl(LinearW4A16Impl):\n    \"\"\"W4a16 linear implementation.\"\"\"\n\n    def __init__(self, in_features: int, out_features: int, w_bit: int, group_size: int):\n        self.in_features = in_features\n        self.out_features = out_features\n        self.w_bit = w_bit\n        self.group_size = group_size\n\n    def forward(self,\n                x,\n                qweight: torch.Tensor,\n                scales: torch.Tensor,\n                qzeros: torch.Tensor,\n                bias: Optional[torch.Tensor] = None,\n                all_reduce: bool = False,\n                group: Optional[torch.distributed.ProcessGroup] = None):\n        \"\"\"forward.\"\"\"\n        out_shape = x.shape[:-1] + (self.out_features, )\n        input_dtype = x.dtype\n        if input_dtype != torch.float16:\n            x = x.half()\n        out = dequantize_gemm(qweight, qzeros, scales, self.w_bit, self.group_size)\n        out = torch.matmul(x, out)\n\n        out = out + bias if bias is not None else out\n        out = out.reshape(out_shape)\n\n        if input_dtype != torch.float16:\n            out = out.to(dtype=input_dtype)\n        if all_reduce:\n            dist.all_reduce(out, group=group)\n        return out\n\n\nclass DefaultLinearW4A16Builder(LinearW4A16Builder):\n    \"\"\"W4a16 linear implementation builder.\"\"\"\n\n    @staticmethod\n    def build(in_features: int,\n              out_features: int,\n              w_bit: int,\n              group_size: int,\n              bias: bool = False,\n              dtype: torch.dtype = None):\n        \"\"\"build.\"\"\"\n        return DefaultLinearW4A16Impl(in_features, out_features, w_bit, group_size)\n"
  },
  {
    "path": "lmdeploy/pytorch/backends/default/embedding.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport torch\nimport torch.distributed as dist\nimport torch.nn.functional as F\n\nfrom ..embedding import EmbeddingBuilder, EmbeddingImpl\n\n\ndef get_masked_input_and_mask(input: torch.Tensor, start_index: int, end_index: int):\n    input = input - start_index\n    masked_input = input.clamp(0, end_index - start_index - 1)\n    inv_vocab_mask = masked_input != input\n    return masked_input, inv_vocab_mask\n\n\nclass DefaultEmbeddingImpl(EmbeddingImpl):\n    \"\"\"Embedding implementation api.\"\"\"\n\n    def __init__(self, start_index: int, end_index: int):\n        self.start_index = start_index\n        self.end_index = end_index\n\n    def forward(self, x, weight: torch.Tensor, all_reduce: bool = False, group: dist.ProcessGroup = None):\n        \"\"\"forward.\"\"\"\n        if all_reduce:\n            mask_input, inv_vocab_mask = get_masked_input_and_mask(x, self.start_index, self.end_index)\n            out = F.embedding(mask_input, weight)\n            out.masked_fill_(inv_vocab_mask.unsqueeze(-1), 0)\n            dist.all_reduce(out, group=group)\n        else:\n            out = F.embedding(x, weight)\n\n        return out\n\n\nclass DefaultEmbeddingBuilder(EmbeddingBuilder):\n    \"\"\"Embedding implementation builder.\"\"\"\n\n    @staticmethod\n    def build(start_index: int, end_index: int):\n        \"\"\"build.\"\"\"\n        return DefaultEmbeddingImpl(start_index=start_index, end_index=end_index)\n"
  },
  {
    "path": "lmdeploy/pytorch/backends/default/linear.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom typing import List, Optional\n\nimport torch\nimport torch.distributed as dist\nimport torch.nn.functional as F\n\nfrom ..linear import LinearBuilder, LinearImpl\n\n\nclass DefaultLinearImpl(LinearImpl):\n    \"\"\"Linear implementation api.\"\"\"\n\n    def forward(self,\n                x,\n                weight: torch.Tensor,\n                bias: Optional[torch.Tensor] = None,\n                all_reduce: bool = False,\n                group: dist.ProcessGroup = None,\n                rank: int = 0,\n                scatter_size: List[int] = None):\n        \"\"\"forward.\"\"\"\n        out = F.linear(x, weight, bias)\n        if all_reduce:\n            if scatter_size is not None:\n                from lmdeploy.pytorch.distributed import reduce_scatter_by_tp_sizes\n                out = reduce_scatter_by_tp_sizes(out, rank, scatter_size, group=group)\n            else:\n                dist.all_reduce(out, group=group)\n        return out\n\n\nclass DefaultLinearBuilder(LinearBuilder):\n    \"\"\"Linear implementation builder.\"\"\"\n\n    @staticmethod\n    def build(in_features: int, out_features: int, bias: bool = True, dtype: torch.dtype = None):\n        \"\"\"build.\"\"\"\n        return DefaultLinearImpl()\n"
  },
  {
    "path": "lmdeploy/pytorch/backends/default/moe.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport torch\n\nfrom ..moe import SoftmaxTopKBuilder, SoftmaxTopKImpl\n\n\nclass DefaultSoftmaxTopKImpl(SoftmaxTopKImpl):\n    \"\"\"RMS norm implementation api.\"\"\"\n\n    def __init__(self, top_k: int, dim: int = -1, n_groups: int = -1):\n        self.top_k = top_k\n        self.dim = dim\n        self.n_groups = n_groups\n        assert self.top_k % self.n_groups == 0, f'{self.top_k} cannot be divided by {self.n_groups}'\n\n    def forward(self, x: torch.Tensor):\n        \"\"\"forward.\"\"\"\n        routing_weights = torch.softmax(x, dim=self.dim, dtype=torch.float32)\n        if self.n_groups > 0:\n            assert routing_weights.shape[\n                self.\n                dim] % self.n_groups == 0, f'{routing_weights.shape[self.dim]} cannot be divided by {self.n_groups}'\n            per_group_top_k = self.top_k // self.n_groups\n            group_size = routing_weights.shape[self.dim] // self.n_groups\n            group_offsets = self.get_group_offsets(self.n_groups, group_size, routing_weights.device)\n            routing_weights = routing_weights.unflatten(self.dim, (self.n_groups, group_size))\n            topk_weights, topk_ids = torch.topk(routing_weights, per_group_top_k, dim=-1)\n            topk_ids = (topk_ids + group_offsets).flatten(-2, -1)\n            topk_weights = topk_weights.flatten(-2, -1)\n        else:\n            topk_weights, topk_ids = torch.topk(routing_weights, self.top_k, dim=self.dim)\n        return topk_weights, topk_ids\n\n\nclass DefaultSoftmaxTopKBuilder(SoftmaxTopKBuilder):\n    \"\"\"RMS norm implementation builder.\"\"\"\n\n    @staticmethod\n    def build(top_k: int, dim: int = -1, n_groups: int = -1):\n        \"\"\"build.\"\"\"\n        return DefaultSoftmaxTopKImpl(top_k, dim, n_groups=n_groups)\n"
  },
  {
    "path": "lmdeploy/pytorch/backends/default/moe_router.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport functools\nfrom typing import Tuple\n\nimport torch\n\nfrom ..moe_router import RouterNoauxTCBuilder, RouterNoauxTCImpl\n\n\ndef _compute_scores(scoring_func: str, logits: torch.Tensor):\n    \"\"\"Compute scores.\"\"\"\n    if scoring_func == 'softmax':\n        scores = logits.softmax(dim=-1, dtype=torch.float32)\n    elif scoring_func == 'sigmoid':\n        scores = logits.sigmoid()\n    else:\n        raise NotImplementedError('unsupported scoring function '\n                                  f'for MoE gating: {scoring_func}')\n    return scores\n\n\n@functools.lru_cache\ndef get_group_offsets(n_groups: int, group_size: int, device: str | torch.device) -> torch.Tensor:\n    group_offsets = (torch.arange(n_groups, device=device) * group_size).view(1, -1, 1)  # [1, n_groups, 1]\n    return group_offsets\n\n\nclass DefaultRouterNoauxTCImpl(RouterNoauxTCImpl):\n\n    def __init__(\n        self,\n        scoring_func: str,\n        top_k: int,\n        n_group: int,\n        topk_group: int,\n        n_routed_experts: int,\n        routed_scaling_factor: float,\n        renormalize: bool = True,\n        router_n_groups: int = -1,\n    ):\n\n        self.scoring_func = scoring_func\n        self.top_k = top_k\n        self.n_group = n_group\n        self.topk_group = topk_group\n        self.n_routed_experts = n_routed_experts\n\n        # renorm\n        self.renormalize = renormalize\n        self.routed_scaling_factor = routed_scaling_factor\n\n        # n_group\n        self.router_n_groups = router_n_groups\n\n    def _forward_router_n_groups(self, scores_for_choice: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:\n        assert scores_for_choice.shape[-1] % self.router_n_groups == 0, \\\n            f'{scores_for_choice.shape[-1]} cannot be divided by {self.router_n_groups}'\n        per_group_top_k = self.top_k // self.router_n_groups\n        group_size = scores_for_choice.shape[-1] // self.router_n_groups\n        group_offsets = get_group_offsets(self.router_n_groups, group_size, device=scores_for_choice.device)\n        scores_for_choice = scores_for_choice.unflatten(-1, (self.router_n_groups, group_size))\n        topk_weight, topk_idx = torch.topk(scores_for_choice, per_group_top_k, dim=-1)\n        topk_idx = (topk_idx + group_offsets).flatten(-2, -1)\n        topk_weight = topk_weight.flatten(-2, -1)\n        return topk_weight, topk_idx\n\n    def _forward_default(self, scores: torch.Tensor, scores_for_choice: torch.Tensor,\n                         sequence_length: int) -> Tuple[torch.Tensor, torch.Tensor]:\n        group_scores = (scores_for_choice.view(sequence_length, self.n_group,\n                                               -1).topk(2, dim=-1)[0].sum(dim=-1))  # [n, n_group]\n        group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1]  # [n, top_k_group]\n        group_mask = torch.zeros_like(group_scores)  # [n, n_group]\n        group_mask.scatter_(1, group_idx, 1)  # [n, n_group]\n        score_mask = (group_mask.unsqueeze(-1).expand(sequence_length, self.n_group,\n                                                      self.n_routed_experts // self.n_group).reshape(\n                                                          sequence_length, -1))  # [n, e]\n        tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)  # [n, e]\n        _, topk_idx = torch.topk(tmp_scores, k=self.top_k, dim=-1, sorted=False)\n        topk_weight = scores.gather(1, topk_idx)\n\n        return topk_weight, topk_idx\n\n    def renorm(self, topk_weight: torch.Tensor) -> torch.Tensor:\n        if self.renormalize:\n            denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20\n            topk_weight = topk_weight / denominator\n            if not topk_weight.is_contiguous():\n                topk_weight = topk_weight.contiguous()\n\n        topk_weight = topk_weight * self.routed_scaling_factor\n        return topk_weight\n\n    def forward(self, logits: torch.Tensor, bias: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"Router forward.\"\"\"\n        sequence_length = logits.shape[0]\n\n        scores = _compute_scores(self.scoring_func, logits)\n        scores_for_choice = scores.view(sequence_length, -1) + bias[None]\n        if self.router_n_groups > 0:\n            topk_weight, topk_idx = self._forward_router_n_groups(scores_for_choice)\n        else:\n            topk_weight, topk_idx = self._forward_default(scores, scores_for_choice, sequence_length)\n\n        topk_weight = self.renorm(topk_weight)\n        return topk_weight, topk_idx\n\n\nclass DefaultRouterNoauxTCBuilder(RouterNoauxTCBuilder):\n\n    @staticmethod\n    def build(\n        scoring_func: str,\n        top_k: int,\n        n_group: int,\n        topk_group: int,\n        n_routed_experts: int,\n        routed_scaling_factor: float,\n        renormalize: bool = True,\n        router_n_groups: int = -1,\n    ):\n        return DefaultRouterNoauxTCImpl(\n            scoring_func=scoring_func,\n            top_k=top_k,\n            n_group=n_group,\n            topk_group=topk_group,\n            n_routed_experts=n_routed_experts,\n            routed_scaling_factor=routed_scaling_factor,\n            renormalize=renormalize,\n            router_n_groups=router_n_groups,\n        )\n"
  },
  {
    "path": "lmdeploy/pytorch/backends/default/multinomial_sampling.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nimport torch\n\nfrom ..multinomial_sampling import MultinomialSamplingBuilder, MultinomialSamplingImpl\n\n\nclass DefaultMultinomialSamplingImpl(MultinomialSamplingImpl):\n    \"\"\"Multinomial sampling implementation api.\"\"\"\n\n    def forward(self,\n                scores: torch.Tensor,\n                seeds: torch.LongTensor,\n                offsets: torch.LongTensor,\n                indices: torch.Tensor = None):\n        \"\"\"forward.\"\"\"\n        sampled_index = torch.multinomial(scores, num_samples=1, replacement=True)\n        outputs = torch.gather(indices, dim=1, index=sampled_index)\n        return outputs.view(-1)\n\n\nclass DefaultMultinomialSamplingBuilder(MultinomialSamplingBuilder):\n    \"\"\"Multinomial sampling implementation builder.\"\"\"\n\n    @staticmethod\n    def build():\n        \"\"\"build.\"\"\"\n        return DefaultMultinomialSamplingImpl()\n"
  },
  {
    "path": "lmdeploy/pytorch/backends/default/norm.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport torch\n\nfrom ..norm import LayerNormBuilder, LayerNormImpl, RMSNormBuilder, RMSNormImpl\n\n\nclass DefaultRMSNormImpl(RMSNormImpl):\n    \"\"\"RMS norm implementation api.\"\"\"\n\n    def __init__(self, hidden_size: int, eps: float = 1e-6):\n        self.hidden_size = hidden_size\n        self.eps = eps\n\n    def forward(self, x: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor = None):\n        \"\"\"forward.\"\"\"\n        input_dtype = x.dtype\n        if residual is not None:\n            x = x + residual\n            residual = x\n        x = x.to(torch.float32)\n        variance = x.pow(2).mean(-1, keepdim=True)\n        x = x * torch.rsqrt(variance + self.eps)\n        x = weight * x.to(input_dtype)\n        if residual is None:\n            return x\n        return x, residual\n\n\nclass DefaultRMSNormBuilder(RMSNormBuilder):\n    \"\"\"RMS norm implementation builder.\"\"\"\n\n    @staticmethod\n    def build(hidden_size: int, eps: float = 1e-6):\n        \"\"\"build.\"\"\"\n        return DefaultRMSNormImpl(hidden_size, eps)\n\n\nclass DefaultLayerNormImpl(LayerNormImpl):\n    \"\"\"RMS norm implementation api.\"\"\"\n\n    def __init__(self, normalized_shape: int, eps: float = 1e-6):\n        if isinstance(normalized_shape, int):\n            normalized_shape = (normalized_shape, )\n        self.normalized_shape = normalized_shape\n        self.eps = eps\n\n    def forward(self,\n                x: torch.Tensor,\n                weight: torch.Tensor = None,\n                bias: torch.Tensor = None,\n                residual: torch.Tensor = None):\n        \"\"\"forward.\"\"\"\n        if residual is not None:\n            x = x + residual\n            residual = x\n        x = torch.nn.functional.layer_norm(x, self.normalized_shape, weight=weight, bias=bias, eps=self.eps)\n        if residual is None:\n            return x\n        return x, residual\n\n\nclass DefaultLayerNormBuilder(LayerNormBuilder):\n    \"\"\"RMS norm implementation builder.\"\"\"\n\n    @staticmethod\n    def build(normalized_shape: int, eps: float = 1e-6):\n        \"\"\"build.\"\"\"\n        return DefaultLayerNormImpl(normalized_shape, eps)\n"
  },
  {
    "path": "lmdeploy/pytorch/backends/default/op_backend.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom typing import Tuple\n\nimport torch\n\nfrom ..base import OpsBackend, OpType\n\n\nclass DefaultOpsBackend(OpsBackend):\n\n    @staticmethod\n    def get_name() -> str:\n        return 'default'\n\n    @classmethod\n    def get_layer_impl_builder(cls, layer_type: OpType):\n        \"\"\"Get builder of given layer type.\"\"\"\n        if layer_type == OpType.Linear:\n            from .linear import DefaultLinearBuilder\n            return DefaultLinearBuilder\n        elif layer_type == OpType.RotaryEmbedding:\n            from .rotary_embedding import DefaultRotaryEmbeddingBuilder\n            return DefaultRotaryEmbeddingBuilder\n        elif layer_type == OpType.ApplyRotaryEmb:\n            from .apply_rotary_emb import DefaultApplyRotaryEmbBuilder\n            return DefaultApplyRotaryEmbBuilder\n        elif layer_type == OpType.SiluAndMul:\n            from .activation import DefaultSiluAndMulBuilder\n            return DefaultSiluAndMulBuilder\n        elif layer_type == OpType.GeluAndMul:\n            from .activation import DefaultGeluAndMulBuilder\n            return DefaultGeluAndMulBuilder\n        elif layer_type == OpType.RMSNorm:\n            from .norm import DefaultRMSNormBuilder\n            return DefaultRMSNormBuilder\n        elif layer_type == OpType.LayerNorm:\n            from .norm import DefaultLayerNormBuilder\n            return DefaultLayerNormBuilder\n        elif layer_type == OpType.MultinomialSampling:\n            from .multinomial_sampling import DefaultMultinomialSamplingBuilder\n            return DefaultMultinomialSamplingBuilder\n        elif layer_type == OpType.LinearW4A16:\n            from .awq_modules import DefaultLinearW4A16Builder\n            return DefaultLinearW4A16Builder\n        elif layer_type == OpType.SoftmaxTopK:\n            from .moe import DefaultSoftmaxTopKBuilder\n            return DefaultSoftmaxTopKBuilder\n        elif layer_type == OpType.Embedding:\n            from .embedding import DefaultEmbeddingBuilder\n            return DefaultEmbeddingBuilder\n        elif layer_type == OpType.RouterNoauxTC:\n            from .moe_router import DefaultRouterNoauxTCBuilder\n            return DefaultRouterNoauxTCBuilder\n        else:\n            raise RuntimeError(f'{layer_type} not supported.')\n\n    @staticmethod\n    def get_k_block_shape(\n        block_size: int,\n        num_heads: int,\n        head_size: int,\n        dtype: torch.dtype,\n    ) -> Tuple[int, ...]:\n        \"\"\"Get block shape of k.\"\"\"\n        return (\n            block_size,\n            num_heads,\n            head_size,\n        )\n\n    @staticmethod\n    def get_v_block_shape(\n        block_size: int,\n        num_heads: int,\n        head_size: int,\n        dtype: torch.dtype,\n    ) -> Tuple[int, ...]:\n        \"\"\"Get block shape of v.\"\"\"\n        return (\n            block_size,\n            num_heads,\n            head_size,\n        )\n\n    @staticmethod\n    def init():\n        pass\n\n    @staticmethod\n    def ccl_backend() -> str:\n        return 'nccl'\n"
  },
  {
    "path": "lmdeploy/pytorch/backends/default/rotary_embedding.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nimport math\nfrom functools import wraps\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn\n\nfrom ..rotary_embedding import (FopeParameters, Llama3Parameters, LongRoPEScalingParameters, RopeType,\n                                RotaryEmbeddingBuilder, RotaryEmbeddingImpl, YarnParameters)\n\n\ndef safe_torch_compile(**compile_kwargs):\n    \"\"\"Auto fallback.\"\"\"\n\n    def decorator(func):\n        compiled_func = None\n        compile_failed = False\n\n        @wraps(func)\n        def wrapper(*args, **kwargs):\n            nonlocal compiled_func, compile_failed\n\n            if compile_failed:\n                return func(*args, **kwargs)\n\n            if compiled_func is None:\n                try:\n                    compiled_func = torch.compile(func, **compile_kwargs)\n                    return compiled_func(*args, **kwargs)\n                except Exception:\n                    compile_failed = True\n                    return func(*args, **kwargs)\n\n            return compiled_func(*args, **kwargs)\n\n        return wrapper\n\n    return decorator\n\n\n@safe_torch_compile(dynamic=True)\ndef _rotary_embedding_fwd(position_ids: torch.Tensor,\n                          inv_freq: torch.Tensor,\n                          scaling_factor: float,\n                          mscale: float = None,\n                          dtype: torch.dtype = None,\n                          device_type: torch.device = None):\n    \"\"\"Rotary embedding forward.\"\"\"\n    if dtype is None:\n        dtype = torch.float16\n    if device_type is None:\n        device_type = 'cuda'\n    position_ids = position_ids.float() / scaling_factor\n    inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)\n    position_ids_expanded = position_ids[:, None, :]\n    # Force float32 since bfloat16 loses precision on long contexts\n    # See https://github.com/huggingface/transformers/pull/29285\n    device_type = device_type if isinstance(device_type, str) and device_type != 'mps' else 'cpu'\n    with torch.autocast(device_type=device_type, enabled=False):\n        freqs = (inv_freq_expanded.float() * position_ids_expanded.float()).transpose(1, 2)\n        emb = freqs.repeat(1, 1, 2)\n        cos = emb.cos()\n        sin = emb.sin()\n\n        if mscale is not None:\n            cos = cos * mscale\n            sin = sin * mscale\n\n    return cos.to(dtype=dtype), sin.to(dtype=dtype)\n\n\nclass RotaryEmbeddingImpl(RotaryEmbeddingImpl, nn.Module):\n    \"\"\"Base rotary embedding.\"\"\"\n\n    def __init__(self, dim: int, base: int = 10000, scaling_factor: float = 1.0):\n        super().__init__()\n        self.scaling_factor = scaling_factor\n        self.dim = dim\n        self.base = base\n        inv_freq = 1.0 / (self.base**(torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))\n        self.register_buffer('inv_freq', inv_freq, persistent=False)\n\n    def forward(self, x: torch.Tensor, position_ids: torch.Tensor):\n        \"\"\"forward.\"\"\"\n        device_type = x.device.type\n        dtype = x.dtype\n        if self.inv_freq.device != x.device:\n            self.inv_freq = self.inv_freq.to(x.device)\n        return _rotary_embedding_fwd(position_ids,\n                                     self.inv_freq,\n                                     scaling_factor=self.scaling_factor,\n                                     dtype=dtype,\n                                     device_type=device_type)\n\n\nclass LlamaDynamicNTKScalingRotaryEmbedding(RotaryEmbeddingImpl):\n    \"\"\"LlamaRotaryEmbedding extended with Dynamic NTK scaling.\n\n    Credits to the Reddit users /u/bloc97 and /u/emozilla\n    \"\"\"\n\n    def __init__(self, dim: int, base: int = 10000, scaling_factor: float = 1.0, max_position_embeddings: int = 2048):\n        super().__init__(dim, base, scaling_factor)\n        self.max_position_embeddings = max_position_embeddings\n\n    def _ntk_inv_freq(self, seq_len: torch.Tensor):\n        \"\"\"ntk_inv_freq.\"\"\"\n        device = seq_len.device\n        base = self.base * ((self.scaling_factor * seq_len / self.max_position_embeddings) -\n                            (self.scaling_factor - 1))**(self.dim / (self.dim - 2))\n        inv_freq = 1.0 / (base**(torch.arange(0, self.dim, 2, dtype=torch.int64, device=device).float() / self.dim))\n        return inv_freq\n\n    def forward(self, x: torch.Tensor, position_ids: torch.Tensor):\n        \"\"\"forward.\"\"\"\n        device_type = x.device.type\n        dtype = x.dtype\n        seq_len = torch.max(position_ids) + 1\n        ntk_inv_freq = self._ntk_inv_freq(seq_len)\n        if self.inv_freq.device != x.device:\n            self.inv_freq = self.inv_freq.to(x.device)\n        inv_freq = torch.where(seq_len > self.max_position_embeddings, ntk_inv_freq, self.inv_freq)\n\n        cos, sin = _rotary_embedding_fwd(position_ids,\n                                         inv_freq,\n                                         scaling_factor=1.0,\n                                         dtype=dtype,\n                                         device_type=device_type)\n        return cos, sin\n\n\nclass Llama3RotaryEmbeddingImpl(RotaryEmbeddingImpl):\n    \"\"\"Llama3 rotary embedding implementation.\"\"\"\n\n    def __init__(\n        self,\n        dim: int,\n        base: int = 10000,\n        scaling_factor: float = 1.0,\n        low_freq_factor: float = 1.0,\n        high_freq_factor: float = 4.0,\n        original_max_position_embeddings: int = 8194,\n    ):\n        super().__init__(dim, base, scaling_factor)\n        old_context_len = original_max_position_embeddings\n        low_freq_wavelen = old_context_len / low_freq_factor\n        high_freq_wavelen = old_context_len / high_freq_factor\n\n        inv_freq = self.inv_freq\n        factor = self.scaling_factor\n\n        wavelen = 2 * math.pi / inv_freq\n        # wavelen < high_freq_wavelen: do nothing\n        # wavelen > low_freq_wavelen: divide by factor\n        inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq)\n        # otherwise: interpolate between the two, using a smooth factor\n        smooth_factor = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)\n        smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / factor + smooth_factor * inv_freq_llama\n        is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)\n        inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)\n        self.scaling_factor = 1.0\n        self.register_buffer('inv_freq', inv_freq_llama)\n\n\ndef yarn_find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048):\n    \"\"\"yarn_find_correction_dim.\"\"\"\n    return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))\n\n\n# Find dim range bounds based on rotations\ndef yarn_find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048, truncate: bool = True):\n    \"\"\"yarn_find_correction_range.\"\"\"\n    low = yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)\n    high = yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)\n    if truncate:\n        low = math.floor(low)\n        high = math.ceil(high)\n    return max(low, 0), min(high, dim - 1)  # Clamp values just in case\n\n\ndef yarn_get_mscale(scale=1, mscale=1):\n    \"\"\"yarn_get_mscale.\"\"\"\n    if scale <= 1:\n        return 1.0\n    return 0.1 * mscale * math.log(scale) + 1.0\n\n\ndef yarn_linear_ramp_mask(min, max, dim):\n    \"\"\"yarn_linear_ramp_mask.\"\"\"\n    if min == max:\n        max += 0.001  # Prevent singularity\n\n    linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)\n    ramp_func = torch.clamp(linear_func, 0, 1)\n    return ramp_func\n\n\nclass YarnRotaryEmbeddingImpl(RotaryEmbeddingImpl):\n    \"\"\"Yarn rotary embedding implementation.\"\"\"\n\n    def __init__(self,\n                 dim: int,\n                 base: int = 10000,\n                 scaling_factor: float = 1.0,\n                 original_max_position_embeddings: int = 4096,\n                 yarn_params: YarnParameters = None):\n        super().__init__(dim, base, scaling_factor)\n        self.original_max_position_embeddings = \\\n            original_max_position_embeddings\n        assert yarn_params is not None\n        self.beta_fast = yarn_params.beta_fast\n        self.beta_slow = yarn_params.beta_slow\n        self.mscale = yarn_params.mscale\n        self.mscale_all_dim = yarn_params.mscale_all_dim\n        self.truncate = yarn_params.truncate\n\n        # get inv_freq\n        freq_extra = 1.0 / (self.base**(torch.arange(0, dim, 2, dtype=torch.float32) / dim))\n        freq_inter = 1.0 / (self.scaling_factor * self.base**(torch.arange(0, dim, 2, dtype=torch.float32) / dim))\n        low, high = yarn_find_correction_range(\n            self.beta_fast,\n            self.beta_slow,\n            dim,\n            self.base,\n            self.original_max_position_embeddings,\n            truncate=self.truncate,\n        )\n        inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to(dtype=torch.float32)\n        inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask\n        self.register_buffer('inv_freq', inv_freq, persistent=False)\n\n        # get mscale\n        if yarn_params.attention_factor is not None:\n            self.mscale = yarn_params.attention_factor\n        else:\n            self.mscale = float(\n                yarn_get_mscale(self.scaling_factor, self.mscale) /\n                yarn_get_mscale(self.scaling_factor, self.mscale_all_dim))\n        if self.mscale == 1.0:\n            self.mscale = None\n\n    def forward(self, x: torch.Tensor, position_ids: torch.Tensor):\n        \"\"\"forward.\"\"\"\n        device_type = x.device.type\n        dtype = x.dtype\n        if self.inv_freq.device != x.device:\n            self.inv_freq = self.inv_freq.to(x.device)\n        return _rotary_embedding_fwd(position_ids,\n                                     self.inv_freq,\n                                     scaling_factor=1.0,\n                                     mscale=self.mscale,\n                                     dtype=dtype,\n                                     device_type=device_type)\n\n\nclass LongRoPEScalingRotaryEmbeddingImpl(RotaryEmbeddingImpl):\n    \"\"\"Yarn rotary embedding implementation.\"\"\"\n\n    def __init__(\n        self,\n        dim: int,\n        base: int = 10000,\n        max_position_embeddings: int = 4096,\n        longrope_params: LongRoPEScalingParameters = None,\n    ):\n        super().__init__(dim, base)\n        short_factor = torch.tensor(longrope_params.short_factor, dtype=torch.float32)\n        long_factor = torch.tensor(longrope_params.long_factor, dtype=torch.float32)\n        self.register_buffer('short_factor', short_factor, persistent=False)\n        self.register_buffer('long_factor', long_factor, persistent=False)\n        self.original_max_position_embeddings = \\\n            longrope_params.original_max_position_embeddings\n        self.mscale = None\n        self.short_mscale = longrope_params.short_mscale\n        self.long_mscale = longrope_params.long_mscale\n        if self.short_mscale is None and self.long_mscale is None:\n            scale = (max_position_embeddings / self.original_max_position_embeddings)\n            if scale <= 1.0:\n                self.mscale = 1.0\n            else:\n                self.mscale = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings))\n\n    def forward(self, x: torch.Tensor, position_ids: torch.Tensor):\n        \"\"\"Rope forward.\"\"\"\n        dtype = x.dtype\n        device = position_ids.device\n        if self.short_factor.device != device:\n            self.register_buffer('short_factor', self.short_factor.to(device), persistent=False)\n            self.register_buffer('long_factor', self.long_factor.to(device), persistent=False)\n\n        max_pos_ids = position_ids.max() + 1\n        mask = max_pos_ids > self.original_max_position_embeddings\n        ext_factors = torch.where(mask, self.long_factor, self.short_factor)\n\n        mscale = self.mscale\n        if mscale is None:\n            mscale = torch.where(mask, self.long_mscale, self.short_mscale)\n\n        inv_freq = self.inv_freq * (1.0 / ext_factors)\n        return _rotary_embedding_fwd(position_ids,\n                                     inv_freq,\n                                     scaling_factor=1.0,\n                                     mscale=mscale,\n                                     dtype=dtype,\n                                     device_type=device)\n\n\nclass FopeRotaryEmbeddingImpl(RotaryEmbeddingImpl):\n\n    def __init__(self,\n                 dim: int,\n                 max_position_embeddings: int = 4096,\n                 scaling_factor: float = 1.0,\n                 params: FopeParameters = None):\n        super().__init__(dim, scaling_factor=scaling_factor)\n        self.head_dim = dim\n        self.max_position_embeddings = max_position_embeddings\n        self.attention_scaling = scaling_factor\n        self.params = params\n\n        inv_freq = self.params.inv_freq\n        inv_freq_idx_selected = torch.ones_like(inv_freq, dtype=torch.bool)\n        if self.params.num_inv_freq is not None:\n            num_inv_freq = self.params.num_inv_freq\n            inv_freq_idx_selected[num_inv_freq:] = False\n        else:\n            inv_freq_idx_selected = inv_freq > (2.0 * torch.pi / self.max_position_embeddings)\n            num_inv_freq = inv_freq_idx_selected.sum().item()\n\n        self.inv_freq = inv_freq[inv_freq_idx_selected]\n        self.register_buffer('inv_freq', self.inv_freq, persistent=False)\n\n    def forward(self, x: torch.Tensor, position_ids: torch.Tensor, sin_coef: torch.Tensor, cos_coef: torch.Tensor):\n        \"\"\"forward.\"\"\"\n        if self.inv_freq.device != x.device:\n            self.inv_freq = self.inv_freq.to(x.device)\n\n        inv_freq = self.inv_freq\n        inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)\n        position_ids_expanded = position_ids[:, None, :].float()\n        freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)\n\n        batch_size, seq_len, _ = x.shape\n        if self.params.fope_sep_head:\n            pos_cos = freqs.cos().unsqueeze(1).expand(batch_size, self.params.num_key_value_heads, seq_len, -1)\n            pos_sin = freqs.sin().unsqueeze(1).expand(batch_size, self.params.num_key_value_heads, seq_len, -1)\n        else:\n            pos_cos = freqs.cos()\n            pos_sin = freqs.sin()\n\n        if self.params.fope_sep_head:\n            sin = torch.einsum('bhtD, hDd -> bthd', pos_sin, sin_coef.float())\n            cos = torch.einsum('bhtD, hDd -> bthd', pos_cos, cos_coef.float())\n        else:\n            sin = torch.einsum('btD, Dd -> btd', pos_sin, sin_coef.float())\n            cos = torch.einsum('btD, Dd -> btd', pos_cos, cos_coef.float())\n\n        sin = F.pad(input=sin, pad=(0, self.head_dim // 2 - sin.size(-1)), mode='constant', value=1)\n        cos = F.pad(input=cos, pad=(0, self.head_dim // 2 - cos.size(-1)), mode='constant', value=1)\n\n        sin = torch.cat((sin, sin), dim=-1)\n        cos = torch.cat((cos, cos), dim=-1)\n\n        cos = cos * self.attention_scaling\n        sin = sin * self.attention_scaling\n\n        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)\n\n\nclass DefaultRotaryEmbeddingBuilder(RotaryEmbeddingBuilder):\n    \"\"\"Rotary embedding builder.\"\"\"\n\n    @staticmethod\n    def build(\n        dim: int,\n        max_position_embeddings: int = 2048,\n        base: int = 10000,\n        scaling_factor: float = 1.0,\n        yarn_params: YarnParameters = None,\n        longrope_params: LongRoPEScalingParameters = None,\n        llama3_params: Llama3Parameters = None,\n        fope_params: FopeParameters = None,\n        emb_type: RopeType = RopeType.Default,\n    ):\n        \"\"\"build.\"\"\"\n        if emb_type in (RopeType.Default, RopeType.LinearScaling):\n            return RotaryEmbeddingImpl(dim, base, scaling_factor)\n        elif emb_type == RopeType.DynamicNTKScaling:\n            return LlamaDynamicNTKScalingRotaryEmbedding(dim, base, scaling_factor, max_position_embeddings)\n        elif emb_type == RopeType.Llama3:\n            return Llama3RotaryEmbeddingImpl(dim, base, scaling_factor, llama3_params.low_freq_factor,\n                                             llama3_params.high_freq_factor,\n                                             llama3_params.original_max_position_embeddings)\n        elif emb_type == RopeType.Yarn:\n            return YarnRotaryEmbeddingImpl(dim, base, scaling_factor, max_position_embeddings, yarn_params=yarn_params)\n        elif emb_type == RopeType.LongRoPEScaling:\n            return LongRoPEScalingRotaryEmbeddingImpl(\n                dim,\n                base,\n                max_position_embeddings=max_position_embeddings,\n                longrope_params=longrope_params,\n            )\n        elif emb_type == RopeType.Fope:\n            return FopeRotaryEmbeddingImpl(\n                dim,\n                max_position_embeddings=max_position_embeddings,\n                scaling_factor=scaling_factor,\n                params=fope_params,\n            )\n        else:\n            raise NotImplementedError(f'Unsupported embedding type: {emb_type}')\n"
  },
  {
    "path": "lmdeploy/pytorch/backends/default/token_dispatcher.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom typing import Tuple\n\nimport torch\n\nfrom ..token_dispatcher import TokenDispatcherImpl\n\n\nclass AlltoAllTokenDispatcher(TokenDispatcherImpl):\n\n    def __init__(\n        self,\n        ep_group,\n        num_experts,\n        num_local_experts: int,\n    ) -> None:\n        self.num_local_experts = num_local_experts\n        assert num_experts is not None\n        self.num_experts = num_experts\n        assert self.num_local_experts > 0, 'Expected at least one expert'\n        self.ep_size = num_experts // num_local_experts\n        self.ep_group = ep_group\n        self.tp_size = 1\n        self.input_splits = None\n        self.output_splits = None\n        input_chunk_idxs = torch.arange(self.num_experts, device=torch.device('cpu'))\n        self.sort_input_by_local_experts = input_chunk_idxs.reshape(-1, self.num_local_experts).T.ravel()\n        self.restore_output_by_local_experts = input_chunk_idxs.reshape(self.num_local_experts, -1).T.ravel()\n\n    def sort_chunks_by_idxs(self, input: torch.Tensor, split_sizes: torch.Tensor, sorted_idxs: torch.Tensor):\n        \"\"\"Split and sort the input tensor based on the split_sizes and sorted\n        indices.\"\"\"\n        input = torch.split(input, split_sizes.tolist(), dim=0)\n        output = torch.cat([input[i] for i in sorted_idxs.tolist()], dim=0)\n        return output\n\n    def all_to_all(self, group: torch.distributed.group, input_: torch.Tensor, output_split: torch.Tensor,\n                   input_split: torch.Tensor):\n        output_split_sizes_ = output_split.tolist()\n        input_split_sizes = input_split.tolist()\n        output = input_.new_empty(\n            size=[sum(output_split_sizes_)] + list(input_.size()[1:]),\n            dtype=input_.dtype,\n            device=torch.cuda.current_device(),\n        )\n        torch.distributed.all_to_all_single(\n            output,\n            input_,\n            output_split_sizes=output_split_sizes_,\n            input_split_sizes=input_split_sizes,\n            group=group,\n        )\n        return output\n\n    def preprocess(self, routing_map: torch.Tensor, local_expert_indices) -> torch.Tensor:\n        assert (len(local_expert_indices) == self.num_local_experts), 'Invalid local expert indices'\n        for i in range(len(local_expert_indices) - 1):\n            assert (local_expert_indices[i] == local_expert_indices[i + 1] -\n                    1), 'local_expert_indices must be continous'\n\n        num_local_tokens_per_expert = routing_map.sum(dim=0).long()\n        self.input_splits = (num_local_tokens_per_expert.reshape(self.ep_size, self.num_local_experts).sum(axis=1).to(\n            torch.device('cpu'), non_blocking=True))\n        dim_size = list(num_local_tokens_per_expert.size())\n        dim_size[0] = dim_size[0] * torch.distributed.get_world_size(self.ep_group)\n        output = num_local_tokens_per_expert.new_empty(dim_size)\n        torch.distributed.all_gather_into_tensor(output, num_local_tokens_per_expert.contiguous(), group=self.ep_group)\n        num_global_tokens_per_expert = (output.reshape(self.ep_size, self.tp_size, self.num_experts).transpose(0, 1))\n        num_global_tokens_per_local_expert = num_global_tokens_per_expert[:, :, local_expert_indices[0]:\n                                                                          local_expert_indices[-1] + 1].contiguous()\n        num_global_tokens_per_rank = num_global_tokens_per_local_expert.sum(axis=2)\n        self.output_splits = (num_global_tokens_per_rank[0].to(torch.device('cpu'), non_blocking=True))\n        num_tokens_per_local_expert = num_global_tokens_per_local_expert.sum(dim=(0, 1))\n        if self.num_local_experts > 1:\n            self.num_global_tokens_per_local_expert = num_global_tokens_per_local_expert.view(\n                -1, self.num_local_experts)\n\n            self.num_global_tokens_per_local_expert = num_global_tokens_per_local_expert.to(torch.device('cpu'),\n                                                                                            non_blocking=True)\n        return num_tokens_per_local_expert\n\n    def dispatch(self, hidden_states: torch.Tensor, topk_ids: torch.Tensor, probs: torch.Tensor,\n                 local_expert_indices) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:\n        self.hidden_shape = hidden_states.shape\n        self.topk_ids = topk_ids\n        self.routing_map, self.topk_weights = super().indices_to_multihot(topk_ids, probs, self.num_experts)\n        assert probs.dim() == 2, 'Expected 2D tensor for probs'\n        assert self.routing_map.dim() == 2, 'Expected 2D tensor for token2expert mask'\n        assert self.routing_map.dtype == torch.bool, 'Expected bool tensor for mask'\n        hidden_states = hidden_states.view(-1, self.hidden_shape[-1])\n        tokens_per_expert = self.preprocess(self.routing_map, local_expert_indices)\n        self.hidden_shape_before_permute = hidden_states.shape\n\n        permutated_local_input_tokens, self.reversed_local_input_permutation_mapping = super().permute(\n            hidden_states,\n            self.routing_map,\n        )\n        global_input_tokens = self.all_to_all(self.ep_group, permutated_local_input_tokens, self.output_splits,\n                                              self.input_splits)\n        if self.num_local_experts > 1:\n            global_input_tokens = self.sort_chunks_by_idxs(\n                global_input_tokens,\n                self.num_global_tokens_per_local_expert.ravel(),\n                self.sort_input_by_local_experts,\n            )\n        return global_input_tokens, None, None, tokens_per_expert\n\n    def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        if self.num_local_experts > 1:\n            hidden_states = self.sort_chunks_by_idxs(\n                hidden_states,\n                self.num_global_tokens_per_local_expert.mT.ravel(),\n                self.restore_output_by_local_experts,\n            )\n        permutated_local_input_tokens = self.all_to_all(self.ep_group, hidden_states, self.input_splits,\n                                                        self.output_splits)\n        output = super().unpermute(\n            permutated_local_input_tokens,\n            self.reversed_local_input_permutation_mapping,\n            restore_shape=self.hidden_shape_before_permute,\n            probs=self.topk_weights,\n            routing_map=self.routing_map,\n        )\n        output = output.view(self.hidden_shape)\n        return output\n"
  },
  {
    "path": "lmdeploy/pytorch/backends/dlinfer/__init__.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n"
  },
  {
    "path": "lmdeploy/pytorch/backends/dlinfer/activation.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom lmdeploy.pytorch.kernels.dlinfer.activation import silu_and_mul\n\nfrom ..activation import SiluAndMulBuilder, SiluAndMulImpl\n\n\nclass DlinferSiluAndMulImpl(SiluAndMulImpl):\n    \"\"\"Silu + multiple fused implementation.\"\"\"\n\n    def forward(self, x):\n        \"\"\"forward.\"\"\"\n        return silu_and_mul(x)\n\n\nclass DlinferSiluAndMulBuilder(SiluAndMulBuilder):\n    \"\"\"Silu and mul implementation builder.\"\"\"\n\n    @staticmethod\n    def build(inplace: bool = False):\n        \"\"\"build.\"\"\"\n        return DlinferSiluAndMulImpl()\n"
  },
  {
    "path": "lmdeploy/pytorch/backends/dlinfer/apply_rotary_emb.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom torch import Tensor\n\nfrom lmdeploy.pytorch.kernels.dlinfer import apply_rotary_pos_emb\n\nfrom ..apply_rotary_emb import ApplyRotaryEmbBuilder, ApplyRotaryEmbImpl\n\n\nclass DlinferApplyRotaryEmbImpl(ApplyRotaryEmbImpl):\n    \"\"\"Apply rotary embedding implementation.\"\"\"\n\n    def forward(self, query: Tensor, key: Tensor, cos: Tensor, sin: Tensor, inplace: bool = True):\n        \"\"\"forward.\"\"\"\n        if inplace:\n            q_embed = None\n            k_embed = None\n        else:\n            q_embed = query.new_empty(query.shape)\n            k_embed = key.new_empty(key.shape)\n        return apply_rotary_pos_emb(query, key, cos, sin, q_embed, k_embed)\n\n\nclass DlinferApplyRotaryEmbBuilder(ApplyRotaryEmbBuilder):\n    \"\"\"Apply rotary embedding implementation builder.\"\"\"\n\n    @staticmethod\n    def build():\n        \"\"\"Build implementation.\"\"\"\n        return DlinferApplyRotaryEmbImpl()\n"
  },
  {
    "path": "lmdeploy/pytorch/backends/dlinfer/ascend/__init__.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom .op_backend import AscendOpsBackend, SocVersion  # noqa: F401\n"
  },
  {
    "path": "lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport itertools\nimport math\nimport os\nimport re\nfrom dataclasses import dataclass\nfrom functools import lru_cache\nfrom pathlib import Path\nfrom typing import Dict, Tuple\n\nimport torch\nimport torch.distributed as dist\n\nfrom lmdeploy.pytorch import envs as _envs\nfrom lmdeploy.pytorch.config import BackendConfig, CacheConfig, ModelConfig\nfrom lmdeploy.pytorch.distributed import get_dist_manager\nfrom lmdeploy.utils import get_logger\n\nfrom ..moe import DlinferMoECommType, DlinferMoeMetadata\nfrom ..op_backend import DlinferOpsBackend\n\nlogger = get_logger('lmdeploy')\n\n\nclass SocVersion:\n    Ascend310P: str = 'Ascend310P'\n    Ascend910: str = 'Ascend910'\n\n    @classmethod\n    @lru_cache(maxsize=1)\n    def device_name(cls) -> str:\n        try:\n            return torch.npu.get_device_name()\n        except ImportError:\n            logger.warning('Failed to import torch_npu. Please make sure torch_npu is installed correctly.')\n        except Exception as e:\n            logger.warning(f'Error during Ascend get device name: {str(e)}. '\n                           'Please check your Ascend environment configuration.')\n\n    @classmethod\n    def is_Ascend310P(cls) -> bool:\n        return cls.device_name().startswith(cls.Ascend310P)\n\n    @classmethod\n    def is_Ascend910(cls) -> bool:\n        return cls.device_name().startswith(cls.Ascend910)\n\n    @classmethod\n    @lru_cache(maxsize=1)\n    def soc_version(cls) -> int:\n        return torch.npu.get_soc_version()\n\n    @classmethod\n    def is_A2(cls) -> bool:\n        return 220 <= cls.soc_version() <= 225\n\n    @classmethod\n    def is_A3(cls) -> bool:\n        return 250 <= cls.soc_version() <= 255\n\n\n@dataclass\nclass DistMeta:\n    dp_size: int\n    tp_size: int\n    ep_size: int\n    tp_rank: int\n    ep_rank: int\n    tp_group: torch.distributed.ProcessGroup\n    ep_group: torch.distributed.ProcessGroup\n\n\nclass AscendKVQuantMeta:\n    has_set_value: bool = False\n    quant_meta: Dict = {}\n\n    @classmethod\n    def set_value(cls, device: str, dtype: torch.dtype, record_file: str, total_layers: int):\n        with open(record_file, 'r') as file:\n            data = file.read()\n        scale_offset_pairs = re.findall(r'scale:\\s*([\\d\\.\\-]+)\\s*offset:\\s*(-?\\d+)', data)\n        scale_offset_pairs = [(float(scale), float(offset)) for scale, offset in scale_offset_pairs]\n        k_scales, v_scales, kv_scales = [], [], []\n        k_zeros, v_zeros, kv_zeros = [], [], []\n        if len(scale_offset_pairs) == total_layers:\n            for scale, offset in scale_offset_pairs:\n                k_scales.append(torch.tensor([scale], device=device, dtype=dtype))\n                v_scales.append(torch.tensor([scale], device=device, dtype=dtype))\n                kv_scales.append(torch.tensor([scale, scale], device=device, dtype=dtype))\n                k_zeros.append(torch.tensor([offset], device=device, dtype=dtype))\n                v_zeros.append(torch.tensor([offset], device=device, dtype=dtype))\n                kv_zeros.append(torch.tensor([offset, offset], device=device, dtype=dtype))\n        elif len(scale_offset_pairs) == total_layers * 2:\n            for i in range(total_layers):\n                scale_k, offset_k = scale_offset_pairs[2 * i]\n                scale_v, offset_v = scale_offset_pairs[2 * i + 1]\n                k_scales.append(torch.tensor([scale_k], device=device, dtype=dtype))\n                v_scales.append(torch.tensor([scale_v], device=device, dtype=dtype))\n                kv_scales.append(torch.tensor([scale_k, scale_v], device=device, dtype=dtype))\n                k_zeros.append(torch.tensor([offset_k], device=device, dtype=dtype))\n                v_zeros.append(torch.tensor([offset_v], device=device, dtype=dtype))\n                kv_zeros.append(torch.tensor([offset_k, offset_v], device=device, dtype=dtype))\n        else:\n            raise ValueError(f'num of scale_offset_pairs({len(scale_offset_pairs)}) '\n                             f'must match num of total_layers({total_layers})')\n\n        cls.quant_meta.update({\n            'k_scales': itertools.cycle(k_scales),\n            'k_zeros': itertools.cycle(k_zeros),\n            'v_scales': itertools.cycle(v_scales),\n            'v_zeros': itertools.cycle(v_zeros),\n            'kv_scales': itertools.cycle(kv_scales),\n            'kv_zeros': itertools.cycle(kv_zeros)\n        })\n        cls.has_set_value = True\n\n\nclass AscendOpsBackend(DlinferOpsBackend):\n    \"\"\"Ascend layer backend.\"\"\"\n    enable_graph: bool = False\n    total_slots = None\n    max_batches = None\n    dist_meta: DistMeta = None\n\n    @staticmethod\n    def get_name() -> str:\n        \"\"\"Backend name.\"\"\"\n        return 'ascend'\n\n    @staticmethod\n    def get_k_block_shape(\n        block_size: int,\n        num_heads: int,\n        head_size: int,\n        dtype: torch.dtype,\n    ) -> Tuple[int, ...]:\n        if SocVersion.is_Ascend910():\n            return (block_size, num_heads, head_size)\n        else:\n            raise ValueError(f'dlinfer does not support {SocVersion.device_name()} device currently.')\n\n    @staticmethod\n    def get_v_block_shape(\n        block_size: int,\n        num_heads: int,\n        head_size: int,\n        dtype: torch.dtype,\n    ) -> Tuple[int, ...]:\n        if SocVersion.is_Ascend910():\n            return (block_size, num_heads, head_size)\n        else:\n            raise ValueError(f'dlinfer does not support {SocVersion.device_name()} device currently.')\n\n    @classmethod\n    def update_step_context(cls, step_context):\n        \"\"\"Update step context.\"\"\"\n\n        block_num, block_size, *_ = step_context.kv_caches[0][0].shape\n        is_unpaged_prefill = False\n        if not step_context.is_decoding:\n            is_unpaged_prefill = all((step_context.q_seqlens == step_context.kv_seqlens).tolist())\n        if step_context.block_offsets.dtype != torch.int32:\n            step_context.block_offsets = step_context.block_offsets.to(torch.int32)\n        if not (step_context.is_decoding or is_unpaged_prefill):\n            step_context.block_offsets = step_context.block_offsets.repeat_interleave(step_context.q_seqlens, 0)\n        if step_context.kv_seqlens.dtype != torch.int32:\n            step_context.kv_seqlens = step_context.kv_seqlens.to(torch.int32)\n        if step_context.q_seqlens.dtype != torch.int32:\n            step_context.q_seqlens = step_context.q_seqlens.to(torch.int32)\n\n        def get_total_slots():\n            if cls.total_slots is None:\n                cls.total_slots = torch.arange(block_num * block_size,\n                                               dtype=torch.int32,\n                                               device=step_context.block_offsets.device)\n                cls.total_slots = cls.total_slots.view(block_num, block_size)\n            return cls.total_slots\n\n        def get_cpu_seqlens(is_decoding, is_unpaged_prefill):\n            \"\"\"Get sequence lengths on CPU.\n\n            Returns:\n                q_seqlens_cpu: query sequence lengths (per sequence).\n                kv_seqlens_cpu: kv sequence lengths (per sequence), used for\n                    list/max seqlens calculation.\n                kv_seqlens_expanded: kv sequence lengths expanded per token via\n                    repeat_interleave, used for attention metadata.\n            \"\"\"\n            if is_decoding:\n                q_seqlens_cpu = None\n                kv_seqlens_cpu = kv_seqlens_expanded = step_context.kv_seqlens.cpu()\n            elif is_unpaged_prefill:\n                q_seqlens_cpu = step_context.q_seqlens.cpu()\n                kv_seqlens_cpu = kv_seqlens_expanded = q_seqlens_cpu\n            else:\n                q_seqlens_cpu = step_context.q_seqlens.cpu()\n                kv_seqlens_cpu = step_context.kv_seqlens.cpu()\n                # Expand kv_seqlens to per-token for paged prefill attention\n                kv_seqlens_expanded = kv_seqlens_cpu.repeat_interleave(q_seqlens_cpu, 0)\n            return q_seqlens_cpu, kv_seqlens_cpu, kv_seqlens_expanded\n\n        def get_list_seqlens(is_decoding, is_unpaged_prefill, q_seqlens_cpu=None, kv_seqlens_cpu=None):\n            if is_decoding:\n                q_seqlens_list, kv_seqlens_list = None, None\n            elif is_unpaged_prefill:\n                q_seqlens_list = kv_seqlens_list = q_seqlens_cpu.tolist()\n            else:\n                q_seqlens_list, kv_seqlens_list = q_seqlens_cpu.tolist(), kv_seqlens_cpu.tolist()\n            return q_seqlens_list, kv_seqlens_list\n\n        def get_max_seqlens(is_decoding, is_unpaged_prefill, q_seqlens_list=None, kv_seqlens_list=None):\n            if is_decoding:\n                max_q_seq_len, max_kv_seq_len = 1, None\n            elif is_unpaged_prefill:\n                max_q_seq_len = max_kv_seq_len = max(q_seqlens_list)\n            else:\n                max_q_seq_len = max(q_seqlens_list)\n                max_kv_seq_len = max(kv_seqlens_list)\n            return max_q_seq_len, max_kv_seq_len\n\n        def get_kv_start_indices_and_attention_mask(is_decoding, is_unpaged_prefill, q_seqlens_list, kv_seqlens_list,\n                                                    max_q_seq_len, max_kv_seq_len):\n            kv_start_indices, attention_mask = [], []\n            if is_decoding:\n                idx = (step_context.kv_seqlens - 1) % block_size\n                block_num = (step_context.kv_seqlens - 1) // block_size\n                last_block = step_context.block_offsets.gather(1, block_num.view(-1, 1)).view(-1)\n                kv_start_indices = last_block * block_size + idx\n            else:\n                for i in range(step_context.q_start_loc.size(0)):\n                    q_seq_len = q_seqlens_list[i]\n                    kv_seq_len = kv_seqlens_list[i]\n\n                    history_length = kv_seq_len - q_seq_len\n                    total_slots = get_total_slots()\n                    slot_tables = total_slots[step_context.block_offsets[i]].view(-1)\n                    slots = slot_tables[history_length:kv_seq_len]\n                    kv_start_indices.append(slots)\n\n                    if not is_unpaged_prefill:\n                        single_attention_mask = torch.triu(\n                            torch.ones(q_seq_len,\n                                       step_context.block_offsets.shape[1] * block_size,\n                                       dtype=torch.bool,\n                                       device=step_context.block_offsets.device),\n                            diagonal=kv_seq_len - q_seq_len + 1,\n                        )\n                        attention_mask.append(single_attention_mask)\n\n                if is_unpaged_prefill:\n                    attention_mask.append(\n                        torch.triu(torch.ones(max_q_seq_len,\n                                              max_kv_seq_len,\n                                              dtype=step_context.kv_caches[0][0].dtype,\n                                              device=step_context.block_offsets.device),\n                                   diagonal=max_kv_seq_len - max_q_seq_len + 1))\n                else:\n                    attention_mask = [torch.cat(attention_mask)]\n\n                kv_start_indices = torch.cat(kv_start_indices)\n\n            return kv_start_indices, attention_mask\n\n        def get_dist_meta():\n            if cls.dist_meta is not None:\n                return cls.dist_meta\n            dist_ctx = get_dist_manager().current_context()\n            dp_size, tp_size, ep_size = dist_ctx.dist_config.dp, dist_ctx.dist_config.tp, dist_ctx.dist_config.ep\n            tp_rank, ep_rank = dist_ctx.attn_tp_group.rank, dist_ctx.ep_rank\n            tp_group = dist_ctx.attn_tp_group.gpu_group\n            ep_group = dist_ctx.ep_gpu_group\n            cls.dist_meta = DistMeta(dp_size=dp_size,\n                                     tp_size=tp_size,\n                                     ep_size=ep_size,\n                                     tp_rank=tp_rank,\n                                     ep_rank=ep_rank,\n                                     tp_group=tp_group,\n                                     ep_group=ep_group)\n            return cls.dist_meta\n\n        def get_tokens_info(dp_size, tp_size, ep_size, ep_group):\n            if ep_size <= 1:\n                return 0, 0, 0\n            # get padded_tokens_current_rank\n            is_graph = cls.enable_graph and step_context.is_decoding\n            if is_graph:\n                from dlinfer.framework.lmdeploy_ext.cudagraph.ascend_cudagraph import get_ascend_compatible_size\n                actual_tokens_current_rank = step_context.q_seqlens.shape[0]\n                padded_tokens_current_rank = min(get_ascend_compatible_size(actual_tokens_current_rank),\n                                                 cls.max_batches)\n            else:\n                actual_tokens_current_rank = step_context.q_seqlens.sum().item()\n                padded_tokens_current_rank = actual_tokens_current_rank\n            # get max_tokens_across_dp\n            if dp_size > 1:\n                runtime_tokens_tensor = torch.tensor([padded_tokens_current_rank],\n                                                     dtype=step_context.q_seqlens.dtype,\n                                                     device=torch.npu.current_device())\n                world_size = dp_size * tp_size\n                runtime_tokens_buffer = torch.zeros([world_size],\n                                                    dtype=step_context.q_seqlens.dtype,\n                                                    device=torch.npu.current_device())\n                dist.all_gather_into_tensor(runtime_tokens_buffer, runtime_tokens_tensor, ep_group)\n                max_tokens_across_dp = torch.max(runtime_tokens_buffer).item()\n            else:\n                max_tokens_across_dp = padded_tokens_current_rank\n            return actual_tokens_current_rank, padded_tokens_current_rank, max_tokens_across_dp\n\n        @lru_cache\n        def init_mc2_token_capacity(tp_size):\n            max_num_tokens = min(cls.max_batches, 512)\n            num_tokens_per_tp_rank = (max_num_tokens + tp_size - 1) // tp_size\n            return num_tokens_per_tp_rank * tp_size\n\n        def select_moe_comm_type(max_tokens_across_dp, dp_size, tp_size, ep_size):\n            if ep_size <= 1:\n                return DlinferMoECommType.ALLGATHER\n            mc2_token_capacity = init_mc2_token_capacity(tp_size)\n            is_graph = cls.enable_graph and step_context.is_decoding\n            if is_graph:\n                max_tokens_across_dp = math.ceil(max_tokens_across_dp / tp_size) * tp_size\n            if SocVersion.is_A2():\n                if max_tokens_across_dp <= mc2_token_capacity and dp_size * tp_size >= 16:\n                    return DlinferMoECommType.MC2\n                else:\n                    return DlinferMoECommType.ALLGATHER\n            elif SocVersion.is_A3():\n                if max_tokens_across_dp <= mc2_token_capacity:\n                    return DlinferMoECommType.MC2\n                else:\n                    return DlinferMoECommType.ALLTOALL\n            else:\n                raise ValueError(f'Unsupported soc_version: {SocVersion.soc_version()}')\n\n        def get_pad_info(actual_tokens_current_rank, padded_tokens_current_rank, max_tokens_across_dp, tp_size,\n                         moe_comm_type):\n            x_active_mask = None\n            if moe_comm_type == DlinferMoECommType.MC2:\n                padded_size = math.ceil(max_tokens_across_dp / tp_size) * tp_size\n                pad_size = padded_size - padded_tokens_current_rank\n                x_active_mask = torch.ones(actual_tokens_current_rank,\n                                           dtype=torch.bool,\n                                           device=torch.npu.current_device())\n            elif moe_comm_type == DlinferMoECommType.ALLTOALL:\n                pad_size = tp_size - padded_tokens_current_rank\n            elif moe_comm_type == DlinferMoECommType.ALLGATHER:\n                pad_size = max_tokens_across_dp - padded_tokens_current_rank\n            else:\n                pad_size = 0\n            return pad_size, x_active_mask\n\n        @lru_cache(maxsize=1)\n        def get_moe_group_name(group):\n            if group is None:\n                return None\n            local_rank = torch.distributed.get_rank(group=group)\n            backend = group._get_backend(torch.device('npu'))\n            group_name = backend.get_hccl_comm_name(local_rank)\n            return group_name\n\n        q_seqlens_cpu, kv_seqlens_cpu, kv_seqlens_expanded = get_cpu_seqlens(step_context.is_decoding,\n                                                                             is_unpaged_prefill)\n        q_seqlens_list, kv_seqlens_list = get_list_seqlens(step_context.is_decoding, is_unpaged_prefill, q_seqlens_cpu,\n                                                           kv_seqlens_cpu)\n        max_q_seq_len, max_kv_seq_len = get_max_seqlens(step_context.is_decoding, is_unpaged_prefill, q_seqlens_list,\n                                                        kv_seqlens_list)\n        kv_start_indices, attention_mask = get_kv_start_indices_and_attention_mask(step_context.is_decoding,\n                                                                                   is_unpaged_prefill, q_seqlens_list,\n                                                                                   kv_seqlens_list, max_q_seq_len,\n                                                                                   max_kv_seq_len)\n\n        if not cls.enable_graph and step_context.kv_quant_policy == 8:\n            record_file = os.getenv('ASCEND_QUANT_RECORD_FILE')\n            assert record_file, 'please specify valid ASCEND_QUANT_RECORD_FILE'\n            path = Path(record_file)\n            is_path = path.is_absolute() or path.is_relative_to('/')\n            exists = path.exists()\n            if not (is_path and exists):\n                raise ValueError('please specify valid ASCEND_QUANT_RECORD_FILE')\n            if not AscendKVQuantMeta.has_set_value:\n                total_layers = len(step_context.kv_caches)\n                AscendKVQuantMeta.set_value(step_context.block_offsets.device, step_context.model_config.dtype,\n                                            record_file, total_layers)\n\n        attn_meta_cls = cls.get_attention_metadata_cls()\n        attn_metadata = attn_meta_cls(\n            step_context.is_decoding,\n            step_context.block_offsets,\n            q_start_loc=None,\n            q_seqlens=q_seqlens_cpu,\n            # kv_seqlens_expanded is only expanded in paged prefill,\n            # otherwise it equals kv_seqlens_cpu\n            kv_seqlens=kv_seqlens_expanded,\n            kv_start_indices=kv_start_indices,\n            block_size=block_size,\n            attention_mask=attention_mask,\n            is_unpaged_prefill=is_unpaged_prefill,\n            max_q_seq_len=max_q_seq_len,\n            max_kv_seq_len=max_kv_seq_len,\n            quant_policy=step_context.kv_quant_policy,\n            quant_meta=AscendKVQuantMeta.quant_meta,\n        )\n        step_context.attn_metadata = attn_metadata\n\n        cls.dist_meta = get_dist_meta()\n        actual_tokens_current_rank, padded_tokens_current_rank, max_tokens_across_dp = get_tokens_info(\n            cls.dist_meta.dp_size, cls.dist_meta.tp_size, cls.dist_meta.ep_size, cls.dist_meta.ep_group)\n        moe_comm_type = select_moe_comm_type(max_tokens_across_dp, cls.dist_meta.dp_size, cls.dist_meta.tp_size,\n                                             cls.dist_meta.ep_size)\n        pad_size, x_active_mask = get_pad_info(actual_tokens_current_rank, padded_tokens_current_rank,\n                                               max_tokens_across_dp, cls.dist_meta.tp_size, moe_comm_type)\n        moe_group_name = get_moe_group_name(cls.dist_meta.ep_group)\n\n        moe_metadata = DlinferMoeMetadata(\n            max_tokens_across_dp=max_tokens_across_dp,\n            pad_size=pad_size,\n            dp_size=cls.dist_meta.dp_size,\n            tp_size=cls.dist_meta.tp_size,\n            ep_size=cls.dist_meta.ep_size,\n            tp_rank=cls.dist_meta.tp_rank,\n            ep_rank=cls.dist_meta.ep_rank,\n            tp_group=cls.dist_meta.tp_group,\n            ep_group=cls.dist_meta.ep_group,\n            moe_comm_type=moe_comm_type,\n            x_active_mask=x_active_mask,\n            moe_group_name=moe_group_name,\n        )\n        step_context.moe_metadata = moe_metadata\n        return step_context\n\n    @staticmethod\n    def build_graph_runner(model: torch.nn.Module, model_config: ModelConfig, cache_config: CacheConfig,\n                           backend_config: BackendConfig, device: torch.device):\n        \"\"\"Build graph runner.\"\"\"\n        AscendOpsBackend.enable_graph = not backend_config.eager_mode\n        AscendOpsBackend.max_batches = cache_config.max_batches\n        from dlinfer.framework.lmdeploy_ext.cudagraph.ascend_cudagraph import AscendGraphRunner\n        return AscendGraphRunner(model, model_config, cache_config, backend_config, device)\n\n    @staticmethod\n    def init():\n        \"\"\"Initialize Ascend backend.\"\"\"\n        try:\n            from torch_npu.contrib import transfer_to_npu  # noqa: F401\n        except ImportError:\n            logger.warning('Failed to import torch_npu. Please make sure torch_npu is installed correctly. '\n                           'Ascend initialization skipped.')\n        except Exception as e:\n            logger.warning(f'Error during Ascend initialization: {str(e)}. '\n                           'Please check your Ascend environment configuration.')\n\n    @staticmethod\n    def ccl_backend():\n        return 'hccl'\n\n    @staticmethod\n    def device_count():\n        \"\"\"Get num available devices.\"\"\"\n        return torch.npu.device_count()\n\n    @staticmethod\n    def support_ray():\n        \"\"\"Support ray.\"\"\"\n        if not _envs.ascend_set_rt_visable_devices_by_ray:\n            os.environ['RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES'] = '1'\n        return True\n"
  },
  {
    "path": "lmdeploy/pytorch/backends/dlinfer/ascend/utils.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport torch\nimport torch_npu\n\nACL_FORMAT_FRACTAL_NZ = 29\n\n\ndef nd_to_nz_spec(tensor: torch.Tensor) -> torch.Tensor:\n    '''\n    This function is copied from vllm-ascend commit hash: 420e794c35fe887db2be81cf9db0461f5b71da0b\n    It converts a tensor in ACL_FORMAT_ND format to ACL_FORMAT_FRACTAL_NZ format for Ascend 310P devices.\n    It behaves similarly to the TransdataOperation and it requires the input tensor to be 2D.\n    '''\n    num_tokens = tensor.shape[0]\n    max_seq_len = tensor.shape[1]\n\n    tokens_pad = (num_tokens + 15) // 16 * 16\n    max_seq_len_pad = (max_seq_len + 15) // 16 * 16\n\n    tensor_pad = \\\n        torch.zeros((1, tokens_pad, max_seq_len_pad), dtype=tensor.dtype, device=tensor.device)\n\n    tensor_pad[0][:num_tokens, :max_seq_len] = tensor\n    tensor_nz = tensor_pad.reshape((1, tokens_pad, max_seq_len_pad // 16, 16)).permute(0, 2, 1, 3)\n\n    tensor_nz = torch_npu.npu_format_cast(tensor_nz.contiguous(), ACL_FORMAT_FRACTAL_NZ)\n    return tensor_nz\n"
  },
  {
    "path": "lmdeploy/pytorch/backends/dlinfer/attention.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nfrom dataclasses import dataclass\nfrom typing import Dict, Optional, Sequence\n\nfrom torch import Tensor\n\nfrom ..attention import AttentionBuilder, AttentionImpl, AttentionMetadata\n\n\n@dataclass\nclass DlinferAttentionMetadata(AttentionMetadata):\n    kv_start_indices: Optional[Tensor] = None\n    block_size: int = 64\n    attention_mask: Sequence[Tensor] = tuple()\n    is_unpaged_prefill: Optional[bool] = None\n    max_q_seq_len: int = 1\n    max_kv_seq_len: int = 1\n    quant_meta: Dict = None\n    cu_seq_lens_kv: Optional[Tensor] = None\n\n\nclass DlinferAttentionImpl(AttentionImpl[DlinferAttentionMetadata]):\n    \"\"\"Dlinfer attention implementation.\"\"\"\n\n    def __init__(\n        self,\n        num_heads: int,\n        head_size: int,\n        scale: float = None,\n        num_kv_heads: int = None,\n        v_head_size: int = None,\n        alibi: bool = None,\n        sliding_window: int = None,\n        logit_softcapping: float = None,\n        causal: bool = True,\n        **kwargs,\n    ):\n        assert causal\n        super().__init__(\n            num_heads,\n            head_size,\n            scale,\n            num_kv_heads,\n            v_head_size,\n            alibi,\n            sliding_window,\n            logit_softcapping,\n            causal=causal,\n            **kwargs,\n        )\n\n        from lmdeploy.pytorch.kernels.dlinfer import fill_kv_cache, paged_attention_fwd\n\n        self.fill_kv_cache = fill_kv_cache\n        self.paged_attention_fwd = paged_attention_fwd\n\n    def forward(\n        self,\n        query: Tensor,\n        key: Tensor,\n        value: Tensor,\n        k_cache: Tensor,\n        v_cache: Tensor,\n        attn_metadata: DlinferAttentionMetadata,\n        k_scales_zeros: Tensor = None,\n        v_scales_zeros: Tensor = None,\n        learnable_sink: Tensor = None,\n        nsa_indices: Tensor = None,\n        inplace: bool = True,\n    ) -> Tensor:\n        \"\"\"forward.\"\"\"\n\n        block_offsets = attn_metadata.block_offsets\n        q_start_loc = attn_metadata.q_start_loc\n        q_seqlens = attn_metadata.q_seqlens\n        kv_seqlens = attn_metadata.kv_seqlens\n        is_decoding = attn_metadata.is_decoding\n        kv_start_indices = attn_metadata.kv_start_indices\n        block_size = attn_metadata.block_size\n        attn_mask = attn_metadata.attention_mask\n        is_unpaged_prefill = attn_metadata.is_unpaged_prefill\n        max_q_seq_len = attn_metadata.max_q_seq_len\n        max_kv_seq_len = attn_metadata.max_kv_seq_len\n        quant_bits = attn_metadata.quant_policy\n        cu_seq_lens_kv = attn_metadata.cu_seq_lens_kv\n\n        if attn_metadata.quant_meta is not None:\n            k_scales_zeros = [next(attn_metadata.quant_meta['k_scales']),\n                              next(attn_metadata.quant_meta['k_zeros'])\n                              ] if 'k_scales' in attn_metadata.quant_meta else []\n            v_scales_zeros = [next(attn_metadata.quant_meta['v_scales']),\n                              next(attn_metadata.quant_meta['v_zeros'])\n                              ] if 'v_scales' in attn_metadata.quant_meta else []\n            kv_scales = next(attn_metadata.quant_meta['kv_scales']) if 'kv_scales' in attn_metadata.quant_meta else None\n            kv_zeros = next(attn_metadata.quant_meta['kv_zeros']) if 'kv_zeros' in attn_metadata.quant_meta else None\n        else:\n            k_scales_zeros = []\n            v_scales_zeros = []\n            kv_scales = None\n            kv_zeros = None\n\n        # fill kv cache\n        k_cache, v_cache = self.fill_kv_cache(key,\n                                              value,\n                                              k_cache,\n                                              v_cache,\n                                              kv_start_indices,\n                                              k_scales_zeros=k_scales_zeros,\n                                              v_scales_zeros=v_scales_zeros,\n                                              quant_bits=quant_bits)\n\n        if inplace:\n            attn_output = query[..., :self.v_head_size]\n        else:\n            q_shape = query.shape\n            o_shape = q_shape[:-1] + (self.v_head_size, )\n            attn_output = query.new_empty(o_shape)\n\n        attn_output = self.paged_attention_fwd(\n            query,\n            key,\n            value,\n            attn_output,\n            k_cache,\n            v_cache,\n            block_offsets,\n            q_start_loc=q_start_loc,\n            q_seqlens=q_seqlens,\n            kv_seqlens=kv_seqlens,\n            cu_seq_lens_kv=cu_seq_lens_kv,\n            max_q_seq_len=max_q_seq_len,\n            max_kv_seq_len=max_kv_seq_len,\n            is_decoding=is_decoding,\n            block_size=block_size,\n            num_heads=self.num_heads,\n            num_kv_heads=self.num_kv_heads,\n            v_head_size=self.v_head_size,\n            attn_mask=attn_mask,\n            softmax_scale=self.scale,\n            is_unpaged_prefill=is_unpaged_prefill,\n            kv_scales=kv_scales,\n            kv_zeros=kv_zeros,\n            quant_bits=quant_bits,\n        )\n\n        return attn_output\n\n\nclass DlinferAttentionBuilder(AttentionBuilder[DlinferAttentionMetadata]):\n    \"\"\"Dlinfer attention builder.\"\"\"\n\n    @staticmethod\n    def build(\n        num_heads: int,\n        head_size: int,\n        scale: float = None,\n        num_kv_heads: int = None,\n        v_head_size: int = None,\n        alibi_scale: float = None,\n        sliding_window: int = None,\n        logit_softcapping: float = None,\n        causal: bool = True,\n        learnable_sink: bool = False,\n        **kwargs,\n    ) -> DlinferAttentionImpl:\n        \"\"\"build.\"\"\"\n        return DlinferAttentionImpl(num_heads,\n                                    head_size,\n                                    scale=scale,\n                                    num_kv_heads=num_kv_heads,\n                                    v_head_size=v_head_size,\n                                    alibi_scale=alibi_scale,\n                                    sliding_window=sliding_window,\n                                    logit_softcapping=logit_softcapping,\n                                    causal=causal,\n                                    **kwargs)\n"
  },
  {
    "path": "lmdeploy/pytorch/backends/dlinfer/awq_modules.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom typing import Optional\n\nimport torch\n\nfrom lmdeploy.pytorch.kernels.dlinfer import awq_linear\n\nfrom ..awq_modules import LinearW4A16Builder, LinearW4A16Impl\n\n\nclass AwqLinearW4A16Impl(LinearW4A16Impl):\n    \"\"\"Awq kernel linear.\"\"\"\n\n    def __init__(self, in_features: int, out_features: int, w_bit: int, group_size: int):\n        self.in_features = in_features\n        self.out_features = out_features\n        self.w_bit = w_bit\n        self.group_size = group_size\n\n    def forward(self,\n                x,\n                qweight: torch.Tensor,\n                scales: torch.Tensor,\n                qzeros: torch.Tensor,\n                bias: Optional[torch.Tensor] = None,\n                all_reduce: bool = False,\n                group: Optional[torch.distributed.ProcessGroup] = None):\n        \"\"\"forward.\"\"\"\n        out = awq_linear(x, qweight, scales, qzeros, bias, all_reduce, self.group_size)\n        return out\n\n\nclass AwqLinearW4A16Builder(LinearW4A16Builder):\n    \"\"\"Awq linear builder.\"\"\"\n\n    @staticmethod\n    def build(in_features: int,\n              out_features: int,\n              w_bit: int,\n              group_size: int,\n              bias: bool = False,\n              dtype: torch.dtype = None):\n        \"\"\"build.\"\"\"\n        return AwqLinearW4A16Impl(in_features, out_features, w_bit, group_size)\n"
  },
  {
    "path": "lmdeploy/pytorch/backends/dlinfer/camb/__init__.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom .op_backend import CambOpsBackend  # noqa: F401\n"
  },
  {
    "path": "lmdeploy/pytorch/backends/dlinfer/camb/op_backend.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom typing import Tuple\n\nimport torch\n\nfrom lmdeploy.pytorch.config import BackendConfig, CacheConfig, ModelConfig\nfrom lmdeploy.utils import get_logger\n\nfrom ..op_backend import DlinferOpsBackend\n\nlogger = get_logger('lmdeploy')\n\n\nclass CambOpsBackend(DlinferOpsBackend):\n    \"\"\"Camb layer backend.\"\"\"\n    total_slots = None\n\n    @staticmethod\n    def get_name() -> str:\n        \"\"\"Backend name.\"\"\"\n        return 'camb'\n\n    @staticmethod\n    def get_k_block_shape(\n        block_size: int,\n        num_heads: int,\n        head_size: int,\n        dtype: torch.dtype,\n    ) -> Tuple[int, ...]:\n        return (\n            num_heads,\n            block_size,\n            head_size,\n        )\n\n    @staticmethod\n    def get_v_block_shape(\n        block_size: int,\n        num_heads: int,\n        head_size: int,\n        dtype: torch.dtype,\n    ) -> Tuple[int, ...]:\n        return (\n            num_heads,\n            block_size,\n            head_size,\n        )\n\n    @classmethod\n    def update_step_context(cls, step_context):\n        \"\"\"Update step context.\"\"\"\n\n        def get_total_slots():\n            if cls.total_slots is None:\n                cls.total_slots = torch.arange(block_num * block_size,\n                                               dtype=torch.int32,\n                                               device=step_context.block_offsets.device)\n                cls.total_slots = cls.total_slots.view(block_num, block_size)\n            return cls.total_slots\n\n        kv_start_indices = []\n        block_num, _, block_size, _ = step_context.kv_caches[0][0].shape\n\n        is_unpaged_prefill = False\n        q_start_loc = step_context.q_start_loc\n        q_seqlens = step_context.q_seqlens\n        kv_seqlens = step_context.kv_seqlens.to(torch.int32)\n        block_offsets = step_context.block_offsets.to(torch.int32)\n        max_q_seq_len = torch.max(q_seqlens).cpu().item()\n        max_kv_seq_len = torch.max(kv_seqlens).cpu().item()\n\n        cu_seqlens = torch.cat((q_start_loc, q_seqlens.sum().unsqueeze(0))).int()\n        cu_seq_lens_kv = None\n\n        q_seqlens_list = step_context.q_seqlens.tolist()\n        kv_seqlens_list = step_context.kv_seqlens.tolist()\n        if not step_context.is_decoding:\n            is_unpaged_prefill = q_seqlens_list == kv_seqlens_list\n            # get kv_indices\n            for i in range(q_start_loc.size(0)):\n                q_seq_len = q_seqlens_list[i]\n                kv_seq_len = kv_seqlens_list[i]\n                # collect kv start indices.\n                history_length = kv_seq_len - q_seq_len\n                total_slots = get_total_slots()\n                slot_tables = total_slots[block_offsets[i]].view(-1)\n                slots = slot_tables[history_length:kv_seq_len]\n                kv_start_indices.append(slots)\n            kv_start_indices = torch.cat(kv_start_indices)\n            if not is_unpaged_prefill:\n                cu_seq_lens_kv = torch.cat((torch.tensor([0], device=kv_seqlens.device), kv_seqlens.cumsum(0))).int()\n        else:\n            # collect kv_start_indices without using a for-loop,\n            # (fill kv-cache for just ONE token during the decoding phase)\n            idx = (step_context.kv_seqlens - 1) % block_size\n            block_num = (step_context.kv_seqlens - 1) // block_size\n            last_block = block_offsets.gather(  # dtype of gather must be int64\n                1, block_num.view(-1, 1)).view(-1)\n            kv_start_indices = (last_block * block_size + idx).to(torch.int32)\n\n        attn_meta_cls = cls.get_attention_metadata_cls()\n        attn_metadata = attn_meta_cls(\n            step_context.is_decoding,\n            block_offsets,\n            q_start_loc=cu_seqlens,\n            cu_seq_lens_kv=cu_seq_lens_kv,\n            q_seqlens=q_seqlens,\n            kv_seqlens=kv_seqlens,\n            kv_start_indices=kv_start_indices,\n            block_size=block_size,\n            attention_mask=None,\n            is_unpaged_prefill=is_unpaged_prefill,\n            max_q_seq_len=max_q_seq_len,\n            max_kv_seq_len=max_kv_seq_len,\n        )\n\n        step_context.attn_metadata = attn_metadata\n        return step_context\n\n    @staticmethod\n    def build_graph_runner(model: torch.nn.Module, model_config: ModelConfig, cache_config: CacheConfig,\n                           backend_config: BackendConfig, device: torch.device):\n        \"\"\"Build graph runner.\"\"\"\n        from lmdeploy.pytorch.backends.cuda.graph_runner import CUDAGraphRunner\n        return CUDAGraphRunner(model, model_config, cache_config, backend_config, device)\n\n    @staticmethod\n    def support_ray():\n        \"\"\"Support ray.\"\"\"\n        return True\n"
  },
  {
    "path": "lmdeploy/pytorch/backends/dlinfer/flash_attention.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom torch import Tensor\n\nfrom ..flash_attention import FlashAttentionBuilder, FlashAttentionImpl\n\n\nclass DlinferFlashAttentionImpl(FlashAttentionImpl):\n    \"\"\"Dlinfer flash attention implementation.\"\"\"\n\n    def __init__(\n        self,\n        num_heads: int,\n        head_dim: int,\n        scale: float = None,\n        num_kv_heads: int = None,\n        v_head_dim: int = None,\n        causal: bool = True,\n        sliding_window: int = None,\n        logit_softcapping: float = None,\n    ):\n        if scale is None:\n            scale = 1.0 / (head_dim**0.5)\n        if num_kv_heads is None:\n            num_kv_heads = num_heads\n        if v_head_dim is None:\n            v_head_dim = head_dim\n        self.num_heads = num_heads\n        self.head_dim = head_dim\n        self.scale = scale\n        self.num_kv_heads = num_kv_heads\n        self.v_head_dim = v_head_dim\n        self.causal = causal\n        self.sliding_window = sliding_window\n        self.logit_softcapping = logit_softcapping\n        from lmdeploy.pytorch.kernels.dlinfer import flash_attention_fwd\n        self.flash_attention_fwd = flash_attention_fwd\n\n    def forward(self,\n                query: Tensor,\n                key: Tensor,\n                value: Tensor,\n                q_start_loc: Tensor,\n                q_seqlens: Tensor,\n                kv_start_loc: Tensor,\n                kv_seqlens: Tensor,\n                max_q_seqlen: int = None):\n        \"\"\"forward.\"\"\"\n        q_shape = query.shape\n        o_shape = q_shape[:-1] + (self.v_head_dim, )\n        out = query.new_empty(o_shape)\n        self.flash_attention_fwd(\n            query,\n            key,\n            value,\n            out,\n            q_start_loc=q_start_loc,\n            q_seqlens=q_seqlens,\n            kv_start_loc=kv_start_loc,\n            kv_seqlens=kv_seqlens,\n            num_heads=self.num_heads,\n            num_kv_heads=self.num_kv_heads,\n            max_q_seqlen=max_q_seqlen,\n            window_size=self.sliding_window,\n            sm_scale=self.scale,\n            logit_softcapping=self.logit_softcapping,\n            causal=self.causal,\n        )\n        return out\n\n\nclass DlinferFlashAttentionBuilder(FlashAttentionBuilder):\n    \"\"\"Dlinfer attention builder.\"\"\"\n\n    @staticmethod\n    def build(\n        num_heads: int,\n        head_dim: int,\n        scale: float = None,\n        num_kv_heads: int = None,\n        v_head_dim: int = None,\n        causal: bool = True,\n        sliding_window: int = None,\n        logit_softcapping: float = None,\n        **kwargs,\n    ) -> FlashAttentionImpl:\n        \"\"\"build.\"\"\"\n        return DlinferFlashAttentionImpl(\n            num_heads=num_heads,\n            head_dim=head_dim,\n            scale=scale,\n            num_kv_heads=num_kv_heads,\n            v_head_dim=v_head_dim,\n            causal=causal,\n            sliding_window=sliding_window,\n            logit_softcapping=logit_softcapping,\n        )\n"
  },
  {
    "path": "lmdeploy/pytorch/backends/dlinfer/linear.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport os\nfrom typing import List, Optional\n\nimport torch\nimport torch.distributed as dist\n\nfrom lmdeploy.pytorch.kernels.dlinfer import linear\n\nfrom ..linear import LinearBuilder, LinearImpl\n\n\nclass DlinferLinearImpl(LinearImpl):\n    \"\"\"Dlinfer linear implementation api.\"\"\"\n\n    def update_weights(self, weight: torch.Tensor, bias: Optional[torch.Tensor] = None):\n        \"\"\"Update weights.\"\"\"\n        if os.getenv('DLINFER_LINEAR_USE_NN_LAYOUT', '0') == '1':\n            weight = weight.data.t().contiguous()\n        return weight, bias\n\n    def forward(self,\n                x,\n                weight: torch.Tensor,\n                bias: Optional[torch.Tensor] = None,\n                all_reduce: bool = False,\n                group: dist.ProcessGroup = None,\n                rank: int = 0,\n                scatter_size: List[int] = None):\n        \"\"\"forward.\"\"\"\n        out = linear(x, weight, bias, False)\n        if all_reduce:\n            dist.all_reduce(out, group=group)\n        return out\n\n\nclass DlinferLinearBuilder(LinearBuilder):\n    \"\"\"Dlinfer linear implementation builder.\"\"\"\n\n    @staticmethod\n    def build(in_features: int, out_features: int, bias: bool = True, dtype: torch.dtype = None):\n        \"\"\"build.\"\"\"\n        return DlinferLinearImpl()\n"
  },
  {
    "path": "lmdeploy/pytorch/backends/dlinfer/maca/__init__.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom .op_backend import MacaOpsBackend  # noqa: F401\n"
  },
  {
    "path": "lmdeploy/pytorch/backends/dlinfer/maca/op_backend.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom typing import Tuple\n\nimport torch\n\nfrom lmdeploy.pytorch.config import BackendConfig, CacheConfig, ModelConfig\nfrom lmdeploy.utils import get_logger\n\nfrom ..op_backend import DlinferOpsBackend\n\nlogger = get_logger('lmdeploy')\n\n\nclass MacaOpsBackend(DlinferOpsBackend):\n    \"\"\"Maca layer backend.\"\"\"\n    total_slots = None\n\n    @staticmethod\n    def get_name() -> str:\n        \"\"\"Backend name.\"\"\"\n        return 'maca'\n\n    @staticmethod\n    def get_k_block_shape(\n        block_size: int,\n        num_heads: int,\n        head_size: int,\n        dtype: torch.dtype,\n    ) -> Tuple[int, ...]:\n        return (block_size, num_heads, head_size)\n\n    @staticmethod\n    def get_v_block_shape(\n        block_size: int,\n        num_heads: int,\n        head_size: int,\n        dtype: torch.dtype,\n    ) -> Tuple[int, ...]:\n        return (block_size, num_heads, head_size)\n\n    @classmethod\n    def update_step_context(cls, step_context):\n        \"\"\"Update step context.\"\"\"\n\n        def get_total_slots():\n            if cls.total_slots is None:\n                cls.total_slots = torch.arange(block_num * block_size,\n                                               dtype=torch.long,\n                                               device=step_context.block_offsets.device)\n                cls.total_slots = cls.total_slots.view(block_num, block_size)\n            return cls.total_slots\n\n        kv_start_indices, attention_mask = [], []\n        block_num, block_size, _, _ = step_context.kv_caches[0][1].shape\n\n        is_unpaged_prefill = False\n        if not step_context.is_decoding:\n            is_unpaged_prefill = \\\n               all((step_context.q_seqlens ==\n                    step_context.kv_seqlens).tolist())\n        q_start_loc = step_context.q_start_loc\n        cu_seqlens = torch.cat((q_start_loc, step_context.q_seqlens.sum().unsqueeze(0))).int()\n\n        q_seqlens = step_context.q_seqlens.int()\n        kv_seqlens = step_context.kv_seqlens.int()\n\n        if step_context.is_decoding:\n            # max_q_seq_len, max_kv_seq_len is not used in decoding stage\n            max_q_seq_len = -1\n            max_kv_seq_len = -1\n\n            # collect kv_start_indices without using a for-loop,\n            # (fill kv-cache for just ONE token during the decoding phase)\n            idx = (step_context.kv_seqlens - 1) % block_size\n            b_num = (step_context.kv_seqlens - 1) // block_size\n            last_block = step_context.block_offsets.gather(1, b_num.view(-1, 1)).view(-1)\n            kv_start_indices = (last_block * block_size + idx).reshape((-1, 1))\n        else:\n            max_q_seq_len = torch.max(q_seqlens).cpu().item()\n            max_kv_seq_len = torch.max(kv_seqlens).cpu().item()\n\n            for i in range(step_context.q_start_loc.size(0)):\n                q_seq_len = int(step_context.q_seqlens[i])\n                kv_seq_len = int(step_context.kv_seqlens[i])\n                # collect kv start indices during the prefill phase.\n                history_length = kv_seq_len - q_seq_len\n                total_slots = get_total_slots()\n                slot_tables = total_slots[step_context.block_offsets[i]].view(-1)\n                slots = slot_tables[history_length:kv_seq_len]\n                kv_start_indices.append(slots)\n            kv_start_indices = torch.cat(kv_start_indices)\n\n        attn_meta_cls = cls.get_attention_metadata_cls()\n        attn_metadata = attn_meta_cls(\n            step_context.is_decoding,\n            step_context.block_offsets.int(),\n            q_start_loc=cu_seqlens,\n            q_seqlens=q_seqlens,\n            kv_seqlens=kv_seqlens,\n            kv_start_indices=kv_start_indices,\n            block_size=block_size,\n            attention_mask=attention_mask,\n            is_unpaged_prefill=is_unpaged_prefill,\n            max_q_seq_len=max_q_seq_len,\n            max_kv_seq_len=max_kv_seq_len,\n        )\n\n        step_context.attn_metadata = attn_metadata\n        return step_context\n\n    @staticmethod\n    def build_graph_runner(model: torch.nn.Module, model_config: ModelConfig, cache_config: CacheConfig,\n                           backend_config: BackendConfig, device: torch.device):\n        \"\"\"Build graph runner.\"\"\"\n        from lmdeploy.pytorch.backends.cuda.graph_runner import CUDAGraphRunner\n        return CUDAGraphRunner(model, model_config, cache_config, backend_config, device)\n\n    @staticmethod\n    def support_ray():\n        \"\"\"Support ray.\"\"\"\n        return True\n"
  },
  {
    "path": "lmdeploy/pytorch/backends/dlinfer/moe.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport os\nfrom typing import Callable, List\n\nimport torch\n\nfrom lmdeploy.pytorch.kernels.dlinfer import DlinferMoECommType  # noqa: F401\nfrom lmdeploy.pytorch.kernels.dlinfer import DlinferMoeMetadata  # noqa: F401\nfrom lmdeploy.pytorch.kernels.dlinfer import fused_moe, moe_gating_topk_softmax\nfrom lmdeploy.pytorch.model_inputs import get_step_ctx_manager\n\nfrom ..moe import FusedMoEBuilder, FusedMoEImpl, SoftmaxTopKBuilder, SoftmaxTopKImpl\n\n\nclass DlinferSoftmaxTopKImpl(SoftmaxTopKImpl):\n    \"\"\"Dlinfer softmax topk implementation.\"\"\"\n\n    def __init__(self, top_k: int, dim: int = -1, n_groups: int = -1):\n        self.top_k = top_k\n        self.dim = dim\n        self.n_groups = n_groups\n\n    def forward(self, x: torch.Tensor):\n        step_context = get_step_ctx_manager().current_context()\n        moe_metadata = getattr(step_context, 'moe_metadata', None)\n        if moe_metadata is not None:\n            moe_metadata.router_n_groups = self.n_groups\n        routing_weights, selected_experts = moe_gating_topk_softmax(x, self.top_k, moe_metadata)\n        return routing_weights, selected_experts\n\n\nclass DlinferSoftmaxTopKBuilder(SoftmaxTopKBuilder):\n    \"\"\"Dlinfer softmax topk implementation builder.\"\"\"\n\n    @staticmethod\n    def build(top_k: int, dim: int = -1, n_groups: int = -1):\n        \"\"\"build.\"\"\"\n        return DlinferSoftmaxTopKImpl(top_k, dim, n_groups)\n\n\nclass DlinferFusedMoEImpl(FusedMoEImpl):\n    \"\"\"Dlinfer fused moe implementation.\"\"\"\n\n    def __init__(self,\n                 top_k: int,\n                 num_experts: int,\n                 renormalize: bool = False,\n                 ep_size: int = 1,\n                 ep_group: torch.distributed.ProcessGroup = None):\n        self.top_k = top_k\n        self.num_experts = num_experts\n        self.renormalize = renormalize\n        self.ep_size = ep_size\n        self.ep_group = ep_group\n        self.expert_ids_per_ep_rank = None\n        if self.ep_size > 1:\n            self.expert_ids_per_ep_rank = torch.tensor(\n                [i % (self.num_experts // self.ep_size) for i in range(num_experts)],\n                dtype=torch.int32,\n                device=torch.cuda.current_device(),\n            )\n\n    def update_weights(self, gate_up_weights: torch.Tensor, down_weights: torch.Tensor):\n        \"\"\"Update weights.\"\"\"\n        device_type = gate_up_weights.device.type\n        if device_type in ['npu']:\n            if os.getenv('DLINFER_RESET_MOE_UPDATE_WEIGHTS', '0') == '1':\n                return gate_up_weights, down_weights\n            return gate_up_weights.transpose(-1, -2).contiguous(), down_weights.transpose(-1, -2).contiguous()\n        return gate_up_weights, down_weights\n\n    def ep_expert_list(self, world_size: int, rank: int):\n        \"\"\"Experts list of current rank.\"\"\"\n        num_experts = self.num_experts\n        expert_per_rank = (num_experts + world_size - 1) // world_size\n        first_expert = rank * expert_per_rank\n        last_expert = min(first_expert + expert_per_rank, num_experts)\n        return list(range(first_expert, last_expert))\n\n    def forward(self,\n                hidden_states: torch.Tensor,\n                topk_weights: torch.Tensor,\n                topk_ids: torch.LongTensor,\n                gate_up_weights: torch.Tensor,\n                down_weights: torch.Tensor,\n                gate_up_bias: torch.Tensor = None,\n                down_bias: torch.Tensor = None,\n                expert_list: List[int] = None,\n                act_func: Callable = None):\n        \"\"\"forward.\"\"\"\n        assert gate_up_bias is None\n        assert down_bias is None\n\n        step_context = get_step_ctx_manager().current_context()\n        moe_metadata = getattr(step_context, 'moe_metadata', None)\n        if moe_metadata is not None:\n            moe_metadata.expert_ids_per_ep_rank = self.expert_ids_per_ep_rank\n        return fused_moe(hidden_states, gate_up_weights, down_weights, topk_weights, topk_ids, self.top_k,\n                         self.renormalize, moe_metadata)\n\n\nclass DlinferFusedMoEBuilder(FusedMoEBuilder):\n    \"\"\"Dlinfer fused moe builder.\"\"\"\n\n    @staticmethod\n    def build(top_k: int,\n              num_experts: int,\n              renormalize: bool = False,\n              hidden_dim: int = 1,\n              ep_size: int = 1,\n              ep_group: torch.distributed.ProcessGroup = None,\n              layer_idx: int = 0,\n              out_dtype: torch.dtype = torch.bfloat16):\n        \"\"\"Build from mlp.\"\"\"\n        return DlinferFusedMoEImpl(top_k=top_k,\n                                   num_experts=num_experts,\n                                   renormalize=renormalize,\n                                   ep_size=ep_size,\n                                   ep_group=ep_group)\n"
  },
  {
    "path": "lmdeploy/pytorch/backends/dlinfer/norm.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport torch\n\nfrom lmdeploy.pytorch.kernels.dlinfer import rms_norm\n\nfrom ..norm import RMSNormBuilder, RMSNormImpl\n\n\nclass DlinferRMSNormImpl(RMSNormImpl):\n    \"\"\"Dlinfer RMS norm implementation.\"\"\"\n\n    def __init__(self, hidden_size: int, eps: float = 1e-6):\n        self.hidden_size = hidden_size\n        self.eps = eps\n\n    def forward(self, x: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor = None):\n        \"\"\"forward.\"\"\"\n        if residual is None:\n            x = rms_norm(x, weight, self.eps)\n            return x\n        else:\n            x, residual = rms_norm(x, weight, self.eps, residual=residual)\n            return x, residual\n\n\nclass DlinferRMSNormBuilder(RMSNormBuilder):\n    \"\"\"Dlinfer RMS norm implementation builder.\"\"\"\n\n    @staticmethod\n    def build(weight: torch.Tensor, eps: float = 1e-6):\n        \"\"\"build.\"\"\"\n        return DlinferRMSNormImpl(weight, eps)\n"
  },
  {
    "path": "lmdeploy/pytorch/backends/dlinfer/op_backend.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom typing import Tuple\n\nimport torch\n\nfrom lmdeploy.utils import get_logger\n\nfrom ..base import OpType\nfrom ..default import DefaultOpsBackend\n\nlogger = get_logger('lmdeploy')\n\n\nclass DlinferOpsBackend(DefaultOpsBackend):\n    \"\"\"Dlinfer layer backend.\"\"\"\n\n    @staticmethod\n    def get_name() -> str:\n        \"\"\"Backend name.\"\"\"\n        return 'dlinfer'\n\n    @classmethod\n    def get_layer_impl_builder(cls, layer_type: OpType):\n        \"\"\"Get dlinfer layer builder.\"\"\"\n        if layer_type == OpType.PagedAttention:\n            from .attention import DlinferAttentionBuilder\n            return DlinferAttentionBuilder\n        elif layer_type == OpType.FlashAttention:\n            from .flash_attention import DlinferFlashAttentionBuilder\n            return DlinferFlashAttentionBuilder\n        elif layer_type == OpType.ApplyRotaryEmb:\n            from .apply_rotary_emb import DlinferApplyRotaryEmbBuilder\n            return DlinferApplyRotaryEmbBuilder\n        elif layer_type == OpType.SiluAndMul:\n            from .activation import DlinferSiluAndMulBuilder\n            return DlinferSiluAndMulBuilder\n        elif layer_type == OpType.RMSNorm:\n            from .norm import DlinferRMSNormBuilder\n            return DlinferRMSNormBuilder\n        elif layer_type == OpType.LinearW8A8:\n            from .qmodules import DlinferLinearW8A8Builder\n            return DlinferLinearW8A8Builder\n        elif layer_type == OpType.RMSNormW8A8:\n            from .qmodules import DlinferRMSNormW8A8Builder\n            return DlinferRMSNormW8A8Builder\n        elif layer_type == OpType.SoftmaxTopK:\n            from .moe import DlinferSoftmaxTopKBuilder\n            return DlinferSoftmaxTopKBuilder\n        elif layer_type == OpType.FusedMoE:\n            from .moe import DlinferFusedMoEBuilder\n            return DlinferFusedMoEBuilder\n        elif layer_type == OpType.Linear:\n            from .linear import DlinferLinearBuilder\n            return DlinferLinearBuilder\n        elif layer_type == OpType.LinearW4A16:\n            from .awq_modules import AwqLinearW4A16Builder\n            return AwqLinearW4A16Builder\n        elif layer_type == OpType.RotaryEmbedding:\n            from .rotary_embedding import DlinferRotaryEmbeddingBuilder\n            return DlinferRotaryEmbeddingBuilder\n        else:\n            logger.debug(f'Op {layer_type} fallback to default implementation.')\n            return super().get_layer_impl_builder(layer_type)\n\n    @staticmethod\n    def get_attention_metadata_cls():\n        from .attention import DlinferAttentionMetadata\n        return DlinferAttentionMetadata\n\n    @staticmethod\n    def get_k_block_shape(\n        block_size: int,\n        num_heads: int,\n        head_size: int,\n        dtype: torch.dtype,\n    ) -> Tuple[int, ...]:\n        return (\n            block_size,\n            num_heads,\n            head_size,\n        )\n\n    @staticmethod\n    def get_v_block_shape(\n        block_size: int,\n        num_heads: int,\n        head_size: int,\n        dtype: torch.dtype,\n    ) -> Tuple[int, ...]:\n        return (\n            block_size,\n            num_heads,\n            head_size,\n        )\n\n    @classmethod\n    def update_step_context(cls, step_context):\n        \"\"\"Update step context.\"\"\"\n        raise NotImplementedError\n"
  },
  {
    "path": "lmdeploy/pytorch/backends/dlinfer/qmodules.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport os\nfrom typing import Optional\n\nimport torch\nimport torch.distributed as dist\n\nfrom lmdeploy.pytorch.kernels.dlinfer.w8a8_kernels import dynamic_quant, linear_w8a8, rms_norm_w8a8\nfrom lmdeploy.pytorch.models.q_modules import QTensor\n\nfrom ..qmodules import LinearW8A8Builder, LinearW8A8Impl, RMSNormW8A8Builder, RMSNormW8A8Impl\n\n\nclass DlinferLinearW8A8Impl(LinearW8A8Impl):\n    \"\"\"Dlinfer linear w8a8 implementation.\"\"\"\n\n    def __init__(self,\n                 in_features: int,\n                 out_features: int,\n                 out_dtype: torch.dtype = torch.float16,\n                 quant_dtype: torch.dtype = torch.int8):\n        self.in_features = in_features\n        self.out_features = out_features\n        self.out_dtype = out_dtype\n        self.quant_dtype = quant_dtype\n\n    def update_weights(self, weight: torch.Tensor, scale: torch.Tensor, bias: Optional[torch.Tensor] = None):\n        \"\"\"Update weights.\"\"\"\n        if os.getenv('DLINFER_LINEAR_USE_NN_LAYOUT', '0') == '1':\n            weight = weight.data.t().contiguous()\n            scale = scale.data.t().contiguous()\n        return weight, scale, bias\n\n    def forward(self,\n                x,\n                weight: torch.Tensor,\n                scale: torch.Tensor,\n                bias: Optional[torch.Tensor] = None,\n                all_reduce: bool = False,\n                group: Optional[torch.distributed.ProcessGroup] = None):\n        \"\"\"forward.\"\"\"\n        if isinstance(x, torch.Tensor):\n            input_quant, input_scale = dynamic_quant(x, self.quant_dtype)\n        else:\n            assert isinstance(x, QTensor)\n            input_quant, input_scale = x.tensor, x.scale\n\n        out = linear_w8a8(input_quant, weight, input_scale, scale, self.out_dtype, self.quant_dtype, bias)\n        if all_reduce:\n            dist.all_reduce(out, group=group)\n        return out\n\n\nclass DlinferLinearW8A8Builder(LinearW8A8Builder):\n    \"\"\"Dlinfer linear w8a8 implementation builder.\"\"\"\n\n    @staticmethod\n    def build(in_features: int,\n              out_features: int,\n              bias: bool = True,\n              dtype: torch.dtype = None,\n              quant_dtype: torch.dtype = torch.int8):\n        \"\"\"build.\"\"\"\n        return DlinferLinearW8A8Impl(in_features, out_features, dtype, quant_dtype)\n\n\nclass DlinferRMSNormW8A8Impl(RMSNormW8A8Impl):\n    \"\"\"Dlinfer RMS norm w8a8 implementation api.\"\"\"\n\n    def __init__(self, hidden_size: int, eps: float = 1e-6, quant_dtype: torch.dtype = torch.int8):\n        super().__init__()\n        self.hidden_size = hidden_size\n        self.eps = eps\n        self.quant_dtype = quant_dtype\n\n    def forward(self, x: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor = None):\n        \"\"\"forward.\"\"\"\n        if residual is None:\n            (x, rms_scale) = rms_norm_w8a8(x, weight, self.eps, self.quant_dtype)\n            x = QTensor(x, rms_scale)\n            return x\n        else:\n            (x, rms_scale, residual) = rms_norm_w8a8(x, weight, self.eps, self.quant_dtype, residual)\n            x = QTensor(x, rms_scale)\n            return x, residual\n\n\nclass DlinferRMSNormW8A8Builder(RMSNormW8A8Builder):\n    \"\"\"Dlinfer RMS norm w8a8 implementation builder.\"\"\"\n\n    @staticmethod\n    def build(hidden_size: int, eps: float = 1e-6, quant_dtype: torch.dtype = torch.int8):\n        \"\"\"build.\"\"\"\n        return DlinferRMSNormW8A8Impl(hidden_size, eps, quant_dtype)\n"
  },
  {
    "path": "lmdeploy/pytorch/backends/dlinfer/rotary_embedding.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nimport math\n\nimport torch\nfrom torch import nn\n\nfrom ..default.rotary_embedding import (FopeRotaryEmbeddingImpl, LlamaDynamicNTKScalingRotaryEmbedding,\n                                        YarnRotaryEmbeddingImpl)\nfrom ..rotary_embedding import (FopeParameters, Llama3Parameters, LongRoPEScalingParameters, RopeType,\n                                RotaryEmbeddingBuilder, RotaryEmbeddingImpl, YarnParameters)\n\n\ndef _rotary_embedding_fwd(position_ids: torch.Tensor,\n                          inv_freq: torch.Tensor,\n                          scaling_factor: float,\n                          mscale: float = None,\n                          dtype: torch.dtype = None):\n    \"\"\"Rotary embedding forward.\"\"\"\n    if dtype is None:\n        dtype = torch.float16\n\n    if scaling_factor != 1.0:\n        position_ids = position_ids.float() / scaling_factor\n    else:\n        position_ids = position_ids.float()\n\n    position_ids = position_ids.unsqueeze(-1)\n    angles = position_ids * inv_freq.view(1, 1, -1)\n    angles = torch.cat((angles, angles), dim=-1)\n\n    sin = angles.sin()\n    cos = angles.cos()\n\n    if mscale is not None:\n        cos = cos * mscale\n        sin = sin * mscale\n    return cos.to(dtype=dtype), sin.to(dtype=dtype)\n\n\nclass DlinferRotaryEmbeddingImpl(RotaryEmbeddingImpl, nn.Module):\n    \"\"\"Base rotary embedding.\"\"\"\n\n    def __init__(self, dim: int, base: int = 10000, scaling_factor: float = 1.0):\n        super().__init__()\n        self.scaling_factor = scaling_factor\n        self.dim = dim\n        self.base = base\n        # yapf: disable\n        inv_freq = 1.0 / (self.base**(torch.arange(0, self.dim, 2, dtype=torch.float, device='cuda') / self.dim))\n        # yapf: enable\n        self.register_buffer('inv_freq', inv_freq, persistent=False)\n\n    def forward(self, x, position_ids):\n        \"\"\"forward.\"\"\"\n        # x: [bs, num_attention_heads, seq_len, head_size]\n        dtype = x.dtype\n        if self.inv_freq.device != x.device:\n            self.inv_freq = self.inv_freq.to(x.device)\n        return _rotary_embedding_fwd(position_ids, self.inv_freq, scaling_factor=self.scaling_factor, dtype=dtype)\n\n\nclass DlinferLlamaDynamicNTKScalingRotaryEmbedding(LlamaDynamicNTKScalingRotaryEmbedding):\n    \"\"\"LlamaRotaryEmbedding extended with Dynamic NTK scaling.\n\n    Credits to the Reddit users /u/bloc97 and /u/emozilla\n    \"\"\"\n\n    def __init__(self, dim: int, base: int = 10000, scaling_factor: float = 1.0, max_position_embeddings: int = 2048):\n        super().__init__(dim, base, scaling_factor, max_position_embeddings)\n        self.dim_scale_ratio = self.dim / (self.dim - 2)\n        self.pos_freq_scaling = torch.arange(0, self.dim, 2, dtype=torch.int64).float().cuda() / self.dim\n        self.scale_offset = self.scaling_factor - 1\n        self.pos_scale_factor = self.scaling_factor / \\\n            self.max_position_embeddings\n\n    def _ntk_inv_freq(self, seq_len: torch.Tensor):\n        \"\"\"Calculate inverse frequency with NTK scaling.\"\"\"\n        base = self.base * ((self.pos_scale_factor * seq_len) - self.scale_offset)**self.dim_scale_ratio\n        inv_freq = 1.0 / (base**self.pos_freq_scaling)\n        return inv_freq\n\n    def forward(self, x: torch.Tensor, position_ids: torch.Tensor):\n        \"\"\"forward.\"\"\"\n        dtype = x.dtype\n        seq_len = torch.max(position_ids) + 1\n        ntk_inv_freq = self._ntk_inv_freq(seq_len)\n        if self.inv_freq.device != x.device:\n            self.inv_freq = self.inv_freq.to(x.device)\n        inv_freq = torch.where(seq_len > self.max_position_embeddings, ntk_inv_freq, self.inv_freq)\n\n        cos, sin = _rotary_embedding_fwd(position_ids, inv_freq, scaling_factor=1.0, dtype=dtype)\n        return cos, sin\n\n\nclass DlinferLlama3RotaryEmbeddingImpl(DlinferRotaryEmbeddingImpl):\n    \"\"\"Llama3 rotary embedding implementation.\"\"\"\n\n    def __init__(\n        self,\n        dim: int,\n        base: int = 10000,\n        scaling_factor: float = 1.0,\n        low_freq_factor: float = 1.0,\n        high_freq_factor: float = 4.0,\n        original_max_position_embeddings: int = 8194,\n    ):\n        super().__init__(dim, base, scaling_factor)\n        old_context_len = original_max_position_embeddings\n        low_freq_wavelen = old_context_len / low_freq_factor\n        high_freq_wavelen = old_context_len / high_freq_factor\n\n        inv_freq = self.inv_freq\n        factor = self.scaling_factor\n\n        wavelen = 2 * math.pi / inv_freq\n        # wavelen < high_freq_wavelen: do nothing\n        # wavelen > low_freq_wavelen: divide by factor\n        inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq)\n        # otherwise: interpolate between the two, using a smooth factor\n        smooth_factor = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)\n        smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / factor + smooth_factor * inv_freq_llama\n        is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)\n        inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)\n        self.scaling_factor = 1.0\n        self.register_buffer('inv_freq', inv_freq_llama)\n\n\nclass DlinferYarnRotaryEmbeddingImpl(YarnRotaryEmbeddingImpl):\n    \"\"\"Yarn rotary embedding implementation.\"\"\"\n\n    def __init__(self,\n                 dim: int,\n                 base: int = 10000,\n                 scaling_factor: float = 1.0,\n                 original_max_position_embeddings: int = 4096,\n                 yarn_params: YarnParameters = None):\n        super().__init__(dim, base, scaling_factor, original_max_position_embeddings, yarn_params)\n\n    def forward(self, x: torch.Tensor, position_ids: torch.Tensor):\n        \"\"\"forward.\"\"\"\n        dtype = x.dtype\n        if self.inv_freq.device != x.device:\n            self.inv_freq = self.inv_freq.to(x.device)\n        return _rotary_embedding_fwd(position_ids, self.inv_freq, scaling_factor=1.0, mscale=self.mscale, dtype=dtype)\n\n\nclass DlinferRotaryEmbeddingBuilder(RotaryEmbeddingBuilder):\n    \"\"\"Rotary embedding dlinfer builder.\"\"\"\n\n    @staticmethod\n    def build(\n        dim: int,\n        max_position_embeddings: int = 2048,\n        base: int = 10000,\n        scaling_factor: float = 1.0,\n        yarn_params: YarnParameters = None,\n        longrope_params: LongRoPEScalingParameters = None,\n        llama3_params: Llama3Parameters = None,\n        fope_params: FopeParameters = None,\n        emb_type: RopeType = RopeType.Default,\n    ):\n        \"\"\"build.\"\"\"\n        if emb_type in (RopeType.Default, RopeType.LinearScaling):\n            return DlinferRotaryEmbeddingImpl(dim, base, scaling_factor)\n        elif emb_type == RopeType.DynamicNTKScaling:\n            return DlinferLlamaDynamicNTKScalingRotaryEmbedding(dim, base, scaling_factor, max_position_embeddings)\n        elif emb_type == RopeType.Llama3:\n            return DlinferLlama3RotaryEmbeddingImpl(dim, base, scaling_factor, llama3_params.low_freq_factor,\n                                                    llama3_params.high_freq_factor, max_position_embeddings)\n        elif emb_type == RopeType.Yarn:\n            return DlinferYarnRotaryEmbeddingImpl(dim,\n                                                  base,\n                                                  scaling_factor,\n                                                  max_position_embeddings,\n                                                  yarn_params=yarn_params)\n        elif emb_type == RopeType.Fope:\n            return FopeRotaryEmbeddingImpl(\n                dim,\n                max_position_embeddings=max_position_embeddings,\n                scaling_factor=scaling_factor,\n                params=fope_params,\n            )\n        else:\n            raise NotImplementedError(f'Unsupported embedding type: {emb_type}')\n"
  },
  {
    "path": "lmdeploy/pytorch/backends/embedding.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom abc import ABC, abstractmethod\n\nimport torch\nimport torch.distributed as dist\n\n\nclass EmbeddingImpl(ABC):\n    \"\"\"Embedding implementation api.\"\"\"\n\n    @abstractmethod\n    def forward(self, x, weight: torch.Tensor, all_reduce: bool = False, group: dist.ProcessGroup = None):\n        \"\"\"forward.\"\"\"\n        raise NotImplementedError\n\n\nclass EmbeddingBuilder(ABC):\n    \"\"\"Embedding implementation builder.\"\"\"\n\n    @staticmethod\n    @abstractmethod\n    def build(start_index: int, end_index: int):\n        \"\"\"build.\"\"\"\n        raise NotImplementedError\n"
  },
  {
    "path": "lmdeploy/pytorch/backends/flash_attention.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom abc import ABC, abstractmethod\n\nfrom torch import Tensor\n\n\nclass FlashAttentionImpl(ABC):\n    \"\"\"FlashAttention implementation.\"\"\"\n\n    def forward(self,\n                query: Tensor,\n                key: Tensor,\n                value: Tensor,\n                q_start_loc: Tensor,\n                q_seqlens: Tensor,\n                kv_start_loc: Tensor,\n                kv_seqlens: Tensor,\n                max_q_seqlen: int = None):\n        \"\"\"forward.\"\"\"\n        raise NotImplementedError\n\n\nclass FlashAttentionBuilder(ABC):\n    \"\"\"FlashAttention implementation builder.\"\"\"\n\n    @staticmethod\n    @abstractmethod\n    def build(\n        num_heads: int,\n        head_dim: int,\n        scale: float = None,\n        num_kv_heads: int = None,\n        v_head_dim: int = None,\n        causal: bool = True,\n        sliding_window: int = None,\n        logit_softcapping: float = None,\n        **kwargs,\n    ) -> FlashAttentionImpl:\n        \"\"\"build.\"\"\"\n        raise NotImplementedError\n"
  },
  {
    "path": "lmdeploy/pytorch/backends/gated_delta_rule.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom abc import ABC, abstractmethod\n\nimport torch\n\n\nclass GatedDeltaRuleImpl(ABC):\n    \"\"\"Gated Delta Rule implementation api.\"\"\"\n\n    @abstractmethod\n    def chunk_gated_delta_rule(self,\n                               q: torch.Tensor,\n                               k: torch.Tensor,\n                               v: torch.Tensor,\n                               g: torch.Tensor | None = None,\n                               beta: torch.Tensor | None = None,\n                               initial_state: torch.Tensor | None = None,\n                               state_indices: torch.Tensor | None = None,\n                               scale: float | None = None,\n                               use_qk_l2norm_in_kernel: bool = False,\n                               cu_seqlens: torch.Tensor | None = None,\n                               output_final_state: bool = False):\n        \"\"\"forward.\"\"\"\n        raise NotImplementedError\n\n    @abstractmethod\n    def fused_recurrent_gated_delta_rule(self,\n                                         q: torch.Tensor,\n                                         k: torch.Tensor,\n                                         v: torch.Tensor,\n                                         g: torch.Tensor | None = None,\n                                         beta: torch.Tensor | None = None,\n                                         initial_state: torch.Tensor | None = None,\n                                         state_indices: torch.Tensor | None = None,\n                                         scale: float | None = None,\n                                         use_qk_l2norm_in_kernel: bool = False,\n                                         output_final_state: bool = False):\n        \"\"\"forward.\"\"\"\n        raise NotImplementedError\n\n\nclass GatedDeltaRuleBuilder(ABC):\n    \"\"\"Gated Delta Rule implementation builder.\"\"\"\n\n    @staticmethod\n    @abstractmethod\n    def build() -> GatedDeltaRuleImpl:\n        \"\"\"build.\"\"\"\n        raise NotImplementedError\n"
  },
  {
    "path": "lmdeploy/pytorch/backends/graph_runner.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport functools\nfrom dataclasses import dataclass\nfrom typing import List\n\nimport torch\n\nfrom lmdeploy.pytorch.config import BackendConfig, CacheConfig, ModelConfig\nfrom lmdeploy.pytorch.model_inputs import StepContext\n\n\n@dataclass\nclass GraphRunnerMeta:\n    padding_batch_size: int = None\n\n\n@functools.lru_cache\ndef _get_capture_batch_size_impl(max_batches: int):\n    \"\"\"Capture batch size.\"\"\"\n    ret = []\n    batch_size = 1\n    while batch_size < max_batches:\n        ret.append(batch_size)\n        batch_size *= 2\n    ret.append(max_batches)\n    return ret\n\n\nclass GraphRunner:\n    \"\"\"Graph runner.\"\"\"\n\n    def __init__(self, model: torch.nn.Module, model_config: ModelConfig, cache_config: CacheConfig,\n                 backend_config: BackendConfig, device: torch.device, **kwargs):\n        self.model = model\n        self.ctx_mgr = model.ctx_mgr\n        self.device = device\n        self.model_config = model_config\n        self.cache_config = cache_config\n        self.backend_config = backend_config\n        self._runner_meta = GraphRunnerMeta()\n\n    def __call__(self, **kwargs):\n        \"\"\"Call graph runner forward.\"\"\"\n        return self.model(**kwargs)\n\n    def get_model(self):\n        \"\"\"Get model.\"\"\"\n        return self.model\n\n    def get_logits(self, hidden_states: torch.Tensor):\n        \"\"\"Get logits of model output.\"\"\"\n        if not hasattr(self.model, 'get_logits'):\n            return hidden_states\n        return self.model.get_logits(hidden_states)\n\n    def prepare_inputs_for_generation(\n        self,\n        past_key_values: List[List[torch.Tensor]],\n        inputs_embeds: torch.Tensor = None,\n        context: StepContext = None,\n    ):\n        \"\"\"Prepare inputs.\"\"\"\n        return self.model.prepare_inputs_for_generation(\n            past_key_values,\n            inputs_embeds,\n            context,\n        )\n\n    def update_model_metas(\n        self,\n        past_key_values: List[List[torch.Tensor]],\n        inputs_embeds: torch.Tensor = None,\n        context: StepContext = None,\n    ):\n        \"\"\"Prepare inputs.\"\"\"\n        if hasattr(self.model, 'update_model_metas'):\n            return self.model.update_model_metas(\n                past_key_values,\n                inputs_embeds,\n                context,\n            )\n\n        return None\n\n    def get_input_processor(self):\n        \"\"\"Get input processor.\"\"\"\n        if hasattr(self.model, 'get_input_processor'):\n            return self.model.get_input_processor()\n        else:\n            return None\n\n    def reset(self):\n        \"\"\"Remove all graphs to prevent hanging on exit.\"\"\"\n        pass\n\n    def get_meta(self):\n        \"\"\"Get graphrunner meta.\"\"\"\n        return self._runner_meta\n\n    def update_inputs(self, inputs):\n        return inputs\n\n    def get_capture_batch_sizes(self) -> List[int]:\n        \"\"\"Capture batch sizes.\"\"\"\n        return _get_capture_batch_size_impl(self.cache_config.max_batches)\n"
  },
  {
    "path": "lmdeploy/pytorch/backends/linear.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom abc import ABC, abstractmethod\nfrom typing import List, Optional\n\nimport torch\nimport torch.distributed as dist\n\n\nclass LinearImpl(ABC):\n    \"\"\"Linear implementation api.\"\"\"\n\n    def update_weights(self, weight: torch.Tensor, bias: Optional[torch.Tensor] = None):\n        \"\"\"Update weights.\"\"\"\n        return weight, bias\n\n    @abstractmethod\n    def forward(self,\n                x,\n                weight: torch.Tensor,\n                bias: Optional[torch.Tensor] = None,\n                all_reduce: bool = False,\n                group: dist.ProcessGroup = None,\n                rank: int = 0,\n                scatter_size: List[int] = None):\n        \"\"\"forward.\"\"\"\n        raise NotImplementedError\n\n\nclass LinearBuilder(ABC):\n    \"\"\"Linear implementation builder.\"\"\"\n\n    @staticmethod\n    @abstractmethod\n    def build(in_features: int, out_features: int, bias: bool = True, dtype: torch.dtype = None):\n        \"\"\"build.\"\"\"\n        raise NotImplementedError\n"
  },
  {
    "path": "lmdeploy/pytorch/backends/lora.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom abc import ABC, abstractmethod\nfrom dataclasses import dataclass, field\n\nimport torch\n\nfrom lmdeploy.pytorch.model_inputs import StepContextManager\n\n\n@dataclass\nclass AdapterInfo:\n    \"\"\"Adapter information.\"\"\"\n    in_features: int\n    out_features: int\n    ranks: torch.Tensor\n    scalings: torch.Tensor\n    base_slice: slice\n    rank_offsets: torch.Tensor = field(init=False)\n    max_rank: int = field(init=False)\n\n    def __post_init__(self):\n        \"\"\"Post init.\"\"\"\n        ranks = self.ranks\n        rank_offsets = ranks.cumsum(0) - ranks\n        max_rank = ranks.max().item()\n        self.rank_offsets = rank_offsets\n        self.max_rank = max_rank\n\n\nclass LoRAImpl(ABC):\n    \"\"\"Lora implementation.\"\"\"\n\n    @abstractmethod\n    def forward(self,\n                x: torch.Tensor,\n                base_output: torch.Tensor,\n                lora_A: torch.Tensor,\n                lora_B: torch.Tensor,\n                adapter_info: AdapterInfo,\n                ctx_mgr: StepContextManager,\n                colwise: bool,\n                is_tp: bool = True):\n        \"\"\"forward.\"\"\"\n        raise NotImplementedError\n\n\nclass LoRABuilder(ABC):\n    \"\"\"Lora implementation builder.\"\"\"\n\n    @staticmethod\n    @abstractmethod\n    def build():\n        \"\"\"build.\"\"\"\n        raise NotImplementedError\n"
  },
  {
    "path": "lmdeploy/pytorch/backends/moe.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport functools\nfrom abc import ABC, abstractmethod\nfrom typing import Callable, List, Optional\n\nimport torch\nimport torch.distributed as dist\n\n\nclass SoftmaxTopKImpl(ABC):\n    \"\"\"Softmax topk implementation api.\"\"\"\n\n    @staticmethod\n    @functools.lru_cache\n    def get_group_offsets(n_groups: int, group_size: int, device: str):\n        group_offsets = (torch.arange(n_groups, device=device) * group_size).view(1, -1, 1)  # [1, n_groups, 1]\n        return group_offsets\n\n    @abstractmethod\n    def forward(self, x: torch.Tensor):\n        \"\"\"forward.\"\"\"\n        raise NotImplementedError\n\n\nclass SoftmaxTopKBuilder(ABC):\n    \"\"\"Softmax topk implementation builder.\"\"\"\n\n    @staticmethod\n    @abstractmethod\n    def build(top_k: int, dim: int = -1, n_groups: int = -1):\n        \"\"\"build.\"\"\"\n        raise NotImplementedError\n\n\nclass FusedMoEImpl(ABC):\n    \"\"\"Fused moe implementation.\"\"\"\n\n    def update_weights(self, gate_up_weights: torch.Tensor, down_weights: torch.Tensor):\n        \"\"\"Update weights.\"\"\"\n        return gate_up_weights, down_weights\n\n    def ep_expert_list(self, world_size: int, rank: int):\n        \"\"\"Experts list of current rank.\"\"\"\n        raise NotImplementedError('Not Implemented.')\n\n    @abstractmethod\n    def forward(self,\n                hidden_states: torch.Tensor,\n                topk_weights: torch.Tensor,\n                topk_ids: torch.LongTensor,\n                gate_up_weights: torch.Tensor,\n                down_weights: torch.Tensor,\n                gate_up_bias: torch.Tensor = None,\n                down_bias: torch.Tensor = None,\n                expert_list: List[int] = None,\n                act_func: Callable = None):\n        \"\"\"forward.\"\"\"\n        raise NotImplementedError\n\n\nclass FusedMoEBuilder(ABC):\n    \"\"\"Fused moe builder.\"\"\"\n\n    @staticmethod\n    @abstractmethod\n    def build(top_k: int,\n              num_experts: int,\n              renormalize: bool = False,\n              hidden_dim: int = 1,\n              ep_size: int = 1,\n              ep_group: dist.ProcessGroup = None,\n              layer_idx: int = 0,\n              out_dtype: torch.dtype = torch.bfloat16):\n        \"\"\"Build from mlp.\"\"\"\n        raise NotImplementedError\n\n\nclass FusedMoEW8A8Impl(ABC):\n    \"\"\"Fused moe w8a8 implementation.\"\"\"\n\n    def update_weights(self, gate_up_weights: torch.Tensor, down_weights: torch.Tensor, gate_up_scale: torch.Tensor,\n                       down_scale: torch.Tensor):\n        \"\"\"Update weights.\"\"\"\n        return gate_up_weights, down_weights, gate_up_scale, down_scale\n\n    def ep_expert_list(self, world_size: int, rank: int):\n        \"\"\"Experts list of current rank.\"\"\"\n        raise NotImplementedError('Not Implemented.')\n\n    @abstractmethod\n    def forward(self,\n                hidden_states: torch.Tensor,\n                input_scale: torch.Tensor,\n                topk_weights: torch.Tensor,\n                topk_ids: torch.LongTensor,\n                gate_up_weights: torch.Tensor,\n                gate_up_scale: torch.Tensor,\n                down_weights: torch.Tensor,\n                down_scale: torch.Tensor,\n                expert_list: List[int] = None):\n        \"\"\"forward.\"\"\"\n        raise NotImplementedError\n\n\nclass FusedMoEW8A8Builder(ABC):\n    \"\"\"Fused moe w8a8 builder.\"\"\"\n\n    @staticmethod\n    @abstractmethod\n    def build(top_k: int,\n              num_experts: int,\n              renormalize: bool = False,\n              out_dtype: torch.dtype = torch.float16,\n              quant_dtype: torch.dtype = torch.int8):\n        \"\"\"Build from mlp.\"\"\"\n        raise NotImplementedError\n\n\nclass FusedMoEBlockedF8Impl(ABC):\n    \"\"\"Fused moe blocked f8 implementation.\"\"\"\n\n    def __init__(self):\n        self.scale_fmt: Optional[str] = None\n\n    def update_weights(self, gate_up_weights: torch.Tensor, down_weights: torch.Tensor, gate_up_scale: torch.Tensor,\n                       down_scale: torch.Tensor):\n        \"\"\"Update weights.\"\"\"\n        return gate_up_weights, down_weights, gate_up_scale, down_scale\n\n    def ep_expert_list(self, world_size: int, rank: int):\n        \"\"\"Experts list of current rank.\"\"\"\n        raise NotImplementedError('Not Implemented.')\n\n    def set_scale_fmt(self, scale_fmt: Optional[str]):\n        \"\"\"Set scale fmt.\"\"\"\n        self.scale_fmt = scale_fmt\n\n    @abstractmethod\n    def forward(self,\n                hidden_states: torch.Tensor,\n                input_scale: torch.Tensor,\n                topk_weights: torch.Tensor,\n                topk_ids: torch.LongTensor,\n                gate_up_weights: torch.Tensor,\n                gate_up_scale: torch.Tensor,\n                down_weights: torch.Tensor,\n                down_scale: torch.Tensor,\n                gate_up_bias: torch.Tensor = None,\n                down_bias: torch.Tensor = None,\n                expert_list: List[int] = None,\n                act_func: Callable = None):\n        \"\"\"forward.\"\"\"\n        raise NotImplementedError\n\n\nclass FusedMoEBlockedF8Builder(ABC):\n    \"\"\"Fused moe blocked f8 builder.\"\"\"\n\n    @staticmethod\n    @abstractmethod\n    def build(top_k: int,\n              num_experts: int,\n              hidden_dim: int = 1,\n              renormalize: bool = False,\n              block_size: int = 128,\n              ep_size: int = 1,\n              ep_group: dist.ProcessGroup = None,\n              out_dtype: torch.dtype = torch.float16,\n              layer_idx: int = 0,\n              custom_gateup_act: bool = False):\n        \"\"\"Build from mlp.\"\"\"\n        raise NotImplementedError\n"
  },
  {
    "path": "lmdeploy/pytorch/backends/moe_router.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom abc import ABC, abstractmethod\nfrom typing import Tuple\n\nimport torch\n\n\nclass RouterNoauxTCImpl(ABC):\n    \"\"\"Noaux tc implementation api.\"\"\"\n\n    @abstractmethod\n    def forward(self, logits: torch.Tensor, bias: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"forward.\"\"\"\n        raise NotImplementedError\n\n\nclass RouterNoauxTCBuilder(ABC):\n    \"\"\"Noaux tc implementation builder.\"\"\"\n\n    @staticmethod\n    @abstractmethod\n    def build(\n        scoring_func: str,\n        top_k: int,\n        n_group: int,\n        topk_group: int,\n        n_routed_experts: int,\n        routed_scaling_factor: float,\n        renormalize: bool = True,\n        router_n_groups: int = -1,\n    ):\n        \"\"\"build.\"\"\"\n        raise NotImplementedError\n"
  },
  {
    "path": "lmdeploy/pytorch/backends/multinomial_sampling.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom abc import ABC, abstractmethod\n\nimport torch\n\n\nclass MultinomialSamplingImpl(ABC):\n    \"\"\"Multinomial sampling implementation api.\"\"\"\n\n    @abstractmethod\n    def forward(scores: torch.Tensor, seeds: torch.LongTensor, offsets: torch.LongTensor, indices: torch.Tensor = None):\n        \"\"\"forward.\"\"\"\n        raise NotImplementedError\n\n\nclass MultinomialSamplingBuilder(ABC):\n    \"\"\"Multinomial sampling implementation builder.\"\"\"\n\n    @staticmethod\n    @abstractmethod\n    def build():\n        \"\"\"build.\"\"\"\n        raise NotImplementedError\n"
  },
  {
    "path": "lmdeploy/pytorch/backends/norm.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom abc import ABC, abstractmethod\n\nimport torch\n\n\nclass RMSNormImpl(ABC):\n    \"\"\"RMS norm implementation api.\"\"\"\n\n    @abstractmethod\n    def forward(self, x: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor = None):\n        \"\"\"forward.\"\"\"\n        raise NotImplementedError\n\n\nclass RMSNormBuilder(ABC):\n    \"\"\"RMS norm implementation builder.\"\"\"\n\n    @staticmethod\n    @abstractmethod\n    def build(hidden_size: int, eps: float = 1e-6):\n        \"\"\"build.\"\"\"\n        raise NotImplementedError\n\n\nclass LayerNormImpl(ABC):\n    \"\"\"Layer norm implementation api.\"\"\"\n\n    @abstractmethod\n    def forward(self, x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor = None, residual: torch.Tensor = None):\n        \"\"\"forward.\"\"\"\n        raise NotImplementedError\n\n\nclass LayerNormBuilder(ABC):\n    \"\"\"Layer norm implementation builder.\"\"\"\n\n    @staticmethod\n    @abstractmethod\n    def build(normalized_shape: int, eps: float = 1e-6):\n        \"\"\"build.\"\"\"\n        raise NotImplementedError\n"
  },
  {
    "path": "lmdeploy/pytorch/backends/nsa.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom abc import ABC, abstractmethod\nfrom dataclasses import dataclass\n\nfrom torch import Tensor\n\n\n@dataclass\nclass NSAIndexMeta:\n    \"\"\"Meta info of NSAIndex layer.\"\"\"\n    cu_seqlen_q: Tensor\n    q_seqlens: Tensor\n    k_seqlens: Tensor\n    block_offset: Tensor\n    max_q_seqlen: int = None\n    max_kv_seqlen: int = None\n\n\nclass BaseNSAIndexFP8(ABC):\n\n    @abstractmethod\n    def forward(self, q: Tensor, k: Tensor, weights: Tensor, k_cache: Tensor, k_s_cache: Tensor,\n                meta: NSAIndexMeta) -> Tensor:\n        \"\"\"forward.\"\"\"\n        raise NotImplementedError('Not implemented.')\n\n\nclass BaseNSAIndexFP8Builder:\n\n    @staticmethod\n    @abstractmethod\n    def build(topk: int, softmax_scale: float, block_size: int = 128, fill: int = -1) -> BaseNSAIndexFP8:\n        \"\"\"Build layer implementation.\"\"\"\n        raise NotImplementedError('Not implemented.')\n"
  },
  {
    "path": "lmdeploy/pytorch/backends/qmodules.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom abc import ABC, abstractmethod\nfrom typing import Optional\n\nimport torch\n\n\nclass RMSNormW8A8Impl(ABC):\n    \"\"\"RMS norm w8a8 implementation api.\"\"\"\n\n    @staticmethod\n    def create_weight(hidden_size: int, dtype: torch.dtype = None, device: torch.device = None):\n        \"\"\"Create weight.\"\"\"\n        if dtype is None:\n            dtype = torch.float16\n        if device is None:\n            device = 'cuda'\n        weight = torch.nn.Parameter(torch.ones(hidden_size, dtype=dtype, device=device), requires_grad=False)\n        return weight\n\n    @abstractmethod\n    def forward(self, x: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor = None):\n        \"\"\"forward.\"\"\"\n        raise NotImplementedError\n\n\nclass RMSNormW8A8Builder(ABC):\n    \"\"\"RMS norm w8a8 implementation builder.\"\"\"\n\n    @staticmethod\n    @abstractmethod\n    def build(hidden_size: int, eps: float = 1e-6, quant_dtype: torch.dtype = torch.int8):\n        \"\"\"build.\"\"\"\n        raise NotImplementedError\n\n\nclass LinearW8A8Impl(ABC):\n    \"\"\"Linear w8a8 implementation api.\"\"\"\n\n    def update_weights(self, weight: torch.Tensor, scale: torch.Tensor, bias: Optional[torch.Tensor] = None):\n        \"\"\"Update weights.\"\"\"\n        return weight, scale, bias\n\n    @abstractmethod\n    def forward(self,\n                x,\n                weight: torch.Tensor,\n                scale: torch.Tensor,\n                bias: Optional[torch.Tensor] = None,\n                all_reduce: bool = False,\n                group: Optional[torch.distributed.ProcessGroup] = None):\n        \"\"\"forward.\"\"\"\n        raise NotImplementedError\n\n\nclass LinearW8A8Builder(ABC):\n    \"\"\"Linear w8a8 implementation builder.\"\"\"\n\n    @staticmethod\n    @abstractmethod\n    def build(in_features: int,\n              out_features: int,\n              bias: bool = True,\n              dtype: torch.dtype = None,\n              quant_dtype: torch.dtype = torch.int8):\n        \"\"\"build.\"\"\"\n        raise NotImplementedError\n"
  },
  {
    "path": "lmdeploy/pytorch/backends/rotary_embedding.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom abc import ABC, abstractmethod\nfrom dataclasses import dataclass\nfrom enum import Enum, auto\nfrom typing import List\n\nimport torch\n\n\nclass RopeType(Enum):\n    \"\"\"Rotary embedding type.\"\"\"\n    Default = auto()\n    LinearScaling = auto()\n    DynamicNTKScaling = auto()\n    Llama3 = auto()\n    Yarn = auto()\n    LongRoPEScaling = auto()\n    Fope = auto()\n\n\n@dataclass\nclass YarnParameters:\n    \"\"\"Yarn parameters.\"\"\"\n    beta_fast: int = 32\n    beta_slow: float = 1\n    mscale: int = 1\n    mscale_all_dim: int = 0\n    attention_factor: int = None\n    truncate: bool = True\n\n\n@dataclass\nclass LongRoPEScalingParameters:\n    \"\"\"Long Ropescaling parameters.\"\"\"\n    short_factor: List[int]\n    long_factor: List[int]\n    original_max_position_embeddings: int\n    long_mscale: float = None\n    short_mscale: float = None\n\n\n@dataclass\nclass Llama3Parameters:\n    \"\"\"Llama3 rope parameters.\"\"\"\n    low_freq_factor: float = 1.0\n    high_freq_factor: float = 4.0\n    original_max_position_embeddings: int = 8192\n\n\n@dataclass\nclass FopeParameters:\n    \"\"\"Fope parameters.\"\"\"\n    num_inv_freq: int = None\n    num_key_value_heads: int = 1\n    fope_sep_head: bool = False\n    inv_freq: torch.Tensor = None\n\n\nclass RotaryEmbeddingImpl(ABC):\n    \"\"\"Rotary embedding implementation api.\"\"\"\n\n    @abstractmethod\n    def forward(self, x, position_ids, **kwargs):\n        \"\"\"forward.\"\"\"\n        raise NotImplementedError\n\n\nclass RotaryEmbeddingBuilder(ABC):\n    \"\"\"Rotary embedding implementation builder.\"\"\"\n\n    @staticmethod\n    @abstractmethod\n    def build(\n        dim: int,\n        max_position_embeddings: int = 2048,\n        base: int = 10000,\n        scaling_factor: float = 1.0,\n        yarn_params: YarnParameters = None,\n        longrope_params: LongRoPEScalingParameters = None,\n        llama3_params: Llama3Parameters = None,\n        fope_params: FopeParameters = None,\n        emb_type: RopeType = RopeType.Default,\n    ):\n        \"\"\"build.\"\"\"\n        raise NotImplementedError\n"
  },
  {
    "path": "lmdeploy/pytorch/backends/selector.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom lmdeploy.pytorch.devices import DeviceContext, get_device_manager\n\n\ndef _get_backend():\n    \"\"\"Get device backend implement.\"\"\"\n    device_mgr = get_device_manager()\n    device_ctx = device_mgr.current_context()\n\n    device_type = device_ctx.device_type\n\n    if device_type == 'cuda':\n        from .cuda import CudaOpsBackend\n        return CudaOpsBackend\n    if device_type == 'ascend':\n        from .dlinfer.ascend import AscendOpsBackend\n        return AscendOpsBackend\n    if device_type == 'maca':\n        from .dlinfer.maca import MacaOpsBackend\n        return MacaOpsBackend\n    if device_type == 'camb':\n        from .dlinfer.camb import CambOpsBackend\n        return CambOpsBackend\n    else:\n        raise RuntimeError(f'Unsupported device type: {device_type}')\n\n\ndef get_backend(backend_type: str = None):\n    \"\"\"Get device backend.\"\"\"\n    if backend_type is None:\n        return _get_backend()\n    else:\n        device_ctx = DeviceContext(backend_type)\n        device_mgr = get_device_manager()\n        with device_mgr.context(device_ctx):\n            return _get_backend()\n\n\ndef init_backend(backend_type: str):\n    \"\"\"Init device backend.\"\"\"\n    backend = get_backend(backend_type)\n    backend.init()\n"
  },
  {
    "path": "lmdeploy/pytorch/backends/token_dispatcher.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom abc import ABC, abstractmethod\nfrom typing import Tuple\n\nimport torch\n\n\nclass TokenDispatcherImpl(ABC):\n    \"\"\"Token dispatcher implementation api.\"\"\"\n\n    def permute(\n        self,\n        tokens,\n        routing_map,\n    ):\n        \"\"\"Copy from Megatron-Core moe for token permutation.\"\"\"\n        num_tokens, _ = tokens.shape\n        num_experts = routing_map.shape[1]\n        routing_map = routing_map.bool().T.contiguous()\n        token_indices = (torch.arange(num_tokens, device=routing_map.device).unsqueeze(0).expand(num_experts, -1))\n        sorted_indices = token_indices.masked_select(routing_map)\n        permuted_input = tokens.index_select(0, sorted_indices)\n        return permuted_input, sorted_indices\n\n    def unpermute(\n        self,\n        permuted_tokens: torch.Tensor,\n        sorted_indices: torch.Tensor,\n        restore_shape: torch.Size,\n        probs: torch.Tensor = None,\n        routing_map: torch.Tensor = None,\n    ):\n        \"\"\"Copy from Megatron-Core moe for token unpermutation.\"\"\"\n        _, hidden = restore_shape\n        if probs is not None:\n            assert routing_map is not None, 'Mask must be provided to permute the probs.'\n            permuted_probs = probs.T.contiguous().masked_select(routing_map.T.contiguous())\n            permuted_tokens = permuted_tokens * permuted_probs.unsqueeze(-1)\n        output_tokens = torch.zeros(restore_shape, device=permuted_tokens.device, dtype=permuted_tokens.dtype)\n        output_tokens.scatter_add_(0, sorted_indices.unsqueeze(1).expand(-1, hidden), permuted_tokens)\n        return output_tokens\n\n    def indices_to_multihot(self, topk_ids, topk_weight, num_experts):\n        tokens = topk_ids.shape[0]\n        multihot_routing_map = torch.zeros(\n            (tokens, num_experts),\n            dtype=torch.bool,\n            device=topk_ids.device,\n        )\n\n        multihot_probs = torch.zeros(\n            (tokens, num_experts),\n            dtype=topk_weight.dtype,\n            device=topk_weight.device,\n        )\n\n        mask = topk_ids != -1\n        valid_indices = topk_ids[mask]\n        row_indices = torch.arange(tokens, device=topk_ids.device).repeat_interleave(mask.sum(dim=1))\n        multihot_routing_map[row_indices, valid_indices] = True\n        multihot_probs[row_indices, valid_indices] = topk_weight[mask]\n        return multihot_routing_map, multihot_probs\n\n    @abstractmethod\n    def dispatch(self, hidden_states: torch.Tensor, probs: torch.Tensor, topk_ids: torch.Tensor,\n                 local_expert_indices) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:\n        \"\"\"dispatch.\"\"\"\n        raise NotImplementedError\n\n    @abstractmethod\n    def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        \"\"\"combine.\"\"\"\n        raise NotImplementedError\n"
  },
  {
    "path": "lmdeploy/pytorch/block.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport numpy as np\n\n\ndef _div_up(x, n):\n    \"\"\"Perform div up.\"\"\"\n    return (x + n - 1) // n\n\n\ndef _round_up(x, n):\n    \"\"\"Perform round up.\"\"\"\n    return _div_up(x, n) * n\n\n\nclass LogicalTokenBlocks:\n    \"\"\"Logical blocks.\"\"\"\n    ALLOC_SIZE = 128\n\n    def __init__(self, blocks: np.ndarray = None):\n        if blocks is None:\n            self._blocks = np.zeros((self.ALLOC_SIZE, ), dtype=np.int64)\n            self._num_real = 0\n        else:\n            assert blocks.ndim == 1\n            self._blocks = blocks\n            self._num_real = len(blocks)\n        self.last_shared_node = None\n\n    def reserve(self, size: int):\n        \"\"\"Reserve cache size.\"\"\"\n        num_blocks = self._blocks.size\n        if num_blocks >= size:\n            return\n        reserve_size = _round_up(size - num_blocks, self.ALLOC_SIZE)\n        self._blocks = np.pad(self._blocks, (0, reserve_size))\n\n    def __setitem__(self, *args, **kwargs):\n        \"\"\"Set values.\"\"\"\n        return self.get_real_blocks().__setitem__(*args, **kwargs)\n\n    def __getitem__(self, *args, **kwargs):\n        \"\"\"Get values.\"\"\"\n        return self.get_real_blocks().__getitem__(*args, **kwargs)\n\n    def get_real_blocks(self):\n        \"\"\"Get logical blocks.\"\"\"\n        return self._blocks[:self._num_real]\n\n    def append(self, blocks: np.ndarray):\n        \"\"\"Append blocks.\"\"\"\n        num_blocks = len(blocks)\n        self.reserve(num_blocks + self._num_real)\n        slice_start = self._num_real\n        slice_end = slice_start + num_blocks\n        self._num_real += num_blocks\n        self._blocks[slice_start:slice_end] = blocks\n\n    def __len__(self):\n        \"\"\"Get length.\"\"\"\n        return self._num_real\n\n    def resize(self, num_blocks: int):\n        \"\"\"Resize logical blocks.\"\"\"\n        assert num_blocks <= len(self)\n        self._num_real = num_blocks\n\n    def reset(self):\n        \"\"\"reset.\"\"\"\n        self.resize(0)\n        self.last_shared_node = None\n\n    def clone(self):\n        \"\"\"Clone logical blocks.\"\"\"\n        ret = LogicalTokenBlocks()\n        ret.append(self.get_real_blocks())\n        return ret\n"
  },
  {
    "path": "lmdeploy/pytorch/check_env/__init__.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n"
  },
  {
    "path": "lmdeploy/pytorch/check_env/adapter.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom .base import BaseChecker\n\n\nclass AdapterChecker(BaseChecker):\n    \"\"\"Check adapter is available.\"\"\"\n\n    def __init__(self, adapter_path: str, logger=None):\n        super().__init__(logger)\n        self.adapter_path = adapter_path\n\n    def check(self):\n        \"\"\"check.\"\"\"\n        path = self.adapter_path\n\n        try:\n            import peft  # noqa: F401\n        except Exception as e:\n            self.log_and_exit(e, 'Adapter', message='Failed to import peft.')\n\n        try:\n            from peft import PeftConfig\n            PeftConfig.from_pretrained(path)\n        except Exception as e:\n            message = ('Please make sure the adapter can be loaded with '\n                       '`peft.PeftConfig.from_pretrained`\\n')\n            err_msg = '' if len(e.args) == 0 else e.args[0]\n            if 'got an unexpected keyword argument' in err_msg:\n                message += ('Or try remove all unexpected keywords '\n                            'in `adapter_config.json`.')\n            self.log_and_exit(e, 'Adapter', message=message)\n"
  },
  {
    "path": "lmdeploy/pytorch/check_env/base.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom logging import Logger\nfrom typing import List\n\nfrom lmdeploy.utils import can_colorize, get_logger\n\nRED_COLOR = '\\033[31m'\nRESET_COLOR = '\\033[0m'\n\n\ndef _red_text(text: str):\n    \"\"\"Red text.\"\"\"\n    if not can_colorize():\n        return text\n    return f'{RED_COLOR}{text}{RESET_COLOR}'\n\n\nclass BaseChecker:\n    \"\"\"Base checker.\"\"\"\n\n    def __init__(self, logger: Logger = None):\n        if logger is None:\n            logger = get_logger('lmdeploy')\n        self.logger = logger\n        self._is_passed = False\n        self._required_checker: List[BaseChecker] = list()\n\n    def get_logger(self):\n        \"\"\"Get logger.\"\"\"\n        return self.logger\n\n    def register_required_checker(self, checker: 'BaseChecker'):\n        \"\"\"register_required.\"\"\"\n        self._required_checker.append(checker)\n\n    def handle(self):\n        \"\"\"Handle check.\"\"\"\n        is_passed = getattr(self, '_is_passed', False)\n        if not is_passed:\n            checker_name = type(self).__name__\n            self.logger.debug(f'Checking <{checker_name}>:')\n            for checker in self._required_checker:\n                checker.handle()\n            self.check()\n            self.is_passed = True\n\n    def log_and_exit(self, e: Exception = None, mod_name: str = None, message: str = None):\n        logger = self.logger\n        if mod_name is None:\n            mod_name = type(self).__name__\n        if message is None:\n            message = 'Please check your environment.'\n        logger.debug('Exception', exc_info=1)\n        if e is not None:\n            logger.error(f'{type(e).__name__}: {e}')\n        logger.error(f'<{mod_name}> check failed!\\n{_red_text(message)}')\n        exit(1)\n\n    def check(self):\n        \"\"\"check.\"\"\"\n        raise NotImplementedError('check not implemented.')\n"
  },
  {
    "path": "lmdeploy/pytorch/check_env/cuda.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom .base import BaseChecker\n\n\nclass CudaChecker(BaseChecker):\n    \"\"\"Check pytorch is available.\"\"\"\n\n    def __init__(self, model_format: str = None, logger=None) -> None:\n        super().__init__(logger=logger)\n        self.model_format = model_format\n\n    def check(self):\n        \"\"\"check.\"\"\"\n        import torch\n\n        if not torch.cuda.is_available():\n            self.log_and_exit(mod_name='CUDA', message='cuda is not available.')\n\n        if self.model_format == 'fp8':\n            props = torch.cuda.get_device_properties(0)\n            if props.major < 9:\n                self.log_and_exit(mod_name='CUDA', message='model_format=fp8 requires sm>=9.0.')\n"
  },
  {
    "path": "lmdeploy/pytorch/check_env/deeplink.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom lmdeploy.utils import try_import_deeplink\n\nfrom .base import BaseChecker\n\n\nclass DeeplinkChecker(BaseChecker):\n    \"\"\"Check pytorch is available.\"\"\"\n\n    def __init__(self, device_type: str, logger=None) -> None:\n        super().__init__(logger=logger)\n        self.device_type = device_type\n\n    def check(self):\n        \"\"\"check.\"\"\"\n        try_import_deeplink(self.device_type)\n"
  },
  {
    "path": "lmdeploy/pytorch/check_env/dist.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nfrom lmdeploy.pytorch.config import DistConfig\nfrom lmdeploy.utils import is_dlblas_installed\n\nfrom .base import BaseChecker\n\n\nclass DistChecker(BaseChecker):\n    \"\"\"Check dist environment.\"\"\"\n\n    def __init__(self, tp: int, dp: int, ep: int, distributed_executor_backend: str, device_type: str, logger=None):\n        super().__init__(logger)\n        self.tp = tp\n        self.dp = dp\n        self.ep = ep\n        self.dist_config = DistConfig(dp=dp, tp=tp, ep=ep)\n        self.world_size = self.dist_config.world_size\n        self.distributed_executor_backend = distributed_executor_backend\n        self.device_type = device_type\n\n    def check(self):\n        \"\"\"check.\"\"\"\n        distributed_executor_backend = self.distributed_executor_backend\n\n        if distributed_executor_backend is None:\n            from lmdeploy.pytorch.engine.executor import get_distributed_executor_backend\n            distributed_executor_backend = get_distributed_executor_backend(self.world_size, self.dp, self.device_type)\n\n        if distributed_executor_backend not in [None, 'uni', 'mp', 'ray']:\n            self.log_and_exit(mod_name='Dist',\n                              message=f'Unsupported distributed_executor_backend: {distributed_executor_backend}')\n\n        if distributed_executor_backend == 'uni' and self.world_size > 1:\n            self.log_and_exit(mod_name='Dist',\n                              message='Does not support distributed_executor_backend=\"uni\" and world_size!=1.')\n\n        if self.dp > 1 and distributed_executor_backend != 'ray':\n            self.log_and_exit(mod_name='Dist',\n                              message='dp>1 requires distributed_executor_backend=\"ray\". '\n                              f'Get distributed_executor_backend={distributed_executor_backend}.')\n\n        if self.ep > 1:\n            if self.device_type == 'cuda' and not is_dlblas_installed():\n                self.log_and_exit(mod_name='Dist',\n                                  message='ep>1 requires install dlblas(https://github.com/DeepLink-org/dlBLAS).')\n            if self.ep % self.dp != 0:\n                self.log_and_exit(mod_name='Dist',\n                                  message=f'ep>1 requires ep % dp == 0. Get dp={self.dp} and ep={self.ep}.')\n        elif self.dist_config.enable_eplb:\n            self.log_and_exit(mod_name='Dist', message=f'Enable eplb requires ep > 1. Get ep={self.ep}.')\n\n        if distributed_executor_backend == 'ray':\n            try:\n                import ray  # noqa: F401\n            except BaseException:\n                self.log_and_exit(mod_name='Dist', message='Multi-nodes support requires `ray`.')\n\n            from lmdeploy.pytorch.backends import get_backend\n            backend = get_backend(self.device_type)\n            if not backend.support_ray():\n                self.log_and_exit(mod_name='Dist', message=f'device={self.device_type} does not support ray.')\n"
  },
  {
    "path": "lmdeploy/pytorch/check_env/model.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom packaging import version\n\nfrom .base import BaseChecker\n\n\nclass ModelChecker(BaseChecker):\n    \"\"\"Check model is available.\"\"\"\n\n    def __init__(self, model_path: str, trust_remote_code: bool, dtype: str, device_type: str, logger=None) -> None:\n        super().__init__(logger=logger)\n        self.model_path = model_path\n        self.trust_remote_code = trust_remote_code\n        self.device_type = device_type\n        self.dtype = dtype\n\n    def check_config(self, trans_version):\n        \"\"\"Check config.\"\"\"\n        model_path = self.model_path\n        trust_remote_code = self.trust_remote_code\n        try:\n            from lmdeploy.pytorch.transformers import config_from_pretrained\n            config = config_from_pretrained(model_path, trust_remote_code=trust_remote_code)\n        except Exception as e:\n            message = (f'Load model config with transformers=={trans_version}'\n                       ' failed. '\n                       'Please make sure model can be loaded with transformers API.')\n            self.log_and_exit(e, 'transformers', message=message)\n        return config\n\n    def check_trans_version(self, config, trans_version):\n        \"\"\"Check transformers version.\"\"\"\n        model_path = self.model_path\n        logger = self.get_logger()\n        model_trans_version = getattr(config, 'transformers_version', None)\n        if model_trans_version is not None:\n            model_trans_version = version.parse(model_trans_version)\n            if trans_version < model_trans_version:\n                message = (f'model `{model_path}` requires '\n                           f'transformers version {model_trans_version} '\n                           f'but transformers {trans_version} is installed.')\n                logger.warning(message)\n\n    def check_dtype(self, config):\n        \"\"\"Check dtype.\"\"\"\n        logger = self.get_logger()\n        model_path = self.model_path\n        device_type = self.device_type\n        dtype = self.dtype\n        try:\n            import torch\n\n            from lmdeploy.pytorch.config import ModelConfig\n            from lmdeploy.utils import is_bf16_supported\n            model_config = ModelConfig.from_hf_config(config,\n                                                      model_path=model_path,\n                                                      dtype=dtype,\n                                                      device_type=device_type)\n            if model_config.dtype == torch.bfloat16:\n                if not is_bf16_supported(device_type):\n                    logger.warning('Device does not support bfloat16.')\n        except Exception as e:\n            message = (f'Checking failed with error {e}. Please send issue to LMDeploy with error logs.')\n            self.log_and_exit(e, 'Model', message=message)\n\n        try:\n            model_config.check_env_func(device_type)\n        except Exception as e:\n            message = (f'Checking failed with error {e}.')\n            self.log_and_exit(e, 'Model', message=message)\n\n    def check(self):\n        \"\"\"check.\"\"\"\n        import transformers\n        trans_version = version.parse(transformers.__version__)\n\n        # config\n        config = self.check_config(trans_version)\n\n        # transformers version\n        self.check_trans_version(config, trans_version)\n\n        # dtype check\n        self.check_dtype(config)\n"
  },
  {
    "path": "lmdeploy/pytorch/check_env/torch.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom .base import BaseChecker\n\n\nclass TorchChecker(BaseChecker):\n    \"\"\"Check pytorch is available.\"\"\"\n\n    def __init__(self, device: str = 'cuda', logger=None) -> None:\n        super().__init__(logger=logger)\n        self.device = device\n\n    def check(self):\n        \"\"\"check.\"\"\"\n        try:\n            import torch\n            a = torch.tensor([1, 2], device=self.device)\n            b = a.new_tensor([3, 4], device=self.device)\n            c = a + b\n            torch.testing.assert_close(c, a.new_tensor([4, 6]))\n        except Exception as e:\n            self.log_and_exit(e, 'PyTorch', 'PyTorch is not available.')\n"
  },
  {
    "path": "lmdeploy/pytorch/check_env/transformers.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom packaging import version\n\nfrom .base import BaseChecker\n\nMIN_TRANSFORMERS_VERSION = '4.33.0'\nMAX_TRANSFORMERS_VERSION = '5.2.0'\n\n\nclass TransformersChecker(BaseChecker):\n    \"\"\"Check transformers is available.\"\"\"\n\n    def check(self):\n        \"\"\"check.\"\"\"\n        import transformers\n        logger = self.get_logger()\n        try:\n            trans_version = version.parse(transformers.__version__)\n            min_version = version.parse(MIN_TRANSFORMERS_VERSION)\n            max_version = version.parse(MAX_TRANSFORMERS_VERSION)\n            if trans_version < min_version or trans_version > max_version:\n                logger.warning('LMDeploy requires transformers version: '\n                               f'[{MIN_TRANSFORMERS_VERSION} ~ '\n                               f'{MAX_TRANSFORMERS_VERSION}], '\n                               'but found version: '\n                               f'{transformers.__version__}')\n        except Exception as e:\n            self.log_and_exit(e, 'transformers', 'transformers is not available.')\n"
  },
  {
    "path": "lmdeploy/pytorch/check_env/triton.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom packaging import version\n\nfrom .base import BaseChecker\n\nMAX_TRITON_VERSION = '3.6.0'\nMIN_TRITON_VERSION = '3.0.0'\n\n\nclass TritonChecker(BaseChecker):\n    \"\"\"Check triton is available.\"\"\"\n\n    def check_version(self):\n        \"\"\"Check version.\"\"\"\n        logger = self.get_logger()\n\n        # version check\n        import triton\n        max_version = version.parse(MAX_TRITON_VERSION)\n        min_version = version.parse(MIN_TRITON_VERSION)\n        triton_version = version.parse(triton.__version__)\n\n        if triton_version > max_version:\n            logger.warning('PytorchEngine has not been tested on '\n                           f'triton>{MAX_TRITON_VERSION}.')\n        if triton_version < min_version:\n            msg = (f'triton>={MIN_TRITON_VERSION} is required. '\n                   f'Found triton=={triton_version}')\n            self.log_and_exit(mod_name='Triton', message=msg)\n\n    def check(self):\n        \"\"\"check.\"\"\"\n        logger = self.get_logger()\n\n        msg = (\n            'Please ensure that your device is functioning properly with <Triton>.\\n'  # noqa: E501\n            'You can verify your environment by running '\n            '`python -m lmdeploy.pytorch.check_env.triton_custom_add`.')\n        try:\n            logger.debug('Checking <Triton> environment.')\n            import torch\n\n            from .triton_custom_add import custom_add\n            a = torch.tensor([1, 2], device='cuda')\n            b = a.new_tensor([3, 4], device='cuda')\n            c = custom_add(a, b)\n            torch.testing.assert_close(c, a + b)\n        except RuntimeError as e:\n            ptxas_error = 'device kernel image is invalid'\n            if len(e.args) > 0 and ptxas_error in e.args[0]:\n                msg = (\n                    'This Error might caused by mismatching between NVIDIA Driver and nvcc compiler. \\n'  # noqa: E501\n                    'Try solution https://github.com/triton-lang/triton/issues/1955#issuecomment-1929908209'  # noqa: E501\n                    ' or reinstall the driver.')\n            self.log_and_exit(e, 'Triton', msg)\n        except Exception as e:\n            self.log_and_exit(e, 'Triton', msg)\n\n        # version check\n        self.check_version()\n"
  },
  {
    "path": "lmdeploy/pytorch/check_env/triton_custom_add.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport torch\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef _add_kernel(A, B, C, size, BLOCK: tl.constexpr):\n    \"\"\"Add kernel.\"\"\"\n    prog_id = tl.program_id(0)\n    offs = prog_id * BLOCK + tl.arange(0, BLOCK)\n    a = tl.load(A + offs, mask=offs < size)\n    b = tl.load(B + offs, mask=offs < size)\n    tl.store(C + offs, a + b, mask=offs < size)\n\n\ndef custom_add(a, b):\n    \"\"\"Custom add one.\"\"\"\n    c = torch.empty_like(a)\n    size = c.size(0)\n    BLOCK = 16\n\n    grid = (triton.cdiv(size, BLOCK), )\n    _add_kernel[grid](a, b, c, size, BLOCK=BLOCK)\n    return c\n\n\nif __name__ == '__main__':\n    a = torch.tensor([1, 2], device='cuda')\n    b = a.new_tensor([3, 4], device='cuda')\n    c = custom_add(a, b)\n    torch.testing.assert_close(c, a + b)\n    print('Done.')\n"
  },
  {
    "path": "lmdeploy/pytorch/config.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport enum\nfrom dataclasses import dataclass, field\nfrom typing import Any, Callable, Dict, List, Literal, Optional, Tuple\n\nimport torch\n\nfrom lmdeploy.messages import PytorchEngineConfig\nfrom lmdeploy.pytorch.disagg.config import EngineRole, MigrationBackend\nfrom lmdeploy.pytorch.utils import maybe_register_config_serialize_by_value\nfrom lmdeploy.utils import get_logger, is_bf16_supported\n\nlogger = get_logger('lmdeploy')\n\n\ndef _update_torch_dtype(config: 'ModelConfig', dtype: str, device_type: str = 'auto'):\n    \"\"\"Update the torch dtype from the model config.\n\n    Args:\n        config (ModelConfig): The input model config.\n        dtype (str): user specified data type. Refer to\n            `PyTorchEngineConfig.dtype` for detailed info\n        device_type (str): The device type. Refer to `PyTorchEngineConfig.device_type` for detailed info\n    \"\"\"\n    quantization_config = getattr(config.hf_config, 'quantization_config', dict())\n    quant_method = quantization_config.get('quant_method', None)\n    if quant_method == 'awq':\n        logger.debug('set torch_dtype to float16 for awq.')\n        config.hf_config.torch_dtype = 'float16'\n        config.dtype = torch.float16\n        return config\n\n    torch_dtype = getattr(config.hf_config, 'dtype', None)\n    if torch_dtype is None and hasattr(config.hf_config, 'text_config'):\n        torch_dtype = getattr(config.hf_config.text_config, 'dtype', None)\n\n    if torch_dtype is None:\n        torch_dtype = getattr(config.hf_config, 'torch_dtype', None)\n\n    # deal with case when torch_dtype is not string but torch.dtype\n    if isinstance(torch_dtype, torch.dtype):\n        torch_dtype = str(torch_dtype).split('.')[1]\n\n    if torch_dtype is None:\n        _dtype = 'float16' if dtype == 'auto' else dtype\n        logger.warning('Model config does not have `torch_dtype`,'\n                       f' use: {_dtype}')\n        torch_dtype = _dtype\n        # update hf_config as well\n        setattr(config.hf_config, 'torch_dtype', torch_dtype)\n    else:\n        if torch_dtype == 'bfloat16' and not is_bf16_supported(device_type):\n            torch_dtype = 'float16'\n        # change to user specified data type if it is not 'auto'\n        if dtype == 'auto':\n            torch_dtype = torch_dtype if torch_dtype in ['float16', 'bfloat16'] else 'float16'\n        else:\n            torch_dtype = dtype\n    config.dtype = eval(f'torch.{torch_dtype}')\n    return config\n\n\n@dataclass\nclass BackendConfig:\n    \"\"\"Backend config.\"\"\"\n    eager_mode: bool = True\n    device_type: str = 'cuda'\n\n\n@dataclass\nclass SchedulerConfig:\n    \"\"\"Config of scheduler.\"\"\"\n\n    max_batches: int\n    max_session_len: int\n    max_request_output_len: int = 512\n    eviction_type: str = 'recompute'\n    prefill_interval: int = 16\n    max_active_adapters: int = 64\n\n\n@dataclass\nclass CacheConfig:\n    \"\"\"Config of key value cache.\"\"\"\n\n    max_batches: int\n    block_size: int\n    num_cpu_blocks: int\n    num_gpu_blocks: int\n    window_size: int = -1\n    cache_max_entry_count: float = 0.8\n    max_prefill_token_num: int = 4096\n    enable_prefix_caching: bool = False\n    quant_policy: Literal[0, 4, 8] = 0\n    device_type: str = 'cuda'\n    num_state_caches: int = None\n    states_shapes: List[Tuple] = field(default_factory=list)\n\n    # reserved blocks for dummy inputs, init to 0 for unit test.\n    num_reserved_gpu_blocks: int = 0\n\n    # For PD Disaggregation\n    role: EngineRole = EngineRole.Hybrid\n    migration_backend: MigrationBackend = MigrationBackend.DLSlime\n\n    def __post_init__(self):\n        \"\"\"Post init.\"\"\"\n        if self.window_size > 1 and self.enable_prefix_caching:\n            logger.warning('Prefix caching is not available for window attention.')\n            self.enable_prefix_caching = False\n\n\nclass TPMode(enum.Enum):\n    \"\"\"TP Mode.\"\"\"\n    DEFAULT = enum.auto()\n    DP_TP = enum.auto()\n\n\n@dataclass\nclass DistConfig:\n    dp: int = 1\n    ep: int = 1\n    dp_rank: int = 0\n    enable_microbatch: bool = False\n    enable_eplb: bool = False\n    world_size: int = 1\n\n    # tp\n    tp: int = 1  # default tp, equal to attn_tp\n    attn_tp: int = None  # tp for attention\n    mlp_tp: int = None  # tp for mlp\n    moe_tp: int = None  # tp for moe\n\n    # tp mode\n    mlp_tp_mode: TPMode = TPMode.DEFAULT\n    moe_tp_mode: TPMode = TPMode.DEFAULT\n\n    def __post_init__(self):\n        \"\"\"Post init.\"\"\"\n        assert self.dp_rank < self.dp\n        assert self.dp >= 1\n\n        dp = self.dp\n        tp = self.tp\n        ep = self.ep\n\n        # ignore layer to for dp==1\n        if dp == 1:\n            self.mlp_tp = None\n            self.attn_tp = None\n            self.moe_tp = None\n\n        # mlp and moe tp\n        self.mlp_tp = self.mlp_tp or tp\n        self.moe_tp = self.moe_tp or (1 if ep > 1 else self.mlp_tp)\n\n        # world_size\n        world_size = ep if ep > 1 else max(self.mlp_tp, self.moe_tp)\n        self.world_size = world_size\n        assert (world_size >= dp and world_size % dp == 0), (f'world_size {world_size}, dp {dp}')\n        assert (world_size >= ep and world_size % ep == 0), (f'world_size {world_size}, ep {ep}')\n        assert (world_size >= self.mlp_tp\n                and world_size % self.mlp_tp == 0), (f'world_size {world_size}, mlp_tp {self.mlp_tp}')\n        assert (world_size >= self.moe_tp\n                and world_size % self.moe_tp == 0), (f'world_size {world_size}, moe_tp {self.moe_tp}')\n\n        # attn tp\n        self.attn_tp = self.attn_tp or self.world_size // dp\n        self.tp = self.attn_tp\n        if self.mlp_tp > 1:\n            assert (self.mlp_tp >= self.attn_tp\n                    and self.mlp_tp % self.attn_tp == 0), (f'mlp_tp {self.mlp_tp}, attn_tp {self.attn_tp}')\n        if self.moe_tp > 1:\n            assert (self.moe_tp >= self.attn_tp\n                    and self.moe_tp % self.attn_tp == 0), (f'moe_tp {self.moe_tp}, attn_tp {self.attn_tp}')\n        assert (world_size >= self.attn_tp\n                and world_size % self.attn_tp == 0), (f'world_size {world_size}, attn_tp {self.attn_tp}')\n\n        # tp mode\n        self.mlp_tp_mode = TPMode.DEFAULT if (self.mlp_tp in [1, self.attn_tp]) else TPMode.DP_TP\n        self.moe_tp_mode = TPMode.DEFAULT if (self.moe_tp in [1, self.attn_tp]) else TPMode.DP_TP\n\n    def get_tp_by_layer(self, layer_type: str):\n        \"\"\"Get tp by layer type.\"\"\"\n        if layer_type == 'attn':\n            return self.attn_tp, TPMode.DEFAULT\n        elif layer_type == 'mlp':\n            return self.mlp_tp, self.mlp_tp_mode\n        elif layer_type == 'moe':\n            return self.moe_tp, self.moe_tp_mode\n        elif layer_type is None:\n            # for some layer that we don't need tp\n            return 1, TPMode.DEFAULT\n        else:\n            raise ValueError(f'Unknown layer type: {layer_type}')\n\n    @classmethod\n    def from_engine_config(cls, engine_config: PytorchEngineConfig):\n        \"\"\"From engine config.\"\"\"\n        dist_config = cls(\n            dp=engine_config.dp,\n            ep=engine_config.ep,\n            dp_rank=engine_config.dp_rank,\n            enable_microbatch=engine_config.enable_microbatch,\n            enable_eplb=engine_config.enable_eplb,\n            tp=engine_config.tp,\n            attn_tp=engine_config.attn_tp_size,\n            mlp_tp=engine_config.mlp_tp_size,\n            moe_tp=engine_config.moe_tp_size,\n        )\n        return dist_config\n\n\ndef _override_hf_config_dict(hf_config: dict, key: str, hf_overrides):\n    \"\"\"Override hf_config dict.\"\"\"\n    from transformers import PretrainedConfig\n    if key not in hf_config:\n        # copy if key not in hf_config\n        hf_config[key] = hf_overrides\n        return\n\n    hf_config_val = hf_config[key]\n    is_dict = isinstance(hf_config_val, dict)\n    is_cfg = isinstance(hf_config_val, PretrainedConfig)\n    if not isinstance(hf_overrides, dict) or not (is_dict or is_cfg):\n        # if one of them is not dict, just override\n        hf_config[key] = hf_overrides\n        return\n\n    for key, value in hf_overrides.items():\n        _override_hf_config(hf_config_val, key, value)\n\n\ndef _overide_hf_config_cfg(hf_config: list, key: str, hf_overrides):\n    \"\"\"Override hf_config config.\"\"\"\n    from transformers import PretrainedConfig\n    if getattr(hf_config, key, None) is None:\n        hf_config.update({key: hf_overrides})\n\n    hf_config_val = getattr(hf_config, key)\n    is_dict = isinstance(hf_config_val, dict)\n    is_cfg = isinstance(hf_config_val, PretrainedConfig)\n    if not isinstance(hf_overrides, dict) or not (is_dict or is_cfg):\n        # if one of them is not list, just override\n        hf_config.update({key: hf_overrides})\n        return\n\n    for key, value in hf_overrides.items():\n        _override_hf_config(hf_config_val, key, value)\n\n\ndef _override_hf_config(hf_config: Any, key: str, hf_overrides):\n    \"\"\"Override HF config.\"\"\"\n    if isinstance(hf_config, dict):\n        _override_hf_config_dict(hf_config, key, hf_overrides)\n    else:\n        _overide_hf_config_cfg(hf_config, key, hf_overrides)\n\n\ndef override_hf_config(hf_config: Any, hf_overrides: Dict[str, Any]):\n    \"\"\"Override HF config.\"\"\"\n    for k, v in hf_overrides.items():\n        _override_hf_config(hf_config, k, v)\n\n\ndef _default_check_env(device: str):\n    pass\n\n\ndef _patch_quantization_config(hf_config: Any, model_format: str = None):\n    \"\"\"Patch quantization config.\"\"\"\n    if model_format is None:\n        return hf_config\n\n    # skip the quantized llm and vlm models\n    if hasattr(hf_config, 'quantization_config') or \\\n        (hasattr(hf_config, 'llm_config') and hasattr(hf_config.llm_config, 'quantization_config')) \\\n            or (hasattr(hf_config, 'text_config') and hasattr(hf_config.text_config, 'quantization_config')):\n        logger.warning('Can not perform weight quantization on quantized model.')\n        return hf_config\n\n    if model_format == 'fp8':\n        logger.debug('Patch quantization config for fp8.')\n        from lmdeploy.pytorch.envs import scale_fmt\n        quantization_config = dict(quant_method='fp8', fmt='e4m3', weight_block_size=[128, 128], scale_fmt=scale_fmt)\n    else:\n        raise RuntimeError(f'Unsupported weight quantization method: {model_format}')\n\n    hf_config.quantization_config = quantization_config\n    # for vlm models\n    if hasattr(hf_config, 'text_config'):\n        hf_config.text_config.quantization_config = quantization_config\n    elif hasattr(hf_config, 'llm_config'):\n        hf_config.llm_config.quantization_config = quantization_config\n\n    return hf_config\n\n\n@dataclass\nclass ModelConfig:\n    \"\"\"Config of model.\"\"\"\n\n    hidden_size: int\n    num_layers: int\n    num_attention_heads: int\n    num_key_value_heads: int\n    bos_token_id: int\n    eos_token_id: List[int]\n    head_dim: int\n    k_head_dim: int = None\n    v_head_dim: int = None\n    sliding_window: int = -1\n    dtype: torch.dtype = torch.float16\n    vocab_size: int = 40000\n    hf_config: Any = None\n    llm_config: Any = None\n    cogvlm_style: bool = False\n    custom_module_map: Dict[str, setattr] = None\n\n    # flash mla\n    use_flash_mla: bool = False\n    use_mla_fp8_cache: bool = False\n    mla_index_topk: Optional[int] = None\n\n    # dllm\n    model_paradigm: str = 'ar'\n    dllm_mask_token: int = 0\n    dllm_block_length: int = None\n\n    # Added for deepseekv3.2 nsa index\n    # caches would be added after kv cache\n    cache_shapes: List[Tuple[List[int], torch.dtype]] = field(default_factory=list)\n    # added for qwen3_next\n    # could used for any SSM model.\n    states_shapes: List[Tuple[Tuple[int], torch.dtype]] = field(default_factory=list)\n\n    # check env for model-device combination\n    check_env_func: Callable = _default_check_env\n\n    # fp32 lm head\n    fp32_lm_head: bool = False\n    tie_word_embeddings: bool = False\n\n    # quant config\n    quant_config: 'QuantizationConfig' = None\n\n    def get_head_size(self):\n        \"\"\"Get head size.\"\"\"\n        return self.head_dim\n\n    @classmethod\n    def from_pretrained(\n        cls,\n        pretrained_model_name_or_path: str,\n        trust_remote_code: bool = True,\n        dtype: str = 'auto',\n        dist_config: DistConfig = None,\n        hf_overrides: Dict[str, Any] = None,\n        is_draft_model: bool = False,\n        spec_method: str = None,\n        model_format: str = None,\n        device_type: str = 'auto',\n    ):\n        \"\"\"Instantiate one of the configuration classes of the library from a\n        pretrained model configuration.\n\n        Args:\n            pretrained_model_name_or_path (str): the pretrained model path\n            trust_remote_code (bool):  Whether or not to allow for custom\n                models defined on the Hub in their own modeling files.\n            dtype (str): user specified data type for model weights and\n                activations. Refer to `PyTorchEngineConfig` for details\n            hf_overrides (Dict[str, Any]): overrides for the HF config.\n        \"\"\"\n        from transformers import AutoConfig\n\n        from lmdeploy.pytorch.transformers import config_from_pretrained\n        hf_config = config_from_pretrained(pretrained_model_name_or_path, trust_remote_code=trust_remote_code)\n        if getattr(hf_config, 'model_type', None) in ['phi3']:\n            # phi3 + trust_remote_code leads to error when tp.\n            hf_config = AutoConfig.from_pretrained(pretrained_model_name_or_path)\n\n        # update quantization config\n        hf_config = _patch_quantization_config(hf_config, model_format=model_format)\n\n        model_config = cls.from_hf_config(\n            hf_config,\n            pretrained_model_name_or_path,\n            dtype=dtype,\n            dist_config=dist_config,\n            is_draft_model=is_draft_model,\n            spec_method=spec_method,\n            device_type=device_type,\n        )\n        fp32_lm_head = False\n        if hf_overrides is not None:\n            logger.warning(f'Overriding HF config with {hf_overrides}')\n            fp32_lm_head = hf_overrides.pop('fp32_lm_head', False)\n            override_hf_config(model_config.hf_config, hf_overrides)\n\n        # for fp32 head\n        model_config.fp32_lm_head = fp32_lm_head\n        model_config.tie_word_embeddings = getattr(hf_config, 'tie_word_embeddings', False)\n\n        # for serialization of transformers modules\n        maybe_register_config_serialize_by_value(trust_remote_code)\n\n        # add quant_config\n        model_config.quant_config = QuantizationConfig.from_config(hf_config)\n        return model_config\n\n    @classmethod\n    def from_hf_config(\n        cls,\n        hf_config: Any,\n        model_path: str = None,\n        dtype: str = 'auto',\n        dist_config: DistConfig = None,\n        is_draft_model: bool = False,\n        spec_method: str = None,\n        device_type: str = 'auto',\n    ):\n        \"\"\"From huggingface config.\"\"\"\n        from lmdeploy.pytorch.configurations import AutoModelConfigBuilder\n        if dist_config is None:\n            dist_config = DistConfig()\n        tp = dist_config.attn_tp\n\n        model_config = AutoModelConfigBuilder.build(hf_config,\n                                                    model_path,\n                                                    tp=tp,\n                                                    is_draft_model=is_draft_model,\n                                                    spec_method=spec_method)\n\n        if model_config.k_head_dim is None:\n            assert model_config.head_dim is not None\n            model_config.k_head_dim = model_config.head_dim\n        if model_config.v_head_dim is None:\n            assert model_config.head_dim is not None\n            model_config.v_head_dim = model_config.head_dim\n\n        # check for tp\n        assert model_config.num_attention_heads % tp == 0\n        if model_config.num_key_value_heads >= tp:\n            assert model_config.num_key_value_heads % tp == 0\n        else:\n            assert tp % model_config.num_key_value_heads == 0\n\n        # should after setting `hf_config` and `model_arch` attributes\n        model_config = _update_torch_dtype(model_config, dtype, device_type=device_type)\n\n        # update eos_token_id to list\n        if isinstance(model_config.eos_token_id, int):\n            model_config.eos_token_id = [model_config.eos_token_id]\n\n        return model_config\n\n\nclass UnmaskingStrategy(enum.Enum):\n    \"\"\"Unmasking Strategy.\"\"\"\n\n    # unmasking from left to right\n    SEQUENTIAL = enum.auto()\n    # unmasking with confidence threshold\n    LOW_CONFIDENCE_DYNAMIC = enum.auto()\n    # unmasking with topk in a block\n    LOW_CONFIDENCE_STATIC = enum.auto()\n\n    @classmethod\n    def from_str(cls, strategy: str):\n        \"\"\"From string.\"\"\"\n        strategy = strategy.lower()\n        if strategy == 'sequential':\n            return cls.SEQUENTIAL\n        elif strategy == 'low_confidence_dynamic':\n            return cls.LOW_CONFIDENCE_DYNAMIC\n        elif strategy == 'low_confidence_static':\n            return cls.LOW_CONFIDENCE_STATIC\n        else:\n            raise ValueError(f'Unknown unmasking strategy: {strategy}')\n\n\n@dataclass\nclass DLLMConfig:\n    block_length: int = 1\n    unmasking_strategy: UnmaskingStrategy = UnmaskingStrategy.LOW_CONFIDENCE_DYNAMIC\n    denoising_steps: int = None\n    confidence_threshold: float = 0.85\n\n\n@dataclass\nclass MiscConfig:\n    prefill_interval: int = 16\n    custom_module_map: str = None\n    empty_init: bool = False\n    model_format: str = None\n    hf_overrides: Dict[str, Any] = None\n    disable_vision_encoder: bool = False\n    logprobs_mode: str = None\n    dllm_config: DLLMConfig = None\n    enable_return_routed_experts: bool = False\n    enable_chunked_prefill: bool = False\n\n    @classmethod\n    def from_engine_config(cls, engine_config: PytorchEngineConfig):\n        \"\"\"From engine config.\"\"\"\n        dllm_unmasking_strategy = UnmaskingStrategy.from_str(engine_config.dllm_unmasking_strategy)\n        dllm_config = DLLMConfig(block_length=engine_config.dllm_block_length,\n                                 unmasking_strategy=dllm_unmasking_strategy,\n                                 denoising_steps=engine_config.dllm_denoising_steps,\n                                 confidence_threshold=engine_config.dllm_confidence_threshold)\n        misc_config = cls(\n            custom_module_map=engine_config.custom_module_map,\n            empty_init=engine_config.empty_init,\n            prefill_interval=engine_config.prefill_interval,\n            model_format=engine_config.model_format,\n            hf_overrides=engine_config.hf_overrides,\n            disable_vision_encoder=engine_config.disable_vision_encoder,\n            logprobs_mode=engine_config.logprobs_mode,\n            dllm_config=dllm_config,\n            enable_return_routed_experts=engine_config.enable_return_routed_experts,\n            enable_chunked_prefill=False,\n        )\n        return misc_config\n\n\n@dataclass\nclass SpecDecodeConfig:\n    model: str\n    method: str\n    cache_config: CacheConfig = None\n    num_speculative_tokens: int = 1\n    model_config: ModelConfig = None\n\n    @classmethod\n    def from_config(\n        cls,\n        method: str,\n        num_speculative_tokens: int,\n        model: str,\n        target_cache_cfg: CacheConfig,\n        target_model: str = None,\n        dtype: str = 'auto',\n    ):\n        model = model or target_model\n        model_config = ModelConfig.from_pretrained(model,\n                                                   trust_remote_code=True,\n                                                   dtype=dtype,\n                                                   is_draft_model=True,\n                                                   spec_method=method)\n        cache_config = None\n        # include medusa\n        no_caches = ['medusa']\n        if method not in no_caches:\n            cache_config = CacheConfig(max_batches=target_cache_cfg.max_batches,\n                                       block_size=target_cache_cfg.block_size,\n                                       num_cpu_blocks=target_cache_cfg.num_cpu_blocks,\n                                       num_gpu_blocks=target_cache_cfg.num_gpu_blocks,\n                                       cache_max_entry_count=target_cache_cfg.cache_max_entry_count,\n                                       max_prefill_token_num=target_cache_cfg.max_prefill_token_num,\n                                       device_type=target_cache_cfg.device_type,\n                                       migration_backend=target_cache_cfg.migration_backend)\n        obj = cls(\n            model=model,\n            method=method,\n            cache_config=cache_config,\n            model_config=model_config,\n            num_speculative_tokens=num_speculative_tokens,\n        )\n        return obj\n\n\n@dataclass\nclass QuantizationConfig:\n    quant_method: str = None\n    quant_dtype: torch.dtype = None\n    scale_fmt: str = None\n    bits: int = None\n    group_size: int = None\n    weight_block_size: Tuple[int] = None\n    activation_scheme: str = None\n    ignored_layers: List[str] = field(default_factory=list)\n    hf_quant_config: Dict[str, Any] = field(default_factory=dict)\n\n    @classmethod\n    def from_config(cls, hf_config: Any):\n        quant_config = getattr(hf_config, 'quantization_config', None)\n\n        if quant_config is None:\n            if hasattr(hf_config, 'llm_config') and hasattr(hf_config.llm_config, 'quantization_config'):\n                quant_config = hf_config.llm_config.quantization_config\n            elif hasattr(hf_config, 'text_config') and hasattr(hf_config.text_config, 'quantization_config'):\n                quant_config = hf_config.text_config.quantization_config\n\n        # no quant config found in hf config\n        if quant_config is None:\n            return cls()\n\n        quant_method = quant_config['quant_method']\n        quant_dtype = quant_config.get('quant_dtype', None)\n        scale_fmt = quant_config.get('scale_fmt', None)\n        weight_block_size = quant_config.get('weight_block_size', None)\n        activation_scheme = quant_config.get('activation_scheme', None)\n\n        bits = None\n        group_size = None\n\n        if quant_method == 'awq':\n            bits = quant_config.get('bits', 4)\n            group_size = quant_config.get('group_size', 128)\n        elif quant_method == 'smooth_quant':\n            if quant_dtype is None:\n                quant_dtype = 'int8'\n        elif quant_method == 'fp8':\n            fmt = quant_config.get('fmt', 'e4m3')\n            if fmt == 'e4m3':\n                quant_dtype = 'float8_e4m3fn'\n            elif fmt == 'e5m2':\n                quant_dtype = 'float8_e5m2'\n            else:\n                raise TypeError(f'Unsupported fp8 fmt: {fmt}')\n        else:\n            raise TypeError(f'Unsupported quant method: {quant_method}')\n\n        if quant_dtype is not None:\n            quant_dtype = eval(f'torch.{quant_dtype}')\n\n        ignored_layers = quant_config.get('ignored_layers', [])\n        if not ignored_layers:\n            ignored_layers = quant_config.get('modules_to_not_convert', [])\n\n        return cls(\n            quant_method=quant_method,\n            quant_dtype=quant_dtype,\n            scale_fmt=scale_fmt,\n            bits=bits,\n            group_size=group_size,\n            weight_block_size=weight_block_size,\n            activation_scheme=activation_scheme,\n            ignored_layers=ignored_layers,\n            hf_quant_config=quant_config,\n        )\n\n    def get_quant_method(self, prefix: str = ''):\n        \"\"\"Get quant method for module.\"\"\"\n        if not prefix or not self.ignored_layers:\n            return self.quant_method\n\n        is_ignore = any([prefix in layer_name for layer_name in self.ignored_layers])\n        quant_method = None if is_ignore else self.quant_method\n        return quant_method\n\n    def get(self, key, default=None):\n        \"\"\"Get extra key from hf quant config.\"\"\"\n        return self.hf_quant_config.get(key, default)\n"
  },
  {
    "path": "lmdeploy/pytorch/configurations/__init__.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport importlib\nimport pkgutil\n\nfrom .builder import AutoModelConfigBuilder\n\n__all__ = []\n\n# load all submodule\nfor loader, module_name, is_pkg in pkgutil.walk_packages(__path__):\n    __all__.append(module_name)\n    _module = importlib.import_module('{}.{}'.format(__name__, module_name))\n    globals()[module_name] = _module\n\n__all__ += ['AutoModelConfigBuilder']\n"
  },
  {
    "path": "lmdeploy/pytorch/configurations/builder.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom abc import ABC\n\nfrom lmdeploy.utils import get_logger\n\nlogger = get_logger('lmdeploy')\n\n\nclass AutoModelConfigBuilder(ABC):\n\n    _sub_classes = list()\n\n    def __init_subclass__(cls) -> None:\n        super().__init_subclass__()\n        AutoModelConfigBuilder.register_builder(cls)\n\n    @classmethod\n    def register_builder(cls, sub_cls):\n        \"\"\"Register builder.\"\"\"\n        if sub_cls not in AutoModelConfigBuilder._sub_classes:\n            AutoModelConfigBuilder._sub_classes.append(sub_cls)\n\n    @classmethod\n    def condition(cls, hf_config):\n        \"\"\"config.\"\"\"\n        raise NotImplementedError(f'`condition` of {cls.__name__} not implemented.')\n\n    @classmethod\n    def build(cls, hf_config, model_path: str = None, **kwargs):\n        \"\"\"build.\"\"\"\n        from .default import DefaultModelConfigBuilder\n\n        if cls != AutoModelConfigBuilder:\n            raise NotImplementedError(f'`build` of {cls.__name__} not implemented.')\n\n        valid_builder = DefaultModelConfigBuilder\n        for builder in cls._sub_classes:\n            if builder == valid_builder:\n                continue\n\n            if builder.condition(hf_config):\n                valid_builder = builder\n                break\n\n        logger.debug(f'build model config with {valid_builder.__name__}')\n\n        cfg = valid_builder.build(hf_config, model_path, **kwargs)\n        if cfg.hf_config is None:\n            cfg.hf_config = hf_config\n        if cfg.llm_config is None:\n            cfg.llm_config = hf_config\n\n        return cfg\n\n    @classmethod\n    def update_num_kv_heads(cls, hf_config, tp, num_key_value_heads):\n        \"\"\"Update num kv heads.\"\"\"\n        # update num_kv_heads for tp mode\n        if tp > 1 and tp > num_key_value_heads:\n            assert tp % num_key_value_heads == 0\n            n_replicate = tp // num_key_value_heads\n            hf_config.num_replicate_key_value_heads = n_replicate\n            num_key_value_heads = tp\n\n        hf_config.num_key_value_heads = num_key_value_heads\n        return num_key_value_heads\n"
  },
  {
    "path": "lmdeploy/pytorch/configurations/chatglm.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom lmdeploy.pytorch.config import ModelConfig\n\nfrom .builder import AutoModelConfigBuilder\n\n\nclass ChatGLMModelConfigBuilder(AutoModelConfigBuilder):\n\n    @classmethod\n    def condition(cls, hf_config):\n        \"\"\"config.\"\"\"\n        return hf_config.model_type == 'chatglm'\n\n    @classmethod\n    def build(cls, hf_config, model_path: str = None, **kwargs):\n        \"\"\"build.\"\"\"\n        head_dim = hf_config.hidden_size // hf_config.num_attention_heads\n        bos_token_id = getattr(hf_config, 'bos_token_id', None)\n        if bos_token_id is None:\n            bos_token_id = hf_config.pad_token_id\n\n        if hf_config.multi_query_attention:\n            num_key_value_heads = hf_config.multi_query_group_num\n        else:\n            num_key_value_heads = hf_config.num_attention_heads\n\n        tp = kwargs.get('tp', 1)\n        # update num_kv_heads for tp mode\n        num_key_value_heads = cls.update_num_kv_heads(hf_config, tp, num_key_value_heads)\n\n        cfg = ModelConfig(hidden_size=hf_config.hidden_size,\n                          num_layers=hf_config.num_layers,\n                          num_attention_heads=hf_config.num_attention_heads,\n                          num_key_value_heads=num_key_value_heads,\n                          bos_token_id=bos_token_id,\n                          eos_token_id=hf_config.eos_token_id,\n                          head_dim=head_dim,\n                          vocab_size=hf_config.padded_vocab_size)\n        # glm-4v\n        if hasattr(hf_config, 'vision_config'):\n            cfg.cogvlm_style = True\n        return cfg\n"
  },
  {
    "path": "lmdeploy/pytorch/configurations/cogvlm.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom .builder import AutoModelConfigBuilder\nfrom .default import DefaultModelConfigBuilder\n\n\nclass CogVLMModelConfigBuilder(AutoModelConfigBuilder):\n\n    @classmethod\n    def condition(cls, hf_config):\n        \"\"\"config.\"\"\"\n        model_arch = hf_config.architectures[0] if hf_config.architectures else None\n        return model_arch == 'CogVLMForCausalLM'\n\n    @classmethod\n    def build(cls, hf_config, model_path: str = None, **kwargs):\n        \"\"\"build.\"\"\"\n        from lmdeploy.utils import is_bf16_supported\n        if getattr(hf_config, 'num_multi_query_heads', None):\n            hf_config.num_key_value_heads = hf_config.num_multi_query_heads\n        else:\n            hf_config.num_key_value_heads = hf_config.num_attention_heads\n\n        cfg = DefaultModelConfigBuilder.build(hf_config, model_path, **kwargs)\n        cfg.cogvlm_style = True\n        torch_dtype = 'bfloat16' if is_bf16_supported() else 'float16'\n        hf_config.torch_dtype = torch_dtype\n        return cfg\n"
  },
  {
    "path": "lmdeploy/pytorch/configurations/deepseek_v2.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom lmdeploy.pytorch.config import ModelConfig\n\nfrom .builder import AutoModelConfigBuilder\nfrom .utils import flash_mla_available\n\n\nclass DeepseekV2ModelConfigBuilder(AutoModelConfigBuilder):\n\n    @classmethod\n    def condition(cls, hf_config):\n        \"\"\"config.\"\"\"\n        return hf_config.model_type in ['deepseek_v3', 'deepseek_v2', 'kimi_k2']\n\n    @classmethod\n    def build(cls, hf_config, model_path: str = None, is_draft_model: bool = False, spec_method: str = None, **kwargs):\n        \"\"\"build.\"\"\"\n        head_dim = (hf_config.kv_lora_rank + hf_config.qk_rope_head_dim)\n        k_head_dim = head_dim\n        v_head_dim = 0\n        num_attention_heads = hf_config.num_attention_heads\n        # multi query attn\n        num_key_value_heads = 1\n        tp = kwargs.get('tp', 1)\n        # update num_kv_heads for tp mode\n        num_key_value_heads = cls.update_num_kv_heads(hf_config, tp, num_key_value_heads)\n        hf_config.use_flash_mla = flash_mla_available()\n        num_layers = hf_config.num_hidden_layers\n        model_paradigm = 'ar'\n\n        if spec_method is not None:\n            assert spec_method == 'deepseek_mtp'\n\n        # draft model cfg\n        if is_draft_model:\n            num_layers = hf_config.num_nextn_predict_layers\n            hf_config.architectures[0] = 'DeepseekMTPModel'\n            # remove for correct mapping when building the patched model\n            if hasattr(hf_config, 'auto_map'):\n                del hf_config.auto_map\n\n        if is_draft_model or spec_method is not None:\n            model_paradigm = 'ar_spec'\n\n        bos_token_id = getattr(hf_config, 'bos_token_id', None)\n        config = ModelConfig(\n            hidden_size=hf_config.hidden_size,\n            num_layers=num_layers,\n            num_attention_heads=num_attention_heads,\n            num_key_value_heads=num_key_value_heads,\n            bos_token_id=bos_token_id,\n            eos_token_id=hf_config.eos_token_id,\n            head_dim=head_dim,\n            k_head_dim=k_head_dim,\n            v_head_dim=v_head_dim,\n            vocab_size=hf_config.vocab_size,\n            use_flash_mla=hf_config.use_flash_mla,\n            model_paradigm=model_paradigm,\n        )\n        return config\n"
  },
  {
    "path": "lmdeploy/pytorch/configurations/deepseek_v32.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport torch\n\nfrom .deepseek_v2 import DeepseekV2ModelConfigBuilder\n\n\ndef _check_env_v32(device: str = 'cuda'):\n    \"\"\"Environment check.\"\"\"\n    if device != 'cuda':\n        return\n\n    # check cuda\n    try:\n        import fast_hadamard_transform  # noqa: F401\n    except ImportError:\n        raise ImportError('Deepseek V3.2 requires <fast_hadamard_transform>.')\n\n    try:\n        import flash_mla  # noqa: F401\n    except ImportError:\n        raise ImportError('Deepseek V3.2 requires <flash_mla>.')\n\n    if not hasattr(flash_mla, 'flash_mla_sparse_fwd'):\n        raise RuntimeError('Latest flash_mla is required: https://github.com/deepseek-ai/FlashMLA.')\n\n\nclass DeepseekV32ModelConfigBuilder(DeepseekV2ModelConfigBuilder):\n\n    @classmethod\n    def condition(cls, hf_config):\n        \"\"\"config.\"\"\"\n        return hf_config.model_type in ['deepseek_v32', 'glm_moe_dsa']\n\n    @classmethod\n    def build(cls, hf_config, model_path: str | None = None, **kwargs):\n        \"\"\"build.\"\"\"\n        config = DeepseekV2ModelConfigBuilder.build(hf_config, model_path=model_path, **kwargs)\n\n        assert hf_config.use_flash_mla, 'DeepSeek-V3.2 requires flash_mla to be available.'\n        index_k_shape = ([hf_config.index_head_dim], torch.float8_e4m3fn)\n        index_k_scale_shape = ([1], torch.float32)\n        config.cache_shapes = [index_k_shape, index_k_scale_shape]\n        config.use_mla_fp8_cache = True\n        config.mla_index_topk = hf_config.index_topk\n        config.check_env_func = _check_env_v32\n        return config\n"
  },
  {
    "path": "lmdeploy/pytorch/configurations/deepseek_vl2.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom .builder import AutoModelConfigBuilder\nfrom .default import DefaultModelConfigBuilder\n\n\nclass DeepseekVLV2ModelConfigBuilder(AutoModelConfigBuilder):\n\n    @classmethod\n    def condition(cls, hf_config):\n        \"\"\"config.\"\"\"\n        return hf_config.model_type in ['deepseek_vl_v2']\n\n    @classmethod\n    def build(cls, hf_config, model_path: str = None, **kwargs):\n        \"\"\"Build deepseek-vl2.\"\"\"\n\n        if hf_config.language_config.use_mla:\n            from .deepseek_v2 import DeepseekV2ModelConfigBuilder\n            cfg = DeepseekV2ModelConfigBuilder.build(hf_config.language_config, model_path, **kwargs)\n            cfg.hf_config = hf_config\n        else:\n            # deepseek-vl2-tiny uses MHA, rather than MLA\n            # in this case, we use DefaultModelConfigBuilder\n            cfg = DefaultModelConfigBuilder.build(hf_config.language_config, model_path, **kwargs)\n            cfg.hf_config = hf_config\n\n        return cfg\n"
  },
  {
    "path": "lmdeploy/pytorch/configurations/default.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom lmdeploy.pytorch.config import ModelConfig\n\nfrom .builder import AutoModelConfigBuilder\n\n\nclass DefaultModelConfigBuilder(AutoModelConfigBuilder):\n\n    @classmethod\n    def condition(cls, hf_config):\n        \"\"\"config.\"\"\"\n        return True\n\n    @classmethod\n    def build(cls, hf_config, model_path: str = None, **kwargs):\n        \"\"\"build.\"\"\"\n        head_dim = getattr(hf_config, 'head_dim', None)\n        head_dim = head_dim or hf_config.hidden_size // hf_config.num_attention_heads\n\n        # head_dim should not be None\n        hf_config.head_dim = head_dim\n        num_attention_heads = hf_config.num_attention_heads\n        num_key_value_heads = getattr(hf_config, 'num_key_value_heads', num_attention_heads)\n        use_sliding_window = getattr(hf_config, 'use_sliding_window', True)\n        sliding_window = -1\n        if use_sliding_window:\n            sliding_window = getattr(hf_config, 'sliding_window', sliding_window) or -1\n        tp = kwargs.get('tp', 1)\n        # update num_kv_heads for tp mode\n        num_key_value_heads = cls.update_num_kv_heads(hf_config, tp, num_key_value_heads)\n\n        return ModelConfig(\n            hidden_size=hf_config.hidden_size,\n            num_layers=hf_config.num_hidden_layers,\n            num_attention_heads=hf_config.num_attention_heads,\n            num_key_value_heads=num_key_value_heads,\n            bos_token_id=hf_config.bos_token_id,\n            eos_token_id=hf_config.eos_token_id,\n            sliding_window=sliding_window,\n            head_dim=head_dim,\n            k_head_dim=head_dim,\n            v_head_dim=head_dim,\n            vocab_size=hf_config.vocab_size,\n            llm_config=hf_config,\n        )\n"
  },
  {
    "path": "lmdeploy/pytorch/configurations/gemma.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom .builder import AutoModelConfigBuilder\nfrom .default import DefaultModelConfigBuilder\n\n\nclass GemmaModelConfigBuilder(AutoModelConfigBuilder):\n\n    @classmethod\n    def condition(cls, hf_config):\n        \"\"\"config.\"\"\"\n        return hf_config.model_type in ['gemma', 'gemma2', 'gemma3_text']\n\n    @classmethod\n    def build(cls, hf_config, model_path: str = None, **kwargs):\n        \"\"\"Build gemma.\"\"\"\n        cfg = DefaultModelConfigBuilder.build(hf_config, model_path, **kwargs)\n        cfg.head_dim = hf_config.head_dim\n        return cfg\n\n\nclass GemmaVLModelConfigBuilder(AutoModelConfigBuilder):\n\n    @classmethod\n    def condition(cls, hf_config):\n        \"\"\"config.\"\"\"\n        model_arch = hf_config.architectures[0] if hf_config.architectures else None\n        return model_arch == 'Gemma3ForConditionalGeneration'\n\n    @classmethod\n    def build(cls, hf_config, model_path: str = None, **kwargs):\n        \"\"\"Build gemma.\"\"\"\n        hf_config.text_config.architectures = ['Gemma3ForCausalLM']\n        cfg = DefaultModelConfigBuilder.build(hf_config.text_config, model_path, **kwargs)\n        # gemma 3 does not enable sliding window on every layers\n        cfg.sliding_window = -1\n        cfg.hf_config = hf_config\n        return cfg\n"
  },
  {
    "path": "lmdeploy/pytorch/configurations/glm4.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom .deepseek_v2 import DeepseekV2ModelConfigBuilder\nfrom .default import DefaultModelConfigBuilder\n\n\nclass Glm4MoeLiteModelConfigBuilder(DeepseekV2ModelConfigBuilder):\n\n    @classmethod\n    def condition(cls, hf_config):\n        \"\"\"config.\"\"\"\n        return hf_config.model_type in ['glm4_moe_lite']\n\n    @classmethod\n    def build(cls, hf_config, model_path: str = None, is_draft_model: bool = False, spec_method: str = None, **kwargs):\n        \"\"\"build.\"\"\"\n        # set default attrs\n        if not hasattr(hf_config, 'scoring_func'):\n            hf_config.scoring_func = 'sigmoid'\n        if not hasattr(hf_config, 'moe_layer_freq'):\n            hf_config.moe_layer_freq = 1\n        return super().build(hf_config,\n                             model_path=model_path,\n                             is_draft_model=is_draft_model,\n                             spec_method=spec_method,\n                             **kwargs)\n\n\nclass Glm4MoeModelConfigBuilder(DefaultModelConfigBuilder):\n\n    @classmethod\n    def condition(cls, hf_config):\n        \"\"\"config.\"\"\"\n        return hf_config.model_type in ['glm4_moe']\n\n    @classmethod\n    def build(cls, hf_config, model_path: str = None, is_draft_model: bool = False, spec_method: str = None, **kwargs):\n        \"\"\"build.\"\"\"\n\n        num_layers = hf_config.num_hidden_layers\n        model_paradigm = 'ar'\n\n        if spec_method is not None:\n            assert spec_method == 'deepseek_mtp'\n\n        # draft model cfg\n        if is_draft_model:\n            num_layers = hf_config.num_nextn_predict_layers\n            hf_config.architectures[0] = 'Glm4MoeMTPModel'\n            # remove for correct mapping when building the patched model\n            if hasattr(hf_config, 'auto_map'):\n                del hf_config.auto_map\n\n        if is_draft_model or spec_method is not None:\n            model_paradigm = 'ar_spec'\n\n        cfg = super().build(hf_config,\n                            model_path=model_path,\n                            is_draft_model=is_draft_model,\n                            spec_method=spec_method,\n                            **kwargs)\n        cfg.model_paradigm = model_paradigm\n        cfg.num_layers = num_layers\n        return cfg\n"
  },
  {
    "path": "lmdeploy/pytorch/configurations/gpt_oss.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom .builder import AutoModelConfigBuilder\nfrom .default import DefaultModelConfigBuilder\n\n\nclass GptOSSModelConfigBuilder(AutoModelConfigBuilder):\n\n    @classmethod\n    def condition(cls, hf_config):\n        \"\"\"config.\"\"\"\n        return hf_config.model_type in ['gpt_oss']\n\n    @classmethod\n    def build(cls, hf_config, model_path: str = None, **kwargs):\n        \"\"\"Build gemma.\"\"\"\n        cfg = DefaultModelConfigBuilder.build(hf_config, model_path, **kwargs)\n        # gpt_oss 3 does not enable sliding window on every layers\n        cfg.sliding_window = -1\n        cfg.hf_config = hf_config\n        return cfg\n"
  },
  {
    "path": "lmdeploy/pytorch/configurations/interns1_pro.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom .builder import AutoModelConfigBuilder\nfrom .default import DefaultModelConfigBuilder\n\n\nclass InterS1ProModelConfigBuilder(AutoModelConfigBuilder):\n\n    @classmethod\n    def condition(cls, hf_config):\n        \"\"\"config.\"\"\"\n        return hf_config.model_type in ['interns1_pro', 'interns1_1']\n\n    @classmethod\n    def build(cls, hf_config, model_path: str = None, **kwargs):\n        \"\"\"build.\"\"\"\n        if hasattr(hf_config, 'quantization_config') and not hasattr(hf_config.text_config, 'quantization_config'):\n            setattr(hf_config.text_config, 'quantization_config', hf_config.quantization_config)\n        cfg = DefaultModelConfigBuilder.build(hf_config.text_config, model_path, **kwargs)\n        setattr(hf_config, 'dtype', hf_config.text_config.dtype)\n        cfg.hf_config = hf_config\n        return cfg\n"
  },
  {
    "path": "lmdeploy/pytorch/configurations/internvl.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom .builder import AutoModelConfigBuilder\nfrom .default import DefaultModelConfigBuilder\n\n\nclass InternVLModelConfigBuilder(AutoModelConfigBuilder):\n\n    @classmethod\n    def condition(cls, hf_config):\n        \"\"\"config.\"\"\"\n        return hf_config.architectures[0] == 'InternVLChatModel'\n\n    @classmethod\n    def build(cls, hf_config, model_path: str = None, **kwargs):\n        \"\"\"Build llava hf.\"\"\"\n        cfg = DefaultModelConfigBuilder.build(hf_config.llm_config, model_path, **kwargs)\n        cfg.hf_config = hf_config\n        return cfg\n"
  },
  {
    "path": "lmdeploy/pytorch/configurations/internvl3_hf.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom .builder import AutoModelConfigBuilder\nfrom .default import DefaultModelConfigBuilder\n\n\nclass InternVL3ModelConfigBuilder(AutoModelConfigBuilder):\n\n    @classmethod\n    def condition(cls, hf_config):\n        \"\"\"config.\"\"\"\n        return hf_config.architectures[0] in ['InternVLForConditionalGeneration', 'InternS1ForConditionalGeneration']\n\n    @classmethod\n    def build(cls, hf_config, model_path: str = None, **kwargs):\n        \"\"\"Build config.\"\"\"\n        # hack quantization_config\n        if hasattr(hf_config, 'quantization_config') and not hasattr(hf_config.text_config, 'quantization_config'):\n            setattr(hf_config.text_config, 'quantization_config', hf_config.quantization_config)\n\n        # fix transformers>5\n        if hasattr(hf_config.text_config, 'tie_word_embeddings'):\n            hf_config.tie_word_embeddings = hf_config.text_config.tie_word_embeddings\n\n        cfg = DefaultModelConfigBuilder.build(hf_config.text_config, model_path, **kwargs)\n        cfg.hf_config = hf_config\n        return cfg\n"
  },
  {
    "path": "lmdeploy/pytorch/configurations/llama.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom .builder import AutoModelConfigBuilder\nfrom .default import DefaultModelConfigBuilder\n\n\nclass LlamaModelConfigBuilder(AutoModelConfigBuilder):\n\n    @classmethod\n    def condition(cls, hf_config):\n        \"\"\"config.\"\"\"\n        return hf_config.architectures[0] in ['LlamaForCausalLM']\n\n    @classmethod\n    def build(cls, hf_config, model_path: str = None, is_draft_model: bool = False, spec_method: str = None, **kwargs):\n        \"\"\"Build llama.\"\"\"\n        cfg = DefaultModelConfigBuilder.build(hf_config, model_path, **kwargs)\n\n        if is_draft_model:\n            # update draft model arch\n            assert spec_method is not None\n            hf_config.architectures[0] = spec_method.capitalize() + hf_config.architectures[0]\n            cfg.vocab_size = getattr(hf_config, 'draft_vocab_size', hf_config.vocab_size)\n            cfg.model_paradigm = 'ar_spec'\n        elif spec_method is not None:\n            # add aux_hidden_state_layers for eagle3\n            if spec_method == 'eagle3':\n                num_layers = cfg.num_layers\n                hf_config.aux_hidden_state_layers = (2, num_layers // 2, num_layers - 3)\n            cfg.model_paradigm = 'ar_spec'\n        cfg.hf_config = hf_config\n        return cfg\n"
  },
  {
    "path": "lmdeploy/pytorch/configurations/llama4.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom .builder import AutoModelConfigBuilder\nfrom .default import DefaultModelConfigBuilder\n\n\nclass Llama4ModelConfigBuilder(AutoModelConfigBuilder):\n\n    @classmethod\n    def condition(cls, hf_config):\n        \"\"\"config.\"\"\"\n        return hf_config.model_type in ['llama4']\n\n    @classmethod\n    def build(cls, hf_config, model_path: str = None, **kwargs):\n        \"\"\"Build llama4.\"\"\"\n        cfg = DefaultModelConfigBuilder.build(hf_config.text_config, model_path, **kwargs)\n        cfg.hf_config = hf_config\n\n        return cfg\n"
  },
  {
    "path": "lmdeploy/pytorch/configurations/llava_hf.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom lmdeploy.pytorch.config import ModelConfig\n\nfrom .builder import AutoModelConfigBuilder\n\n\nclass LlavaHfModelConfigBuilder(AutoModelConfigBuilder):\n\n    @classmethod\n    def condition(cls, hf_config):\n        \"\"\"config.\"\"\"\n        return hf_config.architectures[0] in ['LlavaForConditionalGeneration', 'LlavaNextForConditionalGeneration']\n\n    @classmethod\n    def build(cls, hf_config, model_path: str = None, **kwargs):\n        \"\"\"Build llava hf.\"\"\"\n        text_config = hf_config.text_config\n        hidden_size = getattr(text_config, 'hidden_size', 4096)\n        num_attention_heads = getattr(text_config, 'num_attention_heads', 32)\n        num_key_value_heads = getattr(text_config, 'num_key_value_heads', 32)\n        num_hidden_layers = getattr(text_config, 'num_hidden_layers', 32)\n        bos_token_id = getattr(text_config, 'bos_token_id', 1)\n        eos_token_id = getattr(text_config, 'eos_token_id', 2)\n        head_dim = hidden_size // num_attention_heads\n\n        return ModelConfig(\n            hidden_size=hidden_size,\n            num_layers=num_hidden_layers,\n            num_attention_heads=num_attention_heads,\n            num_key_value_heads=num_key_value_heads,\n            bos_token_id=bos_token_id,\n            eos_token_id=eos_token_id,\n            head_dim=head_dim,\n            vocab_size=text_config.vocab_size,\n            hf_config=hf_config,\n        )\n"
  },
  {
    "path": "lmdeploy/pytorch/configurations/minicpm3.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nfrom .builder import AutoModelConfigBuilder\nfrom .default import DefaultModelConfigBuilder\n\n\nclass MiniCPM3ModelConfigBuilder(AutoModelConfigBuilder):\n\n    @classmethod\n    def condition(cls, hf_config):\n        \"\"\"config.\"\"\"\n        return hf_config.architectures[0] in ['MiniCPM3ForCausalLM']\n\n    @classmethod\n    def build(cls, hf_config, model_path: str = None, **kwargs):\n        \"\"\"build.\"\"\"\n        head_dim = (hf_config.qk_nope_head_dim + hf_config.qk_rope_head_dim)\n\n        cfg = DefaultModelConfigBuilder.build(hf_config, model_path, **kwargs)\n        cfg.head_dim = head_dim\n        cfg.k_head_dim = head_dim\n        cfg.v_head_dim = head_dim\n\n        return cfg\n"
  },
  {
    "path": "lmdeploy/pytorch/configurations/qwen.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom .builder import AutoModelConfigBuilder\nfrom .default import DefaultModelConfigBuilder\n\n\nclass QwenModelConfigBuilder(AutoModelConfigBuilder):\n\n    @classmethod\n    def condition(cls, hf_config):\n        \"\"\"config.\"\"\"\n        return hf_config.model_type == 'qwen'\n\n    @classmethod\n    def build(cls, hf_config, model_path: str = None, **kwargs):\n        \"\"\"build.\"\"\"\n        from lmdeploy.utils import is_bf16_supported\n        cfg = DefaultModelConfigBuilder.build(hf_config, model_path, **kwargs)\n        if cfg.bos_token_id is None:\n            cfg.bos_token_id = 151644\n        if cfg.eos_token_id is None:\n            cfg.eos_token_id = 151645\n\n        torch_dtype = 'bfloat16' if is_bf16_supported() else 'float16'\n        if hf_config.bf16 and is_bf16_supported():\n            torch_dtype = 'bfloat16'\n        elif hf_config.fp16:\n            torch_dtype = 'float16'\n        hf_config.torch_dtype = torch_dtype\n        return cfg\n"
  },
  {
    "path": "lmdeploy/pytorch/configurations/qwen3_5.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport torch\n\nfrom lmdeploy.utils import is_bf16_supported\n\nfrom .builder import AutoModelConfigBuilder\nfrom .default import DefaultModelConfigBuilder\nfrom .qwen3_next import _check_env_qwen3_next\n\n\nclass Qwen3_5ModelConfigBuilder(AutoModelConfigBuilder):\n\n    @classmethod\n    def condition(cls, hf_config):\n        \"\"\"config.\"\"\"\n        return hf_config.model_type in ['qwen3_5', 'qwen3_5_moe']\n\n    @classmethod\n    def build(cls, hf_config, model_path: str = None, tp: int = 1, **kwargs):\n        \"\"\"build.\"\"\"\n        text_config = hf_config.text_config\n        # propagate quantization_config from top-level hf_config into text_config\n        quantization_config = getattr(hf_config, 'quantization_config', None)\n        if quantization_config is not None and not hasattr(text_config, 'quantization_config'):\n            text_config.quantization_config = quantization_config\n        cfg = DefaultModelConfigBuilder.build(text_config, model_path, tp=tp, **kwargs)\n\n        # update num layers\n        num_layers = cfg.num_layers\n        layer_types = text_config.layer_types\n        num_delta_layers = sum([1 for lt in layer_types if lt == 'linear_attention'])\n        num_full_layers = num_layers - num_delta_layers\n        cfg.num_layers = num_full_layers\n\n        # set state shapes\n        head_k_dim = text_config.linear_key_head_dim\n        head_v_dim = text_config.linear_value_head_dim\n        num_v_heads = text_config.linear_num_value_heads // tp\n        num_k_heads = text_config.linear_num_key_heads // tp\n        key_dim = head_k_dim * num_k_heads\n        value_dim = head_v_dim * num_v_heads\n        conv_dim = key_dim * 2 + value_dim\n        conv_kernel_size = text_config.linear_conv_kernel_dim\n\n        conv_state_shape = (num_delta_layers, conv_dim, conv_kernel_size)\n        recurrent_state_shape = (num_delta_layers, num_v_heads, head_k_dim, head_v_dim)\n        if is_bf16_supported():\n            dtype = torch.bfloat16\n        else:\n            dtype = torch.float16\n        cfg.states_shapes = [(conv_state_shape, dtype), (recurrent_state_shape, dtype)]\n        cfg.check_env_func = _check_env_qwen3_next\n        return cfg\n"
  },
  {
    "path": "lmdeploy/pytorch/configurations/qwen3_next.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport torch\n\nfrom .builder import AutoModelConfigBuilder\nfrom .default import DefaultModelConfigBuilder\n\n\ndef _check_env_qwen3_next(device: str):\n    \"\"\"Check env for qwen3 next.\"\"\"\n    if device != 'cuda':\n        return\n\n    try:\n        import fla  # noqa: F401\n    except ImportError:\n        raise ImportError('Qwen3-Next cuda support requires https://github.com/fla-org/flash-linear-attention.')\n\n\nclass Qwen3NextModelConfigBuilder(AutoModelConfigBuilder):\n\n    @classmethod\n    def condition(cls, hf_config):\n        \"\"\"config.\"\"\"\n        return hf_config.model_type == 'qwen3_next'\n\n    @classmethod\n    def build(cls, hf_config, model_path: str = None, tp: int = 1, **kwargs):\n        \"\"\"build.\"\"\"\n        cfg = DefaultModelConfigBuilder.build(hf_config, model_path, tp=tp, **kwargs)\n\n        # update num layers\n        num_layers = cfg.num_layers\n        num_full_layers = num_layers // hf_config.full_attention_interval\n        num_delta_layers = num_full_layers * (hf_config.full_attention_interval - 1)\n        cfg.num_layers = num_full_layers\n\n        # set state shapes\n        head_k_dim = hf_config.linear_key_head_dim\n        head_v_dim = hf_config.linear_value_head_dim\n        num_v_heads = hf_config.linear_num_value_heads // tp\n        num_k_heads = hf_config.linear_num_key_heads // tp\n        key_dim = head_k_dim * num_k_heads\n        value_dim = head_v_dim * num_v_heads\n        conv_dim = key_dim * 2 + value_dim\n        conv_kernel_size = hf_config.linear_conv_kernel_dim\n\n        conv_state_shape = (num_delta_layers, conv_dim, conv_kernel_size)\n        recurrent_state_shape = (num_delta_layers, num_v_heads, head_k_dim, head_v_dim)\n        dtype = torch.bfloat16\n        cfg.states_shapes = [(conv_state_shape, dtype), (recurrent_state_shape, dtype)]\n        cfg.check_env_func = _check_env_qwen3_next\n        return cfg\n"
  },
  {
    "path": "lmdeploy/pytorch/configurations/qwen3_vl.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom .builder import AutoModelConfigBuilder\nfrom .default import DefaultModelConfigBuilder\n\n\nclass Qwen3VLModelConfigBuilder(AutoModelConfigBuilder):\n\n    @classmethod\n    def condition(cls, hf_config):\n        \"\"\"config.\"\"\"\n        return hf_config.model_type in ['qwen2_vl', 'qwen2_5_vl', 'qwen3_vl', 'qwen3_vl_moe']\n\n    @classmethod\n    def build(cls, hf_config, model_path: str = None, **kwargs):\n        \"\"\"build.\"\"\"\n        if not hasattr(hf_config, 'text_config'):\n            # for transformers <= 5\n            return DefaultModelConfigBuilder.build(hf_config, model_path, **kwargs)\n\n        if hasattr(hf_config, 'quantization_config') and not hasattr(hf_config.text_config, 'quantization_config'):\n            setattr(hf_config.text_config, 'quantization_config', hf_config.quantization_config)\n        cfg = DefaultModelConfigBuilder.build(hf_config.text_config, model_path, **kwargs)\n        setattr(hf_config, 'dtype', hf_config.text_config.dtype)\n        cfg.hf_config = hf_config\n        return cfg\n"
  },
  {
    "path": "lmdeploy/pytorch/configurations/sdar.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom .default import AutoModelConfigBuilder, DefaultModelConfigBuilder\n\n\nclass SDARModelConfigBuilder(AutoModelConfigBuilder):\n\n    @classmethod\n    def condition(cls, hf_config):\n        \"\"\"config.\"\"\"\n        return hf_config.model_type in ['sdar', 'sdar_moe']\n\n    @classmethod\n    def build(cls, hf_config, model_path: str = None, **kwargs):\n        \"\"\"build.\"\"\"\n        cfg = DefaultModelConfigBuilder.build(hf_config, model_path, **kwargs)\n        cfg.dllm_mask_token = 151669\n        cfg.model_paradigm = 'dllm'\n        return cfg\n"
  },
  {
    "path": "lmdeploy/pytorch/configurations/utils.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport torch\n\nfrom lmdeploy.utils import get_logger\n\nlogger = get_logger('lmdeploy')\n\n\ndef flash_mla_available():\n    \"\"\"Check if flash mla is available.\"\"\"\n    # use flash_mla by default if it is installed\n    use_flash_mla = False\n    try:\n        \"\"\"In some torch_npu versions, device_properties doesn't have 'major'\n        attribute; In other torch_npu versions, the value of major is None.\"\"\"\n        device_properties = torch.cuda.get_device_properties(0)\n        major = getattr(device_properties, 'major', None)\n        if isinstance(major, int) and major >= 9:\n            import flash_mla  # noqa\n            use_flash_mla = True\n    except ImportError:\n        logger.warning('For higher performance, please install flash_mla https://github.com/deepseek-ai/FlashMLA')\n    return use_flash_mla\n\n\ndef flash_attn_v3_available():\n    \"\"\"Check if flash attn v3 is available.\"\"\"\n    use_fa3 = False\n    try:\n        # Now flash-attention only support FA3 for sm90a && cuda >= 12.3\n        if (torch.cuda.get_device_capability()[0] == 9) and (torch.version.cuda >= '12.3'):\n            import flash_attn_interface  # noqa: F401\n            assert torch.ops.flash_attn_3 is not None\n            use_fa3 = True\n    except Exception:\n        logger.warning('For higher performance, please install FlashAttention-3 '\n                       'https://github.com/Dao-AILab/flash-attention')\n    return use_fa3\n"
  },
  {
    "path": "lmdeploy/pytorch/consts.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n# dllm\nDLLM_MASKED = 0\nDLLM_UNMASKED = 1\nDLLM_CACHED = 2\n"
  },
  {
    "path": "lmdeploy/pytorch/devices/__init__.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom .device_manager import DefaultContext, DeviceContext, get_device_manager\n\n__all__ = ['DeviceContext', 'DefaultContext', 'get_device_manager']\n"
  },
  {
    "path": "lmdeploy/pytorch/devices/device_manager.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom dataclasses import dataclass\nfrom typing import Callable\n\nfrom lmdeploy.pytorch.utils import CtxMgrBase, singleton\n\n\n@dataclass\nclass DeviceContext:\n    device_type: str = 'cuda'\n\n\nDefaultContext = DeviceContext()\n\n\n@singleton\nclass DeviceManager(CtxMgrBase[DeviceContext]):\n\n    def __init__(self):\n        super().__init__(DefaultContext)\n        self._context_callback: dict[int, Callable] = dict()\n        self._next_cb_handle = 0\n\n    def register_context_callback(self, callback: Callable):\n        \"\"\"Register callback.\"\"\"\n        handle = self._next_cb_handle\n        self._context_callback[handle] = callback\n        self._next_cb_handle += 1\n        return handle\n\n    def unregister_context_callback(self, handle: int):\n        \"\"\"Unregister callback.\"\"\"\n        self._context_callback.pop(handle, None)\n\n\ndef get_device_manager():\n    \"\"\"Get device manager.\"\"\"\n    return DeviceManager()\n"
  },
  {
    "path": "lmdeploy/pytorch/disagg/README.md",
    "content": "# LMDeploy-DistServe\n\n## Key Components\n\n1. ​**Router Service**: Coordinates between prefill/decode engines\n2. ​**Migration Manager**: Facilitates high-performance memory sharing\n\n## Installation\n\n```\n# Inference Engine\npip install lmdeploy[all] >= 0.7.0\n\n# Transfer Engine\npip install dlslime>=0.0.2\n```\n\n## Quick Start\n\nA PD disaggregated deployment of internlm2_5-7b-chat is shown below:\n\n### 1. Launch Router Service\n\n```shell\nlmdeploy serve proxy --server-name 0.0.0.0 --server-port 8000 --routing-strategy \"min_expected_latency\" --serving-strategy DistServe --log-level INFO\n```\n\nLMDeploy-DistServe support both NVLink and RDMA for kvcache transferring from Prefill Engine to Decode Engine. RDMA is default model. Set `--migration-protocol NVLink` for NVLink transport.\n\n### 2. Configure Endpoints\n\nFirst deploy your prefill and decode engines.\n\n```shell\n# Prefill Engine\nCUDA_VISIBLE_DEVICES=0 lmdeploy serve api_server internlm/internlm2_5-7b-chat --server-port 23333 --role Prefill --proxy-url http://0.0.0.0:8000 --backend pytorch\n# Decode Engine\nCUDA_VISIBLE_DEVICES=1 lmdeploy serve api_server internlm/internlm2_5-7b-chat --server-port 23334 --role Decode --proxy-url http://0.0.0.0:8000 --backend pytorch\n```\n\nBy now, only **Pytorch backend** supports PD Disaggregation.\n\n## API Usage\n\n```shell\n# API Invoke\ncurl -X POST \"http://localhost:8000/v1/completions\" \\\n-H \"Content-Type: application/json\" \\\n-d '{\"model\": \"internlm/internlm2_5-7b-chat\", \"temperature\":0, \"prompt\": \"Shanghai is a city that \", \"max_tokens\": 16, \"stream\": false}'\n# Output\n{\n  \"id\":\"2\",\n  \"object\":\"text_completion\",\n  \"created\":1743662400,\"\n  model\":\"internlm/internlm2_5-7b-chat\",\n  \"choices\":[\n    {\n      \"index\":0,\n      \"text\":\" is very famous for its skyscrapers. It is also a city\",\"logprobs\":null,\"finish_reason\":\"length\"\n    }\n  ],\n  \"usage\": {\n    \"prompt_tokens\":7,\"total_tokens\":23,\"completion_tokens\":16\n  }\n}\n```\n\n## Trouble Shooting\n\n### RDMA Connection Failed:\n\nMake sure ibverbs is correctly installed:\n\n```\n# on Ubuntu\nsudo apt install libibverbs-dev\n# on CentOS\nsudo yum install ibverbs-devel\n```\n\n```bash\nibstat        # Verify IB device status\nibv_devinfo   # Check device capabilities\n```\n\n### Check GPU Direct RDMA:\n\nBy now, lmdeploy-distserve use GPUDirect RDMA to perform KVTransfer. Make sure GPUDirect RDMA Driver is loaded to kernel.\n\n```bash\nlsmod | grep nv_peer_mem\n# GPUDirect RDMA info will be printed If GPUDirect RDMA is correctly loaded.\n```\n\n### Connection Pool\n\nCurrently, if the ​​Proxy disconnects​​, the connection pool must be ​​warmed up again​​. A future enhancement could involve:\n\nA ​​dedicated connection pool management server​​ (e.g., using ​​Raft-based tools like ETCD​​, as mentioned in ​​Mooncake​​) to improve ​​connection discovery​​ and avoid repeated warmups.\n\n### Proxy\n\nDo not add an engine nodes to **different proxy** because it is not supported and is not considered as a right usage by now.\n"
  },
  {
    "path": "lmdeploy/pytorch/disagg/__init__.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n"
  },
  {
    "path": "lmdeploy/pytorch/disagg/backend/__init__.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom lmdeploy.logger import get_logger\n\nlogger = get_logger('lmdeploy')\n\ntry:\n    logger.debug('Registering DLSlime Backend')\n    from .dlslime import DLSlimeBackend\nexcept ImportError:\n    logger.debug('Disable DLSlime Backend')\n\ntry:\n    logger.debug('Registering Mooncake Backend')\n    from .mooncake import MooncakeBackend\nexcept ImportError:\n    logger.warning('Disable Mooncake Backend')\n\n__all__ = ['DLSlimeBackend', 'MooncakeBackend']\n"
  },
  {
    "path": "lmdeploy/pytorch/disagg/backend/backend.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom mmengine.registry import Registry\n\nMIGRATION_BACKENDS = Registry('migration_backend', locations=['lmdeploy.pytorch.disagg.backend.backend'])\n"
  },
  {
    "path": "lmdeploy/pytorch/disagg/backend/base.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom abc import abstractmethod\n\nfrom lmdeploy.pytorch.disagg.conn.protocol import (DistServeInitRequest, DistServeKVTransferEndpointInfo,\n                                                   MigrationProtocol)\nfrom lmdeploy.pytorch.disagg.messages import DistServeRegisterMRMessage, MigrationAssignment\n\n\nclass MigrationBackendImpl:\n\n    @abstractmethod\n    def p2p_initialize(self, init_request: DistServeInitRequest):\n        raise NotImplementedError\n\n    @abstractmethod\n    def register_memory_region(self, register_mr_request: DistServeRegisterMRMessage):\n        raise NotImplementedError\n\n    @abstractmethod\n    def endpoint_info(self, remote_engine_id: str, protocol: MigrationProtocol):\n        return NotImplementedError\n\n    @abstractmethod\n    def p2p_connect(self, remote_engine_id: str, conn_req: DistServeKVTransferEndpointInfo):\n        raise NotImplementedError\n\n    @abstractmethod\n    def p2p_migrate(self, assignment: MigrationAssignment, async_op: bool = False):\n        raise NotImplementedError\n\n    @abstractmethod\n    def store(self, assignment: MigrationAssignment, async_op: bool = False):\n        raise NotImplementedError\n\n    @abstractmethod\n    def load(self, assignment: MigrationAssignment, async_op: bool = False):\n        raise NotImplementedError\n"
  },
  {
    "path": "lmdeploy/pytorch/disagg/backend/dlslime.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport asyncio\nimport json\nimport os\nfrom typing import Dict\n\nfrom dlslime import RDMAEndpoint, available_nic\n\nfrom lmdeploy.logger import get_logger\nfrom lmdeploy.pytorch.disagg.backend.backend import MIGRATION_BACKENDS\nfrom lmdeploy.pytorch.disagg.backend.base import MigrationBackendImpl\nfrom lmdeploy.pytorch.disagg.config import DistServeEngineConfig, MigrationBackend\nfrom lmdeploy.pytorch.disagg.conn.protocol import (DistServeInitRequest, DistServeKVTransferEndpointInfo,\n                                                   MigrationProtocol)\nfrom lmdeploy.pytorch.disagg.messages import DistServeRegisterMRMessage, MigrationAssignment\n\nlogger = get_logger('lmdeploy')\n\nLMDEPLOY_USE_ASYNC_MIGRATION = os.environ.get('LMDEPLOY_USE_ASYNC_MIGRATION', None)\n\n\nclass DLSlimeMigrationManagement:\n\n    def __init__(self, init_request: DistServeInitRequest):\n        self.rank = init_request.rank\n        self.local_engine_config: DistServeEngineConfig = (init_request.local_engine_config)\n        self.remote_engine_config: DistServeEngineConfig = (init_request.remote_engine_config)\n        self.endpoint: Dict[MigrationProtocol, RDMAEndpoint] = {}\n        if init_request.protocol == MigrationProtocol.RDMA:\n            nics = available_nic()\n            device_name = nics[self.rank % len(nics)]\n            logger.info(f'use device {device_name} for kv migration')\n            self.endpoint[MigrationProtocol.RDMA] = RDMAEndpoint(\n                device_name=device_name,\n                ib_port=1,\n                link_type=init_request.rdma_config.link_type.name,\n            )\n        elif init_request.protocol == MigrationProtocol.NVLINK:\n            try:\n                from dlslime import NVLinkEndpoint\n            except ImportError:\n                logger.warning('Notice: DLSlime not compiled from source with NVLink. Fallback to RDMAEndpoint.')\n                NVLinkEndpoint = RDMAEndpoint\n            self.endpoint[MigrationProtocol.NVLINK] = NVLinkEndpoint()\n\n    def register_memory_region(self, register_mr_request: DistServeRegisterMRMessage):\n        self.endpoint[register_mr_request.protocol].register_memory_region(\n            register_mr_request.mr_key,\n            register_mr_request.addr,\n            register_mr_request.offset,\n            register_mr_request.length,\n        )\n\n    def connect(self, kvtransfer_endpoint_info: DistServeKVTransferEndpointInfo):\n        self.endpoint[kvtransfer_endpoint_info.protocol].connect(json.loads(kvtransfer_endpoint_info.endpoint_info))\n\n    async def p2p_migrate(self, assignment: MigrationAssignment):\n        batch = [(\n            assign.mr_key,\n            assign.mr_key,\n            assign.target_offset,\n            assign.source_offset,\n            assign.length,\n        ) for assign in assignment.batch]\n\n        future = self.endpoint[assignment.protocol].read(batch)\n        if LMDEPLOY_USE_ASYNC_MIGRATION:\n            loop = asyncio.get_running_loop()\n            return await loop.run_in_executor(None, future.wait)\n        else:\n            return future.wait()\n\n\n@MIGRATION_BACKENDS.register_module(MigrationBackend.DLSlime.name)\nclass DLSlimeBackend(MigrationBackendImpl):\n    \"\"\"DLSlime Transfer Engine.\"\"\"\n\n    def __init__(self):\n        self.links: Dict[str, DLSlimeMigrationManagement] = {}\n\n    def p2p_initialize(self, init_request: DistServeInitRequest):\n        self.links[init_request.remote_engine_id] = DLSlimeMigrationManagement(init_request)\n\n    def register_memory_region(self, register_mr_request: DistServeRegisterMRMessage):\n        self.links[register_mr_request.remote_engine_id].register_memory_region(register_mr_request)\n\n    def endpoint_info(self, remote_engine_id: str, protocol: MigrationProtocol):\n        return self.links[remote_engine_id].endpoint[protocol].endpoint_info()\n\n    def p2p_connect(self, remote_engine_id: str, conn_req: DistServeKVTransferEndpointInfo):\n        self.links[remote_engine_id].connect(conn_req)\n\n    async def p2p_migrate(self, assignment: MigrationAssignment, async_op: bool = False):\n        await self.links[assignment.remote_engine_id].p2p_migrate(assignment)\n\n    def store(self, assignment: MigrationAssignment, async_op: bool = False):\n        raise NotImplementedError\n\n    def load(self, assignment: MigrationAssignment, async_op: bool = False):\n        raise NotImplementedError\n"
  },
  {
    "path": "lmdeploy/pytorch/disagg/backend/mooncake.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nimport json\nimport os\nimport socket\nimport subprocess\nfrom typing import Dict\n\nfrom lmdeploy.pytorch.disagg.backend.backend import MIGRATION_BACKENDS\nfrom lmdeploy.pytorch.disagg.backend.base import MigrationBackendImpl\nfrom lmdeploy.pytorch.disagg.config import MigrationBackend, MooncakeEngineConfig\nfrom lmdeploy.pytorch.disagg.conn.protocol import (DistServeInitRequest, DistServeKVTransferEndpointInfo,\n                                                   MigrationProtocol)\nfrom lmdeploy.pytorch.disagg.messages import DistServeRegisterMRMessage, MigrationAssignment\nfrom lmdeploy.utils import get_logger\n\nlogger = get_logger('lmdeploy')\n\nLMDEPLOY_USE_ASYNC_MIGRATION = os.environ.get('LMDEPLOY_USE_ASYNC_MIGRATION', None)\n\n\ndef get_rdma_nics():\n    \"\"\"Get all available RDMA network interface cards on the current machine.\n\n    Returns:\n        list: List of RDMA NICs, e.g. ['erdma_0', 'erdma_1']\n    \"\"\"\n    rdma_nics = []\n\n    try:\n        result = subprocess.run(['ibv_devices'], stdout=subprocess.PIPE, text=True)\n        if result.returncode == 0:\n            # Parse ibv_devices output\n            # Sample output:\n            # device                 node GUID\n            # ------              ----------------\n            lines = result.stdout.strip().split('\\n')\n            for line in lines[2:]:  # Skip header lines\n                if line.strip():\n                    device_name = line.split()[0].strip()\n                    rdma_nics.append(device_name)\n    except Exception as e:\n        logger.error(f'Error executing ibv_devices command: {e}')\n\n    return rdma_nics\n\n\ndef get_local_ip_by_remote() -> str:\n    # try ipv4\n    s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)\n    try:\n        s.connect(('8.8.8.8', 80))  # Doesn't need to be reachable\n        return s.getsockname()[0]\n    except Exception:\n        pass\n\n    # try ipv6\n    try:\n        s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)\n        # Google's public DNS server, see\n        # https://developers.google.com/speed/public-dns/docs/using#addresses\n        s.connect(('2001:4860:4860::8888', 80))  # Doesn't need to be reachable\n        return s.getsockname()[0]\n    except Exception:\n        raise ValueError('Can not get local ip')\n\n\nclass MooncakeMigrationManagement:\n    \"\"\"Manages migration for a single connection in Mooncake backend.\"\"\"\n\n    def __init__(self, init_request: DistServeInitRequest):\n        try:\n            from mooncake.engine import TransferEngine\n        except ImportError as e:\n            raise ImportError('Please install mooncake by following the instructions at '\n                              'https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md '\n                              'to run LMDeploy with MooncakeBackend.') from e\n\n        self.rank = init_request.rank\n        self.local_engine_config: MooncakeEngineConfig = init_request.local_engine_config\n        self.remote_engine_config: MooncakeEngineConfig = init_request.remote_engine_config\n        self.local_engine_id = init_request.local_engine_id\n        self.remote_engine_id = init_request.remote_engine_id\n\n        self.engine = TransferEngine()\n        self.hostname = get_local_ip_by_remote()\n\n        # Get all RDMA information once during initialization\n        self.ibv_devices = get_rdma_nics()\n\n        self.local_kv_table: Dict[str, Dict] = {}\n        self.remote_kv_table: Dict[str, Dict] = {}\n        self.remote_url: str = ''  # Store remote URL for this connection\n\n        # Initialize the p2p connection\n        self._initialize_p2p(init_request)\n\n        self.port: int = self.engine.get_rpc_port()\n\n    def _initialize_p2p(self, init_request: DistServeInitRequest):\n        \"\"\"Initialize p2p connection for this specific link.\"\"\"\n        # TODO: Support more types of metadata_server\n        # e.g. \"etcd://192.168.0.137:2379\"\n        metadata_server = 'P2PHANDSHAKE'\n\n        # Default protocol (Currently only RDMA is supported)\n        protocol = 'rdma'\n\n        # Get the device name from request\n        if not self.ibv_devices:\n            raise RuntimeError('No RDMA devices available')\n\n        device_name = self.ibv_devices[self.rank % len(self.ibv_devices)]\n\n        # Initialize the engine\n        result = self.engine.initialize(self.hostname, metadata_server, protocol, device_name)\n        if result != 0:\n            raise RuntimeError(f'Failed to initialize Mooncake engine: {result}')\n\n        logger.info(f'Mooncake engine initialized for remote_engine_id {self.remote_engine_id} '\n                    f'with hostname {self.hostname}, RPC port: {self.engine.get_rpc_port()}')\n\n    def register_memory_region(self, register_mr_request: DistServeRegisterMRMessage):\n        \"\"\"Register memory region for this connection.\"\"\"\n        # Transmit buffer address to int\n        buffer_addr = register_mr_request.addr\n        buffer_length = register_mr_request.length\n\n        # Register memory region with the engine\n        result = self.engine.register_memory(buffer_addr, buffer_length)\n        if result != 0:\n            raise RuntimeError(f'Failed to register memory region: {result}')\n\n        mr_key = str(register_mr_request.mr_key)\n        self.local_kv_table[mr_key] = {\n            'addr': buffer_addr,\n            'length': buffer_length,\n            'offset': register_mr_request.offset\n        }\n\n        logger.info(f'Registered memory region with mr_key {mr_key}, '\n                    f'addr: {buffer_addr}, length: {buffer_length} for remote_engine_id {self.remote_engine_id}')\n\n    @property\n    def endpoint_info(self) -> Dict:\n        \"\"\"Get endpoint information for this connection.\"\"\"\n\n        mr_info = {}\n        for mr_key, buffer_info in self.local_kv_table.items():\n            mr_info[mr_key] = {\n                'addr': buffer_info['addr'],\n                'length': buffer_info['length'],\n                'offset': buffer_info['offset']\n            }\n\n        endpoint_info = {'mr_info': mr_info, 'session_id': f'{self.hostname}:{self.port}'}\n\n        logger.info(f'Generated endpoint info for remote engine {self.remote_engine_id}: '\n                    f\"session_id={endpoint_info['session_id']}, \"\n                    f'mr_count={len(mr_info)}')\n\n        return endpoint_info\n\n    def connect(self, connect_request: DistServeKVTransferEndpointInfo):\n        \"\"\"Connect to the remote engine.\"\"\"\n        remote_endpoint_info = json.loads(connect_request.endpoint_info)\n\n        self.remote_url = remote_endpoint_info['session_id']\n        self.remote_kv_table = remote_endpoint_info['mr_info']\n\n        logger.info(f'Received remote buffer info: {len(self.remote_kv_table)} regions')\n        for mr_key, buffer_info in self.remote_kv_table.items():\n            logger.debug(f\"Remote buffer mr_key {mr_key}: addr=0x{buffer_info['addr']:x}, \"\n                         f\"length={buffer_info['length']}\")\n\n        logger.info(f'Connecting to remote engine {self.remote_engine_id} at {self.remote_url}')\n\n    async def p2p_migrate(self, assignment: MigrationAssignment, async_op: bool = False):\n        \"\"\"Migrate data to the remote engine.\"\"\"\n        if not LMDEPLOY_USE_ASYNC_MIGRATION:\n            # For synchronous migration, call the method directly\n            self._migrate(assignment)\n        else:\n            # For asynchronous migration, use an async method\n            import asyncio\n            loop = asyncio.get_event_loop()\n            future = loop.create_future()\n\n            await loop.run_in_executor(None, self._migrate, assignment)\n\n            result = await future\n            if result != 0:\n                raise RuntimeError(f'Failed to perform async transfer: {result}')\n\n    def _migrate(self, assignment: MigrationAssignment):\n        \"\"\"Migrate data to the remote engine synchronously.\"\"\"\n        if not self.remote_url:\n            raise RuntimeError(f'No connection established to remote engine {self.remote_engine_id}')\n\n        for i, task in enumerate(assignment.batch):\n            mr_key = str(task.mr_key)\n\n            if mr_key not in self.local_kv_table:\n                raise RuntimeError(f'Memory region with mr_key {mr_key} not registered locally')\n\n            if mr_key not in self.remote_kv_table:\n                raise RuntimeError(f'Remote memory region with mr_key {mr_key} not registered')\n\n            # Get local buffer information\n            local_buffer_info = self.local_kv_table[mr_key]\n            local_addr = local_buffer_info['addr'] + task.source_offset\n\n            # Get remote buffer information\n            remote_buffer_info = self.remote_kv_table[mr_key]\n            remote_addr = remote_buffer_info['addr'] + task.target_offset\n\n            logger.debug(f'Task {i}: Migrating {task.length} bytes')\n            logger.debug(f'  Local Engine: {self.local_engine_id}')\n            logger.debug(f'  Remote Engine: {assignment.remote_engine_id}')\n            logger.debug(f'  MR Key: {mr_key}')\n            logger.debug(f\"  Local:  0x{local_buffer_info['addr']:x} + {task.source_offset} = 0x{local_addr:x}\")\n            logger.debug(f\"  Remote: 0x{remote_buffer_info['addr']:x} + {task.target_offset} = 0x{remote_addr:x}\")\n            logger.debug(f'  Session: {self.remote_url}')\n\n            result = self.engine.transfer_sync_read(\n                self.remote_url,\n                local_addr,\n                remote_addr,\n                task.length,\n            )\n            if result != 0:\n                raise RuntimeError(f'Failed to perform sync transfer: {result}')\n\n\n@MIGRATION_BACKENDS.register_module(MigrationBackend.Mooncake.name)\nclass MooncakeBackend(MigrationBackendImpl):\n    \"\"\"Mooncake backend that manages multiple migration connections.\"\"\"\n\n    def __init__(self):\n        self.links: Dict[int, MooncakeMigrationManagement] = {}\n\n    def p2p_initialize(self, init_request: DistServeInitRequest):\n        self.links[init_request.remote_engine_id] = MooncakeMigrationManagement(init_request)\n\n    def register_memory_region(self, register_mr_request: DistServeRegisterMRMessage):\n        self.links[register_mr_request.remote_engine_id].register_memory_region(register_mr_request)\n\n    def endpoint_info(self, remote_engine_id: int, protocol: MigrationProtocol):\n        return self.links[remote_engine_id].endpoint_info\n\n    def p2p_connect(self, remote_engine_id: str, connect_request: DistServeKVTransferEndpointInfo):\n        self.links[remote_engine_id].connect(connect_request)\n\n    async def p2p_migrate(self, assignment: MigrationAssignment, async_op: bool = False):\n        await self.links[assignment.remote_engine_id].p2p_migrate(assignment, async_op=async_op)\n\n    def store(self, assignment: MigrationAssignment, async_op: bool = False):\n        raise NotImplementedError\n\n    def load(self, assignment: MigrationAssignment, async_op: bool = False):\n        raise NotImplementedError\n"
  },
  {
    "path": "lmdeploy/pytorch/disagg/config.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport enum\nfrom typing import Optional\n\nfrom pydantic import BaseModel\n\n\nclass ServingStrategy(enum.Enum):\n    \"\"\"Serving Strategy.\n\n    Attributes:\n        Hybrid: Prefill and Decode workload are co-located in one engine.\n        DistServe: Prefill and Decode workload are assigned to different\n            engines. After the execution of prefill phase in Prefill Engine,\n            KVCache is migrated from Prefill to Decode Engine.\n    \"\"\"\n\n    Hybrid = enum.auto()\n    DistServe = enum.auto()\n\n\nclass EngineRole(enum.Enum):\n    \"\"\"Role of Engine.\n\n    Note: In the implementation of LMDeploy-Distserve, all engine is hybrid\n        engine technically, the role of engine is up to what kind of request is\n        sent to the engine. However, taking implementation into the consideration,\n        the role is still need to be identified when starting the engine server\n        for the following reasons:\n            1. Make sure the engine can be correctly discovered by the proxy.\n            2. The create of ModelInputs is different among hybrid, prefill and\n                decode engines in DP Engine (DSV3 DP + EP).\n    \"\"\"\n\n    Hybrid = enum.auto()\n    Prefill = enum.auto()\n    Decode = enum.auto()\n\n\nclass MigrationBackend(enum.Enum):\n    \"\"\"Migration Backend.\"\"\"\n\n    DLSlime = enum.auto()\n    Mooncake = enum.auto()\n\n\nclass RDMALinkType(enum.Enum):\n    \"\"\"RDMA Link Type.\"\"\"\n\n    IB = enum.auto()\n    RoCE = enum.auto()\n\n\nclass DistServeRDMAConfig(BaseModel):\n    \"\"\"DistServe RDMA Config.\n\n    Args:\n        with_gdr: default to True.\n        link_type: default to `RDMALinkType.RoCE`.\n\n    Warning: Only GDR is supported by now.\n    Warning: Technically, both RoCE and IB are supported.\n        However, IB mode is not tested because of unavailable\n        testing envoriment.\n    \"\"\"\n\n    # RDMA with GPU Direct RDMA Access\n    with_gdr: bool = True\n    link_type: RDMALinkType = RDMALinkType.RoCE\n\n\nclass DistServeTCPConfig(BaseModel):\n    \"\"\"TODO: Add TCP Protocol\"\"\"\n\n\nclass DistServeNVLinkConfig(BaseModel):\n    \"\"\"TODO: Add NVLink Protocol\"\"\"\n\n\nclass DistServeEngineConfig(BaseModel):\n    \"\"\"DistServe Engine Config.\n\n    In Disaggregated LLM Serving, we need to get engine info of each\n    PD Peer for the following reason:\n        1. Cache: The stride of cache block for correct offset of KV Transfer.\n        2. Parallel: Prefill and decode use different parallel strategy to\n            achieve high SLO Attainment or high throughput. In this situation,\n            we need to caclculate which prefill-decode worker peers need to connect.\n            For example, prefill worker use pp4 and decode worker use tp2pp2,\n            the perfill-decode worker conn peer is (0, 0), (0, 1), (1, 0), (1, 1),\n            (2, 2), (2, 3), (3, 2), (3, 3). Instead, under the situation of\n            (tp4, tp4), perfill-decode worker conn peer is (0, 0), (1, 1), (2, 2),\n            (3, 3).\n    \"\"\"\n\n    # parallel config\n    # (dp, pp, tp, ep)\n    tp_size: int\n    ep_size: int\n    dp_size: int\n    pp_size: Optional[int]\n\n    # Rank of DP\n    dp_rank: int\n\n    # cache config\n    block_size: int\n    num_cpu_blocks: int\n    num_gpu_blocks: int\n\n\nclass MooncakeEngineConfig(DistServeEngineConfig):\n    \"\"\"Mooncake Transfer Engine Config.\n\n    TODO: Support more specific config for Mooncake.\n    \"\"\"\n    pass\n"
  },
  {
    "path": "lmdeploy/pytorch/disagg/conn/__init__.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n"
  },
  {
    "path": "lmdeploy/pytorch/disagg/conn/engine_conn.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport asyncio\nimport os\nfrom typing import TYPE_CHECKING, Dict, List\nfrom urllib.parse import urlparse\n\nimport zmq\nimport zmq.asyncio\n\nfrom lmdeploy.logger import get_logger\nfrom lmdeploy.pytorch.disagg.conn.protocol import (DistServeCacheFreeRequest, DistServeConnectionRequest,\n                                                   DistServeConnectionResponse, DistServeConnectionStatus,\n                                                   DistServeDropConnectionRequest, DistServeEngineEndpointInfo,\n                                                   DistServeInitRequest, DistServeInitResponse,\n                                                   DistServeKVTransferEndpointInfo)\nfrom lmdeploy.pytorch.engine.executor.dist_utils import find_available_port\n\nif TYPE_CHECKING:\n    from lmdeploy.pytorch.engine.engine import Engine\n\nlogger = get_logger('lmdeploy')\n\n\nclass EngineP2PConnection:\n\n    def __init__(self, engine: 'Engine'):\n        self.engine: Engine = engine\n        self.p2p_conn_ctx: Dict[str, zmq.asyncio.Context] = {}\n        self.p2p_sender: Dict[str, zmq.asyncio.Socket] = {}\n        self.p2p_receiver: Dict[str, zmq.asyncio.Socket] = {}\n\n        self.use_unique_kvtransfer_engine = os.environ.get('LMDEPLOY_USE_UNIQUE_KVTRANSFER_ENGINE', False)\n\n    def p2p_initialize(self, init_request: DistServeInitRequest):\n        ctx = zmq.asyncio.Context(2)\n        sender = ctx.socket(zmq.PUSH)\n        sender_port = find_available_port()\n        sender_hostname = urlparse(init_request.local_engine_id).hostname\n        zmq_address = f'tcp://{sender_hostname}:{sender_port}'\n        sender.bind(zmq_address)\n        receiver = ctx.socket(zmq.PULL)\n\n        self.p2p_conn_ctx[init_request.remote_engine_id] = ctx\n        self.p2p_sender[init_request.remote_engine_id] = sender\n        self.p2p_receiver[init_request.remote_engine_id] = receiver\n\n        kvtransfer_endpoint_info: List[DistServeKVTransferEndpointInfo] = self.engine.executor.p2p_initialize(\n            init_request)\n\n        return DistServeInitResponse(engine_endpoint_info=DistServeEngineEndpointInfo(zmq_address=zmq_address),\n                                     kvtransfer_endpoint_info=kvtransfer_endpoint_info,\n                                     status=DistServeConnectionStatus.SUCCESS)\n\n    def p2p_connect(self, conn_request: DistServeConnectionRequest):\n        self.p2p_receiver[conn_request.remote_engine_id].connect(conn_request.remote_engine_endpoint_info.zmq_address)\n        self.engine.executor.p2p_connect(remote_engine_id=conn_request.remote_engine_id,\n                                         conn_request=conn_request.remote_kvtransfer_endpoint_info)\n        event_loop = asyncio.get_event_loop()\n        event_loop.create_task(self.handle_zmq_recv(conn_request.remote_engine_id))\n        return DistServeConnectionResponse(status=DistServeConnectionStatus.SUCCESS)\n\n    def p2p_drop_connect(self, drop_conn_request: DistServeDropConnectionRequest):\n        # TODO (JimyMa): drop RDMA Connection\n        self.zmq_disconnect(drop_conn_request.remote_engine_id)\n        return {'success': True}\n\n    async def zmq_send(self, remote_engine_id: str, remote_session_id: int):\n        await self.p2p_sender[remote_engine_id].send_pyobj(\n            DistServeCacheFreeRequest(remote_engine_id=remote_engine_id, remote_session_id=remote_session_id))\n\n    async def handle_zmq_recv(self, remote_engine_id: str):\n        while True:\n            req: DistServeCacheFreeRequest = await self.p2p_receiver[remote_engine_id].recv_pyobj()\n            if isinstance(req, DistServeCacheFreeRequest):\n                session_id = req.remote_session_id\n                if session_id in self.engine.scheduler.sessions:\n                    self.engine.scheduler.end_session(session_id=session_id)\n                else:\n                    logger.error(f'invalid free, {remote_engine_id}, {session_id}')\n            else:\n                raise ValueError(f'Unsupported zmq request {type(req)}')\n\n    async def zmq_disconnect(self, remote_engine_id: str):\n        self.p2p_receiver[remote_engine_id].close()\n        self.p2p_sender[remote_engine_id].close()\n        self.p2p_conn_ctx[remote_engine_id].term()\n"
  },
  {
    "path": "lmdeploy/pytorch/disagg/conn/protocol.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport enum\nfrom typing import List, Optional\n\nfrom pydantic import BaseModel\n\nfrom lmdeploy.pytorch.disagg.config import (DistServeEngineConfig, DistServeNVLinkConfig, DistServeRDMAConfig,\n                                            DistServeTCPConfig)\n\n\nclass MigrationProtocol(enum.Enum):\n    \"\"\"Migration Transport Protocol.\n\n    Attributes:\n        RDMA: IB or RoCEv1/v2.\n        NVLINK: High device-to-device link.\n\n    Warning: By now, only `GPU Directed RDMA` is supported in DistServe.\n        We preserve several protocol and will be implemented in the future.\n    \"\"\"\n\n    TCP = enum.auto()\n    RDMA = enum.auto()\n    NVLINK = enum.auto()\n\n\nclass DistServeConnectionStatus(enum.Enum):\n    # TODO(JimyMa): Add more connection failure handler\n    SUCCESS = enum.auto()\n    FAIL = enum.auto()\n\n\nclass DistServeInitRequest(BaseModel):\n    local_engine_id: str\n    local_engine_config: DistServeEngineConfig\n\n    remote_engine_id: str\n    remote_engine_config: DistServeEngineConfig\n\n    protocol: MigrationProtocol\n\n    rank: Optional[int] = None\n\n    tcp_config: Optional[DistServeTCPConfig] = None\n    rdma_config: Optional[DistServeRDMAConfig] = None\n    nvlink_config: Optional[DistServeNVLinkConfig] = None\n\n\nclass DistServeEngineEndpointInfo(BaseModel):\n    zmq_address: str\n\n\nclass DistServeKVTransferEndpointInfo(BaseModel):\n    protocol: MigrationProtocol\n    endpoint_info: str\n\n\nclass DistServeInitResponse(BaseModel):\n    status: DistServeConnectionStatus\n    # the control plane initialization feedback\n    engine_endpoint_info: DistServeEngineEndpointInfo\n    # the KVCache Transfer initialization feedback\n    # To ensure generality (where endpoint_info can be initialization information\n    # for different media such as RDMA, NVLink, etc.), we use a string (str) to\n    # store this information.\n    kvtransfer_endpoint_info: List[DistServeKVTransferEndpointInfo]\n\n\nclass DistServeConnectionRequest(BaseModel):\n    protocol: MigrationProtocol\n    remote_engine_id: str\n    remote_engine_endpoint_info: DistServeEngineEndpointInfo\n    remote_kvtransfer_endpoint_info: List[DistServeKVTransferEndpointInfo]\n\n\nclass DistServeConnectionResponse(BaseModel):\n    status: DistServeConnectionStatus\n\n\nclass MigrationRequest(BaseModel):\n    protocol: MigrationProtocol\n\n    remote_engine_id: str\n    remote_session_id: int\n    remote_token_id: int\n    remote_block_ids: List[int]\n\n    is_dummy_prefill: bool = False\n\n\nclass DistServeCacheFreeRequest(BaseModel):\n    remote_engine_id: str\n    remote_session_id: int\n\n\nclass DistServeDropConnectionRequest(BaseModel):\n    engine_id: str\n    remote_engine_id: str\n"
  },
  {
    "path": "lmdeploy/pytorch/disagg/conn/proxy_conn.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport asyncio\nimport enum\nimport os\nfrom collections import defaultdict\nfrom typing import Dict, Set, Tuple\n\nimport aiohttp\nimport requests\n\nfrom lmdeploy.logger import get_logger\nfrom lmdeploy.pytorch.disagg.config import DistServeEngineConfig, EngineRole\nfrom lmdeploy.pytorch.disagg.conn.protocol import (DistServeCacheFreeRequest, DistServeConnectionRequest,\n                                                   DistServeConnectionResponse, DistServeDropConnectionRequest,\n                                                   DistServeInitRequest, DistServeInitResponse)\nfrom lmdeploy.pytorch.disagg.messages import PDConnectionMessage\n\nlogger = get_logger('lmdeploy')\n\nAIOHTTP_TIMEOUT = os.getenv('AIOHTTP_TIMEOUT', None)\n\n\nclass PDConnectionStatus(enum.Enum):\n    Disconnected = enum.auto()\n    Connected = enum.auto()\n    Connecting = enum.auto()\n\n\nclass PDConnectionState:\n    \"\"\"PDConnectionState.\"\"\"\n\n    def __init__(self, status: PDConnectionStatus, event: asyncio.Event):\n        self.status = status\n        self.event = event\n\n    async def wait(self):\n        await self.event.wait()\n\n    def set_status(self, status: PDConnectionStatus):\n        self.status = status\n\n\ndef get_server_api(url: str, api: str):\n    return f'{url}/{api}'\n\n\nclass PDConnectionPool:\n    \"\"\"Constructing the link of Prefill and Decode engine for the migration of\n    KVCache.\n\n    Note: we use Peer to Peer transportation in KVCache migration.\n    Note: Lazy link construction is supported, which perform connection\n        at the first LLM request. As a result, we don't need to construct\n        PD Communication group when start a engine server.\n    Note: we perform simple fault tolerance by checkpointing the session_id of a\n        request which is under migrating and will trigger `gc` when the decode\n        instanceis crushed.\n    TODO (JimyMa): By now, only engines with same parallel configuration can be\n        correctly connected.\n    \"\"\"\n\n    # Maximum concurrent connections​​\n    CONN_SEMAPHORE_SIZE = 2048\n\n    def __init__(self):\n        # all prefill and decode instances\n        # TODO (JimyMa): Maybe encoding instances\n        self.prefill_endpoints: Set[str] = set()\n        self.decode_endpoints: Set[str] = set()\n\n        # Links of PD Connection.\n        self.pool: Dict[Tuple[str, str], PDConnectionState] = {}\n\n        # put migrating session to `self.migration_session_shelf` for increasing fault tolerance\n        # if a session is finished, then pop it from `self.migration_session_shelf`\n        # if a decode instance is disconnected, then gc all blocks of these sessions in prefill instance.\n        self.migration_session_shelf: Dict[str, Set[int]] = defaultdict(set)\n\n        # conn_perform handler queue\n        self.waiting_conn: asyncio.Queue[Tuple[PDConnectionMessage, asyncio.Event]] = (asyncio.Queue())\n\n        # conn Registry Lock\n        self.conn_lock = asyncio.Lock()\n\n        # Connection Retry when failure\n        self.max_retry_cnt = 8\n\n        # trigger signal when conn request arrive.\n        self.conn_req_event = asyncio.Event()\n\n        # conn initialized signal\n        self.initialized = False\n\n    def reg_instance(self, role: EngineRole, endpoint: str):\n        if role == EngineRole.Prefill:\n            self.prefill_endpoints.add(endpoint)\n        elif role == EngineRole.Decode:\n            self.decode_endpoints.add(endpoint)\n        else:\n            raise ValueError(f'Unsupported role: {role}')\n\n    def dereg_instance(self, endpoint: str):\n        if endpoint in self.prefill_endpoints:\n            self.prefill_endpoints.remove(endpoint)\n        elif endpoint in self.decode_endpoints:\n            dropped_key = []\n            for conn_key in self.pool.keys():\n                if conn_key[1] == endpoint:\n                    dropped_key.append(conn_key)\n            for k in dropped_key:\n                self.drop(k)\n            # TODO(JimyMa): handle side-effect by kvcache migration\n            self.decode_endpoints.remove(endpoint)\n\n    def shelf_prefill_session(self, conn_key: Tuple[str, str], session_id: int):\n        self.migration_session_shelf[conn_key].add(session_id)\n\n    def unshelf_prefill_session(self, conn_key: Tuple[str, str], session_id: int):\n        self.migration_session_shelf[conn_key].remove(session_id)\n\n    async def connect(self, conn_req: PDConnectionMessage):\n\n        async def get_engine_config(server_endpoint):\n            async with self.conn_sem:\n                async with self.conn_sess.get(\n                        get_server_api(server_endpoint, 'distserve/engine_info'),\n                        timeout=self.aiotimeout,\n                ) as resp:\n                    result = await resp.json()\n                    return DistServeEngineConfig.model_validate_json(result)\n\n        async def p2p_initialize(server_endpoint, init_request: DistServeInitRequest) -> DistServeInitResponse:\n            async with self.conn_sem:\n                async with self.conn_sess.post(\n                        get_server_api(server_endpoint, 'distserve/p2p_initialize'),\n                        json=init_request.model_dump(mode='json'),\n                        timeout=self.aiotimeout,\n                ) as resp:\n                    result = await resp.json()\n                    return DistServeInitResponse.model_validate(result)\n\n        async def p2p_connect(server_endpoint, conn_request: DistServeConnectionRequest) -> DistServeConnectionResponse:\n            async with self.conn_sem:\n                async with self.conn_sess.post(\n                        get_server_api(server_endpoint, 'distserve/p2p_connect'),\n                        json=conn_request.model_dump(mode='json'),\n                        timeout=self.aiotimeout,\n                ) as resp:\n                    result = await resp.json()\n                    return DistServeConnectionResponse.model_validate(result)\n\n        async def conn_worker(conn_req: PDConnectionMessage, conn_event: asyncio.Event):\n            try:\n                link = (conn_req.p_url, conn_req.d_url)\n                logger.debug(f'{link} connecting...')\n                # Step 1. Get Remote Engine Configuration\n                prefill_engine_config = await get_engine_config(conn_req.p_url)\n                decode_engine_config = await get_engine_config(conn_req.d_url)\n\n                # Note: Only Same Parallel Configurations are supported by now\n                assert prefill_engine_config.tp_size == decode_engine_config.tp_size\n\n                # Step 2. Construct Initialize Configuration\n                prefill_init_req = DistServeInitRequest(\n                    protocol=conn_req.protocol,\n                    local_engine_id=conn_req.p_url,\n                    local_engine_config=prefill_engine_config,\n                    remote_engine_id=conn_req.d_url,\n                    remote_engine_config=decode_engine_config,\n                    rdma_config=conn_req.rdma_config,\n                    nvlink_config=conn_req.nvlink_config,\n                )\n                decode_init_req = DistServeInitRequest(\n                    protocol=conn_req.protocol,\n                    local_engine_id=conn_req.d_url,\n                    local_engine_config=decode_engine_config,\n                    remote_engine_id=conn_req.p_url,\n                    remote_engine_config=prefill_engine_config,\n                    rdma_config=conn_req.rdma_config,\n                    nvlink_config=conn_req.nvlink_config,\n                )\n\n                prefill_init_resp = await p2p_initialize(conn_req.p_url, prefill_init_req)\n                decode_init_resp = await p2p_initialize(conn_req.d_url, decode_init_req)\n\n                # Step 3. Connection\n                prefill_endpoint_conn_reqs = DistServeConnectionRequest(\n                    protocol=conn_req.protocol,\n                    remote_engine_id=conn_req.d_url,\n                    remote_engine_endpoint_info=decode_init_resp.engine_endpoint_info,\n                    remote_kvtransfer_endpoint_info=decode_init_resp.kvtransfer_endpoint_info)\n                decode_endpoint_conn_reqs = DistServeConnectionRequest(\n                    protocol=conn_req.protocol,\n                    remote_engine_id=conn_req.p_url,\n                    remote_engine_endpoint_info=prefill_init_resp.engine_endpoint_info,\n                    remote_kvtransfer_endpoint_info=prefill_init_resp.kvtransfer_endpoint_info)\n                await p2p_connect(conn_req.p_url, prefill_endpoint_conn_reqs)\n                await p2p_connect(conn_req.d_url, decode_endpoint_conn_reqs)\n                self.pool[link].set_status(PDConnectionStatus.Connected)\n                logger.debug(f'{(conn_req.p_url, conn_req.d_url)} connected')\n            except Exception as e:\n                self.pool[link].set_status(PDConnectionStatus.Disconnected)\n                logger.error(f'pd connection error: {e}')\n            conn_event.set()\n\n        async def wait_for_conn(conn_req: PDConnectionMessage, conn_event: asyncio.Event):\n            await self.pool[(conn_req.p_url, conn_req.d_url)].event.wait()\n            conn_event.set()\n\n        async def _perform_conn():\n            logger.debug('perform_conn start')\n            while True:\n                if self.waiting_conn.empty():\n                    await self.conn_req_event.wait()\n\n                self.conn_req_event.clear()\n\n                while not self.waiting_conn.empty():\n                    conn_req, conn_event = self.waiting_conn.get_nowait()\n                    link = (conn_req.p_url, conn_req.d_url)\n                    if link not in self.pool:\n                        self.pool[link] = PDConnectionState(\n                            PDConnectionStatus.Disconnected,\n                            conn_event,\n                        )\n                    if self.pool[link].status == PDConnectionStatus.Connecting:\n                        asyncio.create_task(wait_for_conn(conn_req, conn_event))\n                    elif self.pool[link].status == PDConnectionStatus.Disconnected:\n                        self.pool[link].set_status(PDConnectionStatus.Connecting)\n                        asyncio.create_task(conn_worker(conn_req, conn_event))\n\n        if not self.initialized:\n            loop = asyncio.get_event_loop()\n            loop.create_task(_perform_conn())\n            self.conn_sem = asyncio.Semaphore(self.CONN_SEMAPHORE_SIZE)\n            self.conn_sess = aiohttp.ClientSession(\n                connector=aiohttp.TCPConnector(limit_per_host=256),\n                timeout=aiohttp.ClientTimeout(total=AIOHTTP_TIMEOUT),\n            )\n            self.aiotimeout = aiohttp.ClientTimeout(total=AIOHTTP_TIMEOUT)\n            self.initialized = True\n\n        self.reg_instance(EngineRole.Prefill, conn_req.p_url)\n        self.reg_instance(EngineRole.Decode, conn_req.d_url)\n\n        cnt = 0\n        while cnt < self.max_retry_cnt:\n            if self.is_connected(conn_req.p_url, conn_req.d_url):\n                return\n            if cnt > 0:\n                logger.warning(f'Connection failure, retry cnt: {cnt}')\n            conn_event = asyncio.Event()\n            self.waiting_conn.put_nowait((conn_req, conn_event))\n            self.conn_req_event.set()\n            await conn_event.wait()\n            cnt += 1\n        async with self.conn_lock:\n            self.pool[conn_req.p_url, conn_req.d_url].set_status(PDConnectionStatus.Disconnected)\n        raise TimeoutError('PDConnection Failure')\n\n    def is_connected(self, p_url: str, d_url: str):\n        link = self.pool.get((p_url, d_url), None)\n        if not link:\n            return False\n        return link.status == PDConnectionStatus.Connected\n\n    def drop(self, pd_key: Tuple[str, str]):\n        left = pd_key[0]\n        right = pd_key[1]\n\n        def cache_free(server_endpoint, cache_free_request: DistServeCacheFreeRequest) -> Dict:\n            try:\n                requests.post(get_server_api(server_endpoint, 'distserve/free_cache'),\n                              json=cache_free_request.model_dump(mode='json'))\n            except Exception as e:\n                logger.warning(f'error cache block free {server_endpoint, cache_free_request}. ErrorMsg: {str(e)}')\n\n        def drop_connect(server_endpoint: str, p2p_disconnect_request: DistServeDropConnectionRequest):\n            try:\n                requests.post(get_server_api(server_endpoint, 'distserve/p2p_drop_connect'),\n                              json=p2p_disconnect_request.model_dump(mode='json'))\n            except Exception as e:\n                logger.warning(f'error drop connect {server_endpoint, p2p_disconnect_request}. ErrorMsg: {str(e)}')\n\n        # trigger gc\n        logger.warning('cache block gc triggered.')\n        try:\n            for session_id in self.migration_session_shelf[(left, right)]:\n                cache_free(left, DistServeCacheFreeRequest(remote_engine_id=left, remote_session_id=session_id))\n        except Exception as e:\n            logger.warning(f'gc error, ErrorMsg: {str(e)}')\n\n        # trigger p2p disconnect\n        logger.warning('drop connection triggered.')\n        try:\n            drop_connect(left, DistServeDropConnectionRequest(engine_id=left, remote_engine_id=right))\n            drop_connect(right, DistServeDropConnectionRequest(engine_id=right, remote_engine_id=left))\n        except Exception as e:\n            logger.warning(f'p2p disconnect error, ErrorMsg: {str(e)}')\n\n        self.pool.pop((left, right), None)\n"
  },
  {
    "path": "lmdeploy/pytorch/disagg/messages.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom typing import List, Optional, Tuple\n\nfrom pydantic import BaseModel\n\nfrom lmdeploy.pytorch.disagg.config import DistServeNVLinkConfig, DistServeRDMAConfig, DistServeTCPConfig\nfrom lmdeploy.pytorch.disagg.conn.protocol import MigrationProtocol\n\n\nclass MigrationExecutionBatch(BaseModel):\n    \"\"\"Input of the Migration.\"\"\"\n\n    protocol: MigrationProtocol\n    requests: List[Tuple[str, List[Tuple[int, int]]]] = []\n\n\nclass AssignmentInstruct(BaseModel):\n    \"\"\"Assignment Batch.\"\"\"\n    mr_key: int\n    target_offset: int\n    source_offset: int\n    length: int\n\n\nclass MigrationAssignment(BaseModel):\n    \"\"\"Migration Assignment.\"\"\"\n    protocol: MigrationProtocol\n    remote_engine_id: str\n    batch: List[AssignmentInstruct]\n\n\nclass PDConnectionMessage(BaseModel):\n    p_url: str\n    d_url: str\n    protocol: MigrationProtocol = MigrationProtocol.RDMA\n    tcp_config: Optional[DistServeTCPConfig] = None\n    rdma_config: Optional[DistServeRDMAConfig] = None\n    nvlink_config: Optional[DistServeNVLinkConfig] = None\n\n\nclass DistServeRegisterMRMessage(BaseModel):\n    protocol: MigrationProtocol\n\n    remote_engine_id: str\n    mr_key: int\n    addr: int\n    offset: int\n    length: int\n"
  },
  {
    "path": "lmdeploy/pytorch/distributed.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom dataclasses import dataclass\nfrom datetime import timedelta\nfrom typing import List, Optional\n\nimport torch\nfrom torch import distributed as dist\nfrom torch.distributed import ProcessGroup, ReduceOp, Work  # noqa: F401\n\nfrom lmdeploy.pytorch.utils import CtxMgrBase, singleton\n\nfrom .config import DistConfig, TPMode\n\n\n@dataclass\nclass DistGroup:\n    \"\"\"Distributed group.\"\"\"\n    rank: int = 0\n    cpu_group: dist.ProcessGroup = None\n    gpu_group: dist.ProcessGroup = None\n    cpu_groups: List[dist.ProcessGroup] = None\n    gpu_groups: List[dist.ProcessGroup] = None\n    gpu_gather_group: dist.ProcessGroup = None\n\n    def close(self):\n        \"\"\"Close groups.\"\"\"\n        if not dist.is_initialized():\n            return\n        if self.cpu_groups is not None:\n            for group in self.cpu_groups:\n                dist.destroy_process_group(group)\n            self.cpu_groups = None\n        if self.gpu_groups is not None:\n            for group in self.gpu_groups:\n                dist.destroy_process_group(group)\n            self.gpu_groups = None\n\n\ndef _build_tp_group_impl(tp: int,\n                         rank: int,\n                         world_size: int,\n                         timeout: timedelta,\n                         cpu_backend: str = 'gloo',\n                         ccl_backend: str = 'nccl',\n                         attn_tp: int = 1,\n                         tp_mode: TPMode = TPMode.DEFAULT):\n    \"\"\"Build tp group.\"\"\"\n    assert tp > 1\n    tp_rank = rank % tp\n    tp_group_id = rank // tp\n    gather_group_id = (rank - tp_group_id * tp) % attn_tp\n    ranks = range(world_size)\n    tp_gpu_groups = []\n    tp_cpu_groups = []\n    gather_groups = []\n    for start in range(0, world_size, tp):\n        tp_ranks = ranks[start:start + tp]\n        group = dist.new_group(ranks=tp_ranks, timeout=timeout, backend=ccl_backend)\n        tp_gpu_groups.append(group)\n        cpu_group = dist.new_group(ranks=tp_ranks, timeout=timeout, backend=cpu_backend)\n        tp_cpu_groups.append(cpu_group)\n\n        # create gather group\n        if tp_mode == TPMode.DP_TP and attn_tp != tp:\n            for g_start in range(start, start + attn_tp):\n                g_ranks = ranks[g_start:(g_start + tp):attn_tp]\n                gather_group = dist.new_group(ranks=g_ranks, timeout=timeout, backend=ccl_backend)\n                gather_groups.append(gather_group)\n    tp_gpu_group = tp_gpu_groups[tp_group_id]\n    tp_cpu_group = tp_cpu_groups[tp_group_id]\n\n    if tp_mode == TPMode.DP_TP:\n        if attn_tp == tp:\n            gather_group = tp_gpu_group\n        else:\n            gather_group = gather_groups[gather_group_id]\n    else:\n        gather_group = None\n    return DistGroup(\n        rank=tp_rank,\n        cpu_group=tp_cpu_group,\n        gpu_group=tp_gpu_group,\n        cpu_groups=tp_cpu_groups,\n        gpu_groups=tp_gpu_groups,\n        gpu_gather_group=gather_group,\n    )\n\n\ndef _build_attn_tp_group(context: 'DistContext',\n                         timeout: timedelta,\n                         cpu_backend: str = 'gloo',\n                         ccl_backend: str = 'nccl'):\n    \"\"\"Build attention tp group.\"\"\"\n    dist_config = context.dist_config\n    tp = dist_config.attn_tp\n    # skip if tp == 1\n    if tp == 1:\n        context.attn_tp_group = DistGroup(rank=0)\n        return\n\n    dist_group = _build_tp_group_impl(\n        tp,\n        context.rank,\n        dist_config.world_size,\n        timeout=timeout,\n        cpu_backend=cpu_backend,\n        ccl_backend=ccl_backend,\n        attn_tp=tp,\n        tp_mode=TPMode.DEFAULT,\n    )\n    context.attn_tp_group = dist_group\n\n\ndef _build_mlp_tp_group(context: 'DistContext',\n                        timeout: timedelta,\n                        cpu_backend: str = 'gloo',\n                        ccl_backend: str = 'nccl'):\n    \"\"\"Build mlp tp group.\"\"\"\n    dist_config = context.dist_config\n    tp = dist_config.mlp_tp\n    # skip if tp == 1\n    if tp == 1:\n        context.mlp_tp_group = DistGroup(rank=0)\n        return\n\n    # reuse attn tp group\n    if tp == dist_config.attn_tp:\n        context.mlp_tp_group = context.attn_tp_group\n        return\n\n    dist_group = _build_tp_group_impl(\n        tp,\n        context.rank,\n        dist_config.world_size,\n        timeout=timeout,\n        cpu_backend=cpu_backend,\n        ccl_backend=ccl_backend,\n        attn_tp=dist_config.attn_tp,\n        tp_mode=dist_config.mlp_tp_mode,\n    )\n    context.mlp_tp_group = dist_group\n\n\ndef _build_moe_tp_group(context: 'DistContext',\n                        timeout: timedelta,\n                        cpu_backend: str = 'gloo',\n                        ccl_backend: str = 'nccl'):\n    \"\"\"Build moe tp group.\"\"\"\n    dist_config = context.dist_config\n    tp = dist_config.moe_tp\n    # skip if tp == 1\n    if tp == 1:\n        context.moe_tp_group = DistGroup(rank=0)\n        return\n\n    # reuse attn tp group\n    if tp == dist_config.attn_tp:\n        context.moe_tp_group = context.attn_tp_group\n        return\n\n    # reuse mlp tp group\n    if tp == dist_config.mlp_tp:\n        context.moe_tp_group = context.mlp_tp_group\n        return\n\n    dist_group = _build_tp_group_impl(\n        tp,\n        context.rank,\n        dist_config.world_size,\n        timeout=timeout,\n        cpu_backend=cpu_backend,\n        ccl_backend=ccl_backend,\n        attn_tp=dist_config.attn_tp,\n        tp_mode=dist_config.moe_tp_mode,\n    )\n    context.moe_tp_group = dist_group\n\n\ndef _build_tp_group(context: 'DistContext', timeout: timedelta, cpu_backend: str = 'gloo', ccl_backend: str = 'nccl'):\n    \"\"\"Build tp group.\"\"\"\n    _build_attn_tp_group(context, timeout, cpu_backend, ccl_backend)\n    _build_mlp_tp_group(context, timeout, cpu_backend, ccl_backend)\n    _build_moe_tp_group(context, timeout, cpu_backend, ccl_backend)\n    context.tp_group = context.attn_tp_group\n\n\n@dataclass\nclass DistContext:\n    rank: int = 0\n    dp_rank: int = 0\n    ep_rank: int = 0\n\n    tp_group: DistGroup = None\n    attn_tp_group: DistGroup = None\n    mlp_tp_group: DistGroup = None\n    moe_tp_group: DistGroup = None\n\n    cpu_group: dist.ProcessGroup = None\n    ep_gpu_group: dist.ProcessGroup = None\n    ep_gpu_groups: List[dist.ProcessGroup] = None\n    dist_config: DistConfig = None\n\n    @classmethod\n    def _build_ep_group(cls, context: 'DistContext', timeout: timedelta, ccl_backend: str = 'nccl'):\n        \"\"\"Build ep group.\"\"\"\n        dist_config = context.dist_config\n        ep = dist_config.ep\n        if ep <= 1:\n            return\n\n        dp_rank = context.dp_rank\n        world_size = dist_config.world_size\n        ep_rank = context.rank % ep\n        ep_group_id = dp_rank // ep\n        ranks = range(world_size)\n        ep_gpu_groups = []\n        for start in range(0, world_size, ep):\n            ep_ranks = ranks[start:start + ep]\n            group = dist.new_group(ranks=ep_ranks, timeout=timeout, backend=ccl_backend)\n            ep_gpu_groups.append(group)\n        ep_gpu_group = ep_gpu_groups[ep_group_id]\n\n        context.ep_rank = ep_rank\n        context.ep_gpu_group = ep_gpu_group\n        context.ep_gpu_groups = ep_gpu_groups\n\n    @classmethod\n    def build(cls, rank: int = 0, dist_config: DistConfig = None, ccl_backend: str = 'nccl'):\n        \"\"\"Build dist context.\"\"\"\n        timeout = timedelta(days=35600)\n        cpu_backend = 'gloo'\n\n        if dist_config is None:\n            dist_config = DistConfig()\n\n        dp_rank = dist_config.dp_rank\n        world_size = dist_config.world_size\n        context = DistContext(rank=rank,\n                              dp_rank=dp_rank,\n                              dist_config=dist_config,\n                              attn_tp_group=DistGroup(rank=0),\n                              mlp_tp_group=DistGroup(rank=0),\n                              moe_tp_group=DistGroup(rank=0),\n                              tp_group=DistGroup(rank=0))\n        if world_size == 1:\n            return context\n\n        assert dist.is_initialized()\n\n        # cpu group\n        context.cpu_group = dist.new_group(ranks=list(range(world_size)), timeout=timeout, backend=cpu_backend)\n\n        # tp\n        _build_tp_group(context, timeout, cpu_backend=cpu_backend, ccl_backend=ccl_backend)\n\n        # ep\n        cls._build_ep_group(context, timeout, ccl_backend=ccl_backend)\n\n        return context\n\n    def close(self):\n        \"\"\"Close groups.\"\"\"\n        if not dist.is_initialized():\n            return\n        if self.attn_tp_group is not None:\n            self.attn_tp_group.close()\n        if self.mlp_tp_group is not None:\n            self.mlp_tp_group.close()\n        if self.moe_tp_group is not None:\n            self.moe_tp_group.close()\n        if self.ep_gpu_groups is not None:\n            for group in self.ep_gpu_groups:\n                dist.destroy_process_group(group)\n            self.ep_gpu_groups = None\n\n\nDefaultContext = DistContext.build()\n\n\n@singleton\nclass DistManager(CtxMgrBase[DistContext]):\n    \"\"\"Distributed context manager.\"\"\"\n\n    def __init__(self):\n        super().__init__(DefaultContext)\n\n    def current_config(self) -> DistConfig:\n        \"\"\"Get current dist config.\"\"\"\n        return self.current_context().dist_config\n\n\ndef get_dist_manager():\n    \"\"\"Get device manager.\"\"\"\n    return DistManager()\n\n\ndef get_world_rank():\n    \"\"\"Get distributed world size and rank.\"\"\"\n    ctx = get_dist_manager().current_context()\n    world_size = ctx.dist_config.world_size\n    rank = ctx.rank\n\n    return world_size, rank\n\n\ndef get_tp_world_rank(layer_type: Optional[str] = None):\n    ctx = get_dist_manager().current_context()\n    if layer_type is None:\n        return ctx.dist_config.tp, ctx.tp_group.rank\n    elif layer_type == 'attn':\n        return ctx.dist_config.attn_tp, ctx.attn_tp_group.rank\n    elif layer_type == 'mlp':\n        return ctx.dist_config.mlp_tp, ctx.mlp_tp_group.rank\n    elif layer_type == 'moe':\n        return ctx.dist_config.moe_tp, ctx.moe_tp_group.rank\n    else:\n        raise RuntimeError(f'Unknown layer type: {layer_type}')\n\n\ndef get_dp_world_rank():\n    ctx = get_dist_manager().current_context()\n    return ctx.dist_config.dp, ctx.dp_rank\n\n\ndef get_ep_world_rank():\n    ctx = get_dist_manager().current_context()\n    return ctx.dist_config.ep, ctx.ep_rank\n\n\ndef _check_group_device(device: str):\n    \"\"\"Check group device.\"\"\"\n    assert (device in ['cpu', 'gpu']), ('Expect process group device in (\"cpu\", \"gpu\"), '\n                                        f'but get {device}.')\n\n\ndef get_process_group(device: str = None):\n    \"\"\"Get process group.\"\"\"\n    return dist.GroupMember.WORLD\n\n\ndef get_dist_group(layer_type: str = 'attn'):\n    \"\"\"Get dist group.\"\"\"\n    ctx = get_dist_manager().current_context()\n    if layer_type == 'attn':\n        tp_group = ctx.attn_tp_group\n    elif layer_type == 'mlp':\n        tp_group = ctx.mlp_tp_group\n    elif layer_type == 'moe':\n        tp_group = ctx.moe_tp_group\n    else:\n        raise RuntimeError(f'Unknown layer type: {layer_type}')\n    return tp_group\n\n\ndef get_tp_group(device: str = 'gpu', layer_type: str = 'attn'):\n    \"\"\"Get tp group.\"\"\"\n    _check_group_device(device)\n    tp_group = get_dist_group(layer_type)\n\n    if tp_group is None:\n        return None\n\n    if device == 'cpu':\n        return tp_group.cpu_group\n    else:\n        return tp_group.gpu_group\n\n\ndef get_group(group_type: str, device: str):\n    \"\"\"Get group.\"\"\"\n    if group_type == 'tp':\n        return get_tp_group(device)\n    elif group_type in ['world', 'all']:\n        return get_process_group(device)\n    else:\n        raise RuntimeError(f'Unknown group type: {group_type}')\n\n\ndef all_reduce(tensor, op=ReduceOp.SUM, group='tp', async_op=False):\n    \"\"\"All reduce.\"\"\"\n    if isinstance(group, str):\n        group = get_group(group, 'gpu')\n    return dist.all_reduce(tensor, op, group, async_op)\n\n\ndef broadcast(tensor, src, group='tp', async_op=False):\n    \"\"\"broadcast.\"\"\"\n    if isinstance(group, str):\n        group = get_group(group, 'gpu')\n    return dist.broadcast(tensor, src, group, async_op)\n\n\ndef all_gather_object(object_list, obj, group='tp'):\n    if isinstance(group, str):\n        group = get_group(group, 'cpu')\n    return dist.all_gather_object(object_list, obj, group=group)\n\n\ndef all_gather(tensor_list, tensor, group='tp', async_op=False):\n    if isinstance(group, str):\n        group = get_group(group, 'gpu')\n    return dist.all_gather(tensor_list, tensor, group=group, async_op=async_op)\n\n\ndef all_gather_into_tensor(output_tensor, input_tensor, group='tp', async_op=False):\n    if isinstance(group, str):\n        group = get_group(group, 'gpu')\n    return dist.all_gather_into_tensor(output_tensor, input_tensor, group=group, async_op=async_op)\n\n\ndef reduce_scatter(output, input_list, op=ReduceOp.SUM, group='tp', async_op=False):\n    \"\"\"Reduce scatter.\"\"\"\n    if isinstance(group, str):\n        group = get_group(group, 'gpu')\n    return dist.reduce_scatter(output, input_list, op=op, group=group, async_op=async_op)\n\n\ndef gather_by_tp_sizes(x: torch.Tensor,\n                       tp_sizes: List[int],\n                       group: Optional[dist.ProcessGroup] = None,\n                       async_op: bool = False):\n    \"\"\"Gather input.\"\"\"\n    assert all(size >= 0 for size in tp_sizes), f'Invalid tp sizes: {tp_sizes}'\n    shape = (*x.shape[:-2], sum(tp_sizes), *x.shape[-1:])\n    new_x = x.new_empty(shape)\n    split_new_x = list(new_x.split(tp_sizes, -2))\n    handle = dist.all_gather(split_new_x, x, group=group, async_op=async_op)\n    if async_op:\n        return new_x, handle\n    return new_x\n\n\ndef reduce_scatter_by_tp_sizes(out: torch.Tensor, rank: int, tp_sizes: List[int], group: dist.ProcessGroup):\n    \"\"\"Reduce scatter.\"\"\"\n    attn_tp = get_dist_manager().current_config().attn_tp\n    outs = list(out.split(tp_sizes, -2))\n    outs = [item for item in outs for _ in range(attn_tp)]\n    out = outs[rank]\n    dist.reduce_scatter(out, outs, group=group)\n    return out\n"
  },
  {
    "path": "lmdeploy/pytorch/engine/__init__.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom .engine import Engine\nfrom .engine_instance import EngineInstance\n\n__all__ = ['Engine', 'EngineInstance']\n"
  },
  {
    "path": "lmdeploy/pytorch/engine/base.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom lmdeploy.pytorch.disagg.conn.protocol import (DistServeConnectionRequest, DistServeDropConnectionRequest,\n                                                   DistServeInitRequest)\n\n\nclass EngineBase:\n\n    def close(self) -> None:\n        \"\"\"Close mp engine.\"\"\"\n        raise NotImplementedError('This method is not implemented.')\n\n    def start_loop(self) -> None:\n        \"\"\"Start mp engine loop.\"\"\"\n\n    def end_session(self, session_id: int):\n        \"\"\"End session.\"\"\"\n        raise NotImplementedError('This method is not implemented.')\n\n    def p2p_initialize(self, conn_request: DistServeInitRequest):\n        \"\"\"Init rdma link.\"\"\"\n        raise NotImplementedError('This method is not implemented.')\n\n    def p2p_connect(self, conn_request: DistServeConnectionRequest):\n        \"\"\"rdma_connect.\"\"\"\n        raise NotImplementedError('This method is not implemented.')\n\n    def p2p_drop_connect(self, drop_conn_request: DistServeDropConnectionRequest):\n        \"\"\"Drop connection.\n\n        1. drop engine connection (zmq connection)\n        2. TODO(JimyMa) drop RDMA Connection.\n        \"\"\"\n        raise NotImplementedError('This method is not implemented.')\n\n    def create_instance(self, cuda_stream_id=0):\n        \"\"\"Create instance.\"\"\"\n        raise NotImplementedError('This method is not implemented.')\n\n\nclass EngineInstanceBase:\n\n    async def async_end(self, session_id: int):\n        \"\"\"End the given session.\"\"\"\n        raise NotImplementedError('This method is not implemented.')\n\n    async def async_cancel(self, session_id: int):\n        \"\"\"Stop current streaming inference.\"\"\"\n        raise NotImplementedError('This method is not implemented.')\n\n    async def async_stream_infer(self, *args, **kwargs):\n        \"\"\"Send stream inference request.\"\"\"\n        raise NotImplementedError('This method is not implemented.')\n"
  },
  {
    "path": "lmdeploy/pytorch/engine/cache_engine.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n# modify from: https://github.com/vllm-project/vllm\nimport json\nimport math\nfrom dataclasses import dataclass\nfrom typing import Dict, List, Literal, Optional, Sequence, Tuple\n\nimport torch\n\nfrom lmdeploy.pytorch.backends import get_backend\nfrom lmdeploy.pytorch.disagg.backend.backend import MIGRATION_BACKENDS\nfrom lmdeploy.pytorch.disagg.backend.base import MigrationBackendImpl\nfrom lmdeploy.pytorch.disagg.conn.protocol import DistServeInitRequest, DistServeKVTransferEndpointInfo\nfrom lmdeploy.pytorch.disagg.messages import (AssignmentInstruct, DistServeRegisterMRMessage, MigrationAssignment,\n                                              MigrationExecutionBatch)\nfrom lmdeploy.utils import get_logger\n\nfrom ..config import CacheConfig, ModelConfig\n\nKVCache = Tuple[torch.Tensor, torch.Tensor]\n\nlogger = get_logger('lmdeploy')\n\n\ndef round_up(x: int, alignment: int) -> int:\n    \"\"\"Round up x to the nearest multiple of alignment.\"\"\"\n    return ((x + alignment - 1) // alignment) * alignment\n\n\n@dataclass\nclass CacheDesc:\n    \"\"\"Cache description.\"\"\"\n    shape: List[int]\n    dtype: torch.dtype\n    alignment: int = 256\n\n    def __post_init__(self):\n        self.numel = math.prod(self.shape)\n        self.size = self.numel * self.dtype.itemsize\n        self.aligned_size = round_up(self.size, self.alignment)\n\n\ndef _get_kv_cache_dtype(model_config: ModelConfig):\n    kv_cache_dtype = model_config.dtype\n    if model_config.use_mla_fp8_cache:\n        kv_cache_dtype = torch.float8_e4m3fn\n    return kv_cache_dtype\n\n\n# 512*1 + 4*4 + 64*2 = 656\nMLA_FP8_HEAD_DIM = 656\n\n\nclass CacheEngine:\n    \"\"\"Host and Device memory maintainer.\n\n    Args:\n        cache_config (CacheConfig): config of the cache information.\n        model_config (ModelConfig): config of the model.\n        rank (int): distribution rank, 0 on non-distributed environment.\n        world_size (int): distribution world size, 1 on non-distributed\n            environment.\n        cache_stream (torch.cuda.Stream): the stream used for cache engine swap,\n            if set to None, it's created in CacheEngine.\n    \"\"\"\n\n    def __init__(\n        self,\n        cache_config: CacheConfig,\n        model_config: ModelConfig,\n        rank: int = 0,\n        tp_rank: int = 0,\n        world_size: int = 1,\n        cache_stream: torch.cuda.Stream = None,\n    ) -> None:\n        self.world_size = world_size\n        self.rank = rank\n        self.tp_rank = tp_rank\n        self.cache_config = cache_config\n        self.model_config = model_config\n\n        self.block_size = cache_config.block_size\n        self.num_layers = model_config.num_layers\n        self.kv_cache_dtype = _get_kv_cache_dtype(self.model_config)\n\n        if self.model_config.use_mla_fp8_cache:\n            cache_config.quant_policy = 0\n\n        if cache_config.quant_policy > 0:\n            if self.cache_config.device_type in ['cuda']:\n                self.kv_cache_dtype = torch.uint8\n            elif self.cache_config.device_type in ['ascend', 'npu']:\n                self.kv_cache_dtype = torch.int8\n            else:\n                raise ValueError(f'unsupported device_type {self.cache_config.device_type}')\n\n        # Initialize the cache.\n        self.local_gpu_cache = self.allocate_gpu_cache()\n        self.local_cpu_cache = self.allocate_cpu_cache()\n\n        self.migration_backend_impl: Optional[MigrationBackendImpl] = None\n\n        # Initialize the stream for caching operations.\n        self.cache_stream = cache_stream or torch.cuda.Stream()\n        assert self.cache_stream != torch.cuda.current_stream()\n        # Initialize the events for stream synchronization.\n        self.events = torch.cuda.Event()\n\n        logger.debug(f'Initialize cache engine with {cache_config.num_gpu_blocks}'\n                     f' gpu blocks and {cache_config.num_cpu_blocks} cpu blocks.')\n\n    @property\n    def cpu_cache(self):\n        \"\"\"Gpu cache.\"\"\"\n        return self.local_cpu_cache\n\n    @property\n    def gpu_cache(self):\n        \"\"\"Gpu cache.\"\"\"\n        return self.local_gpu_cache\n\n    @property\n    def num_gpu_blocks(self):\n        \"\"\"Num gpu blocks.\"\"\"\n        return self.cache_config.num_gpu_blocks\n\n    @property\n    def num_cpu_blocks(self):\n        \"\"\"Num gpu blocks.\"\"\"\n        return self.cache_config.num_cpu_blocks\n\n    @classmethod\n    def _get_key_block_shape_impl(cls,\n                                  model_config: ModelConfig,\n                                  block_size: int,\n                                  head_size: int,\n                                  world_size: int = 1,\n                                  quant_policy: Literal[0, 4, 8] = 0):\n        \"\"\"Get single block shape.\"\"\"\n        attn_backend = get_backend()\n        dtype = model_config.dtype\n        num_heads = model_config.num_key_value_heads\n\n        # split heads by tp\n        assert num_heads % world_size == 0, \\\n            f'num_heads: {num_heads}, world_size: {world_size}'\n        num_heads = num_heads // world_size\n\n        # patch for flash mla\n        if model_config.use_mla_fp8_cache:\n            return (block_size, num_heads, MLA_FP8_HEAD_DIM)\n\n        if quant_policy == 4:  # pack head_dim to uint8\n            assert head_size % 2 == 0, \\\n                f'head_size: {head_size}, quant_policy: {quant_policy}'\n            head_size = head_size // 2\n        return attn_backend.get_k_block_shape(block_size, num_heads, head_size, dtype)\n\n    @classmethod\n    def _get_value_block_shape_impl(cls,\n                                    model_config: ModelConfig,\n                                    block_size: int,\n                                    head_size: int,\n                                    world_size: int = 1,\n                                    quant_policy: Literal[0, 4, 8] = 0):\n        \"\"\"Get single block shape.\"\"\"\n        attn_backend = get_backend()\n        dtype = model_config.dtype\n        num_heads = model_config.num_key_value_heads\n\n        # split heads by tp\n        assert num_heads % world_size == 0, \\\n            f'num_heads: {num_heads}, world_size: {world_size}'\n        num_heads = num_heads // world_size\n\n        # patch for flash mla\n        if model_config.use_mla_fp8_cache:\n            # flash mla shared key and value\n            return (block_size, num_heads, 0)\n\n        if quant_policy == 4:  # pack head_dim to uint8\n            assert head_size % 2 == 0, \\\n                f'head_size: {head_size}, quant_policy: {quant_policy}'\n            head_size = head_size // 2\n\n        return attn_backend.get_v_block_shape(block_size, num_heads, head_size, dtype)\n\n    @classmethod\n    def get_k_cache_desc(cls, model_config: ModelConfig, cache_config: CacheConfig, world_size: int = 1) -> CacheDesc:\n        \"\"\"Get key cache description.\"\"\"\n        head_size = model_config.k_head_dim\n        if head_size is None:\n            head_size = model_config.head_dim\n        shape = cls._get_key_block_shape_impl(\n            model_config,\n            block_size=cache_config.block_size,\n            head_size=head_size,\n            world_size=world_size,\n            quant_policy=cache_config.quant_policy,\n        )\n        shape = list(shape)\n        dtype = _get_kv_cache_dtype(model_config)\n        if cache_config.quant_policy in (4, 8):\n            dtype = torch.uint8\n        return CacheDesc(shape=shape, dtype=dtype)\n\n    @classmethod\n    def get_v_cache_desc(cls, model_config: ModelConfig, cache_config: CacheConfig, world_size: int = 1) -> CacheDesc:\n        \"\"\"Get value cache description.\"\"\"\n        head_size = model_config.v_head_dim\n        if head_size is None:\n            head_size = model_config.head_dim\n        shape = cls._get_value_block_shape_impl(\n            model_config,\n            block_size=cache_config.block_size,\n            head_size=head_size,\n            world_size=world_size,\n            quant_policy=cache_config.quant_policy,\n        )\n        shape = list(shape)\n        dtype = _get_kv_cache_dtype(model_config)\n        if cache_config.quant_policy in (4, 8):\n            dtype = torch.uint8\n        return CacheDesc(shape=shape, dtype=dtype)\n\n    @classmethod\n    def get_quant_cache_descs(cls, k_cache_desc: CacheDesc, v_cache_desc: CacheDesc, model_config: ModelConfig,\n                              cache_config: CacheConfig):\n        \"\"\"Get quant cache descs.\"\"\"\n        if cache_config.quant_policy == 0:\n            return []\n\n        dtype = model_config.dtype\n        key_scale_zero_shape = k_cache_desc.shape[:-1] + [2]\n        val_scale_zero_shape = v_cache_desc.shape[:-1] + [2]\n        key_scale_zero_desc = CacheDesc(shape=key_scale_zero_shape, dtype=dtype)\n        val_scale_zero_desc = CacheDesc(shape=val_scale_zero_shape, dtype=dtype)\n        return [key_scale_zero_desc, val_scale_zero_desc]\n\n    @classmethod\n    def get_custom_cache_descs(cls, model_config: ModelConfig, cache_config: CacheConfig) -> List[CacheDesc]:\n        \"\"\"Get custom cache descs.\"\"\"\n        if len(model_config.cache_shapes) == 0:\n            return []\n\n        block_size = cache_config.block_size\n\n        descs = []\n        for shape, dtype in model_config.cache_shapes:\n            custom_shape = (block_size, *shape)\n            desc = CacheDesc(shape=custom_shape, dtype=dtype)\n            descs.append(desc)\n        return descs\n\n    @classmethod\n    def allocate_caches(cls, num_blocks: int, model_config: ModelConfig, cache_config: CacheConfig, world_size: int,\n                        device: str):\n        \"\"\"Allocate caches.\"\"\"\n\n        num_layers = model_config.num_layers\n\n        # get all descs\n        k_cache_desc = cls.get_k_cache_desc(model_config, cache_config, world_size)\n        v_cache_desc = cls.get_v_cache_desc(model_config, cache_config, world_size)\n        quant_cache_descs = cls.get_quant_cache_descs(k_cache_desc, v_cache_desc, model_config, cache_config)\n        custom_cache_descs = cls.get_custom_cache_descs(model_config, cache_config)\n        cache_descs = [k_cache_desc, v_cache_desc] + quant_cache_descs + custom_cache_descs\n\n        # get mempool size\n        mem_pool_size = 0\n        for desc in cache_descs:\n            mem_pool_size += desc.aligned_size\n\n        # create pool\n        mem_pool = torch.zeros((num_layers, num_blocks, mem_pool_size), dtype=torch.uint8, device=device)\n\n        # slice caches\n        caches = []\n        remain_pool = mem_pool\n        for desc in cache_descs:\n            cache = remain_pool[:, :, :desc.size].view(desc.dtype).view((num_layers, num_blocks, *desc.shape))\n            remain_pool = remain_pool[:, :, desc.aligned_size:]\n            caches.append(cache)\n        return mem_pool, caches\n\n    def allocate_gpu_cache(self):\n        \"\"\"Allocate caches on GPU.\"\"\"\n        mem_pool, caches = self.allocate_caches(\n            num_blocks=self.num_gpu_blocks,\n            model_config=self.model_config,\n            cache_config=self.cache_config,\n            world_size=self.world_size,\n            device='cuda',\n        )\n        self.full_gpu_cache = mem_pool\n        self.local_gpu_cache = list(zip(*caches))\n        return self.local_gpu_cache\n\n    def allocate_cpu_cache(self):\n        \"\"\"Allocate caches on Host.\"\"\"\n        mem_pool, caches = self.allocate_caches(\n            num_blocks=self.num_cpu_blocks,\n            model_config=self.model_config,\n            cache_config=self.cache_config,\n            world_size=self.world_size,\n            device='cpu',\n        )\n        self.full_cpu_cache = mem_pool\n        self.local_cpu_cache = list(zip(*caches))\n        return self.local_cpu_cache\n\n    @staticmethod\n    def get_custom_cache_shape_impl(num_layers: int, num_blocks: int, block_size: int, shape: List[int]):\n        \"\"\"Get single block shape.\"\"\"\n        return (num_layers, num_blocks, block_size, *shape)\n\n    @staticmethod\n    def _allocate_single_custom_cache(shape: Sequence[int], dtype: torch.dtype, device: str):\n        \"\"\"Allocate custom cache.\"\"\"\n        return torch.empty(shape, dtype=dtype, device=device)\n\n    def allocate_custom_cache(self, device: str):\n        \"\"\"Allocate custom caches on GPU.\"\"\"\n        num_layers = self.model_config.num_layers\n        custom_caches = []\n        for shape, dtype in self.model_config.cache_shapes:\n            custom_shape = self.get_custom_cache_shape_impl(\n                num_layers=num_layers,\n                num_blocks=self.num_gpu_blocks,\n                block_size=self.block_size,\n                shape=shape,\n            )\n            custom_cache = self._allocate_single_custom_cache(shape=custom_shape, dtype=dtype, device=device)\n            custom_caches.append(custom_cache)\n        return custom_caches\n\n    @torch.inference_mode()\n    def _swap(self, src: List[torch.Tensor], dst: List[torch.Tensor], src_to_dst: Dict[int, int]):\n        \"\"\"Move caches from src memory to dst memory.\n\n        Args:\n            src (List[KVCache]): Source cache.\n            dst (List[KVCache]): Destination cache.\n            src_to_dst (Dict[int, int]): Map between src and dst.\n        \"\"\"\n        BLOCKS_PER_COPY = 2\n        num_copy = len(src_to_dst)\n        src_idx, dst_idx = list(zip(*src_to_dst.items()))\n        src_idx = torch.tensor(src_idx, device=src[0].device)\n        dst_idx = torch.tensor(dst_idx, device=dst[0].device)\n        with torch.cuda.stream(self.cache_stream):\n            for scache, dcache in zip(src, dst):\n                for idx in range(0, num_copy, BLOCKS_PER_COPY):\n                    sidx = src_idx[idx:idx + BLOCKS_PER_COPY]\n                    didx = dst_idx[idx:idx + BLOCKS_PER_COPY]\n                    sdata = scache[:, sidx]\n                    dcache.index_copy_(1, didx, sdata.to(dcache.device))\n            self.events.record(stream=self.cache_stream)\n\n    def swap_in(self, src_to_dst: Dict[int, int]) -> None:\n        \"\"\"Move cache from Host to Device.\n\n        Args:\n            src_to_dst (Dict[int, int]): Map between src and dst.\n        \"\"\"\n        self._swap([self.full_cpu_cache], [self.full_gpu_cache], src_to_dst)\n\n    def swap_out(self, src_to_dst: Dict[int, int]) -> None:\n        \"\"\"Move cache from Device to Host.\n\n        Args:\n            src_to_dst (Dict[int, int]): Map between src and dst.\n        \"\"\"\n        self._swap([self.full_gpu_cache], [self.full_cpu_cache], src_to_dst)\n\n    @classmethod\n    def get_cache_block_size(cls, cache_config: CacheConfig, model_config: ModelConfig, world_size: int = 1) -> int:\n        \"\"\"Get the required cache size of the model.\n\n        Args:\n            block_size (int): The token numbers of the block.\n            model_config (ModelConfig): The config of the model.\n\n        Return:\n            int: Required memory size in bytes.\n        \"\"\"\n        mem_pool, _ = cls.allocate_caches(\n            num_blocks=1,\n            model_config=model_config,\n            cache_config=cache_config,\n            world_size=world_size,\n            device='meta',\n        )\n\n        return mem_pool.numel() * mem_pool.element_size()\n\n    \"\"\" Metheds for PD Disaggregation Begin. \"\"\"\n\n    def p2p_initialize(self, migration_init_request: DistServeInitRequest) -> DistServeKVTransferEndpointInfo:\n        if not self.migration_backend_impl:\n            self.migration_backend_impl = MIGRATION_BACKENDS.module_dict[self.cache_config.migration_backend.name]()\n        migration_init_request.rank = self.rank\n        self.migration_backend_impl.p2p_initialize(migration_init_request)\n        for i, t in enumerate([self.full_gpu_cache]):\n            if t.numel() == 0:\n                continue\n            register_mr_request = DistServeRegisterMRMessage(protocol=migration_init_request.protocol,\n                                                             remote_engine_id=migration_init_request.remote_engine_id,\n                                                             mr_key=i,\n                                                             addr=t.data_ptr(),\n                                                             offset=t.storage_offset(),\n                                                             length=t.numel() * t.itemsize)\n            self.migration_backend_impl.register_memory_region(register_mr_request)\n        return DistServeKVTransferEndpointInfo(protocol=migration_init_request.protocol,\n                                               endpoint_info=json.dumps(\n                                                   self.migration_backend_impl.endpoint_info(\n                                                       migration_init_request.remote_engine_id,\n                                                       migration_init_request.protocol)))\n\n    def p2p_connect(self, remote_engine_id: str, migration_conn_request: List[DistServeKVTransferEndpointInfo]):\n        self.migration_backend_impl.p2p_connect(remote_engine_id, migration_conn_request[self.tp_rank])\n\n    async def migrate(self, migration_execution_inputs: MigrationExecutionBatch):\n\n        assignment_len = self.full_gpu_cache.element_size() * self.full_gpu_cache.size(-1)\n        layer_stride = self.cache_config.num_gpu_blocks * assignment_len\n\n        def get_assignment_batch(mr_key, block_ids, assignment_len, layer_stride, remote_layer_stride):\n            return [\n                AssignmentInstruct(mr_key=mr_key,\n                                   target_offset=block_id[0] * assignment_len + layer * remote_layer_stride,\n                                   source_offset=block_id[1] * assignment_len + layer * layer_stride,\n                                   length=assignment_len) for layer in range(self.model_config.num_layers)\n                for block_id in block_ids\n            ]\n\n        assignment_batch: List[Tuple[str, int, int, int]] = []  # mr_key, target, source, offset\n        for migration_exe_req in migration_execution_inputs.requests:\n            remote_engine_id = migration_exe_req[0]\n            blocks_to_migration = migration_exe_req[1]\n            remote_layer_stride = self.migration_backend_impl.links[\n                remote_engine_id].remote_engine_config.num_gpu_blocks * assignment_len\n\n            for i, t in enumerate([self.full_gpu_cache]):\n                if t.numel() == 0:\n                    continue\n                assignment_batch.extend(\n                    get_assignment_batch(i, blocks_to_migration, assignment_len, layer_stride, remote_layer_stride))\n        await self.migration_backend_impl.p2p_migrate(\n            MigrationAssignment(\n                protocol=migration_execution_inputs.protocol,\n                remote_engine_id=remote_engine_id,\n                batch=assignment_batch,\n            ))\n\n    \"\"\" Metheds for PD Disaggregation End. \"\"\"\n\n\nclass StateCacheEngine:\n    \"\"\"Cache engine for state cache.\"\"\"\n\n    def __init__(self, cache_config: CacheConfig):\n        self.cache_config = cache_config\n        self.mem_pool, self._state_caches = self.allocate_caches(num_caches=cache_config.num_state_caches,\n                                                                 state_shapes=cache_config.states_shapes,\n                                                                 device='cuda')\n\n    @staticmethod\n    def allocate_caches(num_caches: int, state_shapes: List[Tuple[Tuple[int], torch.dtype]], device: torch.device):\n        \"\"\"Allocate cache implement.\"\"\"\n\n        if len(state_shapes) == 0 or num_caches == 0:\n            return torch.empty((0, 0), dtype=torch.uint8, device=device), []\n\n        cache_descs = [CacheDesc(shape, dtype) for shape, dtype in state_shapes]\n\n        # get mempool size\n        mem_pool_size = 0\n        for desc in cache_descs:\n            mem_pool_size += desc.aligned_size\n\n        # create pool\n        mem_pool = torch.zeros((num_caches, mem_pool_size), dtype=torch.uint8, device=device)\n\n        # slice caches\n        caches = []\n        remain_pool = mem_pool\n        for desc in cache_descs:\n            cache = remain_pool[:, :desc.size].view(desc.dtype).view((num_caches, *desc.shape))\n            remain_pool = remain_pool[:, desc.aligned_size:]\n            caches.append(cache)\n        return mem_pool, caches\n\n    @staticmethod\n    def get_cache_state_size(state_shapes: List[Tuple[Tuple[int], torch.dtype]]) -> int:\n        \"\"\"Get the required cache size of the state cache.\n\n        Args:\n            state_shapes (List[Tuple[Tuple[int], torch.dtype]]): The shapes and dtypes of the states.\n\n        Return:\n            int: Required memory size in bytes.\n        \"\"\"\n        mem_pool, _ = StateCacheEngine.allocate_caches(num_caches=1, state_shapes=state_shapes, device='meta')\n        return mem_pool.numel() * mem_pool.element_size()\n\n    @property\n    def state_caches(self):\n        \"\"\"State caches.\"\"\"\n        return self._state_caches\n\n    def init_caches(self, idx: torch.Tensor, mask: torch.Tensor):\n        \"\"\"Initialize state caches.\n\n        idx: indices of caches to be initialized.\n        mask: mask to indicate which idx to be initialized.\n        \"\"\"\n        if idx is None:\n            return\n\n        if len(self._state_caches) <= 0:\n            return\n\n        num_caches = self.cache_config.num_state_caches\n\n        # get mask of all caches so we can perform inplace mask fill\n        cache_masks = torch.zeros((num_caches, ), dtype=torch.bool, device=idx.device)\n        cache_masks.index_copy_(0, idx, mask)\n        reshaped_mask = cache_masks.view((-1, ) + (1, ) * (self.mem_pool.dim() - 1))\n        self.mem_pool.masked_fill_(reshaped_mask, 0)\n"
  },
  {
    "path": "lmdeploy/pytorch/engine/config_builder.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport copy\nimport os\n\nfrom lmdeploy.messages import PytorchEngineConfig, SpeculativeConfig\nfrom lmdeploy.pytorch.config import (BackendConfig, CacheConfig, DistConfig, MiscConfig, SchedulerConfig,\n                                     SpecDecodeConfig)\nfrom lmdeploy.utils import get_logger, get_max_batch_size, get_model\n\n\nclass ConfigBuilder:\n\n    @staticmethod\n    def update_engine_config(engine_config: PytorchEngineConfig):\n        \"\"\"Update pytorch engine config.\"\"\"\n        logger = get_logger('lmdeploy')\n\n        # make sure engine exits\n        if engine_config is None:\n            engine_config = PytorchEngineConfig()\n        else:\n            engine_config = copy.deepcopy(engine_config)\n\n        if engine_config.max_batch_size is None:\n            engine_config.max_batch_size = get_max_batch_size(engine_config.device_type)\n\n        if engine_config.dllm_block_length is not None:\n            max_prefill_token_num = engine_config.max_prefill_token_num\n            max_batch_size = engine_config.max_batch_size\n            if max_batch_size * engine_config.dllm_block_length > max_prefill_token_num:\n                engine_config.max_batch_size = max_prefill_token_num // engine_config.dllm_block_length\n                logger.warning(f'Update max_batch_size to {engine_config.max_batch_size} '\n                               f'since dllm_block_length({engine_config.dllm_block_length}) * max_batch_size '\n                               f'({max_batch_size}) > max_prefill_token_num ({max_prefill_token_num}).')\n\n        if engine_config.dp != 1:\n            if engine_config.tp == 1 and engine_config.ep == 1:\n                logger.warning('Data parallelism is enabled but tensor parallelism and '\n                               'expert parallelism are not enabled. Setting dp=1.')\n                engine_config.dp = 1\n                engine_config.dp_rank = 0\n\n        return engine_config\n\n    @staticmethod\n    def build_scheduler_config(engine_config: PytorchEngineConfig):\n        \"\"\"Build scheduler config.\"\"\"\n        scheduler_config = SchedulerConfig(max_batches=engine_config.max_batch_size,\n                                           max_session_len=engine_config.session_len,\n                                           prefill_interval=engine_config.prefill_interval)\n        return scheduler_config\n\n    @staticmethod\n    def build_cache_config(engine_config: PytorchEngineConfig):\n        \"\"\"Build cache config.\"\"\"\n        cache_config = CacheConfig(\n            max_batches=engine_config.max_batch_size,\n            block_size=engine_config.block_size,\n            num_cpu_blocks=engine_config.num_cpu_blocks,\n            num_gpu_blocks=engine_config.num_gpu_blocks,\n            cache_max_entry_count=engine_config.cache_max_entry_count,\n            max_prefill_token_num=engine_config.max_prefill_token_num,\n            enable_prefix_caching=engine_config.enable_prefix_caching,\n            quant_policy=engine_config.quant_policy,\n            device_type=engine_config.device_type,\n            migration_backend=engine_config.migration_backend,\n            role=engine_config.role,\n            # reserve 1 blocks for dummy input and padding\n            num_reserved_gpu_blocks=1)\n        return cache_config\n\n    @staticmethod\n    def build_backend_config(engine_config: PytorchEngineConfig):\n        \"\"\"Build backend config.\"\"\"\n        backend_config = BackendConfig(\n            eager_mode=engine_config.eager_mode,\n            device_type=engine_config.device_type,\n        )\n        return backend_config\n\n    @staticmethod\n    def build_dist_config(engine_config: PytorchEngineConfig):\n        \"\"\"Build dist config.\"\"\"\n        dist_config = DistConfig.from_engine_config(engine_config=engine_config)\n        return dist_config\n\n    @staticmethod\n    def build_misc_config(engine_config: PytorchEngineConfig):\n        \"\"\"Build misc config.\"\"\"\n        misc_config = MiscConfig.from_engine_config(engine_config)\n        return misc_config\n\n    @staticmethod\n    def build_specdecode_config(target_model, speculative_config: SpeculativeConfig, engine_config: PytorchEngineConfig,\n                                cache_config: CacheConfig):\n        \"\"\"Build spec decode config.\"\"\"\n        specdecode_config = None\n        if speculative_config is not None:\n            draft_model = speculative_config.model\n            if draft_model and not os.path.exists(speculative_config.model):\n                draft_model = get_model(draft_model, engine_config.download_dir, engine_config.revision)\n\n            specdecode_config = SpecDecodeConfig.from_config(\n                method=speculative_config.method,\n                num_speculative_tokens=speculative_config.num_speculative_tokens,\n                model=draft_model,\n                target_model=target_model,\n                target_cache_cfg=cache_config,\n                dtype=engine_config.dtype,\n            )\n        return specdecode_config\n"
  },
  {
    "path": "lmdeploy/pytorch/engine/engine.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport asyncio\nimport gc\nimport os\nfrom dataclasses import dataclass\nfrom typing import Any, Dict, List, Optional, Union\n\nimport numpy as np\nimport torch\n\nfrom lmdeploy.messages import PytorchEngineConfig, RequestMetrics, ResponseType, SpeculativeConfig\nfrom lmdeploy.pytorch.disagg.config import EngineRole\nfrom lmdeploy.pytorch.disagg.conn.engine_conn import EngineP2PConnection\nfrom lmdeploy.pytorch.disagg.conn.protocol import (DistServeConnectionRequest, DistServeDropConnectionRequest,\n                                                   DistServeInitRequest)\nfrom lmdeploy.utils import get_logger, get_model\n\nfrom ..adapter.adapter import AdapterManager\nfrom ..config import CacheConfig, ModelConfig\nfrom ..messages import SchedulerSequence, UpdateTokenMode\nfrom ..paging import Scheduler\nfrom ..strategies import build_strategy_factory\nfrom .base import EngineBase\nfrom .config_builder import ConfigBuilder\nfrom .engine_checker import EngineChecker\nfrom .executor import build_executor\nfrom .request import Request, RequestManager, RequestType, Response\n\nlogger = get_logger('lmdeploy')\n\nSeqList = List[SchedulerSequence]\n\n\n@dataclass\nclass InferOutput:\n    \"\"\"The output of the model inference.\"\"\"\n\n    session_id: int\n    resp: Response\n    token_ids: Union[np.ndarray, List[int]]\n    meta: Any = None\n    finish: bool = False\n    logits: torch.Tensor = None\n    logprobs: torch.Tensor = None\n\n    # send cache blocks back for migration in Disaggregated LLM Serving\n    # when Prefill Engine is Done.\n    cache_block_ids: List[int] = None\n\n    # for logging\n    req_metrics: RequestMetrics = None\n\n    # expert ids\n    routed_experts: torch.Tensor = None\n\n\ndef _build_seq_meta(cache_config: CacheConfig, seq_strategy: Any, sampling_strategy: Any):\n    from lmdeploy.pytorch.messages import SequenceMeta\n\n    seq_meta = SequenceMeta(cache_config.block_size, strategy=seq_strategy, sampling_strategy=sampling_strategy)\n    return seq_meta\n\n\ndef response_reqs(req_manager: RequestManager,\n                  resp: Response,\n                  resp_type: ResponseType,\n                  data: Any = None,\n                  err_msg: str = ''):\n    \"\"\"response.\"\"\"\n    if resp.type == ResponseType.FINISH:\n        return\n    resp.type = resp_type\n    resp.data = data\n    resp.err_msg = err_msg\n    req_manager.response(resp)\n\n\nclass Engine(EngineBase):\n    \"\"\"The inference engine of lmdeploy pytorch.\n\n    Args:\n        model_path (str): The hugging face model path.\n        engine_config (PytorchEngineConfig): The config of the Engine.\n        trust_remote_code (bool): Trust remote code.\n    \"\"\"\n\n    def __init__(\n        self,\n        model_path: str,\n        engine_config: PytorchEngineConfig = None,\n        trust_remote_code: bool = True,\n        speculative_config: SpeculativeConfig = None,\n    ) -> None:\n        # make sure engine config exist\n        engine_config = ConfigBuilder.update_engine_config(engine_config)\n\n        # frequently gc would cause latency spike\n        # default threshold (700, 10, 10)\n        # WARNING: I don't know if it is a good idea to put gc setting here.\n        gc.set_threshold(10000, 100, 100)\n\n        # dist args\n        self.tp = engine_config.tp\n        self.dp = engine_config.dp\n        self.dp_rank = engine_config.dp_rank\n\n        # download models and adapters\n        if not os.path.exists(model_path):\n            model_path = get_model(model_path, engine_config.download_dir, engine_config.revision)\n\n        adapters = engine_config.adapters\n        if adapters is not None and len(adapters) > 0:\n            adapters = self._download_adapters(adapters, engine_config)\n\n        # check environment\n        checker = EngineChecker(model_path=model_path,\n                                engine_config=engine_config,\n                                trust_remote_code=trust_remote_code,\n                                logger=logger)\n        checker.handle()\n\n        # build configs\n        scheduler_config = ConfigBuilder.build_scheduler_config(engine_config)\n        cache_config = ConfigBuilder.build_cache_config(engine_config)\n        backend_config = ConfigBuilder.build_backend_config(engine_config)\n        dist_config = ConfigBuilder.build_dist_config(engine_config)\n        misc_config = ConfigBuilder.build_misc_config(engine_config)\n        # spec decode\n        self.specdecode_config = ConfigBuilder.build_specdecode_config(model_path, speculative_config, engine_config,\n                                                                       cache_config)\n\n        # build model agent\n        self.executor = build_executor(\n            model_path,\n            cache_config=cache_config,\n            backend_config=backend_config,\n            dist_config=dist_config,\n            misc_config=misc_config,\n            adapters=adapters,\n            device_type=engine_config.device_type,\n            distributed_executor_backend=engine_config.distributed_executor_backend,\n            dtype=engine_config.dtype,\n            specdecode_config=self.specdecode_config,\n        )\n        self.executor.init()\n\n        # strategies\n        self.strategy_factory = build_strategy_factory(self.model_config, self.executor.misc_config,\n                                                       self.specdecode_config)\n        self.sampling_strategy = self.strategy_factory.build_sampling_strategy()\n        self.model_agent_strategy = self.strategy_factory.build_model_agent_strategy()\n        self.engine_strategy = self.strategy_factory.build_engine_strategy(cache_config=cache_config,\n                                                                           scheduler_config=scheduler_config)\n        self.seq_strategy = self.strategy_factory.build_sequence_strategy()\n\n        self.input_processor = self.executor.get_input_processor()\n        cache_config = self.executor.cache_config\n        self.adapter_manager = self._build_adapter_manager(adapters)\n        self.seq_meta = _build_seq_meta(cache_config,\n                                        seq_strategy=self.seq_strategy,\n                                        sampling_strategy=self.sampling_strategy)\n        self.scheduler = Scheduler(scheduler_config, cache_config, seq_meta=self.seq_meta)\n\n        # engine args\n        self.model_path = model_path\n        self.engine_config = engine_config\n        self.scheduler_config = scheduler_config\n        self.cache_config = cache_config\n        self.backend_config = backend_config\n        self.dist_config = dist_config\n        self.misc_config = self.executor.misc_config\n        self.max_session_len = self._get_max_session_len()\n        self.engine_config.num_cpu_blocks = self.cache_config.num_cpu_blocks\n        self.engine_config.num_gpu_blocks = self.cache_config.num_gpu_blocks\n\n        self.req_manager = self._bind_request_manager()\n\n        # create main thread\n        self.req_manager.set_main_loop_func(self.async_loop)\n        self._loop_main = None\n\n        # for PD Disaggregation\n        # For migrating prefill request to decode engine\n        self.migration_event: asyncio.Event = None\n        # For backpressure prefill request when cache is full\n        self.perfill_watermark_event: asyncio.Event = None\n\n        self.engine_conn = EngineP2PConnection(self)\n\n    @classmethod\n    def from_pretrained(cls,\n                        pretrained_model_name_or_path: str,\n                        engine_config: PytorchEngineConfig = None,\n                        trust_remote_code: bool = True,\n                        speculative_config: SpeculativeConfig = None,\n                        **kwargs):\n        \"\"\"Lmdeploy python inference engine.\n\n        Args:\n            pretrained_model_name_or_path (str):\n                It could be one of the following options:\n                    - i) The model_id of a lmdeploy-quantized model hosted\n                      inside a model repo on huggingface.co, such as\n                      \"InternLM/internlm-chat-20b-4bit\",\n                      \"lmdeploy/llama2-chat-70b-4bit\", etc.\n                    - ii) The model_id of a model hosted inside a model repo\n                      on huggingface.co, such as \"InternLM/internlm-chat-7b\",\n                      \"Qwen/Qwen-7B-Chat \", \"baichuan-inc/Baichuan2-7B-Chat\"\n                      and so on.\n            engine_config (PytorchEngineConfig): Pytorch engine config.\n            trust_remote_code (bool): Trust remote code\n        \"\"\"\n        if engine_config is not None and engine_config.enable_mp_engine:\n            from .mp_engine import build_mp_engine\n            backend = engine_config.mp_engine_backend\n            return build_mp_engine(\n                backend=backend,\n                model_path=pretrained_model_name_or_path,\n                engine_config=engine_config,\n                trust_remote_code=trust_remote_code,\n                speculative_config=speculative_config,\n            )\n        if len(kwargs) > 0:\n            logger.debug(f'Get unexpected kwargs: {kwargs}')\n        return cls(\n            model_path=pretrained_model_name_or_path,\n            engine_config=engine_config,\n            trust_remote_code=trust_remote_code,\n            speculative_config=speculative_config,\n        )\n\n    def _download_adapters(self, adapters: Dict[str, str], engine_config: PytorchEngineConfig):\n        \"\"\"Download adapters.\"\"\"\n        download_dir = engine_config.download_dir\n        revision = engine_config.revision\n        new_adapters = dict()\n        for name, path in adapters.items():\n            if os.path.exists(path):\n                new_adapters[name] = path\n                continue\n            new_path = get_model(path, download_dir=download_dir, revision=revision)\n            new_adapters[name] = new_path\n\n        return new_adapters\n\n    def _build_adapter_manager(self, adapters):\n        return AdapterManager(adapters)\n\n    def _bind_request_manager(self):\n        \"\"\"Bind request manager.\"\"\"\n        req_manager = RequestManager()\n        req_manager.bind_func(RequestType.ADD_SESSION, self._on_add_session)\n        req_manager.bind_func(RequestType.STOP_SESSION, self._on_stop_session)\n        req_manager.bind_func(RequestType.END_SESSION, self._on_end_session)\n        req_manager.bind_func(RequestType.ADD_MESSAGE, self._on_add_message)\n        return req_manager\n\n    def _response(self, resp: Response, resp_type: ResponseType, data: Any = None, err_msg: str = ''):\n        \"\"\"response.\"\"\"\n        return response_reqs(self.req_manager, resp, resp_type, data, err_msg)\n\n    def _get_max_session_len(self):\n        \"\"\"Get max session len.\"\"\"\n        session_len = self.scheduler_config.max_session_len\n        num_gpu_blocks = self.cache_config.num_gpu_blocks - self.cache_config.num_reserved_gpu_blocks\n        max_tokens = (num_gpu_blocks * self.cache_config.block_size)\n        window_size = self.cache_config.window_size\n        if window_size > 0 and window_size <= max_tokens:\n            max_tokens = (1 << 63) - 1\n        max_tokens -= self.cache_config.block_size\n        if session_len is None:\n            session_len = max_tokens\n        else:\n            session_len = min(max_tokens, session_len)\n        return session_len\n\n    def _on_add_session(self, reqs: List[Request], **kwargs):\n        \"\"\"On add session callback.\"\"\"\n        for req in reqs:\n            session_id = req.data['session_id']\n            resp = req.data.get('response', True)\n            resp_type = ResponseType.SESSION_REPEAT\n            if session_id not in self.scheduler.sessions:\n                self.scheduler.add_session(session_id)\n                resp_type = ResponseType.SUCCESS\n            if resp:\n                self._response(req.resp, resp_type)\n\n    def _on_stop_session(self, reqs: List[Request], **kwargs):\n        \"\"\"On stop session callback.\"\"\"\n        for req in reqs:\n            session_id = req.data['session_id']\n            resp = req.data.get('response', True)\n            resp_type = ResponseType.SESSION_NOT_EXIST\n            if session_id in self.scheduler.sessions:\n                self.scheduler.stop_session(session_id)\n                session = self.scheduler.sessions[session_id]\n                for seq in session.sequences.values():\n                    _resp: Response = getattr(seq, 'resp', None)\n                    if _resp is not None:\n                        _resp.type = ResponseType.CANCEL\n                        _resp.is_done = True\n                        self.req_manager.response(_resp)\n                resp_type = ResponseType.SUCCESS\n            if resp:\n                self._response(req.resp, resp_type)\n\n    def _on_end_session(self, reqs: List[Request], **kwargs):\n        \"\"\"On end session callback.\"\"\"\n        for req in reqs:\n            session_id = req.data['session_id']\n            resp = req.data.get('response', True)\n            resp_type = ResponseType.SESSION_NOT_EXIST\n            if session_id in self.scheduler.sessions:\n                msgs = list(self.scheduler.sessions[session_id].sequences.values())\n                if len(msgs) > 0 and msgs[0].preserve_cache:\n                    msgs[0].state.finish()\n                else:\n                    self.end_session(session_id)\n                resp_type = ResponseType.SUCCESS\n            if resp:\n                self._response(req.resp, resp_type)\n\n    def _on_add_message(self, reqs: List[Request], **kwargs):\n        \"\"\"On add message callback.\"\"\"\n        valid_reqs = []\n        for req in reqs:\n            req_data = req.data\n            session_id = req_data['session_id']\n            if self.scheduler and session_id not in self.scheduler.sessions:\n                self._response(req.resp, ResponseType.SESSION_NOT_EXIST)\n                continue\n            valid_reqs.append(req)\n            if req_data.get('input_multimodals', None) is None:\n                continue\n            elif self.input_processor is None:\n                logger.warning('Do not support Multimodal inputs.')\n                continue\n            input_ids = req_data['token_ids']\n            input_multimodals = req_data['input_multimodals']\n            if len(input_multimodals) == 0:\n                req_data['input_multimodals'] = None\n                continue\n\n            if self.engine_config.disable_vision_encoder:\n                # ignore multimodal inputs\n                req_data['input_multimodals'] = None\n                logger.warning('Vision encoder has not been loaded, multimodal inputs will be ignored.')\n                continue\n\n            result = self.input_processor.preprocess_input(input_ids, input_multimodals)\n\n            input_ids = result.input_ids\n            input_multimodals = result.input_multimodals\n\n            req_data['token_ids'] = input_ids\n            req_data['input_multimodals'] = input_multimodals\n\n        if len(valid_reqs) > 0:\n            self._add_message(valid_reqs)\n\n    def _add_message(self, reqs: List[Request]):\n\n        def __update_max_new_tokens(msg):\n            \"\"\"Update max new tokens.\"\"\"\n            max_session_len = self.max_session_len\n            sampling_param = msg.sampling_param\n            max_new_tokens = sampling_param.max_new_tokens\n            num_all_tokens = msg.num_valid_ids\n            if self.engine_config.role == EngineRole.Prefill:\n                sampling_param.max_new_tokens = 1\n            elif max_new_tokens + num_all_tokens > max_session_len:\n                logger.warning(\n                    f'session[{msg.session_id}]: num tokens is larger than max session len {max_session_len}. '\n                    f'Update max_new_tokens={max_session_len - num_all_tokens}.')\n                sampling_param.max_new_tokens = max_session_len - num_all_tokens\n\n        scheduler = self.scheduler\n        for req in reqs:\n            session_id = req.data['session_id']\n            sess = scheduler.sessions.get(session_id, None)\n            if sess is None:\n                self._response(req.resp, ResponseType.SESSION_NOT_EXIST)\n                continue\n            # TODO: support 1 session n sequence\n            sampling_param = req.data['sampling_param']\n            if len(sess.sequences) == 0:\n                migration_request = req.data.get('migration_request')\n                assert len(req.data['token_ids']) > 0, ('Empty input is not allowed.')\n                sess.add_sequence(req.data['token_ids'],\n                                  sampling_param=sampling_param,\n                                  adapter_name=req.data['adapter_name'],\n                                  multimodals=req.data.get('input_multimodals'),\n                                  input_embeddings=req.data.get('input_embeddings', ),\n                                  migration_request=migration_request,\n                                  resp_cache=req.data.get('with_cache'),\n                                  preserve_cache=req.data.get('preserve_cache'))\n                msg = next(iter(sess.sequences.values()))\n                if migration_request:\n                    self.migration_event.set()\n            else:\n                msg = next(iter(sess.sequences.values()))\n                msg.update_token_ids(\n                    req.data['token_ids'],\n                    multimodals=req.data.get('input_multimodals'),\n                    embeddings=req.data.get('input_embeddings'),\n                    mode=UpdateTokenMode.INPUTS,\n                )\n                msg.sampling_param = sampling_param\n                msg.state.activate()\n\n            __update_max_new_tokens(msg)\n            msg.resp = req.resp\n\n    @property\n    def model_config(self) -> ModelConfig:\n        \"\"\"Model config.\"\"\"\n        return self.executor.model_config\n\n    def p2p_initialize(self, init_request: DistServeInitRequest):\n        return self.engine_conn.p2p_initialize(init_request)\n\n    def p2p_connect(self, conn_request: DistServeConnectionRequest):\n        return self.engine_conn.p2p_connect(conn_request)\n\n    def p2p_drop_connect(self, drop_conn_request: DistServeDropConnectionRequest):\n        return self.engine_conn.p2p_drop_connect(drop_conn_request)\n\n    def _loop_finally(self):\n        \"\"\"Finally process for dist.\"\"\"\n        logger.info('Cleanup executor.')\n        self.migration_event = None\n        self.executor.release()\n\n    def update_params(self, request: Any):\n        \"\"\"Update params.\"\"\"\n        self.executor.update_params(request)\n\n    def sleep(self, level: int = 1):\n        \"\"\"Sleep.\"\"\"\n        self.executor.sleep(level)\n\n    def wakeup(self, tags: Optional[List[str]] = None):\n        \"\"\"Wakeup.\"\"\"\n        self.executor.wakeup(tags)\n\n    async def async_loop(self):\n        engine_loop = None\n        try:\n            from lmdeploy.pytorch.engine.engine_loop import build_engine_loop\n            self._loop_main = asyncio.current_task()\n            event_loop = asyncio.get_event_loop()\n\n            # create engine loop\n            engine_loop = build_engine_loop(self)\n            self.migration_event = engine_loop.migration_event\n\n            # start engine loop\n            engine_loop.start(event_loop)\n            await engine_loop.wait_tasks()\n        except asyncio.CancelledError:\n            logger.info('Engine main loop cancelled.')\n            raise\n        except BaseException:\n            # since AsyncEngine will not wait for engine loop\n            # we have to log it here.\n            logger.exception('Engine main loop failed.')\n            raise\n        finally:\n            logger.debug('Engine main loop finally cleanup.')\n            if engine_loop is not None:\n                engine_loop.stop()\n            self._loop_finally()\n\n    def close(self):\n        if self.executor.device_type == 'cuda':\n            # https://discuss.pytorch.org/t/how-to-delete-a-tensor-in-gpu-to-free-up-memory/48879/32\n            # W/O this, repeatedly rebuilding and destroying engines within the same process\n            # will cause more and more reserved CUDA memory.\n            torch._C._cuda_clearCublasWorkspaces()\n        if self._loop_main is not None:\n            self._loop_main.cancel()\n        else:\n            self._loop_finally()\n\n    def start(self):\n        \"\"\"Start engine loop tasks.\"\"\"\n        if self.req_manager.is_loop_alive():\n            return True\n        self.req_manager.create_loop_task()\n        return True\n\n    def stop(self):\n        \"\"\"Stop engine loop tasks.\"\"\"\n        if self._loop_main is not None:\n            self._loop_main.cancel()\n\n    async def wait_tasks(self):\n        \"\"\"Wait async tasks to finish.\"\"\"\n        if self._loop_main is None:\n            logger.warning('No engine main loop to wait for.')\n            return\n\n        try:\n            # await self._loop_main\n            await self.req_manager.wait_tasks()\n        except asyncio.CancelledError:\n            logger.info('Engine main loop cancelled in wait_tasks.')\n            raise\n\n    def create_instance(self, cuda_stream_id=0):\n        \"\"\"Create a pytorch engine instance.\n\n        Args:\n            cuda_stream_id(int): identity of a cuda stream\n        Returns:\n            EngineInstance: an instance of pytorch engine\n        \"\"\"\n        from .engine_instance import EngineInstance\n        return EngineInstance(self)\n\n    def start_loop(self):\n        \"\"\"Alias of start, API for AsyncEngine.\"\"\"\n        return self.start()\n\n    def end_session(self, session_id: int):\n        \"\"\"End session.\"\"\"\n        if session_id in self.scheduler.sessions:\n            self.scheduler.end_session(session_id)\n            return True\n        return False\n\n    def get_engine_config(self):\n        return self.engine_config\n\n    def get_schedule_metrics(self):\n        return self.scheduler.schedule_metrics\n"
  },
  {
    "path": "lmdeploy/pytorch/engine/engine_checker.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom lmdeploy.messages import PytorchEngineConfig\n\nfrom ..check_env.adapter import AdapterChecker\nfrom ..check_env.base import BaseChecker\nfrom ..check_env.dist import DistChecker\nfrom ..check_env.model import ModelChecker\nfrom ..check_env.torch import TorchChecker\nfrom ..check_env.transformers import TransformersChecker\n\n\nclass EngineChecker(BaseChecker):\n    \"\"\"Check transformers is available.\"\"\"\n\n    def __init__(self,\n                 model_path: str,\n                 engine_config: PytorchEngineConfig,\n                 trust_remote_code: bool = True,\n                 logger=None):\n        super().__init__(logger)\n        logger = self.get_logger()\n\n        self.engine_config = engine_config\n\n        dtype = engine_config.dtype\n        device_type = engine_config.device_type\n\n        # pytorch\n        torch_checker = TorchChecker(logger=logger)\n\n        if device_type == 'cuda':\n            # triton\n            from ..check_env.cuda import CudaChecker\n            from ..check_env.triton import TritonChecker\n            cuda_checker = CudaChecker(model_format=engine_config.model_format, logger=logger)\n            cuda_checker.register_required_checker(torch_checker)\n            triton_checker = TritonChecker(logger=logger)\n            triton_checker.register_required_checker(cuda_checker)\n            self.register_required_checker(triton_checker)\n        else:\n            # deeplink\n            from ..check_env.deeplink import DeeplinkChecker\n            dl_checker = DeeplinkChecker(device_type, logger=logger)\n            self.register_required_checker(dl_checker)\n            self.register_required_checker(torch_checker)\n\n        # transformers\n\n        # model\n        trans_checker = TransformersChecker()\n        model_checker = ModelChecker(model_path=model_path,\n                                     trust_remote_code=trust_remote_code,\n                                     dtype=dtype,\n                                     device_type=device_type,\n                                     logger=logger)\n        model_checker.register_required_checker(torch_checker)\n        model_checker.register_required_checker(trans_checker)\n        self.register_required_checker(model_checker)\n\n        # adapters\n        adapters = engine_config.adapters\n        if adapters is not None:\n            adapter_paths = list(adapters.values())\n            for adapter in adapter_paths:\n                adapter_checker = AdapterChecker(adapter, logger=logger)\n                self.register_required_checker(adapter_checker)\n\n        # dist\n        dist_checker = DistChecker(engine_config.tp,\n                                   engine_config.dp,\n                                   engine_config.ep,\n                                   engine_config.distributed_executor_backend,\n                                   device_type=engine_config.device_type,\n                                   logger=logger)\n        self.register_required_checker(dist_checker)\n\n    def check(self):\n        \"\"\"check.\"\"\"\n        engine_config = self.engine_config\n\n        if engine_config.thread_safe:\n            self.log_and_exit(\n                mod_name='Engine',\n                message='thread safe mode is no longer supported.\\n'\n                'Read https://github.com/InternLM/lmdeploy/blob/main/docs/en/advance/pytorch_multithread.md for more details.',  # noqa: E501\n            )\n\n        if engine_config.max_batch_size <= 0:\n            self.log_and_exit(mod_name='Engine',\n                              message='max_batch_size should be'\n                              f' greater than 0, but got {engine_config.max_batch_size}')\n\n        num_gpu_blocks = engine_config.num_gpu_blocks\n        if num_gpu_blocks > 0 and num_gpu_blocks < 16:\n            self.log_and_exit(mod_name='Engine',\n                              message='num_gpu_blocks should be greater than 16, '\n                              f'but got {num_gpu_blocks}. Set num_gpu_blocks to 0 to automatically '\n                              'determine the number of GPU blocks based on the model size and device memory.')\n\n    def _handle_impl(self):\n        return super().handle()\n\n    def handle(self):\n        import multiprocessing as mp\n        from concurrent.futures import ProcessPoolExecutor\n\n        from lmdeploy.pytorch import envs\n        if not envs.enable_check_env:\n            return\n\n        current_proc = mp.current_process()\n        if not current_proc.daemon and self.engine_config.device_type == 'cuda':\n            mp_ctx = mp.get_context('spawn')\n            with ProcessPoolExecutor(mp_context=mp_ctx) as executor:\n                try:\n                    executor.submit(self._handle_impl).result()\n                except SystemExit:\n                    exit(1)\n                except BaseException as e:\n                    self.log_and_exit(e, mod_name='Engine')\n        else:\n            return self._handle_impl()\n"
  },
  {
    "path": "lmdeploy/pytorch/engine/engine_instance.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom typing import Any, Dict, List\n\nfrom lmdeploy.messages import EngineOutput, GenerationConfig\nfrom lmdeploy.utils import get_logger\n\nfrom ..messages import SamplingParam\nfrom .base import EngineInstanceBase\nfrom .engine import Engine\nfrom .request import RequestSender, RequestType, Response, ResponseType\n\nlogger = get_logger('lmdeploy')\n\nInputMultiModalType = List[Dict[str, Any]]\n\n\ndef _check_resp(resp: Response, state: ResponseType, warning_msg: str = None):\n    \"\"\"Check if response has state.\"\"\"\n    if isinstance(state, ResponseType):\n        state = [state]\n    ret = resp.type in state\n    if not ret and warning_msg is not None:\n        logger.warning(warning_msg)\n    return ret\n\n\ndef _check_resp_success(resp: Response, warning_msg: str = None):\n    \"\"\"Check if response success.\"\"\"\n    return _check_resp(resp, ResponseType.SUCCESS, warning_msg)\n\n\nasync def async_try_add_session(req_sender: RequestSender, session_id: int):\n    \"\"\"Add new session.\n\n    Args:\n        session_id (int): The session id to add.\n    \"\"\"\n    resp = await req_sender.async_send(RequestType.ADD_SESSION, dict(session_id=session_id))\n    _check_resp(resp, [ResponseType.SUCCESS, ResponseType.SESSION_REPEAT], (f'Can not add session {session_id} '\n                                                                            f'with error: {resp.type}'))\n\n\nasync def async_cancel(req_sender: RequestSender, session_id: int):\n    \"\"\"Stop current streaming inference.\"\"\"\n    resp = await req_sender.async_send(RequestType.STOP_SESSION, dict(session_id=session_id))\n    _check_resp_success(resp, (f'Failed to cancel session: {session_id}. '\n                               f'Error: {resp.type}.'))\n\n\ndef try_add_session(req_sender: RequestSender, session_id: int):\n    \"\"\"Add new session.\n\n    Args:\n        session_id (int): The session id to add.\n    \"\"\"\n    resp = req_sender.send(RequestType.ADD_SESSION, dict(session_id=session_id))\n    _check_resp(resp, [ResponseType.SUCCESS, ResponseType.SESSION_REPEAT], (f'Can not add session {session_id} '\n                                                                            f'with error: {resp.type}'))\n\n\ndef end(req_sender: RequestSender, session_id: int):\n    \"\"\"End the given session.\"\"\"\n    logger.debug(f'session[{session_id}] try end session.')\n    req_sender.send_async(RequestType.END_SESSION, dict(session_id=session_id, response=False))\n\n\ndef cancel(req_sender: RequestSender, session_id: int):\n    \"\"\"Stop current streaming inference.\"\"\"\n    logger.debug(f'session[{session_id}] try end session.')\n    resp = req_sender.send(RequestType.STOP_SESSION, dict(session_id=session_id))\n    _check_resp_success(resp, (f'Failed to cancel session: {session_id}. '\n                               f'Error: {resp.type}.'))\n\n\nclass EngineInstance(EngineInstanceBase):\n    \"\"\"Instance of TurboMind.\n\n    Args:\n        engine (Engine): engine\n    \"\"\"\n\n    def __init__(self, engine: Engine):\n        self.engine = engine\n        self.req_sender = engine.req_manager.build_sender()\n\n        self.max_input_len = self.engine.max_session_len\n        self._enable_transfer_obj_ref = engine.engine_config.enable_transfer_obj_ref and \\\n            engine.engine_config.distributed_executor_backend == 'ray'\n\n    def __del__(self):\n        \"\"\"Destructor.\"\"\"\n        self.engine.req_manager.senders.pop(self.req_sender.sender_id)\n\n    def _get_extra_outputs(self, resp: Response):\n        \"\"\"Get extra outputs.\"\"\"\n        outputs = dict(routed_experts=None)\n        routed_experts = resp.data.get('routed_experts', None) if resp.data else None\n        if routed_experts is not None and resp.type in [ResponseType.FINISH, ResponseType.CANCEL]:\n            if self._enable_transfer_obj_ref:\n                import pybase64\n                import ray\n\n                ref = ray.put(routed_experts)\n                data = ray.cloudpickle.dumps(ref)\n                outputs['routed_experts'] = pybase64.b64encode(data).decode('utf-8')\n            else:\n                outputs['routed_experts'] = routed_experts\n        return outputs\n\n    async def _async_try_add_session(self, session_id: int):\n        \"\"\"Add new session.\n\n        Args:\n            session_id (int): The session id to add.\n        \"\"\"\n        return await async_try_add_session(self.req_sender, session_id)\n\n    def _try_add_session(self, session_id: int):\n        \"\"\"Add new session.\n\n        Args:\n            session_id (int): The session id to add.\n        \"\"\"\n        return try_add_session(self.req_sender, session_id)\n\n    async def async_stream_infer(self,\n                                 session_id: int,\n                                 input_ids: List[int],\n                                 gen_config: GenerationConfig = None,\n                                 multimodal: InputMultiModalType = None,\n                                 adapter_name: str = None,\n                                 **kwargs):\n        \"\"\"Send stream inference request.\n\n        Args:\n            session_id (int): The session id.\n            input_ids (List[int]): The input token ids.\n            gen_config (GenerationConfig): The sampling parameters.\n            adapter_name (str): The lora adapter name.\n\n        Yields:\n            int: Error flags. 0 if success.\n            List[int]: The streaming output tokens.\n            int: The number of the output tokens.\n        \"\"\"\n        if len(input_ids) > self.max_input_len:\n            yield EngineOutput(ResponseType.INPUT_LENGTH_ERROR, [])\n            return\n        gen_config = gen_config or GenerationConfig()\n        sampling_param = SamplingParam.from_gen_config(gen_config=gen_config)\n        logger.debug(f'session[{session_id}] try add session.')\n        self.req_sender.send_async(RequestType.ADD_SESSION, dict(session_id=session_id, response=False))\n        msg = dict(\n            token_ids=input_ids,\n            session_id=session_id,\n            sampling_param=sampling_param,\n            adapter_name=adapter_name,\n            input_multimodals=multimodal,\n            migration_request=gen_config.migration_request,\n            with_cache=gen_config.with_cache,\n            preserve_cache=gen_config.preserve_cache,\n        )\n        logger.debug(f'session[{session_id}] add message: num_input_ids={len(input_ids)}.')\n        resp = self.req_sender.send_async(RequestType.ADD_MESSAGE, msg)\n        output_offset = 0\n\n        while True:\n            resp = await self.req_sender.async_recv(resp, wait_main=True)\n\n            cache_block_ids = resp.data.get('cache_block_ids', None) if resp.data else None\n            req_metrics = resp.data.get('req_metrics', None) if resp.data else None\n            logprobs = resp.data.pop('logprobs', None) if resp.data else None\n            extra_outputs = self._get_extra_outputs(resp)\n            routed_experts = extra_outputs.get('routed_experts', None)\n\n            if resp.type == ResponseType.SUCCESS:\n                token_ids = resp.data['token_ids']\n                num_ids = len(token_ids) - output_offset\n                logger.debug(f'session[{session_id}] success: num_out_ids={num_ids}.')\n                yield EngineOutput(resp.type,\n                                   token_ids[output_offset:].tolist(),\n                                   cache_block_ids=cache_block_ids,\n                                   req_metrics=req_metrics,\n                                   routed_experts=routed_experts,\n                                   logprobs=logprobs)\n                output_offset = len(token_ids)\n            elif resp.type in (ResponseType.FINISH, ResponseType.CANCEL):\n                resp_data = resp.data\n                if resp_data is None:\n                    # request might be cancelled before any output\n                    token_ids = []\n                    logits = None\n                else:\n                    token_ids = resp_data['token_ids'][output_offset:].tolist()\n                    logits = resp_data.get('logits', None)\n                num_ids = len(token_ids) - output_offset\n                logger.debug(f'session[{session_id}] finish: num_out_ids={num_ids}.')\n                yield EngineOutput(resp.type,\n                                   token_ids,\n                                   logits=logits,\n                                   cache_block_ids=cache_block_ids,\n                                   req_metrics=req_metrics,\n                                   routed_experts=routed_experts,\n                                   logprobs=logprobs)\n                break\n            else:\n                logger.debug(f'session[{session_id}] failed.')\n                yield EngineOutput(resp.type, [])\n                break\n\n    async def async_infer(self,\n                          session_id: int,\n                          input_ids: List[int] = None,\n                          multimodal: InputMultiModalType = None,\n                          gen_config: GenerationConfig = None,\n                          **kwargs):\n        \"\"\"Send inference request.\n\n        Args:\n            session_id (int): The session id.\n            input_ids (List[int]): The input token ids.\n            gen_config (GenerationConfig): The sampling parameters.\n\n        Returns:\n            int: Error flags. 0 if success.\n            List[int]: The streaming output tokens.\n            int: The number of the output tokens.\n        \"\"\"\n        async for outputs in self.async_stream_infer(session_id,\n                                                     input_ids,\n                                                     multimodal=multimodal,\n                                                     gen_config=gen_config,\n                                                     **kwargs):\n            status = outputs.status\n            if status not in [ResponseType.SUCCESS, ResponseType.FINISH]:\n                return outputs\n\n        return outputs\n\n    def stream_infer(self,\n                     session_id: int,\n                     input_ids: List[int],\n                     multimodal: InputMultiModalType = None,\n                     gen_config: GenerationConfig = None,\n                     adapter_name: str = None,\n                     **kwargs):\n        \"\"\"Send stream inference request.\n\n        Args:\n            session_id (int): The session id.\n            input_ids (List[int]): The input token ids.\n            gen_config (GenerationConfig): The sampling parameters.\n            adapter_name (str): The lora adapter name.\n\n        Yields:\n            int: Error flags. 0 if success.\n            List[int]: The streaming output tokens.\n            int: The number of the output tokens.\n        \"\"\"\n\n        def __call_async():\n            \"\"\"Call async.\"\"\"\n            coro_gen = self.async_stream_infer(session_id,\n                                               input_ids,\n                                               multimodal=multimodal,\n                                               gen_config=gen_config,\n                                               adapter_name=adapter_name,\n                                               **kwargs)\n            while True:\n                try:\n                    yield self.req_sender.run_until_complete(coro_gen.__anext__())\n                except StopAsyncIteration:\n                    break\n\n        yield from __call_async()\n\n    def infer(self,\n              session_id: int,\n              input_ids: List[int] = None,\n              multimodal: InputMultiModalType = None,\n              gen_config: GenerationConfig = None,\n              **kwargs):\n        \"\"\"Send inference request.\n\n        Args:\n            session_id (int): The session id.\n            input_ids (List[int]): The input token ids.\n            gen_config (GenerationConfig): The sampling parameters.\n\n        Returns:\n            int: Error flags. 0 if success.\n            List[int]: The streaming output tokens.\n            int: The number of the output tokens.\n        \"\"\"\n        return self.req_sender.run_until_complete(\n            self.async_infer(session_id, input_ids, multimodal=multimodal, gen_config=gen_config, **kwargs))\n\n    async def async_end(self, session_id: int):\n        \"\"\"End the given session.\"\"\"\n        return end(self.req_sender, session_id)\n\n    def end(self, session_id: int):\n        \"\"\"End the given session.\"\"\"\n        return end(self.req_sender, session_id)\n\n    async def async_cancel(self, session_id: int):\n        \"\"\"Stop current streaming inference.\"\"\"\n        return await async_cancel(self.req_sender, session_id)\n\n    def cancel(self, session_id: int):\n        \"\"\"Stop current streaming inference.\"\"\"\n        return cancel(self.req_sender, session_id)\n"
  },
  {
    "path": "lmdeploy/pytorch/engine/engine_loop.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport asyncio\nimport logging\nimport time\nfrom dataclasses import dataclass\nfrom typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple\n\nimport numpy as np\nimport torch\nfrom torch.profiler import record_function\n\nfrom lmdeploy.messages import RequestMetrics\nfrom lmdeploy.pytorch.disagg.config import EngineRole\nfrom lmdeploy.pytorch.disagg.messages import MigrationExecutionBatch\nfrom lmdeploy.pytorch.messages import MessageStatus, UpdateTokenMode\nfrom lmdeploy.pytorch.utils import cancel_async_tasks, wait_for_async_tasks\nfrom lmdeploy.utils import get_logger\n\nfrom .engine import InferOutput, ResponseType, response_reqs\n\nif TYPE_CHECKING:\n    from lmdeploy.pytorch.disagg.conn.engine_conn import EngineP2PConnection\n    from lmdeploy.pytorch.engine.model_agent import BatchedOutputs\n    from lmdeploy.pytorch.model_inputs import ModelInputs, ModelInputsDelta\n    from lmdeploy.pytorch.paging import Scheduler\n    from lmdeploy.pytorch.strategies.base.sequence import SequenceStrategy\n\n    from .engine import Engine, SeqList\n    from .executor import ExecutorBase\n    from .inputs_maker import InputsMakerAsync\n    from .request import RequestManager\n\nlogger = get_logger('lmdeploy')\n_EMPTY_TOKEN = np.empty((0, ), dtype=np.int64)\n\n\nclass CounterEvent(asyncio.Event):\n\n    def __init__(self):\n        super().__init__()\n        self._counter = 0\n\n    def set(self):\n        if self._counter > 0:\n            self._counter -= 1\n        if self._counter == 0:\n            super().set()\n\n    def clear(self):\n        if self._counter == 0 and super().is_set():\n            super().clear()\n        self._counter += 1\n\n\nclass RunableEventAsync:\n    \"\"\"Awaitable async runable event.\"\"\"\n\n    def __init__(self, scheduler: 'Scheduler'):\n        self.scheduler = scheduler\n        self.event = asyncio.Event()\n\n    async def wait(self):\n        \"\"\"Wait event.\"\"\"\n        await self.event.wait()\n\n    def set(self):\n        \"\"\"Set event.\"\"\"\n        if self.scheduler.has_unfinished():\n            self.event.set()\n        else:\n            self.event.clear()\n\n\ndef build_runable_event(scheduler: 'Scheduler'):\n    \"\"\"Build runable event.\"\"\"\n    return RunableEventAsync(scheduler)\n\n\n@dataclass\nclass EngineLoopConfig:\n    \"\"\"Engine loop config.\n\n    This config is added for Dependency Injection\n    \"\"\"\n    role: EngineRole\n    num_speculative_tokens: Optional[int] = None\n    enable_metrics: bool = False\n    enable_transfer_obj_ref: bool = False\n\n    @staticmethod\n    def from_engine(engine: 'Engine'):\n        \"\"\"Create engine loop config from engine.\"\"\"\n        if engine.specdecode_config is None:\n            num_speculative_tokens = None\n        else:\n            num_speculative_tokens = engine.specdecode_config.num_speculative_tokens\n\n        return EngineLoopConfig(\n            role=engine.engine_config.role,\n            num_speculative_tokens=num_speculative_tokens,\n            enable_metrics=engine.engine_config.enable_metrics,\n            enable_transfer_obj_ref=engine.engine_config.enable_transfer_obj_ref,\n        )\n\n\nclass EngineLoop:\n    \"\"\"Engine loop manager should be created in an async context.\"\"\"\n\n    def __init__(self,\n                 req_manager: 'RequestManager',\n                 scheduler: 'Scheduler',\n                 executor: 'ExecutorBase',\n                 seq_strategy: 'SequenceStrategy',\n                 inputs_maker: 'InputsMakerAsync',\n                 config: EngineLoopConfig,\n                 engine_conn: Optional['EngineP2PConnection'] = None):\n        self.req_manager = req_manager\n        self.scheduler = scheduler\n        self.executor = executor\n        self.seq_strategy = seq_strategy\n        self.inputs_maker = inputs_maker\n        self.config = config\n        self.engine_conn = engine_conn\n\n        # tasks and control events\n        self.tasks: Set[asyncio.Task] = set()\n        self.stop_event = asyncio.Event()\n        self.resp_queue = asyncio.Queue()\n        self.forward_event = CounterEvent()\n        self.migration_event = asyncio.Event()\n        self.has_runable_event = RunableEventAsync(self.scheduler)\n\n        # check init\n        if self.config.role != EngineRole.Hybrid:\n            assert self.engine_conn is not None, 'Engine connection must be provided for non-hybrid engine role.'\n\n    async def preprocess_loop(self):\n        \"\"\"Preprocess request.\"\"\"\n        while not self.stop_event.is_set():\n            await self.req_manager.step()\n            self.has_runable_event.set()\n\n    @staticmethod\n    def _log_resps(outputs: List[InferOutput]):\n        \"\"\"Log resps.\"\"\"\n        if logger.level <= logging.DEBUG:\n            session_ids = [out.session_id for out in outputs]\n            logger.debug(f'Response sessions: {session_ids}')\n            logger.debug(f'Response: num_outputs={len(outputs)}.')\n\n    def _send_resp(self, out: InferOutput):\n        \"\"\"Send response.\"\"\"\n        # skip cancelled response\n        if out.resp.is_done:\n            return\n        resp_type = (ResponseType.FINISH if out.finish else ResponseType.SUCCESS)\n        logprobs = None if out.resp.data is None else out.resp.data.get('logprobs', None)\n        response_reqs(self.req_manager,\n                      out.resp,\n                      resp_type,\n                      data=dict(token_ids=out.token_ids,\n                                logits=out.logits,\n                                cache_block_ids=out.cache_block_ids,\n                                req_metrics=out.req_metrics,\n                                routed_experts=out.routed_experts,\n                                logprobs=logprobs))\n\n    @staticmethod\n    def _update_logprobs(step_outputs: List[InferOutput]):\n        for out in step_outputs:\n            cur_logprobs = out.logprobs\n            if cur_logprobs is None:\n                continue\n\n            if out.resp.data is None:\n                out.resp.data = dict()\n            out.resp.data.setdefault('logprobs', [])\n\n            # logprobs to dict\n            vals = cur_logprobs[0]\n            indices = cur_logprobs[1]\n            cur_logprobs = dict(zip(indices, vals))\n            logprobs = out.resp.data['logprobs']\n            logprobs.append(cur_logprobs)\n\n    def _send_resps(self, step_outputs: List[InferOutput]):\n        \"\"\"Send response callback.\"\"\"\n        self._log_resps(step_outputs)\n        self._update_logprobs(step_outputs)\n\n        is_done = set()\n        for out in reversed(step_outputs):\n            if out.session_id in is_done:\n                continue\n            is_done.add(out.session_id)\n            self._send_resp(out)\n\n    async def send_response_loop(self):\n        \"\"\"Send response to client.\"\"\"\n        que = self.resp_queue\n        while not self.stop_event.is_set():\n            num_outs = que.qsize()\n            if num_outs > 0:\n                resps = []\n                for _ in range(num_outs):\n                    resps += que.get_nowait().values()\n            else:\n                resps = (await que.get()).values()\n            self._send_resps(resps)\n\n    @record_function('make_infer_outputs')\n    def _make_infer_outputs(\n        self,\n        batched_outputs: 'BatchedOutputs',\n        running: 'SeqList',\n        model_inputs: 'ModelInputs',\n        delta: 'ModelInputsDelta',\n    ):\n        \"\"\"Make infer output.\"\"\"\n\n        def __get_logit(msg, logits: torch.Tensor, seq_length: List[int], idx: int):\n            logit = logits.split(seq_length)[idx]\n            if len(msg.all_logits) > 0:\n                # for chunked long context\n                msg.append_logits(logit)\n                logit = msg.logits\n                msg.all_logits.resize(0)\n\n            return logit\n\n        logits = batched_outputs.logits\n        all_routed_experts = batched_outputs.all_routed_experts\n\n        if model_inputs is not None and model_inputs.is_chunk:\n            # chunk long context does not need to update seqs and outputs\n            seq = running[0]\n            seq.append_routed_experts(all_routed_experts)\n            seq.append_logits(logits)\n            return dict()\n\n        new_token_timestamp = batched_outputs.new_token_timestamp\n        logprobs = batched_outputs.logprobs\n\n        if logprobs is not None:\n            logprobs.vals = logprobs.vals.tolist()\n            logprobs.indices = logprobs.indices.tolist()\n\n        seq_length = [seq.num_token_ids for seq in running]\n        is_run = [seq.status == MessageStatus.RUNNING for seq in running]\n        self.seq_strategy.update_running(running=running,\n                                         batched_outputs=batched_outputs,\n                                         model_inputs=model_inputs,\n                                         delta=delta)\n\n        # generate output\n        outputs: Dict[int, InferOutput] = dict()\n        for idx, msg in enumerate(running):\n            if not is_run[idx]:\n                continue\n            token_ids = msg.generated_ids\n            finish = msg.status == MessageStatus.STOPPED or msg.status == MessageStatus.TO_BE_MIGRATED\n            if not finish and len(token_ids) == 0:\n                continue\n            resp_data = msg.resp.data\n            if resp_data is not None and len(resp_data.get('token_ids', [])) == len(token_ids):\n                # no new tokens\n                continue\n            session_id = msg.session_id\n            if msg.resp_cache:\n                cache_block_ids = self.scheduler.block_manager.get_block_table(msg).tolist()\n            else:\n                cache_block_ids = None\n\n            # logprobs\n            num_logprobs = msg.sampling_param.num_logprobs\n            cur_logprobs = None\n            if logprobs is not None and num_logprobs > 0:\n                cur_logprobs = (logprobs.vals[idx][:num_logprobs + 1], logprobs.indices[idx][:num_logprobs + 1])\n            # get spec stats info\n            spec_info = None\n            num_draft_tokens = self.config.num_speculative_tokens\n            if num_draft_tokens is not None and model_inputs is None and self.config.enable_metrics:\n                num_accepted_tokens = (batched_outputs.next_token_ids[idx] > -1).sum() - 1\n                spec_info = dict(num_draft_tokens=num_draft_tokens, num_accepted_tokens=num_accepted_tokens.item())\n            req_metrics = RequestMetrics(new_token_timestamp, msg.engine_events, spec_info=spec_info)\n            out = InferOutput(session_id=session_id,\n                              resp=msg.resp,\n                              finish=finish,\n                              token_ids=token_ids,\n                              cache_block_ids=cache_block_ids,\n                              req_metrics=req_metrics,\n                              logprobs=cur_logprobs,\n                              routed_experts=msg.routed_experts)\n            outputs[session_id] = out\n\n            if msg.return_logits:\n                logit = __get_logit(msg, logits, seq_length, idx)\n                outputs[session_id].logits = logit\n        return outputs\n\n    async def _main_loop_try_send_next_inputs(self):\n        \"\"\"Try send next inputs.\"\"\"\n        scheduler = self.scheduler\n        if not scheduler.has_unfinished():\n            await self.has_runable_event.wait()\n\n        scheduler.collect_migration_done()\n        return await self.inputs_maker.send_next_inputs()\n\n    async def _main_loop_get_outputs(\n        self,\n        running: 'SeqList',\n        forward_inputs: Dict[str, Any],\n    ):\n        \"\"\"Get outputs and prefetch.\"\"\"\n        model_inputs = forward_inputs['inputs']\n        delta = forward_inputs['delta']\n        self.inputs_maker.update_running_seqs(running, model_inputs)\n\n        # try prefetch inputs\n        self.scheduler.collect_migration_done()\n        forward_inputs, next_running = await self.inputs_maker.prefetch_next_inputs()\n\n        # send output\n        out = await self.executor.get_output_async()\n        if out is not None:\n            step_outputs = self._make_infer_outputs(out, running=running, model_inputs=model_inputs, delta=delta)\n            self.resp_queue.put_nowait(step_outputs)\n\n        return forward_inputs, next_running\n\n    async def main_loop(self):\n        \"\"\"Main loop of the engine.\n\n        Each engine instance would communicate with the engine by queue.\n        \"\"\"\n        has_runable_event = self.has_runable_event\n        scheduler = self.scheduler\n        forward_inputs = None\n        next_running = None\n\n        async def __no_running_warning():\n            # TODO (JimyMa): add watermark check event instead of async sleep.\n            # self.perfill_watermark_event.wait()\n            logger.warning(f'no next prefill running request, Maybe cache is full, '\n                           f'free gpu cache blocks: {scheduler.block_manager.get_num_free_gpu_blocks()}, '\n                           f'total gpu cache blocks: {scheduler.block_manager.num_gpu_blocks}')\n            await asyncio.sleep(0.1)\n\n        while not self.stop_event.is_set():\n            if next_running is None:\n                forward_inputs, next_running = await self._main_loop_try_send_next_inputs()\n                if next_running is None:\n                    await __no_running_warning()\n                    continue\n\n            scheduler.activate_seqs(next_running)\n            forward_inputs, next_running = await self._main_loop_get_outputs(\n                running=next_running,\n                forward_inputs=forward_inputs,\n            )\n            self.inputs_maker.deactivate_evict_seqs()\n            has_runable_event.set()\n\n    def update_running_migration(self, running: 'SeqList', next_token_ids: np.ndarray, stopped: torch.Tensor,\n                                 model_metas: List[Dict[str, Any]]):\n        \"\"\"Update scheduler.\"\"\"\n        if model_metas is None:\n            model_metas = [None] * len(running)\n        for token, msg, stop, model_meta in zip(next_token_ids, running, stopped, model_metas):\n            if msg.status != MessageStatus.MIGRATION_RUNNING:\n                continue\n            update_token = token\n\n            # fill token\n            msg.update_token_ids(update_token, model_meta=model_meta, mode=UpdateTokenMode.PREFILL)\n            if stop:\n                update_token = _EMPTY_TOKEN\n                msg.update_token_ids(update_token, model_meta=model_meta, mode=UpdateTokenMode.PREFILL)\n                msg.state.finish()\n\n    async def _migration_loop_migrate(self, migration_ready: 'SeqList'):\n        \"\"\"Migration loop migrate.\"\"\"\n        for msg in migration_ready:\n            # skip dummy prefill migration\n            if msg.migration_request.is_dummy_prefill:\n                continue\n\n            migration_execution_requests: List[Tuple[int, List[Tuple[int, int]]]] = []\n            migration_request = msg.migration_request\n            prefill_block_ids = migration_request.remote_block_ids\n            decode_block_ids = list(self.scheduler.block_manager.get_block_table(msg=msg))\n\n            assert len(prefill_block_ids) == len(decode_block_ids), (\n                f'#prefill block ids ({len(prefill_block_ids)}) must equal to '\n                f'#decode block ids ({len(decode_block_ids)})'\n                f'all id length: {msg.num_token_ids}')\n            migration_execution_requests.append((\n                migration_request.remote_engine_id,\n                list(zip(prefill_block_ids, decode_block_ids)),\n            ))\n            migration_inputs = MigrationExecutionBatch(protocol=migration_request.protocol,\n                                                       requests=migration_execution_requests)\n            logger.info(f'migrating session: {msg.session_id} begin')\n            await self.executor.migrate(migration_inputs)\n            logger.info(f'migrating session: {msg.session_id} done')\n            await self.engine_conn.zmq_send(remote_engine_id=migration_request.remote_engine_id,\n                                            remote_session_id=migration_request.remote_session_id)\n\n    async def _migration_loop_get_outputs(self, migration_ready: 'SeqList'):\n        \"\"\"Migration loop get outputs.\"\"\"\n        outputs: Dict[int, InferOutput] = dict()\n        for _, msg in enumerate(migration_ready):\n            session_id = msg.session_id\n            msg.resp.type = ResponseType.SUCCESS\n            token_ids = [msg.migration_request.remote_token_id]\n            # MUST be a wall-clock time\n            new_token_timestamp = time.time()\n            req_metrics = RequestMetrics(new_token_timestamp, msg.engine_events)\n            out = InferOutput(\n                session_id=session_id,\n                resp=msg.resp,\n                finish=False,\n                token_ids=np.array(token_ids),\n                req_metrics=req_metrics,\n            )\n            outputs[session_id] = out\n            self.update_running_migration([msg], np.array([token_ids]), [False], [None])\n        self.resp_queue.put_nowait(outputs)\n\n    async def _migration_loop_process_ready(self, migration_ready: 'SeqList'):\n        \"\"\"Process migration ready.\"\"\"\n        await self._migration_loop_migrate(migration_ready)\n\n        # generate output\n        with self.scheduler.seqs_migration_activation(migration_ready):\n            await self._migration_loop_get_outputs(migration_ready)\n        self.has_runable_event.set()\n\n    async def migration_loop(self):\n        \"\"\"Async loop migration.\"\"\"\n        while not self.stop_event.is_set():\n            migration_ready = self.scheduler._schedule_migration()\n            if not migration_ready and not self.scheduler.has_migration_waiting():\n                await self.migration_event.wait()\n            elif migration_ready:\n                self.migration_event.clear()\n                await self._migration_loop_process_ready(migration_ready)\n            else:\n                # release coroutine for decoding\n                await asyncio.sleep(.5)\n\n    def start(self, event_loop: asyncio.AbstractEventLoop):\n        \"\"\"Create async tasks.\"\"\"\n        # start executor\n        logger.info('Starting executor.')\n        self.executor.start(self.forward_event)\n        # start owned loops\n        self.tasks.add(event_loop.create_task(self.executor.wait_tasks(), name='MainLoopWaitExecutor'))\n        logger.info('Starting async task MainLoopPreprocessMessage.')\n        self.tasks.add(event_loop.create_task(self.preprocess_loop(), name='MainLoopPreprocessMessage'))\n        logger.info('Starting async task MainLoopResponse.')\n        self.tasks.add(event_loop.create_task(self.send_response_loop(), name='MainLoopSendResponse'))\n        logger.info('Starting async task MainLoop.')\n        self.tasks.add(event_loop.create_task(self.main_loop(), name='MainLoopMain'))\n        if self.config.role != EngineRole.Hybrid:\n            logger.info('Starting async task MigrationLoop.')\n            self.tasks.add(event_loop.create_task(self.migration_loop(), name='MainLoopMigration'))\n\n        for task in self.tasks:\n            task.add_done_callback(self.tasks.discard)\n\n    async def wait_tasks(self):\n        \"\"\"Wait for all tasks to finish.\"\"\"\n        if not self.tasks:\n            return\n\n        # copy the tasks so callback of tasks would not update it\n        tasks = self.tasks.copy()\n        try:\n            await wait_for_async_tasks(tasks)\n        except asyncio.CancelledError:\n            logger.info('EngineLoop wait_tasks cancelled.')\n            raise\n        except BaseException:\n            logger.error('EngineLoop wait_tasks failed.')\n            raise\n        finally:\n            logger.debug('EngineLoop wait_tasks cleanup.')\n            # Make sure task finished/cancelled here.\n            # Error might happen if executor release before executor wait_tasks finish.\n            await cancel_async_tasks(tasks)\n\n    def stop(self):\n        \"\"\"Stop all loops.\"\"\"\n        if self.stop_event.is_set():\n            # Already stopped, avoid calling executor.stop() multiple times\n            return\n        self.executor.stop()\n        self.stop_event.set()\n        self.cancel()\n\n    def cancel(self):\n        \"\"\"Cancel all loops.\"\"\"\n        for task in self.tasks:\n            if not task.done():\n                task.cancel()\n        self.tasks.clear()\n\n\ndef build_engine_loop(engine: 'Engine'):\n    \"\"\"Build engine loop.\"\"\"\n    from .inputs_maker import build_inputs_maker\n\n    config = EngineLoopConfig.from_engine(engine)\n    inputs_maker = build_inputs_maker(engine)\n    return EngineLoop(\n        req_manager=engine.req_manager,\n        scheduler=engine.scheduler,\n        executor=engine.executor,\n        seq_strategy=engine.seq_strategy,\n        inputs_maker=inputs_maker,\n        config=config,\n        engine_conn=engine.engine_conn,\n    )\n"
  },
  {
    "path": "lmdeploy/pytorch/engine/executor/__init__.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom logging import Logger\nfrom typing import Dict\n\nfrom lmdeploy.pytorch import envs\nfrom lmdeploy.pytorch.config import BackendConfig, CacheConfig, DistConfig, MiscConfig, ModelConfig, SpecDecodeConfig\nfrom lmdeploy.utils import get_logger\n\nfrom .base import ExecutorBase\n\n\ndef get_distributed_executor_backend(world_size: int, dp: int, device_type: str, logger: Logger = None):\n    \"\"\"Get distributed executor backend.\"\"\"\n    from lmdeploy.pytorch.backends import get_backend\n\n    def _log_info(message: str):\n        if logger is not None:\n            logger.info(message)\n\n    def _log_and_set_backend(message: str, executor_backend: str):\n        \"\"\"Log and set backend.\"\"\"\n        message += f' distributed_executor_backend={executor_backend}.'\n        _log_info(message)\n        return executor_backend\n\n    executor_backend = envs.executor_backend\n    if executor_backend is not None:\n        return _log_and_set_backend('found environment LMDEPLOY_EXECUTOR_BACKEND.', executor_backend)\n\n    if world_size == 1:\n        return 'uni'\n\n    if dp > 1:\n        executor_backend = 'ray'\n        return _log_and_set_backend(f'dp={dp}.', 'ray')\n\n    backend = get_backend(device_type)\n    if not backend.support_ray():\n        return _log_and_set_backend(f'device={device_type} does not support ray.', 'mp')\n    else:\n        return 'ray'\n\n    # TODO: fix mp hanging, do not delete the comment.\n    # device_count = backend.device_count()\n    # if device_count is None:\n    #     return _log_and_set_backend(f'device={device_type} can not get device_count.', 'mp')\n\n    # if device_count < world_size:\n    #     executor_backend = 'ray'\n    #     return _log_and_set_backend(f'local device_count({device_count})<world_size({world_size}),', 'ray')\n    # else:\n    #     executor_backend = 'mp'\n    #     return _log_and_set_backend(f'local device_count({device_count})>=world_size({world_size}),', 'mp')\n\n\ndef build_executor(\n    model_path: str,\n    cache_config: CacheConfig,\n    backend_config: BackendConfig,\n    dist_config: DistConfig,\n    misc_config: MiscConfig,\n    adapters: Dict[str, str] = None,\n    device_type: str = 'cuda',\n    distributed_executor_backend: str = None,\n    dtype: str = 'auto',\n    specdecode_config: SpecDecodeConfig = None,\n) -> ExecutorBase:\n    \"\"\"Build model agent executor.\"\"\"\n    logger = get_logger('lmdeploy')\n    dp = dist_config.dp\n    world_size = dist_config.world_size\n\n    model_config = ModelConfig.from_pretrained(\n        model_path,\n        trust_remote_code=True,\n        dtype=dtype,\n        hf_overrides=misc_config.hf_overrides,\n        dist_config=dist_config,\n        is_draft_model=False,\n        spec_method=None if specdecode_config is None else specdecode_config.method,\n        model_format=misc_config.model_format,\n        device_type=device_type,\n    )\n\n    if distributed_executor_backend is None:\n        distributed_executor_backend = get_distributed_executor_backend(world_size, dp, device_type, logger)\n\n    if dp > 1:\n        assert distributed_executor_backend == 'ray', (\n            'dp>1 requires distributed_executor_backend=\"ray\", ',\n            f'get distributed_executor_backend=\"{distributed_executor_backend}\"')\n\n    if misc_config.empty_init:\n        assert distributed_executor_backend == 'ray', (\n            'empty_init requires distributed_executor_backend=\"ray\", ',\n            f'get distributed_executor_backend=\"{distributed_executor_backend}\"')\n\n    if distributed_executor_backend is not None:\n        logger.info(f'Build <{distributed_executor_backend}> executor.')\n    if distributed_executor_backend == 'uni':\n        assert world_size == 1, 'uni executor only support world_size==1.'\n        from .uni_executor import UniExecutor\n        return UniExecutor(\n            model_path=model_path,\n            model_config=model_config,\n            cache_config=cache_config,\n            backend_config=backend_config,\n            misc_config=misc_config,\n            adapters=adapters,\n            device_type=device_type,\n            specdecode_config=specdecode_config,\n        )\n    elif distributed_executor_backend == 'mp':\n        from .mp_executor import MPExecutor\n        logger.warning('MPExecutor will be deprecated in future releases, please use RayExecutor instead.')\n        return MPExecutor(\n            model_path=model_path,\n            model_config=model_config,\n            cache_config=cache_config,\n            backend_config=backend_config,\n            dist_config=dist_config,\n            misc_config=misc_config,\n            adapters=adapters,\n            device_type=device_type,\n            specdecode_config=specdecode_config,\n        )\n    elif distributed_executor_backend == 'ray':\n        from .ray_executor import RayExecutor\n        return RayExecutor(\n            model_path=model_path,\n            model_config=model_config,\n            cache_config=cache_config,\n            backend_config=backend_config,\n            dist_config=dist_config,\n            misc_config=misc_config,\n            adapters=adapters,\n            device_type=device_type,\n            dtype=dtype,\n            specdecode_config=specdecode_config,\n        )\n    else:\n        raise RuntimeError(f'Unsupported distributed_executor_backend: {distributed_executor_backend}.')\n"
  },
  {
    "path": "lmdeploy/pytorch/engine/executor/base.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n# Inspired by vLLM: https://github.com/vllm-project/vllm\nimport asyncio\nimport contextlib\nfrom typing import Any, Dict, List, Optional\n\nfrom lmdeploy.pytorch.config import BackendConfig, CacheConfig, DistConfig, MiscConfig, ModelConfig, SpecDecodeConfig\nfrom lmdeploy.pytorch.disagg.conn.protocol import DistServeInitRequest, DistServeKVTransferEndpointInfo\nfrom lmdeploy.pytorch.disagg.messages import MigrationExecutionBatch\nfrom lmdeploy.pytorch.engine.cache_engine import CacheEngine\nfrom lmdeploy.utils import get_logger\n\nlogger = get_logger('lmdeploy')\n\n\nclass ExecutorBase:\n    \"\"\"Executor base class.\"\"\"\n\n    def __init__(self,\n                 model_path: str,\n                 model_config: ModelConfig,\n                 cache_config: CacheConfig,\n                 backend_config: BackendConfig,\n                 dist_config: DistConfig,\n                 misc_config: MiscConfig,\n                 adapters: Dict[str, str] = None,\n                 specdecode_config: SpecDecodeConfig = None,\n                 device_type: str = 'cuda'):\n        \"\"\"Initialize Executor.\"\"\"\n        cache_config.window_size = model_config.sliding_window\n        if cache_config.window_size is not None and cache_config.window_size > 0:\n            # do not support sliding window prefix caching\n            logger.warning('Sliding window prefix caching is not supported.')\n            cache_config.enable_prefix_caching = False\n        self.model_config = model_config\n        self.cache_config = cache_config\n        self.backend_config = backend_config\n        self.dist_config = dist_config\n        self.misc_config = misc_config\n        self.dp = dist_config.dp\n        self.world_size = dist_config.world_size\n        self.device_type = device_type\n        self.specdecode_config = specdecode_config\n\n    def download_models(self):\n        \"\"\"Download model.\"\"\"\n        raise NotImplementedError('Not Implemented.')\n\n    def build_model(self):\n        \"\"\"Build model.\"\"\"\n        raise NotImplementedError('Not Implemented.')\n\n    def gather_free_mem(self):\n        \"\"\"Gather available memory.\"\"\"\n        raise NotImplementedError('Not Implemented.')\n\n    def set_cache_config(self, cache_config: CacheConfig, spec_cache_config: CacheConfig = None):\n        \"\"\"Set all cache config.\"\"\"\n        raise NotImplementedError('Not Implemented.')\n\n    def set_model_config(self, model_config: ModelConfig, spec_model_config: ModelConfig = None):\n        \"\"\"Set all model config.\"\"\"\n        raise NotImplementedError('Not Implemented.')\n\n    def build_graph_runner(self):\n        \"\"\"Build graph runner.\"\"\"\n        raise NotImplementedError('Not Implemented.')\n\n    def build_cache_engine(self):\n        \"\"\"Build cache engine.\"\"\"\n        raise NotImplementedError('Not Implemented.')\n\n    def warmup(self):\n        \"\"\"warmup.\"\"\"\n        raise NotImplementedError('Not Implemented.')\n\n    async def sleep(self, level: int = 1):\n        \"\"\"Sleep.\"\"\"\n        raise NotImplementedError('Not Implemented.')\n\n    def wakeup(self, tags: Optional[List[str]] = None):\n        \"\"\"Wakeup.\"\"\"\n        raise NotImplementedError('Not Implemented.')\n\n    def update_params(self, request: Any):\n        \"\"\"Update params.\"\"\"\n        raise NotImplementedError('Not Implemented.')\n\n    def get_input_processor(self):\n        \"\"\"Get input processor.\"\"\"\n        raise NotImplementedError('Not Implemented.')\n\n    def start(self, forward_event: asyncio.Event):\n        \"\"\"Start engine loop.\"\"\"\n        raise NotImplementedError('Not Implemented.')\n\n    async def wait_tasks(self):\n        \"\"\"Wait tasks.\"\"\"\n        raise NotImplementedError('Not Implemented.')\n\n    def stop(self):\n        \"\"\"Stop engine loop.\"\"\"\n        raise NotImplementedError('Not Implemented.')\n\n    def release(self):\n        \"\"\"Release resources.\"\"\"\n        raise NotImplementedError('Not Implemented.')\n\n    async def forward_async(self, inputs):\n        \"\"\"Start forward.\"\"\"\n        raise NotImplementedError('Not Implemented')\n\n    async def get_output_async(self):\n        \"\"\"Get output async.\"\"\"\n        raise NotImplementedError('Not Implemented')\n\n    \"\"\" PD Disaggregation API Begin \"\"\"\n\n    def p2p_initialize(self, remote_engine_config: DistServeInitRequest):\n        \"\"\"Init rdma link.\"\"\"\n        raise NotImplementedError('Not implemented')\n\n    def p2p_connect(self, conn_request: List[DistServeKVTransferEndpointInfo]):\n        \"\"\"rdma_connect.\"\"\"\n        raise NotImplementedError('Not Implemented')\n\n    async def migrate(self, batch: MigrationExecutionBatch):\n        \"\"\"KV Cache Migration.\"\"\"\n        raise NotImplementedError('Not Implemented')\n\n    \"\"\" PD Disaggregation API End \"\"\"\n\n    def _get_runtime_size(self, num_free_gpu_mem: int, cache_block_size: int, vocal_size: int):\n        \"\"\"Find best prefill num.\"\"\"\n        cache_max_entry_count = self.cache_config.cache_max_entry_count\n        max_prefill_token_num = self.cache_config.max_prefill_token_num\n        max_batches = self.cache_config.max_batches\n        runtime_cache_size = 0\n        while max_prefill_token_num > 0:\n            # estimate runtime mem size\n            runtime_cache_size = int((max_prefill_token_num + max_batches * 2) * vocal_size * 2)\n            num_available = (num_free_gpu_mem - runtime_cache_size) * cache_max_entry_count\n            if cache_block_size == 0 or int(num_available) // cache_block_size >= 16:\n                break\n            max_prefill_token_num = max_prefill_token_num // 2\n        return runtime_cache_size, max_prefill_token_num\n\n    def _adjust_block_size(self):\n        \"\"\"Adjust block_size.\"\"\"\n        if self.model_config.use_flash_mla is True:\n            if self.cache_config.block_size != 64:\n                raise ValueError('Please set block_size to 64 for flash_mla.')\n            return\n        # TODO: support kernel with both large head dim and large block size.\n        if self.model_config.k_head_dim >= 512 and self.cache_config.block_size > 32:\n            self.cache_config.block_size = 32\n            logger.warning(\n                f'Update `block_size={self.cache_config.block_size}` for large `head_dim={self.model_config.k_head_dim}`.'  # noqa\n            )\n\n    def _get_state_cache_mem(self):\n        \"\"\"Get state cache mem usage.\"\"\"\n        cache_config = self.cache_config\n        if len(cache_config.states_shapes) == 0:\n            return 0\n\n        from lmdeploy.pytorch.engine.cache_engine import StateCacheEngine\n\n        num_state_caches = cache_config.num_state_caches\n        if num_state_caches is None:\n            # add more caches for eviction\n            # TODO: Share memory between state cache and pageable cache\n            num_state_caches = int(cache_config.max_batches + 8)\n            cache_config.num_state_caches = num_state_caches\n\n        mems = StateCacheEngine.get_cache_state_size(cache_config.states_shapes)\n        mems *= num_state_caches\n\n        if cache_config.enable_prefix_caching:\n            cache_config.enable_prefix_caching = False\n            logger.warning('Prefix caching has not been support for state space model.')\n\n        return mems\n\n    def update_configs(self):\n        \"\"\"Update cache config.\"\"\"\n        self._adjust_block_size()\n        # spec\n        if self.specdecode_config and self.specdecode_config.cache_config:\n            self.specdecode_config.cache_config.block_size = self.cache_config.block_size\n        cache_config = self.cache_config\n        model_config = self.model_config\n        cache_config.states_shapes = model_config.states_shapes\n\n        # get free mems\n        free_mems = self.gather_free_mem()\n        free_mem = min(free_mems)\n        logger.debug(f'minimal free gpu memory: {free_mem >> 20} mb')\n\n        # get state cache size\n        state_cache_mem = self._get_state_cache_mem()\n        free_mem = free_mem - state_cache_mem\n        assert free_mem > 0, 'No enough gpu memory for state cache. Please reduce max_batch_size.'\n\n        vocal_size = self.model_config.vocab_size\n        tp = self.dist_config.attn_tp\n        cache_block_size = CacheEngine.get_cache_block_size(cache_config, model_config, tp)\n        spec_cache_config = None\n        spec_model_config = None\n        spec_cache_block_size = 0\n        if self.specdecode_config:\n            spec_model_config = self.specdecode_config.model_config\n            if spec_cache_config := self.specdecode_config.cache_config:\n                spec_cache_block_size = CacheEngine.get_cache_block_size(spec_cache_config, spec_model_config, 1)\n\n        runtime_mem, max_prefill_token_num = self._get_runtime_size(free_mem, cache_block_size + spec_cache_block_size,\n                                                                    vocal_size)\n        if cache_config.max_prefill_token_num != max_prefill_token_num:\n            if max_prefill_token_num <= 0:\n                raise RuntimeError('No enough gpu memory for runtime.')\n            cache_config.max_prefill_token_num = max_prefill_token_num\n            logger.warning(f'No enough memory. Update max_prefill_token_num={max_prefill_token_num}')\n\n        if spec_cache_config is not None:\n            spec_cache_config.max_prefill_token_num = max_prefill_token_num\n\n        free_mem -= runtime_mem\n        logger.debug(f'estimated max runtime memory: {runtime_mem >> 20} mb')\n        available_mem = free_mem * cache_config.cache_max_entry_count\n\n        if cache_config.num_gpu_blocks == 0:\n            cache_config.num_gpu_blocks = int(available_mem / cache_block_size)\n            if cache_config.num_gpu_blocks <= 0:\n                raise RuntimeError('No enough gpu memory for kv cache.')\n            if spec_cache_config is not None:\n                spec_cache_config.num_gpu_blocks = cache_config.num_gpu_blocks\n\n        self.set_cache_config(cache_config, spec_cache_config)\n        self.set_model_config(model_config, spec_model_config)\n\n    def init(self):\n        \"\"\"init.\"\"\"\n        logger.info('Building Model.')\n        self.build_model()\n        logger.info('Updating configs.')\n        self.update_configs()\n        logger.info('Building GraphRunner and warmup ops, please waiting.')\n        self.build_graph_runner()\n        logger.info(f'Building CacheEngine with config: \\n{self.cache_config}.')\n        if self.specdecode_config:\n            if spec_cache_config := self.specdecode_config.cache_config:\n                logger.info(f'Building Spec CacheEngine with config: \\n{spec_cache_config}.')\n        self.build_cache_engine()\n        logger.info('Warming up model.')\n        self.warmup()\n\n    @contextlib.contextmanager\n    def remote_log(self, msg: str):\n        \"\"\"Send log for debugging.\n\n        Do not use it in production.\n        \"\"\"\n        # Different executor may have different log sending logic.\n        yield\n"
  },
  {
    "path": "lmdeploy/pytorch/engine/executor/base_worker.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport asyncio\nimport gc\nfrom typing import Any, Dict, List, Optional\n\nfrom lmdeploy.pytorch.backends.selector import get_backend\nfrom lmdeploy.pytorch.config import BackendConfig, CacheConfig, DistConfig, MiscConfig, ModelConfig, SpecDecodeConfig\nfrom lmdeploy.pytorch.devices import DeviceContext\nfrom lmdeploy.pytorch.disagg.conn.protocol import DistServeInitRequest, DistServeKVTransferEndpointInfo\nfrom lmdeploy.pytorch.disagg.messages import MigrationExecutionBatch\nfrom lmdeploy.pytorch.distributed import DistContext\nfrom lmdeploy.pytorch.engine.model_agent import build_model_agent\nfrom lmdeploy.utils import get_logger\n\nfrom .dist_utils import init_process_group, setup_master_addr\n\nlogger = get_logger('lmdeploy')\n\n\nclass WorkerWrapperBase:\n    \"\"\"Worker wrapper.\"\"\"\n\n    def __init__(\n        self,\n        model_path: str,\n        cache_config: CacheConfig,\n        backend_config: BackendConfig,\n        model_config: ModelConfig,\n        dist_config: DistConfig,\n        misc_config: MiscConfig,\n        adapters: Dict[str, str] = None,\n        device_type: str = 'cuda',\n        log_level: int = 30,\n        specdecode_config: SpecDecodeConfig = None,\n    ):\n        self.model_path = model_path\n        self.model_config = model_config\n        self.cache_config = cache_config\n        self.backend_config = backend_config\n        self.dist_config = dist_config\n        self.misc_config = misc_config\n        self.adapters = adapters\n        self.device_type = device_type\n        self.log_level = log_level\n        self.dp = dist_config.dp\n        self.tp = dist_config.tp\n        self.world_size = dist_config.world_size\n        self.device_type = device_type\n        self.specdecode_config = specdecode_config\n        logger.setLevel(log_level)\n        self.out_que: asyncio.Queue = None\n\n        # frequently gc would cause latency spike\n        # default threshold (700, 10, 10)\n        gc.set_threshold(10000, 100, 100)\n\n    def init_process_group(self, rank: int, master_addr: str = None, master_port: str = None):\n        \"\"\"Initialize process group.\"\"\"\n        self.rank = rank\n        if self.world_size > 1:\n            if master_addr is not None and master_port is not None:\n                setup_master_addr(master_addr, master_port)\n\n            init_process_group(rank, self.world_size)\n\n        ccl_backend = get_backend(self.device_type).ccl_backend()\n        self.dist_ctx = DistContext.build(self.rank, self.dist_config, ccl_backend)\n\n    def pack_output(self, output: Dict):\n        \"\"\"Pack output.\"\"\"\n        return output\n\n    async def get_outputs(self):\n        \"\"\"Get outputs.\"\"\"\n        return await self.get_output_async()\n\n    def build_model(self):\n        \"\"\"Build model.\"\"\"\n        self.device_ctx = DeviceContext(device_type=self.device_type)\n\n        self.model_agent = build_model_agent(\n            model_path=self.model_path,\n            model_config=self.model_config,\n            cache_config=self.cache_config,\n            backend_config=self.backend_config,\n            misc_config=self.misc_config,\n            device_ctx=self.device_ctx,\n            dist_ctx=self.dist_ctx,\n            adapters=self.adapters,\n            specdecode_config=self.specdecode_config,\n        )\n        self.model_agent.build_model()\n\n    def get_free_mem(self):\n        \"\"\"Gather free mem.\"\"\"\n        return self.model_agent.get_free_mem()\n\n    def set_cache_config(self, cache_config: CacheConfig, spec_cache_config: CacheConfig = None):\n        \"\"\"Set all cache config.\"\"\"\n        self.model_agent.set_cache_config(cache_config, spec_cache_config)\n\n    def set_model_config(self, model_config: ModelConfig, spec_model_config: ModelConfig = None):\n        \"\"\"Set all model config.\"\"\"\n        self.model_agent.set_model_config(model_config, spec_model_config)\n\n    def build_graph_runner(self):\n        \"\"\"Build graph runner.\"\"\"\n        self.model_agent.build_graph_runner()\n\n    def build_cache_engine(self):\n        \"\"\"Build cache engine.\"\"\"\n        self.model_agent.build_cache_engine()\n\n    def update_params(self, request: Any):\n        \"\"\"Update params.\"\"\"\n        self.model_agent.update_params(request)\n\n    def warmup(self):\n        \"\"\"warmup.\"\"\"\n        self.model_agent.warmup()\n\n    async def sleep(self, level: int = 1):\n        \"\"\"Sleep.\"\"\"\n        await self.model_agent.sleep(level)\n\n    def wakeup(self, tags: Optional[List[str]] = None):\n        \"\"\"Wakeup.\"\"\"\n        self.model_agent.wakeup(tags)\n\n    def get_input_processor(self):\n        \"\"\"Build cache engine.\"\"\"\n        return self.model_agent.get_input_processor()\n\n    def start(self):\n        \"\"\"Start engine loop.\"\"\"\n        self.model_agent.start()\n        self.out_que = asyncio.Queue()\n\n    async def wait_tasks(self):\n        \"\"\"Wait tasks.\"\"\"\n        try:\n            await self.model_agent.wait_tasks()\n        except asyncio.CancelledError:\n            logger.debug('WorkerWrapper wait_tasks cancelled.')\n            raise\n        except BaseException:\n            # we want to keep logs in both ray logs and engine logs\n            msg = f'WorkerWrapper rank[{self.rank}] wait_tasks failed.'\n            logger.exception(msg)\n            raise\n\n    def stop(self):\n        \"\"\"Stop engine loop.\"\"\"\n        self.model_agent.stop()\n\n    async def stop_async(self):\n        await self.model_agent.stop_async()\n\n    async def forward_async(self, inputs):\n        \"\"\"Start forward.\"\"\"\n        self.model_agent.set_forward_inputs(inputs)\n\n    async def get_output_async(self):\n        \"\"\"Get output async.\"\"\"\n        ret = await self.model_agent.get_output_async()\n        ret = self.pack_output(ret)\n        return ret\n\n    def release(self):\n        \"\"\"Stop engine loop.\"\"\"\n        self.model_agent.release()\n\n    \"\"\" PD Disaggregation API Begin \"\"\"\n\n    def p2p_initialize(self, init_request: DistServeInitRequest):\n        return self.model_agent.cache_engine.p2p_initialize(init_request)\n\n    def p2p_connect(self, remote_engine_id: str, conn_request: List[DistServeKVTransferEndpointInfo]):\n        return self.model_agent.cache_engine.p2p_connect(remote_engine_id, conn_request)\n\n    async def migrate(self, inputs: MigrationExecutionBatch):\n        return await self.model_agent.cache_engine.migrate(inputs)\n\n    \"\"\" PD Disaggregation API End \"\"\"\n"
  },
  {
    "path": "lmdeploy/pytorch/engine/executor/dist_utils.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport os\nimport socket\nfrom datetime import timedelta\n\nimport torch.distributed as dist\n\nfrom lmdeploy.pytorch.backends.selector import get_backend\n\n\ndef find_available_port() -> bool:\n    \"\"\"Find available port.\"\"\"\n    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:\n        s.bind(('0.0.0.0', 0))\n        s.listen(1)\n        port = s.getsockname()[1]\n        return port\n\n\ndef setup_master_addr(addr: str, port: str):\n    \"\"\"Setup master addr.\"\"\"\n    from lmdeploy.utils import get_logger\n    logger = get_logger('lmdeploy')\n\n    if not isinstance(port, str):\n        port = str(port)\n    os.environ['MASTER_ADDR'] = addr\n    os.environ['MASTER_PORT'] = port\n    logger.info(f'MASTER_ADDR={addr}, MASTER_PORT={port}')\n\n\ndef init_dist_environ(rank: int, world_size: int):\n    \"\"\"Init environ.\"\"\"\n    os.environ['RANK'] = str(rank)\n    os.environ['WORLD_SIZE'] = str(world_size)\n\n\ndef init_process_group(rank: int, world_size: int):\n    \"\"\"Init process group.\"\"\"\n    DIST_TIMEOUT = timedelta(days=35600)\n    init_dist_environ(rank, world_size)\n    os.environ.pop('TORCHELASTIC_USE_AGENT_STORE', None)\n\n    ccl_backend = get_backend().ccl_backend()\n    dist.init_process_group(backend=ccl_backend, rank=rank, world_size=world_size, timeout=DIST_TIMEOUT)\n    assert dist.is_initialized()\n"
  },
  {
    "path": "lmdeploy/pytorch/engine/executor/mp_executor.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n# modify from vLLM: https://github.com/vllm-project/vllm/blob/main/vllm/v1/executor/multiproc_executor.py\nimport asyncio\nimport multiprocessing.shared_memory as shared_memory\nimport os\nimport pickle\nimport signal\nimport struct\nfrom contextlib import asynccontextmanager, contextmanager\nfrom multiprocessing.context import SpawnContext\nfrom typing import Any, Dict, List, Tuple\n\nimport torch\nimport torch.distributed as dist\nimport torch.multiprocessing as mp\n\nfrom lmdeploy.pytorch.backends.selector import init_backend\nfrom lmdeploy.pytorch.config import BackendConfig, CacheConfig, DistConfig, MiscConfig, ModelConfig, SpecDecodeConfig\nfrom lmdeploy.utils import get_logger, try_import_deeplink\n\nfrom .base import ExecutorBase\nfrom .base_worker import WorkerWrapperBase\nfrom .dist_utils import find_available_port, setup_master_addr\n\nlogger = get_logger('lmdeploy')\n\n# 1m shared memory\nSHARED_BLOCK_SIZE = 1 << 20\n# num shared block\nNUM_SHARED_BLOCK = 32\n# data size\nHEAD_SIZE = 8\n# block real size\nSHARED_BLOCK_REAL_SIZE = SHARED_BLOCK_SIZE + HEAD_SIZE\n\n\ndef get_num_packages(data_size):\n    \"\"\"Get num packages.\"\"\"\n    return (data_size + SHARED_BLOCK_SIZE - 1) // SHARED_BLOCK_SIZE\n\n\nclass Notifier:\n\n    def __init__(self, num_receiver: int, mp_ctx: SpawnContext):\n        self.events = [mp_ctx.Event() for _ in range(NUM_SHARED_BLOCK)]\n        self.bar = mp_ctx.Barrier(num_receiver + 1)\n        self._event_id = 0\n\n    def _update_event_id(self):\n        self._event_id = (self._event_id + 1) % NUM_SHARED_BLOCK\n\n    def set(self):\n        self.events[self._event_id].set()\n        if self._event_id == NUM_SHARED_BLOCK - 1:\n            self.bar.wait()\n            [event.clear() for event in self.events]\n            self.bar.wait()\n        self._update_event_id()\n\n    async def set_async(self):\n        # not safe if we might launch multiple reqs\n        event_loop = asyncio.get_event_loop()\n        self.events[self._event_id].set()\n        if self._event_id == NUM_SHARED_BLOCK - 1:\n            await event_loop.run_in_executor(None, self.bar.wait)\n            [event.clear() for event in self.events]\n            self.bar.wait()\n        self._update_event_id()\n\n    @contextmanager\n    def wait(self):\n        self.events[self._event_id].wait()\n        yield\n        if self._event_id == NUM_SHARED_BLOCK - 1:\n            self.bar.wait()\n            self.bar.wait()\n        self._update_event_id()\n\n    @asynccontextmanager\n    async def wait_async(self):\n        event_loop = asyncio.get_event_loop()\n        await event_loop.run_in_executor(None, self.events[self._event_id].wait)\n        yield\n        if self._event_id == NUM_SHARED_BLOCK - 1:\n            self.bar.wait()\n            self.bar.wait()\n        self._update_event_id()\n\n    def close(self):\n        for event in self.events:\n            event.set()\n        self.bar.abort()\n\n\nclass SharedBuffer:\n    \"\"\"Shared buffer.\"\"\"\n\n    def __init__(self, proc_id: int, notifier: Notifier, name: str = None):\n        self.proc_id = proc_id\n        self.notifier = notifier\n        self.is_create = name is None\n        if self.is_create:\n            # double buffer\n            self.shm = shared_memory.SharedMemory(create=True, size=SHARED_BLOCK_REAL_SIZE * NUM_SHARED_BLOCK)\n        else:\n            self.shm = shared_memory.SharedMemory(name=name)\n        self._buf_id = 0\n\n        if proc_id >= 0:\n            self.proc_mask = 1 << proc_id\n        else:\n            self.proc_mask = 0\n\n        self.is_closed = False\n\n    @contextmanager\n    def acquire_buf(self):\n        buf = self.shm.buf\n        assert buf is not None\n        buf_start = self._buf_id * SHARED_BLOCK_REAL_SIZE\n        out_buf = buf[buf_start:buf_start + SHARED_BLOCK_REAL_SIZE]\n        yield out_buf\n        self._buf_id = (self._buf_id + 1) % NUM_SHARED_BLOCK\n\n    def name(self):\n        return self.shm.name\n\n    def pack_data(self, data, receiver_mask):\n        \"\"\"Pack data.\"\"\"\n        dumped_data = pickle.dumps(data, protocol=pickle.HIGHEST_PROTOCOL)\n        data_size = len(dumped_data)\n\n        num_packs = get_num_packages(data_size)\n        head = struct.pack('II', data_size, receiver_mask)\n\n        for _ in range(num_packs):\n            with self.acquire_buf() as buf:\n                pac_size = min(len(dumped_data), SHARED_BLOCK_SIZE)\n                packed_data = head + dumped_data[:pac_size]\n                buf[:HEAD_SIZE + pac_size] = packed_data\n                dumped_data = dumped_data[pac_size:]\n                yield buf\n\n    def send(self, data, receiver_mask: int = 0xff):\n        \"\"\"Pack data.\"\"\"\n        for _ in self.pack_data(data, receiver_mask):\n            self.notifier.set()\n\n    async def send_async(self, data, receiver_mask: int = 0xff):\n        \"\"\"Async pack data.\"\"\"\n        for _ in self.pack_data(data, receiver_mask):\n            await self.notifier.set_async()\n\n    def _receive_step0(self):\n        \"\"\"step0.\"\"\"\n        with self.acquire_buf() as buf:\n            head = buf[:HEAD_SIZE]\n            data_size, receiver_mask = struct.unpack('II', head)\n            is_receiver = ((receiver_mask & self.proc_mask) > 0)\n\n            pac_size = min(data_size, SHARED_BLOCK_SIZE)\n            remain_size = data_size - pac_size\n\n            dumped_data = b''\n            if is_receiver:\n                dumped_data += buf[HEAD_SIZE:HEAD_SIZE + pac_size]\n\n        return dumped_data, is_receiver, remain_size\n\n    def _receive_step1(self, dumped_data, is_receiver, remain_size):\n        \"\"\"step1.\"\"\"\n        while remain_size > 0:\n            with self.notifier.wait(), self.acquire_buf() as buf:\n                pac_size = min(remain_size, SHARED_BLOCK_SIZE)\n                remain_size -= pac_size\n                if not is_receiver:\n                    continue\n                dumped_data += buf[HEAD_SIZE:HEAD_SIZE + pac_size]\n\n        if not is_receiver:\n            return None\n        data = pickle.loads(dumped_data)\n        return data\n\n    def receive(self):\n        \"\"\"Unpack data.\"\"\"\n        with self.notifier.wait():\n            dumped_data, is_receiver, remain_size = self._receive_step0()\n        return self._receive_step1(dumped_data, is_receiver, remain_size)\n\n    async def receive_async(self):\n        \"\"\"Async receive data.\"\"\"\n        async with self.notifier.wait_async():\n            dumped_data, is_receiver, remain_size = self._receive_step0()\n        return self._receive_step1(dumped_data, is_receiver, remain_size)\n\n    def close(self):\n        if self.is_closed:\n            return\n        self.shm.close()\n        if self.is_create:\n            self.shm.unlink()\n        self.notifier.close()\n        self.is_closed = True\n\n\nclass MPExecutor(ExecutorBase):\n    \"\"\"Single node multi device Executor powered by multiprocess.\"\"\"\n\n    @classmethod\n    def setup_master_addr(cls):\n        \"\"\"Setup master addr.\"\"\"\n        port = find_available_port()\n        os.environ.setdefault('MASTER_ADDR', '127.0.0.1')\n        os.environ.setdefault('MASTER_PORT', str(port))\n        addr = os.environ['MASTER_ADDR']\n        port = os.environ['MASTER_PORT']\n        setup_master_addr(addr, port)\n\n    def __init__(self,\n                 model_path: str,\n                 model_config: ModelConfig,\n                 cache_config: CacheConfig,\n                 backend_config: BackendConfig,\n                 dist_config: DistConfig,\n                 misc_config: MiscConfig,\n                 adapters: Dict[str, str] = None,\n                 specdecode_config: SpecDecodeConfig = None,\n                 device_type: str = 'cuda'):\n        \"\"\"Initialize Executor.\"\"\"\n        super().__init__(model_path=model_path,\n                         model_config=model_config,\n                         cache_config=cache_config,\n                         backend_config=backend_config,\n                         dist_config=dist_config,\n                         misc_config=misc_config,\n                         specdecode_config=specdecode_config,\n                         adapters=adapters,\n                         device_type=device_type)\n\n        # initialize processes.\n        self.setup_master_addr()\n        mp_ctx = mp.get_context('spawn')\n        self.mp_ctx = mp_ctx\n        self.comm_notifier = Notifier(self.world_size, mp_ctx)\n        self.comm_buf = SharedBuffer(-1, notifier=self.comm_notifier)\n        self.comm_buf_name = self.comm_buf.name()\n\n        logger.info('Creating processes.')\n        self.procs: List[ExecutorProc] = []\n        self.ret_bufs: List[SharedBuffer] = []\n        for proc_id in range(self.world_size):\n            proc = ExecutorProc(proc_id=proc_id, mp_ctx=mp_ctx)\n\n            ret_notifier = Notifier(1, mp_ctx)\n            ret_buf = SharedBuffer(0, notifier=ret_notifier)\n            self.ret_bufs.append(ret_buf)\n            proc.start(proc_id=proc_id,\n                       comm_notifier=self.comm_notifier,\n                       comm_buf_name=self.comm_buf_name,\n                       ret_notifier=ret_notifier,\n                       ret_buf_name=ret_buf.name(),\n                       model_path=model_path,\n                       model_config=model_config,\n                       cache_config=cache_config,\n                       backend_config=backend_config,\n                       dist_config=dist_config,\n                       misc_config=misc_config,\n                       specdecode_config=specdecode_config,\n                       adapters=adapters,\n                       device_type=device_type,\n                       log_level=logger.level)\n            self.procs.append(proc)\n\n        self._prefetch_task: asyncio.Task = None\n        self.remote_outs: asyncio.Queue = None\n\n        def signal_handler(signum, frame):\n            logger.error('Received custom termination signal from sub processing, exiting...')\n            self.stop()\n            self.release()\n            os._exit(1)\n\n        signal.signal(signal.SIGUSR1, signal_handler)\n\n    def collective_rpc(self,\n                       method: str,\n                       args: Tuple[Any] = None,\n                       kwargs: Dict[str, Any] = None,\n                       receiver_mask: int = 0xff,\n                       return_mask: int = 0xff):\n        \"\"\"Collective rpc.\"\"\"\n        if args is None:\n            args = list()\n        if kwargs is None:\n            kwargs = dict()\n        return_mask &= receiver_mask\n        self.comm_buf.send(\n            dict(\n                method=method,\n                args=args,\n                kwargs=kwargs,\n                return_mask=return_mask,\n            ),\n            receiver_mask=receiver_mask,\n        )\n\n        if return_mask:\n            outputs = [None] * len(self.ret_bufs)\n            for proc_id, ret_buf in enumerate(self.ret_bufs):\n                if bool(return_mask & (1 << proc_id)):\n                    outputs[proc_id] = ret_buf.receive()\n            return outputs\n\n    async def collective_rpc_async(self,\n                                   method: str,\n                                   args: Tuple[Any] = None,\n                                   kwargs: Dict[str, Any] = None,\n                                   receiver_mask: int = 0xff,\n                                   return_mask: int = 0xff):\n        \"\"\"Collective rpc.\"\"\"\n        if args is None:\n            args = list()\n        if kwargs is None:\n            kwargs = dict()\n        self.comm_buf.send(\n            dict(\n                method=method,\n                args=args,\n                kwargs=kwargs,\n                return_mask=return_mask,\n            ),\n            receiver_mask=receiver_mask,\n        )\n\n        if return_mask:\n            outputs = [None] * len(self.ret_bufs)\n            for proc_id, ret_buf in enumerate(self.ret_bufs):\n                if bool(return_mask & (1 << proc_id)):\n                    outputs[proc_id] = await ret_buf.receive_async()\n            return outputs\n\n    def download_models(self):\n        \"\"\"Download model.\"\"\"\n        raise NotImplementedError('Not Implemented.')\n\n    def build_model(self):\n        \"\"\"Build model.\"\"\"\n        self.collective_rpc('build_model')\n\n    def gather_free_mem(self):\n        \"\"\"Gather available memory.\"\"\"\n        ret = self.collective_rpc('get_free_mem')\n        return ret\n\n    def set_cache_config(self, cache_config: CacheConfig, spec_cache_config: CacheConfig = None):\n        \"\"\"Set all cache config.\"\"\"\n        self.collective_rpc('set_cache_config', args=(cache_config, spec_cache_config))\n\n    def set_model_config(self, model_config: ModelConfig, spec_model_config: ModelConfig = None):\n        \"\"\"Set all cache config.\"\"\"\n        self.collective_rpc('set_model_config', args=(model_config, spec_model_config))\n\n    def build_graph_runner(self):\n        \"\"\"Build graph runner.\"\"\"\n        self.collective_rpc('build_graph_runner')\n\n    def build_cache_engine(self):\n        \"\"\"Build cache engine.\"\"\"\n        self.collective_rpc('build_cache_engine')\n\n    def warmup(self):\n        \"\"\"Build cache engine.\"\"\"\n        self.collective_rpc('warmup')\n\n    async def _prefetch_outputs(self):\n        while True:\n            out = (await self.collective_rpc_async('get_outputs', receiver_mask=1, return_mask=1))[0]\n            self.remote_outs.put_nowait(out)\n\n    def start(self, forward_event: asyncio.Event):\n        \"\"\"Start engine loop.\"\"\"\n        self.collective_rpc('start')\n\n        self.remote_outs = asyncio.Queue()\n        event_loop = asyncio.get_event_loop()\n        self._prefetch_task = event_loop.create_task(self._prefetch_outputs())\n\n    async def wait_tasks(self):\n        \"\"\"Wait tasks.\"\"\"\n        # we don't need a complex wait tasks since MPExecutor will be deprecated soon.\n        await self._prefetch_task\n\n    async def forward_async(self, inputs):\n        \"\"\"Start forward.\"\"\"\n        await self.collective_rpc_async('forward_async', args=(inputs, ), return_mask=0)\n\n    async def get_output_async(self):\n        \"\"\"Get output async.\"\"\"\n        return await self.remote_outs.get()\n\n    def get_input_processor(self):\n        \"\"\"Get input processor.\"\"\"\n        return self.collective_rpc('get_input_processor', receiver_mask=1, return_mask=1)[0]\n\n    def stop(self):\n        \"\"\"Stop engine loop.\"\"\"\n        if self._prefetch_task is not None:\n            self._prefetch_task.cancel()\n\n    def release(self):\n        \"\"\"release.\"\"\"\n        for proc in self.procs:\n            proc.close()\n\n        for proc in self.procs:\n            proc.join()\n\n        self.comm_buf.close()\n        for ret_buf in self.ret_bufs:\n            ret_buf.close()\n\n\nclass MPWorkerWrapper(WorkerWrapperBase):\n    \"\"\"Mp worker wrapper.\"\"\"\n\n    def __init__(\n        self,\n        model_path: str,\n        cache_config: CacheConfig,\n        backend_config: BackendConfig,\n        model_config: ModelConfig,\n        dist_config: DistConfig,\n        misc_config: MiscConfig,\n        specdecode_config: SpecDecodeConfig = None,\n        adapters: Dict[str, str] = None,\n        device_type: str = 'cuda',\n        log_level: int = 30,\n    ):\n        super().__init__(\n            model_path=model_path,\n            cache_config=cache_config,\n            backend_config=backend_config,\n            model_config=model_config,\n            dist_config=dist_config,\n            misc_config=misc_config,\n            specdecode_config=specdecode_config,\n            adapters=adapters,\n            device_type=device_type,\n            log_level=log_level,\n        )\n\n\nclass ExecutorProc:\n\n    def __init__(self, proc_id: int, mp_ctx: SpawnContext):\n        \"\"\"Executor proc.\"\"\"\n        self.proc_id = proc_id\n        self.mp_ctx = mp_ctx\n        self._proc = None\n\n    def start(self, **kwargs):\n        \"\"\"Start proc.\"\"\"\n        assert self._proc is None\n        proc = self.mp_ctx.Process(target=self._main_loop,\n                                   kwargs=kwargs,\n                                   name=f'ExecutorProc-{self.proc_id}',\n                                   daemon=True)\n        proc.start()\n        self._proc = proc\n\n    def close(self):\n        \"\"\"Stop proc.\"\"\"\n        if self._proc is None:\n            return\n        if not self._proc.is_alive():\n            return\n        self._proc.terminate()\n\n    def join(self):\n        if self._proc is None:\n            return\n        self._proc.join()\n\n    def _main_loop(\n        self,\n        proc_id: int,\n        comm_notifier: Any,\n        comm_buf_name: str,\n        ret_notifier: Any,\n        ret_buf_name: str,\n        model_path: str,\n        model_config: ModelConfig,\n        cache_config: CacheConfig,\n        backend_config: BackendConfig,\n        dist_config: DistConfig,\n        misc_config: MiscConfig,\n        specdecode_config: SpecDecodeConfig = None,\n        adapters: Dict[str, str] = None,\n        device_type: str = 'cuda',\n        log_level: int = 30,\n    ):\n        \"\"\"Main loop.\"\"\"\n        init_backend(device_type)\n        torch.cuda.set_device(proc_id)\n\n        # catch signal\n        def handle_sigterm(signum, frame):\n            logger.debug(f'Proc[{proc_id}] terminated.')\n            exit(0)\n\n        signal.signal(signal.SIGTERM, handle_sigterm)\n\n        worker = MPWorkerWrapper(model_path,\n                                 cache_config=cache_config,\n                                 backend_config=backend_config,\n                                 model_config=model_config,\n                                 dist_config=dist_config,\n                                 misc_config=misc_config,\n                                 specdecode_config=specdecode_config,\n                                 adapters=adapters,\n                                 device_type=device_type,\n                                 log_level=log_level)\n        try_import_deeplink(device_type)\n        worker.init_process_group(proc_id)\n        comm_buf = SharedBuffer(proc_id, notifier=comm_notifier, name=comm_buf_name)\n        ret_buf = SharedBuffer(-1, notifier=ret_notifier, name=ret_buf_name)\n        event_loop = asyncio.new_event_loop()\n        asyncio.set_event_loop(event_loop)\n        destroy_pg = worker.world_size > 1\n        try:\n            event_loop.run_until_complete(\n                self._main_loop_impl(proc_id, comm_buf=comm_buf, ret_buf=ret_buf, worker=worker))\n        except asyncio.CancelledError:\n            logger.warning(f'Proc[{proc_id}] main loop cancelled.')\n            destroy_pg = False\n            os.kill(os.getppid(), signal.SIGUSR1)\n        except SystemExit:\n            # terminated by executor\n            logger.debug(f'Proc[{proc_id}] system exit.')\n        except KeyboardInterrupt:\n            logger.debug(f'Proc[{proc_id}] keyboard interrupt.')\n            exit(0)\n        except BaseException:\n            logger.exception(f'Proc[{proc_id}] failed')\n            os.kill(os.getppid(), signal.SIGUSR1)\n        finally:\n            logger.debug(f'Proc[{proc_id}] cleanup.')\n            worker.stop()\n            worker.release()\n            comm_buf.close()\n            ret_buf.close()\n            if dist.is_initialized() and destroy_pg:\n                dist.destroy_process_group()\n\n    @staticmethod\n    async def _task_wrapper(func, args: List, kwargs: Dict, need_return: bool, ret_buf: SharedBuffer):\n        ret = await func(*args, **kwargs)\n        if need_return:\n            await ret_buf.send_async(ret)\n\n    async def _main_loop_impl(self, proc_id: int, comm_buf: SharedBuffer, ret_buf: SharedBuffer,\n                              worker: MPWorkerWrapper):\n        \"\"\"Main loop.\"\"\"\n        proc_mask = 1 << proc_id\n        event_loop = asyncio.get_event_loop()\n        while True:\n            command = await comm_buf.receive_async()\n            if command is None:\n                continue\n            method = command['method']\n            return_mask = command.get('return_mask', True)\n            args = command.get('args', list())\n            kwargs = command.get('kwargs', dict())\n            need_return = bool(proc_mask & return_mask)\n\n            func = getattr(worker, method, None)\n            assert func is not None, f'method: <{method}> not exists.'\n            call_async = asyncio.iscoroutinefunction(func)\n\n            logger.debug(f'proc[{proc_id}] call method: <{method}>.')\n            if call_async:\n                event_loop.create_task(self._task_wrapper(func, args, kwargs, need_return, ret_buf))\n            else:\n                ret = func(*args, **kwargs)\n                if need_return:\n                    ret_buf.send(ret)\n"
  },
  {
    "path": "lmdeploy/pytorch/engine/executor/ray_executor.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport asyncio\nimport contextlib\nimport json\nimport os\nfrom typing import Any, Dict, List, Optional, Tuple\n\nimport ray\nimport ray.exceptions\nimport torch\nfrom ray.util.placement_group import PlacementGroup\nfrom ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy\n\nfrom lmdeploy.pytorch import envs as _envs\nfrom lmdeploy.pytorch.backends.selector import init_backend\nfrom lmdeploy.pytorch.config import BackendConfig, CacheConfig, DistConfig, MiscConfig, ModelConfig, SpecDecodeConfig\nfrom lmdeploy.pytorch.devices import DeviceContext, get_device_manager\nfrom lmdeploy.pytorch.disagg.conn.protocol import DistServeInitRequest, DistServeKVTransferEndpointInfo\nfrom lmdeploy.pytorch.disagg.messages import MigrationExecutionBatch\nfrom lmdeploy.pytorch.ray import RayContext, get_device_str\nfrom lmdeploy.pytorch.utils import wait_for_async_tasks\nfrom lmdeploy.utils import get_logger, try_import_deeplink\n\nfrom .base import ExecutorBase\nfrom .base_worker import WorkerWrapperBase\nfrom .dist_utils import find_available_port\n\nlogger = get_logger('lmdeploy')\n\n\ndef _get_master_addr():\n    \"\"\"Get master addr.\"\"\"\n    addr = _envs.dist_master_addr\n    if addr is not None:\n        return addr\n    gcs_addr = ray.get_runtime_context().gcs_address\n    master_addr = gcs_addr.split(':')[0]\n    return master_addr\n\n\ndef _get_master_port():\n    \"\"\"Get master port.\"\"\"\n    port = _envs.dist_master_port\n    if port is not None:\n        return port\n    return find_available_port()\n\n\ndef get_ascend_device_rank_mapping(master_addr):\n    rank_table_file = _envs.ascend_rank_table_file\n    if not rank_table_file:\n        raise ValueError('ASCEND_RANK_TABLE_FILE_PATH is not set')\n    with open(rank_table_file, 'r') as f:\n        rank_table = json.load(f)\n    try:\n        assert master_addr == rank_table['server_list'][0]['server_id'], 'Master address does not match rank table'\n        rank_mapping: Dict[int, int] = {}\n        worker_ip_by_rank: Dict[int, str] = {}\n        for server in rank_table['server_list']:\n            node_ip = server['server_id']\n            for idx, device in enumerate(server['device']):\n                # Prefer explicit device_id if present; fall back to enumeration order.\n                local_rank = int(device.get('device_id', idx))\n                global_rank = int(device['rank_id'])\n                rank_mapping[global_rank] = local_rank\n                worker_ip_by_rank[global_rank] = node_ip\n\n        if len(worker_ip_by_rank) == 0:\n            raise ValueError('Rank table contains no devices.')\n\n        ranks = sorted(worker_ip_by_rank.keys())\n        if ranks[0] != 0 or ranks[-1] != len(ranks) - 1:\n            raise ValueError(f'Rank ids are not contiguous starting from 0: {ranks[:8]}...{ranks[-8:]}')\n        worker_ips = [worker_ip_by_rank[r] for r in range(len(ranks))]\n    except Exception as e:\n        logger.error(f'Parse rank table file({rank_table})  failed')\n        raise e\n\n    envs = {\n        'ASCEND_RANK_TABLE_FILE_PATH': rank_table_file,\n    }\n    return rank_mapping, worker_ips, envs\n\n\ndef _update_env_cuda_alloc_conf(env_vars: Dict):\n    \"\"\"Update runtime env for CUDA alloc conf.\"\"\"\n    cuda_alloc_conf = os.getenv('PYTORCH_CUDA_ALLOC_CONF', None)\n    if cuda_alloc_conf is None:\n        return\n\n    # check and update conf, skip expandable_segments\n    cuda_alloc_conf = cuda_alloc_conf.split(',')\n    new_cuda_alloc_conf = []\n    for conf in cuda_alloc_conf:\n        if 'expandable_segments' in conf:\n            if 'True' in conf:\n                logger.warning('\"expandable_segments:True\" is not supported.')\n            continue\n        new_cuda_alloc_conf.append(conf)\n    if len(new_cuda_alloc_conf) == 0:\n        new_cuda_alloc_conf = ['expandable_segments:False']\n    cuda_alloc_conf = ','.join(new_cuda_alloc_conf)\n\n    # update env_vars\n    env_vars['PYTORCH_CUDA_ALLOC_CONF'] = cuda_alloc_conf\n\n\ndef _update_runtime_envs(runtime_env: Dict):\n    \"\"\"Update runtime envs.\"\"\"\n    new_envs = _envs.get_all_envs()\n    env_vars: Dict = runtime_env.get('env_vars', {})\n    env_vars.update(new_envs)\n    _update_env_cuda_alloc_conf(env_vars)\n    runtime_env['env_vars'] = env_vars\n    return runtime_env\n\n\ndef _update_runtime_env_nsys(runtime_env: Dict):\n    \"\"\"Update runtime env for nsys.\"\"\"\n    nsight_env = {\n        't': 'cuda,cudnn,cublas,nvtx',\n        'o': \"'worker_process_%p'\",\n        'stop-on-exit': 'true',\n    }\n    prefix_path = _envs.ray_nsys_output_prefix\n    if prefix_path is not None:\n        nsight_env['o'] = f'{prefix_path}%p'\n    runtime_env['nsight'] = nsight_env\n    return runtime_env\n\n\nclass RemoteLogger:\n    \"\"\"Remote logger.\"\"\"\n\n    def __init__(self):\n        self._records = dict()\n        self._next_handle = 0\n\n    def start(self, msg: str):\n        \"\"\"Start remote log.\"\"\"\n        record = torch.profiler.record_function(msg)\n        record.__enter__()\n        handle = self._next_handle\n        self._records[handle] = record\n        self._next_handle += 1\n        return handle\n\n    def end(self, handle: int):\n        \"\"\"End remote log.\"\"\"\n        record = self._records.pop(handle, None)\n        if record is not None:\n            record.__exit__(None, None, None)\n\n\nclass RayWorkerWrapper(WorkerWrapperBase):\n    \"\"\"Worker wrapper.\"\"\"\n\n    def __init__(\n        self,\n        model_path: str,\n        cache_config: CacheConfig,\n        backend_config: BackendConfig,\n        model_config: ModelConfig,\n        dist_config: DistConfig,\n        misc_config: MiscConfig,\n        adapters: Dict[str, str] = None,\n        device_type: str = 'cuda',\n        dtype: str = 'auto',\n        log_level: int = 30,\n        specdecode_config: SpecDecodeConfig = None,\n    ):\n        init_backend(device_type)\n        try_import_deeplink(device_type)\n\n        super().__init__(\n            model_path=model_path,\n            cache_config=cache_config,\n            backend_config=backend_config,\n            model_config=model_config,\n            dist_config=dist_config,\n            misc_config=misc_config,\n            adapters=adapters,\n            device_type=device_type,\n            log_level=log_level,\n            specdecode_config=specdecode_config,\n        )\n        self.node_ip = ray.util.get_node_ip_address()\n        self._remote_logger = RemoteLogger()\n\n    def set_device(self, local_rank):\n        \"\"\"Set worker local rank.\"\"\"\n        torch.cuda.set_device(local_rank)\n\n    def set_env(self, envs: Dict[str, str]):\n        for key, value in envs.items():\n            os.environ[key] = value\n\n    def get_node_ip(self):\n        \"\"\"Get worker ip.\"\"\"\n        return self.node_ip\n\n    def warmup_dist(self):\n        # None default CUDA_VISIBLE_DEVICES might leads to slow first time all_reduce\n        # WHY?\n        logger.debug('Warmup all_reduce.')\n        import torch\n\n        from lmdeploy.pytorch.distributed import all_reduce, get_dist_manager\n        with get_dist_manager().context(self.dist_ctx):\n            group = self.dist_ctx.tp_group.gpu_group\n            tmp = torch.empty((1, ), device='cuda')\n            all_reduce(tmp, group=group)\n\n    def pack_output(self, output: Dict):\n        \"\"\"Pack output.\"\"\"\n        return output.to_numpy()\n\n    def remote_log_start(self, msg: str):\n        \"\"\"Remote log start.\"\"\"\n        return self._remote_logger.start(msg)\n\n    def remote_log_end(self, handle: int):\n        \"\"\"Remote log end.\"\"\"\n        return self._remote_logger.end(handle)\n\n    def exit(self):\n        \"\"\"Exit actor.\"\"\"\n        ray.actor.exit_actor()\n\n\nclass RayExecutor(ExecutorBase):\n    \"\"\"Ray executor.\"\"\"\n\n    def __init__(\n        self,\n        model_path: str,\n        model_config: ModelConfig,\n        cache_config: CacheConfig,\n        backend_config: BackendConfig,\n        dist_config: DistConfig,\n        misc_config: MiscConfig,\n        adapters: Dict[str, str] = None,\n        device_type: str = 'cuda',\n        dtype: str = 'auto',\n        specdecode_config: SpecDecodeConfig = None,\n    ):\n        \"\"\"Initialize Executor.\"\"\"\n        super().__init__(\n            model_path=model_path,\n            model_config=model_config,\n            cache_config=cache_config,\n            backend_config=backend_config,\n            dist_config=dist_config,\n            misc_config=misc_config,\n            adapters=adapters,\n            device_type=device_type,\n            specdecode_config=specdecode_config,\n        )\n\n        device_ctx = DeviceContext(device_type)\n        with get_device_manager().context(device_ctx):\n            logger.info('Init ray cluster.')\n            attn_tp = dist_config.attn_tp\n            self.ray_ctx = RayContext(attn_tp, dp=dist_config.dp, device_type=device_type)\n            placement_group = self.ray_ctx.get_placement_group()\n            self.placement_group = placement_group\n\n            if self.dp == 1:\n                self.master_addr = _get_master_addr()\n                self.master_port = _get_master_port()\n            else:\n                self.master_addr = _envs.dp_master_addr\n                self.master_port = _envs.dp_master_port\n                if self.master_addr is None or self.master_port is None:\n                    raise RuntimeError('DP > 1 requires \"LMDEPLOY_DP_MASTER_ADDR\" and \"LMDEPLOY_DP_MASTER_PORT\".')\n\n            # create workerwrapper actors\n            worker_kwargs = dict(\n                model_path=model_path,\n                cache_config=cache_config,\n                model_config=model_config,\n                backend_config=backend_config,\n                dist_config=dist_config,\n                misc_config=misc_config,\n                adapters=adapters,\n                device_type=device_type,\n                dtype=dtype,\n                log_level=logger.level,\n                specdecode_config=specdecode_config,\n            )\n\n            logger.info('Init ray workers.')\n            self.workers = self._init_workers_ray(placement_group, worker_kwargs)\n            self.dag = None\n            self._prefetch_task: asyncio.Task = None\n            self.remote_outs: asyncio.Queue = None\n\n            logger.info('Init distributed environment by device.')\n            self.rank_offset = dist_config.dp_rank * attn_tp\n            self._init_distributed_environment_by_device(device_type)\n\n            logger.info('Init distributed process group.')\n            ray.get([\n                worker.init_process_group.remote(rank + self.rank_offset, self.master_addr, self.master_port)\n                for rank, worker in enumerate(self.workers)\n            ])\n\n            if self.dist_config.world_size > 1:\n                logger.info('Warming up distribute environment, this might take long time, please waiting...')\n                ray.get([worker.warmup_dist.remote() for worker in self.workers])\n\n    def collective_rpc(self,\n                       method: str,\n                       args: Tuple[Any] = None,\n                       kwargs: Dict[str, Any] = None,\n                       timeout: float = None):\n        \"\"\"Collective rpc.\"\"\"\n        if args is None:\n            args = list()\n        if kwargs is None:\n            kwargs = dict()\n        return ray.get([getattr(worker, method).remote(*args, **kwargs) for worker in self.workers], timeout=timeout)\n\n    def build_model(self):\n        \"\"\"Build model.\"\"\"\n        self.collective_rpc('build_model')\n\n    def gather_free_mem(self):\n        \"\"\"Gather available memory.\"\"\"\n        return self.collective_rpc('get_free_mem')\n\n    def set_cache_config(self, cache_config: CacheConfig, spec_cache_config: CacheConfig = None):\n        \"\"\"Set all cache config.\"\"\"\n        self.collective_rpc('set_cache_config', (cache_config, spec_cache_config))\n\n    def set_model_config(self, model_config: ModelConfig, spec_model_config: ModelConfig = None):\n        \"\"\"Set all model config.\"\"\"\n        self.collective_rpc('set_model_config', (model_config, spec_model_config))\n\n    def build_graph_runner(self):\n        \"\"\"Build graph runner.\"\"\"\n        self.collective_rpc('build_graph_runner')\n\n    def build_cache_engine(self):\n        \"\"\"Build cache engine.\"\"\"\n        self.collective_rpc('build_cache_engine')\n\n    def update_params(self, request: Any):\n        \"\"\"Update params.\"\"\"\n        self.collective_rpc('update_params', (request, ))\n\n    def warmup(self):\n        \"\"\"Build cache engine.\"\"\"\n        self.collective_rpc('warmup')\n\n    def sleep(self, level: int = 1):\n        \"\"\"Sleep.\"\"\"\n        self.collective_rpc('sleep', (level, ))\n\n    def wakeup(self, tags: Optional[List[str]] = None):\n        \"\"\"Wakeup.\"\"\"\n        if tags is None or 'kv_cache' in tags:\n            self.update_configs()\n        self.collective_rpc('wakeup', (tags, ))\n\n    def get_input_processor(self):\n        \"\"\"Build cache engine.\"\"\"\n        return ray.get(self.workers[0].get_input_processor.remote())\n\n    def _prefetch_task_callback(self, task: asyncio.Task):\n        try:\n            task.result()\n        except asyncio.CancelledError:\n            logger.debug(f'{task.get_name()} cancelled.')\n        except KeyboardInterrupt:\n            logger.debug(f'{task.get_name()} KeyboardInterrupt.')\n        except BaseException:\n            logger.debug(f'{task.get_name()} task failed.')\n\n    def start(self, forward_event: asyncio.Event):\n        \"\"\"Start engine loop.\"\"\"\n        self.forward_event = forward_event\n        self.collective_rpc('start')\n\n        self.remote_outs = asyncio.Queue()\n        logger.info('Starting async task RayPrefetchOutput loop.')\n\n    async def wait_tasks(self):\n        \"\"\"Wait tasks.\"\"\"\n        dp_rank = self.dist_config.dp_rank\n        tasks_to_cancel = set()\n        event_loop = asyncio.get_event_loop()\n\n        async def _wait_single_worker(worker):\n            try:\n                task = worker.wait_tasks.remote()\n                tasks_to_cancel.add(task)\n                await task\n            except ray.exceptions.ActorDiedError:\n                # It is safe to ignore wait tasks on died actor\n                logger.info('RayExecutor worker has been killed before finish wait_tasks.')\n\n        tasks = [\n            event_loop.create_task(_wait_single_worker(worker), name=f'WorkerWaitTasks_{idx}')\n            for idx, worker in enumerate(self.workers)\n        ]\n        if self._prefetch_task is not None:\n            tasks.append(self._prefetch_task)\n        try:\n            await wait_for_async_tasks(tasks)\n        except asyncio.CancelledError:\n            logger.info(f'RayExecutor DP[{dp_rank}] wait_tasks cancelled.')\n            raise\n        except BaseException:\n            logger.error(f'RayExecutor DP[{dp_rank}] wait_tasks failed.')\n            raise\n        finally:\n            logger.debug(f'RayExecutor DP[{dp_rank}] wait_tasks cleanup.')\n            for task in tasks_to_cancel:\n                try:\n                    ray.cancel(task)\n                except ray.exceptions.ActorDiedError:\n                    logger.debug('RayExecutor worker has been killed before finish cancel task.')\n                except Exception as e:\n                    logger.error(f'RayExecutor DP[{dp_rank}] Cancel wait_tasks failed: {e}')\n\n    def stop(self):\n        \"\"\"Stop engine loop.\"\"\"\n        # TODO: For dp > 1 we currently rely on external teardown (e.g. Ray actor\n        # destruction) instead of explicitly stopping worker loops here. Implementing\n        # coordinated shutdown across multiple dp ranks is non-trivial, especially\n        # when some ranks may have already failed. The explicit stop_async RPC is\n        # therefore only issued when dp == 1.\n        if self.dp == 1:\n            try:\n                # add timeout might disable dump profile\n                # hope this will not lead to hanging\n                self.collective_rpc('stop_async')\n            except ray.exceptions.ActorDiedError:\n                logger.info('RayExecutor worker has been killed before finish stop_async.')\n            logger.debug('RayExecutor workers stopped.')\n        if self._prefetch_task is not None:\n            self._prefetch_task.cancel()\n\n    def release(self):\n        \"\"\"release.\"\"\"\n        if _envs.ray_timeline_enable:\n            ray.timeline(_envs.ray_timeline_output_path)\n\n        if self.dp == 1:\n            try:\n                self.collective_rpc('release', timeout=5.0)\n                logger.debug('RayExecutor workers released.')\n            except ray.exceptions.ActorDiedError:\n                logger.info('RayExecutor worker has been killed before finish release.')\n                [ray.kill(worker) for worker in self.workers]\n            except ray.exceptions.GetTimeoutError:\n                logger.info('Ray release timeout, killing workers')\n                [ray.kill(worker) for worker in self.workers]\n        else:\n            [ray.kill(worker) for worker in self.workers]\n\n        self.ray_ctx.shutdown()\n\n    def _compile_dag(self):\n        \"\"\"Compile dag.\"\"\"\n        from ray.dag.input_node import InputNode\n        from ray.dag.output_node import MultiOutputNode\n        with InputNode() as input_data:\n            outputs = [worker.forward_async.bind(input_data) for worker in self.workers]\n            output = MultiOutputNode(outputs)\n\n        return output\n\n    async def forward_async(self, inputs):\n        \"\"\"Start forward.\"\"\"\n\n        if self.dag is None:\n            self.dag = self._compile_dag()\n            self._prev_inputs = None\n            self._prev_out = None\n\n        if self._prev_out is not None:\n            try:\n                ray.get(self._prev_out)\n            except SystemExit:\n                logger.error('Ray worker exited.')\n                raise\n            finally:\n                # free ray.put inputs\n                try:\n                    ray._private.internal_api.free(self._prev_inputs)\n                except Exception as e:\n                    logger.warning(f'Free input ref failed: {e}')\n\n        self._prev_inputs = ray.put(inputs)\n        # make sure in order\n        self._prev_out = self.dag.execute(self._prev_inputs)\n\n    async def get_output_async(self):\n        \"\"\"Get output async.\"\"\"\n        ret = await self.workers[0].get_outputs.remote()\n        ret = ret.to_tensor()\n        return ret\n\n    @contextlib.contextmanager\n    def remote_log(self, msg: str):\n        \"\"\"Send log for debugging.\n\n        Do not use it in production.\n        \"\"\"\n        handle_ref = self.workers[0].remote_log_start.remote(msg)\n        yield\n        handle = ray.get(handle_ref)\n        ray.get(self.workers[0].remote_log_end.remote(handle))\n\n    def _sort_workers(self, driver_ip: str, workers: List[RayWorkerWrapper]):\n        \"\"\"Sort workers by ip.\"\"\"\n        worker_ips = ray.get([worker.get_node_ip.remote() for worker in workers])\n\n        ip_counts: Dict[str, int] = {}\n        for ip in worker_ips:\n            ip_counts[ip] = ip_counts.get(ip, 0) + 1\n\n        worker_ip_map = list(zip(workers, worker_ips))\n\n        def sort_by_driver_then_worker_ip(item):\n            \"\"\"Sort the workers based on 3 properties:\n\n            1. If the worker is on the same node as the driver (vllm engine),\n                it should be placed first.\n            2. Then, if the worker is on a node with fewer workers, it should\n                be placed first.\n            3. Finally, if the work is on a node with smaller IP address, it\n                should be placed first.\n            \"\"\"\n            ip = item[1]\n            return (0 if ip == driver_ip else 1, ip_counts[ip], ip)\n\n        # After sorting, the workers on the same node will be\n        # close to each other, and the workers on the driver\n        # node will be placed first.\n        sorted_worker_ip_map = sorted(worker_ip_map, key=sort_by_driver_then_worker_ip)\n        workers = [item[0] for item in sorted_worker_ip_map]\n        return workers\n\n    def _sort_workers_by_ip(self, ips, workers: List[RayWorkerWrapper]):\n        worker_ips = ray.get([worker.get_node_ip.remote() for worker in workers])\n\n        if len(ips) != len(workers):\n            raise ValueError(f'The length of the ips list does not match the workers, '\n                             f'ips length: {len(ips)}, workers length: {len(workers)}')\n\n        # Check if all elements in ips are present in worker_ips and vice versa (ignoring order)\n        if set(ips) != set(worker_ips):\n            raise ValueError(f'The IP addresses in the ips list do not match the worker IPs. '\n                             f'ips: {ips}, worker_ips: {worker_ips}')\n\n        worker_ip_map = list(zip(workers, worker_ips))\n        ip_priority = {ip: idx for idx, ip in enumerate(ips)}\n\n        def get_priority(ip):\n            return ip_priority.get(ip)\n\n        sorted_worker_ip_map = sorted(worker_ip_map, key=lambda x: get_priority(x[1]))\n        sorted_workers = [item[0] for item in sorted_worker_ip_map]\n        return sorted_workers\n\n    def _valid_bundle_id(self, bundle_id: int):\n        \"\"\"Check if a bundle is valid only when self.use_external_ray=True.\"\"\"\n        if (not self.ray_ctx.owned_pg and _envs.ray_external_pg_bundles\n                and bundle_id not in _envs.ray_external_pg_bundles):\n            return False\n        return True\n\n    def _init_workers_ray(self, placement_group: PlacementGroup, worker_kwargs: dict):\n        \"\"\"Init worker ray.\"\"\"\n        device_str = get_device_str()\n        bundle_indices = []\n        for bundle_id, bundle in enumerate(placement_group.bundle_specs):\n            if bundle.get(device_str, 0) and self._valid_bundle_id(bundle_id):\n                bundle_indices.append(bundle_id)\n        attn_tp = self.dist_config.attn_tp\n        bundle_indices = bundle_indices[:attn_tp]\n\n        workers = list()\n        for _, bundle_id in enumerate(bundle_indices):\n            scheduling_strategy = PlacementGroupSchedulingStrategy(\n                placement_group=placement_group,\n                placement_group_capture_child_tasks=True,\n                placement_group_bundle_index=bundle_id,\n            )\n\n            if device_str == 'GPU':\n                runtime_env = dict()\n                runtime_env = _update_runtime_envs(runtime_env)\n                if _envs.ray_nsys_enable:\n                    runtime_env = _update_runtime_env_nsys(runtime_env)\n                worker = ray.remote(\n                    num_cpus=0,\n                    num_gpus=0.01,\n                    scheduling_strategy=scheduling_strategy,\n                    runtime_env=runtime_env,\n                )(RayWorkerWrapper).remote(**worker_kwargs)\n            else:\n                worker = ray.remote(\n                    num_cpus=0,\n                    num_gpus=0,\n                    resources={device_str: 0.01},\n                    scheduling_strategy=scheduling_strategy,\n                )(RayWorkerWrapper).remote(**worker_kwargs)\n            workers.append(worker)\n        return workers\n\n    def _init_distributed_environment_by_device(self, device_str: str):\n        \"\"\"Init distributed environment.\"\"\"\n        driver_ip = _get_master_addr()\n        if device_str == 'cuda':\n            self.workers = self._sort_workers(driver_ip, self.workers)\n\n        elif device_str == 'ascend':\n            self._init_ascend_distributed_environment(driver_ip)\n        elif device_str in ['camb', 'maca']:\n            self.workers = self._sort_workers(driver_ip, self.workers)\n            ray.get([worker.set_device.remote(idx) for idx, worker in enumerate(self.workers)])\n        else:\n            raise ValueError(f'Unsupported device type: {device_str}')\n\n    def _init_ascend_distributed_environment(self, driver_ip):\n        \"\"\"Init ascend distributed environment.\"\"\"\n        rank_table_file = _envs.ascend_rank_table_file\n        set_rt_visable_devices_by_ray = _envs.ascend_set_rt_visable_devices_by_ray\n\n        if rank_table_file:\n            # if rank table file is set, use it to get rank mapping, multiple nodes\n            rank_mapping, worker_ips, envs = get_ascend_device_rank_mapping(driver_ip)\n            rank_start = self.rank_offset\n            rank_end = rank_start + len(self.workers)\n            if rank_end > len(worker_ips):\n                raise ValueError(\n                    'Rank table world_size is smaller than required ranks for current dp_rank. '\n                    f'rank_table_world_size={len(worker_ips)}, required_rank_range=[{rank_start}, {rank_end})')\n\n            # In dp mode each process only owns a slice of global ranks.\n            expected_worker_ips = worker_ips[rank_start:rank_end]\n            self.workers = self._sort_workers_by_ip(expected_worker_ips, self.workers)\n\n            ray.get(\n                [worker.set_device.remote(rank_mapping[rank_start + idx]) for idx, worker in enumerate(self.workers)])\n            ray.get([worker.set_env.remote(envs) for worker in self.workers])\n        elif not set_rt_visable_devices_by_ray:\n            # if rank table file is not set, treat as single node\n            # simply set device by index, this is for single node, multiple devices\n            self.workers = self._sort_workers(driver_ip, self.workers)\n            ray.get([worker.set_device.remote(idx + self.rank_offset) for idx, worker in enumerate(self.workers)])\n        else:\n            self.workers = self._sort_workers(driver_ip, self.workers)\n\n    \"\"\" PD Disaggregation API Begin \"\"\"\n\n    def p2p_initialize(self, init_request: DistServeInitRequest):\n        return self.collective_rpc('p2p_initialize', (init_request, ))\n\n    def p2p_connect(self, remote_engine_id: str, conn_request: List[DistServeKVTransferEndpointInfo]):\n        \"\"\"Rdma connect.\"\"\"\n        return self.collective_rpc('p2p_connect', (\n            remote_engine_id,\n            conn_request,\n        ))\n\n    async def migrate(self, batch: MigrationExecutionBatch):\n        jobs = (worker.migrate.remote(batch) for worker in self.workers)\n        return await asyncio.gather(*jobs)\n\n    \"\"\" PD Disaggregation API Begin \"\"\"\n"
  },
  {
    "path": "lmdeploy/pytorch/engine/executor/uni_executor.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport asyncio\nfrom typing import Dict, List\n\nfrom lmdeploy.pytorch.config import BackendConfig, CacheConfig, DistConfig, MiscConfig, ModelConfig, SpecDecodeConfig\nfrom lmdeploy.pytorch.devices import DeviceContext\nfrom lmdeploy.pytorch.disagg.conn.protocol import DistServeInitRequest, DistServeKVTransferEndpointInfo\nfrom lmdeploy.pytorch.disagg.messages import MigrationExecutionBatch\nfrom lmdeploy.pytorch.engine.model_agent import build_model_agent\nfrom lmdeploy.utils import get_logger\n\nfrom .base import ExecutorBase\n\nlogger = get_logger('lmdeploy')\n\n\nclass UniExecutor(ExecutorBase):\n    \"\"\"Single node single device Executor.\"\"\"\n\n    def __init__(\n        self,\n        model_path: str,\n        model_config: ModelConfig,\n        cache_config: CacheConfig,\n        backend_config: BackendConfig,\n        misc_config: MiscConfig,\n        adapters: Dict[str, str] = None,\n        device_type: str = 'cuda',\n        specdecode_config: SpecDecodeConfig = None,\n    ):\n        \"\"\"Initialize Executor.\"\"\"\n        super().__init__(model_path=model_path,\n                         model_config=model_config,\n                         cache_config=cache_config,\n                         backend_config=backend_config,\n                         dist_config=DistConfig(),\n                         misc_config=misc_config,\n                         adapters=adapters,\n                         device_type=device_type,\n                         specdecode_config=specdecode_config)\n\n        self.device_ctx = DeviceContext(device_type=device_type)\n        self.model_agent = build_model_agent(\n            model_path=model_path,\n            model_config=model_config,\n            cache_config=cache_config,\n            backend_config=backend_config,\n            misc_config=misc_config,\n            device_ctx=self.device_ctx,\n            adapters=adapters,\n            specdecode_config=specdecode_config,\n        )\n\n    def download_models(self):\n        \"\"\"Download model.\"\"\"\n        raise NotImplementedError('Not Implemented.')\n\n    def build_model(self):\n        \"\"\"Build model.\"\"\"\n        self.model_agent.build_model()\n\n    def gather_free_mem(self):\n        \"\"\"Gather available memory.\"\"\"\n        return [self.model_agent.get_free_mem()]\n\n    def set_cache_config(self, cache_config: CacheConfig, spec_cache_config: CacheConfig = None):\n        \"\"\"Set all cache config.\"\"\"\n        self.model_agent.set_cache_config(cache_config, spec_cache_config)\n\n    def set_model_config(self, model_config: ModelConfig, spec_model_config: ModelConfig):\n        \"\"\"Set all cache config.\"\"\"\n        self.model_agent.set_model_config(model_config, spec_model_config)\n\n    def build_graph_runner(self):\n        \"\"\"Build graph runner.\"\"\"\n        self.model_agent.build_graph_runner()\n\n    def build_cache_engine(self):\n        \"\"\"Build cache engine.\"\"\"\n        self.model_agent.build_cache_engine()\n\n    def warmup(self):\n        self.model_agent.warmup()\n\n    def start(self, forward_event: asyncio.Event):\n        \"\"\"Start engine loop.\"\"\"\n        self.model_agent.start(forward_event)\n\n    async def wait_tasks(self):\n        \"\"\"Wait tasks.\"\"\"\n        await self.model_agent.wait_tasks()\n\n    def stop(self):\n        \"\"\"Stop engine loop.\"\"\"\n        self.model_agent.stop()\n\n    def release(self):\n        \"\"\"Release resources.\"\"\"\n        self.model_agent.release()\n\n    async def forward_async(self, inputs):\n        \"\"\"Start forward.\"\"\"\n        self.model_agent.set_forward_inputs(inputs)\n        # switch to task: ModelAgent._async_loop_inputs_preprocess\n        await asyncio.sleep(0)\n\n    async def get_output_async(self, dp_rank: int = 0):\n        \"\"\"Get output async.\"\"\"\n        assert dp_rank == 0\n        return await self.model_agent.get_output_async()\n\n    def get_input_processor(self):\n        \"\"\"Get input processor.\"\"\"\n        return self.model_agent.get_input_processor()\n\n    \"\"\" PD Disaggregation API Begin \"\"\"\n\n    def p2p_initialize(self, init_request: DistServeInitRequest):\n        \"\"\"Init rdma link.\n\n        note: return list to be composible with multiprocess executor like ray.\n        \"\"\"\n        return [self.model_agent.cache_engine.p2p_initialize(init_request)]\n\n    def p2p_connect(self, remote_engine_id: str, conn_request: List[DistServeKVTransferEndpointInfo]):\n        \"\"\"rdma_connect.\"\"\"\n        self.model_agent.cache_engine.p2p_connect(remote_engine_id, conn_request)\n\n    async def migrate(self, batch: MigrationExecutionBatch):\n        \"\"\"KV Cache Migration.\"\"\"\n        return await self.model_agent.cache_engine.migrate(batch)\n\n    \"\"\" PD Disaggregation API End \"\"\"\n"
  },
  {
    "path": "lmdeploy/pytorch/engine/guided_process.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport json\nimport logging\nfrom typing import Any, Dict, List, Optional, Tuple\n\nimport torch\nimport xgrammar as xgr\nfrom transformers import PreTrainedTokenizerBase\n\nlogger = logging.getLogger('lmdeploy')\n\n\nclass GuidedDecodingManager:\n    processors = {}\n\n    def __init__(self, tokenizer: PreTrainedTokenizerBase, vocab_size: Optional[int]):\n        if vocab_size is None:\n            vocab_size = tokenizer.vocab_size\n\n        tokenizer_info = xgr.TokenizerInfo.from_huggingface(tokenizer, vocab_size=vocab_size)\n        self.compiler = xgr.GrammarCompiler(tokenizer_info)\n        self.vocab_size = vocab_size\n\n    def get_processors(self, session_ctx: List[Dict[str, Any]],\n                       response_formats: Tuple[Dict]) -> Dict[int, xgr.GrammarMatcher]:\n        processors = {}\n        for i, _format in enumerate(response_formats):\n            if isinstance(_format, Dict) and _format.get('type', 'text') != 'text':\n                schema_type = _format['type']\n                if schema_type == 'json_schema':\n                    schema = _format['json_schema']\n                    if isinstance(schema, Dict):\n                        for key in ['json_schema', 'schema']:\n                            if key in schema:\n                                schema = json.dumps(schema[key], ensure_ascii=False)\n\n                    if not isinstance(schema, str):\n                        raise ValueError(f'Cannot parse schema {schema}. The schema must be '\n                                         'either a dictionary or a string that contains the'\n                                         ' JSON Schema specification')\n                elif schema_type == 'regex_schema':\n                    schema = _format.get('regex_schema', '')\n                elif schema_type == 'json_object':\n                    schema = '{\"type\" : \"object\", \"additionalProperties\": true}'\n                else:\n                    raise ValueError(f'unsupported format type: {schema_type}')\n\n                session_id = session_ctx[i]['session_id']\n                seq_id = session_ctx[i]['seq_id']\n\n                processors[i] = self.get_processor(session_id, seq_id, schema, schema_type)\n\n        return processors\n\n    def get_processor(self, session_id: int, seq_id: int, schema: str, type: str) -> xgr.GrammarMatcher:\n        if session_id in self.processors:\n            session_dict = self.processors[session_id]\n            if seq_id in session_dict:\n                processor = session_dict[seq_id]\n                return processor\n\n        if type == 'json_schema':\n            if isinstance(schema, str):\n                schema = json.loads(schema)\n\n            assert isinstance(schema, dict)\n            compiled = self.compiler.compile_json_schema(schema)\n        elif type == 'regex_schema':\n            compiled = self.compiler.compile_regex(schema)\n        elif type == 'json_object':\n            compiled = self.compiler.compile_json_schema(schema)\n        else:\n            assert False, f'Do not support schema type {type}'\n\n        processor = xgr.GrammarMatcher(compiled, terminate_without_stop_token=True)\n        self.processors.setdefault(session_id, {})[seq_id] = processor\n        logger.info(f'create guided processor for session_id={session_id}, seq_id={seq_id}, and '\n                    f'total_processors={len(self.processors)}')\n        return processor\n\n    def remove_processor(self, session_id: int):\n        if session_id in self.processors:\n            del self.processors[session_id]\n            logger.info(\n                f'delete guided processor for session_id={session_id}, and total_processors={len(self.processors)}')\n\n    def allocate_batched_bitmap(self, batch_size: int) -> torch.Tensor:\n        return xgr.allocate_token_bitmask(batch_size, self.vocab_size)\n\n    def fill_bitmap(self, processor: xgr.GrammarMatcher, guided_bitmask: torch.Tensor, index: int) -> None:\n        processor.fill_next_token_bitmask(guided_bitmask, index)\n\n    def accept_token(self, processor: xgr.GrammarMatcher, token: int) -> None:\n        processor.accept_token(token)\n\n    def apply_batched_bitmap(self, logits: torch.Tensor, guided_bitmask: torch.Tensor) -> None:\n        device = logits.device\n        dtype = logits.dtype\n\n        if device.type in {'cpu', 'cuda'}:\n            xgr.apply_token_bitmask_inplace(logits, guided_bitmask.to(device))\n        else:\n            cpu_logits = logits.cpu().float()\n            cpu_mask = guided_bitmask.cpu()\n            xgr.apply_token_bitmask_inplace(cpu_logits, cpu_mask)\n            logits.copy_(cpu_logits.to(device, dtype))\n\n    def clear(self) -> None:\n        self.processors.clear()\n        logger.info(f'clear guided processors, total_processors={len(self.processors)}')\n"
  },
  {
    "path": "lmdeploy/pytorch/engine/input_process.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom abc import ABC, abstractmethod\nfrom dataclasses import dataclass\nfrom typing import Any, Dict, List, Optional\n\nfrom lmdeploy.pytorch.multimodal.data_type import MultiModalInputs\n\nTypeModelMetas = Dict[str, Any]\n\nInputMultiModalType = List[Dict[str, Any]]\n\n\n@dataclass\nclass PreprocessInputResult:\n    \"\"\"Results of preprocess input.\"\"\"\n    input_ids: List[int]\n    input_multimodals: Optional[MultiModalInputs] = None\n    model_metas: Optional[TypeModelMetas] = None\n\n\nclass BaseModelInputProcessor(ABC):\n    \"\"\"Processor of model inputs.\"\"\"\n\n    @abstractmethod\n    def preprocess_input(self,\n                         input_ids: List[int],\n                         input_mms: InputMultiModalType = None,\n                         **kwargs) -> PreprocessInputResult:\n        \"\"\"Preprocess input.\"\"\"\n        raise NotImplementedError('Not implemented.')\n\n\nclass DefaultModelInputProcessor(BaseModelInputProcessor):\n    \"\"\"Default model input processor.\"\"\"\n\n    def preprocess_input(self,\n                         input_ids: List[int],\n                         input_mms: MultiModalInputs = None,\n                         **kwargs) -> PreprocessInputResult:\n        \"\"\"Preprocess input.\"\"\"\n        return PreprocessInputResult(\n            input_ids=input_ids,\n            input_multimodals=input_mms,\n        )\n"
  },
  {
    "path": "lmdeploy/pytorch/engine/inputs_maker.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport logging\nfrom collections import defaultdict\nfrom dataclasses import dataclass\nfrom typing import TYPE_CHECKING, List, Optional\n\nimport numpy as np\nimport torch\nfrom torch.profiler import record_function\n\nfrom lmdeploy.pytorch.disagg.config import EngineRole\nfrom lmdeploy.pytorch.messages import MessageStatus\nfrom lmdeploy.pytorch.model_inputs import ModelInputs, ModelInputsDelta, VisionModelInputs\nfrom lmdeploy.utils import get_logger\n\nif TYPE_CHECKING:\n    from lmdeploy.pytorch.adapter.adapter import AdapterManager\n    from lmdeploy.pytorch.messages import SchedulerSequence\n    from lmdeploy.pytorch.multimodal.data_type import MultiModalInputs\n    from lmdeploy.pytorch.paging import Scheduler\n    from lmdeploy.pytorch.strategies.base.engine import EngineStrategy\n    from lmdeploy.pytorch.strategies.base.model_agent import ModelAgentStrategy\n    from lmdeploy.pytorch.strategies.base.sampling import SamplingStrategy\n\n    from .engine import Engine, SeqList\n    from .executor import ExecutorBase\n\nlogger = get_logger('lmdeploy')\n\n\ndef _tensorlize_block_offsets(block_offsets, dtype=torch.int32):\n    \"\"\"Tensorlize block_offsets.\"\"\"\n    # copy on numpy is faster than torch.nn.utils.rnn.pad_sequence\n    batch_size = len(block_offsets)\n    max_len = max([len(off) for off in block_offsets])\n    out = np.zeros((batch_size, max_len), dtype=block_offsets[0].dtype)\n\n    for idx, off in enumerate(block_offsets):\n        off_len = len(off)\n        out[idx, :off_len] = off\n    return torch.as_tensor(out, dtype=dtype)\n\n\n@dataclass\nclass InputsMakerConfig:\n    \"\"\"Input maker config.\n\n    This config is added for Dependency Injection\n    \"\"\"\n    max_batches: int\n    max_prefill_token_num: int\n    role: EngineRole\n    is_ssm: bool = False\n    dp: int = 1\n    spec_decoding: bool = False\n    enable_chunked_prefill: bool = False\n\n    @staticmethod\n    def from_engine(engine: 'Engine'):\n        cache_config = engine.cache_config\n        return InputsMakerConfig(\n            spec_decoding=engine.specdecode_config is not None,\n            max_batches=cache_config.max_batches,\n            max_prefill_token_num=cache_config.max_prefill_token_num,\n            role=cache_config.role,\n            is_ssm=len(cache_config.states_shapes) > 0,\n            dp=engine.dist_config.dp,\n            enable_chunked_prefill=engine.misc_config.enable_chunked_prefill,\n        )\n\n\nclass LongContextChunker:\n    \"\"\"Long context chunker.\"\"\"\n\n    def __init__(self, max_prefill_token_num: int):\n        self.max_prefill_token_num = max_prefill_token_num\n\n        # long prefill seq\n        self.clear()\n\n    def enabled(self):\n        \"\"\"Is enabled.\"\"\"\n        return self.seq is not None\n\n    def is_long_context(self, seq: 'SchedulerSequence'):\n        \"\"\"Is long context.\"\"\"\n        return seq.num_token_ids > self.max_prefill_token_num\n\n    def set_seq(self, seq: 'SchedulerSequence'):\n        \"\"\"Set seq.\"\"\"\n        self.seq = seq\n        self.next_step = seq.num_history_ids\n\n        # fill multimodals\n        # if image size exceeds max_prefill_token_num, enlarge it\n        max_prefill_num = self.max_prefill_token_num\n        mm = seq.get_input_multimodals()\n        self.multimodals = defaultdict(list)\n        for key, value in mm.items():\n            # sorted by start\n            value = sorted(value, key=lambda x: x.start)\n            self.multimodals[key] = value\n            max_mm_size = max([v.end - v.start for v in value], default=0)\n            max_prefill_num = max(max_prefill_num, max_mm_size)\n\n        self.max_prefill_num = max_prefill_num\n\n    def multimodal_iter(self):\n        \"\"\"Multimodal iterator.\"\"\"\n        multimodal_data = []\n        for modal_type, modal_datas in self.multimodals.items():\n            if len(modal_datas) == 0:\n                continue\n            multimodal_data += [(modal_type, data) for data in modal_datas]\n\n        multimodal_data = sorted(multimodal_data, key=lambda x: x[1].start)\n        for modal_type, data in multimodal_data:\n            yield modal_type, data\n\n    def next_chunk_size(self):\n        \"\"\"Get chunk size.\"\"\"\n        seq = self.seq\n        if seq is None:\n            return 0, None\n\n        llm_chunk_size = min(seq.num_token_ids, self.max_prefill_num)\n\n        if len(self.multimodals) == 0:\n            # no vlm inputs found\n            return llm_chunk_size, None\n\n        start = seq.num_history_ids\n        end = start + llm_chunk_size\n        out_multimodals: 'MultiModalInputs' = defaultdict(list)\n        for modal_type, mm in self.multimodal_iter():\n            assert mm.start >= start, 'multimodal data should be sorted by start'\n            if mm.start >= end:\n                # | start ... end ... mm.start ... mm.end |\n                # if start is beyond threshold, stop\n                break\n\n            if mm.end > end:\n                # | start ... mm.start ... end ... mm.end |\n                # assume multimodals not overlap\n                end = mm.start\n                break\n\n            # | start ... mm.start ... mm.end ... end |\n            out_multimodals[modal_type].append(mm)\n\n        return end - start, out_multimodals\n\n    def is_last_chunk(self):\n        \"\"\"Is last chunk.\"\"\"\n        if self.seq is None:\n            return True\n        return self.seq.num_token_ids <= self.max_prefill_num\n\n    def clear(self):\n        \"\"\"Clear.\"\"\"\n        self.seq: 'SchedulerSequence' = None\n        self.multimodals: MultiModalInputs = defaultdict(list)\n        self.next_step: int = 0\n        self.max_prefill_num: int = self.max_prefill_token_num\n\n    def update_step(self, inputs: ModelInputs):\n        \"\"\"Step chunker.\"\"\"\n        if self.seq is None:\n            return\n        if self.is_last_chunk():\n            # last chunk should be treated as normal prefill\n            return\n        assert inputs.is_chunk\n        chunk_size = inputs.max_q_seqlen\n        self.next_step += chunk_size\n        self.seq.set_step(self.next_step)\n\n        # remove used multimodals\n        for mms in self.multimodals.values():\n            while len(mms) > 0 and mms[0].end <= self.next_step:\n                mms.pop(0)\n        self.multimodals = dict((k, v) for k, v in self.multimodals.items() if len(v) > 0)\n\n    def check_enable(self):\n        if not self.enabled():\n            return\n        if self.seq.status != MessageStatus.RUNNING:\n            self.clear()\n\n\nclass InputsMakerAsync:\n\n    def __init__(\n        self,\n        executor: 'ExecutorBase',\n        scheduler: 'Scheduler',\n        adapter_manager: 'AdapterManager',\n        engine_strategy: 'EngineStrategy',\n        sampling_strategy: 'SamplingStrategy',\n        model_agent_strategy: 'ModelAgentStrategy',\n        config: InputsMakerConfig,\n    ):\n        self.executor = executor\n        self.scheduler = scheduler\n        self.adapter_manager = adapter_manager\n        self.config = config\n        self.spec_decoding = config.spec_decoding\n\n        # strategies\n        self.engine_strategy = engine_strategy\n        self.sampling_strategy = sampling_strategy\n        self.model_agent_strategy = model_agent_strategy\n\n        self._init_do_prefill(config)\n\n        # record for next forward.\n        self.next_is_prefill = True\n        self.forward_inputs = None\n\n        # running seqs\n        # mark the seqs that have been sent to executor\n        self.running_seqs: List['SchedulerSequence'] = []\n        self.to_evict_seqs: List['SchedulerSequence'] = []\n\n        # long context chunker\n        self.long_context_chunker = LongContextChunker(config.max_prefill_token_num)\n\n    def _init_do_prefill(self, config: InputsMakerConfig):\n        if config.role == EngineRole.Prefill:\n            self.do_prefill = self.do_prefill_pnode\n        elif config.enable_chunked_prefill:\n            self.do_prefill = self.do_prefill_chunked\n        else:\n            self.do_prefill = self.do_prefill_default\n\n    def _create_vision_model_inputs(self, messages: 'SeqList', model_inputs: ModelInputs):\n        \"\"\"Create vision model inputs.\"\"\"\n        batch_size = len(messages)\n\n        def __get_vlm_embeddings():\n            \"\"\"Get vlm input embeddings and indexings.\"\"\"\n            max_q_seq_length = model_inputs.seq_length.max().item()\n            input_embeddings = [[\n                emb.embeddings if isinstance(emb.embeddings, torch.Tensor) else torch.as_tensor(emb.embeddings)\n                for emb in msg.input_embeddings\n            ] for msg in messages]\n            input_embedding_ranges = [\n                torch.tensor([[emb.start, emb.end] for emb in msg.input_embeddings]) for msg in messages\n            ]\n            input_embedding_indexing = torch.zeros((batch_size, max_q_seq_length), dtype=torch.bool)\n            for msg_id, msg in enumerate(messages):\n                num_history_ids = msg.num_history_ids\n                for emb in msg.input_embeddings:\n                    # make slice index relative to embeddings\n                    emb_start = emb.start - num_history_ids\n                    emb_end = emb.end - num_history_ids\n                    input_embedding_indexing[msg_id][emb_start:emb_end] = True\n            return (input_embeddings, input_embedding_indexing, input_embedding_ranges)\n\n        def __has_values(input_multimodals):\n            for input_mm in input_multimodals:\n                for val in input_mm.values():\n                    if len(val) > 0:\n                        return True\n            return False\n\n        has_embedding = any([len(msg.history_embeddings) > 0 for msg in messages])\n        if has_embedding:\n            has_embedding = any([len(msg.input_embeddings) > 0 for msg in messages])\n\n        has_multimodal = any([not msg.history_multimodals.empty() for msg in messages])\n        input_multimodals = None\n        if has_multimodal:\n            input_multimodals = [msg.get_input_multimodals() for msg in messages]\n            has_multimodal = __has_values(input_multimodals)\n            if not has_multimodal:\n                # no multimodal inputs\n                input_multimodals = None\n\n        if not has_embedding and not has_multimodal:\n            # no vision inputs\n            return None\n\n        if has_embedding:\n            # for inputs with embeddings\n            (input_embeddings, input_embedding_indexing, input_embedding_ranges) = __get_vlm_embeddings()\n        else:\n            input_embeddings = None\n            input_embedding_indexing = None\n            input_embedding_ranges = None\n\n        history_lengths = model_inputs.history_lengths\n        vision_embedding_inputs = VisionModelInputs(history_lengths=history_lengths,\n                                                    input_embeddings=input_embeddings,\n                                                    input_embedding_indexing=input_embedding_indexing,\n                                                    input_embedding_ranges=input_embedding_ranges,\n                                                    input_multimodals=input_multimodals)\n        return vision_embedding_inputs\n\n    @property\n    def torch_int_dtype(self):\n        \"\"\"Return int32 for cuda, int64 for others.\"\"\"\n        if self.executor.device_type == 'cuda':\n            return torch.int32\n        return torch.int64\n\n    def _set_adapter_ids(self, model_inputs: ModelInputs, messages: 'SeqList'):\n        \"\"\"Set adapter ids to model inputs.\"\"\"\n        if self.adapter_manager.num_adapters() <= 1:\n            return\n        adapter_names = [msg.adapter_name for msg in messages]\n        local_adapter_ids = self.adapter_manager.get_adapter_ids(adapter_names)\n        local_adapter_ids = model_inputs.seq_length.new_tensor(local_adapter_ids)\n        model_inputs.local_adapter_ids = local_adapter_ids\n\n    @torch.inference_mode()\n    @record_function('create_model_inputs')\n    def create_model_inputs(self, messages: 'SeqList', is_prefill: bool):\n        \"\"\"Create model inputs from messages.\n\n        Args:\n            messages (SeqList): The input messages.\n        \"\"\"\n        batch_size = len(messages)\n        # history lengths\n        history_lengths = torch.tensor([msg.num_history_ids for msg in messages])\n\n        # input ids\n        token_ids = [msg.token_ids for msg in messages]\n\n        input_ids = torch.as_tensor(np.concatenate(token_ids))[None]\n\n        # seqlens\n        is_decoding = not is_prefill\n        if not is_decoding:\n            seq_length = [len(tokens) for tokens in token_ids]\n            seq_length = torch.tensor(seq_length, dtype=torch.long)\n            max_q_seqlen = seq_length.max().item()\n        else:\n            max_q_seqlen = len(token_ids[0])\n            seq_length = torch.full((batch_size, ), max_q_seqlen, dtype=torch.long)\n        kv_seqlens = seq_length + history_lengths\n        max_kv_seqlen = kv_seqlens.max().item()\n        sum_kv_seqlen = kv_seqlens.sum().item()\n\n        # block offsets\n        block_offsets = self.scheduler.get_block_tables(messages)\n        block_offsets = _tensorlize_block_offsets(block_offsets, dtype=self.torch_int_dtype)\n\n        # num_ignored_history\n        num_ignored_history = torch.tensor([msg.num_ignored_history for msg in messages])\n\n        # model_metas\n        model_metas = [msg.model_meta for msg in messages]\n\n        # create model inputs for all required fields\n        model_inputs = ModelInputs(\n            input_ids=input_ids,\n            seq_length=seq_length,\n            history_lengths=history_lengths,\n            block_offsets=block_offsets,\n            is_decoding=is_decoding,\n            num_ignored_history=num_ignored_history,\n            max_q_seqlen=max_q_seqlen,\n            max_kv_seqlen=max_kv_seqlen,\n            sum_kv_seqlen=sum_kv_seqlen,\n            model_metas=model_metas,\n        )\n\n        # adapters\n        self._set_adapter_ids(model_inputs, messages)\n\n        # vision inputs\n        vision_model_inputs = self._create_vision_model_inputs(messages, model_inputs)\n        model_inputs.vision_inputs = vision_model_inputs\n\n        # ssm\n        if self.config.is_ssm:\n            state_offsets = torch.tensor([msg.logical_state for msg in messages])\n            model_inputs.state_offsets = state_offsets\n\n        return model_inputs\n\n    @torch.inference_mode()\n    @record_function('create_model_inputs_long_context')\n    def create_model_inputs_long_context(self,\n                                         seq: 'SchedulerSequence',\n                                         chunk_size: int,\n                                         multimodals: Optional['MultiModalInputs'] = None):\n        \"\"\"Create model inputs for long context messages.\"\"\"\n        token_ids = seq.token_ids[:chunk_size]\n        input_ids = torch.as_tensor(token_ids)[None]\n        q_seqlens = torch.tensor([chunk_size])\n        history_lens = torch.tensor([seq.num_history_ids])\n\n        # block offsets\n        block_offsets = self.scheduler.get_block_tables([seq])\n        block_offsets = torch.as_tensor(block_offsets[0], dtype=self.torch_int_dtype)[None]\n\n        # num_ignored_history\n        num_ignored_history = torch.tensor([seq.num_ignored_history])\n\n        # model_metas\n        model_metas = [seq.model_meta]\n\n        kv_seqlens = q_seqlens + history_lens\n        max_kv_seqlen = kv_seqlens.item()\n        sum_kv_seqlen = max_kv_seqlen\n\n        model_inputs = ModelInputs(\n            input_ids=input_ids,\n            seq_length=q_seqlens,\n            history_lengths=history_lens,\n            block_offsets=block_offsets,\n            is_decoding=False,\n            num_ignored_history=num_ignored_history,\n            max_q_seqlen=q_seqlens.item(),\n            max_kv_seqlen=max_kv_seqlen,\n            sum_kv_seqlen=sum_kv_seqlen,\n            model_metas=model_metas,\n            is_chunk=True,\n        )\n\n        # adapters\n        self._set_adapter_ids(model_inputs, [seq])\n\n        # vision inputs\n        if multimodals is not None and len(multimodals) > 0:\n            vision_model_inputs = VisionModelInputs(\n                history_lengths=model_inputs.history_lengths,\n                input_multimodals=[multimodals],\n            )\n            model_inputs.vision_inputs = vision_model_inputs\n\n        # ssm\n        if self.config.is_ssm:\n            model_inputs.state_offsets = torch.tensor([seq.logical_state])\n\n        return model_inputs\n\n    @torch.inference_mode()\n    @record_function('create_model_inputs_delta')\n    def create_model_inputs_delta(self):\n        \"\"\"Create model inputs delta from messages.\"\"\"\n        batch_size = len(self.running_seqs)\n        assert batch_size > 0\n        num_decode_tokens = self.engine_strategy.get_num_decode_tokens()\n        max_q_seqlen = num_decode_tokens\n        prealloc_size = self.engine_strategy.get_prealloc_size(True)\n        valid_mask = self.scheduler.schedule_running(self.running_seqs,\n                                                     num_decode_tokens=num_decode_tokens,\n                                                     prealloc_size=prealloc_size)\n\n        valid_mask = np.array(valid_mask)\n        indices_cpu = np.arange(0, batch_size)[valid_mask]\n        valid_seqs: List['SchedulerSequence'] = [self.running_seqs[i] for i in indices_cpu]\n        invalid_seqs: List['SchedulerSequence'] = [self.running_seqs[i] for i in range(batch_size) if not valid_mask[i]]\n        if len(valid_seqs) == 0:\n            return None, valid_seqs, invalid_seqs\n\n        # block offsets\n        block_offsets = self.scheduler.get_block_tables(valid_seqs)\n        block_offsets = _tensorlize_block_offsets(block_offsets, dtype=self.torch_int_dtype)\n\n        # sliding window\n        if self.scheduler.cache_config.window_size > 0:\n            num_ignored_history = torch.tensor([msg.num_ignored_history for msg in valid_seqs])\n        else:\n            num_ignored_history = torch.zeros(len(valid_seqs), dtype=torch.long)\n\n        kv_seqlens = [seq.num_all_ids + max_q_seqlen for seq in valid_seqs]\n        sum_kv_seqlen = sum(kv_seqlens) + batch_size * max_q_seqlen\n        max_kv_seqlen = max(kv_seqlens) + max_q_seqlen\n\n        output = ModelInputsDelta(\n            indices=None,\n            block_offsets=block_offsets,\n            indice_cpu=indices_cpu,\n            max_q_seqlen=max_q_seqlen,\n            max_kv_seqlen=max_kv_seqlen,\n            sum_kv_seqlen=sum_kv_seqlen,\n            num_ignored_history=num_ignored_history,\n        )\n\n        return output, valid_seqs, invalid_seqs\n\n    def create_model_inputs_delta_valid_only(self):\n        \"\"\"Create model inputs delta for valid running seqs only.\n\n        Only check validation, no resources will be scheduled.\n        \"\"\"\n        from lmdeploy.pytorch.messages import MessageStatus\n        batch_size = len(self.running_seqs)\n\n        valid_mask = [seq.status == MessageStatus.RUNNING for seq in self.running_seqs]\n        if all(valid_mask):\n            return None, self.running_seqs, []\n\n        valid_mask = np.array(valid_mask, dtype=bool)\n        indices_cpu = np.arange(0, batch_size)[valid_mask]\n        valid_seqs: List['SchedulerSequence'] = [self.running_seqs[i] for i in indices_cpu]\n        invalid_seqs: List['SchedulerSequence'] = [self.running_seqs[i] for i in range(batch_size) if not valid_mask[i]]\n\n        num_decode_tokens = self.engine_strategy.get_num_decode_tokens()\n        max_q_seqlen = num_decode_tokens\n        kv_seqlens = [seq.num_all_ids + max_q_seqlen for seq in valid_seqs]\n        if len(kv_seqlens) == 0:\n            sum_kv_seqlen = 0\n            max_kv_seqlen = 0\n        else:\n            sum_kv_seqlen = sum(kv_seqlens) + batch_size * max_q_seqlen\n            max_kv_seqlen = max(kv_seqlens) + max_q_seqlen\n\n        output = ModelInputsDelta(\n            indices=None,\n            block_offsets=None,\n            indice_cpu=indices_cpu,\n            max_q_seqlen=max_q_seqlen,\n            max_kv_seqlen=max_kv_seqlen,\n            sum_kv_seqlen=sum_kv_seqlen,\n            num_ignored_history=None,\n        )\n\n        return output, valid_seqs, invalid_seqs\n\n    def update_running_seqs(self, running: 'SeqList', inputs: Optional[ModelInputs]):\n        \"\"\"Update running seqs.\"\"\"\n        if self.config.role == EngineRole.Prefill:\n            # p node will not update running seqs\n            return\n\n        is_decoding = inputs is None\n        if self.long_context_chunker.enabled() and not is_decoding:\n            # long context chunk does not need to update running seqs\n            self.long_context_chunker.update_step(inputs)\n            return\n\n        if is_decoding:\n            self.running_seqs = running\n        else:\n            self.running_seqs += running\n\n    def deactivate_evict_seqs(self):\n        \"\"\"Deactivate and evict seqs.\"\"\"\n        scheduler = self.scheduler\n        to_evict_seqs = self.to_evict_seqs\n        if len(to_evict_seqs) == 0:\n            return\n        # deactivate seqs(running -> ready)\n        scheduler.deactivate_seqs(to_evict_seqs)\n        # ready to waiting\n        scheduler.evict_seqs(to_evict_seqs)\n        self.to_evict_seqs.clear()\n\n    @torch.inference_mode()\n    @record_function('make_forward_inputs')\n    def _make_forward_inputs(self, prefill: bool, enable_empty: bool = False):\n        \"\"\"Make forward inputs for ModelAgent._async_step_background()\"\"\"\n\n        def __need_logits(seqs: 'SeqList'):\n            \"\"\"Need logits.\"\"\"\n            if self.spec_decoding:\n                return True\n            return any(seq.return_logits for seq in seqs)\n\n        def __need_routed_experts(seqs: 'SeqList'):\n            \"\"\"Need routed experts.\"\"\"\n            return any(seq.return_routed_experts for seq in seqs)\n\n        def __create_model_inputs(seqs):\n            \"\"\"Createe model inputs.\"\"\"\n            inputs = self.create_model_inputs(seqs, True)\n            delta, valid_seqs, _ = self.create_model_inputs_delta_valid_only()\n            self.running_seqs = valid_seqs\n            extra_inputs = self.model_agent_strategy.make_extra_inputs(seqs, inputs)\n            return inputs, delta, extra_inputs\n\n        def __create_inputs_chunk(running: 'SeqList'):\n            chunk_size, multimodals = self.long_context_chunker.next_chunk_size()\n            inputs = self.create_model_inputs_long_context(running[0], chunk_size, multimodals)\n            extra_inputs = self.model_agent_strategy.make_extra_inputs(running, inputs)\n            return inputs, extra_inputs\n\n        def __create_inputs_long_context_chunk():\n            seq = self.long_context_chunker.seq\n            running = [seq]\n            if self.long_context_chunker.is_last_chunk():\n                inputs, delta, extra_inputs = __create_model_inputs(running)\n                self.long_context_chunker.clear()\n            else:\n                inputs, extra_inputs = __create_inputs_chunk(running)\n                delta = None\n            inputs.is_first_chunk = False\n            return running, inputs, delta, extra_inputs\n\n        def __create_inputs_prefill():\n            if self.config.role == EngineRole.Prefill:\n                prealloc_size = 0\n            else:\n                prealloc_size = self.engine_strategy.get_prealloc_size(True)\n            scheduler_output = scheduler.schedule(is_prefill=prefill, prealloc_size=prealloc_size)\n            running = scheduler_output.running\n            swap_in_map = scheduler_output.swap_in_map\n            swap_out_map = scheduler_output.swap_out_map\n\n            inputs = None\n            delta = None\n            extra_inputs = None\n            if len(running) == 1 and self.long_context_chunker.is_long_context(running[0]):\n                # set long context chunker\n                self.long_context_chunker.set_seq(running[0])\n                inputs, extra_inputs = __create_inputs_chunk(running)\n            elif len(running) > 0:\n                # create inputs\n                inputs, delta, extra_inputs = __create_model_inputs(running)\n            return running, inputs, delta, extra_inputs, swap_in_map, swap_out_map\n\n        scheduler = self.scheduler\n        logger.debug(f'Make forward inputs with prefill={prefill}, enable_empty={enable_empty}')\n\n        inputs = None\n        delta = None\n        swap_in_map = {}\n        swap_out_map = {}\n\n        self.long_context_chunker.check_enable()\n        if self.long_context_chunker.enabled():\n            # long context chunking\n            running, inputs, delta, extra_inputs = __create_inputs_long_context_chunk()\n        elif prefill:\n            # prefill\n            (\n                running,\n                inputs,\n                delta,\n                extra_inputs,\n                swap_in_map,\n                swap_out_map,\n            ) = __create_inputs_prefill()\n\n        # try decoding\n        if inputs is None and len(self.running_seqs) > 0 and self.config.role != EngineRole.Prefill:\n            prefill = False\n            delta, running, invalid_seqs = self.create_model_inputs_delta()\n            self.to_evict_seqs = invalid_seqs\n            extra_inputs = None\n\n        # skip if enable empty\n        if inputs is None and delta is None:\n            return None\n\n        sampling_inputs = self.sampling_strategy.make_sampling_inputs(running)\n        if inputs is not None:\n            stopping_criteria = self.model_agent_strategy.make_stopping_criteria(running)\n        else:\n            stopping_criteria = None\n\n        return_logits = __need_logits(running)\n        return_routed_experts = __need_routed_experts(running)\n\n        return dict(\n            running=running,\n            inputs=inputs,\n            delta=delta,\n            swap_in_map=swap_in_map,\n            swap_out_map=swap_out_map,\n            sampling_inputs=sampling_inputs,\n            stopping_criteria=stopping_criteria,\n            return_logits=return_logits,\n            extra_inputs=extra_inputs,\n            return_routed_experts=return_routed_experts,\n        )\n\n    def do_prefill_pnode(self):\n        return True\n\n    def do_prefill_default(self):\n        # decoding if no waiting\n        scheduler = self.scheduler\n\n        # do decoding if not waiting\n        if not scheduler.has_waiting():\n            return False\n\n        # do prefill if too much tokens\n        waiting = scheduler.waiting\n        token_count = 0\n        for seq in waiting:\n            token_count += seq.num_token_ids\n            if token_count >= self.config.max_prefill_token_num:\n                return True\n\n        # prefill if no enough running\n        num_ready = scheduler.num_ready()\n        num_running = scheduler.num_running()\n        max_batches = self.config.max_batches\n        if num_ready + num_running < max_batches * 0.5:\n            return True\n\n        # decoding\n        return False\n\n    def do_prefill_chunked(self):\n        \"\"\"Chunked prefill strategy.\n\n        both dp=1 and dp>1 are supported.\n        \"\"\"\n        scheduler = self.scheduler\n        return not scheduler.has_ready()\n\n    async def _send_next_inputs_impl(self, prefill: bool = None, enable_empty: bool = False):\n        forward_inputs = self._make_forward_inputs(prefill, enable_empty)\n        if forward_inputs is None:\n            return None, None\n        next_running = forward_inputs.pop('running')\n        inputs = forward_inputs['inputs']\n        if logger.level <= logging.DEBUG and inputs is not None:\n            logger.debug(f'Sending forward inputs: {inputs.log_info()}')\n            session_ids = [seq.session_id for seq in next_running]\n            logger.debug(f'Forward session_ids: {session_ids}')\n        await self.executor.forward_async(forward_inputs)\n        self.forward_inputs = forward_inputs\n        return forward_inputs, next_running\n\n    async def send_next_inputs(self):\n        prefill = self.do_prefill()\n        return await self._send_next_inputs_impl(prefill)\n\n    async def prefetch_next_inputs(self):\n        prefill = self.do_prefill()\n        # send next forward\n        logger.debug('Prefetching next forward inputs.')\n        return await self._send_next_inputs_impl(prefill, True)\n\n\ndef build_inputs_maker(engine: 'Engine'):\n    \"\"\"Build inputs makers.\"\"\"\n    config = InputsMakerConfig.from_engine(engine)\n    return InputsMakerAsync(\n        executor=engine.executor,\n        scheduler=engine.scheduler,\n        adapter_manager=engine.adapter_manager,\n        engine_strategy=engine.engine_strategy,\n        sampling_strategy=engine.sampling_strategy,\n        model_agent_strategy=engine.model_agent_strategy,\n        config=config,\n    )\n"
  },
  {
    "path": "lmdeploy/pytorch/engine/logits_process.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport asyncio\nfrom dataclasses import dataclass, fields\nfrom functools import lru_cache\nfrom typing import Any\n\nimport numpy as np\nimport torch\n\nfrom lmdeploy.messages import LogitsProcessor\nfrom lmdeploy.pytorch import envs\n\nfrom ..messages import SchedulerSequence\nfrom .guided_process import GuidedDecodingManager\n\n\ndef _process_temperature_(scores: torch.Tensor, temperature: torch.Tensor):\n    \"\"\"Process temperature.\"\"\"\n    temperature = temperature.to(scores.dtype)\n    scores.div_(temperature[:, None])\n    return scores\n\n\ndef _process_bad_words_(scores: torch.Tensor,\n                        bad_words: torch.Tensor,\n                        mask: torch.Tensor,\n                        filter_value: float = -float('inf')):\n    \"\"\"Apply bad-word filtering to token scores.\n\n    This function updates ``scores`` in place by setting the scores of\n    \"bad\" token indices to ``filter_value``.\n    Args:\n        scores (torch.Tensor): A tensor of shape ``[batch_size, vocab_size]``\n            containing the logits or scores for each token in the vocabulary.\n        bad_words (torch.Tensor): A tensor of shape\n            ``[batch_size, num_bad_words]`` containing token indices that\n            should be suppressed. Invalid or masked positions may contain\n            negative values; these entries are ignored and not used as\n            indices into ``scores``.\n        mask (torch.Tensor): A boolean tensor with the same shape as\n            ``bad_words``. Positions with ``True`` indicate that the\n            corresponding entry in ``bad_words`` is a valid bad-word index\n            that should be filtered. Positions with ``False`` are treated as\n            invalid/masked and are not applied to ``scores``.\n        filter_value (float, optional): The value to assign to the scores of\n            bad-word tokens. Defaults to ``-float('inf')``.\n    Returns:\n        torch.Tensor: The ``scores`` tensor after bad-word filtering has\n        been applied.\n    \"\"\"\n    # invalid badwords might be negative\n    valid_bad_words = bad_words.where(mask, 0)\n    filtered_scores = scores.gather(1, valid_bad_words)\n    filtered_scores[mask] = filter_value\n    scores.scatter_(1, valid_bad_words, filtered_scores)\n    return scores\n\n\ndef _process_repetition_penalty_(scores: torch.Tensor, input_ids: torch.Tensor, penalty: torch.Tensor):\n    \"\"\"Process repetition penalty.\"\"\"\n    score = torch.gather(scores, 1, input_ids)\n    penalty = penalty.to(score.dtype)\n    score = torch.where(score < 0, score * penalty[:, None], score / penalty[:, None])\n    scores.scatter_(1, input_ids, score)\n    return scores\n\n\ndef _filter_topk_sorted_(scores: torch.Tensor, topk: torch.LongTensor, filter_value: float = -float('inf')):\n    \"\"\"Filter topk on sorted scores.\"\"\"\n    filter_value = -float('inf')\n    num_tokens = scores.size(1)\n    token_idx = torch.arange(num_tokens, device=scores.device)\n    mask = token_idx[None, :] >= topk[:, None]\n    scores.masked_fill_(mask, filter_value)\n    return scores\n\n\ndef _filter_topp_sorted_(scores: torch.Tensor, topp: torch.Tensor, filter_value: float = -float('inf')):\n    \"\"\"Filter topp on sorted scores.\"\"\"\n    softmax_scores = scores.softmax(-1)\n    cum_scores = softmax_scores.cumsum(1) - softmax_scores\n    mask = cum_scores > topp[:, None]\n    mask[:, 0] = False  # keep at least one\n    scores.masked_fill_(mask, filter_value)\n    return scores\n\n\ndef _filter_minp_sorted_(scores: torch.Tensor, minp: torch.Tensor, filter_value: float = -float('inf')):\n    \"\"\"Filter minp on sorted scores.\"\"\"\n    softmax_scores = scores.softmax(-1)\n    top_probs, _ = softmax_scores.max(dim=-1, keepdim=True)\n    scaled_min_p = minp.unsqueeze(dim=1) * top_probs\n    mask = softmax_scores < scaled_min_p\n    scores.masked_fill_(mask, filter_value)\n    return scores\n\n\n@lru_cache\ndef _ngram_one(dtype: torch.dtype, device: torch.device, fill: int = 1):\n    return torch.ones(fill, dtype=dtype, device=device)\n\n\ndef ngram(\n    token_ids: torch.Tensor,\n    n: torch.Tensor | None,\n    threshold: torch.Tensor,\n    max_n: int,\n    max_window_size: int,\n):\n    \"\"\"Compute n-gram matches between sliding windows and a target sequence.\n\n    For each batch, performs cosine similarity checking between:\n      - All sliding windows of length `max_n` from the full sequence\n      - The last `max_n` tokens of the sequence (target window)\n\n    A match is counted when both:\n      1. Cosine similarity ≈ 1 (normalized vectors match)\n      2. Vector lengths match (preventing zero/normalization artifacts)\n\n    Parameters\n    ----------\n    token_ids : torch.Tensor\n        Input token IDs of shape (batch_size, seq_len).\n        Values are typically ≥0 (0 may represent padding/special tokens).\n    n : torch.Tensor\n        Effective n-gram length for each batch element, shape (batch_size,).\n        When `same_n=False`, positions beyond `n` in the last `max_n` tokens are masked.\n    threshold : torch.Tensor\n        Minimum number of matching windows required for validity, shape (batch_size,).\n    max_n : int\n        Maximum n-gram length (window size for matching).\n    max_window_size: int\n        Maximum window size for matching.\n\n    Returns\n    -------\n    matched_mask : torch.Tensor\n        Boolean mask of shape (batch_size, seq_len - max_n + 1) indicating\n        which sliding windows match the target n-gram.\n    found : torch.Tensor\n        Boolean tensor of shape (batch_size,) indicating whether each batch\n        element has at least `threshold` matches.\n    \"\"\"\n\n    batch_size, seq_len = token_ids.size()\n    if seq_len < max_n:\n        # Not enough tokens to form a single n-gram\n        matched_mask = torch.zeros((batch_size, 0), dtype=torch.bool, device=token_ids.device)\n        found = torch.zeros((batch_size, ), dtype=torch.bool, device=token_ids.device)\n        return matched_mask, found\n    # token_ids could be 0, so we add 2 to avoid div 0\n    token_ids = (token_ids + 2).to(torch.float32).log2()\n\n    # Trim to max_window_size\n    if seq_len >= max_window_size:\n        token_ids = token_ids[:, -max_window_size:]\n    max_window_size = token_ids.size(1)\n\n    # normalize ids\n    # we would set n=None if n shared same value. Read lmdeploy/pytorch/strategies/ar/sampling.py for more details\n    same_n = n is None\n    norm = token_ids[:, -max_n:]\n    if not same_n:\n        # fill 0 for n < max_n\n        mask = torch.arange(max_n, device=token_ids.device).unsqueeze(0) >= (max_n - n.unsqueeze(1))\n        norm = norm * mask.to(torch.float32)\n    norm = norm.norm(2, dim=-1, keepdim=True)\n    normed_ids = token_ids / norm\n\n    # concate p1 and p2 so we can check distance and vector in one conv1d\n    normed_n_ids = normed_ids[:, -max_n:]\n    normed_ids_p2 = normed_ids * normed_ids\n    ones_ids = torch.ones_like(normed_n_ids)\n    if not same_n:\n        # fill 0 for n < max_n\n        normed_n_ids = normed_n_ids * mask.to(torch.float32)\n        ones_ids = ones_ids * mask.to(torch.float32)\n    normed_ids = torch.cat([normed_ids, normed_ids_p2], dim=0)\n    normed_n_ids = torch.cat([normed_n_ids, ones_ids], dim=0)\n\n    # check cos distance & check vector length\n    match_norm = torch.conv1d(normed_ids.unsqueeze(0), normed_n_ids.unsqueeze(1), groups=batch_size * 2)[0]\n    match_norm, match_ones = match_norm.chunk(2, dim=0)\n\n    # both match result should be close to 1\n    one_tensor = _ngram_one(dtype=match_norm.dtype, device=match_norm.device, fill=1)\n    matched_mask = match_norm.isclose(one_tensor) & match_ones.isclose(one_tensor)\n\n    # threshold\n    count = matched_mask.sum(-1)\n    found = (count >= threshold) & (threshold > 0)\n\n    return matched_mask, found\n\n\ndef _filter_repetition_ngram_(\n    scores: torch.Tensor,\n    stop_words: torch.Tensor,\n    generated_ids: torch.Tensor,\n    n: torch.Tensor | None,\n    threshold: torch.Tensor,\n    max_n: int,\n    max_ngram_window_size: int,\n):\n    \"\"\"Filter ngram.\n\n    if generated ngram found, set all scores -inf, and set stop words to 0. We assume that stop words always exist.\n    \"\"\"\n    if stop_words is None or stop_words.numel() == 0:\n        return scores\n    # use first stop words\n    _, found = ngram(generated_ids, n, threshold, max_n, max_ngram_window_size)\n    stop_words = stop_words[:, 0]\n    # fill all scores -inf\n    scores.masked_fill_(found[:, None], -float('inf'))\n    # set stop words to 0\n    stop_scores = scores.gather(1, stop_words[:, None])\n    stop_scores.masked_fill_(found[:, None], 0)\n    scores.scatter_(1, stop_words[:, None], stop_scores)\n    return scores\n\n\ndef _multinomial_sampling(scores: torch.Tensor,\n                          seeds: torch.LongTensor,\n                          offsets: torch.LongTensor,\n                          indices: torch.LongTensor = None):\n    \"\"\"sampling.\"\"\"\n    from lmdeploy.pytorch.nn.multinomial_sampling import multinomial_sampling\n    return multinomial_sampling(scores, seeds, offsets, indices)\n\n\nSeqList = list[SchedulerSequence]\n\n\n@dataclass\nclass SamplingInputsDelta:\n    num_ignore_eos: torch.Tensor = None\n    random_offsets: torch.Tensor = None\n    all_ids: None | torch.Tensor = None\n\n\n@dataclass\nclass SamplingInputs:\n    temperature: torch.Tensor = None\n    bad_words: torch.LongTensor = None\n    bad_mask: torch.BoolTensor = None\n    stop_words: torch.LongTensor = None\n    stop_mask: torch.BoolTensor = None\n    repetition_penalty: torch.Tensor = None\n    top_k: torch.LongTensor = None\n    top_p: torch.Tensor = None\n    min_p: torch.Tensor = None\n    random_seeds: torch.Tensor = None\n    random_offsets: torch.Tensor = None\n    max_top_k: int = 1\n    min_top_p: float = 1.0\n    response_formats: list[str, ...] = ()\n    logits_processors: list[list[LogitsProcessor]] = None\n    max_num_logprobs: None | int = None\n    all_ids: None | torch.Tensor = None\n    num_ignore_eos: torch.Tensor = None\n    batch_size: int = 0\n    session_ctx: None | list[dict[str, Any]] = None\n    session_to_cleanup: None | list[int] = None\n    # for repetition_penalty and ngram\n    generated_ids: torch.Tensor | None = None\n    generated_ids_cpu: np.ndarray | None = None\n\n    # n gram\n    repetition_ngram_size: torch.Tensor | None = None\n    repetition_ngram_threshold: torch.Tensor | None = None\n    max_repetition_ngram_size: int = 0\n\n    def to_device(self, device: str, non_blocking: bool = False):\n        \"\"\"To device.\"\"\"\n        out_dict = dict()\n        if self.generated_ids is None and self.generated_ids_cpu is not None:\n            self.generated_ids = torch.from_numpy(self.generated_ids_cpu.copy())\n        for f in fields(self):\n            k = f.name\n            v = getattr(self, k)\n            if isinstance(v, torch.Tensor):\n                v = v.to(device, non_blocking=non_blocking)\n            out_dict[k] = v\n\n        return SamplingInputs(**out_dict)\n\n    def get_delta(self) -> SamplingInputsDelta:\n        \"\"\"Get delta.\"\"\"\n        delta = SamplingInputsDelta()\n        for f in fields(self):\n            k = f.name\n            v = getattr(self, k)\n            if isinstance(v, torch.Tensor):\n                setattr(delta, k, v)\n        return delta\n\n    def update_delta(self, delta: SamplingInputsDelta):\n        \"\"\"Update from delta.\"\"\"\n        for f in fields(delta):\n            k = f.name\n            v = getattr(delta, k)\n            if v is not None:\n                setattr(self, k, v)\n\n\ndef _apply_custom_logits_processors(batched_logits_processors, all_ids, logits):\n    \"\"\"Apply custom logits processors.\"\"\"\n    for seq_id, processors in enumerate(batched_logits_processors):\n        if processors is not None:\n            for processor in processors:\n                logits[seq_id] = processor(all_ids[seq_id], logits[seq_id])\n    return logits\n\n\ndef _torch_topk(x: torch.Tensor, k: int, dim: int = -1, largest: bool = True, sorted: bool = True):\n    if k == 1:\n        # torch.topk would not fallback to torch.max/torch.min automatically\n        if largest:\n            return torch.max(x, dim=dim, keepdim=True)\n        else:\n            return torch.min(x, dim=dim, keepdim=True)\n    else:\n        return torch.topk(x, k, dim=dim, largest=largest, sorted=sorted)\n\n\nclass FusedLogitsProcessor:\n    \"\"\"Custom logits processor.\"\"\"\n\n    def __init__(\n        self,\n        sampling_inputs: SamplingInputs,\n        logprobs_mode: None | str = None,\n        guided_decoding_manager: None | GuidedDecodingManager = None,\n    ):\n        self.sampling_inputs: SamplingInputs = sampling_inputs\n        self.logprobs_mode = logprobs_mode\n        self.guided_decoding_manager = guided_decoding_manager\n        if sampling_inputs.session_to_cleanup:\n            self.cleanup_sessions(sampling_inputs.session_to_cleanup)\n\n        if self.guided_decoding_manager:\n            self.guided_processors = self.guided_decoding_manager.get_processors(sampling_inputs.session_ctx,\n                                                                                 sampling_inputs.response_formats)\n        else:\n            self.guided_processors = {}\n\n    async def _wait_stream_once(self):\n        \"\"\"Wait stream once.\"\"\"\n        stream = torch.cuda.current_stream()\n        if not stream.query():\n            await asyncio.sleep(0)\n\n    async def __call__(self, scores: torch.Tensor) -> torch.Tensor:\n        r\"\"\"\n        Args:\n            scores (torch.Tensor):\n                Prediction scores of a language modeling head.\n                These can be logits for each vocabulary when not using\n                beam search or log softmax for each vocabulary token\n                when using beam search\n\n\n        Return:\n            torch.Tensor: The processed prediction scores.\n\n        \"\"\"\n\n        num_logprobs = self.sampling_inputs.max_num_logprobs\n        # get raw logprobs\n        if num_logprobs < 0:\n            logprobs = None\n        else:\n            if self.logprobs_mode == 'raw_logits':\n                logprobs = scores.clone()\n            elif self.logprobs_mode == 'raw_logprobs':\n                logprobs = scores.log_softmax(dim=-1)\n            else:\n                logprobs = None\n\n        sampling_inputs = self.sampling_inputs\n        all_ids = sampling_inputs.all_ids\n        custom_logits_processors = self.sampling_inputs.logits_processors\n        if self.guided_decoding_manager and self.guided_processors:\n            if not hasattr(self, 'guided_bitmask'):\n                self.guided_bitmask = self.guided_decoding_manager.allocate_batched_bitmap(len(scores))\n\n            assert self.guided_bitmask is not None\n            guided_bitmask = self.guided_bitmask\n\n            await self._wait_stream_once()\n            for i, processor in self.guided_processors.items():\n                self.guided_decoding_manager.fill_bitmap(processor, guided_bitmask, i)\n\n            self.guided_decoding_manager.apply_batched_bitmap(scores, guided_bitmask)\n\n        if any(custom_logits_processors):\n            await self._wait_stream_once()\n            scores = _apply_custom_logits_processors(custom_logits_processors, all_ids, scores)\n\n        repetition_penalty = sampling_inputs.repetition_penalty\n        if repetition_penalty is not None:\n            generated_ids = sampling_inputs.generated_ids\n            scores = _process_repetition_penalty_(scores, generated_ids, repetition_penalty)\n\n        if sampling_inputs.max_repetition_ngram_size > 0:\n            generated_ids = sampling_inputs.generated_ids\n            assert generated_ids is not None\n            assert sampling_inputs.repetition_ngram_threshold is not None\n            max_repetition_ngram_window_size = envs.repetition_window_size\n            scores = _filter_repetition_ngram_(\n                scores,\n                sampling_inputs.stop_words,\n                generated_ids,\n                sampling_inputs.repetition_ngram_size,\n                sampling_inputs.repetition_ngram_threshold,\n                sampling_inputs.max_repetition_ngram_size,\n                max_repetition_ngram_window_size,\n            )\n\n        temperature = sampling_inputs.temperature\n        if temperature is not None:\n            scores = _process_temperature_(scores, temperature)\n\n        bad_words = sampling_inputs.bad_words\n        if bad_words is not None:\n            bad_mask = sampling_inputs.bad_mask\n            scores = _process_bad_words_(scores, bad_words, bad_mask)\n\n        stop_words = sampling_inputs.stop_words\n        if stop_words is not None:\n            ignore_eos = sampling_inputs.num_ignore_eos > 0\n            stop_mask = sampling_inputs.stop_mask\n            stop_mask = torch.where(ignore_eos[:, None], stop_mask, False)\n            scores = _process_bad_words_(scores, stop_words, stop_mask)\n\n        return scores, logprobs\n\n    @torch.inference_mode()\n    def sampling(self, logits: torch.Tensor):\n        \"\"\"sampling.\"\"\"\n        sampling_inputs = self.sampling_inputs\n\n        def __random_sampling(scores: torch.Tensor, indices: torch.LongTensor):\n            \"\"\"Random sampling.\"\"\"\n            max_topk = sampling_inputs.max_top_k\n            top_k = sampling_inputs.top_k\n            if max_topk <= 0:\n                max_topk = scores.size(1)\n                if top_k is not None:\n                    top_k = torch.masked_fill(top_k, top_k <= 0, max_topk)\n\n            if top_k is not None:\n                scores = _filter_topk_sorted_(scores, top_k)\n\n            top_p = sampling_inputs.top_p\n            if top_p is not None:\n                scores = _filter_topp_sorted_(scores, top_p)\n\n            min_p = sampling_inputs.min_p\n            if min_p is not None:\n                scores = _filter_minp_sorted_(scores, min_p)\n\n            softmax_scores = scores.softmax(1)\n\n            seeds = sampling_inputs.random_seeds\n            offsets = sampling_inputs.random_offsets\n            return _multinomial_sampling(softmax_scores, seeds, offsets, indices)\n\n        if sampling_inputs.max_top_k == 1:\n            result = logits.argmax(-1)\n        else:\n            # sort logits is too slow. and we only need topk logits\n            max_topk = sampling_inputs.max_top_k\n            if max_topk <= 0:\n                scores, indices = logits.sort(1, descending=True)\n            else:\n                scores, indices = _torch_topk(logits, max_topk, dim=1)\n            result = __random_sampling(scores, indices)\n\n        if self.guided_decoding_manager and self.guided_processors:\n            for i, processor in self.guided_processors.items():\n                self.guided_decoding_manager.accept_token(processor, result[i])\n\n        return result\n\n    @torch.inference_mode()\n    def compute_logprobs(self, raw_logprobs: torch.Tensor, token_ids: torch.LongTensor):\n        \"\"\"Compute logprobs.\"\"\"\n        if raw_logprobs is None:\n            return None\n\n        indices = token_ids.unsqueeze(-1)\n        logprobs = raw_logprobs.gather(-1, indices)\n        num_logprobs = self.sampling_inputs.max_num_logprobs\n        if num_logprobs > 0:\n            topk_logprobs, topk_indices = _torch_topk(raw_logprobs, num_logprobs, dim=-1)\n            logprobs = torch.cat([logprobs, topk_logprobs], dim=-1)\n            indices = torch.cat([indices, topk_indices], dim=-1)\n\n        return logprobs, indices.to(torch.int32)\n\n    def cleanup_sessions(self, session_ids: list[int]):\n        if self.guided_decoding_manager:\n            for session_id in session_ids:\n                self.guided_decoding_manager.remove_processor(session_id)\n"
  },
  {
    "path": "lmdeploy/pytorch/engine/model_agent/__init__.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom typing import Dict\n\nfrom lmdeploy.pytorch.config import BackendConfig, CacheConfig, MiscConfig, ModelConfig, SpecDecodeConfig\nfrom lmdeploy.pytorch.devices import DeviceContext, get_device_manager\nfrom lmdeploy.pytorch.distributed import DistContext, get_dist_manager\n\nfrom .agent import BaseModelAgent, BatchedOutputs  # noqa: F401\n\n\ndef build_model_agent(\n    model_path: str,\n    model_config: ModelConfig,\n    cache_config: CacheConfig,\n    backend_config: BackendConfig,\n    misc_config: MiscConfig,\n    dist_ctx: DistContext = None,\n    device_ctx: DeviceContext = None,\n    adapters: Dict[str, str] = None,\n    specdecode_config: SpecDecodeConfig = None,\n):\n    \"\"\"Create model agent.\n\n    Args:\n        model_path (str): the path of the input model\n        cache_config (CacheConfig): config of kv cache\n        backend_config (BackendConfig): config of backend devices\n        trust_remote_code (bool): To use the remote modeling code or not\n        adapters (Dict): lora adapters\n        tp (int): the number of devices to be used in tensor parallelism\n        dtype (str): the data type of model weights and activations\n        custom_module_map (str): customized nn module map\n    \"\"\"\n\n    if device_ctx is None:\n        device_mgr = get_device_manager()\n        device_ctx = device_mgr.current_context()\n    if dist_ctx is None:\n        dist_mgr = get_dist_manager()\n        dist_ctx = dist_mgr.current_context()\n\n    model_agent = BaseModelAgent(\n        model_path,\n        model_config=model_config,\n        cache_config=cache_config,\n        backend_config=backend_config,\n        misc_config=misc_config,\n        adapters=adapters,\n        dist_ctx=dist_ctx,\n        device_ctx=device_ctx,\n        specdecode_config=specdecode_config,\n    )\n    return model_agent\n"
  },
  {
    "path": "lmdeploy/pytorch/engine/model_agent/agent.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport asyncio\nimport time\nfrom contextlib import contextmanager\nfrom dataclasses import dataclass, field, fields\nfrom multiprocessing.reduction import ForkingPickler\nfrom os import getenv\nfrom typing import Any, Dict, List, Optional\n\nimport numpy as np\nimport pybase64\nimport torch\nimport torch.distributed as dist\nfrom torch.profiler import record_function\n\nfrom lmdeploy.pytorch.backends import get_backend\nfrom lmdeploy.pytorch.config import BackendConfig, CacheConfig, MiscConfig, ModelConfig, SpecDecodeConfig\nfrom lmdeploy.pytorch.devices import DeviceContext, get_device_manager\nfrom lmdeploy.pytorch.disagg.config import EngineRole\nfrom lmdeploy.pytorch.distributed import DistContext, get_dist_manager\nfrom lmdeploy.pytorch.engine.cache_engine import CacheEngine, StateCacheEngine\nfrom lmdeploy.pytorch.engine.guided_process import GuidedDecodingManager\nfrom lmdeploy.pytorch.engine.logits_process import FusedLogitsProcessor, SamplingInputs, SamplingInputsDelta\nfrom lmdeploy.pytorch.model_inputs import ModelInputs, ModelInputsDelta, step_ctx_manager\nfrom lmdeploy.pytorch.models.patch import BuildModelContext, add_adapters, build_patched_model, update_custom_module_map\nfrom lmdeploy.pytorch.spec_decode import build_spec_agent\nfrom lmdeploy.pytorch.strategies import build_strategy_factory\nfrom lmdeploy.pytorch.strategies.base.model_agent import ExtraInputs, ExtraOutputs, StoppingCriteria\nfrom lmdeploy.pytorch.utils import get_gpu_memory, monkey_patch_hf_modules_cache, wait_for_async_tasks\nfrom lmdeploy.pytorch.weight_loader.model_weight_loader import ModelWeightLoader, load_model_weights\nfrom lmdeploy.serve.openai.protocol import UpdateParamsRequest\nfrom lmdeploy.tokenizer import Tokenizer\nfrom lmdeploy.utils import FlattenedTensorBucket, FlattenedTensorMetadata, get_logger\n\nfrom .inputs_maker import build_inputs_maker\nfrom .profiler import AgentProfiler\n\nlogger = get_logger('lmdeploy')\n\n\n@dataclass\nclass SleepWakeupState:\n    to_sleep: asyncio.Event = field(default_factory=asyncio.Event)\n    to_wakeup: asyncio.Event = field(default_factory=asyncio.Event)\n    is_sleeping: bool = False\n\n\n@dataclass\nclass BatchedLogProbs:\n    vals: torch.Tensor\n    indices: torch.Tensor\n\n    def to_cpu(self):\n        \"\"\"To cpu.\"\"\"\n        return BatchedLogProbs(vals=self.vals.cpu(), indices=self.indices.cpu())\n\n    def to_numpy(self):\n        \"\"\"To numpy.\"\"\"\n        if self.vals.dtype == torch.bfloat16:\n            np_vals = self.vals\n        else:\n            np_vals = self.vals.detach().numpy()\n        return BatchedLogProbs(vals=np_vals, indices=self.indices.detach().numpy())\n\n    def to_tensor(self):\n        \"\"\"To tensor.\"\"\"\n        if isinstance(self.vals, torch.Tensor):\n            vals = self.vals\n        else:\n            vals = torch.from_numpy(self.vals)\n        return BatchedLogProbs(vals=vals, indices=torch.from_numpy(self.indices))\n\n\n@dataclass\nclass BatchedOutputs:\n    next_token_ids: torch.Tensor\n    stopped: torch.Tensor\n    stop_pos: Optional[torch.Tensor] = None\n    logits: Optional[torch.Tensor] = None\n    model_metas: List[Dict[str, Any]] = None\n    logprobs: Optional[BatchedLogProbs] = None\n    new_token_timestamp: int = 0\n    extra_outputs: Optional[ExtraOutputs] = None\n    all_routed_experts: Optional[torch.Tensor] = None\n\n    def to_cpu(self):\n        \"\"\"To cpu.\"\"\"\n        out = dict()\n        for f in fields(self):\n            k = f.name\n            v = getattr(self, k)\n            if isinstance(v, torch.Tensor):\n                v = v.cpu()\n            elif hasattr(v, 'to_cpu'):\n                v = v.to_cpu()\n            out[k] = v\n        return BatchedOutputs(**out)\n\n    def to_numpy(self):\n        \"\"\"To numpy.\"\"\"\n        out = dict()\n        for f in fields(self):\n            k = f.name\n            v = getattr(self, k)\n            if isinstance(v, torch.Tensor) and v.dtype != torch.bfloat16:\n                v = v.detach().numpy()\n            elif hasattr(v, 'to_numpy'):\n                v = v.to_numpy()\n            out[k] = v\n        return BatchedOutputs(**out)\n\n    def to_tensor(self):\n        \"\"\"To tensor.\"\"\"\n        out = dict()\n        for f in fields(self):\n            k = f.name\n            v = getattr(self, k)\n            if isinstance(v, np.ndarray):\n                v = torch.from_numpy(v)\n            elif hasattr(v, 'to_tensor'):\n                v = v.to_tensor()\n            out[k] = v\n        return BatchedOutputs(**out)\n\n\ndef msg_with_rank(rank: int, msg: str):\n    \"\"\"Return message with rank.\"\"\"\n    return f'rank[{rank}] - {msg}'\n\n\ndef cache_swapping(cache_engine: CacheEngine, swap_in_map: dict, swap_out_map: dict):\n    \"\"\"Perform cache swapping.\"\"\"\n    issued_cache_op = False\n    swap_in_map = swap_in_map or dict()\n    swap_out_map = swap_out_map or dict()\n    if len(swap_in_map) > 0:\n        cache_engine.swap_in(swap_in_map)\n        issued_cache_op = True\n    if len(swap_out_map) > 0:\n        cache_engine.swap_out(swap_out_map)\n        issued_cache_op = True\n\n    if issued_cache_op:\n        cache_engine.events.wait()\n\n\n@torch.inference_mode()\ndef model_forward(\n    model: torch.nn.Module,\n    inputs: ModelInputs,\n    cache_engine: CacheEngine,\n    state_cache_engine: StateCacheEngine,\n    stream: torch.cuda.Stream = None,\n):\n    \"\"\"Perform model forward.\"\"\"\n    stream = stream or torch.cuda.current_stream()\n    with torch.cuda.stream(stream), step_ctx_manager(model.ctx_mgr):\n        # forward\n        ctx_mgr = model.ctx_mgr\n        context = ctx_mgr.build_context(\n            inputs=inputs,\n            model_config=cache_engine.model_config,\n            cache_config=cache_engine.cache_config,\n            kv_caches=cache_engine.gpu_cache,\n            state_caches=state_cache_engine.state_caches,\n            kv_quant_policy=cache_engine.cache_config.quant_policy,\n        )\n\n        with ctx_mgr.context(context):\n            model_metas = model.update_model_metas(\n                past_key_values=cache_engine.gpu_cache,\n                context=context,\n            )\n            input_dict = model.prepare_inputs_for_generation(\n                past_key_values=cache_engine.gpu_cache,\n                context=context,\n            )\n            output = model(**input_dict)\n            if not isinstance(output, Dict):\n                output = dict(hidden_states=output)\n            # InternVL-3.5-Flash will change the seqlen, model_metas during forward\n            if getattr(context, 'is_model_meta_updated', False):\n                model_metas = context.model_metas\n            output['model_metas'] = model_metas\n            output['seq_length'] = context.q_seqlens[:len(inputs.seq_length)]\n            # for draft model reuse\n            output['position_ids'] = context.position_ids\n            return output\n\n\ndef _try_to_cuda(val, non_blocking: bool = False):\n    if val is None:\n        return val\n    elif isinstance(val, torch.Tensor):\n        return val.cuda(non_blocking=non_blocking)\n    elif hasattr(val, 'to_device'):\n        return val.to_device('cuda', non_blocking=non_blocking)\n    else:\n        raise RuntimeError(f'Can not cast {type(val)} to cuda.')\n\n\nclass DistGatherScalar:\n    \"\"\"Distribute value gather.\"\"\"\n\n    def __init__(self, val, size: int, device: str = 'cpu', group: dist.ProcessGroup = None):\n        self.val = val\n        self.device = device\n        self.group = group\n\n        self.all_vals = torch.tensor([val] * size, device=device)\n        self.worker = dist.all_gather_into_tensor(self.all_vals,\n                                                  self.all_vals.new_tensor([val]),\n                                                  group=group,\n                                                  async_op=True)\n\n    async def async_wait(self, timeout: float = 0.001):\n        while not self.worker.is_completed():\n            await asyncio.sleep(timeout)\n        self.worker.wait()\n        return self.all_vals\n\n\nSwapMap = Dict[int, int]\n\n\n@dataclass\nclass StepInputs:\n    \"\"\"Step inputs.\"\"\"\n    model_inputs: ModelInputs = None\n    extra_inputs: ExtraInputs = None\n    stopping_criteria: StoppingCriteria = None\n    sampling_delta: SamplingInputsDelta = None\n\n    @record_function('StepInputs.merge')\n    def merge(\n        self,\n        inputs: ModelInputs,\n        extra_inputs: ExtraInputs,\n        stopping_criteria: StoppingCriteria,\n        sampling_delta: SamplingInputsDelta,\n        next_token_ids: torch.Tensor,\n        model_metas,\n        extra_outputs: ExtraOutputs,\n        model_agent: 'BaseModelAgent',\n    ):\n        \"\"\"Merge prefill inputs.\"\"\"\n        inputs, extra_inputs = model_agent.agent_strategy.update_prefill_for_next_step(\n            inputs,\n            extra_inputs,\n            next_token_ids,\n            model_metas,\n            extra_outputs,\n        )\n        stopping_criteria = stopping_criteria.clone()\n        sampling_delta = model_agent.sampling_strategy.step_sampling_delta(sampling_delta,\n                                                                           next_token_ids,\n                                                                           extra_inputs=extra_inputs)\n        if self.model_inputs is None:\n            self.model_inputs = inputs\n            self.extra_inputs = extra_inputs\n            self.stopping_criteria = stopping_criteria\n            self.sampling_delta = sampling_delta\n        else:\n            self.model_inputs = model_agent.inputs_strategy.merge(self.model_inputs, inputs)\n            self.extra_inputs = self.extra_inputs.merge(extra_inputs)\n            self.stopping_criteria = self.stopping_criteria.merge(stopping_criteria)\n            self.sampling_delta = model_agent.sampling_strategy.merge_sampling_delta(\n                self.sampling_delta, sampling_delta)\n\n    def update_delta(\n        self,\n        delta: ModelInputsDelta,\n        model_agent: 'BaseModelAgent',\n    ):\n        \"\"\"Get inputs from delta.\"\"\"\n        self.model_inputs = model_agent.inputs_strategy.update_inputs(self.model_inputs, delta)\n        self.extra_inputs = model_agent.agent_strategy.update_extra_inputs(self.extra_inputs, delta)\n        self.stopping_criteria = self.stopping_criteria.update(delta)\n        self.sampling_delta = model_agent.sampling_strategy.update_sampling_delta(self.sampling_delta, delta)\n\n    @record_function('StepInputs.step')\n    def step(\n        self,\n        model_inputs: ModelInputs,\n        extra_inputs: ExtraInputs,\n        stopping_criteria: StoppingCriteria,\n        sampling_delta: SamplingInputsDelta,\n        next_token_ids: torch.Tensor,\n        model_metas,\n        extra_outputs: ExtraOutputs,\n        model_agent: 'BaseModelAgent',\n    ):\n        \"\"\"Update inputs.\"\"\"\n        # dp might change is_decoding of decoding inputs\n        model_inputs.is_decoding = True\n        (\n            self.model_inputs,\n            self.extra_inputs,\n        ) = model_agent.agent_strategy.update_decoding_for_next_step(\n            model_inputs,\n            next_token_ids=next_token_ids,\n            model_metas=model_metas,\n            extra_inputs=extra_inputs,\n            extra_outputs=extra_outputs,\n        )\n        self.stopping_criteria = stopping_criteria.clone()\n        self.sampling_delta = model_agent.sampling_strategy.step_sampling_delta(sampling_delta,\n                                                                                next_token_ids,\n                                                                                extra_inputs=extra_inputs)\n\n\nclass BaseModelAgent:\n    \"\"\"Base model agent.\n\n    load model on local gpu\n\n    Args:\n        model_path (str): The hugging face model path.\n        model_config (ModelConfig): The config of the model.\n        cache_config (CacheConfig): The config of the cache info.\n        trust_remote_code (bool): Trust remote code\n    \"\"\"\n\n    def __init__(\n        self,\n        model_path: str,\n        model_config: ModelConfig,\n        cache_config: CacheConfig,\n        backend_config: BackendConfig,\n        misc_config: MiscConfig,\n        dist_ctx: DistContext,\n        device_ctx: DeviceContext,\n        adapters: Dict[str, str] = None,\n        specdecode_config: SpecDecodeConfig = None,\n    ):\n\n        self.model_config = model_config\n        self.cache_config = cache_config\n        # use raw tokenizer\n        if dist_ctx.dist_config.world_size > 1:\n            monkey_patch_hf_modules_cache()\n        self.tokenizer = Tokenizer(model_path).model.model\n\n        # asyncio\n        self._pre_in_que = None\n        self._in_que = None\n        self._out_que = None\n        self._background_task = None\n        self._preprocess_task = None\n        self.tasks = set()\n\n        # cuda stream\n        self.stream = torch.cuda.Stream()\n        self.out_stream = torch.cuda.Stream()\n        self.cache_stream = torch.cuda.Stream()\n\n        self.dist_ctx = dist_ctx\n        self.device_ctx = device_ctx\n\n        device = 'cuda'\n        self.backend_config = backend_config\n        self.misc_config = misc_config\n        self.dist_config = dist_ctx.dist_config\n        rank = dist_ctx.rank\n\n        self.model_path = model_path\n        self.adapters = adapters\n        self.device = device\n        self.rank = rank\n\n        tp = self.dist_config.tp\n        world_size = self.dist_config.world_size\n        self.tp = tp\n        self.world_size = world_size\n        self.need_output = rank % self.dist_config.attn_tp == 0\n\n        self.patched_model = None\n        self.cache_engine = None\n        self.state_cache_engine = None\n        self.profiler: AgentProfiler = None\n        try:\n            self.guided_decoding_manager = GuidedDecodingManager(self.tokenizer, model_config.vocab_size)\n        except ValueError as e:\n            logger.warning(f'Failed to create GuidedManager for tokenizer {type(self.tokenizer)}: {e}')\n            self.guided_decoding_manager = None\n\n        # microbatch\n        self.enable_microbatch = self.dist_config.enable_microbatch\n        self.enable_microbatch_prefill_batchsize_threshold = \\\n            int(getenv('ENABLE_MICROBATCH_PREFILL_BATCHSIZE_THRESHOLD', 2))\n        self.enable_microbatch_prefill_token_threshold = \\\n            int(getenv('ENABLE_MICROBATCH_PREFILL_TOKEN_THRESHOLD', 2))\n        self.enable_microbatch_decode_batchsize_threshold = \\\n            int(getenv('ENABLE_MICROBATCH_DECODE_BATCHSIZE_THRESHOLD', 2))\n\n        # strategy\n        self.strategy_factory = build_strategy_factory(model_config, misc_config, specdecode_config=specdecode_config)\n        self.inputs_strategy = self.strategy_factory.build_model_inputs_strategy()\n        self.agent_strategy = self.strategy_factory.build_model_agent_strategy()\n        self.sampling_strategy = self.strategy_factory.build_sampling_strategy()\n\n        # spec decoding\n        self.spec_agent = build_spec_agent(specdecode_config,\n                                           backend_config,\n                                           dist_ctx,\n                                           self.inputs_strategy,\n                                           self.agent_strategy,\n                                           device=device)\n        # sleep wakeup state\n        self.state: SleepWakeupState = SleepWakeupState()\n\n        # decoding inputs\n        self.step_inputs = StepInputs()\n\n        # long context\n        self._prev_chunk_output: Dict = None\n\n    @contextmanager\n    def all_context(self):\n        device_mgr = get_device_manager()\n        dist_mgr = get_dist_manager()\n        with device_mgr.context(self.device_ctx), dist_mgr.context(self.dist_ctx), torch.inference_mode():\n            yield\n\n    def set_cache_config(self, cache_config: CacheConfig, spec_cache_config: CacheConfig = None):\n        \"\"\"Set all cache config.\"\"\"\n        self.cache_config = cache_config\n        self.spec_agent.set_cache_config(spec_cache_config)\n\n    def set_model_config(self, model_config: ModelConfig, spec_model_config: ModelConfig = None):\n        \"\"\"Set model config.\"\"\"\n        self.model_config = model_config\n        self.spec_agent.set_model_config(spec_model_config)\n\n    def get_free_mem(self):\n        \"\"\"Gather available memory.\"\"\"\n        with self.all_context():\n            torch.cuda.empty_cache()\n            gpu_mem_physical_free, _ = get_gpu_memory()\n            return gpu_mem_physical_free\n\n    def warmup(self):\n        \"\"\"warmup.\"\"\"\n        from lmdeploy.pytorch.envs import skip_warmup\n        if skip_warmup:\n            return\n\n        with self.all_context(), torch.cuda.stream(self.stream):\n            max_batches = self.cache_config.max_batches\n            world_size = self.dist_config.world_size\n\n            num_tokens = max_batches\n            dp = self.dist_config.dp\n\n            if dp > 1:\n                # make sure warmup started together\n                group = self.dist_ctx.cpu_group\n                dist.barrier(group=group)\n\n            # warmup prefill\n            inputs = self.inputs_strategy.make_dummy(max_batches,\n                                                     is_decoding=False,\n                                                     device='cuda',\n                                                     vocab_size=self.model_config.vocab_size)\n            if dp > 1:\n                num_tokens = inputs.input_ids.numel()\n                inputs.build_dp_meta([num_tokens] * world_size)\n            logger.debug('Warmup prefill start.')\n            self._forward_impl(inputs)\n            torch.cuda.synchronize()\n            logger.debug('Warmup prefill done.')\n\n            # warmup decoding(with cuda graph)\n            capture_batch_sizes = self.patched_model.get_capture_batch_sizes()\n            capture_batch_sizes = sorted(capture_batch_sizes, reverse=True)\n            if self.cache_config.role == EngineRole.Prefill:\n                # do not warmup decoding for prefill engine\n                capture_batch_sizes = []\n            for num_tokens in capture_batch_sizes:\n                inputs = self.inputs_strategy.make_dummy(num_tokens,\n                                                         is_decoding=True,\n                                                         device='cuda',\n                                                         vocab_size=self.model_config.vocab_size)\n                if dp > 1:\n                    num_tokens = inputs.input_ids.numel()\n                    inputs.build_dp_meta([num_tokens] * world_size)\n                logger.debug(f'Warmup decoding num_tokens={num_tokens} start.')\n                self._forward_impl(inputs)\n                torch.cuda.synchronize()\n                logger.debug(f'Warmup decoding num_tokens={num_tokens} done.')\n\n            # warmup draft model\n            self.spec_agent.warmup(max_batches, self.model_config)\n\n    def _slice_outs(self, inputs: torch.Tensor, seq_length: torch.LongTensor):\n        \"\"\"Slice outputs.\"\"\"\n        return self.agent_strategy.slice_outputs(inputs, seq_length)\n\n    def _postprocess_forward_output(self, output: dict, inputs: ModelInputs):\n        \"\"\"Post process forward output.\"\"\"\n        hidden_states = output['hidden_states']\n        seq_length = output.get('seq_length', inputs.seq_length)\n        hidden_states = self._slice_outs(hidden_states[0], seq_length)[None]\n        output['hidden_states'] = hidden_states\n        return output\n\n    async def _async_model_forward(\n        self,\n        inputs: ModelInputs,\n        return_logits: bool,\n    ):\n        \"\"\"Model forward.\"\"\"\n        origin_inputs = inputs\n        ret = await self.async_forward(inputs)\n\n        if not return_logits:\n            ret = self._postprocess_forward_output(ret, origin_inputs)\n\n        hidden_states, ret = self.spec_agent.update_main_model_outputs(ret, origin_inputs)\n\n        logits = self.get_logits(hidden_states)\n        ret['logits'] = logits\n        return ret\n\n    async def async_sampling_logits(self, logits: torch.Tensor, sampling_inputs: SamplingInputs):\n        \"\"\"Sampling logits.\"\"\"\n\n        # record function does not support async function\n        # so we can not decorate it on async_sampling_logits\n        with record_function('sampling_logits'):\n            logits_processor = FusedLogitsProcessor(\n                sampling_inputs,\n                logprobs_mode=self.misc_config.logprobs_mode,\n                guided_decoding_manager=self.guided_decoding_manager,\n            )\n            origin_logits = logits\n            logits, raw_logprobs = await logits_processor(origin_logits)\n            next_token_ids = logits_processor.sampling(logits)\n            logprobs = logits_processor.compute_logprobs(raw_logprobs, next_token_ids)\n            if logprobs is not None:\n                logprobs = BatchedLogProbs(\n                    vals=logprobs[0],\n                    indices=logprobs[1],\n                )\n\n        return next_token_ids, logprobs\n\n    def _push_output(self, output: BatchedOutputs):\n        \"\"\"Push output.\"\"\"\n        event = torch.cuda.Event()\n        event.record()\n        self._out_que.put_nowait((output, event))\n\n    @contextmanager\n    def _broadcast_next_token(self, next_token_ids: torch.Tensor, extra_inputs: ExtraInputs, enable: bool = True):\n        if not enable:\n            yield\n            return\n\n        dist_ctx = self.dist_ctx\n        with self.agent_strategy.broadcast_next_token(next_token_ids, extra_inputs, dist_ctx) as handle:\n            yield handle\n\n    @record_function('prepare_dp')\n    async def _prepare_dp_v1(self, inputs: ModelInputs):\n        \"\"\"Prepare dp.\n\n        If all inputs are dummy inputs, skip forward. If any of the inputs is prefill, then do prefill. Set padding\n        batch size for decoding.\n        \"\"\"\n        world_size = self.dist_config.world_size\n        is_decoding = inputs.is_decoding\n        num_tokens = inputs.input_ids.numel()\n        is_dummy = inputs.is_dummy\n\n        # gather dp forward metadata\n        batch_size = inputs.seq_length.numel()\n        is_sleeping = self.state.is_sleeping\n        dp_forward_meta = [int(is_decoding), int(is_dummy), num_tokens, int(is_sleeping)]\n        # check enable_microbatch\n        if self.enable_microbatch:\n            tokens_num = inputs.input_ids.numel()\n            if is_decoding:\n                enable_microbatch = batch_size >= \\\n                    self.enable_microbatch_decode_batchsize_threshold\n            else:\n                enable_microbatch = batch_size >= \\\n                    self.enable_microbatch_prefill_batchsize_threshold and \\\n                    tokens_num >= self.enable_microbatch_prefill_token_threshold\n            dp_forward_meta.append(int(enable_microbatch))\n        group = self.dist_ctx.cpu_group\n        device = 'cpu'\n        gathered_meta = DistGatherScalar(dp_forward_meta, world_size, device=device, group=group)\n        gathered_meta = (await gathered_meta.async_wait()).cpu()\n\n        # check is_decoding\n        # if any one of the rank is prefill, then all ranks are prefill\n        is_decoding = gathered_meta[:, 0].all().item()\n        inputs.is_decoding = is_decoding\n\n        # check if all inputs are dummy inputs\n        is_all_dummy = gathered_meta[:, 1].all().item()\n        is_all_sleeping = gathered_meta[:, 3].all().item()\n        if is_all_dummy:\n            return None, is_all_sleeping\n\n        # pad batch size for decoding\n        all_num_tokens = gathered_meta[:, 2].tolist()\n        if is_decoding:\n            max_num_tokens = max(all_num_tokens)\n            meta = self.patched_model.get_meta()\n            meta.padding_batch_size = max_num_tokens\n            logger.debug(f'max_num_tokens={max_num_tokens}')\n\n        # update if enable_microbatch\n        if self.enable_microbatch:\n            inputs.enable_microbatch = gathered_meta[:, 4].all().item()\n\n        # update dp meta\n        inputs.build_dp_meta(all_num_tokens)\n        inputs = self.patched_model.update_inputs(inputs)\n        return inputs, is_all_sleeping\n\n    def _get_inputs_from_delta(\n        self,\n        delta: ModelInputsDelta,\n        sampling_inputs: SamplingInputs,\n    ):\n        \"\"\"Get inputs from delta.\"\"\"\n        self.step_inputs.update_delta(delta, self)\n        inputs = self.step_inputs.model_inputs\n        extra_inputs = self.step_inputs.extra_inputs\n        stopping_criteria = self.step_inputs.stopping_criteria\n        sampling_inputs.update_delta(self.step_inputs.sampling_delta)\n        return inputs, extra_inputs, stopping_criteria, sampling_inputs\n\n    def _prepare_inputs_prefill(\n        self,\n        inputs: ModelInputs,\n        delta: ModelInputsDelta,\n    ):\n        \"\"\"Prepare prefill inputs.\"\"\"\n\n        if delta is not None:\n            # update decoding inputs with delta\n            # for second round chat\n            self.step_inputs.update_delta(delta, self)\n\n        if inputs.is_first_chunk:\n            self._prev_chunk_output = None\n\n        # check long context\n        if self._prev_chunk_output is not None:\n            # update model metas\n            model_metas = self._prev_chunk_output.get('model_metas')\n            inputs.model_metas = model_metas\n\n            if not inputs.is_chunk:\n                # remove _prev_chunk_output\n                self._prev_chunk_output = None\n\n        return inputs\n\n    async def _step_postprocess_with_output(self,\n                                            last_logits: torch.Tensor,\n                                            logits: torch.Tensor,\n                                            inputs: ModelInputs,\n                                            sampling_inputs: SamplingInputs,\n                                            stopping_criteria: StoppingCriteria,\n                                            model_metas: Any,\n                                            need_broadcast_next: bool,\n                                            return_logits: bool = False,\n                                            all_routed_experts: Any = None,\n                                            extra_inputs: ExtraInputs = None):\n        \"\"\"Step postprocess with output.\"\"\"\n        rank = self.rank\n        logger.debug(f'<ForwardTask> rank[{rank}]: Sampling.')\n        # sampling\n        next_token_ids, logprobs = await self.async_sampling_logits(last_logits, sampling_inputs)\n\n        # post sampling\n        next_token_ids, extra_inputs = self.agent_strategy.post_sampling(inputs, last_logits, next_token_ids,\n                                                                         extra_inputs)\n\n        # spec decoding\n        output_token_ids = next_token_ids\n        if self.spec_agent.is_enabled():\n            extra_inputs = await self.spec_agent.async_model_forward(next_token_ids, inputs, extra_inputs,\n                                                                     sampling_inputs)\n            next_token_ids = extra_inputs.next_token_ids\n            output_token_ids = extra_inputs.output_token_ids\n            logits = None\n\n        with self._broadcast_next_token(next_token_ids, extra_inputs, enable=need_broadcast_next):\n            logger.debug(f'<ForwardTask> rank[{rank}]: synchronize token ids')\n\n            # stopping criteria\n            stopped, stop_pos, stopping_criteria = stopping_criteria.step(\n                next_token_ids,\n                sampling_inputs.stop_words,\n                inputs=inputs,\n                extra_inputs=extra_inputs,\n            )\n\n            # send output\n            logger.debug(f'<ForwardTask> rank[{rank}]: Output')\n            extra_outputs = self.agent_strategy.make_extra_outputs(extra_inputs)\n\n        self._push_output(\n            BatchedOutputs(next_token_ids=output_token_ids,\n                           logits=logits if return_logits else None,\n                           stopped=stopped,\n                           stop_pos=stop_pos,\n                           model_metas=model_metas,\n                           logprobs=logprobs,\n                           all_routed_experts=all_routed_experts,\n                           extra_outputs=extra_outputs))\n\n        return inputs, extra_inputs, stopping_criteria, extra_outputs, next_token_ids\n\n    async def _step_postprocess_without_output(\n        self,\n        inputs: ModelInputs,\n        last_logits: torch.Tensor,\n        extra_inputs: ExtraInputs,\n        need_broadcast_next: bool,\n    ):\n        rank = self.rank\n        # Avoid adding the ADInplaceOrView dispatch key to `next_token_ids`,\n        # as it can trigger recompilation on different ranks when using torch.compile.\n        next_token_ids, extra_inputs = self.agent_strategy.make_dummy_next_token(inputs, last_logits, extra_inputs)\n\n        # broadcast next token for TP > 1\n        with self._broadcast_next_token(next_token_ids, extra_inputs, enable=need_broadcast_next):\n            logger.debug(f'<ForwardTask> rank[{rank}]: synchronize token ids')\n\n        extra_outputs = self.agent_strategy.make_extra_outputs(extra_inputs)\n\n        return inputs, next_token_ids, extra_inputs, extra_outputs\n\n    async def _async_step(\n        self,\n        inputs: ModelInputs,\n        delta: ModelInputsDelta = None,\n        swap_in_map: Dict = None,\n        swap_out_map: Dict = None,\n        sampling_inputs: SamplingInputs = None,\n        stopping_criteria: StoppingCriteria = None,\n        return_logits: bool = False,\n        return_routed_experts: bool = False,\n        extra_inputs: ExtraInputs = None,\n    ):\n        \"\"\"Asyc forward task.\"\"\"\n\n        @record_function('update_decoding_for_next_step')\n        def __update_inputs(\n            inputs,\n            next_token_ids,\n            model_metas,\n            extra_inputs,\n            extra_outputs,\n            stopping_criteria,\n            sampling_delta: SamplingInputsDelta = None,\n        ):\n            \"\"\"Update inputs.\"\"\"\n            # dp might change is_decoding of decoding inputs\n            self.step_inputs.step(\n                inputs,\n                extra_inputs,\n                stopping_criteria,\n                sampling_delta,\n                next_token_ids,\n                model_metas,\n                extra_outputs,\n                model_agent=self,\n            )\n\n        dist_ctx = get_dist_manager().current_context()\n        dist_config = dist_ctx.dist_config\n        rank = self.rank\n        tp = dist_config.attn_tp\n        need_broadcast_next = (tp > 1)\n        dp = dist_config.dp\n        need_update_inputs = False\n\n        if inputs is None:\n            # decoding step, update prev_inputs with delta\n            need_update_inputs = True\n            assert delta is not None\n            (\n                inputs,\n                extra_inputs,\n                stopping_criteria,\n                sampling_inputs,\n            ) = self._get_inputs_from_delta(\n                delta,\n                sampling_inputs,\n            )\n        elif not inputs.is_dummy:\n            # prefill step\n            inputs = self._prepare_inputs_prefill(\n                inputs,\n                delta,\n            )\n\n        # dp might change is_decoding in inputs\n        is_decoding = inputs.is_decoding\n        if dp > 1:\n            # update inputs for dp\n            inputs, is_all_sleeping = await self._prepare_dp_v1(inputs)\n            # skip dummy forward.\n            if inputs is None:\n                if is_all_sleeping:\n                    self.state.to_sleep.set()\n                    await self.state.to_wakeup.wait()\n                    self.state.to_wakeup.clear()\n                    # sync after wakeup\n                    dist.barrier()\n                logger.debug(f'<ForwardTask> rank[{rank}]: all inputs are dummy, skip forward.')\n                await asyncio.sleep(0.01)\n                return\n\n        if not is_decoding:\n            # init state cache for first time prefill\n            # I don't know if this is necessary...\n            self.state_cache_engine.init_caches(inputs.state_offsets, inputs.history_lengths == 0)\n\n        # swap caches\n        cache_swapping(self.cache_engine, swap_in_map=swap_in_map, swap_out_map=swap_out_map)\n\n        # inference\n        logger.debug(f'<ForwardTask> rank[{rank}]: model forward. '\n                     f'batch_size={inputs.seq_length.size(0)} '\n                     f'num_tokens={inputs.input_ids.size(-1)} '\n                     f'is_decoding={inputs.is_decoding}')\n        output = await self._async_model_forward(\n            inputs,\n            return_logits=return_logits,\n        )\n        # recovery is_decoding\n        inputs.is_decoding = is_decoding\n\n        if inputs.is_dummy:\n            # skip dummy forward output\n            return\n\n        logits = output['logits'][0]  # [bs, seq, prob] -> [seq, prob]\n        seq_length = output.get('seq_length', inputs.seq_length)\n        last_logits = self._slice_outs(logits, seq_length)  # [bs, 1, prob] -> [bs, prob]\n        extra_inputs = self.agent_strategy.slice_extra_inputs(extra_inputs, inputs, output)\n        model_metas = output.get('model_metas')\n\n        if self.need_output:\n            logger.debug(f'<ForwardTask> rank[{rank}]: Sampling.')\n            # for router replay\n            if return_routed_experts:\n                all_routed_experts = output.get('all_routed_experts', None)\n            else:\n                all_routed_experts = None\n\n            (\n                inputs,\n                extra_inputs,\n                stopping_criteria,\n                extra_outputs,\n                next_token_ids,\n            ) = await self._step_postprocess_with_output(\n                last_logits,\n                logits,\n                inputs,\n                sampling_inputs,\n                stopping_criteria,\n                model_metas,\n                need_broadcast_next,\n                return_logits=return_logits,\n                all_routed_experts=all_routed_experts,\n                extra_inputs=extra_inputs,\n            )\n        else:\n            (\n                inputs,\n                next_token_ids,\n                extra_inputs,\n                extra_outputs,\n            ) = await self._step_postprocess_without_output(\n                inputs,\n                last_logits,\n                extra_inputs,\n                need_broadcast_next,\n            )\n\n        sampling_delta = sampling_inputs.get_delta()\n        if need_update_inputs:\n            __update_inputs(inputs,\n                            next_token_ids,\n                            model_metas,\n                            extra_inputs,\n                            extra_outputs,\n                            stopping_criteria,\n                            sampling_delta=sampling_delta)\n        elif inputs.is_chunk:\n            # _prev_chunk_output is used to update model metas\n            self._prev_chunk_output = output\n        elif self.cache_config.role != EngineRole.Prefill:\n            self.step_inputs.merge(\n                inputs,\n                extra_inputs,\n                stopping_criteria,\n                sampling_delta,\n                next_token_ids,\n                model_metas,\n                extra_outputs,\n                model_agent=self,\n            )\n\n    async def _async_loop_background(self, forward_event: asyncio.Event = None):\n        \"\"\"Async loop background.\"\"\"\n        with self.all_context(), torch.cuda.stream(self.stream), torch.inference_mode():\n\n            # for dp\n            input_maker = build_inputs_maker(self)\n\n            while True:\n                forward_inputs = await input_maker.get()\n\n                await self._async_step(**forward_inputs, )\n                if forward_event is not None:\n                    forward_event.set()\n\n                input_maker.step()\n\n    async def _async_loop_inputs_preprocess(self, forward_event: asyncio.Event = None):\n        \"\"\"Async loop inputs preprocess.\"\"\"\n        non_blocking = True\n        keys = ['inputs', 'delta', 'sampling_inputs', 'stopping_criteria', 'extra_inputs']\n        while True:\n            forward_inputs = await self._pre_in_que.get()\n            forward_inputs_cuda = {}\n            forward_inputs_cuda.update(forward_inputs)\n            logger.debug('preprocessing forward inputs.')\n            with torch.cuda.stream(self.out_stream), torch.inference_mode(), record_function('inputs_H2D'):\n                for k in keys:\n                    if k not in forward_inputs_cuda:\n                        continue\n                    forward_inputs_cuda[k] = _try_to_cuda(forward_inputs_cuda[k], non_blocking=non_blocking)\n                self.out_stream.synchronize()\n            logger.debug('preprocessing forward inputs done.')\n            self._in_que.put_nowait(forward_inputs_cuda)\n            if forward_event is not None:\n                forward_event.clear()\n\n    def start(self, forward_event: asyncio.Event = None):\n        \"\"\"Start event loop.\"\"\"\n        event_loop = asyncio.get_event_loop()\n        self._pre_in_que = asyncio.Queue()\n        self._in_que = asyncio.Queue()\n        self._out_que = asyncio.Queue()\n\n        # forward task\n        logger.debug('Create task ModelAgentLoop.')\n        self._background_task = event_loop.create_task(self._async_loop_background(forward_event),\n                                                       name='ModelAgentLoop')\n        self.tasks.add(self._background_task)\n        self._background_task.add_done_callback(self.tasks.discard)\n\n        # preprocess inputs task\n        logger.debug('Create task ModelAgentPreprocess.')\n        self._preprocess_task = event_loop.create_task(self._async_loop_inputs_preprocess(forward_event),\n                                                       name='ModelAgentPreprocess')\n        self.tasks.add(self._preprocess_task)\n        self._preprocess_task.add_done_callback(self.tasks.discard)\n\n        # profiler\n        self.profiler = AgentProfiler(self.dist_ctx, self.stream)\n        self.profiler.create_task()\n\n    async def wait_tasks(self):\n        \"\"\"Wait tasks.\"\"\"\n        if len(self.tasks) == 0:\n            return\n        try:\n            await wait_for_async_tasks(self.tasks)\n        except asyncio.CancelledError:\n            logger.debug(f'ModelAgent rank[{self.rank}] wait_tasks cancelled.')\n            raise\n        except BaseException as e:\n            raise e from None\n        finally:\n            logger.debug(f'ModelAgent rank[{self.rank}] wait_tasks cleanup.')\n\n    def stop(self):\n        \"\"\"Stop task.\"\"\"\n        if self.dist_config.dp > 1:\n            return\n\n        if self.profiler is not None:\n            self.profiler.dump()\n\n        for task in self.tasks:\n            if not task.done():\n                task.cancel()\n\n        if self.guided_decoding_manager:\n            self.guided_decoding_manager.clear()\n\n    async def stop_async(self):\n        \"\"\"Stop task.\"\"\"\n        if self.dist_config.dp > 1:\n            return\n\n        if self.profiler is not None:\n            # dirty hack for profiler\n            while not self.stream.query():\n                logger.debug('Profiler waiting for stream finish.')\n                await asyncio.sleep(1)\n            self.profiler.dump()\n\n        for task in self.tasks:\n            if not task.done():\n                task.cancel()\n\n        try:\n            await asyncio.gather(*self.tasks, return_exceptions=True)\n        except asyncio.CancelledError:\n            logger.debug(f'ModelAgent {task.get_name()} task cancelled.')\n\n        if self.guided_decoding_manager:\n            self.guided_decoding_manager.clear()\n\n    def set_forward_inputs(self, inputs):\n        \"\"\"Set forward inputs.\"\"\"\n        assert self._pre_in_que is not None, ('Please start backendground task before forward.')\n        self._pre_in_que.put_nowait(inputs)\n\n    async def get_output_async(self):\n        \"\"\"Async get output.\"\"\"\n        assert self._out_que is not None, ('Please start backendground task before forward.')\n        out = await self._out_que.get()\n        if out is None:\n            return dict()\n\n        out, event = out\n        while not event.query():\n            await asyncio.sleep(0.001)\n        with torch.cuda.stream(self.out_stream), torch.inference_mode(), record_function('outputs_D2H'):\n            event.wait()\n            out = out.to_cpu()\n            out.new_token_timestamp = time.time()\n        return out\n\n    def _build_model(self):\n        \"\"\"Build patched model.\"\"\"\n        model_path = self.model_path\n        adapters = self.adapters\n        device = self.device\n        rank = self.rank\n        custom_module_map = self.model_config.custom_module_map\n        if custom_module_map is not None:\n            update_custom_module_map(custom_module_map)\n        logger.debug(msg_with_rank(rank, 'build model.'))\n        # for router replay\n        enable_return_routed_experts = self.misc_config.enable_return_routed_experts and self.need_output\n\n        build_model_ctx = BuildModelContext(\n            disable_vision_encoder=self.misc_config.disable_vision_encoder,\n            dllm_config=self.misc_config.dllm_config,\n            strategy_factory=self.strategy_factory,\n            enable_return_routed_experts=enable_return_routed_experts,\n            quant_config=self.model_config.quant_config,\n            fp32_lm_head=self.model_config.fp32_lm_head,\n            tie_word_embeddings=self.model_config.tie_word_embeddings,\n        )\n        patched_model = build_patched_model(self.model_config, device=device, build_model_ctx=build_model_ctx)\n        logger.debug(msg_with_rank(rank, 'loading weights.'))\n        if not self.misc_config.empty_init:\n            load_model_weights(patched_model, model_path, device=device)\n        if adapters is not None:\n            logger.debug(msg_with_rank(rank, 'loading adapters.'))\n            add_adapters(patched_model, adapters, dtype=self.model_config.dtype, device=device)\n        self.patched_model = patched_model\n        self.build_model_ctx = build_model_ctx\n\n    def build_model(self):\n        \"\"\"Build model api.\"\"\"\n        with self.all_context():\n            self._build_model()\n            self.spec_agent.build_model(self.misc_config.empty_init,\n                                        self.patched_model,\n                                        build_model_ctx=self.build_model_ctx)\n\n    def build_graph_runner(self):\n        \"\"\"Build graph runner.\"\"\"\n        with self.all_context():\n            backend = get_backend()\n            self.patched_model = backend.build_graph_runner(self.patched_model,\n                                                            model_config=self.model_config,\n                                                            cache_config=self.cache_config,\n                                                            backend_config=self.backend_config,\n                                                            device=self.device)\n            self.spec_agent.build_graph_runner()\n\n    def build_cache_engine(self):\n        \"\"\"Build cache engine.\"\"\"\n        with self.all_context():\n            dist_ctx = get_dist_manager().current_context()\n            dist_cfg = self.dist_config\n            tp = dist_cfg.attn_tp\n\n            self.cache_engine = CacheEngine(self.cache_config,\n                                            self.model_config,\n                                            rank=self.rank,\n                                            tp_rank=dist_ctx.attn_tp_group.rank,\n                                            world_size=tp,\n                                            cache_stream=self.cache_stream)\n            self.state_cache_engine = StateCacheEngine(self.cache_config)\n\n            self.spec_agent.build_cache_engine(self.cache_stream)\n\n    def _forward_impl(self, inputs: ModelInputs):\n        output = model_forward(\n            self.patched_model,\n            inputs,\n            self.cache_engine,\n            state_cache_engine=self.state_cache_engine,\n            stream=self.stream,\n        )\n        return output\n\n    async def async_forward(self, inputs: ModelInputs):\n        \"\"\"Model forward.\n\n        Args:\n            inputs (Dict): The input data comes from _make_inputs.\n            swap_in_map (SwapMap): Cache maps to swap in.\n            swap_out_map (SwapMap): Cache maps to swap out.\n        \"\"\"\n        output = self._forward_impl(inputs)\n        await asyncio.sleep(0)\n        return output\n\n    @record_function('get_logits')\n    def get_logits(self, hidden_states: torch.Tensor):\n        \"\"\"Get logits of model output.\"\"\"\n        return self.patched_model.get_logits(hidden_states)\n\n    def get_input_processor(self):\n        \"\"\"Get input processor.\"\"\"\n        return self.patched_model.get_input_processor()\n\n    def reset_graph_runner(self):\n        \"\"\"Reset graph runner to prevent tp hanging.\"\"\"\n        if hasattr(self.patched_model, 'reset'):\n            self.patched_model.reset()\n\n        self.spec_agent.reset_graph_runner()\n\n    @torch.inference_mode()\n    def update_params(self, request: UpdateParamsRequest):\n        \"\"\"Update params.\"\"\"\n\n        # modified from https://github.com/vllm-project/vllm/blob/v0.8.5/examples/offline_inference/rlhf_utils.py#L82\n        def _construct(item):\n            func, args = item\n            args = list(args)\n            args[6] = torch.cuda.current_device()  # device id.\n            # clone() seems necessary otherwise the producer can not release the memory\n            return func(*args).clone()\n\n        with self.all_context():\n            serialized_data = request.serialized_named_tensors\n            if isinstance(serialized_data, list):\n                serialized_data = serialized_data[self.dist_ctx.tp_group.rank]\n            model = self.patched_model.get_model()\n            weights = ForkingPickler.loads(pybase64.b64decode(serialized_data))\n            if request.load_format == 'flattened_bucket':\n                metadata: List[FlattenedTensorMetadata] = weights['metadata']\n                if metadata:\n                    flattened_tensor: torch.Tensor = _construct(weights['flattened_tensor'])\n                    bucket = FlattenedTensorBucket(flattened_tensor=flattened_tensor, metadata=metadata)\n                    weights = bucket.reconstruct_tensors()\n                else:\n                    # empty data\n                    weights = []\n            else:\n                weights = [(k, _construct(v)) for k, v in weights]\n\n            weights = ModelWeightLoader._rename_weights_iterator(weights, model)\n            model.load_weights(weights)\n\n            if request.finished:\n                for _, mod in model.named_modules():\n                    if not hasattr(mod, 'update_weights'):\n                        continue\n                    mod.update_weights()\n\n            torch.cuda.empty_cache()\n\n    @torch.inference_mode()\n    async def sleep(self, level: int = 1):\n        \"\"\"Sleep.\"\"\"\n        self.state.is_sleeping = True\n        if self.dist_config.dp > 1:\n            await self.state.to_sleep.wait()\n        self.cache_engine = None\n        self.reset_graph_runner()\n        device = 'cpu' if level == 1 else 'meta'\n        self.patched_model.get_model().to(device=device, non_blocking=True)\n        torch.cuda.synchronize()\n        torch.cuda.empty_cache()\n        self.state.to_sleep.clear()\n\n    @torch.inference_mode()\n    def wakeup(self, tags: Optional[List[str]] = None):\n        \"\"\"Wakeup.\"\"\"\n        if tags is None:\n            tags = ['weights', 'kv_cache']\n        if 'weights' in tags:\n            device = next(self.patched_model.get_model().parameters()).device\n            assert device.type in ['cpu', 'meta']\n            if device.type == 'cpu':\n                self.patched_model.get_model().to(torch.cuda.current_device())\n            else:\n                # user should update weights after wakeup\n                old_empty_init = self.misc_config.empty_init\n                self.misc_config.empty_init = True\n                self.build_model()\n                self.build_graph_runner()\n                self.misc_config.empty_init = old_empty_init\n\n        if 'kv_cache' in tags:\n            self.build_cache_engine()\n            # wake up signal\n            self.state.is_sleeping = False\n            if self.dist_config.dp > 1:\n                self.state.to_wakeup.set()\n\n    def release(self):\n        \"\"\"release.\"\"\"\n        self.reset_graph_runner()\n        self.patched_model = None\n        self.cache_engine = None\n        torch.cuda.empty_cache()\n"
  },
  {
    "path": "lmdeploy/pytorch/engine/model_agent/inputs_maker.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport asyncio\nfrom typing import TYPE_CHECKING\n\nimport torch\nimport torch.distributed as dist\n\nfrom lmdeploy.pytorch.disagg.config import EngineRole\n\nif TYPE_CHECKING:\n    from .agent import BaseModelAgent\n\n\nclass DefaultForwardInputsMaker:\n    \"\"\"Default forward inputs maker.\"\"\"\n\n    def __init__(self, model_agent: 'BaseModelAgent'):\n        self._in_que = model_agent._in_que\n\n    async def get(self):\n        \"\"\"get.\"\"\"\n        return await self._in_que.get()\n\n    def step(self):\n        \"\"\"step.\"\"\"\n        # No-op for default maker\n        pass\n\n\nclass DPForwardInputsMaker:\n    \"\"\"Dp forward inputs maker.\"\"\"\n\n    def __init__(self, model_agent: 'BaseModelAgent'):\n        self.model_agent = model_agent\n        self.dist_ctx = model_agent.dist_ctx\n        self.model_config = model_agent.model_config\n        self.cache_config = model_agent.cache_config\n        self.inputs_strategy = model_agent.inputs_strategy\n        self.device = model_agent.device\n        self._in_que = model_agent._in_que\n\n        # maker metas\n        self._ready_event = torch.cuda.Event()\n        self._ready_event.record()\n\n    def _make_dummy_forward_inputs(self):\n        \"\"\"Make dummy forward inputs.\"\"\"\n        is_decoding = self.cache_config.role != EngineRole.Prefill\n        dist_config = self.dist_ctx.dist_config\n        batch_size = 2 if dist_config.enable_microbatch else 1\n        batch_size = min(self.cache_config.max_batches, batch_size)\n        model_inputs = self.inputs_strategy.make_dummy(batch_size,\n                                                       is_decoding,\n                                                       device=self.device,\n                                                       vocab_size=self.model_config.vocab_size)\n        forward_inputs = dict(inputs=model_inputs, )\n        return forward_inputs\n\n    async def _gather_has_inputs(self, has_inputs: bool = False):\n        \"\"\"Broadcast has inputs.\"\"\"\n        attn_tp_group = self.dist_ctx.attn_tp_group\n        attn_tp = self.dist_ctx.dist_config.attn_tp\n        if attn_tp == 1:\n            return has_inputs\n\n        group = attn_tp_group.cpu_group\n        has_inputs = torch.tensor((int(has_inputs), ))\n        handle = dist.all_reduce(has_inputs, op=dist.ReduceOp.SUM, group=group, async_op=True)\n        future = handle.get_future()\n        while not future.done():\n            await asyncio.sleep(0)\n        future.wait()\n        return (has_inputs > 0).item()\n\n    async def _get_inputs(self):\n        # get local forward inputs\n        try:\n            forward_inputs = self._in_que.get_nowait()\n        except asyncio.QueueEmpty:\n            forward_inputs = None\n\n        # async inputs around tp group\n        has_inputs = await self._gather_has_inputs(forward_inputs is not None)\n        if has_inputs and forward_inputs is None:\n            forward_inputs = await self._in_que.get()\n\n        return forward_inputs\n\n    async def get(self):\n        \"\"\"get.\"\"\"\n        # # wait until has inputs or prev forward finish\n        while self._in_que.qsize() == 0 and not self._ready_event.query():\n            await asyncio.sleep(0.001)\n\n        # try get inputs\n        forward_inputs = await self._get_inputs()\n\n        # make dummy inputs\n        if forward_inputs is None:\n            forward_inputs = self._make_dummy_forward_inputs()\n\n        return forward_inputs\n\n    def step(self):\n        \"\"\"step.\"\"\"\n        self._ready_event.wait()\n        self._ready_event = torch.cuda.Event()\n        self._ready_event.record()\n\n\ndef build_inputs_maker(model_agent: 'BaseModelAgent'):\n    \"\"\"Build inputs maker.\"\"\"\n    dist_config = model_agent.dist_ctx.dist_config\n    if dist_config.dp > 1:\n        return DPForwardInputsMaker(model_agent)\n    else:\n        return DefaultForwardInputsMaker(model_agent)\n"
  },
  {
    "path": "lmdeploy/pytorch/engine/model_agent/profiler.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport asyncio\n\nimport torch\nfrom torch.profiler import ProfilerActivity, profile\n\nfrom lmdeploy.pytorch.distributed import DistContext\nfrom lmdeploy.utils import get_logger\n\nlogger = get_logger('lmdeploy')\n\n\nclass AgentProfiler:\n\n    def __init__(self, dist_ctx: DistContext, stream: torch.Stream):\n        from lmdeploy.pytorch import envs\n        self.rank = dist_ctx.rank\n        self.dp_rank = dist_ctx.dp_rank\n        self.dp = dist_ctx.dist_config.dp\n        self.stream = stream\n        self.profiler = None\n        self.name = f'rank[{self.rank}]'\n\n        self.delay = envs.torch_profile_delay\n        self.duration = envs.torch_profile_duration\n\n        self.profiler = self._build_profiler()\n        self.prefix = envs.torch_profile_output_prefix\n        self._task = None\n        self._started = False\n        if self.dp > 1 and self.duration < 0 and self.profiler is not None:\n            logger.warning('Do not support duration<=0 for dp > 1.')\n            self.profiler = None\n\n    def _build_profiler(self):\n        from lmdeploy.pytorch import envs\n        activities = []\n        if envs.torch_profile_cpu:\n            activities.append(ProfilerActivity.CPU)\n        if envs.torch_profile_cuda:\n            activities.append(ProfilerActivity.CUDA)\n        if len(activities) > 0:\n            logger.warning(f'Profiler start on {self.name}. '\n                           'Please Note that profiling might harm performance.')\n            profiler = profile(activities=activities)\n            return profiler\n        else:\n            return None\n\n    def dump(self):\n        \"\"\"Dump profile result.\"\"\"\n        if self.profiler is None:\n            return\n\n        if not self._started:\n            logger.warning(f'Profiler {self.name} not started, skip dump.')\n            return\n\n        try:\n            self.profiler.stop()\n            rank = self.rank\n            dump_path = f'{self.prefix}{rank}.json'\n            self.profiler.export_chrome_trace(dump_path)\n            logger.warning(f'Profiler {self.name} dump to {dump_path}.')\n        except Exception as e:\n            logger.error(f'Failed to dump profile {self.name} result: {e}')\n        finally:\n            self.profiler = None\n\n    async def profile_task(self):\n        \"\"\"Profile task.\"\"\"\n        if self.profiler is None:\n            return\n\n        # start profiler with delay\n        await asyncio.sleep(self.delay)\n        self.profiler.start()\n        self._started = True\n\n        if self.duration <= 0:\n            return\n\n        # dump profiler\n        await asyncio.sleep(self.duration)\n        self.dump()\n\n    def create_task(self):\n        \"\"\"Create task.\"\"\"\n        event_loop = asyncio.get_event_loop()\n        self._task = event_loop.create_task(self.profile_task())\n"
  },
  {
    "path": "lmdeploy/pytorch/engine/mp_engine/__init__.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom lmdeploy.messages import PytorchEngineConfig\n\n\ndef build_mp_engine(backend: str, model_path: str, engine_config: PytorchEngineConfig = None, **kwargs):\n    \"\"\"Build mp engine.\"\"\"\n    if backend == 'mp':\n        from .zmq_engine import ZMQMPEngine\n        return ZMQMPEngine(model_path, engine_config=engine_config, **kwargs)\n    elif backend == 'ray':\n        from .ray_engine import RayMPEngine\n        return RayMPEngine(model_path, engine_config=engine_config, **kwargs)\n    else:\n        raise ValueError(f'Unsupported backend: {backend}')\n"
  },
  {
    "path": "lmdeploy/pytorch/engine/mp_engine/base.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport asyncio\nfrom collections import defaultdict\nfrom dataclasses import dataclass, field\nfrom typing import Any, List, Optional\n\nfrom lmdeploy.messages import ResponseType\nfrom lmdeploy.pytorch.disagg.conn.protocol import (DistServeConnectionRequest, DistServeDropConnectionRequest,\n                                                   DistServeInitRequest)\nfrom lmdeploy.utils import get_logger\n\nfrom ..base import EngineBase, EngineInstanceBase\n\nlogger = get_logger('lmdeploy')\n\n\n@dataclass\nclass SessionState:\n    is_exists: asyncio.Event = field(default_factory=asyncio.Event)\n\n\nclass MPEngine(EngineBase):\n\n    def __init__(self) -> None:\n        \"\"\"Initialize mp engine.\"\"\"\n        self.session_states = defaultdict(SessionState)\n        self.engine_config = self._collective_rpc('get_engine_config')\n\n    def _collective_rpc(self, func, *args, **kwargs):\n        \"\"\"Collective rpc call.\"\"\"\n        raise NotImplementedError('This method has not been implemented yet.')\n\n    async def _collective_rpc_async(self, func, *args, **kwargs):\n        \"\"\"Collective rpc call.\"\"\"\n        raise NotImplementedError('This method has not been implemented yet.')\n\n    async def _collective_rpc_streaming_async(self, func, *args, **kwargs):\n        \"\"\"Collective rpc call.\"\"\"\n        raise NotImplementedError('This method has not been implemented yet.')\n\n    def close(self) -> None:\n        \"\"\"Close mp engine.\"\"\"\n        raise NotImplementedError('This method has not been implemented yet.')\n\n    def start_loop(self) -> None:\n        \"\"\"Start mp engine loop.\"\"\"\n        raise NotImplementedError('This method has not been implemented yet.')\n\n    def end_session(self, session_id: int):\n        \"\"\"End session.\"\"\"\n        return self._collective_rpc('end_session', session_id)\n\n    def sleep(self, level: int):\n        \"\"\"sleep.\"\"\"\n        return self._collective_rpc('sleep', level)\n\n    def wakeup(self, tags: Optional[List[str]] = None):\n        \"\"\"Wakeup.\"\"\"\n        return self._collective_rpc('wakeup', tags)\n\n    def update_params(self, request: Any):\n        \"\"\"Update params.\"\"\"\n        return self._collective_rpc('update_params', request)\n\n    def get_schedule_metrics(self):\n        \"\"\"Get schedule metrics.\"\"\"\n        return self._collective_rpc('get_schedule_metrics')\n\n    def p2p_initialize(self, conn_request: DistServeInitRequest):\n        \"\"\"Init rdma link.\"\"\"\n        return self._collective_rpc('p2p_initialize', conn_request)\n\n    def p2p_connect(self, conn_request: DistServeConnectionRequest):\n        \"\"\"rdma_connect.\"\"\"\n        return self._collective_rpc('p2p_connect', conn_request)\n\n    def p2p_drop_connect(self, drop_conn_request: DistServeDropConnectionRequest):\n        \"\"\"Drop connection.\n\n        1. drop engine connection (zmq connection)\n        2. TODO(JimyMa) drop RDMA Connection.\n        \"\"\"\n        return self._collective_rpc('p2p_drop_connect', drop_conn_request)\n\n    def create_instance(self, cuda_stream_id=0):\n        \"\"\"Create instance.\"\"\"\n        return MPEngineInstance(self)\n\n\nclass MPEngineInstance(EngineInstanceBase):\n    \"\"\"MP Engine Instance.\"\"\"\n\n    def __init__(self, engine: MPEngine):\n        self.engine = engine\n        self.session_states = engine.session_states\n\n    async def async_end(self, session_id: int):\n        \"\"\"End the given session.\"\"\"\n        if session_id not in self.session_states:\n            logger.warning(f'Session {session_id} not found when end session.')\n            return ResponseType.SESSION_NOT_EXIST\n        await self.session_states[session_id].is_exists.wait()\n        ret = await self.engine._collective_rpc_async('instance_async_end', session_id)\n        self.session_states.pop(session_id)\n        return ret\n\n    async def async_cancel(self, session_id: int):\n        \"\"\"Stop current streaming inference.\"\"\"\n        if session_id not in self.session_states:\n            logger.warning(f'Session {session_id} not found when cancel session.')\n            return ResponseType.SESSION_NOT_EXIST\n        await self.session_states[session_id].is_exists.wait()\n        return await self.engine._collective_rpc_async('instance_async_cancel', session_id)\n\n    async def async_stream_infer(self, session_id: int, *args, **kwargs):\n        \"\"\"Send stream inference request.\"\"\"\n        state = self.session_states[session_id]\n        kwargs['session_id'] = session_id\n        kwargs['notify_add_msg'] = True\n        generator = self.engine._collective_rpc_streaming_async('instance_async_stream_infer', *args, **kwargs)\n        # session should have been added\n        state.is_exists.set()\n\n        async for result in generator:\n            yield result\n"
  },
  {
    "path": "lmdeploy/pytorch/engine/mp_engine/base_worker.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport asyncio\nfrom contextlib import asynccontextmanager\nfrom typing import TYPE_CHECKING, Any, List, Optional\n\nfrom lmdeploy.messages import EngineOutput\nfrom lmdeploy.pytorch.disagg.conn.protocol import (DistServeConnectionRequest, DistServeDropConnectionRequest,\n                                                   DistServeInitRequest)\nfrom lmdeploy.utils import get_logger\n\nlogger = get_logger('lmdeploy')\n\nif TYPE_CHECKING:\n    from lmdeploy.pytorch.engine.engine import Engine\n\n\nclass EngineInstancePool:\n    \"\"\"Engine Instance Pool.\"\"\"\n\n    def __init__(self, engine):\n        from lmdeploy.pytorch.engine import Engine\n        self.engine: Engine = engine\n        # enlarge `num_instance`, otherwise an sequence cannot be stopped in time\n        self.num_instance = self.engine.engine_config.max_batch_size * 2\n        self.pool = None\n\n    def create_instance_pool(self, num_instance: int):\n        \"\"\"Create instance pool.\"\"\"\n        pool = asyncio.Queue(maxsize=num_instance)\n        for _ in range(num_instance):\n            instance = self.engine.create_instance()\n            pool.put_nowait(instance)\n        return pool\n\n    @asynccontextmanager\n    async def instance(self):\n        \"\"\"Get an instance from the pool.\"\"\"\n        # lazy create pool\n        if self.pool is None:\n            self.pool = self.create_instance_pool(self.num_instance)\n        instance = await self.pool.get()\n        try:\n            yield instance\n        finally:\n            self.pool.put_nowait(instance)\n\n    async def async_end(self, session_id: int):\n        \"\"\"End the given session.\"\"\"\n        async with self.instance() as instance:\n            return await instance.async_end(session_id)\n\n    async def async_cancel(self, session_id: int):\n        \"\"\"Stop current streaming inference.\"\"\"\n        async with self.instance() as instance:\n            return await instance.async_cancel(session_id)\n\n    async def async_stream_infer(self, *args, **kwargs):\n        \"\"\"Send stream inference request.\"\"\"\n        async with self.instance() as instance:\n            async for result in instance.async_stream_infer(*args, **kwargs):\n                yield result\n\n\nclass EngineWorkerBase:\n    \"\"\"Base class for engine worker.\"\"\"\n\n    def __init__(self, engine: 'Engine'):\n        engine.start_loop()\n        self.engine = engine\n        self.instance_pool = EngineInstancePool(engine)\n\n    def end_session(self, session_id: int):\n        \"\"\"End session.\"\"\"\n        return self.engine.end_session(session_id)\n\n    def get_engine_config(self):\n        \"\"\"Get engine config.\"\"\"\n        return self.engine.get_engine_config()\n\n    def get_schedule_metrics(self):\n        \"\"\"Get schedule metrics.\"\"\"\n        return self.engine.get_schedule_metrics()\n\n    def p2p_initialize(self, conn_request: DistServeInitRequest):\n        \"\"\"Init rdma link.\"\"\"\n        return self.engine.p2p_initialize(conn_request)\n\n    def p2p_connect(self, conn_request: DistServeConnectionRequest):\n        \"\"\"rdma_connect.\"\"\"\n        return self.engine.p2p_connect(conn_request)\n\n    def p2p_drop_connect(self, drop_conn_request: DistServeDropConnectionRequest):\n        \"\"\"Drop connection.\n\n        1. drop engine connection (zmq connection)\n        2. TODO(JimyMa) drop RDMA Connection.\n        \"\"\"\n        return self.engine.p2p_drop_connect(drop_conn_request)\n\n    def sleep(self, level: int = 1):\n        \"\"\"sleep.\"\"\"\n        return self.engine.sleep(level)\n\n    def wakeup(self, tags: Optional[List[str]] = None):\n        \"\"\"Wakeup.\"\"\"\n        return self.engine.wakeup(tags)\n\n    def update_params(self, request: Any):\n        \"\"\"Update params.\"\"\"\n        return self.engine.update_params(request)\n\n    def close(self) -> None:\n        \"\"\"Close engine worker.\"\"\"\n        self.engine.close()\n\n    async def instance_async_end(self, session_id: int):\n        \"\"\"End the given session.\"\"\"\n        return await self.instance_pool.async_end(session_id)\n\n    async def instance_async_cancel(self, session_id: int):\n        \"\"\"Stop current streaming inference.\"\"\"\n        return await self.instance_pool.async_cancel(session_id)\n\n    async def instance_async_stream_infer(self, *args, **kwargs):\n        \"\"\"Send stream inference request.\"\"\"\n        async for result in self.instance_pool.async_stream_infer(*args, **kwargs):\n            yield result\n\n\nclass EngineOutputGather:\n    \"\"\"Helper class to gather incremental engine output.\"\"\"\n\n    def __init__(self):\n        self._output = dict()\n\n    def get(self, stream_id):\n        if stream_id not in self._output:\n            self._output[stream_id] = EngineOutput(status=None, token_ids=[], logprobs=[])\n        return self._output[stream_id]\n\n    def add(self, stream_id, result):\n        if not isinstance(result, EngineOutput):\n            return\n        output = self.get(stream_id)\n        output.token_ids.extend(result.token_ids or [])\n        output.logprobs.extend(result.logprobs or [])\n\n    def pop(self, stream_id, result):\n        if not isinstance(result, EngineOutput):\n            return result\n        output = self._output.pop(stream_id)\n        result.token_ids = output.token_ids or []\n        result.logprobs = output.logprobs or None\n        return result\n"
  },
  {
    "path": "lmdeploy/pytorch/engine/mp_engine/ray_engine.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport asyncio\nfrom typing import Dict\n\nimport ray\nfrom ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy\n\nfrom lmdeploy.messages import PytorchEngineConfig\nfrom lmdeploy.pytorch import envs as _envs\nfrom lmdeploy.pytorch.ray import RayContext, get_device_str, get_resource_kwargs\nfrom lmdeploy.utils import get_logger\n\nfrom .base import MPEngine\nfrom .base_worker import EngineOutputGather, EngineWorkerBase\n\nlogger = get_logger('lmdeploy')\n\n\nclass RayEngineWorker(EngineWorkerBase):\n\n    def __init__(self,\n                 model_path: str,\n                 engine_config: PytorchEngineConfig = None,\n                 log_level: int = 30,\n                 **kwargs) -> None:\n        \"\"\"Initialize Ray engine worker.\"\"\"\n        from lmdeploy.pytorch.engine.engine import Engine\n        logger.setLevel(log_level)\n        # create engine\n        if engine_config is not None:\n            engine_config.enable_mp_engine = False\n        engine = Engine.from_pretrained(model_path, engine_config=engine_config, **kwargs)\n        super().__init__(engine)\n\n        self._stream_id = 0\n        self._stream_aiter = dict()\n        self._stream_task = dict()\n        self._engine_output_gather = EngineOutputGather()\n\n    async def _stream_task_wrapper(self, stream_id: int, init_event: asyncio.Event, func: str, *args, **kwargs):\n        \"\"\"Create a stream task.\"\"\"\n        method = getattr(self, func)\n        event = self._stream_aiter[stream_id][0]\n        try:\n            generator = method(*args, **kwargs)\n            init_event.set()\n            async for result in generator:\n                self._engine_output_gather.add(stream_id, result)\n                self._stream_aiter[stream_id][1] = (result, False)\n                event.set()\n        finally:\n            self._stream_aiter[stream_id][1] = (result, True)\n            event.set()\n            init_event.set()\n\n    async def create_stream_task(self, func, *args, **kwargs):\n        \"\"\"Create a stream task.\"\"\"\n        stream_id = self._stream_id\n        self._stream_id += 1\n        event_loop = asyncio.get_event_loop()\n        self._stream_aiter[stream_id] = [asyncio.Event(), None]\n        init_event = asyncio.Event()\n        task = event_loop.create_task(self._stream_task_wrapper(stream_id, init_event, func, *args, **kwargs))\n        self._stream_task[stream_id] = task\n        await init_event.wait()\n\n        return stream_id\n\n    async def get_stream_task_result(self, stream_id: int):\n        \"\"\"Get the result of a stream task.\"\"\"\n        assert stream_id in self._stream_aiter, f'Stream id {stream_id} not found.'\n        stopped = False\n\n        event = self._stream_aiter[stream_id][0]\n        await event.wait()\n        result, stopped = self._stream_aiter[stream_id][1]\n        event.clear()\n\n        result = self._engine_output_gather.pop(stream_id, result)\n\n        if stopped:\n            self._stream_aiter.pop(stream_id, None)\n            self._stream_task.pop(stream_id, None)\n        return result, stopped\n\n\ndef _update_runtime_envs(runtime_env: Dict):\n    \"\"\"Update runtime envs.\"\"\"\n    new_envs = _envs.get_all_envs()\n    env_vars: Dict = runtime_env.get('env_vars', {})\n    env_vars.update(new_envs)\n    runtime_env['env_vars'] = env_vars\n    return runtime_env\n\n\nclass RayMPEngine(MPEngine):\n\n    def __init__(self, model_path: str, engine_config: PytorchEngineConfig = None, **kwargs) -> None:\n        \"\"\"Initialize mp engine.\"\"\"\n        self.ray_ctx = self._init_ray(engine_config)\n        placement_group = self.ray_ctx.get_placement_group()\n        self.placement_group = placement_group\n\n        self.worker = self._create_worker(model_path, engine_config, log_level=logger.level, **kwargs)\n        super().__init__()\n\n    def _init_ray(self, engine_config: PytorchEngineConfig = None):\n        \"\"\"Initialize Ray.\"\"\"\n        if engine_config is None:\n            engine_config = PytorchEngineConfig()\n\n        device_type = engine_config.device_type if engine_config else 'cuda'\n        dp = engine_config.dp if engine_config else 1\n        world_size = engine_config.tp if dp <= 1 else 1\n\n        ray_ctx = RayContext(world_size, dp=dp, device_type=device_type)\n        return ray_ctx\n\n    def _create_worker(self, model_path: str, engine_config: PytorchEngineConfig = None, **kwargs):\n        \"\"\"Create a Ray worker.\"\"\"\n        bundle_id = 0 if len(_envs.ray_external_pg_bundles) == 0 else _envs.ray_external_pg_bundles[0]\n        scheduling_strategy = PlacementGroupSchedulingStrategy(\n            placement_group=self.placement_group,\n            placement_group_capture_child_tasks=True,\n            placement_group_bundle_index=bundle_id,\n        )\n\n        runtime_env = dict()\n        _update_runtime_envs(runtime_env)\n        device_str = get_device_str(engine_config.device_type)\n        resource_kwargs = get_resource_kwargs(device_str=device_str, resource_used=0.01)\n        worker = ray.remote(\n            num_cpus=0,\n            **resource_kwargs,\n            scheduling_strategy=scheduling_strategy,\n            runtime_env=runtime_env,\n        )(RayEngineWorker).remote(model_path, engine_config, **kwargs)\n\n        return worker\n\n    def _collective_rpc(self, func, *args, **kwargs):\n        \"\"\"Collective rpc call.\"\"\"\n        method = getattr(self.worker, func)\n        return ray.get(method.remote(*args, **kwargs))\n\n    async def _collective_rpc_async(self, func, *args, **kwargs):\n        \"\"\"Collective rpc call.\"\"\"\n        method = getattr(self.worker, func)\n        return await method.remote(*args, **kwargs)\n\n    async def _collective_rpc_streaming_async(self, func, *args, **kwargs):\n        \"\"\"Collective rpc call.\"\"\"\n        # ray generator would try cache every result, which is too verbose.\n        stream_id = await self._collective_rpc_async('create_stream_task', func, *args, **kwargs)\n\n        stopped = False\n        while not stopped:\n            result, stopped = await self._collective_rpc_async('get_stream_task_result', stream_id)\n            yield result\n\n    def close(self) -> None:\n        \"\"\"Close mp engine.\"\"\"\n        logger.info('Closing mp engine.')\n        self._collective_rpc('close')\n        self.ray_ctx.shutdown()\n\n    def start_loop(self) -> None:\n        \"\"\"Start mp engine loop.\"\"\"\n"
  },
  {
    "path": "lmdeploy/pytorch/engine/mp_engine/zmq_engine.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport asyncio\nimport atexit\nimport signal\nfrom typing import TYPE_CHECKING\n\nimport torch.multiprocessing as mp\n\nfrom lmdeploy.messages import PytorchEngineConfig, SpeculativeConfig\nfrom lmdeploy.utils import get_logger\n\nfrom .base import MPEngine\n\nlogger = get_logger('lmdeploy')\n\nif TYPE_CHECKING:\n    from lmdeploy.pytorch.engine.engine import Engine\n\n\ndef cancel_async_tasks(loop: asyncio.AbstractEventLoop):\n    \"\"\"Cancel async tasks.\"\"\"\n    tasks = asyncio.all_tasks(loop=loop)\n    for task in tasks:\n        if not task.done():\n            task.cancel()\n    loop.run_until_complete(loop.shutdown_asyncgens())\n    loop.close()\n\n\nclass ZMQMPEngine(MPEngine):\n\n    def __init__(self,\n                 model_path: str,\n                 engine_config: PytorchEngineConfig = None,\n                 speculative_config: SpeculativeConfig = None,\n                 **kwargs) -> None:\n        \"\"\"Initialize mp engine.\"\"\"\n        from .zmq_rpc import AsyncRPCClient\n        self.shared_dict = None\n        self.port = None\n        self.proc = None\n        self._start_mp_proc(model_path, engine_config, speculative_config=speculative_config, **kwargs)\n\n        self.rpc_client = AsyncRPCClient(port=self.port)\n\n        super().__init__()\n        atexit.register(self.close)\n\n    def _start_mp_proc(\n        self,\n        model_path: str,\n        engine_config: PytorchEngineConfig = None,\n        speculative_config: SpeculativeConfig = None,\n        **kwargs,\n    ):\n        \"\"\"Start mp proc.\"\"\"\n        logger.debug('Starting engine multi-process.')\n        with mp.Manager() as manager:\n            self.shared_dict = manager.dict()\n            condition = manager.Condition()\n            self.mp_ctx = mp.get_context('spawn')\n            log_level = logger.level\n            target_kwargs = dict(\n                model_path=model_path,\n                engine_config=engine_config,\n                log_level=log_level,\n                speculative_config=speculative_config,\n            )\n            target_kwargs.update(kwargs)\n            self.proc = self.mp_ctx.Process(\n                target=self._mp_proc,\n                args=(self.shared_dict, condition),\n                kwargs=target_kwargs,\n                name='mp_engine_proc',\n            )\n            self.proc.start()\n            logger.debug('Receiving rpc server port from mp process.')\n            with condition:\n                if 'rpc_server_port' not in self.shared_dict:\n                    condition.wait()\n            self.port = self.shared_dict['rpc_server_port']\n\n    @staticmethod\n    def _mp_proc(\n        shared_dict: dict,\n        condition: mp.Condition,\n        model_path: str,\n        engine_config: PytorchEngineConfig = None,\n        log_level: str = 'WARNING',\n        speculative_config: SpeculativeConfig = None,\n        **kwargs,\n    ):\n        \"\"\"Mp process function.\"\"\"\n        from lmdeploy.pytorch.engine import Engine\n\n        from .zmq_rpc import AsyncRPCServer\n\n        logger.setLevel(log_level)\n\n        # create an async rpc server\n        server = AsyncRPCServer()\n        with condition:\n            shared_dict['rpc_server_port'] = server.port\n            condition.notify()\n\n        # create engine\n        if engine_config is not None:\n            engine_config.enable_mp_engine = False\n        engine = Engine.from_pretrained(\n            model_path,\n            engine_config=engine_config,\n            speculative_config=speculative_config,\n            **kwargs,\n        )\n\n        loop = asyncio.new_event_loop()\n        asyncio.set_event_loop(loop)\n\n        try:\n            loop.run_until_complete(ZMQMPEngine._mp_proc_async(server, engine))\n        except KeyboardInterrupt:\n            logger.info('Received KeyboardInterrupt, stopping mp process.')\n\n    @staticmethod\n    async def _mp_proc_async(server, engine: 'Engine'):\n        \"\"\"Mp process function.\"\"\"\n        import inspect\n\n        from .base_worker import EngineWorkerBase\n\n        loop = asyncio.get_running_loop()\n        current_task = asyncio.current_task()\n\n        async def shutdown(loop, signame):\n            logger.info(f'MP process received signal {signame}, stopping server.')\n            if current_task is not None:\n                current_task.cancel()\n\n        for signame in {'SIGINT', 'SIGTERM'}:\n            sig = getattr(signal, signame)\n            loop.add_signal_handler(sig, lambda signame=signame: asyncio.create_task(shutdown(loop, signame)))\n\n        worker = EngineWorkerBase(engine)\n\n        for name, value in inspect.getmembers(EngineWorkerBase):\n            if not name.startswith('_') and inspect.isfunction(value):\n                method = getattr(worker, name)\n                server.register_method(name, method)\n\n        try:\n            # run server\n            await server.run()\n        except asyncio.CancelledError:\n            logger.info('RPC Server stopping due to cancellation.')\n        except Exception as e:\n            logger.error(f'RPC Server stopped with exception: {e}')\n        finally:\n            server.stop()\n            engine.close()\n            try:\n                await engine.wait_tasks()\n            except asyncio.CancelledError:\n                logger.info('Engine wait_tasks cancelled during shutdown.')\n            except Exception as e:\n                logger.debug(f'Engine wait_tasks failed during shutdown: {e}')\n\n    def _collective_rpc(self, func, *args, **kwargs):\n        \"\"\"Collective rpc call.\"\"\"\n        return self.rpc_client.call(func, *args, **kwargs)\n\n    async def _collective_rpc_async(self, func, *args, **kwargs):\n        \"\"\"Collective rpc call.\"\"\"\n        return await self.rpc_client.async_call(func, *args, **kwargs)\n\n    async def _collective_rpc_streaming_async(self, func, *args, **kwargs):\n        \"\"\"Collective rpc call.\"\"\"\n        async for out in self.rpc_client.async_stream_call(func, *args, **kwargs):\n            yield out\n\n    def close(self) -> None:\n        \"\"\"Close mp engine.\"\"\"\n        if self.proc is None:\n            return\n        logger.info('Closing mp engine.')\n        self.rpc_client.stop()\n        self.proc.terminate()\n        self.proc.join(10)\n        if not self.proc.is_alive():\n            self.proc.close()\n        else:\n            logger.warning('MP process did not terminate in time, force killing.')\n            self.proc.kill()\n        self.proc = None\n\n    def start_loop(self) -> None:\n        \"\"\"Start mp engine loop.\"\"\"\n"
  },
  {
    "path": "lmdeploy/pytorch/engine/mp_engine/zmq_rpc.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport asyncio\nimport inspect\nimport pickle\nfrom typing import Callable, Dict\nfrom uuid import uuid4\n\nimport zmq\nimport zmq.asyncio\nfrom zmq.asyncio import Context\n\nfrom lmdeploy.utils import get_logger\n\nfrom .base_worker import EngineOutputGather\n\nlogger = get_logger('lmdeploy')\n\n\ndef _task_callback(task: asyncio.Task) -> None:\n    \"\"\"Raise exception on finish.\"\"\"\n    task_name = task.get_name()\n    try:\n        task.result()\n    except asyncio.CancelledError:\n        logger.debug(f'Task <{task_name}> cancelled.')\n    except Exception:\n        logger.exception(f'Task <{task_name}> failed')\n    finally:\n        if not task.done():\n            task.cancel()\n\n\nclass AsyncRPCServer:\n\n    def __init__(self):\n        # Warning: DO NOT allow visit rpc server from external network\n        # unauthorized access may lead to code execution vulnerability\n        address = 'tcp://localhost'\n        self.context = zmq.Context()\n        self.socket = self.context.socket(zmq.ROUTER)\n        self.port = self.socket.bind_to_random_port(address)\n        self.methods: Dict[str, Callable] = {}\n        self.running = False\n\n        # streaming\n        self.stream_output = dict()\n        self._stream_idx = 0\n        self._engine_output_gather = EngineOutputGather()\n\n        self.tasks = set()\n\n    def get_port(self):\n        return self.port\n\n    def _get_next_stream_id(self):\n        \"\"\"Get next stream id.\"\"\"\n        self._stream_idx += 1\n        return self._stream_idx\n\n    def register_method(self, name: str, func: Callable):\n        \"\"\"Register method.\"\"\"\n        if asyncio.iscoroutinefunction(func):\n            func_type = 'async'\n        elif inspect.isasyncgenfunction(func):\n            func_type = 'async_streaming'\n        else:\n            func_type = 'default'\n        self.methods[name] = (func_type, func)\n\n    def send_multipart(self, client_id: bytes, data: bytes):\n        \"\"\"Send multipart message to client.\"\"\"\n        try:\n            self.socket.send_multipart([client_id, pickle.dumps(data)])\n        except zmq.ZMQError as e:\n            logger.error(f'Failed to send message to client[{client_id}]: {e}')\n\n    def call_method_default(self, client_id, method: Callable, request: Dict):\n        request_id = request.get('request_id')\n        args = request.get('args', [])\n        kwargs = request.get('kwargs', {})\n        try:\n            result = method(*args, **kwargs)\n            response = dict(success=True, request_id=request_id, result=result)\n        except Exception as e:\n            response = dict(success=False, request_id=request_id, error=str(e))\n        self.send_multipart(client_id, response)\n\n    async def _method_async_task(self, client_id, request_id, method: Callable, args: tuple, kwargs: Dict):\n        \"\"\"Call method in a task.\"\"\"\n        try:\n            result = await method(*args, **kwargs)\n            response = dict(success=True, request_id=request_id, result=result)\n        except Exception as e:\n            response = dict(success=False, request_id=request_id, error=str(e))\n        self.send_multipart(client_id, response)\n\n    async def _method_async_streaming_task(self, stream_id: int, request_id: int, client_id: int, method: Callable,\n                                           args: tuple, kwargs: Dict):\n        \"\"\"Call method in a task for streaming.\"\"\"\n\n        def __send_resp():\n            response = dict(success=True, request_id=request_id, result=stream_id)\n            session_id = kwargs.get('session_id', None)\n            if session_id is None:\n                session_id = args[0]\n            self.send_multipart(client_id, response)\n\n        stream_out = dict(\n            event=asyncio.Event(),\n            result=None,\n            stopped=False,\n        )\n        self.stream_output[stream_id] = stream_out\n        __send_resp()\n        try:\n            generator = method(*args, **kwargs)\n            async for result in generator:\n                self._engine_output_gather.add(stream_id, result)\n                stream_out['result'] = result\n                stream_out['event'].set()\n        except Exception as e:\n            stream_out['error'] = e\n            stream_out['event'].set()\n        finally:\n            stream_out['stopped'] = True\n\n    async def get_stream_output(self, stream_id: int):\n        \"\"\"Get streaming output.\"\"\"\n        if stream_id not in self.stream_output:\n            raise ValueError(f'Stream ID {stream_id} not found')\n        stream_out = self.stream_output[stream_id]\n        event = stream_out['event']\n        await event.wait()\n        event.clear()\n        result = stream_out['result']\n        stopped = stream_out['stopped']\n        result = self._engine_output_gather.pop(stream_id, result)\n        if stopped:\n            self.stream_output.pop(stream_id)\n        if 'error' in stream_out:\n            raise stream_out['error']\n        return result, stopped\n\n    async def call_method_async(self, client_id, method: Callable, request: Dict):\n        \"\"\"Call method async.\"\"\"\n        request_id = request.get('request_id')\n        method_name = request.get('method')\n        args = request.get('args', [])\n        kwargs = request.get('kwargs', {})\n        event_loop = asyncio.get_event_loop()\n        name = f'{method_name}_{client_id}'\n        if request.get('streaming', False):\n            # if method is a streaming method, use a different task\n            stream_id = self._get_next_stream_id()\n            task = event_loop.create_task(self._method_async_streaming_task(stream_id, request_id, client_id, method,\n                                                                            args, kwargs),\n                                          name=name)\n            self.tasks.add(task)\n            task.add_done_callback(self.tasks.discard)\n        else:\n            task = event_loop.create_task(self._method_async_task(client_id, request_id, method, args, kwargs),\n                                          name=name)\n            self.tasks.add(task)\n            task.add_done_callback(self.tasks.discard)\n\n    async def call_and_response(self):\n        \"\"\"Call method.\"\"\"\n        # receive message: [client_id, empty, request_data]\n        client_id, request_data = self.socket.recv_multipart()\n        request = pickle.loads(request_data)\n\n        method_name = request.get('method')\n        logger.debug(f'call method: {method_name}')\n        if method_name not in self.methods:\n            request_id = request.get('request_id')\n            response = dict(success=False, request_id=request_id, error=f'Method {method_name} not found')\n            self.send_multipart(client_id, response)\n        else:\n            method_type, method = self.methods[method_name]\n            if method_type in ('async', 'async_streaming'):\n                await self.call_method_async(client_id, method, request)\n            else:\n                self.call_method_default(client_id, method, request)\n\n    async def run(self):\n        logger.info('Starting AsyncRPCServer...')\n        self.running = True\n        poller = zmq.asyncio.Poller()\n        poller.register(self.socket, zmq.POLLIN)\n\n        self.register_method('_asyncrpcserver_get_stream_output', self.get_stream_output)\n        try:\n            events = await poller.poll(timeout=10)\n            while self.running:\n                while self.socket in dict(events):\n                    await self.call_and_response()\n                    events = await poller.poll(timeout=0)\n                events = await poller.poll(timeout=10)\n\n        except zmq.ZMQError:\n            logger.exception('ZMQRPCServer error')\n        except Exception:\n            logger.exception('AsyncRPCServer error')\n        finally:\n            logger.info('Stopping AsyncRPCServer...')\n            self.socket.close()\n            self.context.term()\n            self.running = False\n\n    def stop(self):\n        self.running = False\n        for task in self.tasks:\n            task.cancel()\n\n\nclass AsyncRPCClient:\n\n    def __init__(self, port: int = 5555):\n        logger.info(f'Connecting to AsyncRPCServer on port {port}...')\n        address = f'tcp://localhost:{port}'\n\n        socket_type = zmq.DEALER\n\n        # sync socket\n        self.sync_ctx = zmq.Context()\n        self.sync_socket = self.sync_ctx.socket(socket_type)\n        self.sync_socket.connect(address)\n        self.sync_poller = zmq.Poller()\n        self.sync_poller.register(self.sync_socket, zmq.POLLIN)\n\n        # async socket\n        self.async_ctx = Context.instance()\n        self.async_socket = self.async_ctx.socket(socket_type)\n        self.async_socket.connect(address)\n\n        self.pending = {}\n        self._listen_task = None\n        self.running = False\n\n    def _set_reply_default(self, request_id: int, reply: Dict):\n        \"\"\"Default reply handler for sync socket.\"\"\"\n        logger.debug(f'recv reply request_id: {request_id}')\n        future: asyncio.Future = self.pending.pop(request_id)\n        try:\n            if reply['success']:\n                future.set_result(reply['result'])\n            else:\n                future.set_exception(Exception(reply['error']))\n        except Exception as e:\n            logger.debug(f'Set future failed with exception: {e}')\n\n    def _set_reply(self, reply: Dict):\n        request_id = reply['request_id']\n        self._set_reply_default(request_id, reply)\n\n    def _poll_recv(self, timeout: float = 3):\n        \"\"\"Poll and receive message.\"\"\"\n        # socket.recv would block the process, use poll to avoid hanging\n        while True:\n            sockets = dict(self.sync_poller.poll(timeout=timeout * 1000))\n            if self.sync_socket in sockets:\n                return self.sync_socket.recv()\n\n    def _try_start_listen(self):\n        \"\"\"Try to start listening on async socket.\"\"\"\n        if self._listen_task is None or self._listen_task.done():\n            logger.debug('Starting async listen task...')\n            self._listen_task = asyncio.create_task(self.listen(), name='AsyncRPCClient.listen')\n            self._listen_task.add_done_callback(_task_callback)\n\n    def call(self, method, *args, **kwargs):\n        request_id = str(uuid4())\n        logger.debug(f'call method: {method}, request_id: {request_id}')\n        data = pickle.dumps(dict(request_id=request_id, method=method, args=args, kwargs=kwargs))\n        self.sync_socket.send(data)\n\n        reply = self._poll_recv()\n        reply = pickle.loads(reply)\n        while reply['request_id'] != request_id:\n            self._set_reply(reply)\n            reply = self._poll_recv()\n            reply = pickle.loads(reply)\n\n        logger.debug(f'recv reply request_id: {request_id}')\n        if reply['success']:\n            return reply['result']\n        else:\n            raise Exception(reply['error'])\n\n    async def _async_call_impl(self, method, streaming, *args, **kwargs):\n        self._try_start_listen()\n        request_id = str(uuid4())\n        future = asyncio.Future()\n        self.pending[request_id] = future\n\n        logger.debug(f'call method: {method}, request_id: {request_id}')\n        data = pickle.dumps(dict(request_id=request_id, method=method, args=args, kwargs=kwargs, streaming=streaming))\n        await self.async_socket.send(data)\n\n        return await future\n\n    async def async_call(self, method, *args, **kwargs):\n        \"\"\"Async call.\"\"\"\n        return await self._async_call_impl(method, False, *args, **kwargs)\n\n    async def async_stream_call(self, method, *args, **kwargs):\n        \"\"\"Streaming call.\"\"\"\n        stream_id = await self._async_call_impl(method, True, *args, **kwargs)\n\n        stopped = False\n        while not stopped:\n            output, stopped = await self.async_call('_asyncrpcserver_get_stream_output', stream_id)\n            yield output\n\n    async def listen(self):\n        self._listen_task = asyncio.current_task()\n        self.running = True\n        try:\n            while self.running:\n                reply = await self.async_socket.recv()\n                reply = pickle.loads(reply)\n                self._set_reply(reply)\n        except zmq.ZMQError:\n            logger.exception('AsyncRPCClient listen error')\n        finally:\n            self.running = False\n            self.close_sockets()\n\n    def stop(self):\n        \"\"\"Stop the client.\"\"\"\n        self.running = False\n        if self._listen_task is not None:\n            self._listen_task.cancel()\n        self.close_sockets()\n\n    def close_sockets(self):\n        \"\"\"Close sockets.\"\"\"\n        self.async_socket.close()\n        self.sync_socket.close()\n        self.async_ctx.term()\n        self.sync_ctx.term()\n"
  },
  {
    "path": "lmdeploy/pytorch/engine/request.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport asyncio\nimport enum\nimport logging\nfrom dataclasses import dataclass, field\nfrom typing import Any, Awaitable, Callable, Coroutine, Dict, List\n\nfrom lmdeploy.messages import RequestMetrics, ResponseType\nfrom lmdeploy.utils import get_logger\n\nlogger = get_logger('lmdeploy')\n\n\nclass RequestType(enum.Enum):\n    \"\"\"Request type.\"\"\"\n\n    ADD_SESSION = enum.auto()\n    ADD_MESSAGE = enum.auto()\n    STOP_SESSION = enum.auto()\n    END_SESSION = enum.auto()\n    STOP_ENGINE = enum.auto()\n    RESUME_ENGINE = enum.auto()\n\n\n@dataclass\nclass Response:\n    \"\"\"Response.\"\"\"\n\n    type: ResponseType\n    sender_id: int\n    event: asyncio.Event\n    data: Any = None\n    err_msg: str = ''\n    is_done: bool = False\n    req_metrics: RequestMetrics = None\n\n\n@dataclass\nclass Request:\n    \"\"\"Request.\"\"\"\n\n    type: RequestType\n    sender_id: int\n    data: Any = None\n    resp: Response = None\n\n\nReqList = List[Request]\n\n\ndef _run_until_complete(future: Awaitable):\n    \"\"\"Run untile complete.\"\"\"\n    try:\n        event_loop = asyncio.get_event_loop()\n    except Exception:\n        logger.warning('Can not found event loop in current thread.'\n                       ' Create a new event loop.')\n        event_loop = asyncio.new_event_loop()\n        asyncio.set_event_loop(event_loop)\n    return event_loop.run_until_complete(future)\n\n\n@dataclass\nclass RequestSender:\n    \"\"\"Request sender.\n\n    Args:\n        sender_id (int): The id of the sender\n    \"\"\"\n    sender_id: int\n    manager: 'RequestManager'\n    resp_dict: Dict[int, List[Response]] = field(default_factory=dict)\n\n    @classmethod\n    def new(cls, sender_id: int, manager: 'RequestManager'):\n        \"\"\"new.\"\"\"\n        obj = cls(sender_id=sender_id, manager=manager)\n        return obj\n\n    @property\n    def req_que(self):\n        \"\"\"Request queue.\"\"\"\n        return self.manager.requests\n\n    @property\n    def event_loop(self):\n        \"\"\"Get event loop.\"\"\"\n        return self.manager.event_loop\n\n    def is_loop_alive(self):\n        \"\"\"Is loop alive.\"\"\"\n        return self.manager.is_loop_alive()\n\n    def run_until_complete(self, future: Awaitable):\n        \"\"\"Run untile complete.\"\"\"\n        return self.manager.run_until_complete(future)\n\n    def _req_put(self, reqs: Any):\n        \"\"\"Async rq_que put.\"\"\"\n        self.req_que.put_nowait(reqs)\n\n    def _gather_request(self, req_types: List[RequestType], data: List[Any]):\n        \"\"\"Gather requests.\"\"\"\n        if self.manager._loop_task is None:\n            self.manager.create_loop_task()\n        assert len(req_types) == len(data)\n\n        reqs = []\n        resps = []\n        for rtype, rdata in zip(req_types, data):\n            event = asyncio.Event()\n            resp = Response(type=ResponseType.INTERNAL_ENGINE_ERROR,\n                            sender_id=self.sender_id,\n                            event=event,\n                            data=None,\n                            err_msg=None)\n            req = Request(type=rtype, sender_id=self.sender_id, data=rdata, resp=resp)\n            resps.append(resp)\n            reqs.append(req)\n        return resps, reqs\n\n    def batched_send_async(self, req_types: List[RequestType], data: List[Any]):\n        \"\"\"Batched send request asynchronize.\"\"\"\n        resps, reqs = self._gather_request(req_types, data)\n        self._req_put(reqs)\n        return resps\n\n    def send_async(self, req_type: RequestType, data: Any):\n        \"\"\"Send request asynchronize.\"\"\"\n        return self.batched_send_async(req_types=[req_type], data=[data])[0]\n\n    async def async_recv(self, resp: Response, wait_main: bool = False) -> Response:\n        \"\"\"Receive response of given request id async.\"\"\"\n        if wait_main:\n            await self.manager.prepare_send()\n        event = resp.event\n        while not event.is_set():\n            try:\n                await asyncio.wait_for(event.wait(), 1)\n            except asyncio.TimeoutError:\n                if self.is_loop_alive():\n                    continue\n                logger.debug('Engine main loop failed.')\n                resp.type = ResponseType.ENGINE_STOP_ERROR\n                break\n        event.clear()\n        return resp\n\n    def recv(self, resp: Response) -> Response:\n        \"\"\"Receive response of given request id.\"\"\"\n        coro = self.async_recv(resp)\n        return self.run_until_complete(coro)\n\n    async def async_send(self, req_type: RequestType, data: Any):\n        \"\"\"Send and receive synchronize.\"\"\"\n        resp = self.send_async(req_type, data)\n        return await self.async_recv(resp)\n\n    def send(self, req_type: RequestType, data: Any) -> Response:\n        \"\"\"Send and receive synchronize.\"\"\"\n        resp = self.send_async(req_type, data)\n        return self.recv(resp)\n\n\nclass RequestManager:\n    \"\"\"Request manager.\"\"\"\n\n    def __init__(self):\n        self.senders: Dict[int, RequestSender] = dict()\n        self.callbacks: Dict[RequestType, Callable] = dict()\n        self.request_priority: List[RequestType] = [\n            RequestType.STOP_ENGINE, RequestType.ADD_SESSION, RequestType.STOP_SESSION, RequestType.END_SESSION,\n            RequestType.ADD_MESSAGE\n        ]\n        self.requests: asyncio.Queue = None\n        self._loop_task: asyncio.Future = None\n        self._loop_coro: Callable = None\n        self._next_sender_id = 0\n\n        # sender speed limiter\n        self._condition: asyncio.Condition = None\n        self._sender_wait_task: asyncio.Task = None\n        self._send_count = 0\n        self._send_event = None\n\n    async def prepare_send(self):\n        if self._condition is None:\n            return\n\n        self._send_count += 1\n        self._send_event.set()\n        async with self._condition:\n            await self._condition.wait()\n        self._send_count -= 1\n        if self._send_count == 0:\n            self._send_event.clear()\n\n    async def sender_wait_loop(self):\n        \"\"\"Wait for loop to be created.\"\"\"\n        self._condition = asyncio.Condition()\n        self._send_count = 0\n        self._send_event = asyncio.Event()\n\n        try:\n            while True:\n                await self._send_event.wait()\n                # notify one sender to control send speed\n                async with self._condition:\n                    self._condition.notify()\n                await asyncio.sleep(0.0001)\n        finally:\n            # notify all senders to exit\n            async with self._condition:\n                self._condition.notify_all()\n            self._condition = None\n            self._send_event = None\n\n    def create_loop_task(self):\n        \"\"\"Create coro task.\"\"\"\n        if self._loop_task is not None:\n            logger.debug('loop task has been created.')\n            return self._loop_task\n        logger.debug('creating engine loop task.')\n        event_loop = asyncio.get_event_loop()\n        assert self._loop_coro is not None, ('Please set loop task with manager.start_loop')\n        loop_unshielded = event_loop.create_task(self._loop_coro(), name='EngineMainLoop')\n        self._loop_task = loop_unshielded\n        self._sender_wait_task = event_loop.create_task(self.sender_wait_loop(), name='SenderWaitLoop')\n        self.requests = asyncio.Queue()\n        return self._loop_task\n\n    async def wait_tasks(self):\n        \"\"\"Wait for loop task and sender wait task to finish.\"\"\"\n        if self._loop_task is None:\n            return\n\n        try:\n            await self._loop_task\n        except asyncio.CancelledError:\n            logger.info('Engine main loop task has been cancelled.')\n            raise\n        finally:\n            if self._sender_wait_task is not None:\n                self._sender_wait_task.cancel()\n                try:\n                    await self._sender_wait_task\n                except Exception:\n                    logger.debug('Sender wait task has been cancelled.')\n\n    @property\n    def event_loop(self):\n        \"\"\"Get event loop.\"\"\"\n        if self._loop_task is None:\n            return None\n        else:\n            return self._loop_task.get_loop()\n\n    def set_main_loop_func(self, loop: Callable[[Coroutine], asyncio.Task]):\n        \"\"\"Start main loop.\"\"\"\n        self._loop_coro = loop\n\n    def stop_loop(self):\n        if self.is_loop_alive():\n            self._loop_task.cancel()\n        self._loop_task = None\n        if self._sender_wait_task is not None:\n            self._sender_wait_task.cancel()\n            self._sender_wait_task = None\n\n    def is_loop_alive(self):\n        \"\"\"Check if main loop is alive.\"\"\"\n\n        if self._loop_task is None:\n            logger.debug('loop task has not been created.')\n            return False\n        if self._loop_task.get_loop() != asyncio.get_event_loop():\n            logger.warning('Current event loop is different from'\n                           ' the one bound to loop task!')\n            return False\n        return not self._loop_task.done()\n\n    def build_sender(self):\n        \"\"\"Create a new sender.\"\"\"\n        sender_id = self._next_sender_id\n        self._next_sender_id += 1\n        new_sender = RequestSender.new(sender_id, self)\n        self.senders[sender_id] = new_sender\n        return new_sender\n\n    def has_requests(self):\n        \"\"\"Has unprocessed request.\"\"\"\n        if self.requests is None:\n            return False\n        return not self.requests.empty()\n\n    async def get_all_requests(self) -> Dict[RequestType, List[Request]]:\n        \"\"\"Get all requests in current queue.\"\"\"\n        num_reqs = self.requests.qsize()\n        reqs: ReqList = []\n\n        def __proc_reqs(elem):\n            \"\"\"Proc reqs.\"\"\"\n            nonlocal reqs\n            if isinstance(elem, Request):\n                elem = [elem]\n            reqs += elem\n\n        if num_reqs == 0:\n            elem = await self.requests.get()\n            __proc_reqs(elem)\n            num_reqs = self.requests.qsize()\n\n        for _ in range(num_reqs):\n            elem = self.requests.get_nowait()\n            __proc_reqs(elem)\n\n        # gather requests\n        reqs_by_type: Dict[RequestType, List[Request]] = dict((t, []) for t in RequestType)\n        for req in reqs:\n            reqs_by_type[req.type].append(req)\n        return reqs_by_type\n\n    def bind_func(self, req_type: RequestType, callback: Callable):\n        \"\"\"Bind handler for given request type.\"\"\"\n        self.callbacks[req_type] = callback\n\n    def set_request_priority(self, priority: List[RequestType]):\n        \"\"\"Set the priority of request type.\"\"\"\n        self.request_priority = priority\n\n    def response(self, resp: Response):\n        \"\"\"Send response.\"\"\"\n        resp.event.set()\n\n    def process_request(self, req_type: RequestType, reqs: ReqList, **kwargs):\n        \"\"\"Process reqs with given req type.\"\"\"\n        # get callback\n        func = self.callbacks.get(req_type, None)\n        if func is not None:\n            func(reqs, **kwargs)\n        else:\n            # TODO: send error message\n            for req in reqs:\n                resp = req.resp\n                resp.type = ResponseType.HANDLER_NOT_EXIST\n                resp.err_msg = (f'callback for {req_type}'\n                                ' not exists.')\n                self.response(resp)\n\n    async def step(self, **kwargs):\n        \"\"\"Handle requests.\n\n        Should only be called in loop task.\n        \"\"\"\n\n        def _log_reqs(reqs: ReqList):\n            num_reqs = len(reqs)\n            if num_reqs == 0:\n                return\n            logger_level = logger.level\n            if logger_level <= logging.DEBUG:\n                sender_id = [req.sender_id for req in reqs]\n                logger.debug(f'Receive {req_type.name} Request: senders: {sender_id}')\n            elif logger_level <= logging.INFO:\n                logger.info(f'Receive {req_type.name} Request: {num_reqs}')\n\n        reqs_by_type = await self.get_all_requests()\n\n        # handle requests\n        for req_type in self.request_priority:\n            reqs: ReqList = reqs_by_type.get(req_type, [])\n            if not reqs:\n                continue\n\n            _log_reqs(reqs)\n            self.process_request(req_type, reqs, **kwargs)\n\n    def run_until_complete(self, future: Awaitable):\n        \"\"\"Run untile complete.\"\"\"\n        return _run_until_complete(future)\n"
  },
  {
    "path": "lmdeploy/pytorch/envs.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport contextlib\nimport os\nfrom typing import Union\n\n\ndef env_to_bool(\n    env_var: str,\n    default: bool = False,\n    *,\n    true_values: Union[set, list] = {'true', '1', 'yes', 'on'},\n    false_values: Union[set, list] = {'false', '0', 'no', 'off'},\n):\n    \"\"\"Env to bool.\"\"\"\n    value = os.getenv(env_var)\n    if value is None:\n        return default\n    value = value.lower().strip()\n    if value in true_values:\n        return True\n    elif value in false_values:\n        return False\n    else:\n        raise ValueError(f\"Cannot convert environment variable '{env_var}={value}' to boolean. \"\n                         f'Allowed true values: {true_values}, false values: {false_values}')\n\n\ndef env_to_int(\n    env_var: str,\n    default: int = 0,\n):\n    \"\"\"Env to int.\"\"\"\n    value = os.getenv(env_var)\n    if value is None:\n        return default\n    try:\n        value = int(value)\n    except Exception:\n        value = default\n    return value\n\n\ndef env_to_list_int(\n    env_var: str,\n    default: list[int] = None,\n):\n    \"\"\"Env to list of int.\"\"\"\n    default_ = default if default is not None else []\n    value = os.getenv(env_var)\n    if value is None:\n        return default_\n    try:\n        value = [int(x) for x in value.split(',')]\n    except Exception:\n        value = default_\n    return value\n\n\ndef env_to_float(\n    env_var: str,\n    default: float = 0,\n):\n    \"\"\"Env to float.\"\"\"\n    value = os.getenv(env_var)\n    if value is None:\n        return default\n    try:\n        value = float(value)\n    except Exception:\n        value = default\n    return value\n\n\n_ENVS = dict()\n\n\n@contextlib.contextmanager\ndef set_envs():\n    _origin_get_env = os.getenv\n\n    def _patched_get_env(\n        env_var: str,\n        default: Union[str, None] = None,\n    ):\n        \"\"\"Patched get_env.\"\"\"\n        if env_var in os.environ:\n            _ENVS[env_var] = os.environ[env_var]\n\n        return _origin_get_env(env_var, default)\n\n    os.getenv = _patched_get_env\n    yield\n    os.getenv = _origin_get_env\n\n\nwith set_envs():\n    # loader\n    random_load_weight = env_to_bool('LMDEPLOY_RANDOM_LOAD_WEIGHT', True)\n\n    # profile\n    ray_nsys_enable = env_to_bool('LMDEPLOY_RAY_NSYS_ENABLE', False)\n    ray_nsys_output_prefix = os.getenv('LMDEPLOY_RAY_NSYS_OUT_PREFIX', None)\n\n    # ascend\n    ascend_set_rt_visable_devices_by_ray = env_to_bool('ASCEND_SET_RT_VISIBLE_DEVICES_BY_RAY', False)\n    ascend_rank_table_file = os.getenv('ASCEND_RANK_TABLE_FILE_PATH')\n\n    # dp\n    dp_master_addr = os.getenv('LMDEPLOY_DP_MASTER_ADDR', None)\n    dp_master_port = os.getenv('LMDEPLOY_DP_MASTER_PORT', None)\n\n    # executor\n    executor_backend = os.getenv('LMDEPLOY_EXECUTOR_BACKEND', None)\n\n    # torch profiler\n    torch_profile_cpu = env_to_bool('LMDEPLOY_PROFILE_CPU', False)\n    torch_profile_cuda = env_to_bool('LMDEPLOY_PROFILE_CUDA', False)\n    torch_profile_delay = env_to_int('LMDEPLOY_PROFILE_DELAY', 0)\n    torch_profile_duration = env_to_int('LMDEPLOY_PROFILE_DURATION', -1)\n    torch_profile_output_prefix = os.getenv('LMDEPLOY_PROFILE_OUT_PREFIX', 'lmdeploy_profile_')\n\n    # ray timeline\n    ray_timeline_enable = env_to_bool('LMDEPLOY_RAY_TIMELINE_ENABLE', False)\n    ray_timeline_output_path = os.getenv('LMDEPLOY_RAY_TIMELINE_OUT_PATH', 'ray_timeline.json')\n\n    # ray external placement group bundles\n    # only used when lmdeploy is initialized inside a Ray Actor with pg allocated\n    ray_external_pg_bundles = env_to_list_int('LMDEPLOY_RAY_EXTERNAL_PG_BUNDLES', [])\n\n    # dist\n    dist_master_addr = os.getenv('LMDEPLOY_DIST_MASTER_ADDR', None)\n    dist_master_port = os.getenv('LMDEPLOY_DIST_MASTER_PORT', None)\n\n    # logging\n    log_file = os.getenv('LMDEPLOY_LOG_FILE', None)\n\n    # check env\n    enable_check_env = env_to_bool('LMDEPLOY_ENABLE_CHECK_ENV', True)\n\n    # dlblas\n    # we don't need to read this, it would be passed to ray workers\n    # If Ray is launched from outside, it may fail to access the environment variables.\n    os.getenv('DEEPEP_MAX_TOKENS_PER_RANK', None)\n    os.getenv('DEEPEP_ENABLE_MNNVL', None)\n    os.getenv('DEEPEP_MODE', 'auto')\n\n    # deepep\n    deep_ep_buffer_num_sms = env_to_int('DEEPEP_BUFFER_NUM_SMS', 20)\n\n    # deepgemm\n    os.getenv('DG_JIT_DEBUG', '0')\n    os.getenv('DG_JIT_PRINT_COMPILER_COMMAND', '0')\n\n    # model agent\n    skip_warmup = env_to_bool('LMDEPLOY_SKIP_WARMUP', False)\n\n    # model format\n    scale_fmt = os.getenv('LMDEPLOY_SCALE_FMT', None)\n\n    # repetition check\n    repetition_window_size = env_to_int('LMDEPLOY_REPETITION_WINDOW_SIZE', 1024)\n\n\ndef get_all_envs():\n    \"\"\"Get all environment variables.\"\"\"\n    return _ENVS\n"
  },
  {
    "path": "lmdeploy/pytorch/kernels/__init__.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nfrom .w8a8_triton_kernels import (matmul_kernel_dynamic_quant, per_channel_quant, per_token_quant_int8,\n                                  rms_norm_dynamic_quant)\n\n__all__ = [\n    'matmul_kernel_dynamic_quant',\n    'per_channel_quant',\n    'per_token_quant_int8',\n    'rms_norm_dynamic_quant',\n]\n"
  },
  {
    "path": "lmdeploy/pytorch/kernels/cuda/__init__.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom ..default.w8a8_kernels import per_channel_quant\nfrom .apply_rotary_pos_emb import apply_rotary_pos_emb\nfrom .fill_kv_cache import fill_kv_cache\nfrom .flashattention import flash_attn_varlen_func\nfrom .flatten_kv_cache import flatten_kv_cache\nfrom .fused_moe import fused_moe\nfrom .multinomial_sampling import multinomial_sampling\nfrom .pagedattention import flash_attn_with_kvcache\nfrom .rms_norm import rms_norm\nfrom .w8a8_fused_moe import fused_moe_w8a8\nfrom .w8a8_triton_kernels import matmul_kernel_dynamic_quant, per_token_quant_int8, rms_norm_dynamic_quant\n\n__all__ = [\n    'apply_rotary_pos_emb',\n    'fused_moe',\n    'flash_attn_with_kvcache',\n    'fill_kv_cache',\n    'multinomial_sampling',\n    'rms_norm',\n    'matmul_kernel_dynamic_quant',\n    'per_channel_quant',\n    'per_token_quant_int8',\n    'rms_norm_dynamic_quant',\n    'flash_attn_varlen_func',\n    'flatten_kv_cache',\n    'fused_moe_w8a8',\n]\n"
  },
  {
    "path": "lmdeploy/pytorch/kernels/cuda/activation.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport torch\nimport triton\nimport triton.language as tl\nfrom packaging import version\n\nfrom .utils import get_device_props\n\nTRITON_VERSION = version.parse(triton.__version__)\n\nif TRITON_VERSION >= version.parse('3.0.0'):\n    fast_expf = tl.math.exp\nelse:\n    fast_expf = tl.math.fast_expf\n\n\n@triton.jit\ndef _silu_and_mul_kernel(\n    gateup_ptr,\n    out_ptr,\n    N: tl.constexpr,\n    M,\n    stride_gum: tl.constexpr,\n    stride_gun: tl.constexpr,\n    stride_om: tl.constexpr,\n    stride_on: tl.constexpr,\n    BLOCK_SIZE_N: tl.constexpr,\n):\n    \"\"\"Silu and mul kernel.\"\"\"\n    n_block_id = tl.program_id(0)\n    m_id_start = tl.program_id(1)\n    m_id_stride = tl.num_programs(1)\n\n    up_ptr = gateup_ptr + N * stride_gun\n    offs_n = n_block_id * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n\n    if N % BLOCK_SIZE_N == 0:\n        mask = None\n    else:\n        mask = offs_n < N\n\n    gate_ptrs = gateup_ptr + m_id_start * stride_gum + offs_n * stride_gun\n    up_ptrs = up_ptr + m_id_start * stride_gum + offs_n * stride_gun\n    out_ptrs = out_ptr + m_id_start * stride_om + offs_n * stride_on\n\n    for _ in tl.range(m_id_start, M, m_id_stride):\n        gate = tl.load(gate_ptrs, mask=mask)\n        up = tl.load(up_ptrs, mask=mask)\n        # exp expect fp32\n        gate = gate.to(tl.float32)\n\n        gate = gate / (1 + fast_expf(-gate))\n        gate = gate.to(gateup_ptr.dtype.element_ty)\n        out = gate * up\n\n        tl.store(out_ptrs, out, mask=mask)\n\n        gate_ptrs += m_id_stride * stride_gum\n        up_ptrs += m_id_stride * stride_gum\n        out_ptrs += m_id_stride * stride_om\n\n\ndef silu_and_mul(gate_up: torch.Tensor, out: torch.Tensor = None):\n    \"\"\"Silu and mul.\"\"\"\n    assert gate_up.dim() == 2\n\n    M = gate_up.size(0)\n    N = gate_up.size(-1) // 2\n    if out is None:\n        out_shape = (M, N)\n        out = gate_up.new_empty(out_shape)\n\n    BLOCK_SIZE_N = triton.next_power_of_2(N)\n    BLOCK_SIZE_N = min(BLOCK_SIZE_N, 512)\n    num_warps = 4\n    num_stages = 1\n\n    props = get_device_props(gate_up.device.index)\n    num_sm = props['multi_processor_count']\n    warps_per_sm = props['warps_per_sm']\n    grid_size0 = triton.cdiv(N, BLOCK_SIZE_N)\n    grid_size1 = min(M, num_sm * warps_per_sm // num_warps)\n    assert grid_size0 < 65536 and grid_size1 < 65536\n    grid = (grid_size0, grid_size1)\n    _silu_and_mul_kernel[grid](gate_up,\n                               out,\n                               N,\n                               M,\n                               stride_gum=gate_up.stride(0),\n                               stride_gun=gate_up.stride(1),\n                               stride_om=out.stride(0),\n                               stride_on=out.stride(1),\n                               BLOCK_SIZE_N=BLOCK_SIZE_N,\n                               num_warps=num_warps,\n                               num_stages=num_stages)\n\n    return out\n\n\n@triton.jit\ndef _silu_and_mul_moe_ep_kernel(\n    gateup_ptr,\n    out_ptr,\n    mask_ptr,\n    N: tl.constexpr,\n    M: tl.constexpr,\n    stride_gue: tl.constexpr,\n    stride_gum: tl.constexpr,\n    stride_gun: tl.constexpr,\n    stride_oe: tl.constexpr,\n    stride_om: tl.constexpr,\n    stride_on: tl.constexpr,\n    stride_m: tl.constexpr,\n    BLOCK_SIZE_N: tl.constexpr,\n):\n    \"\"\"Silu and mul kernel.\"\"\"\n    n_block_id = tl.program_id(0)\n    e_id = tl.program_id(1)\n    m_id_start = tl.program_id(2)\n    m_id_stride = tl.num_programs(2)\n\n    offs_n = n_block_id * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n\n    if N % BLOCK_SIZE_N == 0:\n        mask = None\n    else:\n        mask = offs_n < N\n\n    mask_m = tl.load(mask_ptr + e_id * stride_m)\n    mask_m = tl.minimum(mask_m, M)\n    if mask_m < m_id_start:\n        return\n    gate_ptrs = gateup_ptr + e_id * stride_gue + m_id_start * stride_gum + offs_n * stride_gun\n    up_ptrs = gate_ptrs + N * stride_gun\n    out_ptrs = out_ptr + e_id * stride_oe + m_id_start * stride_om + offs_n * stride_on\n\n    for _ in tl.range(m_id_start, mask_m, m_id_stride):\n        gate = tl.load(gate_ptrs, mask=mask)\n        up = tl.load(up_ptrs, mask=mask)\n        # exp expect fp32\n        gate = gate.to(tl.float32)\n        gate = gate / (1 + fast_expf(-gate))\n        gate = gate.to(gateup_ptr.dtype.element_ty)\n        out = gate * up\n\n        tl.store(out_ptrs, out, mask=mask)\n\n        gate_ptrs += m_id_stride * stride_gum\n        up_ptrs += m_id_stride * stride_gum\n        out_ptrs += m_id_stride * stride_om\n\n\ndef silu_and_mul_moe_ep(gate_up: torch.Tensor, mask_m: torch.Tensor, out: torch.Tensor = None):\n    \"\"\"Silu and mul for moe with expert parallelism.\"\"\"\n    # gate_up: [num_experts, batch_size, 2*hidden_size]\n    assert gate_up.dim() == 3\n    assert mask_m.dim() == 1\n    assert mask_m.size(0) == gate_up.size(0)\n\n    stride_m = mask_m.stride(0)\n    assert gate_up.size(0) % stride_m == 0\n\n    E = gate_up.size(0)\n    M = gate_up.size(1)\n    N = gate_up.size(-1) // 2\n    if out is None:\n        out_shape = (E, M, N)\n        out = gate_up.new_empty(out_shape)\n\n    BLOCK_SIZE_N = triton.next_power_of_2(N)\n    BLOCK_SIZE_N = min(BLOCK_SIZE_N, 512)\n    num_warps = 4\n    num_stages = 1\n\n    props = get_device_props(gate_up.device.index)\n    num_sm = props['multi_processor_count']\n    warps_per_sm = props['warps_per_sm']\n    ctas_per_sm = warps_per_sm // num_warps\n    ctas_per_device = num_sm * ctas_per_sm\n    grid_size0 = triton.cdiv(N, BLOCK_SIZE_N)\n    grid_size1 = min(M, triton.cdiv(ctas_per_device, grid_size0 * E))\n    grid = (grid_size0, E, grid_size1)\n    _silu_and_mul_moe_ep_kernel[grid](gate_up,\n                                      out,\n                                      mask_m,\n                                      N,\n                                      M,\n                                      stride_gue=gate_up.stride(0),\n                                      stride_gum=gate_up.stride(1),\n                                      stride_gun=gate_up.stride(2),\n                                      stride_oe=out.stride(0),\n                                      stride_om=out.stride(1),\n                                      stride_on=out.stride(2),\n                                      stride_m=mask_m.stride(0),\n                                      BLOCK_SIZE_N=BLOCK_SIZE_N,\n                                      num_warps=num_warps,\n                                      num_stages=num_stages)\n\n    return out\n"
  },
  {
    "path": "lmdeploy/pytorch/kernels/cuda/apply_rotary_pos_emb.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport torch\nimport triton\nimport triton.language as tl\nfrom torch import Tensor\n\n\n@triton.jit\ndef _apply_rotary_impl(x_l, x_h, cos_l, cos_h, sin_l, sin_h):\n    \"\"\"Apply rotary positional embedding implementation.\"\"\"\n    # x_l, x_h: [BLOCK, BLOCK_N]\n    # cos_l, cos_h, sin_l, sin_h: [BLOCK, BLOCK_N]\n\n    # qe_l = q_l * cos_l - q_h * sin_l\n    # qe_h = q_h * cos_h + q_l * sin_h\n\n    # triton 3.4 would do fma 3 times to perform the above computation,\n    # which causes higher numerical error. So we manually expand the\n    # computation to avoid fma.\n    x_l_new0 = x_l * cos_l + 0\n    x_l_new1 = x_h * sin_l + 0\n    x_h_new0 = x_h * cos_h + 0\n    x_h_new1 = x_l * sin_h + 0\n    return x_l_new0 - x_l_new1, x_h_new0 + x_h_new1\n\n\n@triton.jit(do_not_specialize=('seq_len', ))\ndef apply_rotary_pos_emb_qk_kernel(\n    Q,\n    K,\n    COS,\n    SIN,\n    Q_EMB,\n    K_EMB,\n    seq_len,\n    stride_qs: tl.constexpr,\n    stride_qh: tl.constexpr,\n    stride_qd: tl.constexpr,\n    stride_ks: tl.constexpr,\n    stride_kh: tl.constexpr,\n    stride_kd: tl.constexpr,\n    stride_qes: tl.constexpr,\n    stride_qeh: tl.constexpr,\n    stride_qed: tl.constexpr,\n    stride_kes: tl.constexpr,\n    stride_keh: tl.constexpr,\n    stride_ked: tl.constexpr,\n    half_size: tl.constexpr,\n    BLOCK: tl.constexpr,\n    BLOCK_QH: tl.constexpr,\n    BLOCK_N: tl.constexpr,\n):\n    \"\"\"Apply rotary on key AND query kernel.\"\"\"\n    seq_block_id = tl.program_id(1)\n    head_id = tl.program_id(0)\n\n    pos_offset = seq_block_id * BLOCK + tl.arange(0, BLOCK)\n    pos_mask = pos_offset < seq_len\n    pos_offset = tl.max_contiguous(tl.multiple_of(pos_offset % seq_len, BLOCK), BLOCK)\n\n    feat_size = half_size * 2\n    feat_offset_l = tl.arange(0, BLOCK_N)\n    feat_mask = feat_offset_l < half_size\n    feat_offset_l = feat_offset_l % half_size\n    feat_offset_h = half_size + feat_offset_l\n    seq_mask = pos_mask[:, None] & feat_mask[None, :]\n    cs_offset_l = pos_offset[:, None] * feat_size + feat_offset_l[None, :]\n    cs_offset_h = pos_offset[:, None] * feat_size + feat_offset_h[None, :]\n    q_elem_type = Q.dtype.element_ty\n    cos_l = tl.load(COS + cs_offset_l).to(q_elem_type)\n    cos_h = tl.load(COS + cs_offset_h).to(q_elem_type)\n    sin_l = tl.load(SIN + cs_offset_l).to(q_elem_type)\n    sin_h = tl.load(SIN + cs_offset_h).to(q_elem_type)\n\n    if head_id < BLOCK_QH:\n        q_ptr = Q + pos_offset * stride_qs\n        qe_ptr = Q_EMB + pos_offset * stride_qes\n        ql_ptrs = q_ptr[:, None] + feat_offset_l[None, :] * stride_qd\n        qh_ptrs = q_ptr[:, None] + feat_offset_h[None, :] * stride_qd\n        qel_ptrs = qe_ptr[:, None] + feat_offset_l[None, :] * stride_qed\n        qeh_ptrs = qe_ptr[:, None] + feat_offset_h[None, :] * stride_qed\n        ql_ptrs += head_id * stride_qh\n        qh_ptrs += head_id * stride_qh\n        qel_ptrs += head_id * stride_qeh\n        qeh_ptrs += head_id * stride_qeh\n\n        q_l = tl.load(ql_ptrs)\n        q_h = tl.load(qh_ptrs)\n\n        qe_l, qe_h = _apply_rotary_impl(q_l, q_h, cos_l, cos_h, sin_l, sin_h)\n\n        tl.store(qel_ptrs, qe_l, mask=seq_mask)\n        tl.store(qeh_ptrs, qe_h, mask=seq_mask)\n    else:\n        head_id = head_id - BLOCK_QH\n        k_ptr = K + pos_offset * stride_ks\n        ke_ptr = K_EMB + pos_offset * stride_kes\n        kl_ptrs = k_ptr[:, None] + feat_offset_l[None, :] * stride_kd\n        kh_ptrs = k_ptr[:, None] + feat_offset_h[None, :] * stride_kd\n        kel_ptrs = ke_ptr[:, None] + feat_offset_l[None, :] * stride_ked\n        keh_ptrs = ke_ptr[:, None] + feat_offset_h[None, :] * stride_ked\n        kl_ptrs += head_id * stride_kh\n        kh_ptrs += head_id * stride_kh\n        kel_ptrs += head_id * stride_keh\n        keh_ptrs += head_id * stride_keh\n        k_l = tl.load(kl_ptrs)\n        k_h = tl.load(kh_ptrs)\n\n        ke_l, ke_h = _apply_rotary_impl(k_l, k_h, cos_l, cos_h, sin_l, sin_h)\n\n        tl.store(kel_ptrs, ke_l, mask=seq_mask)\n        tl.store(keh_ptrs, ke_h, mask=seq_mask)\n\n\ndef apply_rotary_pos_emb(q: Tensor,\n                         k: Tensor,\n                         cos: Tensor,\n                         sin: Tensor,\n                         q_embed: Tensor = None,\n                         k_embed: Tensor = None):\n    \"\"\"Apply rotary positional embedding on query and key.\n\n    Args:\n        q (Tensor): Query state.\n        k (Tensor): Key state.\n        cos (Tensor): cosine matrix (seq_len, dim).\n        sin (Tensor): sine matrix (seq_len, dim).\n        q_embed (Tensor): output q, can be same as q\n        k_embed (Tensor): output k, can be same as k\n\n    Returns:\n        Tuple[Tensor, Tensor]: Embedded query and key.\n    \"\"\"\n    if cos.device != q.device:\n        cos = cos.to(device=q.device)\n    if sin.device != q.device:\n        sin = sin.to(device=q.device)\n\n    if q_embed is None:\n        q_embed = torch.empty_like(q)\n    if k_embed is None:\n        k_embed = torch.empty_like(k)\n\n    seq_len = cos.numel() // cos.size(-1)\n\n    if q.size(-1) == cos.size(-1):\n        half_size = q.size(-1) // 2\n    elif q.size(-1) > cos.size(-1):\n        # only do rope with rope_dim size\n        half_size = cos.size(-1) // 2\n    else:\n        raise ValueError('Not support head_dim < rope_dim, '\n                         f'but given head_dim={q.size(-1)} '\n                         f'rope_dim={cos.size(-1)}')\n    BLOCK_N = triton.next_power_of_2(half_size)\n    num_heads_q = q.size(-2)\n    num_heads_k = k.size(-2)\n    num_warps = 2\n    num_stages = 1\n\n    # compute best BLOCK size\n    num_threads = num_warps * 32\n    elem_size = q.dtype.itemsize\n    elem_per_ldgv4 = 16 // elem_size\n    BLOCK = num_threads * elem_per_ldgv4 // BLOCK_N\n    BLOCK = max(1, BLOCK)\n\n    grid = (\n        num_heads_q + num_heads_k,\n        triton.cdiv(seq_len, BLOCK),\n    )\n    apply_rotary_pos_emb_qk_kernel[grid](q,\n                                         k,\n                                         cos,\n                                         sin,\n                                         q_embed,\n                                         k_embed,\n                                         seq_len=seq_len,\n                                         stride_qs=q.stride(-3),\n                                         stride_qh=q.stride(-2),\n                                         stride_qd=q.stride(-1),\n                                         stride_ks=k.stride(-3),\n                                         stride_kh=k.stride(-2),\n                                         stride_kd=k.stride(-1),\n                                         stride_qes=q_embed.stride(-3),\n                                         stride_qeh=q_embed.stride(-2),\n                                         stride_qed=q_embed.stride(-1),\n                                         stride_kes=k_embed.stride(-3),\n                                         stride_keh=k_embed.stride(-2),\n                                         stride_ked=k_embed.stride(-1),\n                                         half_size=half_size,\n                                         BLOCK=BLOCK,\n                                         BLOCK_QH=num_heads_q,\n                                         BLOCK_N=BLOCK_N,\n                                         num_warps=num_warps,\n                                         num_stages=num_stages)\n\n    return q_embed, k_embed\n"
  },
  {
    "path": "lmdeploy/pytorch/kernels/cuda/awq_kernels.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport torch\nimport triton\nfrom triton import language as tl\n\n\ndef get_cuda_autotune_config():\n    return [\n        triton.Config({\n            'BLOCK_SIZE_N': 64,\n            'GROUP_SIZE_M': 8,\n        }, num_stages=3, num_warps=4),\n        triton.Config({\n            'BLOCK_SIZE_N': 128,\n            'GROUP_SIZE_M': 8,\n        }, num_stages=3, num_warps=4),\n    ]\n\n\n@triton.jit\ndef _dequant_s4_to_f16x2(weight, shift: tl.constexpr, is_top: tl.constexpr):\n\n    immLut: tl.constexpr = (0xf0 & 0xcc) | 0xaa\n    BOTTOM_MASK: tl.constexpr = 0x000f000f\n    TOP_MASK: tl.constexpr = 0x00f000f0\n    I4s_TO_F16s_MAGIC_NUM: tl.constexpr = 0x64006400\n    FP16_TOP_MAGIC_NUM: tl.constexpr = 0x64006400\n    ONE_SIXTEENTH: tl.constexpr = 0x2c002c00\n    NEG_64: tl.constexpr = 0xd400d400\n\n    if shift:\n        weight = weight >> 8\n\n    if is_top:\n        return tl.inline_asm_elementwise(\"\"\"{\n        .reg .b32 tmp;\n        lop3.b32 tmp, $2, $3, $4, $5;\n        fma.rn.f16x2 tmp, tmp, $6, $7;\n        mov.b32 {$0, $1}, tmp;\n    }\"\"\",\n                                         '=h,=h,r,n,n,n,r,r',\n                                         args=[weight, TOP_MASK, I4s_TO_F16s_MAGIC_NUM, immLut, ONE_SIXTEENTH, NEG_64],\n                                         dtype=(tl.float16, tl.float16),\n                                         is_pure=True,\n                                         pack=1)\n    else:\n        return tl.inline_asm_elementwise(\"\"\"{\n        .reg .b32 tmp;\n        lop3.b32 tmp, $2, $3, $4, $5;\n        sub.f16x2 tmp, tmp, $6;\n        mov.b32 {$0, $1}, tmp;\n    }\"\"\",\n                                         '=h,=h,r,n,n,n,r',\n                                         args=[weight, BOTTOM_MASK, I4s_TO_F16s_MAGIC_NUM, immLut, FP16_TOP_MAGIC_NUM],\n                                         dtype=(tl.float16, tl.float16),\n                                         is_pure=True,\n                                         pack=1)\n\n\n@triton.jit\ndef _unpack_weight(weight):\n    \"\"\"Unpack weight.\"\"\"\n    # broadcast and shift\n    width: tl.constexpr = 8\n    BLOCK_SIZE_K: tl.constexpr = weight.shape[0]\n    BLOCK_SIZE_QN: tl.constexpr = weight.shape[1]\n    BLOCK_SIZE_N: tl.constexpr = BLOCK_SIZE_QN * width\n\n    w0, w1 = _dequant_s4_to_f16x2(weight, False, False)\n    w2, w3 = _dequant_s4_to_f16x2(weight, False, True)\n    w4, w5 = _dequant_s4_to_f16x2(weight, True, False)\n    w6, w7 = _dequant_s4_to_f16x2(weight, True, True)\n\n    w04 = tl.join(w0, w4)\n    w15 = tl.join(w1, w5)\n    w26 = tl.join(w2, w6)\n    w37 = tl.join(w3, w7)\n    w0246 = tl.join(w04, w26)\n    w1357 = tl.join(w15, w37)\n    weight = tl.join(w0246, w1357)\n\n    return weight.reshape(BLOCK_SIZE_K, BLOCK_SIZE_N)\n\n\n@triton.autotune(\n    configs=get_cuda_autotune_config(),\n    key=['N', 'K'],\n    reset_to_zero=['c_ptr'],\n)\n@triton.jit\ndef awq_linear_kernel(\n        a_ptr,\n        qw_ptr,\n        s_ptr,\n        qz_ptr,\n        c_ptr,\n        M,\n        N: tl.constexpr,\n        K: tl.constexpr,\n        stride_am,\n        stride_ak: tl.constexpr,  #\n        stride_wk: tl.constexpr,\n        stride_wn: tl.constexpr,  #\n        stride_sk: tl.constexpr,\n        stride_sn: tl.constexpr,  #\n        stride_zk: tl.constexpr,\n        stride_zn: tl.constexpr,  #\n        stride_cm,\n        stride_cn: tl.constexpr,\n        # Meta-parameters\n        SPLIT_K: tl.constexpr,\n        NUM_STAGES: tl.constexpr,\n        BLOCK_SIZE_M: tl.constexpr,\n        BLOCK_SIZE_N: tl.constexpr,\n        BLOCK_SIZE_K: tl.constexpr,  #\n        GROUP_SIZE_M: tl.constexpr,  #\n):\n    \"\"\"Kernel for computing the matmul C = A x B.\n\n    A has shape (M, K), B has shape (K, N) and C has shape (M, N)\n    \"\"\"\n\n    # -----------------------------------------------------------\n    # Map program ids `pid` to the block of C it should compute.\n    # This is done in a grouped ordering to promote L2 data reuse.\n    # See above `L2 Cache Optimizations` section for details.\n    kid = tl.program_id(axis=1)\n    pid = tl.program_id(axis=0)\n    num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n    num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n    num_pid_in_group = GROUP_SIZE_M * num_pid_n\n    group_id = pid // num_pid_in_group\n    first_pid_m = group_id * GROUP_SIZE_M\n    group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n    pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)\n    pid_n = (pid % num_pid_in_group) // group_size_m\n\n    # ----------------------------------------------------------\n    # Create pointers for the first blocks of A and B.\n    offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M\n    offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N\n    BLOCK_SIZE_QN: tl.constexpr = BLOCK_SIZE_N // 8\n    offs_wn = pid_n * BLOCK_SIZE_QN + tl.arange(0, BLOCK_SIZE_QN)\n    offs_k = tl.arange(0, BLOCK_SIZE_K)\n    a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)\n    qw_ptrs = qw_ptr + (offs_k[:, None] * stride_wk + offs_wn[None, :] * stride_wn)\n    s_ptrs = s_ptr + offs_bn * stride_sn\n    qz_ptrs = qz_ptr + offs_wn * stride_zn\n\n    # -----------------------------------------------------------\n    # Iterate to compute a block of the C matrix.\n    # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block\n    # of fp32 values for higher accuracy.\n    # `accumulator` will be converted back to fp16 after the loop.\n    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n\n    k_start = kid\n    k_last = K // BLOCK_SIZE_K\n\n    # prefetch\n    a_ptrs += k_start * BLOCK_SIZE_K * stride_ak\n    qw_ptrs += k_start * BLOCK_SIZE_K * stride_wk\n    s_ptrs += k_start * stride_sk\n    qz_ptrs += k_start * stride_zk\n    qw = tl.load(qw_ptrs)\n    qz = tl.load(qz_ptrs)[None, :]\n    s = tl.load(s_ptrs)[None, :]\n    qw_ptrs += SPLIT_K * BLOCK_SIZE_K * stride_wk\n    s_ptrs += SPLIT_K * stride_sk\n    qz_ptrs += SPLIT_K * stride_zk\n\n    for k in tl.range(k_start, k_last, SPLIT_K, num_stages=NUM_STAGES):\n\n        # unpack b\n        z = _unpack_weight(qz)\n        w = _unpack_weight(qw)\n        b = (w - z) * s\n\n        # load a\n        a = tl.load(a_ptrs)\n\n        # load next q\n        mask = k + SPLIT_K < k_last\n        qz = tl.load(qz_ptrs, mask=mask)[None, :]\n        s = tl.load(s_ptrs, mask=mask)[None, :]\n        qw = tl.load(qw_ptrs, mask=mask)\n\n        # We accumulate along the K dimension.\n        accumulator = tl.dot(a, b, acc=accumulator)\n\n        # Advance the ptrs to the next K block.\n        a_ptrs += SPLIT_K * BLOCK_SIZE_K * stride_ak\n        qw_ptrs += SPLIT_K * BLOCK_SIZE_K * stride_wk\n        s_ptrs += SPLIT_K * stride_sk\n        qz_ptrs += SPLIT_K * stride_zk\n\n    c = accumulator.to(tl.float16)\n\n    # -----------------------------------------------------------\n    # Write back the block of the output matrix C with masks.\n    offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n    offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n    c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]\n    c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)\n\n    if SPLIT_K > 1:\n        tl.atomic_add(c_ptrs, c, mask=c_mask, sem='relaxed', scope='gpu')\n    else:\n        tl.store(c_ptrs, c, mask=c_mask)\n\n\ndef awq_linear(x, qweight, scales, qzeros):\n    \"\"\"Awq linear.\"\"\"\n    M = x.size(0)\n    K = qweight.size(0)\n    N = scales.size(1)\n    group_size = K // scales.size(0)\n    SPLIT_K = max(1, K // 4096)\n\n    def grid(META):\n        \"\"\"grid.\"\"\"\n        return (\n            triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),\n            SPLIT_K,\n        )\n\n    if SPLIT_K > 1:\n        out = scales.new_zeros(M, N)\n    else:\n        out = scales.new_empty(M, N)\n\n    props = torch.cuda.get_device_properties(x.device)\n    if props.major == 9:\n        num_stages = 2\n    elif props.major == 8 and props.minor in [6, 9]:\n        num_stages = 2\n    else:\n        num_stages = 3\n\n    BLOCK_SIZE_M = triton.next_power_of_2(M)\n    BLOCK_SIZE_M = max(16, min(128, BLOCK_SIZE_M))\n    awq_linear_kernel[grid](\n        # Pointers to matrices\n        x,\n        qweight,\n        scales,\n        qzeros,\n        out,\n        # Matrix dimensions\n        M,\n        N,\n        K,\n        stride_am=x.stride(0),\n        stride_ak=x.stride(1),  #\n        stride_wk=qweight.stride(0),\n        stride_wn=qweight.stride(1),  #\n        stride_sk=scales.stride(0),\n        stride_sn=scales.stride(1),  #\n        stride_zk=qzeros.stride(0),\n        stride_zn=qzeros.stride(1),  #\n        stride_cm=out.stride(0),\n        stride_cn=out.stride(1),\n        # Meta-parameters\n        BLOCK_SIZE_M=BLOCK_SIZE_M,\n        BLOCK_SIZE_K=group_size,\n        SPLIT_K=SPLIT_K,\n        NUM_STAGES=num_stages,\n    )\n\n    return out\n"
  },
  {
    "path": "lmdeploy/pytorch/kernels/cuda/bitonic_topk.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport torch\nimport triton\nimport triton.language as tl\nfrom triton.language import core\nfrom triton.language.standard import _log2\n\ntry:\n    # For Triton >= 3.6.0, core.get_int_dtype must be wrapped with\n    # triton.runtime.jit.constexpr_function to be usable as a constexpr helper\n    # inside @triton.jit kernels. This try/except keeps compatibility with\n    # older Triton versions where constexpr_function is not available.\n    get_int_dtype = triton.runtime.jit.constexpr_function(core.get_int_dtype)\nexcept Exception:\n    # fallback to original function if constexpr_function is not available (Triton < 3.6.0)\n    get_int_dtype = core.get_int_dtype\n\n\n@triton.jit\ndef _indicator(n_dims: core.constexpr, j: core.constexpr):\n    ar = core.arange(0, 2)\n    ar = core.reshape(ar, [1] * (n_dims - j - 1) + [2] + [1] * j)\n    return ar\n\n\n@triton.jit\ndef _flip_along_middle(x, n_dims, i):\n    idtype = get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True)\n    ix = x.to(idtype, bitcast=True)\n    iy = ix ^ tl.xor_sum(ix, n_dims - 1 - i, True)\n    y = iy.to(x.dtype, bitcast=True)\n    return y\n\n\n@triton.jit\ndef _compare_and_swap(x, ids, flip, i: core.constexpr):\n    # compare-and-swap on the ith *innermost* dimension\n    n_dims: core.constexpr = _log2(x.numel)\n\n    # determines whether we are in the right (rather than left) position along the axis:\n    is_right = _indicator(n_dims, i)\n\n    # flip along middle dimension (the bitwise XORs will be optimised away):\n    y = _flip_along_middle(x, n_dims, i)\n    ids_y = _flip_along_middle(ids, n_dims, i)\n\n    # conditional swap:\n    mask = (x > y) != (flip ^ is_right)\n    ret_x = core.where(mask, y, x)\n    ret_ids = core.where(mask, ids_y, ids)\n    return ret_x, ret_ids\n\n\n@triton.jit\ndef _bitonic_merge_hypercube(x, ids, stage: core.constexpr, order: core.constexpr):\n    \"\"\"order_type 0 == ascending order_type 1 == descending order_type 2 ==\n    alternating.\"\"\"\n    # flip denotes whether to re-arrange sub-sequences of elements in ascending or\n    # descending order.\n    # if flip = 00000000... then all elements will be re-arranged ascendingly at this stage\n    # if flip = 00110011... then all the elements will be re-arranged alternatingly (with\n    # a stride of 2) at this stage\n    if order == 2:\n        flip = _indicator(_log2(x.numel), stage)\n    else:\n        flip = order\n    # perform `stage` rounds of `compare-and-swap`\n    for i in core.static_range(stage):\n        x, ids = _compare_and_swap(x, ids, flip, stage - 1 - i)\n    return x, ids\n\n\n@triton.jit\ndef _bitonic_merge(x, ids, stage: tl.constexpr, order: tl.constexpr, n_dims: tl.constexpr):\n    \"\"\"order_type 0 == ascending order_type 1 == descending order_type 2 ==\n    alternating.\"\"\"\n    h = core.reshape(x, [2] * _log2(x.numel))\n    h_ids = core.reshape(ids, [2] * _log2(x.numel))\n    h, h_ids = _bitonic_merge_hypercube(h, h_ids, stage, order)\n    x = core.reshape(h, x.shape)\n    ids = core.reshape(h_ids, ids.shape)\n    return x, ids\n\n\n@triton.jit\ndef argsort(x, ids, dim: tl.constexpr = None, descending: tl.constexpr = core.CONSTEXPR_0):\n    # handle default dimension or check that it is the most minor dim\n    _dim: tl.constexpr = len(x.shape) - 1 if dim is None else dim\n    tl.static_assert(_dim == len(x.shape) - 1, 'only minor dimension is currently supported')\n    # iteratively run bitonic merge-sort steps\n    n_dims: tl.constexpr = _log2(x.shape[_dim])\n\n    for i in tl.static_range(1, n_dims + 1):\n        x, ids = _bitonic_merge(x, ids, i, 2 if i < n_dims else descending, n_dims)\n    return x, ids\n\n\n@triton.jit\ndef _bitonic_topk_kernel0(score_ptr,\n                          seqlen_ptr,\n                          out_ptr,\n                          ids_ptr,\n                          stride_m,\n                          K: tl.constexpr,\n                          fill: tl.constexpr,\n                          descending: tl.constexpr = core.CONSTEXPR_0,\n                          sorted: tl.constexpr = True):\n    \"\"\"kernel0.\"\"\"\n    batch_id = tl.program_id(0).to(tl.int64)\n    block_id = tl.program_id(1).to(tl.int64)\n\n    seqlen = tl.load(seqlen_ptr + batch_id)\n\n    if block_id * K >= seqlen:\n        return\n\n    offs_k = tl.arange(0, K)\n    origin_ids = block_id * K + offs_k\n    # num scores should less than max(int32), I guess\n    origin_ids = origin_ids.to(tl.int32)\n    mask = (origin_ids < seqlen)\n    score_ptrs = score_ptr + batch_id * stride_m + origin_ids\n    scores = tl.load(score_ptrs, mask=mask, other=-1e6)\n    ids = tl.where(mask, origin_ids, fill)\n    ids = origin_ids\n\n    if sorted or (seqlen > K):\n        scores, ids = argsort(scores, ids, 0, descending)\n\n    tl.store(out_ptr + batch_id * stride_m + origin_ids, scores, mask=mask)\n    tl.store(ids_ptr + batch_id * stride_m + origin_ids, ids, mask=mask)\n\n\n@triton.jit\ndef _concate(a, b):\n    \"\"\"concate.\"\"\"\n    c = tl.join(a, b)  # [k, 2]\n    c = c.trans()  # [2, k]\n    # there are bugs in `tr.ravel` when triton<=3.2.0\n    c = tl.reshape(c, (a.numel + b.numel, ))\n    return c\n\n\n@triton.jit\ndef _split(a, k):\n    \"\"\"split.\"\"\"\n    a = a.reshape(2, k)\n    a = a.trans()\n    return tl.split(a)\n\n\n@triton.jit\ndef _bitonic_topk_kernel1(score_ptr,\n                          ids_ptr,\n                          seqlen_ptr,\n                          out_ptr,\n                          stride_m,\n                          K: tl.constexpr,\n                          fill: tl.constexpr,\n                          threshold: tl.constexpr,\n                          descending: tl.constexpr = core.CONSTEXPR_0):\n    \"\"\"kernel1.\"\"\"\n    batch_id = tl.program_id(0).to(tl.int64)\n\n    seqlen = tl.load(seqlen_ptr + batch_id)\n    offs_k = tl.arange(0, K)\n    score_ptrs = score_ptr + batch_id * stride_m + offs_k\n    ids_ptrs = ids_ptr + batch_id * stride_m + offs_k\n\n    # initialize\n    pos = offs_k\n    mask = pos < seqlen\n    scores = tl.load(score_ptrs, mask=mask, other=threshold)\n    ids = tl.load(ids_ptrs, mask=mask, other=fill)\n\n    pos = 2 * K - 1 - offs_k\n    score_ptrs = score_ptr + batch_id * stride_m + pos\n    ids_ptrs = ids_ptr + batch_id * stride_m + pos\n\n    stage: tl.constexpr = _log2(2 * K)\n    for k in tl.range(K, seqlen, K, num_stages=3):\n        mask = pos < seqlen\n        new_scores = tl.load(score_ptrs, mask=mask, other=threshold)\n        new_ids = tl.load(ids_ptrs, mask=mask, other=fill)\n\n        merged_scores = _concate(scores, new_scores)\n        merged_ids = _concate(ids, new_ids)\n\n        merged_scores, merged_ids = _bitonic_merge(merged_scores, merged_ids, stage, descending, stage)\n\n        scores, _ = _split(merged_scores, K)\n        ids, _ = _split(merged_ids, K)\n        score_ptrs += K\n        ids_ptrs += K\n        pos += K\n\n    out_ptrs = out_ptr + batch_id * K + offs_k\n    ids = tl.where(scores <= threshold, fill, ids)\n    tl.store(out_ptrs, ids)\n\n\ndef bitonic_topk(scores: torch.Tensor,\n                 q_seqlens: torch.Tensor,\n                 kv_seqlens: torch.Tensor,\n                 k: int,\n                 fill: int = -1,\n                 descending: bool = True,\n                 sorted: bool = True,\n                 threshold: float = -1e6):\n    \"\"\"Bitnoic topk.\"\"\"\n    num_tokens = scores.size(0)\n    max_kv_len = scores.size(-1)\n    assert max_kv_len < (1 << 31)\n\n    if num_tokens != kv_seqlens.size(0):\n        repeat_kv_seqlens = torch.repeat_interleave(kv_seqlens, q_seqlens, output_size=num_tokens)\n    else:\n        repeat_kv_seqlens = kv_seqlens\n    tmp_scores = torch.empty_like(scores)\n    tmp_ids = torch.empty_like(scores, dtype=torch.int32)\n    num_warps = triton.cdiv(k, 4096)\n    grid = (num_tokens, triton.cdiv(max_kv_len, k))\n    _bitonic_topk_kernel0[grid](scores,\n                                repeat_kv_seqlens,\n                                tmp_scores,\n                                tmp_ids,\n                                stride_m=scores.stride(0),\n                                K=k,\n                                fill=fill,\n                                descending=1 if descending else 0,\n                                sorted=sorted,\n                                num_warps=num_warps)\n\n    out = kv_seqlens.new_empty((num_tokens, k), dtype=torch.int32)\n    _bitonic_topk_kernel1[(num_tokens, )](tmp_scores,\n                                          tmp_ids,\n                                          repeat_kv_seqlens,\n                                          out,\n                                          stride_m=tmp_scores.stride(0),\n                                          K=k,\n                                          fill=fill,\n                                          descending=1 if descending else 0,\n                                          threshold=threshold,\n                                          num_warps=num_warps * 2)\n    return out\n"
  },
  {
    "path": "lmdeploy/pytorch/kernels/cuda/blocked_fp8_fused_moe.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n# modify from: https://github.com/vllm-project/vllm\nfrom typing import Callable\n\nimport torch\nimport triton\nimport triton.language as tl\n\nfrom .activation import silu_and_mul\nfrom .blocked_gemm_fp8 import quant_fp8\nfrom .fused_moe import _get_sorted_idx, _make_intermediate, _renormalize, moe_reduce\n\n\ndef get_cuda_autotune_config():\n    return [\n        triton.Config({\n            'BLOCK_SIZE_M': 128,\n            'BLOCK_SIZE_N': 128,\n        }, num_stages=3, num_warps=4),\n    ]\n\n\n@triton.autotune(\n    configs=get_cuda_autotune_config(),\n    key=['N', 'K', 'M_NP2'],\n)\n@triton.jit\ndef fused_moe_blocked_f8_kernel(\n    A,\n    A_scale,\n    B,\n    B_scale,\n    bias,\n    C,\n    SortedIdx,\n    ExpStart,\n    ExpEnd,\n    N: tl.constexpr,\n    K: tl.constexpr,\n    group_ak: tl.constexpr,\n    group_bk: tl.constexpr,\n    group_bn: tl.constexpr,\n    stride_am: tl.constexpr,\n    stride_ak: tl.constexpr,\n    stride_asm,\n    stride_ask: tl.constexpr,\n    stride_be: tl.constexpr,\n    stride_bn: tl.constexpr,\n    stride_bk: tl.constexpr,\n    stride_bse: tl.constexpr,\n    stride_bsk: tl.constexpr,\n    stride_bsn: tl.constexpr,\n    stride_bie: tl.constexpr,\n    stride_bin: tl.constexpr,\n    stride_cm,\n    stride_cn: tl.constexpr,\n    BLOCK_SIZE_M: tl.constexpr,\n    BLOCK_SIZE_N: tl.constexpr,\n    BLOCK_SIZE_K: tl.constexpr,\n    GROUP_SIZE_M: tl.constexpr,\n    M_NP2: tl.constexpr,\n    top_k: tl.constexpr,\n    expert_offset: tl.constexpr,\n    reindex_a: tl.constexpr,\n    reindex_c: tl.constexpr,\n):\n    \"\"\"Fused moe kernel.\"\"\"\n    exp_id = tl.program_id(1)\n    pid = tl.program_id(0)\n\n    exp_start = tl.load(ExpStart + exp_id + expert_offset)\n    exp_end = tl.load(ExpEnd + exp_id + expert_offset)\n    M = exp_end - exp_start\n    if M <= 0:\n        return\n\n    num_pid_m = tl.cdiv(M_NP2, BLOCK_SIZE_M)\n    num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n\n    if GROUP_SIZE_M == 1:\n        pid_m = pid % num_pid_m\n        pid_n = pid // num_pid_m\n    else:\n        num_pid_in_group = GROUP_SIZE_M * num_pid_n\n        group_id = pid // num_pid_in_group\n        first_pid_m = group_id * GROUP_SIZE_M\n        group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n        pid_m = first_pid_m + (pid % group_size_m)\n        pid_n = (pid % num_pid_in_group) // group_size_m\n\n    if pid_m * BLOCK_SIZE_M >= M or pid_n * BLOCK_SIZE_N >= N:\n        return\n\n    offs_sid = exp_start + pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n    mask_sid = offs_sid < exp_end\n    sid = tl.load(SortedIdx + offs_sid, mask=mask_sid, other=0)\n\n    offs_k = tl.arange(0, BLOCK_SIZE_K)\n    if reindex_a:\n        offs_am = sid // top_k\n    else:\n        offs_am = offs_sid\n    a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)\n    offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N\n    offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N)\n\n    # deepseek has 160 experts, exp index would overflow int32\n    exp_id = exp_id.to(tl.int64)\n    exp_off = stride_be * exp_id\n    b_ptrs = B + exp_off + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)\n\n    offs_bsn = pid_n * BLOCK_SIZE_N // group_bn\n    as_ptrs = A_scale + offs_am * stride_asm\n    bs_ptrs = B_scale + stride_bse * exp_id + offs_bsn * stride_bsn\n\n    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n\n    # initialize acc_ratio and acc_scale\n    a_scale = tl.load(as_ptrs, mask=mask_sid, other=1.0)\n    b_scale = tl.load(bs_ptrs)\n    acc_scale0 = a_scale * b_scale\n\n    k_start = BLOCK_SIZE_K\n    offs_ksa = k_start // group_ak\n    offs_ksb = k_start // group_bk\n    a_scale = tl.load(as_ptrs + offs_ksa * stride_ask, mask=mask_sid & (k_start < K), other=1.0)\n    b_scale = tl.load(bs_ptrs + offs_ksb * stride_bsk, mask=k_start < K, other=1.0)\n    acc_scale1 = tl.maximum(a_scale * b_scale, 1e-12)\n    acc_ratio = acc_scale0 / acc_scale1\n    acc_scale = acc_scale1\n\n    for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):\n        # load scales\n        k_start = (k + 2) * BLOCK_SIZE_K\n        offs_ksa = k_start // group_ak\n        offs_ksb = k_start // group_bk\n        a_scale = tl.load(as_ptrs + offs_ksa * stride_ask, mask=mask_sid & (k_start < K), other=1.0)\n        b_scale = tl.load(bs_ptrs + offs_ksb * stride_bsk, mask=k_start < K, other=1.0)\n\n        # load ab\n        a = tl.load(a_ptrs, mask=mask_sid[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), other=0.0)\n        b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)\n\n        # mma\n        accumulator = tl.dot(a, b, acc=accumulator)\n        accumulator *= acc_ratio[:, None]\n\n        # update scales and ratio\n        new_acc_scale = tl.maximum(a_scale * b_scale, 1e-12)\n        acc_ratio = acc_scale / new_acc_scale\n        acc_scale = new_acc_scale\n\n        a_ptrs += BLOCK_SIZE_K * stride_ak\n        b_ptrs += BLOCK_SIZE_K * stride_bk\n\n    c = accumulator * (acc_ratio * acc_scale)[:, None]\n\n    if bias is not None:\n        bias_ptrs = bias + exp_id * stride_bie + offs_bn * stride_bin\n        bias_val = tl.load(bias_ptrs).to(accumulator.dtype)\n        c += bias_val[None]\n\n    c = c.to(C.dtype.element_ty)\n\n    if reindex_c:\n        offs_cm = sid\n    else:\n        offs_cm = offs_sid\n    c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_bn[None, :]\n    tl.store(c_ptrs, c, mask=mask_sid[:, None])\n\n\ndef fused_moe_blocked_fp8_kernel_launcher(\n    A: torch.Tensor,\n    A_scale: torch.Tensor,\n    B: torch.Tensor,\n    B_scale: torch.Tensor,\n    C: torch.Tensor,\n    sorted_idx: torch.Tensor,\n    exp_start: torch.Tensor,\n    exp_end: torch.Tensor,\n    bias: torch.Tensor = None,\n    top_k: int = 1,\n    num_tokens: int = None,\n    expert_offset: int = 0,\n    reindex_a: bool = True,\n    reindex_c: bool = True,\n):\n    \"\"\"Fused moe kernel launcher.\"\"\"\n\n    if num_tokens is None:\n        num_tokens = A.size(0)\n    M_NP2 = triton.next_power_of_2(num_tokens)\n    M_NP2 = max(64, M_NP2)\n    E, N, K = B.shape\n\n    assert A.dim() == 2\n    assert A_scale.dim() == 2\n    assert B.dim() == 3\n    assert B_scale.dim() == 3\n\n    assert K % A_scale.size(1) == 0\n    assert K % B_scale.size(2) == 0\n    assert N % B_scale.size(1) == 0\n\n    group_ak = K // A_scale.size(1)\n    group_bk = K // B_scale.size(2)\n    group_bn = N // B_scale.size(1)\n\n    def _grid_fn(META):\n        grid = (triton.cdiv(M_NP2, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), E)\n        return grid\n\n    A = A.flatten(0, -2)\n    C = C.flatten(0, -2)\n    enable_bias = bias is not None\n\n    BLOCK_SIZE_K = group_bk\n    GROUP_SIZE_M = 1\n    grid = _grid_fn\n    fused_moe_blocked_f8_kernel[grid](\n        A,\n        A_scale,\n        B,\n        B_scale,\n        bias,\n        C,\n        sorted_idx,\n        exp_start,\n        exp_end,\n        N=N,\n        K=K,\n        group_ak=group_ak,\n        group_bk=group_bk,\n        group_bn=group_bn,\n        stride_am=A.stride(0),\n        stride_ak=A.stride(1),\n        stride_asm=A_scale.stride(0),\n        stride_ask=A_scale.stride(1),\n        stride_be=B.stride(0),\n        stride_bn=B.stride(1),\n        stride_bk=B.stride(2),\n        stride_bse=B_scale.stride(0),\n        stride_bsn=B_scale.stride(1),\n        stride_bsk=B_scale.stride(2),\n        stride_cm=C.stride(0),\n        stride_cn=C.stride(1),\n        stride_bie=bias.stride(0) if enable_bias else 0,\n        stride_bin=bias.stride(1) if enable_bias else 0,\n        top_k=top_k,\n        expert_offset=expert_offset,\n        reindex_a=reindex_a,\n        reindex_c=reindex_c,\n        M_NP2=M_NP2,\n        BLOCK_SIZE_K=BLOCK_SIZE_K,\n        GROUP_SIZE_M=GROUP_SIZE_M,\n    )\n\n\ndef fused_moe_blocked_fp8(input: torch.Tensor,\n                          input_scale: torch.Tensor,\n                          w1: torch.Tensor,\n                          w1_scale: torch.Tensor,\n                          w2: torch.Tensor,\n                          w2_scale: torch.Tensor,\n                          topk_weights: torch.Tensor,\n                          topk_ids: torch.Tensor,\n                          topk: int,\n                          w1_bias: torch.Tensor = None,\n                          w2_bias: torch.Tensor = None,\n                          out_dtype: torch.dtype = torch.float16,\n                          expert_offset: int = 0,\n                          num_experts: int = None,\n                          renormalize: bool = False,\n                          act_func: Callable = None) -> torch.Tensor:\n    \"\"\"Fused moe.\"\"\"\n    device = input.device\n    M = input.size(0)\n    E, N, _ = w1.shape\n    if num_experts is None:\n        num_experts = E\n    full_exp = num_experts == E\n    group_size = input.size(-1) // input_scale.size(-1)\n\n    topk_weights = _renormalize(topk_weights, renormalize)\n    sorted_idx, exp_start, exp_end = _get_sorted_idx(topk_ids, num_experts)\n\n    intermediate_cache1 = _make_intermediate((M, topk, N), dtype=out_dtype, device=device, zeros=not full_exp)\n    # gate and up\n    fused_moe_blocked_fp8_kernel_launcher(\n        input,\n        input_scale,\n        w1,\n        w1_scale,\n        intermediate_cache1,\n        sorted_idx=sorted_idx,\n        exp_start=exp_start,\n        exp_end=exp_end,\n        bias=w1_bias,\n        top_k=topk,\n        num_tokens=M,\n        expert_offset=expert_offset,\n        reindex_a=True,\n        reindex_c=False,\n    )\n\n    # activate\n    intermediate_cache1 = intermediate_cache1.flatten(0, -2)\n    if act_func is None:\n        gate_cache = silu_and_mul(intermediate_cache1)\n    else:\n        gate_cache = act_func(intermediate_cache1)\n    del intermediate_cache1\n    gate_cache, gate_scale = quant_fp8(gate_cache, group_size, dtype=input.dtype)\n\n    intermediate_cache2 = _make_intermediate((M, topk, w2.shape[1]), dtype=out_dtype, device=device, zeros=not full_exp)\n    # down\n    fused_moe_blocked_fp8_kernel_launcher(\n        gate_cache,\n        gate_scale,\n        w2,\n        w2_scale,\n        intermediate_cache2,\n        sorted_idx=sorted_idx,\n        exp_start=exp_start,\n        exp_end=exp_end,\n        bias=w2_bias,\n        top_k=1,\n        num_tokens=M,\n        expert_offset=expert_offset,\n        reindex_a=False,\n        reindex_c=True,\n    )\n\n    ret = moe_reduce(intermediate_cache2, topk_weights)\n    return ret\n"
  },
  {
    "path": "lmdeploy/pytorch/kernels/cuda/blocked_gemm_fp8.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom typing import Optional\n\nimport torch\nimport triton\nimport triton.language as tl\nfrom torch import Tensor\n\nfrom lmdeploy.utils import get_logger\n\nfrom .utils import get_device_props\n\nlogger = get_logger('lmdeploy')\n\n\n@triton.jit\ndef fast_log2_ceil(x):\n    bits_x = tl.cast(x, tl.uint32, bitcast=True)\n    exp_x = (bits_x >> 23) & 0xFF\n    man_bits = bits_x & ((1 << 23) - 1)\n    tmp = exp_x - 127 + tl.where(man_bits != 0, 1, 0)\n    return tl.cast(tmp, tl.int32)\n\n\n@triton.jit\ndef fast_pow2(x):\n    bits_x = (x + 127) << 23\n    return tl.cast(bits_x, tl.float32, bitcast=True)\n\n\n@triton.jit\ndef fast_round_scale(amax, fp8_max_inv):\n    return fast_pow2(fast_log2_ceil(amax * fp8_max_inv))\n\n\n@triton.jit(do_not_specialize=['M', 'M_out'])\ndef _quant_fp8_kernel(\n    a_ptr,\n    out_ptr,\n    scale_ptr,\n    M,\n    M_out,\n    K: tl.constexpr,\n    num_groups_per_cta: tl.constexpr,\n    fp8_min: tl.constexpr,\n    fp8_max: tl.constexpr,\n    stride_am,\n    stride_ak: tl.constexpr,\n    stride_om,\n    stride_ok: tl.constexpr,\n    stride_sm,\n    stride_sg,\n    ROUND_SCALE: tl.constexpr,\n    GROUP_SIZE: tl.constexpr,\n    NUM_STAGES: tl.constexpr,\n):\n    \"\"\"Quant fp8 kernel.\"\"\"\n    group_id = tl.program_id(0) * num_groups_per_cta\n    m_id_start = tl.program_id(1)\n    m_id_stride = tl.num_programs(1)\n\n    GROUP_SIZE_CTA: tl.constexpr = GROUP_SIZE * num_groups_per_cta\n    g_offs = group_id * GROUP_SIZE + tl.arange(0, GROUP_SIZE_CTA)\n    g_offs = tl.max_contiguous(tl.multiple_of(g_offs, GROUP_SIZE), GROUP_SIZE)\n    gs_offs = group_id + tl.arange(0, num_groups_per_cta)\n    rfp8_max = 1 / fp8_max\n\n    m_id = m_id_start\n    a_ptrs = a_ptr + m_id * stride_am + g_offs * stride_ak\n    o_ptrs = out_ptr + m_id * stride_om + g_offs * stride_ok\n    s_ptr = scale_ptr + m_id * stride_sm + gs_offs * stride_sg\n    if K % GROUP_SIZE_CTA == 0:\n        mask_n = True\n        mask_s = True\n        mask_o = True\n    else:\n        mask_n = g_offs < K\n        mask_o = g_offs < K\n        mask_s = gs_offs < tl.cdiv(K, GROUP_SIZE)\n\n    for m_id in tl.range(m_id_start, M_out, m_id_stride, num_stages=NUM_STAGES):\n        a = tl.load(a_ptrs, mask=mask_n & (m_id < M), other=0)\n        a = a.reshape(num_groups_per_cta, GROUP_SIZE)\n        a_max = tl.max(tl.abs(a), axis=1)\n        a_max = tl.maximum(a_max, 1e-6).to(tl.float32)\n        if ROUND_SCALE == 1:\n            scale = fast_round_scale(a_max, rfp8_max)\n            rscale = 1 / scale\n        else:\n            scale = a_max * rfp8_max\n            rscale = fp8_max / a_max  # triton does not support rcp\n        out = a.to(tl.float32) * rscale[:, None]\n\n        out = tl.clamp(out, fp8_min, fp8_max)\n        out = out.to(out_ptr.dtype.element_ty)\n        out = out.reshape(GROUP_SIZE * num_groups_per_cta)\n        tl.store(o_ptrs, out, mask=mask_o)\n        tl.store(s_ptr, scale, mask=mask_s)\n\n        a_ptrs += m_id_stride * stride_am\n        o_ptrs += m_id_stride * stride_om\n        s_ptr += m_id_stride * stride_sm\n\n\ndef _quant_fp8_launcher(A: Tensor, group_size: int, out: Tensor, scales: Tensor, scale_fmt: Optional[str] = None):\n    \"\"\"Quant online.\"\"\"\n    assert scale_fmt in (None, 'ue8m0')\n    round_scale = 1 if scale_fmt == 'ue8m0' else 0\n    M, K = A.shape\n    M_out = out.size(0)\n\n    dtype = out.dtype\n    finfo = torch.finfo(dtype)\n    fmin = finfo.min\n    fmax = finfo.max\n\n    num_warps = 2\n    # every cp/ldg instruct can load 128bit=16byte data\n    # each warp can read 512 byte data\n    elem_size = A.element_size()\n    num_groups_per_warp = 512 // (group_size * elem_size)\n    num_groups_per_cta = num_groups_per_warp * num_warps\n    grid_size0 = triton.cdiv(K, group_size * num_groups_per_cta)\n    props = get_device_props(A.device.index)\n    num_sm = props['multi_processor_count']\n    warps_per_sm = props['warps_per_sm']\n    blocks_per_sm = props['blocks_per_sm']\n    max_ctas = num_sm * min(blocks_per_sm, warps_per_sm // num_warps)\n    grid_size1 = min(M_out, max_ctas // grid_size0)\n    assert grid_size1 < 65536\n    num_stages = min(4, max(1, triton.cdiv(M_out, grid_size1)))\n    grid = (grid_size0, grid_size1)\n    _quant_fp8_kernel[grid](\n        A,\n        out,\n        scales,\n        M,\n        M_out,\n        K,\n        num_groups_per_cta=num_groups_per_cta,\n        fp8_min=fmin,\n        fp8_max=fmax,\n        stride_am=A.stride(0),\n        stride_ak=A.stride(1),\n        stride_om=out.stride(0),\n        stride_ok=out.stride(1),\n        stride_sm=scales.stride(0),\n        stride_sg=scales.stride(1),\n        ROUND_SCALE=round_scale,\n        GROUP_SIZE=group_size,\n        NUM_STAGES=num_stages,\n        num_warps=num_warps,\n        num_stages=num_stages,\n    )\n\n    return out, scales\n\n\ndef quant_fp8(A: Tensor,\n              group_size: int,\n              dtype: torch.dtype = torch.float8_e4m3fn,\n              trans_scale: bool = False,\n              scale_fmt: Optional[str] = None):\n    \"\"\"Quant fp8.\"\"\"\n    assert A.dim() == 2\n    M, K = A.shape\n    assert K % group_size == 0\n    num_groups = K // group_size\n    out = torch.empty_like(A, dtype=dtype)\n    if trans_scale:\n        scales = A.new_empty(num_groups, M, dtype=torch.float32).T\n    else:\n        scales = A.new_empty(M, num_groups, dtype=torch.float32)\n    return _quant_fp8_launcher(A, group_size, out, scales, scale_fmt=scale_fmt)\n\n\ndef quant_fp8_tma(A: Tensor,\n                  group_size: int,\n                  dtype: torch.dtype = torch.float8_e4m3fn,\n                  scale_fmt: Optional[str] = None):\n    \"\"\"Quant fp8 tma.\"\"\"\n    from lmdeploy.pytorch.third_party.deep_gemm import ceil_div, get_m_alignment_for_contiguous_layout\n    assert A.dim() == 2\n    M, K = A.shape\n    assert K % group_size == 0\n    num_groups = K // group_size\n    alignment = get_m_alignment_for_contiguous_layout()\n    aligned_M = ceil_div(M, alignment) * alignment\n    out = A.new_empty(aligned_M, K, dtype=dtype)\n    scales = A.new_empty(num_groups, aligned_M, dtype=torch.float32).T\n    return _quant_fp8_launcher(A, group_size, out, scales, scale_fmt=scale_fmt)\n\n\ndef _gemm_fp8_tma_pre_hook(nargs):\n    BLOCK_M = nargs['BLOCK_M']\n    BLOCK_N = nargs['BLOCK_N']\n    BLOCK_K = nargs['BLOCK_K']\n    nargs['desc_a'].block_shape = (BLOCK_M, BLOCK_K)\n    nargs['desc_b'].block_shape = (BLOCK_N, BLOCK_K)\n\n\n@triton.autotune(configs=[\n    triton.Config({\n        'BLOCK_M': 128,\n        'BLOCK_N': 128,\n    }, num_stages=3, num_warps=8, pre_hook=_gemm_fp8_tma_pre_hook),\n    triton.Config({\n        'BLOCK_M': 128,\n        'BLOCK_N': 64,\n    }, num_stages=3, num_warps=4, pre_hook=_gemm_fp8_tma_pre_hook)\n],\n                 key=['N', 'K'])\n@triton.jit\ndef _gemm_fp8_tma_kernel(\n    desc_a,\n    a_scale_ptr,\n    desc_b,\n    b_scale_ptr,\n    C,\n    M,\n    N: tl.constexpr,\n    K: tl.constexpr,\n    group_ak: tl.constexpr,\n    group_bk: tl.constexpr,\n    group_bn: tl.constexpr,\n    stride_asm: tl.constexpr,\n    stride_ask,\n    stride_bsk: tl.constexpr,\n    stride_bsn: tl.constexpr,\n    stride_cm,\n    stride_cn: tl.constexpr,\n    BLOCK_M: tl.constexpr,\n    BLOCK_N: tl.constexpr,\n    BLOCK_K: tl.constexpr,\n    GROUP_M: tl.constexpr,\n):\n    \"\"\"Gemm fp8 kernel.\"\"\"\n    pid = tl.program_id(axis=0)\n    num_pid_m = tl.cdiv(M, BLOCK_M)\n    num_pid_n = tl.cdiv(N, BLOCK_N)\n    num_pid_in_group = GROUP_M * num_pid_n\n    group_id = pid // num_pid_in_group\n    first_pid_m = group_id * GROUP_M\n    group_size_m = min(num_pid_m - first_pid_m, GROUP_M)\n    pid_m = first_pid_m + (pid % group_size_m)\n    pid_n = (pid % num_pid_in_group) // group_size_m\n\n    offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M\n\n    offs_bsn = pid_n * BLOCK_N // group_bn\n    as_ptrs = a_scale_ptr + offs_am * stride_asm\n    bs_ptrs = b_scale_ptr + offs_bsn * stride_bsn\n\n    acc_scale = tl.load(as_ptrs) * tl.load(bs_ptrs)\n    acc_ratio = 1 / acc_scale\n    accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)\n    off_m = pid_m * BLOCK_M\n    off_n = pid_n * BLOCK_N\n    off_k = 0\n    for k in range(0, tl.cdiv(K, BLOCK_K)):\n        # load scales\n        k_start = (k + 1) * BLOCK_K\n        offs_ksa = k_start // group_ak\n        offs_ksb = k_start // group_bk\n        a_scale = tl.load(as_ptrs + offs_ksa * stride_ask, mask=k_start < K, other=1.0)\n        b_scale = tl.load(bs_ptrs + offs_ksb * stride_bsk, mask=k_start < K, other=1.0)\n\n        # load ab\n        a = desc_a.load([off_m, off_k])\n        b = desc_b.load([off_n, off_k]).T\n\n        # mma\n        accumulator = tl.dot(a, b, acc=accumulator * acc_ratio[:, None])\n\n        # update scales and ratio\n        new_acc_scale = a_scale * b_scale\n        acc_ratio = acc_scale / new_acc_scale\n        acc_scale = new_acc_scale\n\n        off_k += BLOCK_K\n    c = accumulator * (acc_ratio * acc_scale)[:, None]\n\n    offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n    offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n    c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]\n    c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)\n    tl.store(c_ptrs, c, mask=c_mask)\n\n\n@triton.autotune(configs=[\n    triton.Config({\n        'BLOCK_M': 64,\n        'BLOCK_N': 128,\n    }, num_stages=3, num_warps=4),\n    triton.Config({\n        'BLOCK_M': 128,\n        'BLOCK_N': 64,\n    }, num_stages=3, num_warps=4)\n],\n                 key=['N', 'K'])\n@triton.jit\ndef _gemm_fp8_kernel(\n    A,\n    a_scale_ptr,\n    B,\n    b_scale_ptr,\n    C,\n    M,\n    N: tl.constexpr,\n    K: tl.constexpr,\n    group_ak: tl.constexpr,\n    group_bk: tl.constexpr,\n    group_bn: tl.constexpr,\n    stride_am,\n    stride_ak: tl.constexpr,\n    stride_asm: tl.constexpr,\n    stride_ask,\n    stride_bk: tl.constexpr,\n    stride_bn: tl.constexpr,\n    stride_bsk: tl.constexpr,\n    stride_bsn: tl.constexpr,\n    stride_cm,\n    stride_cn: tl.constexpr,\n    BLOCK_M: tl.constexpr,\n    BLOCK_N: tl.constexpr,\n    BLOCK_K: tl.constexpr,\n    GROUP_M: tl.constexpr,\n):\n    \"\"\"Gemm fp8 kernel.\"\"\"\n    pid = tl.program_id(axis=0)\n    num_pid_m = tl.cdiv(M, BLOCK_M)\n    num_pid_n = tl.cdiv(N, BLOCK_N)\n    num_pid_in_group = GROUP_M * num_pid_n\n    group_id = pid // num_pid_in_group\n    first_pid_m = group_id * GROUP_M\n    group_size_m = min(num_pid_m - first_pid_m, GROUP_M)\n    pid_m = first_pid_m + (pid % group_size_m)\n    pid_n = (pid % num_pid_in_group) // group_size_m\n\n    offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M\n    offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N\n    offs_k = tl.arange(0, BLOCK_K)\n    a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)\n    b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)\n\n    offs_bsn = pid_n * BLOCK_N // group_bn\n    as_ptrs = a_scale_ptr + offs_am * stride_asm\n    bs_ptrs = b_scale_ptr + offs_bsn * stride_bsn\n\n    acc_scale = tl.load(as_ptrs) * tl.load(bs_ptrs)\n    acc_ratio = 1 / acc_scale\n    accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)\n    for k in range(0, tl.cdiv(K, BLOCK_K)):\n        # load scales\n        k_start = (k + 1) * BLOCK_K\n        offs_ksa = k_start // group_ak\n        offs_ksb = k_start // group_bk\n        a_scale = tl.load(as_ptrs + offs_ksa * stride_ask, mask=k_start < K, other=1.0)\n        b_scale = tl.load(bs_ptrs + offs_ksb * stride_bsk, mask=k_start < K, other=1.0)\n\n        # load ab\n        a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_K, other=0.0)\n        b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0.0)\n\n        # mma\n        accumulator = tl.dot(a, b, acc=accumulator * acc_ratio[:, None])\n\n        # update scales and ratio\n        new_acc_scale = a_scale * b_scale\n        acc_ratio = acc_scale / new_acc_scale\n        acc_scale = new_acc_scale\n\n        a_ptrs += BLOCK_K * stride_ak\n        b_ptrs += BLOCK_K * stride_bk\n    c = accumulator * (acc_ratio * acc_scale)[:, None]\n\n    offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n    offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n    c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]\n    c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)\n    tl.store(c_ptrs, c, mask=c_mask)\n\n\ndef blocked_gemm_fp8(A: Tensor,\n                     A_scale: Tensor,\n                     B: Tensor,\n                     B_scale: torch.Tensor,\n                     out_dtype: torch.dtype = torch.float16):\n    \"\"\"Gemm fp8.\"\"\"\n\n    def grid(META):\n        return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )\n\n    assert A.dim() == 2\n    assert A_scale.dim() == 2\n    assert B.dim() == 2\n    assert B_scale.dim() == 2\n\n    M, K = A.shape\n    _, N = B.shape\n\n    group_ak = triton.cdiv(K, A_scale.size(1))\n    group_bk = triton.cdiv(K, B_scale.size(0))\n    group_bn = triton.cdiv(N, B_scale.size(1))\n\n    C = A.new_empty(M, N, dtype=out_dtype)\n\n    BLOCK_K = max(group_ak, group_bk)\n\n    from .utils import supports_tma\n\n    run_tma = supports_tma()\n    run_tma = run_tma and A.is_contiguous() and B.T.is_contiguous()\n\n    # run_tma = False\n    if run_tma:\n        from .utils import TensorDescriptor\n\n        dummy_block = (1, 1)\n        desc_a = TensorDescriptor.from_tensor(A, block_shape=dummy_block)\n        desc_b = TensorDescriptor.from_tensor(B.T, block_shape=dummy_block)\n\n        def _grid_tma(META):\n            \"\"\"Grid tma.\"\"\"\n            BLOCK_M = META['BLOCK_M']\n            BLOCK_N = META['BLOCK_N']\n            return (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), )\n\n        _gemm_fp8_tma_kernel[_grid_tma](\n            desc_a,\n            A_scale,\n            desc_b,\n            B_scale,\n            C,\n            M=M,\n            N=N,\n            K=K,\n            group_ak=group_ak,\n            group_bk=group_bk,\n            group_bn=group_bn,\n            stride_asm=A_scale.stride(0),\n            stride_ask=A_scale.stride(1),\n            stride_bsk=B_scale.stride(0),\n            stride_bsn=B_scale.stride(1),\n            stride_cm=C.stride(0),\n            stride_cn=C.stride(1),\n            BLOCK_K=BLOCK_K,\n            GROUP_M=8,\n        )\n    else:\n        _gemm_fp8_kernel[grid](\n            A,\n            A_scale,\n            B,\n            B_scale,\n            C,\n            M=M,\n            N=N,\n            K=K,\n            group_ak=group_ak,\n            group_bk=group_bk,\n            group_bn=group_bn,\n            stride_am=A.stride(0),\n            stride_ak=A.stride(1),\n            stride_asm=A_scale.stride(0),\n            stride_ask=A_scale.stride(1),\n            stride_bk=B.stride(0),\n            stride_bn=B.stride(1),\n            stride_bsk=B_scale.stride(0),\n            stride_bsn=B_scale.stride(1),\n            stride_cm=C.stride(0),\n            stride_cn=C.stride(1),\n            BLOCK_K=BLOCK_K,\n            GROUP_M=8,\n        )\n\n    return C\n\n\ndef deep_gemm_fp8(A: Tensor,\n                  A_scale: Tensor,\n                  B: Tensor,\n                  B_scale: torch.Tensor,\n                  out_dtype: torch.dtype = torch.bfloat16):\n    \"\"\"Deepgemm fp8.\"\"\"\n    from lmdeploy.pytorch.third_party.deep_gemm import fp8_gemm_nt\n    M, _ = A.shape\n    N, _ = B.shape\n    assert out_dtype == torch.bfloat16, 'DeepGemm requires bf16 output.'\n    C = A.new_empty(M, N, dtype=out_dtype)\n    fp8_gemm_nt((A, A_scale), (B, B_scale), C, None)\n    return C\n"
  },
  {
    "path": "lmdeploy/pytorch/kernels/cuda/causal_conv1d.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport tilelang\nimport tilelang.language as T\nimport torch\n\n# The kernels below is modified from: https://github.com/Dao-AILab/causal-conv1d\n\n\n@tilelang.jit(pass_configs={\n    tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,\n}, )\ndef causal_conv1d_fwd(hidden_size, width, has_bias, activation, dtype, stride_x, num_warps, ChunkSizeL=64):\n    \"\"\"TileLang kernel for causal convolution forward pass.\n\n    Each thread processes one output position for all channels sequentially.\n    \"\"\"\n    num_threads = num_warps * 32\n    num_bits = T.DataType(dtype).bits\n    num_bytes = num_bits // 8\n    # elems_per_row <= num_threads\n    elems_per_row = 128 // num_bytes\n    ChunkSizeC = elems_per_row\n    silu_activation = activation in ['silu', 'swish']\n\n    l_per_thread = min(ChunkSizeC * ChunkSizeL // num_threads, ChunkSizeL)\n    assert num_threads * l_per_thread == ChunkSizeC * ChunkSizeL\n    thrs_per_row = ChunkSizeL // l_per_thread\n    assert thrs_per_row * l_per_thread == ChunkSizeL\n    sum_seqlen = T.dynamic('sum_seqlen')\n\n    @T.prim_func\n    def causal_conv1d_fwd_main(\n        X: T.StridedTensor([hidden_size, sum_seqlen], dtype=dtype, strides=(1, stride_x)),\n        W: T.Tensor([hidden_size, width], dtype=dtype),\n        seq_idx: T.Tensor([sum_seqlen], dtype=T.int32),\n        Bias: T.Tensor([hidden_size], dtype=dtype) = None,\n        Init_states: T.Tensor([hidden_size, width - 1], dtype=dtype) = None,\n        Out: T.StridedTensor([hidden_size, sum_seqlen], dtype=dtype, strides=(1, hidden_size)) = None,\n        Final_States: T.Tensor([hidden_size, width - 1], dtype=dtype) = None,\n    ):\n        # Process sum_seqlen output positions across all threads and blocks\n        # every cta process (ChunkSizeC, ChunkSizeL) output tile\n        with T.Kernel(T.ceildiv(hidden_size, ChunkSizeC), T.ceildiv(sum_seqlen, ChunkSizeL),\n                      threads=num_threads) as (bc, bl):\n\n            x_smem = T.alloc_shared((ChunkSizeL + width - 1, ChunkSizeC), dtype)\n\n            # load x(copy can not be used on strided tensor)\n            for lidx, cidx in T.Parallel(ChunkSizeL, ChunkSizeC):\n                glidx = bl * ChunkSizeL + lidx\n                gcidx = bc * ChunkSizeC + cidx\n                x_smem[lidx + width - 1, cidx] = T.if_then_else(glidx >= 0 and glidx < sum_seqlen, X[gcidx, glidx],\n                                                                T.cast(0.0, dtype))\n            for lidx, cidx in T.Parallel(width, ChunkSizeC):\n                glidx = bl * ChunkSizeL + lidx - width + 1\n                gcidx = bc * ChunkSizeC + cidx\n                x_smem[lidx, cidx] = T.if_then_else(glidx >= 0 and glidx < sum_seqlen, X[gcidx, glidx],\n                                                    T.cast(0.0, dtype))\n\n            x_local = T.alloc_local((width - 1 + l_per_thread, ), T.float32)\n            seq_idx_local = T.alloc_local((width - 1 + l_per_thread, ), seq_idx.dtype)\n            w_local = T.alloc_local((width, ), T.float32)\n            if has_bias:\n                bias_var = T.alloc_var(T.float32)\n            else:\n                bias_var = 0.0\n            T.clear(w_local)\n\n            tid = T.get_thread_binding(0)\n            row_idx = tid // thrs_per_row\n            col_idx = tid % thrs_per_row\n\n            # load w/b\n            if bc * ChunkSizeC + row_idx < hidden_size:\n                for widx in T.unroll(width):\n                    w_local[widx] = W[bc * ChunkSizeC + row_idx, widx]\n                if has_bias:\n                    bias_var = Bias[bc * ChunkSizeC + row_idx]\n\n            # load x\n            # load seq_idx\n            for i in T.unroll(l_per_thread + width - 1):\n                x_local[i] = x_smem[col_idx * l_per_thread + i, row_idx]\n\n            # load seq_idx\n            for i in T.unroll(l_per_thread + width - 1):\n                gi = bl * ChunkSizeL + col_idx * l_per_thread + i - (width - 1)\n                seq_idx_local[i] = T.if_then_else(gi >= 0 and gi < sum_seqlen, seq_idx[gi], -1)\n\n            out_vals = T.alloc_local((l_per_thread, ), T.float32)\n            T.clear(out_vals)\n            for i in T.unroll(l_per_thread):\n                out_vals[i] = bias_var\n                seq_idx_cur = seq_idx_local[i + width - 1]\n                if seq_idx_cur < 0:\n                    out_vals[i] = 0.0\n                    continue\n                for w in T.unroll(width):\n                    out_vals[i] += T.if_then_else(seq_idx_local[i + w] == seq_idx_cur, w_local[w] * x_local[i + w], 0.0)\n                if silu_activation:\n                    out_vals[i] = T.sigmoid(out_vals[i]) * out_vals[i]\n\n            for i in T.unroll(l_per_thread):\n                x_smem[col_idx * l_per_thread + i, row_idx] = out_vals[i]\n\n            for lidx, cidx in T.Parallel(ChunkSizeL, ChunkSizeC):\n                glidx = bl * ChunkSizeL + lidx\n                gcidx = bc * ChunkSizeC + cidx\n                Out[gcidx, glidx] = T.if_then_else(glidx >= 0 and glidx < sum_seqlen, x_smem[lidx, cidx],\n                                                   T.cast(0.0, dtype))\n\n    return causal_conv1d_fwd_main\n\n\ndef causal_conv1d_fn(\n    x,\n    weight,\n    bias=None,\n    seq_idx=None,\n    initial_states=None,\n    return_final_states=False,\n    final_states_out=None,\n    activation=None,\n):\n    \"\"\"Causal 1D convolution function using TileLang kernel.\n\n    Args:\n        x: Input tensor of shape [batch_size, hidden_size, sequence_length]\n           Note: batch_size must be 1\n        weight: Convolution weights of shape [hidden_size, kernel_size]\n        bias: Optional bias of shape [hidden_size]\n        seq_idx: Sequence indices of shape [sequence_length] to handle multiple sequences\n        initial_states: Initial states for sequence start [hidden_size, kernel_size-1]\n        return_final_states: Whether to return final states\n        final_states_out: Output tensor for final states\n        activation: Activation function name ('silu', 'gelu', 'relu', or None)\n\n    Returns:\n        output: Convolution result of shape [batch_size, hidden_size, sequence_length]\n        (and final_states if return_final_states=True)\n    \"\"\"\n    assert x.dim() == 3, 'x should be in shape of [batch_size, hidden_size, sum_seqlen]'\n    assert x.size(0) == 1, 'batch_size should be 1 for continuous batching'\n    assert x.stride(1) == 1, 'x should be in channel last format'\n    assert weight.dim() == 2, 'weight should be in shape of [hidden_size, kernel_size]'\n    assert seq_idx is not None, 'seq_idx is required for causal_conv1d_fn'\n    assert activation in ['silu', 'swish', None]\n    assert not return_final_states, 'return_final_states=True is not supported in this version'\n\n    _, hidden_size, _ = x.shape\n    kernel_size = weight.shape[1]\n    dtype = x.dtype\n\n    # Reshape to 2D format for kernel: [hidden_size, sum_seqlen]\n    x_2d = x.squeeze(0)  # [hidden_size, sum_seqlen]\n    seq_idx_1d = seq_idx.squeeze(0) if seq_idx.dim() > 1 else seq_idx  # [sum_seqlen]\n\n    # Initialize output tensor, hidden_size first for better memory access pattern\n    out = x_2d.new_empty(x_2d.size(1), hidden_size)\n    out = out.T\n\n    # Create and call the TileLang kernel\n    num_warps = 4  # Tunable parameter\n    kernel = causal_conv1d_fwd(hidden_size, kernel_size, bias is not None, activation, dtype, x.stride(2), num_warps)\n\n    kernel(\n        x_2d,\n        weight,\n        seq_idx_1d,\n        bias,\n        initial_states,\n        out,\n        None,\n    )\n\n    # Reshape back to original format: [1, hidden_size, sum_seqlen]\n    out = out.unsqueeze(0)\n\n    return out\n\n\n@tilelang.jit(pass_configs={\n    tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,\n}, )\ndef causal_conv1d_update_fwd(hidden_size: int, seqlen: int, state_len: int, width: int, has_bias: bool,\n                             activation: str | None, dtype, conv_stride: tuple[int, int, int], num_warps: int):\n    \"\"\"TileLang kernel for causal convolution forward pass.\n\n    Each thread processes one output position for all channels sequentially.\n    \"\"\"\n    num_threads = num_warps * 32\n    silu_activation = activation in ['silu', 'swish']\n\n    advance_len = seqlen\n    batch = T.dynamic('batch')\n    conv_batch = T.dynamic('conv_batch')\n    conv_batch_stride = T.dynamic('conv_batch_stride')\n    update_idx = -(width - 1)\n    update_idx = update_idx if update_idx >= 0 else update_idx + state_len\n\n    @T.prim_func\n    def causal_conv1d_update_main(\n        X: T.Tensor((batch, hidden_size, seqlen), dtype=dtype),\n        Conv_State: T.StridedTensor((conv_batch, hidden_size, state_len),\n                                    dtype=dtype,\n                                    strides=(conv_batch_stride, conv_stride[1], conv_stride[2])),\n        W: T.Tensor((hidden_size, width), dtype=dtype),\n        Bias: T.Tensor((hidden_size, ), dtype=dtype) = None,\n        Out: T.Tensor((batch, hidden_size, seqlen), dtype=dtype) = None,\n        Conv_state_indices: T.Tensor((batch, ), dtype=T.int32) = None,\n    ):\n        with T.Kernel(batch, T.ceildiv(hidden_size, num_threads), threads=num_threads) as (bi, bc):\n            tidx = T.get_thread_binding(0)\n            batch_id = bi\n            channel_id = bc * num_threads + tidx\n\n            # load conv state index\n            conv_state_batch_coord = T.if_then_else(Conv_state_indices is not None, Conv_state_indices[batch_id],\n                                                    T.cast(batch_id, T.int32))\n\n            # skip padding tokens\n            # tilelang does not support return in branch,\n            # so I have to create this ugly branch to skip the computation for padding tokens\n            if conv_state_batch_coord < 0:\n                for i in T.unroll(seqlen, unroll_factor=2):\n                    Out[batch_id, channel_id, i] = 0.0\n            else:\n                # load bias and weight\n                bias_val = T.if_then_else(has_bias, T.cast(Bias[channel_id], T.float32), 0.0)\n                weight_vals = T.alloc_local((width, ), T.float32)\n                for i in T.unroll(width):\n                    weight_vals[i] = W[channel_id, i]\n\n                # fill conv states and read x_vals\n                x_vals = T.alloc_local((width, ), T.float32)\n                for i in T.unroll(state_len - advance_len - (width - 1), unroll_factor=2):\n                    Conv_State[conv_state_batch_coord, channel_id, i] = Conv_State[conv_state_batch_coord, channel_id,\n                                                                                   i + advance_len]\n                for i in T.unroll(width - 1):\n                    state_val = Conv_State[conv_state_batch_coord, channel_id, state_len - (width - 1) + i]\n                    if i < advance_len + (width - 1) and state_len - advance_len - (width - 1) + i >= 0:\n                        Conv_State[conv_state_batch_coord, channel_id,\n                                   state_len - advance_len - (width - 1) + i] = state_val\n                    x_vals[i] = state_val\n\n                # compute output\n                for i in T.unroll(seqlen, unroll_factor=2):\n                    x_val = X[batch_id, channel_id, i]\n                    if i < advance_len and state_len - advance_len + i >= 0:\n                        Conv_State[conv_state_batch_coord, channel_id, state_len - advance_len + i] = x_val\n                    x_vals[width - 1] = x_val\n                    out_val = T.alloc_var(T.float32)\n                    out_val = bias_val\n                    for j in T.unroll(width):\n                        out_val += weight_vals[j] * x_vals[j]\n                    if silu_activation:\n                        out_val = T.sigmoid(out_val) * out_val\n                    Out[batch_id, channel_id, i] = out_val\n                    # shift x_vals\n                    for j in T.unroll(width - 1):\n                        x_vals[j] = x_vals[j + 1]\n\n    return causal_conv1d_update_main\n\n\n# TODO: support cache_seqlens\n# TODO: support complex layout\ndef causal_conv1d_update(x,\n                         conv_state,\n                         weight,\n                         bias=None,\n                         activation=None,\n                         cache_seqlens=None,\n                         conv_state_indices=None):\n    \"\"\"Tilelang implementation of causal_conv1d_update.\"\"\"\n    assert x.dim() in (2, 3)\n    assert conv_state.dim() == 3\n    assert weight.dim() == 2\n    assert activation in ['silu', 'swish', None]\n    assert cache_seqlens is None, 'cache_seqlens is not supported in this version'\n    if conv_state_indices is not None:\n        assert conv_state_indices.dim() == 1 and conv_state_indices.is_contiguous()\n        assert conv_state_indices.dtype == torch.int32\n\n    unsqueeze = x.dim() == 2\n    if unsqueeze:\n        x = x.unsqueeze(-1)\n\n    has_bias = bias is not None\n    width = weight.size(-1)\n    _, hidden_size, seqlen = x.shape\n    state_len = conv_state.size(-1)\n\n    out = x.new_empty(x.shape)\n\n    num_warps = 2\n    kernel = causal_conv1d_update_fwd(hidden_size=hidden_size,\n                                      seqlen=seqlen,\n                                      state_len=state_len,\n                                      width=width,\n                                      has_bias=has_bias,\n                                      activation=activation,\n                                      dtype=x.dtype,\n                                      conv_stride=conv_state.stride(),\n                                      num_warps=num_warps)\n\n    kernel(x, conv_state, weight, bias, out, conv_state_indices)\n\n    if unsqueeze:\n        out = out.squeeze(-1)\n\n    return out\n"
  },
  {
    "path": "lmdeploy/pytorch/kernels/cuda/ds_index.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport torch\nimport triton\nimport triton.language as tl\n\nfrom .utils import get_device_props\n\n\n@triton.jit\ndef _fp8_index_kernel(\n    q_ptr,\n    q_s_ptr,\n    k_cache_ptr,\n    k_s_cache_ptr,\n    cu_seqlen_q_ptr,\n    k_seqlen_ptr,\n    block_offset_ptr,\n    out_ptr,\n    stride_qm: tl.constexpr,\n    stride_qh: tl.constexpr,\n    stride_qd: tl.constexpr,\n    stride_qsm: tl.constexpr,\n    stride_qsh: tl.constexpr,\n    stride_kb: tl.constexpr,\n    stride_kn: tl.constexpr,\n    stride_kd: tl.constexpr,\n    stride_ksb: tl.constexpr,\n    stride_ksn: tl.constexpr,\n    stride_boff0,\n    stride_boff1: tl.constexpr,\n    stride_om,\n    stride_on: tl.constexpr,\n    max_q_seqlen,\n    causal: tl.constexpr,\n    BLOCK_H: tl.constexpr,\n    BLOCK_N: tl.constexpr,\n    BLOCK_D: tl.constexpr,\n    NUM_SPLIT: tl.constexpr,\n):\n    \"\"\"Fp8 index kernel.\"\"\"\n    m_id = tl.program_id(0).to(tl.int64)\n    split_id = tl.program_id(1).to(tl.int64)\n\n    assert stride_qd == 1\n    assert stride_kd == 1\n\n    batch_id = m_id // max_q_seqlen\n    q_id = m_id % max_q_seqlen\n    q_start = tl.load(cu_seqlen_q_ptr + batch_id)\n    q_seqlen = tl.load(cu_seqlen_q_ptr + batch_id + 1) - q_start\n    if q_id >= q_seqlen:\n        return\n\n    k_seqlen = tl.load(k_seqlen_ptr + batch_id)\n    if k_seqlen <= 0:\n        return\n\n    q_pos = q_start + q_id\n    offs_h = tl.arange(0, BLOCK_H)\n    offs_d = tl.arange(0, BLOCK_D)\n    offs_n = tl.arange(0, BLOCK_N)\n\n    q_ptrs = q_ptr + q_pos * stride_qm + offs_h[:, None] * stride_qh + offs_d[None, :] * stride_qd\n    q_s_ptrs = q_s_ptr + q_pos * stride_qsm + offs_h * stride_qsh\n    q = tl.load(q_ptrs)\n    q_s = tl.load(q_s_ptrs)\n\n    k_ptrs = k_cache_ptr + offs_n[None, :] * stride_kn + offs_d[:, None] * stride_kd\n    k_s_ptrs = k_s_cache_ptr + offs_n * stride_ksn\n    o_ptrs = out_ptr + q_pos * stride_om + offs_n * stride_on + split_id * BLOCK_N * stride_on\n    boff_ptr = block_offset_ptr + batch_id * stride_boff0 + split_id * stride_boff1\n\n    causal_pos = k_seqlen - q_seqlen + q_id\n    num_blocks = tl.cdiv(k_seqlen, BLOCK_N)\n    for boff_id in tl.range(split_id, num_blocks, NUM_SPLIT, num_stages=3):\n        boff = tl.load(boff_ptr)\n\n        k = tl.load(k_ptrs + boff * stride_kb)\n        k_s = tl.load(k_s_ptrs + boff * stride_ksb)\n\n        logits = tl.zeros((BLOCK_H, BLOCK_N), dtype=tl.float32)\n        logits = tl.dot(q, k, acc=logits)\n        logits = tl.maximum(logits, 0) * q_s[:, None]\n        logits_sum = tl.sum(logits, axis=0) * k_s\n\n        if causal:\n            mask_off = boff_id * BLOCK_N + offs_n\n            mask = mask_off <= causal_pos\n            logits_sum = tl.where(mask, logits_sum, float('-inf'))\n\n        tl.store(o_ptrs, logits_sum, mask=offs_n + boff_id * BLOCK_N < k_seqlen)\n        boff_ptr += NUM_SPLIT * stride_boff1\n        o_ptrs += NUM_SPLIT * BLOCK_N * stride_on\n\n\ndef fp8_index(q: torch.Tensor,\n              q_s: torch.Tensor,\n              k_cache: torch.Tensor,\n              k_s_cache: torch.Tensor,\n              cu_seqlen_q: torch.Tensor,\n              k_seqlens: torch.Tensor,\n              block_offset: torch.Tensor,\n              max_q_seqlen: int = None,\n              max_k_seqlen: int = None,\n              causal: bool = False):\n    \"\"\"Fp8 index.\n\n    q: (cum_seqlen, num_heads, head_dim)\n    q_s: (cum_seqlen, num_heads)\n    k_cache: (num_blocks, block_size, head_dim)\n    k_s_cache: (num_blocks, block_size)\n    cu_seqlen_q: (batch_size,)\n    cu_seqlen_k: (batch_size,)\n    block_offset: (batch_size, num_blocks)\n    \"\"\"\n    assert q.dim() == 3\n    assert k_cache.dim() == 3\n    assert q_s.dim() == 2\n    assert k_s_cache.dim() == 2\n    cum_seqlen, num_heads, head_dim = q.shape\n    block_size = k_cache.size(1)\n    batch_size = k_seqlens.numel()\n    is_decoding = batch_size == cum_seqlen\n    if max_k_seqlen is None:\n        max_num_blocks = k_cache.size(0)\n        max_k_seqlen = max_num_blocks * block_size\n\n    # max q seqlen\n    if is_decoding:\n        if max_q_seqlen is None:\n            max_q_seqlen = 1\n        assert max_q_seqlen == 1\n    elif max_q_seqlen is None:\n        max_q_seqlen = cum_seqlen\n\n    assert q.stride(-1) == 1 and k_cache.stride(-1) == 1\n\n    out = q.new_empty((cum_seqlen, max_k_seqlen), dtype=torch.float32)\n\n    num_warps = 4\n    device_idx = q.device.index\n    props = get_device_props(device_idx)\n    num_sm = props['multi_processor_count']\n    # estimated occupancy 12.5%\n    warps_per_sm = props['warps_per_sm'] // 8\n    assert warps_per_sm >= num_warps\n    cta_per_sm = warps_per_sm // num_warps\n    cta_per_device = num_sm * cta_per_sm\n    # we better have a tensor to indicate batch id of each q\n    M = max_q_seqlen * batch_size\n    NUM_SPLIT = max(1, triton.cdiv(cta_per_device, M))\n    grid = (M, NUM_SPLIT)\n\n    _fp8_index_kernel[grid](q,\n                            q_s,\n                            k_cache,\n                            k_s_cache,\n                            cu_seqlen_q,\n                            k_seqlens,\n                            block_offset,\n                            out,\n                            *q.stride(),\n                            *q_s.stride(),\n                            *k_cache.stride(),\n                            *k_s_cache.stride(),\n                            *block_offset.stride(),\n                            *out.stride(),\n                            max_q_seqlen=max_q_seqlen,\n                            causal=causal,\n                            BLOCK_H=num_heads,\n                            BLOCK_N=block_size,\n                            BLOCK_D=head_dim,\n                            NUM_SPLIT=NUM_SPLIT,\n                            num_warps=num_warps)\n    return out\n"
  },
  {
    "path": "lmdeploy/pytorch/kernels/cuda/fill_kv_cache.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom typing import Literal, Optional\n\nimport torch\nimport triton\nimport triton.language as tl\nfrom torch import Tensor\n\n\n@triton.jit\ndef _quant_int8(val):\n    val_min = tl.min(val, 1)\n    val_max = tl.max(val, 1)\n    scales = (val_max - val_min) / 255\n    zeros = -val_min / scales\n    q_val = (val / scales[:, None] + zeros[:, None] + 0.5).to(tl.uint8)\n    return q_val, scales, zeros\n\n\n@triton.jit\ndef _quant_int4(val1, val2):\n    val1 = val1.to(tl.float32)\n    val2 = val2.to(tl.float32)\n    val_min = tl.min(tl.minimum(val1, val2), 1)\n    val_max = tl.max(tl.maximum(val1, val2), 1)\n    scales = (val_max - val_min) / 15\n    zeros = -val_min / scales\n    q_val1 = (val1 / scales[:, None] + zeros[:, None] + 0.5).to(tl.uint8)\n    q_val2 = (val2 / scales[:, None] + zeros[:, None] + 0.5).to(tl.uint8)\n    q_val = q_val1 + q_val2 * 16\n    return q_val, scales, zeros\n\n\n@triton.jit\ndef _fill_kv_cache_kernel(\n    KStates,\n    VStates,\n    KCaches,\n    VCaches,\n    QStartLoc,\n    QSeqLens,\n    KVSeqLens,\n    BlockOffsets,\n    is_decoding: tl.constexpr,\n    head_dim: tl.constexpr,\n    head_dim_v: tl.constexpr,\n    stride_kss,\n    stride_ksh,\n    stride_ksd,\n    stride_vss,\n    stride_vsh,\n    stride_vsd,\n    stride_kcn: tl.constexpr,\n    stride_kcb: tl.constexpr,\n    stride_kch: tl.constexpr,\n    stride_kcd: tl.constexpr,\n    stride_vcn: tl.constexpr,\n    stride_vcb: tl.constexpr,\n    stride_vch: tl.constexpr,\n    stride_vcd: tl.constexpr,\n    stride_boff,\n    BLOCK: tl.constexpr,\n    BLOCK_D: tl.constexpr,\n    BLOCK_DV: tl.constexpr,\n):\n    \"\"\"Fill kv cache kernel.\"\"\"\n    batch_id = tl.program_id(2)\n    head_id = tl.program_id(0)\n    block_id = tl.program_id(1)\n\n    q_startloc = tl.load(QStartLoc + batch_id)\n    q_seqlen = tl.load(QSeqLens + batch_id)\n    kv_seqlen = tl.load(KVSeqLens + batch_id)\n    history_seqlen = kv_seqlen - q_seqlen\n\n    kv_block_id = history_seqlen // BLOCK + block_id\n\n    if kv_seqlen <= 0:\n        return\n\n    if kv_block_id * BLOCK >= kv_seqlen:\n        return\n\n    if is_decoding:\n        page_offs = tl.full((1, ), history_seqlen % BLOCK, dtype=tl.int32)\n        kv_mask = tl.full((1, ), 1, dtype=tl.int1)\n        q_offs = tl.full((1, ), q_startloc, dtype=tl.int32)\n    else:\n        page_offs = tl.arange(0, BLOCK)\n        kv_offs = kv_block_id * BLOCK + page_offs\n        kv_mask = (kv_offs >= history_seqlen) & (kv_offs < kv_seqlen)\n        token_off = q_startloc + kv_block_id * BLOCK - history_seqlen\n        q_offs = token_off + page_offs\n\n    block_off = tl.load(BlockOffsets + batch_id * stride_boff + kv_block_id)\n\n    d_off = tl.arange(0, BLOCK_D)\n    mask_ks = kv_mask[:, None]\n    mask_kc = mask_ks & (d_off[None, :] < head_dim)\n    d_off = d_off % head_dim\n\n    ks_ptr = KStates + head_id * stride_ksh\n    ks_ptrs = ks_ptr + q_offs[:, None] * stride_kss + d_off[None, :] * stride_ksd\n    kc_ptr = KCaches + block_off * stride_kcn + head_id * stride_kch\n    kc_ptrs = kc_ptr + page_offs[:, None] * stride_kcb + d_off[None, :] * stride_kcd\n\n    if BLOCK_DV > 0:\n        dv_off = tl.arange(0, BLOCK_DV)\n        mask_vs = kv_mask[:, None]\n        mask_vc = mask_vs & (dv_off[None, :] < head_dim_v)\n        dv_off = dv_off % head_dim_v\n        vs_ptr = VStates + head_id * stride_vsh\n        vs_ptrs = vs_ptr + q_offs[:, None] * stride_vss + dv_off[None, :] * stride_vsd\n        vc_ptr = VCaches + block_off * stride_vcn + head_id * stride_vch\n        vc_ptrs = vc_ptr + page_offs[:, None] * stride_vcb + dv_off[None, :] * stride_vcd\n\n    k = tl.load(ks_ptrs, mask=mask_ks)\n    if BLOCK_DV > 0:\n        v = tl.load(vs_ptrs, mask=mask_vs)\n    tl.store(kc_ptrs, k, mask=mask_kc)\n    if BLOCK_DV > 0:\n        tl.store(vc_ptrs, v, mask=mask_vc)\n\n\n@triton.jit\ndef _fill_page_quant_int8(\n    state_ptr,\n    cache_ptr,\n    scales_zeros_ptr,\n    block_off,\n    head_id,\n    page_offs,\n    q_offs,\n    kv_mask,\n    head_dim: tl.constexpr,\n    stride_ss,\n    stride_sh,\n    stride_sd,\n    stride_cn: tl.constexpr,\n    stride_cb: tl.constexpr,\n    stride_ch: tl.constexpr,\n    stride_cd: tl.constexpr,\n    stride_szn: tl.constexpr,\n    stride_szb: tl.constexpr,\n    stride_szh: tl.constexpr,\n    stride_szd: tl.constexpr,\n    BLOCK_D: tl.constexpr,\n):\n    \"\"\"Fill page int8.\"\"\"\n    d_off = tl.arange(0, BLOCK_D)\n    mask_kc = kv_mask[:, None] & (d_off[None, :] < head_dim)\n    d_off = d_off % head_dim\n    state_ptr = state_ptr + head_id * stride_sh\n    state_ptrs = state_ptr + q_offs[:, None] * stride_ss + d_off[None, :] * stride_sd\n    cache_ptr = cache_ptr + block_off * stride_cn + head_id * stride_ch\n    cache_ptrs = cache_ptr + page_offs[:, None] * stride_cb + d_off[None, :] * stride_cd\n    scales_zeros_ptr = scales_zeros_ptr + block_off * stride_szn + head_id * stride_szh\n    scales_ptrs = scales_zeros_ptr + page_offs[:, None] * stride_szb\n    zeros_ptrs = scales_ptrs + stride_szd\n\n    state = tl.load(state_ptrs, mask=kv_mask[:, None])\n    state, scales, zeros = _quant_int8(state)\n\n    tl.store(cache_ptrs, state, mask=mask_kc)\n    tl.store(scales_ptrs, scales[:, None], mask=kv_mask[:, None])\n    tl.store(zeros_ptrs, zeros[:, None], mask=kv_mask[:, None])\n\n\n@triton.jit\ndef _fill_page_quant_int4(\n    state_ptr,\n    cache_ptr,\n    scales_zeros_ptr,\n    block_off,\n    head_id,\n    page_offs,\n    q_offs,\n    kv_mask,\n    head_dim: tl.constexpr,\n    stride_ss,\n    stride_sh,\n    stride_sd,\n    stride_cn: tl.constexpr,\n    stride_cb: tl.constexpr,\n    stride_ch: tl.constexpr,\n    stride_cd: tl.constexpr,\n    stride_szn: tl.constexpr,\n    stride_szb: tl.constexpr,\n    stride_szh: tl.constexpr,\n    stride_szd: tl.constexpr,\n    BLOCK_D: tl.constexpr,\n):\n    \"\"\"Fill page int4.\"\"\"\n    d_off = tl.arange(0, BLOCK_D)\n    mask_kc = kv_mask[:, None] & (d_off[None, :] < head_dim)\n    state_ptr = state_ptr + head_id * stride_sh\n    state0_ptrs = state_ptr + q_offs[:, None] * stride_ss + d_off[None, :] * stride_sd\n    state1_ptrs = state0_ptrs + head_dim * stride_sd\n    cache_ptr = cache_ptr + block_off * stride_cn + head_id * stride_ch\n    cache_ptrs = cache_ptr + page_offs[:, None] * stride_cb + d_off[None, :] * stride_cd\n    scales_zeros_ptr = scales_zeros_ptr + block_off * stride_szn + head_id * stride_szh\n    scales_ptrs = scales_zeros_ptr + page_offs[:, None] * stride_szb\n    zeros_ptrs = scales_ptrs + stride_szd\n\n    state0 = tl.load(state0_ptrs, mask=mask_kc)\n    state1 = tl.load(state1_ptrs, mask=mask_kc)\n    state, scales, zeros = _quant_int4(state0, state1)\n\n    tl.store(cache_ptrs, state, mask=mask_kc)\n    tl.store(scales_ptrs, scales[:, None], mask=kv_mask[:, None])\n    tl.store(zeros_ptrs, zeros[:, None], mask=kv_mask[:, None])\n\n\n@triton.jit\ndef _fill_page_quant(state_ptr, cache_ptr, scales_zeros_ptr, block_off, head_id, page_offs, q_offs, kv_mask,\n                     head_dim: tl.constexpr, stride_ss, stride_sh, stride_sd, stride_cn: tl.constexpr,\n                     stride_cb: tl.constexpr, stride_ch: tl.constexpr, stride_cd: tl.constexpr,\n                     stride_szn: tl.constexpr, stride_szb: tl.constexpr, stride_szh: tl.constexpr,\n                     stride_szd: tl.constexpr, BLOCK_D: tl.constexpr, quant_policy: tl.constexpr):\n    \"\"\"Fill page.\"\"\"\n    if quant_policy == 8:\n        return _fill_page_quant_int8(state_ptr,\n                                     cache_ptr,\n                                     scales_zeros_ptr,\n                                     block_off,\n                                     head_id,\n                                     page_offs,\n                                     q_offs,\n                                     kv_mask,\n                                     head_dim=head_dim,\n                                     stride_ss=stride_ss,\n                                     stride_sh=stride_sh,\n                                     stride_sd=stride_sd,\n                                     stride_cn=stride_cn,\n                                     stride_cb=stride_cb,\n                                     stride_ch=stride_ch,\n                                     stride_cd=stride_cd,\n                                     stride_szn=stride_szn,\n                                     stride_szb=stride_szb,\n                                     stride_szh=stride_szh,\n                                     stride_szd=stride_szd,\n                                     BLOCK_D=BLOCK_D)\n    elif quant_policy == 4:\n        return _fill_page_quant_int4(state_ptr,\n                                     cache_ptr,\n                                     scales_zeros_ptr,\n                                     block_off,\n                                     head_id,\n                                     page_offs,\n                                     q_offs,\n                                     kv_mask,\n                                     head_dim=head_dim,\n                                     stride_ss=stride_ss,\n                                     stride_sh=stride_sh,\n                                     stride_sd=stride_sd,\n                                     stride_cn=stride_cn,\n                                     stride_cb=stride_cb,\n                                     stride_ch=stride_ch,\n                                     stride_cd=stride_cd,\n                                     stride_szn=stride_szn,\n                                     stride_szb=stride_szb,\n                                     stride_szh=stride_szh,\n                                     stride_szd=stride_szd,\n                                     BLOCK_D=BLOCK_D)\n    else:\n        tl.static_assert(False, 'Unsupported quant policy')\n\n\n@triton.jit\ndef _fill_kv_cache_quant_kernel(\n    KStates,\n    VStates,\n    KCaches,\n    VCaches,\n    KScalesZeros,\n    VScalesZeros,\n    QStartLoc,\n    QSeqLens,\n    KVSeqLens,\n    BlockOffsets,\n    is_decoding: tl.constexpr,\n    head_dim: tl.constexpr,\n    head_dim_v: tl.constexpr,\n    stride_kss,\n    stride_ksh,\n    stride_ksd,\n    stride_vss,\n    stride_vsh,\n    stride_vsd,\n    stride_kcn: tl.constexpr,\n    stride_kcb: tl.constexpr,\n    stride_kch: tl.constexpr,\n    stride_kcd: tl.constexpr,\n    stride_vcn: tl.constexpr,\n    stride_vcb: tl.constexpr,\n    stride_vch: tl.constexpr,\n    stride_vcd: tl.constexpr,\n    stride_kszn: tl.constexpr,\n    stride_kszb: tl.constexpr,\n    stride_kszh: tl.constexpr,\n    stride_kszd: tl.constexpr,\n    stride_vszn: tl.constexpr,\n    stride_vszb: tl.constexpr,\n    stride_vszh: tl.constexpr,\n    stride_vszd: tl.constexpr,\n    quant_policy: tl.constexpr,\n    stride_boff,\n    BLOCK: tl.constexpr,\n    BLOCK_D: tl.constexpr,\n    BLOCK_DV: tl.constexpr,\n):\n    \"\"\"Fill kv cache kernel with int4 and int8 quant fuzed.\n\n    Args:\n        stride_xss: stride of sequence length dim of key or value states\n        stride_xsh: stride of head_num dim of key or value states\n        stride_xsh: stride of head_size dim of key or value states\n        stride_xn: stride of page num dim\n        stride_xb: stride of block size dim\n        stride_xh: stride of head_num dim\n        stride_xd: stride of head_size dim\n    \"\"\"\n    batch_id = tl.program_id(2)\n    head_id = tl.program_id(0)\n    block_id = tl.program_id(1)\n\n    q_startloc = tl.load(QStartLoc + batch_id)\n    q_seqlen = tl.load(QSeqLens + batch_id)\n    kv_seqlen = tl.load(KVSeqLens + batch_id)\n    history_seqlen = kv_seqlen - q_seqlen\n\n    kv_block_id = history_seqlen // BLOCK + block_id\n\n    if kv_seqlen <= 0:\n        return\n\n    if kv_block_id * BLOCK >= kv_seqlen:\n        return\n\n    if is_decoding:\n        page_offs = tl.full((1, ), history_seqlen % BLOCK, dtype=tl.int32)\n        kv_mask = tl.full((1, ), 1, dtype=tl.int1)\n        q_offs = tl.full((1, ), q_startloc, dtype=tl.int32)\n    else:\n        page_offs = tl.arange(0, BLOCK)\n        kv_offs = kv_block_id * BLOCK + page_offs\n        kv_mask = (kv_offs >= history_seqlen) & (kv_offs < kv_seqlen)\n        token_off = q_startloc + kv_block_id * BLOCK - history_seqlen\n        q_offs = token_off + page_offs\n\n    block_off = tl.load(BlockOffsets + batch_id * stride_boff + kv_block_id)\n\n    _fill_page_quant(KStates,\n                     KCaches,\n                     KScalesZeros,\n                     block_off,\n                     head_id,\n                     page_offs,\n                     q_offs,\n                     kv_mask,\n                     head_dim=head_dim,\n                     stride_ss=stride_kss,\n                     stride_sh=stride_ksh,\n                     stride_sd=stride_ksd,\n                     stride_cn=stride_kcn,\n                     stride_cb=stride_kcb,\n                     stride_ch=stride_kch,\n                     stride_cd=stride_kcd,\n                     stride_szn=stride_kszn,\n                     stride_szb=stride_kszb,\n                     stride_szh=stride_kszh,\n                     stride_szd=stride_kszd,\n                     BLOCK_D=BLOCK_D,\n                     quant_policy=quant_policy)\n\n    if BLOCK_DV > 0:\n        _fill_page_quant(VStates,\n                         VCaches,\n                         VScalesZeros,\n                         block_off,\n                         head_id,\n                         page_offs,\n                         q_offs,\n                         kv_mask,\n                         head_dim=head_dim_v,\n                         stride_ss=stride_vss,\n                         stride_sh=stride_vsh,\n                         stride_sd=stride_vsd,\n                         stride_cn=stride_vcn,\n                         stride_cb=stride_vcb,\n                         stride_ch=stride_vch,\n                         stride_cd=stride_vcd,\n                         stride_szn=stride_vszn,\n                         stride_szb=stride_vszb,\n                         stride_szh=stride_vszh,\n                         stride_szd=stride_vszd,\n                         BLOCK_D=BLOCK_DV,\n                         quant_policy=quant_policy)\n\n\ndef fill_kv_cache(k_states: Tensor,\n                  v_states: Optional[Tensor],\n                  k_caches: Tensor,\n                  v_caches: Optional[Tensor],\n                  q_start_loc: Tensor,\n                  q_seq_length: Tensor,\n                  kv_seq_length: Tensor,\n                  max_q_seq_length: int,\n                  block_offsets: Tensor,\n                  k_scales_zeros: Tensor = None,\n                  v_scales_zeros: Tensor = None,\n                  quant_policy: Literal[0, 4, 8] = 0,\n                  kv_layout: str = 'bshd'):\n    \"\"\"Fill key/value state to cache for paged attention.\"\"\"\n    if kv_layout == 'bshd':\n        b_dim, s_dim, h_dim, d_dim = (0, 1, 2, 3)\n    elif kv_layout == 'bhsd':\n        b_dim, s_dim, h_dim, d_dim = (0, 2, 1, 3)\n    else:\n        raise RuntimeError('Unsupported layout.')\n    if v_states is None:\n        v_states = k_states[..., :0]\n    if v_caches is None:\n        v_caches = k_caches[..., :0]\n\n    block_offsets = block_offsets.contiguous()\n    batch_size = block_offsets.size(0)\n    block_size = k_caches.size(s_dim)\n    num_heads = k_caches.size(h_dim)\n    head_dim = k_caches.size(d_dim)\n    head_dim_v = v_caches.size(d_dim)\n    if v_states.size(-1) == 0:\n        head_dim_v = 0\n    if max_q_seq_length == 1:\n        max_num_blocks = 1\n    else:\n        max_num_blocks = triton.cdiv(max_q_seq_length, block_size) + 1\n\n    BLOCK = block_size\n    BLOCK_D = triton.next_power_of_2(head_dim)\n    BLOCK_DV = triton.next_power_of_2(head_dim_v)\n    if k_caches.data_ptr() == v_caches.data_ptr() and head_dim_v <= head_dim:\n        BLOCK_DV = 0\n    grid = (num_heads, max_num_blocks, batch_size)\n    is_decoding = max_num_blocks == 1\n    if quant_policy == 0:\n        _fill_kv_cache_kernel[grid](\n            k_states,\n            v_states,\n            k_caches,\n            v_caches,\n            q_start_loc,\n            q_seq_length,\n            kv_seq_length,\n            block_offsets,\n            is_decoding=is_decoding,\n            head_dim=head_dim,\n            head_dim_v=head_dim_v,\n            stride_kss=k_states.stride(-3),\n            stride_ksh=k_states.stride(-2),\n            stride_ksd=k_states.stride(-1),\n            stride_vss=v_states.stride(-3),\n            stride_vsh=v_states.stride(-2),\n            stride_vsd=v_states.stride(-1),\n            stride_kcn=k_caches.stride(b_dim),\n            stride_kcb=k_caches.stride(s_dim),\n            stride_kch=k_caches.stride(h_dim),\n            stride_kcd=k_caches.stride(d_dim),\n            stride_vcn=v_caches.stride(b_dim),\n            stride_vcb=v_caches.stride(s_dim),\n            stride_vch=v_caches.stride(h_dim),\n            stride_vcd=v_caches.stride(d_dim),\n            stride_boff=block_offsets.stride(0),\n            BLOCK=BLOCK,\n            BLOCK_D=BLOCK_D,\n            BLOCK_DV=BLOCK_DV,\n            num_warps=4,\n            num_stages=3,\n        )\n    else:\n        _fill_kv_cache_quant_kernel[grid](\n            k_states,\n            v_states,\n            k_caches,\n            v_caches,\n            k_scales_zeros,\n            v_scales_zeros,\n            q_start_loc,\n            q_seq_length,\n            kv_seq_length,\n            block_offsets,\n            is_decoding=is_decoding,\n            head_dim=head_dim,\n            head_dim_v=head_dim_v,\n            stride_kss=k_states.stride(-3),\n            stride_ksh=k_states.stride(-2),\n            stride_ksd=k_states.stride(-1),\n            stride_vss=v_states.stride(-3),\n            stride_vsh=v_states.stride(-2),\n            stride_vsd=v_states.stride(-1),\n            stride_kcn=k_caches.stride(b_dim),\n            stride_kcb=k_caches.stride(s_dim),\n            stride_kch=k_caches.stride(h_dim),\n            stride_kcd=k_caches.stride(d_dim),\n            stride_vcn=v_caches.stride(b_dim),\n            stride_vcb=v_caches.stride(s_dim),\n            stride_vch=v_caches.stride(h_dim),\n            stride_vcd=v_caches.stride(d_dim),\n            stride_kszn=k_scales_zeros.stride(b_dim),\n            stride_kszb=k_scales_zeros.stride(s_dim),\n            stride_kszh=k_scales_zeros.stride(h_dim),\n            stride_kszd=k_scales_zeros.stride(d_dim),\n            stride_vszn=v_scales_zeros.stride(b_dim),\n            stride_vszb=v_scales_zeros.stride(s_dim),\n            stride_vszh=v_scales_zeros.stride(h_dim),\n            stride_vszd=v_scales_zeros.stride(d_dim),\n            quant_policy=quant_policy,\n            stride_boff=block_offsets.stride(0),\n            BLOCK=BLOCK,\n            BLOCK_D=BLOCK_D,\n            BLOCK_DV=BLOCK_DV,\n            num_warps=4,\n            num_stages=1,\n        )\n\n\n@triton.jit\ndef fast_log2_ceil(x):\n    bits_x = tl.cast(x, tl.uint32, bitcast=True)\n    exp_x = (bits_x >> 23) & 0xFF\n    man_bits = bits_x & ((1 << 23) - 1)\n    tmp = exp_x - 127 + tl.where(man_bits != 0, 1, 0)\n    return tl.cast(tmp, tl.int32)\n\n\n@triton.jit\ndef fast_pow2(x):\n    bits_x = (x + 127) << 23\n    return tl.cast(bits_x, tl.float32, bitcast=True)\n\n\n@triton.jit\ndef fast_round_scale(amax, fp8_max_inv):\n    return fast_pow2(fast_log2_ceil(amax * fp8_max_inv))\n\n\n@triton.jit\ndef _quant_blocked_fp8(x,\n                       fp8_min: tl.constexpr,\n                       fp8_max: tl.constexpr,\n                       dtype: tl.constexpr,\n                       GROUP_SIZE: tl.constexpr = 128,\n                       ROUND_SCALE: tl.constexpr = 0):\n    x = x.to(tl.float32)\n    M: tl.constexpr = x.shape[0]\n    N: tl.constexpr = x.shape[1]\n    rfp8_max: tl.constexpr = 1 / fp8_max\n    x = x.reshape(M, N // GROUP_SIZE, GROUP_SIZE)\n    amax = tl.maximum(tl.max(tl.abs(x), axis=2, keep_dims=True), 1e-6)\n    if ROUND_SCALE == 1:\n        scale = fast_round_scale(amax, rfp8_max)\n    else:\n        scale = amax * rfp8_max\n    out = x / scale\n\n    out = tl.clamp(out, fp8_min, fp8_max)\n    out = out.to(dtype)\n    out = out.reshape(M, N)\n    scale = scale.reshape(M, N // GROUP_SIZE)\n    return out, scale\n\n\n@triton.jit\ndef _fill_kv_cache_blocked_fp8_kernel(\n    KStates,\n    VStates,\n    KCaches,\n    VCaches,\n    KSCaches,\n    VSCaches,\n    cu_seqlen_q_ptr,\n    KVSeqLens,\n    BlockOffsets,\n    fp8_min: tl.constexpr,\n    fp8_max: tl.constexpr,\n    is_decoding: tl.constexpr,\n    head_dim: tl.constexpr,\n    head_dim_v: tl.constexpr,\n    stride_kss,\n    stride_ksh,\n    stride_ksd,\n    stride_vss,\n    stride_vsh,\n    stride_vsd,\n    stride_kcn: tl.constexpr,\n    stride_kcb: tl.constexpr,\n    stride_kch: tl.constexpr,\n    stride_kcd: tl.constexpr,\n    stride_vcn: tl.constexpr,\n    stride_vcb: tl.constexpr,\n    stride_vch: tl.constexpr,\n    stride_vcd: tl.constexpr,\n    stride_kscn: tl.constexpr,\n    stride_kscb: tl.constexpr,\n    stride_ksch: tl.constexpr,\n    stride_kscd: tl.constexpr,\n    stride_vscn: tl.constexpr,\n    stride_vscb: tl.constexpr,\n    stride_vsch: tl.constexpr,\n    stride_vscd: tl.constexpr,\n    stride_boff,\n    ROUND_SCALE: tl.constexpr,\n    GROUP_SIZE: tl.constexpr,\n    BLOCK: tl.constexpr,\n    BLOCK_D: tl.constexpr,\n    BLOCK_DV: tl.constexpr,\n):\n    \"\"\"Fill kv cache kernel.\"\"\"\n    batch_id = tl.program_id(2)\n    head_id = tl.program_id(0)\n    block_id = tl.program_id(1)\n\n    q_startloc = tl.load(cu_seqlen_q_ptr + batch_id)\n    q_seqlen = tl.load(cu_seqlen_q_ptr + batch_id + 1) - q_startloc\n    kv_seqlen = tl.load(KVSeqLens + batch_id)\n    history_seqlen = kv_seqlen - q_seqlen\n\n    kv_block_id = history_seqlen // BLOCK + block_id\n\n    if kv_seqlen <= 0:\n        return\n\n    if kv_block_id * BLOCK >= kv_seqlen:\n        return\n\n    if is_decoding:\n        page_offs = tl.full((1, ), history_seqlen % BLOCK, dtype=tl.int32)\n        kv_mask = tl.full((1, ), 1, dtype=tl.int1)\n        q_offs = tl.full((1, ), q_startloc, dtype=tl.int32)\n    else:\n        page_offs = tl.arange(0, BLOCK)\n        kv_offs = kv_block_id * BLOCK + page_offs\n        kv_mask = (kv_offs >= history_seqlen) & (kv_offs < kv_seqlen)\n        token_off = q_startloc + kv_block_id * BLOCK - history_seqlen\n        q_offs = token_off + page_offs\n\n    block_off = tl.load(BlockOffsets + batch_id * stride_boff + kv_block_id)\n\n    d_off = tl.arange(0, BLOCK_D)\n    mask_ks = kv_mask[:, None]\n    mask_kc = mask_ks & (d_off[None, :] < head_dim)\n    d_off = d_off % head_dim\n\n    BLOCK_DS: tl.constexpr = (BLOCK_D + GROUP_SIZE - 1) // GROUP_SIZE\n    ds_off = tl.arange(0, BLOCK_DS)\n\n    ks_ptr = KStates + head_id * stride_ksh\n    ks_ptrs = ks_ptr + q_offs[:, None] * stride_kss + d_off[None, :] * stride_ksd\n    kc_ptr = KCaches + block_off * stride_kcn + head_id * stride_kch\n    kc_ptrs = kc_ptr + page_offs[:, None] * stride_kcb + d_off[None, :] * stride_kcd\n    ksc_ptr = KSCaches + block_off * stride_kscn + head_id * stride_ksch\n    ksc_ptrs = ksc_ptr + page_offs[:, None] * stride_kscb + ds_off[None, :] * stride_kscd\n\n    if BLOCK_DV > 0:\n        dv_off = tl.arange(0, BLOCK_DV)\n        mask_vs = kv_mask[:, None]\n        mask_vc = mask_vs & (dv_off[None, :] < head_dim_v)\n\n        BLOCK_DVS: tl.constexpr = (BLOCK_DV + GROUP_SIZE - 1) // GROUP_SIZE\n        dvs_off = tl.arange(0, BLOCK_DVS)\n\n        dv_off = dv_off % head_dim_v\n        vs_ptr = VStates + head_id * stride_vsh\n        vs_ptrs = vs_ptr + q_offs[:, None] * stride_vss + dv_off[None, :] * stride_vsd\n        vc_ptr = VCaches + block_off * stride_vcn + head_id * stride_vch\n        vc_ptrs = vc_ptr + page_offs[:, None] * stride_vcb + dv_off[None, :] * stride_vcd\n        vsc_ptr = VSCaches + block_off * stride_vscn + head_id * stride_vsch\n        vsc_ptrs = vsc_ptr + page_offs[:, None] * stride_vscb + dvs_off[None, :] * stride_vscd\n\n    k = tl.load(ks_ptrs, mask=mask_ks)\n    if BLOCK_DV > 0:\n        v = tl.load(vs_ptrs, mask=mask_vs)\n    kc, kcs = _quant_blocked_fp8(k, fp8_min, fp8_max, KCaches.dtype.element_ty, GROUP_SIZE, ROUND_SCALE)\n    tl.store(kc_ptrs, kc, mask=mask_kc)\n    tl.store(ksc_ptrs, kcs, mask=kv_mask[:, None] & (ds_off[None, :] < tl.cdiv(head_dim, GROUP_SIZE)))\n    if BLOCK_DV > 0:\n        vc, vcs = _quant_blocked_fp8(v, fp8_min, fp8_max, VCaches.dtype.element_ty, GROUP_SIZE, ROUND_SCALE)\n        tl.store(vc_ptrs, vc, mask=mask_vc)\n        tl.store(vsc_ptrs, vcs, mask=kv_mask[:, None] & (ds_off[None, :] < tl.cdiv(head_dim_v, GROUP_SIZE)))\n\n\ndef fill_kv_cache_blocked_fp8(k_states: Tensor,\n                              v_states: Optional[Tensor],\n                              k_caches: Tensor,\n                              v_caches: Optional[Tensor],\n                              ks_caches: Tensor,\n                              vs_caches: Optional[Tensor],\n                              cu_seqlen_q: Tensor,\n                              kv_seqlens: Tensor,\n                              max_q_seqlen: int,\n                              block_offsets: Tensor,\n                              group_size: int = 128,\n                              kv_layout: str = 'bshd',\n                              scale_fmt: Optional[str] = None):\n    \"\"\"Fill key/value state to cache for paged attention with fp8 quantization.\n\n    Args:\n        k_states (Tensor): Key states of shape\n            (seq_length, num_heads, head_dim).\n        v_states (Optional[Tensor]): Value states of shape\n            (seq_length, num_heads, head_dim_v). If None, no value states\n            are processed.\n        k_caches (Tensor): 4D k cache, shape depends on ``kv_layout``.\n        v_caches (Optional[Tensor]): 4D v cache, shape depends on\n            ``kv_layout``. If None, no value caches are processed.\n        ks_caches (Tensor): 4D k scale cache, shape depends on\n            ``kv_layout``.\n        vs_caches (Optional[Tensor]): 4D v scale cache, shape depends on\n            ``kv_layout``. If None, no value scale caches are processed.\n        cu_seqlen_q (Tensor): Cumulative sequence lengths of queries,\n            shape (batch_size + 1, ).\n        kv_seqlens (Tensor): Sequence lengths of key/values, shape\n            (batch_size, ).\n        max_q_seqlen (int): Maximum sequence length of queries.\n        block_offsets (Tensor): Block offsets for each batch, shape\n            (batch_size, ).\n        group_size (int, optional): Group size for fp8 quantization. Default\n            is 128.\n        kv_layout (str, optional): Layout of key/value caches. Valid values\n            are ``'bshd'`` and ``'bhsd'``. Default is ``'bshd'``.\n        scale_fmt (str, optional): Format of the fp8 scaling factors. Valid\n            values are ``None`` and ``'ue8m0'``. When set to ``'ue8m0'``,\n            scaling factors are stored/interpreted using the UE8M0 fp8 scale\n            format; when ``None``, the default scale layout for this kernel\n            is used.\n    \"\"\"\n    assert scale_fmt in (None, 'ue8m0'), f'Unsupported scale format: {scale_fmt}.'\n\n    if kv_layout == 'bshd':\n        b_dim, s_dim, h_dim, d_dim = (0, 1, 2, 3)\n    elif kv_layout == 'bhsd':\n        b_dim, s_dim, h_dim, d_dim = (0, 2, 1, 3)\n    else:\n        raise RuntimeError('Unsupported layout.')\n\n    if v_states is None:\n        v_states = k_states[..., :0]\n    if v_caches is None:\n        v_caches = k_caches[..., :0]\n    if vs_caches is None:\n        vs_caches = ks_caches[..., :0]\n\n    block_offsets = block_offsets.contiguous()\n    batch_size = block_offsets.size(0)\n    block_size = k_caches.size(s_dim)\n    num_heads = k_caches.size(h_dim)\n    head_dim = k_caches.size(d_dim)\n    head_dim_v = v_states.size(-1)\n    if max_q_seqlen == 1:\n        max_num_blocks = 1\n    else:\n        max_num_blocks = triton.cdiv(max_q_seqlen, block_size) + 1\n\n    BLOCK = block_size\n    BLOCK_D = triton.next_power_of_2(head_dim)\n    BLOCK_DV = triton.next_power_of_2(head_dim_v)\n    if k_caches.data_ptr() == v_caches.data_ptr() and head_dim_v <= head_dim:\n        BLOCK_DV = 0\n\n    dtype = k_caches.dtype\n    finfo = torch.finfo(dtype)\n    fmin = finfo.min\n    fmax = finfo.max\n\n    grid = (num_heads, max_num_blocks, batch_size)\n    ROUND_SCALE = 1 if scale_fmt == 'ue8m0' else 0\n    is_decoding = max_q_seqlen == 1\n    _fill_kv_cache_blocked_fp8_kernel[grid](\n        k_states,\n        v_states,\n        k_caches,\n        v_caches,\n        ks_caches,\n        vs_caches,\n        cu_seqlen_q,\n        kv_seqlens,\n        block_offsets,\n        fp8_min=fmin,\n        fp8_max=fmax,\n        is_decoding=is_decoding,\n        head_dim=head_dim,\n        head_dim_v=head_dim_v,\n        stride_kss=k_states.stride(-3),\n        stride_ksh=k_states.stride(-2),\n        stride_ksd=k_states.stride(-1),\n        stride_vss=v_states.stride(-3),\n        stride_vsh=v_states.stride(-2),\n        stride_vsd=v_states.stride(-1),\n        stride_kcn=k_caches.stride(b_dim),\n        stride_kcb=k_caches.stride(s_dim),\n        stride_kch=k_caches.stride(h_dim),\n        stride_kcd=k_caches.stride(d_dim),\n        stride_vcn=v_caches.stride(b_dim),\n        stride_vcb=v_caches.stride(s_dim),\n        stride_vch=v_caches.stride(h_dim),\n        stride_vcd=v_caches.stride(d_dim),\n        stride_kscn=ks_caches.stride(b_dim),\n        stride_kscb=ks_caches.stride(s_dim),\n        stride_ksch=ks_caches.stride(h_dim),\n        stride_kscd=ks_caches.stride(d_dim),\n        stride_vscn=vs_caches.stride(b_dim),\n        stride_vscb=vs_caches.stride(s_dim),\n        stride_vsch=vs_caches.stride(h_dim),\n        stride_vscd=vs_caches.stride(d_dim),\n        stride_boff=block_offsets.stride(0),\n        ROUND_SCALE=ROUND_SCALE,\n        GROUP_SIZE=group_size,\n        BLOCK=BLOCK,\n        BLOCK_D=BLOCK_D,\n        BLOCK_DV=BLOCK_DV,\n        num_warps=4,\n    )\n"
  },
  {
    "path": "lmdeploy/pytorch/kernels/cuda/flashattention.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport math\nfrom typing import Sequence\n\nimport torch\nimport triton\nimport triton.language as tl\nfrom packaging import version\nfrom torch import Tensor\n\nfrom lmdeploy.utils import get_logger\n\nlogger = get_logger('lmdeploy')\n\nTRITON_VERSION = version.parse(triton.__version__)\nVERSION_300 = version.parse('3.0.0')\nVERSION_320 = version.parse('3.2.0')\nassert TRITON_VERSION >= VERSION_300\n\n# TODO: fast op might not work on non-nv device\ntanh = tl.extra.cuda.libdevice.tanh\ntl_log2 = tl.log2\ntl_exp2 = tl.exp2\n\n\ndef _get_block_d(head_dim_k, head_dim_v):\n    \"\"\"Get block d.\"\"\"\n    BLOCK_DK = triton.next_power_of_2(head_dim_k)\n    BLOCK_DK1 = 0\n    if BLOCK_DK != head_dim_k:\n        BLOCK_DK = BLOCK_DK // 2\n        BLOCK_DK1 = max(16, triton.next_power_of_2(head_dim_k - BLOCK_DK))\n    BLOCK_DV = triton.next_power_of_2(head_dim_v)\n    return BLOCK_DK, BLOCK_DK1, BLOCK_DV\n\n\n@triton.jit\ndef softcapping(qk, logit_softcapping: tl.constexpr):\n    \"\"\"Soft capping.\"\"\"\n    if logit_softcapping > 0.0:\n        qk = qk / logit_softcapping\n        qk = tanh(qk)\n        qk = qk * logit_softcapping\n    return qk\n\n\n@triton.jit\ndef _load_kv(ptrs, boundary_check: tl.constexpr):\n    \"\"\"Load kv.\"\"\"\n    if boundary_check is not None:\n        return tl.load(ptrs, boundary_check=boundary_check, padding_option='zero')\n    else:\n        return tl.load(ptrs)\n\n\n@triton.jit\ndef _prefill_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, q1, k1_ptrs, loop_start, loop_end, sm_scale, alibi_slope,\n                       global_offs_m, history_mask, kv_min_loc, causal_mask: tl.constexpr, window_size: tl.constexpr,\n                       logit_softcapping: tl.constexpr, k_bound: tl.constexpr, v_bound: tl.constexpr,\n                       shared_kv: tl.constexpr, block_sparse_size: tl.constexpr, BLOCK_N: tl.constexpr,\n                       BLOCK_DK1: tl.constexpr):\n    k_ptrs = tl.advance(k_ptrs, (0, loop_start))\n    v_ptrs = tl.advance(v_ptrs, (loop_start, 0))\n    if BLOCK_DK1:\n        k1_ptrs = tl.advance(k1_ptrs, (0, loop_start))\n\n    offs_n = tl.arange(0, BLOCK_N)\n    for start_n in range(loop_start, loop_end, BLOCK_N):\n        start_n = tl.multiple_of(start_n, BLOCK_N)\n\n        k = _load_kv(k_ptrs, boundary_check=k_bound)\n        qk = tl.dot(q, k)\n\n        if BLOCK_DK1 != 0:\n            k1 = _load_kv(k1_ptrs, boundary_check=k_bound)\n            qk += tl.dot(q1, k1)\n\n        if causal_mask:\n            qk *= sm_scale\n            qk = softcapping(qk, logit_softcapping)\n            qk = qk * tl_log2(math.e)\n            if block_sparse_size > 1:\n                offs_mask = (start_n + offs_n) // block_sparse_size * block_sparse_size\n                qk_mask = (history_mask[:, None]) >= offs_mask[None, :]\n            else:\n                qk_mask = (history_mask[:, None]) >= (start_n + offs_n[None, :])\n            if window_size > 0:\n                qk_mask = qk_mask & ((start_n + offs_n[None, :]) >= kv_min_loc[:, None])\n            qk = tl.where(\n                qk_mask,\n                qk,\n                float(-1e30),\n            )\n            m_i_new = tl.maximum(m_i, tl.max(qk, 1))\n            qk -= m_i_new[:, None]\n        elif window_size > 0:\n            qk *= sm_scale\n            qk = softcapping(qk, logit_softcapping)\n            qk = qk * tl_log2(math.e)\n            qk_mask = ((start_n + offs_n[None, :]) >= kv_min_loc[:, None])\n            qk = tl.where(\n                qk_mask,\n                qk,\n                float(-1e30),\n            )\n            m_i_new = tl.maximum(m_i, tl.max(qk, 1))\n            qk -= m_i_new[:, None]\n        elif logit_softcapping > 0:\n            qk *= sm_scale\n            qk = softcapping(qk, logit_softcapping)\n            qk = qk * tl_log2(math.e)\n            m_i_new = tl.maximum(m_i, tl.max(qk, 1))\n            qk -= m_i_new[:, None]\n        else:\n            qk_scale = sm_scale * tl_log2(math.e)\n            m_i_new = tl.maximum(m_i, tl.max(qk, 1) * qk_scale)\n            qk = qk * qk_scale - m_i_new[:, None]\n\n        if alibi_slope is not None:\n            relative_pos = start_n + offs_n[None, :] - global_offs_m[:, None]\n            bias = -tl.abs(relative_pos).to(tl.float32) * alibi_slope * tl_log2(math.e)\n            qk += bias\n\n        # -- compute p, m_i and l_i\n        p = tl_exp2(qk)\n        alpha = tl_exp2(m_i - m_i_new)\n        l_i = alpha * l_i + tl.sum(p, 1)\n        # -- update output accumulator --\n        # scale acc\n        acc = acc * alpha[:, None]\n\n        # update acc\n        if shared_kv:\n            v = tl.trans(k)\n        else:\n            v = _load_kv(v_ptrs, boundary_check=v_bound)\n        p = p.to(v.dtype)\n        acc += tl.dot(p, v)\n        # update m_i and l_i\n        m_i = m_i_new\n\n        k_ptrs = tl.advance(k_ptrs, (0, BLOCK_N))\n        v_ptrs = tl.advance(v_ptrs, (BLOCK_N, 0))\n        if BLOCK_DK1:\n            k1_ptrs = tl.advance(k1_ptrs, (0, BLOCK_N))\n\n    return acc, l_i, m_i\n\n\n# # FOR DEBUG, DON'T REMOVE\n# import itertools\n# configs = [\n#     triton.Config({\n#         'BLOCK_M': BM,\n#         'BLOCK_N': BN\n#     }, num_stages=s, num_warps=w)\n#     for BM, BN, s, w in itertools.product([64, 128], [32, 64], [3, 4], [4])\n# ]\n\n\n# @triton.autotune(list(configs),\n#                  key=['head_dim_k', 'head_dim_v'])\n@triton.jit\ndef _flash_prefill_fwd_kernel(\n    q_ptr,\n    k_ptr,\n    v_ptr,\n    o_ptr,\n    cu_seqlens_q_ptr,\n    cu_seqlens_k_ptr,\n    q_start_loc_ptr,\n    q_seqlens_ptr,\n    kv_start_loc_ptr,\n    kv_seqlens_ptr,\n    sinks,\n    alibi_slopes_ptr,\n    sm_scale,\n    stride_qs: tl.constexpr,\n    stride_qh: tl.constexpr,\n    stride_qd: tl.constexpr,\n    stride_ks: tl.constexpr,\n    stride_kh,\n    stride_kd: tl.constexpr,\n    stride_vs: tl.constexpr,\n    stride_vh,\n    stride_vd: tl.constexpr,\n    stride_os: tl.constexpr,\n    stride_oh: tl.constexpr,\n    stride_od: tl.constexpr,\n    kv_group_num,\n    head_dim_k: tl.constexpr,\n    head_dim_v: tl.constexpr,\n    causal: tl.constexpr,\n    window_size: tl.constexpr,\n    logit_softcapping: tl.constexpr,\n    shared_kv: tl.constexpr,\n    block_sparse_size: tl.constexpr,\n    BLOCK_M: tl.constexpr,\n    BLOCK_N: tl.constexpr,\n    BLOCK_DK: tl.constexpr,\n    BLOCK_DK1: tl.constexpr,\n    BLOCK_DV: tl.constexpr,\n):\n    \"\"\"Flash attention kernel.\"\"\"\n    start_m = tl.program_id(0)\n    head_id = tl.program_id(1)\n    batch_id = tl.program_id(2)\n\n    if cu_seqlens_q_ptr is not None:\n        q_start_loc = tl.load(cu_seqlens_q_ptr + batch_id).to(tl.int32)\n        q_seqlen = tl.load(cu_seqlens_q_ptr + batch_id + 1).to(tl.int32) - q_start_loc\n    else:\n        q_start_loc = tl.load(q_start_loc_ptr + batch_id).to(tl.int32)\n        q_seqlen = tl.load(q_seqlens_ptr + batch_id).to(tl.int32)\n\n    if cu_seqlens_k_ptr is not None:\n        kv_start_loc = tl.load(cu_seqlens_k_ptr + batch_id).to(tl.int32)\n        kv_seqlen = tl.load(cu_seqlens_k_ptr + batch_id + 1).to(tl.int32) - kv_start_loc\n    else:\n        kv_start_loc = tl.load(kv_start_loc_ptr + batch_id).to(tl.int32)\n        kv_seqlen = tl.load(kv_seqlens_ptr + batch_id).to(tl.int32)\n\n    if BLOCK_M * start_m >= q_seqlen:\n        return\n\n    kv_head_id = head_id // kv_group_num\n    history_len = kv_seqlen - q_seqlen\n\n    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n\n    loop_start = 0\n    kv_min_loc = tl.zeros([BLOCK_M], dtype=tl.int32)\n    if window_size > 0:\n        start_block_id = tl.maximum(history_len + start_m * BLOCK_M - window_size, 0) // BLOCK_N\n        kv_min_loc = tl.maximum(history_len + offs_m - window_size, 0)\n        loop_start = start_block_id * BLOCK_N\n\n    offs_dk = tl.arange(0, BLOCK_DK)\n    mask_dk = offs_dk < head_dim_k\n    offs_dk = tl.multiple_of(tl.max_contiguous(offs_dk % head_dim_k, BLOCK_DK), BLOCK_DK)\n    off_q = ((q_start_loc + offs_m[:, None]) * stride_qs + head_id * stride_qh + offs_dk[None, :] * stride_qd)\n    q_ptrs = q_ptr + off_q\n    q = tl.load(q_ptrs, mask=((offs_m[:, None] < q_seqlen) & mask_dk[None, :]))\n\n    k_ptrs = tl.make_block_ptr(\n        base=k_ptr + kv_start_loc * stride_ks + kv_head_id * stride_kh,\n        shape=(head_dim_k, kv_seqlen),\n        strides=(stride_kd, stride_ks),\n        offsets=(0, 0),\n        block_shape=(BLOCK_DK, BLOCK_N),\n        order=(0, 1),\n    )\n    v_ptrs = tl.make_block_ptr(\n        base=v_ptr + kv_start_loc * stride_vs + kv_head_id * stride_vh,\n        shape=(kv_seqlen, head_dim_v),\n        strides=(stride_vs, stride_vd),\n        offsets=(0, 0),\n        block_shape=(BLOCK_N, BLOCK_DV),\n        order=(1, 0),\n    )\n\n    # for alibi\n    if alibi_slopes_ptr is not None:\n        alibi_slope = tl.load(alibi_slopes_ptr + head_id)\n    else:\n        alibi_slope = None\n    global_offs_m = history_len + offs_m\n\n    if BLOCK_DK + BLOCK_DK1 == head_dim_k:\n        k_bound0: tl.constexpr = None\n        k_bound1: tl.constexpr = (1, )\n    else:\n        k_bound0: tl.constexpr = (1, )\n        k_bound1: tl.constexpr = (0, 1)\n    if head_dim_v == BLOCK_DV:\n        v_bound0: tl.constexpr = None\n        v_bound1: tl.constexpr = (0, )\n    else:\n        v_bound0: tl.constexpr = (1, )\n        v_bound1: tl.constexpr = (0, 1)\n\n    if BLOCK_DK1 != 0:\n        offs_dk1 = BLOCK_DK + tl.arange(0, BLOCK_DK1)\n        mask_dk1 = offs_dk1 < head_dim_k\n        offs_dk1 = tl.multiple_of(tl.max_contiguous(offs_dk1 % head_dim_k, BLOCK_DK1), BLOCK_DK1)\n        offs_q1 = ((q_start_loc + offs_m[:, None]) * stride_qs + head_id * stride_qh + offs_dk1[None, :] * stride_qd)\n        q1_ptrs = q_ptr + offs_q1\n        q1 = tl.load(q1_ptrs, mask=((offs_m[:, None] < q_seqlen) & mask_dk1[None, :]))\n        k1_ptrs = tl.make_block_ptr(\n            base=k_ptr + kv_start_loc * stride_ks + kv_head_id * stride_kh,\n            shape=(head_dim_k, kv_seqlen),\n            strides=(stride_kd, stride_ks),\n            offsets=(BLOCK_DK, 0),\n            block_shape=(BLOCK_DK1, BLOCK_N),\n            order=(0, 1),\n        )\n    else:\n        q1 = q\n        k1_ptrs = k_ptrs\n\n    m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf')\n    l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0\n    acc = tl.zeros([BLOCK_M, BLOCK_DV], dtype=tl.float32)\n\n    if causal:\n        history_mask = history_len + start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n        loop_end = (history_len + start_m * BLOCK_M) // BLOCK_N * BLOCK_N\n    else:\n        history_mask = tl.full([BLOCK_M], kv_seqlen - 1, dtype=tl.int32)\n        loop_end = kv_seqlen // BLOCK_N * BLOCK_N\n\n    acc, l_i, m_i = _prefill_fwd_inner(acc,\n                                       l_i,\n                                       m_i,\n                                       q,\n                                       k_ptrs,\n                                       v_ptrs,\n                                       q1,\n                                       k1_ptrs,\n                                       loop_start,\n                                       loop_end,\n                                       sm_scale,\n                                       alibi_slope,\n                                       global_offs_m,\n                                       history_mask,\n                                       kv_min_loc,\n                                       causal_mask=False,\n                                       window_size=window_size,\n                                       logit_softcapping=logit_softcapping,\n                                       k_bound=k_bound0,\n                                       v_bound=v_bound0,\n                                       shared_kv=shared_kv,\n                                       block_sparse_size=block_sparse_size,\n                                       BLOCK_N=BLOCK_N,\n                                       BLOCK_DK1=BLOCK_DK1)\n\n    loop_start = loop_end\n    if causal:\n        loop_end = tl.minimum(kv_seqlen, loop_start + BLOCK_M + BLOCK_N)\n    else:\n        loop_end = kv_seqlen\n    acc, l_i, m_i = _prefill_fwd_inner(acc,\n                                       l_i,\n                                       m_i,\n                                       q,\n                                       k_ptrs,\n                                       v_ptrs,\n                                       q1,\n                                       k1_ptrs,\n                                       loop_start,\n                                       loop_end,\n                                       sm_scale,\n                                       alibi_slope,\n                                       global_offs_m,\n                                       history_mask,\n                                       kv_min_loc,\n                                       causal_mask=True,\n                                       window_size=window_size,\n                                       logit_softcapping=logit_softcapping,\n                                       k_bound=k_bound1,\n                                       v_bound=v_bound1,\n                                       shared_kv=shared_kv,\n                                       block_sparse_size=block_sparse_size,\n                                       BLOCK_N=BLOCK_N,\n                                       BLOCK_DK1=BLOCK_DK1)\n    # epilogue\n    if sinks is not None:\n        sink = tl.load(sinks + head_id).to(l_i.dtype)\n        l_i = l_i + tl.exp2(sink * tl_log2(math.e) - m_i)\n\n    m_i += tl.math.log2(l_i)\n    acc = acc / l_i[:, None]\n\n    # initialize pointers to output\n    offs_dv = tl.arange(0, BLOCK_DV)\n    mask_dv = offs_dv < head_dim_v\n    off_o = ((q_start_loc + offs_m[:, None]) * stride_os + head_id * stride_oh + offs_dv[None, :] * stride_od)\n    out_ptrs = o_ptr + off_o\n    tl.store(out_ptrs, acc, mask=(offs_m[:, None] < q_seqlen) & mask_dv[None, :])\n\n\n_nv_cap = None\n\n\ndef _kernel_meta_sm7x(BLOCK_DK):\n    num_warps = 4\n    num_stages = min(4, max(2, 768 // BLOCK_DK))\n    BLOCK_M = max(16, 8192 // BLOCK_DK)\n    BLOCK_N = 32\n    return BLOCK_M, BLOCK_N, num_warps, num_stages\n\n\ndef _kernel_meta_sm8x(BLOCK_DK: int, shared_kv: bool):\n    num_warps = 8\n    min_m = 64 if shared_kv else 16\n    BLOCK_M = max(min_m, 16384 // BLOCK_DK)\n    BLOCK_M = min(128, BLOCK_M)\n    BLOCK_N = BLOCK_M\n    num_stages = 3 if BLOCK_DK <= 128 else 2\n\n    return BLOCK_M, BLOCK_N, num_warps, num_stages\n\n\ndef _kernel_meta_sm86(BLOCK_DK: int, shared_kv: bool):\n    \"\"\"Sm86 has different smem size with sm80.\"\"\"\n    num_warps = 4\n    if BLOCK_DK <= 128:\n        BLOCK_M = 128\n        BLOCK_N = 64\n        num_stages = 3\n    elif BLOCK_DK <= 256:\n        BLOCK_M = 64\n        BLOCK_N = 32\n        num_stages = 2\n    else:\n        BLOCK_M = 32\n        BLOCK_N = 32\n        num_stages = 2\n\n    return BLOCK_M, BLOCK_N, num_warps, num_stages\n\n\ndef _kernel_meta_sm9x(BLOCK_DK: int, shared_kv: bool):\n\n    num_warps = 8\n    BLOCK_M = 128 if BLOCK_DK <= 256 else 64\n    if not shared_kv and BLOCK_DK >= 512:\n        BLOCK_M = 32\n\n    # fix crash on triton<3.2.0\n    if BLOCK_DK >= 512 and TRITON_VERSION < VERSION_320:\n        BLOCK_M = 32\n        num_warps = 4\n\n    BLOCK_N = 128 if BLOCK_DK <= 128 else 64\n\n    num_stages = 3 if BLOCK_DK <= 128 else 2\n    return BLOCK_M, BLOCK_N, num_warps, num_stages\n\n\ndef _kernel_meta_sm12x(BLOCK_DK: int, shared_kv: bool):\n    # Blackwell (sm_120, cc 12.x) + B200/B100 variants\n    if BLOCK_DK <= 128:\n        BLOCK_M = 128\n        BLOCK_N = 128 if shared_kv else 64\n        num_warps = 8\n        num_stages = 3\n    elif BLOCK_DK <= 256:\n        BLOCK_M = 64\n        BLOCK_N = 128 if shared_kv else 64\n        num_warps = 8\n        num_stages = 3\n    elif BLOCK_DK <= 512:\n        BLOCK_M = 64 if shared_kv else 32\n        BLOCK_N = 64\n        num_warps = 4\n        num_stages = 2\n    else:\n        BLOCK_M = 32\n        BLOCK_N = 32 if not shared_kv else 64\n        num_warps = 4\n        num_stages = 2\n\n    return BLOCK_M, BLOCK_N, num_warps, num_stages\n\n\ndef _kernel_meta_rocm(BLOCK_DK: int, shared_kv: bool):\n    BLOCK_N = 32\n    BLOCK_M = 32 if BLOCK_DK > 128 else 64\n    num_warps = 4\n    num_stages = 1\n    return BLOCK_M, BLOCK_N, num_warps, num_stages\n\n\ndef flash_attn_varlen_func(\n    q: Tensor,\n    k: Tensor,\n    v: Tensor,\n    cu_seqlens_q: Tensor = None,\n    cu_seqlens_k: Tensor = None,\n    max_seqlen_q: int = None,\n    max_seqlen_k: int = None,  # not used, just for align with fa interface\n    softmax_scale: float = None,\n    causal: bool = False,\n    window_size: int = (-1, -1),\n    softcap: float = 0.0,\n    # old seqlens\n    q_start_loc: Tensor = None,\n    q_seqlens: Tensor = None,\n    kv_start_loc: Tensor = None,\n    kv_seqlens: Tensor = None,\n    # args not in fa\n    alibi_slopes: Tensor = None,\n    sinks: Tensor = None,\n    block_sparse_size: int = 1,\n    kv_layout: str = 'hsd',\n):\n    \"\"\"Varlen flash Attention forward.\n\n    Support sliding window, softcapping.\n    \"\"\"\n\n    global _nv_cap\n    if _nv_cap is None:\n        _nv_cap = torch.cuda.get_device_capability()\n\n    def grid(args):\n        return (triton.cdiv(max_seqlen_q, args['BLOCK_M']), num_heads, batch)\n\n    if kv_layout == 'shd':\n        s_dim, h_dim, d_dim = (0, 1, 2)\n    elif kv_layout == 'hsd':\n        s_dim, h_dim, d_dim = (1, 0, 2)\n    else:\n        raise RuntimeError('Unsupported layout.')\n\n    if max_seqlen_q is None:\n        max_seqlen_q = q.size(0)\n\n    if window_size is None:\n        window_size = -1\n    elif isinstance(window_size, Sequence):\n        window_size = window_size[0]\n\n    if softcap is None:\n        softcap = -1.0\n\n    head_dim_q = q.size(-1)\n    head_dim_k = k.size(d_dim)\n    head_dim_v = v.size(d_dim)\n\n    o = q.new_empty(*q.size()[:-1], head_dim_v)\n    assert head_dim_q == head_dim_k and head_dim_v == o.size(-1)\n\n    if softmax_scale is None:\n        softmax_scale = 1.0 / (head_dim_q**0.5)\n\n    if cu_seqlens_k is None:\n        assert kv_start_loc is not None and kv_seqlens is not None\n    if cu_seqlens_q is None:\n        assert q_start_loc is not None and q_seqlens is not None\n        batch = q_seqlens.size(0)\n    else:\n        batch = cu_seqlens_q.size(0) - 1\n    num_heads = q.size(-2)\n    num_kv_heads = k.size(h_dim)\n    kv_group_num = num_heads // num_kv_heads\n\n    if sinks is not None:\n        assert sinks.is_contiguous()\n        assert sinks.numel() == num_heads\n\n    BLOCK_DK, BLOCK_DK1, BLOCK_DV = _get_block_d(head_dim_k, head_dim_v)\n\n    shared_kv = k.data_ptr() == v.data_ptr() and BLOCK_DK == BLOCK_DV\n\n    num_warps = 4\n    hip_mode = getattr(torch.version, 'hip', None) is not None\n    if hip_mode:\n        BLOCK_M, BLOCK_N, num_warps, num_stages = _kernel_meta_rocm(BLOCK_DK, shared_kv)\n    else:\n        if _nv_cap[0] < 8:\n            BLOCK_M, BLOCK_N, num_warps, num_stages = _kernel_meta_sm7x(BLOCK_DK)\n        elif _nv_cap[0] < 9:\n            if _nv_cap[1] in [6, 9]:\n                BLOCK_M, BLOCK_N, num_warps, num_stages = _kernel_meta_sm86(BLOCK_DK, shared_kv)\n            else:\n                BLOCK_M, BLOCK_N, num_warps, num_stages = _kernel_meta_sm8x(BLOCK_DK, shared_kv)\n        elif _nv_cap[0] < 10:\n            BLOCK_M, BLOCK_N, num_warps, num_stages = _kernel_meta_sm9x(BLOCK_DK, shared_kv)\n        else:\n            BLOCK_M, BLOCK_N, num_warps, num_stages = _kernel_meta_sm12x(BLOCK_DK, shared_kv)\n\n    BLOCK_M = min(128, BLOCK_M)\n    _flash_prefill_fwd_kernel[grid](\n        q,\n        k,\n        v,\n        o,\n        cu_seqlens_q,\n        cu_seqlens_k,\n        q_start_loc,\n        q_seqlens,\n        kv_start_loc,\n        kv_seqlens,\n        sinks,\n        alibi_slopes,\n        sm_scale=softmax_scale,\n        stride_qs=q.stride(0),\n        stride_qh=q.stride(1),\n        stride_qd=q.stride(2),\n        stride_ks=k.stride(s_dim),\n        stride_kh=k.stride(h_dim),\n        stride_kd=k.stride(d_dim),\n        stride_vs=v.stride(s_dim),\n        stride_vh=v.stride(h_dim),\n        stride_vd=v.stride(d_dim),\n        stride_os=o.stride(0),\n        stride_oh=o.stride(1),\n        stride_od=o.stride(2),\n        kv_group_num=kv_group_num,\n        head_dim_k=head_dim_k,\n        head_dim_v=head_dim_v,\n        causal=causal,\n        window_size=window_size,\n        logit_softcapping=softcap,\n        shared_kv=shared_kv,\n        block_sparse_size=block_sparse_size,\n        BLOCK_DK=BLOCK_DK,\n        BLOCK_DK1=BLOCK_DK1,\n        BLOCK_DV=BLOCK_DV,\n        BLOCK_M=BLOCK_M,\n        BLOCK_N=BLOCK_N,\n        num_warps=num_warps,\n        num_stages=num_stages,\n    )\n\n    return o\n"
  },
  {
    "path": "lmdeploy/pytorch/kernels/cuda/flatten_kv_cache.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom typing import Literal\n\nimport torch\nimport triton\nimport triton.language as tl\nfrom torch import Tensor\n\n\n@triton.jit\ndef _flatten_kv_cache(\n    kc_ptr,\n    vc_ptr,\n    ko_ptr,\n    vo_ptr,\n    start_loc_ptr,\n    seqlens_ptr,\n    block_offsets_ptr,\n    stride_kcb: tl.constexpr,\n    stride_kcs: tl.constexpr,\n    stride_kch: tl.constexpr,\n    stride_kcd: tl.constexpr,\n    stride_vcb: tl.constexpr,\n    stride_vcs: tl.constexpr,\n    stride_vch: tl.constexpr,\n    stride_vcd: tl.constexpr,\n    stride_koh,\n    stride_kos: tl.constexpr,\n    stride_kod: tl.constexpr,\n    stride_voh,\n    stride_vos: tl.constexpr,\n    stride_vod: tl.constexpr,\n    stride_boff,\n    OUT_SIZE,\n    HEAD_DIM_K: tl.constexpr,\n    HEAD_DIM_V: tl.constexpr,\n    BLOCK_BS: tl.constexpr,\n    BLOCK_DK: tl.constexpr,\n    BLOCK_DV: tl.constexpr,\n):\n    \"\"\"Flatten kv cache.\"\"\"\n    page_id = tl.program_id(0)\n    batch_id = tl.program_id(1)\n    head_id = tl.program_id(2)\n\n    num_batches = tl.num_programs(1)\n\n    seqlen = tl.load(seqlens_ptr + batch_id)\n    start_loc = tl.load(start_loc_ptr + batch_id)\n    # fill last block to prevent attention nan\n    if batch_id == num_batches - 1:\n        seqlen = (OUT_SIZE - start_loc).to(seqlen.dtype)\n    if page_id * BLOCK_BS >= seqlen:\n        return\n\n    start_loc = tl.load(start_loc_ptr + batch_id)\n    b_off = tl.load(block_offsets_ptr + batch_id * stride_boff + page_id)\n\n    offs_bs = tl.arange(0, BLOCK_BS)\n    offs_dk = tl.arange(0, BLOCK_DK) % HEAD_DIM_K\n    offs_dv = tl.arange(0, BLOCK_DV) % HEAD_DIM_V\n    offs_obs = page_id * BLOCK_BS + tl.arange(0, BLOCK_BS)\n    mask_bs = offs_obs < seqlen\n    mask_dk = tl.arange(0, BLOCK_DK) < HEAD_DIM_K\n    mask_dv = tl.arange(0, BLOCK_DV) < HEAD_DIM_V\n\n    kc_ptrs = (kc_ptr + b_off * stride_kcb + offs_bs[:, None] * stride_kcs + head_id * stride_kch +\n               offs_dk[None, :] * stride_kcd)\n    vc_ptrs = (vc_ptr + b_off * stride_vcb + offs_bs[:, None] * stride_vcs + head_id * stride_vch +\n               offs_dv[None, :] * stride_vcd)\n    ko_ptrs = (ko_ptr + head_id * stride_koh + (start_loc + offs_obs[:, None]) * stride_kos +\n               offs_dk[None, :] * stride_kod)\n    vo_ptrs = (vo_ptr + head_id * stride_voh + (start_loc + offs_obs[:, None]) * stride_vos +\n               offs_dv[None, :] * stride_vod)\n\n    kc = tl.load(kc_ptrs)\n    tl.store(ko_ptrs, kc, mask=mask_bs[:, None] & mask_dk[None, :])\n    if HEAD_DIM_V > 0:\n        vc = tl.load(vc_ptrs)\n        tl.store(vo_ptrs, vc, mask=mask_bs[:, None] & mask_dv[None, :])\n\n\n@triton.jit\ndef _dequant_int4(val, HEAD_DIM: tl.constexpr, BLOCK: tl.constexpr):\n    \"\"\"Dequant int4.\"\"\"\n    offs = tl.arange(0, BLOCK) // (HEAD_DIM // 2)\n    shift = (offs % 2) * 4\n    return (val >> shift) & 0xf\n\n\n@triton.jit\ndef _flatten_kv_cache_quant(\n    kc_ptr,\n    vc_ptr,\n    ko_ptr,\n    vo_ptr,\n    ksz_ptr,\n    vsz_ptr,\n    start_loc_ptr,\n    seqlens_ptr,\n    block_offsets_ptr,\n    stride_kcb: tl.constexpr,\n    stride_kcs: tl.constexpr,\n    stride_kch: tl.constexpr,\n    stride_kcd: tl.constexpr,\n    stride_vcb: tl.constexpr,\n    stride_vcs: tl.constexpr,\n    stride_vch: tl.constexpr,\n    stride_vcd: tl.constexpr,\n    stride_kszb: tl.constexpr,\n    stride_kszs: tl.constexpr,\n    stride_kszh: tl.constexpr,\n    stride_kszd: tl.constexpr,\n    stride_vszb: tl.constexpr,\n    stride_vszs: tl.constexpr,\n    stride_vszh: tl.constexpr,\n    stride_vszd: tl.constexpr,\n    stride_koh,\n    stride_kos: tl.constexpr,\n    stride_kod: tl.constexpr,\n    stride_voh,\n    stride_vos: tl.constexpr,\n    stride_vod: tl.constexpr,\n    stride_boff,\n    quant_policy: tl.constexpr,\n    OUT_SIZE,\n    HEAD_DIM_K: tl.constexpr,\n    HEAD_DIM_V: tl.constexpr,\n    BLOCK_BS: tl.constexpr,\n    BLOCK_DK: tl.constexpr,\n    BLOCK_DV: tl.constexpr,\n):\n    \"\"\"Flatten kv cache.\"\"\"\n    page_id = tl.program_id(0)\n    batch_id = tl.program_id(1)\n    head_id = tl.program_id(2)\n\n    num_batches = tl.num_programs(1)\n\n    seqlen = tl.load(seqlens_ptr + batch_id)\n    start_loc = tl.load(start_loc_ptr + batch_id)\n    if batch_id == num_batches - 1:\n        seqlen = OUT_SIZE - start_loc\n    if page_id * BLOCK_BS >= seqlen:\n        return\n\n    b_off = tl.load(block_offsets_ptr + batch_id * stride_boff + page_id)\n\n    offs_bs = tl.arange(0, BLOCK_BS)\n    if quant_policy == 4:\n        HALF_HDK: tl.constexpr = HEAD_DIM_K // 2\n        HALF_HDV: tl.constexpr = HEAD_DIM_V // 2\n        offs_dk = tl.arange(0, BLOCK_DK) % HALF_HDK\n        offs_dv = tl.arange(0, BLOCK_DV) % HALF_HDV\n    else:\n        offs_dk = tl.arange(0, BLOCK_DK) % HEAD_DIM_K\n        offs_dv = tl.arange(0, BLOCK_DV) % HEAD_DIM_V\n    offs_obs = page_id * BLOCK_BS + tl.arange(0, BLOCK_BS)\n    mask_bs = offs_obs < seqlen\n\n    offs_dok = tl.arange(0, BLOCK_DK)\n    offs_dov = tl.arange(0, BLOCK_DV)\n    mask_dok = offs_dok < HEAD_DIM_K\n    mask_dov = offs_dov < HEAD_DIM_V\n\n    kc_ptrs = (kc_ptr + b_off * stride_kcb + offs_bs[:, None] * stride_kcs + head_id * stride_kch +\n               offs_dk[None, :] * stride_kcd)\n    vc_ptrs = (vc_ptr + b_off * stride_vcb + offs_bs[:, None] * stride_vcs + head_id * stride_vch +\n               offs_dv[None, :] * stride_vcd)\n    ksz_ptrs = (ksz_ptr + b_off * stride_kszb + offs_bs * stride_kszs + head_id * stride_kszh)\n    vsz_ptrs = (vsz_ptr + b_off * stride_vszb + offs_bs * stride_vszs + head_id * stride_vszh)\n    ko_ptrs = (ko_ptr + head_id * stride_koh + (start_loc + offs_obs[:, None]) * stride_kos +\n               offs_dok[None, :] * stride_kod)\n    vo_ptrs = (vo_ptr + head_id * stride_voh + (start_loc + offs_obs[:, None]) * stride_vos +\n               offs_dov[None, :] * stride_vod)\n\n    kc = tl.load(kc_ptrs)\n    if quant_policy == 4:\n        kc = _dequant_int4(kc, HEAD_DIM_K, BLOCK_DK)\n    ks = tl.load(ksz_ptrs)\n    kz = tl.load(ksz_ptrs + stride_kszd)\n    ksz = ks * kz\n    kq = (kc * ks[:, None] - ksz[:, None]).to(ko_ptr.dtype.element_ty)\n    tl.store(ko_ptrs, kq, mask=mask_bs[:, None] & mask_dok[None, :])\n    vc = tl.load(vc_ptrs)\n    if quant_policy == 4:\n        vc = _dequant_int4(vc, HEAD_DIM_V, BLOCK_DV)\n    vs = tl.load(vsz_ptrs)\n    vz = tl.load(vsz_ptrs + stride_vszd)\n    vsz = vs * vz\n    vq = (vc * vs[:, None] - vsz[:, None]).to(vo_ptr.dtype.element_ty)\n    tl.store(vo_ptrs, vq, mask=mask_bs[:, None] & mask_dov[None, :])\n\n\ndef flatten_kv_cache(k_caches: Tensor,\n                     v_caches: Tensor,\n                     seqlens: Tensor,\n                     block_offsets: Tensor,\n                     start_loc: Tensor = None,\n                     out_size: int = None,\n                     out_dtype: torch.dtype = None,\n                     k_scales_zeros: Tensor = None,\n                     v_scales_zeros: Tensor = None,\n                     quant_policy: Literal[0, 4, 8] = 0,\n                     kv_layout: str = 'bshd',\n                     flatten_kv_layout: str = 'hsd'):\n    \"\"\"Recovery paged kv cache to normal kv cache.\"\"\"\n    if kv_layout == 'bshd':\n        b_dim, s_dim, h_dim, d_dim = (0, 1, 2, 3)\n    elif kv_layout == 'bhsd':\n        b_dim, s_dim, h_dim, d_dim = (0, 2, 1, 3)\n    else:\n        raise RuntimeError('Unsupported layout.')\n\n    if out_dtype is None:\n        out_dtype = k_caches.dtype\n\n    if out_size is None or out_size <= 0:\n        out_size = k_caches.size(b_dim) * k_caches.size(s_dim)\n\n    if start_loc is None:\n        start_loc = seqlens.cumsum(0) - seqlens\n\n    batch_size, num_blocks = block_offsets.size()\n    num_heads = k_caches.size(h_dim)\n    k_head_dim = k_caches.size(d_dim)\n    v_head_dim = v_caches.size(d_dim)\n    if quant_policy == 4:\n        k_head_dim *= 2\n        v_head_dim *= 2\n    BLOCK_DK = triton.next_power_of_2(k_head_dim)\n    BLOCK_DV = triton.next_power_of_2(v_head_dim)\n    BLOCK_BS = k_caches.size(s_dim)\n    shared_kv = k_caches.data_ptr() == v_caches.data_ptr() and v_head_dim < k_head_dim\n    if flatten_kv_layout == 'hsd':\n        k_states = k_caches.new_empty(num_heads, out_size, k_head_dim, dtype=out_dtype)\n        if quant_policy == 0 and shared_kv:\n            v_states = k_states[..., :v_head_dim]\n            v_head_dim = 0\n        else:\n            v_states = v_caches.new_empty(num_heads, out_size, v_head_dim, dtype=out_dtype)\n        stride_koh = k_states.stride(0)\n        stride_kos = k_states.stride(1)\n        stride_voh = v_states.stride(0)\n        stride_vos = v_states.stride(1)\n    elif flatten_kv_layout == 'shd':\n        k_states = k_caches.new_empty(out_size, num_heads, k_head_dim, dtype=out_dtype)\n        if quant_policy == 0 and shared_kv:\n            v_states = k_states[..., :v_head_dim]\n            v_head_dim = 0\n        else:\n            v_states = v_caches.new_empty(out_size, num_heads, v_head_dim, dtype=out_dtype)\n        stride_koh = k_states.stride(1)\n        stride_kos = k_states.stride(0)\n        stride_voh = v_states.stride(1)\n        stride_vos = v_states.stride(0)\n    else:\n        raise RuntimeError('Unsupported layout.')\n\n    grid = (num_blocks, batch_size, num_heads)\n    if quant_policy == 0:\n        _flatten_kv_cache[grid](\n            k_caches,\n            v_caches,\n            k_states,\n            v_states,\n            start_loc,\n            seqlens,\n            block_offsets,\n            stride_kcb=k_caches.stride(b_dim),\n            stride_kcs=k_caches.stride(s_dim),\n            stride_kch=k_caches.stride(h_dim),\n            stride_kcd=k_caches.stride(d_dim),\n            stride_vcb=v_caches.stride(b_dim),\n            stride_vcs=v_caches.stride(s_dim),\n            stride_vch=v_caches.stride(h_dim),\n            stride_vcd=v_caches.stride(d_dim),\n            stride_koh=stride_koh,\n            stride_kos=stride_kos,\n            stride_kod=k_states.stride(2),\n            stride_voh=stride_voh,\n            stride_vos=stride_vos,\n            stride_vod=v_states.stride(2),\n            stride_boff=block_offsets.stride(0),\n            OUT_SIZE=out_size,\n            HEAD_DIM_K=k_head_dim,\n            HEAD_DIM_V=v_head_dim,\n            BLOCK_BS=BLOCK_BS,\n            BLOCK_DK=BLOCK_DK,\n            BLOCK_DV=BLOCK_DV,\n        )\n    else:\n        _flatten_kv_cache_quant[grid](\n            k_caches,\n            v_caches,\n            k_states,\n            v_states,\n            k_scales_zeros,\n            v_scales_zeros,\n            start_loc,\n            seqlens,\n            block_offsets,\n            stride_kcb=k_caches.stride(b_dim),\n            stride_kcs=k_caches.stride(s_dim),\n            stride_kch=k_caches.stride(h_dim),\n            stride_kcd=k_caches.stride(d_dim),\n            stride_vcb=v_caches.stride(b_dim),\n            stride_vcs=v_caches.stride(s_dim),\n            stride_vch=v_caches.stride(h_dim),\n            stride_vcd=v_caches.stride(d_dim),\n            stride_kszb=k_scales_zeros.stride(b_dim),\n            stride_kszs=k_scales_zeros.stride(s_dim),\n            stride_kszh=k_scales_zeros.stride(h_dim),\n            stride_kszd=k_scales_zeros.stride(d_dim),\n            stride_vszb=v_scales_zeros.stride(b_dim),\n            stride_vszs=v_scales_zeros.stride(s_dim),\n            stride_vszh=v_scales_zeros.stride(h_dim),\n            stride_vszd=v_scales_zeros.stride(d_dim),\n            stride_koh=stride_koh,\n            stride_kos=stride_kos,\n            stride_kod=k_states.stride(2),\n            stride_voh=stride_voh,\n            stride_vos=stride_vos,\n            stride_vod=v_states.stride(2),\n            stride_boff=block_offsets.stride(0),\n            quant_policy=quant_policy,\n            OUT_SIZE=out_size,\n            HEAD_DIM_K=k_head_dim,\n            HEAD_DIM_V=v_head_dim,\n            BLOCK_BS=BLOCK_BS,\n            BLOCK_DK=BLOCK_DK,\n            BLOCK_DV=BLOCK_DV,\n        )\n\n    return k_states, v_states\n\n\n@triton.jit\ndef dequant_fp8(x, scale, GROUP_SIZE: tl.constexpr):\n    \"\"\"Dequant fp8.\"\"\"\n    M: tl.constexpr = x.shape[0]\n    N: tl.constexpr = x.shape[1]\n    x = x.to(scale.dtype)\n    x = x.reshape(M, N // GROUP_SIZE, GROUP_SIZE)\n    scale = scale.reshape(M, N // GROUP_SIZE, 1)\n    x = x * scale\n    x = x.reshape(M, N)\n    return x\n\n\n@triton.jit\ndef flatten_kv_cache_mla_fp8_kernel(\n    kc_nope_ptr,\n    kc_scale_ptr,\n    kc_pe_ptr,\n    ko_ptr,\n    start_loc_ptr,\n    seqlens_ptr,\n    block_offsets_ptr,\n    stride_kcb: tl.constexpr,\n    stride_kcs: tl.constexpr,\n    stride_kch: tl.constexpr,\n    stride_kcd: tl.constexpr,\n    stride_kcsb: tl.constexpr,\n    stride_kcss: tl.constexpr,\n    stride_kcsh: tl.constexpr,\n    stride_kcsd: tl.constexpr,\n    stride_kcpb: tl.constexpr,\n    stride_kcps: tl.constexpr,\n    stride_kcph: tl.constexpr,\n    stride_kcpd: tl.constexpr,\n    stride_koh,\n    stride_kos: tl.constexpr,\n    stride_kod: tl.constexpr,\n    stride_boff,\n    OUT_SIZE,\n    BLOCK_BS: tl.constexpr,\n    BLOCK_NOPE: tl.constexpr,\n    BLOCK_PE: tl.constexpr,\n    GROUP_SIZE: tl.constexpr,\n):\n    \"\"\"Mla fp8 flatten kv cache kernel.\"\"\"\n    page_id = tl.program_id(0)\n    batch_id = tl.program_id(1)\n    head_id = tl.program_id(2)\n    num_batches = tl.num_programs(1)\n\n    seqlen = tl.load(seqlens_ptr + batch_id)\n    start_loc = tl.load(start_loc_ptr + batch_id)\n    # fill last block to prevent attention nan\n    if batch_id == num_batches - 1:\n        seqlen = OUT_SIZE - start_loc\n    if page_id * BLOCK_BS >= seqlen:\n        return\n\n    b_off = tl.load(block_offsets_ptr + batch_id * stride_boff + page_id)\n\n    BLOCK_SCALE: tl.constexpr = BLOCK_NOPE // GROUP_SIZE\n    offs_bs = tl.arange(0, BLOCK_BS)\n    offs_dnope = tl.arange(0, BLOCK_NOPE)\n    offs_scale = tl.arange(0, BLOCK_SCALE)\n    offs_dpe = tl.arange(0, BLOCK_PE)\n    offs_obs = page_id * BLOCK_BS + tl.arange(0, BLOCK_BS)\n    mask_bs = offs_obs < seqlen\n\n    offs_kc = b_off * stride_kcb + offs_bs[:, None] * stride_kcs + head_id * stride_kch\n    kc_nope_ptrs = (kc_nope_ptr + offs_kc + offs_dnope[None, :] * stride_kcd)\n\n    offs_kc_scale = b_off * stride_kcsb + offs_bs[:, None] * stride_kcss + head_id * stride_kcsh\n    kc_scale_ptrs = (kc_scale_ptr + offs_kc_scale + offs_scale[None, :] * stride_kcsd)\n\n    offs_kc_pe = b_off * stride_kcpb + offs_bs[:, None] * stride_kcps + head_id * stride_kcph\n    kc_pe_ptrs = (kc_pe_ptr + offs_kc_pe + offs_dpe[None, :] * stride_kcpd)\n\n    offs_ko = head_id * stride_koh + (start_loc + offs_obs[:, None]) * stride_kos\n    ko_nope_ptrs = (ko_ptr + offs_ko + offs_dnope[None, :] * stride_kod)\n    ko_pe_ptrs = (ko_ptr + offs_ko + (BLOCK_NOPE + offs_dpe[None, :]) * stride_kod)\n\n    # nope\n    kc_nope = tl.load(kc_nope_ptrs)\n    kc_scale = tl.load(kc_scale_ptrs)\n    ko_nope = dequant_fp8(kc_nope, kc_scale, GROUP_SIZE)\n    ko_nope = ko_nope.to(ko_ptr.dtype.element_ty)\n    tl.store(ko_nope_ptrs, ko_nope, mask=mask_bs[:, None])\n\n    # pe\n    kc_pe = tl.load(kc_pe_ptrs)\n    tl.store(ko_pe_ptrs, kc_pe, mask=mask_bs[:, None])\n\n\ndef flatten_kv_cache_mla_fp8(k_caches: Tensor,\n                             seqlens: Tensor,\n                             block_offsets: Tensor,\n                             start_loc: Tensor = None,\n                             out_size: int = None,\n                             out_dtype: torch.dtype = None,\n                             flatten_kv_layout: str = 'hsd'):\n    \"\"\"This kernel is designed to support mla fp8.\"\"\"\n    assert k_caches.dim() == 4\n\n    b_dim, s_dim, h_dim, d_dim = (0, 1, 2, 3)\n\n    if out_dtype is None:\n        out_dtype = torch.bfloat16\n\n    if out_size is None or out_size <= 0:\n        out_size = k_caches.size(b_dim) * k_caches.size(s_dim)\n\n    # TODO: DIRTY magic number\n    k_caches_nope = k_caches[..., :512]\n    k_caches_scale = k_caches[..., 512:512 + 16].view(torch.float32)\n    k_caches_pe = k_caches[..., 512 + 16:].view(out_dtype)\n\n    if start_loc is None:\n        start_loc = seqlens.cumsum(0) - seqlens\n\n    batch_size, num_blocks = block_offsets.size()\n    num_heads = k_caches.size(h_dim)\n    k_head_dim = 576\n    BLOCK_NOPE = 512\n    BLOCK_PE = 64\n    BLOCK_BS = k_caches.size(s_dim)\n    if flatten_kv_layout == 'hsd':\n        k_states = k_caches.new_empty(num_heads, out_size, k_head_dim, dtype=out_dtype)\n        stride_koh = k_states.stride(0)\n        stride_kos = k_states.stride(1)\n    elif flatten_kv_layout == 'shd':\n        k_states = k_caches.new_empty(out_size, num_heads, k_head_dim, dtype=out_dtype)\n        stride_koh = k_states.stride(1)\n        stride_kos = k_states.stride(0)\n    else:\n        raise RuntimeError(f'Unsupported layout: {flatten_kv_layout}.')\n\n    grid = (num_blocks, batch_size, num_heads)\n    flatten_kv_cache_mla_fp8_kernel[grid](\n        k_caches_nope,\n        k_caches_scale,\n        k_caches_pe,\n        k_states,\n        start_loc,\n        seqlens,\n        block_offsets,\n        stride_kcb=k_caches_nope.stride(b_dim),\n        stride_kcs=k_caches_nope.stride(s_dim),\n        stride_kch=k_caches_nope.stride(h_dim),\n        stride_kcd=k_caches_nope.stride(d_dim),\n        stride_kcsb=k_caches_scale.stride(b_dim),\n        stride_kcss=k_caches_scale.stride(s_dim),\n        stride_kcsh=k_caches_scale.stride(h_dim),\n        stride_kcsd=k_caches_scale.stride(d_dim),\n        stride_kcpb=k_caches_pe.stride(b_dim),\n        stride_kcps=k_caches_pe.stride(s_dim),\n        stride_kcph=k_caches_pe.stride(h_dim),\n        stride_kcpd=k_caches_pe.stride(d_dim),\n        stride_koh=stride_koh,\n        stride_kos=stride_kos,\n        stride_kod=k_states.stride(2),\n        stride_boff=block_offsets.stride(0),\n        OUT_SIZE=out_size,\n        BLOCK_BS=BLOCK_BS,\n        BLOCK_NOPE=BLOCK_NOPE,\n        BLOCK_PE=BLOCK_PE,\n        GROUP_SIZE=128,\n    )\n\n    return k_states\n"
  },
  {
    "path": "lmdeploy/pytorch/kernels/cuda/fused_lora.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport torch\nimport triton\nimport triton.language as tl\n\n\ndef get_autotune_config():\n    \"\"\"Get autotune config.\"\"\"\n    return [\n        triton.Config({\n            'BLOCK_SIZE_M': 32,\n            'BLOCK_SIZE_N': 128,\n            'BLOCK_SIZE_K': 128\n        }, num_stages=4, num_warps=4),\n        triton.Config({\n            'BLOCK_SIZE_M': 16,\n            'BLOCK_SIZE_N': 256,\n            'BLOCK_SIZE_K': 128\n        }, num_stages=4, num_warps=4),\n    ]\n\n\n@triton.jit\ndef _atomic_store(ptrs, val, mask):\n    \"\"\"Atomic store values.\"\"\"\n    dtype = ptrs.dtype.element_ty\n    if (dtype == torch.float16) | (dtype == torch.float32):\n        tl.atomic_add(ptrs, val, mask=mask, sem='relaxed')\n    else:\n        # bfloat16 does not support atomic add\n        origin = tl.load(ptrs, mask=mask)\n        val = val.to(origin.dtype)\n        val += origin\n        tl.store(ptrs, val, mask=mask)\n\n\n@triton.autotune(\n    configs=get_autotune_config(),\n    key=['N', 'K'],\n    restore_value=['c_ptr'],\n)\n@triton.jit\ndef _fused_lora_kernel(\n    a_ptr,\n    lora_a_ptr,\n    lora_b_ptr,\n    c_ptr,\n    scaling_ptr,\n    rank_start_ptr,\n    ranks_ptr,\n    seq_start_ptr,\n    seq_lens_ptr,\n    adapter_ids_ptr,\n    N: tl.constexpr,\n    K: tl.constexpr,\n    stride_am,\n    stride_ak: tl.constexpr,\n    stride_lar: tl.constexpr,\n    stride_lak: tl.constexpr,\n    stride_lbr: tl.constexpr,\n    stride_lbn: tl.constexpr,\n    stride_cm,\n    stride_cn: tl.constexpr,\n    BLOCK_SIZE_R: tl.constexpr,\n    BLOCK_SIZE_M: tl.constexpr,\n    BLOCK_SIZE_N: tl.constexpr,\n    BLOCK_SIZE_K: tl.constexpr,\n    CUM: tl.constexpr,\n):\n    \"\"\"Fused lora kernel.\"\"\"\n    pid = tl.program_id(axis=0)\n    bid = tl.program_id(axis=1)\n\n    M = tl.load(seq_lens_ptr + bid)\n    if M <= 0:\n        return\n\n    seq_start = tl.load(seq_start_ptr + bid)\n    adapter_id = tl.load(adapter_ids_ptr + bid)\n    rank_start = tl.load(rank_start_ptr + adapter_id)\n    rank = tl.load(ranks_ptr + adapter_id)\n\n    pid_m = pid\n\n    if pid_m * BLOCK_SIZE_M >= M:\n        return\n\n    offs_m = (seq_start + pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M))\n    offs_n = tl.arange(0, BLOCK_SIZE_N)\n\n    mask_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) < M\n    offs_cm = offs_m\n    c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_n[None, :]\n\n    if rank == 0:\n        if not CUM:\n            for n in range(0, tl.cdiv(N, BLOCK_SIZE_N)):\n                mask_cn = (offs_n < N - n * BLOCK_SIZE_N)\n                c_mask = mask_cm[:, None] * mask_cn[None, :]\n                tl.store(c_ptrs, 0.0, mask=c_mask)\n                c_ptrs += stride_cn * BLOCK_SIZE_N\n    else:\n\n        offs_am = (seq_start + (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M)\n        offs_r = rank_start + tl.arange(0, BLOCK_SIZE_R) % rank\n        offs_k = tl.arange(0, BLOCK_SIZE_K)\n        a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)\n        la_ptrs = lora_a_ptr + (offs_k[:, None] * stride_lak + offs_r[None, :] * stride_lar)\n\n        accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_R), dtype=tl.float32)\n        for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):\n            # Load the next block of A and B\n            # If it is out of bounds, set it to 0.\n            a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)\n            la = tl.load(la_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)\n            # We accumulate along the K dimension.\n            accumulator = tl.dot(a, la, acc=accumulator)\n            # Advance the ptrs to the next K block.\n            a_ptrs += BLOCK_SIZE_K * stride_ak\n            la_ptrs += BLOCK_SIZE_K * stride_lak\n        ar = accumulator.to(lora_b_ptr.dtype.element_ty)\n\n        scaling = tl.load(scaling_ptr + adapter_id).to(ar.dtype)\n        ar *= scaling\n        ar = tl.where(tl.arange(0, BLOCK_SIZE_R)[None, :] < rank, ar, tl.zeros_like(ar))\n        lb_ptrs = lora_b_ptr + (offs_r[:, None] * stride_lbr + offs_n[None, :] * stride_lbn)\n\n        for n in range(0, tl.cdiv(N, BLOCK_SIZE_N)):\n            lb = tl.load(lb_ptrs, mask=offs_n[None, :] < N - n * BLOCK_SIZE_N)\n            c = tl.dot(ar, lb)\n\n            mask_cn = (offs_n < N - n * BLOCK_SIZE_N)\n            c_mask = mask_cm[:, None] * mask_cn[None, :]\n            if CUM:\n                _atomic_store(c_ptrs, c, mask=c_mask)\n            else:\n                tl.store(c_ptrs, c, mask=c_mask)\n            c_ptrs += stride_cn * BLOCK_SIZE_N\n            lb_ptrs += stride_lbn * BLOCK_SIZE_N\n\n\ndef fused_lora(input: torch.Tensor,\n               lora_a: torch.Tensor,\n               lora_b: torch.Tensor,\n               scaling: torch.LongTensor,\n               rank_start: torch.LongTensor,\n               ranks: torch.LongTensor,\n               seq_start: torch.LongTensor,\n               seq_lens: torch.LongTensor,\n               adapter_ids: torch.LongTensor,\n               max_rank: int,\n               max_seqlen: int,\n               output: torch.Tensor = None,\n               cum: bool = False):\n    \"\"\"Fused lora.\"\"\"\n\n    def grid(META):\n        ret = ((triton.cdiv(max_seqlen, META['BLOCK_SIZE_M'])), batch_size)\n        return ret\n\n    assert input.dim() == 2\n    batch_size = seq_lens.numel()\n    M, K = input.shape\n    N = lora_b.size(1)\n\n    if output is None:\n        output = input.new_empty((M, N))\n        cum = False\n    else:\n        assert output.size(0) == M\n        assert output.size(1) == N\n\n    BLOCK_SIZE_R = max(16, max_rank)\n    _fused_lora_kernel[grid](\n        input,\n        lora_a,\n        lora_b,\n        output,\n        scaling,\n        rank_start,\n        ranks,\n        seq_start,\n        seq_lens,\n        adapter_ids,\n        N,\n        K,\n        stride_am=input.stride(0),\n        stride_ak=input.stride(1),\n        stride_lar=lora_a.stride(0),\n        stride_lak=lora_a.stride(1),\n        stride_lbr=lora_b.stride(0),\n        stride_lbn=lora_b.stride(1),\n        stride_cm=output.stride(0),\n        stride_cn=output.stride(1),\n        BLOCK_SIZE_R=BLOCK_SIZE_R,\n        CUM=cum,\n    )\n\n    return output\n"
  },
  {
    "path": "lmdeploy/pytorch/kernels/cuda/fused_moe.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n# modify from: https://github.com/vllm-project/vllm\nfrom typing import Callable\n\nimport torch\nimport triton\nimport triton.language as tl\n\nfrom .activation import silu_and_mul\n\n\ndef get_cuda_autotune_config():\n    return [\n        triton.Config({\n            'BLOCK_SIZE_M': 128,\n            'BLOCK_SIZE_N': 256,\n            'BLOCK_SIZE_K': 64,\n            'GROUP_SIZE_M': 1,\n        },\n                      num_stages=3,\n                      num_warps=8),\n        triton.Config({\n            'BLOCK_SIZE_M': 64,\n            'BLOCK_SIZE_N': 256,\n            'BLOCK_SIZE_K': 32,\n            'GROUP_SIZE_M': 1,\n        },\n                      num_stages=4,\n                      num_warps=4),\n        # SM8\n        triton.Config({\n            'BLOCK_SIZE_M': 128,\n            'BLOCK_SIZE_N': 128,\n            'BLOCK_SIZE_K': 32,\n            'GROUP_SIZE_M': 1,\n        },\n                      num_stages=4,\n                      num_warps=4),\n        triton.Config({\n            'BLOCK_SIZE_M': 64,\n            'BLOCK_SIZE_N': 256,\n            'BLOCK_SIZE_K': 32,\n            'GROUP_SIZE_M': 1,\n        },\n                      num_stages=4,\n                      num_warps=4),\n        triton.Config({\n            'BLOCK_SIZE_M': 64,\n            'BLOCK_SIZE_N': 128,\n            'BLOCK_SIZE_K': 64,\n            'GROUP_SIZE_M': 1,\n        },\n                      num_stages=4,\n                      num_warps=4),\n        # SM7-\n        triton.Config({\n            'BLOCK_SIZE_M': 64,\n            'BLOCK_SIZE_N': 128,\n            'BLOCK_SIZE_K': 32,\n            'GROUP_SIZE_M': 1,\n        },\n                      num_stages=4,\n                      num_warps=4),\n        triton.Config({\n            'BLOCK_SIZE_M': 128,\n            'BLOCK_SIZE_N': 32,\n            'BLOCK_SIZE_K': 32,\n            'GROUP_SIZE_M': 1,\n        },\n                      num_stages=4,\n                      num_warps=4),\n        triton.Config({\n            'BLOCK_SIZE_M': 64,\n            'BLOCK_SIZE_N': 32,\n            'BLOCK_SIZE_K': 32,\n            'GROUP_SIZE_M': 1,\n        },\n                      num_stages=5,\n                      num_warps=2),\n    ]\n\n\ndef _config_prune_func(config: list, *args, **kwargs):\n    \"\"\"Fused moe config prune.\"\"\"\n    device_cap = torch.cuda.get_device_capability()\n    num_sm9x = 2\n    cum_num_sm8x = 5\n\n    if device_cap[0] >= 9:\n        return config[:num_sm9x]\n    elif device_cap[0] >= 8:\n        return config[num_sm9x:cum_num_sm8x]\n    else:\n        return config[cum_num_sm8x:]\n\n\n@triton.autotune(\n    configs=get_cuda_autotune_config(),\n    key=['N', 'K', 'tune_hint'],\n    prune_configs_by=dict(early_config_prune=_config_prune_func),\n)\n@triton.jit\ndef fused_moe_kernel(\n    A,\n    B,\n    bias,\n    C,\n    SortedIdx,\n    ExpStart,\n    ExpEnd,\n    N: tl.constexpr,\n    K: tl.constexpr,\n    stride_am: tl.constexpr,\n    stride_ak: tl.constexpr,\n    stride_be: tl.constexpr,\n    stride_bn: tl.constexpr,\n    stride_bk: tl.constexpr,\n    stride_cm: tl.constexpr,\n    stride_cn: tl.constexpr,\n    stride_bie: tl.constexpr,\n    stride_bin: tl.constexpr,\n    BLOCK_SIZE_M: tl.constexpr,\n    BLOCK_SIZE_N: tl.constexpr,\n    BLOCK_SIZE_K: tl.constexpr,\n    GROUP_SIZE_M: tl.constexpr,\n    M_NP2: tl.constexpr,\n    tune_hint: tl.constexpr,\n    top_k: tl.constexpr,\n    expert_offset: tl.constexpr,\n    reindex_a: tl.constexpr,\n    reindex_c: tl.constexpr,\n):\n    \"\"\"Fused moe kernel.\"\"\"\n    exp_id = tl.program_id(1)\n    pid = tl.program_id(0)\n\n    exp_start = tl.load(ExpStart + exp_id + expert_offset)\n    exp_end = tl.load(ExpEnd + exp_id + expert_offset)\n    M = exp_end - exp_start\n    if M <= 0:\n        return\n\n    num_pid_m = tl.cdiv(M_NP2, BLOCK_SIZE_M)\n    num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n\n    if GROUP_SIZE_M == 1:\n        pid_m = pid % num_pid_m\n        pid_n = pid // num_pid_m\n    else:\n        num_pid_in_group = GROUP_SIZE_M * num_pid_n\n        group_id = pid // num_pid_in_group\n        first_pid_m = group_id * GROUP_SIZE_M\n        group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n        pid_m = first_pid_m + (pid % group_size_m)\n        pid_n = (pid % num_pid_in_group) // group_size_m\n\n    if pid_m * BLOCK_SIZE_M >= M or pid_n * BLOCK_SIZE_N >= N:\n        return\n\n    offs_sid = exp_start + pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n    mask_sid = offs_sid < exp_end\n    sid = tl.load(SortedIdx + offs_sid, mask=mask_sid, other=0)\n\n    offs_k = tl.arange(0, BLOCK_SIZE_K)\n    if reindex_a:\n        offs_am = sid // top_k\n    else:\n        offs_am = offs_sid\n    a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)\n    offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N\n    offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N)\n\n    # deepseek has 160 experts, exp index would overflow int32\n    exp_off = stride_be * exp_id.to(tl.int64)\n    b_ptrs = B + exp_off + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)\n\n    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n\n    for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):\n        a = tl.load(a_ptrs, mask=mask_sid[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), other=0.0)\n        b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)\n        accumulator = tl.dot(a, b, acc=accumulator)\n        a_ptrs += BLOCK_SIZE_K * stride_ak\n        b_ptrs += BLOCK_SIZE_K * stride_bk\n\n    if bias is not None:\n        bias_ptrs = bias + exp_id * stride_bie + offs_bn * stride_bin\n        bias_val = tl.load(bias_ptrs).to(accumulator.dtype)\n        accumulator += bias_val[None]\n\n    c = accumulator.to(A.dtype.element_ty)\n\n    if reindex_c:\n        offs_cm = sid\n    else:\n        offs_cm = offs_sid\n    c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_bn[None, :]\n    tl.store(c_ptrs, c, mask=mask_sid[:, None])\n\n\ndef fused_moe_kernel_launcher(\n    A: torch.Tensor,\n    B: torch.Tensor,\n    C: torch.Tensor,\n    sorted_idx: torch.Tensor,\n    exp_start: torch.Tensor,\n    exp_end: torch.Tensor,\n    bias: torch.Tensor = None,\n    top_k: int = 1,\n    num_tokens: int = None,\n    expert_offset: int = 0,\n    reindex_a: bool = True,\n    reindex_c: bool = True,\n):\n    \"\"\"Fused moe kernel launcher.\"\"\"\n\n    if num_tokens is None:\n        num_tokens = A.size(0)\n    M_NP2 = triton.next_power_of_2(num_tokens)\n    M_NP2 = max(64, M_NP2)\n    E, N, K = B.shape\n    tune_hint = min(2, triton.cdiv(M_NP2, 512))\n\n    def _grid_fn(META):\n        grid = (triton.cdiv(M_NP2, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), E)\n        return grid\n\n    A = A.flatten(0, -2)\n    C = C.flatten(0, -2)\n    enable_bias = bias is not None\n\n    grid = _grid_fn\n    fused_moe_kernel[grid](\n        A,\n        B,\n        bias,\n        C,\n        sorted_idx,\n        exp_start,\n        exp_end,\n        N=N,\n        K=K,\n        stride_am=A.stride(0),\n        stride_ak=A.stride(1),\n        stride_be=B.stride(0),\n        stride_bn=B.stride(1),\n        stride_bk=B.stride(2),\n        stride_cm=C.stride(0),\n        stride_cn=C.stride(1),\n        stride_bie=bias.stride(0) if enable_bias else 0,\n        stride_bin=bias.stride(1) if enable_bias else 0,\n        tune_hint=tune_hint,\n        top_k=top_k,\n        expert_offset=expert_offset,\n        reindex_a=reindex_a,\n        reindex_c=reindex_c,\n        M_NP2=M_NP2,\n    )\n\n\n@triton.jit\ndef _get_exp_mask_kernel(\n    a_ptr,\n    o_mask_ptr,\n    o_k_ptr,\n    stride_a_token: tl.constexpr,\n    stride_a_exp: tl.constexpr,\n    stride_o_exp,\n    stride_o_token: tl.constexpr,\n    topk: tl.constexpr,\n    num_experts: tl.constexpr,\n    BLOCK_NA: tl.constexpr,\n    BLOCK_NO: tl.constexpr,\n):\n    token_id = tl.program_id(0)\n\n    offs_n = tl.arange(0, BLOCK_NA)\n    mask_n = offs_n < topk\n    a_ptrs = a_ptr + token_id * stride_a_token + offs_n * stride_a_exp\n    a = tl.load(a_ptrs, mask=mask_n)\n\n    # fill zeros\n    offs_no = tl.arange(0, BLOCK_NO)\n    mask_no = offs_no < num_experts\n    o_ptrs = o_mask_ptr + token_id * stride_o_token + offs_no * stride_o_exp\n    tl.store(o_ptrs, 0, mask=mask_no)\n\n    # fill a\n    o_ptrs = o_mask_ptr + token_id * stride_o_token + a * stride_o_exp\n    tl.store(o_ptrs, 1, mask=mask_n)\n\n    # fill kid\n    ok_ptrs = o_k_ptr + token_id * stride_o_token + a * stride_o_exp\n    tl.store(ok_ptrs, offs_n, mask=mask_n)\n\n\ndef _get_exp_mask(topk_ids: torch.Tensor, num_experts: int):\n    \"\"\"Get exp mask.\"\"\"\n    assert topk_ids.dim() == 2\n    M, topk = topk_ids.shape\n    assert topk <= num_experts\n\n    out_mask = topk_ids.new_empty((num_experts, M))\n    out_k = topk_ids.new_empty((num_experts, M))\n    BLOCK_NA = triton.next_power_of_2(topk)\n    BLOCK_NO = triton.next_power_of_2(num_experts)\n\n    grid = (M, )\n    _get_exp_mask_kernel[grid](\n        topk_ids,\n        out_mask,\n        out_k,\n        stride_a_token=topk_ids.stride(0),\n        stride_a_exp=topk_ids.stride(1),\n        stride_o_exp=out_mask.stride(0),\n        stride_o_token=out_mask.stride(1),\n        topk=topk,\n        num_experts=num_experts,\n        BLOCK_NA=BLOCK_NA,\n        BLOCK_NO=BLOCK_NO,\n        num_warps=1,\n    )\n    return out_mask, out_k\n\n\n@triton.jit\ndef _get_start_end_kernel(\n    exp_cum_ptr,\n    exp_topk_ptr,\n    exp_out_ptr,\n    start_ptr,\n    end_ptr,\n    stride_cum_exp,\n    stride_cum_token: tl.constexpr,\n    stride_out: tl.constexpr,\n    num_tokens,\n    num_experts: tl.constexpr,\n    topk: tl.constexpr,\n    BLOCK_N: tl.constexpr,\n):\n    \"\"\"Get start end kernel.\"\"\"\n    token_start = tl.program_id(0)\n\n    offs_exp = tl.arange(0, BLOCK_N)\n    off_cum = offs_exp * stride_cum_exp + token_start * stride_cum_token\n    cum_ptrs = exp_cum_ptr + off_cum\n    val_k_ptrs = exp_topk_ptr + off_cum\n\n    mask_exp = offs_exp < num_experts\n\n    # get prev and cur cum\n    token_id = token_start\n    prev_cum_mask = mask_exp\n    if token_start == 0:\n        prev_cum_mask = mask_exp & (tl.arange(0, BLOCK_N) > 0)\n    prev_cum = tl.load(cum_ptrs - stride_cum_token, mask=prev_cum_mask, other=0)\n    cur_cum = tl.load(cum_ptrs, mask=mask_exp)\n\n    # store sorted idx\n    mask_out = mask_exp & (cur_cum > prev_cum)\n    val_k = tl.load(val_k_ptrs, mask=mask_exp)\n    val = token_id * topk + val_k\n    out_ptrs = exp_out_ptr + prev_cum * stride_out\n    tl.store(out_ptrs, val, mask=mask_out)\n\n    # fill start\n    if token_id == 0:\n        cur_start_ptrs = start_ptr + offs_exp\n        tl.store(cur_start_ptrs, prev_cum, mask=mask_exp)\n\n    # fill end\n    if token_id == num_tokens - 1:\n        cur_end_ptrs = end_ptr + offs_exp\n        tl.store(cur_end_ptrs, cur_cum, mask=mask_exp)\n\n\ndef get_start_end(exp_cum: torch.Tensor, exp_topk: torch.Tensor, topk: int):\n    \"\"\"Get start end.\"\"\"\n    num_experts, num_tokens = exp_cum.shape\n\n    start_end = exp_cum.new_empty(2, num_experts)\n    exp_start = start_end[0, :]\n    exp_end = start_end[1, :]\n\n    out = exp_cum.new_empty((num_tokens * topk))\n\n    num_warps = 1\n\n    BLOCK_N = triton.next_power_of_2(num_experts)\n    grid = (num_tokens, )\n\n    _get_start_end_kernel[grid](\n        exp_cum,\n        exp_topk,\n        out,\n        exp_start,\n        exp_end,\n        stride_cum_exp=exp_cum.stride(0),\n        stride_cum_token=exp_cum.stride(1),\n        stride_out=out.stride(0),\n        num_tokens=num_tokens,\n        num_experts=num_experts,\n        topk=topk,\n        BLOCK_N=BLOCK_N,\n        num_warps=num_warps,\n    )\n    return out, exp_start, exp_end\n\n\ndef _get_sorted_idx(topk_ids: torch.Tensor, num_experts: int):\n    \"\"\"Get sorted idx.\"\"\"\n    assert topk_ids.dim() == 2\n    _, topk = topk_ids.shape\n\n    # get expert mask   (num_experts, num_tokens)\n    exp_mask, exp_topk = _get_exp_mask(topk_ids, num_experts)\n    # get cumsum   (num_experts, num_tokens)\n    exp_cum = exp_mask.flatten().cumsum(0).view_as(exp_mask)\n\n    # get sort idx and start/end\n    sorted_idx, start, end = get_start_end(exp_cum, exp_topk, topk)\n\n    return sorted_idx, start, end\n\n\ndef _renormalize(topk_weights: torch.Tensor, renormalize: bool):\n    if renormalize:\n        topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)\n    if not topk_weights.is_contiguous():\n        topk_weights = topk_weights.contiguous()\n    return topk_weights\n\n\ndef _make_intermediate(shape: tuple, dtype: torch.dtype, device: torch.device, zeros: bool):\n    \"\"\"Make intermediate.\"\"\"\n    if zeros:\n        return torch.zeros(shape, dtype=dtype, device=device)\n    else:\n        return torch.empty(shape, dtype=dtype, device=device)\n\n\n@triton.jit\ndef _moe_reduce_kernel(\n    hidden_states_ptr,\n    weights_ptr,\n    out_ptr,\n    stride_hm,\n    stride_hk: tl.constexpr,\n    stride_hn: tl.constexpr,\n    stride_wm,\n    stride_wk: tl.constexpr,\n    stride_om,\n    stride_on: tl.constexpr,\n    fp32_acc: tl.constexpr,\n    K: tl.constexpr,\n    N: tl.constexpr,\n    BLOCK_K: tl.constexpr,\n    BLOCK_N: tl.constexpr,\n):\n    pid = tl.program_id(0)\n    num_n_split = tl.cdiv(N, BLOCK_N)\n    mid = pid // num_n_split\n    nid = pid % num_n_split\n\n    offs_k = tl.arange(0, BLOCK_K)\n    offs_n = nid * BLOCK_N + tl.arange(0, BLOCK_N)\n    weights_ptrs = weights_ptr + mid * stride_wm + offs_k * stride_wk\n    h_ptrs = hidden_states_ptr + mid * stride_hm + offs_k[:, None] * stride_hk + offs_n[None, :] * stride_hn\n    o_ptrs = out_ptr + mid * stride_om + offs_n * stride_on\n\n    mask_k = offs_k < K\n    mask_n = offs_n < N  # dummy load to get N\n    mask_h = mask_k[:, None] & mask_n[None, :]\n\n    h = tl.load(h_ptrs, mask=mask_h, other=0.0)\n    w = tl.load(weights_ptrs, mask=mask_k, other=0.0)\n\n    if fp32_acc:\n        h = h.to(tl.float32)\n        w = w.to(tl.float32)\n    else:\n        w = w.to(h.dtype)\n\n    wh = h * w[:, None]\n    o = wh.sum(axis=0)\n    tl.store(o_ptrs, o, mask=mask_n)\n\n\ndef moe_reduce(hidden_states: torch.Tensor, topk_weights: torch.Tensor, fp32_acc: bool = False) -> torch.Tensor:\n    \"\"\"Moe reduce.\"\"\"\n    assert hidden_states.dim() == 3\n    assert topk_weights.dim() == 2\n    assert hidden_states.size(0) == topk_weights.size(0)\n    assert hidden_states.size(1) == topk_weights.size(1)\n    M, K, N = hidden_states.shape\n\n    out = hidden_states.new_empty((M, N))\n\n    BLOCK_K = triton.next_power_of_2(K)\n    num_warps = 1\n    BLOCK_N = triton.cdiv(num_warps * 512, hidden_states.element_size())\n    grid = (M * triton.cdiv(N, BLOCK_N), )\n\n    _moe_reduce_kernel[grid](\n        hidden_states,\n        topk_weights,\n        out,\n        hidden_states.stride(0),\n        hidden_states.stride(1),\n        hidden_states.stride(2),\n        topk_weights.stride(0),\n        topk_weights.stride(1),\n        out.stride(0),\n        out.stride(1),\n        fp32_acc,\n        K,\n        N,\n        BLOCK_K,\n        BLOCK_N,\n        num_warps=num_warps,\n    )\n\n    return out\n\n\ndef fused_moe(hidden_states: torch.Tensor,\n              w1: torch.Tensor,\n              w2: torch.Tensor,\n              topk_weights: torch.Tensor,\n              topk_ids: torch.Tensor,\n              topk: int,\n              w1_bias: torch.Tensor = None,\n              w2_bias: torch.Tensor = None,\n              expert_offset: int = 0,\n              num_experts: int = None,\n              renormalize: bool = False,\n              act_func: Callable = None) -> torch.Tensor:\n    \"\"\"Fused moe.\"\"\"\n    M = hidden_states.size(0)\n    E, N, _ = w1.shape\n    if num_experts is None:\n        num_experts = E\n    full_exp = num_experts == E\n\n    topk_weights = _renormalize(topk_weights, renormalize)\n    sorted_idx, exp_start, exp_end = _get_sorted_idx(topk_ids, num_experts)\n\n    intermediate_cache1 = _make_intermediate((M, topk, N),\n                                             dtype=hidden_states.dtype,\n                                             device=hidden_states.device,\n                                             zeros=not full_exp)\n    # gate and up\n    fused_moe_kernel_launcher(\n        hidden_states,\n        w1,\n        intermediate_cache1,\n        sorted_idx=sorted_idx,\n        exp_start=exp_start,\n        exp_end=exp_end,\n        bias=w1_bias,\n        top_k=topk,\n        num_tokens=M,\n        expert_offset=expert_offset,\n        reindex_a=True,\n        reindex_c=False,\n    )\n\n    # activate\n    unflat_size = intermediate_cache1.shape[:-1]\n    intermediate_cache1 = intermediate_cache1.flatten(0, -2)\n\n    if act_func is None:\n        gate_cache = silu_and_mul(intermediate_cache1)\n    else:\n        gate_cache = act_func(intermediate_cache1)\n    gate_cache = gate_cache.unflatten(0, unflat_size)\n\n    intermediate_cache2 = _make_intermediate((M, topk, w2.shape[1]),\n                                             dtype=hidden_states.dtype,\n                                             device=hidden_states.device,\n                                             zeros=not full_exp)\n    # down\n    fused_moe_kernel_launcher(\n        gate_cache,\n        w2,\n        intermediate_cache2,\n        sorted_idx=sorted_idx,\n        exp_start=exp_start,\n        exp_end=exp_end,\n        bias=w2_bias,\n        top_k=1,\n        num_tokens=M,\n        expert_offset=expert_offset,\n        reindex_a=False,\n        reindex_c=True,\n    )\n\n    ret = moe_reduce(intermediate_cache2, topk_weights)\n    return ret\n"
  },
  {
    "path": "lmdeploy/pytorch/kernels/cuda/fused_moe_ep.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n# modify from dlblas: https://github.com/DeepLink-org/DLBlas\nfrom typing import List, Optional\n\nimport torch\nimport triton\nimport triton.language as tl\n\nfrom .activation import silu_and_mul\n\n\n@triton.jit\ndef _fwd_kernel_ep_scatter_step1(\n    num_recv_tokens_per_expert,\n    expert_start_loc,\n    m_indices,\n    num_experts: tl.constexpr,\n    BLOCK_E: tl.constexpr,\n    BLOCK_EXPERT_NUM: tl.constexpr,\n):\n    cur_expert = tl.program_id(0)\n    offset_cumsum = tl.arange(0, BLOCK_EXPERT_NUM)\n    tokens_per_expert = tl.load(\n        num_recv_tokens_per_expert + offset_cumsum,\n        mask=offset_cumsum < num_experts,\n        other=0,\n    )\n    cumsum = tl.cumsum(tokens_per_expert) - tokens_per_expert\n    tl.store(expert_start_loc + offset_cumsum, cumsum, mask=offset_cumsum < num_experts)\n    cur_expert_start = tl.load(expert_start_loc + cur_expert)\n    cur_expert_token_num = tl.load(num_recv_tokens_per_expert + cur_expert)\n    m_indices_start_ptr = m_indices + cur_expert_start\n    off_expert = tl.arange(0, BLOCK_E)\n    for start_m in tl.range(0, cur_expert_token_num, BLOCK_E, num_stages=4):\n        tl.store(\n            m_indices_start_ptr + start_m + off_expert,\n            cur_expert,\n        )\n\n\n@triton.jit\ndef _fwd_kernel_ep_scatter_step2(\n    total_token_num,\n    expert_start_loc,\n    recv_x,\n    recv_x_stride0,\n    recv_x_stride1,\n    recv_topk,\n    recv_topk_stride0,\n    recv_topk_stride1,\n    output_tensor,\n    output_tensor_stride0,\n    output_tensor_stride1,\n    output_index,\n    output_index_stride0,\n    output_index_stride1,\n    topk_num: tl.constexpr,\n    HIDDEN_SIZE: tl.constexpr,\n    HIDDEN_SIZE_PAD: tl.constexpr,\n):\n    start_token_id = tl.program_id(0)\n    grid_num = tl.num_programs(0)\n    offset_in = tl.arange(0, HIDDEN_SIZE_PAD)\n    mask = offset_in < HIDDEN_SIZE\n    for token_id in range(start_token_id, total_token_num, grid_num):\n        to_copy = tl.load(recv_x + token_id * recv_x_stride0 + offset_in, mask=mask)\n        for topk_index in tl.range(0, topk_num, 1, num_stages=4):\n            expert_id = tl.load(recv_topk + token_id * recv_topk_stride0 + topk_index)\n            if expert_id >= 0:\n                dest_token_index = tl.atomic_add(expert_start_loc + expert_id, 1)\n                dest_token_index = dest_token_index.to(tl.int64)\n                tl.store(output_index + token_id * output_index_stride0 + topk_index, dest_token_index)\n                output_tensor_ptr = output_tensor + dest_token_index * output_tensor_stride0\n                tl.store(output_tensor_ptr + offset_in, to_copy, mask=mask)\n\n\n# copy from https://github.com/ModelTC/lightllm/blob/main/lightllm/common/fused_moe/deepep_scatter_gather.py\ndef ep_scatter(\n    recv_x: torch.Tensor,\n    recv_topk: torch.Tensor,\n    num_recv_tokens_per_expert: torch.Tensor,\n    expert_start_loc: torch.Tensor,\n    output_tensor: torch.Tensor,\n    m_indices: torch.Tensor,\n    output_index: torch.Tensor,\n):\n    BLOCK_E = 128  # token num of per expert is aligned to 128\n    num_warps = 8\n    num_experts = num_recv_tokens_per_expert.shape[0]\n    hidden_size = recv_x.shape[1]\n    grid = num_experts\n    assert m_indices.shape[0] % BLOCK_E == 0\n    _fwd_kernel_ep_scatter_step1[(grid, )](\n        num_recv_tokens_per_expert,\n        expert_start_loc,\n        m_indices,\n        num_experts=num_experts,\n        num_warps=num_warps,\n        BLOCK_E=BLOCK_E,\n        BLOCK_EXPERT_NUM=triton.next_power_of_2(num_experts),\n    )\n    grid = min(recv_topk.shape[0], 1024 * 8)\n    _fwd_kernel_ep_scatter_step2[(grid, )](\n        recv_topk.shape[0],\n        expert_start_loc,\n        recv_x,\n        recv_x.stride(0),\n        recv_x.stride(1),\n        recv_topk,\n        recv_topk.stride(0),\n        recv_topk.stride(1),\n        output_tensor,\n        output_tensor.stride(0),\n        output_tensor.stride(1),\n        output_index,\n        output_index.stride(0),\n        output_index.stride(1),\n        topk_num=recv_topk.shape[1],\n        num_warps=num_warps,\n        HIDDEN_SIZE=hidden_size,\n        HIDDEN_SIZE_PAD=triton.next_power_of_2(hidden_size),\n    )\n    return\n\n\n@triton.jit\ndef _fwd_kernel_ep_gather(\n    total_token_num,\n    input_tensor,\n    input_tensor_stride0,\n    input_tensor_stride1,\n    recv_topk_ids,\n    recv_topk_ids_stride0,\n    recv_topk_ids_stride1,\n    recv_topk_weight,\n    recv_topk_weight_stride0,\n    recv_topk_weight_stride1,\n    input_index,\n    input_index_stride0,\n    input_index_stride1,\n    output_tensor,\n    output_tensor_stride0,\n    output_tensor_stride1,\n    topk_num: tl.constexpr,\n    BLOCK_D: tl.constexpr,\n):\n    cur_block = tl.program_id(0)\n    start_cur_token = tl.program_id(1)\n    grid_num = tl.num_programs(1)\n    # align with xtuner rl\n    compute_dtype = output_tensor.dtype.element_ty\n    # compute_dtype = tl.float32\n\n    for cur_token in range(start_cur_token, total_token_num, grid_num):\n        off_d = tl.arange(0, BLOCK_D)\n        accumulator = tl.zeros([BLOCK_D], dtype=compute_dtype)\n        for topk_index in range(0, topk_num):\n            expert_id = tl.load(recv_topk_ids + cur_token * recv_topk_ids_stride0 + topk_index)\n            if expert_id >= 0:\n                source_token_index = tl.load(input_index + cur_token * input_index_stride0 + topk_index)\n                acc_weight = tl.load(recv_topk_weight + cur_token * recv_topk_weight_stride0 + topk_index)\n                tmp = tl.load(input_tensor + source_token_index * input_tensor_stride0 + cur_block * BLOCK_D + off_d)\n                accumulator += tmp.to(compute_dtype) * acc_weight.to(compute_dtype)\n        tl.store(\n            output_tensor + cur_token * output_tensor_stride0 + cur_block * BLOCK_D + off_d,\n            accumulator.to(output_tensor.dtype.element_ty),\n        )\n\n\n@torch.no_grad()\ndef ep_gather(\n    input_tensor: torch.Tensor,\n    recv_topk_ids: torch.Tensor,\n    recv_topk_weight: torch.Tensor,\n    input_index: torch.Tensor,\n    output_tensor: torch.Tensor,\n):\n    BLOCK_D = 1024  # block size of quantization\n    num_warps = 2\n    num_tokens = output_tensor.shape[0]\n    hidden_size = input_tensor.shape[1]\n    assert hidden_size % BLOCK_D == 0\n    grid = (triton.cdiv(hidden_size, BLOCK_D), min(num_tokens, 1024))\n    _fwd_kernel_ep_gather[grid](\n        num_tokens,\n        input_tensor,\n        input_tensor.stride(0),\n        input_tensor.stride(1),\n        recv_topk_ids,\n        recv_topk_ids.stride(0),\n        recv_topk_ids.stride(1),\n        recv_topk_weight,\n        recv_topk_weight.stride(0),\n        recv_topk_weight.stride(1),\n        input_index,\n        input_index.stride(0),\n        input_index.stride(1),\n        output_tensor,\n        output_tensor.stride(0),\n        output_tensor.stride(1),\n        topk_num=recv_topk_ids.shape[1],\n        num_warps=num_warps,\n        BLOCK_D=BLOCK_D,\n    )\n    return\n\n\ndef _deepgemm_grouped_bf16_nt_contiguous(\n    x: torch.Tensor,\n    w: torch.Tensor,\n    out: torch.Tensor,\n    m_indices: torch.Tensor,\n):\n    from lmdeploy.pytorch.third_party import deep_gemm\n    return deep_gemm.m_grouped_bf16_gemm_nt_contiguous(x, w, out, m_indices)\n\n\ndef fused_moe_v3(\n    hidden_states: torch.Tensor,\n    topk_idx,\n    topk_weights,\n    w13_weight: torch.Tensor,\n    w2_weight: torch.Tensor,\n    num_recv_tokens_per_expert: Optional[List[int]],\n):\n    if num_recv_tokens_per_expert is None:\n        return hidden_states\n    all_tokens = sum(num_recv_tokens_per_expert)\n    if all_tokens <= 0:\n        return hidden_states\n    M, K = hidden_states.size()\n    N = w13_weight.size(1)\n    gather_out = torch.empty_like(hidden_states)\n    input_tensor = hidden_states.new_empty((all_tokens, K))\n    m_indices = hidden_states.new_empty(all_tokens, dtype=torch.int32)\n    output_index = torch.empty_like(topk_idx)\n    num_recv_tokens_per_expert_gpu = torch.tensor(\n        num_recv_tokens_per_expert,\n        dtype=torch.int32,\n        pin_memory=True,\n        device='cpu',\n    ).cuda(non_blocking=True)\n    expert_start_loc = torch.empty_like(num_recv_tokens_per_expert_gpu)\n    ep_scatter(\n        hidden_states,\n        topk_idx,\n        num_recv_tokens_per_expert_gpu,\n        expert_start_loc,\n        input_tensor,\n        m_indices,\n        output_index,\n    )\n    del hidden_states\n    gateup_output = gather_out.new_empty((all_tokens, N))\n    _deepgemm_grouped_bf16_nt_contiguous(input_tensor, w13_weight, gateup_output, m_indices)\n    down_input = gateup_output.new_empty((\n        all_tokens,\n        N // 2,\n    ))\n    down_input = silu_and_mul(gateup_output.view(-1, N), down_input)\n    down_output = gather_out.new_empty((all_tokens, K))\n    _deepgemm_grouped_bf16_nt_contiguous(\n        down_input,\n        w2_weight,\n        down_output,\n        m_indices,\n    )\n    ep_gather(down_output, topk_idx, topk_weights, output_index, gather_out)\n    return gather_out\n"
  },
  {
    "path": "lmdeploy/pytorch/kernels/cuda/fused_noaux_tc.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport torch\nimport triton\nimport triton.language as tl\n\n\n@triton.autotune(\n    configs=[\n        triton.Config({}, num_warps=1, num_stages=1),\n        triton.Config({}, num_warps=1, num_stages=2),\n        triton.Config({}, num_warps=1, num_stages=3),\n        triton.Config({}, num_warps=1, num_stages=4),\n        triton.Config({}, num_warps=2, num_stages=1),\n        triton.Config({}, num_warps=2, num_stages=2),\n        triton.Config({}, num_warps=2, num_stages=3),\n        triton.Config({}, num_warps=2, num_stages=4),\n        triton.Config({}, num_warps=4, num_stages=1),\n        triton.Config({}, num_warps=4, num_stages=2),\n        triton.Config({}, num_warps=4, num_stages=3),\n        triton.Config({}, num_warps=4, num_stages=4),\n        triton.Config({}, num_warps=8, num_stages=1),\n        triton.Config({}, num_warps=8, num_stages=2),\n        triton.Config({}, num_warps=8, num_stages=3),\n        triton.Config({}, num_warps=8, num_stages=4),\n    ],\n    key=['num_experts', 'n_group'],\n)\n@triton.jit\ndef _noaux_routing_kernel(\n    logits_ptr,\n    bias_ptr,\n    scores_ptr,\n    tmp_scores_ptr,\n    batch_size,\n    num_experts: tl.constexpr,\n    n_group: tl.constexpr,\n    group_size: tl.constexpr,\n    topk_group: tl.constexpr,\n    # The following arguments are not used inside the kernel but kept for signature compatibility\n    renormalize: tl.constexpr,\n    routed_scaling_factor,\n    logits_stride_0,\n    logits_stride_1,\n    bias_stride_0,\n    scores_stride_0,\n    scores_stride_1,\n    tmp_scores_stride_0,\n    tmp_scores_stride_1,\n    BLOCK_SIZE: tl.constexpr,\n):\n    pid = tl.program_id(0)\n    if pid >= batch_size:\n        return\n    idx = tl.arange(0, BLOCK_SIZE)\n    mask = idx < num_experts  # always true if BLOCK_SIZE == num_experts, but kept for safety\n    # 1. Load logits and bias\n    logits = tl.load(logits_ptr + pid * logits_stride_0 + idx * logits_stride_1, mask=mask, other=0.0)\n    bias = tl.load(bias_ptr + idx * bias_stride_0, mask=mask, other=0.0)\n    # 2. Compute scores (sigmoid) and bias‑adjusted scores\n    scores = tl.sigmoid(logits)  # original scores\n    scores_fc = scores + bias  # bias‑adjusted scores\n    # 3. Compute group scores: sum of top‑2 scores_fc per group\n    # Reshape to (n_group, group_size) – requires BLOCK_SIZE == num_experts\n    scores_fc_2d = tl.reshape(scores_fc, (n_group, group_size))\n    # Max and argmax per group\n    max_val = tl.max(scores_fc_2d, axis=1)\n    max_idx = tl.argmax(scores_fc_2d, axis=1)  # index within group (0..group_size-1)\n    # Second max per group: mask out the max element\n    col_range = tl.arange(0, group_size)\n    mask_max = col_range[None, :] == max_idx[:, None]\n    scores_fc_masked = tl.where(mask_max, -float('inf'), scores_fc_2d)\n    second_max = tl.max(scores_fc_masked, axis=1)\n    group_scores = max_val + second_max\n    # 4. Select top‑k groups and build selected_mask\n    selected_mask = tl.zeros((BLOCK_SIZE, ), dtype=tl.int1)\n    group_scores_copy = group_scores\n    for _ in range(topk_group):\n        max_idx_g = tl.argmax(group_scores_copy, axis=0)  # group index\n        # mark experts in this group\n        group_start = max_idx_g * group_size\n        group_end = group_start + group_size\n        group_mask = (idx >= group_start) & (idx < group_end) & mask\n        selected_mask = selected_mask | group_mask\n        # remove this group\n        g_idx = tl.arange(0, n_group)\n        g_mask = g_idx == max_idx_g\n        group_scores_copy = tl.where(g_mask, -float('inf'), group_scores_copy)\n    # 5. Build masked scores (tmp_scores) – experts in selected groups keep scores_fc, others 0\n    tmp_scores = tl.where(selected_mask, scores_fc, 0.0)\n    # 6. Store outputs\n    off_scores = pid * scores_stride_0 + idx * scores_stride_1\n    tl.store(scores_ptr + off_scores, scores, mask=mask)\n    off_tmp = pid * tmp_scores_stride_0 + idx * tmp_scores_stride_1\n    tl.store(tmp_scores_ptr + off_tmp, tmp_scores, mask=mask)\n\n\n# ---------------------------------------------------------------------------\n# Wrappers and Benchmarking Logic (Kept exactly as requested)\n# ---------------------------------------------------------------------------\n\n\ndef fused_noaux_tc_routing(\n    logits: torch.Tensor,\n    bias: torch.Tensor,\n    num_experts: int = 256,\n    n_group: int = 8,\n    topk_group: int = 4,\n    top_k: int = 8,\n    renormalize: bool = True,\n    routed_scaling_factor: float = 2.5,\n) -> tuple[torch.Tensor, torch.Tensor]:\n    batch_size = logits.shape[0]\n    group_size = num_experts // n_group\n    assert num_experts % n_group == 0, 'num_experts must be divisible by n_group'\n    # Convert to float32 and ensure contiguous\n    logits = logits.float().contiguous()\n    bias = bias.float().contiguous()\n    # Output tensors from the kernel\n    scores = torch.empty(batch_size, num_experts, device=logits.device, dtype=torch.float32)\n    tmp_scores = torch.empty(batch_size, num_experts, device=logits.device, dtype=torch.float32)\n    # Block size: exactly num_experts (must be multiple of 32 for good performance)\n    BLOCK_SIZE = num_experts\n    # Ensure BLOCK_SIZE is at least 32 and a multiple of 32? Not strictly required but good.\n    # If not multiple of 32, we could round up, but then reshape would break. So we assume it is.\n    # For safety, we assert:\n    assert BLOCK_SIZE % 32 == 0, 'num_experts must be a multiple of 32 for optimal performance'\n    # Kernel launch\n    grid = (batch_size, )\n    _noaux_routing_kernel[grid](\n        logits,\n        bias,\n        scores,\n        tmp_scores,\n        batch_size,\n        num_experts=num_experts,\n        n_group=n_group,\n        group_size=group_size,\n        topk_group=topk_group,\n        renormalize=int(renormalize),  # not used inside kernel\n        routed_scaling_factor=routed_scaling_factor,\n        logits_stride_0=logits.stride(0),\n        logits_stride_1=logits.stride(1),\n        bias_stride_0=bias.stride(0),\n        scores_stride_0=scores.stride(0),\n        scores_stride_1=scores.stride(1),\n        tmp_scores_stride_0=tmp_scores.stride(0),\n        tmp_scores_stride_1=tmp_scores.stride(1),\n        BLOCK_SIZE=BLOCK_SIZE,\n    )\n    # Final expert selection using PyTorch's topk (guarantees exact match)\n    _, topk_idx = torch.topk(tmp_scores, k=top_k, dim=-1, sorted=False)\n    topk_weight = scores.gather(1, topk_idx)\n    if renormalize:\n        topk_weight = topk_weight / (topk_weight.sum(dim=-1, keepdim=True) + 1e-20)\n    topk_weight = topk_weight * routed_scaling_factor\n    return topk_weight, topk_idx\n"
  },
  {
    "path": "lmdeploy/pytorch/kernels/cuda/gated_delta_rule.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom typing import Sequence\n\nimport tilelang\nimport tilelang.language as T\nimport torch\nfrom tvm import tir\n\nBufferLikeType = tir.Buffer | tir.BufferRegion | tir.BufferLoad\n\n\n@T.macro\ndef normalize_qk(k_local: T.Buffer, q_local: T.Buffer, k_per_thr: int) -> None:\n    k_sum = T.alloc_var(T.float32)\n    q_sum = T.alloc_var(T.float32)\n    k_sum = 0\n    q_sum = 0\n    for i in T.Unroll(k_per_thr):\n        k_sum += k_local[i] * k_local[i]\n        q_sum += q_local[i] * q_local[i]\n    k_sum = T.warp_reduce_sum(k_sum)\n    q_sum = T.warp_reduce_sum(q_sum)\n    k_norm = T.rsqrt(k_sum + 1e-6)\n    q_norm = T.rsqrt(q_sum + 1e-6)\n    for i in T.Unroll(k_per_thr):\n        k_local[i] = k_local[i] * k_norm\n        q_local[i] = q_local[i] * q_norm\n\n\n@tilelang.jit(pass_configs={\n    tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,\n}, )\ndef fused_recurrent_gated_delta_rule_fwd(SEQLEN,\n                                         H,\n                                         K,\n                                         HV,\n                                         V,\n                                         NUM_STATE,\n                                         q_stride: Sequence[int],\n                                         k_stride: Sequence[int],\n                                         v_stride: Sequence[int],\n                                         state_stride: Sequence[int],\n                                         scale,\n                                         dtype,\n                                         state_dtype,\n                                         g_dtype=None,\n                                         beta_dtype=None,\n                                         use_g: bool = False,\n                                         use_beta: bool = False,\n                                         use_qk_l2norm_in_kernel: bool = False,\n                                         output_final_state: bool = False,\n                                         use_state_indices: bool = False,\n                                         is_circular_buffer: bool = False,\n                                         num_warps: int = 1):\n\n    num_threads = num_warps * 32\n    state_num_bits = T.DataType(state_dtype).bits\n    data_num_bits = T.DataType(dtype).bits\n    state_vec_width = 128 // state_num_bits\n    data_vec_width = 128 // data_num_bits\n    warp_size = 32\n    k_per_thr = T.ceildiv(K, warp_size)\n    v_per_warp = max(state_vec_width, data_vec_width, 8)\n    # Target v_per_cta >= V to minimize grid_V blocks.\n    # More waves means fewer blocks but more sequential wave iterations.\n    target_v_per_cta = max(V, v_per_warp * num_warps * 2)\n    num_waves = T.ceildiv(target_v_per_cta, v_per_warp * num_warps)\n    v_per_cta = v_per_warp * num_warps * num_waves\n\n    B = T.dynamic('B')\n    N = B if not use_state_indices else T.dynamic('N')\n\n    # dtype\n    if g_dtype is None:\n        g_dtype = dtype\n    if beta_dtype is None:\n        beta_dtype = dtype\n\n    @T.prim_func\n    def fused_recurrent_gated_delta_rule_main(\n        Query: T.StridedTensor([B, SEQLEN, H, K], dtype=dtype, strides=q_stride),\n        Key: T.StridedTensor([B, SEQLEN, H, K], dtype=dtype, strides=k_stride),\n        Value: T.StridedTensor([B, SEQLEN, HV, V], dtype=dtype, strides=v_stride),\n        Out: T.Tensor([B, SEQLEN, HV, V], dtype=dtype),\n        G: T.Tensor([B, SEQLEN, HV], dtype=g_dtype),\n        Beta: T.Tensor([B, SEQLEN, HV], dtype=beta_dtype),\n        State: T.StridedTensor([N, NUM_STATE, HV, K, V], dtype=state_dtype, strides=state_stride),\n        StateIndices: T.Tensor([B], dtype=torch.int64) = None,\n        CacheSeqlens: T.Tensor([B], dtype=torch.int32) = None,\n    ):\n        with T.Kernel(T.ceildiv(V, v_per_cta), B * HV, threads=num_threads) as (v_start, bhv_idx):\n            tidx = T.get_thread_binding(0)\n            b_id = bhv_idx // HV\n            hv_id = bhv_idx % HV\n            h_id = hv_id // (HV // H)\n            warp_id = tidx // warp_size\n            lane_id = tidx % warp_size\n            k_off = lane_id * k_per_thr\n\n            # state_idx\n            if use_state_indices:\n                state_id = StateIndices[b_id]\n            else:\n                state_id = b_id\n\n            if is_circular_buffer:\n                state_seq_id = CacheSeqlens[b_id] % NUM_STATE\n                state_update_id = T.alloc_var(T.int32)\n                state_update_id = (state_seq_id + 1) % NUM_STATE\n            else:\n                state_seq_id = 0\n                state_update_id = 0\n\n            # load states\n            h_smem = T.alloc_shared([K, v_per_cta], state_dtype)\n            T.annotate_layout({h_smem: tilelang.layout.make_swizzled_layout(h_smem)})\n            for i, j in T.Parallel(K, v_per_cta):\n                v_idx = v_start * v_per_cta + j\n                if v_idx < V:\n                    h_smem[i, j] = State[state_id, state_seq_id, hv_id, i, v_idx]\n                else:\n                    h_smem[i, j] = 0.0\n\n            # since H is more heavy than qkv, we would put wave loop outside\n            for wave_id in range(num_waves):\n                # load states local\n\n                v_warp_off = wave_id * num_warps * v_per_warp + warp_id * v_per_warp\n                v_off = v_start * v_per_cta + v_warp_off\n                h_local = T.alloc_local([k_per_thr, v_per_warp], T.float32)\n                if is_circular_buffer:\n                    state_update_id = (state_seq_id + 1) % NUM_STATE\n                for j in T.Unroll(k_per_thr):\n                    k_idx = k_off + j\n                    for vg in T.Unroll(v_per_warp // state_vec_width):\n                        for i in T.Vectorized(state_vec_width):\n                            idx = vg * state_vec_width + i\n                            h_local[j, idx] = h_smem[k_idx, v_warp_off + idx]\n\n                for seq_id in range(SEQLEN):\n                    # load q, k, g, beta\n                    q_local = T.alloc_local([k_per_thr], T.float32)\n                    k_local = T.alloc_local([k_per_thr], T.float32)\n                    for i in T.Vectorized(k_per_thr):\n                        k_idx = (k_off + i) % K\n                        q_local[i] = Query[b_id, seq_id, h_id, k_idx]\n                    for i in T.Vectorized(k_per_thr):\n                        k_idx = (k_off + i) % K\n                        k_local[i] = Key[b_id, seq_id, h_id, k_idx]\n\n                    # normalize\n                    if use_qk_l2norm_in_kernel:\n                        normalize_qk(k_local, q_local, k_per_thr)\n\n                    for i in T.Vectorized(k_per_thr):\n                        q_local[i] = q_local[i] * scale\n\n                    # load g, beta\n                    if use_g:\n                        g = T.cast(G[b_id, seq_id, hv_id], T.float32)\n                    else:\n                        g = 0.0\n                    g_exp = T.exp(g)\n                    if use_beta:\n                        beta = T.cast(Beta[b_id, seq_id, hv_id], T.float32)\n                    else:\n                        beta = 1.0\n\n                    # load v\n                    v_local = T.alloc_local([v_per_warp], dtype)\n                    for vg in T.Unroll(v_per_warp // data_vec_width):\n                        for i in T.Vectorized(data_vec_width):\n                            idx = vg * data_vec_width + i\n                            v_idx = (v_off + idx) % V\n                            v_local[idx] = Value[b_id, seq_id, hv_id, v_idx]\n\n                    # update states\n                    for i in T.Unroll(v_per_warp):\n                        hk = T.alloc_var(T.float32)\n                        hk = 0\n                        for j in T.Unroll(k_per_thr):\n                            h_local[j, i] = h_local[j, i] * g_exp\n                            hk += h_local[j, i] * k_local[j]\n                        hk = T.warp_reduce_sum(hk)\n                        v = (v_local[i] - hk) * beta\n                        for j in T.Unroll(k_per_thr):\n                            h_local[j, i] = h_local[j, i] + k_local[j] * v\n\n                    # store states\n                    if output_final_state and state_id >= 0:\n                        if is_circular_buffer:\n                            for j in T.Unroll(k_per_thr):\n                                if (k_off + j) < K:\n                                    for vg in T.Unroll(v_per_warp // state_vec_width):\n                                        for i in T.Vectorized(state_vec_width):\n                                            idx = vg * state_vec_width + i\n                                            if v_off + idx < V:\n                                                State[state_id, state_update_id, hv_id, k_off + j,\n                                                      v_off + idx] = h_local[j, idx]\n                            state_update_id = (state_update_id + 1) % NUM_STATE\n\n                    # compute output\n                    o_local = T.alloc_local([v_per_warp], dtype)\n                    for i in T.Unroll(v_per_warp):\n                        # o = q * h\n                        o = T.alloc_var(T.float32)\n                        o = 0.0\n                        for j in T.Unroll(k_per_thr):\n                            o += q_local[j] * h_local[j, i]\n                        o = T.warp_reduce_sum(o)\n                        o_local[i] = o\n\n                    if lane_id == 0 and state_id >= 0:\n                        for vg in T.Unroll(v_per_warp // data_vec_width):\n                            for i in T.Vectorized(data_vec_width):\n                                idx = vg * data_vec_width + i\n                                v_idx = (v_off + idx)\n                                if v_idx < V:\n                                    Out[b_id, seq_id, hv_id, v_idx] = o_local[idx]\n\n                # write h_local back to h_smem for coalesced global store\n                if output_final_state and state_id >= 0 and not is_circular_buffer:\n                    for j in T.Unroll(k_per_thr):\n                        k_idx = k_off + j\n                        for vg in T.Unroll(v_per_warp // state_vec_width):\n                            for i in T.Vectorized(state_vec_width):\n                                idx = vg * state_vec_width + i\n                                h_smem[k_idx, v_warp_off + idx] = h_local[j, idx]\n\n            # coalesced state writeback via shared memory\n            if output_final_state and state_id >= 0 and not is_circular_buffer:\n                for i, j in T.Parallel(K, v_per_cta):\n                    v_idx = v_start * v_per_cta + j\n                    if v_idx < V:\n                        State[state_id, state_update_id, hv_id, i, v_idx] = h_smem[i, j]\n\n    return fused_recurrent_gated_delta_rule_main\n\n\ndef fused_recurrent_gated_delta_rule(\n    q: torch.Tensor,\n    k: torch.Tensor,\n    v: torch.Tensor,\n    g: torch.Tensor | None = None,\n    beta: torch.Tensor | None = None,\n    scale: float | None = None,\n    initial_state: torch.Tensor | None = None,\n    output_final_state: bool = False,\n    use_qk_l2norm_in_kernel: bool = False,\n    state_indices: torch.Tensor | None = None,\n    cache_seqlens: torch.Tensor | None = None,\n) -> tuple[torch.Tensor, torch.Tensor | None]:\n    \"\"\"Fused recurrent gated delta rule.\n\n    Args:\n        q: [B, T, H, K]\n        k: [B, T, H, K]\n        v: [B, T, HV, V]\n        g: [B, T, HV], optional\n        beta: [B, T, HV], optional\n        scale: float, optional\n        initial_state: [N, HV, K, V], optional, if state_indices is not proviced, N=B\n        use_qk_l2norm_in_kernel: whether to apply l2 normalization on q and k in the kernel\n        state_indices: [B], optional, the indices to update in the recurrent state, required\n        cache_seqlens: [B], optional, the cached sequence lengths for each batch element\n    Returns:\n        o: [B, T, HV, V]\n        final_state: [N, HV, K, V] if output_final_state else None\n    \"\"\"\n    # T is imported as tilelang.language, use seqlen instead\n    _, seqlen, H, K, V = *k.shape, v.shape[-1]\n    HV = v.shape[2]\n    if scale is None:\n        scale = 1 / (q.shape[-1]**0.5)\n    g_dtype = torch.float32\n    beta_dtype = torch.float32\n    if g is not None:\n        assert g.is_contiguous()\n        g_dtype = g.dtype\n    if beta is not None:\n        assert beta.is_contiguous()\n        beta_dtype = beta.dtype\n    if state_indices is not None:\n        assert state_indices.is_contiguous()\n        assert initial_state is not None, 'initial_state is required when state_indices is provided'\n        assert state_indices.shape == (q.shape[0], )\n\n    o = torch.empty_like(v)\n    final_state = initial_state\n    state_dtype = q.dtype\n    if final_state is not None:\n        state_dim = final_state.dim()\n        # expand dim\n        if state_dim == 4:\n            final_state = final_state.unsqueeze(1)\n        state_stride = final_state.stride()\n        state_dtype = final_state.dtype\n\n        # set and check num states\n        num_states = final_state.shape[1]\n    else:\n        state_dim = 4\n        state_stride = (0, 0, 0, 0, 0)\n        num_states = 1\n\n    num_warps = 4\n    kernel = fused_recurrent_gated_delta_rule_fwd(seqlen,\n                                                  H,\n                                                  K,\n                                                  HV,\n                                                  V,\n                                                  NUM_STATE=num_states,\n                                                  q_stride=q.stride(),\n                                                  k_stride=k.stride(),\n                                                  v_stride=v.stride(),\n                                                  state_stride=state_stride,\n                                                  scale=scale,\n                                                  dtype=q.dtype,\n                                                  state_dtype=state_dtype,\n                                                  g_dtype=g_dtype,\n                                                  beta_dtype=beta_dtype,\n                                                  use_g=g is not None,\n                                                  use_beta=beta is not None,\n                                                  use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,\n                                                  output_final_state=output_final_state,\n                                                  use_state_indices=state_indices is not None,\n                                                  is_circular_buffer=cache_seqlens is not None,\n                                                  num_warps=num_warps)\n\n    kernel(q, k, v, o, g, beta, final_state, state_indices, cache_seqlens)\n\n    if not output_final_state:\n        final_state = None\n    elif final_state is not None and state_dim == 4:\n        final_state = final_state.squeeze(1)\n    return o, final_state\n"
  },
  {
    "path": "lmdeploy/pytorch/kernels/cuda/multinomial_sampling.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport torch\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef _multinomial_sampling_kernel(Scores, Seeds, Offsets, Indices, Outputs, stride_sb, stride_st, stride_ib, stride_it,\n                                 num_tokens, BLOCK_N: tl.constexpr):\n    \"\"\"Kernel.\"\"\"\n    batch_id = tl.program_id(0)\n    n_off = tl.arange(0, BLOCK_N)\n\n    # sampling random seed\n    seed = tl.load(Seeds + batch_id)\n    offset = tl.load(Offsets + batch_id).to(tl.int32)\n    samp = tl.rand(seed, offset)\n\n    # initialize\n    acc = 0.0\n    score_ptr = Scores + batch_id * stride_sb + n_off * stride_st\n    indice_ptr = Indices + batch_id * stride_ib\n    output = tl.load(indice_ptr)\n\n    found_mask = False\n    for b_idx in tl.range(0, num_tokens, BLOCK_N):\n        # triton does not have break statement, use mask to skip computation\n        if not found_mask:\n            s_off = b_idx + n_off\n            s_mask = (s_off < num_tokens)\n            scores = tl.load(score_ptr, mask=s_mask, other=0.0).to(tl.float32)\n            c_scores = tl.cumsum(scores, 0)\n            cum_scores = acc + c_scores\n            acc += tl.max(c_scores, 0)\n\n            pre_cum_scores = cum_scores - scores\n            valid_mask = (samp > pre_cum_scores) & (samp <= cum_scores)\n            found_mask = tl.sum(valid_mask, 0) > 0\n\n            if found_mask:\n                valid_pos = tl.argmax(valid_mask.to(tl.int32), 0)\n                indice = tl.load(indice_ptr + valid_pos * stride_it)\n                output = indice\n        score_ptr += stride_st * BLOCK_N\n        indice_ptr += stride_it * BLOCK_N\n\n    tl.store(Outputs + batch_id, output)\n\n\ndef multinomial_sampling(scores: torch.Tensor,\n                         seeds: torch.LongTensor,\n                         offsets: torch.LongTensor,\n                         indices: torch.Tensor = None):\n    \"\"\"Multinomial sampling.\n\n    Note that this kernel assumes the input scores are already sorted in descending order.\n\n    scores: [batch_size, num_tokens], sorted softmax scores\n    seeds: [batch_size]\n    offsets: [batch_size]\n    indices: [batch_size, num_tokens], original token indices before sorting\n    \"\"\"\n    assert scores.dim() == 2\n    batch_size, num_tokens = scores.size()\n    device = scores.device\n\n    if num_tokens == 1:\n        return torch.zeros_like(scores, dtype=torch.long)\n\n    if indices is None:\n        indices = torch.arange(num_tokens, device=device)\n        indices = indices.expand_as(scores)\n\n    assert indices.dim() == 2\n    assert indices.size() == scores.size()\n\n    outputs = indices[:, 0].clone()\n\n    BLOCK_N = 128\n\n    grid = [batch_size]\n    _multinomial_sampling_kernel[grid](scores,\n                                       seeds,\n                                       offsets,\n                                       indices,\n                                       outputs,\n                                       stride_sb=scores.stride(0),\n                                       stride_st=scores.stride(1),\n                                       stride_ib=indices.stride(0),\n                                       stride_it=indices.stride(1),\n                                       num_tokens=num_tokens,\n                                       BLOCK_N=BLOCK_N,\n                                       num_warps=1)\n\n    return outputs\n"
  },
  {
    "path": "lmdeploy/pytorch/kernels/cuda/pagedattention.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n# modify from: https://github.com/ModelTC/lightllm\nimport math\nfrom typing import Literal, Sequence\n\nimport torch\nimport triton\nimport triton.language as tl\nfrom packaging import version\nfrom torch import Tensor\n\nfrom lmdeploy.utils import get_logger\n\nfrom .utils import get_device_props\n\nlogger = get_logger('lmdeploy')\n\nTRITON_VERSION = version.parse(triton.__version__)\nVERSION_300 = version.parse('3.0.0')\n\nassert TRITON_VERSION >= version.parse('2.2.0')\n\n# TODO: fast op might not work on non-nv device\nif TRITON_VERSION >= VERSION_300:\n    tanh = tl.extra.cuda.libdevice.tanh\n    fast_dividef = tl.extra.cuda.libdevice.fast_dividef\n    tl_log2 = tl.log2\n    tl_exp2 = tl.exp2\nelse:\n    tanh = tl.math.tanh\n    fast_dividef = tl.math.fast_dividef\n    tl_log2 = tl.math.log2\n    tl_exp2 = tl.math.exp2\n\n\n@triton.jit\ndef _fwd_grouped_split_kernel(\n    q_ptr,\n    k_ptr,\n    v_ptr,\n    sm_scale: tl.constexpr,\n    cache_seqlens_ptr,\n    page_table_ptr,\n    acc_out_ptr,\n    alibi_slopes_ptr,\n    stride_qbs: tl.constexpr,\n    stride_qh: tl.constexpr,\n    stride_qd: tl.constexpr,\n    stride_kp: tl.constexpr,\n    stride_kbs: tl.constexpr,\n    stride_kh: tl.constexpr,\n    stride_kd: tl.constexpr,\n    stride_vp: tl.constexpr,\n    stride_vbs: tl.constexpr,\n    stride_vh: tl.constexpr,\n    stride_vd: tl.constexpr,\n    stride_ok: tl.constexpr,\n    stride_obs: tl.constexpr,\n    stride_oh: tl.constexpr,\n    stride_od: tl.constexpr,\n    stride_boffb,\n    kv_group_num: tl.constexpr,\n    seq_len: tl.constexpr,\n    window_size: tl.constexpr,\n    head_size: tl.constexpr,\n    head_size_v: tl.constexpr,\n    num_heads_q: tl.constexpr,\n    logit_softcapping: tl.constexpr,\n    shared_kv: tl.constexpr,\n    SPLIT_K: tl.constexpr,\n    BLOCK_DMODEL: tl.constexpr,\n    BLOCK_DV: tl.constexpr,\n    BLOCK_N: tl.constexpr,\n    BLOCK_H: tl.constexpr,\n    BLOCK_DMODEL1: tl.constexpr,\n):\n    \"\"\"First step kernel of split k attention.\"\"\"\n    cur_batch = tl.program_id(2)\n    tile_id = tl.program_id(0)\n    split_k_id = tl.program_id(1)\n\n    HEADS_PER_REQ: tl.constexpr = kv_group_num * seq_len\n    TILES_PER_GROUP: tl.constexpr = tl.cdiv(HEADS_PER_REQ, BLOCK_H)\n    subtile_id = tile_id % TILES_PER_GROUP\n    cur_kv_head = tile_id // TILES_PER_GROUP\n    offs_h = subtile_id * BLOCK_H + tl.arange(0, BLOCK_H)\n    cur_head = cur_kv_head * kv_group_num + offs_h % kv_group_num\n    cur_token = cur_batch * seq_len + offs_h // kv_group_num\n\n    mask_h = cur_head < cur_kv_head * kv_group_num + kv_group_num\n    mask_h = mask_h & (cur_token < cur_batch * seq_len + seq_len)\n    mask_h = mask_h & (cur_head < num_heads_q)\n\n    q_seqlen = 1\n    kv_seqlen = tl.load(cache_seqlens_ptr + cur_batch)\n    if kv_seqlen <= 0:\n        return\n    history_len = kv_seqlen - q_seqlen\n    if alibi_slopes_ptr is not None:\n        alibi_slopes = tl.load(alibi_slopes_ptr + cur_head, mask=mask_h, other=1.0) * tl_log2(math.e)\n    else:\n        alibi_slopes = None\n\n    # initialize offsets\n    offs_n = tl.arange(0, BLOCK_N)\n    offs_d = tl.arange(0, BLOCK_DMODEL)\n    mask_d = offs_d < head_size\n    offs_d = offs_d % head_size\n    offs_dv = tl.arange(0, BLOCK_DV)\n    mask_dv = offs_dv < head_size_v\n    offs_dv = offs_dv % head_size_v\n    off_k = (cur_kv_head * stride_kh + offs_d[:, None] * stride_kd + offs_n[None, :] * stride_kbs)\n    off_v = (cur_kv_head * stride_vh + offs_dv[None, :] * stride_vd + offs_n[:, None] * stride_vbs)\n\n    off_q = (cur_token[:, None] * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :] * stride_qd)\n    q = tl.load(q_ptr + off_q, mask=mask_h[:, None] & mask_d[None, :], other=0)\n\n    k_ptrs = k_ptr + off_k\n    v_ptrs = v_ptr + off_v\n\n    if BLOCK_DMODEL1 != 0:\n        offs_d1 = BLOCK_DMODEL + tl.arange(0, BLOCK_DMODEL1)\n        mask_d1 = offs_d1 < head_size\n        offs_d1 = offs_d1 % head_size\n        off_q1 = (cur_token[:, None] * stride_qbs + cur_head[:, None] * stride_qh + offs_d1[None, :] * stride_qd)\n        q1 = tl.load(q_ptr + off_q1, mask=mask_h[:, None] & mask_d1[None, :], other=0)\n        off_k1 = (cur_kv_head * stride_kh + offs_d1[:, None] * stride_kd + offs_n[None, :] * stride_kbs)\n        k1_ptrs = k_ptr + off_k1\n\n    block_offset_ptrs = page_table_ptr + cur_batch * stride_boffb\n\n    # initialize pointer to m and l\n    m_i = tl.zeros([BLOCK_H], dtype=tl.float32) - float('inf')\n    l_i = tl.zeros([BLOCK_H], dtype=tl.float32)\n    acc = tl.zeros([BLOCK_H, BLOCK_DV], dtype=tl.float32)\n\n    num_total_blocks = tl.cdiv(kv_seqlen, BLOCK_N)\n    BLOCK_PER_CTA = tl.cdiv(num_total_blocks, SPLIT_K)\n    kv_len_per_prog = BLOCK_PER_CTA * BLOCK_N\n    loop_start = kv_len_per_prog * split_k_id\n    loop_end = tl.minimum(loop_start + kv_len_per_prog, kv_seqlen)\n\n    # load block offset\n    # dirty\n    start_block_id = loop_start // BLOCK_N\n    if window_size > 0:\n        start_block_id = tl.maximum(history_len - window_size, loop_start) // BLOCK_N\n        kv_min_loc = tl.maximum(history_len - window_size, 0)\n\n    loop_start = start_block_id * BLOCK_N\n    block_offset_ptrs += start_block_id\n    for start_n in range(loop_start, loop_end, BLOCK_N):\n        start_n = tl.multiple_of(start_n, BLOCK_N)\n        b_offset = tl.load(block_offset_ptrs)\n        block_offset_ptrs += 1\n\n        # -- compute qk ----\n        k = tl.load(k_ptrs + b_offset * stride_kp)\n        if BLOCK_DMODEL1 != 0:\n            k1 = tl.load(k1_ptrs + b_offset * stride_kp)\n\n        if shared_kv:\n            v = k.trans(1, 0)\n        else:\n            v = tl.load(v_ptrs + b_offset * stride_vp)\n\n        qk = tl.zeros([BLOCK_H, BLOCK_N], dtype=tl.float32)\n        qk += tl.dot(q, k)\n        if BLOCK_DMODEL1 != 0:\n            qk += tl.dot(q1, k1)\n        qk *= sm_scale\n        if logit_softcapping > 0.0:\n            qk = qk / logit_softcapping\n            qk = tanh(qk)\n            qk = qk * logit_softcapping\n        qk = qk * tl_log2(math.e)\n        # NOTE: inf - inf = nan, and nan will leads to error\n        if start_n + BLOCK_N > history_len or window_size > 0:\n            qk_mask = history_len >= (start_n + offs_n)\n            if window_size > 0:\n                qk_mask = qk_mask & ((start_n + offs_n) >= kv_min_loc)\n            qk = tl.where(\n                qk_mask[None, :],\n                qk,\n                -float('inf'),\n            )\n\n        if alibi_slopes_ptr is not None:\n            relative_pos = kv_seqlen - start_n - offs_n[None, :]\n            bias = -tl.abs(relative_pos).to(tl.float32) * alibi_slopes[:, None]\n            qk += bias\n\n        # -- compute p, m_i and l_i\n        m_i_new = tl.maximum(m_i, tl.max(qk, 1))\n        p = tl_exp2(qk - m_i_new[:, None])\n        alpha = tl_exp2(m_i - m_i_new)\n        l_i_new = alpha * l_i + tl.sum(p, 1)\n\n        # -- update output accumulator --\n        # scale acc\n        acc = acc * alpha[:, None]\n\n        # update acc\n        p, v = _convert_pv(p, v)\n        acc += tl.dot(p, v)\n        # update m_i and l_i\n        l_i = l_i_new\n        m_i = m_i_new\n\n    # initialize pointers to output\n    if loop_end > loop_start:\n        off_acc = (cur_token[:, None] * stride_obs + split_k_id * stride_ok + cur_head[:, None] * stride_oh +\n                   offs_dv[None, :] * stride_od)\n        tl.store(acc_out_ptr + off_acc, acc, mask=mask_h[:, None] & mask_dv[None, :])\n\n    off_meta = (cur_token * stride_obs + split_k_id * stride_ok + cur_head * stride_oh + head_size_v)\n    tl.store(acc_out_ptr + off_meta, m_i, mask=mask_h)\n    tl.store(acc_out_ptr + off_meta + 1, l_i, mask=mask_h)\n\n\n@triton.jit\ndef _fwd_grouped_split_quant_kernel(\n    q_ptr,\n    k_ptr,\n    v_ptr,\n    KScalesZeros,\n    VScalesZeros,\n    sm_scale,\n    cache_seqlens_ptr,\n    page_table_ptr,\n    acc_out_ptr,\n    alibi_slopes_ptr,\n    stride_qbs: tl.constexpr,\n    stride_qh: tl.constexpr,\n    stride_qd: tl.constexpr,\n    stride_kp: tl.constexpr,\n    stride_kbs: tl.constexpr,\n    stride_kh: tl.constexpr,\n    stride_kd: tl.constexpr,\n    stride_vp: tl.constexpr,\n    stride_vbs: tl.constexpr,\n    stride_vh: tl.constexpr,\n    stride_vd: tl.constexpr,\n    stride_kszp: tl.constexpr,\n    stride_kszbs: tl.constexpr,\n    stride_kszh: tl.constexpr,\n    stride_kszd: tl.constexpr,\n    stride_vszp: tl.constexpr,\n    stride_vszbs: tl.constexpr,\n    stride_vszh: tl.constexpr,\n    stride_vszd: tl.constexpr,\n    quant_policy: tl.constexpr,\n    stride_ok: tl.constexpr,\n    stride_obs: tl.constexpr,\n    stride_oh: tl.constexpr,\n    stride_od: tl.constexpr,\n    stride_boffb,\n    kv_group_num: tl.constexpr,\n    window_size: tl.constexpr,\n    head_size: tl.constexpr,\n    head_size_v: tl.constexpr,\n    num_heads_q: tl.constexpr,\n    logit_softcapping: tl.constexpr,\n    SPLIT_K: tl.constexpr,\n    BLOCK_DMODEL: tl.constexpr,\n    BLOCK_DV: tl.constexpr,\n    BLOCK_N: tl.constexpr,\n    BLOCK_H: tl.constexpr,\n    BLOCK_DMODEL1: tl.constexpr,\n):\n    \"\"\"First step kernel of split k attention.\n\n    Args:\n        stride_xp: stride of page num dim\n        stride_xbs: stride of block size dim\n        stride_h: stride of head num dim\n        stride_d: stride of head size dim\n    \"\"\"\n    cur_batch = tl.program_id(2)\n    cur_kv_head = tl.program_id(0)\n    split_k_id = tl.program_id(1)\n\n    if BLOCK_H < kv_group_num:\n        HEAD_PER_CTA: tl.constexpr = BLOCK_H\n    else:\n        HEAD_PER_CTA: tl.constexpr = kv_group_num\n    cur_head = cur_kv_head * HEAD_PER_CTA + tl.arange(0, BLOCK_H)\n    mask_h = cur_head < cur_kv_head * HEAD_PER_CTA + HEAD_PER_CTA\n    mask_h = mask_h & (cur_head < num_heads_q)\n    if BLOCK_H < kv_group_num:\n        cur_kv_head = (cur_kv_head * HEAD_PER_CTA) // kv_group_num\n\n    q_seqlen = 1\n    kv_seqlen = tl.load(cache_seqlens_ptr + cur_batch)\n    if kv_seqlen <= 0:\n        return\n    history_len = kv_seqlen - q_seqlen\n    if alibi_slopes_ptr is not None:\n        alibi_slopes = tl.load(alibi_slopes_ptr + cur_head, mask=mask_h, other=1.0) * tl_log2(math.e)\n    else:\n        alibi_slopes = None\n\n    # initialize offsets\n    offs_n = tl.arange(0, BLOCK_N)\n    offs_d = tl.arange(0, BLOCK_DMODEL)\n    offs_dsz = tl.arange(0, 1)\n    mask_d = offs_d < head_size\n    offs_d = offs_d % head_size\n    offs_dv = tl.arange(0, BLOCK_DV)\n    mask_dv = offs_dv < head_size_v\n    offs_dv = offs_dv % head_size_v\n    off_k = (cur_kv_head * stride_kh + offs_d[:, None] * stride_kd + offs_n[None, :] * stride_kbs)\n    off_v = (cur_kv_head * stride_vh + offs_dv[None, :] * stride_vd + offs_n[:, None] * stride_vbs)\n    off_ksz = (cur_kv_head * stride_kszh + offs_dsz[:, None] * stride_kszd + offs_n[None, :] * stride_kszbs)\n    off_vsz = (cur_kv_head * stride_vszh + offs_dsz[None, :] * stride_vszd + offs_n[:, None] * stride_vszbs)\n\n    off_q = (cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :] * stride_qd)\n    q = tl.load(q_ptr + off_q, mask=mask_h[:, None] & mask_d[None, :], other=0)\n\n    ksz_ptrs = KScalesZeros + off_ksz\n    vsz_ptrs = VScalesZeros + off_vsz\n\n    if BLOCK_DMODEL1 != 0:\n        offs_d1 = BLOCK_DMODEL + tl.arange(0, BLOCK_DMODEL1)\n        mask_d1 = offs_d1 < head_size\n        offs_d1 = offs_d1 % head_size\n        off_q1 = (cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d1[None, :] * stride_qd)\n        q1 = tl.load(q_ptr + off_q1, mask=mask_h[:, None] & mask_d1[None, :], other=0)\n        off_k1 = (cur_kv_head * stride_kh + offs_d1[:, None] * stride_kd + offs_n[None, :] * stride_kbs)\n\n    block_offset_ptrs = page_table_ptr + cur_batch * stride_boffb\n\n    # initialize pointer to m and l\n    m_i = tl.zeros([BLOCK_H], dtype=tl.float32) - float('inf')\n    l_i = tl.zeros([BLOCK_H], dtype=tl.float32)\n    if quant_policy == 4:\n        if BLOCK_DMODEL1 != 0:\n            offs_d1 = BLOCK_DMODEL // 2 + tl.arange(0, BLOCK_DMODEL1)\n            shift_k1d = (offs_d1 // (head_size // 2) * 4)[:, None]\n            offs_d1 = offs_d1 % (head_size // 2)\n            off_k1 = (cur_kv_head * stride_kh + offs_d1[:, None] * stride_kd + offs_n[None, :] * stride_kbs)\n        offs_d = tl.arange(0, BLOCK_DMODEL) % (head_size // 2)\n        shift_kd = (tl.arange(0, BLOCK_DMODEL) // (head_size // 2) * 4)[:, None]\n        off_k = (cur_kv_head * stride_kh + offs_d[:, None] * stride_kd + offs_n[None, :] * stride_kbs)\n        offs_dv = tl.arange(0, BLOCK_DV * 2) % head_size_v\n        shift_vd = (tl.arange(0, BLOCK_DV * 2) // head_size_v * 4)\n        off_v = (cur_kv_head * stride_vh + offs_dv[None, :] * stride_vd + offs_n[:, None] * stride_vbs)\n        acc = tl.zeros([BLOCK_H, BLOCK_DV * 2], dtype=tl.float32)  # v head_dim packed\n        mask_dv = tl.arange(0, BLOCK_DV * 2) < (head_size_v * 2)\n        offs_dv = tl.arange(0, BLOCK_DV * 2) % (head_size_v * 2)\n    else:\n        acc = tl.zeros([BLOCK_H, BLOCK_DV], dtype=tl.float32)\n\n    num_total_blocks = tl.cdiv(kv_seqlen, BLOCK_N)\n    BLOCK_PER_CTA = tl.cdiv(num_total_blocks, SPLIT_K)\n    kv_len_per_prog = BLOCK_PER_CTA * BLOCK_N\n    loop_start = kv_len_per_prog * split_k_id\n    loop_end = tl.minimum(loop_start + kv_len_per_prog, kv_seqlen)\n\n    # load block offset\n    # dirty\n    start_block_id = loop_start // BLOCK_N\n    if window_size > 0:\n        start_block_id = tl.maximum(history_len - window_size, loop_start) // BLOCK_N\n        kv_min_loc = tl.maximum(history_len - window_size, 0)\n\n    loop_start = start_block_id * BLOCK_N\n    for start_n in range(loop_start, loop_end, BLOCK_N):\n        start_n = tl.multiple_of(start_n, BLOCK_N)\n        b_offset = tl.load(block_offset_ptrs + start_n // BLOCK_N)\n\n        # -- compute qk ----\n        # k = tl.load(k_ptrs + b_offset * stride_kp)\n        k = tl.load(k_ptr + off_k + b_offset * stride_kp)\n        if quant_policy == 4:\n            k = (k >> shift_kd) & 0x0F\n        ks = tl.load(ksz_ptrs + b_offset * stride_kszp)\n        kz = tl.load(ksz_ptrs + b_offset * stride_kszp + 1)\n        if BLOCK_DMODEL1 != 0:\n            k1 = tl.load(k_ptr + off_k1 + b_offset * stride_kp)\n            if quant_policy == 4:\n                k1 = (k1 >> shift_k1d) & 0x0F\n            k1 = ((k1 - kz) * ks).to(q.dtype)\n\n        if quant_policy == 4:\n            v = tl.load(v_ptr + off_v + b_offset * stride_vp)\n            v = (v >> shift_vd) & 0x0F\n        else:\n            v = tl.load(v_ptr + off_v + b_offset * stride_vp)\n        vs = tl.load(vsz_ptrs + b_offset * stride_vszp)\n        vz = tl.load(vsz_ptrs + b_offset * stride_vszp + 1)\n\n        k = ((k - kz) * ks).to(q.dtype)\n        v = ((v - vz) * vs).to(q.dtype)\n        qk = tl.zeros([BLOCK_H, BLOCK_N], dtype=tl.float32)\n        qk += tl.dot(q, k)\n        if BLOCK_DMODEL1 != 0:\n            qk += tl.dot(q1, k1)\n        qk *= sm_scale\n        if logit_softcapping > 0.0:\n            qk = qk / logit_softcapping\n            qk = tanh(qk)\n            qk = qk * logit_softcapping\n        qk = qk * tl_log2(math.e)\n        # NOTE: inf - inf = nan, and nan will leads to error\n        if start_n + BLOCK_N > history_len or window_size > 0:\n            qk_mask = history_len >= (start_n + offs_n)\n            if window_size > 0:\n                qk_mask = qk_mask & ((start_n + offs_n) >= kv_min_loc)\n            qk = tl.where(\n                qk_mask[None, :],\n                qk,\n                -float('inf'),\n            )\n\n        if alibi_slopes_ptr is not None:\n            relative_pos = kv_seqlen - start_n - offs_n[None, :]\n            bias = -tl.abs(relative_pos).to(tl.float32) * alibi_slopes[:, None]\n            qk += bias\n\n        # -- compute p, m_i and l_i\n        m_i_new = tl.maximum(m_i, tl.max(qk, 1))\n        p = tl_exp2(qk - m_i_new[:, None])\n        alpha = tl_exp2(m_i - m_i_new)\n        l_i_new = alpha * l_i + tl.sum(p, 1)\n\n        # -- update output accumulator --\n        # scale acc\n        acc = acc * alpha[:, None]\n\n        # update acc\n        p, v = _convert_pv(p, v)\n        acc += tl.dot(p, v)\n        # update m_i and l_i\n        l_i = l_i_new\n        m_i = m_i_new\n\n    # initialize pointers to output\n    if loop_end > loop_start:\n        off_acc = (cur_batch * stride_obs + split_k_id * stride_ok + cur_head[:, None] * stride_oh +\n                   offs_dv[None, :] * stride_od)\n        tl.store(acc_out_ptr + off_acc, acc, mask=mask_h[:, None] & mask_dv[None, :])\n\n    if quant_policy == 4:\n        off_meta = (cur_batch * stride_obs + split_k_id * stride_ok + cur_head * stride_oh + head_size_v * 2)\n    else:\n        off_meta = (cur_batch * stride_obs + split_k_id * stride_ok + cur_head * stride_oh + head_size_v)\n    tl.store(acc_out_ptr + off_meta, m_i, mask=mask_h)\n    tl.store(acc_out_ptr + off_meta + 1, l_i, mask=mask_h)\n\n\n@triton.jit\ndef _reduce_split_kernel(\n    acc_ptr,\n    out_ptr,\n    sinks_ptr,\n    stride_ak,\n    stride_abs,\n    stride_ah,\n    stride_ad,\n    stride_obs,\n    stride_oh,\n    stride_od,\n    head_size_v: tl.constexpr,\n    SPLIT_K: tl.constexpr,\n    BLOCK_DV: tl.constexpr,\n):\n    \"\"\"Second step kernel of split k attention.\"\"\"\n    cur_batch = tl.program_id(1)\n    cur_head = tl.program_id(0)\n\n    # initialize offsets\n    offs_dv = tl.arange(0, BLOCK_DV)\n    offs_k = tl.arange(0, SPLIT_K)\n    mask_dv = offs_dv < head_size_v\n\n    offs_acc = (cur_batch * stride_abs + cur_head * stride_ah + offs_k[:, None] * stride_ak +\n                offs_dv[None, :] * stride_ad)\n    offs_mi = (cur_batch * stride_abs + cur_head * stride_ah + stride_ak * offs_k + head_size_v)\n\n    m_k = tl.load(acc_ptr + offs_mi)\n    l_k = tl.load(acc_ptr + offs_mi + 1)\n    acc_k = tl.load(acc_ptr + offs_acc, mask=mask_dv[None, :] & (m_k[:, None] > -float('inf')), other=0.0)\n\n    m_max = tl.max(m_k, 0)\n    alpha = tl_exp2(m_k - m_max)\n    acc_k = acc_k * alpha[:, None]\n    l_k = l_k * alpha\n\n    acc = tl.sum(acc_k, 0)\n    l_sum = tl.sum(l_k, 0)\n\n    if sinks_ptr is not None:\n        sink = tl.load(sinks_ptr + cur_head).to(l_sum.dtype)\n        l_sum = l_sum + tl.exp2(sink * tl_log2(math.e) - m_max)\n    acc = acc / l_sum\n\n    out_offs = (cur_batch * stride_obs + cur_head * stride_oh + offs_dv * stride_od)\n    tl.store(out_ptr + out_offs, acc, mask=mask_dv)\n\n\n@triton.jit\ndef _convert_pv(p, v):\n    \"\"\"Convert pv.\"\"\"\n    p = p.to(v.dtype)\n    return p, v\n\n\n_nv_cap = None\n\n\ndef _kernel_meta_default(BLOCK_DMODEL: int, BLOCK_H: int):\n    \"\"\"Kernel meta default.\"\"\"\n    return 4, 2\n\n\ndef _kernel_meta_sm8x(BLOCK_DMODEL: int, BLOCK_H: int):\n    \"\"\"Kernel meta default.\"\"\"\n    num_stages = 2\n    if BLOCK_DMODEL * BLOCK_H > 8192:\n        num_warps = 8\n    else:\n        num_warps = 4\n    return num_warps, num_stages\n\n\ndef _kernel_meta_sm9x(BLOCK_DMODEL: int, BLOCK_H: int):\n    \"\"\"Kernel meta default.\"\"\"\n    num_warps = 4\n    if BLOCK_DMODEL * BLOCK_H > 4096:\n        num_stages = 2\n    else:\n        num_stages = 3\n    return num_warps, num_stages\n\n\ndef _get_split_k(device_idx: int, head_grid: int, batch_size: int, num_warps: int):\n    \"\"\"Get split k.\"\"\"\n    props = get_device_props(device_idx)\n    num_sm = props['multi_processor_count']\n    # estimated occupancy 12.5%\n    warps_per_sm = props['warps_per_sm'] // 8\n    cta_per_sm = triton.cdiv(warps_per_sm, num_warps)\n    cta_per_device = num_sm * cta_per_sm\n\n    SPLIT_K = triton.cdiv(cta_per_device // head_grid, triton.next_power_of_2(batch_size))\n    SPLIT_K = 1 << (SPLIT_K.bit_length() - 1)\n    max_split = 1 << (num_sm.bit_length() - 1)\n    SPLIT_K = max(min(SPLIT_K, max_split), 4)\n    return SPLIT_K\n\n\ndef flash_attn_with_kvcache(\n    q: Tensor,\n    k_cache: Tensor,\n    v_cache: Tensor,\n    cache_seqlens: Tensor,\n    page_table: Tensor,\n    cu_seqlens_q: Tensor = None,  # not used, for align with fa\n    max_seqlen_q: int = None,\n    softmax_scale: float = None,\n    causal: bool = False,  # not used, for align with fa\n    window_size: int = None,\n    softcap: float = None,\n    scheduler_metadata: Tensor = None,  # not used, for align with fa\n    # args not in fa\n    alibi_slopes: Tensor = None,\n    k_scales_zeros: Tensor = None,\n    v_scales_zeros: Tensor = None,\n    quant_policy: Literal[0, 4, 8] = 0,\n    sinks: Tensor = None,\n    kv_layout: str = 'bshd',\n):\n    \"\"\"Paged Attention forward.\n\n    Note that this kernel is decoding-only\n    \"\"\"\n\n    global _nv_cap\n    if _nv_cap is None:\n        _nv_cap = torch.cuda.get_device_capability()\n\n    if kv_layout == 'bshd':\n        b_dim, s_dim, h_dim, d_dim = (0, 1, 2, 3)\n    elif kv_layout == 'bhsd':\n        b_dim, s_dim, h_dim, d_dim = (0, 2, 1, 3)\n    else:\n        raise RuntimeError('Unsupported layout.')\n\n    if window_size is None:\n        window_size = -1\n    elif isinstance(window_size, Sequence):\n        window_size = window_size[0]\n\n    if softcap is None:\n        softcap = -1.0\n\n    shared_kv = k_cache.data_ptr() == v_cache.data_ptr()\n\n    def _get_block_d(Lk):\n        \"\"\"Get block d.\"\"\"\n        BLOCK_DMODEL = triton.next_power_of_2(Lk)\n        BLOCK_DMODEL1 = 0\n        if BLOCK_DMODEL != Lk:\n            BLOCK_DMODEL = BLOCK_DMODEL // 2\n            BLOCK_DMODEL1 = max(16, triton.next_power_of_2(Lk - BLOCK_DMODEL))\n        BLOCK_DV = triton.next_power_of_2(Lv)\n        return BLOCK_DMODEL, BLOCK_DMODEL1, BLOCK_DV\n\n    # shape constraints\n    Lq, Lk, Lv = q.shape[-1], k_cache.shape[d_dim], v_cache.shape[d_dim]\n    if quant_policy == 4:\n        assert Lq == Lk * 2\n        o = q.new_empty(q.shape[:-1] + (Lv * 2, ))\n    else:\n        assert Lq == Lk\n        o = q.new_empty(q.shape[:-1] + (Lv, ))\n\n    if softmax_scale is None:\n        softmax_scale = 1.0 / (Lq**0.5)\n    batch, head = cache_seqlens.shape[0], q.shape[-2]\n    num_tokens = q.shape[-3]\n    num_kv_heads = k_cache.shape[h_dim]\n    kv_group_num = head // num_kv_heads\n\n    if sinks is not None:\n        assert sinks.is_contiguous()\n        assert sinks.numel() == head\n\n    BLOCK = k_cache.size(s_dim)\n    assert BLOCK >= 16\n    if Lq > 512 and BLOCK > 32:\n        logger.warning(f'`head_dim={Lq}` and `block_size={BLOCK}` '\n                       'might leads to bad performance. '\n                       'Please reduce `block_size`.')\n\n    valid = num_tokens % batch == 0\n    assert valid, 'we only support decoding paged attention.'\n    seq_len = num_tokens // batch\n    if max_seqlen_q is not None:\n        assert max_seqlen_q == seq_len, 'we only support decoding paged attention.'\n\n    BLOCK_DMODEL, BLOCK_DMODEL1, BLOCK_DV = _get_block_d(Lq)\n    HEADS_PER_REQ = kv_group_num * seq_len\n    BLOCK_H = max(16, min(BLOCK, triton.next_power_of_2(HEADS_PER_REQ)))\n    TILES_PER_GROUP = triton.cdiv(HEADS_PER_REQ, BLOCK_H)\n    grid_1 = TILES_PER_GROUP * num_kv_heads\n\n    if _nv_cap[0] < 8:\n        num_warps, num_stages = _kernel_meta_default(BLOCK_DMODEL, BLOCK_H)\n    elif _nv_cap[0] < 9:\n        num_warps, num_stages = _kernel_meta_sm8x(BLOCK_DMODEL, BLOCK_H)\n    else:\n        num_warps, num_stages = _kernel_meta_sm9x(BLOCK_DMODEL, BLOCK_H)\n\n    SPLIT_K = _get_split_k(q.device.index, grid_1, batch, num_warps)\n\n    if quant_policy != 4:\n        acc = q.new_empty(num_tokens, head, SPLIT_K, Lv + 2, dtype=torch.float32)\n    else:\n        acc = q.new_empty(num_tokens, head, SPLIT_K, o.shape[-1] + 2, dtype=torch.float32)\n\n    grid = (\n        grid_1,\n        SPLIT_K,\n        batch,\n    )\n\n    if quant_policy > 0:\n        _fwd_grouped_split_quant_kernel[grid](q,\n                                              k_cache,\n                                              v_cache,\n                                              k_scales_zeros,\n                                              v_scales_zeros,\n                                              softmax_scale,\n                                              cache_seqlens,\n                                              page_table,\n                                              acc,\n                                              alibi_slopes,\n                                              stride_qbs=q.stride(-3),\n                                              stride_qh=q.stride(-2),\n                                              stride_qd=q.stride(-1),\n                                              stride_kp=k_cache.stride(b_dim),\n                                              stride_kbs=k_cache.stride(s_dim),\n                                              stride_kh=k_cache.stride(h_dim),\n                                              stride_kd=k_cache.stride(d_dim),\n                                              stride_vp=v_cache.stride(b_dim),\n                                              stride_vbs=v_cache.stride(s_dim),\n                                              stride_vh=v_cache.stride(h_dim),\n                                              stride_vd=v_cache.stride(d_dim),\n                                              stride_kszp=k_scales_zeros.stride(b_dim),\n                                              stride_kszbs=k_scales_zeros.stride(s_dim),\n                                              stride_kszh=k_scales_zeros.stride(h_dim),\n                                              stride_kszd=k_scales_zeros.stride(d_dim),\n                                              stride_vszp=v_scales_zeros.stride(b_dim),\n                                              stride_vszbs=v_scales_zeros.stride(s_dim),\n                                              stride_vszh=v_scales_zeros.stride(h_dim),\n                                              stride_vszd=v_scales_zeros.stride(d_dim),\n                                              quant_policy=quant_policy,\n                                              stride_ok=acc.stride(-2),\n                                              stride_obs=acc.stride(-4),\n                                              stride_oh=acc.stride(-3),\n                                              stride_od=acc.stride(-1),\n                                              stride_boffb=page_table.stride(0),\n                                              kv_group_num=kv_group_num,\n                                              window_size=window_size,\n                                              head_size=Lq,\n                                              head_size_v=Lv,\n                                              num_heads_q=head,\n                                              logit_softcapping=softcap,\n                                              SPLIT_K=SPLIT_K,\n                                              BLOCK_DMODEL=BLOCK_DMODEL,\n                                              BLOCK_DV=BLOCK_DV,\n                                              BLOCK_N=BLOCK,\n                                              BLOCK_H=BLOCK_H,\n                                              BLOCK_DMODEL1=BLOCK_DMODEL1,\n                                              num_warps=num_warps,\n                                              num_stages=num_stages)\n\n    else:\n        _fwd_grouped_split_kernel[grid](q,\n                                        k_cache,\n                                        v_cache,\n                                        softmax_scale,\n                                        cache_seqlens,\n                                        page_table,\n                                        acc,\n                                        alibi_slopes,\n                                        stride_qbs=q.stride(-3),\n                                        stride_qh=q.stride(-2),\n                                        stride_qd=q.stride(-1),\n                                        stride_kp=k_cache.stride(b_dim),\n                                        stride_kbs=k_cache.stride(s_dim),\n                                        stride_kh=k_cache.stride(h_dim),\n                                        stride_kd=k_cache.stride(d_dim),\n                                        stride_vp=v_cache.stride(b_dim),\n                                        stride_vbs=v_cache.stride(s_dim),\n                                        stride_vh=v_cache.stride(h_dim),\n                                        stride_vd=v_cache.stride(d_dim),\n                                        stride_ok=acc.stride(-2),\n                                        stride_obs=acc.stride(-4),\n                                        stride_oh=acc.stride(-3),\n                                        stride_od=acc.stride(-1),\n                                        stride_boffb=page_table.stride(0),\n                                        kv_group_num=kv_group_num,\n                                        seq_len=seq_len,\n                                        window_size=window_size,\n                                        head_size=Lk,\n                                        head_size_v=Lv,\n                                        num_heads_q=head,\n                                        logit_softcapping=softcap,\n                                        shared_kv=shared_kv,\n                                        SPLIT_K=SPLIT_K,\n                                        BLOCK_DMODEL=BLOCK_DMODEL,\n                                        BLOCK_DV=BLOCK_DV,\n                                        BLOCK_N=BLOCK,\n                                        BLOCK_H=BLOCK_H,\n                                        BLOCK_DMODEL1=BLOCK_DMODEL1,\n                                        num_warps=num_warps,\n                                        num_stages=num_stages)\n\n    num_warps = 2\n    grid = (head, num_tokens)\n    if quant_policy == 4:\n        Lv *= 2\n        BLOCK_DV *= 2\n    _reduce_split_kernel[grid](acc,\n                               o,\n                               sinks,\n                               stride_ak=acc.stride(2),\n                               stride_abs=acc.stride(0),\n                               stride_ah=acc.stride(1),\n                               stride_ad=acc.stride(3),\n                               stride_obs=o.stride(0),\n                               stride_oh=o.stride(1),\n                               stride_od=o.stride(2),\n                               SPLIT_K=SPLIT_K,\n                               head_size_v=Lv,\n                               BLOCK_DV=BLOCK_DV,\n                               num_warps=num_warps,\n                               num_stages=1)\n    return o\n"
  },
  {
    "path": "lmdeploy/pytorch/kernels/cuda/rms_norm.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport torch\nimport triton\nimport triton.language as tl\nfrom torch import Tensor\n\nfrom .utils import get_device_props\n\n\n@triton.jit\ndef _compute_rms_norm(x, w, eps: tl.constexpr, N_COLS: tl.constexpr):\n    \"\"\"Compute rms norm.\"\"\"\n    xf = x.to(tl.float32)\n\n    var = tl.sum(xf * xf, 0) * float(1.0 / N_COLS)\n    out = xf * tl.math.rsqrt(var + eps)\n    out = w * out.to(x.dtype)\n    return out\n\n\n@triton.jit\ndef add_rms_norm_kernel(input, weight, residual, output, out_residual, num_feats, num_groups, stride_ib, stride_ih,\n                        stride_id: tl.constexpr, stride_rb, stride_rh, stride_rd: tl.constexpr, stride_ob, stride_oh,\n                        stride_od: tl.constexpr, stride_rob, stride_roh, stride_rod: tl.constexpr,\n                        has_residual: tl.constexpr, eps: tl.constexpr, N_COLS: tl.constexpr, BLOCK_N: tl.constexpr,\n                        NUM_STAGES: tl.constexpr):\n    \"\"\"Rms norm kernel.\"\"\"\n    prog_id = tl.program_id(0)\n    prog_stride = tl.num_programs(0)\n    offsets = tl.arange(0, BLOCK_N)\n    mask = offsets < N_COLS\n\n    w = tl.load(weight + offsets, mask=mask)\n\n    x_ptrs = input + offsets * stride_id\n    res_ptrs = residual + offsets * stride_rd\n    out_res_ptrs = out_residual + offsets * stride_rod\n    out_ptrs = output + offsets * stride_od\n    for idx in tl.range(prog_id, num_feats, prog_stride, num_stages=NUM_STAGES):\n        batch_id = idx // num_groups\n        head_id = idx % num_groups\n        cur_x_ptrs = x_ptrs + batch_id * stride_ib + head_id * stride_ih\n        cur_res_ptrs = res_ptrs + batch_id * stride_rb + head_id * stride_rh\n        cur_out_ptrs = out_ptrs + batch_id * stride_ob + head_id * stride_oh\n        cur_out_res_ptrs = out_res_ptrs + batch_id * stride_rob + head_id * stride_roh\n        x = tl.load(cur_x_ptrs, mask=mask)\n        if has_residual:\n            res = tl.load(cur_res_ptrs, mask=mask)\n            x += res\n            tl.store(cur_out_res_ptrs, x, mask=mask)\n        out = _compute_rms_norm(x, w, eps, N_COLS)\n        tl.store(cur_out_ptrs, out, mask=mask)\n\n\ndef _unsqueeze_to_3d(tensor: Tensor) -> Tensor:\n    \"\"\"Unsqueeze tensor to 3d.\"\"\"\n    if tensor.dim() == 3:\n        return tensor\n    elif tensor.dim() == 2:\n        return tensor.unsqueeze(0)\n    elif tensor.dim() == 1:\n        return tensor.unsqueeze(0).unsqueeze(0)\n    else:\n        raise ValueError(f'Unsupported tensor dim {tensor.dim()}')\n\n\ndef _squeeze_to_origin_dim(tensor: Tensor, origin_dim: int) -> Tensor:\n    \"\"\"Squeeze tensor to origin dim.\"\"\"\n    if origin_dim == 3:\n        return tensor\n    elif origin_dim == 2:\n        return tensor.squeeze(0)\n    elif origin_dim == 1:\n        return tensor.squeeze(0).squeeze(0)\n    else:\n        raise ValueError(f'Unsupported origin dim {origin_dim}')\n\n\ndef rms_norm(hidden_states: Tensor,\n             weight: Tensor,\n             eps: float = 1e-6,\n             residual: Tensor = None,\n             out: Tensor = None,\n             out_residual: Tensor = None):\n    \"\"\"Rms norm.\"\"\"\n    assert hidden_states.dim() <= 3\n    assert weight.stride(-1) == 1\n    feat_size = weight.shape[0]\n    assert hidden_states.size(-1) == feat_size\n\n    origin_dim = hidden_states.dim()\n    if out is None:\n        out = torch.empty_like(hidden_states)\n    has_residual = residual is not None\n    if has_residual:\n        if out_residual is None:\n            out_residual = torch.empty_like(residual)\n    else:\n        residual = hidden_states\n        out_residual = out\n\n    shape = hidden_states.shape\n    assert residual.shape == shape\n    assert out.shape == shape\n    assert out_residual.shape == shape\n\n    hidden_states = _unsqueeze_to_3d(hidden_states)\n    residual = _unsqueeze_to_3d(residual)\n    out = _unsqueeze_to_3d(out)\n    out_residual = _unsqueeze_to_3d(out_residual)\n\n    num_feats = hidden_states.numel() // hidden_states.size(-1)\n\n    BLOCK_N = triton.next_power_of_2(feat_size)\n\n    props = get_device_props(hidden_states.device.index)\n    num_sm = props['multi_processor_count']\n    warps_per_sm = props['warps_per_sm']\n    blocks_per_sm = props['blocks_per_sm']\n    num_warps = min(triton.cdiv(BLOCK_N, 2048), 4)\n    cta_per_sm = min(blocks_per_sm, warps_per_sm // num_warps)\n    cta_per_device = num_sm * cta_per_sm\n    num_stages = 1\n\n    grid = (min(num_feats, cta_per_device), )\n    add_rms_norm_kernel[grid](\n        hidden_states,\n        weight,\n        residual,\n        out,\n        out_residual,\n        num_feats=num_feats,\n        num_groups=hidden_states.size(1),\n        stride_ib=hidden_states.stride(0),\n        stride_ih=hidden_states.stride(1),\n        stride_id=hidden_states.stride(2),\n        stride_rb=residual.stride(0),\n        stride_rh=residual.stride(1),\n        stride_rd=residual.stride(2),\n        stride_ob=out.stride(0),\n        stride_oh=out.stride(1),\n        stride_od=out.stride(2),\n        stride_rob=out_residual.stride(0),\n        stride_roh=out_residual.stride(1),\n        stride_rod=out_residual.stride(2),\n        has_residual=has_residual,\n        eps=eps,\n        N_COLS=feat_size,\n        BLOCK_N=BLOCK_N,\n        NUM_STAGES=num_stages,\n        num_warps=num_warps,\n        num_stages=num_stages,\n    )\n\n    out = _squeeze_to_origin_dim(out, origin_dim)\n    out_residual = _squeeze_to_origin_dim(out_residual, origin_dim)\n    if has_residual:\n        return out, out_residual\n    return out\n\n\nif __name__ == '__main__':\n    import time\n\n    def torch_forward(hidden_states, weight, variance_epsilon=1e-6):\n        \"\"\"Pytorch forward.\"\"\"\n        input_dtype = hidden_states.dtype\n        hidden_states = hidden_states.to(torch.float32)\n        variance = hidden_states.pow(2).mean(-1, keepdim=True)\n        hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)\n        return weight * hidden_states.to(input_dtype)\n\n    def test_rms_norm(bsz, ctx_len, feat_len, dtype):\n        \"\"\"Test rms norm.\"\"\"\n        input = torch.empty((bsz, ctx_len, feat_len), dtype=dtype, device='cuda').normal_(mean=0., std=0.5).contiguous()\n        weight = torch.empty((feat_len), dtype=dtype, device='cuda').normal_(mean=0., std=0.5).contiguous()\n        triton_output = rms_norm(hidden_states=input, weight=weight)\n        torch_output = torch_forward(hidden_states=input, weight=weight)\n        assert torch.allclose(torch_output, triton_output, atol=1e-2, rtol=0)\n\n        N_REPEATS = 20\n\n        t0 = time.time()\n        for _ in range(N_REPEATS):\n            torch_forward(hidden_states=input, weight=weight)\n\n        t1 = time.time()\n        for _ in range(N_REPEATS):\n            rms_norm(hidden_states=input, weight=weight)\n        t2 = time.time()\n\n        torch_cost = (t1 - t0) / N_REPEATS * 1000\n        triton_cost = (t2 - t1) / N_REPEATS * 1000\n        print('input {} weight {} dtype {}\\n  torch {:.3f} triton {:.3f} (ms)\\n'.format(\n            input.shape, weight.shape, dtype, torch_cost, triton_cost))\n\n    test_rms_norm(1, 8128, 5120, torch.float16)\n    test_rms_norm(1, 8128, 5120, torch.float32)\n    test_rms_norm(1, 992, 128, torch.float16)\n    test_rms_norm(1, 65537, 128, torch.float32)\n"
  },
  {
    "path": "lmdeploy/pytorch/kernels/cuda/utils.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport functools\n\nimport torch\nimport triton\nfrom packaging import version\n\nWARPS_PER_SM = {\n    (8, 0): 64,\n    (8, 6): 48,\n    (8, 7): 48,\n    (8, 9): 48,\n    (9, 0): 64,\n    (10, 0): 64,\n    (10, 1): 48,\n    (11, 0): 48,\n    (12, 0): 48,\n}\n\nBLOCKS_PER_SM = {\n    (8, 0): 32,\n    (8, 6): 16,\n    (8, 7): 16,\n    (8, 9): 24,\n    (9, 0): 32,\n    (10, 0): 32,\n    (10, 1): 24,\n    (11, 0): 24,\n    (12, 0): 24,\n}\n\nTRITON_VERSION = version.parse(triton.__version__)\n\n\n@functools.lru_cache\ndef get_device_props(device=None):\n    if device is None:\n        device = torch.cuda.current_device()\n\n    props = torch.cuda.get_device_properties(device)\n\n    warps_per_sm = WARPS_PER_SM.get((props.major, props.minor), 32)\n    blocks_per_sm = BLOCKS_PER_SM.get((props.major, props.minor), warps_per_sm // 2)\n    out = dict(\n        multi_processor_count=props.multi_processor_count,\n        warps_per_sm=warps_per_sm,\n        blocks_per_sm=blocks_per_sm,\n    )\n    return out\n\n\ndef is_cuda():\n    return triton.runtime.driver.active.get_current_target().backend == 'cuda'\n\n\n@functools.lru_cache\ndef supports_tma():\n    ret = is_cuda() and torch.cuda.get_device_capability()[0] >= 9\n    if not ret:\n        return False\n\n    VALID_VERSION = version.parse('3.4.0')\n    return TRITON_VERSION >= VALID_VERSION\n\n\nif supports_tma():\n    from triton.tools.tensor_descriptor import TensorDescriptor  # noqa: F401\n"
  },
  {
    "path": "lmdeploy/pytorch/kernels/cuda/w8a8_fused_moe.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n# modify from: https://github.com/vllm-project/vllm\nimport torch\nimport triton\nimport triton.language as tl\n\nfrom .activation import silu_and_mul\nfrom .fused_moe import _get_sorted_idx, _make_intermediate, _renormalize, moe_reduce\nfrom .w8a8_triton_kernels import per_token_quant_int8\n\n\ndef get_cuda_autotune_config():\n    return [\n        triton.Config({\n            'BLOCK_SIZE_M': 128,\n            'BLOCK_SIZE_N': 128,\n            'BLOCK_SIZE_K': 32,\n            'GROUP_SIZE_M': 1,\n        },\n                      num_stages=4,\n                      num_warps=4),\n        triton.Config({\n            'BLOCK_SIZE_M': 64,\n            'BLOCK_SIZE_N': 256,\n            'BLOCK_SIZE_K': 32,\n            'GROUP_SIZE_M': 1,\n        },\n                      num_stages=4,\n                      num_warps=4),\n        triton.Config({\n            'BLOCK_SIZE_M': 64,\n            'BLOCK_SIZE_N': 128,\n            'BLOCK_SIZE_K': 64,\n            'GROUP_SIZE_M': 1,\n        },\n                      num_stages=4,\n                      num_warps=4),\n        triton.Config({\n            'BLOCK_SIZE_M': 128,\n            'BLOCK_SIZE_N': 128,\n            'BLOCK_SIZE_K': 128,\n            'GROUP_SIZE_M': 1,\n        },\n                      num_stages=3,\n                      num_warps=8),\n    ]\n\n\n@triton.autotune(\n    configs=get_cuda_autotune_config(),\n    key=['N', 'K', 'M_NP2'],\n)\n@triton.jit\ndef fused_moe_w8a8_kernel(\n    A,\n    A_scale,\n    B,\n    B_scale,\n    C,\n    SortedIdx,\n    ExpStart,\n    ExpEnd,\n    N: tl.constexpr,\n    K: tl.constexpr,\n    stride_am: tl.constexpr,\n    stride_ak: tl.constexpr,\n    stride_be: tl.constexpr,\n    stride_bn: tl.constexpr,\n    stride_bk: tl.constexpr,\n    stride_bse: tl.constexpr,\n    stride_cm: tl.constexpr,\n    stride_cn: tl.constexpr,\n    BLOCK_SIZE_M: tl.constexpr,\n    BLOCK_SIZE_N: tl.constexpr,\n    BLOCK_SIZE_K: tl.constexpr,\n    GROUP_SIZE_M: tl.constexpr,\n    M_NP2: tl.constexpr,\n    top_k: tl.constexpr,\n    expert_offset: tl.constexpr,\n    reindex_a: tl.constexpr,\n    reindex_c: tl.constexpr,\n    ACCUMULATOR_DTYPE: tl.constexpr,\n):\n    \"\"\"Fused moe kernel.\"\"\"\n    exp_id = tl.program_id(1)\n    pid = tl.program_id(0)\n\n    exp_start = tl.load(ExpStart + exp_id + expert_offset)\n    exp_end = tl.load(ExpEnd + exp_id + expert_offset)\n    M = exp_end - exp_start\n    if M <= 0:\n        return\n\n    num_pid_m = tl.cdiv(M_NP2, BLOCK_SIZE_M)\n    num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n\n    if GROUP_SIZE_M == 1:\n        pid_m = pid % num_pid_m\n        pid_n = pid // num_pid_m\n    else:\n        num_pid_in_group = GROUP_SIZE_M * num_pid_n\n        group_id = pid // num_pid_in_group\n        first_pid_m = group_id * GROUP_SIZE_M\n        group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n        pid_m = first_pid_m + (pid % group_size_m)\n        pid_n = (pid % num_pid_in_group) // group_size_m\n\n    if pid_m * BLOCK_SIZE_M >= M or pid_n * BLOCK_SIZE_N >= N:\n        return\n\n    offs_sid = exp_start + pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n    mask_sid = offs_sid < exp_end\n    sid = tl.load(SortedIdx + offs_sid, mask=mask_sid, other=0)\n\n    offs_k = tl.arange(0, BLOCK_SIZE_K)\n    if reindex_a:\n        offs_am = sid // top_k\n    else:\n        offs_am = offs_sid\n    a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)\n    as_ptrs = A_scale + offs_am\n    offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N\n    offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N)\n\n    # deepseek has 160 experts, exp index would overflow int32\n    exp_id = exp_id.to(tl.int64)\n    exp_off = stride_be * exp_id\n    b_ptrs = B + exp_off + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)\n    bs_ptrs = B_scale + exp_id * stride_bse + offs_bn\n\n    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=ACCUMULATOR_DTYPE)\n\n    for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):\n        a = tl.load(a_ptrs, mask=mask_sid[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), other=0.0)\n        b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)\n        accumulator = tl.dot(a, b, acc=accumulator, out_dtype=ACCUMULATOR_DTYPE)\n        a_ptrs += BLOCK_SIZE_K * stride_ak\n        b_ptrs += BLOCK_SIZE_K * stride_bk\n\n    ascale = tl.load(as_ptrs, mask=mask_sid)\n    bscale = tl.load(bs_ptrs)\n    c = accumulator.to(ascale.dtype)\n    c = c * ascale[:, None] * bscale[None, :]\n\n    c = c.to(C.dtype.element_ty)\n\n    if reindex_c:\n        offs_cm = sid\n    else:\n        offs_cm = offs_sid\n    c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_bn[None, :]\n    tl.store(c_ptrs, c, mask=mask_sid[:, None])\n\n\ndef fused_moe_w8a8_kernel_launcher(\n    A: torch.Tensor,\n    A_scale: torch.Tensor,\n    B: torch.Tensor,\n    B_scale: torch.Tensor,\n    C: torch.Tensor,\n    sorted_idx: torch.Tensor,\n    exp_start: torch.Tensor,\n    exp_end: torch.Tensor,\n    top_k: int = 1,\n    num_tokens: int = None,\n    expert_offset: int = 0,\n    reindex_a: bool = True,\n    reindex_c: bool = True,\n):\n    \"\"\"Fused moe kernel launcher.\"\"\"\n\n    if num_tokens is None:\n        num_tokens = A.size(0)\n    M_NP2 = triton.next_power_of_2(num_tokens)\n    M_NP2 = max(64, M_NP2)\n    E, N, K = B.shape\n\n    assert A_scale.is_contiguous()\n    assert B_scale.is_contiguous()\n    accumulator_dtype = tl.float32 if A.is_floating_point() else tl.int32\n\n    def _grid_fn(META):\n        grid = (triton.cdiv(M_NP2, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), E)\n        return grid\n\n    A = A.flatten(0, -2)\n    C = C.flatten(0, -2)\n\n    grid = _grid_fn\n    fused_moe_w8a8_kernel[grid](\n        A,\n        A_scale,\n        B,\n        B_scale,\n        C,\n        sorted_idx,\n        exp_start,\n        exp_end,\n        N=N,\n        K=K,\n        stride_am=A.stride(0),\n        stride_ak=A.stride(1),\n        stride_be=B.stride(0),\n        stride_bn=B.stride(1),\n        stride_bk=B.stride(2),\n        stride_bse=B_scale.stride(0),\n        stride_cm=C.stride(0),\n        stride_cn=C.stride(1),\n        top_k=top_k,\n        expert_offset=expert_offset,\n        reindex_a=reindex_a,\n        reindex_c=reindex_c,\n        M_NP2=M_NP2,\n        ACCUMULATOR_DTYPE=accumulator_dtype,\n    )\n\n\ndef fused_moe_w8a8(input: torch.Tensor,\n                   input_scale: torch.Tensor,\n                   w1: torch.Tensor,\n                   w1_scale: torch.Tensor,\n                   w2: torch.Tensor,\n                   w2_scale: torch.Tensor,\n                   topk_weights: torch.Tensor,\n                   topk_ids: torch.Tensor,\n                   topk: int,\n                   out_dtype: torch.dtype = torch.float16,\n                   quant_dtype: torch.dtype = torch.int8,\n                   expert_offset: int = 0,\n                   num_experts: int = None,\n                   renormalize: bool = False) -> torch.Tensor:\n    \"\"\"Fused moe.\"\"\"\n    device = input.device\n    M = input.size(0)\n    E, N, _ = w1.shape\n    if num_experts is None:\n        num_experts = E\n    full_exp = num_experts == E\n\n    topk_weights = _renormalize(topk_weights, renormalize)\n    sorted_idx, exp_start, exp_end = _get_sorted_idx(topk_ids, num_experts)\n\n    intermediate_cache1 = _make_intermediate((M, topk, N), dtype=out_dtype, device=device, zeros=not full_exp)\n    # gate and up\n    fused_moe_w8a8_kernel_launcher(\n        input,\n        input_scale,\n        w1,\n        w1_scale,\n        intermediate_cache1,\n        sorted_idx=sorted_idx,\n        exp_start=exp_start,\n        exp_end=exp_end,\n        top_k=topk,\n        num_tokens=M,\n        expert_offset=expert_offset,\n        reindex_a=True,\n        reindex_c=False,\n    )\n\n    # activate\n    unflat_size = intermediate_cache1.shape[:-1]\n    intermediate_cache1 = intermediate_cache1.flatten(0, -2)\n    gate_cache = silu_and_mul(intermediate_cache1)\n    del intermediate_cache1\n    gate_cache = gate_cache.unflatten(0, unflat_size)\n    gate_cache, gate_scale = per_token_quant_int8(gate_cache, 1e-7, quant_dtype=quant_dtype)\n\n    intermediate_cache2 = _make_intermediate((M, topk, w2.shape[1]), dtype=out_dtype, device=device, zeros=not full_exp)\n    # down\n    fused_moe_w8a8_kernel_launcher(\n        gate_cache,\n        gate_scale,\n        w2,\n        w2_scale,\n        intermediate_cache2,\n        sorted_idx=sorted_idx,\n        exp_start=exp_start,\n        exp_end=exp_end,\n        top_k=1,\n        num_tokens=M,\n        expert_offset=expert_offset,\n        reindex_a=False,\n        reindex_c=True,\n    )\n\n    ret = moe_reduce(intermediate_cache2, topk_weights)\n    return ret\n"
  },
  {
    "path": "lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport torch\nimport torch.nn.functional as F\nimport triton\nimport triton.language as tl\nfrom packaging import version\n\nfrom ..default.w8a8_kernels import per_channel_quant\n\nTRITON_VERSION = version.parse(triton.__version__)\nif TRITON_VERSION >= version.parse('3.0.0'):\n    tl_round = tl.extra.cuda.libdevice.round\nelse:\n    tl_round = tl.math.round\n\n\n@triton.autotune(\n    configs=[\n        triton.Config({\n            'BLOCK_M': 128,\n            'BLOCK_N': 256,\n            'BLOCK_K': 128,\n        }, num_stages=3, num_warps=8),\n        triton.Config({\n            'BLOCK_M': 256,\n            'BLOCK_N': 128,\n            'BLOCK_K': 128,\n        }, num_stages=3, num_warps=8)\n    ],\n    key=['N', 'K'],\n)\n@triton.jit(do_not_specialize=['M'])\ndef _linear(\n    A,\n    B,\n    C,\n    M,\n    N,\n    K,\n    stride_am,\n    stride_ak,\n    stride_bk,\n    stride_bn,\n    stride_cm,\n    stride_cn,\n    BLOCK_M: tl.constexpr,\n    BLOCK_N: tl.constexpr,\n    BLOCK_K: tl.constexpr,\n    GROUP_SIZE_M: tl.constexpr,\n    rms_scale_ptr,\n    linear_scale_ptr,\n    ACCUMULATOR_DTYPE: tl.constexpr,\n):\n    \"\"\"Triton-accelerated function used to perform linear operations (dot\n    product) on input tensors `A` and `B`, and store the result in output\n    tensor `C`.\n\n    The function applies auto-tuning for optimal performance and uses Just-in- Time compilation.\n    \"\"\"\n\n    pid = tl.program_id(axis=0)\n    num_pid_m = tl.cdiv(M, BLOCK_M)\n    num_pid_n = tl.cdiv(N, BLOCK_N)\n    num_pid_in_group = GROUP_SIZE_M * num_pid_n\n    group_id = pid // num_pid_in_group\n    first_pid_m = group_id * GROUP_SIZE_M\n    group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n    pid_m = first_pid_m + (pid % group_size_m)\n    pid_n = (pid % num_pid_in_group) // group_size_m\n\n    offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M\n    offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N\n    offs_k = tl.arange(0, BLOCK_K)\n    a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)\n    b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)\n    accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACCUMULATOR_DTYPE)\n    for k in range(0, tl.cdiv(K, BLOCK_K)):\n        a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_K, other=None)\n        b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=None)\n        accumulator = tl.dot(a, b, accumulator, out_dtype=ACCUMULATOR_DTYPE)\n        a_ptrs += BLOCK_K * stride_ak\n        b_ptrs += BLOCK_K * stride_bk\n    c = accumulator.to(tl.float32)\n\n    rms_scale = tl.load(rms_scale_ptr + offs_am)[:, None]\n    linear_scale = tl.load(linear_scale_ptr + offs_bn)[None, :]\n    c = c * rms_scale * linear_scale\n\n    offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n    offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n    c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]\n    c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)\n    tl.store(c_ptrs, c, mask=c_mask)\n\n\n@triton.autotune(\n    configs=[\n        triton.Config({\n            'BLOCK_M': 128,\n            'BLOCK_N': 256,\n            'BLOCK_K': 128,\n        }, num_stages=3, num_warps=8),\n        triton.Config({\n            'BLOCK_M': 256,\n            'BLOCK_N': 128,\n            'BLOCK_K': 128,\n        }, num_stages=3, num_warps=8)\n    ],\n    key=['N', 'K'],\n)\n@triton.jit(do_not_specialize=['M'])\ndef _linear_add(A, B, C, residual_ptr, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,\n                BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr,\n                rms_scale_ptr, linear_scale_ptr, ACCUMULATOR_DTYPE: tl.constexpr):\n    \"\"\"Triton-accelerated function used to perform a linear operation (dot\n    product) on input tensors `A` and `B`, with addition of residual.\n\n    The result is stored in tensor `C`. The function applies auto-tuning for optimal performance and uses Just-in-Time\n    compilation.\n    \"\"\"\n\n    pid = tl.program_id(axis=0)\n    num_pid_m = tl.cdiv(M, BLOCK_M)\n    num_pid_n = tl.cdiv(N, BLOCK_N)\n    num_pid_in_group = GROUP_SIZE_M * num_pid_n\n    group_id = pid // num_pid_in_group\n    first_pid_m = group_id * GROUP_SIZE_M\n    group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n    pid_m = first_pid_m + (pid % group_size_m)\n    pid_n = (pid % num_pid_in_group) // group_size_m\n\n    offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M\n    offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N\n    offs_k = tl.arange(0, BLOCK_K)\n    a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)\n    b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)\n\n    accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACCUMULATOR_DTYPE)\n    for k in range(0, tl.cdiv(K, BLOCK_K)):\n        a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_K, other=None)\n        b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=None)\n        accumulator = tl.dot(a, b, accumulator, out_dtype=ACCUMULATOR_DTYPE)\n        a_ptrs += BLOCK_K * stride_ak\n        b_ptrs += BLOCK_K * stride_bk\n    c = accumulator.to(tl.float32)\n\n    rms_scale = tl.load(rms_scale_ptr + offs_am)[:, None]\n    linear_scale = tl.load(linear_scale_ptr + offs_bn)[None, :]\n    c = c * rms_scale * linear_scale\n    c = c.to(residual_ptr.dtype.element_ty)\n\n    offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n    offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n\n    residual_ptrs = (residual_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :])\n    c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)\n    residual = tl.load(residual_ptrs, mask=c_mask, other=0.)\n    c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]\n    tl.store(c_ptrs, c + residual, mask=c_mask)\n\n\ndef matmul_kernel_dynamic_quant(a, b, rms_scale, linear_scale, residual=None, bias=None, output_dtype=torch.float16):\n    \"\"\"This function performs matrix multiplication with dynamic quantization.\n\n    It takes two input tensors `a` and `b`, scales them with `rms_scale` and `linear_scale`, and optionally adds a\n    `residual` tensor and a `bias`. The output is returned in the specified `output_dtype`.\n    \"\"\"\n\n    assert a.shape[-1] == b.shape[-1]\n    assert b.ndim == 2 and b.is_contiguous()\n    M = a.numel() // a.shape[-1]\n    N, K = b.shape\n    c_shape = a.shape[:-1] + (N, )\n    if residual is not None:\n        assert residual.shape == c_shape\n        assert residual.is_contiguous()\n    c = a.new_empty(c_shape, dtype=output_dtype)\n    accumulator_dtype = tl.float32 if a.is_floating_point() else tl.int32\n\n    def grid(META):\n        return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )\n\n    if residual is not None:\n        _linear_add[grid](a,\n                          b,\n                          c,\n                          residual,\n                          M,\n                          N,\n                          K,\n                          a.stride(-2),\n                          a.stride(-1),\n                          b.stride(1),\n                          b.stride(0),\n                          c.stride(-2),\n                          c.stride(-1),\n                          GROUP_SIZE_M=8,\n                          rms_scale_ptr=rms_scale,\n                          linear_scale_ptr=linear_scale,\n                          ACCUMULATOR_DTYPE=accumulator_dtype)\n    else:\n        _linear[grid](a,\n                      b,\n                      c,\n                      M,\n                      N,\n                      K,\n                      a.stride(-2),\n                      a.stride(-1),\n                      b.stride(1),\n                      b.stride(0),\n                      c.stride(-2),\n                      c.stride(-1),\n                      GROUP_SIZE_M=8,\n                      rms_scale_ptr=rms_scale,\n                      linear_scale_ptr=linear_scale,\n                      ACCUMULATOR_DTYPE=accumulator_dtype)\n    if bias is not None:\n        c += bias\n\n    return c\n\n\n@triton.jit\ndef _per_token_quant_int8(\n        y_ptr,\n        y_q_ptr,\n        y_s_ptr,\n        y_stride: tl.constexpr,\n        yq_stride: tl.constexpr,\n        N,  # number of columns in X\n        eps: tl.constexpr,  # epsilon to avoid division by zero\n        BLOCK: tl.constexpr,\n        Q_MAX: tl.constexpr,\n        IS_FLOATING_POINT: tl.constexpr,  # True for floating point dtype\n):\n    \"\"\"A Triton-accelerated function to perform per-token quantization on a\n    tensor.\n\n    This function converts the tensor values into signed 8-bit integers.\n    \"\"\"\n    # Map the program id to the row of X and Y it should compute.\n    row = tl.program_id(0)\n    y_ptr += row * y_stride\n    y_q_ptr += row * yq_stride\n    y_s_ptr += row\n\n    cols = tl.arange(0, BLOCK)  # N <= BLOCK\n    mask = cols < N\n\n    y = tl.load(y_ptr + cols, mask=mask, other=0.).to(tl.float32)\n    # Quant\n    _absmax = tl.maximum(tl.max(tl.abs(y)), eps)\n    y_s = _absmax / Q_MAX\n    y_q = y / y_s\n    if not IS_FLOATING_POINT:\n        y_q = tl_round(y_q).to(tl.int8)\n\n    tl.store(y_q_ptr + cols, y_q, mask=mask)\n    tl.store(y_s_ptr, y_s)\n\n\ndef per_token_quant_int8(x, eps, quant_dtype=torch.int8):\n    \"\"\"Function to perform per-token quantization on an input tensor `x`.\n\n    It converts the tensor values into signed 8-bit integers and returns the quantized tensor along with the scaling\n    factor used for quantization.\n    \"\"\"\n    qdtype_info = torch.finfo(quant_dtype) if quant_dtype.is_floating_point else torch.iinfo(quant_dtype)\n    q_max = qdtype_info.max\n    x_q = torch.empty_like(x, device=x.device, dtype=quant_dtype)\n    M = x.numel() // x.shape[-1]\n    N = x.shape[-1]\n    x_s = torch.empty(x.shape[:-1] + (1, ), device=x.device, dtype=torch.float32)\n    BLOCK = triton.next_power_of_2(N)\n    # heuristics for number of warps\n    num_warps = min(max(BLOCK // 256, 1), 8)\n\n    if x.dim() > 2:\n        x = x.flatten(0, -2)\n    assert x.stride(-1) == 1\n    # enqueue kernel\n    _per_token_quant_int8[(M, )](x,\n                                 x_q,\n                                 x_s,\n                                 y_stride=x.stride(-2),\n                                 yq_stride=x_q.stride(-2),\n                                 N=N,\n                                 eps=eps,\n                                 BLOCK=BLOCK,\n                                 Q_MAX=q_max,\n                                 IS_FLOATING_POINT=quant_dtype.is_floating_point,\n                                 num_warps=num_warps)\n\n    return x_q, x_s\n\n\n@triton.jit\ndef _compute_rms_norm(x, w, eps: tl.constexpr, N_COLS: tl.constexpr):\n    \"\"\"Compute rms norm.\"\"\"\n    xf = x.to(tl.float32)\n\n    var = tl.sum(xf * xf, 0) * float(1.0 / N_COLS)\n    out = xf * tl.math.rsqrt(var + eps)\n    out = (w * out).to(x.dtype)\n    return out\n\n\n@triton.jit\ndef rms_norm_quant_kernel(\n    input,\n    weight,\n    output,\n    out_scale,\n    input_row_stride: tl.constexpr,\n    eps: tl.constexpr,\n    N_COLS: tl.constexpr,\n    BLOCK_N: tl.constexpr,\n    Q_MIN: tl.constexpr,\n    Q_MAX: tl.constexpr,\n    IS_FLOATING_POINT: tl.constexpr,\n):\n    \"\"\"Rms norm kernel.\"\"\"\n    prog_id = tl.program_id(0)\n    offsets = tl.arange(0, BLOCK_N)\n\n    w = tl.load(weight + offsets, mask=offsets < N_COLS)\n\n    x_ptr = input + prog_id * input_row_stride\n    x = tl.load(x_ptr + offsets, mask=offsets < N_COLS)\n    out = _compute_rms_norm(x, w, eps, N_COLS)\n\n    scale = tl.max(tl.abs(out)).to(tl.float32) / Q_MAX\n    out_s_ptr = out_scale + prog_id\n    tl.store(out_s_ptr, scale)\n    out = out / scale\n    if not IS_FLOATING_POINT:\n        out = tl_round(out)\n    out = tl.clamp(out, Q_MIN, Q_MAX)\n    out_ptr = output + prog_id * input_row_stride\n    tl.store(out_ptr + offsets, out, mask=offsets < N_COLS)\n\n\n@triton.jit\ndef add_rms_norm_quant_kernel(\n    input,\n    weight,\n    residual,\n    output,\n    out_scale,\n    out_residual,\n    input_row_stride: tl.constexpr,\n    residual_row_stride: tl.constexpr,\n    eps: tl.constexpr,\n    N_COLS: tl.constexpr,\n    BLOCK_N: tl.constexpr,\n    Q_MIN: tl.constexpr,\n    Q_MAX: tl.constexpr,\n    IS_FLOATING_POINT: tl.constexpr,\n):\n    \"\"\"Rms norm kernel.\"\"\"\n    prog_id = tl.program_id(0)\n    offsets = tl.arange(0, BLOCK_N)\n\n    w = tl.load(weight + offsets, mask=offsets < N_COLS)\n\n    x_ptr = input + prog_id * input_row_stride\n    x = tl.load(x_ptr + offsets, mask=offsets < N_COLS)\n\n    res_ptr = residual + prog_id * residual_row_stride\n    res = tl.load(res_ptr + offsets, mask=offsets < N_COLS)\n\n    new_x = x + res\n    out_res_ptr = out_residual + prog_id * residual_row_stride\n    tl.store(out_res_ptr + offsets, new_x, mask=offsets < N_COLS)\n\n    out = _compute_rms_norm(new_x, w, eps, N_COLS)\n\n    scale = tl.max(tl.abs(out)).to(tl.float32) / Q_MAX\n    out_s_ptr = out_scale + prog_id\n    tl.store(out_s_ptr, scale)\n    out = out / scale\n    if not IS_FLOATING_POINT:\n        out = tl_round(out)\n    out = tl.clamp(out, Q_MIN, Q_MAX)\n    out_ptr = output + prog_id * input_row_stride\n    tl.store(out_ptr + offsets, out, mask=offsets < N_COLS)\n\n\ndef rms_norm_dynamic_quant(x, w, eps, residual=None, quant_dtype=torch.int8):\n    \"\"\"Performs RMS normalization with dynamic quantization.\n\n    The function reshapes the input tensor `x`, creates an empty tensor `y` with the same shape as `x`, and calculates\n    RMS normalization on the reshaped `x` using a Triton kernel `rms_norm_quant_kernel`.\n    \"\"\"\n    qdtype_info = torch.finfo(quant_dtype) if quant_dtype.is_floating_point else torch.iinfo(quant_dtype)\n    y = torch.empty_like(x, dtype=quant_dtype)\n    scale = x.new_empty(x.shape[:-1] + (1, ), dtype=torch.float32)\n\n    feat_size = w.shape[0]\n    seq_len = x.numel() // x.size(-1)\n    input_stride = x.stride(-2)\n    BLOCK_N = triton.next_power_of_2(feat_size)\n    grid = (seq_len, )\n\n    if residual is None:\n        rms_norm_quant_kernel[grid](x,\n                                    w,\n                                    y,\n                                    scale,\n                                    input_row_stride=input_stride,\n                                    eps=eps,\n                                    N_COLS=feat_size,\n                                    BLOCK_N=BLOCK_N,\n                                    Q_MIN=qdtype_info.min,\n                                    Q_MAX=qdtype_info.max,\n                                    IS_FLOATING_POINT=quant_dtype.is_floating_point,\n                                    num_warps=4,\n                                    num_stages=2)\n        return y, scale\n    else:\n        out_residual = torch.empty_like(x)\n        res_stride = residual.stride(-2)\n        add_rms_norm_quant_kernel[grid](x,\n                                        w,\n                                        residual,\n                                        y,\n                                        scale,\n                                        out_residual,\n                                        input_row_stride=input_stride,\n                                        residual_row_stride=res_stride,\n                                        eps=eps,\n                                        N_COLS=feat_size,\n                                        BLOCK_N=BLOCK_N,\n                                        Q_MIN=qdtype_info.min,\n                                        Q_MAX=qdtype_info.max,\n                                        IS_FLOATING_POINT=quant_dtype.is_floating_point,\n                                        num_warps=4,\n                                        num_stages=2)\n        return y, scale, out_residual\n\n\ndef test_rms_and_linear(x, rms_weight, linear_weight, output_dtype=torch.float16, quant_dtype=torch.int8, eps=1e-5):\n    \"\"\"Test quantized rms norm and quantized linear layer.\"\"\"\n\n    def rms_norm_torch(x, w, eps):\n        variance = x.to(torch.float32).pow(2).mean(-1, keepdim=True)\n        x = x * torch.rsqrt(variance + eps)\n        return w * x\n\n    def linear_torch(x, b):\n        return F.linear(x, b)\n\n    linear_weight_quant, linear_scale = per_channel_quant(linear_weight, quant_dtype)\n\n    rms_out, rms_scale = rms_norm_dynamic_quant(x, rms_weight, eps, quant_dtype=quant_dtype)\n    assert rms_out.shape == x.shape and rms_scale.shape[:-1] == x.shape[:-1]\n    linear_out = matmul_kernel_dynamic_quant(rms_out,\n                                             linear_weight_quant,\n                                             rms_scale,\n                                             linear_scale,\n                                             output_dtype=output_dtype)\n\n    rms_out_torch = rms_norm_torch(x, rms_weight, eps).half()\n    linear_out_torch = linear_torch(rms_out_torch, linear_weight)\n    print(f'linear_out.abs().mean() = {linear_out.abs().mean()}')\n    print(f'linear_out_torch.abs().mean() = {linear_out_torch.abs().mean()}')\n    print('perchannel error: ', (linear_out - linear_out_torch).abs().mean())\n    cos = torch.nn.CosineSimilarity(0)\n    print('Output cos', cos(linear_out.flatten().to(torch.float32), linear_out_torch.flatten().to(torch.float32)))\n\n\ndef test_per_token_quant(x, eps, quant_dtype=torch.int8):\n    \"\"\"Test per-token quantization.\"\"\"\n\n    def per_token_quant_int8_torch(x, eps, quant_dtype):\n        qdtype_info = torch.finfo(quant_dtype) if quant_dtype.is_floating_point else torch.iinfo(quant_dtype)\n\n        _absmax = torch.clamp(x.abs().max(dim=-1, keepdim=True)[0], min=eps)\n        x_s = _absmax / qdtype_info.max\n        x_q = x / x_s\n        if not quant_dtype.is_floating_point:\n            x_q = x_q.round()\n        x_q = torch.clamp(x_q, min=qdtype_info.min, max=qdtype_info.max)\n        return x_q, x_s\n\n    x_q, x_s = per_token_quant_int8(x, eps, quant_dtype=quant_dtype)\n    x_q_torch, x_s_torch = per_token_quant_int8_torch(x, eps, quant_dtype=quant_dtype)\n    assert x_q.shape == x_q_torch.shape and x_s.shape == x_s_torch.shape\n    cos = torch.nn.CosineSimilarity(0)\n    print('x_q cos', cos(x_q.flatten().to(torch.float32), x_q_torch.flatten().to(torch.float32)))\n    print('x_s cos', cos(x_s.flatten().to(torch.float32), x_s_torch.flatten().to(torch.float32)))\n\n\ndef bench_rms_and_linear(M: int, provider: str, dtype: torch.dtype = torch.float16, eps: float = 1e-5):\n    \"\"\"Benchmark rms and linear.\"\"\"\n\n    def rms_norm_torch(x, w, eps):\n        variance = x.to(torch.float32).pow(2).mean(-1, keepdim=True)\n        x = x * torch.rsqrt(variance + eps)\n        return w * x\n\n    def linear_torch(x, b):\n        return F.linear(x, b)\n\n    N = 4096\n    K = 4096\n\n    x_shape = (M, K)\n    rms_w_shape = (x_shape[-1], )\n    rms_weight = torch.randn(rms_w_shape, dtype=dtype, device='cuda', requires_grad=True)\n    x = torch.randn(x_shape, dtype=dtype, device='cuda')\n    linear_weight = torch.randn((N, K), dtype=dtype, device='cuda', requires_grad=True)\n\n    if provider == 'torch_fp16':\n        rms_out_torch = rms_norm_torch(x, rms_weight, eps).half()\n\n        def y_fwd():\n            linear_torch(rms_out_torch, linear_weight)\n    else:\n        if provider == 'triton_int8':\n            quant_dtype = torch.int8\n        elif provider == 'triton_fp8_e4m3':\n            quant_dtype = torch.float8_e4m3fn\n        elif provider == 'triton_fp8_e5m2':\n            quant_dtype = torch.float8_e5m2\n\n        linear_weight_quant, linear_scale = per_channel_quant(linear_weight, quant_dtype)\n\n        alpha = max(x.max().abs(), x.min().abs())\n        if quant_dtype.is_floating_point:\n            qdtype_info = torch.finfo(quant_dtype)\n        else:\n            qdtype_info = torch.iinfo(quant_dtype)\n        rms_scale = alpha / qdtype_info.max\n        rms_out, rms_scale = rms_norm_dynamic_quant(x, rms_weight, eps, quant_dtype=quant_dtype)\n\n        def y_fwd():\n\n            matmul_kernel_dynamic_quant(rms_out, linear_weight_quant, rms_scale, linear_scale, output_dtype=dtype)\n\n    quantiles = [0.5, 0.2, 0.8]\n    ms, min_ms, max_ms = triton.testing.do_bench(y_fwd, quantiles=quantiles, rep=500)\n\n    def perf(ms):\n        return 2 * M * N * K * 1e-12 / (ms * 1e-3)\n\n    return perf(ms), perf(max_ms), perf(min_ms)\n\n\nif __name__ == '__main__':\n    torch.manual_seed(0)\n    device_map = torch.cuda.get_device_capability()\n    is_fp8_supported = device_map[0] >= 9\n    dtype = torch.float16\n    # test (bs, seq_len, dim) x (dim, out_dim)\n    x = torch.randn((2, 2048, 4096), dtype=dtype, device='cuda')\n    rms_weight = torch.randn((4096, ), dtype=dtype, device='cuda', requires_grad=True)\n\n    linear_weight = torch.randn((11008, 4096), dtype=dtype, device='cuda', requires_grad=True)\n    test_rms_and_linear(x, rms_weight, linear_weight, quant_dtype=torch.int8)\n    if is_fp8_supported:\n        test_rms_and_linear(x, rms_weight, linear_weight, quant_dtype=torch.float8_e4m3fn)\n        test_rms_and_linear(x, rms_weight, linear_weight, quant_dtype=torch.float8_e5m2)\n\n    # test (M, K) x (K, N)\n    x = torch.randn((4, 4096), dtype=dtype, device='cuda')\n    rms_weight = torch.randn((4096, ), dtype=dtype, device='cuda', requires_grad=True)\n\n    linear_weight = torch.randn((2048, 4096), dtype=dtype, device='cuda', requires_grad=True)\n    test_rms_and_linear(x, rms_weight, linear_weight, quant_dtype=torch.int8)\n    if is_fp8_supported:\n        test_rms_and_linear(x, rms_weight, linear_weight, quant_dtype=torch.float8_e4m3fn)\n        test_rms_and_linear(x, rms_weight, linear_weight, quant_dtype=torch.float8_e5m2)\n\n    # test per-token quant\n    x = torch.randn((4, 2048, 4096), dtype=dtype, device='cuda')\n    eps = 1e-7\n    test_per_token_quant(x, eps, quant_dtype=torch.int8)\n    if is_fp8_supported:\n        test_per_token_quant(x, eps, quant_dtype=torch.float8_e4m3fn)\n        test_per_token_quant(x, eps, quant_dtype=torch.float8_e5m2)\n\n    # benchmark triton kernels\n    line_vals = ['triton_int8', 'torch_fp16']\n    line_names = ['triton_int8', 'torch_fp16']\n\n    if is_fp8_supported:\n        line_vals += ['triton_fp8_e4m3', 'triton_fp8_e5m2']\n        line_names += ['triton_fp8_e4m3', 'triton_fp8_e5m2']\n    config = triton.testing.Benchmark(x_names=['M'],\n                                      x_vals=[1, 16, 32, 64, 128, 256] + [512 * i * 2 for i in range(1, 5)],\n                                      line_arg='provider',\n                                      line_vals=line_vals,\n                                      line_names=line_names,\n                                      styles=[('blue', '-'), ('green', '-'), ('orange', '-'), ('black', '-'),\n                                              ('yellow', '-')],\n                                      ylabel='TFLOPS',\n                                      plot_name='bench-triton',\n                                      args={\n                                          'dtype': torch.float16,\n                                      })\n    bench_funch = (triton.testing.perf_report(config))(bench_rms_and_linear)\n    bench_funch.run(print_data=True)\n"
  },
  {
    "path": "lmdeploy/pytorch/kernels/default/__init__.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom .multinomial_sampling import multinomial_sampling\nfrom .w8a8_kernels import per_channel_quant\n\n__all__ = [\n    'multinomial_sampling',\n    'per_channel_quant',\n]\n"
  },
  {
    "path": "lmdeploy/pytorch/kernels/default/multinomial_sampling.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport torch\nfrom torch import LongTensor, Tensor\n\n\ndef multinomial_sampling(scores: Tensor, seeds: LongTensor, offsets: LongTensor, indices: Tensor = None):\n    sampled_index = torch.multinomial(scores, num_samples=1, replacement=True)\n    outputs = torch.gather(indices, dim=1, index=sampled_index)\n    return outputs.view(-1)\n"
  },
  {
    "path": "lmdeploy/pytorch/kernels/default/w8a8_kernels.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport torch\n\n\ndef per_channel_quant(x: torch.Tensor, dtype: torch.dtype):\n    \"\"\"Quantize the input tensor 'x' channel-wise using the given number of\n    bits.\n\n    Args:\n        x (torch.Tensor): The input tensor to be quantized. Must be a\n            2-dimensional tensor.\n        dtype (torch.dtype): The data type to which the quantized tensor should\n            be converted.\n\n    Returns:\n        tuple: A tuple containing two items -- the quantized tensor and\n            the scale used for quantization.\n    \"\"\"\n    assert x.ndim == 2\n    x = x.to(torch.float32)\n    x_absmax = x.view(x.shape[0], -1).abs().max(dim=1, keepdim=True)[0]\n    qtype_info = torch.finfo(dtype) if dtype.is_floating_point else torch.iinfo(dtype)\n    q_max = qtype_info.max\n    q_min = qtype_info.min\n    scale = x_absmax / q_max\n    x_q = x / scale\n    if not dtype.is_floating_point:\n        x_q = torch.round(x_q)\n    x_q = x_q.clamp(q_min, q_max).to(dtype)\n    return x_q, scale\n"
  },
  {
    "path": "lmdeploy/pytorch/kernels/dispatcher.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport importlib\nimport inspect\nfrom typing import Callable\n\nfrom lmdeploy.utils import get_logger\n\nfrom ..devices import DeviceContext, get_device_manager\n\nlogger = get_logger('lmdeploy')\n\n\ndef _default_api(*args, **kwargs):\n    \"\"\"Default api.\"\"\"\n    ...\n\n\nclass ParamParser:\n\n    def __init__(self, param: inspect.Parameter) -> None:\n        self.param = param\n\n    def name(self):\n        \"\"\"name.\"\"\"\n        return self.param.name\n\n    def func_arg(self):\n        \"\"\"Func arg.\"\"\"\n        param = self.param\n        name = self.name()\n        kind = param.kind\n        ret = name\n        if kind == inspect.Parameter.VAR_POSITIONAL:\n            ret = f'*{name}'\n        elif kind == inspect.Parameter.VAR_KEYWORD:\n            ret = f'**{name}'\n\n        default = param.default\n        if default != inspect._empty:\n            ret = f'{ret}={default}'\n\n        return ret\n\n    def func_input(self):\n        \"\"\"Func input.\"\"\"\n        param = self.param\n        name = self.name()\n        kind = param.kind\n        ret = name\n        if kind == inspect.Parameter.VAR_POSITIONAL:\n            ret = f'*{name}'\n        elif kind == inspect.Parameter.VAR_KEYWORD:\n            ret = f'**{name}'\n        else:\n            ret = f'{name}={name}'\n        return ret\n\n\nclass FunctionDispatcher:\n\n    def __init__(self, func_name: str):\n        self.device_manager = get_device_manager()\n        self.impl_map: dict[str, Callable] = dict()\n        self.func_name = func_name\n        self.dispatched_func = self.load_and_call\n        self.device_manager.register_context_callback(self.device_callback)\n        self.device_map = {'cuda': 'cuda', 'ascend': 'dlinfer', 'npu': 'dlinfer', 'maca': 'dlinfer', 'camb': 'dlinfer'}\n\n    def device_callback(self, context: DeviceContext):\n        \"\"\"Device context callback.\"\"\"\n        self.dispatched_func = self.load_and_call\n\n    def load_func(self, device: str):\n        \"\"\"Load function.\"\"\"\n        try:\n            mod = importlib.import_module(f'lmdeploy.pytorch.kernels.{device}')\n            func = getattr(mod, self.func_name)\n            self.impl_map[device] = func\n        except Exception:\n            logger.debug(f'Failed to load <{self.func_name}>'\n                         f' for <{device}>, '\n                         'try load default implementation.')\n            mod = importlib.import_module('lmdeploy.pytorch.kernels.default')\n            if not hasattr(mod, self.func_name):\n                raise RuntimeError(f'<{self.func_name}> default and <{device}>'\n                                   ' implementation not exists.')\n            func = getattr(mod, self.func_name)\n            self.impl_map[device] = func\n\n    def load_and_call(self, *args, **kwargs):\n        \"\"\"Load and call.\"\"\"\n        device = self.device_manager.current_context().device_type\n        if device not in self.impl_map:\n            self.load_func(device)\n        self.dispatched_func = self.impl_map[device]\n        return self.dispatched_func(*args, **kwargs)\n\n    def make_caller(self, api: Callable = _default_api, globals=None):\n        \"\"\"Make call function.\"\"\"\n        signature = inspect.signature(api)\n        params = signature.parameters\n\n        param_parsers = [ParamParser(p) for p in params.values()]\n        func_args = [p.func_arg() for p in param_parsers]\n        func_inputs = [p.func_input() for p in param_parsers]\n        func_args = ', '.join(func_args)\n        func_inputs = ', '.join(func_inputs)\n\n        src = f\"\"\"\ndef {self.func_name}({func_args}):\n    return dispatcher.dispatched_func({func_inputs})\n\"\"\"   # noqa: E501\n\n        scope = dict(dispatcher=self, )\n        if globals is not None:\n            scope.update(globals)\n        exec(src, scope)\n        return scope[f'{self.func_name}']\n"
  },
  {
    "path": "lmdeploy/pytorch/kernels/dlinfer/__init__.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom ..default import multinomial_sampling, per_channel_quant\nfrom .apply_rotary_pos_emb import apply_rotary_pos_emb\nfrom .awq_kernels import awq_linear\nfrom .fill_kv_cache import fill_kv_cache\nfrom .flash_attention import flash_attention_fwd\nfrom .fused_moe import DlinferMoECommType, DlinferMoeMetadata, fused_moe\nfrom .linear import linear\nfrom .moe_gating_topk_softmax import moe_gating_topk_softmax\nfrom .pagedattention import paged_attention_fwd\nfrom .rms_norm import rms_norm\n\n__all__ = [\n    'rms_norm',\n    'apply_rotary_pos_emb',\n    'awq_linear',\n    'fill_kv_cache',\n    'DlinferMoECommType',\n    'DlinferMoeMetadata',\n    'fused_moe',\n    'paged_attention_fwd',\n    'flash_attention_fwd',\n    'linear',\n    'moe_gating_topk_softmax',\n    'multinomial_sampling',\n    'per_channel_quant',\n]\n"
  },
  {
    "path": "lmdeploy/pytorch/kernels/dlinfer/activation.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport dlinfer.ops as ext_ops\nfrom torch import Tensor\n\n\ndef silu_and_mul(input_tensor: Tensor, ) -> Tensor:\n    return ext_ops.silu_and_mul(input_tensor)\n"
  },
  {
    "path": "lmdeploy/pytorch/kernels/dlinfer/apply_rotary_pos_emb.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom typing import Optional, Tuple\n\nimport dlinfer.ops as ext_ops\nfrom torch import Tensor\n\n\ndef apply_rotary_pos_emb(\n    query_states: Tensor,\n    key_states: Tensor,\n    cos: Tensor,\n    sin: Tensor,\n    q_embed: Optional[Tensor],\n    k_embed: Optional[Tensor],\n) -> Tuple[Tensor, Tensor]:\n    query_states_embed, key_states_embed = \\\n        ext_ops.apply_rotary_pos_emb(query_states,\n                                     key_states,\n                                     cos, sin)\n    if q_embed is None:\n        q_embed = query_states_embed.view(query_states.shape)\n    elif q_embed is not query_states:\n        q_embed.copy_(query_states_embed.view(query_states.shape))\n\n    if k_embed is None:\n        k_embed = key_states_embed.view(key_states.shape)\n    elif k_embed is not key_states:\n        k_embed.copy_(key_states_embed.view(key_states.shape))\n\n    return q_embed, k_embed\n"
  },
  {
    "path": "lmdeploy/pytorch/kernels/dlinfer/awq_kernels.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom typing import Optional\n\nimport dlinfer.ops as ext_ops\nfrom torch import Tensor\n\n\ndef awq_linear(x: Tensor,\n               qweight: Tensor,\n               scales: Tensor,\n               qzeros: Tensor,\n               bias: Optional[Tensor] = None,\n               all_reduce: bool = False,\n               group_size: int = 0):\n    return ext_ops.weight_quant_matmul(x.squeeze(0),\n                                       qweight,\n                                       scales,\n                                       offset=qzeros,\n                                       bias=bias,\n                                       all_reduce=all_reduce,\n                                       group_size=group_size).unsqueeze(0)\n"
  },
  {
    "path": "lmdeploy/pytorch/kernels/dlinfer/fill_kv_cache.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom typing import Optional, Sequence\n\nimport dlinfer.ops as ext_ops\nfrom torch import Tensor\n\n\ndef fill_kv_cache(\n    key_states: Tensor,\n    value_states: Tensor,\n    key_caches: Tensor,\n    value_caches: Tensor,\n    kv_start_indices: Tensor,\n    k_scales_zeros: Sequence[Optional[Tensor]],\n    v_scales_zeros: Sequence[Optional[Tensor]],\n    quant_bits: int = 0,\n):\n    \"\"\"Fill key/value state to cache for paged attention.\"\"\"\n    return ext_ops.fill_kv_cache(key_states,\n                                 value_states,\n                                 key_caches,\n                                 value_caches,\n                                 kv_start_indices,\n                                 k_scales_zeros=k_scales_zeros,\n                                 v_scales_zeros=v_scales_zeros,\n                                 quant_bits=quant_bits)\n"
  },
  {
    "path": "lmdeploy/pytorch/kernels/dlinfer/flash_attention.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport dlinfer.ops as ext_ops\nfrom torch import Tensor\n\n\ndef flash_attention_fwd(\n    query_states: Tensor,\n    key_states: Tensor,\n    value_states: Tensor,\n    attn_output: Tensor,\n    q_start_loc: Tensor,\n    q_seqlens: Tensor,\n    kv_start_loc: Tensor,\n    kv_seqlens: Tensor,\n    num_heads: int,\n    num_kv_heads: int,\n    max_q_seqlen: int = None,\n    window_size: int = None,\n    sm_scale: float = None,\n    logit_softcapping: float = None,\n    causal: bool = True,\n):\n    return ext_ops.prefill_attention(\n        query_states,\n        key_states,\n        value_states,\n        None,\n        None,\n        q_start_loc,\n        q_seqlens,\n        kv_seqlens,\n        max_q_seqlen,\n        num_heads,\n        num_kv_heads,\n        attn_mask=[],\n        softmax_scale=sm_scale,\n        attn_output=attn_output,\n    )\n"
  },
  {
    "path": "lmdeploy/pytorch/kernels/dlinfer/fused_moe.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport dlinfer.ops as ext_ops\nfrom dlinfer.utils.type_annotation import MoECommType as DlinferMoECommType  # noqa: F401\nfrom dlinfer.utils.type_annotation import MoeMetadata as DlinferMoeMetadata\nfrom torch import Tensor\n\n\ndef fused_moe(\n    hidden_states: Tensor,\n    gate_up_weights: Tensor,\n    down_weights: Tensor,\n    topk_weights: Tensor,\n    topk_ids: Tensor,\n    topk: int,\n    renormalize: bool,\n    moe_metadata: DlinferMoeMetadata,\n):\n    \"\"\"Dlinfer fused moe.\"\"\"\n    return ext_ops.fused_moe(hidden_states, gate_up_weights, down_weights, topk_weights, topk_ids, topk, renormalize,\n                             moe_metadata)\n"
  },
  {
    "path": "lmdeploy/pytorch/kernels/dlinfer/fused_rotary_emb.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport dlinfer.ops as ext_ops\nimport torch\nfrom torch import Tensor\n\n\ndef fused_rotary_emb(\n    query_states: Tensor,\n    key_states: Tensor,\n    position_ids: torch.LongTensor,\n    inv_freq: Tensor,\n    scaling_factor: float,\n    out_q: Tensor = None,\n    out_k: Tensor = None,\n    context=None,\n):\n    batch, seqlen, head, dim = query_states.shape\n    num_kv_heads = key_states.shape[-2]\n    query_states_reshaped = query_states.view(batch, seqlen, head, dim)\n    key_states_reshaped = key_states.view(batch, seqlen, num_kv_heads, dim)\n    position_ids = position_ids.squeeze(0).unsqueeze(-1)\n    pos_freq = position_ids / scaling_factor * inv_freq\n    if not (hasattr(context, 'cos') or hasattr(context, 'sin')):\n        cos = (torch.cos(pos_freq).view(batch, seqlen, 1, -1).repeat(1, 1, 1, 2).to(query_states.dtype))\n        sin = (torch.sin(pos_freq).view(batch, seqlen, 1, -1).repeat(1, 1, 1, 2).to(query_states.dtype))\n        if context:\n            setattr(context, 'cos', cos)\n            setattr(context, 'sin', sin)\n    cached_cos = context.cos if context else cos\n    cached_sin = context.sin if context else sin\n    ext_ops.apply_rotary_pos_emb(query_states_reshaped, key_states_reshaped, cached_cos, cached_sin, None, None)\n    if out_q is None:\n        out_q = query_states\n    else:\n        out_q.copy_(query_states)\n    if out_k is None:\n        out_k = key_states\n    else:\n        out_k.copy_(key_states)\n    return out_q, out_k\n"
  },
  {
    "path": "lmdeploy/pytorch/kernels/dlinfer/linear.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom typing import Optional\n\nimport dlinfer.ops as ext_ops\nfrom torch import Tensor\n\n\ndef linear(x: Tensor, weight: Tensor, bias: Optional[Tensor] = None, all_reduce: bool = False, group: str = ''):\n    return ext_ops.linear(x, weight, bias=bias, all_reduce=all_reduce, group=group)\n"
  },
  {
    "path": "lmdeploy/pytorch/kernels/dlinfer/moe_gating_topk_softmax.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom typing import Tuple\n\nimport dlinfer.ops as ext_ops\nfrom dlinfer.utils.type_annotation import MoeMetadata as DlinferMoeMetadata\nfrom torch import Tensor\n\n\ndef moe_gating_topk_softmax(router_logits: Tensor, topk: int,\n                            moe_metadata: DlinferMoeMetadata) -> Tuple[Tensor, Tensor]:\n    routing_weights, selected_experts = ext_ops.moe_gating_topk_softmax(router_logits, topk, moe_metadata)\n    return routing_weights, selected_experts\n"
  },
  {
    "path": "lmdeploy/pytorch/kernels/dlinfer/pagedattention.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom typing import Optional, Sequence\n\nimport dlinfer.ops as ext_ops\nfrom torch import Tensor\n\n\ndef prefill_attention(\n    query_states: Tensor,\n    key_states: Tensor,\n    value_states: Tensor,\n    attn_output: Tensor,\n    key_cache: Tensor,\n    value_cache: Tensor,\n    block_offsets: Tensor,\n    q_start_loc: Tensor,\n    q_seq_len: Tensor,\n    kv_seq_len: Tensor,\n    cu_seq_lens_kv: Tensor,\n    max_q_seq_len: int,\n    max_kv_seq_len: int,\n    block_size: int,\n    num_q_heads: int,\n    num_kv_heads: int,\n    head_size_v: int,\n    attn_mask: Sequence[Optional[Tensor]],\n    softmax_scale: Optional[float],\n    is_unpaged_prefill: Optional[bool],\n    kv_scales: Optional[Tensor],\n    kv_zeros: Optional[Tensor],\n    quant_bits: Optional[int],\n) -> Tensor:\n    if is_unpaged_prefill:\n        return ext_ops.prefill_attention(\n            query_states,\n            key_states,\n            value_states,\n            key_cache,\n            value_cache,\n            q_start_loc,\n            q_seq_len,\n            kv_seq_len,\n            max_q_seq_len,\n            num_q_heads,\n            num_kv_heads,\n            attn_mask,\n            softmax_scale=softmax_scale,\n            attn_output=attn_output,\n        )\n    else:\n        return ext_ops.paged_prefill_attention(\n            query_states,\n            key_states,\n            value_states,\n            key_cache,\n            value_cache,\n            block_offsets,\n            block_size,\n            q_start_loc,\n            q_seq_len,\n            kv_seq_len,\n            cu_seq_lens_kv,\n            max_q_seq_len,\n            max_kv_seq_len,\n            num_q_heads,\n            num_kv_heads,\n            attn_mask,\n            head_size_v=head_size_v,\n            softmax_scale=softmax_scale,\n            attn_output=attn_output,\n            kv_scales=kv_scales,\n            kv_zeros=kv_zeros,\n            quant_bits=quant_bits,\n        )\n\n\ndef paged_token_attention(\n    q,\n    k_cache,\n    v_cache,\n    attn_output,\n    kv_seq_len,\n    max_kv_seq_len,\n    block_offsets,\n    block_size,\n    num_q_heads,\n    num_kv_heads,\n    head_size_v,\n    softmax_scale: Optional[float],\n    kv_scales: Optional[Tensor],\n    kv_zeros: Optional[Tensor],\n    quant_bits: Optional[int],\n):\n    return ext_ops.paged_decode_attention(\n        q,\n        k_cache,\n        v_cache,\n        block_offsets,\n        block_size,\n        kv_seq_len,\n        max_kv_seq_len,\n        num_q_heads,\n        num_kv_heads,\n        head_size_v=head_size_v,\n        softmax_scale=softmax_scale,\n        attn_output=attn_output,\n        kv_scales=kv_scales,\n        kv_zeros=kv_zeros,\n        quant_bits=quant_bits,\n    )\n\n\ndef paged_attention_fwd(\n    query_states: Tensor,\n    key_states: Tensor,\n    value_states: Tensor,\n    attn_output: Tensor,\n    key_cache: Tensor,\n    value_cache: Tensor,\n    block_offsets: Tensor,\n    q_start_loc: Tensor,\n    q_seqlens: Tensor,\n    kv_seqlens: Tensor,\n    cu_seq_lens_kv: Tensor,\n    max_q_seq_len: int,\n    max_kv_seq_len: int,\n    is_decoding: bool,\n    block_size: int,\n    num_heads: int,\n    num_kv_heads: int,\n    v_head_size: int,\n    attn_mask: Sequence[Optional[Tensor]] = (),\n    softmax_scale: Optional[float] = None,\n    is_unpaged_prefill: Optional[bool] = None,\n    kv_scales: Optional[Tensor] = None,\n    kv_zeros: Optional[Tensor] = None,\n    quant_bits: Optional[int] = 0,\n):\n    if not is_decoding:\n        return prefill_attention(\n            query_states,\n            key_states,\n            value_states,\n            attn_output,\n            key_cache,\n            value_cache,\n            block_offsets,\n            q_start_loc,\n            q_seqlens,\n            kv_seqlens,\n            cu_seq_lens_kv,\n            max_q_seq_len,\n            max_kv_seq_len,\n            block_size,\n            num_heads,\n            num_kv_heads,\n            v_head_size,\n            attn_mask,\n            softmax_scale,\n            is_unpaged_prefill,\n            kv_scales=kv_scales,\n            kv_zeros=kv_zeros,\n            quant_bits=quant_bits,\n        )\n    else:\n        return paged_token_attention(\n            query_states,\n            key_cache,\n            value_cache,\n            attn_output,\n            kv_seqlens,\n            max_kv_seq_len,\n            block_offsets,\n            block_size,\n            num_heads,\n            num_kv_heads,\n            v_head_size,\n            softmax_scale=softmax_scale,\n            kv_scales=kv_scales,\n            kv_zeros=kv_zeros,\n            quant_bits=quant_bits,\n        )\n"
  },
  {
    "path": "lmdeploy/pytorch/kernels/dlinfer/rms_norm.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport dlinfer.ops as ext_ops\nfrom torch import Tensor\n\n\ndef rms_norm(hidden_states: Tensor, weight: Tensor, epsilon: float = 1e-6, residual: Tensor = None, out: Tensor = None):\n    if residual is None:\n        rms_norm_out = ext_ops.rms_norm(hidden_states, weight, epsilon)\n        if out is None:\n            out = rms_norm_out\n        else:\n            out.copy_(rms_norm_out)\n        return out\n    else:\n        return ext_ops.add_rms_norm(hidden_states, residual, weight, epsilon)\n"
  },
  {
    "path": "lmdeploy/pytorch/kernels/dlinfer/w8a8_kernels.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport dlinfer.ops as ext_ops\nimport torch\nfrom torch import Tensor\n\n\ndef dynamic_quant(x: Tensor, quant_dtype: torch.dtype, quant_granularity: str = 'PER_TOKEN'):\n    input_quant, input_scale = ext_ops.dynamic_quant(x, quant_dtype, quant_granularity)\n    return input_quant, input_scale\n\n\ndef linear_w8a8(\n    a: Tensor,\n    b: Tensor,\n    rms_scale: float,\n    linear_scale: float,\n    out_dtype: torch.dtype,\n    quant_dtype: torch.dtype,\n    bias=None,\n):\n    \"\"\"This function performs matrix multiplication with dynamic quantization.\n\n    It takes two input tensors `a` and `b`, scales them with `rms_scale` and `linear_scale`, and optionally adds a\n    `bias`. The output is returned in the specified `output_dtype`.\n    \"\"\"\n    return ext_ops.linear_w8a8(a, b, rms_scale, linear_scale, out_dtype, quant_dtype, bias)\n\n\ndef rms_norm_w8a8(\n    hidden_states: Tensor,\n    weight: Tensor,\n    epsilon: float,\n    quant_dtype: torch.dtype = torch.int8,\n    residual: Tensor = None,\n):\n    \"\"\"Rms norm kernel.\"\"\"\n    if residual is None:\n        return ext_ops.rms_norm_w8a8(hidden_states, weight, epsilon, quant_dtype)\n    else:\n        return ext_ops.add_rms_norm_w8a8(hidden_states, residual, weight, epsilon, quant_dtype)\n"
  },
  {
    "path": "lmdeploy/pytorch/kernels/w8a8_triton_kernels.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom .dispatcher import FunctionDispatcher\n\nper_channel_quant = FunctionDispatcher('per_channel_quant').make_caller()\n\nmatmul_kernel_dynamic_quant = FunctionDispatcher('matmul_kernel_dynamic_quant').make_caller()\n\nper_token_quant_int8 = FunctionDispatcher('per_token_quant_int8').make_caller()\n\nrms_norm_dynamic_quant = FunctionDispatcher('rms_norm_dynamic_quant').make_caller()\n"
  },
  {
    "path": "lmdeploy/pytorch/messages.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport enum\nfrom collections import defaultdict\nfrom dataclasses import dataclass, field\nfrom typing import TYPE_CHECKING, Any, Dict, List\n\nimport numpy as np\nimport torch\nfrom torch import Tensor\n\nfrom lmdeploy.messages import EngineEvent, EventType, GenerationConfig, LogitsProcessor\nfrom lmdeploy.pytorch.disagg.conn.protocol import MigrationRequest\nfrom lmdeploy.pytorch.multimodal.data_type import MultiModalInputs\nfrom lmdeploy.utils import get_logger\n\nfrom .block import LogicalTokenBlocks\n\nif TYPE_CHECKING:\n    from lmdeploy.pytorch.paging.scheduler import Scheduler\n    from lmdeploy.pytorch.paging.seq_states.states import StateBase\n    from lmdeploy.pytorch.strategies.base.sampling import SamplingStrategy\n    from lmdeploy.pytorch.strategies.base.sequence import SequenceStrategy\n\nlogger = get_logger('lmdeploy')\n\n# vlm input type from pipeline\nInputEmbeddingType = List[np.ndarray]\nInputEmbeddingRangeType = List[List[int]]\n\n\n@dataclass\nclass InputEmbeddings:\n    \"\"\"InputEmbeddings.\"\"\"\n    embeddings: np.ndarray\n    start: int\n    end: int\n\n    def move_position(self, offset: int = 0):\n        if offset != 0:\n            self.start += offset\n            self.end += offset\n        return self\n\n\n@dataclass\nclass SamplingParam:\n    \"\"\"Sampling parameter.\"\"\"\n    top_p: float = 1.0\n    top_k: int = 1\n    min_p: float = 0.0\n    temperature: float = 0.8\n    repetition_penalty: float = 1.0\n    ignore_eos: bool = False\n    random_seed: int = None\n    stop_words: List[int] = field(default_factory=list)\n    bad_words: List[int] = field(default_factory=list)\n    max_new_tokens: int = 512\n    min_new_tokens: int = 0\n    response_format: None | str = None\n    logits_processors: None | List[LogitsProcessor] = None\n    out_logits: bool = False\n    out_last_hidden_states: bool = False\n    num_logprobs: int = -1\n    return_routed_experts: bool = False\n\n    # ngram\n    repetition_ngram_size: int = 0\n    repetition_ngram_threshold: int = 0\n\n    @classmethod\n    def from_gen_config(cls, gen_config: GenerationConfig):\n        \"\"\"From gen config.\"\"\"\n        min_new_tokens = gen_config.min_new_tokens or 0\n\n        stop_words = gen_config.stop_token_ids or []\n        bad_words = gen_config.bad_token_ids or []\n        if gen_config.ignore_eos:\n            bad_words += stop_words\n            stop_words = []\n\n        top_k = gen_config.top_k\n        top_p = gen_config.top_p\n        min_p = gen_config.min_p\n        temperature = gen_config.temperature\n        repetition_penalty = gen_config.repetition_penalty\n        max_new_tokens = gen_config.max_new_tokens\n        response_format = gen_config.response_format\n\n        output_logits = gen_config.output_logits\n        if output_logits:\n            if (output_logits != 'all' or gen_config.max_new_tokens > 0):\n                output_logits = None\n                logger.warning('Pytorch Engine only support output_logits=\"all\"'\n                               ' with max_new_tokens=0')\n        if gen_config.output_last_hidden_state is not None:\n            logger.warning('Pytorch Engine does not support output last hidden states.')\n        if top_p < 0 or top_p > 1.0:\n            logger.warning('`top_p` has to be a float > 0 and < 1'\n                           f' but is {top_p}')\n            top_p = 1.0\n        if min_p < 0 or min_p > 1.0:\n            logger.warning('`min_p` has to be a float > 0 and < 1'\n                           f' but is {min_p}')\n            min_p = 0.0\n        if temperature == 0:\n            logger.warning('`temperature` is 0, set top_k=1.')\n            temperature = 1.0\n            top_k = 1\n        if temperature < 0:\n            logger.warning('`temperature` has to be a strictly'\n                           f' positive value, but is {temperature}')\n            temperature = 1.0\n        if repetition_penalty <= 0:\n            logger.warning('`repetition_penalty` has to be a strictly'\n                           f' positive value, but is {repetition_penalty}')\n            repetition_penalty = 1.0\n        if max_new_tokens < 0:\n            logger.warning('`max_new_tokens` has to be a strictly'\n                           f' positive value, but is {max_new_tokens}')\n            max_new_tokens = 512\n        if min_new_tokens < 0 or min_new_tokens > max_new_tokens:\n            logger.warning('`min_new_tokens` has to be '\n                           'a int >=0 and <= `max_new_tokens`,'\n                           f' but is {min_new_tokens}')\n            min_new_tokens = 0\n        logprobs = gen_config.logprobs\n        if logprobs is None:\n            logprobs = -1\n\n        random_seed = gen_config.random_seed\n        if random_seed is None:\n            import random\n            random_seed = random.getrandbits(64)\n        return SamplingParam(\n            top_p=top_p,\n            top_k=top_k,\n            min_p=min_p,\n            temperature=temperature,\n            repetition_penalty=repetition_penalty,\n            ignore_eos=gen_config.ignore_eos,\n            random_seed=random_seed,\n            stop_words=stop_words,\n            bad_words=bad_words,\n            response_format=response_format,\n            max_new_tokens=max_new_tokens,\n            min_new_tokens=min_new_tokens,\n            logits_processors=gen_config.logits_processors,\n            out_logits=(output_logits is not None),\n            num_logprobs=logprobs,\n            return_routed_experts=gen_config.return_routed_experts,\n            repetition_ngram_size=gen_config.repetition_ngram_size,\n            repetition_ngram_threshold=gen_config.repetition_ngram_threshold,\n        )\n\n\nclass MessageStatus(enum.Enum):\n    \"\"\"Status of a sequence.\"\"\"\n\n    WAITING = enum.auto()\n    READY = enum.auto()\n    STOPPED = enum.auto()\n    RUNNING = enum.auto()\n\n    # PD Disaggregation\n    # MIGRATION_WAITING: state of Unmigrated Requests\n    # in both prefill and decode engines are tagged by\n    # MIGRATION_READY: state of Migrating Requests\n    # in decode engine\n    TO_BE_MIGRATED = enum.auto()\n    MIGRATION_WAITING = enum.auto()\n    MIGRATION_READY = enum.auto()\n    MIGRATION_RUNNING = enum.auto()\n    MIGRATION_DONE = enum.auto()\n\n\nSeqMap = Dict[int, 'SchedulerSequence']\n\n\n@dataclass\nclass SequenceMeta:\n    \"\"\"Meta data shared by all sequence.\"\"\"\n    block_size: int\n    strategy: 'SequenceStrategy' = None\n    sampling_strategy: 'SamplingStrategy' = None\n\n\nclass SequenceManager:\n    \"\"\"Sequence manager.\"\"\"\n\n    def __init__(self, seq_meta: SequenceMeta) -> None:\n        self._seq_map: SeqMap = dict()\n        self._status_seq_map: Dict[MessageStatus, SeqMap] = defaultdict(dict)\n\n        self.seq_meta = seq_meta\n        self._seq_count = 0\n\n    def _new_seq_id(self):\n        seq_id = self._seq_count\n        self._seq_count += 1\n        return seq_id\n\n    def get_all_sequences(self):\n        \"\"\"Get all sequences.\"\"\"\n        return self._seq_map.values()\n\n    def get_sequences(self, states: MessageStatus):\n        \"\"\"Get sequences.\"\"\"\n        return self._status_seq_map[states]\n\n    def num_sequences(self, status: MessageStatus):\n        \"\"\"Num sequences.\"\"\"\n        return len(self.get_sequences(status))\n\n    def add_sequence(self, seq: 'SchedulerSequence'):\n        \"\"\"Add sequence.\"\"\"\n        seq_id = seq.seq_id\n        status = seq.status\n        status_map = self._status_seq_map[status]\n        self._seq_map[seq_id] = seq\n        status_map[seq_id] = seq\n\n    def remove_sequence(self, seq: 'SchedulerSequence'):\n        \"\"\"Remove sequence.\"\"\"\n        seq_id = seq.seq_id\n        status = seq.status\n        status_map = self._status_seq_map[status]\n        self._seq_map.pop(seq_id)\n        status_map.pop(seq_id)\n\n    def update_sequence_status(self, seq: 'SchedulerSequence', new_status: MessageStatus):\n        \"\"\"Update status.\"\"\"\n        old_status = seq.status\n        if new_status == old_status:\n            return\n        seq_id = seq.seq_id\n        old_status_map = self._status_seq_map[old_status]\n        new_status_map = self._status_seq_map[new_status]\n        # may be remove by async_end\n        if seq_id in old_status_map:\n            old_status_map.pop(seq_id)\n            new_status_map[seq_id] = seq\n\n\ndef _to_ndarray(token_ids) -> np.ndarray:\n    \"\"\"To ndarray.\"\"\"\n    if isinstance(token_ids, Tensor):\n        token_ids = token_ids.numpy()\n    elif not isinstance(token_ids, np.ndarray):\n        token_ids = np.array(token_ids)\n    if token_ids.ndim == 0:\n        token_ids = token_ids[None]\n    return token_ids\n\n\nclass SchedulerSession:\n    \"\"\"Scheduler session.\"\"\"\n\n    def __init__(self, session_id: int, seq_manager: SequenceManager, scheduler: 'Scheduler') -> None:\n        self.session_id = session_id\n        self.seq_meta = seq_manager.seq_meta\n        self.sequences: SeqMap = dict()\n        self.seq_manager = seq_manager\n        self.scheduler = scheduler\n\n    def add_sequence(self,\n                     token_ids: Tensor,\n                     sampling_param: SamplingParam = None,\n                     adapter_name: str = None,\n                     multimodals: MultiModalInputs = None,\n                     input_embeddings: List[InputEmbeddings] = None,\n                     migration_request: None | MigrationRequest = None,\n                     resp_cache: bool = False,\n                     preserve_cache: bool = False) -> 'SchedulerSequence':\n        \"\"\"Add a new message.\"\"\"\n        from lmdeploy.pytorch.paging.seq_states.states import build_seq_state\n\n        if sampling_param is None:\n            sampling_param = SamplingParam()\n\n        seq_id = self.seq_manager._new_seq_id()\n        seq = self.seq_meta.strategy.make_sequence(seq_id=seq_id,\n                                                   session=self,\n                                                   sampling_param=sampling_param,\n                                                   adapter_name=adapter_name,\n                                                   migration_request=migration_request,\n                                                   resp_cache=resp_cache,\n                                                   preserve_cache=preserve_cache)\n        seq.update_token_ids(\n            token_ids,\n            multimodals=multimodals,\n            embeddings=input_embeddings,\n            mode=UpdateTokenMode.INPUTS,\n        )\n        self.sequences[seq.seq_id] = seq\n\n        # set status\n        # update seq manager\n        status = MessageStatus.WAITING if migration_request is None else MessageStatus.MIGRATION_WAITING\n        seq.set_state(build_seq_state(self.scheduler, seq, status))\n        self.seq_manager.add_sequence(seq)\n\n        # metrics\n        seq.record_event(EventType.QUEUED)\n\n        return seq\n\n    def remove_sequence(self, seq: 'SchedulerSequence'):\n        \"\"\"Remove sequence.\"\"\"\n        assert seq.seq_id in self.sequences\n        seq.state.free()\n        self.sequences.pop(seq.seq_id)\n        self.seq_manager.remove_sequence(seq)\n\n\ndef _div_up(x, n):\n    \"\"\"Perform div up.\"\"\"\n    return (x + n - 1) // n\n\n\ndef _round_up(x, n):\n    \"\"\"Perform round up.\"\"\"\n    return _div_up(x, n) * n\n\n\nclass HistoryEmbeddings:\n    \"\"\"History embeddings.\"\"\"\n\n    def __init__(self, embeddings: List[InputEmbeddings] = None):\n        self._embeddings: List[InputEmbeddings] = []\n        if embeddings is not None:\n            self._embeddings.extend(embeddings)\n\n    def append(self, embeddings: List[InputEmbeddings]):\n        self._embeddings.extend(embeddings)\n\n    def clone(self):\n        ret = HistoryEmbeddings(self._embeddings)\n        return ret\n\n    def copy(self):\n        return self.clone()\n\n    def get_step(self, step: int) -> int:\n        \"\"\"Get step before a whole image.\"\"\"\n        real_step = step\n        num_all_images = len(self._embeddings)\n        history_image_num = 0\n        if num_all_images > 0:\n            history_image_num = sum([1 for emb in self._embeddings if emb.end <= step])\n            if history_image_num < num_all_images:\n                emb = self._embeddings[history_image_num]\n                # for case step in middle of an image\n                if emb.start < step:\n                    real_step = emb.start\n        num_images = num_all_images - history_image_num\n        return real_step, history_image_num, num_images\n\n    @property\n    def embeddings(self):\n        \"\"\"embeddings.\"\"\"\n        return self._embeddings\n\n    def __len__(self):\n        \"\"\"Get num images.\"\"\"\n        return len(self._embeddings)\n\n    def __getitem__(self, *args, **kwargs):\n        \"\"\"Get values.\"\"\"\n        return self._embeddings.__getitem__(*args, **kwargs)\n\n\nclass _HistoryDataBase:\n    \"\"\"Base class for history data storage.\"\"\"\n    ALLOC_SIZE = 512\n    COPY_ON_RESIZE = False\n\n    def __init__(self, data: np.ndarray = None, dtype: np.dtype = np.int64):\n        self.dtype = dtype\n        self._data = None\n        self._num_real = 0\n\n        if data is None:\n            self._data = self._create_empty_array(dtype)\n        else:\n            self._data = data.astype(dtype) if hasattr(data, 'astype') else data\n            self._num_real = len(data)\n\n    def _create_empty_array(self, dtype):\n        \"\"\"Create empty array.\n\n        Override in subclass for different shapes.\n        \"\"\"\n        return np.empty((self.ALLOC_SIZE, ), dtype=dtype)\n\n    def _get_pad_width(self, reserve_size: int):\n        \"\"\"Get pad width for np.pad.\n\n        Override for multi-dimensional arrays.\n        \"\"\"\n        return (0, reserve_size)\n\n    def reserve(self, size: int):\n        \"\"\"Reserve cache.\"\"\"\n        if self._data is None:\n            return\n        num_tokens = len(self._data)\n        if num_tokens >= size:\n            return\n        reserve_size = _round_up(size - num_tokens, self.ALLOC_SIZE)\n        pad_width = self._get_pad_width(reserve_size)\n        self._data = np.pad(self._data, pad_width)\n\n    def get_real(self):\n        \"\"\"Get real data.\"\"\"\n        if self._data is None:\n            return None\n        return self._data[:self._num_real]\n\n    def resize(self, size: int):\n        \"\"\"Set size.\"\"\"\n        assert size <= self._num_real\n        self._num_real = size\n        if self.COPY_ON_RESIZE and self._data is not None:\n            self._data = self._data[:size].copy()\n\n    def append(self, new_data: np.ndarray):\n        \"\"\"Append data.\"\"\"\n        if self._data is None:\n            self._data = new_data.astype(self.dtype)\n            self._num_real = len(new_data)\n            return\n        num_tokens = len(new_data)\n        self.reserve(num_tokens + self._num_real)\n        slice_start = self._num_real\n        slice_end = slice_start + num_tokens\n        self._num_real += num_tokens\n        self._data[slice_start:slice_end] = new_data\n\n    def __setitem__(self, *args, **kwargs):\n        \"\"\"Set values.\"\"\"\n        return self.get_real().__setitem__(*args, **kwargs)\n\n    def __getitem__(self, *args, **kwargs):\n        \"\"\"Get values.\"\"\"\n        return self.get_real().__getitem__(*args, **kwargs)\n\n    def __len__(self):\n        \"\"\"Get length.\"\"\"\n        return self._num_real\n\n    def clone(self):\n        \"\"\"clone.\"\"\"\n        data = None if self._data is None else self.get_real().copy()\n        ret = type(self)(data, dtype=self.dtype)\n        return ret\n\n    def copy(self):\n        \"\"\"copy.\"\"\"\n        return self.clone()\n\n\nclass HistoryTokenIds(_HistoryDataBase):\n    \"\"\"History token ids.\"\"\"\n    ALLOC_SIZE = 512\n\n    def __init__(self, token_ids: np.ndarray = None, dtype: np.dtype = np.int64):\n        super().__init__(token_ids, dtype)\n\n    @property\n    def _token_ids(self):\n        \"\"\"For backward compatibility.\"\"\"\n        return self._data\n\n    @_token_ids.setter\n    def _token_ids(self, value):\n        \"\"\"For backward compatibility.\"\"\"\n        self._data = value\n\n\nclass HistoryRouterExperts(_HistoryDataBase):\n    \"\"\"History router experts.\"\"\"\n    ALLOC_SIZE = 64\n    COPY_ON_RESIZE = True\n\n    def __init__(self, expert_ids: np.ndarray = None, dtype: np.dtype = np.uint16):\n        super().__init__(expert_ids, dtype)\n\n    def _create_empty_array(self, dtype):\n        \"\"\"Create empty array.\n\n        Override in subclass for different shapes.\n        \"\"\"\n        return None\n\n    def _get_pad_width(self, reserve_size: int):\n        \"\"\"Get pad width for multi-dimensional array.\"\"\"\n        return ((0, reserve_size), (0, 0), (0, 0))\n\n\nclass HistoryLogits(_HistoryDataBase):\n    \"\"\"History logits.\"\"\"\n    ALLOC_SIZE = 64\n    COPY_ON_RESIZE = True\n\n    def __init__(self, logits: np.ndarray = None, dtype: np.dtype = np.int16):\n        super().__init__(logits, dtype)\n        self._torch_dtype = None\n\n    def _create_empty_array(self, dtype):\n        \"\"\"Create empty array.\n\n        Override in subclass for different shapes.\n        \"\"\"\n        return None\n\n    def _get_pad_width(self, reserve_size: int):\n        \"\"\"Get pad width for multi-dimensional array.\"\"\"\n        return ((0, reserve_size), (0, 0))\n\n    def set_torch_dtype(self, torch_dtype):\n        \"\"\"Set torch dtype.\"\"\"\n        self._torch_dtype = torch_dtype\n\n    def get_logits(self):\n        \"\"\"Get logits as torch tensor.\"\"\"\n        if self._data is None:\n            return None\n        if self._torch_dtype is None:\n            return None\n\n        logits_np = self.get_real()\n        return torch.frombuffer(logits_np, dtype=self._torch_dtype).view(logits_np.shape)\n\n    def clone(self):\n        \"\"\"clone.\"\"\"\n        ret = super().clone()\n        ret.set_torch_dtype(self._torch_dtype)\n        return ret\n\n\nclass HistoryMultiModals:\n\n    def __init__(self, multimodals: MultiModalInputs = None):\n        if multimodals is None:\n            multimodals = dict()\n        self.multimodals = multimodals\n\n    def get_datas(self, start=0, end=-1):\n        \"\"\"Get multimodals from prompts position [start, end).\"\"\"\n        outs: MultiModalInputs = dict()\n        test_range = range(start, end)\n        for modal_type, modal_datas in self.multimodals.items():\n            data = []\n            for modal_data in modal_datas:\n                if (modal_data.start not in test_range and modal_data.end - 1 not in test_range):\n                    continue\n                data.append(modal_data)\n            if len(data) > 0:\n                outs[modal_type] = data\n        return outs\n\n    def add_inputs(self, input_mms: MultiModalInputs):\n        \"\"\"Add new inputs.\"\"\"\n        for modal_type, vals in input_mms.items():\n            if modal_type in self.multimodals:\n                self.multimodals[modal_type] += vals\n            else:\n                self.multimodals[modal_type] = vals\n\n    def empty(self):\n        if len(self.multimodals) == 0:\n            return True\n\n        return all(len(vals) == 0 for vals in self.multimodals)\n\n    @staticmethod\n    def update_multimodals(input_mms: MultiModalInputs, prev_len: int):\n        \"\"\"Update multimodals.\"\"\"\n        for vals in input_mms.values():\n            for val in vals:\n                val.start += prev_len\n                val.end += prev_len\n        return input_mms\n\n\nclass UpdateTokenMode(enum.Enum):\n    \"\"\"Update token mode.\"\"\"\n    INPUTS = enum.auto()\n    PREFILL = enum.auto()\n    DECODE = enum.auto()\n\n\n@dataclass\nclass SchedulerSequence:\n    \"\"\"Scheduler message.\"\"\"\n    seq_id: int\n    session: SchedulerSession\n    history_cache: HistoryTokenIds = field(default_factory=HistoryTokenIds)\n    history_embeddings: HistoryEmbeddings = field(default_factory=HistoryEmbeddings)\n    history_multimodals: HistoryMultiModals = field(default_factory=HistoryMultiModals)\n    num_new_tokens: int = 0\n    sampling_param: SamplingParam = field(default_factory=SamplingParam)\n    logical_blocks: LogicalTokenBlocks = field(default_factory=LogicalTokenBlocks)\n    logical_state: int = -1\n    adapter_name: str = None\n    arrive_time: float = 0.0\n    output_start_pos: int = 0\n    meta: Any = None\n    num_ignored_history: int = 0\n    model_meta: Dict[str, Any] = None\n\n    # For Disaggregation\n    migration_request: None | MigrationRequest = None\n    resp_cache: bool = False\n    preserve_cache: bool = False\n\n    # For logging\n    engine_events: List[EngineEvent] = field(default_factory=list)\n\n    # for router replay\n    all_routed_experts: HistoryRouterExperts = field(default_factory=HistoryRouterExperts)\n\n    # logits\n    all_logits: HistoryLogits = field(default_factory=HistoryLogits)\n\n    def __post_init__(self):\n        \"\"\"Post init.\"\"\"\n        self._seq_meta: SequenceMeta = self.session.seq_meta\n        self._num_history_images: int = 0\n        self._num_history_ids: int = 0\n        self._num_token_ids: int = len(self.history_cache)\n\n        # vlm\n        self._num_images: int = len(self.history_embeddings)\n        self._state = None\n\n    @property\n    def block_size(self) -> int:\n        \"\"\"Block size.\"\"\"\n        return self._seq_meta.block_size\n\n    @property\n    def history_image_num(self) -> int:\n        \"\"\"Get history image number.\"\"\"\n        return self._num_history_images\n\n    @property\n    def history_image_token_len(self) -> int:\n        \"\"\"Get history image token length.\"\"\"\n        return sum([emb.end - emb.start for emb in self.history_embeddings[:self._num_history_images]])\n\n    @property\n    def session_id(self) -> int:\n        \"\"\"Get session id.\"\"\"\n        return self.session.session_id\n\n    @property\n    def token_ids(self) -> np.ndarray:\n        \"\"\"Token ids.\"\"\"\n        start = self.num_history_ids\n        end = start + self._num_token_ids\n        return self.history_cache[start:end]\n\n    @property\n    def input_embeddings(self) -> List[InputEmbeddings]:\n        \"\"\"Get current embeddings.\"\"\"\n        start = self.history_image_num\n        end = start + self._num_images\n        return self.history_embeddings[start:end]\n\n    @property\n    def history_ids(self) -> np.ndarray:\n        \"\"\"History ids.\"\"\"\n        return self.history_cache[:self.num_history_ids]\n\n    @property\n    def all_ids(self) -> np.ndarray:\n        \"\"\"Full token ids.\"\"\"\n        return self.history_cache[:self.num_all_ids]\n\n    @property\n    def valid_ids(self) -> np.ndarray:\n        \"\"\"Valid token ids.\"\"\"\n        return self.history_cache[:self.num_valid_ids]\n\n    @property\n    def generated_ids(self) -> np.ndarray:\n        end = self.num_valid_ids\n        start = end - self.num_new_tokens\n        return self.history_cache[start:end]\n\n    @property\n    def return_routed_experts(self) -> bool:\n        return self.sampling_param.return_routed_experts\n\n    @property\n    def routed_experts(self) -> np.ndarray:\n        if (not self.return_routed_experts) or self.all_routed_experts is None:\n            return None\n\n        end = max(0, self.num_all_ids - 1)\n        if 0 < end <= len(self.all_routed_experts):\n            return self.all_routed_experts.get_real()[:end]\n        else:\n            return None\n\n    def append_routed_experts(self, routed_experts: Tensor | np.ndarray):\n        \"\"\"Append routed experts.\"\"\"\n        if not self.return_routed_experts:\n            return\n        if routed_experts is None:\n            return\n        if isinstance(routed_experts, Tensor):\n            routed_experts = routed_experts.cpu().numpy()\n        self.all_routed_experts.append(routed_experts)\n\n    @property\n    def num_history_ids(self):\n        \"\"\"Num history ids.\"\"\"\n        return self._num_history_ids\n\n    @property\n    def num_token_ids(self):\n        return self._num_token_ids\n\n    @property\n    def num_valid_ids(self):\n        return self._num_history_ids + self._num_token_ids\n\n    @property\n    def num_images(self):\n        return self._num_images\n\n    @property\n    def num_all_ids(self):\n        \"\"\"Num all tokens.\"\"\"\n        return self._num_history_ids + self._num_token_ids\n\n    @property\n    def num_blocks(self):\n        \"\"\"Num blocks.\"\"\"\n        return len(self.logical_blocks)\n\n    @property\n    def state(self) -> 'StateBase':\n        return self._state\n\n    def set_state(self, state: 'StateBase'):\n        \"\"\"Set state.\"\"\"\n        self._state = state\n\n    @property\n    def status(self):\n        return self.state.status\n\n    @property\n    def return_logits(self):\n        return self.sampling_param.out_logits\n\n    @property\n    def logits(self):\n        \"\"\"Get logits.\"\"\"\n        return self.all_logits.get_logits()\n\n    def append_logits(self, logits: Tensor | np.ndarray):\n        \"\"\"Append logits.\"\"\"\n        if not self.return_logits:\n            return\n        if logits is None:\n            return\n        if isinstance(logits, Tensor):\n            self.all_logits.set_torch_dtype(logits.dtype)\n            logits = logits.view(torch.int16).numpy()\n        self.all_logits.append(logits)\n\n    def get_input_multimodals(self):\n        \"\"\"Get input multimodals.\"\"\"\n        start = self.num_history_ids\n        end = self.num_all_ids\n        return self.history_multimodals.get_datas(start, end)\n\n    def record_event(\n        self,\n        event_type: EventType,\n        timestamp: None | float = None,\n    ) -> None:\n        self.engine_events.append(EngineEvent.new_event(event_type, timestamp))\n\n    def _update_embeddings(self, embeddings: List[InputEmbeddings]):\n        \"\"\"Update input embeddings.\"\"\"\n        self._num_history_images += self._num_images\n        if embeddings is None:\n            self._num_images = 0\n            return\n        new_embeddings = [emb.move_position(self._num_history_ids) for emb in embeddings]\n        self._num_images = len(new_embeddings)\n        self.history_embeddings.append(new_embeddings)\n\n    def _update_multimodals(self, multimodals: MultiModalInputs):\n        \"\"\"Update input multimodals.\"\"\"\n        if multimodals is None:\n            return\n        multimodals = HistoryMultiModals.update_multimodals(multimodals, self.num_valid_ids)\n        self.history_multimodals.add_inputs(multimodals)\n\n    def update_token_ids(self,\n                         token_ids: Tensor,\n                         multimodals: MultiModalInputs = None,\n                         embeddings: List[InputEmbeddings] = None,\n                         model_meta: Dict[str, Any] = None,\n                         mode: UpdateTokenMode = UpdateTokenMode.INPUTS,\n                         **kwargs):\n        \"\"\"Update token ids, old token ids will be added to history.\"\"\"\n        raise NotImplementedError('NotImplemented')\n\n    def set_step(self, step: int):\n        \"\"\"Set step.\"\"\"\n        raise NotImplementedError('NotImplemented')\n"
  },
  {
    "path": "lmdeploy/pytorch/model_inputs.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom dataclasses import dataclass, field, fields\nfrom typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional\n\nimport numpy as np\nimport torch\nimport torch.distributed as torch_dist\nfrom torch.profiler import record_function\n\n# from torch import distributed as dist\nimport lmdeploy.pytorch.distributed as dist\nfrom lmdeploy.pytorch.backends import get_backend\nfrom lmdeploy.pytorch.config import CacheConfig, DLLMConfig, ModelConfig, QuantizationConfig\nfrom lmdeploy.pytorch.multimodal.data_type import MultiModalData\nfrom lmdeploy.pytorch.utils import CtxMgrBase, singleton\n\nif TYPE_CHECKING:\n    from lmdeploy.pytorch.strategies.base import StrategyFactoryBase\n\n\n@dataclass\nclass DPMeta:\n    tp_sizes: List[int] = None\n    moe_tp_sizes: List[int] = None\n\n    @staticmethod\n    def _gather_tp_sizes(tp: int, seqlen: int, num_tokens: List[int], dist_ctx: dist.DistContext, layer_type: str):\n        \"\"\"Gather tp size.\"\"\"\n        attn_tp = dist_ctx.dist_config.attn_tp\n        if tp > 1 and tp != attn_tp:\n            dist_group = dist.get_dist_group(layer_type=layer_type)\n            gather_group = dist_group.gpu_gather_group\n            ranks = torch_dist.get_process_group_ranks(gather_group)\n            tp_sizes = [num_tokens[r] for r in ranks]\n            assert all(size >= 0 for size in tp_sizes), (f'Invalid tp sizes: {tp_sizes}')\n        else:\n            tp_sizes = [seqlen]\n        return tp_sizes\n\n    @classmethod\n    def build(cls, seqlen: int, num_tokens: List[int]):\n        \"\"\"Get dp meta.\"\"\"\n        dist_ctx = dist.get_dist_manager().current_context()\n        dist_config = dist_ctx.dist_config\n\n        mlp_tp = dist_config.mlp_tp\n        tp_sizes = cls._gather_tp_sizes(mlp_tp, seqlen, num_tokens, dist_ctx, layer_type='mlp')\n\n        moe_tp = dist_config.moe_tp\n        if moe_tp == mlp_tp:\n            moe_tp_sizes = tp_sizes\n        else:\n            moe_tp_sizes = cls._gather_tp_sizes(moe_tp, seqlen, num_tokens, dist_ctx, layer_type='moe')\n\n        return DPMeta(tp_sizes=tp_sizes, moe_tp_sizes=moe_tp_sizes)\n\n    def sync_tp_size(self, tp_size: int):\n        self.tp_sizes = [tp_size] * len(self.tp_sizes)\n        self.moe_tp_sizes = [tp_size] * len(self.moe_tp_sizes)\n\n\n@dataclass\nclass VisionModelInputs:\n    \"\"\"Vision model inputs.\"\"\"\n    history_lengths: torch.LongTensor = None\n    input_embeddings: List[List[torch.Tensor]] = None\n    input_embedding_ranges: List[torch.LongTensor] = None\n    input_embedding_indexing: torch.BoolTensor = None\n    input_multimodals: List[MultiModalData] = None\n\n    def to_device(self, device: str, non_blocking: bool = False):\n        \"\"\"To device.\"\"\"\n        out_dict = dict()\n        for f in fields(self):\n            k = f.name\n            v = getattr(self, k)\n            if v is None:\n                continue\n            if isinstance(v, torch.Tensor):\n                v = v.to(device, non_blocking=non_blocking)\n            elif k == 'input_embedding_ranges':\n                v = [e.to(device, non_blocking=non_blocking) for e in v]\n            elif k == 'input_embeddings':\n                v = [[e.to(device, non_blocking=non_blocking) for e in li] for li in v]\n            elif k == 'input_multimodals':\n                new_v = []\n                for mm_datas in v:\n                    new_mm_datas = dict()\n                    for modal_type, data in mm_datas.items():\n                        data = [d.to_device(device, non_blocking=non_blocking) for d in data]\n                        new_mm_datas[modal_type] = data\n                    new_v.append(new_mm_datas)\n                v = new_v\n            out_dict[k] = v\n\n        return VisionModelInputs(**out_dict)\n\n    def get_inputs(self, history_lengths: torch.Tensor, seq_lengths: torch.Tensor):\n        \"\"\"Get vision embedding inputs.\"\"\"\n        input_embeddings = None\n        input_embedding_indexing = None\n        if self.input_embeddings is not None and len(self.input_embeddings) > 0:\n            input_embedding_li = []\n            for (his_len, seq_len, embeddings, emb_ranges) in zip(history_lengths, seq_lengths, self.input_embeddings,\n                                                                  self.input_embedding_ranges):\n                for emb, (emb_start, emb_end) in zip(embeddings, emb_ranges):\n                    start = max(emb_start, his_len) - emb_start\n                    end = min(emb_end, his_len + seq_len) - emb_start\n                    if 0 <= start < end:\n                        input_embedding_li.append(emb[start:end])\n            # has embeddings\n            if len(input_embedding_li) > 0:\n                input_embeddings = torch.cat(input_embedding_li, dim=0)\n                device = input_embeddings.device\n                starts = history_lengths - self.history_lengths\n                ends = starts + seq_lengths\n                input_embedding_indexing = torch.cat(\n                    [indexing[s:e] for indexing, s, e in zip(self.input_embedding_indexing, starts, ends)], dim=0)\n                index_ranges = torch.arange(input_embedding_indexing.numel(), device=device)\n                input_embedding_indexing = index_ranges[input_embedding_indexing]\n        return input_embeddings, input_embedding_indexing\n\n\n@dataclass\nclass ModelInputsDelta:\n    \"\"\"Delta of ModelInputs.\"\"\"\n    # valid indices\n    indices: Optional[torch.Tensor]\n    # new block offsets\n    block_offsets: torch.Tensor\n    # cpu copy of indices\n    indice_cpu: np.ndarray\n    max_q_seqlen: int\n    max_kv_seqlen: int\n    sum_kv_seqlen: int\n    is_decoding: bool = True\n    # sliding window\n    num_ignored_history: Optional[torch.Tensor] = None\n\n    @property\n    def seq_length(self):\n        \"\"\"Get seq_length.\"\"\"\n        batch_size = self.block_offsets.size(0)\n        return torch.full((batch_size, ), self.max_q_seqlen, dtype=torch.long)\n\n    def fill_tensors(self):\n        \"\"\"Fill tensor fields.\"\"\"\n        if self.indices is None:\n            self.indice_cpu = self.indice_cpu.copy()\n            self.indices = torch.as_tensor(self.indice_cpu)\n\n    @torch.inference_mode()\n    def to_device(self, device: str, non_blocking: bool = False):\n        \"\"\"To device.\"\"\"\n        out_dict = dict()\n        self.fill_tensors()\n        for f in fields(self):\n            k = f.name\n            v = getattr(self, k)\n            if isinstance(v, torch.Tensor):\n                v = v.to(device, non_blocking=non_blocking)\n            out_dict[k] = v\n\n        return ModelInputsDelta(**out_dict)\n\n    def log_info(self):\n        \"\"\"Get log info.\"\"\"\n        ret = (f'num_tokens={self.indices.numel()}, batch_size={self.indices.numel()}'\n               f', is_decoding={self.is_decoding}')\n        return ret\n\n\n@dataclass\nclass ModelInputs:\n    \"\"\"Input of the model.\"\"\"\n    input_ids: torch.Tensor\n    seq_length: torch.Tensor\n    history_lengths: torch.Tensor\n    block_offsets: torch.Tensor\n    is_decoding: bool\n    num_ignored_history: torch.Tensor\n    max_q_seqlen: int\n    max_kv_seqlen: int\n    sum_kv_seqlen: int\n    local_adapter_ids: torch.Tensor = None\n    vision_inputs: VisionModelInputs = None\n    model_metas: List[Dict[str, Any]] = None\n    dp_meta: 'DPMeta' = None\n    enable_microbatch: bool = False\n    is_dummy: bool = False\n    state_offsets: torch.Tensor = None\n    target_hidden_states: torch.Tensor = None\n    target_position_ids: torch.Tensor = None\n    is_chunk: bool = False\n    is_first_chunk: bool = True\n\n    def step(self, input_ids: torch.Tensor, step_seqlens: torch.Tensor = None):\n        \"\"\"Update input ids.\"\"\"\n        assert self.is_decoding\n        if step_seqlens is None:\n            step_seqlens = self.seq_length\n        self.history_lengths += step_seqlens\n        self.max_kv_seqlen += self.max_q_seqlen\n        self.sum_kv_seqlen += self.max_q_seqlen * self.seq_length.numel()\n        if input_ids.dim() == 1:\n            input_ids = input_ids[None, :]\n        self.input_ids = input_ids\n        return self\n\n    @torch.inference_mode()\n    def to_device(self, device: str, non_blocking: bool = False):\n        \"\"\"To device.\"\"\"\n        out_dict = dict()\n        for f in fields(self):\n            k = f.name\n            v = getattr(self, k)\n            if isinstance(v, torch.Tensor):\n                v = v.to(device, non_blocking=non_blocking)\n            elif isinstance(v, VisionModelInputs):\n                v = v.to_device(device, non_blocking=non_blocking)\n            out_dict[k] = v\n\n        return ModelInputs(**out_dict)\n\n    def build_dp_meta(self, num_tokens: List[int]):\n        \"\"\"Build dp meta.\"\"\"\n        self.dp_meta = DPMeta.build(self.input_ids.numel(), num_tokens)\n\n    def log_info(self):\n        \"\"\"Get log info.\"\"\"\n        ret = (f'num_tokens={self.input_ids.numel()}, batch_size={self.seq_length.numel()}'\n               f', is_decoding={self.is_decoding}, has_vision={self.vision_inputs is not None}')\n        return ret\n\n\n@dataclass\nclass StepContext:\n    \"\"\"Context of Model.\n\n    patched model might need extra information to perform inference. This dataclass provide these infos and tools.\n    \"\"\"\n    input_ids: torch.LongTensor\n    model_config: ModelConfig\n    cache_config: CacheConfig\n    block_offsets: torch.IntTensor\n    position_ids: torch.LongTensor\n    attention_mask: torch.LongTensor\n    q_seqlens: torch.LongTensor\n    kv_seqlens: torch.IntTensor\n    q_start_loc: torch.LongTensor\n    kv_caches: List\n    is_decoding: bool\n    sum_kv_seqlen: int\n    max_kv_seqlen: int = None\n    local_adapter_ids: torch.LongTensor = None\n    input_embeddings: torch.Tensor = None\n    input_embedding_indexing: torch.Tensor = None\n    input_multimodals: List[MultiModalData] = None\n    vision_inputs: VisionModelInputs = None\n    attn_metadata: Any = None\n    kv_quant_policy: Literal[0, 4, 8] = 0\n    model_metas: List[Dict[str, Any]] = None\n    dp_meta: DPMeta = None\n    enable_microbatch: bool = False\n    # for draft model\n    target_hidden_states: torch.Tensor = None\n\n    # states for ssm\n    state_caches: List = None\n    state_offsets: torch.LongTensor = None\n\n    _outputs: Dict = field(default_factory=dict)\n\n    @classmethod\n    def new(\n        cls,\n        inputs: ModelInputs,\n        model_config: ModelConfig,\n        cache_config: CacheConfig,\n        kv_caches: List = None,\n        state_caches: List = None,\n        kv_quant_policy: Literal[0, 4, 8] = 0,\n    ):\n        \"\"\"Build step context.\n\n        Args:\n            inputs (ModelInputs): packaged model inputs.\n            device (str): The device of the tensors.\n        \"\"\"\n        q_seqlens = inputs.seq_length\n        history_seqlens = inputs.history_lengths\n\n        input_multimodals = None\n        if inputs.vision_inputs is not None:\n            input_multimodals = inputs.vision_inputs.input_multimodals\n\n        # for vlm\n        input_embeddings, input_embedding_indexing = None, None\n        if (inputs.vision_inputs is not None and inputs.vision_inputs.input_embeddings is not None):\n            input_embeddings, input_embedding_indexing = \\\n                inputs.vision_inputs.get_inputs(history_seqlens, q_seqlens)\n\n        # position ids\n        attention_mask, position_ids = cls.get_mask_and_position_ids(inputs)\n        q_start_loc = q_seqlens.cumsum(0) - q_seqlens\n\n        # seq_len + history_length\n        kv_seqlens = q_seqlens + history_seqlens\n        kv_seqlens -= inputs.num_ignored_history\n\n        ret = StepContext(\n            input_ids=inputs.input_ids,\n            model_config=model_config,\n            cache_config=cache_config,\n            block_offsets=inputs.block_offsets,\n            position_ids=position_ids,\n            input_embeddings=input_embeddings,\n            input_embedding_indexing=input_embedding_indexing,\n            input_multimodals=input_multimodals,\n            attention_mask=attention_mask,\n            q_seqlens=q_seqlens,\n            kv_seqlens=kv_seqlens,\n            q_start_loc=q_start_loc,\n            kv_caches=kv_caches,\n            is_decoding=inputs.is_decoding,\n            sum_kv_seqlen=inputs.sum_kv_seqlen,\n            max_kv_seqlen=inputs.max_kv_seqlen,\n            local_adapter_ids=inputs.local_adapter_ids,\n            vision_inputs=inputs.vision_inputs,\n            kv_quant_policy=kv_quant_policy,\n            model_metas=inputs.model_metas,\n            dp_meta=inputs.dp_meta,\n            enable_microbatch=inputs.enable_microbatch,\n            state_caches=state_caches,\n            state_offsets=inputs.state_offsets,\n            target_hidden_states=inputs.target_hidden_states,\n        )\n\n        ret = get_backend().update_step_context(ret)\n        return ret\n\n    @classmethod\n    def get_mask_and_position_ids(cls, inputs: ModelInputs):\n        \"\"\"Get position ids.\"\"\"\n        q_seqlens = inputs.seq_length\n        history_seqlens = inputs.history_lengths\n        max_q_seqlen = inputs.max_q_seqlen\n        target_position_ids = inputs.target_position_ids\n        # decoding\n        if max_q_seqlen == 1:\n            attention_mask = torch.ones_like(q_seqlens)[:, None]\n            if target_position_ids is not None:\n                position_ids = target_position_ids\n            else:\n                position_ids = history_seqlens.unsqueeze(0).clone()\n            return attention_mask, position_ids\n\n        num_tokens = inputs.input_ids.numel()\n        batch_size = inputs.seq_length.numel()\n        device = q_seqlens.device\n\n        # batch with same seqlens\n        if max_q_seqlen * batch_size == num_tokens:\n            attention_mask = None\n            ranges = torch.arange(0, max_q_seqlen, device=device)\n            position_ids = history_seqlens[:, None] + ranges[None, :]\n            position_ids = position_ids.flatten()\n            return attention_mask, position_ids[None]\n\n        # get mask\n        mask_range = torch.arange(max_q_seqlen, device=device)[None, :]\n        attention_mask = (mask_range < q_seqlens[:, None]).long()\n        if target_position_ids is not None:\n            return attention_mask, target_position_ids\n\n        # position_ids\n        indices = attention_mask.long().cumsum(-1) - 1\n        position_ids = indices + history_seqlens.unsqueeze(-1)\n        indices[1:] += q_seqlens.cumsum(0)[:-1, None]\n        position_ids_1d = position_ids.new_empty(num_tokens)\n        position_ids_1d[indices.flatten()] = position_ids.flatten()\n        position_ids = position_ids_1d[None]\n        return attention_mask, position_ids\n\n\n@dataclass\nclass BuildModelContext:\n    \"\"\"Context for building model.\"\"\"\n    disable_vision_encoder: bool = False\n    dllm_config: DLLMConfig = None\n    strategy_factory: 'StrategyFactoryBase' = None\n    enable_return_routed_experts: bool = False\n    quant_config: QuantizationConfig = field(default_factory=QuantizationConfig)\n    fp32_lm_head: bool = False\n    tie_word_embeddings: bool = False\n\n\nclass StepContextManager(CtxMgrBase[StepContext]):\n\n    def __init__(self, build_ctx: BuildModelContext = None):\n        super().__init__(None)\n        build_ctx = build_ctx or BuildModelContext()\n        self.build_ctx = build_ctx\n\n    @record_function('build_step_context')\n    def build_context(\n        self,\n        inputs: ModelInputs,\n        model_config: ModelConfig,\n        cache_config: CacheConfig,\n        kv_caches: List = None,\n        state_caches: List = None,\n        kv_quant_policy: Literal[0, 4, 8] = 0,\n    ):\n        \"\"\"Build context.\"\"\"\n        return StepContext.new(\n            inputs,\n            model_config,\n            cache_config,\n            kv_caches,\n            state_caches,\n            kv_quant_policy,\n        )\n\n\n@singleton\nclass StepCtxMgrApi(CtxMgrBase[StepContextManager]):\n    \"\"\"Context manager for StepContextManager.\"\"\"\n\n    def __init__(self):\n        super().__init__(None)\n\n\nset_step_ctx_manager = StepCtxMgrApi().set_context\nget_step_ctx_manager = StepCtxMgrApi().current_context\nstep_ctx_manager = StepCtxMgrApi().context\n"
  },
  {
    "path": "lmdeploy/pytorch/models/__init__.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom .q_modules import QLinear, QRMSNorm\n\n__all__ = ['QLinear', 'QRMSNorm']\n"
  },
  {
    "path": "lmdeploy/pytorch/models/baichuan.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nfrom typing import Any, Iterable, List, Optional, Tuple\n\nimport torch\nfrom torch import nn\n\nfrom lmdeploy.pytorch.model_inputs import StepContext, StepContextManager\nfrom lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, RMSNorm, RopeType, SiluAndMul, build_rotary_embedding\nfrom lmdeploy.pytorch.nn.linear import (build_down_linear, build_gateup_linear, build_o_proj, build_qkv_proj,\n                                        build_rowwise_linear)\nfrom lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight\n\nfrom .utils.cudagraph import CudaGraphMixin\n\n\ndef _is_baichuan_13b(config: Any):\n    \"\"\"Is baichuan 13b.\"\"\"\n    return config.num_hidden_layers == 40\n\n\nclass BaichuanAttention(nn.Module):\n    \"\"\"Rewrite module of Attention.\"\"\"\n\n    def __init__(self, config: Any, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        quantization_config = getattr(config, 'quantization_config', None)\n        num_heads = config.num_attention_heads\n        num_key_value_heads = num_heads\n        hidden_size = config.hidden_size\n        head_dim = hidden_size // num_heads\n        self.is_13b = _is_baichuan_13b(config)\n\n        # packed qkv\n        self.W_pack = build_qkv_proj(\n            hidden_size,\n            num_q_heads=num_heads,\n            num_kv_heads=num_key_value_heads,\n            head_size=head_dim,\n            bias=False,\n            quant_config=quantization_config,\n            dtype=dtype,\n            device=device,\n        )\n\n        # rotary embedding\n        self.apply_rotary_pos_emb = ApplyRotaryEmb()\n\n        # attention\n        self.attn_fwd = Attention(\n            num_heads,\n            head_dim,\n            num_kv_heads=num_key_value_heads,\n            v_head_size=head_dim,\n            alibi=self.is_13b,\n        )\n\n        # o_proj\n        self.o_proj = build_o_proj(num_heads * head_dim,\n                                   hidden_size,\n                                   bias=False,\n                                   quant_config=quantization_config,\n                                   dtype=dtype,\n                                   device=device,\n                                   is_tp=True)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],\n        past_key_value: Optional[Tuple[torch.Tensor]] = None,\n        attn_metadata: Any = None,\n    ):\n        \"\"\"Rewrite of LlamaAttention.forward.\"\"\"\n        # qkv proj\n        qkv_states = self.W_pack(hidden_states)\n        # (-1, heads, head_dim)\n        qkv_states = qkv_states.flatten(0, -2)\n        query_states, key_states, value_states = self.W_pack.split_qkv(qkv_states)\n\n        # apply rotary embedding\n        if not self.is_13b:\n            cos, sin = rotary_pos_emb\n            query_states, key_states = self.apply_rotary_pos_emb(\n                query_states,\n                key_states,\n                cos,\n                sin,\n                inplace=True,\n            )\n\n        # attention\n        attn_output = self.attn_fwd(\n            query_states,\n            key_states,\n            value_states,\n            past_key_value[0],\n            past_key_value[1],\n            attn_metadata,\n            k_scales_zeros=None if len(past_key_value) == 2 else past_key_value[2],\n            v_scales_zeros=None if len(past_key_value) == 2 else past_key_value[3],\n            inplace=True,\n        )\n        attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1)\n\n        # o proj\n        attn_output = self.o_proj(attn_output)\n        return attn_output\n\n\nclass MLP(nn.Module):\n\n    def __init__(self, config: Any, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        quantization_config = getattr(config, 'quantization_config', None)\n        # gate up\n        self.gate_up_proj = build_gateup_linear(\n            config.hidden_size,\n            [config.intermediate_size, config.intermediate_size],\n            bias=False,\n            dtype=dtype,\n            device=device,\n            quant_config=quantization_config,\n            is_tp=True,\n        )\n\n        # silu and mul\n        self.act_fn = SiluAndMul(inplace=True)\n\n        # down\n        self.down_proj = build_down_linear(config.intermediate_size,\n                                           config.hidden_size,\n                                           bias=False,\n                                           quant_config=quantization_config,\n                                           dtype=dtype,\n                                           device=device,\n                                           is_tp=True)\n\n    def forward(self, x):\n        \"\"\"forward.\"\"\"\n        gate_up = self.gate_up_proj(x)\n        act = self.act_fn(gate_up)\n        return self.down_proj(act)\n\n\nclass DecoderLayer(nn.Module):\n    \"\"\"Baichuan decoder layer.\"\"\"\n\n    def __init__(self, config: Any, layer_idx: int, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        self.layer_idx = layer_idx\n        quantization_config = getattr(config, 'quantization_config', None)\n\n        # build attention layer\n        self.self_attn = BaichuanAttention(config, dtype=dtype, device=device)\n\n        # build MLP\n        self.mlp = MLP(config, dtype=dtype, device=device)\n\n        # build input layer norm\n        self.input_layernorm = RMSNorm(config.hidden_size,\n                                       config.rms_norm_eps,\n                                       quant_config=quantization_config,\n                                       dtype=dtype,\n                                       device=device)\n\n        # build attention layer norm\n        self.post_attention_layernorm = RMSNorm(config.hidden_size,\n                                                config.rms_norm_eps,\n                                                quant_config=quantization_config,\n                                                dtype=dtype,\n                                                device=device)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],\n        past_key_value: Optional[List[torch.FloatTensor]],\n        residual: Optional[torch.Tensor] = None,\n        attn_metadata: Any = None,\n    ):\n        \"\"\"forward.\"\"\"\n        if residual is None:\n            residual = hidden_states\n            hidden_states = self.input_layernorm(hidden_states)\n        else:\n            hidden_states, residual = self.input_layernorm(hidden_states, residual)\n\n        # Self Attention\n        hidden_states = self.self_attn(\n            hidden_states=hidden_states,\n            rotary_pos_emb=rotary_pos_emb,\n            past_key_value=past_key_value,\n            attn_metadata=attn_metadata,\n        )\n\n        # Fully Connected\n        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)\n        hidden_states = self.mlp(hidden_states)\n\n        outputs = (hidden_states, residual)\n        return outputs\n\n\nclass BaichuanModel(nn.Module):\n    \"\"\"Baichuan model.\"\"\"\n\n    def __init__(self, config: Any, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n\n        self.embed_tokens = nn.Embedding(config.vocab_size,\n                                         config.hidden_size,\n                                         self.padding_idx,\n                                         dtype=dtype,\n                                         device=device)\n\n        # build all decode layers\n        self.layers = nn.ModuleList([\n            DecoderLayer(config, layer_idx, dtype=dtype, device=device) for layer_idx in range(config.num_hidden_layers)\n        ])\n\n        # build norm\n        self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)\n\n        self.is_13b = _is_baichuan_13b(config)\n        if not self.is_13b:\n            # build rotary embedding in LlamaModel\n            emb_type = RopeType.LinearScaling\n            rope_dim = config.hidden_size // config.num_attention_heads\n            rope_max_pos_emb = config.max_position_embeddings\n            rope_base = 10000\n            scaling_factor = 1.0\n            self.rotary_emb = build_rotary_embedding(\n                rope_dim,\n                rope_max_pos_emb,\n                rope_base,\n                scaling_factor,\n                emb_type=emb_type,\n            )\n\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        attn_metadata: Any = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n    ):\n        \"\"\"forward.\"\"\"\n\n        # token embedding\n        if inputs_embeds is None:\n            inputs_embeds = self.embed_tokens(input_ids)\n\n        hidden_states = inputs_embeds\n\n        # rotary embedding\n        rotary_pos_emb = (None, None)\n        if not self.is_13b:\n            cos, sin = self.rotary_emb(hidden_states, position_ids)\n            cos, sin = cos[0], sin[0]\n            rotary_pos_emb = (cos, sin)\n\n        # decoding\n        residual = None\n        for idx, decoder_layer in enumerate(self.layers):\n            past_key_value = past_key_values[idx]\n            hidden_states, residual = decoder_layer(\n                hidden_states,\n                rotary_pos_emb=rotary_pos_emb,\n                past_key_value=past_key_value,\n                residual=residual,\n                attn_metadata=attn_metadata,\n            )\n\n        # norm\n        hidden_states, _ = self.norm(hidden_states, residual)\n\n        return hidden_states\n\n    def get_input_embeddings(self):\n        \"\"\"Get input embeddings.\"\"\"\n        return self.embed_tokens\n\n\nclass BaichuanForCausalLM(nn.Module, CudaGraphMixin):\n    \"\"\"Rewrote model of LlamaForCausalLM.\"\"\"\n\n    packed_modules_mapping = {\n        'gate_up_proj': [\n            'gate_proj',\n            'up_proj',\n        ],\n    }\n\n    def __init__(self,\n                 config: Any,\n                 ctx_mgr: StepContextManager,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__()\n        self.config = config\n        self.ctx_mgr = ctx_mgr\n        # build BaichuanModel\n        self.model = BaichuanModel(config, dtype=dtype, device=device)\n        # build lm_head\n        self.lm_head = build_rowwise_linear(config.hidden_size,\n                                            config.vocab_size,\n                                            bias=False,\n                                            dtype=dtype,\n                                            device=device)\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        past_key_values: List[List[torch.Tensor]],\n        attn_metadata: Any = None,\n        inputs_embeds: torch.Tensor = None,\n        **kwargs,\n    ):\n        \"\"\"Model forward, return logits.\"\"\"\n        hidden_states = self.model(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            inputs_embeds=inputs_embeds,\n        )\n        return hidden_states\n\n    def get_logits(self, hidden_states: torch.Tensor):\n        \"\"\"Compute logits of the model output.\"\"\"\n        return self.lm_head(hidden_states)\n\n    def get_input_embeddings(self):\n        \"\"\"Get input embeddings.\"\"\"\n        return self.model.get_input_embeddings()\n\n    def prepare_inputs_for_generation(\n        self,\n        past_key_values: List[List[torch.Tensor]],\n        inputs_embeds: Optional[torch.Tensor] = None,\n        context: StepContext = None,\n    ):\n        \"\"\"Prepare input.\"\"\"\n        # get input_ids, position_ids and attention metadatas\n        input_ids = context.input_ids\n        position_ids = context.position_ids\n        attn_metadata = context.attn_metadata\n\n        # process vision embeddings\n        vision_embeddings = context.input_embeddings\n        vision_embedding_indexing = context.input_embedding_indexing\n        if vision_embeddings is not None and len(vision_embeddings) > 0:\n            if inputs_embeds is None:\n                inputs_embeds = self.get_input_embeddings()(input_ids)\n            inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds)\n\n        # inputs of forward\n        return dict(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            inputs_embeds=inputs_embeds,\n        )\n\n    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):\n        \"\"\"Load weights.\"\"\"\n        # modify from vllm\n        stacked_params_mapping = [\n            # (param_name, shard_name, shard_id)\n            ('.gate_up_proj', '.gate_proj', 0),\n            ('.gate_up_proj', '.up_proj', 1),\n        ]\n\n        params_dict = dict(self.named_parameters())\n        for name, loaded_weight in weights:\n            if 'rotary_emb.inv_freq' in name:\n                continue\n            if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):\n                continue\n            if self.config.tie_word_embeddings and 'lm_head.weight' in name:\n                continue\n            for (param_name, weight_name, shard_id) in stacked_params_mapping:\n                if weight_name not in name:\n                    continue\n                name = name.replace(weight_name, param_name)\n                param = params_dict[name]\n                load_weight(param, loaded_weight, shard_id=shard_id)\n                break\n            else:\n                if '.W_pack' in name:\n                    param = params_dict[name]\n                    q, k, v = param.weight_spliter(loaded_weight)\n                    load_weight(param, q, shard_id='q')\n                    load_weight(param, k, shard_id='k')\n                    load_weight(param, v, shard_id='v')\n                elif 'lm_head' in name:\n                    loaded_weight = nn.functional.normalize(loaded_weight)\n                    param = params_dict[name]\n                    load_weight(param, loaded_weight)\n                else:\n                    param = params_dict[name]\n                    load_weight(param, loaded_weight)\n"
  },
  {
    "path": "lmdeploy/pytorch/models/chatglm2.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nfrom typing import Any, Dict, Iterable, List, Optional, Tuple\n\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom transformers.configuration_utils import PretrainedConfig\n\nfrom lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor, PreprocessInputResult\nfrom lmdeploy.pytorch.model_inputs import StepContext, StepContextManager\nfrom lmdeploy.pytorch.multimodal.data_type import MultiModalData\nfrom lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, RMSNorm, RopeType, SiluAndMul, build_rotary_embedding,\n                                 build_rotary_params)\nfrom lmdeploy.pytorch.nn.linear import (build_colwise_linear, build_down_linear, build_gateup_linear, build_o_proj,\n                                        build_qkv_proj, build_rowwise_linear)\nfrom lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight\n\nfrom .utils.cudagraph import CudaGraphMixin\nfrom .utils.model import DeployModelMixin, vlm_model\n\nLANGUAGE_TOKEN_TYPE = 0\nVISION_TOKEN_TYPE = 1\n\n\nclass SelfAttention(torch.nn.Module):\n    \"\"\"Parallel self-attention layer abstract class.\n\n    Self-attention layer takes input with size [s, b, h] and returns output of the same size.\n    \"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        quantization_config = getattr(config, 'quantization_config', None)\n\n        self.projection_size = config.kv_channels * config.num_attention_heads\n        self.num_attention_heads = config.num_attention_heads\n        self.num_kv_heads = config.num_key_value_heads\n        self.head_size = (self.projection_size // config.num_attention_heads)\n        num_replicate_kv_heads = getattr(config, 'num_replicate_key_value_heads', 1)\n        self.query_key_value = build_qkv_proj(config.hidden_size,\n                                              num_q_heads=self.num_attention_heads,\n                                              num_kv_heads=self.num_kv_heads,\n                                              head_size=self.head_size,\n                                              bias=config.add_bias_linear or config.add_qkv_bias,\n                                              quant_config=quantization_config,\n                                              dtype=dtype,\n                                              device=device,\n                                              num_replicate_kv_heads=num_replicate_kv_heads)\n\n        # apply rotary\n        self.apply_rotary_pos_emb = ApplyRotaryEmb()\n\n        # attention\n        self.attn_fwd = Attention(\n            self.num_attention_heads,\n            self.head_size,\n            num_kv_heads=self.num_kv_heads,\n        )\n\n        # o_proj\n        self.dense = build_o_proj(self.projection_size,\n                                  config.hidden_size,\n                                  bias=config.add_bias_linear,\n                                  quant_config=quantization_config,\n                                  dtype=dtype,\n                                  device=device,\n                                  is_tp=True)\n\n    @staticmethod\n    def _extract_rope(states: torch.Tensor):\n        \"\"\"Extract rope.\"\"\"\n        rope = states.chunk(2, -1)[0]\n        rope = rope.unflatten(-1, (-1, 2))\n        rope = rope.transpose(-2, -1).flatten(-2, -1).contiguous()\n        return rope\n\n    @staticmethod\n    def _fill_rope(states: torch.Tensor, rope: torch.Tensor):\n        \"\"\"Fill rope.\"\"\"\n        rope_part = states.chunk(2, -1)[0]\n        rope = rope.unflatten(-1, (2, -1))\n        rope = rope.transpose(-2, -1).flatten(-2, -1)\n        rope_part.copy_(rope)\n        return states\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],\n        past_key_value: Optional[Tuple[torch.Tensor]] = None,\n        attn_metadata: Any = None,\n    ):\n        \"\"\"Rewrite of LlamaAttention.forward.\"\"\"\n        # qkv proj\n        qkv_states = self.query_key_value(hidden_states)\n        # (-1, heads, head_dim)\n        qkv_states = qkv_states.flatten(0, -2)\n        (query_states, key_states, value_states) = self.query_key_value.split_qkv(qkv_states)\n\n        # apply rotary embedding\n        cos, sin = rotary_pos_emb\n        q_rope = self._extract_rope(query_states)\n        k_rope = self._extract_rope(key_states)\n        q_rope, k_rope = self.apply_rotary_pos_emb(\n            q_rope,\n            k_rope,\n            cos,\n            sin,\n            inplace=True,\n        )\n        query_states = self._fill_rope(query_states, q_rope)\n        key_states = self._fill_rope(key_states, k_rope)\n\n        # attention\n        attn_output = self.attn_fwd(\n            query_states,\n            key_states,\n            value_states,\n            past_key_value[0],\n            past_key_value[1],\n            attn_metadata,\n            k_scales_zeros=None if len(past_key_value) == 2 else past_key_value[2],\n            v_scales_zeros=None if len(past_key_value) == 2 else past_key_value[3],\n            inplace=True,\n        )\n        attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1)\n\n        # o proj\n        attn_output = self.dense(attn_output)\n        return attn_output\n\n\nclass MLP(nn.Module):\n    \"\"\"mlp.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        quantization_config = getattr(config, 'quantization_config', None)\n\n        self.add_bias = config.add_bias_linear\n        # gate up\n        self.dense_h_to_4h = build_gateup_linear(\n            config.hidden_size,\n            [config.ffn_hidden_size, config.ffn_hidden_size],\n            bias=self.add_bias,\n            dtype=dtype,\n            device=device,\n            quant_config=quantization_config,\n            is_tp=True,\n        )\n\n        # silu and mul\n        self.act_fn = SiluAndMul(inplace=True)\n\n        # down\n        self.dense_4h_to_h = build_down_linear(config.ffn_hidden_size,\n                                               config.hidden_size,\n                                               bias=self.add_bias,\n                                               quant_config=quantization_config,\n                                               dtype=dtype,\n                                               device=device,\n                                               is_tp=True)\n\n    def forward(self, x):\n        \"\"\"forward.\"\"\"\n        gate_up = self.dense_h_to_4h(x)\n        act = self.act_fn(gate_up)\n        return self.dense_4h_to_h(act)\n\n\nclass GLMBlock(torch.nn.Module):\n    \"\"\"A single transformer layer.\n\n    Transformer layer takes input with size [s, b, h] and returns an output of the same size.\n    \"\"\"\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 layer_number: int,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__()\n        self.layer_number = layer_number\n        self.apply_residual_connection_post_layernorm = \\\n            config.apply_residual_connection_post_layernorm\n        assert not self.apply_residual_connection_post_layernorm\n\n        quantization_config = getattr(config, 'quantization_config', None)\n\n        # build attention layer\n        self.self_attention = SelfAttention(config, dtype=dtype, device=device)\n\n        # build MLP\n        self.mlp = MLP(config, dtype=dtype, device=device)\n\n        # build input layer norm\n        self.input_layernorm = RMSNorm(config.hidden_size,\n                                       config.layernorm_epsilon,\n                                       quant_config=quantization_config,\n                                       dtype=dtype,\n                                       device=device)\n\n        # build attention layer norm\n        self.post_attention_layernorm = RMSNorm(config.hidden_size,\n                                                config.layernorm_epsilon,\n                                                quant_config=quantization_config,\n                                                dtype=dtype,\n                                                device=device)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],\n        past_key_value: Optional[List[torch.FloatTensor]],\n        residual: Optional[torch.Tensor] = None,\n        attn_metadata: Any = None,\n    ):\n\n        if residual is None:\n            residual = hidden_states\n            layernorm_output = self.input_layernorm(hidden_states)\n        else:\n            layernorm_output, residual = self.input_layernorm(hidden_states, residual)\n\n        # Self Attention\n        layernorm_input = self.self_attention(\n            hidden_states=layernorm_output,\n            rotary_pos_emb=rotary_pos_emb,\n            past_key_value=past_key_value,\n            attn_metadata=attn_metadata,\n        )\n\n        # Fully Connected\n        layernorm_output, residual = self.post_attention_layernorm(layernorm_input, residual)\n        mlp_output = self.mlp(layernorm_output)\n\n        outputs = (mlp_output, residual)\n        return outputs\n\n\nclass GLMTransformer(nn.Module):\n    \"\"\"Transformer class.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        self.num_layers = config.num_layers\n        self.post_layer_norm = config.post_layer_norm\n\n        def build_layer(layer_number):\n            \"\"\"Build layer.\"\"\"\n            return GLMBlock(config, layer_number, dtype=dtype, device=device)\n\n        self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)])\n\n        if self.post_layer_norm:\n            assert config.rmsnorm\n            self.final_layernorm = RMSNorm(config.hidden_size, config.layernorm_epsilon, dtype=dtype, device=device)\n\n    def _get_layer(self, layer_number: int):\n        \"\"\"Get layer.\"\"\"\n        return self.layers[layer_number]\n\n    def forward(\n        self,\n        hidden_states: torch.LongTensor,\n        rotary_pos_emb: List[torch.Tensor],\n        past_key_values: Optional[List[torch.FloatTensor]],\n        attn_metadata: Any,\n    ):\n        \"\"\"forward.\"\"\"\n        residual = None\n        for index in range(self.num_layers):\n            layer = self._get_layer(index)\n            hidden_states, residual = layer(\n                hidden_states,\n                rotary_pos_emb,\n                past_key_value=past_key_values[index],\n                residual=residual,\n                attn_metadata=attn_metadata,\n            )\n\n        if self.post_layer_norm:\n            hidden_states, _ = self.final_layernorm(hidden_states, residual)\n        return hidden_states\n\n\nclass Embedding(nn.Module):\n    \"\"\"Language model embeddings.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        self.hidden_size = config.hidden_size\n        # Word embeddings (parallel).\n        self.word_embeddings = nn.Embedding(config.padded_vocab_size, self.hidden_size, dtype=dtype, device=device)\n        self.fp32_residual_connection = config.fp32_residual_connection\n\n    def forward(self, input_ids):\n        \"\"\"Rewrite to not transpose hidden_statens for all models.\"\"\"\n        # Embeddings.\n        embeddings = self.word_embeddings(input_ids)\n        if self.fp32_residual_connection:\n            embeddings = embeddings.float()\n        return embeddings\n\n\nclass PatchEmbedding(nn.Module):\n    \"\"\"Vision embedding.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        self.proj = nn.Conv2d(config.in_channels,\n                              config.hidden_size,\n                              kernel_size=config.patch_size,\n                              stride=config.patch_size,\n                              dtype=dtype,\n                              device=device)\n        self.cls_embedding = nn.Parameter(torch.empty(1, config.hidden_size, dtype=dtype, device=device))\n        self.position_embedding = nn.Embedding(config.num_positions, config.hidden_size, dtype=dtype, device=device)\n\n    def forward(self, images):\n        \"\"\"forward.\"\"\"\n        x = self.proj(images)\n        x = x.flatten(2).transpose(1, 2)\n        cls_token = self.cls_embedding.expand(x.shape[0], -1, -1)\n        x = torch.cat((cls_token, x), dim=1)\n        x += self.position_embedding.weight.unsqueeze(0)\n        return x\n\n\nclass EVA2CLIPAttention(nn.Module):\n    \"\"\"Vision attention.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        quantization_config = getattr(config, 'quantization_config', None)\n        hidden_size = config.hidden_size\n        num_heads = config.num_heads\n        head_dim = config.hidden_size // config.num_heads\n        self.scale = head_dim**-0.5\n\n        # packed qkv\n        self.query_key_value = build_qkv_proj(\n            hidden_size,\n            num_q_heads=num_heads,\n            num_kv_heads=num_heads,\n            head_size=head_dim,\n            bias=True,\n            quant_config=quantization_config,\n            dtype=dtype,\n            device=device,\n        )\n\n        # o_proj\n        self.dense = build_rowwise_linear(hidden_size,\n                                          hidden_size,\n                                          bias=True,\n                                          quant_config=quantization_config,\n                                          dtype=dtype,\n                                          device=device,\n                                          is_tp=True)\n\n    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        \"\"\"forward.\"\"\"\n        # qkv proj\n        qkv_states = self.query_key_value(hidden_states)\n        q, k, v = self.query_key_value.split_qkv(qkv_states)\n\n        q = q.transpose(1, 2)\n        k = k.transpose(1, 2)\n        v = v.transpose(1, 2)\n\n        attn_output = F.scaled_dot_product_attention(q, k, v, scale=self.scale)\n\n        # o proj\n        attn_output = attn_output.transpose(1, 2)\n        attn_output = attn_output.flatten(-2, -1)\n        attn_output = self.dense(attn_output)\n        return attn_output\n\n\nclass EVA2CLIPMLP(nn.Module):\n    \"\"\"Vision MLP.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        from transformers.activations import ACT2FN\n\n        # gate up\n        quantization_config = getattr(config, 'quantization_config', None)\n        self.fc1 = build_colwise_linear(\n            config.hidden_size,\n            config.intermediate_size,\n            bias=True,\n            dtype=dtype,\n            device=device,\n            quant_config=quantization_config,\n            is_tp=True,\n        )\n\n        # silu and mul\n        if config.hidden_act in ['gelu', 'gelu_fast', 'quick_gelu', 'gelu_python']:\n            self.activation_fn = nn.GELU()\n        else:\n            self.activation_fn = ACT2FN[config.hidden_act]\n\n        # down\n        self.fc2 = build_rowwise_linear(config.intermediate_size,\n                                        config.hidden_size,\n                                        bias=True,\n                                        quant_config=quantization_config,\n                                        dtype=dtype,\n                                        device=device,\n                                        is_tp=True)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        \"\"\"forward.\"\"\"\n        x = self.fc1(x)\n        x = self.activation_fn(x)\n        x = self.fc2(x)\n        return x\n\n\nclass EVA2CLIPTransformerLayer(nn.Module):\n    \"\"\"Vision trans layer.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, dtype=dtype, device=device)\n        self.attention = EVA2CLIPAttention(config, dtype=dtype, device=device)\n        self.mlp = EVA2CLIPMLP(config, dtype=dtype, device=device)\n        self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,\n                                                     eps=config.layer_norm_eps,\n                                                     dtype=dtype,\n                                                     device=device)\n\n    def forward(self, hidden_states):\n        \"\"\"forward.\"\"\"\n        attention_input = hidden_states\n        attention_output = self.input_layernorm(self.attention(attention_input))\n        hidden_states = attention_input + attention_output\n        mlp_input = hidden_states\n        mlp_output = self.post_attention_layernorm(self.mlp(mlp_input))\n        output = mlp_input + mlp_output\n        return output\n\n\nclass EVA2CLIPTransformer(nn.Module):\n    \"\"\"Vision transformer.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        self.layers = nn.ModuleList(\n            [EVA2CLIPTransformerLayer(config, dtype=dtype, device=device) for _ in range(config.num_hidden_layers)])\n\n    def forward(self, hidden_states):\n        \"\"\"forward.\"\"\"\n        for layer_module in self.layers:\n            hidden_states = layer_module(hidden_states)\n        return hidden_states\n\n\nclass GLU(nn.Module):\n    \"\"\"GLU.\"\"\"\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 in_features: int,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__()\n        self.linear_proj = nn.Linear(in_features, config.hidden_size, bias=False, dtype=dtype, device=device)\n        self.norm1 = nn.LayerNorm(config.hidden_size, dtype=dtype, device=device)\n        self.act1 = nn.GELU()\n        self.act2 = nn.functional.silu\n        self.dense_h_to_4h = nn.Linear(config.hidden_size,\n                                       config.ffn_hidden_size,\n                                       bias=False,\n                                       dtype=dtype,\n                                       device=device)\n        self.gate_proj = nn.Linear(config.hidden_size, config.ffn_hidden_size, bias=False, dtype=dtype, device=device)\n        self.dense_4h_to_h = nn.Linear(config.ffn_hidden_size,\n                                       config.hidden_size,\n                                       bias=False,\n                                       dtype=dtype,\n                                       device=device)\n\n    def forward(self, x):\n        x = self.linear_proj(x)\n        x = self.act1(self.norm1(x))\n        x = self.act2(self.gate_proj(x)) * self.dense_h_to_4h(x)\n        x = self.dense_4h_to_h(x)\n        return x\n\n\n@vlm_model\nclass EVA2CLIPModel(nn.Module):\n    \"\"\"Vision model.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        from argparse import Namespace\n        vision_config = Namespace(**config.vision_config)\n\n        self.patch_embedding = PatchEmbedding(vision_config, dtype=dtype, device=device)\n        self.transformer = EVA2CLIPTransformer(vision_config, dtype=dtype, device=device)\n        self.linear_proj = GLU(config, in_features=config.hidden_size, dtype=dtype, device=device)\n        self.conv = nn.Conv2d(in_channels=vision_config.hidden_size,\n                              out_channels=config.hidden_size,\n                              kernel_size=2,\n                              stride=2,\n                              dtype=dtype,\n                              device=device)\n        self.boi = nn.Parameter(torch.empty(1, 1, config.hidden_size, dtype=dtype, device=device))\n        self.eoi = nn.Parameter(torch.empty(1, 1, config.hidden_size, dtype=dtype, device=device))\n        self.scaling_factor = vision_config.scaling_factor\n\n    def forward(self, images):\n        \"\"\"forward.\"\"\"\n        x = self.patch_embedding(images)\n        x = self.transformer(x)\n\n        x = x[:, 1:]\n\n        b, s, h = x.shape\n        grid_size = int(s**0.5)\n        x = x.view(b, grid_size, grid_size, h).permute(0, 3, 1, 2)\n        x = self.conv(x)\n\n        x = x.flatten(2).transpose(1, 2)\n        x = self.linear_proj(x)\n        boi = self.boi.expand(x.shape[0], -1, -1)\n        eoi = self.eoi.expand(x.shape[0], -1, -1)\n        x = torch.cat((boi, x, eoi), dim=1)\n        x = x / self.scaling_factor\n        return x\n\n\nclass ChatGLMModel(nn.Module):\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        self.embedding = Embedding(config, dtype=dtype, device=device)\n\n        # build rotary embedding\n        emb_type = RopeType.LinearScaling\n        rotary_dim = (config.hidden_size //\n                      config.num_attention_heads if config.kv_channels is None else config.kv_channels)\n        rope_max_pos_emb = 1 << 20\n        rope_base = 10000 * getattr(config, 'rope_ratio', 1.0)\n        rope_params = dict(emb_type=emb_type,\n                           dim=rotary_dim // 2,\n                           max_position_embeddings=rope_max_pos_emb,\n                           base=rope_base)\n        update_params = build_rotary_params(config)\n        rope_params.update(update_params)\n        self.rotary_pos_emb = build_rotary_embedding(**rope_params)\n\n        # build encoder\n        self.encoder = GLMTransformer(config, dtype=dtype, device=device)\n\n        # output_layers\n        self.output_layer = build_rowwise_linear(config.hidden_size,\n                                                 config.padded_vocab_size,\n                                                 bias=False,\n                                                 dtype=dtype,\n                                                 device=device)\n\n        self.vision = None\n        if hasattr(config, 'vision_config'):\n            self.vision = EVA2CLIPModel(config, dtype=dtype, device=device)\n\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        attn_metadata: Any = None,\n        images: torch.Tensor = None,\n        image_mask: torch.Tensor = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n    ):\n        \"\"\"forward.\"\"\"\n\n        # token embedding\n        if inputs_embeds is None:\n            images_features = None\n            if images is not None:\n                images_features = self.vision(images)\n                images_features = images_features.flatten(0, 1)[None]\n            inputs_embeds = self.embedding(input_ids)\n            if images is not None:\n                inputs_embeds.masked_scatter_(image_mask[..., None], images_features)\n\n        hidden_states = inputs_embeds\n\n        # rotary embedding\n        cos, sin = self.rotary_pos_emb(hidden_states, position_ids)\n        cos, sin = cos[0], sin[0]\n        rotary_pos_emb = (cos, sin)\n\n        hidden_states = self.encoder(\n            hidden_states,\n            rotary_pos_emb=rotary_pos_emb,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n        )\n\n        return hidden_states\n\n    def get_input_embeddings(self):\n        \"\"\"Get input embeddings.\"\"\"\n        return self.embedding\n\n\nclass ChatGLMForConditionalGeneration(nn.Module, DeployModelMixin, CudaGraphMixin):\n    \"\"\"Rewrote model of LlamaForCausalLM.\"\"\"\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 ctx_mgr: StepContextManager,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__()\n        self.config = config\n        self.ctx_mgr = ctx_mgr\n        # build Model\n        self.transformer = ChatGLMModel(config, dtype=dtype, device=device)\n\n        self.input_processor = ChatGLMInputProcessor(self.config, dtype)\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        past_key_values: List[List[torch.Tensor]],\n        attn_metadata: Any = None,\n        images: torch.Tensor = None,\n        image_mask: torch.Tensor = None,\n        inputs_embeds: torch.Tensor = None,\n        **kwargs,\n    ):\n        \"\"\"Model forward, return logits.\"\"\"\n        hidden_states = self.transformer(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            images=images,\n            image_mask=image_mask,\n            inputs_embeds=inputs_embeds,\n        )\n        return hidden_states\n\n    def get_logits(self, hidden_states: torch.Tensor):\n        \"\"\"Compute logits of the model output.\"\"\"\n        return self.transformer.output_layer(hidden_states)\n\n    def get_input_embeddings(self):\n        \"\"\"Get input embeddings.\"\"\"\n        return self.transformer.get_input_embeddings()\n\n    def prepare_inputs_for_generation(\n        self,\n        past_key_values: List[List[torch.Tensor]],\n        inputs_embeds: Optional[torch.Tensor] = None,\n        context: StepContext = None,\n    ):\n        \"\"\"Prepare input.\"\"\"\n        # get input_ids, position_ids and attention metadatas\n        input_ids = context.input_ids\n        position_ids = context.position_ids\n        attn_metadata = context.attn_metadata\n\n        images = None\n        image_mask = None\n        if context.input_multimodals is not None:\n            images = [input_mm.get('image', []) for input_mm in context.input_multimodals]\n            # flatten batch\n            images = [data for im_data in images for data in im_data]\n            if len(images) != 0:\n                image_token_id = images[0].meta['image_token_id']\n                image_mask = input_ids == image_token_id\n                images = torch.stack([data.data for data in images])\n            else:\n                images = None\n                image_mask = None\n\n        # process vision embeddings\n        vision_embeddings = context.input_embeddings\n        vision_embedding_indexing = context.input_embedding_indexing\n        if vision_embeddings is not None and len(vision_embeddings) > 0:\n            if inputs_embeds is None:\n                inputs_embeds = self.get_input_embeddings()(input_ids)\n            inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds)\n\n        # inputs of forward\n        return dict(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            images=images,\n            image_mask=image_mask,\n            inputs_embeds=inputs_embeds,\n        )\n\n    def _get_model_metas(self, context: StepContext):\n        \"\"\"Get model metas.\"\"\"\n        model_metas = context.model_metas\n        if model_metas is None:\n            batch_size = context.q_seqlens.numel()\n            return [dict(num_img_tokens=0)] * batch_size\n        return [dict(num_img_tokens=0) if meta is None else meta for meta in model_metas]\n\n    def update_model_metas(self,\n                           past_key_values: List[List[torch.Tensor]],\n                           inputs_embeds: Optional[torch.Tensor] = None,\n                           context: StepContext = None):\n        \"\"\"Update model meta.\"\"\"\n        model_metas = self._get_model_metas(context)\n        if not hasattr(self.config, 'vision_config'):\n            return model_metas\n\n        input_multimodals = context.input_multimodals\n        if input_multimodals is None:\n            input_imgs = [[] for _ in model_metas]\n        else:\n            input_imgs = []\n            for mm in input_multimodals:\n                if mm is None:\n                    input_imgs.append([])\n                else:\n                    input_imgs.append(mm.get('image', []))\n\n        config = self.config\n        image_size: int = config.vision_config['image_size']\n        patch_size: int = config.vision_config['patch_size']\n        vision_token_num = ((image_size // patch_size // 2) * (image_size // patch_size // 2) + 2)\n        num_pad = vision_token_num - 3\n\n        batched_num_img_tokens = []\n        new_model_metas = []\n        for meta, imgs in zip(model_metas, input_imgs):\n            if meta is None:\n                num_img_tokens = 0\n            else:\n                num_img_tokens = meta.get('num_img_tokens', 0)\n\n            batched_num_img_tokens.append(num_img_tokens)\n\n            num_img_tokens += num_pad * len(imgs)\n            new_model_metas.append(dict(num_img_tokens=num_img_tokens))\n\n        # prepare cogvlm position_ids\n        q_seqlens = context.q_seqlens\n        position_ids = context.position_ids\n\n        if context.is_decoding or all(len(imgs) == 0 for imgs in input_imgs):\n            num_img_tokens = torch.tensor(batched_num_img_tokens, device=position_ids.device)\n            position_ids -= num_img_tokens[None]\n        else:\n            batched_position_ids = position_ids[0].split(q_seqlens)\n            for pos_ids, num_img_tok, imgs in zip(batched_position_ids, batched_num_img_tokens, input_imgs):\n                pos_ids -= num_img_tok\n                if len(imgs) == 0:\n                    continue\n\n                seq_len = pos_ids.size(0)\n                start = pos_ids[0].cpu().item()\n                new_pos_ids = []\n\n                imgs = sorted(imgs, key=lambda img: img.start)\n                for img in imgs:\n                    img_pad_pos = img.start + 1 - num_img_tok\n                    num_pad = img.end - img.start - 2\n                    new_pos_ids += list(range(start, img_pad_pos))\n                    new_pos_ids += [img_pad_pos] * num_pad\n                    start = img_pad_pos + 1\n                    num_img_tok += num_pad\n\n                remain = seq_len - len(new_pos_ids)\n                new_pos_ids += list(range(start, start + remain))\n\n                new_pos_ids = pos_ids.new_tensor(new_pos_ids)\n                pos_ids[:] = new_pos_ids\n\n            position_ids = torch.cat(batched_position_ids)[None]\n        context.position_ids = position_ids\n\n        return new_model_metas\n\n    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):\n        \"\"\"Load weights.\"\"\"\n        # modify from vllm\n\n        params_dict = dict(self.named_parameters())\n        for name, loaded_weight in weights:\n            if 'transformer.vision' in name:\n                if '.query_key_value' in name:\n                    param = params_dict[name]\n                    q, k, v = param.weight_spliter(loaded_weight)\n                    load_weight(param, q, shard_id='q')\n                    load_weight(param, k, shard_id='k')\n                    load_weight(param, v, shard_id='v')\n                else:\n                    param = params_dict[name]\n                    load_weight(param, loaded_weight)\n                continue\n\n            if 'rotary_pos_emb.inv_freq' in name:\n                continue\n            if ('rotary_pos_emb.cos_cached' in name or 'rotary_pos_emb.sin_cached' in name):\n                continue\n            if (self.config.tie_word_embeddings and 'output_layer.weight' in name):\n                continue\n            if '.query_key_value' in name:\n                param = params_dict[name]\n                q, k, v = param.weight_spliter(loaded_weight)\n                load_weight(param, q, shard_id='q')\n                load_weight(param, k, shard_id='k')\n                load_weight(param, v, shard_id='v')\n            elif '.dense_h_to_4h' in name:\n                param = params_dict[name]\n                gate, up = param.weight_spliter(loaded_weight)\n                load_weight(param, gate, shard_id=0)\n                load_weight(param, up, shard_id=1)\n            else:\n                param = params_dict[name]\n                load_weight(param, loaded_weight)\n\n    def get_input_processor(self) -> BaseModelInputProcessor:\n        \"\"\"Get input processor.\"\"\"\n        return self.input_processor\n\n\nclass ChatGLMInputProcessor(BaseModelInputProcessor):\n    \"\"\"Input processor.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype) -> None:\n        self.config = config\n        self.dtype = dtype\n\n        if hasattr(config, 'vision_config'):\n            vision_config = config.vision_config\n            self.image_size = vision_config['image_size']\n            self.patch_size = vision_config['patch_size']\n            self.num_patches = (self.image_size // self.patch_size)**2\n            self.num_positions = self.num_patches + 1\n            self.vision_token_num = self.num_patches // 4\n\n    def preprocess_input(self,\n                         input_ids: List[int],\n                         input_multimodals: List[Dict[str, Any]] = None,\n                         **kwargs) -> PreprocessInputResult:\n        \"\"\"Prepare multimodal input.\"\"\"\n        if input_multimodals is None or len(input_multimodals) == 0:\n            return input_ids, input_multimodals\n\n        input_imgs = []\n        for input_mm in input_multimodals:\n            pixel_values = input_mm['pixel_values'].to(self.dtype)\n            offset = input_mm['offset']\n            num_pad = input_mm['image_tokens']\n            image_token_id = input_mm['image_token_id']\n            if isinstance(num_pad, torch.Tensor):\n                num_pad = num_pad.item()\n\n            mm_data = MultiModalData(data=pixel_values,\n                                     start=offset,\n                                     end=offset + num_pad,\n                                     meta=dict(image_token_id=image_token_id))\n            input_imgs.append(mm_data)\n\n        result = PreprocessInputResult(\n            input_ids=input_ids,\n            input_multimodals=dict(image=input_imgs),\n        )\n        return result\n"
  },
  {
    "path": "lmdeploy/pytorch/models/cogvlm.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nfrom argparse import Namespace\nfrom typing import Any, Iterable, List, Optional, Tuple\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn\nfrom transformers.configuration_utils import PretrainedConfig\n\nimport lmdeploy.pytorch.distributed as dist\nfrom lmdeploy.pytorch.distributed import get_tp_world_rank\nfrom lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor, PreprocessInputResult\nfrom lmdeploy.pytorch.model_inputs import StepContext, StepContextManager\nfrom lmdeploy.pytorch.multimodal.data_type import MultiModalData\nfrom lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, RMSNorm, RopeType, SiluAndMul, build_rotary_embedding\nfrom lmdeploy.pytorch.nn.linear import (build_colwise_linear, build_merged_colwise_linear, build_qkv_proj,\n                                        build_rowwise_linear)\nfrom lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight\n\nfrom .utils.cudagraph import CudaGraphMixin\nfrom .utils.model import DeployModelMixin, vlm_model\n\n\nclass VisionExpertAttention(nn.Module):\n    \"\"\"Rewrite module of VisionExpertAttention.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        is_cogvlm2 = hasattr(config, 'num_multi_query_heads')\n        quantization_config = getattr(config, 'quantization_config', None)\n        num_heads = config.num_attention_heads\n        num_key_value_heads = getattr(config, 'num_key_value_heads', num_heads)\n        num_replicate_kv_heads = getattr(config, 'num_replicate_key_value_heads', 1)\n        hidden_size = config.hidden_size\n        head_dim = getattr(config, 'head_dim', hidden_size // num_heads)\n        self.hidden_size = hidden_size\n        self.num_kv_heads = num_key_value_heads\n        self.head_dim = head_dim\n\n        # packed qkv\n        self.vision_expert_query_key_value = build_qkv_proj(hidden_size,\n                                                            num_q_heads=num_heads,\n                                                            num_kv_heads=num_key_value_heads,\n                                                            head_size=head_dim,\n                                                            bias=is_cogvlm2,\n                                                            quant_config=quantization_config,\n                                                            dtype=dtype,\n                                                            device=device,\n                                                            num_replicate_kv_heads=num_replicate_kv_heads)\n        self.language_expert_query_key_value = build_qkv_proj(hidden_size,\n                                                              num_q_heads=num_heads,\n                                                              num_kv_heads=num_key_value_heads,\n                                                              head_size=head_dim,\n                                                              bias=False,\n                                                              quant_config=quantization_config,\n                                                              dtype=dtype,\n                                                              device=device,\n                                                              num_replicate_kv_heads=num_replicate_kv_heads)\n\n        # rotary embedding\n        self.apply_rotary_pos_emb = ApplyRotaryEmb()\n\n        # attention\n        self.attn_fwd = Attention(\n            num_heads,\n            head_dim,\n            num_kv_heads=num_key_value_heads,\n        )\n\n        # o_proj\n        self.vision_expert_dense = build_rowwise_linear(hidden_size,\n                                                        hidden_size,\n                                                        bias=False,\n                                                        quant_config=quantization_config,\n                                                        dtype=dtype,\n                                                        device=device,\n                                                        is_tp=True,\n                                                        all_reduce=False)\n        self.language_expert_dense = build_rowwise_linear(hidden_size,\n                                                          hidden_size,\n                                                          bias=False,\n                                                          quant_config=quantization_config,\n                                                          dtype=dtype,\n                                                          device=device,\n                                                          is_tp=True,\n                                                          all_reduce=False)\n        world_size, _ = get_tp_world_rank()\n        self.world_size = world_size\n        self.all_reduce = world_size > 1\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],\n        past_key_value: Optional[Tuple[torch.Tensor]] = None,\n        attn_metadata: Any = None,\n        lang_ids: torch.LongTensor = None,\n        vision_ids: torch.LongTensor = None,\n    ):\n        \"\"\"Rewrite of LlamaAttention.forward.\"\"\"\n        bsz, seqlen, _ = hidden_states.size()\n        hidden_size = self.hidden_size // self.world_size\n        kv_size = self.num_kv_heads * self.head_dim // self.world_size\n\n        # qkv proj\n        if lang_ids is None and vision_ids is None:\n            qkv_states = self.language_expert_query_key_value(hidden_states)\n        else:\n            qkv_states = hidden_states.new_empty(bsz, seqlen, hidden_size + kv_size * 2)\n            if lang_ids is not None:\n                qkv_states[:, lang_ids] = self.language_expert_query_key_value(hidden_states[:, lang_ids])\n            if vision_ids is not None:\n                qkv_states[:, vision_ids] = self.vision_expert_query_key_value(hidden_states[:, vision_ids])\n        # (-1, heads, head_dim)\n        qkv_states = qkv_states.flatten(0, -2)\n        query_states, key_states, value_states = \\\n            self.language_expert_query_key_value.split_qkv(qkv_states)\n\n        # apply rotary embedding\n        cos, sin = rotary_pos_emb\n        query_states, key_states = self.apply_rotary_pos_emb(\n            query_states,\n            key_states,\n            cos,\n            sin,\n            inplace=True,\n        )\n\n        # attention\n        attn_output = self.attn_fwd(\n            query_states,\n            key_states,\n            value_states,\n            past_key_value[0],\n            past_key_value[1],\n            attn_metadata,\n            k_scales_zeros=None if len(past_key_value) == 2 else past_key_value[2],\n            v_scales_zeros=None if len(past_key_value) == 2 else past_key_value[3],\n            inplace=True,\n        )\n        attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1)\n\n        # o proj\n        if lang_ids is None and vision_ids is None:\n            attn_output = self.language_expert_dense(attn_output)\n        else:\n            new_attn_output = torch.empty_like(hidden_states)\n            if lang_ids is not None:\n                new_attn_output[:, lang_ids] = self.language_expert_dense(attn_output[:, lang_ids])\n            if vision_ids is not None:\n                new_attn_output[:, vision_ids] = self.vision_expert_dense(attn_output[:, vision_ids])\n            attn_output = new_attn_output\n\n        if self.all_reduce:\n            dist.all_reduce(attn_output)\n        return attn_output\n\n\nclass MLP(nn.Module):\n    \"\"\"mlp.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        assert config.hidden_act == 'silu'\n\n        quantization_config = getattr(config, 'quantization_config', None)\n\n        # gate up\n        self.gate_up_proj = build_merged_colwise_linear(\n            config.hidden_size,\n            [config.intermediate_size, config.intermediate_size],\n            bias=False,\n            dtype=dtype,\n            device=device,\n            quant_config=quantization_config,\n            is_tp=True,\n        )\n\n        # silu and mul\n        self.act_fn = SiluAndMul(inplace=True)\n\n        # down\n        self.down_proj = build_rowwise_linear(config.intermediate_size,\n                                              config.hidden_size,\n                                              bias=False,\n                                              quant_config=quantization_config,\n                                              dtype=dtype,\n                                              device=device,\n                                              is_tp=True,\n                                              all_reduce=False)\n\n    def forward(self, x):\n        \"\"\"forward.\"\"\"\n        gate_up = self.gate_up_proj(x)\n        act = self.act_fn(gate_up)\n        return self.down_proj(act)\n\n\nclass VisionExpertMLP(nn.Module):\n    \"\"\"Vision expert mlp.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        self.language_mlp = MLP(config, dtype=dtype, device=device)\n        self.vision_mlp = MLP(config, dtype=dtype, device=device)\n        world_size, _ = get_tp_world_rank()\n        self.all_reduce = world_size > 1\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        lang_ids: torch.LongTensor = None,\n        vision_ids: torch.LongTensor = None,\n    ):\n        \"\"\"forward.\"\"\"\n        if lang_ids is None and vision_ids is None:\n            output = self.language_mlp(hidden_states)\n        else:\n            output = torch.empty_like(hidden_states)\n            if lang_ids is not None:\n                output[:, lang_ids] = self.language_mlp(hidden_states[:, lang_ids])\n            if vision_ids is not None:\n                output[:, vision_ids] = self.vision_mlp(hidden_states[:, vision_ids])\n        if self.all_reduce:\n            dist.all_reduce(output)\n        return output\n\n\nclass CogVLMDecoderLayer(nn.Module):\n    \"\"\"Decoder layer.\"\"\"\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 layer_idx: int,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__()\n        self.layer_idx = layer_idx\n        quantization_config = getattr(config, 'quantization_config', None)\n\n        # build attention layer\n        self.self_attn = VisionExpertAttention(config, dtype=dtype, device=device)\n\n        # build MLP\n        self.mlp = VisionExpertMLP(config, dtype=dtype, device=device)\n\n        # build input layer norm\n        self.input_layernorm = RMSNorm(config.hidden_size,\n                                       config.rms_norm_eps,\n                                       quant_config=quantization_config,\n                                       dtype=dtype,\n                                       device=device)\n\n        # build attention layer norm\n        self.post_attention_layernorm = RMSNorm(config.hidden_size,\n                                                config.rms_norm_eps,\n                                                quant_config=quantization_config,\n                                                dtype=dtype,\n                                                device=device)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],\n        past_key_value: Optional[List[torch.FloatTensor]],\n        residual: Optional[torch.Tensor] = None,\n        attn_metadata: Any = None,\n        lang_ids: torch.LongTensor = None,\n        vision_ids: torch.LongTensor = None,\n    ):\n\n        if residual is None:\n            residual = hidden_states\n            hidden_states = self.input_layernorm(hidden_states)\n        else:\n            hidden_states, residual = self.input_layernorm(hidden_states, residual)\n\n        # Self Attention\n        hidden_states = self.self_attn(\n            hidden_states=hidden_states,\n            rotary_pos_emb=rotary_pos_emb,\n            past_key_value=past_key_value,\n            attn_metadata=attn_metadata,\n            lang_ids=lang_ids,\n            vision_ids=vision_ids,\n        )\n\n        # Fully Connected\n        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)\n        hidden_states = self.mlp(\n            hidden_states,\n            lang_ids=lang_ids,\n            vision_ids=vision_ids,\n        )\n\n        outputs = (hidden_states, residual)\n        return outputs\n\n\nclass PatchEmbedding(nn.Module):\n    \"\"\"Vision embedding.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        self.proj = nn.Conv2d(config.in_channels,\n                              config.hidden_size,\n                              kernel_size=config.patch_size,\n                              stride=config.patch_size,\n                              dtype=dtype,\n                              device=device)\n        self.cls_embedding = nn.Parameter(torch.empty(1, config.hidden_size, dtype=dtype, device=device))\n        self.position_embedding = nn.Embedding(config.num_positions, config.hidden_size, dtype=dtype, device=device)\n\n    def forward(self, images):\n        \"\"\"forward.\"\"\"\n        x = self.proj(images)\n        x = x.flatten(2).transpose(1, 2)\n        cls_token = self.cls_embedding.expand(x.shape[0], -1, -1)\n        x = torch.cat((cls_token, x), dim=1)\n        x += self.position_embedding.weight.unsqueeze(0)\n        return x\n\n\nclass EVA2CLIPAttention(nn.Module):\n    \"\"\"Vision attention.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        quantization_config = getattr(config, 'quantization_config', None)\n        hidden_size = config.hidden_size\n        num_heads = config.num_heads\n        head_dim = config.hidden_size // config.num_heads\n        self.scale = head_dim**-0.5\n\n        # packed qkv\n        self.query_key_value = build_qkv_proj(\n            hidden_size,\n            num_q_heads=num_heads,\n            num_kv_heads=num_heads,\n            head_size=head_dim,\n            bias=True,\n            quant_config=quantization_config,\n            dtype=dtype,\n            device=device,\n        )\n\n        # o_proj\n        self.dense = build_rowwise_linear(hidden_size,\n                                          hidden_size,\n                                          bias=True,\n                                          quant_config=quantization_config,\n                                          dtype=dtype,\n                                          device=device,\n                                          is_tp=True)\n\n    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        \"\"\"forward.\"\"\"\n        # qkv proj\n        qkv_states = self.query_key_value(hidden_states)\n        q, k, v = self.query_key_value.split_qkv(qkv_states)\n\n        q = q.transpose(1, 2)\n        k = k.transpose(1, 2)\n        v = v.transpose(1, 2)\n\n        attn_output = F.scaled_dot_product_attention(q, k, v, scale=self.scale)\n\n        # o proj\n        attn_output = attn_output.transpose(1, 2)\n        attn_output = attn_output.flatten(-2, -1)\n        attn_output = self.dense(attn_output)\n        return attn_output\n\n\nclass EVA2CLIPMLP(nn.Module):\n    \"\"\"Vision MLP.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        from transformers.activations import ACT2FN\n\n        # gate up\n        quantization_config = getattr(config, 'quantization_config', None)\n        self.fc1 = build_colwise_linear(\n            config.hidden_size,\n            config.intermediate_size,\n            bias=True,\n            dtype=dtype,\n            device=device,\n            quant_config=quantization_config,\n            is_tp=True,\n        )\n\n        # silu and mul\n        if config.hidden_act in ['gelu', 'gelu_fast', 'quick_gelu', 'gelu_python']:\n            self.activation_fn = nn.GELU()\n        else:\n            self.activation_fn = ACT2FN[config.hidden_act]\n\n        # down\n        self.fc2 = build_rowwise_linear(config.intermediate_size,\n                                        config.hidden_size,\n                                        bias=True,\n                                        quant_config=quantization_config,\n                                        dtype=dtype,\n                                        device=device,\n                                        is_tp=True)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        \"\"\"forward.\"\"\"\n        x = self.fc1(x)\n        x = self.activation_fn(x)\n        x = self.fc2(x)\n        return x\n\n\nclass EVA2CLIPTransformerLayer(nn.Module):\n    \"\"\"Vision trans layer.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, dtype=dtype, device=device)\n        self.attention = EVA2CLIPAttention(config, dtype=dtype, device=device)\n        self.mlp = EVA2CLIPMLP(config, dtype=dtype, device=device)\n        self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,\n                                                     eps=config.layer_norm_eps,\n                                                     dtype=dtype,\n                                                     device=device)\n\n    def forward(self, hidden_states):\n        \"\"\"forward.\"\"\"\n        attention_input = hidden_states\n        attention_output = self.input_layernorm(self.attention(attention_input))\n        hidden_states = attention_input + attention_output\n        mlp_input = hidden_states\n        mlp_output = self.post_attention_layernorm(self.mlp(mlp_input))\n        output = mlp_input + mlp_output\n        return output\n\n\nclass EVA2CLIPTransformer(nn.Module):\n    \"\"\"Vision transformer.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        self.layers = nn.ModuleList(\n            [EVA2CLIPTransformerLayer(config, dtype=dtype, device=device) for _ in range(config.num_hidden_layers)])\n\n    def forward(self, hidden_states):\n        \"\"\"forward.\"\"\"\n        for layer_module in self.layers:\n            hidden_states = layer_module(hidden_states)\n        return hidden_states\n\n\nclass GLU(nn.Module):\n    \"\"\"GLU.\"\"\"\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 in_features: int,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__()\n        self.linear_proj = nn.Linear(in_features, config.hidden_size, bias=False, dtype=dtype, device=device)\n        self.norm1 = nn.LayerNorm(config.hidden_size, dtype=dtype, device=device)\n        self.act1 = nn.GELU()\n        self.act2 = nn.functional.silu\n        self.dense_h_to_4h = nn.Linear(config.hidden_size,\n                                       config.intermediate_size,\n                                       bias=False,\n                                       dtype=dtype,\n                                       device=device)\n        self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False, dtype=dtype, device=device)\n        self.dense_4h_to_h = nn.Linear(config.intermediate_size,\n                                       config.hidden_size,\n                                       bias=False,\n                                       dtype=dtype,\n                                       device=device)\n\n    def forward(self, x):\n        x = self.linear_proj(x)\n        x = self.act1(self.norm1(x))\n        x = self.act2(self.gate_proj(x)) * self.dense_h_to_4h(x)\n        x = self.dense_4h_to_h(x)\n        return x\n\n\n@vlm_model\nclass EVA2CLIPModel(nn.Module):\n    \"\"\"Vision model.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        vision_config = Namespace(**config.vision_config)\n\n        self.patch_embedding = PatchEmbedding(vision_config, dtype=dtype, device=device)\n        self.transformer = EVA2CLIPTransformer(vision_config, dtype=dtype, device=device)\n        self.linear_proj = GLU(config, in_features=vision_config.hidden_size, dtype=dtype, device=device)\n        if vision_config.num_positions == 1226:\n            # cogvlm-chat-hf\n            self.conv = None\n        else:\n            # cogvlm2\n            self.conv = nn.Conv2d(in_channels=vision_config.hidden_size,\n                                  out_channels=vision_config.hidden_size,\n                                  kernel_size=2,\n                                  stride=2,\n                                  dtype=dtype,\n                                  device=device)\n        self.boi = nn.Parameter(torch.empty(1, 1, config.hidden_size, dtype=dtype, device=device))\n        self.eoi = nn.Parameter(torch.empty(1, 1, config.hidden_size, dtype=dtype, device=device))\n\n    def forward(self, images):\n        \"\"\"forward.\"\"\"\n        x = self.patch_embedding(images)\n        x = self.transformer(x)\n\n        x = x[:, 1:]\n        # cogvlm2\n        if self.conv is not None:\n            b, s, h = x.shape\n            grid_size = int(s**0.5)\n            x = x.view(b, grid_size, grid_size, h).permute(0, 3, 1, 2)\n            x = self.conv(x)\n\n            x = x.flatten(2).transpose(1, 2)\n        x = self.linear_proj(x)\n        boi = self.boi.expand(x.shape[0], -1, -1)\n        eoi = self.eoi.expand(x.shape[0], -1, -1)\n        x = torch.cat((boi, x, eoi), dim=1)\n        return x\n\n\nclass CogVLMModel(nn.Module):\n    \"\"\"model.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n\n        self.embed_tokens = nn.Embedding(config.vocab_size,\n                                         config.hidden_size,\n                                         self.padding_idx,\n                                         dtype=dtype,\n                                         device=device)\n\n        # build all decode layers\n        self.layers = nn.ModuleList([\n            CogVLMDecoderLayer(config, layer_idx, dtype=dtype, device=device)\n            for layer_idx in range(config.num_hidden_layers)\n        ])\n\n        # build norm\n        self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)\n\n        # vision model\n        self.vision = EVA2CLIPModel(config, dtype=dtype, device=device)\n\n        # build rotary embedding\n        emb_type = RopeType.LinearScaling\n        rope_dim = config.hidden_size // config.num_attention_heads\n        rope_max_pos_emb = 2048\n        rope_base = 10000\n        self.rotary_emb = build_rotary_embedding(\n            rope_dim,\n            rope_max_pos_emb,\n            rope_base,\n            emb_type=emb_type,\n        )\n\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        attn_metadata: Any = None,\n        images: torch.Tensor = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        lang_ids: torch.LongTensor = None,\n        vision_ids: torch.LongTensor = None,\n    ):\n        \"\"\"Rewrite of LlamaModel.forward.\"\"\"\n\n        # token embedding\n        if inputs_embeds is None:\n            if images is not None:\n                images_features = self.vision(images)\n\n            inputs_embeds = self.embed_tokens(input_ids)\n            if vision_ids is not None:\n                inputs_embeds[0, vision_ids] = images_features.flatten(0, 1)\n\n        hidden_states = inputs_embeds\n\n        # rotary embedding\n        cos, sin = self.rotary_emb(hidden_states, position_ids)\n        cos, sin = cos[0], sin[0]\n        rotary_pos_emb = (cos, sin)\n\n        # decoding\n        residual = None\n        for idx, decoder_layer in enumerate(self.layers):\n            past_key_value = past_key_values[idx]\n            hidden_states, residual = decoder_layer(\n                hidden_states,\n                rotary_pos_emb=rotary_pos_emb,\n                past_key_value=past_key_value,\n                residual=residual,\n                attn_metadata=attn_metadata,\n                lang_ids=lang_ids,\n                vision_ids=vision_ids,\n            )\n\n        # norm\n        hidden_states, _ = self.norm(hidden_states, residual)\n\n        return hidden_states\n\n    def get_input_embeddings(self):\n        \"\"\"Get input embeddings.\"\"\"\n        return self.embed_tokens\n\n\nLANGUAGE_TOKEN_TYPE = 0\nVISION_TOKEN_TYPE = 1\n\n\nclass CogVLMForCausalLM(nn.Module, CudaGraphMixin, DeployModelMixin):\n    \"\"\"ModelForCausalLM.\"\"\"\n\n    packed_modules_mapping = {\n        'gate_up_proj': [\n            'gate_proj',\n            'up_proj',\n        ],\n    }\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 ctx_mgr: StepContextManager,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__()\n        self.config = config\n        self.ctx_mgr = ctx_mgr\n        # preprocessor\n        self.input_processor = CogVLMInputProcessor(self.config, dtype)\n        # build model\n        self.model = CogVLMModel(config, dtype=dtype, device=device)\n        # build lm_head\n        self.lm_head = build_rowwise_linear(config.hidden_size,\n                                            config.vocab_size,\n                                            bias=False,\n                                            dtype=dtype,\n                                            device=device)\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        past_key_values: List[List[torch.Tensor]],\n        attn_metadata: Any = None,\n        images: torch.Tensor = None,\n        inputs_embeds: torch.Tensor = None,\n        lang_ids: torch.LongTensor = None,\n        vision_ids: torch.LongTensor = None,\n        **kwargs,\n    ):\n        \"\"\"Model forward, return logits.\"\"\"\n        hidden_states = self.model(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            images=images,\n            inputs_embeds=inputs_embeds,\n            lang_ids=lang_ids,\n            vision_ids=vision_ids,\n        )\n        return hidden_states\n\n    def get_logits(self, hidden_states: torch.Tensor):\n        \"\"\"Compute logits of the model output.\"\"\"\n        return self.lm_head(hidden_states)\n\n    def get_input_embeddings(self):\n        \"\"\"Get input embeddings.\"\"\"\n        return self.model.get_input_embeddings()\n\n    def prepare_inputs_for_generation(\n        self,\n        past_key_values: List[List[torch.Tensor]],\n        inputs_embeds: Optional[torch.Tensor] = None,\n        context: StepContext = None,\n    ):\n        \"\"\"Prepare input.\"\"\"\n        # get input_ids, position_ids and attention metadatas\n        input_ids = context.input_ids\n\n        # position_ids, lang_ids, vis_ids = _get_cogvlm_position_ids(context)\n        position_ids = context.position_ids\n        lang_ids = None\n        vis_ids = None\n\n        # vision inputs\n        images = None\n        if context.input_multimodals is not None:\n            images = [input_mm.get('image', []) for input_mm in context.input_multimodals]\n            # flatten batch\n            images = [data for im_data in images for data in im_data]\n            if len(images) == 0:\n                images = None\n\n        if images is not None:\n            image_token_id = images[0].meta['image_token_id']\n            vis_mask = input_ids[0] == image_token_id\n            images = torch.stack([data.data for data in images])\n\n            # get lang_ids\n            vis_range = torch.arange(0, input_ids.size(-1), device=input_ids.device)\n            vis_ids = vis_range[vis_mask]\n            lang_ids = vis_range[~vis_mask]\n\n        attn_metadata = context.attn_metadata\n\n        # process vision embeddings\n        vision_embeddings = context.input_embeddings\n        vision_embedding_indexing = context.input_embedding_indexing\n        if vision_embeddings is not None and len(vision_embeddings) > 0:\n            if inputs_embeds is None:\n                inputs_embeds = self.get_input_embeddings()(input_ids)\n            inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds)\n\n        # inputs of forward\n        return dict(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            images=images,\n            inputs_embeds=inputs_embeds,\n            lang_ids=lang_ids,\n            vision_ids=vis_ids,\n        )\n\n    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):\n        \"\"\"Load weights.\"\"\"\n        # modify from vllm\n        stacked_params_mapping = [\n            # (param_name, shard_name, shard_id)\n            ('.gate_up_proj', '.gate_proj', 0),\n            ('.gate_up_proj', '.up_proj', 1),\n        ]\n\n        params_dict = dict(self.named_parameters())\n        for name, loaded_weight in weights:\n            if 'rotary_emb.inv_freq' in name:\n                continue\n            if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):\n                continue\n            if self.config.tie_word_embeddings and 'lm_head.weight' in name:\n                continue\n            for (param_name, weight_name, shard_id) in stacked_params_mapping:\n                if '.vision.' in name:\n                    continue\n                if weight_name not in name:\n                    continue\n                name = name.replace(weight_name, param_name)\n                param = params_dict[name]\n                load_weight(param, loaded_weight, shard_id=shard_id)\n                break\n            else:\n                if '_expert_query_key_value' in name:\n                    param = params_dict[name]\n                    q, k, v = param.weight_spliter(loaded_weight)\n                    load_weight(param, q, shard_id='q')\n                    load_weight(param, k, shard_id='k')\n                    load_weight(param, v, shard_id='v')\n                elif '.query_key_value' in name:\n                    param = params_dict[name]\n                    q, k, v = param.weight_spliter(loaded_weight)\n                    load_weight(param, q, shard_id='q')\n                    load_weight(param, k, shard_id='k')\n                    load_weight(param, v, shard_id='v')\n                else:\n                    param = params_dict[name]\n                    load_weight(param, loaded_weight)\n\n    def _get_model_metas(self, context: StepContext):\n        \"\"\"Get model metas.\"\"\"\n        model_metas = context.model_metas\n        if model_metas is None:\n            batch_size = context.q_seqlens.numel()\n            return [dict(num_img_tokens=0)] * batch_size\n        return [dict(num_img_tokens=0) if meta is None else meta for meta in model_metas]\n\n    def update_model_metas(self,\n                           past_key_values: List[List[torch.Tensor]],\n                           inputs_embeds: Optional[torch.Tensor] = None,\n                           context: StepContext = None):\n        \"\"\"Update model meta.\"\"\"\n        model_metas = self._get_model_metas(context)\n        input_multimodals = context.input_multimodals\n        if input_multimodals is None:\n            input_imgs = [[] for _ in model_metas]\n        else:\n            input_imgs = []\n            for mm in input_multimodals:\n                if mm is None:\n                    input_imgs.append([])\n                else:\n                    input_imgs.append(mm.get('image', []))\n\n        num_pad = self.input_processor.vision_token_num - 3\n\n        batched_num_img_tokens = []\n        new_model_metas = []\n        for meta, imgs in zip(model_metas, input_imgs):\n            if meta is None:\n                num_img_tokens = 0\n            else:\n                num_img_tokens = meta.get('num_img_tokens', 0)\n\n            batched_num_img_tokens.append(num_img_tokens)\n\n            num_img_tokens += num_pad * len(imgs)\n            new_model_metas.append(dict(num_img_tokens=num_img_tokens))\n\n        # prepare cogvlm position_ids\n        q_seqlens = context.q_seqlens\n        position_ids = context.position_ids\n\n        if context.is_decoding or all(len(imgs) == 0 for imgs in input_imgs):\n            num_img_tokens = torch.tensor(batched_num_img_tokens, device=position_ids.device)\n            position_ids -= num_img_tokens[None]\n        else:\n            batched_position_ids = position_ids[0].split(q_seqlens)\n            for pos_ids, num_img_tok, imgs in zip(batched_position_ids, batched_num_img_tokens, input_imgs):\n                pos_ids -= num_img_tok\n                if len(imgs) == 0:\n                    continue\n\n                seq_len = pos_ids.size(0)\n                start = pos_ids[0].cpu().item()\n                new_pos_ids = []\n\n                imgs = sorted(imgs, key=lambda img: img.start)\n                for img in imgs:\n                    img_pad_pos = img.start + 1 - num_img_tok\n                    num_pad = img.end - img.start - 2\n                    new_pos_ids += list(range(start, img_pad_pos))\n                    new_pos_ids += [img_pad_pos] * num_pad\n                    start = img_pad_pos + 1\n                    num_img_tok += num_pad\n\n                remain = seq_len - len(new_pos_ids)\n                new_pos_ids += list(range(start, start + remain))\n\n                new_pos_ids = pos_ids.new_tensor(new_pos_ids)\n                pos_ids[:] = new_pos_ids\n\n            position_ids = torch.cat(batched_position_ids)[None]\n        context.position_ids = position_ids\n\n        return new_model_metas\n\n    def get_input_processor(self) -> BaseModelInputProcessor:\n        \"\"\"Get input processor.\"\"\"\n        return self.input_processor\n\n\nclass CogVLMInputProcessor(BaseModelInputProcessor):\n    \"\"\"Input processor.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype) -> None:\n        self.config = config\n        self.dtype = dtype\n        image_size: int = config.vision_config['image_size']\n        patch_size: int = config.vision_config['patch_size']\n        if config.vision_config['num_positions'] == 1226:\n            # # cogvlm-chat-hf\n            self.vision_token_num = 2 + (image_size // patch_size)**2\n        else:\n            # cogvlm2\n            self.vision_token_num = 2 + (image_size // patch_size // 2)**2\n\n    def preprocess_input(self, input_ids: List[int], input_multimodals=None, **kwargs) -> PreprocessInputResult:\n        \"\"\"Prepare multimodal input.\"\"\"\n        if input_multimodals is None or len(input_multimodals) == 0:\n            return input_ids, input_multimodals\n\n        input_imgs = []\n        for input_mm in input_multimodals:\n            pixel_values = input_mm['pixel_values'].to(self.dtype)\n            offset = input_mm['offset']\n            image_token_id = input_mm['image_token_id']\n            num_pad = input_mm['image_tokens']\n            if isinstance(num_pad, torch.Tensor):\n                num_pad = num_pad.item()\n\n            mm_data = MultiModalData(data=pixel_values,\n                                     start=offset,\n                                     end=offset + num_pad,\n                                     meta=dict(image_token_id=image_token_id))\n            input_imgs.append(mm_data)\n\n        result = PreprocessInputResult(\n            input_ids=input_ids,\n            input_multimodals=dict(image=input_imgs),\n        )\n        return result\n"
  },
  {
    "path": "lmdeploy/pytorch/models/deepseek.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nfrom typing import Any, Dict, Iterable, List, Optional, Tuple\n\nimport torch\nfrom torch import nn\nfrom transformers.configuration_utils import PretrainedConfig\n\nimport lmdeploy.pytorch.distributed as dist\nfrom lmdeploy.pytorch.distributed import get_tp_world_rank\nfrom lmdeploy.pytorch.model_inputs import StepContext, StepContextManager\nfrom lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, RMSNorm, SiluAndMul, build_rotary_embedding_from_config\nfrom lmdeploy.pytorch.nn.linear import build_merged_colwise_linear, build_qkv_proj, build_rowwise_linear\nfrom lmdeploy.pytorch.nn.moe import SoftmaxTopK, build_fused_moe\nfrom lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight\n\nfrom .utils.cudagraph import CudaGraphMixin\n\n\nclass DeepseekAttention(nn.Module):\n    \"\"\"Rewrite module of MistralAttention.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        quantization_config = getattr(config, 'quantization_config', None)\n        num_heads = config.num_attention_heads\n        num_key_value_heads = config.num_key_value_heads\n        hidden_size = config.hidden_size\n        head_dim = getattr(config, 'head_dim', hidden_size // num_heads)\n\n        # packed qkv\n        self.qkv_proj = build_qkv_proj(\n            hidden_size,\n            num_q_heads=num_heads,\n            num_kv_heads=num_key_value_heads,\n            head_size=head_dim,\n            bias=False,\n            quant_config=quantization_config,\n            dtype=dtype,\n            device=device,\n        )\n\n        # rotary embedding\n        self.apply_rotary_pos_emb = ApplyRotaryEmb()\n\n        # attention\n        self.attn_fwd = Attention(\n            num_heads,\n            head_dim,\n            num_kv_heads=num_key_value_heads,\n        )\n\n        # o_proj\n        self.o_proj = build_rowwise_linear(num_heads * head_dim,\n                                           hidden_size,\n                                           bias=False,\n                                           quant_config=quantization_config,\n                                           dtype=dtype,\n                                           device=device,\n                                           is_tp=True)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],\n        past_key_value: Optional[Tuple[torch.Tensor]] = None,\n        attn_metadata: Any = None,\n    ):\n        \"\"\"Rewrite of LlamaAttention.forward.\"\"\"\n        # qkv proj\n        qkv_states = self.qkv_proj(hidden_states)\n        # (-1, heads, head_dim)\n        qkv_states = qkv_states.flatten(0, -2)\n        query_states, key_states, value_states = self.qkv_proj.split_qkv(qkv_states)\n\n        # apply rotary embedding\n        cos, sin = rotary_pos_emb\n        query_states, key_states = self.apply_rotary_pos_emb(\n            query_states,\n            key_states,\n            cos,\n            sin,\n            inplace=True,\n        )\n\n        # attention\n        attn_output = self.attn_fwd(\n            query_states,\n            key_states,\n            value_states,\n            past_key_value[0],\n            past_key_value[1],\n            attn_metadata,\n            k_scales_zeros=None if len(past_key_value) == 2 else past_key_value[2],\n            v_scales_zeros=None if len(past_key_value) == 2 else past_key_value[3],\n            inplace=True,\n        )\n        attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1)\n\n        # o proj\n        attn_output = self.o_proj(attn_output)\n        return attn_output\n\n\nclass DeepseekMoE(nn.Module):\n    \"\"\"Deepseek MoE.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        self.hidden_dim = config.hidden_size\n        self.ffn_dim = config.moe_intermediate_size\n        self.num_experts = config.n_routed_experts\n        self.top_k = config.num_experts_per_tok\n        self.norm_topk_prob = config.norm_topk_prob\n        self.renormalize = self.top_k > 1 and self.norm_topk_prob\n\n        self.gate = build_rowwise_linear(\n            self.hidden_dim,\n            self.num_experts,\n            bias=False,\n            dtype=dtype,\n            device=device,\n            is_tp=False,\n        )\n\n        self.softmax_topk = SoftmaxTopK(\n            self.top_k,\n            n_groups=getattr(config, 'router_n_groups', -1),\n        )\n\n        self.experts = build_fused_moe(\n            self.hidden_dim,\n            self.ffn_dim,\n            self.num_experts,\n            top_k=self.top_k,\n            renormalize=self.renormalize,\n            dtype=dtype,\n            device=device,\n            all_reduce=False,\n        )\n\n        self.shared_experts = None\n        if config.n_shared_experts is not None:\n            intermediate_size = (config.moe_intermediate_size * config.n_shared_experts)\n            self.shared_experts = DeepseekMLP(\n                config=config,\n                intermediate_size=intermediate_size,\n                dtype=dtype,\n                device=device,\n                is_tp=True,\n                all_reduce=False,\n            )\n        world_size, _ = get_tp_world_rank()\n        if world_size > 1:\n            self._all_reduce = True\n        else:\n            self._all_reduce = False\n\n    def forward(self, hidden_states: torch.Tensor):\n        \"\"\"forward.\"\"\"\n        batch_size, sequence_length, hidden_dim = hidden_states.shape\n        hidden_states = hidden_states.view(-1, hidden_dim)\n        router_logits = self.gate(hidden_states)\n\n        topk_weights, topk_ids = self.softmax_topk(router_logits)\n        out_states = self.experts(\n            hidden_states,\n            topk_weights,\n            topk_ids,\n        )\n\n        if self.shared_experts is not None:\n            shared_states = self.shared_experts(hidden_states)\n            out_states += shared_states\n        out_states = out_states.reshape(batch_size, sequence_length, -1)\n\n        if self._all_reduce:\n            dist.all_reduce(out_states)\n\n        return out_states\n\n\nclass DeepseekMLP(nn.Module):\n    \"\"\"Deepseek mlp.\"\"\"\n\n    def __init__(self,\n                 config: Any,\n                 intermediate_size: int = None,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None,\n                 is_tp: bool = True,\n                 all_reduce: bool = True):\n        super().__init__()\n        quantization_config = getattr(config, 'quantization_config', None)\n        # gate up\n        if intermediate_size is None:\n            intermediate_size = config.intermediate_size\n        self.gate_up_proj = build_merged_colwise_linear(\n            config.hidden_size,\n            [intermediate_size, intermediate_size],\n            bias=False,\n            dtype=dtype,\n            device=device,\n            quant_config=quantization_config,\n            is_tp=is_tp,\n        )\n\n        # silu and mul\n        self.act_fn = SiluAndMul(inplace=True)\n\n        # down\n        self.down_proj = build_rowwise_linear(\n            intermediate_size,\n            config.hidden_size,\n            bias=False,\n            quant_config=quantization_config,\n            dtype=dtype,\n            device=device,\n            is_tp=is_tp,\n            all_reduce=all_reduce,\n        )\n\n    def forward(self, x):\n        \"\"\"forward.\"\"\"\n        gate_up = self.gate_up_proj(x)\n        act = self.act_fn(gate_up)\n        return self.down_proj(act)\n\n\nclass DeepseekDecoderLayer(nn.Module):\n    \"\"\"Llama decoder layer.\"\"\"\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 layer_idx: int,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__()\n        self.layer_idx = layer_idx\n        quantization_config = getattr(config, 'quantization_config', None)\n\n        # build attention layer\n        self.self_attn = DeepseekAttention(config, dtype=dtype, device=device)\n\n        # build MLP\n        self.mlp = (DeepseekMoE(config, dtype=dtype, device=device) if\n                    (config.n_routed_experts is not None and layer_idx >= config.first_k_dense_replace\n                     and layer_idx % config.moe_layer_freq == 0) else DeepseekMLP(config, dtype=dtype, device=device))\n\n        # build input layer norm\n        self.input_layernorm = RMSNorm(config.hidden_size,\n                                       config.rms_norm_eps,\n                                       quant_config=quantization_config,\n                                       dtype=dtype,\n                                       device=device)\n\n        # build attention layer norm\n        self.post_attention_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],\n        past_key_value: Optional[List[torch.FloatTensor]],\n        residual: Optional[torch.Tensor] = None,\n        attn_metadata: Any = None,\n    ):\n\n        if residual is None:\n            residual = hidden_states\n            hidden_states = self.input_layernorm(hidden_states)\n        else:\n            hidden_states, residual = self.input_layernorm(hidden_states, residual)\n\n        # Self Attention\n        hidden_states = self.self_attn(\n            hidden_states=hidden_states,\n            rotary_pos_emb=rotary_pos_emb,\n            past_key_value=past_key_value,\n            attn_metadata=attn_metadata,\n        )\n\n        # Fully Connected\n        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)\n        hidden_states = self.mlp(hidden_states)\n\n        outputs = (hidden_states, residual)\n        return outputs\n\n\nclass DeepseekModel(nn.Module):\n    \"\"\"model.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n\n        self.embed_tokens = nn.Embedding(config.vocab_size,\n                                         config.hidden_size,\n                                         self.padding_idx,\n                                         dtype=dtype,\n                                         device=device)\n\n        # build all decode layers\n        self.layers = nn.ModuleList([\n            DeepseekDecoderLayer(config, layer_idx, dtype=dtype, device=device)\n            for layer_idx in range(config.num_hidden_layers)\n        ])\n\n        # build norm\n        self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)\n\n        # build rotary embedding\n        self.rotary_emb = build_rotary_embedding_from_config(config)\n\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        attn_metadata: Any = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n    ):\n        \"\"\"Rewrite of LlamaModel.forward.\"\"\"\n\n        # token embedding\n        if inputs_embeds is None:\n            inputs_embeds = self.embed_tokens(input_ids)\n\n        hidden_states = inputs_embeds\n\n        # rotary embedding\n        cos, sin = self.rotary_emb(hidden_states, position_ids)\n        cos, sin = cos[0], sin[0]\n        rotary_pos_emb = (cos, sin)\n\n        # decoding\n        residual = None\n        for idx, decoder_layer in enumerate(self.layers):\n            past_key_value = past_key_values[idx]\n            hidden_states, residual = decoder_layer(\n                hidden_states,\n                rotary_pos_emb=rotary_pos_emb,\n                past_key_value=past_key_value,\n                residual=residual,\n                attn_metadata=attn_metadata,\n            )\n\n        # norm\n        hidden_states, _ = self.norm(hidden_states, residual)\n\n        return hidden_states\n\n    def get_input_embeddings(self):\n        \"\"\"Get input embeddings.\"\"\"\n        return self.embed_tokens\n\n\nclass DeepseekForCausalLM(nn.Module, CudaGraphMixin):\n    \"\"\"ModelForCausalLM.\"\"\"\n\n    packed_modules_mapping = {\n        'qkv_proj': [\n            'q_proj',\n            'k_proj',\n            'v_proj',\n        ],\n        'gate_up_proj': [\n            'gate_proj',\n            'up_proj',\n        ],\n    }\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 ctx_mgr: StepContextManager,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__()\n        self.config = config\n        self.ctx_mgr = ctx_mgr\n        # build model\n        self.model = DeepseekModel(config, dtype=dtype, device=device)\n        # build lm_head\n        self.lm_head = build_rowwise_linear(config.hidden_size,\n                                            config.vocab_size,\n                                            bias=False,\n                                            dtype=dtype,\n                                            device=device)\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        past_key_values: List[List[torch.Tensor]],\n        attn_metadata: Any = None,\n        inputs_embeds: torch.Tensor = None,\n        **kwargs,\n    ):\n        \"\"\"Model forward, return logits.\"\"\"\n        hidden_states = self.model(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            inputs_embeds=inputs_embeds,\n        )\n        return hidden_states\n\n    def get_logits(self, hidden_states: torch.Tensor):\n        \"\"\"Compute logits of the model output.\"\"\"\n        return self.lm_head(hidden_states)\n\n    def get_input_embeddings(self):\n        \"\"\"Get input embeddings.\"\"\"\n        return self.model.get_input_embeddings()\n\n    def prepare_inputs_for_generation(\n        self,\n        past_key_values: List[List[torch.Tensor]],\n        inputs_embeds: Optional[torch.Tensor] = None,\n        context: StepContext = None,\n    ):\n        \"\"\"Prepare input.\"\"\"\n        # get input_ids, position_ids and attention metadatas\n        input_ids = context.input_ids\n        position_ids = context.position_ids\n        attn_metadata = context.attn_metadata\n\n        # process vision embeddings\n        vision_embeddings = context.input_embeddings\n        vision_embedding_indexing = context.input_embedding_indexing\n        if vision_embeddings is not None and len(vision_embeddings) > 0:\n            if inputs_embeds is None:\n                inputs_embeds = self.get_input_embeddings()(input_ids)\n            inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds)\n\n        # inputs of forward\n        return dict(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            inputs_embeds=inputs_embeds,\n        )\n\n    def _load_weight_experts(self, name: str, loaded_weight: torch.Tensor, params_dict: Dict[str, nn.Parameter],\n                             expert_params_mapping: List):\n        \"\"\"Load weight experts.\"\"\"\n        for (param_name, weight_name, expert_id, shard_id) in expert_params_mapping:\n            if weight_name not in name:\n                continue\n            name = name.replace(weight_name, param_name)\n            param = params_dict[name]\n            load_weight(param, loaded_weight, expert_id=expert_id, shard_id=shard_id)\n            break\n        else:\n            param = params_dict[name]\n            load_weight(param, loaded_weight)\n\n    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):\n        \"\"\"Load weights.\"\"\"\n        # modify from vllm\n        stacked_params_mapping = [\n            # (param_name, shard_name, shard_id)\n            ('.qkv_proj', '.q_proj', 'q'),\n            ('.qkv_proj', '.k_proj', 'k'),\n            ('.qkv_proj', '.v_proj', 'v'),\n            ('.gate_up_proj', '.gate_proj', 0),\n            ('.gate_up_proj', '.up_proj', 1),\n        ]\n\n        num_experts = self.config.n_routed_experts\n        expert_params_mapping = []\n        for exp_id in range(num_experts):\n            gate_param = ('.experts.gate_up', f'.experts.{exp_id}.gate_proj', exp_id, 'gate')\n            up_param = ('.experts.gate_up', f'.experts.{exp_id}.up_proj', exp_id, 'up')\n            down_param = ('.experts.down', f'.experts.{exp_id}.down_proj', exp_id, 'down')\n            expert_params_mapping += [gate_param, up_param, down_param]\n\n        params_dict = dict(self.named_parameters())\n        for name, loaded_weight in weights:\n            if 'rotary_emb.inv_freq' in name:\n                continue\n            if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):\n                continue\n            if self.config.tie_word_embeddings and 'lm_head.weight' in name:\n                continue\n            if '.experts' in name:\n                self._load_weight_experts(name, loaded_weight, params_dict, expert_params_mapping=expert_params_mapping)\n            else:\n                for (param_name, weight_name, shard_id) in stacked_params_mapping:\n                    if weight_name not in name:\n                        continue\n                    name = name.replace(weight_name, param_name)\n                    param = params_dict[name]\n                    load_weight(param, loaded_weight, shard_id=shard_id)\n                    break\n                else:\n                    param = params_dict[name]\n                    load_weight(param, loaded_weight)\n"
  },
  {
    "path": "lmdeploy/pytorch/models/deepseek_mtp.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nfrom typing import Any, Dict, Iterable, List, Optional, Tuple\n\nimport torch\nfrom torch import nn\nfrom transformers.configuration_utils import PretrainedConfig\n\nfrom lmdeploy.pytorch.model_inputs import StepContext, StepContextManager\nfrom lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, RMSNorm, RopeType, SiluAndMul, build_rotary_embedding,\n                                 build_rotary_params)\nfrom lmdeploy.pytorch.nn.linear import (build_colwise_linear, build_down_linear, build_gateup_linear, build_o_proj,\n                                        build_rowwise_linear)\nfrom lmdeploy.pytorch.nn.moe import build_fused_moe\nfrom lmdeploy.pytorch.nn.rotary_embedding import get_rope_parameters, get_rope_theta\nfrom lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight\nfrom lmdeploy.utils import get_logger\n\nfrom .deepseek_v2 import DeepseekV2Attention, DeepseekV2DecoderLayer, MoEGate, yarn_get_mscale\nfrom .utils.cudagraph import CudaGraphMeta, CudaGraphMixin\n\nlogger = get_logger('lmdeploy')\n\n\nclass DeepseekV2BMM(nn.Module):\n    \"\"\"Wrapped bmm.\"\"\"\n\n    def __init__(self, batch: int, in_features: int, out_features: int, dtype: torch.dtype, device: torch.device):\n        super().__init__()\n\n        weight = self.create_weight(batch, in_features, out_features, dtype=dtype, device=device)\n        weight = torch.nn.Parameter(weight, requires_grad=False)\n        self.register_parameter('weight', weight)\n        weight.weight_loader = self.weight_loader\n\n        self.batch = batch\n        self.in_features = in_features\n        self.out_features = out_features\n        self.dtype = dtype\n        self.device = device\n\n    def create_weight(self, batch: int, in_features: int, out_features: int, dtype: torch.dtype, device: torch.device):\n        \"\"\"Create weight.\"\"\"\n        return torch.empty((batch, in_features, out_features), dtype=dtype, device=device)\n\n    def weight_loader(self, param: nn.Parameter, weight: torch.Tensor):\n        \"\"\"Weight loader.\"\"\"\n        param.data.copy_(weight)\n\n    def forward(self, x: torch.Tensor, output: torch.Tensor):\n        \"\"\"forward.\"\"\"\n        torch.bmm(x.transpose(0, 1), self.weight, out=output.transpose(0, 1))\n\n\nclass DeepseekV2Attention(DeepseekV2Attention):\n    \"\"\"Deepseekv2 attention.\"\"\"\n\n    def __init__(self, config: Any, dtype: torch.dtype = None, device: torch.device = None):\n        nn.Module.__init__(self)\n        quantization_config = getattr(config, 'quantization_config', None)\n        self.q_lora_rank = config.q_lora_rank\n        self.hidden_size = config.hidden_size\n        self.num_heads = config.num_attention_heads\n        self.qk_rope_head_dim = config.qk_rope_head_dim\n        self.kv_lora_rank = config.kv_lora_rank\n        self.v_head_dim = config.v_head_dim\n        self.qk_nope_head_dim = config.qk_nope_head_dim\n        self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim\n        num_replicate_kv_heads = getattr(config, 'num_replicate_key_value_heads', 1)\n        num_key_value_heads = getattr(config, 'num_key_value_heads', 1)\n        use_flash_mla = getattr(config, 'use_flash_mla', False)\n\n        if self.q_lora_rank is None:\n            self.q_proj = build_colwise_linear(\n                self.hidden_size,\n                self.num_heads * self.q_head_dim,\n                bias=False,\n                dtype=dtype,\n                device=device,\n                is_tp=False,\n                quant_config=quantization_config,\n                dp_disable_tp=True,\n            )\n        else:\n            self.q_a_proj = build_colwise_linear(\n                self.hidden_size,\n                config.q_lora_rank,\n                bias=config.attention_bias,\n                dtype=dtype,\n                device=device,\n                is_tp=False,\n                quant_config=quantization_config,\n            )\n            self.q_a_layernorm = RMSNorm(config.q_lora_rank,\n                                         1e-6,\n                                         quant_config=quantization_config,\n                                         dtype=dtype,\n                                         device=device)\n            self.q_b_proj = build_colwise_linear(\n                config.q_lora_rank,\n                self.num_heads * self.q_head_dim,\n                bias=False,\n                dtype=dtype,\n                device=device,\n                is_tp=False,\n                quant_config=quantization_config,\n                dp_disable_tp=True,\n            )\n\n        self.kv_a_proj_with_mqa = build_colwise_linear(\n            self.hidden_size,\n            config.kv_lora_rank + config.qk_rope_head_dim,\n            bias=config.attention_bias,\n            dtype=dtype,\n            device=device,\n            is_tp=False,\n            quant_config=quantization_config,\n        )\n        self.kv_a_layernorm = RMSNorm(config.kv_lora_rank,\n                                      1e-6,\n                                      quant_config=quantization_config,\n                                      dtype=dtype,\n                                      device=device)\n        self.kc = DeepseekV2BMM(self.num_heads,\n                                config.qk_nope_head_dim,\n                                config.kv_lora_rank,\n                                dtype=dtype,\n                                device=device)\n\n        self.apply_rotary_pos_emb = ApplyRotaryEmb()\n\n        self.softmax_scale = self.q_head_dim**(-0.5)\n\n        rope_scaling = get_rope_parameters(config)\n        if rope_scaling is not None:\n            mscale_all_dim = rope_scaling.get('mscale_all_dim', 0)\n            scaling_factor = rope_scaling.get('factor', 1.0)\n            if mscale_all_dim:\n                mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)\n                self.softmax_scale = self.softmax_scale * mscale * mscale\n\n        self.attn_fwd = Attention(self.num_heads,\n                                  config.kv_lora_rank + self.qk_rope_head_dim,\n                                  scale=self.softmax_scale,\n                                  num_kv_heads=num_key_value_heads,\n                                  v_head_size=config.kv_lora_rank,\n                                  num_replicate_kv_heads=num_replicate_kv_heads,\n                                  use_flash_mla=use_flash_mla)\n\n        self.vc = DeepseekV2BMM(self.num_heads, config.kv_lora_rank, self.v_head_dim, dtype=dtype, device=device)\n        self.o_proj = build_o_proj(\n            self.num_heads * self.v_head_dim,\n            self.hidden_size,\n            bias=config.attention_bias,\n            dtype=dtype,\n            device=device,\n            is_tp=False,\n            quant_config=quantization_config,\n        )\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],\n        past_key_value: Optional[Tuple[torch.Tensor]] = None,\n        attn_metadata: Any = None,\n    ):\n        \"\"\"Rewrite of LlamaAttention.forward.\"\"\"\n        num_heads = self.num_heads\n        nope_size = self.kv_lora_rank\n        q_len = hidden_states.size(1)\n\n        # qkv_proj\n        query_states, key_states, value_states, q_pe, k_pe = self._qkv_proj(hidden_states, num_heads=num_heads)\n\n        cos, sin = rotary_pos_emb\n        q_pe, k_pe = self.apply_rotary_pos_emb(\n            q_pe,\n            k_pe,\n            cos,\n            sin,\n            inplace=False,\n        )\n        query_states[..., nope_size:] = q_pe\n        key_states[..., nope_size:] = k_pe\n\n        attn_output = self.attn_fwd(\n            query_states,\n            key_states,\n            value_states,\n            past_key_value[0],\n            past_key_value[0][..., :nope_size],\n            attn_metadata,\n            k_scales_zeros=None if len(past_key_value) == 2 else past_key_value[2],\n            v_scales_zeros=None if len(past_key_value) == 2 else past_key_value[3],\n            inplace=True,\n        )\n        attn_bmm_out = attn_output.new_empty(q_len, num_heads, self.v_head_dim)\n\n        self.vc(attn_output, attn_bmm_out)\n        attn_output = attn_bmm_out.flatten(-2, -1)[None]\n        attn_output = self.o_proj(attn_output)\n\n        return attn_output\n\n\nclass DeepseekV2MoE(nn.Module):\n    \"\"\"Deepseek v2 MoE.\"\"\"\n\n    def __init__(self, config: Any, layer_idx, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        quantization_config = getattr(config, 'quantization_config', None)\n        self.hidden_dim = config.hidden_size\n        self.ffn_dim = config.moe_intermediate_size\n        self.num_experts = config.n_routed_experts\n        self.top_k = config.num_experts_per_tok\n        self.norm_topk_prob = config.norm_topk_prob\n        self.routed_scaling_factor = config.routed_scaling_factor\n        self.renormalize = self.top_k > 1 and self.norm_topk_prob\n        self.topk_method = config.topk_method\n        self.n_group = config.n_group\n        self.topk_group = config.topk_group\n\n        self.gate = MoEGate(config, dtype=dtype, device=device, info=None)\n        self.experts = build_fused_moe(\n            self.hidden_dim,\n            self.ffn_dim,\n            self.num_experts,\n            top_k=self.top_k,\n            renormalize=False,\n            dtype=dtype,\n            device=device,\n            all_reduce=False,\n            quant_config=quantization_config,\n            layer_idx=layer_idx,\n        )\n        self.shared_experts = None\n        if config.n_shared_experts is not None:\n            intermediate_size = (config.moe_intermediate_size * config.n_shared_experts)\n            self.shared_experts = DeepseekV2MLP(\n                config=config,\n                intermediate_size=intermediate_size,\n                dtype=dtype,\n                device=device,\n            )\n\n    def forward(self, hidden_states: torch.Tensor):\n        \"\"\"forward.\"\"\"\n        batch_size, sequence_length, hidden_dim = hidden_states.shape\n        hidden_states = hidden_states.view(-1, hidden_dim)\n        topk_weights, topk_ids = self.gate(hidden_states)\n\n        out_states = self.experts(\n            hidden_states,\n            topk_weights,\n            topk_ids,\n        )\n\n        if self.shared_experts is not None:\n            shared_states = self.shared_experts(hidden_states)\n            out_states += shared_states\n        out_states = out_states.reshape(batch_size, sequence_length, -1)\n\n        return out_states\n\n\nclass DeepseekV2MLP(nn.Module):\n    \"\"\"Deepseek v2 mlp.\"\"\"\n\n    def __init__(self,\n                 config: Any,\n                 intermediate_size: int = None,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__()\n\n        quantization_config = getattr(config, 'quantization_config', None)\n        # gate up\n        if intermediate_size is None:\n            intermediate_size = config.intermediate_size\n        self.gate_up_proj = build_gateup_linear(\n            config.hidden_size,\n            [intermediate_size, intermediate_size],\n            bias=False,\n            dtype=dtype,\n            device=device,\n            quant_config=quantization_config,\n            is_tp=False,\n        )\n\n        # silu and mul\n        self.act_fn = SiluAndMul(inplace=True)\n\n        # down\n        self.down_proj = build_down_linear(\n            intermediate_size,\n            config.hidden_size,\n            bias=False,\n            quant_config=quantization_config,\n            dtype=dtype,\n            device=device,\n            is_tp=False,\n            all_reduce=False,\n        )\n\n    def forward(self, x):\n        \"\"\"forward.\"\"\"\n        gate_up = self.gate_up_proj(x)\n        act = self.act_fn(gate_up)\n        return self.down_proj(act)\n\n\nclass DeepseekV2DecoderLayer(DeepseekV2DecoderLayer):\n    \"\"\"Deepseekv2 decoder layer.\"\"\"\n\n    def __init__(self, config: Any, layer_idx: int, dtype: torch.dtype = None, device: torch.device = None):\n        nn.Module.__init__(self)\n        self.layer_idx = layer_idx\n        quantization_config = None\n\n        # build attention layer\n        self.self_attn = DeepseekV2Attention(config, dtype=dtype, device=device)\n\n        # mlp\n        self.mlp = (DeepseekV2MoE(config, layer_idx, dtype=dtype, device=device) if\n                    (config.n_routed_experts is not None and layer_idx >= config.first_k_dense_replace\n                     and layer_idx % config.moe_layer_freq == 0) else DeepseekV2MLP(config, dtype=dtype, device=device))\n\n        # build input layer norm\n        self.input_layernorm = RMSNorm(config.hidden_size,\n                                       config.rms_norm_eps,\n                                       quant_config=quantization_config,\n                                       dtype=dtype,\n                                       device=device)\n\n        # build attention layer norm\n        self.post_attention_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)\n\n\n# modify from vllm\n\n\nclass SharedHead(nn.Module):\n    \"\"\"Deepseekv2 shared head.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)\n        # build lm_head\n        self.head = build_rowwise_linear(config.hidden_size, config.vocab_size, bias=False, dtype=dtype, device=device)\n\n    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        return self.norm(hidden_states)\n\n\ndef build_deepseek_rotary_embedding(config: PretrainedConfig):\n    \"\"\"Build deepseek rotary embedding.\"\"\"\n    emb_type = RopeType.LinearScaling\n    rope_dim = config.qk_rope_head_dim if getattr(config, 'use_mla', True) else (config.hidden_size //\n                                                                                 config.num_attention_heads)\n    rope_max_pos_emb = config.max_position_embeddings\n    rope_base = get_rope_theta(config)\n\n    rope_params = dict(emb_type=emb_type, dim=rope_dim, max_position_embeddings=rope_max_pos_emb, base=rope_base)\n    update_params = build_rotary_params(config)\n    rope_params.update(update_params)\n    return build_rotary_embedding(**rope_params)\n\n\nclass DeepSeekMultiTokenPredictorLayer(nn.Module):\n\n    def __init__(\n        self,\n        config: PretrainedConfig,\n        layer_idx: int,\n        dtype: torch.dtype = None,\n        device: torch.device = None,\n        decoder_layer_cls=DeepseekV2DecoderLayer,\n        build_rotary_embedding_func=build_deepseek_rotary_embedding,\n    ) -> None:\n        super().__init__()\n        self.config = config\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n        self.embed_tokens = nn.Embedding(config.vocab_size,\n                                         config.hidden_size,\n                                         self.padding_idx,\n                                         dtype=dtype,\n                                         device=device)\n        quantization_config = getattr(config, 'quantization_config', None)\n\n        self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, dtype=dtype, device=device)\n        self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, dtype=dtype, device=device)\n        self.eh_proj = build_colwise_linear(\n            config.hidden_size * 2,\n            config.hidden_size,\n            bias=False,\n            dtype=dtype,\n            device=device,\n            is_tp=False,\n            quant_config=quantization_config,\n            dp_disable_tp=True,\n        )\n\n        self.shared_head = SharedHead(config=config, dtype=dtype, device=device)\n\n        self.mtp_block = decoder_layer_cls(config, layer_idx=layer_idx, dtype=dtype, device=device)\n\n        self.rotary_emb = build_rotary_embedding_func(config)\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        previous_hidden_states: torch.Tensor,\n        past_key_value: List[List[torch.Tensor]],\n        inputs_embeds: Optional[torch.Tensor] = None,\n        attn_metadata: Any = None,\n        spec_step_index: int = 0,\n    ) -> torch.Tensor:\n        if inputs_embeds is None:\n            inputs_embeds = self.embed_tokens(input_ids)\n        assert inputs_embeds is not None\n\n        # masking inputs at position 0, as not needed by MTP\n        inputs_embeds[position_ids == 0] = 0\n        inputs_embeds = self.enorm(inputs_embeds)\n        previous_hidden_states = self.hnorm(previous_hidden_states)\n\n        hidden_states = self.eh_proj(torch.cat([inputs_embeds, previous_hidden_states], dim=-1))\n\n        # rotary emb\n        cos, sin = self.rotary_emb(hidden_states, position_ids)\n        cos, sin = cos[0], sin[0]\n        rotary_pos_emb = (cos, sin)\n\n        hidden_states, residual = self.mtp_block(\n            hidden_states,\n            rotary_pos_emb,\n            past_key_value,\n            attn_metadata=attn_metadata,\n        )\n        hidden_states = residual + hidden_states\n        return hidden_states\n\n\nclass DeepSeekMultiTokenPredictor(nn.Module):\n\n    def __init__(\n        self,\n        config: PretrainedConfig,\n        dtype: torch.dtype = None,\n        device: torch.device = None,\n        decoder_layer_cls=DeepseekV2DecoderLayer,\n        build_rotary_embedding_func=build_deepseek_rotary_embedding,\n    ):\n        super().__init__()\n        self.config = config\n        self.mtp_start_layer_idx = config.num_hidden_layers\n        self.num_mtp_layers = config.num_nextn_predict_layers\n        # to map the exact layer index from weights\n        self.layers = torch.nn.ModuleDict({\n            str(idx):\n            DeepSeekMultiTokenPredictorLayer(\n                config,\n                idx,\n                dtype=dtype,\n                device=device,\n                decoder_layer_cls=decoder_layer_cls,\n                build_rotary_embedding_func=build_rotary_embedding_func,\n            )\n            for idx in range(self.mtp_start_layer_idx, self.mtp_start_layer_idx + self.num_mtp_layers)\n        })\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        previous_hidden_states: torch.Tensor,\n        past_key_values: List[List[torch.Tensor]],\n        inputs_embeds: Optional[torch.Tensor] = None,\n        attn_metadata: Any = None,\n        spec_step_idx: int = 0,\n    ) -> torch.Tensor:\n        current_step_idx = (spec_step_idx % self.num_mtp_layers)\n        layer_idx = self.mtp_start_layer_idx + current_step_idx\n        past_key_value = past_key_values[current_step_idx]\n        return self.layers[str(layer_idx)](\n            input_ids,\n            position_ids,\n            previous_hidden_states,\n            past_key_value,\n            inputs_embeds=inputs_embeds,\n            attn_metadata=attn_metadata,\n            spec_step_index=current_step_idx,\n        )\n\n    def get_logits(\n        self,\n        hidden_states: torch.Tensor,\n        spec_step_idx: int = 0,\n    ) -> torch.Tensor:\n        current_step_idx = (spec_step_idx % self.num_mtp_layers)\n        mtp_layer = self.layers[str(self.mtp_start_layer_idx + current_step_idx)]\n\n        hidden_states = mtp_layer.shared_head(hidden_states)\n        logits = mtp_layer.shared_head.head(hidden_states)\n        return logits\n\n\nclass DeepseekMTPModel(nn.Module, CudaGraphMixin):\n\n    def __init__(\n        self,\n        config: PretrainedConfig,\n        ctx_mgr: StepContextManager,\n        dtype: torch.dtype = None,\n        device: torch.device = None,\n        decoder_layer_cls=DeepseekV2DecoderLayer,\n        build_rotary_embedding_func=build_deepseek_rotary_embedding,\n    ):\n        super().__init__()\n        self.config = config\n        self.quantization_config = getattr(config, 'quantization_config', None)\n        self.dtype = dtype\n        self.ctx_mgr = ctx_mgr\n        self.model = DeepSeekMultiTokenPredictor(config,\n                                                 dtype=dtype,\n                                                 device=device,\n                                                 decoder_layer_cls=decoder_layer_cls,\n                                                 build_rotary_embedding_func=build_rotary_embedding_func)\n\n        self._load_buffers = dict()\n\n    def get_logits(self, hidden_states: torch.Tensor, spec_step_idx: int = 0):\n        \"\"\"Compute logits of the model output.\"\"\"\n        return self.model.get_logits(hidden_states, spec_step_idx=spec_step_idx)\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        target_hidden_states: torch.Tensor,\n        past_key_values: List[List[torch.Tensor]],\n        attn_metadata: Any = None,\n        inputs_embeds: Optional[torch.Tensor] = None,\n        spec_step_idx: int = 0,\n    ) -> torch.Tensor:\n        hidden_states = self.model(input_ids,\n                                   position_ids,\n                                   target_hidden_states,\n                                   inputs_embeds=inputs_embeds,\n                                   past_key_values=past_key_values,\n                                   attn_metadata=attn_metadata,\n                                   spec_step_idx=spec_step_idx)\n        return hidden_states\n\n    def make_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs):\n        \"\"\"Make cudagraph buffers from forward inputs.\"\"\"\n        max_tokens = graph_meta.max_tokens\n\n        input_buffers = super().make_buffers_cudagraph(graph_meta=graph_meta, **kwargs)\n        input_buffers['target_hidden_states'] = input_buffers['input_ids'].new_zeros(1,\n                                                                                     max_tokens,\n                                                                                     self.config.hidden_size,\n                                                                                     dtype=self.dtype)\n\n        return input_buffers\n\n    def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, input_ids: torch.Tensor, **kwargs):\n        \"\"\"Fill cudagraph buffers from forward inputs.\"\"\"\n\n        new_inputs = super().fill_buffers_cudagraph(graph_meta=graph_meta, input_ids=input_ids, **kwargs)\n\n        num_tokens = input_ids.size(-1)\n        input_buffers = graph_meta.input_buffers\n        target_hidden_states = kwargs.get('target_hidden_states')\n        assert target_hidden_states is not None\n        input_buffers['target_hidden_states'][:, :num_tokens] = target_hidden_states\n        new_inputs['target_hidden_states'] = input_buffers['target_hidden_states']\n        return new_inputs\n\n    def prepare_inputs_for_generation(\n        self,\n        past_key_values: List[List[torch.Tensor]],\n        inputs_embeds: Optional[torch.Tensor] = None,\n        context: StepContext = None,\n    ):\n        \"\"\"Prepare input.\"\"\"\n        input_ids = context.input_ids\n        position_ids = context.position_ids\n        attn_metadata = context.attn_metadata\n        target_hidden_states = context.target_hidden_states\n        return dict(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            inputs_embeds=inputs_embeds,\n            target_hidden_states=target_hidden_states,\n        )\n\n    def _load_weight_experts(self, name: str, loaded_weight: torch.Tensor, params_dict: Dict[str, nn.Parameter],\n                             expert_params_mapping: List):\n        \"\"\"Load weight experts.\"\"\"\n        for (param_name, weight_name, expert_id, shard_id) in expert_params_mapping:\n            if weight_name not in name:\n                continue\n            name = name.replace(weight_name, param_name)\n            param = params_dict[name]\n            load_weight(param, loaded_weight, expert_id=expert_id, shard_id=shard_id)\n            break\n        else:\n            param = params_dict[name]\n            load_weight(param, loaded_weight)\n\n    def _load_weight_attention(self, name: str, loaded_weight: torch.Tensor, params_dict: Dict[str, nn.Parameter],\n                               update_pe_mapping: List):\n        \"\"\"Load weight attention.\"\"\"\n        device = next(iter(params_dict.values())).device\n\n        def __update_pe(weight, head_dim: int, pe_dim_offset: int):\n            # (num_heads, q_head_dim, input_dim)\n            weight = weight.unflatten(0, (-1, head_dim))\n            # (num_heads, nope_head_dim, input_dim)\n            w_pe = weight[:, pe_dim_offset:]\n            # (num_heads, nope_head_dim//2, 2, input_dim)\n            new_w_pe = w_pe.unflatten(1, (-1, 2))\n            # (num_heads, nope_head_dim, input_dim)\n            new_w_pe = new_w_pe.transpose(1, 2).flatten(1, 2)\n            weight[:, pe_dim_offset:] = new_w_pe\n            weight = weight.flatten(0, 1)\n            return weight\n\n        def __load_kcvc(name: str, weight: torch.Tensor):\n            \"\"\"Load kc and vc from weight.\"\"\"\n            config = self.config\n            v_head_dim = config.v_head_dim\n            qk_nope_head_dim = config.qk_nope_head_dim\n            w_kc, w_vc = weight.unflatten(0, (-1, qk_nope_head_dim + v_head_dim)).split([qk_nope_head_dim, v_head_dim],\n                                                                                        dim=1)\n            w_vc = w_vc.transpose(1, 2).contiguous()\n            kc_param_name = name.replace('.kv_b_proj', '.kc')\n            param_kc = params_dict[kc_param_name]\n            load_weight(param_kc, w_kc)\n            vc_param_name = name.replace('.kv_b_proj', '.vc')\n            param_vc = params_dict[vc_param_name]\n            load_weight(param_vc, w_vc)\n\n        def __dequant_weight(weight: torch.Tensor, scale: torch.Tensor, dtype: torch.dtype):\n            \"\"\"Dequant weight.\"\"\"\n            dim_w0, dim_w1 = weight.shape\n            dim_s0, dim_s1 = scale.shape\n            assert dim_w0 % dim_s0 == 0\n            assert dim_w1 % dim_s1 == 0\n            group0 = dim_w0 // dim_s0\n            group1 = dim_w1 // dim_s1\n            weight = weight.reshape(dim_s0, group0, dim_s1, group1)\n            scale = scale.reshape(dim_s0, 1, dim_s1, 1)\n            weight = weight.to(scale.dtype) * scale\n            weight = weight.to(dtype)\n            weight = weight.reshape(dim_w0, dim_w1)\n            return weight\n\n        def __load_kcvc_blocked_fp8(name: str, loaded_weight: torch.Tensor):\n            \"\"\"Dequant weight.\"\"\"\n            if name.endswith('.weight'):\n                weight_name = name\n                scale_name = name.replace('.weight', '.scale')\n            elif name.endswith('.weight_scale_inv'):\n                weight_name = name.replace('.weight_scale_inv', '.weight')\n                scale_name = name\n            self._load_buffers[name] = loaded_weight\n            if (weight_name in self._load_buffers and scale_name in self._load_buffers):\n                weight = self._load_buffers.pop(weight_name)\n                scale = self._load_buffers.pop(scale_name)\n                kc_param_name = weight_name.replace('.kv_b_proj', '.kc')\n                dtype = params_dict[kc_param_name].dtype\n                weight = __dequant_weight(weight, scale, dtype)\n                __load_kcvc(weight_name, weight)\n\n        for (mod_name, head_dim, pe_dim_offset) in update_pe_mapping:\n            if mod_name not in name:\n                continue\n            if name.endswith('.weight_scale_inv'):\n                weight = loaded_weight\n            else:\n                loaded_weight = loaded_weight.to(device)\n                weight = __update_pe(loaded_weight, head_dim, pe_dim_offset)\n            param = params_dict[name]\n            load_weight(param, weight)\n            break\n        else:\n            if '.kv_b_proj' in name:\n                quantization_config = self.quantization_config\n                quant_method = None\n                if quantization_config is not None:\n                    quant_method = quantization_config.get('quant_method')\n\n                loaded_weight = loaded_weight.to(device)\n                if quant_method == 'fp8':\n                    # update blocked fp8 weight\n                    __load_kcvc_blocked_fp8(name, loaded_weight)\n                else:\n                    __load_kcvc(name, loaded_weight)\n            else:\n                param = params_dict[name]\n                load_weight(param, loaded_weight)\n\n    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):\n        \"\"\"Load weights.\"\"\"\n\n        def __skip_nextn(name, nextn_keys):\n            for nextn_key in nextn_keys:\n                if nextn_key in name:\n                    return True\n            return False\n\n        stacked_params_mapping = [\n            # (param_name, shard_name, shard_id)\n            ('.gate_up_proj', '.gate_proj', 0),\n            ('.gate_up_proj', '.up_proj', 1),\n        ]\n\n        config = self.config\n\n        qk_rope_head_dim = config.qk_rope_head_dim\n        kv_lora_rank = config.kv_lora_rank\n        qk_nope_head_dim = config.qk_nope_head_dim\n        q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim\n        kv_dim = kv_lora_rank + qk_rope_head_dim\n        update_pe_mapping = [('q_proj', q_head_dim, qk_nope_head_dim), ('q_b_proj', q_head_dim, qk_nope_head_dim),\n                             ('kv_a_proj_with_mqa', kv_dim, kv_lora_rank)]\n\n        num_experts = self.config.n_routed_experts\n        expert_params_mapping = []\n        for exp_id in range(num_experts):\n            gate_param = ('.experts.gate_up', f'.experts.{exp_id}.gate_proj', exp_id, 'gate')\n            up_param = ('.experts.gate_up', f'.experts.{exp_id}.up_proj', exp_id, 'up')\n            down_param = ('.experts.down', f'.experts.{exp_id}.down_proj', exp_id, 'down')\n            expert_params_mapping += [gate_param, up_param, down_param]\n\n        num_hidden_layers = self.config.num_hidden_layers\n\n        num_nextn_predict_layers = getattr(self.config, 'num_nextn_predict_layers', 1)\n        nextn_keys = [f'.layers.{num_hidden_layers+i}' for i in range(num_nextn_predict_layers)]\n\n        params_dict = dict(self.named_parameters())\n        for name, loaded_weight in weights:\n            # keep nextn\n            if not __skip_nextn(name, nextn_keys):\n                continue\n            if '.layers' in name:\n                layer_idx = int(name.split('layers.')[1].split('.')[0])\n                name = self._rewrite_spec_layer_name(layer_idx, name)\n            if '.experts' in name:\n                self._load_weight_experts(name, loaded_weight, params_dict, expert_params_mapping=expert_params_mapping)\n            elif '.self_attn' in name and getattr(config, 'use_mla', True):\n                # attention\n                self._load_weight_attention(name, loaded_weight, params_dict, update_pe_mapping)\n            else:\n                # other\n                for (param_name, weight_name, shard_id) in stacked_params_mapping:\n                    if weight_name not in name:\n                        continue\n                    name = name.replace(weight_name, param_name)\n                    param = params_dict[name]\n                    load_weight(param, loaded_weight, shard_id=shard_id)\n                    break\n                else:\n                    param = params_dict[name]\n                    load_weight(param, loaded_weight)\n\n    def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str:\n        \"\"\"Rewrite the weight name to match the format of the original model.\n\n        Add .mtp_block for modules in transformer layer block for spec layer\n        \"\"\"\n        spec_layer_weight_names = ['embed_tokens', 'enorm', 'hnorm', 'eh_proj', 'shared_head']\n        spec_layer_weight = False\n        for weight_name in spec_layer_weight_names:\n            if weight_name in name:\n                spec_layer_weight = True\n                break\n        if not spec_layer_weight:\n            # treat rest weights as weights for transformer layer block\n            name = name.replace(f'model.layers.{spec_layer}.', f'model.layers.{spec_layer}.mtp_block.')\n        return name\n"
  },
  {
    "path": "lmdeploy/pytorch/models/deepseek_v2.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nimport math\nfrom copy import deepcopy\nfrom enum import Enum, auto\nfrom os import getenv\nfrom typing import Any, Dict, Iterable, List, Optional, Tuple\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn\n\nimport lmdeploy.pytorch.distributed as dist\nfrom lmdeploy.pytorch.distributed import get_dist_manager, get_ep_world_rank, get_tp_world_rank\nfrom lmdeploy.pytorch.model_inputs import StepContext, StepContextManager, get_step_ctx_manager\nfrom lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, ParallelEmbedding, RMSNorm, RopeType, SiluAndMul,\n                                 build_rotary_embedding, build_rotary_params)\nfrom lmdeploy.pytorch.nn.eplb import EPLBDispatchInfo, EPLBManager\nfrom lmdeploy.pytorch.nn.linear import (build_colwise_linear, build_down_linear, build_gateup_linear, build_o_proj,\n                                        build_rowwise_linear)\nfrom lmdeploy.pytorch.nn.moe import MoeType, SoftmaxTopK, build_fused_moe\nfrom lmdeploy.pytorch.nn.rotary_embedding import get_rope_parameters, get_rope_theta\nfrom lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight\n\nfrom .utils.cudagraph import CudaGraphMixin\n\n\n# microbatch\nclass ExecType(Enum):\n    \"\"\"Batch exec type.\"\"\"\n    One = auto()\n    Two0101 = auto()\n    Two0110 = auto()\n    TwoLikeOne = auto()\n    TwoPrefill = auto()\n    TwoDecode = auto()\n\n\nclass BatchWorker:\n\n    def __init__(self, tag: str, generator):\n        self._tag = tag\n        self._generator = generator\n        self._count = 0\n        self.output = None\n\n    def next(self):\n        assert not self.done\n\n        try:\n            next(self._generator)\n        except StopIteration as e:\n            assert e.value is not None\n            self.output = e.value\n\n        self._count += 1\n\n    @property\n    def done(self):\n        return self.output is not None\n\n\ndef execute_batch(inputs: list, fn, delta_stages: int = 0, exec_type: ExecType = ExecType.One, extern_tag: str = ''):\n    worker_list = [BatchWorker(str(idx), fn(**input, tag=str(idx) + extern_tag)) for idx, input in enumerate(inputs)]\n\n    if exec_type == ExecType.One:\n        assert len(inputs) == 1\n        i = 0\n        while not worker_list[0].done:\n            worker_list[0].next()\n            i += 1\n\n    if exec_type == ExecType.TwoLikeOne:\n        assert len(inputs) == 2\n        i = 0\n        while not worker_list[0].done:\n            worker_list[0].next()\n            i += 1\n        i = 0\n        while not worker_list[1].done:\n            worker_list[1].next()\n            i += 1\n\n    if exec_type == ExecType.Two0101:\n        assert len(inputs) == 2\n\n        for _ in range(delta_stages):\n            worker_list[0].next()\n        i = 0\n        while not worker_list[0].done:\n            worker_list[0].next()\n            worker_list[1].next()\n            i += 1\n\n        while not worker_list[1].done:\n            worker_list[1].next()\n\n    if exec_type == ExecType.Two0110:\n        assert len(inputs) == 2\n\n        for _ in range(delta_stages):\n            worker_list[0].next()\n        i = 0\n        while not worker_list[0].done:\n            if i % 2 == 0:\n                worker_list[0].next()\n                worker_list[1].next()\n            else:\n                worker_list[1].next()\n                worker_list[0].next()\n            i += 1\n\n        while not worker_list[1].done:\n            worker_list[1].next()\n\n    if exec_type == ExecType.TwoPrefill:\n        \"\"\"\n        before:\n        A-attn0->A-attn1\n        roll:\n        B-attn0->B-attn1->A-dis->A-dis_wait->A-moe->B-dis->B-dis_wait->A-comb->\n        B-moe->(A-share->A-comb_wait)->B-comb->A-attn0->A-attn1->(B-share->B-comb_wait)\n        after:\n        B-dis_wait->B-moe->B-comb->B-comb_wait and end\n        \"\"\"\n        assert len(inputs) == 2 and delta_stages in [0, 2]\n\n        for _ in range(2):\n            worker_list[0].next()\n\n        pipeline = [\n            '1-attn0', '1-attn1', '0-dis', '0-dis_wait', '0-moe', '1-dis', '1-dis_wait', '0-comb', '1-moe',\n            '0-share+0-comb_wait', '1-comb', '0-attn0', '0-attn1', '1-share+1-comb_wait'\n        ]\n        pipline_length = len(pipeline)\n        i = 0\n        while not worker_list[0].done:\n            worker_list[int(pipeline[i % pipline_length][0])].next()\n            i += 1\n\n        while not worker_list[1].done:\n            worker_list[1].next()\n\n    if exec_type == ExecType.TwoDecode:\n        \"\"\"\n        before:\n        A-attn0->A-attn1->(A-dis->A-share)\n        roll:\n        B-attn0->A-dis_wait->A-moe->A-comb->B-attn1->A-comb_wait->(B-dis->B-share)->\n        A-attn0->B-dis_wait->B-moe->B-comb->A-attn1->B-comb_wait->(A-dis->A-share)\n        after:\n        B-dis_wait->B-moe->B-comb->B-comb_wait and end\n        \"\"\"\n        assert len(inputs) == 2 and delta_stages in [0, 3]\n\n        for _ in range(3):\n            worker_list[0].next()\n\n        pipeline = [\n            '1-attn0', '0-dis_wait', '0-moe', '0-comb', '1-attn1', '0-comb_wait', '1-dis+1-share', '0-attn0',\n            '1-dis_wait', '1-moe', '1-comb', '0-attn1', '1-comb_wait', '0-dis+0-share'\n        ]\n        pipline_length = len(pipeline)\n        i = 0\n        while not worker_list[0].done:\n            worker_list[int(pipeline[i % pipline_length][0])].next()\n            i += 1\n\n        while not worker_list[1].done:\n            worker_list[1].next()\n\n    for worker in worker_list:\n        assert worker.done\n    return [worker.output for worker in worker_list]\n\n\ndef get_new_meta(attn_metadata, start_idx: int, end_idx: int):\n    new_attn_metadata = deepcopy(attn_metadata)\n    new_attn_metadata.block_offsets = attn_metadata.block_offsets[start_idx:end_idx, ...]\n    new_attn_metadata.q_start_loc = attn_metadata.q_start_loc[start_idx:end_idx] - attn_metadata.q_start_loc[start_idx]\n    new_attn_metadata.kv_start_loc = attn_metadata.kv_start_loc[start_idx:end_idx] - \\\n        attn_metadata.kv_start_loc[start_idx] if attn_metadata.kv_start_loc is not None else None\n    new_attn_metadata.q_seqlens = attn_metadata.q_seqlens[start_idx:end_idx]\n    new_attn_metadata.kv_seqlens = attn_metadata.kv_seqlens[start_idx:end_idx] \\\n        if attn_metadata.kv_seqlens is not None else None\n    new_attn_metadata.kv_flatten_size = sum(new_attn_metadata.kv_seqlens.tolist()) \\\n        if attn_metadata.kv_flatten_size is not None else None\n    # create buffers for flash mla\n    if attn_metadata.num_splits is not None:\n        Attention.update_meta_flashmla(new_attn_metadata,\n                                       get_step_ctx_manager().current_context().model_config.num_attention_heads)\n    return new_attn_metadata\n\n\ndef get_new_rotary_pos_emb(rotary_pos_emb, start_loc, end_loc):\n    new_rotary_pos_emb = (rotary_pos_emb[0][start_loc:end_loc, ...].contiguous(), rotary_pos_emb[1][start_loc:end_loc,\n                                                                                                    ...].contiguous())\n    return new_rotary_pos_emb\n\n\ndef get_new_input(hidden_states, rotary_pos_emb, past_key_values, residual, attn_metadata, start_idx, end_idx,\n                  start_loc, end_loc):\n    new_hidden_states = hidden_states[:, start_loc:end_loc, :].contiguous()\n    new_rotary_pos_emb = get_new_rotary_pos_emb(rotary_pos_emb, start_loc, end_loc)\n    new_past_key_values = past_key_values\n    new_residual = residual[:, start_loc:end_loc, :].contiguous() if residual is not None else None\n    new_attn_metadata = get_new_meta(attn_metadata, start_idx, end_idx)\n    return new_hidden_states, new_rotary_pos_emb, new_past_key_values, new_residual, new_attn_metadata\n\n\ndef get_split_flags(attn_metadata, num=2):\n    \"\"\"Split flags for seqlens and startloc, support 2 only.\"\"\"\n    assert num == 2\n    if attn_metadata.is_decoding:\n        batch_size = attn_metadata.q_start_loc.numel()\n        flag_a = {\n            'start_idx': 0,\n            'end_idx': batch_size // 2,\n            'start_loc': 0,\n            'end_loc': batch_size // 2,\n        }\n        flag_b = {\n            'start_idx': batch_size // 2,\n            'end_idx': batch_size,\n            'start_loc': batch_size // 2,\n            'end_loc': batch_size,\n        }\n    else:\n        q_start_loc = attn_metadata.q_start_loc.tolist()\n        q_seqlens = attn_metadata.q_seqlens.tolist()\n        total_len = sum(q_seqlens)\n        min_diff = total_len\n        split_flag = 1\n        for idx in range(1, len(q_seqlens)):\n            diff = abs(sum(q_seqlens[:idx]) - sum(q_seqlens[idx:]))\n            if diff < min_diff:\n                min_diff = diff\n                split_flag = idx\n        flag_a = {\n            'start_idx': 0,\n            'end_idx': split_flag,\n            'start_loc': q_start_loc[0],\n            'end_loc': q_start_loc[split_flag],\n        }\n        flag_b = {\n            'start_idx': split_flag,\n            'end_idx': len(q_seqlens),\n            'start_loc': q_start_loc[split_flag],\n            'end_loc': q_start_loc[-1] + q_seqlens[-1],\n        }\n    return [flag_a, flag_b]\n\n\ndef split_input(hidden_states,\n                rotary_pos_emb,\n                past_key_values,\n                residual,\n                attn_metadata,\n                moe_start_idx,\n                moe_end_idx,\n                num=2):\n    \"\"\"Split input, support 1 or 2 only.\"\"\"\n    # one batch\n    if num == 1:\n        input = {\n            'hidden_states': hidden_states,\n            'rotary_pos_emb': rotary_pos_emb,\n            'past_key_values': past_key_values,\n            'residual': residual,\n            'attn_metadata': attn_metadata,\n            'start_idx': moe_start_idx,\n            'end_idx': moe_end_idx\n        }\n        extern_tag = 'D' if attn_metadata.is_decoding else 'P'\n        return [input], ExecType.One, 0, extern_tag\n    else:\n        # two batch or more\n        flag_list = get_split_flags(attn_metadata, num=num)\n\n        inputs = []\n        for flag in flag_list:\n            (hidden_states_splited, rotary_pos_emb_splited, past_key_values_splited, residual_splited,\n             attn_metadata_splited) = get_new_input(hidden_states, rotary_pos_emb, past_key_values, residual,\n                                                    attn_metadata, flag['start_idx'], flag['end_idx'],\n                                                    flag['start_loc'], flag['end_loc'])\n            input = {\n                'hidden_states': hidden_states_splited,\n                'rotary_pos_emb': rotary_pos_emb_splited,\n                'past_key_values': past_key_values,\n                'residual': residual_splited,\n                'attn_metadata': attn_metadata_splited,\n                'start_idx': moe_start_idx,\n                'end_idx': moe_end_idx\n            }\n            inputs.append(input)\n\n        if attn_metadata.is_decoding:\n            exec_type = ExecType.TwoDecode\n            delta_stages = 0\n            extern_tag = 'D'\n        else:\n            exec_type = ExecType.TwoPrefill\n            delta_stages = 0\n            extern_tag = 'P'\n\n        return inputs, exec_type, delta_stages, extern_tag\n\n\ndef merge_output(output_list):\n    # one batch\n    if len(output_list) == 1:\n        return output_list[0]\n    # two batch or more\n    hidden_states = torch.concat([output[0] for output in output_list], dim=1)\n    residual = None\n    if output_list[0][1] is not None:\n        residual = torch.concat([output[1] for output in output_list], dim=1)\n    return hidden_states, residual\n\n\ndef yarn_get_mscale(scale=1, mscale=1):\n    if scale <= 1:\n        return 1.0\n    return 0.1 * mscale * math.log(scale) + 1.0\n\n\nclass DeepseekV2BMM(nn.Module):\n    \"\"\"Wrapped bmm.\"\"\"\n\n    def __init__(self, batch: int, in_features: int, out_features: int, dtype: torch.dtype, device: torch.device):\n        super().__init__()\n        batch = self._update_batch(batch)\n\n        weight = self.create_weight(batch, in_features, out_features, dtype=dtype, device=device)\n        weight = torch.nn.Parameter(weight, requires_grad=False)\n        self.register_parameter('weight', weight)\n        weight.weight_loader = self.weight_loader\n\n        self.batch = batch\n        self.in_features = in_features\n        self.out_features = out_features\n        self.dtype = dtype\n        self.device = device\n\n    def _update_batch(self, batch: int):\n        \"\"\"Update out features.\"\"\"\n        world_size, _ = get_tp_world_rank('attn')\n        batch = batch // world_size\n        return batch\n\n    def create_weight(self, batch: int, in_features: int, out_features: int, dtype: torch.dtype, device: torch.device):\n        \"\"\"Create weight.\"\"\"\n        return torch.empty((batch, in_features, out_features), dtype=dtype, device=device)\n\n    def weight_loader(self, param: nn.Parameter, weight: torch.Tensor):\n        \"\"\"Weight loader.\"\"\"\n        world_size, rank = get_tp_world_rank('attn')\n        weight = weight.chunk(world_size, 0)[rank]\n        param.data.copy_(weight)\n\n    def forward(self, x: torch.Tensor, output: torch.Tensor):\n        \"\"\"forward.\"\"\"\n        torch.bmm(x.transpose(0, 1), self.weight, out=output.transpose(0, 1))\n\n\nclass DeepseekV2Attention(nn.Module):\n    \"\"\"Deepseekv2 attention.\"\"\"\n\n    def __init__(self, config: Any, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        quantization_config = getattr(config, 'quantization_config', None)\n        self.q_lora_rank = config.q_lora_rank\n        self.hidden_size = config.hidden_size\n        self.num_heads = config.num_attention_heads\n        self.qk_rope_head_dim = config.qk_rope_head_dim\n        self.kv_lora_rank = config.kv_lora_rank\n        self.v_head_dim = config.v_head_dim\n        self.qk_nope_head_dim = config.qk_nope_head_dim\n        self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim\n        num_replicate_kv_heads = getattr(config, 'num_replicate_key_value_heads', 1)\n        num_key_value_heads = getattr(config, 'num_key_value_heads', 1)\n        use_flash_mla = getattr(config, 'use_flash_mla', False)\n\n        if self.q_lora_rank is None:\n            self.q_proj = build_colwise_linear(\n                self.hidden_size,\n                self.num_heads * self.q_head_dim,\n                bias=False,\n                dtype=dtype,\n                device=device,\n                is_tp=True,\n                quant_config=quantization_config,\n                dp_disable_tp=True,\n            )\n        else:\n            self.q_a_proj = build_colwise_linear(\n                self.hidden_size,\n                config.q_lora_rank,\n                bias=config.attention_bias,\n                dtype=dtype,\n                device=device,\n                is_tp=False,\n                quant_config=quantization_config,\n            )\n            self.q_a_layernorm = RMSNorm(config.q_lora_rank,\n                                         1e-6,\n                                         quant_config=quantization_config,\n                                         dtype=dtype,\n                                         device=device)\n            self.q_b_proj = build_colwise_linear(\n                config.q_lora_rank,\n                self.num_heads * self.q_head_dim,\n                bias=False,\n                dtype=dtype,\n                device=device,\n                is_tp=True,\n                quant_config=quantization_config,\n                dp_disable_tp=True,\n            )\n\n        self.kv_a_proj_with_mqa = build_colwise_linear(\n            self.hidden_size,\n            config.kv_lora_rank + config.qk_rope_head_dim,\n            bias=config.attention_bias,\n            dtype=dtype,\n            device=device,\n            is_tp=False,\n            quant_config=quantization_config,\n        )\n        self.kv_a_layernorm = RMSNorm(config.kv_lora_rank,\n                                      1e-6,\n                                      quant_config=quantization_config,\n                                      dtype=dtype,\n                                      device=device)\n        self.kc = DeepseekV2BMM(self.num_heads,\n                                config.qk_nope_head_dim,\n                                config.kv_lora_rank,\n                                dtype=dtype,\n                                device=device)\n\n        self.apply_rotary_pos_emb = ApplyRotaryEmb()\n\n        self.softmax_scale = self.q_head_dim**(-0.5)\n\n        rope_scaling = get_rope_parameters(config)\n        if rope_scaling is not None:\n            mscale_all_dim = rope_scaling.get('mscale_all_dim', 0)\n            scaling_factor = rope_scaling.get('factor', 1.0)\n            if mscale_all_dim:\n                mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)\n                self.softmax_scale = self.softmax_scale * mscale * mscale\n\n        self.attn_fwd = Attention(self.num_heads,\n                                  config.kv_lora_rank + self.qk_rope_head_dim,\n                                  scale=self.softmax_scale,\n                                  num_kv_heads=num_key_value_heads,\n                                  v_head_size=config.kv_lora_rank,\n                                  num_replicate_kv_heads=num_replicate_kv_heads,\n                                  use_flash_mla=use_flash_mla)\n\n        self.vc = DeepseekV2BMM(self.num_heads, config.kv_lora_rank, self.v_head_dim, dtype=dtype, device=device)\n        self.o_proj = build_o_proj(\n            self.num_heads * self.v_head_dim,\n            self.hidden_size,\n            bias=config.attention_bias,\n            dtype=dtype,\n            device=device,\n            is_tp=True,\n            quant_config=quantization_config,\n        )\n\n    def _q_proj(self, hidden_states, num_heads: int, nope_size: int, pe_size: int):\n        \"\"\"Q proj.\"\"\"\n        q_len = hidden_states.size(1)\n\n        query_states = hidden_states.new_empty(q_len, num_heads, nope_size + pe_size)\n\n        if self.q_lora_rank is None:\n            q = self.q_proj(hidden_states)\n        else:\n            q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))\n        q = q.view(q_len, num_heads, self.q_head_dim)\n        # q_pe: (q_len, num_heads, qk_rope_head_dim)\n        q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)\n        # q_nope: (q_len, num_heads, kv_lora_rank)\n        q_nope_out = query_states[..., :nope_size]\n        self.kc(q_nope, q_nope_out)\n        return query_states, q_pe\n\n    def _kv_proj(self, hidden_states, nope_size: int):\n        \"\"\"Kv proj.\"\"\"\n        # (q_len, 1, nope_size + pe_size)\n        key_states = self.kv_a_proj_with_mqa(hidden_states[0, :, None])\n        # (q_len, 1, pe_size)\n        k_pe = key_states[..., nope_size:]\n        # kv_a_layernorm\n        value_states = key_states[..., :nope_size]\n        value_states = self.kv_a_layernorm(value_states)\n        key_states[..., :nope_size] = value_states\n        return key_states, value_states, k_pe\n\n    def _qkv_proj(self, hidden_states: torch.Tensor, num_heads: int):\n        \"\"\"Qkv proj.\"\"\"\n        nope_size = self.kv_lora_rank\n        pe_size = self.qk_rope_head_dim\n        query_states, q_pe = self._q_proj(hidden_states, num_heads, nope_size, pe_size)\n        key_states, value_states, k_pe = self._kv_proj(hidden_states, nope_size)\n\n        return query_states, key_states, value_states, q_pe, k_pe\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],\n        past_key_value: Optional[Tuple[torch.Tensor]] = None,\n        attn_metadata: Any = None,\n    ):\n        \"\"\"Rewrite of LlamaAttention.forward.\"\"\"\n        dist_config = get_dist_manager().current_config()\n        if dist_config.dp > 1:\n            num_heads = self.num_heads\n        else:\n            world_size = dist_config.world_size\n            num_heads = self.num_heads // world_size\n        nope_size = self.kv_lora_rank\n        q_len = hidden_states.size(1)\n\n        # qkv_proj\n        query_states, key_states, value_states, q_pe, k_pe = self._qkv_proj(hidden_states, num_heads=num_heads)\n\n        cos, sin = rotary_pos_emb\n        q_pe, k_pe = self.apply_rotary_pos_emb(\n            q_pe,\n            k_pe,\n            cos,\n            sin,\n            inplace=False,\n        )\n        query_states[..., nope_size:] = q_pe\n        key_states[..., nope_size:] = k_pe\n\n        attn_output = self.attn_fwd(\n            query_states,\n            key_states,\n            value_states,\n            past_key_value[0],\n            past_key_value[0][..., :nope_size],\n            attn_metadata,\n            k_scales_zeros=None if len(past_key_value) == 2 else past_key_value[2],\n            v_scales_zeros=None if len(past_key_value) == 2 else past_key_value[3],\n            inplace=True,\n        )\n        attn_bmm_out = attn_output.new_empty(q_len, num_heads, self.v_head_dim)\n\n        self.vc(attn_output, attn_bmm_out)\n        attn_output = attn_bmm_out.flatten(-2, -1)[None]\n        attn_output = self.o_proj(attn_output)\n\n        return attn_output\n\n\nclass MoEGate(nn.Module):\n    \"\"\"Deepseek Gate.\"\"\"\n\n    def __init__(self,\n                 config: Any,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None,\n                 info: EPLBDispatchInfo = None):\n        super().__init__()\n        self.config = config\n        self.top_k = config.num_experts_per_tok\n        self.n_routed_experts = config.n_routed_experts\n        self.routed_scaling_factor = config.routed_scaling_factor\n        self.scoring_func = config.scoring_func\n        self.topk_method = config.topk_method\n        self.n_group = config.n_group\n        self.topk_group = config.topk_group\n        self.norm_topk_prob = config.norm_topk_prob\n        self.renormalize = self.top_k > 1 and self.norm_topk_prob\n        self.router_n_groups = getattr(config, 'router_n_groups', -1)\n        assert self.top_k % self.router_n_groups == 0, f'{self.top_k} cannot be divided by {self.router_n_groups}'\n        # topk selection algorithm\n        self.norm_topk_prob = config.norm_topk_prob\n        self.gating_dim = config.hidden_size\n        self.weight = nn.Parameter(\n            torch.empty((self.n_routed_experts, self.gating_dim), dtype=torch.float32, device=device))\n        if self.topk_method == 'noaux_tc':\n            from lmdeploy.pytorch.nn.moe.route import NoauxTCRouter\n            self.e_score_correction_bias = nn.Parameter(\n                torch.empty((self.n_routed_experts, ), dtype=torch.float32, device=device))\n            self.noaux_tc_router = NoauxTCRouter(self.scoring_func,\n                                                 top_k=self.top_k,\n                                                 n_group=self.n_group,\n                                                 topk_group=self.topk_group,\n                                                 n_routed_experts=self.n_routed_experts,\n                                                 routed_scaling_factor=self.routed_scaling_factor,\n                                                 renormalize=self.renormalize,\n                                                 router_n_groups=self.router_n_groups)\n        self.softmax_topk = SoftmaxTopK(self.top_k, n_groups=self.router_n_groups)\n        self.fake_eplb = getenv('LMDEPLOY_FAKE_EPLB', 'False').lower() == 'true'\n        self.eplb_dispatch_info = info\n\n    def _compute_scores(self, logits: torch.Tensor):\n        \"\"\"Compute scores.\"\"\"\n        if self.scoring_func == 'softmax':\n            scores = logits.softmax(dim=-1, dtype=torch.float32)\n        elif self.scoring_func == 'sigmoid':\n            scores = logits.sigmoid()\n        else:\n            raise NotImplementedError('unsupported scoring function '\n                                      f'for MoE gating: {self.scoring_func}')\n        return scores\n\n    def _postprocess_topk_weight(self, topk_weight: torch.Tensor):\n        if self.renormalize:\n            denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20\n            topk_weight = topk_weight / denominator\n            if not topk_weight.is_contiguous():\n                topk_weight = topk_weight.contiguous()\n        if not self.renormalize:\n            topk_weight = topk_weight * self.routed_scaling_factor\n        return topk_weight\n\n    def forward(self, hidden_states: torch.Tensor):\n        \"\"\"forward.\"\"\"\n        router_logits = F.linear(hidden_states.to(self.weight.dtype), self.weight)\n        if self.fake_eplb:\n            # Forcefully manipulate router_logits to simulate expert load balancing (EPLB).\n            # This is a benchmark-only hack to achieve optimal performance metrics.\n            router_logits = torch.randn_like(router_logits)\n\n        if self.topk_method == 'greedy':\n            topk_weight, topk_idx = self.softmax_topk(router_logits)\n\n            topk_weight = self._postprocess_topk_weight(topk_weight)\n        elif self.topk_method == 'group_limited_greedy':\n            scores = router_logits\n            grouped_logits = scores.unflatten(-1, (self.n_group, -1))\n            group_scores = (grouped_logits.max(-1).values)\n            group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1]  # [n, top_k_group]\n            group_mask = torch.zeros_like(group_scores)  # [n, n_group]\n            group_mask.scatter_(1, group_idx, 1)  # [n, n_group]\n            group_mask = ~group_mask.bool()[..., None]\n            grouped_logits = grouped_logits.masked_fill(group_mask, 0.0)\n            scores = grouped_logits.flatten(1, 2)\n            topk_weight, topk_idx = self.softmax_topk(scores)\n\n            topk_weight = self._postprocess_topk_weight(topk_weight)\n        elif self.topk_method == 'noaux_tc':\n            topk_weight, topk_idx = self.noaux_tc_router(router_logits, self.e_score_correction_bias)\n        else:\n            raise RuntimeError(f'Unsupported topk_method: {self.topk_method}')\n\n        if self.eplb_dispatch_info is not None:\n            topk_idx = EPLBManager.topk_ids_logical_to_physical(topk_idx, self.eplb_dispatch_info)\n\n        return topk_weight, topk_idx\n\n\nclass DeepseekV2MoE(nn.Module):\n    \"\"\"Deepseek v2 MoE.\"\"\"\n\n    def __init__(self, config: Any, layer_idx, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        quantization_config = getattr(config, 'quantization_config', None)\n        self.hidden_dim = config.hidden_size\n        self.ffn_dim = config.moe_intermediate_size\n        self.num_experts = config.n_routed_experts\n        self.top_k = config.num_experts_per_tok\n        self.norm_topk_prob = config.norm_topk_prob\n        self.routed_scaling_factor = config.routed_scaling_factor\n        self.renormalize = self.top_k > 1 and self.norm_topk_prob\n        self.topk_method = config.topk_method\n        self.n_group = config.n_group\n        self.topk_group = config.topk_group\n\n        dist_ctx = get_dist_manager().current_context()\n        dist_config = dist_ctx.dist_config\n        dp = dist_config.dp\n        world_size = dist_config.world_size\n        moe_all_reduce = dp > 1 and dist_config.tp > 1\n        if get_dist_manager().current_context().dist_config.enable_eplb:\n            eplb_dispatch_info = EPLBManager.get_dispatch_info(\n                ep_rank=dist_ctx.ep_rank,\n                layer_idx=layer_idx,\n            )\n            self.num_experts = EPLBManager.num_physical_experts()\n            self.gate = MoEGate(config, dtype=dtype, device=device, info=eplb_dispatch_info)\n        else:\n            self.gate = MoEGate(config, dtype=dtype, device=device, info=None)\n        self.experts = build_fused_moe(\n            self.hidden_dim,\n            self.ffn_dim,\n            self.num_experts,\n            top_k=self.top_k,\n            renormalize=False,\n            dtype=dtype,\n            device=device,\n            all_reduce=moe_all_reduce,\n            quant_config=quantization_config,\n            layer_idx=layer_idx,\n        )\n        self.shared_experts = None\n        if config.n_shared_experts is not None:\n            intermediate_size = (config.moe_intermediate_size * config.n_shared_experts)\n            self.shared_experts = DeepseekV2MLP(\n                config=config,\n                intermediate_size=intermediate_size,\n                dtype=dtype,\n                device=device,\n                is_shared_expert=True,\n            )\n\n        if dp == 1 and world_size > 1:\n            self._all_reduce = True\n        else:\n            self._all_reduce = False\n\n    def forward(self, hidden_states: torch.Tensor):\n        \"\"\"forward.\"\"\"\n        batch_size, sequence_length, hidden_dim = hidden_states.shape\n        hidden_states = hidden_states.view(-1, hidden_dim)\n        topk_weights, topk_ids = self.gate(hidden_states)\n\n        out_states = self.experts(\n            hidden_states,\n            topk_weights,\n            topk_ids,\n        )\n\n        if self.shared_experts is not None:\n            shared_states = self.shared_experts(hidden_states)\n            out_states += shared_states\n        out_states = out_states.reshape(batch_size, sequence_length, -1)\n\n        if self._all_reduce:\n            dist.all_reduce(out_states)\n\n        return out_states\n\n\nclass DeepseekV2MLP(nn.Module):\n    \"\"\"Deepseek v2 mlp.\"\"\"\n\n    def __init__(self,\n                 config: Any,\n                 intermediate_size: int = None,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None,\n                 is_shared_expert: bool = False):\n        super().__init__()\n        quantization_config = getattr(config, 'quantization_config', None)\n        if is_shared_expert:\n            dist_config = get_dist_manager().current_config()\n            dp = dist_config.dp\n            if dp == 1:\n                # split weight, do all reduce in moe\n                is_tp = True\n                all_reduce = False\n            else:\n                # do not split weight on dp\n                # TODO: support dp+tp?\n                is_tp = False\n                all_reduce = False\n        else:\n            all_reduce = True\n            is_tp = True\n\n        # gate up\n        if intermediate_size is None:\n            intermediate_size = config.intermediate_size\n        self.gate_up_proj = build_gateup_linear(\n            config.hidden_size,\n            [intermediate_size, intermediate_size],\n            bias=False,\n            dtype=dtype,\n            device=device,\n            quant_config=quantization_config,\n            is_tp=is_tp,\n        )\n\n        # silu and mul\n        self.act_fn = SiluAndMul(inplace=True)\n\n        # down\n        self.down_proj = build_down_linear(\n            intermediate_size,\n            config.hidden_size,\n            bias=False,\n            quant_config=quantization_config,\n            dtype=dtype,\n            device=device,\n            is_tp=is_tp,\n            all_reduce=all_reduce,\n        )\n\n    def forward(self, x):\n        \"\"\"forward.\"\"\"\n        gate_up = self.gate_up_proj(x)\n        act = self.act_fn(gate_up)\n        return self.down_proj(act)\n\n\nclass DeepseekV2DecoderLayer(nn.Module):\n    \"\"\"Deepseekv2 decoder layer.\"\"\"\n\n    def __init__(self, config: Any, layer_idx: int, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        self.layer_idx = layer_idx\n        quantization_config = None\n\n        # build attention layer\n        if getattr(config, 'use_mla', True):\n            self.self_attn = DeepseekV2Attention(config, dtype=dtype, device=device)\n        else:\n            # deepseek-vl2-tiny uses MHA LlamaAttention structure\n            from lmdeploy.pytorch.models.llama import LlamaAttention\n            self.self_attn = LlamaAttention(config, dtype=dtype, device=device)\n\n        # mlp\n        self.mlp = (DeepseekV2MoE(config, layer_idx, dtype=dtype, device=device) if\n                    (config.n_routed_experts is not None and layer_idx >= config.first_k_dense_replace\n                     and layer_idx % config.moe_layer_freq == 0) else DeepseekV2MLP(config, dtype=dtype, device=device))\n\n        # build input layer norm\n        self.input_layernorm = RMSNorm(config.hidden_size,\n                                       config.rms_norm_eps,\n                                       quant_config=quantization_config,\n                                       dtype=dtype,\n                                       device=device)\n\n        # build attention layer norm\n        self.post_attention_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],\n        past_key_value: Optional[List[torch.FloatTensor]],\n        residual: Optional[torch.Tensor] = None,\n        attn_metadata: Any = None,\n    ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:\n\n        if residual is None:\n            residual = hidden_states\n            hidden_states = self.input_layernorm(hidden_states)\n        else:\n            hidden_states, residual = self.input_layernorm(hidden_states, residual)\n\n        # Self Attention\n        hidden_states = self.self_attn(\n            hidden_states=hidden_states,\n            rotary_pos_emb=rotary_pos_emb,\n            past_key_value=past_key_value,\n            attn_metadata=attn_metadata,\n        )\n\n        # Fully Connected\n        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)\n        hidden_states = self.mlp(hidden_states)\n\n        outputs = (hidden_states, residual)\n        return outputs\n\n    def forward_yield(\n        self,\n        hidden_states: torch.Tensor,\n        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],\n        past_key_value: Optional[List[torch.FloatTensor]],\n        residual: Optional[torch.Tensor] = None,\n        attn_metadata: Any = None,\n        tag: Any = None,\n    ):\n        \"\"\"forward_yield.\"\"\"\n        is_decoding = attn_metadata.is_decoding\n        if residual is None:\n            residual = hidden_states\n            hidden_states = self.input_layernorm(hidden_states)\n        else:\n            hidden_states, residual = self.input_layernorm(hidden_states, residual)\n\n        # yield for attn0 and attn1\n        yield\n        # Self Attention\n        hidden_states = self.self_attn(\n            hidden_states=hidden_states,\n            rotary_pos_emb=rotary_pos_emb,\n            past_key_value=past_key_value,\n            attn_metadata=attn_metadata,\n        )\n\n        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)\n\n        # MOE\n        batch_size, sequence_length, hidden_dim = hidden_states.shape\n        hidden_states = hidden_states.view(-1, hidden_dim)\n        topk_weights, topk_idx = self.mlp.gate(hidden_states)\n\n        topk_weights = self.mlp.experts.renormalize(topk_weights)\n        topk_weights = topk_weights.to(torch.float32)\n        topk_idx = topk_idx.to(torch.int64)\n        hidden_shape = hidden_states.shape\n        shared_states = None\n\n        state = {\n            'hidden_states': hidden_states,\n            'topk_idx': topk_idx,\n            'topk_weights': topk_weights,\n            'raw_hidden_shape': hidden_shape,\n            'moe_type': MoeType.DSAsyncDecode if is_decoding else MoeType.DSAsyncPrefill,\n        }\n\n        self.mlp.experts.before_dispatch(state)\n\n        # yield for attn1, dis (+share)\n        yield\n        recv_state = self.mlp.experts.dispatch(state)\n        if self.mlp.shared_experts is not None and is_decoding:\n            shared_states = self.mlp.shared_experts(hidden_states)\n        # yield for dis, dis_wait\n        yield\n        self.mlp.experts.wait(recv_state)\n        # yield for dis_wait, moe\n        yield\n        gemm_state = self.mlp.experts.gemm(recv_state)\n        # yield for moe, comb\n        yield\n        out_state = self.mlp.experts.combine(gemm_state)\n        # yield for comb, (+share) comb_wait\n        yield\n        if self.mlp.shared_experts is not None and not is_decoding:\n            shared_states = self.mlp.shared_experts(hidden_states)\n        self.mlp.experts.wait(out_state)\n        # yield for (+share) comb_wait, (+share) attn0\n        yield\n        out_hidden_states = out_state['hidden_states'].view(hidden_shape)\n        if shared_states is not None:\n            out_hidden_states += shared_states\n        elif self.mlp.shared_experts is not None:\n            shared_states = self.mlp.shared_experts(hidden_states)\n            out_hidden_states += shared_states\n        else:\n            pass\n        out_hidden_states = out_hidden_states.reshape(batch_size, sequence_length, -1)\n        outputs = (out_hidden_states, residual)\n        return outputs\n\n\nclass DeepseekV2Model(nn.Module):\n    \"\"\"Mixtral model.\"\"\"\n\n    def __init__(self, config: Any, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        self.config = config\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n        self.embed_tokens = ParallelEmbedding(config.vocab_size,\n                                              config.hidden_size,\n                                              self.padding_idx,\n                                              dtype=dtype,\n                                              device=device,\n                                              is_tp=True)\n\n        if get_dist_manager().current_context().dist_config.enable_eplb:\n            ep_size_, _ = get_ep_world_rank()\n            EPLBManager.init_global_eplb_metadata(ep_size_, config.n_routed_experts, config.num_hidden_layers)\n        self.layers = nn.ModuleList([\n            DeepseekV2DecoderLayer(config, layer_idx, dtype=dtype, device=device)\n            for layer_idx in range(config.num_hidden_layers)\n        ])\n\n        # build norm\n        self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, quant_config=None, dtype=dtype, device=device)\n\n        emb_type = RopeType.LinearScaling\n        rope_dim = config.qk_rope_head_dim if getattr(config, 'use_mla', True) else (config.hidden_size //\n                                                                                     config.num_attention_heads)\n        rope_max_pos_emb = config.max_position_embeddings\n        rope_base = get_rope_theta(config)\n\n        rope_params = dict(emb_type=emb_type, dim=rope_dim, max_position_embeddings=rope_max_pos_emb, base=rope_base)\n        update_params = build_rotary_params(config)\n        rope_params.update(update_params)\n        self.rotary_emb = build_rotary_embedding(**rope_params)\n\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        attn_metadata: Any = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n    ):\n        \"\"\"forward.\"\"\"\n        if inputs_embeds is None:\n            inputs_embeds = self.embed_tokens(input_ids)\n\n        hidden_states = inputs_embeds\n        residual = None\n        cos, sin = self.rotary_emb(hidden_states, position_ids)\n        cos, sin = cos[0], sin[0]\n        rotary_pos_emb = (cos, sin)\n        for idx, decoder_layer in enumerate(self.layers):\n            past_key_value = past_key_values[idx]\n            hidden_states, residual = decoder_layer(\n                hidden_states,\n                rotary_pos_emb=rotary_pos_emb,\n                past_key_value=past_key_value,\n                residual=residual,\n                attn_metadata=attn_metadata,\n            )\n\n        hidden_states, _ = self.norm(hidden_states, residual)\n\n        return hidden_states\n\n    def forward_microbatch(\n        self,\n        input_ids: torch.LongTensor = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        attn_metadata: Any = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n    ):\n        \"\"\"forward_microbatch.\"\"\"\n        assert self.config.moe_layer_freq == 1\n        moe_start_idx = min(self.config.first_k_dense_replace, len(self.layers))\n\n        # embed and mlplayers\n        if inputs_embeds is None:\n            inputs_embeds = self.embed_tokens(input_ids)\n        hidden_states = inputs_embeds\n        residual = None\n        cos, sin = self.rotary_emb(hidden_states, position_ids)\n        cos, sin = cos[0], sin[0]\n        rotary_pos_emb = (cos, sin)\n\n        for idx, decoder_layer in enumerate(self.layers[:moe_start_idx]):\n            past_key_value = past_key_values[idx]\n            hidden_states, residual = decoder_layer(\n                hidden_states,\n                rotary_pos_emb=rotary_pos_emb,\n                past_key_value=past_key_value,\n                residual=residual,\n                attn_metadata=attn_metadata,\n            )\n\n        if moe_start_idx < len(self.layers):\n            # run two micro batch\n            num = 2\n            input_list, exec_type, delta_stages, extern_tag = split_input(hidden_states,\n                                                                          rotary_pos_emb,\n                                                                          past_key_values,\n                                                                          residual,\n                                                                          attn_metadata,\n                                                                          moe_start_idx,\n                                                                          len(self.layers),\n                                                                          num=num)\n\n            output_list = execute_batch(inputs=input_list,\n                                        fn=self.forward_yieldlayers,\n                                        delta_stages=delta_stages,\n                                        exec_type=exec_type,\n                                        extern_tag=extern_tag)\n            hidden_states, residual = merge_output(output_list)\n\n        hidden_states, _ = self.norm(hidden_states, residual)\n\n        return hidden_states\n\n    def forward_yieldlayers(self,\n                            hidden_states: torch.Tensor,\n                            rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],\n                            past_key_values: Optional[List[torch.FloatTensor]] = None,\n                            residual: Optional[torch.Tensor] = None,\n                            attn_metadata: Any = None,\n                            start_idx: int = -1,\n                            end_idx: int = -1,\n                            tag: Any = None):\n        \"\"\"forward_yieldlayers.\"\"\"\n        for idx in range(start_idx, end_idx):\n            past_key_value = past_key_values[idx]\n            hidden_states, residual = yield from self.layers[idx].forward_yield(hidden_states,\n                                                                                rotary_pos_emb=rotary_pos_emb,\n                                                                                past_key_value=past_key_value,\n                                                                                residual=residual,\n                                                                                attn_metadata=attn_metadata,\n                                                                                tag=tag)\n        return hidden_states, residual\n\n    def get_input_embeddings(self):\n        \"\"\"Get input embeddings.\"\"\"\n        return self.embed_tokens\n\n\nclass DeepseekV2ForCausalLM(nn.Module, CudaGraphMixin):\n    \"\"\"Mixture model for causalLM.\"\"\"\n\n    def __init__(self,\n                 config: Any,\n                 ctx_mgr: StepContextManager,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__()\n        self.config = config\n        self.quantization_config = getattr(config, 'quantization_config', None)\n        self.dtype = dtype\n        self.ctx_mgr = ctx_mgr\n        self.model = DeepseekV2Model(config, dtype=dtype, device=device)\n        # build lm_head\n        self.lm_head = build_rowwise_linear(config.hidden_size,\n                                            config.vocab_size,\n                                            bias=False,\n                                            dtype=dtype,\n                                            device=device)\n        self._load_buffers = dict()\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        past_key_values: List[List[torch.Tensor]],\n        attn_metadata: Any = None,\n        inputs_embeds: torch.Tensor = None,\n        **kwargs,\n    ):\n        if get_step_ctx_manager().current_context().enable_microbatch:\n            hidden_states = self.model.forward_microbatch(\n                input_ids=input_ids,\n                position_ids=position_ids,\n                past_key_values=past_key_values,\n                attn_metadata=attn_metadata,\n                inputs_embeds=inputs_embeds,\n            )\n        else:\n            hidden_states = self.model.forward(\n                input_ids=input_ids,\n                position_ids=position_ids,\n                past_key_values=past_key_values,\n                attn_metadata=attn_metadata,\n                inputs_embeds=inputs_embeds,\n            )\n        return hidden_states\n\n    def get_logits(self, hidden_states: torch.Tensor):\n        \"\"\"Compute logits of the model output.\"\"\"\n        return self.lm_head(hidden_states)\n\n    def get_input_embeddings(self):\n        \"\"\"Get input embeddings.\"\"\"\n        return self.model.get_input_embeddings()\n\n    def prepare_inputs_for_generation(\n        self,\n        past_key_values: List[List[torch.Tensor]],\n        inputs_embeds: Optional[torch.Tensor] = None,\n        context: StepContext = None,\n    ):\n        \"\"\"Prepare input.\"\"\"\n        input_ids = context.input_ids\n        position_ids = context.position_ids\n        attn_metadata = context.attn_metadata\n\n        return dict(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            inputs_embeds=inputs_embeds,\n        )\n\n    def _load_weight_experts(self, name: str, loaded_weight: torch.Tensor, params_dict: Dict[str, nn.Parameter],\n                             expert_params_mapping: List):\n        \"\"\"Load weight experts.\"\"\"\n        for (param_name, weight_name, expert_id, shard_id) in expert_params_mapping:\n            if weight_name not in name:\n                continue\n            name = name.replace(weight_name, param_name)\n            param = params_dict[name]\n            load_weight(param, loaded_weight, expert_id=expert_id, shard_id=shard_id)\n            break\n        else:\n            param = params_dict[name]\n            load_weight(param, loaded_weight)\n\n    def _load_weight_attention(self, name: str, loaded_weight: torch.Tensor, params_dict: Dict[str, nn.Parameter],\n                               update_pe_mapping: List):\n        \"\"\"Load weight attention.\"\"\"\n        device = next(iter(params_dict.values())).device\n\n        def __update_pe(weight, head_dim: int, pe_dim_offset: int):\n            # (num_heads, q_head_dim, input_dim)\n            weight = weight.unflatten(0, (-1, head_dim))\n            # (num_heads, nope_head_dim, input_dim)\n            w_pe = weight[:, pe_dim_offset:]\n            # (num_heads, nope_head_dim//2, 2, input_dim)\n            new_w_pe = w_pe.unflatten(1, (-1, 2))\n            # (num_heads, nope_head_dim, input_dim)\n            new_w_pe = new_w_pe.transpose(1, 2).flatten(1, 2)\n            weight[:, pe_dim_offset:] = new_w_pe\n            weight = weight.flatten(0, 1)\n            return weight\n\n        def __load_kcvc(name: str, weight: torch.Tensor):\n            \"\"\"Load kc and vc from weight.\"\"\"\n            config = self.config\n            v_head_dim = config.v_head_dim\n            qk_nope_head_dim = config.qk_nope_head_dim\n            w_kc, w_vc = weight.unflatten(0, (-1, qk_nope_head_dim + v_head_dim)).split([qk_nope_head_dim, v_head_dim],\n                                                                                        dim=1)\n            w_vc = w_vc.transpose(1, 2).contiguous()\n            kc_param_name = name.replace('.kv_b_proj', '.kc')\n            param_kc = params_dict[kc_param_name]\n            load_weight(param_kc, w_kc)\n            vc_param_name = name.replace('.kv_b_proj', '.vc')\n            param_vc = params_dict[vc_param_name]\n            load_weight(param_vc, w_vc)\n\n        def __dequant_weight(weight: torch.Tensor, scale: torch.Tensor, dtype: torch.dtype):\n            \"\"\"Dequant weight.\"\"\"\n            dim_w0, dim_w1 = weight.shape\n            dim_s0, dim_s1 = scale.shape\n            assert dim_w0 % dim_s0 == 0\n            assert dim_w1 % dim_s1 == 0\n            group0 = dim_w0 // dim_s0\n            group1 = dim_w1 // dim_s1\n            weight = weight.reshape(dim_s0, group0, dim_s1, group1)\n            scale = scale.reshape(dim_s0, 1, dim_s1, 1)\n            weight = weight.to(scale.dtype) * scale\n            weight = weight.to(dtype)\n            weight = weight.reshape(dim_w0, dim_w1)\n            return weight\n\n        def __load_kcvc_blocked_fp8(name: str, loaded_weight: torch.Tensor):\n            \"\"\"Dequant weight.\"\"\"\n            if name.endswith('.weight'):\n                weight_name = name\n                scale_name = name.replace('.weight', '.scale')\n            elif name.endswith('.weight_scale_inv'):\n                weight_name = name.replace('.weight_scale_inv', '.weight')\n                scale_name = name\n            self._load_buffers[name] = loaded_weight\n            if (weight_name in self._load_buffers and scale_name in self._load_buffers):\n                weight = self._load_buffers.pop(weight_name)\n                scale = self._load_buffers.pop(scale_name)\n                kc_param_name = weight_name.replace('.kv_b_proj', '.kc')\n                dtype = params_dict[kc_param_name].dtype\n                weight = __dequant_weight(weight, scale, dtype)\n                __load_kcvc(weight_name, weight)\n\n        for (mod_name, head_dim, pe_dim_offset) in update_pe_mapping:\n            if mod_name not in name:\n                continue\n            if name.endswith('.weight_scale_inv'):\n                weight = loaded_weight\n            else:\n                loaded_weight = loaded_weight.to(device)\n                weight = __update_pe(loaded_weight, head_dim, pe_dim_offset)\n            param = params_dict[name]\n            load_weight(param, weight)\n            break\n        else:\n            if '.kv_b_proj' in name:\n                quantization_config = self.quantization_config\n                quant_method = None\n                if quantization_config is not None:\n                    quant_method = quantization_config.get('quant_method')\n\n                loaded_weight = loaded_weight.to(device)\n                if quant_method == 'fp8':\n                    # update blocked fp8 weight\n                    __load_kcvc_blocked_fp8(name, loaded_weight)\n                else:\n                    __load_kcvc(name, loaded_weight)\n            else:\n                param = params_dict[name]\n                load_weight(param, loaded_weight)\n\n    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):\n        \"\"\"Load weights.\"\"\"\n\n        def __skip_nextn(name, nextn_keys):\n            for nextn_key in nextn_keys:\n                if nextn_key in name:\n                    return True\n            return False\n\n        def __skip_layers():\n            \"\"\"We might change the number of layers so we can debug the model\n            with less gpus.\"\"\"\n            import re\n            matches = re.findall(r'\\.layers\\.(\\d+)\\.', name)\n            layer_id = int(matches[0])\n            return layer_id >= self.config.num_hidden_layers\n\n        stacked_params_mapping = [\n            # (param_name, shard_name, shard_id)\n            ('.gate_up_proj', '.gate_proj', 0),\n            ('.gate_up_proj', '.up_proj', 1),\n        ]\n\n        config = self.config\n\n        update_pe_mapping = []\n        if getattr(config, 'use_mla', True):\n            qk_rope_head_dim = config.qk_rope_head_dim\n            kv_lora_rank = config.kv_lora_rank\n            qk_nope_head_dim = config.qk_nope_head_dim\n            q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim\n            kv_dim = kv_lora_rank + qk_rope_head_dim\n            update_pe_mapping = [('q_proj', q_head_dim, qk_nope_head_dim), ('q_b_proj', q_head_dim, qk_nope_head_dim),\n                                 ('kv_a_proj_with_mqa', kv_dim, kv_lora_rank)]\n        else:\n            # deepseek-vl2-tiny uses MHA LlamaAttention, weight loading differs from MLA\n            stacked_params_mapping.extend([\n                # (param_name, shard_name, shard_id)\n                ('.qkv_proj', '.q_proj', 'q'),\n                ('.qkv_proj', '.k_proj', 'k'),\n                ('.qkv_proj', '.v_proj', 'v'),\n            ])\n\n        num_experts = self.config.n_routed_experts\n        expert_params_mapping = []\n        for exp_id in range(num_experts):\n            gate_param = ('.experts.gate_up', f'.experts.{exp_id}.gate_proj', exp_id, 'gate')\n            up_param = ('.experts.gate_up', f'.experts.{exp_id}.up_proj', exp_id, 'up')\n            down_param = ('.experts.down', f'.experts.{exp_id}.down_proj', exp_id, 'down')\n            expert_params_mapping += [gate_param, up_param, down_param]\n\n        num_hidden_layers = self.config.num_hidden_layers\n\n        num_nextn_predict_layers = getattr(self.config, 'num_nextn_predict_layers', 1)\n        nextn_keys = [f'.layers.{num_hidden_layers+i}' for i in range(num_nextn_predict_layers)]\n\n        params_dict = dict(self.named_parameters())\n        for name, loaded_weight in weights:\n            if 'rotary_emb.inv_freq' in name:\n                continue\n            if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):\n                continue\n            if '.layers' in name:\n                # skip nextn\n                if __skip_nextn(name, nextn_keys):\n                    continue\n\n                if __skip_layers():\n                    continue\n\n            if self.config.tie_word_embeddings and 'lm_head.weight' in name:\n                continue\n\n            if '.experts' in name:\n                self._load_weight_experts(name, loaded_weight, params_dict, expert_params_mapping=expert_params_mapping)\n            elif '.self_attn' in name and getattr(config, 'use_mla', True):\n                # attention\n                self._load_weight_attention(name, loaded_weight, params_dict, update_pe_mapping)\n            else:\n                # other\n                for (param_name, weight_name, shard_id) in stacked_params_mapping:\n                    if weight_name not in name:\n                        continue\n                    name = name.replace(weight_name, param_name)\n                    param = params_dict[name]\n                    load_weight(param, loaded_weight, shard_id=shard_id)\n                    break\n                else:\n                    param = params_dict[name]\n                    load_weight(param, loaded_weight)\n"
  },
  {
    "path": "lmdeploy/pytorch/models/deepseek_v32.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom typing import Any, Sequence, Tuple\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn\n\nfrom lmdeploy.pytorch.distributed import get_dist_manager, get_ep_world_rank\nfrom lmdeploy.pytorch.model_inputs import StepContextManager\nfrom lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, RMSNorm, RopeType, build_rotary_embedding,\n                                 build_rotary_params)\nfrom lmdeploy.pytorch.nn.eplb import EPLBManager\nfrom lmdeploy.pytorch.nn.linear import build_colwise_linear, build_o_proj, build_rowwise_linear\nfrom lmdeploy.pytorch.nn.nsa import IndexerTopKFP8\nfrom lmdeploy.pytorch.nn.rotary_embedding import get_rope_parameters, get_rope_theta\n\nfrom .deepseek_v2 import (DeepseekV2Attention, DeepseekV2BMM, DeepseekV2DecoderLayer, DeepseekV2ForCausalLM,\n                          DeepseekV2MLP, DeepseekV2Model, DeepseekV2MoE, yarn_get_mscale)\n\n\ndef rotate_activation(x: torch.Tensor) -> torch.Tensor:\n    assert x.dtype == torch.bfloat16\n    from fast_hadamard_transform import hadamard_transform\n    hidden_size = x.size(-1)\n    return hadamard_transform(x, scale=hidden_size**-0.5)\n\n\nclass LayerNorm(nn.Module):\n    \"\"\"Layer Normalization.\"\"\"\n\n    def __init__(self, dim: int, eps: float = 1e-6, device: torch.device = None):\n        super().__init__()\n        if device is None:\n            device = 'cuda'\n        self.dim = dim\n        self.eps = eps\n        self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32, device=device))\n        self.bias = nn.Parameter(torch.zeros(dim, dtype=torch.float32, device=device))\n\n    def forward(self, x: torch.Tensor):\n        return F.layer_norm(x.float(), (self.dim, ), self.weight, self.bias, self.eps).type_as(x)\n\n\nclass Indexer(nn.Module):\n\n    def __init__(self, config: Any, layer_idx: int, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        try:\n            import fast_hadamard_transform  # noqa: F401\n        except ImportError:\n            raise ImportError('Please install fast_hadamard_transform package.')\n        quant_config = getattr(config, 'quantization_config', None)\n        self.layer_idx = layer_idx\n        # self.dim: int = 2048\n        self.dim: int = config.hidden_size\n        self.n_heads: int = config.index_n_heads\n        self.n_local_heads = config.index_n_heads\n        self.head_dim: int = config.index_head_dim\n        self.rope_head_dim: int = config.qk_rope_head_dim\n        self.index_topk: int = config.index_topk\n        self.q_lora_rank: int = config.q_lora_rank\n        self.wq_b = build_colwise_linear(self.q_lora_rank,\n                                         self.n_heads * self.head_dim,\n                                         bias=False,\n                                         dtype=dtype,\n                                         device=device,\n                                         is_tp=False,\n                                         quant_config=quant_config)\n        self.wk = build_colwise_linear(self.dim,\n                                       self.head_dim,\n                                       bias=False,\n                                       dtype=dtype,\n                                       device=device,\n                                       is_tp=False,\n                                       quant_config=quant_config)\n        self.k_norm = LayerNorm(self.head_dim, device=device)\n        self.weights_proj = build_colwise_linear(self.dim,\n                                                 self.n_heads,\n                                                 bias=False,\n                                                 dtype=dtype,\n                                                 device=device,\n                                                 is_tp=False)\n        self.softmax_scale = self.head_dim**-0.5\n        self.apply_rotary_pos_emb = ApplyRotaryEmb()\n        self.indexer_topk = IndexerTopKFP8(self.index_topk, self.softmax_scale, block_size=128, fill=-1)\n\n    def forward(self,\n                x: torch.Tensor,\n                qr: torch.Tensor,\n                freqs_cis: torch.Tensor,\n                index_cache: Tuple[torch.Tensor, torch.Tensor],\n                attn_metadata: Any = None):\n        q = self.wq_b(qr)\n        q = q.unflatten(-1, (-1, self.head_dim))\n        q_pe, q_nope = torch.split(q, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1)\n        k = self.wk(x)\n        k = self.k_norm(k)\n        k_pe, k_nope = torch.split(k, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1)\n\n        # apply rotary embedding\n        cos, sin = freqs_cis\n        q_pe, k_pe = self.apply_rotary_pos_emb(\n            q_pe,\n            k_pe[..., None, :],\n            cos,\n            sin,\n            inplace=False,\n        )\n        k_pe = k_pe[0, :]\n        k_nope = k_nope[0, :, None]\n        q = torch.cat([q_pe, q_nope], dim=-1)\n        k = torch.cat([k_pe, k_nope], dim=-1)\n        q = rotate_activation(q)\n        k = rotate_activation(k)\n\n        weights = self.weights_proj(x) * self.n_heads**-0.5\n\n        return self.indexer_topk(q[0], k[:, 0], weights[0], index_cache[0], index_cache[1], attn_metadata=attn_metadata)\n\n\nclass DeepseekV32Attention(DeepseekV2Attention):\n\n    def __init__(self, config: Any, layer_idx: int, dtype: torch.dtype = None, device: torch.device = None):\n        nn.Module.__init__(self)\n        self.layer_idx = layer_idx\n        quantization_config = getattr(config, 'quantization_config', None)\n        self.q_lora_rank = config.q_lora_rank\n        self.hidden_size = config.hidden_size\n        self.num_heads = config.num_attention_heads\n        self.qk_rope_head_dim = config.qk_rope_head_dim\n        self.kv_lora_rank = config.kv_lora_rank\n        self.v_head_dim = config.v_head_dim\n        self.qk_nope_head_dim = config.qk_nope_head_dim\n        self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim\n        num_replicate_kv_heads = getattr(config, 'num_replicate_key_value_heads', 1)\n        num_key_value_heads = getattr(config, 'num_key_value_heads', 1)\n        use_flash_mla = getattr(config, 'use_flash_mla', False)\n\n        if self.q_lora_rank is None:\n            self.q_proj = build_colwise_linear(\n                self.hidden_size,\n                self.num_heads * self.q_head_dim,\n                bias=False,\n                dtype=dtype,\n                device=device,\n                is_tp=True,\n                quant_config=quantization_config,\n                dp_disable_tp=True,\n            )\n        else:\n            self.q_a_proj = build_colwise_linear(\n                self.hidden_size,\n                config.q_lora_rank,\n                bias=config.attention_bias,\n                dtype=dtype,\n                device=device,\n                is_tp=False,\n                quant_config=quantization_config,\n            )\n            self.q_a_layernorm = RMSNorm(config.q_lora_rank,\n                                         1e-6,\n                                         quant_config=quantization_config,\n                                         dtype=torch.float32,\n                                         device=device)\n            self.q_b_proj = build_colwise_linear(\n                config.q_lora_rank,\n                self.num_heads * self.q_head_dim,\n                bias=False,\n                dtype=dtype,\n                device=device,\n                is_tp=True,\n                quant_config=quantization_config,\n                dp_disable_tp=True,\n            )\n\n        self.kv_a_proj_with_mqa = build_colwise_linear(\n            self.hidden_size,\n            config.kv_lora_rank + config.qk_rope_head_dim,\n            bias=config.attention_bias,\n            dtype=dtype,\n            device=device,\n            is_tp=False,\n            quant_config=quantization_config,\n        )\n        self.kv_a_layernorm = RMSNorm(config.kv_lora_rank,\n                                      1e-6,\n                                      quant_config=quantization_config,\n                                      dtype=torch.float32,\n                                      device=device)\n        self.kc = DeepseekV2BMM(self.num_heads,\n                                config.qk_nope_head_dim,\n                                config.kv_lora_rank,\n                                dtype=dtype,\n                                device=device)\n\n        self.apply_rotary_pos_emb = ApplyRotaryEmb()\n\n        self.softmax_scale = self.q_head_dim**(-0.5)\n\n        rope_scaling = get_rope_parameters(config)\n        if rope_scaling is not None:\n            mscale_all_dim = rope_scaling.get('mscale_all_dim', 0)\n            if mscale_all_dim:\n                scaling_factor = rope_scaling['factor']\n                mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)\n                self.softmax_scale = self.softmax_scale * mscale * mscale\n\n        self.attn_fwd = Attention(self.num_heads,\n                                  config.kv_lora_rank + self.qk_rope_head_dim,\n                                  scale=self.softmax_scale,\n                                  num_kv_heads=num_key_value_heads,\n                                  v_head_size=config.kv_lora_rank,\n                                  num_replicate_kv_heads=num_replicate_kv_heads,\n                                  use_flash_mla=use_flash_mla)\n\n        self.vc = DeepseekV2BMM(self.num_heads, config.kv_lora_rank, self.v_head_dim, dtype=dtype, device=device)\n        self.o_proj = build_o_proj(\n            self.num_heads * self.v_head_dim,\n            self.hidden_size,\n            bias=config.attention_bias,\n            dtype=dtype,\n            device=device,\n            is_tp=True,\n            quant_config=quantization_config,\n        )\n\n        self.indexer = Indexer(config, layer_idx, dtype=dtype, device=device)\n\n    def _q_proj(self, hidden_states, num_heads: int, nope_size: int, pe_size: int):\n        \"\"\"Q proj.\"\"\"\n        q_len = hidden_states.size(1)\n\n        query_states = hidden_states.new_empty(q_len, num_heads, nope_size + pe_size)\n\n        if self.q_lora_rank is None:\n            qr = hidden_states\n            q = self.q_proj(hidden_states)\n        else:\n            qr = self.q_a_layernorm(self.q_a_proj(hidden_states))\n            q = self.q_b_proj(qr)\n        q = q.view(q_len, num_heads, self.q_head_dim)\n        # q_pe: (q_len, num_heads, qk_rope_head_dim)\n        q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)\n        # q_nope: (q_len, num_heads, kv_lora_rank)\n        q_nope_out = query_states[..., :nope_size]\n        self.kc(q_nope, q_nope_out)\n        return query_states, q_pe, qr\n\n    def _kv_proj(self, hidden_states, nope_size: int):\n        \"\"\"Kv proj.\"\"\"\n        # (q_len, 1, nope_size + pe_size)\n        key_states = self.kv_a_proj_with_mqa(hidden_states[0, :, None])\n        # (q_len, 1, pe_size)\n        k_pe = key_states[..., nope_size:]\n        # kv_a_layernorm\n        value_states = key_states[..., :nope_size]\n        value_states = self.kv_a_layernorm(value_states)\n        key_states[..., :nope_size] = value_states\n        return key_states, value_states, k_pe\n\n    def _qkv_proj(self, hidden_states: torch.Tensor, num_heads: int):\n        \"\"\"Qkv proj.\"\"\"\n        nope_size = self.kv_lora_rank\n        pe_size = self.qk_rope_head_dim\n        query_states, q_pe, qr = self._q_proj(hidden_states, num_heads, nope_size, pe_size)\n        key_states, value_states, k_pe = self._kv_proj(hidden_states, nope_size)\n\n        return query_states, key_states, value_states, q_pe, k_pe, qr\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],\n        past_key_value: Sequence[torch.Tensor] = None,\n        attn_metadata: Any = None,\n    ):\n        \"\"\"Rewrite of LlamaAttention.forward.\"\"\"\n        dist_ctx = get_dist_manager().current_context()\n        tp_world_size = dist_ctx.dist_config.attn_tp\n        num_heads = self.num_heads // tp_world_size\n        nope_size = self.kv_lora_rank\n        q_len = hidden_states.size(1)\n\n        # qkv_proj\n        query_states, key_states, value_states, q_pe, k_pe, qr = self._qkv_proj(hidden_states, num_heads=num_heads)\n\n        cos, sin = rotary_pos_emb\n        q_pe, k_pe = self.apply_rotary_pos_emb(\n            q_pe,\n            k_pe,\n            cos,\n            sin,\n            inplace=False,\n        )\n        query_states[..., nope_size:] = q_pe\n        key_states[..., nope_size:] = k_pe\n\n        topk_indices = self.indexer(hidden_states, qr, rotary_pos_emb, past_key_value[-2:], attn_metadata=attn_metadata)\n\n        attn_output = self.attn_fwd(\n            query_states,\n            key_states,\n            value_states,\n            past_key_value[0],\n            past_key_value[0][..., :nope_size],\n            attn_metadata,\n            k_scales_zeros=None if len(past_key_value) == 2 else past_key_value[2],\n            v_scales_zeros=None if len(past_key_value) == 2 else past_key_value[3],\n            nsa_indices=topk_indices,\n        )\n        attn_bmm_out = attn_output.new_empty(q_len, num_heads, self.v_head_dim)\n\n        self.vc(attn_output, attn_bmm_out)\n        attn_output = attn_bmm_out.flatten(-2, -1)[None]\n        attn_output = self.o_proj(attn_output)\n\n        return attn_output\n\n\nclass DeepseekV32DecoderLayer(DeepseekV2DecoderLayer):\n\n    def __init__(self, config: Any, layer_idx: int, dtype: torch.dtype = None, device: torch.device = None):\n        nn.Module.__init__(self)\n        self.layer_idx = layer_idx\n        quantization_config = None\n\n        # build attention layer\n        if getattr(config, 'use_mla', True):\n            self.self_attn = DeepseekV32Attention(config, layer_idx, dtype=dtype, device=device)\n        else:\n            # deepseek-vl2-tiny uses MHA LlamaAttention structure\n            from lmdeploy.pytorch.models.llama import LlamaAttention\n            self.self_attn = LlamaAttention(config, dtype=dtype, device=device)\n\n        # mlp\n        self.mlp = (DeepseekV2MoE(config, layer_idx, dtype=dtype, device=device) if\n                    (config.n_routed_experts is not None and layer_idx >= config.first_k_dense_replace\n                     and layer_idx % config.moe_layer_freq == 0) else DeepseekV2MLP(config, dtype=dtype, device=device))\n\n        # build input layer norm\n        self.input_layernorm = RMSNorm(config.hidden_size,\n                                       config.rms_norm_eps,\n                                       quant_config=quantization_config,\n                                       dtype=torch.float32,\n                                       device=device)\n\n        # build attention layer norm\n        self.post_attention_layernorm = RMSNorm(config.hidden_size,\n                                                config.rms_norm_eps,\n                                                dtype=torch.float32,\n                                                device=device)\n\n\nclass DeepseekV32Model(DeepseekV2Model):\n\n    def __init__(self, config: Any, dtype: torch.dtype = None, device: torch.device = None):\n        nn.Module.__init__(self)\n        self.config = config\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n        self.embed_tokens = nn.Embedding(config.vocab_size,\n                                         config.hidden_size,\n                                         self.padding_idx,\n                                         dtype=dtype,\n                                         device=device)\n        if get_dist_manager().current_context().dist_config.enable_eplb:\n            ep_size_, _ = get_ep_world_rank()\n            EPLBManager.init_global_eplb_metadata(ep_size_, config.n_routed_experts, config.num_hidden_layers)\n        self.layers = nn.ModuleList([\n            DeepseekV32DecoderLayer(config, layer_idx, dtype=dtype, device=device)\n            for layer_idx in range(config.num_hidden_layers)\n        ])\n\n        # build norm\n        self.norm = RMSNorm(config.hidden_size,\n                            config.rms_norm_eps,\n                            quant_config=None,\n                            dtype=torch.float32,\n                            device=device)\n\n        emb_type = RopeType.LinearScaling\n        rope_dim = config.qk_rope_head_dim if getattr(config, 'use_mla', True) else (config.hidden_size //\n                                                                                     config.num_attention_heads)\n        rope_max_pos_emb = config.max_position_embeddings\n        rope_base = get_rope_theta(config)\n\n        rope_params = dict(emb_type=emb_type, dim=rope_dim, max_position_embeddings=rope_max_pos_emb, base=rope_base)\n        update_params = build_rotary_params(config)\n        rope_params.update(update_params)\n        self.rotary_emb = build_rotary_embedding(**rope_params)\n\n\nclass DeepseekV32ForCausalLM(DeepseekV2ForCausalLM):\n\n    def __init__(self,\n                 config: Any,\n                 ctx_mgr: StepContextManager,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        nn.Module.__init__(self)\n        self.config = config\n        self.quantization_config = getattr(config, 'quantization_config', None)\n        self.dtype = dtype\n        self.ctx_mgr = ctx_mgr\n        self.model = DeepseekV32Model(config, dtype=dtype, device=device)\n        # build lm_head\n        self.lm_head = build_rowwise_linear(config.hidden_size,\n                                            config.vocab_size,\n                                            bias=False,\n                                            dtype=dtype,\n                                            device=device)\n        self._load_buffers = dict()\n"
  },
  {
    "path": "lmdeploy/pytorch/models/deepseek_vl2.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n# adapted from https://github.com/deepseek-ai/DeepSeek-VL2/blob/main/deepseek_vl2/models/modeling_deepseek_vl_v2.py\n\nfrom typing import Any, Dict, Iterable, List, Optional, Tuple\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom einops import rearrange, repeat\nfrom transformers.configuration_utils import PretrainedConfig\n\nfrom lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor, PreprocessInputResult\nfrom lmdeploy.pytorch.model_inputs import StepContext, StepContextManager\nfrom lmdeploy.pytorch.multimodal.data_type import MultiModalData\nfrom lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight\n\nfrom .deepseek_v2 import DeepseekV2ForCausalLM\nfrom .utils.cudagraph import CudaGraphMixin\nfrom .utils.model import DeployModelMixin, vlm_model\n\n\n@vlm_model\nclass MlpProjector(nn.Module):\n\n    def __init__(self, cfg, dtype):\n\n        super().__init__()\n\n        self.cfg = cfg\n\n        if cfg.projector_type == 'identity':\n            modules = nn.Identity()\n\n        elif cfg.projector_type == 'linear':\n            modules = nn.Linear(cfg.input_dim, cfg.n_embed, dtype=dtype)\n\n        elif cfg.projector_type == 'mlp_gelu':\n            mlp_depth = cfg.depth\n            modules = [nn.Linear(cfg.input_dim, cfg.n_embed, dtype=dtype)]\n            for _ in range(1, mlp_depth):\n                modules.append(nn.GELU())\n                modules.append(nn.Linear(cfg.n_embed, cfg.n_embed, dtype=dtype))\n            modules = nn.Sequential(*modules)\n\n        elif cfg.projector_type == 'downsample_mlp_gelu':\n            mlp_depth = cfg.depth\n            mlp_ratio = cfg.mlp_ratio\n            modules = [\n                nn.Linear(cfg.input_dim * cfg.downsample_ratio * cfg.downsample_ratio,\n                          cfg.n_embed * mlp_ratio,\n                          dtype=dtype)\n            ]\n            for _ in range(1, mlp_depth - 1):\n                modules.append(nn.GELU())\n                modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed * mlp_ratio, dtype=dtype))\n            modules.append(nn.GELU())\n            modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed, dtype=dtype))\n            modules = nn.Sequential(*modules)\n\n        else:\n            raise ValueError(f'Unknown projector type: {cfg.projector_type}')\n\n        if cfg.token_pooling:\n            self.token_pooling_layer = nn.Linear(cfg.input_dim * 4, cfg.input_dim, dtype=dtype)\n\n        self.layers = modules\n\n    def forward(self, x):\n        if self.cfg.token_pooling:\n            batch_size, wxh, channels = x.shape\n            w = h = int(wxh**0.5)\n            x = x.view(batch_size, w, h, channels)\n            x = x.permute(0, 3, 1, 2)\n            patches = x.unfold(2, 2, 2).unfold(3, 2, 2)\n            batch_size, channels, h_patches, w_patches, _, _ = patches.size()\n            # concatenate along the channel dimension\n            patches = patches.contiguous().view(batch_size, channels, h_patches * w_patches, -1)\n\n            # pass through the linear layer\n            patches = patches.permute(0, 2, 1, 3).contiguous()\n            patches = patches.view(batch_size, h_patches * w_patches, channels * 4)\n\n            x = self.token_pooling_layer(patches)\n\n        elif self.cfg.projector_type == 'downsample_mlp_gelu':\n            bs, hw, input_dim = x.shape\n            h = w = int((hw)**0.5)\n            \"\"\"Compute padding.\"\"\"\n            if h % self.cfg.downsample_ratio:\n                pad = self.cfg.downsample_ratio - h % self.cfg.downsample_ratio\n            else:\n                pad = 0\n            x = x.reshape(bs, h, w, input_dim)\n            if pad > 0:\n                x = F.pad(x, (0, 0, 0, pad, 0, pad), 'constant', 0)\n            \"\"\"4 to 1 concat\"\"\"\n            x = x.permute(0, 3, 1, 2)  # B, C, H, W\n            x = F.unfold(x, kernel_size=self.cfg.downsample_ratio, stride=self.cfg.downsample_ratio,\n                         padding=0)  # B, C*4, HW // 4\n            x = x.permute(0, 2, 1)\n\n        return self.layers(x)\n\n\nclass DeepseekVLV2ForCausalLM(nn.Module, CudaGraphMixin, DeployModelMixin):\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 ctx_mgr: StepContextManager,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__()\n        self.config = config\n        self.ctx_mgr = ctx_mgr\n\n        # ----------- vision encoder ------------\n        self.vision = self._init_vision_module(dtype=dtype)\n\n        # ----------- vl projector ------------\n        projector_config = config.projector_config\n        self.projector = MlpProjector(projector_config, dtype)\n\n        # image token format\n        self.tile_tag = config.tile_tag\n        self.global_view_pos = config.global_view_pos\n\n        # special tokens used to format image token sequence\n        embed_std = 1 / torch.sqrt(torch.tensor(projector_config.n_embed, dtype=torch.float32))\n        if self.tile_tag == '2D':\n            # <|view_separator|>, <|\\n|>\n            self.image_newline = nn.Parameter(torch.randn(projector_config.n_embed) * embed_std)\n            # fix the typo: view_seperater\n            self.view_seperator = nn.Parameter(torch.randn(projector_config.n_embed) * embed_std)\n        elif self.tile_tag == '1D':\n            # <|tile_x|>, <|tile_global|>\n            candidate_resolutions = config.candidate_resolutions\n            if len(candidate_resolutions) == 0:\n                raise ValueError(\n                    f'len(candidate_resolutions) should be larger than 0, but got {len(candidate_resolutions)}')\n            tile_variants_num = len(candidate_resolutions)\n            self.tile_indicators = nn.Parameter(\n                torch.randn(size=(tile_variants_num + 1, config.aligner.params.n_embed)) * embed_std)\n        else:\n            raise ValueError(f'tile tag should be either 1D or 2D, but got {self.tile_tag}')\n\n        # ----------- language model ------------\n        language_config = config.language_config\n        self.language = DeepseekV2ForCausalLM(config=language_config, ctx_mgr=ctx_mgr, dtype=dtype, device=device)\n\n        #  ----------- input processor ------------\n        self.input_processor = DeepSeekVLV2InputProcessor(config, dtype)\n\n    # adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/deepseek_vl2.py#L359\n    def _init_vision_module(\n        self,\n        dtype: torch.dtype,\n    ) -> nn.Module:\n        try:\n            import timm\n        except ImportError:\n            raise ImportError('Please install timm') from ImportError\n\n        model = timm.create_model(\n            'vit_so400m_patch14_siglip_384.webli',\n            pretrained=False,\n            num_classes=0,\n            dynamic_img_size=True,\n            dynamic_img_pad=True,\n        )\n        model = model.to(dtype=dtype)\n        return model\n\n    def prepare_inputs_embeds(self,\n                              input_ids: torch.LongTensor,\n                              images: Optional[torch.FloatTensor] = None,\n                              images_seq_mask: Optional[torch.LongTensor] = None,\n                              images_spatial_crop: Optional[torch.LongTensor] = None,\n                              **ignore_kwargs):\n        \"\"\"\n\n        Args:\n            input_ids (torch.LongTensor): [b, T]\n            images (torch.FloatTensor): [b, max_n_images, 3, height, width]\n            images_seq_mask (torch.BoolTensor): [b, T]\n            images_spatial_crop (torch.LongTensor): [b, max_n_images, 2]\n\n        Returns:\n            input_embeds (torch.Tensor): [b, T, D]\n        \"\"\"\n\n        if images is None or images_spatial_crop.sum() == 0:\n            return self.language.get_input_embeddings()(input_ids)\n\n        bs, max_n_images, _ = images_spatial_crop.shape\n        batch_num_tiles = [0 for _ in range(bs)]\n        total_tiles = []\n        for idx in range(bs):\n            for jdx in range(max_n_images):\n                num_width_tiles, num_height_tiles = images_spatial_crop[idx, jdx]\n                if num_width_tiles == 0 or num_height_tiles == 0:\n                    break\n                batch_num_tiles[idx] += (1 + num_width_tiles * num_height_tiles)\n\n            total_tiles.append(images[idx, :batch_num_tiles[idx]])\n\n        # [batch_all_tiles, 3, height, width]\n        total_tiles = torch.cat(total_tiles, dim=0)\n        assert total_tiles.shape[0] == sum(batch_num_tiles)\n        if total_tiles.shape[0] == 0:\n            return self.language.get_input_embeddings()(input_ids)\n\n        # [batch_all_tiles, vit_seq_len, c]\n        images_feature = self.vision.forward_features(total_tiles)  # timm siglip forward_features\n\n        # [batch_all_tiles, hw, D]\n        images_embeds = self.projector(images_feature)\n        _, hw, n_dim = images_embeds.shape\n        h = w = int(hw**0.5)\n\n        # put image tokens into the input_embeds, [b, T, D]\n        input_embeds = self.language.get_input_embeddings()(input_ids)\n\n        # fill image token sequence according to self.tile_tag & self.global_view_pos\n        tile_index = 0\n        for idx in range(images_spatial_crop.shape[0]):\n            images_in_this_batch = []\n            for jdx in range(images_spatial_crop.shape[1]):\n\n                # extra global & local features\n                num_width_tiles, num_height_tiles = images_spatial_crop[idx, jdx]\n                if num_width_tiles == 0 or num_height_tiles == 0:\n                    break\n\n                num_tiles_in_image = num_width_tiles * num_height_tiles\n\n                # [hw, D]\n                global_features = images_embeds[tile_index]\n\n                # [num_height_tiles * num_width_tiles, hw, D]\n                local_features = images_embeds[tile_index + 1:tile_index + 1 + num_tiles_in_image]\n\n                tile_index += num_tiles_in_image + 1\n\n                # format global and local features\n                if self.tile_tag == '2D':\n\n                    # ----------------- global view add newline -----------------\n                    # [hw, D] -> [h, w, D]\n                    global_features = global_features.view(h, w, n_dim)\n                    # [D]     -> [h, 1, D]\n                    new_lines_in_global = repeat(self.image_newline, 'd -> h 1 d', h=h)\n                    # cat([h, w, D], [h, 1, D], dim=1) -> [h, w + 1, D]\n                    global_features = torch.cat([global_features, new_lines_in_global], dim=1)\n                    # [h, w + 1, D] -> [h * (w + 1), D]\n                    global_features = global_features.view(-1, n_dim)\n\n                    # ----------------- local view add newline -----------------\n                    # [num_height_tiles * num_width_tiles, h * w, D] -> [num_height_tiles * h, num_width_tiles * w, D]\n                    local_features = rearrange(local_features,\n                                               '(th tw) (h w) d -> (th h) (tw w) d',\n                                               th=num_height_tiles,\n                                               tw=num_width_tiles,\n                                               h=h,\n                                               w=w)\n\n                    # [D] -> [num_height_tiles * h, 1, D]\n                    new_lines_in_local = repeat(self.image_newline, 'd -> (th h) 1 d', th=num_height_tiles, h=h)\n\n                    # [num_height_tiles * h, num_width_tiles * w + 1, D]\n                    local_features = torch.cat([local_features, new_lines_in_local], dim=1)\n\n                    # [num_height_tiles * h, num_width_tiles * w + 1, D]\n                    #   --> [(num_height_tiles * h) * (num_width_tiles * w + 1), D]\n                    local_features = local_features.view(-1, n_dim)\n\n                    # ----------------- merge global and local tiles -----------------\n                    if self.global_view_pos == 'head':\n                        global_local_features = torch.cat(\n                            [global_features, self.view_seperator[None, :], local_features], dim=0)\n                    else:\n                        global_local_features = torch.cat(\n                            [local_features, self.view_seperator[None, :], global_features], dim=0)\n\n                else:\n                    # abandoned，will not step into this logic\n                    global_features = torch.cat([self.tile_indicators[0:1], global_features], dim=0)\n                    local_features = torch.cat(\n                        [self.tile_indicators[1:num_tiles_in_image + 1].unsqueeze(1), local_features], dim=1)\n                    local_features = rearrange(local_features, 'crop_num hw d -> (crop_num hw) d')\n\n                    if self.global_view_pos == 'head':\n                        global_local_features = torch.cat([global_features, local_features], dim=0)\n                    else:\n                        global_local_features = torch.cat([local_features, global_features], dim=0)\n\n                images_in_this_batch.append(global_local_features)\n\n            if len(images_in_this_batch) > 0:\n                images_in_this_batch = torch.cat(images_in_this_batch, dim=0).to(input_embeds.dtype)\n                crt_image_mask = images_seq_mask[idx].unsqueeze(-1).to(input_embeds.device)\n                input_embeds[idx].masked_scatter_(crt_image_mask, images_in_this_batch)\n\n        return input_embeds\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        past_key_values: List[List[torch.Tensor]],\n        attn_metadata: Any = None,\n        pixel_values: torch.Tensor = None,\n        image_mask: torch.Tensor = None,\n        images_spatial_crop: torch.Tensor = None,\n        inputs_embeds: torch.Tensor = None,\n        **kwargs,\n    ):\n        # process image embeddings\n        if inputs_embeds is None and pixel_values is not None:\n            inputs_embeds = self.prepare_inputs_embeds(input_ids=input_ids,\n                                                       images=pixel_values,\n                                                       images_seq_mask=image_mask,\n                                                       images_spatial_crop=images_spatial_crop)\n\n        outputs = self.language.forward(\n            input_ids=input_ids,\n            inputs_embeds=inputs_embeds,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n        )\n        return outputs\n\n    def get_logits(self, hidden_states: torch.Tensor):\n        \"\"\"Compute logits of the model output.\"\"\"\n        return self.language.get_logits(hidden_states)\n\n    def get_input_embeddings(self):\n        \"\"\"Get input embeddings.\"\"\"\n        return self.language.get_input_embeddings()\n\n    def prepare_inputs_for_generation(\n        self,\n        past_key_values: List[List[torch.Tensor]],\n        inputs_embeds: torch.Tensor = None,\n        context: StepContext = None,\n    ):\n        \"\"\"Prepare input.\"\"\"\n        input_ids = context.input_ids\n        position_ids = context.position_ids\n        attn_metadata = context.attn_metadata\n\n        # vision inputs\n        pixel_values = None\n        images_spatial_crop = None\n        image_mask = None\n        if context.input_multimodals is not None:\n            pixel_values = [input_mm.get('image', []) for input_mm in context.input_multimodals]\n            images_spatial_crop = [p_value[0].meta.get('images_spatial_crop', None) for p_value in pixel_values]\n            # flatten batch\n            pixel_values = [data for im_data in pixel_values for data in im_data]\n            if len(pixel_values) > 0:\n                image_token_id = pixel_values[0].meta['image_token_id']\n                image_mask = input_ids == image_token_id\n                pixel_values = torch.cat([data.data for data in pixel_values]).unsqueeze(0)\n            else:\n                pixel_values = None\n                image_mask = None\n\n            if len(images_spatial_crop) > 0:\n                images_spatial_crop = torch.cat([crop for crop in images_spatial_crop]).unsqueeze(0)\n            else:\n                images_spatial_crop = None\n\n        return dict(\n            input_ids=input_ids,  # [b, T]\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            pixel_values=pixel_values,  # [b, max_n_images, 3, height, width]\n            images_spatial_crop=images_spatial_crop,  # [b, max_n_images, 2]\n            image_mask=image_mask,  # [b, T]\n            inputs_embeds=inputs_embeds,\n        )\n\n    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):\n        \"\"\"Load weights.\"\"\"\n\n        lang_prefix = 'language.'\n        lang_prefix_length = len(lang_prefix)\n        new_weights = dict()\n        params_dict = dict(self.named_parameters())\n\n        for name, loaded_weight in weights:\n            if name.startswith(lang_prefix):\n                new_key = name[lang_prefix_length:]\n                new_weights[new_key] = loaded_weight\n                continue\n\n            if 'qkv' in name and 'vision' not in name:\n                param = params_dict[name]\n                q, k, v = param.weight_spliter(loaded_weight)\n                load_weight(param, q, shard_id='q')\n                load_weight(param, k, shard_id='k')\n                load_weight(param, v, shard_id='v')\n            else:\n                param = params_dict[name]\n                load_weight(param, loaded_weight)\n\n        self.language.load_weights(new_weights.items())\n\n    def get_input_processor(self) -> BaseModelInputProcessor:\n        \"\"\"Get input processor.\"\"\"\n        return self.input_processor\n\n\nclass DeepSeekVLV2InputProcessor(BaseModelInputProcessor):\n    \"\"\"Deepseek-vl2 input processor.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype) -> None:\n        self.config = config\n        self.dtype = dtype\n        vision_config = config.vision_config\n        self.patch_size = vision_config.patch_size\n\n    def preprocess_input(self,\n                         input_ids: List[int],\n                         input_multimodals: List[Dict[str, Any]] = None,\n                         **kwargs) -> PreprocessInputResult:\n        \"\"\"Prepare multimodal input.\"\"\"\n        if input_multimodals is None or len(input_multimodals) == 0:\n            return input_ids, input_multimodals\n\n        input_imgs = []\n        for input_mm in input_multimodals:\n            pixel_values = input_mm['pixel_values'].to(self.dtype)\n            offset = input_mm['offset']\n            image_token_id = input_mm['image_token_id']\n            num_pad = input_mm['image_tokens']\n            images_spatial_crop = input_mm.get('images_spatial_crop', None)\n            if isinstance(num_pad, torch.Tensor):\n                num_pad = num_pad.item()\n\n            mm_data = MultiModalData(data=pixel_values,\n                                     start=offset,\n                                     end=offset + num_pad,\n                                     meta=dict(\n                                         image_token_id=image_token_id,\n                                         images_spatial_crop=images_spatial_crop,\n                                     ))\n\n            input_imgs.append(mm_data)\n\n        result = PreprocessInputResult(\n            input_ids=input_ids,\n            input_multimodals=dict(image=input_imgs),\n        )\n\n        return result\n"
  },
  {
    "path": "lmdeploy/pytorch/models/gemma.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nimport math\nfrom typing import Any, Iterable, List, Optional, Tuple\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn\nfrom transformers.configuration_utils import PretrainedConfig\n\nfrom lmdeploy.pytorch.model_inputs import StepContext, StepContextManager\nfrom lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, GeluAndMul, RMSNorm, RopeType, build_rotary_embedding,\n                                 build_rotary_embedding_from_config)\nfrom lmdeploy.pytorch.nn.linear import (build_down_linear, build_gateup_linear, build_o_proj, build_qkv_proj,\n                                        build_rowwise_linear)\nfrom lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight\n\nfrom .utils.cudagraph import CudaGraphMixin\n\n\nclass GemmaAttention(nn.Module):\n    \"\"\"Rewrite module of GemmaAttention.\"\"\"\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 layer_idx: int,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__()\n        self.layer_idx = layer_idx\n        quantization_config = getattr(config, 'quantization_config', None)\n        num_heads = config.num_attention_heads\n        num_key_value_heads = config.num_key_value_heads\n        hidden_size = config.hidden_size\n        head_dim = config.head_dim\n        num_replicate_kv_heads = getattr(config, 'num_replicate_key_value_heads', 1)\n        # packed qkv\n        self.qkv_proj = build_qkv_proj(hidden_size,\n                                       num_q_heads=num_heads,\n                                       num_kv_heads=num_key_value_heads,\n                                       head_size=head_dim,\n                                       bias=config.attention_bias,\n                                       quant_config=quantization_config,\n                                       dtype=dtype,\n                                       device=device,\n                                       num_replicate_kv_heads=num_replicate_kv_heads)\n\n        # rotary embedding\n        self.apply_rotary_pos_emb = ApplyRotaryEmb()\n        self.model_type = config.model_type\n\n        # attention\n        self.num_heads = num_heads\n        self.head_dim = head_dim\n        self.num_kv_heads = num_key_value_heads\n        self.scaling = 1 / math.sqrt(config.head_dim)\n        if hasattr(config, 'query_pre_attn_scalar'):\n            self.scaling = config.query_pre_attn_scalar**-0.5\n        if self.model_type == 'gemma3_text':\n            sliding_window_pattern = getattr(config, 'sliding_window_pattern', 6)\n            is_sliding = bool((layer_idx + 1) % sliding_window_pattern)\n            self.sliding_window = (getattr(config, 'sliding_window', -1) if is_sliding else -1)\n        else:\n            self.sliding_window = (getattr(config, 'sliding_window', -1) if not bool(layer_idx % 2) else -1)\n        logit_softcapping = getattr(config, 'attn_logit_softcapping', 0.0)\n        if logit_softcapping is None:\n            logit_softcapping = 0.0\n        self.attn_fwd = Attention(num_heads,\n                                  head_dim,\n                                  scale=self.scaling,\n                                  num_kv_heads=num_key_value_heads,\n                                  sliding_window=self.sliding_window,\n                                  logit_softcapping=logit_softcapping)\n\n        # o_proj\n        self.o_proj = build_o_proj(num_heads * head_dim,\n                                   hidden_size,\n                                   bias=config.attention_bias,\n                                   quant_config=quantization_config,\n                                   dtype=dtype,\n                                   device=device,\n                                   is_tp=True)\n\n        if self.model_type == 'gemma3_text':\n            self.q_norm = RMSNorm(config.head_dim,\n                                  config.rms_norm_eps,\n                                  quant_config=quantization_config,\n                                  dtype=dtype,\n                                  device=device)\n            self.k_norm = RMSNorm(config.head_dim,\n                                  config.rms_norm_eps,\n                                  quant_config=quantization_config,\n                                  dtype=dtype,\n                                  device=device)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],\n        rotary_pos_emb_local: Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] = None,\n        past_key_value: Optional[Tuple[torch.Tensor]] = None,\n        attn_metadata: Any = None,\n        global_attn_masks: torch.Tensor = None,\n        local_attn_masks: torch.Tensor = None,\n    ):\n        \"\"\"Rewrite of LlamaAttention.forward.\"\"\"\n        # qkv proj\n        qkv_states = self.qkv_proj(hidden_states)\n        # (-1, heads, head_dim)\n        qkv_states = qkv_states.flatten(0, -2)\n        query_states, key_states, value_states = self.qkv_proj.split_qkv(qkv_states)\n\n        if self.model_type == 'gemma3_text':\n            query_states = self.q_norm(query_states)\n            key_states = self.k_norm(key_states)\n\n        # apply rotary embedding\n        if rotary_pos_emb_local is not None and self.sliding_window != -1:\n            cos, sin = rotary_pos_emb_local\n        else:\n            cos, sin = rotary_pos_emb\n        query_states, key_states = self.apply_rotary_pos_emb(\n            query_states,\n            key_states,\n            cos,\n            sin,\n            inplace=True,\n        )\n\n        gemma3_naive_attn_with_masks = global_attn_masks is not None and local_attn_masks is not None\n        # attention\n        attn_output = self.attn_fwd(\n            query_states,\n            key_states,\n            value_states,\n            past_key_value[0],\n            past_key_value[1],\n            attn_metadata,\n            k_scales_zeros=None if len(past_key_value) == 2 else past_key_value[2],\n            v_scales_zeros=None if len(past_key_value) == 2 else past_key_value[3],\n            inplace=not gemma3_naive_attn_with_masks,\n        )\n        attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1)\n\n        # gemma3 VL applied different attn masks\n        # intentionally compute attn twice to fill kv cache\n        if gemma3_naive_attn_with_masks is True:\n            attn_masks = local_attn_masks if self.sliding_window > 0 else global_attn_masks\n\n            attn_output = self.naive_attn_with_masks(query_states,\n                                                     key_states,\n                                                     value_states,\n                                                     out=attn_output,\n                                                     attn_masks=attn_masks,\n                                                     seq_lens=attn_metadata.q_seqlens)\n\n        # o proj\n        attn_output = self.o_proj(attn_output)\n        return attn_output\n\n    # adapted from https://github.com/vllm-project/vllm/blob/5eeabc2a4400fde9b030f2f72746a2b03db059bd/vllm/model_executor/models/gemma3.py#L218  # noqa\n    def naive_attn_with_masks(\n        self,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        out: torch.Tensor,\n        attn_masks: torch.Tensor,\n        seq_lens: torch.Tensor,\n    ) -> torch.Tensor:\n        q_len = q.shape[0]\n        q = q.view(q_len, -1, self.head_dim)\n        # Expand the key and value to handle GQA.\n        num_queries_per_kv = self.num_heads // self.num_kv_heads\n        k = k.view(q_len, -1, self.head_dim)\n        k = k.repeat_interleave(num_queries_per_kv, dim=-2)\n        v = v.view(q_len, -1, self.head_dim)\n        v = v.repeat_interleave(num_queries_per_kv, dim=-2)\n\n        start_idx = 0\n        for seq_len, attn_mask in zip(seq_lens, attn_masks):\n            end_idx = start_idx + seq_len\n            query = q[start_idx:end_idx].unsqueeze(0)\n            key = k[start_idx:end_idx].unsqueeze(0)\n            value = v[start_idx:end_idx].unsqueeze(0)\n\n            # Transpose.\n            query = query.transpose(1, 2)\n            key = key.transpose(1, 2)\n            value = value.transpose(1, 2)\n\n            output = F.scaled_dot_product_attention(\n                query,\n                key,\n                value,\n                attn_mask,\n                self.scaling,\n            )\n            output = output.transpose(1, 2).flatten(-2, -1)\n            out[start_idx:end_idx] = output\n            start_idx = end_idx\n        return out\n\n\nclass GemmaMLP(nn.Module):\n    \"\"\"mlp.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        quantization_config = getattr(config, 'quantization_config', None)\n        # gate up\n        self.gate_up_proj = build_gateup_linear(\n            config.hidden_size,\n            [config.intermediate_size, config.intermediate_size],\n            bias=False,\n            dtype=dtype,\n            device=device,\n            quant_config=quantization_config,\n            is_tp=True,\n        )\n\n        hidden_activation = getattr(config, 'hidden_activation', None)\n        if hidden_activation is None:\n            hidden_activation = 'gelu_pytorch_tanh'\n            assert hidden_activation == 'gelu_pytorch_tanh'\n        self.act_fn = GeluAndMul(approximate='tanh')\n\n        # down\n        self.down_proj = build_down_linear(config.intermediate_size,\n                                           config.hidden_size,\n                                           bias=False,\n                                           quant_config=quantization_config,\n                                           dtype=dtype,\n                                           device=device,\n                                           is_tp=True)\n\n    def forward(self, x):\n        \"\"\"forward.\"\"\"\n        gate_up = self.gate_up_proj(x)\n        act = self.act_fn(gate_up)\n        out = self.down_proj(act)\n        return out\n\n\nclass GemmaDecoderLayer(nn.Module):\n    \"\"\"Llama decoder layer.\"\"\"\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 layer_idx: int,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__()\n        self.layer_idx = layer_idx\n        quantization_config = getattr(config, 'quantization_config', None)\n\n        # build attention layer\n        self.self_attn = GemmaAttention(config, layer_idx, dtype=dtype, device=device)\n\n        # build MLP\n        self.mlp = GemmaMLP(config, dtype=dtype, device=device)\n\n        # build input layer norm\n        self.input_layernorm = RMSNorm(config.hidden_size,\n                                       config.rms_norm_eps,\n                                       quant_config=quantization_config,\n                                       dtype=dtype,\n                                       device=device)\n\n        # build attention layer norm\n        self.post_attention_layernorm = RMSNorm(config.hidden_size,\n                                                config.rms_norm_eps,\n                                                quant_config=quantization_config,\n                                                dtype=dtype,\n                                                device=device)\n\n        self.model_type = config.model_type\n        if self.model_type in ('gemma2', 'gemma3_text'):\n            self.pre_feedforward_layernorm = RMSNorm(config.hidden_size,\n                                                     config.rms_norm_eps,\n                                                     quant_config=quantization_config,\n                                                     dtype=dtype,\n                                                     device=device)\n            self.post_feedforward_layernorm = RMSNorm(config.hidden_size,\n                                                      config.rms_norm_eps,\n                                                      quant_config=quantization_config,\n                                                      dtype=dtype,\n                                                      device=device)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],\n        past_key_value: Optional[List[torch.FloatTensor]],\n        rotary_pos_emb_local: Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] = None,\n        residual: Optional[torch.Tensor] = None,\n        attn_metadata: Any = None,\n        global_attn_masks: torch.Tensor = None,\n        local_attn_masks: torch.Tensor = None,\n    ):\n\n        if residual is None:\n            residual = hidden_states\n            hidden_states = self.input_layernorm(hidden_states)\n        else:\n            hidden_states, residual = self.input_layernorm(hidden_states, residual)\n\n        # Self Attention\n        hidden_states = self.self_attn(\n            hidden_states=hidden_states,\n            rotary_pos_emb=rotary_pos_emb,\n            rotary_pos_emb_local=rotary_pos_emb_local,\n            past_key_value=past_key_value,\n            attn_metadata=attn_metadata,\n            global_attn_masks=global_attn_masks,\n            local_attn_masks=local_attn_masks,\n        )\n\n        # Fully Connected\n\n        if self.model_type in ('gemma2', 'gemma3_text'):\n            hidden_states = self.post_attention_layernorm(hidden_states)\n            hidden_states, residual = self.pre_feedforward_layernorm(hidden_states, residual)\n            hidden_states = self.mlp(hidden_states)\n            hidden_states = self.post_feedforward_layernorm(hidden_states)\n        else:\n            hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)\n            hidden_states = self.mlp(hidden_states)\n\n        outputs = (hidden_states, residual)\n        return outputs\n\n\nclass Gemma3TextScaledWordEmbedding(nn.Embedding):\n    \"\"\"This module overrides nn.Embeddings' forward by multiplying with\n    embeddings scale.\"\"\"\n\n    def __init__(self,\n                 num_embeddings: int,\n                 embedding_dim: int,\n                 padding_idx: int,\n                 dtype=torch.dtype,\n                 embed_scale: Optional[float] = 1.0):\n        super().__init__(num_embeddings, embedding_dim, padding_idx, dtype=dtype)\n        self.embed_scale = embed_scale\n\n    def forward(self, input_ids: torch.Tensor):\n        return super().forward(input_ids) * self.embed_scale\n\n\nclass GemmaModel(nn.Module):\n    \"\"\"model.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        self.config = config\n        self.model_type = config.model_type\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n\n        if self.config.model_type == 'gemma3_text':\n            self.embed_tokens = Gemma3TextScaledWordEmbedding(config.vocab_size,\n                                                              config.hidden_size,\n                                                              self.padding_idx,\n                                                              dtype=dtype,\n                                                              embed_scale=config.hidden_size**0.5)\n        else:\n            self.embed_tokens = nn.Embedding(config.vocab_size,\n                                             config.hidden_size,\n                                             self.padding_idx,\n                                             dtype=dtype,\n                                             device=device)\n\n        # build all decode layers\n        self.layers = nn.ModuleList([\n            GemmaDecoderLayer(config, layer_idx, dtype=dtype, device=device)\n            for layer_idx in range(config.num_hidden_layers)\n        ])\n\n        # build norm\n        self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)\n\n        # build rotary embedding\n        self.build_rope_emb(config)\n\n    def build_rope_emb(self, config: PretrainedConfig):\n        rope_dim = config.head_dim\n        rope_max_pos_emb = config.max_position_embeddings\n\n        if self.model_type != 'gemma3_text':\n            self.rotary_emb = build_rotary_embedding_from_config(config)\n            return\n\n        # for gemma3\n        if hasattr(config, 'rope_local_base_freq'):\n            rope_base = config.rope_local_base_freq\n            self.rotary_emb = build_rotary_embedding_from_config(config)\n\n            if self.model_type == 'gemma3_text':\n                self.rotary_emb_local = build_rotary_embedding(\n                    rope_dim,\n                    rope_max_pos_emb,\n                    rope_base,\n                    emb_type=RopeType.Default,\n                )\n        else:\n            # for transformers>=5\n            rope_dim = config.head_dim\n            from lmdeploy.pytorch.nn.rotary_embedding import get_rope_parameters\n            rope_parameters = get_rope_parameters(config)\n            full_attention = rope_parameters['full_attention']\n            sliding_attention = rope_parameters['sliding_attention']\n            # note that emb type has been fixed.\n            self.rotary_emb = build_rotary_embedding(\n                rope_dim,\n                rope_max_pos_emb,\n                base=full_attention['rope_theta'],\n                scaling_factor=full_attention['factor'],\n                emb_type=RopeType.LinearScaling,\n            )\n            self.rotary_emb_local = build_rotary_embedding(\n                rope_dim,\n                rope_max_pos_emb,\n                base=sliding_attention['rope_theta'],\n                emb_type=RopeType.Default,\n            )\n\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        attn_metadata: Any = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        global_attn_masks: torch.Tensor = None,\n        local_attn_masks: torch.Tensor = None,\n    ):\n        \"\"\"Rewrite of LlamaModel.forward.\"\"\"\n\n        # token embedding\n        if inputs_embeds is None:\n            inputs_embeds = self.embed_tokens(input_ids)\n\n        hidden_states = inputs_embeds\n        if self.model_type != 'gemma3_text':\n            hidden_states = hidden_states * (self.config.hidden_size**0.5)\n\n        # rotary embedding\n        cos, sin = self.rotary_emb(hidden_states, position_ids)\n        cos, sin = cos[0], sin[0]\n        rotary_pos_emb = (cos, sin)\n        rotary_pos_emb_local = None\n        if self.model_type == 'gemma3_text':\n            cos_local, sin_local = self.rotary_emb_local(hidden_states, position_ids)\n            cos_local, sin_local = cos_local[0], sin_local[0]\n            rotary_pos_emb_local = (cos_local, sin_local)\n\n        # decoding\n        residual = None\n        for idx, decoder_layer in enumerate(self.layers):\n            past_key_value = past_key_values[idx]\n            hidden_states, residual = decoder_layer(\n                hidden_states,\n                rotary_pos_emb=rotary_pos_emb,\n                rotary_pos_emb_local=rotary_pos_emb_local,\n                past_key_value=past_key_value,\n                residual=residual,\n                attn_metadata=attn_metadata,\n                global_attn_masks=global_attn_masks,\n                local_attn_masks=local_attn_masks,\n            )\n\n        # norm\n        hidden_states, _ = self.norm(hidden_states, residual)\n\n        return hidden_states\n\n    def get_input_embeddings(self):\n        \"\"\"Get input embeddings.\"\"\"\n        return self.embed_tokens\n\n\nclass GemmaForCausalLM(nn.Module, CudaGraphMixin):\n    \"\"\"ModelForCausalLM.\"\"\"\n\n    packed_modules_mapping = {\n        'qkv_proj': [\n            'q_proj',\n            'k_proj',\n            'v_proj',\n        ],\n        'gate_up_proj': [\n            'gate_proj',\n            'up_proj',\n        ],\n    }\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 ctx_mgr: StepContextManager,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__()\n        self.config = config\n        self.ctx_mgr = ctx_mgr\n        # build model\n        self.model = GemmaModel(config, dtype=dtype, device=device)\n        # build lm_head\n        self.lm_head = build_rowwise_linear(config.hidden_size,\n                                            config.vocab_size,\n                                            bias=False,\n                                            dtype=dtype,\n                                            device=device)\n        self.final_logit_softcapping = getattr(config, 'final_logit_softcapping', None)\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        past_key_values: List[List[torch.Tensor]],\n        attn_metadata: Any = None,\n        inputs_embeds: torch.Tensor = None,\n        global_attn_masks: torch.Tensor = None,\n        local_attn_masks: torch.Tensor = None,\n        **kwargs,\n    ):\n        \"\"\"Model forward, return logits.\"\"\"\n        hidden_states = self.model(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            inputs_embeds=inputs_embeds,\n            global_attn_masks=global_attn_masks,\n            local_attn_masks=local_attn_masks,\n        )\n        return hidden_states\n\n    def get_logits(self, hidden_states: torch.Tensor):\n        \"\"\"Compute logits of the model output.\"\"\"\n        logits = self.lm_head(hidden_states)\n        if self.final_logit_softcapping is not None:\n            logits = logits / self.final_logit_softcapping\n            logits = torch.tanh(logits)\n            logits = logits * self.final_logit_softcapping\n        return logits\n\n    def get_input_embeddings(self):\n        \"\"\"Get input embeddings.\"\"\"\n        return self.model.get_input_embeddings()\n\n    def prepare_inputs_for_generation(\n        self,\n        past_key_values: List[List[torch.Tensor]],\n        inputs_embeds: Optional[torch.Tensor] = None,\n        context: StepContext = None,\n    ):\n        \"\"\"Prepare input.\"\"\"\n        # get input_ids, position_ids and attention metadatas\n        input_ids = context.input_ids\n        position_ids = context.position_ids\n        attn_metadata = context.attn_metadata\n\n        # process vision embeddings\n        vision_embeddings = context.input_embeddings\n        vision_embedding_indexing = context.input_embedding_indexing\n        if vision_embeddings is not None and len(vision_embeddings) > 0:\n            if inputs_embeds is None:\n                inputs_embeds = self.get_input_embeddings()(input_ids)\n            inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds)\n\n        # inputs of forward\n        return dict(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            inputs_embeds=inputs_embeds,\n        )\n\n    def update_weights(self):\n        \"\"\"Update weights.\"\"\"\n        self.lm_head.weight = self.model.embed_tokens.weight\n\n    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):\n        \"\"\"Load weights.\"\"\"\n        # modify from vllm\n        stacked_params_mapping = [\n            # (param_name, shard_name, shard_id)\n            ('.qkv_proj', '.q_proj', 'q'),\n            ('.qkv_proj', '.k_proj', 'k'),\n            ('.qkv_proj', '.v_proj', 'v'),\n            ('.gate_up_proj', '.gate_proj', 0),\n            ('.gate_up_proj', '.up_proj', 1),\n        ]\n        norm_layers = [\n            '.norm', '.input_layernorm', '.post_attention_layernorm', 'pre_feedforward_layernorm',\n            'post_feedforward_layernorm', 'q_norm', 'k_norm'\n        ]\n\n        params_dict = dict(self.named_parameters())\n        for name, loaded_weight in weights:\n            if 'rotary_emb.inv_freq' in name:\n                continue\n            if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):\n                continue\n            if 'lm_head' in name:\n                continue\n            for (param_name, weight_name, shard_id) in stacked_params_mapping:\n                if weight_name not in name:\n                    continue\n                name = name.replace(weight_name, param_name)\n                param = params_dict[name]\n                load_weight(param, loaded_weight, shard_id=shard_id)\n                break\n            else:\n                for weight_name in norm_layers:\n                    if weight_name not in name:\n                        continue\n                    loaded_weight += 1\n                    break\n                param = params_dict[name]\n                load_weight(param, loaded_weight)\n"
  },
  {
    "path": "lmdeploy/pytorch/models/gemma3_vl.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nfrom typing import Any, Dict, Iterable, List, Tuple\n\nimport torch\nfrom torch import nn\nfrom transformers.configuration_utils import PretrainedConfig\n\nfrom lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor, PreprocessInputResult\nfrom lmdeploy.pytorch.model_inputs import StepContext, StepContextManager\nfrom lmdeploy.pytorch.multimodal.data_type import MultiModalData\nfrom lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight\n\nfrom .patch import build_model_from_hf_config\nfrom .siglip import SiglipVisionModel\nfrom .utils.cudagraph import CudaGraphMixin\nfrom .utils.model import DeployModelMixin\n\n\nclass Gemma3RMSNorm(nn.Module):\n\n    def __init__(self, dim: int, eps: float = 1e-6):\n        super().__init__()\n        self.eps = eps\n        self.weight = nn.Parameter(torch.zeros(dim))\n\n    def _norm(self, x):\n        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)\n\n    def forward(self, x):\n        output = self._norm(x.float())\n        # Llama does x.to(float16) * w whilst Gemma3 is (x * w).to(float16)\n        # See https://github.com/huggingface/transformers/pull/29402\n        output = output * (1.0 + self.weight.float())\n        return output.type_as(x)\n\n    def extra_repr(self):\n        return f'{tuple(self.weight.shape)}, eps={self.eps}'\n\n\nclass Gemma3MultiModalProjector(nn.Module):\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 ctx_mgr: StepContextManager,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__()\n        self.config = config\n        self.ctx_mgr = ctx_mgr\n\n        self.mm_input_projection_weight = nn.Parameter(\n            torch.zeros(config.vision_config.hidden_size, config.text_config.hidden_size, dtype=dtype, device=device))\n\n        self.mm_soft_emb_norm = Gemma3RMSNorm(config.vision_config.hidden_size, eps=config.vision_config.layer_norm_eps)\n\n        self.patches_per_image = int(config.vision_config.image_size // config.vision_config.patch_size)\n        self.tokens_per_side = int(config.mm_tokens_per_image**0.5)\n        self.kernel_size = self.patches_per_image // self.tokens_per_side\n        self.avg_pool = nn.AvgPool2d(kernel_size=self.kernel_size, stride=self.kernel_size)\n\n    def forward(self, vision_outputs: torch.Tensor):\n        batch_size, _, seq_length = vision_outputs.shape\n\n        reshaped_vision_outputs = vision_outputs.transpose(1, 2)\n        reshaped_vision_outputs = reshaped_vision_outputs.reshape(batch_size, seq_length, self.patches_per_image,\n                                                                  self.patches_per_image)\n        reshaped_vision_outputs = reshaped_vision_outputs.contiguous()\n\n        pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs)\n        pooled_vision_outputs = pooled_vision_outputs.flatten(2)\n        pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2)\n\n        normed_vision_outputs = self.mm_soft_emb_norm(pooled_vision_outputs)\n\n        projected_vision_outputs = torch.matmul(normed_vision_outputs, self.mm_input_projection_weight)\n        return projected_vision_outputs.type_as(vision_outputs)\n\n\nclass Gemma3VLInputProcessor(BaseModelInputProcessor):\n    \"\"\"Internvl input processor.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype) -> None:\n        self.config = config\n        self.dtype = dtype\n\n        vision_config = config.vision_config\n        self.image_size = vision_config.image_size\n        self.patch_size = vision_config.patch_size\n        self.num_patches = (self.image_size // self.patch_size)**2\n        self.num_positions = self.num_patches + 1\n        self.vision_token_num = self.num_patches // 4\n\n    def preprocess_input(self,\n                         input_ids: List[int],\n                         input_multimodals: List[Dict[str, Any]] = None,\n                         **kwargs) -> PreprocessInputResult:\n        \"\"\"Prepare multimodal input.\"\"\"\n        if input_multimodals is None or len(input_multimodals) == 0:\n            return input_ids, input_multimodals\n\n        input_imgs = []\n        for input_mm in input_multimodals:\n            pixel_values = input_mm['pixel_values'].to(self.dtype)\n            offset = input_mm['offset']\n            image_token_id = input_mm['image_token_id']\n            num_pad = input_mm['image_tokens']\n            if isinstance(num_pad, torch.Tensor):\n                num_pad = num_pad.item()\n\n            mm_data = MultiModalData(data=pixel_values,\n                                     start=offset,\n                                     end=offset + num_pad,\n                                     meta=dict(image_token_id=image_token_id))\n            input_imgs.append(mm_data)\n\n        result = PreprocessInputResult(\n            input_ids=input_ids,\n            input_multimodals=dict(image=input_imgs),\n        )\n        return result\n\n\nclass Gemma3ForConditionalGeneration(nn.Module, CudaGraphMixin, DeployModelMixin):\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 ctx_mgr: StepContextManager,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__()\n        self.config = config\n        self.ctx_mgr = ctx_mgr\n        text_config = config.text_config\n        self.sliding_window = text_config.sliding_window\n        self.language_model = build_model_from_hf_config(text_config, dtype=dtype, device=device)\n        self.vision_tower = SiglipVisionModel(config=config.vision_config, ctx_mgr=ctx_mgr, dtype=dtype, device=device)\n        self.multi_modal_projector = Gemma3MultiModalProjector(config, ctx_mgr=ctx_mgr, dtype=dtype, device=device)\n        self.input_processor = Gemma3VLInputProcessor(self.config, dtype=dtype)\n\n    def get_input_embeddings(self):\n        return self.language_model.get_input_embeddings()\n\n    def get_logits(self, hidden_states: torch.Tensor):\n        \"\"\"Compute logits of the model output.\"\"\"\n        return self.language_model.get_logits(hidden_states)\n\n    def get_image_features(self, pixel_values: torch.Tensor):\n        \"\"\"Projects the last hidden state from the vision model into language\n        model space.\n\n        Args:\n            pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)\n               The tensors corresponding to the input images.\n        Returns:\n            image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).\n        \"\"\"\n        vision_outputs = self.vision_tower(pixel_values=pixel_values)\n        image_features = self.multi_modal_projector(vision_outputs)\n        return image_features\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        past_key_values: List[List[torch.Tensor]],\n        attn_metadata: Any = None,\n        pixel_values: torch.FloatTensor = None,\n        image_mask: torch.Tensor = None,\n        inputs_embeds: torch.Tensor = None,\n        vision_embedding_indexing: torch.Tensor = None,\n        text_embedding_indexing: torch.Tensor = None,\n        **kwargs,\n    ):\n        r\"\"\"\n        Args:\n            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,\n                config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to\n                 `-100` are ignored (masked), the loss is only computed for the tokens with labels in\n                 `[0, ..., config.text_config.vocab_size]`.\n\n            logits_to_keep (`int` or `torch.Tensor`, *optional*):\n                If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all\n                `input_ids` (special case). Only last token logits are needed for generation, and calculating them only\n                for that token can save memory, which becomes pretty significant for long sequences or large vocabulary\n                size. If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length\n                dimension. This is useful when using packed tensor format (single dimension for batch and\n                sequence length).\n        \"\"\"\n\n        if inputs_embeds is None and pixel_values is not None:\n            # extract feature\n            vit_embeds = self.get_image_features(pixel_values)\n            lang_embeds = self.get_input_embeddings()(input_ids)\n            lang_embeds.masked_scatter_(image_mask[..., None], vit_embeds)\n\n            inputs_embeds = lang_embeds\n        if pixel_values is not None:\n            kwargs = self.prepare_attn_masks(input_ids[0], position_ids[0], mask_dtype=pixel_values.dtype, **kwargs)\n\n        hidden_states = self.language_model(input_ids,\n                                            position_ids,\n                                            inputs_embeds=inputs_embeds,\n                                            past_key_values=past_key_values,\n                                            attn_metadata=attn_metadata,\n                                            **kwargs)\n\n        return hidden_states\n\n    # modified from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/gemma3_mm.py#L539\n    def prepare_attn_masks(\n        self,\n        input_ids: torch.Tensor,\n        positions: torch.Tensor,\n        mask_dtype: torch.dtype,\n        **kwargs,\n    ):\n        kwargs['has_images'] = True\n        start_idices = (positions == 0).cpu().nonzero()\n        num_seqs = len(start_idices)\n        seq_lens = []\n        for i in range(num_seqs):\n            start_idx = start_idices[i].item()\n            if i < num_seqs - 1:\n                end_idx = start_idices[i + 1].item()\n            else:\n                end_idx = len(input_ids)\n            seq_lens.append(end_idx - start_idx)\n        kwargs['seq_lens'] = seq_lens\n\n        global_attn_masks = []\n        local_attn_masks = []\n        start_idx = 0\n        for seq_len in seq_lens:\n            end_idx = start_idx + seq_len\n            input_token_ids = input_ids[start_idx:end_idx]\n            start_idx = end_idx\n            # Create a global causal mask.\n            global_attn_mask = torch.empty(\n                1,\n                1,\n                seq_len,\n                seq_len,\n                dtype=mask_dtype,\n                device=input_ids.device,\n            )\n            global_attn_mask.fill_(float('-inf'))\n            # Fill the lower triangle with 0.\n            global_attn_mask = global_attn_mask.triu(diagonal=1)\n\n            # Consider the bidirectional attention between image tokens.\n            img_mask = torch.zeros_like(global_attn_mask)\n            img_pos = (input_token_ids == self.config.image_token_index)\n            img_mask[:, :, :, img_pos] += 1\n            img_mask[:, :, img_pos, :] += 1\n            global_attn_mask = torch.where(img_mask == 2, 0, global_attn_mask)\n            global_attn_masks.append(global_attn_mask)\n\n            # Create a local causal mask with sliding window (1024).\n            local_attn_mask = torch.ones_like(global_attn_mask)\n            local_attn_mask = torch.tril(local_attn_mask, diagonal=-self.sliding_window)\n            local_attn_mask = torch.where(local_attn_mask == 0, global_attn_mask, float('-inf'))\n            local_attn_masks.append(local_attn_mask)\n        kwargs['global_attn_masks'] = global_attn_masks\n        kwargs['local_attn_masks'] = local_attn_masks\n        return kwargs\n\n    def prepare_inputs_for_generation(\n        self,\n        past_key_values=None,\n        inputs_embeds=None,\n        context: StepContext = None,\n        **kwargs,\n    ):\n        # Overwritten -- custom `position_ids` and `pixel_values` handling\n        model_inputs = self.language_model.prepare_inputs_for_generation(\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            context=context,\n            **kwargs,\n        )\n\n        # vision inputs\n        pixel_values = None\n        image_mask = None\n        if context.input_multimodals is not None:\n            pixel_values = [input_mm.get('image', []) for input_mm in context.input_multimodals]\n            # flatten batch\n            pixel_values = [data for im_data in pixel_values for data in im_data]\n            if len(pixel_values) > 0:\n                image_token_id = pixel_values[0].meta['image_token_id']\n                image_mask = model_inputs['input_ids'] == image_token_id\n                pixel_values = torch.cat([data.data for data in pixel_values])\n            else:\n                pixel_values = None\n                image_mask = None\n        model_inputs['image_mask'] = image_mask\n        model_inputs['pixel_values'] = pixel_values\n        return model_inputs\n\n    def tie_weights(self):\n        return self.language_model.tie_weights()\n\n    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):\n        \"\"\"Load weights.\"\"\"\n        stacked_params_mapping = [\n            # (param_name, shard_name, shard_id)\n            ('.qkv_proj', '.q_proj', 'q'),\n            ('.qkv_proj', '.k_proj', 'k'),\n            ('.qkv_proj', '.v_proj', 'v'),\n        ]\n\n        lang_prefix = 'language_model.'\n        lang_prefix_length = len(lang_prefix)\n        new_weights = dict()\n        params_dict = dict(self.named_parameters())\n        for name, loaded_weight in weights:\n            if name.startswith(lang_prefix):\n                new_key = name[lang_prefix_length:]\n                new_weights[new_key] = loaded_weight\n                continue\n\n            for (param_name, weight_name, shard_id) in stacked_params_mapping:\n                if weight_name not in name:\n                    continue\n                name = name.replace(weight_name, param_name)\n                param = params_dict[name]\n                load_weight(param, loaded_weight, shard_id=shard_id)\n                break\n            else:\n                param = params_dict[name]\n                load_weight(param, loaded_weight)\n\n        self.language_model.load_weights(new_weights.items())\n\n    def get_input_processor(self) -> BaseModelInputProcessor:\n        \"\"\"Get input processor.\"\"\"\n        return self.input_processor\n"
  },
  {
    "path": "lmdeploy/pytorch/models/glm4.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom typing import Any, Iterable, List, Optional, Tuple\n\nimport torch\nfrom torch import nn\nfrom transformers import PretrainedConfig\n\nfrom lmdeploy.pytorch.model_inputs import StepContext, StepContextManager\nfrom lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, RMSNorm, SiluAndMul, build_rotary_embedding_from_config\nfrom lmdeploy.pytorch.nn.linear import (build_down_linear, build_gateup_linear, build_o_proj, build_qkv_proj,\n                                        build_rowwise_linear)\nfrom lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight\n\nfrom .utils.cudagraph import CudaGraphMixin\n\n\nclass Glm4Attention(nn.Module):\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        quantization_config = getattr(config, 'quantization_config', None)\n        num_heads = config.num_attention_heads\n        num_key_value_heads = config.num_key_value_heads\n        hidden_size = config.hidden_size\n        head_dim = getattr(config, 'head_dim', hidden_size // num_heads)\n        num_replicate_kv_heads = getattr(config, 'num_replicate_key_value_heads', 1)\n\n        # packed qkv\n        self.qkv_proj = build_qkv_proj(hidden_size,\n                                       num_q_heads=num_heads,\n                                       num_kv_heads=num_key_value_heads,\n                                       head_size=head_dim,\n                                       bias=config.attention_bias,\n                                       quant_config=quantization_config,\n                                       dtype=dtype,\n                                       device=device,\n                                       num_replicate_kv_heads=num_replicate_kv_heads)\n\n        # rotary embedding\n        self.apply_rotary_pos_emb = ApplyRotaryEmb()\n\n        # attention\n        self.attn_fwd = Attention(num_heads, head_dim, num_kv_heads=num_key_value_heads, v_head_size=head_dim)\n\n        # o_proj\n        self.o_proj = build_o_proj(num_heads * head_dim,\n                                   hidden_size,\n                                   bias=False,\n                                   quant_config=quantization_config,\n                                   dtype=dtype,\n                                   device=device,\n                                   is_tp=True)\n\n    @staticmethod\n    def _extract_rope(states: torch.Tensor):\n        \"\"\"Extract rope.\"\"\"\n        rope = states.chunk(2, -1)[0]\n        rope = rope.unflatten(-1, (-1, 2))\n        rope = rope.transpose(-2, -1).flatten(-2, -1).contiguous()\n        return rope\n\n    @staticmethod\n    def _fill_rope(states: torch.Tensor, rope: torch.Tensor):\n        \"\"\"Fill rope.\"\"\"\n        rope_part = states.chunk(2, -1)[0]\n        rope = rope.unflatten(-1, (2, -1))\n        rope = rope.transpose(-2, -1).flatten(-2, -1)\n        rope_part.copy_(rope)\n        return states\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],\n        past_key_value: Optional[Tuple[torch.Tensor]] = None,\n        attn_metadata: Any = None,\n    ):\n        \"\"\"Rewrite of LlamaAttention.forward.\"\"\"\n        # qkv proj\n        qkv_states = self.qkv_proj(hidden_states)\n        # (-1, heads, head_dim)\n        qkv_states = qkv_states.flatten(0, -2)\n        query_states, key_states, value_states = self.qkv_proj.split_qkv(qkv_states)\n\n        # apply rotary embedding\n        # chatglm series, glm4-0414 have special treatments for rope\n        cos, sin = rotary_pos_emb\n        q_rope = self._extract_rope(query_states)\n        k_rope = self._extract_rope(key_states)\n        q_rope, k_rope = self.apply_rotary_pos_emb(\n            q_rope,\n            k_rope,\n            cos,\n            sin,\n            inplace=True,\n        )\n        query_states = self._fill_rope(query_states, q_rope)\n        key_states = self._fill_rope(key_states, k_rope)\n\n        # attention\n        attn_output = self.attn_fwd(\n            query_states,\n            key_states,\n            value_states,\n            past_key_value[0],\n            past_key_value[1],\n            attn_metadata,\n            k_scales_zeros=None if len(past_key_value) == 2 else past_key_value[2],\n            v_scales_zeros=None if len(past_key_value) == 2 else past_key_value[3],\n            inplace=True,\n        )\n        attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1)\n\n        # o proj\n        attn_output = self.o_proj(attn_output)\n        return attn_output\n\n\nclass Glm4MLP(nn.Module):\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        quantization_config = getattr(config, 'quantization_config', None)\n        # gate up\n        self.gate_up_proj = build_gateup_linear(\n            config.hidden_size,\n            [config.intermediate_size, config.intermediate_size],\n            bias=False,\n            dtype=dtype,\n            device=device,\n            quant_config=quantization_config,\n            is_tp=True,\n        )\n\n        # silu and mul\n        self.act_fn = SiluAndMul(inplace=True)\n\n        # down\n        self.down_proj = build_down_linear(config.intermediate_size,\n                                           config.hidden_size,\n                                           bias=False,\n                                           quant_config=quantization_config,\n                                           dtype=dtype,\n                                           device=device,\n                                           is_tp=True)\n\n    def forward(self, x):\n        \"\"\"forward.\"\"\"\n        gate_up = self.gate_up_proj(x)\n        act = self.act_fn(gate_up)\n        return self.down_proj(act)\n\n\nclass Glm4DecoderLayer(nn.Module):\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 layer_idx: int,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__()\n        self.layer_idx = layer_idx\n        quantization_config = getattr(config, 'quantization_config', None)\n\n        # build attention layer\n        self.self_attn = Glm4Attention(config, dtype=dtype, device=device)\n\n        # build MLP\n        self.mlp = Glm4MLP(config, dtype=dtype, device=device)\n\n        # build input layer norm\n        self.input_layernorm = RMSNorm(config.hidden_size,\n                                       config.rms_norm_eps,\n                                       quant_config=quantization_config,\n                                       dtype=dtype,\n                                       device=device)\n\n        # build post attention layer norm\n        self.post_attention_layernorm = RMSNorm(config.hidden_size,\n                                                config.rms_norm_eps,\n                                                quant_config=quantization_config,\n                                                dtype=dtype,\n                                                device=device)\n\n        # build post self attention layer norm\n        self.post_self_attn_layernorm = RMSNorm(config.hidden_size,\n                                                config.rms_norm_eps,\n                                                quant_config=quantization_config,\n                                                dtype=dtype,\n                                                device=device)\n\n        # build post MLP layer norm\n        self.post_mlp_layernorm = RMSNorm(config.hidden_size,\n                                          config.rms_norm_eps,\n                                          quant_config=quantization_config,\n                                          dtype=dtype,\n                                          device=device)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],\n        past_key_value: Optional[List[torch.FloatTensor]],\n        residual: Optional[torch.Tensor] = None,\n        attn_metadata: Any = None,\n    ):\n        if residual is None:\n            residual = hidden_states\n            hidden_states = self.input_layernorm(hidden_states)\n        else:\n            hidden_states, residual = self.input_layernorm(hidden_states, residual)\n\n        # self attn\n        hidden_states = self.self_attn(\n            hidden_states=hidden_states,\n            rotary_pos_emb=rotary_pos_emb,\n            past_key_value=past_key_value,\n            attn_metadata=attn_metadata,\n        )\n\n        # post self attention layer norm\n        hidden_states = self.post_self_attn_layernorm(hidden_states)\n\n        # post attention layer norm\n        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)\n\n        # MLP\n        hidden_states = self.mlp(hidden_states)\n\n        # post MLP layer norm\n        hidden_states = self.post_mlp_layernorm(hidden_states)\n\n        return (hidden_states, residual)\n\n\nclass Glm4Model(nn.Module):\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n\n        self.embed_tokens = nn.Embedding(config.vocab_size,\n                                         config.hidden_size,\n                                         self.padding_idx,\n                                         dtype=dtype,\n                                         device=device)\n\n        # build all decode layers\n        self.layers = nn.ModuleList([\n            Glm4DecoderLayer(config, layer_idx, dtype=dtype, device=device)\n            for layer_idx in range(config.num_hidden_layers)\n        ])\n\n        # build norm\n        self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)\n\n        # build rotary embedding\n        self.rotary_emb = build_rotary_embedding_from_config(config)\n\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        attn_metadata: Any = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n    ):\n        \"\"\"Rewrite of LlamaModel.forward.\"\"\"\n\n        # token embedding\n        if inputs_embeds is None:\n            inputs_embeds = self.embed_tokens(input_ids)\n\n        hidden_states = inputs_embeds\n\n        # rotary embedding\n        cos, sin = self.rotary_emb(hidden_states, position_ids)\n        cos, sin = cos[0], sin[0]\n        rotary_pos_emb = (cos, sin)\n\n        # decoding\n        residual = None\n        for idx, decoder_layer in enumerate(self.layers):\n            past_key_value = past_key_values[idx]\n            hidden_states, residual = decoder_layer(\n                hidden_states,\n                rotary_pos_emb=rotary_pos_emb,\n                past_key_value=past_key_value,\n                residual=residual,\n                attn_metadata=attn_metadata,\n            )\n\n        # norm\n        hidden_states, _ = self.norm(hidden_states, residual)\n\n        return hidden_states\n\n\nclass Glm4ForCausalLM(nn.Module, CudaGraphMixin):\n    \"\"\"ModelForCausalLM.\"\"\"\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 ctx_mgr: StepContextManager,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__()\n        self.config = config\n        self.ctx_mgr = ctx_mgr\n        # build model\n        self.model = Glm4Model(config, dtype=dtype, device=device)\n        # build lm_head\n        self.lm_head = build_rowwise_linear(config.hidden_size,\n                                            config.vocab_size,\n                                            bias=False,\n                                            dtype=dtype,\n                                            device=device)\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        past_key_values: List[List[torch.Tensor]],\n        attn_metadata: Any = None,\n        inputs_embeds: torch.Tensor = None,\n        **kwargs,\n    ):\n        \"\"\"Model forward, return logits.\"\"\"\n        hidden_states = self.model(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            inputs_embeds=inputs_embeds,\n        )\n        return hidden_states\n\n    def get_logits(self, hidden_states: torch.Tensor):\n        \"\"\"Compute logits of the model output.\"\"\"\n        return self.lm_head(hidden_states)\n\n    def update_weights(self):\n        \"\"\"Update weights.\"\"\"\n        if self.config.tie_word_embeddings:\n            self.lm_head.weight = self.model.embed_tokens.weight\n\n    def get_input_embeddings(self):\n        \"\"\"Get input embeddings.\"\"\"\n        return self.model.get_input_embeddings()\n\n    def prepare_inputs_for_generation(\n        self,\n        past_key_values: List[List[torch.Tensor]],\n        inputs_embeds: Optional[torch.Tensor] = None,\n        context: StepContext = None,\n    ):\n        \"\"\"Prepare input.\"\"\"\n        # get input_ids, position_ids and attention metadatas\n        input_ids = context.input_ids\n        position_ids = context.position_ids\n        attn_metadata = context.attn_metadata\n\n        # process vision embeddings\n        vision_embeddings = context.input_embeddings\n        vision_embedding_indexing = context.input_embedding_indexing\n        if vision_embeddings is not None and len(vision_embeddings) > 0:\n            if inputs_embeds is None:\n                inputs_embeds = self.get_input_embeddings()(input_ids)\n            inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds)\n\n        # inputs of forward\n        return dict(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            inputs_embeds=inputs_embeds,\n        )\n\n    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):\n        \"\"\"Load weights.\"\"\"\n        # modify from vllm\n        stacked_params_mapping = [\n            # (param_name, shard_name, shard_id)\n            ('.qkv_proj', '.q_proj', 'q'),\n            ('.qkv_proj', '.k_proj', 'k'),\n            ('.qkv_proj', '.v_proj', 'v'),\n        ]\n\n        params_dict = dict(self.named_parameters())\n        for name, loaded_weight in weights:\n            if 'rotary_emb.inv_freq' in name:\n                continue\n            if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):\n                continue\n            if self.config.tie_word_embeddings and 'lm_head.weight' in name:\n                continue\n\n            for (param_name, weight_name, shard_id) in stacked_params_mapping:\n                if weight_name not in name:\n                    continue\n                name = name.replace(weight_name, param_name)\n                param = params_dict[name]\n                load_weight(param, loaded_weight, shard_id=shard_id)\n                break\n            else:\n                # GLM4 gate up proj weights are packed\n                if '.gate_up_proj' in name:\n                    param = params_dict[name]\n                    gate, up = param.weight_spliter(loaded_weight)\n                    load_weight(param, gate, shard_id=0)\n                    load_weight(param, up, shard_id=1)\n                    continue\n                else:\n                    param = params_dict[name]\n                    load_weight(param, loaded_weight)\n"
  },
  {
    "path": "lmdeploy/pytorch/models/glm4_1v.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n# adapted from:\n# https://github.com/huggingface/transformers/blob/main/src/transformers/models/glm4v/modeling_glm4v.py\n\nfrom typing import Any, Callable, Dict, Iterable, List, Optional, Tuple\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn\nfrom transformers.configuration_utils import PretrainedConfig\n\nfrom lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor, PreprocessInputResult\nfrom lmdeploy.pytorch.model_inputs import StepContext, StepContextManager\nfrom lmdeploy.pytorch.multimodal.data_type import MultiModalData\nfrom lmdeploy.pytorch.nn import ApplyRotaryEmb, FlashAttention, RMSNorm, SiluAndMul, build_rotary_embedding_from_config\nfrom lmdeploy.pytorch.nn.linear import build_merged_colwise_linear, build_qkv_proj, build_rowwise_linear\nfrom lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight\n\nfrom .glm4 import Glm4DecoderLayer\nfrom .utils.cudagraph import CudaGraphMeta, CudaGraphMixin\nfrom .utils.model import DeployModelMixin, vlm_model\n\n\ndef _apply_mrope_selection(hidden_states: torch.Tensor, mrope_position_ids: torch.Tensor, mrope_section: List[int],\n                           position_ids: torch.Tensor, rotary_emb_func: Callable):\n    _mrope_position_ids = torch.zeros(3, position_ids.shape[-1], dtype=position_ids.dtype, device=position_ids.device)\n    _mrope_position_ids[:, :mrope_position_ids.shape[-1]] = mrope_position_ids\n    cos, sin = rotary_emb_func(hidden_states, _mrope_position_ids)\n    _cos = torch.zeros(cos.shape[1], cos.shape[-1], dtype=cos.dtype, device=cos.device)\n    _sin = torch.zeros_like(_cos)\n    mrope_section = mrope_section * 2\n\n    def _apply_split(src, dst):\n        start = 0\n        for i, m in enumerate(src.split(mrope_section, dim=-1)):\n            dst[:, start:start + mrope_section[i]] = m[i % 3]\n            start += mrope_section[i]\n\n    _apply_split(cos, _cos)\n    _apply_split(sin, _sin)\n\n    return _cos, _sin\n\n\nclass Glm4vTextModel(nn.Module):\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n        self.mrope_section = config.rope_scaling['mrope_section']\n\n        self.embed_tokens = nn.Embedding(config.vocab_size,\n                                         config.hidden_size,\n                                         self.padding_idx,\n                                         dtype=dtype,\n                                         device=device)\n\n        # build all decode layers\n        self.layers = nn.ModuleList([\n            Glm4DecoderLayer(config, layer_idx, dtype=dtype, device=device)\n            for layer_idx in range(config.num_hidden_layers)\n        ])\n\n        # build norm\n        self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)\n\n        # build rotary embedding\n        self.rotary_emb = build_rotary_embedding_from_config(config)\n\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        attn_metadata: Any = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        mrope_position_ids: torch.LongTensor = None,\n    ):\n        \"\"\"Rewrite of LlamaModel.forward.\"\"\"\n\n        # token embedding\n        if inputs_embeds is None:\n            inputs_embeds = self.embed_tokens(input_ids)\n\n        hidden_states = inputs_embeds\n\n        # rotary embedding\n        if mrope_position_ids is None:\n            cos, sin = self.rotary_emb(hidden_states, position_ids)\n            cos, sin = cos[0], sin[0]\n        else:\n            cos, sin = _apply_mrope_selection(hidden_states, mrope_position_ids, self.mrope_section, position_ids,\n                                              self.rotary_emb)\n        rotary_pos_emb = (cos, sin)\n\n        # decoding\n        residual = None\n        for idx, decoder_layer in enumerate(self.layers):\n            past_key_value = past_key_values[idx]\n            hidden_states, residual = decoder_layer(\n                hidden_states,\n                rotary_pos_emb=rotary_pos_emb,\n                past_key_value=past_key_value,\n                residual=residual,\n                attn_metadata=attn_metadata,\n            )\n\n        # norm\n        hidden_states, _ = self.norm(hidden_states, residual)\n\n        return hidden_states\n\n\nclass Glm4VisionMLP(nn.Module):\n    \"\"\"Vision MLP.\"\"\"\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 bias: bool = False,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__()\n        quantization_config = getattr(config, 'quantization_config', None)\n        # gate up\n        self.gate_up_proj = build_merged_colwise_linear(\n            in_features=config.hidden_size,\n            all_out_features=[config.out_hidden_size, config.out_hidden_size],\n            bias=bias,\n            dtype=dtype,\n            device=device,\n            quant_config=quantization_config,\n            is_tp=True,\n        )\n\n        # silu and mul\n        self.act_fn = SiluAndMul(inplace=True)\n\n        # down\n        self.down_proj = build_rowwise_linear(in_features=config.out_hidden_size,\n                                              out_features=config.hidden_size,\n                                              bias=bias,\n                                              quant_config=quantization_config,\n                                              dtype=dtype,\n                                              device=device,\n                                              is_tp=True)\n\n    def forward(self, x):\n        \"\"\"forward.\"\"\"\n        return self.down_proj(self.act_fn(self.gate_up_proj(x)))\n\n\nclass Glm4vVisionPatchEmbed(nn.Module):\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None) -> None:\n        super().__init__()\n        self.patch_size = config.patch_size\n        self.temporal_patch_size = config.temporal_patch_size\n        self.in_channels = config.in_channels\n        self.embed_dim = config.hidden_size\n\n        kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size]\n        self.proj = nn.Conv3d(self.in_channels,\n                              self.embed_dim,\n                              kernel_size=kernel_size,\n                              stride=kernel_size,\n                              dtype=dtype,\n                              device=device)\n\n    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        target_dtype = self.proj.weight.dtype\n        hidden_states = hidden_states.view(-1, self.in_channels, self.temporal_patch_size, self.patch_size,\n                                           self.patch_size)\n        hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim)\n        return hidden_states\n\n\nclass Glm4vVisionRotaryEmbedding(nn.Module):\n    \"\"\"Vision rotary embedding.\"\"\"\n\n    def __init__(self, dim: int, theta: float = 10000.0, device: torch.device = None) -> None:\n        super().__init__()\n        inv_freq = 1.0 / (theta**(torch.arange(0, dim, 2, dtype=torch.float, device=device) / dim))\n        self.register_buffer('inv_freq', inv_freq, persistent=False)\n\n    def forward(self, seqlen: int) -> torch.Tensor:\n        seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)\n        freqs = torch.outer(seq, self.inv_freq)\n        return freqs\n\n\nclass Glm4vVisionPatchMerger(nn.Module):\n\n    def __init__(self,\n                 dim: int,\n                 context_dim: int,\n                 hidden_act: str,\n                 bias: bool = False,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None) -> None:\n        super().__init__()\n\n        self.proj = nn.Linear(dim, dim, bias=bias, dtype=dtype, device=device)\n        self.post_projection_norm = nn.LayerNorm(dim, dtype=dtype, device=device)\n\n        # gate up\n        self.gate_up_proj = build_merged_colwise_linear(\n            in_features=dim,\n            all_out_features=[context_dim, context_dim],\n            bias=bias,\n            dtype=dtype,\n            device=device,\n            is_tp=True,\n        )\n\n        # down\n        self.down_proj = build_rowwise_linear(in_features=context_dim,\n                                              out_features=dim,\n                                              bias=bias,\n                                              dtype=dtype,\n                                              device=device,\n                                              is_tp=True)\n\n        # gelu\n        self.act1 = nn.GELU()\n\n        # silu and mul\n        self.act_fn = SiluAndMul(inplace=True)\n\n    def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:\n        hidden_state = self.proj(hidden_state)\n        hidden_state = self.act1(self.post_projection_norm(hidden_state))\n        return self.down_proj(self.act_fn(self.gate_up_proj(hidden_state)))\n\n\nclass Glm4vVisionEmbeddings(nn.Module):\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        self.config = config\n        self.embed_dim = config.hidden_size\n        self.image_size = config.image_size\n        self.patch_size = config.patch_size\n\n        self.num_patches = (self.image_size // self.patch_size)**2\n        self.num_positions = self.num_patches\n        self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim, dtype=dtype, device=device)\n        self.register_buffer('position_ids', torch.arange(self.num_positions).expand((1, -1)), persistent=False)\n\n    def forward(self, embeddings, lengths, image_shapes, h_coords, w_coords) -> torch.Tensor:\n        \"\"\"Forward pass with integrated position encoding adaptation using 2D\n        interpolation.\n\n        Args:\n            embeddings: Input embeddings tensor\n            lengths (torch.Tensor): Sequence lengths for each image in the batch.\n            image_shapes (torch.Tensor): Tensor of shape [batch_size, 3] representing the image shapes (t, h, w).\n            h_coords (torch.Tensor): Tensor of shape [total_seq] representing the h coordinate for each patch.\n            w_coords (torch.Tensor): Tensor of shape [total_seq] representing the w coordinate for each patch.\n\n        Returns:\n            torch.Tensor: Embeddings with adapted position encoding added.\n        \"\"\"\n        # Get position embedding parameters\n        pos_embed_weight = self.position_embedding.weight\n        hidden_size = pos_embed_weight.shape[1]\n        total_seq = h_coords.shape[0]\n        device = pos_embed_weight.device\n\n        # Move coordinates to correct device\n        h_coords, w_coords = h_coords.to(device), w_coords.to(device)\n\n        # Handle empty sequence case\n        if total_seq == 0:\n            adapted_pos_embed = torch.empty(0, hidden_size, device=device, dtype=pos_embed_weight.dtype)\n        else:\n            # Convert inputs to tensors if needed\n            if isinstance(lengths, list):\n                lengths = torch.tensor(lengths, device=device, dtype=torch.long)\n            if not isinstance(image_shapes, torch.Tensor):\n                image_shapes = torch.tensor(image_shapes, device=device, dtype=torch.long)\n\n            # Prepare 2D position embedding\n            orig_size_sq = pos_embed_weight.shape[0]\n            orig_size = int(orig_size_sq**0.5)\n            pos_embed_2d = (pos_embed_weight.view(orig_size, orig_size,\n                                                  hidden_size).permute(2, 0, 1).unsqueeze(0).to(device=device,\n                                                                                                dtype=torch.float32))\n\n            # Calculate target dimensions for each patch\n            target_h = torch.cat([image_shapes[i, 1].repeat(lengths[i])\n                                  for i in range(len(lengths))]).to(device=device, dtype=torch.float32)\n            target_w = torch.cat([image_shapes[i, 2].repeat(lengths[i])\n                                  for i in range(len(lengths))]).to(device=device, dtype=torch.float32)\n\n            # Normalize coordinates to [-1, 1] range for grid_sample\n            h_coords = h_coords.to(device=device, dtype=torch.float32)\n            w_coords = w_coords.to(device=device, dtype=torch.float32)\n            norm_w = ((w_coords + 0.5) / target_w) * 2 - 1\n            norm_h = ((h_coords + 0.5) / target_h) * 2 - 1\n\n            # Create sampling grid\n            grid = torch.stack((norm_w, norm_h), dim=-1).unsqueeze(0).unsqueeze(2)\n\n            # Perform bicubic interpolation\n            interpolated_embed_fp32 = F.grid_sample(pos_embed_2d,\n                                                    grid,\n                                                    mode='bicubic',\n                                                    align_corners=False,\n                                                    padding_mode='border')\n\n            # Reshape and convert back to original dtype\n            adapted_pos_embed_fp32 = interpolated_embed_fp32.squeeze(0).squeeze(-1).permute(1, 0)\n            adapted_pos_embed = adapted_pos_embed_fp32.to(pos_embed_weight.dtype).to(embeddings.device)\n\n        # Add adapted position encoding to embeddings\n        embeddings = embeddings + adapted_pos_embed\n        return embeddings\n\n\nclass Glm4vVisionAttention(nn.Module):\n    \"\"\"Vision attention.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        quantization_config = getattr(config, 'quantization_config', None)\n        dim = config.hidden_size\n        num_heads = config.num_heads\n        head_dim = dim // num_heads\n        self.head_dim = head_dim\n\n        # packed qkv\n        self.qkv = build_qkv_proj(\n            dim,\n            num_q_heads=num_heads,\n            num_kv_heads=num_heads,\n            head_size=head_dim,\n            bias=config.attention_bias,\n            quant_config=quantization_config,\n            dtype=dtype,\n            device=device,\n        )\n\n        # rotary embedding\n        self.apply_rotary_pos_emb = ApplyRotaryEmb()\n\n        # attention\n        self.attention = FlashAttention(\n            num_heads,\n            head_dim,\n            causal=False,\n        )\n\n        # o_proj\n        self.proj = build_rowwise_linear(dim,\n                                         dim,\n                                         bias=True,\n                                         quant_config=quantization_config,\n                                         dtype=dtype,\n                                         device=device,\n                                         is_tp=True)\n\n    def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor,\n                rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor]) -> torch.Tensor:\n        seq_length = hidden_states.shape[0]\n        # qkv proj\n        qkv_states = self.qkv(hidden_states)\n        # (-1, heads, head_dim)\n        qkv_states = qkv_states.flatten(0, -2)\n        q, k, v = self.qkv.split_qkv(qkv_states)\n\n        cos, sin = rotary_pos_emb\n        q, k = self.apply_rotary_pos_emb(q, k, cos, sin)\n\n        attn_output = self.attention(\n            q,\n            k,\n            v,\n            q_start_loc=cu_seqlens[:-1],\n            q_seqlens=cu_seqlens[1:] - cu_seqlens[:-1],\n        )\n\n        attn_output = attn_output.reshape(seq_length, -1)\n\n        # o proj\n        attn_output = self.proj(attn_output)\n        return attn_output\n\n\nclass Glm4vVisionBlock(nn.Module):\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None) -> None:\n        super().__init__()\n        self.norm1 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, dtype=dtype, device=device)\n        self.norm2 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, dtype=dtype, device=device)\n        self.attn = Glm4vVisionAttention(config, dtype=dtype, device=device)\n        self.mlp = Glm4VisionMLP(config, bias=False, dtype=dtype, device=device)\n\n    def forward(self,\n                hidden_states,\n                cu_seqlens,\n                rotary_pos_emb,\n                residual: Optional[torch.Tensor] = None) -> torch.Tensor:\n        if residual is None:\n            residual = hidden_states\n            hidden_states = self.norm1(hidden_states)\n        else:\n            hidden_states, residual = self.norm1(hidden_states, residual)\n\n        hidden_states = self.attn(hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb)\n\n        hidden_states, residual = self.norm2(hidden_states, residual)\n        hidden_states = self.mlp(hidden_states)\n        return hidden_states, residual\n\n\nclass Glm4vVisionModel(nn.Module):\n    \"\"\"Vision transformer.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        self.config = config\n        self.spatial_merge_size = config.spatial_merge_size\n        self.patch_size = config.patch_size\n\n        self.embeddings = Glm4vVisionEmbeddings(config, dtype=dtype, device=device)\n        self.patch_embed = Glm4vVisionPatchEmbed(config, dtype=dtype, device=device)\n\n        head_dim = config.hidden_size // config.num_heads\n        self.rotary_pos_emb = Glm4vVisionRotaryEmbedding(head_dim // 2, device=device)\n\n        self.blocks = nn.ModuleList([Glm4vVisionBlock(config, dtype=dtype, device=device) for _ in range(config.depth)])\n        self.merger = Glm4vVisionPatchMerger(dim=config.out_hidden_size,\n                                             context_dim=config.intermediate_size,\n                                             hidden_act=config.hidden_act,\n                                             dtype=dtype,\n                                             device=device)\n\n        self.post_conv_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, dtype=dtype, device=device)\n        self.downsample = nn.Conv2d(\n            in_channels=config.hidden_size,\n            out_channels=config.out_hidden_size,\n            kernel_size=config.spatial_merge_size,\n            stride=config.spatial_merge_size,\n            dtype=dtype,\n            device=device,\n        )\n        self.post_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, dtype=dtype, device=device)\n\n    def rot_pos_emb(self, grid_thw):\n        \"\"\"Rotary position embedding.\"\"\"\n        pos_ids = []\n        for t, h, w in grid_thw:\n            hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)\n            hpos_ids = hpos_ids.reshape(\n                h // self.spatial_merge_size,\n                self.spatial_merge_size,\n                w // self.spatial_merge_size,\n                self.spatial_merge_size,\n            )\n            hpos_ids = hpos_ids.permute(0, 2, 1, 3)\n            hpos_ids = hpos_ids.flatten()\n\n            wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)\n            wpos_ids = wpos_ids.reshape(\n                h // self.spatial_merge_size,\n                self.spatial_merge_size,\n                w // self.spatial_merge_size,\n                self.spatial_merge_size,\n            )\n            wpos_ids = wpos_ids.permute(0, 2, 1, 3)\n            wpos_ids = wpos_ids.flatten()\n            pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))\n        pos_ids = torch.cat(pos_ids, dim=0)\n        max_grid_size = grid_thw[:, 1:].max()\n        rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)\n        rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)\n        return rotary_pos_emb, pos_ids\n\n    def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor,\n                grid_thw: torch.Tensor, image_type_ids: List[torch.Tensor]) -> torch.Tensor:\n        \"\"\"forward.\"\"\"\n        hidden_states = self.patch_embed(hidden_states)\n        hidden_states = self.post_conv_layernorm(hidden_states)\n\n        cu_seqlens = torch.nn.functional.pad(cu_seqlens, (1, 0), value=0)\n        seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()\n        hidden_states = self.embeddings(hidden_states, seqlens, grid_thw, image_type_ids[:, 0], image_type_ids[:, 1])\n\n        residual = None\n        for blk in self.blocks:\n            hidden_states, residual = blk(hidden_states,\n                                          cu_seqlens=cu_seqlens,\n                                          rotary_pos_emb=rotary_pos_emb,\n                                          residual=residual)\n\n        hidden_states = hidden_states + residual\n\n        hidden_states = self.post_layernorm(hidden_states)\n\n        hidden_states = hidden_states.view(-1, self.spatial_merge_size, self.spatial_merge_size,\n                                           hidden_states.shape[-1])\n        hidden_states = hidden_states.permute(0, 3, 1, 2)\n        hidden_states = self.downsample(hidden_states).view(-1, self.config.out_hidden_size)\n\n        hidden_states = self.merger(hidden_states)\n        return hidden_states\n\n\n@vlm_model\nclass Glm4vForConditionalGeneration(nn.Module, DeployModelMixin, CudaGraphMixin):\n    \"\"\"ModelForCausalLM.\"\"\"\n    packed_modules_mapping = {\n        'qkv_proj': [\n            'q_proj',\n            'k_proj',\n            'v_proj',\n        ],\n        'gate_up_proj': [\n            'gate_proj',\n            'up_proj',\n        ],\n    }\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 ctx_mgr: StepContextManager,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__()\n        self.config = config\n        self.ctx_mgr = ctx_mgr\n\n        # preprocessor\n        self.input_processor = Glm4vInputProcessor(self.config)\n\n        # build vision model\n        self.visual = Glm4vVisionModel(config.vision_config, dtype=dtype, device=device)\n\n        # build language model\n        self.language_model = Glm4vTextModel(config, dtype=dtype, device=device)\n\n        # build lm_head\n        self.lm_head = build_rowwise_linear(config.hidden_size,\n                                            config.vocab_size,\n                                            bias=False,\n                                            dtype=dtype,\n                                            device=device)\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        past_key_values: List[List[torch.Tensor]],\n        attn_metadata: Any = None,\n        inputs_embeds: torch.Tensor = None,\n        mrope_position_ids: torch.Tensor = None,\n        pixel_values: torch.Tensor = None,\n        vis_cu_seqlens: torch.Tensor = None,\n        vis_pos_emb: torch.Tensor = None,\n        image_type_ids: List[torch.Tensor] = None,\n        grid_thw: torch.Tensor = None,\n        image_mask: torch.Tensor = None,\n        **kwargs,\n    ):\n        \"\"\"Model forward, return logits.\"\"\"\n        if inputs_embeds is None:\n            inputs_embeds = self.get_input_embeddings()(input_ids)\n            if pixel_values is not None:\n                dtype = inputs_embeds.dtype\n                pixel_values = pixel_values.to(dtype)\n                vis_pos_emb = (vis_pos_emb[0].to(dtype), vis_pos_emb[1].to(dtype))\n                image_embeds = self.visual(pixel_values,\n                                           cu_seqlens=vis_cu_seqlens,\n                                           rotary_pos_emb=vis_pos_emb,\n                                           image_type_ids=image_type_ids,\n                                           grid_thw=grid_thw)\n                inputs_embeds = inputs_embeds.masked_scatter(image_mask[..., None], image_embeds)\n\n        hidden_states = self.language_model(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            inputs_embeds=inputs_embeds,\n            mrope_position_ids=mrope_position_ids,\n        )\n        return hidden_states\n\n    def get_logits(self, hidden_states: torch.Tensor):\n        \"\"\"Compute logits of the model output.\"\"\"\n        return self.lm_head(hidden_states)\n\n    def update_weights(self):\n        \"\"\"Update weights.\"\"\"\n        if self.config.tie_word_embeddings:\n            self.lm_head.weight = self.language_model.embed_tokens.weight\n\n    def get_input_embeddings(self):\n        \"\"\"Get input embeddings.\"\"\"\n        return self.language_model.embed_tokens\n\n    def prepare_inputs_for_generation(\n        self,\n        past_key_values: List[List[torch.Tensor]],\n        inputs_embeds: Optional[torch.Tensor] = None,\n        context: StepContext = None,\n    ):\n        \"\"\"Prepare input.\"\"\"\n\n        # get input_ids, position_ids and attention metadatas\n        input_ids = context.input_ids\n        position_ids = context.position_ids\n        attn_metadata = context.attn_metadata\n\n        pixel_values = None\n        vis_cu_seqlens = None\n        vis_pos_emb = None\n        image_type_ids = None\n        image_mask = None\n        grid_thw = None\n        if context.input_multimodals is not None:\n            image_data = [input_mm.get('image', []) for input_mm in context.input_multimodals]\n            if len(image_data) > 0:\n                # flatten batch\n                image_data = [data for im_data in image_data for data in im_data]\n                pixel_values = torch.cat([data.data for data in image_data])\n                image_token_id = image_data[0].meta['image_token_id']\n                image_mask = input_ids == image_token_id\n                grid_thw = torch.cat([data.meta['grid_thw'] for data in image_data]).cpu()\n                vis_pos_emb, image_type_ids = self.visual.rot_pos_emb(grid_thw)\n                vis_cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2],\n                                                         grid_thw[:, 0]).to(pixel_values.device)\n                vis_cu_seqlens = vis_cu_seqlens.cumsum(dim=0, dtype=torch.int32)\n                vis_pos_emb = vis_pos_emb.repeat(1, 2)\n                vis_pos_emb = (vis_pos_emb.cos(), vis_pos_emb.sin())\n\n        mrope_position_ids = getattr(context, 'mrope_position_ids', None)\n\n        # process vision embeddings\n        vision_embeddings = context.input_embeddings\n        vision_embedding_indexing = context.input_embedding_indexing\n        if vision_embeddings is not None and len(vision_embeddings) > 0:\n            if inputs_embeds is None:\n                inputs_embeds = self.get_input_embeddings()(input_ids)\n            inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds)\n\n        # inputs of forward\n        return dict(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            inputs_embeds=inputs_embeds,\n            mrope_position_ids=mrope_position_ids,\n            pixel_values=pixel_values,\n            vis_cu_seqlens=vis_cu_seqlens,\n            vis_pos_emb=vis_pos_emb,\n            image_type_ids=image_type_ids,\n            grid_thw=grid_thw,\n            image_mask=image_mask,\n        )\n\n    @classmethod\n    def rename_weight(cls, name: str) -> str:\n        \"\"\"Rename weight.\"\"\"\n        if name.startswith('model.language_model.'):\n            return 'language_model.' + name[len('model.language_model.'):]\n        elif name.startswith('model.visual.'):\n            return 'visual.' + name[len('model.visual.'):]\n        elif name.startswith('model.'):\n            return name[len('model.'):]\n        return name\n\n    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):\n        \"\"\"Load weights.\"\"\"\n        # modify from vllm\n        stacked_params_mapping = [\n            # (param_name, shard_name, shard_id)\n            ('.qkv_proj', '.q_proj', 'q'),\n            ('.qkv_proj', '.k_proj', 'k'),\n            ('.qkv_proj', '.v_proj', 'v'),\n            ('.gate_up_proj', '.gate_proj', 0),\n            ('.gate_up_proj', '.up_proj', 1),\n        ]\n\n        params_dict = dict(self.named_parameters())\n        for name, loaded_weight in weights:\n            if 'rotary_emb.inv_freq' in name:\n                continue\n            if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):\n                continue\n            if self.config.tie_word_embeddings and 'lm_head.weight' in name:\n                continue\n\n            for (param_name, weight_name, shard_id) in stacked_params_mapping:\n                if weight_name not in name:\n                    continue\n                name = name.replace(weight_name, param_name)\n                param = params_dict[name]\n                load_weight(param, loaded_weight, shard_id=shard_id)\n                break\n            else:\n                if '.qkv.' in name:\n                    param = params_dict[name]\n                    q, k, v = param.weight_spliter(loaded_weight)\n                    load_weight(param, q, shard_id='q')\n                    load_weight(param, k, shard_id='k')\n                    load_weight(param, v, shard_id='v')\n                elif '.gate_up_proj' in name:\n                    param = params_dict[name]\n                    gate, up = param.weight_spliter(loaded_weight)\n                    load_weight(param, gate, shard_id=0)\n                    load_weight(param, up, shard_id=1)\n                    continue\n                else:\n                    param = params_dict[name]\n                    load_weight(param, loaded_weight)\n\n    def make_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs):\n        \"\"\"Make cudagraph buffers from forward inputs.\"\"\"\n        max_tokens = graph_meta.max_tokens\n\n        input_buffers = super().make_buffers_cudagraph(graph_meta=graph_meta, **kwargs)\n        mrope_position_ids = kwargs.get('mrope_position_ids', None)\n        if mrope_position_ids is not None:\n            input_buffers['mrope_position_ids'] = mrope_position_ids.new_zeros(3, max_tokens)\n\n        return input_buffers\n\n    def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs):\n        \"\"\"Fill cudagraph buffers from forward inputs.\"\"\"\n\n        new_inputs = super().fill_buffers_cudagraph(graph_meta=graph_meta, **kwargs)\n\n        input_ids = kwargs.get('input_ids')\n        num_tokens = input_ids.size(-1)\n        new_batch_size = graph_meta.max_batchs\n\n        is_decoding = graph_meta.is_decoding\n        input_buffers = graph_meta.input_buffers\n        mrope_position_ids = kwargs.get('mrope_position_ids', None)\n        if mrope_position_ids is not None:\n            input_buffers['mrope_position_ids'][:, :num_tokens] = mrope_position_ids\n            if is_decoding:\n                new_inputs['mrope_position_ids'] = input_buffers['mrope_position_ids'][:, :new_batch_size]\n            else:\n                new_inputs['mrope_position_ids'] = input_buffers['mrope_position_ids']\n\n        return new_inputs\n\n    def _get_model_metas(self, context: StepContext):\n        \"\"\"Get model metas.\"\"\"\n        model_metas = context.model_metas\n        if model_metas is None:\n            batch_size = context.q_seqlens.numel()\n            return [dict(mrope_delta=0)] * batch_size\n        return [dict(mrope_delta=0) if meta is None else meta for meta in model_metas]\n\n    def _update_model_meta_decoding(self, context: StepContext):\n        \"\"\"Update model meta for decoding.\"\"\"\n        model_metas = self._get_model_metas(context)\n        position_ids = context.position_ids\n\n        mrope_deltas = [meta['mrope_delta'] for meta in model_metas]\n        mrope_deltas = position_ids.new_tensor(mrope_deltas)\n        mrope_position_ids = position_ids + mrope_deltas[None]\n        mrope_position_ids = mrope_position_ids.expand(3, -1)\n\n        context.mrope_position_ids = mrope_position_ids\n        return model_metas\n\n    def _get_multimodal_pos_ids(self, grid_thw: list, device: torch.device):\n        \"\"\"Get mrope ids.\"\"\"\n        t, h, w = grid_thw\n        h //= 2\n        w //= 2\n        stride = torch.tensor([h * w, w, 1], device=device)[:, None]\n        size = torch.tensor([t, h, w], device=device)[:, None]\n        pos_ids = torch.arange(t * h * w, device=device)[None].expand(3, -1)\n        pos_ids = pos_ids // stride % size\n        return pos_ids\n\n    def _update_model_meta_prefilling(self, context: StepContext):\n        \"\"\"Update model meta for prefilling.\"\"\"\n        model_metas = self._get_model_metas(context)\n        input_multimodals = context.input_multimodals\n        if input_multimodals is None:\n            input_multimodals = [None] * len(model_metas)\n        position_ids = context.position_ids\n        batched_pos_ids = position_ids[0].split(context.q_seqlens.tolist())\n        mrope_position_ids = []\n        new_model_metas = []\n        for pos_ids, model_meta, input_mm in zip(batched_pos_ids, model_metas, input_multimodals):\n            images = []\n            if input_mm is not None:\n                images = input_mm.get('image', [])\n            if model_meta is None or 'mrope_delta' not in model_meta:\n                mrope_delta = 0\n            else:\n                mrope_delta = model_meta['mrope_delta']\n\n            pos_start = pos_ids[0].item()\n            mrope_pos_ids = pos_ids + mrope_delta\n            mrope_pos_ids = mrope_pos_ids[None].expand(3, -1).clone()\n            for img in images:\n                grid_thw = img.meta['grid_thw'][0].tolist()\n                _, h, w = grid_thw\n                h //= 2\n                w //= 2\n                num_pad = img.end - img.start - max(h, w)\n                mrope_delta -= num_pad\n                fill_start = img.start - pos_start\n                fill_end = img.end - pos_start\n                img_pos_ids = self._get_multimodal_pos_ids(grid_thw, pos_ids.device)\n                img_pos_ids += mrope_pos_ids[:, fill_start:fill_start + 1]\n                mrope_pos_ids[:, fill_end:] -= num_pad\n                mrope_pos_ids[:, fill_start:fill_end] = img_pos_ids\n\n            mrope_position_ids.append(mrope_pos_ids)\n            new_model_metas.append(dict(mrope_delta=mrope_delta))\n\n        mrope_position_ids = torch.cat(mrope_position_ids, dim=1)\n        context.mrope_position_ids = mrope_position_ids\n\n        return new_model_metas\n\n    def update_model_metas(self,\n                           past_key_values: List[List[torch.Tensor]],\n                           inputs_embeds: Optional[torch.Tensor] = None,\n                           context: StepContext = None):\n        \"\"\"Update model meta.\"\"\"\n        if context.is_decoding:\n            return self._update_model_meta_decoding(context)\n        else:\n            return self._update_model_meta_prefilling(context)\n\n    def get_input_processor(self) -> BaseModelInputProcessor:\n        \"\"\"Get input processor.\"\"\"\n        return self.input_processor\n\n\nclass Glm4vInputProcessor(BaseModelInputProcessor):\n    \"\"\"Glm4v input processor.\"\"\"\n\n    def __init__(self, config: PretrainedConfig) -> None:\n        self.config = config\n\n    def preprocess_input(self,\n                         input_ids: List[int],\n                         input_multimodals: List[Dict[str, Any]] = None,\n                         **kwargs) -> PreprocessInputResult:\n        \"\"\"Prepare multimodal input.\"\"\"\n        if input_multimodals is None or len(input_multimodals) == 0:\n            return input_ids, input_multimodals\n\n        input_imgs = []\n        for input_mm in input_multimodals:\n            pixel_values = input_mm['pixel_values']\n            image_grid_thw = input_mm['image_grid_thw']\n            offset = input_mm['offset']\n            start = offset\n            image_token_id = input_mm['image_token_id']\n            num_pad = input_mm['image_tokens']\n            if isinstance(num_pad, torch.Tensor):\n                num_pad = num_pad.item()\n\n            mm_data = MultiModalData(data=pixel_values,\n                                     start=start,\n                                     end=start + num_pad,\n                                     meta=dict(grid_thw=image_grid_thw, image_token_id=image_token_id))\n            input_imgs.append(mm_data)\n\n        result = PreprocessInputResult(\n            input_ids=input_ids,\n            input_multimodals=dict(image=input_imgs),\n        )\n        return result\n"
  },
  {
    "path": "lmdeploy/pytorch/models/glm4_moe.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom typing import Any, Dict, Iterable, List, Optional, Tuple\n\nimport torch\nfrom torch import nn\nfrom transformers.configuration_utils import PretrainedConfig\n\nimport lmdeploy.pytorch.distributed as dist\nfrom lmdeploy.pytorch.distributed import get_tp_world_rank\nfrom lmdeploy.pytorch.model_inputs import StepContext, StepContextManager\nfrom lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, RMSNorm, SiluAndMul, build_rotary_embedding_from_config\nfrom lmdeploy.pytorch.nn.linear import build_merged_colwise_linear, build_qkv_proj, build_rowwise_linear\nfrom lmdeploy.pytorch.nn.moe import SoftmaxTopK, build_fused_moe\nfrom lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight\n\nfrom .utils.cudagraph import CudaGraphMixin\n\n\nclass Glm4MoeAttention(nn.Module):\n    \"\"\"Rewrite module of Qwen3MoeAttention.\"\"\"\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None,\n                 is_tp: bool = True):\n        super().__init__()\n        quantization_config = getattr(config, 'quantization_config', None)\n        num_heads = config.num_attention_heads\n        num_key_value_heads = config.num_key_value_heads\n        hidden_size = config.hidden_size\n        head_dim = getattr(config, 'head_dim', hidden_size // num_heads)\n        num_replicate_kv_heads = getattr(config, 'num_replicate_key_value_heads', 1)\n        self.use_qk_norm = config.use_qk_norm\n\n        # packed qkv\n        self.qkv_proj = build_qkv_proj(\n            hidden_size,\n            num_q_heads=num_heads,\n            num_kv_heads=num_key_value_heads,\n            head_size=head_dim,\n            bias=config.attention_bias,\n            quant_config=quantization_config,\n            num_replicate_kv_heads=num_replicate_kv_heads,\n            dtype=dtype,\n            device=device,\n            is_tp=is_tp,\n        )\n\n        # rotary embedding\n        self.apply_rotary_pos_emb = ApplyRotaryEmb()\n\n        # attention\n        self.attn_fwd = Attention(num_heads, head_dim, num_kv_heads=num_key_value_heads, v_head_size=head_dim)\n\n        # o_proj\n        self.o_proj = build_rowwise_linear(num_heads * head_dim,\n                                           hidden_size,\n                                           bias=False,\n                                           quant_config=quantization_config,\n                                           dtype=dtype,\n                                           device=device,\n                                           is_tp=is_tp)\n\n        # q, k norm\n        if self.use_qk_norm:\n            self.q_norm = RMSNorm(head_dim,\n                                  config.rms_norm_eps,\n                                  quant_config=quantization_config,\n                                  dtype=dtype,\n                                  device=device)\n            self.k_norm = RMSNorm(head_dim,\n                                  config.rms_norm_eps,\n                                  quant_config=quantization_config,\n                                  dtype=dtype,\n                                  device=device)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],\n        past_key_value: Optional[Tuple[torch.Tensor]] = None,\n        attn_metadata: Any = None,\n    ):\n        \"\"\"Rewrite of LlamaAttention.forward.\"\"\"\n        # qkv proj\n        qkv_states = self.qkv_proj(hidden_states)\n        # (-1, heads, head_dim)\n        qkv_states = qkv_states.flatten(0, -2)\n        query_states, key_states, value_states = self.qkv_proj.split_qkv(qkv_states)\n\n        # apply q, k norm\n        if self.use_qk_norm:\n            query_states = self.q_norm(query_states)\n            key_states = self.k_norm(key_states)\n\n        # apply rotary embedding\n        cos, sin = rotary_pos_emb\n        query_states, key_states = self.apply_rotary_pos_emb(\n            query_states,\n            key_states,\n            cos,\n            sin,\n            inplace=True,\n        )\n\n        # attention\n        attn_output = self.attn_fwd(\n            query_states,\n            key_states,\n            value_states,\n            past_key_value[0],\n            past_key_value[1],\n            attn_metadata,\n            k_scales_zeros=None if len(past_key_value) == 2 else past_key_value[2],\n            v_scales_zeros=None if len(past_key_value) == 2 else past_key_value[3],\n            inplace=True,\n        )\n        attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1)\n\n        # o proj\n        attn_output = self.o_proj(attn_output)\n        return attn_output\n\n\nclass Glm4MoeMLP(nn.Module):\n    \"\"\"mlp.\"\"\"\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 intermediate_size: int = None,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None,\n                 is_tp: bool = True,\n                 all_reduce: bool = True):\n        super().__init__()\n        quantization_config = getattr(config, 'quantization_config', None)\n        if intermediate_size is None:\n            intermediate_size = config.intermediate_size\n\n        # gate up\n        self.gate_up_proj = build_merged_colwise_linear(\n            config.hidden_size,\n            [intermediate_size, intermediate_size],\n            bias=False,\n            dtype=dtype,\n            device=device,\n            quant_config=quantization_config,\n            is_tp=is_tp,\n        )\n\n        # silu and mul\n        self.act_fn = SiluAndMul(inplace=True)\n\n        # down\n        self.down_proj = build_rowwise_linear(intermediate_size,\n                                              config.hidden_size,\n                                              bias=False,\n                                              quant_config=quantization_config,\n                                              dtype=dtype,\n                                              device=device,\n                                              is_tp=is_tp,\n                                              all_reduce=all_reduce)\n\n    def forward(self, x):\n        \"\"\"forward.\"\"\"\n        gate_up = self.gate_up_proj(x)\n        act = self.act_fn(gate_up)\n        return self.down_proj(act)\n\n\nclass Glm4MoE(nn.Module):\n    \"\"\"Moe block.\"\"\"\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 layer_idx: int,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None,\n                 is_tp: bool = True):\n        super().__init__()\n        quantization_config = getattr(config, 'quantization_config', None)\n        self.hidden_dim = config.hidden_size\n        self.ffn_dim = config.moe_intermediate_size\n        self.num_experts = config.n_routed_experts\n        self.top_k = config.num_experts_per_tok\n        self.norm_topk_prob = config.norm_topk_prob\n        self.renormalize = self.norm_topk_prob\n\n        self.routed_scaling_factor = config.routed_scaling_factor\n\n        # build gate\n        # refers to https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/glm4_moe.py\n        # NOTE In the transformers implementation, the gate isn't an nn.Linear,\n        # https://github.com/huggingface/transformers/blob/v4.55.1/src/transformers/models/glm4_moe/modeling_glm4_moe.py#L260\n        self.gate = nn.Linear(\n            config.hidden_size,\n            config.n_routed_experts,\n            bias=False,\n            dtype=torch.float32,\n        )\n        self.gate.e_score_correction_bias = nn.Parameter(torch.empty(config.n_routed_experts, dtype=torch.float32))\n\n        self.softmax_topk = SoftmaxTopK(\n            self.top_k,\n            n_groups=getattr(config, 'router_n_groups', -1),\n        )\n\n        # build experts\n        self.experts = build_fused_moe(\n            self.hidden_dim,\n            self.ffn_dim,\n            self.num_experts,\n            top_k=self.top_k,\n            renormalize=self.renormalize,\n            dtype=dtype,\n            device=device,\n            quant_config=quantization_config,\n            all_reduce=False,\n            layer_idx=layer_idx,\n        )\n\n        # build shared experts\n        intermediate_size = config.moe_intermediate_size * config.n_shared_experts\n        self.shared_experts = Glm4MoeMLP(\n            config=config,\n            intermediate_size=intermediate_size,\n            dtype=dtype,\n            device=device,\n            is_tp=is_tp,\n            all_reduce=False,\n        )\n\n        # get all reduce\n        world_size, _ = get_tp_world_rank()\n        if world_size > 1:\n            self._all_reduce = True\n        else:\n            self._all_reduce = False\n\n    def forward(self, hidden_states: torch.Tensor):\n        batch_size, sequence_length, hidden_dim = hidden_states.shape\n        hidden_states = hidden_states.view(-1, hidden_dim)\n\n        # gate\n        router_logits = self.gate(hidden_states.to(dtype=torch.float32))\n        topk_weights, topk_ids = self.softmax_topk(router_logits)\n\n        # experts\n        out_states = self.experts(\n            hidden_states,\n            topk_weights,\n            topk_ids,\n        )\n        out_states = out_states * self.routed_scaling_factor\n\n        # shared experts\n        shared_states = self.shared_experts(hidden_states)\n\n        out_states += shared_states\n        out_states = out_states.reshape(batch_size, sequence_length, -1)\n\n        if self._all_reduce:\n            dist.all_reduce(out_states)\n        return out_states\n\n\nclass Glm4MoeDecoderLayer(nn.Module):\n    \"\"\"Decoder layer.\"\"\"\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 layer_idx: int,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__()\n        self.layer_idx = layer_idx\n        quantization_config = getattr(config, 'quantization_config', None)\n\n        # build attention layer\n        self.self_attn = Glm4MoeAttention(config, dtype=dtype, device=device)\n\n        if layer_idx >= config.first_k_dense_replace:\n            self.mlp = Glm4MoE(config, layer_idx=layer_idx, dtype=dtype, device=device)\n        else:\n            self.mlp = Glm4MoeMLP(config, dtype=dtype, device=device)\n\n        # build input layer norm\n        self.input_layernorm = RMSNorm(config.hidden_size,\n                                       config.rms_norm_eps,\n                                       quant_config=quantization_config,\n                                       dtype=dtype,\n                                       device=device)\n\n        # build attention layer norm\n        self.post_attention_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],\n        past_key_value: Optional[List[torch.FloatTensor]],\n        residual: Optional[torch.Tensor] = None,\n        attn_metadata: Any = None,\n    ):\n        if residual is None:\n            residual = hidden_states\n            hidden_states = self.input_layernorm(hidden_states)\n        else:\n            hidden_states, residual = self.input_layernorm(hidden_states, residual)\n\n        # self attention\n        hidden_states = self.self_attn(\n            hidden_states=hidden_states,\n            rotary_pos_emb=rotary_pos_emb,\n            past_key_value=past_key_value,\n            attn_metadata=attn_metadata,\n        )\n\n        # fully connected\n        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)\n        hidden_states = self.mlp(hidden_states)\n\n        outputs = (hidden_states, residual)\n        return outputs\n\n\nclass Glm4MoeModel(nn.Module):\n    \"\"\"model.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n\n        self.embed_tokens = nn.Embedding(config.vocab_size,\n                                         config.hidden_size,\n                                         self.padding_idx,\n                                         dtype=dtype,\n                                         device=device)\n\n        # build all decode layers\n        self.layers = nn.ModuleList([\n            Glm4MoeDecoderLayer(config, layer_idx, dtype=dtype, device=device)\n            for layer_idx in range(config.num_hidden_layers)\n        ])\n\n        # build norm\n        self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)\n\n        # build rotary embedding\n        self.rotary_emb = self._build_rotary_embedding(config)\n\n    def _build_rotary_embedding(self, config: PretrainedConfig):\n        \"\"\"Build rotary embedding.\"\"\"\n        return build_rotary_embedding_from_config(config)\n\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        attn_metadata: Any = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n    ):\n        \"\"\"Rewrite of LlamaModel.forward.\"\"\"\n\n        # token embedding\n        if inputs_embeds is None:\n            inputs_embeds = self.embed_tokens(input_ids)\n\n        hidden_states = inputs_embeds\n\n        # rotary embedding\n        cos, sin = self.rotary_emb(hidden_states, position_ids)\n        cos, sin = cos[0], sin[0]\n        rotary_pos_emb = (cos, sin)\n\n        # decoding\n        residual = None\n        for idx, decoder_layer in enumerate(self.layers):\n            past_key_value = past_key_values[idx]\n            hidden_states, residual = decoder_layer(\n                hidden_states,\n                rotary_pos_emb=rotary_pos_emb,\n                past_key_value=past_key_value,\n                residual=residual,\n                attn_metadata=attn_metadata,\n            )\n\n        # norm\n        hidden_states, _ = self.norm(hidden_states, residual)\n\n        return hidden_states\n\n\nclass Glm4MoeForCausalLM(nn.Module, CudaGraphMixin):\n    \"\"\"ModelForCausalLM.\"\"\"\n\n    packed_modules_mapping = {\n        'qkv_proj': [\n            'q_proj',\n            'k_proj',\n            'v_proj',\n        ],\n        'gate_up_proj': [\n            'gate_proj',\n            'up_proj',\n        ],\n    }\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 ctx_mgr: StepContextManager,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__()\n        self.config = config\n        self.ctx_mgr = ctx_mgr\n\n        # build model\n        self.model = Glm4MoeModel(\n            config=config,\n            dtype=dtype,\n            device=device,\n        )\n\n        # build lm_head\n        self.lm_head = build_rowwise_linear(\n            config.hidden_size,\n            config.vocab_size,\n            bias=False,\n            dtype=dtype,\n            device=device,\n        )\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        past_key_values: List[List[torch.Tensor]],\n        attn_metadata: Any = None,\n        inputs_embeds: torch.Tensor = None,\n        **kwargs,\n    ):\n        \"\"\"Model forward, return logits.\"\"\"\n        hidden_states = self.model(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            inputs_embeds=inputs_embeds,\n        )\n        return hidden_states\n\n    def get_logits(self, hidden_states: torch.Tensor):\n        \"\"\"Compute logits of the model output.\"\"\"\n        return self.lm_head(hidden_states)\n\n    def get_input_embeddings(self):\n        \"\"\"Get input embeddings.\"\"\"\n        return self.model.get_input_embeddings()\n\n    def prepare_inputs_for_generation(\n        self,\n        past_key_values: List[List[torch.Tensor]],\n        inputs_embeds: Optional[torch.Tensor] = None,\n        context: StepContext = None,\n    ):\n        \"\"\"Prepare input.\"\"\"\n        # get input_ids, position_ids and attention metadatas\n        input_ids = context.input_ids\n        position_ids = context.position_ids\n        attn_metadata = context.attn_metadata\n\n        # process vision embeddings\n        vision_embeddings = context.input_embeddings\n        vision_embedding_indexing = context.input_embedding_indexing\n        if vision_embeddings is not None and len(vision_embeddings) > 0:\n            if inputs_embeds is None:\n                inputs_embeds = self.get_input_embeddings()(input_ids)\n            inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds)\n\n        # inputs of forward\n        return dict(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            inputs_embeds=inputs_embeds,\n        )\n\n    def _load_weight_experts(self, name: str, loaded_weight: torch.Tensor, params_dict: Dict[str, nn.Parameter],\n                             expert_params_mapping: List):\n        \"\"\"Load weight experts.\"\"\"\n        # load fused weights\n        if any([k in name for k in ['fused_w1w3', 'fused_w2']]):\n            return self._load_weight_fused_experts(name, loaded_weight, params_dict)\n\n        for (param_name, weight_name, expert_id, shard_id) in expert_params_mapping:\n            if weight_name not in name:\n                continue\n            name = name.replace(weight_name, param_name)\n            param = params_dict[name]\n            load_weight(param, loaded_weight, expert_id=expert_id, shard_id=shard_id)\n            break\n        else:\n            param = params_dict[name]\n            load_weight(param, loaded_weight)\n\n    def _load_weight_fused_experts(self, name: str, loaded_weight: torch.Tensor, params_dict: Dict[str, nn.Parameter]):\n        \"\"\"Load weight of fused expert weights.\"\"\"\n        num_experts = self.config.num_experts\n        fused_gateup_name = 'fused_w1w3'\n        fused_down_name = 'fused_w2'\n        if fused_gateup_name in name:\n            chunk_size = loaded_weight.shape[0] // num_experts\n\n            for expert_id in range(num_experts):\n                param_name = name.replace(f'experts.{fused_gateup_name}', 'experts.gate_up')\n                param = params_dict[param_name]\n                w1 = loaded_weight.narrow(dim=0, start=chunk_size * expert_id, length=chunk_size // 2)\n                w3 = loaded_weight.narrow(dim=0, start=chunk_size * expert_id + chunk_size // 2, length=chunk_size // 2)\n                load_weight(param, w1, expert_id=expert_id, shard_id='gate')\n                load_weight(param, w3, expert_id=expert_id, shard_id='up')\n\n        elif fused_down_name in name:\n            chunk_size = loaded_weight.shape[0] // num_experts\n\n            for expert_id in range(num_experts):\n                param_name = name.replace(f'experts.{fused_down_name}', 'experts.down')\n                param = params_dict[param_name]\n                w2 = loaded_weight.narrow(dim=0, start=chunk_size * expert_id, length=chunk_size)\n                load_weight(param, w2, expert_id=expert_id, shard_id='down')\n\n    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):\n        \"\"\"Load weights.\"\"\"\n        # modify from vllm\n        stacked_params_mapping = [\n            # (param_name, shard_name, shard_id)\n            ('.qkv_proj', '.q_proj', 'q'),\n            ('.qkv_proj', '.k_proj', 'k'),\n            ('.qkv_proj', '.v_proj', 'v'),\n            ('.gate_up_proj', '.gate_proj', 0),\n            ('.gate_up_proj', '.up_proj', 1),\n        ]\n\n        mtp_param_list = []\n        if hasattr(self.config, 'num_nextn_predict_layers'):\n            num_hidden_layers = self.config.num_hidden_layers\n            num_nextn_predict_layers = self.config.num_nextn_predict_layers\n            mtp_param_list = [f'.layers.{num_hidden_layers+i}' for i in range(num_nextn_predict_layers)]\n\n        # expert map\n        num_experts = self.config.n_routed_experts\n        expert_params_mapping = []\n        for exp_id in range(num_experts):\n            gate_param = ('.experts.gate_up', f'.experts.{exp_id}.gate_proj', exp_id, 'gate')\n            up_param = ('.experts.gate_up', f'.experts.{exp_id}.up_proj', exp_id, 'up')\n            down_param = ('.experts.down', f'.experts.{exp_id}.down_proj', exp_id, 'down')\n            expert_params_mapping += [gate_param, up_param, down_param]\n\n        params_dict = dict(self.named_parameters())\n\n        for name, loaded_weight in weights:\n            # skip MTP related weights\n            if any(mtp_param_name in name for mtp_param_name in mtp_param_list):\n                continue\n            if 'rotary_emb.inv_freq' in name:\n                continue\n            if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):\n                continue\n            if self.config.tie_word_embeddings and 'lm_head.weight' in name:\n                continue\n            name = name.replace('.block_sparse_moe.', '.mlp.')\n            if '.experts' in name:\n                self._load_weight_experts(name, loaded_weight, params_dict, expert_params_mapping=expert_params_mapping)\n            else:\n                for (param_name, weight_name, shard_id) in stacked_params_mapping:\n                    if weight_name not in name:\n                        continue\n                    name = name.replace(weight_name, param_name)\n                    param = params_dict[name]\n                    load_weight(param, loaded_weight, shard_id=shard_id)\n                    break\n                else:\n                    param = params_dict[name]\n                    load_weight(param, loaded_weight)\n"
  },
  {
    "path": "lmdeploy/pytorch/models/glm4moe_mtp.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nfrom typing import Iterable\n\nimport torch\nfrom torch import nn\nfrom transformers.configuration_utils import PretrainedConfig\n\nfrom lmdeploy.pytorch.model_inputs import StepContextManager\nfrom lmdeploy.pytorch.nn import RMSNorm, build_rotary_embedding_from_config\nfrom lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight\n\nfrom .deepseek_mtp import DeepseekMTPModel\nfrom .glm4_moe import Glm4MoE, Glm4MoeAttention, Glm4MoeDecoderLayer, Glm4MoeMLP\n\n\nclass Glm4MoeMTPDecoderLayer(Glm4MoeDecoderLayer):\n    \"\"\"Decoder layer.\"\"\"\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 layer_idx: int,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        nn.Module.__init__(self)\n        self.layer_idx = layer_idx\n        quantization_config = getattr(config, 'quantization_config', None)\n\n        # build attention layer\n        self.self_attn = Glm4MoeAttention(config, dtype=dtype, device=device, is_tp=False)\n\n        if layer_idx >= config.first_k_dense_replace:\n            self.mlp = Glm4MoE(config, layer_idx=layer_idx, dtype=dtype, device=device, is_tp=False)\n            self.mlp._all_reduce = False\n        else:\n            self.mlp = Glm4MoeMLP(config, dtype=dtype, device=device, is_tp=False, all_reduce=False)\n\n        # build input layer norm\n        self.input_layernorm = RMSNorm(config.hidden_size,\n                                       config.rms_norm_eps,\n                                       quant_config=quantization_config,\n                                       dtype=dtype,\n                                       device=device)\n\n        # build attention layer norm\n        self.post_attention_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)\n\n\nclass Glm4MoeMTPModel(DeepseekMTPModel):\n    \"\"\"ModelForCausalLM.\"\"\"\n\n    packed_modules_mapping = {\n        'qkv_proj': [\n            'q_proj',\n            'k_proj',\n            'v_proj',\n        ],\n        'gate_up_proj': [\n            'gate_proj',\n            'up_proj',\n        ],\n    }\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 ctx_mgr: StepContextManager,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__(\n            config,\n            ctx_mgr,\n            dtype=dtype,\n            device=device,\n            decoder_layer_cls=Glm4MoeMTPDecoderLayer,\n            build_rotary_embedding_func=build_rotary_embedding_from_config,\n        )\n\n    def _load_weight_experts(self, name: str, loaded_weight: torch.Tensor, params_dict: dict[str, nn.Parameter],\n                             expert_params_mapping: list[list[str]]):\n        \"\"\"Load weight experts.\"\"\"\n        for (param_name, weight_name, expert_id, shard_id) in expert_params_mapping:\n            if weight_name not in name:\n                continue\n            name = name.replace(weight_name, param_name)\n            param = params_dict[name]\n            load_weight(param, loaded_weight, expert_id=expert_id, shard_id=shard_id)\n            break\n        else:\n            param = params_dict[name]\n            load_weight(param, loaded_weight)\n\n    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):\n        \"\"\"Load weights.\"\"\"\n\n        def __skip_nextn(name, nextn_keys):\n            for nextn_key in nextn_keys:\n                if nextn_key in name:\n                    return True\n            return False\n\n        stacked_params_mapping = [\n            # (param_name, shard_name, shard_id)\n            ('.qkv_proj', '.q_proj', 'q'),\n            ('.qkv_proj', '.k_proj', 'k'),\n            ('.qkv_proj', '.v_proj', 'v'),\n            ('.gate_up_proj', '.gate_proj', 0),\n            ('.gate_up_proj', '.up_proj', 1),\n        ]\n\n        num_hidden_layers = self.config.num_hidden_layers\n\n        num_nextn_predict_layers = getattr(self.config, 'num_nextn_predict_layers', 1)\n        nextn_keys = [f'.layers.{num_hidden_layers+i}' for i in range(num_nextn_predict_layers)]\n\n        # expert map\n        num_experts = self.config.n_routed_experts\n        expert_params_mapping = []\n        for exp_id in range(num_experts):\n            gate_param = ('.experts.gate_up', f'.experts.{exp_id}.gate_proj', exp_id, 'gate')\n            up_param = ('.experts.gate_up', f'.experts.{exp_id}.up_proj', exp_id, 'up')\n            down_param = ('.experts.down', f'.experts.{exp_id}.down_proj', exp_id, 'down')\n            expert_params_mapping += [gate_param, up_param, down_param]\n\n        params_dict = dict(self.named_parameters())\n\n        for name, loaded_weight in weights:\n            # keep nextn\n            if not __skip_nextn(name, nextn_keys):\n                continue\n            if '.layers' in name:\n                layer_idx = int(name.split('layers.')[1].split('.')[0])\n                name = self._rewrite_spec_layer_name(layer_idx, name)\n            if 'rotary_emb.inv_freq' in name:\n                continue\n            if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):\n                continue\n            if self.config.tie_word_embeddings and 'lm_head.weight' in name:\n                continue\n            name = name.replace('.block_sparse_moe.', '.mlp.')\n            if '.experts' in name:\n                self._load_weight_experts(name, loaded_weight, params_dict, expert_params_mapping=expert_params_mapping)\n            else:\n                for (param_name, weight_name, shard_id) in stacked_params_mapping:\n                    if weight_name not in name:\n                        continue\n                    name = name.replace(weight_name, param_name)\n                    param = params_dict[name]\n                    load_weight(param, loaded_weight, shard_id=shard_id)\n                    break\n                else:\n                    param = params_dict[name]\n                    load_weight(param, loaded_weight)\n"
  },
  {
    "path": "lmdeploy/pytorch/models/gpt_oss.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nimport functools\nfrom typing import Any, Callable, Dict, Iterable, List, Optional, Tuple\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn\nfrom transformers.configuration_utils import PretrainedConfig\n\nfrom lmdeploy.pytorch.model_inputs import StepContext, StepContextManager\nfrom lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, RMSNorm, build_rotary_embedding_from_config\nfrom lmdeploy.pytorch.nn.linear import build_o_proj, build_qkv_proj\nfrom lmdeploy.pytorch.nn.moe import build_fused_moe\nfrom lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight\n\nfrom .patch import get_build_model_context\nfrom .utils.cudagraph import CudaGraphMixin\nfrom .utils.model import DeployModelMixinV1, build_embedding\n\n\nclass GptOssAttention(nn.Module):\n    \"\"\"attention.\"\"\"\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 attention_type: str,\n                 layer_idx: int,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__()\n        quantization_config = getattr(config, 'quantization_config', None)\n        self.layer_idx = layer_idx\n        num_attention_heads = config.num_attention_heads\n        num_key_value_heads = config.num_key_value_heads\n        hidden_size = config.hidden_size\n        head_dim = getattr(config, 'head_dim', config.hidden_size // config.num_attention_heads)\n        num_replicate_kv_heads = getattr(config, 'num_replicate_key_value_heads', 1)\n        scaling = head_dim**-0.5\n\n        self.qkv_proj = build_qkv_proj(hidden_size,\n                                       num_q_heads=num_attention_heads,\n                                       num_kv_heads=num_key_value_heads,\n                                       head_size=head_dim,\n                                       bias=config.attention_bias,\n                                       quant_config=quantization_config,\n                                       dtype=dtype,\n                                       device=device,\n                                       num_replicate_kv_heads=num_replicate_kv_heads)\n\n        # rotary embedding\n        self.apply_rotary_pos_emb = ApplyRotaryEmb()\n\n        if attention_type == 'sliding_attention':\n            sliding_window = config.sliding_window\n        elif attention_type == 'full_attention':\n            sliding_window = None\n        else:\n            raise ValueError(f'Unsupported attention type: {attention_type}')\n        # attention\n        self.attn_fwd = Attention(\n            num_attention_heads,\n            head_dim,\n            scale=scaling,\n            num_kv_heads=num_key_value_heads,\n            v_head_size=head_dim,\n            sliding_window=sliding_window,\n            learnable_sink=True,\n        )\n\n        # o_proj\n        self.o_proj = build_o_proj(num_attention_heads * head_dim,\n                                   hidden_size,\n                                   bias=config.attention_bias,\n                                   quant_config=quantization_config,\n                                   dtype=dtype,\n                                   device=device,\n                                   is_tp=True)\n\n        # sinks\n        self.sinks = self.build_sinks(config, device)\n\n    @classmethod\n    def build_sinks(cls, config: PretrainedConfig, device):\n        \"\"\"Build sinks.\"\"\"\n        from lmdeploy.pytorch.distributed import get_tp_world_rank\n        world_size, _ = get_tp_world_rank()\n        num_attention_heads = config.num_attention_heads\n        assert num_attention_heads % world_size == 0, (\n            f'num_attention_heads={num_attention_heads} should be divisible by TP={world_size}')\n        num_attention_heads = num_attention_heads // world_size\n        sinks = nn.Parameter(torch.empty(num_attention_heads, device=device))\n        sinks.weight_loader = cls.weight_loader_sinks\n        return sinks\n\n    @classmethod\n    def weight_loader_sinks(cls, param: nn.Parameter, loaded_weight: torch.Tensor):\n        \"\"\"Load weight of sinks.\"\"\"\n        from lmdeploy.pytorch.distributed import get_tp_world_rank\n        world_size, rank = get_tp_world_rank()\n        loaded_weight = loaded_weight.chunk(world_size, dim=0)[rank]\n        param.data.copy_(loaded_weight)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],\n        past_key_value: Optional[Tuple[torch.Tensor]] = None,\n        attn_metadata: Any = None,\n    ):\n        \"\"\"Rewrite of forward.\"\"\"\n        # qkv proj\n        qkv_states = self.qkv_proj(hidden_states)\n        # (-1, heads, head_dim)\n        qkv_states = qkv_states.flatten(0, -2)\n        query_states, key_states, value_states = self.qkv_proj.split_qkv(qkv_states)\n\n        # apply rotary embedding\n        cos, sin = rotary_pos_emb\n        query_states, key_states = self.apply_rotary_pos_emb(\n            query_states,\n            key_states,\n            cos,\n            sin,\n            inplace=True,\n        )\n\n        # attention\n        attn_output = self.attn_fwd(\n            query_states,\n            key_states,\n            value_states,\n            past_key_value[0],\n            past_key_value[1],\n            attn_metadata,\n            k_scales_zeros=None if len(past_key_value) == 2 else past_key_value[2],\n            v_scales_zeros=None if len(past_key_value) == 2 else past_key_value[3],\n            s_aux=self.sinks,\n        )\n        attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1)\n\n        # o proj\n        attn_output = self.o_proj(attn_output)\n        return attn_output\n\n\nclass GateupAct:\n\n    def __init__(self, limit: float = 7.0, alpha: float = 1.702):\n        self.limit = limit\n        self.alpha = alpha\n        self._run: Callable = None\n\n    def _impl(self, gateup: torch.Tensor) -> torch.Tensor:\n        \"\"\"Moe act.\"\"\"\n        gate, up = gateup.chunk(2, dim=-1)\n        gate = gate.clamp(min=None, max=self.limit)\n        up = up.clamp(min=-self.limit, max=self.limit)\n        glu = gate * torch.sigmoid(gate * self.alpha)\n        return (up + 1) * glu\n\n    @staticmethod\n    @functools.lru_cache(maxsize=None)\n    def build(limit: float, alpha: float):\n        return GateupAct(limit, alpha)\n\n    def _try_compile(self, gateup: torch.Tensor) -> Callable:\n        try:\n            run = torch.compile(self._impl, dynamic=True)\n            run(gateup)\n            self._run = run\n        except Exception:\n            self._run = self._impl\n\n    def __call__(self, gateup: torch.Tensor) -> torch.Tensor:\n        \"\"\"Call the act function.\"\"\"\n        if self._run is None:\n            self._try_compile(gateup)\n\n        return self._run(gateup)\n\n\nclass GptOssExperts(nn.Module):\n    \"\"\"experts.\"\"\"\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 layer_idx: int,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__()\n        quantization_config = getattr(config, 'quantization_config', None)\n        self.layer_idx = layer_idx\n        self.intermediate_size = config.intermediate_size\n        self.num_experts = config.num_local_experts\n        self.hidden_size = config.hidden_size\n        self.expert_dim = self.intermediate_size\n        self.top_k = config.num_experts_per_tok\n        self.alpha = 1.702\n        self.limit = 7.0\n        self._gateup_act = GateupAct.build(self.limit, self.alpha)\n\n        self.experts = build_fused_moe(\n            self.hidden_size,\n            self.expert_dim,\n            self.num_experts,\n            bias=True,\n            top_k=self.top_k,\n            renormalize=False,\n            dtype=dtype,\n            device=device,\n            quant_config=quantization_config,\n            all_reduce=True,\n            layer_idx=layer_idx,\n            act_func=self._gateup_act,\n        )\n\n    def forward(self, hidden_states: torch.Tensor, router_indices, routing_weights) -> torch.Tensor:\n        \"\"\"forward.\"\"\"\n        batch_size, sequence_length, _ = hidden_states.shape\n        out_states = self.experts(\n            hidden_states[0],\n            routing_weights,\n            router_indices,\n        )\n\n        out_states = out_states.reshape(batch_size, sequence_length, -1)\n        return out_states\n\n\nclass GptOssTopKRouter(nn.Module):\n    \"\"\"Gate + topk + softmax.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        self.top_k = config.num_experts_per_tok\n        self.num_experts = config.num_local_experts\n        self.hidden_dim = config.hidden_size\n        self.weight = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, dtype=dtype, device=device))\n        self.bias = nn.Parameter(torch.empty(self.num_experts, dtype=dtype, device=device))\n\n    def forward(self, hidden_states):\n        hidden_states = hidden_states.reshape(-1, self.hidden_dim)\n        router_logits = F.linear(hidden_states, self.weight, self.bias)  # (seq_len, num_experts)\n        router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1)  # (seq_len, top_k)\n        router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype)\n        router_scores = router_top_value\n        return router_scores, router_indices\n\n\nclass GptOssMLP(nn.Module):\n    \"\"\"mlp.\"\"\"\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 layer_idx: int,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__()\n        self.layer_idx = layer_idx\n        self.router = GptOssTopKRouter(config, dtype=dtype, device=device)\n        self.experts = GptOssExperts(config, layer_idx, dtype=dtype, device=device)\n\n    def forward(self, hidden_states, all_routed_experts: torch.Tensor = None):\n        router_scores, router_indices = self.router(hidden_states)  # (num_experts, seq_len)\n        if all_routed_experts is not None:\n            all_routed_experts[:, self.layer_idx, :] = router_indices\n        routed_out = self.experts(hidden_states, router_indices=router_indices, routing_weights=router_scores)\n        return routed_out\n\n\nclass GptOssDecoderLayer(nn.Module):\n    \"\"\"Decoder layer.\"\"\"\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 layer_idx: int,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__()\n        self.layer_idx = layer_idx\n        self.attention_type = config.layer_types[layer_idx]\n\n        quantization_config = getattr(config, 'quantization_config', None)\n\n        # build attention layer\n        self.self_attn = GptOssAttention(config, self.attention_type, layer_idx=layer_idx, dtype=dtype, device=device)\n\n        # build MLP\n        self.mlp = GptOssMLP(config, layer_idx, dtype=dtype, device=device)\n\n        # build input layer norm\n        self.input_layernorm = RMSNorm(config.hidden_size,\n                                       config.rms_norm_eps,\n                                       quant_config=quantization_config,\n                                       dtype=dtype,\n                                       device=device)\n\n        # build attention layer norm\n        self.post_attention_layernorm = RMSNorm(config.hidden_size,\n                                                config.rms_norm_eps,\n                                                quant_config=quantization_config,\n                                                dtype=dtype,\n                                                device=device)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],\n        past_key_value: Optional[List[torch.FloatTensor]],\n        residual: Optional[torch.Tensor] = None,\n        attn_metadata: Any = None,\n        all_routed_experts: torch.Tensor = None,\n    ):\n\n        if residual is None:\n            residual = hidden_states\n            hidden_states = self.input_layernorm(hidden_states)\n        else:\n            hidden_states, residual = self.input_layernorm(hidden_states, residual)\n\n        # Self Attention\n        hidden_states = self.self_attn(\n            hidden_states=hidden_states,\n            rotary_pos_emb=rotary_pos_emb,\n            past_key_value=past_key_value,\n            attn_metadata=attn_metadata,\n        )\n\n        # Fully Connected\n        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)\n        hidden_states = self.mlp(hidden_states, all_routed_experts=all_routed_experts)\n\n        outputs = (hidden_states, residual)\n        return outputs\n\n\nclass GptOssModel(nn.Module):\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n\n        self.embed_tokens = build_embedding(\n            config.vocab_size,\n            config.hidden_size,\n            config.pad_token_id,\n            dtype=dtype,\n            device=device,\n        )\n\n        # build all decode layers\n        self.layers = nn.ModuleList([\n            GptOssDecoderLayer(config, layer_idx, dtype=dtype, device=device)\n            for layer_idx in range(config.num_hidden_layers)\n        ])\n\n        # build norm\n        self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)\n\n        # build rotary embedding\n        self.rotary_emb = build_rotary_embedding_from_config(config)\n\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        attn_metadata: Any = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        all_routed_experts: torch.Tensor = None,\n    ):\n        \"\"\"Rewrite of forward.\"\"\"\n\n        # token embedding\n        if inputs_embeds is None:\n            inputs_embeds = self.embed_tokens(input_ids)\n\n        hidden_states = inputs_embeds\n\n        # rotary embedding\n        cos, sin = self.rotary_emb(hidden_states, position_ids)\n        cos, sin = cos[0], sin[0]\n        rotary_pos_emb = (cos, sin)\n\n        # decoding\n        residual = None\n        for idx, decoder_layer in enumerate(self.layers):\n            past_key_value = past_key_values[idx]\n            hidden_states, residual = decoder_layer(\n                hidden_states,\n                rotary_pos_emb=rotary_pos_emb,\n                past_key_value=past_key_value,\n                residual=residual,\n                attn_metadata=attn_metadata,\n                all_routed_experts=all_routed_experts,\n            )\n\n        # norm\n        hidden_states, _ = self.norm(hidden_states, residual)\n\n        return hidden_states\n\n    def get_input_embeddings(self):\n        \"\"\"Get input embeddings.\"\"\"\n        return self.embed_tokens\n\n\nclass GptOssForCausalLM(nn.Module, DeployModelMixinV1, CudaGraphMixin):\n    \"\"\"ModelForCausalLM.\"\"\"\n\n    packed_modules_mapping = {\n        'qkv_proj': [\n            'q_proj',\n            'k_proj',\n            'v_proj',\n        ],\n    }\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 ctx_mgr: StepContextManager,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__()\n        self.config = config\n        self.ctx_mgr = ctx_mgr\n        # build model\n        self.model = GptOssModel(config, dtype=dtype, device=device)\n        # build lm_head\n        self.lm_head = self.build_lm_head(config.hidden_size, config.vocab_size, bias=False, dtype=dtype, device=device)\n\n        # for router replay\n        bm_ctx = get_build_model_context()\n        self.enable_return_routed_experts = bm_ctx.enable_return_routed_experts\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        past_key_values: List[List[torch.Tensor]],\n        attn_metadata: Any = None,\n        inputs_embeds: torch.Tensor = None,\n        **kwargs,\n    ):\n        \"\"\"Model forward, return logits.\"\"\"\n        # router replay\n        all_routed_experts = None\n        if self.enable_return_routed_experts:\n            if inputs_embeds is not None:\n                num_tokens = inputs_embeds.size(1)\n            else:\n                num_tokens = input_ids.size(1)\n            all_routed_experts = position_ids.new_empty(\n                (num_tokens, self.config.num_hidden_layers, self.config.num_experts_per_tok), dtype=torch.uint16)\n\n        hidden_states = self.model(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            inputs_embeds=inputs_embeds,\n            all_routed_experts=all_routed_experts,\n        )\n\n        if all_routed_experts is None:\n            return hidden_states\n        return dict(hidden_states=hidden_states, all_routed_experts=all_routed_experts)\n\n    def get_input_embeddings(self):\n        \"\"\"Get input embeddings.\"\"\"\n        return self.model.get_input_embeddings()\n\n    def prepare_inputs_for_generation(\n        self,\n        past_key_values: List[List[torch.Tensor]],\n        inputs_embeds: Optional[torch.Tensor] = None,\n        context: StepContext = None,\n    ):\n        \"\"\"Prepare input.\"\"\"\n        # get input_ids, position_ids and attention metadatas\n        input_ids = context.input_ids\n        position_ids = context.position_ids\n        attn_metadata = context.attn_metadata\n\n        # process vision embeddings\n        vision_embeddings = context.input_embeddings\n        vision_embedding_indexing = context.input_embedding_indexing\n        if vision_embeddings is not None and len(vision_embeddings) > 0:\n            if inputs_embeds is None:\n                inputs_embeds = self.get_input_embeddings()(input_ids)\n            inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds)\n\n        # inputs of forward\n        return dict(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            inputs_embeds=inputs_embeds,\n        )\n\n    def _load_weight_experts_gate_up(self, name: str, loaded_weight: torch.Tensor, params_dict: Dict[str,\n                                                                                                     nn.Parameter]):\n        \"\"\"Load weight of experts gate up.\"\"\"\n        num_experts = self.config.num_local_experts\n\n        loaded_weight = loaded_weight.cuda()\n        if 'gate_up_proj_bias' in name:\n            param_name = name.replace('experts.gate_up_proj_bias', 'experts.experts.gate_up.bias')\n        elif 'gate_up_proj' in name:\n            param_name = name.replace('experts.gate_up_proj', 'experts.experts.gate_up.weight')\n            loaded_weight = loaded_weight.transpose(1, 2)\n        param = params_dict[param_name]\n        for expert_id in range(num_experts):\n            w1 = loaded_weight[expert_id, ::2]\n            w3 = loaded_weight[expert_id, 1::2]\n            load_weight(param, w1, expert_id=expert_id, shard_id='gate')\n            load_weight(param, w3, expert_id=expert_id, shard_id='up')\n\n    def _load_weight_experts_down(self, name: str, loaded_weight: torch.Tensor, params_dict: Dict[str, nn.Parameter]):\n        \"\"\"Load weight of experts down.\"\"\"\n        num_experts = self.config.num_local_experts\n\n        loaded_weight = loaded_weight.cuda()\n        if 'down_proj_bias' in name:\n            param_name = name.replace('experts.down_proj_bias', 'experts.experts.down.bias')\n        elif 'down_proj' in name:\n            param_name = name.replace('experts.down_proj', 'experts.experts.down.weight')\n            loaded_weight = loaded_weight.transpose(1, 2)\n        param = params_dict[param_name]\n        for expert_id in range(num_experts):\n            w2 = loaded_weight[expert_id]\n            load_weight(param, w2, expert_id=expert_id, shard_id='down')\n\n    def _load_weight_experts(self, name: str, loaded_weight: torch.Tensor, params_dict: Dict[str, nn.Parameter]):\n        \"\"\"Load weight of fused expert weights.\"\"\"\n        if 'gate_up' in name:\n            self._load_weight_experts_gate_up(name, loaded_weight, params_dict)\n\n        elif 'down' in name:\n            self._load_weight_experts_down(name, loaded_weight, params_dict)\n\n    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):\n        \"\"\"Load weights.\"\"\"\n        # modify from vllm\n        stacked_params_mapping = [\n            # (param_name, shard_name, shard_id)\n            ('.qkv_proj', '.q_proj', 'q'),\n            ('.qkv_proj', '.k_proj', 'k'),\n            ('.qkv_proj', '.v_proj', 'v'),\n        ]\n\n        params_dict = dict(self.named_parameters())\n        for name, loaded_weight in weights:\n            if 'rotary_emb.inv_freq' in name:\n                continue\n            if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):\n                continue\n            if self.config.tie_word_embeddings and 'lm_head.weight' in name:\n                continue\n            name = name.replace('.block_sparse_moe.', '.mlp.')\n            if '.experts' in name:\n                self._load_weight_experts(name, loaded_weight, params_dict)\n            else:\n                for (param_name, weight_name, shard_id) in stacked_params_mapping:\n                    if weight_name not in name:\n                        continue\n                    name = name.replace(weight_name, param_name)\n                    param = params_dict[name]\n                    load_weight(param, loaded_weight, shard_id=shard_id)\n                    break\n                else:\n                    param = params_dict[name]\n                    load_weight(param, loaded_weight)\n"
  },
  {
    "path": "lmdeploy/pytorch/models/internlm.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nfrom typing import Any, Iterable, List, Optional, Tuple\n\nimport torch\nfrom torch import nn\nfrom transformers.configuration_utils import PretrainedConfig\n\nfrom lmdeploy.pytorch.model_inputs import StepContext, StepContextManager\nfrom lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, RMSNorm, RopeType, SiluAndMul, build_rotary_embedding\nfrom lmdeploy.pytorch.nn.linear import (build_down_linear, build_gateup_linear, build_o_proj, build_qkv_proj,\n                                        build_rowwise_linear)\nfrom lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight\n\nfrom .utils.cudagraph import CudaGraphMixin\n\n\nclass InternLMAttention(nn.Module):\n    \"\"\"Rewrite module of LlamaAttention.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        quantization_config = getattr(config, 'quantization_config', None)\n        num_heads = config.num_attention_heads\n        num_key_value_heads = config.num_key_value_heads\n        hidden_size = config.hidden_size\n        head_dim = getattr(config, 'head_dim', hidden_size // num_heads)\n        num_replicate_kv_heads = getattr(config, 'num_replicate_key_value_heads', 1)\n        # packed qkv\n        self.qkv_proj = build_qkv_proj(hidden_size,\n                                       num_q_heads=num_heads,\n                                       num_kv_heads=num_key_value_heads,\n                                       head_size=head_dim,\n                                       bias=config.bias,\n                                       quant_config=quantization_config,\n                                       dtype=dtype,\n                                       device=device,\n                                       num_replicate_kv_heads=num_replicate_kv_heads)\n\n        # rotary embedding\n        self.apply_rotary_pos_emb = ApplyRotaryEmb()\n\n        # attention\n        self.attn_fwd = Attention(\n            num_heads,\n            head_dim,\n            num_kv_heads=num_key_value_heads,\n            v_head_size=head_dim,\n        )\n\n        # o_proj\n        self.o_proj = build_o_proj(num_heads * head_dim,\n                                   hidden_size,\n                                   bias=config.bias,\n                                   quant_config=quantization_config,\n                                   dtype=dtype,\n                                   device=device,\n                                   is_tp=True)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],\n        past_key_value: Optional[Tuple[torch.Tensor]] = None,\n        attn_metadata: Any = None,\n    ):\n        \"\"\"Rewrite of LlamaAttention.forward.\"\"\"\n        # qkv proj\n        qkv_states = self.qkv_proj(hidden_states)\n        # (-1, heads, head_dim)\n        qkv_states = qkv_states.flatten(0, -2)\n        query_states, key_states, value_states = self.qkv_proj.split_qkv(qkv_states)\n\n        # apply rotary embedding\n        cos, sin = rotary_pos_emb\n        query_states, key_states = self.apply_rotary_pos_emb(\n            query_states,\n            key_states,\n            cos,\n            sin,\n            inplace=True,\n        )\n\n        # attention\n        attn_output = self.attn_fwd(\n            query_states,\n            key_states,\n            value_states,\n            past_key_value[0],\n            past_key_value[1],\n            attn_metadata,\n            k_scales_zeros=None if len(past_key_value) == 2 else past_key_value[2],\n            v_scales_zeros=None if len(past_key_value) == 2 else past_key_value[3],\n            inplace=True,\n        )\n        attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1)\n\n        # o proj\n        attn_output = self.o_proj(attn_output)\n        return attn_output\n\n\nclass InternLMMLP(nn.Module):\n    \"\"\"mlp.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        quantization_config = getattr(config, 'quantization_config', None)\n        # gate up\n        self.gate_up_proj = build_gateup_linear(\n            config.hidden_size,\n            [config.intermediate_size, config.intermediate_size],\n            bias=config.bias,\n            dtype=dtype,\n            device=device,\n            quant_config=quantization_config,\n            is_tp=True,\n        )\n\n        # silu and mul\n        self.act_fn = SiluAndMul(inplace=True)\n\n        # down\n        self.down_proj = build_down_linear(config.intermediate_size,\n                                           config.hidden_size,\n                                           bias=config.bias,\n                                           quant_config=quantization_config,\n                                           dtype=dtype,\n                                           device=device,\n                                           is_tp=True)\n\n    def forward(self, x):\n        \"\"\"forward.\"\"\"\n        gate_up = self.gate_up_proj(x)\n        act = self.act_fn(gate_up)\n        return self.down_proj(act)\n\n\nclass InternLMDecoderLayer(nn.Module):\n    \"\"\"Decoder layer.\"\"\"\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 layer_idx: int,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__()\n        self.layer_idx = layer_idx\n        quantization_config = getattr(config, 'quantization_config', None)\n\n        # build attention layer\n        self.self_attn = InternLMAttention(config, dtype=dtype, device=device)\n\n        # build MLP\n        self.mlp = InternLMMLP(config, dtype=dtype, device=device)\n\n        # build input layer norm\n        self.input_layernorm = RMSNorm(config.hidden_size,\n                                       config.rms_norm_eps,\n                                       quant_config=quantization_config,\n                                       dtype=dtype,\n                                       device=device)\n\n        # build attention layer norm\n        self.post_attention_layernorm = RMSNorm(config.hidden_size,\n                                                config.rms_norm_eps,\n                                                quant_config=quantization_config,\n                                                dtype=dtype,\n                                                device=device)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],\n        past_key_value: Optional[List[torch.FloatTensor]],\n        residual: Optional[torch.Tensor] = None,\n        attn_metadata: Any = None,\n    ):\n\n        if residual is None:\n            residual = hidden_states\n            hidden_states = self.input_layernorm(hidden_states)\n        else:\n            hidden_states, residual = self.input_layernorm(hidden_states, residual)\n\n        # Self Attention\n        hidden_states = self.self_attn(\n            hidden_states=hidden_states,\n            rotary_pos_emb=rotary_pos_emb,\n            past_key_value=past_key_value,\n            attn_metadata=attn_metadata,\n        )\n\n        # Fully Connected\n        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)\n        hidden_states = self.mlp(hidden_states)\n\n        outputs = (hidden_states, residual)\n        return outputs\n\n\nclass InternLMModel(nn.Module):\n    \"\"\"model.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n\n        self.embed_tokens = nn.Embedding(config.vocab_size,\n                                         config.hidden_size,\n                                         self.padding_idx,\n                                         dtype=dtype,\n                                         device=device)\n\n        # build all decode layers\n        self.layers = nn.ModuleList([\n            InternLMDecoderLayer(config, layer_idx, dtype=dtype, device=device)\n            for layer_idx in range(config.num_hidden_layers)\n        ])\n\n        # build norm\n        self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)\n\n        # build rotary embedding in LlamaModel\n        rope_dim = config.hidden_size // config.num_attention_heads\n        rope_max_pos_emb = config.max_position_embeddings\n        scaling_factor = 1.0\n        rope_scaling = config.rotary\n        rope_base = rope_scaling['base']\n        rope_type = rope_scaling['type']\n        if rope_type == 'dynamic':\n            emb_type = RopeType.DynamicNTKScaling\n            scaling_factor = rope_scaling.get('scaling_factor', 1.0)\n        elif rope_type == 'origin':\n            emb_type = RopeType.LinearScaling\n        else:\n            raise RuntimeError(f'Unsupported rope type: {rope_type}')\n\n        self.rotary_emb = build_rotary_embedding(\n            rope_dim,\n            rope_max_pos_emb,\n            rope_base,\n            scaling_factor,\n            emb_type=emb_type,\n        )\n\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        attn_metadata: Any = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n    ):\n        \"\"\"Rewrite of LlamaModel.forward.\"\"\"\n\n        # token embedding\n        if inputs_embeds is None:\n            inputs_embeds = self.embed_tokens(input_ids)\n\n        hidden_states = inputs_embeds\n\n        # rotary embedding\n        cos, sin = self.rotary_emb(hidden_states, position_ids)\n        cos, sin = cos[0], sin[0]\n        rotary_pos_emb = (cos, sin)\n\n        # decoding\n        residual = None\n        for idx, decoder_layer in enumerate(self.layers):\n            past_key_value = past_key_values[idx]\n            hidden_states, residual = decoder_layer(\n                hidden_states,\n                rotary_pos_emb=rotary_pos_emb,\n                past_key_value=past_key_value,\n                residual=residual,\n                attn_metadata=attn_metadata,\n            )\n\n        # norm\n        hidden_states, _ = self.norm(hidden_states, residual)\n\n        return hidden_states\n\n    def get_input_embeddings(self):\n        \"\"\"Get input embeddings.\"\"\"\n        return self.embed_tokens\n\n\nclass InternLMForCausalLM(nn.Module, CudaGraphMixin):\n    \"\"\"Rewrote model of LlamaForCausalLM.\"\"\"\n\n    packed_modules_mapping = {\n        'qkv_proj': [\n            'q_proj',\n            'k_proj',\n            'v_proj',\n        ],\n        'gate_up_proj': [\n            'gate_proj',\n            'up_proj',\n        ],\n    }\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 ctx_mgr: StepContextManager,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__()\n        self.config = config\n        self.ctx_mgr = ctx_mgr\n        # build LLamaModel\n        self.model = InternLMModel(config, dtype=dtype, device=device)\n        # build lm_head\n        self.lm_head = build_rowwise_linear(config.hidden_size,\n                                            config.vocab_size,\n                                            bias=False,\n                                            dtype=dtype,\n                                            device=device)\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        past_key_values: List[List[torch.Tensor]],\n        attn_metadata: Any = None,\n        inputs_embeds: torch.Tensor = None,\n        **kwargs,\n    ):\n        \"\"\"Model forward, return logits.\"\"\"\n        hidden_states = self.model(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            inputs_embeds=inputs_embeds,\n        )\n        return hidden_states\n\n    def get_logits(self, hidden_states: torch.Tensor):\n        \"\"\"Compute logits of the model output.\"\"\"\n        return self.lm_head(hidden_states)\n\n    def get_input_embeddings(self):\n        \"\"\"Get input embeddings.\"\"\"\n        return self.model.get_input_embeddings()\n\n    def prepare_inputs_for_generation(\n        self,\n        past_key_values: List[List[torch.Tensor]],\n        inputs_embeds: Optional[torch.Tensor] = None,\n        context: StepContext = None,\n    ):\n        \"\"\"Prepare input.\"\"\"\n        # get input_ids, position_ids and attention metadatas\n        input_ids = context.input_ids\n        position_ids = context.position_ids\n        attn_metadata = context.attn_metadata\n\n        # process vision embeddings\n        vision_embeddings = context.input_embeddings\n        vision_embedding_indexing = context.input_embedding_indexing\n        if vision_embeddings is not None and len(vision_embeddings) > 0:\n            if inputs_embeds is None:\n                inputs_embeds = self.get_input_embeddings()(input_ids)\n            inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds)\n\n        # inputs of forward\n        return dict(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            inputs_embeds=inputs_embeds,\n        )\n\n    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):\n        \"\"\"Load weights.\"\"\"\n        # modify from vllm\n        stacked_params_mapping = [\n            # (param_name, shard_name, shard_id)\n            ('.qkv_proj', '.q_proj', 'q'),\n            ('.qkv_proj', '.k_proj', 'k'),\n            ('.qkv_proj', '.v_proj', 'v'),\n            ('.gate_up_proj', '.gate_proj', 0),\n            ('.gate_up_proj', '.up_proj', 1),\n        ]\n\n        params_dict = dict(self.named_parameters())\n        for name, loaded_weight in weights:\n            if 'rotary_emb.inv_freq' in name:\n                continue\n            if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):\n                continue\n            if self.config.tie_word_embeddings and 'lm_head.weight' in name:\n                continue\n            for (param_name, weight_name, shard_id) in stacked_params_mapping:\n                if weight_name not in name:\n                    continue\n                name = name.replace(weight_name, param_name)\n                param = params_dict[name]\n                load_weight(param, loaded_weight, shard_id=shard_id)\n                break\n            else:\n                param = params_dict[name]\n                load_weight(param, loaded_weight)\n"
  },
  {
    "path": "lmdeploy/pytorch/models/internlm2.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nfrom typing import Any, Iterable, List, Optional, Tuple\n\nimport torch\nfrom torch import nn\nfrom transformers.configuration_utils import PretrainedConfig\n\nfrom lmdeploy.pytorch.model_inputs import StepContext, StepContextManager\nfrom lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, RMSNorm, SiluAndMul, build_rotary_embedding_from_config\nfrom lmdeploy.pytorch.nn.linear import build_down_linear, build_gateup_linear, build_o_proj, build_qkv_proj\nfrom lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight\n\nfrom .utils.cudagraph import CudaGraphMixin\nfrom .utils.model import DeployModelMixinV1, build_embedding\n\n\nclass InternLM2Attention(nn.Module):\n    \"\"\"Rewrite module of InternLM2Attention.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        quantization_config = getattr(config, 'quantization_config', None)\n        num_heads = config.num_attention_heads\n        num_key_value_heads = config.num_key_value_heads\n        hidden_size = config.hidden_size\n        head_dim = hidden_size // num_heads\n        num_replicate_kv_heads = getattr(config, 'num_replicate_key_value_heads', 1)\n        # packed qkv\n        self.wqkv = build_qkv_proj(\n            hidden_size,\n            num_q_heads=num_heads,\n            num_kv_heads=num_key_value_heads,\n            head_size=head_dim,\n            bias=config.bias,\n            quant_config=quantization_config,\n            dtype=dtype,\n            device=device,\n            num_replicate_kv_heads=num_replicate_kv_heads,\n        )\n\n        # rotary embedding\n        self.apply_rotary_pos_emb = ApplyRotaryEmb()\n\n        # attention\n        self.attn_fwd = Attention(\n            num_heads,\n            head_dim,\n            num_kv_heads=num_key_value_heads,\n        )\n\n        # o_proj\n        self.wo = build_o_proj(num_heads * head_dim,\n                               hidden_size,\n                               bias=config.bias,\n                               quant_config=quantization_config,\n                               dtype=dtype,\n                               device=device,\n                               is_tp=True)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],\n        past_key_value: Optional[Tuple[torch.Tensor]] = None,\n        attn_metadata: Any = None,\n    ):\n        \"\"\"Rewrite of InternLM2Attention.forward.\"\"\"\n        # qkv proj\n        qkv_states = self.wqkv(hidden_states)\n        # (-1, heads, head_dim)\n        qkv_states = qkv_states.flatten(0, -2)\n        query_states, key_states, value_states = self.wqkv.split_qkv(qkv_states)\n\n        # apply rotary embedding\n        cos, sin = rotary_pos_emb\n        query_states, key_states = self.apply_rotary_pos_emb(\n            query_states,\n            key_states,\n            cos,\n            sin,\n            inplace=True,\n        )\n\n        # attention\n        attn_output = self.attn_fwd(\n            query_states,\n            key_states,\n            value_states,\n            past_key_value[0],\n            past_key_value[1],\n            attn_metadata,\n            k_scales_zeros=None if len(past_key_value) == 2 else past_key_value[2],\n            v_scales_zeros=None if len(past_key_value) == 2 else past_key_value[3],\n            inplace=True,\n        )\n        attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1)\n\n        # o proj\n        attn_output = self.wo(attn_output)\n        return attn_output\n\n\nclass InternLM2MLP(nn.Module):\n    \"\"\"mlp.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        quantization_config = getattr(config, 'quantization_config', None)\n        # gate up\n        self.gate_up_proj = build_gateup_linear(\n            config.hidden_size,\n            [config.intermediate_size, config.intermediate_size],\n            bias=False,\n            dtype=dtype,\n            device=device,\n            quant_config=quantization_config,\n            is_tp=True,\n        )\n\n        # silu and mul\n        self.act_fn = SiluAndMul(inplace=True)\n\n        # down\n        self.w2 = build_down_linear(config.intermediate_size,\n                                    config.hidden_size,\n                                    bias=False,\n                                    quant_config=quantization_config,\n                                    dtype=dtype,\n                                    device=device,\n                                    is_tp=True)\n\n    def forward(self, x):\n        \"\"\"forward.\"\"\"\n        gate_up = self.gate_up_proj(x)\n        act = self.act_fn(gate_up)\n        return self.w2(act)\n\n\nclass InternLM2DecoderLayer(nn.Module):\n    \"\"\"Decoder layer.\"\"\"\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 layer_idx: int,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__()\n        self.layer_idx = layer_idx\n        quantization_config = getattr(config, 'quantization_config', None)\n\n        # build attention layer\n        self.attention = InternLM2Attention(config, dtype=dtype, device=device)\n\n        # build MLP\n        self.feed_forward = InternLM2MLP(config, dtype=dtype, device=device)\n\n        # build input layer norm\n        self.attention_norm = RMSNorm(config.hidden_size,\n                                      config.rms_norm_eps,\n                                      quant_config=quantization_config,\n                                      dtype=dtype,\n                                      device=device)\n\n        # build attention layer norm\n        self.ffn_norm = RMSNorm(config.hidden_size,\n                                config.rms_norm_eps,\n                                quant_config=quantization_config,\n                                dtype=dtype,\n                                device=device)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],\n        past_key_value: Optional[List[torch.FloatTensor]],\n        residual: Optional[torch.Tensor] = None,\n        attn_metadata: Any = None,\n    ):\n\n        if residual is None:\n            residual = hidden_states\n            hidden_states = self.attention_norm(hidden_states)\n        else:\n            hidden_states, residual = self.attention_norm(hidden_states, residual)\n\n        # Self Attention\n        hidden_states = self.attention(\n            hidden_states=hidden_states,\n            rotary_pos_emb=rotary_pos_emb,\n            past_key_value=past_key_value,\n            attn_metadata=attn_metadata,\n        )\n\n        # Fully Connected\n        hidden_states, residual = self.ffn_norm(hidden_states, residual)\n        hidden_states = self.feed_forward(hidden_states)\n\n        outputs = (hidden_states, residual)\n        return outputs\n\n\nclass InternLM2Model(nn.Module):\n    \"\"\"Internlm2 model.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n\n        self.tok_embeddings = build_embedding(\n            config.vocab_size,\n            config.hidden_size,\n            self.padding_idx,\n            dtype=dtype,\n            device=device,\n        )\n        # build all decode layers\n        self.layers = nn.ModuleList([\n            InternLM2DecoderLayer(config, layer_idx, dtype=dtype, device=device)\n            for layer_idx in range(config.num_hidden_layers)\n        ])\n\n        # build norm\n        self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)\n\n        # build rotary embedding in Model\n        self.rotary_emb = build_rotary_embedding_from_config(config)\n\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        attn_metadata: Any = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n    ):\n        \"\"\"Rewrite of forward.\"\"\"\n\n        # token embedding\n        if inputs_embeds is None:\n            inputs_embeds = self.tok_embeddings(input_ids)\n\n        hidden_states = inputs_embeds\n\n        # rotary embedding\n        cos, sin = self.rotary_emb(hidden_states, position_ids)\n        cos, sin = cos[0], sin[0]\n        rotary_pos_emb = (cos, sin)\n\n        # decoding\n        residual = None\n        for idx, decoder_layer in enumerate(self.layers):\n            past_key_value = past_key_values[idx]\n            hidden_states, residual = decoder_layer(\n                hidden_states,\n                rotary_pos_emb=rotary_pos_emb,\n                past_key_value=past_key_value,\n                residual=residual,\n                attn_metadata=attn_metadata,\n            )\n\n        # norm\n        hidden_states, _ = self.norm(hidden_states, residual)\n\n        return hidden_states\n\n    def get_input_embeddings(self):\n        \"\"\"Get input embeddings.\"\"\"\n        return self.tok_embeddings\n\n\nclass InternLM2ForCausalLM(nn.Module, DeployModelMixinV1, CudaGraphMixin):\n    \"\"\"Rewrote model of InternLM2ForCausalLM.\"\"\"\n\n    packed_modules_mapping = {\n        'gate_up_proj': [\n            'w1',\n            'w3',\n        ],\n    }\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 ctx_mgr: StepContextManager,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__()\n        self.config = config\n        self.ctx_mgr = ctx_mgr\n        # build Model\n        self.model = InternLM2Model(config, dtype=dtype, device=device)\n        # build lm_head\n        self.output = self.build_lm_head(config.hidden_size, config.vocab_size, bias=False, dtype=dtype, device=device)\n        self.lm_head = self.output\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        past_key_values: List[List[torch.Tensor]],\n        attn_metadata: Any = None,\n        inputs_embeds: torch.Tensor = None,\n        **kwargs,\n    ):\n        \"\"\"Model forward, return logits.\"\"\"\n        hidden_states = self.model(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            inputs_embeds=inputs_embeds,\n        )\n        return hidden_states\n\n    def get_input_embeddings(self):\n        \"\"\"Get input embeddings.\"\"\"\n        return self.model.get_input_embeddings()\n\n    def prepare_inputs_for_generation(\n        self,\n        past_key_values: List[List[torch.Tensor]],\n        inputs_embeds: Optional[torch.Tensor] = None,\n        context: StepContext = None,\n    ):\n        \"\"\"Prepare input.\"\"\"\n        # get input_ids, position_ids and attention metadatas\n        input_ids = context.input_ids\n        position_ids = context.position_ids\n        attn_metadata = context.attn_metadata\n\n        # process vision embeddings\n        vision_embeddings = context.input_embeddings\n        vision_embedding_indexing = context.input_embedding_indexing\n        if vision_embeddings is not None and len(vision_embeddings) > 0:\n            if inputs_embeds is None:\n                inputs_embeds = self.get_input_embeddings()(input_ids)\n            inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds)\n\n        # inputs of forward\n        return dict(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            inputs_embeds=inputs_embeds,\n        )\n\n    def load_lora_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], adapter_id: int):\n        \"\"\"Load lora weights.\"\"\"\n\n        from lmdeploy.pytorch.adapter.adapter import load_lora_weights\n\n        num_heads = self.config.num_attention_heads\n        num_key_value_heads = self.config.num_key_value_heads\n        hidden_size = self.config.hidden_size\n        head_dim = hidden_size // num_heads\n        group_size = num_heads // num_key_value_heads\n\n        def _rearange_wqkv(weights):\n            for name, loaded_weight in weights:\n                if 'wqkv.lora_B' in name:\n                    loaded_weight = loaded_weight.unflatten(0, (-1, 2 + group_size, head_dim))\n                    q = loaded_weight[:, :-2].flatten(0, 2)\n                    k = loaded_weight[:, -2].flatten(0, 1)\n                    v = loaded_weight[:, -1].flatten(0, 1)\n                    loaded_weight = torch.cat([q, k, v], dim=0)\n                yield name, loaded_weight\n\n        weights_iter = _rearange_wqkv(weights)\n        load_lora_weights(self, weights_iter, adapter_id)\n\n    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):\n        \"\"\"Load weights.\"\"\"\n        # modify from vllm\n        stacked_params_mapping = [\n            # (param_name, shard_name, shard_id)\n            ('.gate_up_proj', '.w1', 0),\n            ('.gate_up_proj', '.w3', 1),\n        ]\n\n        params_dict = dict(self.named_parameters())\n        for name, loaded_weight in weights:\n            if 'rotary_emb.inv_freq' in name:\n                continue\n            if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):\n                continue\n            for (param_name, weight_name, shard_id) in stacked_params_mapping:\n                if weight_name not in name:\n                    continue\n                name = name.replace(weight_name, param_name)\n                param = params_dict[name]\n                load_weight(param, loaded_weight, shard_id=shard_id)\n                break\n            else:\n                if '.wqkv' in name:\n                    param = params_dict[name]\n                    q, k, v = param.weight_spliter(loaded_weight, layout='hgd')\n                    load_weight(param, q, shard_id='q')\n                    load_weight(param, k, shard_id='k')\n                    load_weight(param, v, shard_id='v')\n                else:\n                    param = params_dict[name]\n                    load_weight(param, loaded_weight)\n"
  },
  {
    "path": "lmdeploy/pytorch/models/internlm2_reward.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nfrom typing import Any, Iterable, List, Optional, Tuple\n\nimport torch\nfrom torch import nn\nfrom transformers.configuration_utils import PretrainedConfig\n\nfrom lmdeploy.pytorch.model_inputs import StepContext, StepContextManager\nfrom lmdeploy.pytorch.nn.linear import build_rowwise_linear\nfrom lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight\n\nfrom .internlm2 import InternLM2Model\nfrom .utils.cudagraph import CudaGraphMixin\n\n\nclass InternLM2ForRewardModel(nn.Module, CudaGraphMixin):\n    \"\"\"Rewrote model of InternLM2ForRewardModel.\"\"\"\n\n    packed_modules_mapping = {\n        'gate_up_proj': [\n            'w1',\n            'w3',\n        ],\n    }\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 ctx_mgr: StepContextManager,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__()\n        self.config = config\n        self.ctx_mgr = ctx_mgr\n        # build Model\n        self.model = InternLM2Model(config, dtype=dtype, device=device)\n        # build v_head\n        self.v_head = build_rowwise_linear(config.hidden_size, 1, bias=False, dtype=dtype, device=device)\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        past_key_values: List[List[torch.Tensor]],\n        attn_metadata: Any = None,\n        inputs_embeds: torch.Tensor = None,\n        **kwargs,\n    ):\n        \"\"\"Model forward, return logits.\"\"\"\n        hidden_states = self.model(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            inputs_embeds=inputs_embeds,\n        )\n        return hidden_states\n\n    def get_logits(self, hidden_states: torch.Tensor):\n        \"\"\"Compute logits of the model output.\"\"\"\n        return self.v_head(hidden_states)\n\n    def get_input_embeddings(self):\n        \"\"\"Get input embeddings.\"\"\"\n        return self.model.get_input_embeddings()\n\n    def prepare_inputs_for_generation(\n        self,\n        past_key_values: List[List[torch.Tensor]],\n        inputs_embeds: Optional[torch.Tensor] = None,\n        context: StepContext = None,\n    ):\n        \"\"\"Prepare input.\"\"\"\n        # get input_ids, position_ids and attention metadatas\n        input_ids = context.input_ids\n        position_ids = context.position_ids\n        attn_metadata = context.attn_metadata\n\n        vision_embeddings = context.input_embeddings\n        if vision_embeddings is not None and len(vision_embeddings) > 0:\n            raise ValueError('InternLM2RewardModel does not support vision embedding')\n\n        # inputs of forward\n        return dict(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            inputs_embeds=inputs_embeds,\n        )\n\n    def load_lora_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], adapter_id: int):\n        \"\"\"Load lora weights.\"\"\"\n\n        from lmdeploy.pytorch.adapter.adapter import load_lora_weights\n\n        num_heads = self.config.num_attention_heads\n        num_key_value_heads = self.config.num_key_value_heads\n        hidden_size = self.config.hidden_size\n        head_dim = hidden_size // num_heads\n        group_size = num_heads // num_key_value_heads\n\n        def _rearange_wqkv(weights):\n            for name, loaded_weight in weights:\n                if 'wqkv.lora_B' in name:\n                    loaded_weight = loaded_weight.unflatten(0, (-1, 2 + group_size, head_dim))\n                    q = loaded_weight[:, :-2].flatten(0, 2)\n                    k = loaded_weight[:, -2].flatten(0, 1)\n                    v = loaded_weight[:, -1].flatten(0, 1)\n                    loaded_weight = torch.cat([q, k, v], dim=0)\n                yield name, loaded_weight\n\n        weights_iter = _rearange_wqkv(weights)\n        load_lora_weights(self, weights_iter, adapter_id)\n\n    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):\n        \"\"\"Load weights.\"\"\"\n        # modify from vllm\n        stacked_params_mapping = [\n            # (param_name, shard_name, shard_id)\n            ('.gate_up_proj', '.w1', 0),\n            ('.gate_up_proj', '.w3', 1),\n        ]\n\n        params_dict = dict(self.named_parameters())\n        for name, loaded_weight in weights:\n            if 'rotary_emb.inv_freq' in name:\n                continue\n            if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):\n                continue\n            for (param_name, weight_name, shard_id) in stacked_params_mapping:\n                if weight_name not in name:\n                    continue\n                name = name.replace(weight_name, param_name)\n                param = params_dict[name]\n                load_weight(param, loaded_weight, shard_id=shard_id)\n                break\n            else:\n                if '.wqkv' in name:\n                    param = params_dict[name]\n                    q, k, v = param.weight_spliter(loaded_weight, layout='hgd')\n                    load_weight(param, q, shard_id='q')\n                    load_weight(param, k, shard_id='k')\n                    load_weight(param, v, shard_id='v')\n                else:\n                    param = params_dict[name]\n                    load_weight(param, loaded_weight)\n"
  },
  {
    "path": "lmdeploy/pytorch/models/internlm2_ve.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom typing import Any, Iterable, List, Optional, Tuple\n\nimport torch\nfrom torch import nn\nfrom transformers.configuration_utils import PretrainedConfig\n\nfrom lmdeploy.pytorch.model_inputs import StepContext, StepContextManager\nfrom lmdeploy.pytorch.models.internlm2 import InternLM2Attention, InternLM2MLP\nfrom lmdeploy.pytorch.nn import RMSNorm, RopeType, build_rotary_embedding\nfrom lmdeploy.pytorch.nn.linear import build_rowwise_linear\nfrom lmdeploy.pytorch.nn.rotary_embedding import get_rope_parameters, get_rope_theta\nfrom lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight\n\nfrom .utils.cudagraph import CudaGraphMixin\n\n\nclass InternLM2VEDecoderLayer(nn.Module):\n    \"\"\"Decoder layer with visual expert.\"\"\"\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 layer_idx: int,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__()\n        self.layer_idx = layer_idx\n        self.hidden_size = config.hidden_size\n        quantization_config = getattr(config, 'quantization_config', None)\n\n        # build attention layer\n        self.attention = InternLM2Attention(config, dtype=dtype, device=device)\n\n        # build MLP\n        self.feed_forward = InternLM2MLP(config, dtype=dtype, device=device)\n\n        # build visual expert\n        self.feed_forward_ve = InternLM2MLP(config, dtype=dtype, device=device)\n\n        # build input layer norm\n        self.attention_norm = RMSNorm(config.hidden_size,\n                                      config.rms_norm_eps,\n                                      quant_config=quantization_config,\n                                      dtype=dtype,\n                                      device=device)\n\n        # build attention layer norm\n        self.ffn_norm = RMSNorm(config.hidden_size,\n                                config.rms_norm_eps,\n                                quant_config=quantization_config,\n                                dtype=dtype,\n                                device=device)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],\n        past_key_value: Optional[List[torch.FloatTensor]],\n        residual: Optional[torch.Tensor] = None,\n        attn_metadata: Any = None,\n        vision_embedding_indexing: Optional[torch.Tensor] = None,\n        text_embedding_indexing: Optional[torch.Tensor] = None,\n    ):\n\n        if residual is None:\n            residual = hidden_states\n            hidden_states = self.attention_norm(hidden_states)\n        else:\n            hidden_states, residual = self.attention_norm(hidden_states, residual)\n\n        # Self Attention\n        hidden_states = self.attention(\n            hidden_states=hidden_states,\n            rotary_pos_emb=rotary_pos_emb,\n            past_key_value=past_key_value,\n            attn_metadata=attn_metadata,\n        )\n\n        # Fully Connected\n        hidden_states, residual = self.ffn_norm(hidden_states, residual)\n        if vision_embedding_indexing is not None:\n            hidden_states[:, vision_embedding_indexing, :] = self.feed_forward_ve(\n                hidden_states[:, vision_embedding_indexing, :].reshape(-1, self.hidden_size)).unsqueeze(0)\n            if text_embedding_indexing is not None:\n                hidden_states[:, text_embedding_indexing, :] = self.feed_forward(\n                    hidden_states[:, text_embedding_indexing, :].reshape(-1, self.hidden_size)).unsqueeze(0)\n        else:\n            hidden_states = self.feed_forward(hidden_states)\n\n        outputs = (hidden_states, residual)\n        return outputs\n\n\nclass InternLM2VEModel(nn.Module):\n    \"\"\"Internlm2 model with visual expert.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n\n        self.tok_embeddings = nn.Embedding(config.vocab_size,\n                                           config.hidden_size,\n                                           self.padding_idx,\n                                           dtype=dtype,\n                                           device=device)\n\n        # build all decode layers\n        self.layers = nn.ModuleList([\n            InternLM2VEDecoderLayer(config, layer_idx, dtype=dtype, device=device)\n            for layer_idx in range(config.num_hidden_layers)\n        ])\n\n        # build norm\n        self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)\n\n        # build rotary embedding in Model\n        rope_scaling = get_rope_parameters(config)\n        scaling_factor = 1.0\n        emb_type = RopeType.LinearScaling\n        if rope_scaling is not None:\n            scaling_factor = rope_scaling.get('factor', scaling_factor)\n            rope_type = rope_scaling['type']\n            if rope_type == 'linear':\n                emb_type = RopeType.LinearScaling\n            if rope_type == 'dynamic':\n                emb_type = RopeType.DynamicNTKScaling\n            else:\n                raise RuntimeError(f'Unsupported rope type: {rope_type}')\n        rope_dim = config.hidden_size // config.num_attention_heads\n        rope_max_pos_emb = config.max_position_embeddings\n        rope_base = get_rope_theta(config)\n        self.rotary_emb = build_rotary_embedding(\n            rope_dim,\n            rope_max_pos_emb,\n            rope_base,\n            scaling_factor,\n            emb_type=emb_type,\n        )\n\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        attn_metadata: Any = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        vision_embedding_indexing: Optional[torch.Tensor] = None,\n        text_embedding_indexing: Optional[torch.Tensor] = None,\n    ):\n        \"\"\"Rewrite of forward.\"\"\"\n\n        # token embedding\n        if inputs_embeds is None:\n            inputs_embeds = self.tok_embeddings(input_ids)\n\n        hidden_states = inputs_embeds\n\n        # rotary embedding\n        cos, sin = self.rotary_emb(hidden_states, position_ids)\n        cos, sin = cos[0], sin[0]\n        rotary_pos_emb = (cos, sin)\n\n        # decoding\n        residual = None\n        for idx, decoder_layer in enumerate(self.layers):\n            past_key_value = past_key_values[idx]\n            hidden_states, residual = decoder_layer(\n                hidden_states,\n                rotary_pos_emb=rotary_pos_emb,\n                past_key_value=past_key_value,\n                residual=residual,\n                attn_metadata=attn_metadata,\n                vision_embedding_indexing=vision_embedding_indexing,\n                text_embedding_indexing=text_embedding_indexing,\n            )\n\n        # norm\n        hidden_states, _ = self.norm(hidden_states, residual)\n\n        return hidden_states\n\n    def get_input_embeddings(self):\n        \"\"\"Get input embeddings.\"\"\"\n        return self.tok_embeddings\n\n\nclass InternLM2VEForCausalLM(nn.Module, CudaGraphMixin):\n    \"\"\"Rewrote model of InternLM2ForCausalLM with visual expert.\"\"\"\n\n    packed_modules_mapping = {\n        'gate_up_proj': [\n            'w1',\n            'w3',\n        ],\n    }\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 ctx_mgr: StepContextManager,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__()\n        self.config = config\n        self.ctx_mgr = ctx_mgr\n        # build Model\n        self.model = InternLM2VEModel(config, dtype=dtype, device=device)\n        # build lm_head\n        self.output = build_rowwise_linear(config.hidden_size,\n                                           config.vocab_size,\n                                           bias=False,\n                                           dtype=dtype,\n                                           device=device)\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        past_key_values: List[List[torch.Tensor]],\n        attn_metadata: Any = None,\n        inputs_embeds: torch.Tensor = None,\n        vision_embedding_indexing: Optional[torch.Tensor] = None,\n        text_embedding_indexing: Optional[torch.Tensor] = None,\n        **kwargs,\n    ):\n        \"\"\"Model forward, return logits.\"\"\"\n        hidden_states = self.model(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            inputs_embeds=inputs_embeds,\n            vision_embedding_indexing=vision_embedding_indexing,\n            text_embedding_indexing=text_embedding_indexing,\n        )\n        return hidden_states\n\n    def get_logits(self, hidden_states: torch.Tensor):\n        \"\"\"Compute logits of the model output.\"\"\"\n        return self.output(hidden_states)\n\n    def support_cuda_graph(\n        self,\n        input_ids: torch.Tensor,\n        attn_metadata: Any = None,\n        **kwargs,\n    ):\n        \"\"\"Support cudagraph.\"\"\"\n        if not attn_metadata.is_decoding:\n            return False\n        seq_lens = input_ids.size(1)\n        if seq_lens <= 512:\n            return True\n        return False\n\n    def get_input_embeddings(self):\n        \"\"\"Get input embeddings.\"\"\"\n        return self.model.get_input_embeddings()\n\n    def prepare_inputs_for_generation(\n        self,\n        past_key_values: List[List[torch.Tensor]],\n        inputs_embeds: Optional[torch.Tensor] = None,\n        context: StepContext = None,\n    ):\n        \"\"\"Prepare input.\"\"\"\n        # get input_ids, position_ids and attention metadatas\n        input_ids = context.input_ids\n        position_ids = context.position_ids\n        attn_metadata = context.attn_metadata\n\n        # process vision embeddings\n        vision_embeddings = context.input_embeddings\n        vision_embedding_indexing = context.input_embedding_indexing\n        if vision_embeddings is not None and len(vision_embeddings) > 0:\n            if inputs_embeds is None:\n                inputs_embeds = self.get_input_embeddings()(input_ids)\n            inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds)\n\n        # inputs of forward\n        return dict(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            inputs_embeds=inputs_embeds,\n        )\n\n    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):\n        \"\"\"Load weights.\"\"\"\n        # modify from vllm\n        stacked_params_mapping = [\n            # (param_name, shard_name, shard_id)\n            ('.gate_up_proj', '.w1', 0),\n            ('.gate_up_proj', '.w3', 1),\n        ]\n\n        params_dict = dict(self.named_parameters())\n        for name, loaded_weight in weights:\n            if 'rotary_emb.inv_freq' in name:\n                continue\n            if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):\n                continue\n            for (param_name, weight_name, shard_id) in stacked_params_mapping:\n                if weight_name not in name:\n                    continue\n                name = name.replace(weight_name, param_name)\n                param = params_dict[name]\n                load_weight(param, loaded_weight, shard_id=shard_id)\n                break\n            else:\n                if '.wqkv' in name:\n                    param = params_dict[name]\n                    q, k, v = param.weight_spliter(loaded_weight, layout='hgd')\n                    load_weight(param, q, shard_id='q')\n                    load_weight(param, k, shard_id='k')\n                    load_weight(param, v, shard_id='v')\n                else:\n                    param = params_dict[name]\n                    load_weight(param, loaded_weight)\n"
  },
  {
    "path": "lmdeploy/pytorch/models/internlm3.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nfrom typing import Any, Iterable, List, Optional, Tuple\n\nimport torch\nfrom torch import nn\nfrom transformers.configuration_utils import PretrainedConfig\n\nfrom lmdeploy.pytorch.model_inputs import StepContext, StepContextManager\nfrom lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, RMSNorm, SiluAndMul, build_rotary_embedding_from_config\nfrom lmdeploy.pytorch.nn.linear import build_down_linear, build_gateup_linear, build_o_proj, build_qkv_proj\nfrom lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight\n\nfrom .utils.cudagraph import CudaGraphMixin\nfrom .utils.model import DeployModelMixinV1, build_embedding\n\n\nclass InternLM3Attention(nn.Module):\n    \"\"\"Rewrite module of InternLM3Attention.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        quantization_config = getattr(config, 'quantization_config', None)\n        num_heads = config.num_attention_heads\n        num_key_value_heads = config.num_key_value_heads\n        hidden_size = config.hidden_size\n        head_dim = getattr(config, 'head_dim', hidden_size // num_heads)\n        num_replicate_kv_heads = getattr(config, 'num_replicate_key_value_heads', 1)\n        # packed qkv\n        self.qkv_proj = build_qkv_proj(\n            hidden_size,\n            num_q_heads=num_heads,\n            num_kv_heads=num_key_value_heads,\n            head_size=head_dim,\n            bias=config.qkv_bias,\n            quant_config=quantization_config,\n            dtype=dtype,\n            device=device,\n            num_replicate_kv_heads=num_replicate_kv_heads,\n        )\n\n        # rotary embedding\n        self.apply_rotary_pos_emb = ApplyRotaryEmb()\n\n        # attention\n        self.attn_fwd = Attention(\n            num_heads,\n            head_dim,\n            num_kv_heads=num_key_value_heads,\n            v_head_size=head_dim,\n        )\n\n        # o_proj\n        self.o_proj = build_o_proj(num_heads * head_dim,\n                                   hidden_size,\n                                   bias=config.bias,\n                                   quant_config=quantization_config,\n                                   dtype=dtype,\n                                   device=device,\n                                   is_tp=True)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],\n        past_key_value: Optional[Tuple[torch.Tensor]] = None,\n        attn_metadata: Any = None,\n    ):\n        \"\"\"Rewrite of InternLM3Attention.forward.\"\"\"\n        # qkv proj\n        qkv_states = self.qkv_proj(hidden_states)\n        # (-1, heads, head_dim)\n        qkv_states = qkv_states.flatten(0, -2)\n        query_states, key_states, value_states = self.qkv_proj.split_qkv(qkv_states)\n\n        # apply rotary embedding\n        cos, sin = rotary_pos_emb\n        query_states, key_states = self.apply_rotary_pos_emb(\n            query_states,\n            key_states,\n            cos,\n            sin,\n            inplace=True,\n        )\n\n        # attention\n        attn_output = self.attn_fwd(\n            query_states,\n            key_states,\n            value_states,\n            past_key_value[0],\n            past_key_value[1],\n            attn_metadata,\n            k_scales_zeros=None if len(past_key_value) == 2 else past_key_value[2],\n            v_scales_zeros=None if len(past_key_value) == 2 else past_key_value[3],\n            inplace=True,\n        )\n        attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1)\n\n        # o proj\n        attn_output = self.o_proj(attn_output)\n        return attn_output\n\n\nclass InternLM3MLP(nn.Module):\n    \"\"\"Internlm3 mlp.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        quantization_config = getattr(config, 'quantization_config', None)\n        # gate up\n        mlp_bias = getattr(config, 'bias', False)\n        self.gate_up_proj = build_gateup_linear(\n            config.hidden_size,\n            [config.intermediate_size, config.intermediate_size],\n            bias=mlp_bias,\n            dtype=dtype,\n            device=device,\n            quant_config=quantization_config,\n            is_tp=True,\n        )\n\n        # silu and mul\n        self.act_fn = SiluAndMul(inplace=True)\n\n        # down\n        self.down_proj = build_down_linear(config.intermediate_size,\n                                           config.hidden_size,\n                                           bias=mlp_bias,\n                                           quant_config=quantization_config,\n                                           dtype=dtype,\n                                           device=device,\n                                           is_tp=True)\n\n    def forward(self, x):\n        \"\"\"forward.\"\"\"\n        gate_up = self.gate_up_proj(x)\n        act = self.act_fn(gate_up)\n        return self.down_proj(act)\n\n\nclass InternLM3DecoderLayer(nn.Module):\n    \"\"\"Llama decoder layer.\"\"\"\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 layer_idx: int,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__()\n        self.layer_idx = layer_idx\n        quantization_config = getattr(config, 'quantization_config', None)\n\n        # build attention layer\n        self.self_attn = InternLM3Attention(config, dtype=dtype, device=device)\n\n        # build MLP\n        self.mlp = InternLM3MLP(config, dtype=dtype, device=device)\n\n        # build input layer norm\n        self.input_layernorm = RMSNorm(config.hidden_size,\n                                       config.rms_norm_eps,\n                                       quant_config=quantization_config,\n                                       dtype=dtype,\n                                       device=device)\n\n        # build attention layer norm\n        self.post_attention_layernorm = RMSNorm(config.hidden_size,\n                                                config.rms_norm_eps,\n                                                quant_config=quantization_config,\n                                                dtype=dtype,\n                                                device=device)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],\n        past_key_value: Optional[List[torch.FloatTensor]],\n        residual: Optional[torch.Tensor] = None,\n        attn_metadata: Any = None,\n    ):\n\n        if residual is None:\n            residual = hidden_states\n            hidden_states = self.input_layernorm(hidden_states)\n        else:\n            hidden_states, residual = self.input_layernorm(hidden_states, residual)\n\n        # Self Attention\n        hidden_states = self.self_attn(\n            hidden_states=hidden_states,\n            rotary_pos_emb=rotary_pos_emb,\n            past_key_value=past_key_value,\n            attn_metadata=attn_metadata,\n        )\n\n        # Fully Connected\n        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)\n        hidden_states = self.mlp(hidden_states)\n\n        outputs = (hidden_states, residual)\n        return outputs\n\n\nclass InternLM3Model(nn.Module):\n    \"\"\"Internlm3 model.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n\n        self.embed_tokens = build_embedding(\n            config.vocab_size,\n            config.hidden_size,\n            self.padding_idx,\n            dtype=dtype,\n            device=device,\n        )\n\n        # build all decode layers\n        self.layers = nn.ModuleList([\n            InternLM3DecoderLayer(config, layer_idx, dtype=dtype, device=device)\n            for layer_idx in range(config.num_hidden_layers)\n        ])\n\n        # build norm\n        self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)\n\n        # build rotary embedding\n        self.rotary_emb = build_rotary_embedding_from_config(config)\n\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        attn_metadata: Any = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n    ):\n        \"\"\"Rewrite of InternLM3Model.forward.\"\"\"\n\n        # token embedding\n        if inputs_embeds is None:\n            inputs_embeds = self.embed_tokens(input_ids)\n\n        hidden_states = inputs_embeds\n\n        # rotary embedding\n        cos, sin = self.rotary_emb(hidden_states, position_ids)\n        cos, sin = cos[0], sin[0]\n        rotary_pos_emb = (cos, sin)\n\n        # decoding\n        residual = None\n        for idx, decoder_layer in enumerate(self.layers):\n            past_key_value = past_key_values[idx]\n            hidden_states, residual = decoder_layer(\n                hidden_states,\n                rotary_pos_emb=rotary_pos_emb,\n                past_key_value=past_key_value,\n                residual=residual,\n                attn_metadata=attn_metadata,\n            )\n\n        # norm\n        hidden_states, _ = self.norm(hidden_states, residual)\n\n        return hidden_states\n\n    def get_input_embeddings(self):\n        \"\"\"Get input embeddings.\"\"\"\n        return self.embed_tokens\n\n\nclass InternLM3ForCausalLM(nn.Module, DeployModelMixinV1, CudaGraphMixin):\n    \"\"\"Rewrote model of InternLM3ForCausalLM.\"\"\"\n\n    packed_modules_mapping = {\n        'qkv_proj': [\n            'q_proj',\n            'k_proj',\n            'v_proj',\n        ],\n        'gate_up_proj': [\n            'gate_proj',\n            'up_proj',\n        ],\n    }\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 ctx_mgr: StepContextManager,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__()\n        self.config = config\n        self.ctx_mgr = ctx_mgr\n        # build InternLM3Model\n        self.model = InternLM3Model(config, dtype=dtype, device=device)\n        # build lm_head\n        self.lm_head = self.build_lm_head(config.hidden_size, config.vocab_size, bias=False, dtype=dtype, device=device)\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        past_key_values: List[List[torch.Tensor]],\n        attn_metadata: Any = None,\n        inputs_embeds: torch.Tensor = None,\n        **kwargs,\n    ):\n        \"\"\"Model forward, return logits.\"\"\"\n        hidden_states = self.model(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            inputs_embeds=inputs_embeds,\n        )\n        return hidden_states\n\n    def get_input_embeddings(self):\n        \"\"\"Get input embeddings.\"\"\"\n        return self.model.get_input_embeddings()\n\n    def prepare_inputs_for_generation(\n        self,\n        past_key_values: List[List[torch.Tensor]],\n        inputs_embeds: Optional[torch.Tensor] = None,\n        context: StepContext = None,\n    ):\n        \"\"\"Prepare input.\"\"\"\n        # get input_ids, position_ids and attention metadatas\n        input_ids = context.input_ids\n        position_ids = context.position_ids\n        attn_metadata = context.attn_metadata\n\n        # process vision embeddings\n        vision_embeddings = context.input_embeddings\n        vision_embedding_indexing = context.input_embedding_indexing\n        if vision_embeddings is not None and len(vision_embeddings) > 0:\n            if inputs_embeds is None:\n                inputs_embeds = self.get_input_embeddings()(input_ids)\n            inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds)\n\n        # inputs of forward\n        return dict(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            inputs_embeds=inputs_embeds,\n        )\n\n    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):\n        \"\"\"Load weights.\"\"\"\n        # modify from vllm\n        stacked_params_mapping = [\n            # (param_name, shard_name, shard_id)\n            ('.qkv_proj', '.q_proj', 'q'),\n            ('.qkv_proj', '.k_proj', 'k'),\n            ('.qkv_proj', '.v_proj', 'v'),\n            ('.gate_up_proj', '.gate_proj', 0),\n            ('.gate_up_proj', '.up_proj', 1),\n        ]\n\n        params_dict = dict(self.named_parameters())\n        for name, loaded_weight in weights:\n            if 'rotary_emb.inv_freq' in name:\n                continue\n            if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):\n                continue\n            if self.config.tie_word_embeddings and 'lm_head.weight' in name:\n                continue\n\n            for (param_name, weight_name, shard_id) in stacked_params_mapping:\n                if weight_name not in name:\n                    continue\n                name = name.replace(weight_name, param_name)\n                param = params_dict[name]\n                load_weight(param, loaded_weight, shard_id=shard_id)\n                break\n            else:\n                param = params_dict[name]\n                load_weight(param, loaded_weight)\n"
  },
  {
    "path": "lmdeploy/pytorch/models/interns1_pro.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom typing import Any, Dict, Iterable, List, Optional, Tuple\n\nimport torch\nfrom torch import nn\nfrom transformers.configuration_utils import PretrainedConfig\n\nfrom lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor, PreprocessInputResult\nfrom lmdeploy.pytorch.model_inputs import StepContext, StepContextManager\nfrom lmdeploy.pytorch.multimodal.data_type import MultiModalData\nfrom lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight\nfrom lmdeploy.vl.constants import Modality\n\nfrom .interns1_pro_ts import InternS1ProTimeSeriesModel\nfrom .patch import add_prefix, get_build_model_context\nfrom .qwen3_moe import Qwen3MoeModel\nfrom .qwen3_vl import Qwen3VLVisionModel\nfrom .utils.cudagraph import CudaGraphMixin\nfrom .utils.model import DeployModelMixinV1\n\n\nclass InternS1ProForConditionalGeneration(nn.Module, DeployModelMixinV1, CudaGraphMixin):\n    \"\"\"ModelForCausalLM.\"\"\"\n\n    packed_modules_mapping = {\n        'qkv_proj': [\n            'q_proj',\n            'k_proj',\n            'v_proj',\n        ],\n        'gate_up_proj': [\n            'gate_proj',\n            'up_proj',\n        ],\n    }\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 ctx_mgr: StepContextManager,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None,\n                 prefix: str = ''):\n        super().__init__()\n\n        self.config = config\n        self.ctx_mgr = ctx_mgr\n\n        # build preprocessor\n        self.input_processor = InternS1ProInputProcessor(self.config, dtype)\n\n        # build vision model\n        self.visual = Qwen3VLVisionModel(\n            config.vision_config,\n            dtype=dtype,\n            device=device,\n            prefix=add_prefix('visual', prefix=prefix),\n        )\n\n        # build text model\n        self.language_model = Qwen3MoeModel(config.text_config,\n                                            dtype=dtype,\n                                            device=device,\n                                            prefix=add_prefix('language_model', prefix=prefix))\n\n        # build lm_head\n        self.lm_head = self.build_lm_head(config.text_config.hidden_size,\n                                          config.text_config.vocab_size,\n                                          bias=False,\n                                          dtype=dtype,\n                                          device=device)\n\n        # build time series model\n        if hasattr(config, 'ts_config'):\n            self.time_series = InternS1ProTimeSeriesModel(config.ts_config, dtype=dtype, device=device)\n\n        # for router replay\n        bm_ctx = get_build_model_context()\n        self.enable_return_routed_experts = bm_ctx.enable_return_routed_experts\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        past_key_values: List[List[torch.Tensor]],\n        attn_metadata: Any = None,\n        inputs_embeds: torch.Tensor = None,\n        pixel_values: torch.Tensor = None,\n        vis_cu_seqlens: torch.Tensor = None,\n        vis_pos_emb: torch.Tensor = None,\n        image_mask: torch.Tensor = None,\n        pos_embeds: torch.Tensor = None,\n        grid_thw: torch.Tensor = None,\n        # for time series\n        ts_values: torch.Tensor = None,\n        ts_lens: torch.Tensor = None,\n        ts_sr: torch.Tensor = None,\n        ts_mask: torch.Tensor = None,\n        **kwargs,\n    ):\n        \"\"\"Model forward, return logits.\"\"\"\n\n        if inputs_embeds is None:\n            inputs_embeds = self.get_input_embeddings()(input_ids)\n\n            if pixel_values is not None:\n                dtype = inputs_embeds.dtype\n                pixel_values = pixel_values.to(dtype)\n                vis_pos_emb = (vis_pos_emb[0].to(dtype), vis_pos_emb[1].to(dtype))\n\n                # get image embeds\n                # different from qwen3vl, interns1_1 does not use deepstack visual embeds\n                image_embeds, _ = self.visual(pixel_values,\n                                              cu_seqlens=vis_cu_seqlens,\n                                              rotary_pos_emb=vis_pos_emb,\n                                              pos_embeds=pos_embeds)\n\n                # split image embeds per sample\n                split_sizes = (grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist()\n                image_embeds = torch.split(image_embeds, split_sizes)\n                image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, dtype)\n\n                # mask and scatter to create final input embeddings\n                expanded_image_mask = image_mask.unsqueeze(-1).expand_as(inputs_embeds)\n                inputs_embeds = inputs_embeds.masked_scatter(expanded_image_mask, image_embeds)\n\n            elif ts_values is not None:\n                ts_embeds = self.time_series(ts_values, ts_lens, ts_sr)  # [B, T, C]\n                inputs_embeds = inputs_embeds.masked_scatter_(ts_mask[..., None], ts_embeds)\n\n        # router replay\n        all_routed_experts = None\n        if self.enable_return_routed_experts:\n            all_routed_experts = input_ids.new_empty((input_ids.size(1), self.config.text_config.num_hidden_layers,\n                                                      self.config.text_config.num_experts_per_tok),\n                                                     dtype=torch.uint16)\n\n        hidden_states = self.language_model(input_ids=input_ids,\n                                            position_ids=position_ids,\n                                            past_key_values=past_key_values,\n                                            attn_metadata=attn_metadata,\n                                            inputs_embeds=inputs_embeds,\n                                            all_routed_experts=all_routed_experts)\n\n        if all_routed_experts is None:\n            return hidden_states\n        return dict(hidden_states=hidden_states, all_routed_experts=all_routed_experts)\n\n    def get_input_embeddings(self):\n        \"\"\"Get input embeddings.\"\"\"\n        return self.language_model.get_input_embeddings()\n\n    def prepare_inputs_for_generation(\n        self,\n        past_key_values: List[List[torch.Tensor]],\n        inputs_embeds: Optional[torch.Tensor] = None,\n        context: StepContext = None,\n    ):\n        \"\"\"Prepare input.\"\"\"\n\n        # get input_ids, position_ids and attention metadatas\n        input_ids = context.input_ids\n        position_ids = context.position_ids\n        attn_metadata = context.attn_metadata\n\n        pixel_values = None\n        vis_cu_seqlens = None\n        vis_pos_emb = None\n        image_mask = None\n        grid_thw = None\n        pos_embeds = None\n        # for time series\n        ts_values = None\n        ts_lens = None\n        ts_sr = None\n        ts_mask = None\n        if context.input_multimodals is not None:\n            mm_inputs = [input_mm.get('mm_data', []) for input_mm in context.input_multimodals]\n            # flatten batch\n            mm_inputs = [item for sublist in mm_inputs for item in sublist]\n\n            if len(mm_inputs) > 0:\n                modality = mm_inputs[0].modality\n                image_token_id = mm_inputs[0].meta.get('image_token_id')\n                video_token_id = mm_inputs[0].meta.get('video_token_id')\n                ts_token_id = mm_inputs[0].meta.get('ts_token_id')\n\n                if modality == Modality.TIME_SERIES:\n                    ts_values = torch.cat([inp.data for inp in mm_inputs])\n                    ts_mask = input_ids == ts_token_id\n\n                    ts_lens = mm_inputs[0].meta['ts_lens']\n                    ts_sr = mm_inputs[0].meta['ts_sr']\n                else:\n                    pixel_values = torch.cat([inp.data for inp in mm_inputs])\n                    mm_token_id = image_token_id if modality == Modality.IMAGE else video_token_id\n                    image_mask = (input_ids == mm_token_id)\n\n                    grid_thw = torch.cat([data.meta['grid_thw'] for data in mm_inputs]).cpu()\n                    vis_pos_emb = self.visual.rot_pos_emb(grid_thw)\n                    pos_embeds = self.visual.fast_pos_embed_interpolate(grid_thw)\n                    vis_cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2],\n                                                             grid_thw[:, 0]).to(pixel_values.device)\n                    vis_cu_seqlens = vis_cu_seqlens.cumsum(dim=0, dtype=torch.int32)\n                    vis_pos_emb = vis_pos_emb.repeat(1, 2)\n                    vis_pos_emb = (vis_pos_emb.cos(), vis_pos_emb.sin())\n\n        # process vision embeddings\n        vision_embeddings = context.input_embeddings\n        vision_embedding_indexing = context.input_embedding_indexing\n        if vision_embeddings is not None and len(vision_embeddings) > 0:\n            if inputs_embeds is None:\n                inputs_embeds = self.get_input_embeddings()(input_ids)\n            inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds)\n\n        # inputs of forward\n        return dict(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            inputs_embeds=inputs_embeds,\n            pixel_values=pixel_values,\n            vis_cu_seqlens=vis_cu_seqlens,\n            vis_pos_emb=vis_pos_emb,\n            image_mask=image_mask,\n            grid_thw=grid_thw,\n            pos_embeds=pos_embeds,\n            # for time series\n            ts_values=ts_values,\n            ts_lens=ts_lens,\n            ts_sr=ts_sr,\n            ts_mask=ts_mask,\n        )\n\n    @classmethod\n    def rename_weight(cls, name: str) -> str:\n        \"\"\"Rename weight.\"\"\"\n        if name.startswith('model.language_model.'):\n            return 'language_model.' + name[len('model.language_model.'):]\n        elif name.startswith('model.visual.'):\n            return 'visual.' + name[len('model.visual.'):]\n        elif name.startswith('model.'):\n            return name[len('model.'):]\n        return name\n\n    def _load_weight_experts(self, name: str, loaded_weight: torch.Tensor, params_dict: Dict[str, nn.Parameter],\n                             expert_params_mapping: List):\n        \"\"\"Load weight experts.\"\"\"\n\n        for (param_name, weight_name, expert_id, shard_id) in expert_params_mapping:\n            if weight_name not in name:\n                continue\n            name = name.replace(weight_name, param_name)\n            param = params_dict[name]\n            load_weight(param, loaded_weight, expert_id=expert_id, shard_id=shard_id)\n            break\n        else:\n            param = params_dict[name]\n            load_weight(param, loaded_weight)\n\n    # modify from vllm qwen3vlmoe fused expert loading\n    def _load_weight_fused_experts(self, name: str, loaded_weight: torch.Tensor, params_dict: Dict[str, nn.Parameter],\n                                   fused_expert_params_mapping: List):\n        \"\"\"Load weight of fused expert weights.\"\"\"\n        num_experts = self.config.text_config.num_experts\n\n        for (param_name, weight_name) in fused_expert_params_mapping:\n            if weight_name not in name:\n                continue\n            name = name.replace(weight_name, param_name)\n            param = params_dict[name]\n\n            loaded_weight = loaded_weight.transpose(-1, -2)  # no bias\n            if 'gate_up' in name:\n                loaded_weight = loaded_weight.chunk(2, dim=-2)\n                w1 = loaded_weight[0]\n                w3 = loaded_weight[1]\n                for expert_id in range(num_experts):\n                    load_weight(param, w1[expert_id], expert_id=expert_id, shard_id='gate')\n                    load_weight(param, w3[expert_id], expert_id=expert_id, shard_id='up')\n            elif 'down' in name:\n                w2 = loaded_weight\n                for expert_id in range(num_experts):\n                    load_weight(param, w2[expert_id], expert_id=expert_id, shard_id='down')\n\n    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):\n        \"\"\"Load weights.\"\"\"\n        # modify from vllm\n        stacked_params_mapping = [\n            # (param_name, shard_name, shard_id)\n            ('.qkv_proj', '.q_proj', 'q'),\n            ('.qkv_proj', '.k_proj', 'k'),\n            ('.qkv_proj', '.v_proj', 'v'),\n            ('.gate_up_proj', '.gate_proj', 0),\n            ('.gate_up_proj', '.up_proj', 1),\n        ]\n\n        # expert mapping\n        num_experts = self.config.text_config.num_experts\n        expert_params_mapping = []\n        for exp_id in range(num_experts):\n            # (param_name, weight_name, expert_id, shard_id)\n            gate_param = ('.experts.gate_up', f'.experts.{exp_id}.gate_proj', exp_id, 'gate')\n            up_param = ('.experts.gate_up', f'.experts.{exp_id}.up_proj', exp_id, 'up')\n            down_param = ('.experts.down', f'.experts.{exp_id}.down_proj', exp_id, 'down')\n            expert_params_mapping += [gate_param, up_param, down_param]\n\n        # fused expert mapping\n        fused_expert_params_mapping = [\n            # (param_name, weight_name)\n            ('.experts.gate_up.weight', '.experts.gate_up_proj'),\n            ('.experts.down.weight', '.experts.down_proj'),\n        ]\n\n        params_dict = dict(self.named_parameters())\n        buffers_dict = dict(self.named_buffers())\n        for name, loaded_weight in weights:\n            if 'rotary_emb.inv_freq' in name:\n                continue\n            if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):\n                continue\n            if self.config.tie_word_embeddings and 'lm_head.weight' in name:\n                continue\n            name = name.replace('.block_sparse_moe.', '.mlp.')\n            if '.experts' in name:\n                is_fused_expert = ('experts.gate_up_proj' in name or 'experts.down_proj' in name)\n                if is_fused_expert:\n                    self._load_weight_fused_experts(name,\n                                                    loaded_weight,\n                                                    params_dict,\n                                                    fused_expert_params_mapping=fused_expert_params_mapping)\n                else:\n                    self._load_weight_experts(name,\n                                              loaded_weight,\n                                              params_dict,\n                                              expert_params_mapping=expert_params_mapping)\n            else:\n                for (param_name, weight_name, shard_id) in stacked_params_mapping:\n                    if weight_name not in name:\n                        continue\n                    name = name.replace(weight_name, param_name)\n                    param = params_dict[name]\n                    load_weight(param, loaded_weight, shard_id=shard_id)\n                    break\n                else:\n                    if '.qkv.' in name:\n                        param = params_dict[name]\n                        q, k, v = param.weight_spliter(loaded_weight)\n                        load_weight(param, q, shard_id='q')\n                        load_weight(param, k, shard_id='k')\n                        load_weight(param, v, shard_id='v')\n                    else:\n                        if name in params_dict:\n                            param = params_dict[name]\n                            load_weight(param, loaded_weight)\n                        elif name in buffers_dict:\n                            param = buffers_dict[name]\n                            load_weight(param, loaded_weight)\n\n    def get_input_processor(self) -> BaseModelInputProcessor:\n        \"\"\"Get input processor.\"\"\"\n        return self.input_processor\n\n\nclass InternS1ProInputProcessor(BaseModelInputProcessor):\n    \"\"\"InternS1Pro input processor.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype) -> None:\n        self.config = config\n        self.dtype = dtype\n\n    def _make_image_mm_data(self, input_mm: Dict[str, Any]) -> MultiModalData:\n        \"\"\"Make image MultiModalData.\"\"\"\n        pixel_values = input_mm['pixel_values'].to(self.dtype)\n        image_grid_thw = input_mm['image_grid_thw']\n        offset = input_mm['offset']\n        start = offset\n        image_token_id = input_mm['image_token_id']\n        num_pad = input_mm['image_tokens']\n        if isinstance(num_pad, torch.Tensor):\n            num_pad = num_pad.item()\n\n        mm_data = MultiModalData(modality=Modality.IMAGE,\n                                 data=pixel_values,\n                                 start=start,\n                                 end=start + num_pad,\n                                 meta=dict(grid_thw=image_grid_thw, image_token_id=image_token_id))\n        return mm_data\n\n    def _make_video_mm_data(self, input_mm: Dict[str, Any]) -> MultiModalData:\n        \"\"\"Make video MultiModalData.\"\"\"\n        pixel_values_videos = input_mm['pixel_values_videos'].to(self.dtype)\n        video_grid_thw = input_mm['video_grid_thw']\n        offset = input_mm['offset']\n        start = offset\n        video_token_id = input_mm['video_token_id']\n        num_pad = input_mm['video_tokens']\n        if isinstance(num_pad, torch.Tensor):\n            num_pad = num_pad.item()\n\n        mm_data = MultiModalData(modality=Modality.VIDEO,\n                                 data=pixel_values_videos,\n                                 start=start,\n                                 end=start + num_pad,\n                                 meta=dict(\n                                     grid_thw=video_grid_thw,\n                                     video_token_id=video_token_id,\n                                 ))\n        return mm_data\n\n    def _make_time_series_mm_data(self, input_mm: Dict[str, Any]) -> MultiModalData:\n        \"\"\"Make time series MultiModalData.\"\"\"\n        ts_values = input_mm['ts_values'].to(self.dtype)\n        offset = input_mm['offset']\n        ts_token_id = input_mm['ts_token_id']\n        ts_lens = input_mm['ts_lens']\n        ts_sr = input_mm['ts_sr']\n        num_pad = input_mm['ts_tokens']\n        if isinstance(num_pad, torch.Tensor):\n            num_pad = num_pad.item()\n\n        mm_data = MultiModalData(modality=Modality.TIME_SERIES,\n                                 data=ts_values,\n                                 start=offset,\n                                 end=offset + num_pad,\n                                 meta=dict(ts_lens=ts_lens, ts_sr=ts_sr, ts_token_id=ts_token_id))\n        return mm_data\n\n    def preprocess_input(self,\n                         input_ids: List[int],\n                         input_multimodals: List[Dict[str, Any]] = None,\n                         **kwargs) -> PreprocessInputResult:\n        \"\"\"Prepare multimodal input.\"\"\"\n        if input_multimodals is None or len(input_multimodals) == 0:\n            return input_ids, input_multimodals\n\n        input_mm_data = []\n        for input_mm in input_multimodals:\n            modality = input_mm.get('modality')\n            if modality == Modality.IMAGE:\n                mm_data = self._make_image_mm_data(input_mm)\n            elif modality == Modality.VIDEO:\n                mm_data = self._make_video_mm_data(input_mm)\n            elif modality == Modality.TIME_SERIES:\n                mm_data = self._make_time_series_mm_data(input_mm)\n            input_mm_data.append(mm_data)\n\n        result = PreprocessInputResult(input_ids=input_ids, input_multimodals=dict(mm_data=input_mm_data))\n\n        return result\n"
  },
  {
    "path": "lmdeploy/pytorch/models/interns1_pro_ts.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nimport math\nfrom typing import Optional, Tuple, Union\n\nimport torch\nfrom torch import nn\nfrom transformers.activations import ACT2FN\nfrom transformers.configuration_utils import PretrainedConfig\n\nfrom lmdeploy.pytorch.nn import LayerNorm\nfrom lmdeploy.pytorch.nn.linear import build_colwise_linear, build_rowwise_linear\n\nfrom .whisper import WhisperEncoderLayer\n\n\nclass InternS1ProTimeSeriesEncoder(nn.Module):\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        self.config = config\n\n        self.embed_dim = config.d_model\n        self.num_mel_bins = config.num_mel_bins\n        self.padding_idx = config.pad_token_id\n        self.max_source_positions = config.max_source_positions\n        self.embed_scale = math.sqrt(self.embed_dim) if config.scale_embedding else 1.0\n\n        self.conv1 = nn.Conv1d(self.num_mel_bins, self.embed_dim, kernel_size=3, padding=1, dtype=dtype, device=device)\n        self.conv2 = nn.Conv1d(self.embed_dim,\n                               self.embed_dim,\n                               kernel_size=3,\n                               stride=2,\n                               padding=1,\n                               dtype=dtype,\n                               device=device)\n        self.embed_positions = nn.Embedding(self.max_source_positions, self.embed_dim, dtype=dtype, device=device)\n\n        self.layers = nn.ModuleList(\n            [WhisperEncoderLayer(config, dtype=dtype, device=device) for _ in range(config.encoder_layers)])\n        self.layer_norm = LayerNorm(config.d_model, eps=1e-5, dtype=dtype, device=device)\n\n        self.adapt_in = build_colwise_linear(\n            in_features=config.ts_adapt_in_dim,\n            out_features=80,\n            bias=True,\n            dtype=dtype,\n            device=device,\n        )\n        self.adapt_out = build_rowwise_linear(\n            in_features=self.embed_dim,\n            out_features=config.ts_adapt_out_dim,\n            bias=True,\n            dtype=dtype,\n            device=device,\n        )\n\n    def _make_causal_mask(self,\n                          input_ids_shape: torch.Size,\n                          dtype: torch.dtype,\n                          device: torch.device,\n                          past_key_values_length: int = 0):\n        \"\"\"Make causal mask used for bi-directional self-attention.\"\"\"\n        bsz, tgt_len = input_ids_shape\n        mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)\n        mask_cond = torch.arange(mask.size(-1), device=device)\n        mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)\n        mask = mask.to(dtype)\n\n        if past_key_values_length > 0:\n            mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)\n        return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)\n\n    def _prepare_decoder_attention_mask(self, input_shape, inputs_embeds, past_key_values_length):\n        # create causal mask\n        # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]\n        combined_attention_mask = None\n\n        if input_shape[-1] > 1:\n            combined_attention_mask = self._make_causal_mask(\n                input_shape,\n                inputs_embeds.dtype,\n                device=inputs_embeds.device,\n                past_key_values_length=past_key_values_length,\n            )\n\n        return combined_attention_mask\n\n    def forward(self, input_features):\n        # (N, T, C) -> (T, N, C) -> (N, C, T)\n        input_features = input_features.permute(1, 0, 2)\n        input_features = self.adapt_in(input_features)\n        input_features = input_features.permute(1, 2, 0)\n\n        # (N, C, T) -> (N, C, T//2)\n        inputs_embeds = nn.functional.gelu(self.conv1(input_features))\n        inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))\n\n        # (N, C, T) -> (N, T, C)\n        inputs_embeds = inputs_embeds.permute(0, 2, 1)\n        embed_pos = self.embed_positions.weight\n\n        if inputs_embeds.shape[1] > embed_pos.shape[0]:\n            target_len = inputs_embeds.shape[1]\n            padding = [0, 0, 0, target_len - embed_pos.shape[0]]\n\n            embed_pos = nn.functional.pad(embed_pos, pad=padding, mode='constant', value=0)\n            hidden_states = inputs_embeds[:, :embed_pos.shape[0], :] + embed_pos\n        else:\n            hidden_states = inputs_embeds + embed_pos[:inputs_embeds.shape[1], :]\n\n        input_shape = inputs_embeds.size()[:-1]\n        past_key_values_length = 0\n        attention_mask = self._prepare_decoder_attention_mask(input_shape, inputs_embeds, past_key_values_length)\n\n        for idx, encoder_layer in enumerate(self.layers):\n            layer_outputs = encoder_layer(hidden_states, attention_mask)\n            hidden_states = layer_outputs\n\n        # (N, T, C) -> (T, N, C)\n        hidden_states = hidden_states.permute(1, 0, 2)\n        hidden_states = self.layer_norm(hidden_states)\n        hidden_states = self.adapt_out(hidden_states)\n\n        # (T, N, C) -> (N, T, C)\n        hidden_states = hidden_states.permute(1, 0, 2)\n\n        return hidden_states\n\n\nclass InternS1ProTimeSeriesConcatSubsampling(nn.Module):\n\n    def __init__(self, in_channels: int, concat_size: int):\n        super().__init__()\n        self.in_channels = in_channels\n        self.out_channels = in_channels * concat_size\n\n    def forward(self, ts_signals: torch.Tensor, ts_lens: torch.Tensor):\n        if ts_signals.shape[1] % 2 != 0:\n            ts_signals = ts_signals[:, :-1, :]\n        even_frames = ts_signals[:, ::2, :]\n        odd_frames = ts_signals[:, 1::2, :]\n        ts_signals = torch.cat((even_frames, odd_frames), dim=2)\n        ts_lens = ts_lens // 2\n        return ts_signals, ts_lens\n\n\nclass InternS1ProTimeSeriesFixPositionalEncoding(nn.Module):\n\n    def __init__(self, d_model, max_len=20000, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        pe = torch.zeros(max_len, d_model, dtype=torch.float)\n        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)\n        div_term = torch.exp(torch.arange(0, d_model, 2, dtype=torch.float) * (-math.log(10000.0) / d_model))\n        pe[:, 0::2] = torch.sin(position * div_term)\n        pe[:, 1::2] = torch.cos(position * div_term)\n        # hf forces float32 during init, but becomes bf16 during forward\n        pe = pe.unsqueeze(0).transpose(0, 1).to(dtype=dtype, device=device)  # (max_len, 1, d_model)\n        self.register_buffer('pe', pe, persistent=True)\n\n    def forward(self, x):\n        # x: (seq_len, batch_size, d_model)\n        x = x + self.pe[:x.size(0), :]\n        return x.clone()\n\n\nclass InternS1ProTimeSeriesMultiChannelAdaptiveSubsampling(nn.Module):\n\n    def __init__(self,\n                 hidden_dim=128,\n                 nhead=8,\n                 num_encoder_layers=1,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__()\n        self.conv = nn.Conv1d(in_channels=1,\n                              out_channels=hidden_dim,\n                              kernel_size=5,\n                              stride=1,\n                              padding=2,\n                              dtype=dtype,\n                              device=device)\n        encoder_layers = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=nhead, dtype=dtype, device=device)\n        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_encoder_layers)\n        self.pos_encoder = InternS1ProTimeSeriesFixPositionalEncoding(d_model=hidden_dim, dtype=dtype, device=device)\n        self.subsampling = InternS1ProTimeSeriesConcatSubsampling(128, 2)\n\n    def forward(self, inputs, input_lens, sr):\n        sr = torch.as_tensor(sr, dtype=torch.float32)\n        strides = torch.floor(160 / ((1 + torch.exp(-sr / 100))**6))\n        patch_sizes = strides * 2\n        patched_outputs = []\n        output_lens = []\n\n        for i in range(len(inputs)):\n            seq = inputs[i]  # [seq_len, num_channel]\n            ps = patch_sizes[i].item()\n            st = strides[i].item()\n            le = input_lens[i]\n\n            output_len = torch.ceil((le - ps) / st) + 1\n            pad_len = ((output_len - 1) * st + ps - le).long().item()\n            if seq.ndim == 1:\n                seq = seq.unsqueeze(-1)\n            seq = nn.functional.pad(seq, (0, 0, 0, pad_len), 'constant', 0)\n            assert output_len > 0, (seq.shape, ps, st, le, output_len)\n            output_lens.append(output_len)\n            indices = (torch.arange(0, output_len * st, st).unsqueeze(1) + torch.arange(ps)).long()\n            patched = seq[indices]\n\n            output = self.forward_encoder(patched)  # [num_patch, D]\n            patched_outputs.append(output)\n\n        outputs = nn.utils.rnn.pad_sequence(patched_outputs, batch_first=True)\n        output_lens = torch.tensor(output_lens).squeeze().to(outputs.device).long()\n        if output_lens.ndim == 0:\n            output_lens = output_lens.unsqueeze(0)\n\n        outputs, output_lens = self.subsampling(outputs.clone(), output_lens.clone())\n        return outputs, output_lens\n\n    def forward_encoder(self, x):\n        num_patch, patch_len, C = x.shape\n        # conv1\n        # treat each channel as an independent sample and feed it into conv1\n        x = x.reshape(num_patch * C, 1, patch_len)\n        x = nn.functional.relu((self.conv(x)))  # [B*C, D1, L]\n        x = x.permute(2, 0, 1)  # [L, B*C, D1]\n\n        x = self.pos_encoder(x)  # [L, B*C, D1]\n        x = self.transformer_encoder(x)\n        x = x.mean(0)\n\n        x = x.reshape(num_patch, C, -1)\n\n        return x.mean(1)\n\n\nclass InternS1ProTimeSeriesProjector(nn.Module):\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        self.layer_norm = LayerNorm(config.ts_hidden_dim, eps=1e-5, dtype=dtype, device=device)\n        self.linear_1 = build_colwise_linear(in_features=config.ts_hidden_dim,\n                                             out_features=config.out_hidden_size,\n                                             bias=True,\n                                             dtype=dtype,\n                                             device=device)\n        self.act = ACT2FN[config.activation_function]\n        self.linear_2 = build_rowwise_linear(in_features=config.out_hidden_size,\n                                             out_features=config.out_hidden_size,\n                                             bias=True,\n                                             dtype=dtype,\n                                             device=device)\n\n    def forward(self, ts_features):\n        hidden_states = self.layer_norm(ts_features)\n        hidden_states = self.linear_1(hidden_states)\n        hidden_states = self.act(hidden_states)\n        hidden_states = self.linear_2(hidden_states)\n        return hidden_states\n\n\nclass InternS1ProTimeSeriesModel(nn.Module):\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        self.config = config\n        self.encoder_embed = InternS1ProTimeSeriesMultiChannelAdaptiveSubsampling(dtype=dtype, device=device)\n        self.encoder = InternS1ProTimeSeriesEncoder(config, dtype=dtype, device=device)\n        self.projector = InternS1ProTimeSeriesProjector(config, dtype=dtype, device=device)\n\n    def forward(\n        self,\n        time_series_signals: Optional[torch.FloatTensor] = None,\n        ts_lens: Optional[torch.Tensor] = None,\n        sr: Optional[torch.Tensor] = None,\n        time_series_embeds: Optional[torch.FloatTensor] = None,\n    ) -> Union[Tuple]:\n        if time_series_signals is None and time_series_embeds is None:\n            raise ValueError('You have to specify time_series_signals or time_series_embeds')\n\n        # embedded values can be passed in directly, but the dimensions must match\n        if time_series_embeds is not None and len(\n                time_series_embeds.shape) == 3 and time_series_embeds.shape[-1] == self.config.ts_adapt_in_dim:\n            time_series_embeds = time_series_embeds\n        else:\n            if ((isinstance(time_series_signals, list) and len(time_series_signals[0].shape) == 2)\n                    or (isinstance(time_series_signals, torch.Tensor) and len(time_series_signals.shape) == 3)):\n                time_series_embeds, ts_lens = self.encoder_embed(time_series_signals, ts_lens, sr)\n            else:\n                raise ValueError(f'wrong time_series_signals size: {time_series_signals[0].shape}')\n\n        # [B, 64000, 1] -> [B, 200, 256] -> [B, 100, 1024]\n        last_hidden_state = self.encoder(input_features=time_series_embeds)\n\n        # ts_lens after encoder\n        ts_lens = (ts_lens + 1) // 2\n        assert torch.all(ts_lens > 0), f'The length of time_series_embeds is so small. ts_lens: {ts_lens}'\n\n        last_hidden_state = self.projector(last_hidden_state)\n        return last_hidden_state\n"
  },
  {
    "path": "lmdeploy/pytorch/models/internvl.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nfrom typing import Any, Dict, Iterable, List, Optional, Tuple\n\nimport torch\nimport torch.nn.functional as F\nfrom packaging import version\nfrom torch import nn\nfrom transformers.configuration_utils import PretrainedConfig\n\nfrom lmdeploy.pytorch.distributed import get_tp_world_rank\nfrom lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor, PreprocessInputResult\nfrom lmdeploy.pytorch.model_inputs import StepContext, StepContextManager\nfrom lmdeploy.pytorch.models.utils.micro_batch import enable_micro_batch, split_batch\nfrom lmdeploy.pytorch.multimodal.data_type import MultiModalData\nfrom lmdeploy.pytorch.nn import LayerNorm, RMSNorm\nfrom lmdeploy.pytorch.nn.linear import build_colwise_linear, build_o_proj, build_qkv_proj, build_rowwise_linear\nfrom lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight\n\nfrom .patch import build_model_from_hf_config\nfrom .utils.cudagraph import CudaGraphMixin\nfrom .utils.model import DeployModelMixinV1, vlm_model\n\n\nclass Gating(nn.Module):\n\n    def __init__(self, hidden_size=2048, expansion_factor=4, dtype=None, device=None):\n        super().__init__()\n\n        mid_dim = hidden_size * expansion_factor\n\n        def mlp_block(in_dim, out_dim):\n            return nn.Sequential(\n                nn.Linear(in_dim, out_dim, bias=True, dtype=dtype, device=device),\n                nn.GELU(),\n                nn.Identity(),\n                nn.Linear(out_dim, in_dim, bias=True, dtype=dtype, device=device),\n                nn.Identity(),\n                nn.LayerNorm(in_dim, dtype=dtype, device=device),\n            )\n\n        self.block1 = mlp_block(hidden_size, mid_dim)\n        self.block2 = mlp_block(hidden_size, mid_dim)\n        self.block3 = mlp_block(hidden_size, mid_dim)\n        self.block4 = mlp_block(hidden_size, mid_dim)\n\n        self.gate = nn.Sequential(\n            nn.LayerNorm(hidden_size, dtype=dtype, device=device),\n            nn.Linear(hidden_size, 2, bias=True, dtype=dtype, device=device)  # 2 experts\n        )\n\n    def forward(self, x):\n        x = x + self.block1(x)\n        x = x + self.block2(x)\n        x = x + self.block3(x)\n        x = x + self.block4(x)\n\n        logits = self.gate(x)  # shape: [B, 2]\n        probs = torch.softmax(logits, dim=-1)\n        return probs\n\n\nclass CrossAttentionPooling(nn.Module):\n\n    def __init__(self, dim, num_heads=16, dtype=None, device=None):\n        super().__init__()\n        self.query_token = nn.Parameter(torch.randn(1, dim, dtype=dtype, device=device))  # [1, D]\n\n        self.attn1 = nn.MultiheadAttention(embed_dim=dim,\n                                           num_heads=num_heads,\n                                           batch_first=True,\n                                           dtype=dtype,\n                                           device=device)\n        self.norm1 = nn.LayerNorm(dim, dtype=dtype, device=device)\n\n        self.attn2 = nn.MultiheadAttention(embed_dim=dim,\n                                           num_heads=num_heads,\n                                           batch_first=True,\n                                           dtype=dtype,\n                                           device=device)\n        self.norm2 = nn.LayerNorm(dim, dtype=dtype, device=device)\n\n        self.attn3 = nn.MultiheadAttention(embed_dim=dim,\n                                           num_heads=num_heads,\n                                           batch_first=True,\n                                           dtype=dtype,\n                                           device=device)\n        self.norm3 = nn.LayerNorm(dim, dtype=dtype, device=device)\n\n        self.attn4 = nn.MultiheadAttention(embed_dim=dim,\n                                           num_heads=num_heads,\n                                           batch_first=True,\n                                           dtype=dtype,\n                                           device=device)\n        self.norm4 = nn.LayerNorm(dim, dtype=dtype, device=device)\n\n    def forward(self, batched_tokens: list[torch.Tensor]):\n        \"\"\"\n        batched_tokens: List of Tensors of shape [Ti, D], length = B\n        \"\"\"\n        B = len(batched_tokens)\n        D = batched_tokens[0].shape[-1]\n        device = batched_tokens[0].device\n\n        # 1. Padding\n        max_len = max(t.shape[0] for t in batched_tokens)\n        dtype = self.query_token.dtype\n        padded = torch.zeros(B, max_len, D, dtype=dtype, device=device)\n        padding_mask = torch.ones(B, max_len, dtype=torch.bool, device=device)\n\n        for i, t in enumerate(batched_tokens):\n            L = t.shape[0]\n            padded[i, :L] = t\n            padding_mask[i, :L] = False\n\n        # 2. Query token: [B, 1, D]\n        query = self.query_token.unsqueeze(0).expand(B, -1, -1)  # learnable token for each sample\n\n        # 3. First attention\n        out1, _ = self.attn1(query, padded, padded, key_padding_mask=padding_mask)  # [B, 1, D]\n        out1 = self.norm1(out1)\n\n        # 4. Second attention\n        out2, _ = self.attn2(out1, padded, padded, key_padding_mask=padding_mask)  # [B, 1, D]\n        out2 = self.norm2(out2)\n\n        out3, _ = self.attn2(out2, padded, padded, key_padding_mask=padding_mask)  # [B, 1, D]\n        out3 = self.norm2(out3)\n\n        out4, _ = self.attn2(out3, padded, padded, key_padding_mask=padding_mask)  # [B, 1, D]\n        out4 = self.norm2(out4)\n\n        return out4.squeeze(1)\n\n\nclass InternVisionEmbeddings(nn.Module):\n    \"\"\"Intern vision embedding.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        self.config = config\n        self.embed_dim = config.hidden_size\n        self.image_size = config.image_size\n        self.patch_size = config.patch_size\n\n        self.class_embedding = nn.Parameter(torch.empty(1, 1, self.embed_dim, dtype=dtype, device=device), )\n\n        self.patch_embedding = nn.Conv2d(in_channels=3,\n                                         out_channels=self.embed_dim,\n                                         kernel_size=self.patch_size,\n                                         stride=self.patch_size,\n                                         dtype=dtype,\n                                         device=device)\n\n        self.num_patches = (self.image_size // self.patch_size)**2\n        self.num_positions = self.num_patches + 1\n\n        self.position_embedding = nn.Parameter(\n            torch.empty(1, self.num_positions, self.embed_dim, dtype=dtype, device=device))\n\n    def _get_pos_embed(self, pos_embed, H, W):\n        target_dtype = pos_embed.dtype\n        pos_embed = pos_embed.float().reshape(1, self.image_size // self.patch_size, self.image_size // self.patch_size,\n                                              -1).permute(0, 3, 1, 2)\n        pos_embed = F.interpolate(pos_embed, size=(H, W), mode='bicubic',\n                                  align_corners=False).reshape(1, -1, H * W).permute(0, 2, 1).to(target_dtype)\n        return pos_embed\n\n    def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:\n        target_dtype = self.patch_embedding.weight.dtype\n        patch_embeds = self.patch_embedding(pixel_values)  # shape = [*, channel, width, height]\n        batch_size, _, height, width = patch_embeds.shape\n        patch_embeds = patch_embeds.flatten(2).transpose(1, 2)\n        class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)\n        embeddings = torch.cat([class_embeds, patch_embeds], dim=1)\n        position_embedding = torch.cat(\n            [self.position_embedding[:, :1, :],\n             self._get_pos_embed(self.position_embedding[:, 1:, :], height, width)],\n            dim=1)\n        embeddings = embeddings + position_embedding.to(target_dtype)\n        return embeddings\n\n\nNORM2FN = {\n    'rms_norm': RMSNorm,\n    'layer_norm': LayerNorm,\n}\n\n\n@torch.compile(dynamic=True)\ndef pre_rms_norm(q: torch.Tensor, k: torch.Tensor) -> torch.Tensor:\n    \"\"\"Pre rms norm.\"\"\"\n    q = q.to(torch.float32)\n    k = k.to(torch.float32)\n    variance_q = (q * q).sum(-1, keepdim=True)\n    variance_k = (k * k).sum(-1, keepdim=True)\n    variance = torch.stack([variance_q, variance_k], dim=0)\n    return variance\n\n\n@torch.compile(dynamic=True)\ndef post_rms_norm(q: torch.Tensor, k: torch.Tensor, weight_q: torch.Tensor, weight_k: torch.Tensor,\n                  variance: torch.Tensor, eps: float, embed_dim: int, dtype: torch.dtype):\n    \"\"\"Post rms norm.\"\"\"\n    q = q.to(torch.float32)\n    k = k.to(torch.float32)\n    variance = variance / embed_dim + eps\n    variance_q, variance_k = variance\n    q = q * torch.rsqrt(variance_q)\n    q = q.to(dtype) * weight_q\n    k = k * torch.rsqrt(variance_k)\n    k = k.to(dtype) * weight_k\n    return q, k\n\n\nclass InternAttention(nn.Module):\n    \"\"\"Intern vl attention.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        self.config = config\n        quantization_config = getattr(config, 'quantization_config', None)\n        self.embed_dim = config.hidden_size\n        self.num_heads = config.num_attention_heads\n        self.head_dim = self.embed_dim // self.num_heads\n\n        self.qkv = build_qkv_proj(\n            self.embed_dim,\n            num_q_heads=self.num_heads,\n            num_kv_heads=self.num_heads,\n            head_size=self.head_dim,\n            bias=config.qkv_bias,\n            quant_config=quantization_config,\n            dtype=dtype,\n            device=device,\n        )\n\n        self.qk_normalization = config.qk_normalization\n\n        if self.qk_normalization:\n            self.q_norm = RMSNorm(\n                self.embed_dim,\n                eps=config.layer_norm_eps,\n                dtype=dtype,\n                device=device,\n                tp=True,\n                align=self.head_dim,\n            )\n            self.k_norm = RMSNorm(\n                self.embed_dim,\n                eps=config.layer_norm_eps,\n                dtype=dtype,\n                device=device,\n                tp=True,\n                align=self.head_dim,\n            )\n\n        self.scale = self.head_dim**-0.5\n\n        # o_proj\n        self.proj = build_o_proj(self.embed_dim,\n                                 self.embed_dim,\n                                 bias=True,\n                                 quant_config=quantization_config,\n                                 dtype=dtype,\n                                 device=device,\n                                 is_tp=True,\n                                 tp_align_size=self.head_dim)\n\n    def pre_rms_norm(self, q: torch.Tensor, k: torch.Tensor) -> torch.Tensor:\n        \"\"\"Pre rms norm.\"\"\"\n        return pre_rms_norm(q, k)\n\n    def post_rms_norm(self, q: torch.Tensor, k: torch.Tensor, variance: torch.Tensor,\n                      dtype: torch.dtype) -> torch.Tensor:\n        \"\"\"Post rms norm.\"\"\"\n        eps = self.config.layer_norm_eps\n        return post_rms_norm(q, k, self.q_norm.weight, self.k_norm.weight, variance, eps, self.embed_dim, dtype)\n\n    def qkv_norm(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:\n        import lmdeploy.pytorch.distributed as dist\n        q_shape = q.shape\n        k_shape = k.shape\n        q = q.flatten(-2, -1)\n        k = k.flatten(-2, -1)\n\n        tp, _ = get_tp_world_rank()\n        if tp == 1:\n            q = self.q_norm(q).view(q_shape)\n            k = self.k_norm(k).view(k_shape)\n            return q, k\n\n        # variance\n        variance = self.pre_rms_norm(q, k)\n        dist.all_reduce(variance)\n        q, k = self.post_rms_norm(q, k, variance, q.dtype)\n        q = q.view(q_shape)\n        k = k.view(k_shape)\n\n        return q, k\n\n    def forward(self, hidden_states):\n        \"\"\"forward.\"\"\"\n\n        # qkv proj\n        qkv_states = self.qkv(hidden_states)\n        q, k, v = self.qkv.split_qkv(qkv_states)\n\n        if self.qk_normalization:\n            q, k = self.qkv_norm(q, k)\n\n        q = q.transpose(1, 2)\n        k = k.transpose(1, 2)\n        v = v.transpose(1, 2)\n\n        attn_output = F.scaled_dot_product_attention(q, k, v, scale=self.scale)\n\n        # o proj\n        attn_output = attn_output.transpose(1, 2)\n        attn_output = attn_output.flatten(-2, -1)\n        attn_output = self.proj(attn_output)\n        return attn_output\n\n\nclass InternMLP(nn.Module):\n    \"\"\"Intern vl mlp.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        from transformers.activations import ACT2FN\n        self.config = config\n        quantization_config = getattr(config, 'quantization_config', None)\n        self.act = ACT2FN[config.hidden_act]\n\n        self.fc1 = build_colwise_linear(\n            config.hidden_size,\n            config.intermediate_size,\n            bias=True,\n            dtype=dtype,\n            device=device,\n            quant_config=quantization_config,\n            is_tp=True,\n            dp_disable_tp=True,\n        )\n\n        self.fc2 = build_rowwise_linear(\n            config.intermediate_size,\n            config.hidden_size,\n            bias=True,\n            quant_config=quantization_config,\n            dtype=dtype,\n            device=device,\n            is_tp=True,\n            dp_disable_tp=True,\n        )\n\n    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        hidden_states = self.fc1(hidden_states)\n        hidden_states = self.act(hidden_states)\n        hidden_states = self.fc2(hidden_states)\n        return hidden_states\n\n\nclass InternVisionEncoderLayer(nn.Module):\n    \"\"\"Intern vision encoder layer.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        self.config = config\n        self.embed_dim = config.hidden_size\n        self.intermediate_size = config.intermediate_size\n        self.norm_type = getattr(config, 'norm_type', 'rms_norm')\n\n        self.attn = InternAttention(config, dtype=dtype, device=device)\n        self.mlp = InternMLP(config, dtype=dtype, device=device)\n        self.norm1 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps, dtype=dtype, device=device)\n        self.norm2 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps, dtype=dtype, device=device)\n\n        self.ls1 = nn.Parameter(torch.empty(self.embed_dim, dtype=dtype, device=device))\n        self.ls2 = nn.Parameter(torch.empty(self.embed_dim, dtype=dtype, device=device))\n\n    @enable_micro_batch(param_name='hidden_states', index=0)\n    def _attn(self, hidden_states):\n        hidden_states = hidden_states + self.attn(self.norm1(hidden_states).to(hidden_states[0].dtype)) * self.ls1\n        return hidden_states\n\n    @enable_micro_batch(param_name='hidden_states', index=0)\n    def _mlp(self, hidden_states):\n        hidden_states = hidden_states + self.mlp(self.norm2(hidden_states).to(hidden_states.dtype)) * self.ls2\n        return hidden_states\n\n    def forward(\n        self,\n        hidden_states,\n    ):\n        hidden_states = self._attn(hidden_states)\n        hidden_states = self._mlp(hidden_states)\n        return hidden_states\n\n\nclass InternVisionEncoder(nn.Module):\n    \"\"\"Intern vision encoder.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        self.config = config\n        self.layers = nn.ModuleList(\n            [InternVisionEncoderLayer(config, dtype=dtype, device=device) for idx in range(config.num_hidden_layers)])\n\n    def forward(\n        self,\n        inputs_embeds,\n    ):\n        \"\"\"forward.\"\"\"\n        hidden_states = inputs_embeds\n        for _, encoder_layer in enumerate(self.layers):\n            layer_outputs = encoder_layer(hidden_states, )\n            hidden_states = layer_outputs\n        return hidden_states\n\n\n@vlm_model\nclass InternVisionModel(nn.Module):\n    \"\"\"Intern vision model.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        self.config = config\n\n        self.embeddings = InternVisionEmbeddings(config, dtype=dtype, device=device)\n        self.encoder = InternVisionEncoder(config, dtype=dtype, device=device)\n\n    def forward(\n        self,\n        pixel_values: Optional[torch.FloatTensor] = None,\n    ):\n        \"\"\"forward.\"\"\"\n        assert pixel_values.dim() == 4\n        hidden_states = self.embeddings(pixel_values)\n\n        encoder_outputs = self.encoder(inputs_embeds=hidden_states)\n        last_hidden_state = encoder_outputs\n\n        return last_hidden_state\n\n\nclass InternVLChatModel(nn.Module, DeployModelMixinV1, CudaGraphMixin):\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 ctx_mgr: StepContextManager,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__()\n        self.config = config\n        self.ctx_mgr = ctx_mgr\n        self.select_layer = config.select_layer\n\n        llm_config = config.llm_config\n        self.llm_arch_name = llm_config.architectures[0]\n        self.is_mono = self.llm_arch_name == 'InternLM2VEForCausalLM'\n\n        vision_config = config.vision_config\n        if self.is_mono:\n            from .internvl_patch import InternVisionPatchModel\n            self.vision_model = InternVisionPatchModel(\n                vision_config,\n                dtype=dtype,\n                device=device,\n            )\n        else:\n            self.vision_model = InternVisionModel(vision_config, dtype=dtype, device=device)\n\n        self.language_model = build_model_from_hf_config(llm_config, dtype=dtype, device=device)\n        self.lm_head = self.language_model.lm_head\n        vit_hidden_size = config.vision_config.hidden_size\n        llm_hidden_size = config.llm_config.hidden_size\n        self.downsample_ratio = config.downsample_ratio\n        self.mlp1 = nn.Sequential(\n            nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio)**2, dtype=dtype, device=device),\n            nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio)**2,\n                      llm_hidden_size,\n                      bias=True,\n                      dtype=dtype,\n                      device=device), nn.GELU(),\n            nn.Linear(llm_hidden_size, llm_hidden_size, bias=True, dtype=dtype, device=device))\n\n        # for Mono-InternVL\n        if self.is_mono:\n            assert dtype != torch.float16, ('Currently Mono-InternVL does not support FP16 due to'\n                                            'numerical instability. Please use BF16 instead.')\n\n        self.input_processor = InternVLInputProcessor(self.config, dtype)\n\n        self.compile_vit = False\n\n        self.flash_mode = getattr(config, 'flash_mode', None)\n        if self.flash_mode is not None:\n            self.flash_relative_threshold = config.flash_relative_threshold\n            self.flash_absolute_threshold = config.flash_absolute_threshold\n\n            self.mlp2 = nn.Sequential(\n                nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio)**4, dtype=dtype, device=device),\n                nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio)**4,\n                          llm_hidden_size * 2,\n                          bias=True,\n                          dtype=dtype,\n                          device=device), nn.GELU(), nn.Identity(),\n                nn.Linear(llm_hidden_size * 2, llm_hidden_size * 2, bias=True, dtype=dtype, device=device), nn.GELU(),\n                nn.Identity(), nn.Linear(llm_hidden_size * 2, llm_hidden_size, bias=True, dtype=dtype, device=device))\n\n            self.pooling_before_gating = CrossAttentionPooling(dim=vit_hidden_size, dtype=dtype, device=device)\n            self.gating = Gating(hidden_size=vit_hidden_size, dtype=dtype, device=device)\n\n    def compile_model(self):\n        torch_version = version.parse(torch.__version__)\n        if torch_version < version.parse('2.5.0'):\n            return\n\n        tp, _ = get_tp_world_rank()\n        if torch_version >= version.parse('2.6.0') and tp > 1:\n            torch._inductor.config.reorder_for_compute_comm_overlap = True\n            if isinstance(self.vision_model, InternVisionModel):\n                self.vision_model.encoder.forward = split_batch(self.vision_model.encoder.forward,\n                                                                'inputs_embeds',\n                                                                index=0)\n\n        self.extract_feature = torch.compile(self.extract_feature, mode='max-autotune-no-cudagraphs')\n        self.compile_vit = True\n        self.has_compiled_vit = False\n\n    def _mark_dynamic_once(self, pixel_values, dims):\n        \"\"\"Call torch._dynamo.mark_dynamic to avoid recompile.\"\"\"\n        if not self.compile_vit or self.has_compiled_vit or pixel_values is None:\n            return\n\n        torch._dynamo.mark_dynamic(pixel_values, dims)\n        self.has_compiled_vit = True\n\n    def pixel_shuffle(self, x, scale_factor=0.5):\n        n, w, h, c = x.size()\n        # N, W, H, C --> N, W, H * scale, C // scale\n        x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))\n        # N, W, H * scale, C // scale --> N, H * scale, W, C // scale\n        x = x.permute(0, 2, 1, 3).contiguous()\n        # N, H * scale, W, C // scale -->\n        # N, H * scale, W * scale, C // (scale ** 2)\n        x = x.view(n, int(h * scale_factor), int(w * scale_factor), int(c / (scale_factor * scale_factor)))\n        x = x.permute(0, 2, 1, 3).contiguous()\n        return x\n\n    def extract_feature(self, pixel_values):\n        \"\"\"Extract vision feature.\"\"\"\n        assert self.select_layer == -1\n        vit_embeds = self.vision_model(pixel_values)\n        if self.is_mono:\n            if int(vit_embeds.shape[1]**0.5)**2 != vit_embeds.shape[1]:\n                vit_embeds = vit_embeds[:, 1:, :]\n        else:\n            vit_embeds = vit_embeds[:, 1:, :]\n\n        h = w = int(vit_embeds.shape[1]**0.5)\n        vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)\n        vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio)\n        vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])\n        vit_embeds = self.mlp1(vit_embeds)\n        return vit_embeds\n\n    def compress_visual_tokens_in_sentence(\n        self,\n        input_embeds: torch.Tensor,\n        input_ids: torch.Tensor,\n        img_context_token_id: int,\n        gate_result,\n    ) -> tuple:\n        # reshape\n        B, N, C = input_embeds.shape\n        input_embeds = input_embeds.reshape(B * N, C)\n        input_ids = input_ids.reshape(B * N)\n\n        N, C = input_embeds.shape\n        lengths, starts, ends = self.get_image_num_per_sample(input_ids, img_context_token_id)\n\n        keep_mask = torch.ones(N, dtype=torch.bool, device=input_embeds.device)\n\n        total_blocks = 0\n        block_counts = []\n        for length in lengths.tolist():\n            if length % 256 != 0:\n                raise ValueError(f'l % 256 != 0, l = {length}')\n            num_blocks = length // 256\n            block_counts.append(num_blocks)\n            total_blocks += num_blocks\n\n        flag_idx = 0\n        for s, e, l, num_blocks in zip(starts.tolist(), ends.tolist(), lengths.tolist(), block_counts):\n            for i in range(num_blocks):\n                block_start = s + i * 256\n                block_end = block_start + 256\n\n                compress = gate_result[flag_idx]\n                flag_idx += 1\n\n                if compress:\n                    keep_mask[block_start + 64:block_end] = False\n\n        # update\n        new_input_embeds = input_embeds[keep_mask.to(input_embeds.device), :]\n        new_input_ids = input_ids[keep_mask.to(input_ids.device)]\n        new_image_mask = (new_input_ids == img_context_token_id)\n\n        # reshape back\n        new_input_ids = new_input_ids.reshape(B, -1)\n        new_input_embeds = new_input_embeds.reshape(B, -1, C)\n\n        # since multiple sequences may concat together, we need to update the seqlens individually\n        # we calculate compressed token len for each sequence, and get new len for each sequence\n        crt_ctx = self.ctx_mgr.current_context()\n        seq_lengths = crt_ctx.q_seqlens\n        # split the keep_mask into chunks corresponding to each original sequence\n        mask_chunks = torch.split(keep_mask, seq_lengths.tolist())\n        # the new length of each sequence is the number of tokens kept (sum of True values)\n        new_seq_lengths = [chunk.sum().item() for chunk in mask_chunks]\n\n        return new_input_embeds, new_input_ids, new_image_mask, new_seq_lengths\n\n    def get_image_num_per_sample(self, input_ids: torch.Tensor, img_context_token_id: int):\n        input_ids = input_ids.squeeze(0)  # (N,)\n        selected = (input_ids == img_context_token_id)\n        padded = torch.cat(\n            [torch.tensor([0], device=selected.device),\n             selected.int(),\n             torch.tensor([0], device=selected.device)])\n        diff = torch.diff(padded)\n\n        starts = (diff == 1).nonzero(as_tuple=True)[0]\n        ends = (diff == -1).nonzero(as_tuple=True)[0]\n        lengths = ends - starts\n\n        return lengths, starts, ends\n\n    def split_and_merge(self, features: torch.Tensor, split_sizes: torch.Tensor):\n        \"\"\"\n        features: Tensor of shape [T, 1024, 1024]\n        split_sizes: 1D Tensor like [3, 3, 4] — tile of each sample\n\n        returns: List of Tensors of shape [tile_i * 1024, 1024]\n        \"\"\"\n        # split features -> each sample a tile list\n        tile_splits = torch.split(features, split_sizes, dim=0)\n\n        # merge the first two dimensions: tile * 1024 × 1024\n        merged = [x.reshape(-1, x.shape[-1]) for x in tile_splits]\n\n        return merged\n\n    def extract_feature_flash(self, pixel_values, lengths):\n\n        vit_embeds_1024 = self.vision_model(pixel_values)\n\n        vit_embeds_1024 = vit_embeds_1024[:, 1:, :]\n        h = w = int(vit_embeds_1024.shape[1]**0.5)\n        vit_embeds_1024 = vit_embeds_1024.reshape(vit_embeds_1024.shape[0], h, w, -1)\n\n        # begin moe\n        lengths = [int(x) for x in lengths.tolist()]\n        vit_embeds_1024_split_and_merge = self.split_and_merge(vit_embeds_1024, lengths)\n\n        gate = self.pooling_before_gating(vit_embeds_1024_split_and_merge)\n        gate = self.gating(gate)\n\n        vit_embeds_256 = vit_embeds_1024\n\n        with torch.no_grad():\n            vit_embeds_64 = self.pixel_shuffle(vit_embeds_1024, scale_factor=self.downsample_ratio**2)\n            vit_embeds_64 = vit_embeds_64.reshape(vit_embeds_64.shape[0], -1, vit_embeds_64.shape[-1])\n            vit_embeds_64 = self.mlp2(vit_embeds_64)\n\n            vit_embeds_256 = self.pixel_shuffle(vit_embeds_256, scale_factor=self.downsample_ratio)\n            vit_embeds_256 = vit_embeds_256.reshape(vit_embeds_256.shape[0], -1, vit_embeds_256.shape[-1])\n            vit_embeds_256 = self.mlp1(vit_embeds_256)\n\n        return vit_embeds_64, vit_embeds_256, gate\n\n    def extract_and_compress(self, pixel_values: torch.Tensor, input_ids: torch.Tensor, img_context_token_id: int):\n        lang_embeds = self.language_model.get_input_embeddings()(input_ids)\n\n        self._mark_dynamic_once(pixel_values, [0])\n\n        lengths, starts, ends = self.get_image_num_per_sample(input_ids, img_context_token_id)\n        lengths = lengths // 256\n        lengths_sum = torch.ones(int(lengths.sum().item()), dtype=torch.int64)\n        lengths = lengths_sum.repeat_interleave(1)\n        vit_embeds_64, vit_embeds_256, gate_result = self.extract_feature_flash(pixel_values, lengths)\n\n        relative_threshold_value = torch.quantile(gate_result[:, 0].to(torch.float32), self.flash_relative_threshold)\n        gate_result = (gate_result[:, 0] > relative_threshold_value) & (gate_result[:, 0]\n                                                                        >= self.flash_absolute_threshold)\n\n        selected_embeds = [\n            vit_embeds_64[i] if gate_result[i] else vit_embeds_256[i] for i in range(gate_result.size(0))\n        ]\n\n        vit_embeds = torch.cat(selected_embeds, dim=0)\n\n        # compress visual tokens in sentence\n        new_lang_embeds, new_input_ids, new_image_mask, new_seq_lengths = self.compress_visual_tokens_in_sentence(\n            input_embeds=lang_embeds,\n            input_ids=input_ids,\n            img_context_token_id=img_context_token_id,\n            gate_result=gate_result,\n        )\n\n        return vit_embeds, new_lang_embeds, new_input_ids, new_image_mask, new_seq_lengths\n\n    def update_forward_inputs(self, input_ids: torch.Tensor, new_seqlens: List[int],\n                              context: StepContext) -> StepContext:\n        \"\"\"Update the forward inputs, position_ids and attention metadata.\"\"\"\n        from lmdeploy.pytorch.model_inputs import ModelInputs\n\n        crt_ctx = self.ctx_mgr.current_context()\n        assert crt_ctx is not None, 'Current context cannot be None.'\n\n        # update model metas\n        prev_lens = [0] * len(context.model_metas)\n        has_model_metas = context.model_metas is not None and context.model_metas[0] is not None\n        context.is_model_meta_updated = has_model_metas\n        if has_model_metas:\n            prev_lens = [meta.get('new_seqlen', 0) for meta in context.model_metas]\n\n            for idx, meta in enumerate(context.model_metas):\n                meta.update({'new_seqlen': prev_lens[idx] + new_seqlens[idx]})\n        else:\n            context.model_metas = [dict(new_seqlen=seqlen) for seqlen in new_seqlens]\n\n        # create new model inputs and context, to get updated position_ids and attn_metadata\n        device = input_ids.device\n        total_msgs = len(new_seqlens)\n        kv_seqlens = torch.tensor([meta['new_seqlen'] for meta in context.model_metas], dtype=torch.long)\n        new_model_inputs = ModelInputs(input_ids=input_ids,\n                                       seq_length=torch.tensor(new_seqlens, device=device, dtype=torch.long),\n                                       history_lengths=torch.tensor(prev_lens, device=device, dtype=torch.long),\n                                       block_offsets=crt_ctx.block_offsets,\n                                       is_decoding=False,\n                                       num_ignored_history=torch.zeros(total_msgs, device=device, dtype=torch.long),\n                                       max_q_seqlen=kv_seqlens.max().item(),\n                                       max_kv_seqlen=kv_seqlens.max().item(),\n                                       sum_kv_seqlen=kv_seqlens.sum().item(),\n                                       model_metas=context.model_metas)\n        new_ctx = self.ctx_mgr.build_context(new_model_inputs, crt_ctx.model_config, crt_ctx.cache_config)\n\n        # update attributes of the context in model agent\n        context.q_seqlens = new_ctx.q_seqlens\n\n        return new_ctx.position_ids, new_ctx.attn_metadata\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        past_key_values: List[List[torch.Tensor]],\n        attn_metadata: Any = None,\n        pixel_values: torch.Tensor = None,\n        image_mask: torch.Tensor = None,\n        inputs_embeds: torch.Tensor = None,\n        vision_embedding_indexing: torch.Tensor = None,\n        text_embedding_indexing: torch.Tensor = None,\n        image_token_id: int = None,\n        context: StepContext = None,\n        **kwargs,\n    ):\n        if inputs_embeds is None and pixel_values is not None:\n            if self.flash_mode:\n                # extract feature and compress visual tokens\n                vit_embeds, lang_embeds, input_ids, image_mask, new_seqlens = self.extract_and_compress(\n                    pixel_values, input_ids, image_token_id)\n\n                # update forward inputs\n                position_ids, attn_metadata = self.update_forward_inputs(input_ids, new_seqlens, context)\n            else:\n                # extract feature\n                self._mark_dynamic_once(pixel_values, [0])\n                vit_embeds = self.extract_feature(pixel_values)\n                lang_embeds = self.language_model.get_input_embeddings()(input_ids)\n\n            lang_embeds.masked_scatter_(image_mask[..., None], vit_embeds)\n\n            inputs_embeds = lang_embeds\n\n        if self.is_mono:\n            return self.language_model.forward(input_ids=input_ids,\n                                               inputs_embeds=inputs_embeds,\n                                               past_key_values=past_key_values,\n                                               position_ids=position_ids,\n                                               attn_metadata=attn_metadata,\n                                               vision_embedding_indexing=vision_embedding_indexing,\n                                               text_embedding_indexing=text_embedding_indexing)\n        else:\n            return self.language_model.forward(input_ids=input_ids,\n                                               inputs_embeds=inputs_embeds,\n                                               past_key_values=past_key_values,\n                                               position_ids=position_ids,\n                                               attn_metadata=attn_metadata)\n\n    def get_input_embeddings(self):\n        \"\"\"Get input embeddings.\"\"\"\n        return self.language_model.get_input_embeddings()\n\n    def prepare_inputs_for_generation(\n        self,\n        past_key_values: List[List[torch.Tensor]],\n        inputs_embeds: torch.Tensor = None,\n        context: StepContext = None,\n    ):\n        \"\"\"Prepare input.\"\"\"\n        input_ids = context.input_ids\n        position_ids = context.position_ids\n        attn_metadata = context.attn_metadata\n        vision_embeddings = context.input_embeddings\n        vision_embedding_indexing = None\n\n        # vision inputs\n        pixel_values = None\n        image_mask = None\n        image_token_id = None\n        if context.input_multimodals is not None:\n            pixel_values = [input_mm.get('image', []) for input_mm in context.input_multimodals]\n            # flatten batch\n            pixel_values = [data for im_data in pixel_values for data in im_data]\n            if len(pixel_values) > 0:\n                image_token_id = pixel_values[0].meta['image_token_id']\n                image_mask = input_ids == image_token_id\n                pixel_values = torch.cat([data.data for data in pixel_values])\n            else:\n                pixel_values = None\n                image_mask = None\n\n        if self.is_mono and pixel_values is not None:\n            vision_embedding_indexing = torch.arange(input_ids.shape[1], device=input_ids.device)\n            vision_embedding_indexing = vision_embedding_indexing[image_mask[0]]\n\n        # get inputs from context\n        if vision_embeddings is not None and len(vision_embeddings) > 0:\n            vision_embedding_indexing = context.input_embedding_indexing\n            if inputs_embeds is None:\n                inputs_embeds = self.get_input_embeddings()(input_ids)\n            inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds)\n\n        has_model_metas = context.model_metas is not None and context.model_metas[0] is not None\n        context.is_model_meta_updated = has_model_metas\n        if context.is_decoding:\n            if has_model_metas:\n                # NOTE, zhouxinyu, we need to consider the increasing batch in the decoding stage\n                # currently implementation will keep the batch size same as the prefill stage\n\n                # model meta from the previous step, therefore +1 for the current decoding step\n                new_kv_seqlens = [(meta['new_seqlen'] + 1) for meta in context.model_metas]\n\n                # update model metas for the next step\n                context.model_metas = [dict(new_seqlen=seqlen) for seqlen in new_kv_seqlens]\n\n                # update position ids, attn_metadata\n                new_kv_seqlens = torch.tensor(new_kv_seqlens, device=input_ids.device, dtype=torch.long)\n                position_ids = new_kv_seqlens - 1\n                attn_metadata.kv_seqlens = new_kv_seqlens\n                attn_metadata.cu_seqlens_k = torch.nn.functional.pad(\n                    torch.cumsum(new_kv_seqlens, dim=0, dtype=torch.int32), (1, 0))\n        else:\n            # in the case of long context, messages may be split into multiple segments and perform prefill sequentially\n            # 1. this will only be done when flash_mode is on\n            # 2. if it is a text segment, we update model metas before forward\n            # 3. if it is an image segment, we update model metas later, after vision forward / compression\n            is_text_segment = (inputs_embeds is None) and (pixel_values is None)\n\n            if self.flash_mode and is_text_segment:\n                crt_ctx = self.ctx_mgr.current_context()\n                seq_lengths = crt_ctx.q_seqlens\n\n                if has_model_metas:\n                    prev_lens = [meta.get('new_seqlen', 0) for meta in context.model_metas]\n\n                    for idx, meta in enumerate(context.model_metas):\n                        meta.update({'new_seqlen': prev_lens[idx] + seq_lengths[idx].item()})\n\n                    # update position ids, attn_metadata\n                    prev_lens = torch.tensor(prev_lens, device=input_ids.device, dtype=torch.long)\n                    ranges = torch.arange(0, input_ids.shape[1], device=input_ids.device)\n                    position_ids = prev_lens[:, None] + ranges[None, :]\n                    position_ids = position_ids\n                    attn_metadata.kv_seqlens = prev_lens + seq_lengths\n                else:\n                    # init model metas\n                    context.model_metas = [{'new_seqlen': seqlen} for seqlen in seq_lengths.tolist()]\n\n        if self.is_mono and vision_embedding_indexing is not None:\n            all_indices = torch.arange(input_ids.shape[1]).to(input_ids)\n            text_embedding_indexing = all_indices[~torch.isin(all_indices, vision_embedding_indexing)]\n            if vision_embedding_indexing.numel() == 0:\n                vision_embedding_indexing = None\n            if text_embedding_indexing.numel() == 0:\n                text_embedding_indexing = None\n            return dict(input_ids=input_ids,\n                        position_ids=position_ids,\n                        past_key_values=past_key_values,\n                        attn_metadata=attn_metadata,\n                        pixel_values=pixel_values,\n                        image_mask=image_mask,\n                        inputs_embeds=inputs_embeds,\n                        vision_embedding_indexing=vision_embedding_indexing,\n                        text_embedding_indexing=text_embedding_indexing,\n                        image_token_id=image_token_id,\n                        context=context)\n        else:\n            return dict(input_ids=input_ids,\n                        position_ids=position_ids,\n                        past_key_values=past_key_values,\n                        attn_metadata=attn_metadata,\n                        pixel_values=pixel_values,\n                        image_mask=image_mask,\n                        inputs_embeds=inputs_embeds,\n                        image_token_id=image_token_id,\n                        context=context)\n\n    def load_lora_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], adapter_id: int):\n        \"\"\"Load lora weights.\"\"\"\n\n        if hasattr(self.language_model, 'load_lora_weights'):\n            return self.language_model.load_lora_weights(weights, adapter_id)\n        else:\n            from lmdeploy.pytorch.adapter.adapter import load_lora_weights\n\n            return load_lora_weights(weights, adapter_id)\n\n    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):\n        \"\"\"Load weights.\"\"\"\n\n        lang_prefix = 'language_model.'\n        lang_prefix_length = len(lang_prefix)\n        new_weights = dict()\n        params_dict = dict(self.named_parameters())\n        for name, loaded_weight in weights:\n            if name.startswith(lang_prefix):\n                new_key = name[lang_prefix_length:]\n                new_weights[new_key] = loaded_weight\n                continue\n\n            if 'qkv' in name:\n                param = params_dict[name]\n                q, k, v = param.weight_spliter(loaded_weight)\n                load_weight(param, q, shard_id='q')\n                load_weight(param, k, shard_id='k')\n                load_weight(param, v, shard_id='v')\n            else:\n                param = params_dict[name]\n                load_weight(param, loaded_weight)\n\n        self.language_model.load_weights(new_weights.items())\n\n    def get_input_processor(self) -> BaseModelInputProcessor:\n        \"\"\"Get input processor.\"\"\"\n        return self.input_processor\n\n\nclass InternVLInputProcessor(BaseModelInputProcessor):\n    \"\"\"Internvl input processor.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype) -> None:\n        self.config = config\n        self.dtype = dtype\n\n        vision_config = config.vision_config\n        self.image_size = vision_config.image_size\n        self.patch_size = vision_config.patch_size\n        self.num_patches = (self.image_size // self.patch_size)**2\n        self.num_positions = self.num_patches + 1\n        self.vision_token_num = self.num_patches // 4\n\n    def preprocess_input(self,\n                         input_ids: List[int],\n                         input_multimodals: List[Dict[str, Any]] = None,\n                         **kwargs) -> PreprocessInputResult:\n        \"\"\"Prepare multimodal input.\"\"\"\n        if input_multimodals is None or len(input_multimodals) == 0:\n            return input_ids, input_multimodals\n\n        input_imgs = []\n        for input_mm in input_multimodals:\n            pixel_values = input_mm['pixel_values'].to(self.dtype)\n            offset = input_mm['offset']\n            image_token_id = input_mm['image_token_id']\n            num_pad = input_mm['image_tokens']\n            if isinstance(num_pad, torch.Tensor):\n                num_pad = num_pad.item()\n\n            mm_data = MultiModalData(data=pixel_values,\n                                     start=offset,\n                                     end=offset + num_pad,\n                                     meta=dict(image_token_id=image_token_id))\n            input_imgs.append(mm_data)\n\n        result = PreprocessInputResult(\n            input_ids=input_ids,\n            input_multimodals=dict(image=input_imgs),\n        )\n        return result\n"
  },
  {
    "path": "lmdeploy/pytorch/models/internvl3_hf.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nfrom typing import Any, Dict, Iterable, List, Optional, Tuple, Union\n\nimport torch\nimport torch.nn.functional as F\nfrom packaging import version\nfrom torch import nn\nfrom transformers.activations import ACT2FN\nfrom transformers.configuration_utils import PretrainedConfig\n\nfrom lmdeploy.pytorch.distributed import get_tp_world_rank\nfrom lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor, PreprocessInputResult\nfrom lmdeploy.pytorch.model_inputs import StepContext, StepContextManager\nfrom lmdeploy.pytorch.models.utils.micro_batch import enable_micro_batch, split_batch\nfrom lmdeploy.pytorch.multimodal.data_type import MultiModalData\nfrom lmdeploy.pytorch.nn import LayerNorm, RMSNorm\nfrom lmdeploy.pytorch.nn.linear import build_colwise_linear, build_o_proj, build_qkv_proj, build_rowwise_linear\nfrom lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight\n\nfrom .patch import build_model_from_hf_config\nfrom .utils.cudagraph import CudaGraphMixin\nfrom .utils.model import DeployModelMixinV1, vlm_model\n\n\n@torch.compile(dynamic=True)\ndef pre_rms_norm(q: torch.Tensor, k: torch.Tensor) -> torch.Tensor:\n    \"\"\"Pre rms norm.\"\"\"\n    q = q.to(torch.float32)\n    k = k.to(torch.float32)\n    variance_q = (q * q).sum(-1, keepdim=True)\n    variance_k = (k * k).sum(-1, keepdim=True)\n    variance = torch.stack([variance_q, variance_k], dim=0)\n    return variance\n\n\n@torch.compile(dynamic=True)\ndef post_rms_norm(q: torch.Tensor, k: torch.Tensor, weight_q: torch.Tensor, weight_k: torch.Tensor,\n                  variance: torch.Tensor, eps: float, embed_dim: int, dtype: torch.dtype):\n    \"\"\"Post rms norm.\"\"\"\n    q = q.to(torch.float32)\n    k = k.to(torch.float32)\n    variance = variance / embed_dim + eps\n    variance_q, variance_k = variance\n    q = q * torch.rsqrt(variance_q)\n    q = q.to(dtype) * weight_q\n    k = k * torch.rsqrt(variance_k)\n    k = k.to(dtype) * weight_k\n    return q, k\n\n\nclass InternVLVisionPatchEmbeddings(nn.Module):\n    \"\"\"This class turns `pixel_values` of shape `(batch_size, num_channels,\n    height, width)` into the initial `hidden_states` (patch embeddings) of\n    shape `(batch_size, seq_length, hidden_size)` to be consumed by a\n    Transformer.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        image_size, patch_size = config.image_size, config.patch_size\n        num_channels, hidden_size = config.num_channels, config.hidden_size\n\n        num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])\n        patch_shape = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])\n        self.image_size = image_size\n        self.patch_size = patch_size\n        self.num_channels = num_channels\n        self.num_patches = num_patches\n        self.patch_shape = patch_shape\n\n        self.projection = nn.Conv2d(num_channels,\n                                    hidden_size,\n                                    kernel_size=patch_size,\n                                    stride=patch_size,\n                                    dtype=dtype,\n                                    device=device)\n\n    def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:\n        batch_size, num_channels, height, width = pixel_values.shape\n        if num_channels != self.num_channels:\n            raise ValueError(\n                'Make sure that the channel dimension of the pixel values match with the one set in the configuration.')\n\n        embeddings = self.projection(pixel_values)\n        embeddings = embeddings.flatten(2).transpose(1, 2)\n\n        return embeddings\n\n\nclass InternVLVisionEmbeddings(nn.Module):\n    \"\"\"Intern vision embedding.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        self.config = config\n        self.embed_dim = config.hidden_size\n        self.image_size = config.image_size\n        self.patch_size = config.patch_size\n\n        self.cls_token = nn.Parameter(torch.empty(1, 1, self.embed_dim, dtype=dtype, device=device))\n        if config.use_mask_token:\n            self.mask_token = nn.Parameter(torch.empty(1, 1, self.embed_dim, dtype=dtype, device=device))\n        else:\n            self.mask_token = None\n        self.patch_embeddings = InternVLVisionPatchEmbeddings(config, dtype=dtype, device=device)\n\n        self.num_positions = self.patch_embeddings.num_patches + 1\n\n        if config.use_absolute_position_embeddings:\n            self.position_embeddings = nn.Parameter(\n                torch.empty(1, self.num_positions, self.embed_dim, dtype=dtype, device=device))\n        else:\n            self.position_embeddings = None\n\n    def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int):\n        num_patches = embeddings.shape[1] - 1\n        num_positions = self.position_embeddings.shape[1] - 1\n\n        # always interpolate when tracing to ensure the exported model works for dynamic input shapes\n        if num_patches == num_positions and height == width:\n            return self.position_embeddings\n\n        target_dtype = embeddings.dtype\n        class_pos_embed = self.position_embeddings[:, :1]\n        patch_pos_embed = self.position_embeddings[:, 1:]\n        dim = embeddings.shape[-1]\n        new_height = height // self.patch_size[0]\n        new_width = width // self.patch_size[1]\n        sqrt_num_positions = int(num_positions**0.5)\n        patch_pos_embed = patch_pos_embed.float().reshape(1, sqrt_num_positions, sqrt_num_positions,\n                                                          -1).permute(0, 3, 1, 2)\n        patch_pos_embed = F.interpolate(patch_pos_embed,\n                                        size=(new_height, new_width),\n                                        mode='bicubic',\n                                        align_corners=False)\n        patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim).to(target_dtype)\n\n        return torch.cat((class_pos_embed, patch_pos_embed), dim=1)\n\n    def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:\n        _, _, height, width = pixel_values.shape\n        patch_embeds = self.patch_embeddings(pixel_values)  # shape = [*, channel, width, height]\n        batch_size = patch_embeds.shape[0]\n        cls_token = self.cls_token.expand(batch_size, 1, -1)\n        embeddings = torch.cat([cls_token, patch_embeds], dim=1)\n        if self.position_embeddings is not None:\n            position_embeddings = self.interpolate_pos_encoding(embeddings, height, width)\n            embeddings = embeddings + position_embeddings\n        return embeddings\n\n\nNORM2FN = {\n    'rms_norm': RMSNorm,\n    'layer_norm': LayerNorm,\n}\n\n\nclass InternVLVisionAttention(nn.Module):\n    \"\"\"Intern vl attention.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        self.config = config\n        quantization_config = getattr(config, 'quantization_config', None)\n        self.embed_dim = config.hidden_size\n        self.num_heads = config.num_attention_heads\n        self.head_dim = self.embed_dim // self.num_heads\n\n        self.qkv_proj = build_qkv_proj(\n            self.embed_dim,\n            num_q_heads=self.num_heads,\n            num_kv_heads=self.num_heads,\n            head_size=self.head_dim,\n            bias=config.attention_bias,\n            quant_config=quantization_config,\n            dtype=dtype,\n            device=device,\n        )\n\n        self.use_qk_norm = config.use_qk_norm\n\n        if self.use_qk_norm:\n            self.q_norm = RMSNorm(\n                self.embed_dim,\n                eps=config.layer_norm_eps,\n                dtype=dtype,\n                device=device,\n                tp=True,\n                align=self.head_dim,\n            )\n            self.k_norm = RMSNorm(\n                self.embed_dim,\n                eps=config.layer_norm_eps,\n                dtype=dtype,\n                device=device,\n                tp=True,\n                align=self.head_dim,\n            )\n\n        self.scale = self.head_dim**-0.5\n\n        # o_proj\n        self.projection_layer = build_o_proj(self.embed_dim,\n                                             self.embed_dim,\n                                             bias=True,\n                                             quant_config=quantization_config,\n                                             dtype=dtype,\n                                             device=device,\n                                             is_tp=True,\n                                             tp_align_size=self.head_dim)\n\n    def pre_rms_norm(self, q: torch.Tensor, k: torch.Tensor) -> torch.Tensor:\n        \"\"\"Pre rms norm.\"\"\"\n        return pre_rms_norm(q, k)\n\n    def post_rms_norm(self, q: torch.Tensor, k: torch.Tensor, variance: torch.Tensor, dtype: torch.dtype):\n        \"\"\"Post rms norm.\"\"\"\n        eps = self.config.layer_norm_eps\n        return post_rms_norm(q, k, self.q_norm.weight, self.k_norm.weight, variance, eps, self.embed_dim, dtype)\n\n    def qkv_norm(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:\n        import lmdeploy.pytorch.distributed as dist\n        q_shape = q.shape\n        k_shape = k.shape\n        q = q.flatten(-2, -1)\n        k = k.flatten(-2, -1)\n\n        tp, _ = get_tp_world_rank()\n        if tp == 1:\n            q = self.q_norm(q).view(q_shape)\n            k = self.k_norm(k).view(k_shape)\n            return q, k\n\n        # variance\n        variance = self.pre_rms_norm(q, k)\n        dist.all_reduce(variance)\n        q, k = self.post_rms_norm(q, k, variance, q.dtype)\n        q = q.view(q_shape)\n        k = k.view(k_shape)\n\n        return q, k\n\n    def forward(self, hidden_states):\n        \"\"\"forward.\"\"\"\n\n        # qkv proj\n        qkv_states = self.qkv_proj(hidden_states)\n        q, k, v = self.qkv_proj.split_qkv(qkv_states)\n\n        if self.use_qk_norm:\n            q, k = self.qkv_norm(q, k)\n\n        q = q.transpose(1, 2)\n        k = k.transpose(1, 2)\n        v = v.transpose(1, 2)\n\n        attn_output = F.scaled_dot_product_attention(q, k, v, scale=self.scale)\n\n        # o proj\n        attn_output = attn_output.transpose(1, 2)\n        attn_output = attn_output.flatten(-2, -1)\n        attn_output = self.projection_layer(attn_output)\n        return attn_output\n\n\nclass InternVLVisionMLP(nn.Module):\n    \"\"\"Intern vl mlp.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n\n        self.config = config\n        quantization_config = getattr(config, 'quantization_config', None)\n        self.act = ACT2FN[config.hidden_act]\n\n        self.fc1 = build_colwise_linear(\n            config.hidden_size,\n            config.intermediate_size,\n            bias=True,\n            dtype=dtype,\n            device=device,\n            quant_config=quantization_config,\n            is_tp=True,\n            dp_disable_tp=True,\n        )\n\n        self.fc2 = build_rowwise_linear(\n            config.intermediate_size,\n            config.hidden_size,\n            bias=True,\n            quant_config=quantization_config,\n            dtype=dtype,\n            device=device,\n            is_tp=True,\n            dp_disable_tp=True,\n        )\n\n    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        hidden_states = self.fc1(hidden_states)\n        hidden_states = self.act(hidden_states)\n        hidden_states = self.fc2(hidden_states)\n        return hidden_states\n\n\nclass InternVLVisionLayer(nn.Module):\n    \"\"\"Intern vision layer.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        self.config = config\n        self.embed_dim = config.hidden_size\n        self.intermediate_size = config.intermediate_size\n        self.norm_type = getattr(config, 'norm_type', 'rms_norm')\n\n        self.attention = InternVLVisionAttention(config, dtype=dtype, device=device)\n        self.mlp = InternVLVisionMLP(config, dtype=dtype, device=device)\n        self.layernorm_before = NORM2FN[self.norm_type](self.embed_dim,\n                                                        eps=config.layer_norm_eps,\n                                                        dtype=dtype,\n                                                        device=device)\n        self.layernorm_after = NORM2FN[self.norm_type](self.embed_dim,\n                                                       eps=config.layer_norm_eps,\n                                                       dtype=dtype,\n                                                       device=device)\n\n        self.lambda_1 = nn.Parameter(torch.empty(self.embed_dim, dtype=dtype, device=device))\n        self.lambda_2 = nn.Parameter(torch.empty(self.embed_dim, dtype=dtype, device=device))\n\n    @enable_micro_batch(param_name='hidden_states', index=0)\n    def _attn(self, hidden_states):\n        hidden_states = hidden_states + self.attention(self.layernorm_before(hidden_states).to(\n            hidden_states[0].dtype)) * self.lambda_1\n        return hidden_states\n\n    @enable_micro_batch(param_name='hidden_states', index=0)\n    def _mlp(self, hidden_states):\n        hidden_states = hidden_states + self.mlp(self.layernorm_after(hidden_states).to(\n            hidden_states.dtype)) * self.lambda_2\n        return hidden_states\n\n    def forward(\n        self,\n        hidden_states,\n    ):\n        hidden_states = self._attn(hidden_states)\n        hidden_states = self._mlp(hidden_states)\n        return hidden_states\n\n\nclass InternVLVisionEncoder(nn.Module):\n    \"\"\"Intern vision encoder.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        self.config = config\n        self.layer = nn.ModuleList(\n            [InternVLVisionLayer(config, dtype=dtype, device=device) for idx in range(config.num_hidden_layers)])\n\n    def forward(\n        self,\n        inputs_embeds,\n    ):\n        \"\"\"forward.\"\"\"\n        hidden_states = inputs_embeds\n        for _, encoder_layer in enumerate(self.layer):\n            layer_outputs = encoder_layer(hidden_states, )\n            hidden_states = layer_outputs\n        return hidden_states\n\n\n@vlm_model\nclass InternVLVisionModel(nn.Module):\n    \"\"\"Intern vision model.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        self.config = config\n\n        self.embeddings = InternVLVisionEmbeddings(config, dtype=dtype, device=device)\n        self.encoder = InternVLVisionEncoder(config, dtype=dtype, device=device)\n        self.layernorm = None\n        if not config.use_mean_pooling:\n            self.layernorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps, dtype=dtype, device=device)\n\n    def get_input_embeddings(self):\n        return self.embeddings.patch_embeddings\n\n    def forward(\n        self,\n        pixel_values: Optional[torch.FloatTensor] = None,\n    ):\n        \"\"\"forward.\"\"\"\n        assert pixel_values.dim() == 4\n        hidden_states = self.embeddings(pixel_values)\n        hidden_states = self.encoder(inputs_embeds=hidden_states)\n        last_hidden_state = hidden_states\n        if self.layernorm is not None:\n            last_hidden_state = self.layernorm(hidden_states)\n\n        return hidden_states, last_hidden_state\n\n\nclass InternVLMultiModalProjector(nn.Module):\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        input_dim = config.vision_config.hidden_size * int(1 / config.downsample_ratio)**2\n        self.layer_norm = LayerNorm(input_dim, eps=1e-5, dtype=dtype, device=device)\n\n        quantization_config = getattr(config.text_config, 'quantization_config', None)\n        self.act = ACT2FN[config.projector_hidden_act]\n        self.linear_1 = build_colwise_linear(\n            input_dim,\n            config.text_config.hidden_size,\n            bias=True,\n            dtype=dtype,\n            device=device,\n            quant_config=quantization_config,\n            is_tp=True,\n            dp_disable_tp=True,\n        )\n\n        self.linear_2 = build_rowwise_linear(\n            config.text_config.hidden_size,\n            config.text_config.hidden_size,\n            bias=True,\n            quant_config=quantization_config,\n            dtype=dtype,\n            device=device,\n            is_tp=True,\n            dp_disable_tp=True,\n        )\n\n    def forward(self, image_features):\n        hidden_states = self.layer_norm(image_features)\n        hidden_states = self.linear_1(hidden_states)\n        hidden_states = self.act(hidden_states)\n        hidden_states = self.linear_2(hidden_states)\n        return hidden_states\n\n\nclass InternVLForConditionalGeneration(nn.Module, DeployModelMixinV1, CudaGraphMixin):\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 ctx_mgr: StepContextManager,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__()\n        self.config = config\n        self.ctx_mgr = ctx_mgr\n\n        self.vision_tower = InternVLVisionModel(config.vision_config, dtype=dtype, device=device)\n        self.multi_modal_projector = InternVLMultiModalProjector(config, dtype=dtype, device=device)\n        self.language_model = build_model_from_hf_config(config.text_config, dtype=dtype, device=device)\n        self.lm_head = self.language_model.lm_head\n        self.vision_feature_layer = config.vision_feature_layer\n        self.vision_feature_select_strategy = config.vision_feature_select_strategy\n\n        self.input_processor = InternVLProcessor(self.config, dtype)\n\n        self.compile_vit = False\n\n    def compile_model(self):\n        torch_version = version.parse(torch.__version__)\n        if torch_version < version.parse('2.5.0'):\n            return\n\n        tp, _ = get_tp_world_rank()\n        if torch_version >= version.parse('2.6.0') and tp > 1:\n            torch._inductor.config.reorder_for_compute_comm_overlap = True\n            if isinstance(self.vision_tower, InternVLVisionModel):\n                self.vision_tower.encoder.forward = split_batch(self.vision_tower.encoder.forward,\n                                                                'inputs_embeds',\n                                                                index=0)\n\n        self.get_image_features = torch.compile(self.get_image_features, mode='max-autotune-no-cudagraphs')\n        self.compile_vit = True\n        self.has_compiled_vit = False\n\n    def _mark_dynamic_once(self, pixel_values, dims):\n        \"\"\"Call torch._dynamo.mark_dynamic to avoid recompile.\"\"\"\n        if not self.compile_vit or self.has_compiled_vit or pixel_values is None:\n            return\n\n        torch._dynamo.mark_dynamic(pixel_values, dims)\n        self.has_compiled_vit = True\n\n    def get_input_embeddings(self):\n        \"\"\"Get input embeddings.\"\"\"\n        return self.language_model.get_input_embeddings()\n\n    def get_image_features(\n        self,\n        pixel_values: torch.FloatTensor,\n        vision_feature_layer: Union[int, List[int]],\n        vision_feature_select_strategy: str,\n        **kwargs,\n    ):\n        \"\"\"Obtains image last hidden states from the vision tower and apply\n        multimodal projection.\n\n        Args:\n            pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)\n               The tensors corresponding to the input images.\n            vision_feature_layer (`int` or `List[int]`):\n                Layer index or list of layer indices to extract features from.\n        Returns:\n            vision_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`.\n        \"\"\"\n        downsample_ratio = self.config.downsample_ratio\n        hidden_states, last_hidden_state = self.vision_tower(pixel_values=pixel_values)\n        if vision_feature_layer == -1:\n            vision_features = last_hidden_state\n        else:\n            vision_features = hidden_states[vision_feature_layer]\n        if vision_feature_select_strategy == 'default':\n            vision_features = vision_features[:, 1:, :]\n\n        # Calculate dimensions based on vision features\n        channels = vision_features.shape[1]\n        feature_size = int(channels**0.5)\n        batch_size = vision_features.shape[0]\n\n        # Reshape tensor to spatial dimensions\n        vision_features = vision_features.reshape(batch_size, feature_size, feature_size, -1)\n\n        # Apply downsampling using pixel shuffle\n        vision_features = self.pixel_shuffle(vision_features, scale_factor=downsample_ratio)\n\n        # Reshape tensor to prepare for projection\n        vision_features = vision_features.reshape(batch_size, -1, vision_features.shape[-1])\n\n        # Project features through multi-modal projector\n        vision_features = self.multi_modal_projector(vision_features)\n\n        return vision_features\n\n    def pixel_shuffle(self, vision_features: torch.Tensor, scale_factor: float = 0.5):\n        \"\"\"Perform pixel shuffle downsampling on vision features.\n\n        Args:\n            vision_features (`torch.Tensor`):\n                Input tensor of shape (batch_size, width, height, channels).\n            scale_factor (`float`, *optional*, defaults to `0.5`):\n                Factor by which to downsample. Default is 0.5, which halves the dimensions.\n\n        Returns:\n            vision_features (`torch.Tensor`):\n                Downsampled tensor of shape (batch_size, height*scale_factor,\n                                                width*scale_factor, channels/(scale_factor^2)).\n        \"\"\"\n        batch_size, width, height, channels = vision_features.size()\n\n        if height % scale_factor != 0 or width % scale_factor != 0:\n            raise ValueError('Height and width must be divisible by scale_factor for proper downsampling.')\n\n        # Reshape to allow downsampling\n        vision_features = vision_features.view(batch_size, width, int(height * scale_factor),\n                                               int(channels / scale_factor))\n        # Permute dimensions to align downsampled axis correctly\n        vision_features = vision_features.permute(0, 2, 1, 3).contiguous()\n\n        # Reshape to achieve final downsampled dimensions\n        vision_features = vision_features.view(batch_size, int(height * scale_factor), int(width * scale_factor),\n                                               int(channels / (scale_factor**2)))\n\n        # Swap height and width back for proper orientation\n        vision_features = vision_features.permute(0, 2, 1, 3).contiguous()\n\n        return vision_features\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        past_key_values: List[List[torch.Tensor]],\n        attn_metadata: Any = None,\n        pixel_values: torch.Tensor = None,\n        image_mask: torch.Tensor = None,\n        inputs_embeds: torch.Tensor = None,\n        **kwargs,\n    ):\n        if inputs_embeds is None and pixel_values is not None:\n            # extract feature\n            self._mark_dynamic_once(pixel_values, [0])\n            vit_embeds = self.get_image_features(\n                pixel_values,\n                self.vision_feature_layer,\n                self.vision_feature_select_strategy,\n            )\n            lang_embeds = self.get_input_embeddings()(input_ids)\n            lang_embeds.masked_scatter_(image_mask[..., None], vit_embeds)\n\n            inputs_embeds = lang_embeds\n            input_ids = None\n\n        if (input_ids is None) ^ (inputs_embeds is not None):\n            raise ValueError('You must specify exactly one of input_ids or inputs_embeds')\n\n        if inputs_embeds is None:\n            inputs_embeds = self.get_input_embeddings()(input_ids)\n\n        outputs = self.language_model.forward(input_ids=input_ids,\n                                              inputs_embeds=inputs_embeds,\n                                              past_key_values=past_key_values,\n                                              position_ids=position_ids,\n                                              attn_metadata=attn_metadata)\n        return outputs\n\n    def prepare_inputs_for_generation(\n        self,\n        past_key_values: List[List[torch.Tensor]],\n        inputs_embeds: torch.Tensor = None,\n        context: StepContext = None,\n    ):\n        \"\"\"Prepare input.\"\"\"\n        input_ids = context.input_ids\n        position_ids = context.position_ids\n        attn_metadata = context.attn_metadata\n        vision_embeddings = context.input_embeddings\n        vision_embedding_indexing = None\n\n        # vision inputs\n        pixel_values = None\n        image_mask = None\n        if context.input_multimodals is not None:\n            pixel_values = [input_mm.get('image', []) for input_mm in context.input_multimodals]\n            # flatten batch\n            pixel_values = [data for im_data in pixel_values for data in im_data]\n            if len(pixel_values) > 0:\n                image_token_id = pixel_values[0].meta['image_token_id']\n                image_mask = input_ids == image_token_id\n                pixel_values = torch.cat([data.data for data in pixel_values])\n            else:\n                pixel_values = None\n                image_mask = None\n\n        # get inputs from context\n        if vision_embeddings is not None and len(vision_embeddings) > 0:\n            vision_embedding_indexing = context.input_embedding_indexing\n            if inputs_embeds is None:\n                inputs_embeds = self.get_input_embeddings()(input_ids)\n            inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds)\n\n        return dict(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            pixel_values=pixel_values,\n            image_mask=image_mask,\n            inputs_embeds=inputs_embeds,\n        )\n\n    def load_lora_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], adapter_id: int):\n        \"\"\"Load lora weights.\"\"\"\n\n        if hasattr(self.model.language_model, 'load_lora_weights'):\n            return self.model.language_model.load_lora_weights(weights, adapter_id)\n        else:\n            from lmdeploy.pytorch.adapter.adapter import load_lora_weights\n\n            return load_lora_weights(weights, adapter_id)\n\n    @classmethod\n    def rename_weight(cls, name: str) -> str:\n        \"\"\"Rename weight.\"\"\"\n        if name == 'lm_head.weight':\n            return 'language_model.lm_head.weight'\n        elif name.startswith('model.language_model.'):\n            return 'language_model.model.' + name[len('model.language_model.'):]\n        elif name.startswith('model.'):\n            return name[len('model.'):]\n        return name\n\n    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):\n        \"\"\"Load weights.\"\"\"\n\n        lang_prefix = 'language_model.'\n        lang_prefix_length = len(lang_prefix)\n        new_weights = dict()\n        params_dict = dict(self.named_parameters())\n        vision_stacked_params_mapping = [\n            # (param_name, shard_name, shard_id)\n            ('.qkv_proj', '.q_proj', 'q'),\n            ('.qkv_proj', '.k_proj', 'k'),\n            ('.qkv_proj', '.v_proj', 'v'),\n        ]\n        for name, loaded_weight in weights:\n\n            if name.startswith(lang_prefix):\n                new_key = name[lang_prefix_length:]\n                new_weights[new_key] = loaded_weight\n                continue\n\n            for (param_name, weight_name, shard_id) in vision_stacked_params_mapping:\n                if weight_name not in name:\n                    continue\n                name = name.replace(weight_name, param_name)\n                param = params_dict[name]\n                load_weight(param, loaded_weight, shard_id=shard_id)\n                break\n            else:\n                param = params_dict[name]\n                load_weight(param, loaded_weight)\n\n        self.language_model.load_weights(new_weights.items())\n\n    def get_input_processor(self) -> BaseModelInputProcessor:\n        \"\"\"Get input processor.\"\"\"\n        return self.input_processor\n\n\nclass InternVLProcessor(BaseModelInputProcessor):\n    \"\"\"Internvl input processor.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype) -> None:\n        self.config = config\n        self.dtype = dtype\n\n    def preprocess_input(self,\n                         input_ids: List[int],\n                         input_multimodals: List[Dict[str, Any]] = None,\n                         **kwargs) -> PreprocessInputResult:\n        \"\"\"Prepare multimodal input.\"\"\"\n        if input_multimodals is None or len(input_multimodals) == 0:\n            return input_ids, input_multimodals\n\n        input_imgs = []\n        for input_mm in input_multimodals:\n            pixel_values = input_mm['pixel_values'].to(self.dtype)\n            offset = input_mm['offset']\n            image_token_id = input_mm['image_token_id']\n            num_pad = input_mm['image_tokens']\n            if isinstance(num_pad, torch.Tensor):\n                num_pad = num_pad.item()\n\n            mm_data = MultiModalData(data=pixel_values,\n                                     start=offset,\n                                     end=offset + num_pad,\n                                     meta=dict(image_token_id=image_token_id))\n            input_imgs.append(mm_data)\n\n        result = PreprocessInputResult(\n            input_ids=input_ids,\n            input_multimodals=dict(image=input_imgs),\n        )\n        return result\n"
  },
  {
    "path": "lmdeploy/pytorch/models/internvl_patch.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom typing import Optional\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn\nfrom transformers.configuration_utils import PretrainedConfig\n\n\nclass InternVisionEmbeddings(nn.Module):\n    \"\"\"Mono vision.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        self.config = config\n        self.embed_dim = config.hidden_size\n        self.image_size = config.image_size\n        self.patch_size = config.patch_size\n\n        self.class_embedding = nn.Parameter(torch.empty(1, 1, self.embed_dim, dtype=dtype, device=device), )\n\n        self.patch_embedding = nn.Conv2d(in_channels=3,\n                                         out_channels=self.embed_dim,\n                                         kernel_size=self.patch_size,\n                                         stride=self.patch_size,\n                                         dtype=dtype,\n                                         device=device)\n\n        self.num_patches = (self.image_size // self.patch_size)**2\n        self.num_positions = self.num_patches + 1\n\n        self.position_embedding = nn.Parameter(\n            torch.empty(1, self.num_positions, self.embed_dim, dtype=dtype, device=device))\n\n    def _get_pos_embed(self, pos_embed, H, W):\n        target_dtype = pos_embed.dtype\n        pos_embed = pos_embed.float().reshape(1, self.image_size // self.patch_size, self.image_size // self.patch_size,\n                                              -1).permute(0, 3, 1, 2)\n        pos_embed = F.interpolate(pos_embed, size=(H, W), mode='bicubic', align_corners=False)\n        pos_embed = pos_embed.reshape(1, -1, H * W).permute(0, 2, 1).to(target_dtype)\n        return pos_embed\n\n    def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:\n        target_dtype = self.patch_embedding.weight.dtype\n        patch_embeds = self.patch_embedding(pixel_values)  # shape = [*, channel, width, height]\n        batch_size, _, height, width = patch_embeds.shape\n        patch_embeds = patch_embeds.flatten(2).transpose(1, 2)\n        class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)\n        embeddings = torch.cat([class_embeds, patch_embeds], dim=1)\n        position_embedding = torch.cat(\n            [self.position_embedding[:, :1, :],\n             self._get_pos_embed(self.position_embedding[:, 1:, :], height, width)],\n            dim=1)\n        embeddings = embeddings + position_embedding.to(target_dtype)\n        return embeddings\n\n\nclass InternVisionPatchModel(nn.Module):\n    \"\"\"Mono vision.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        self.config = config\n        self.embeddings = InternVisionEmbeddings(config, dtype=dtype, device=device)\n\n    def forward(\n        self,\n        pixel_values: Optional[torch.FloatTensor] = None,\n    ):\n        if len(pixel_values.shape) != 4:\n            raise ValueError(f'wrong pixel_values size: {pixel_values.shape}')\n\n        hidden_states = self.embeddings(pixel_values)[:, 1:]\n        return hidden_states\n"
  },
  {
    "path": "lmdeploy/pytorch/models/llama.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nfrom typing import Any, Dict, Iterable, List, Optional, Tuple\n\nimport torch\nfrom torch import nn\nfrom transformers.models.llama import LlamaConfig\n\nfrom lmdeploy.pytorch.model_inputs import StepContext, StepContextManager\nfrom lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, RMSNorm, SiluAndMul, build_rotary_embedding_from_config\nfrom lmdeploy.pytorch.nn.linear import (build_down_linear, build_gateup_linear, build_o_proj, build_qkv_proj,\n                                        build_rowwise_linear)\nfrom lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight\n\nfrom .utils.cudagraph import CudaGraphMixin\n\n\nclass LlamaAttention(nn.Module):\n    \"\"\"Rewrite module of LlamaAttention.\"\"\"\n\n    def __init__(self, config: LlamaConfig, dtype: torch.dtype = None, device: torch.device = None, is_tp: bool = True):\n        super().__init__()\n        quantization_config = getattr(config, 'quantization_config', None)\n        num_heads = config.num_attention_heads\n        num_key_value_heads = config.num_key_value_heads\n        hidden_size = config.hidden_size\n        head_dim = getattr(config, 'head_dim', hidden_size // num_heads)\n        num_replicate_kv_heads = getattr(config, 'num_replicate_key_value_heads', 1)\n        # packed qkv\n        self.qkv_proj = build_qkv_proj(\n            hidden_size,\n            num_q_heads=num_heads,\n            num_kv_heads=num_key_value_heads,\n            head_size=head_dim,\n            bias=config.attention_bias,\n            quant_config=quantization_config,\n            dtype=dtype,\n            device=device,\n            num_replicate_kv_heads=num_replicate_kv_heads,\n            is_tp=is_tp,\n        )\n\n        # rotary embedding\n        self.apply_rotary_pos_emb = ApplyRotaryEmb()\n\n        # attention\n        self.attn_fwd = Attention(\n            num_heads,\n            head_dim,\n            num_kv_heads=num_key_value_heads,\n            v_head_size=head_dim,\n        )\n\n        # o_proj\n        self.o_proj = build_o_proj(num_heads * head_dim,\n                                   hidden_size,\n                                   bias=config.attention_bias,\n                                   quant_config=quantization_config,\n                                   dtype=dtype,\n                                   device=device,\n                                   is_tp=is_tp)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],\n        past_key_value: Optional[Tuple[torch.Tensor]] = None,\n        attn_metadata: Any = None,\n    ):\n        \"\"\"Rewrite of LlamaAttention.forward.\"\"\"\n        # qkv proj\n        qkv_states = self.qkv_proj(hidden_states)\n        # (-1, heads, head_dim)\n        qkv_states = qkv_states.flatten(0, -2)\n        query_states, key_states, value_states = self.qkv_proj.split_qkv(qkv_states)\n\n        # apply rotary embedding\n        cos, sin = rotary_pos_emb\n        query_states, key_states = self.apply_rotary_pos_emb(\n            query_states,\n            key_states,\n            cos,\n            sin,\n            inplace=True,\n        )\n\n        # attention\n        attn_output = self.attn_fwd(\n            query_states,\n            key_states,\n            value_states,\n            past_key_value[0],\n            past_key_value[1],\n            attn_metadata,\n            k_scales_zeros=None if len(past_key_value) == 2 else past_key_value[2],\n            v_scales_zeros=None if len(past_key_value) == 2 else past_key_value[3],\n            inplace=True,\n        )\n        attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1)\n\n        # o proj\n        attn_output = self.o_proj(attn_output)\n        return attn_output\n\n\nclass LlamaMLP(nn.Module):\n    \"\"\"Llama mlp.\"\"\"\n\n    def __init__(self, config: LlamaConfig, dtype: torch.dtype = None, device: torch.device = None, is_tp: bool = True):\n        super().__init__()\n        quantization_config = getattr(config, 'quantization_config', None)\n        # gate up\n        mlp_bias = getattr(config, 'mlp_bias', False)\n        self.gate_up_proj = build_gateup_linear(\n            config.hidden_size,\n            [config.intermediate_size, config.intermediate_size],\n            bias=mlp_bias,\n            dtype=dtype,\n            device=device,\n            quant_config=quantization_config,\n            is_tp=is_tp,\n        )\n\n        # silu and mul\n        self.act_fn = SiluAndMul(inplace=True)\n\n        # down\n        self.down_proj = build_down_linear(config.intermediate_size,\n                                           config.hidden_size,\n                                           bias=mlp_bias,\n                                           quant_config=quantization_config,\n                                           dtype=dtype,\n                                           device=device,\n                                           is_tp=is_tp)\n\n    def forward(self, x):\n        \"\"\"forward.\"\"\"\n        gate_up = self.gate_up_proj(x)\n        act = self.act_fn(gate_up)\n        return self.down_proj(act)\n\n\nclass LlamaDecoderLayer(nn.Module):\n    \"\"\"Llama decoder layer.\"\"\"\n\n    def __init__(self,\n                 config: LlamaConfig,\n                 layer_idx: int,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None,\n                 is_tp: bool = True):\n        super().__init__()\n        self.layer_idx = layer_idx\n        quantization_config = getattr(config, 'quantization_config', None)\n\n        # build attention layer\n        self.self_attn = LlamaAttention(config, dtype=dtype, device=device, is_tp=is_tp)\n\n        # build MLP\n        self.mlp = LlamaMLP(config, dtype=dtype, device=device, is_tp=is_tp)\n\n        # build input layer norm\n        self.input_layernorm = RMSNorm(config.hidden_size,\n                                       config.rms_norm_eps,\n                                       quant_config=quantization_config,\n                                       dtype=dtype,\n                                       device=device)\n\n        # build attention layer norm\n        self.post_attention_layernorm = RMSNorm(config.hidden_size,\n                                                config.rms_norm_eps,\n                                                quant_config=quantization_config,\n                                                dtype=dtype,\n                                                device=device)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],\n        past_key_value: Optional[List[torch.FloatTensor]],\n        residual: Optional[torch.Tensor] = None,\n        attn_metadata: Any = None,\n    ):\n\n        if residual is None:\n            residual = hidden_states\n            hidden_states = self.input_layernorm(hidden_states)\n        else:\n            hidden_states, residual = self.input_layernorm(hidden_states, residual)\n\n        # Self Attention\n        hidden_states = self.self_attn(\n            hidden_states=hidden_states,\n            rotary_pos_emb=rotary_pos_emb,\n            past_key_value=past_key_value,\n            attn_metadata=attn_metadata,\n        )\n\n        # Fully Connected\n        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)\n        hidden_states = self.mlp(hidden_states)\n\n        outputs = (hidden_states, residual)\n        return outputs\n\n\nclass LlamaModel(nn.Module):\n    \"\"\"Llama model.\"\"\"\n\n    def __init__(self, config: LlamaConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n\n        self.embed_tokens = nn.Embedding(config.vocab_size,\n                                         config.hidden_size,\n                                         self.padding_idx,\n                                         dtype=dtype,\n                                         device=device)\n\n        # build all decode layers\n        self.layers = nn.ModuleList([\n            LlamaDecoderLayer(config, layer_idx, dtype=dtype, device=device)\n            for layer_idx in range(config.num_hidden_layers)\n        ])\n        self.aux_hidden_state_layers: Tuple[int] = getattr(config, 'aux_hidden_state_layers', tuple())\n        # build norm\n        self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)\n\n        # build rotary embedding in LlamaModel\n        self.rotary_emb = build_rotary_embedding_from_config(config)\n\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        attn_metadata: Any = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n    ):\n        \"\"\"Rewrite of LlamaModel.forward.\"\"\"\n\n        # token embedding\n        if inputs_embeds is None:\n            inputs_embeds = self.embed_tokens(input_ids)\n\n        hidden_states = inputs_embeds\n\n        # rotary embedding\n        cos, sin = self.rotary_emb(hidden_states, position_ids)\n        cos, sin = cos[0], sin[0]\n        rotary_pos_emb = (cos, sin)\n\n        # for eagle3\n        aux_hidden_states = []\n        # decoding\n        residual = None\n        for idx, decoder_layer in enumerate(self.layers):\n            past_key_value = past_key_values[idx]\n            if idx in self.aux_hidden_state_layers:\n                aux_hidden_states.append(hidden_states + residual)\n            hidden_states, residual = decoder_layer(\n                hidden_states,\n                rotary_pos_emb=rotary_pos_emb,\n                past_key_value=past_key_value,\n                residual=residual,\n                attn_metadata=attn_metadata,\n            )\n\n        # norm\n        hidden_states, _ = self.norm(hidden_states, residual)\n\n        if len(aux_hidden_states) > 0:\n            aux_hidden_states = torch.cat(aux_hidden_states, dim=-1)\n            return dict(hidden_states=hidden_states, aux_hidden_states=aux_hidden_states)\n        return hidden_states\n\n    def get_input_embeddings(self):\n        \"\"\"Get input embeddings.\"\"\"\n        return self.embed_tokens\n\n\nclass LlamaForCausalLM(nn.Module, CudaGraphMixin):\n    \"\"\"Rewrote model of LlamaForCausalLM.\"\"\"\n\n    packed_modules_mapping = {\n        'qkv_proj': [\n            'q_proj',\n            'k_proj',\n            'v_proj',\n        ],\n        'gate_up_proj': [\n            'gate_proj',\n            'up_proj',\n        ],\n    }\n\n    def __init__(self,\n                 config: LlamaConfig,\n                 ctx_mgr: StepContextManager,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__()\n        self.config = config\n        self.ctx_mgr = ctx_mgr\n        self.dtype = dtype\n        # build LLamaModel\n        self.model = LlamaModel(config, dtype=dtype, device=device)\n        # build lm_head\n        self.lm_head = build_rowwise_linear(config.hidden_size,\n                                            config.vocab_size,\n                                            bias=False,\n                                            dtype=dtype,\n                                            device=device)\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        past_key_values: List[List[torch.Tensor]],\n        attn_metadata: Any = None,\n        inputs_embeds: torch.Tensor = None,\n        **kwargs,\n    ):\n        \"\"\"Model forward, return logits.\"\"\"\n        hidden_states = self.model(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            inputs_embeds=inputs_embeds,\n        )\n        return hidden_states\n\n    def update_weights(self):\n        \"\"\"Update weights.\"\"\"\n        if self.config.tie_word_embeddings:\n            self.lm_head.weight = self.model.embed_tokens.weight\n\n    def get_logits(self, hidden_states: torch.Tensor):\n        \"\"\"Compute logits of the model output.\"\"\"\n        hidden_states = hidden_states.to(dtype=self.dtype)\n        return self.lm_head(hidden_states)\n\n    def get_input_embeddings(self):\n        \"\"\"Get input embeddings.\"\"\"\n        return self.model.get_input_embeddings()\n\n    def get_outputs_cudagraph(self, output_buffers: Dict[str, torch.Tensor], input_ids: torch.Tensor, **kwargs):\n        \"\"\"Get outputs from buffers.\"\"\"\n        num_tokens = input_ids.size(-1)\n        outputs = dict()\n        outputs['hidden_states'] = output_buffers['hidden_states'][:, :num_tokens]\n        if 'aux_hidden_states' in output_buffers:\n            outputs['aux_hidden_states'] = output_buffers['aux_hidden_states'][:, :num_tokens]\n        return outputs\n\n    def prepare_inputs_for_generation(\n        self,\n        past_key_values: List[List[torch.Tensor]],\n        inputs_embeds: Optional[torch.Tensor] = None,\n        context: StepContext = None,\n    ):\n        \"\"\"Prepare input.\"\"\"\n        # get input_ids, position_ids and attention metadatas\n        input_ids = context.input_ids\n        position_ids = context.position_ids\n        attn_metadata = context.attn_metadata\n\n        # process vision embeddings\n        vision_embeddings = context.input_embeddings\n        vision_embedding_indexing = context.input_embedding_indexing\n        if vision_embeddings is not None and len(vision_embeddings) > 0:\n            if inputs_embeds is None:\n                inputs_embeds = self.get_input_embeddings()(input_ids)\n            inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds)\n\n        # inputs of forward\n        return dict(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            inputs_embeds=inputs_embeds,\n        )\n\n    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):\n        \"\"\"Load weights.\"\"\"\n        # modify from vllm\n        stacked_params_mapping = [\n            # (param_name, shard_name, shard_id)\n            ('.qkv_proj', '.q_proj', 'q'),\n            ('.qkv_proj', '.k_proj', 'k'),\n            ('.qkv_proj', '.v_proj', 'v'),\n            ('.gate_up_proj', '.gate_proj', 0),\n            ('.gate_up_proj', '.up_proj', 1),\n        ]\n\n        params_dict = dict(self.named_parameters())\n        for name, loaded_weight in weights:\n            if 'rotary_emb.inv_freq' in name:\n                continue\n            if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):\n                continue\n            if self.config.tie_word_embeddings and 'lm_head.weight' in name:\n                continue\n            for (param_name, weight_name, shard_id) in stacked_params_mapping:\n                if weight_name not in name:\n                    continue\n                name = name.replace(weight_name, param_name)\n                param = params_dict[name]\n                load_weight(param, loaded_weight, shard_id=shard_id)\n                break\n            else:\n                param = params_dict[name]\n                load_weight(param, loaded_weight)\n"
  },
  {
    "path": "lmdeploy/pytorch/models/llama4.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom typing import Any, Dict, Iterable, List, Optional, Tuple\n\nimport torch\nfrom torch import nn\nfrom transformers.models.llama4 import Llama4Config, Llama4TextConfig, Llama4VisionConfig\n\nimport lmdeploy.pytorch.distributed as dist\nfrom lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor, PreprocessInputResult\nfrom lmdeploy.pytorch.model_inputs import StepContext, StepContextManager\nfrom lmdeploy.pytorch.multimodal.data_type import MultiModalData\nfrom lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, RMSNorm, SiluAndMul, build_rotary_embedding_from_config\nfrom lmdeploy.pytorch.nn.linear import (build_colwise_linear, build_merged_colwise_linear, build_qkv_proj,\n                                        build_rowwise_linear)\nfrom lmdeploy.pytorch.nn.moe import build_fused_moe\nfrom lmdeploy.pytorch.nn.rotary_embedding import get_rope_theta\nfrom lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight\n\nfrom .utils.cudagraph import CudaGraphMixin\n\n\nclass Llama4TextAttention(nn.Module):\n    \"\"\"attention.\"\"\"\n\n    def __init__(self,\n                 config: Llama4TextConfig,\n                 layer_idx: int,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__()\n\n        self.config = config\n        self.layer_idx = layer_idx\n        self.head_dim = getattr(config, 'head_dim', config.hidden_size // config.num_attention_heads)\n        self.num_attention_heads = config.num_attention_heads\n        self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads\n        self.num_key_value_heads = config.num_key_value_heads\n        self.scaling = self.head_dim**-0.5\n        self.attn_scale = config.attn_scale\n        self.floor_scale = config.floor_scale\n        self.attn_temperature_tuning = config.attn_temperature_tuning\n        self.is_causal = True\n        self.use_rope = int((layer_idx + 1) % 4 != 0)  # rope unused for dense layers\n        self.attn_bias = config.attention_bias\n\n        # qkv\n        self.qkv_proj = build_qkv_proj(\n            config.hidden_size,\n            num_q_heads=self.num_attention_heads,\n            num_kv_heads=self.num_key_value_heads,\n            head_size=self.head_dim,\n            bias=self.attn_bias,\n            dtype=dtype,\n            device=device,\n        )\n\n        self.apply_rotary_pos_emb = ApplyRotaryEmb()\n\n        self.attn_fwd = Attention(\n            self.num_attention_heads,\n            self.head_dim,\n            num_kv_heads=self.num_key_value_heads,\n            v_head_size=self.head_dim,\n        )\n\n        # o_proj\n        self.o_proj = build_rowwise_linear(config.num_attention_heads * self.head_dim,\n                                           config.hidden_size,\n                                           bias=self.attn_bias,\n                                           dtype=dtype,\n                                           device=device,\n                                           is_tp=True)\n\n        if self.config.use_qk_norm and self.use_rope:\n            self.qk_norm = RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],\n        past_key_value: Optional[Tuple[torch.Tensor]] = None,\n        attn_metadata: Any = None,\n    ):\n        \"\"\"forward.\"\"\"\n        qkv_states = self.qkv_proj(hidden_states)\n        # (-1, heads, head_dim)\n        qkv_states = qkv_states.flatten(0, -2)\n        query_states, key_states, value_states = self.qkv_proj.split_qkv(qkv_states)\n\n        if self.use_rope:\n            cos, sin = rotary_pos_emb\n            # TODO: fuse apply rotary pos emb\n            query_states = query_states.unflatten(-1, (-1, 2)).transpose(-1, -2).flatten(-2)\n            key_states = key_states.unflatten(-1, (-1, 2)).transpose(-1, -2).flatten(-2)\n            query_states, key_states = self.apply_rotary_pos_emb(\n                query_states,\n                key_states,\n                cos,\n                sin,\n            )\n            query_states = query_states.unflatten(-1, (2, -1)).transpose(-1, -2).flatten(-2)\n            key_states = key_states.unflatten(-1, (2, -1)).transpose(-1, -2).flatten(-2)\n\n        if hasattr(self, 'qk_norm'):\n            query_states = self.qk_norm(query_states)\n            key_states = self.qk_norm(key_states)\n\n        attn_output = self.attn_fwd(\n            query_states,\n            key_states,\n            value_states,\n            past_key_value[0],\n            past_key_value[1],\n            attn_metadata,\n            k_scales_zeros=None if len(past_key_value) == 2 else past_key_value[2],\n            v_scales_zeros=None if len(past_key_value) == 2 else past_key_value[3],\n            inplace=True,\n        )\n        attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1)\n\n        attn_output = self.o_proj(attn_output)\n\n        return attn_output\n\n\nclass Llama4TextMLP(nn.Module):\n    \"\"\"attention.\"\"\"\n\n    def __init__(self,\n                 config: Llama4TextConfig,\n                 intermediate_size: int = None,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None,\n                 is_tp: bool = True,\n                 all_reduce: bool = True):\n        super().__init__()\n\n        if intermediate_size is None:\n            intermediate_size = config.intermediate_size\n\n        self.config = config\n\n        mlp_bias = False\n        mlp_args = dict(\n            bias=mlp_bias,\n            dtype=dtype,\n            device=device,\n            is_tp=is_tp,\n        )\n        # gate up\n        self.gate_up_proj = build_merged_colwise_linear(\n            config.hidden_size,\n            [intermediate_size, intermediate_size],\n            **mlp_args,\n        )\n\n        # silu and mul\n        self.act_fn = SiluAndMul(inplace=True)\n\n        # down\n        self.down_proj = build_rowwise_linear(\n            intermediate_size,\n            config.hidden_size,\n            all_reduce=all_reduce,\n            **mlp_args,\n        )\n\n    def forward(self, x):\n        \"\"\"forward.\"\"\"\n        gate_up = self.gate_up_proj(x)\n        act = self.act_fn(gate_up)\n        return self.down_proj(act)\n\n\nclass Llama4TextMoe(nn.Module):\n    \"\"\"attention.\"\"\"\n\n    def __init__(self, config: Llama4TextConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        quantization_config = getattr(config, 'quantization_config', None)\n        self.config = config\n        self.top_k = config.num_experts_per_tok\n        self.hidden_dim = config.hidden_size\n        self.ffn_dim = config.intermediate_size\n        self.num_experts = config.num_local_experts\n\n        self.router = build_rowwise_linear(\n            self.hidden_dim,\n            self.num_experts,\n            bias=False,\n            dtype=dtype,\n            device=device,\n            is_tp=False,\n            quant_config=None,\n        )\n        self.experts = build_fused_moe(\n            self.hidden_dim,\n            self.ffn_dim,\n            self.num_experts,\n            top_k=1,\n            renormalize=False,\n            dtype=dtype,\n            device=device,\n            all_reduce=False,\n            quant_config=quantization_config,\n        )\n        self.shared_expert = Llama4TextMLP(config, dtype=dtype, device=device, is_tp=True, all_reduce=False)\n\n        dist_config = dist.get_dist_manager().current_config()\n        self.tp = dist_config.tp\n\n    def forward(self, hidden_states: torch.Tensor):\n        \"\"\"forward.\"\"\"\n        batch, seq_len, hidden_dim = hidden_states.shape\n        hidden_states = hidden_states.view(-1, hidden_dim)\n        router_logits = self.router(hidden_states)\n\n        topk_weights, topk_ids = torch.topk(router_logits, self.top_k, dim=-1)\n        input_weight = topk_weights.float().sigmoid().to(hidden_states.dtype)\n\n        moe_hidden_states = hidden_states[:, None, :] * input_weight[:, :, None]\n        moe_hidden_states = moe_hidden_states.view(-1, hidden_dim)\n        topk_weights = torch.ones_like(input_weight).reshape(-1, 1)\n        topk_ids = topk_ids.reshape(-1, 1)\n\n        out_states = self.experts(\n            moe_hidden_states,\n            topk_weights,\n            topk_ids,\n        )\n\n        out_states = out_states.reshape(-1, self.top_k, hidden_dim)\n        out_states = out_states.sum(1)\n\n        shared_states = self.shared_expert(hidden_states)\n        out_states += shared_states\n        out_states = out_states.reshape(batch, seq_len, -1)\n\n        if self.tp > 1:\n            dist.all_reduce(out_states)\n\n        return out_states\n\n\nclass Llama4TextDecoderLayer(nn.Module):\n    \"\"\"Decoder layer.\"\"\"\n\n    def __init__(self,\n                 config: Llama4TextConfig,\n                 layer_idx: int,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__()\n        self.config = config\n        self.layer_idx = layer_idx\n        self.hidden_size = config.hidden_size\n        self.self_attn = Llama4TextAttention(config, layer_idx, dtype=dtype, device=device)\n        self.use_chunked_attention = int((layer_idx + 1) % 4 != 0)  # <=> use rope\n        self.is_moe_layer = layer_idx in config.moe_layers\n        if self.is_moe_layer:  # the 128E model interleaves dense / sparse\n            self.feed_forward = Llama4TextMoe(config, dtype=dtype, device=device)\n        else:\n            self.feed_forward = Llama4TextMLP(config,\n                                              intermediate_size=config.intermediate_size_mlp,\n                                              dtype=dtype,\n                                              device=device)\n\n        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, dtype=dtype, device=device)\n        self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, dtype=dtype, device=device)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],\n        past_key_value: Optional[List[torch.FloatTensor]],\n        residual: Optional[torch.Tensor] = None,\n        attn_metadata: Any = None,\n    ):\n        \"\"\"forward.\"\"\"\n\n        if residual is None:\n            residual = hidden_states\n            hidden_states = self.input_layernorm(hidden_states)\n        else:\n            hidden_states, residual = self.input_layernorm(hidden_states, residual)\n\n        # Self Attention\n        hidden_states = self.self_attn(\n            hidden_states=hidden_states,\n            rotary_pos_emb=rotary_pos_emb,\n            past_key_value=past_key_value,\n            attn_metadata=attn_metadata,\n        )\n\n        # Fully Connected\n        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)\n        hidden_states = self.feed_forward(hidden_states)\n\n        outputs = (hidden_states, residual)\n        return outputs\n\n\nclass Llama4TextModel(nn.Module):\n    \"\"\"Llama4 text model.\"\"\"\n\n    def __init__(self, config: Llama4TextConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n\n        self.embed_tokens = nn.Embedding(config.vocab_size,\n                                         config.hidden_size,\n                                         self.padding_idx,\n                                         dtype=dtype,\n                                         device=device)\n        self.layers = nn.ModuleList([\n            Llama4TextDecoderLayer(config, layer_idx, dtype=dtype, device=device)\n            for layer_idx in range(config.num_hidden_layers)\n        ])\n        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, dtype=dtype, device=device)\n\n        self.rotary_emb = self.build_llama4_rotary_embedding(config)\n\n    @staticmethod\n    def build_llama4_rotary_embedding(config: Llama4TextConfig):\n        \"\"\"Build llama4 rotary embedding.\"\"\"\n        return build_rotary_embedding_from_config(config)\n\n    def forward(\n        self,\n        inputs_embeds: torch.Tensor,\n        position_ids: torch.Tensor,\n        past_key_values: List[List[torch.Tensor]],\n        attn_metadata: Any = None,\n        **kwargs,\n    ):\n        \"\"\"Model forward.\"\"\"\n        hidden_states = inputs_embeds\n\n        # rotary embedding\n        cos, sin = self.rotary_emb(hidden_states, position_ids)\n        cos, sin = cos[0], sin[0]\n        rotary_pos_emb = (cos, sin)\n\n        # decoding\n        residual = None\n        for idx, decoder_layer in enumerate(self.layers):\n            past_key_value = past_key_values[idx]\n            hidden_states, residual = decoder_layer(\n                hidden_states,\n                rotary_pos_emb=rotary_pos_emb,\n                past_key_value=past_key_value,\n                residual=residual,\n                attn_metadata=attn_metadata,\n            )\n\n        # norm\n        hidden_states, _ = self.norm(hidden_states, residual)\n        return hidden_states\n\n\nclass Llama4ForCausalLM(nn.Module):\n\n    def __init__(self,\n                 config: Llama4TextConfig,\n                 ctx_mgr: StepContextManager,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__()\n        self.config = config\n        self.ctx_mgr = ctx_mgr\n        self.model = Llama4TextModel(config, dtype=dtype, device=device)\n        self.vocab_size = config.vocab_size\n        self.lm_head = build_rowwise_linear(config.hidden_size,\n                                            config.vocab_size,\n                                            bias=False,\n                                            device=device,\n                                            dtype=dtype)\n\n    def forward(\n        self,\n        inputs_embeds: torch.Tensor,\n        position_ids: torch.Tensor,\n        past_key_values: List[List[torch.Tensor]],\n        attn_metadata: Any = None,\n        **kwargs,\n    ):\n        \"\"\"Model forward.\"\"\"\n        outputs = self.model(\n            inputs_embeds=inputs_embeds,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            **kwargs,\n        )\n\n        return outputs\n\n    def get_input_embeddings(self):\n        \"\"\"Input embeddings.\"\"\"\n        return self.model.embed_tokens\n\n    def get_logits(self, hidden_states: torch.Tensor):\n        \"\"\"Compute logits of the model output.\"\"\"\n        return self.lm_head(hidden_states)\n\n\nclass Llama4MultiModalProjector(nn.Module):\n\n    def __init__(self, config: Llama4Config, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        self.linear_1 = nn.Linear(\n            config.vision_config.vision_output_dim,\n            config.text_config.hidden_size,\n            bias=False,\n            dtype=dtype,\n            device=device,\n        )\n\n    def forward(self, image_features):\n        \"\"\"forward.\"\"\"\n        hidden_states = self.linear_1(image_features)\n        return hidden_states\n\n\nclass Llama4UnfoldConvolution(nn.Module):\n    \"\"\"Llama4 unfold conv.\"\"\"\n\n    def __init__(self, config: Llama4VisionConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        kernel_size = config.patch_size\n        if isinstance(kernel_size, int):\n            kernel_size = (kernel_size, kernel_size)\n        self.unfold = torch.nn.Unfold(kernel_size=kernel_size, stride=config.patch_size)\n        self.linear = nn.Linear(\n            config.num_channels * kernel_size[0] * kernel_size[1],\n            config.hidden_size,\n            bias=False,\n            dtype=dtype,\n            device=device,\n        )\n\n    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        \"\"\"forward.\"\"\"\n        hidden_states = self.unfold(hidden_states)\n        hidden_states = hidden_states.permute(0, 2, 1)\n        hidden_states = self.linear(hidden_states)\n        return hidden_states\n\n\nclass Llama4VisionRotaryEmbedding(nn.Module):\n\n    def __init__(self, config: Llama4VisionConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        idx = config.image_size // config.patch_size\n        img_idx = torch.arange(idx**2, dtype=torch.int32).reshape(idx**2, 1)\n        img_idx = torch.cat([img_idx, img_idx[:1]], dim=0)\n        img_idx[-1, -1] = -2  # ID_CLS_TOKEN\n        frequencies_x = img_idx % idx  # get the coordinates of the 2d matrix along x\n        frequencies_y = img_idx // idx  # get the coordinates of the 2d matrix along y\n        freq_dim = config.hidden_size // config.num_attention_heads // 2\n        rope_freq = 1.0 / (get_rope_theta(config)**(torch.arange(0, freq_dim, 2)[:(freq_dim // 2)].float() / freq_dim))\n        freqs_x = ((frequencies_x + 1)[..., None] * rope_freq[None, None, :]).repeat_interleave(2, dim=-1)\n        freqs_y = ((frequencies_y + 1)[..., None] * rope_freq[None, None, :]).repeat_interleave(2, dim=-1)\n        freqs = torch.cat([freqs_x, freqs_y], dim=-1).float().contiguous()[..., ::2]\n        freqs = freqs.masked_fill(img_idx.reshape(-1, 1, 1) < 0, 0)\n        freq_cis = torch.view_as_complex(torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1))\n        self.freqs_ci = freq_cis.to(device)  # idx**2, idx**2, idx * 2\n\n    def forward(self, hidden_states):\n        return self.freqs_ci\n\n\ndef reshape_for_broadcast(freqs_ci: torch.Tensor, query: torch.Tensor):\n    ndim = query.ndim\n    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(query.shape)]\n    return freqs_ci.view(*shape)\n\n\ndef vision_apply_rotary_emb(\n    query: torch.Tensor,\n    key: torch.Tensor,\n    freqs_ci: torch.Tensor,\n) -> Tuple[torch.Tensor, torch.Tensor]:\n    query_ = torch.view_as_complex(query.float().reshape(*query.shape[:-1], -1, 2))\n    key_ = torch.view_as_complex(key.float().reshape(*key.shape[:-1], -1, 2))\n    freqs_ci = reshape_for_broadcast(freqs_ci=freqs_ci, query=query_)  # freqs_ci[:,:,None,:]\n    freqs_ci = freqs_ci.to(query_.device)\n    query_out = torch.view_as_real(query_ * freqs_ci).flatten(3)\n    key_out = torch.view_as_real(key_ * freqs_ci).flatten(3)\n    return query_out.type_as(query), key_out.type_as(key)  # but this drops to 8e-3\n\n\nclass Llama4VisionAttention(nn.Module):\n    \"\"\"Vision attn.\"\"\"\n\n    def __init__(self, config: Llama4VisionConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        self.config = config\n        self.embed_dim = config.hidden_size\n        self.num_heads = config.num_attention_heads\n        self.head_dim = config.hidden_size // config.num_attention_heads\n\n        # qkv\n        self.qkv_proj = build_qkv_proj(\n            self.embed_dim,\n            num_q_heads=self.num_heads,\n            num_kv_heads=self.num_heads,\n            head_size=self.head_dim,\n            bias=True,\n            dtype=dtype,\n            device=device,\n            is_tp=True,\n        )\n\n        # o_proj\n        self.o_proj = build_rowwise_linear(self.num_heads * self.head_dim,\n                                           self.embed_dim,\n                                           bias=True,\n                                           dtype=dtype,\n                                           device=device,\n                                           is_tp=True)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        freqs_ci: torch.Tensor,\n    ):\n        \"\"\"forward.\"\"\"\n        input_shape = hidden_states.shape[:-1]\n        hidden_shape = (*input_shape, -1, self.head_dim)\n        qkv_states = self.qkv_proj(hidden_states)\n        # (-1, heads, head_dim)\n        query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)\n        query_states = query_states.reshape(hidden_shape)\n        key_states = key_states.reshape(hidden_shape)\n        value_states = value_states.reshape(hidden_shape)\n\n        query_states, key_states = vision_apply_rotary_emb(query_states, key_states, freqs_ci=freqs_ci)\n\n        query_states = query_states.transpose(1, 2)\n        key_states = key_states.transpose(1, 2)\n        value_states = value_states.transpose(1, 2)\n\n        from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS\n        attention_interface = ALL_ATTENTION_FUNCTIONS['sdpa']\n        attn_output, attn_weights = attention_interface(\n            self,\n            query_states,\n            key_states,\n            value_states,\n            None,\n            dropout=0.0,\n            scaling=None,\n            is_causal=False,  # HAS TO BE ENFORCED\n            output_attentions=False,\n        )\n\n        attn_output = attn_output.reshape(*input_shape, -1).contiguous()\n        attn_output = self.o_proj(attn_output)\n        return attn_output\n\n\nclass Llama4VisionMLP(nn.Module):\n    \"\"\"Vision mlp.\"\"\"\n\n    def __init__(self, config: Llama4VisionConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        self.config = config\n        self.activation_fn = nn.GELU()\n        self.fc1 = build_colwise_linear(config.hidden_size,\n                                        config.intermediate_size,\n                                        bias=True,\n                                        dtype=dtype,\n                                        device=device,\n                                        is_tp=True)\n        self.fc2 = build_rowwise_linear(config.intermediate_size,\n                                        config.hidden_size,\n                                        bias=True,\n                                        dtype=dtype,\n                                        device=device,\n                                        is_tp=True)\n\n    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        \"\"\"forward.\"\"\"\n        hidden_states = self.fc1(hidden_states)\n        hidden_states = self.activation_fn(hidden_states)\n        hidden_states = self.fc2(hidden_states)\n        return hidden_states\n\n\nclass Llama4VisionEncoderLayer(nn.Module):\n    \"\"\"Vision encoder layer.\"\"\"\n\n    def __init__(self, config: Llama4VisionConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        self.config = config\n\n        self.self_attn = Llama4VisionAttention(config, dtype=dtype, device=device)\n        self.mlp = Llama4VisionMLP(config, dtype=dtype, device=device)\n\n        self.input_layernorm = nn.LayerNorm(config.hidden_size, dtype=dtype, device=device)\n        self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, dtype=dtype, device=device)\n\n    def forward(\n        self,\n        hidden_state: torch.Tensor,\n        freqs_ci: torch.Tensor,\n    ):\n        \"\"\"forward.\"\"\"\n        # Self Attention\n        residual = hidden_state\n\n        hidden_state = self.input_layernorm(hidden_state)\n\n        hidden_state = self.self_attn(\n            hidden_state,\n            freqs_ci=freqs_ci,\n        )\n        hidden_state = residual + hidden_state\n\n        # Feed forward\n        residual = hidden_state\n        hidden_state = self.post_attention_layernorm(hidden_state)\n        hidden_state = self.mlp(hidden_state)\n        hidden_state = residual + hidden_state\n\n        return hidden_state\n\n\nclass Llama4VisionEncoder(nn.Module):\n    \"\"\"Vision encoder.\"\"\"\n\n    def __init__(self, config: Llama4VisionConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        self.config = config\n        self.layers = nn.ModuleList(\n            [Llama4VisionEncoderLayer(config, dtype=dtype, device=device) for _ in range(config.num_hidden_layers)])\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        freqs_ci: torch.Tensor,\n    ):\n        \"\"\"forward.\"\"\"\n        for encoder_layer in self.layers:\n            hidden_states = encoder_layer(\n                hidden_state=hidden_states,\n                freqs_ci=freqs_ci,\n            )\n        return hidden_states\n\n\ndef pixel_shuffle(input_tensor: torch.Tensor, shuffle_ratio: int):\n    # input_tensor: [batch_size, num_patches, channels]\n    import math\n    batch_size, num_patches, channels = input_tensor.shape\n    patch_size = int(math.sqrt(num_patches))\n\n    input_tensor = input_tensor.view(batch_size, patch_size, patch_size, -1)\n    batch_size, height, width, channels = input_tensor.size()\n\n    reshaped_tensor = input_tensor.view(batch_size, height, int(width * shuffle_ratio), int(channels / shuffle_ratio))\n    reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous()\n\n    reshaped_tensor = reshaped_tensor.view(batch_size, int(height * shuffle_ratio), int(width * shuffle_ratio),\n                                           int(channels / (shuffle_ratio**2)))\n    reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous()\n\n    output_tensor = reshaped_tensor.view(batch_size, -1, reshaped_tensor.shape[-1])\n    return output_tensor\n\n\nclass Llama4VisionMLP2(torch.nn.Module):\n\n    def __init__(self, config: Llama4VisionConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        self.hidden_size = config.hidden_size\n        self.intermediate_size = config.intermediate_size\n        self.fc1 = build_colwise_linear(self.intermediate_size,\n                                        config.projector_input_dim,\n                                        bias=False,\n                                        dtype=dtype,\n                                        device=device,\n                                        is_tp=True)\n        self.fc2 = build_rowwise_linear(config.projector_output_dim,\n                                        config.projector_output_dim,\n                                        bias=False,\n                                        dtype=dtype,\n                                        device=device,\n                                        is_tp=True)\n        self.activation_fn = nn.GELU()  # ACT2FN[config.hidden_act]\n\n    def forward(self, hidden_states):\n        \"\"\"forward.\"\"\"\n        hidden_states = self.fc1(hidden_states)\n        hidden_states = self.activation_fn(hidden_states)\n        return self.activation_fn(self.fc2(hidden_states))\n\n\nclass Llama4VisionPixelShuffleMLP(nn.Module):\n\n    def __init__(self, config: Llama4VisionConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        self.pixel_shuffle_ratio = config.pixel_shuffle_ratio\n        self.inner_dim = int(config.projector_input_dim // (self.pixel_shuffle_ratio**2))\n        self.output_dim = config.projector_output_dim\n        self.mlp = Llama4VisionMLP2(config, dtype=dtype, device=device)\n\n    def forward(self, encoded_patches: torch.Tensor) -> torch.Tensor:\n        encoded_patches = pixel_shuffle(encoded_patches, self.pixel_shuffle_ratio)\n        return self.mlp(encoded_patches)\n\n\nclass Llama4VisionModel(nn.Module):\n    \"\"\"Llama4 vision model.\"\"\"\n\n    def __init__(self, config: Llama4VisionConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        self.config = config\n        self.image_size = config.image_size\n        self.patch_size = config.patch_size\n        self.hidden_size = config.hidden_size\n        self.num_channels = config.num_channels\n\n        self.num_patches = (self.image_size // self.patch_size)**2 + 1\n        self.scale = config.hidden_size**-0.5\n\n        self.patch_embedding = Llama4UnfoldConvolution(config, dtype=dtype, device=device)\n\n        self.class_embedding = nn.Parameter(self.scale * torch.empty(self.hidden_size, dtype=dtype, device=device))\n        self.positional_embedding_vlm = nn.Parameter(\n            self.scale * torch.empty(self.num_patches, self.hidden_size, dtype=dtype, device=device))\n        self.rotary_embedding = Llama4VisionRotaryEmbedding(config, dtype=dtype, device=device)\n\n        # layer norms\n        self.layernorm_pre = nn.LayerNorm(self.hidden_size, dtype=dtype, device=device)\n        self.layernorm_post = nn.LayerNorm(self.hidden_size, dtype=dtype, device=device)\n\n        # encoders\n        self.model = Llama4VisionEncoder(config, dtype=dtype, device=device)\n        self.vision_adapter = Llama4VisionPixelShuffleMLP(config, dtype=dtype, device=device)\n\n    def get_input_embeddings(self):\n        \"\"\"This function is used to fetch the first embedding layer to activate\n        grads on inputs.\"\"\"\n        return self.patch_embedding\n\n    def forward(\n        self,\n        pixel_values: torch.Tensor,\n    ):\n        \"\"\"forward.\"\"\"\n        batch_size_times_num_tiles, num_channels, height, width = pixel_values.shape\n        num_concurrent_media = 1\n        num_chunks = 1\n        hidden_state = self.patch_embedding(pixel_values)\n        _, num_patches, hidden_dim = hidden_state.shape\n\n        # Add cls token\n        hidden_state = hidden_state.reshape(batch_size_times_num_tiles * num_concurrent_media * num_chunks, num_patches,\n                                            hidden_dim)\n        class_embedding = self.class_embedding.expand(hidden_state.shape[0], 1, hidden_state.shape[-1])\n        hidden_state = torch.cat([hidden_state, class_embedding], dim=1)\n        num_patches += 1\n\n        # Position embeddings\n        hidden_state = hidden_state.reshape(batch_size_times_num_tiles * num_concurrent_media, num_chunks, num_patches,\n                                            hidden_dim)\n        positional_embedding = self.positional_embedding_vlm.to(dtype=hidden_state.dtype, device=hidden_state.device)\n        hidden_state = hidden_state + positional_embedding\n\n        hidden_state = self.layernorm_pre(hidden_state)\n\n        hidden_state = hidden_state.view(batch_size_times_num_tiles, -1, hidden_dim)\n        freqs_ci = self.rotary_embedding(pixel_values)\n\n        output = self.model(\n            hidden_state,\n            freqs_ci=freqs_ci,\n        )\n\n        hidden_state = output\n\n        hidden_state = self.layernorm_post(hidden_state)\n\n        hidden_state = hidden_state[:, :-1, :]\n\n        # now, we use Llama4VisionPixelShuffle + mlp to project embeddings\n        hidden_state = self.vision_adapter(hidden_state)\n\n        return hidden_state\n\n\nclass Llama4ForConditionalGeneration(nn.Module, CudaGraphMixin):\n\n    def __init__(self,\n                 config: Llama4Config,\n                 ctx_mgr: StepContextManager,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__()\n        self.config = config\n        self.ctx_mgr = ctx_mgr\n\n        self.vision_model = Llama4VisionModel(config.vision_config, dtype=dtype, device=device)\n\n        self.multi_modal_projector = Llama4MultiModalProjector(config, dtype=dtype, device=device)\n\n        self._update_quant_config(config)\n        self.language_model = Llama4ForCausalLM(config.text_config, ctx_mgr, dtype=dtype, device=device)\n        self.vocab_size = config.text_config.vocab_size\n\n        self.input_processor = Llama4InputProcessor(config, dtype)\n\n    @staticmethod\n    def _update_quant_config(config: Llama4Config):\n        \"\"\"Update quant config.\"\"\"\n        quant_config = getattr(config, 'quantization_config', None)\n\n        if quant_config is None:\n            return config\n\n        quantization_config = dict(\n            quant_dtype='float8_e4m3fn',\n            quant_method='smooth_quant',\n        )\n        text_config = config.text_config\n        setattr(text_config, 'quantization_config', quantization_config)\n\n        return config\n\n    def get_image_features(\n        self,\n        pixel_values: torch.FloatTensor,\n        **kwargs,\n    ):\n        \"\"\"Get image features.\"\"\"\n        kwargs = {k: v for k, v in kwargs.items() if v is not None}\n        hidden_state = self.vision_model(pixel_values, **kwargs)\n        return hidden_state\n\n    def get_input_embeddings(self):\n        \"\"\"Input embeddings.\"\"\"\n        return self.language_model.get_input_embeddings()\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        past_key_values: List[List[torch.Tensor]],\n        attn_metadata: Any = None,\n        pixel_values: torch.FloatTensor = None,\n        image_mask: torch.Tensor = None,\n        **kwargs,\n    ):\n        \"\"\"Model forward.\"\"\"\n        image_embeds = None\n        if pixel_values is not None:\n            image_features = self.get_image_features(pixel_values=pixel_values, )\n            vision_flat = image_features.view(-1, image_features.size(-1))\n            image_embeds = self.multi_modal_projector(vision_flat)\n\n        lang_embeds: torch.Tensor = self.get_input_embeddings()(input_ids)\n\n        if image_embeds is not None:\n            lang_embeds.masked_scatter_(image_mask[..., None], image_embeds)\n\n        inputs_embeds = lang_embeds\n\n        return self.language_model(\n            inputs_embeds,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n        )\n\n    def get_logits(self, hidden_states: torch.Tensor):\n        \"\"\"Compute logits of the model output.\"\"\"\n        return self.language_model.get_logits(hidden_states)\n\n    def prepare_inputs_for_generation(\n        self,\n        past_key_values: List[List[torch.Tensor]],\n        inputs_embeds: Optional[torch.Tensor] = None,\n        context: StepContext = None,\n    ):\n        \"\"\"Prepare input.\"\"\"\n        # get input_ids, position_ids and attention metadatas\n        input_ids = context.input_ids\n        position_ids = context.position_ids\n        attn_metadata = context.attn_metadata\n\n        # vision inputs\n        pixel_values = None\n        image_mask = None\n        if context.input_multimodals is not None:\n            pixel_values = [input_mm.get('image', []) for input_mm in context.input_multimodals]\n            # flatten batch\n            pixel_values = [data for im_data in pixel_values for data in im_data]\n            if len(pixel_values) > 0:\n                image_token_id = pixel_values[0].meta['image_token_id']\n                image_mask = input_ids == image_token_id\n                pixel_values = torch.cat([data.data for data in pixel_values])\n            else:\n                pixel_values = None\n                image_mask = None\n\n        # inputs of forward\n        return dict(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            inputs_embeds=inputs_embeds,\n            pixel_values=pixel_values,\n            image_mask=image_mask,\n        )\n\n    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):\n        \"\"\"Load weights.\"\"\"\n\n        def _load_experts_bf16(name, loaded_weight):\n            if '.gate_up_proj' in name:\n                loaded_weight = loaded_weight.to(device)\n                name = name.replace('.gate_up_proj', '.gate_up.weight')\n                param = params_dict[name]\n                for exp_id in range(num_experts):\n                    weight_gate, weight_up = loaded_weight[exp_id].chunk(2, -1)\n                    load_weight(param, weight_gate.t(), expert_id=exp_id, shard_id='gate')\n                    load_weight(param, weight_up.t(), expert_id=exp_id, shard_id='up')\n            elif '.down_proj' in name:\n                loaded_weight = loaded_weight.to(device)\n                name = name.replace('.down_proj', '.down.weight')\n                param = params_dict[name]\n                for exp_id in range(num_experts):\n                    weight = loaded_weight[exp_id].t()\n                    load_weight(param, weight, expert_id=exp_id, shard_id='down')\n\n        def _load_experts_fp8(name, loaded_weight):\n            name = name.replace('.weight_scale', '.scale')\n            for (param_name, weight_name, expert_id, shard_id) in expert_params_mapping:\n                if weight_name not in name:\n                    continue\n                name = name.replace(weight_name, param_name)\n                param = params_dict[name]\n                load_weight(param, loaded_weight, expert_id=expert_id, shard_id=shard_id)\n                break\n            else:\n                param = params_dict[name]\n                load_weight(param, loaded_weight)\n\n        def _load_experts(name, loaded_weight):\n            \"\"\"Load experts weight.\"\"\"\n            quantization_config = getattr(self.config, 'quantization_config', None)\n            if quantization_config is None:\n                _load_experts_bf16(name, loaded_weight)\n            else:\n                _load_experts_fp8(name, loaded_weight)\n\n        # modify from vllm\n        stacked_params_mapping = [\n            # (param_name, shard_name, shard_id)\n            ('.qkv_proj', '.q_proj', 'q'),\n            ('.qkv_proj', '.k_proj', 'k'),\n            ('.qkv_proj', '.v_proj', 'v'),\n            ('.gate_up_proj', '.gate_proj', 0),\n            ('.gate_up_proj', '.up_proj', 1),\n        ]\n\n        num_experts = self.config.text_config.num_local_experts\n        expert_params_mapping = []\n        for exp_id in range(num_experts):\n            gate_param = ('.experts.gate_up', f'.experts.{exp_id}.gate_proj', exp_id, 'gate')\n            up_param = ('.experts.gate_up', f'.experts.{exp_id}.up_proj', exp_id, 'up')\n            down_param = ('.experts.down', f'.experts.{exp_id}.down_proj', exp_id, 'down')\n            expert_params_mapping += [gate_param, up_param, down_param]\n\n        params_dict = dict(self.named_parameters())\n        device = next(iter(params_dict.values())).device\n        for name, loaded_weight in weights:\n            if 'rotary_emb.inv_freq' in name:\n                continue\n            if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):\n                continue\n            if self.config.tie_word_embeddings and 'lm_head.weight' in name:\n                continue\n\n            if '.experts' in name:\n                _load_experts(name, loaded_weight)\n            else:\n                for (param_name, weight_name, shard_id) in stacked_params_mapping:\n                    if weight_name not in name:\n                        continue\n                    name = name.replace(weight_name, param_name)\n                    param = params_dict[name]\n                    load_weight(param, loaded_weight, shard_id=shard_id)\n                    break\n                else:\n                    param = params_dict[name]\n                    load_weight(param, loaded_weight)\n\n    def get_input_processor(self) -> BaseModelInputProcessor:\n        \"\"\"Get input processor.\"\"\"\n        return self.input_processor\n\n\nclass Llama4InputProcessor(BaseModelInputProcessor):\n    \"\"\"Llama4 input processor.\"\"\"\n\n    def __init__(self, config: Llama4Config, dtype) -> None:\n        self.config = config\n        self.dtype = dtype\n\n        self.vision_config = config.vision_config\n\n    def preprocess_input(self,\n                         input_ids: List[int],\n                         input_multimodals: List[Dict[str, Any]] = None,\n                         **kwargs) -> PreprocessInputResult:\n        \"\"\"Prepare multimodal input.\"\"\"\n\n        if input_multimodals is None or len(input_multimodals) == 0:\n            return input_ids, input_multimodals\n\n        input_imgs = []\n        for input_mm in input_multimodals:\n            pixel_values = input_mm['pixel_values'].to(self.dtype)\n            offset = input_mm['offset']\n            image_token_id = input_mm['image_token_id']\n            num_pad = input_mm['image_tokens']\n            if isinstance(num_pad, torch.Tensor):\n                num_pad = num_pad.item()\n\n            mm_data = MultiModalData(data=pixel_values,\n                                     start=offset,\n                                     end=offset + num_pad,\n                                     meta=dict(image_token_id=image_token_id))\n            input_imgs.append(mm_data)\n\n        result = PreprocessInputResult(\n            input_ids=input_ids,\n            input_multimodals=dict(image=input_imgs),\n        )\n        return result\n"
  },
  {
    "path": "lmdeploy/pytorch/models/llama_eagle.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nfrom typing import Any, Iterable, List, Optional, Tuple\n\nimport torch\nfrom torch import nn\nfrom transformers import PretrainedConfig\n\nfrom lmdeploy.pytorch.model_inputs import StepContext\nfrom lmdeploy.pytorch.nn import build_rotary_embedding_from_config\nfrom lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight\n\nfrom .llama import LlamaDecoderLayer\nfrom .utils.cudagraph import CudaGraphMeta, CudaGraphMixin\n\n\nclass EagleLlamaDecoderLayer(LlamaDecoderLayer):\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 layer_idx: int,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None) -> None:\n        super().__init__(config, layer_idx, dtype=dtype, device=device, is_tp=False)\n\n        # Skip the input_layernorm\n        # https://github.com/SafeAILab/EAGLE/blob/35c78f6cdc19a73e05cf5c330b4c358dad970c6a/eagle/model/cnets.py#L427\n        if layer_idx == 0:\n            del self.input_layernorm\n            setattr(self, 'input_layernorm', lambda x: x)\n\n\nclass EagleLlamaModel(nn.Module):\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n        self.embed_tokens = nn.Embedding(config.vocab_size,\n                                         config.hidden_size,\n                                         self.padding_idx,\n                                         dtype=dtype,\n                                         device=device)\n\n        # build all decode layers\n        self.layers = nn.ModuleList([\n            EagleLlamaDecoderLayer(config, layer_idx, dtype=dtype, device=device)\n            for layer_idx in range(config.num_hidden_layers)\n        ])\n        # build fc\n        self.fc = nn.Linear(\n            config.hidden_size * 2,\n            config.hidden_size,\n            bias=False,\n            dtype=dtype,\n            device=device,\n        )\n\n        # build rotary embedding in LlamaModel\n        self.rotary_emb = build_rotary_embedding_from_config(config)\n\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        attn_metadata: Any = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        previous_hidden_states: Optional[torch.FloatTensor] = None,\n    ):\n        \"\"\"Rewrite of LlamaModel.forward.\"\"\"\n        # token embedding\n        if inputs_embeds is None:\n            assert input_ids is not None\n            inputs_embeds = self.embed_tokens(input_ids)\n        previous_hidden_states = previous_hidden_states.to(inputs_embeds)\n        hidden_states = torch.cat([inputs_embeds, previous_hidden_states], dim=-1)\n        hidden_states = self.fc(hidden_states)\n        # rotary embedding\n        cos, sin = self.rotary_emb(hidden_states, position_ids)\n        cos, sin = cos[0], sin[0]\n        rotary_pos_emb = (cos, sin)\n\n        # decoding\n        residual = None\n        for idx, decoder_layer in enumerate(self.layers):\n            past_key_value = past_key_values[idx]\n            hidden_states, residual = decoder_layer(\n                hidden_states,\n                rotary_pos_emb=rotary_pos_emb,\n                past_key_value=past_key_value,\n                residual=residual,\n                attn_metadata=attn_metadata,\n            )\n        hidden_states = hidden_states + residual\n        return hidden_states\n\n    def get_input_embeddings(self):\n        \"\"\"Get input embeddings.\"\"\"\n        return self.embed_tokens\n\n\nclass EagleLlamaForCausalLM(nn.Module, CudaGraphMixin):\n\n    packed_modules_mapping = {\n        'qkv_proj': [\n            'q_proj',\n            'k_proj',\n            'v_proj',\n        ],\n        'gate_up_proj': [\n            'gate_proj',\n            'up_proj',\n        ],\n    }\n\n    def __init__(self, config, ctx_mgr, dtype=None, device=None):\n        nn.Module.__init__(self)\n        self.config = config\n        self.ctx_mgr = ctx_mgr\n        self.dtype = dtype\n        # build LLamaModel\n        self.model = EagleLlamaModel(config, dtype=dtype, device=device)\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        past_key_values: List[List[torch.Tensor]],\n        attn_metadata: Any = None,\n        inputs_embeds: torch.Tensor = None,\n        target_hidden_states: torch.Tensor = None,\n        **kwargs,\n    ):\n        \"\"\"Model forward, return logits.\"\"\"\n        hidden_states = self.model(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            inputs_embeds=inputs_embeds,\n            previous_hidden_states=target_hidden_states,\n        )\n        return hidden_states\n\n    def prepare_inputs_for_generation(\n        self,\n        past_key_values: List[List[torch.Tensor]],\n        inputs_embeds: Optional[torch.Tensor] = None,\n        context: StepContext = None,\n    ):\n        \"\"\"Prepare input.\"\"\"\n        input_ids = context.input_ids\n        position_ids = context.position_ids\n        attn_metadata = context.attn_metadata\n        target_hidden_states = context.target_hidden_states\n        return dict(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            inputs_embeds=inputs_embeds,\n            target_hidden_states=target_hidden_states,\n        )\n\n    def make_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs):\n        \"\"\"Make cudagraph buffers from forward inputs.\"\"\"\n        max_tokens = graph_meta.max_tokens\n\n        input_buffers = super().make_buffers_cudagraph(graph_meta=graph_meta, **kwargs)\n        input_buffers['target_hidden_states'] = input_buffers['input_ids'].new_zeros(1,\n                                                                                     max_tokens,\n                                                                                     self.config.hidden_size,\n                                                                                     dtype=self.dtype)\n\n        return input_buffers\n\n    def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs):\n        \"\"\"Fill cudagraph buffers from forward inputs.\"\"\"\n\n        new_inputs = super().fill_buffers_cudagraph(graph_meta=graph_meta, **kwargs)\n\n        num_tokens = kwargs['input_ids'].size(-1)\n\n        is_decoding = graph_meta.is_decoding\n        input_buffers = graph_meta.input_buffers\n        padded_num_tokens = new_inputs['input_ids'].size(-1)\n\n        target_hidden_states = kwargs.get('target_hidden_states')\n        assert target_hidden_states is not None\n        input_buffers['target_hidden_states'][:, :num_tokens] = target_hidden_states\n        if is_decoding:\n            new_inputs['target_hidden_states'] = input_buffers['target_hidden_states'][:, :padded_num_tokens, :]\n        else:\n            new_inputs['target_hidden_states'] = input_buffers['target_hidden_states']\n\n        return new_inputs\n\n    def update_weights(self):\n        \"\"\"Update weights.\"\"\"\n        if self.config.tie_word_embeddings:\n            self.lm_head.weight = self.model.embed_tokens.weight\n\n    def get_input_embeddings(self):\n        \"\"\"Get input embeddings.\"\"\"\n        return self.model.get_input_embeddings()\n\n    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):\n        \"\"\"Load weights.\"\"\"\n        stacked_params_mapping = [\n            # (param_name, shard_name, shard_id)\n            ('.qkv_proj', '.q_proj', 'q'),\n            ('.qkv_proj', '.k_proj', 'k'),\n            ('.qkv_proj', '.v_proj', 'v'),\n            ('.gate_up_proj', '.gate_proj', 0),\n            ('.gate_up_proj', '.up_proj', 1),\n        ]\n\n        params_dict = dict(self.named_parameters())\n        for name, loaded_weight in weights:\n            name = 'model.' + name\n            if 'rotary_emb.inv_freq' in name:\n                continue\n            if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):\n                continue\n            if self.config.tie_word_embeddings and 'lm_head.weight' in name:\n                continue\n            for (param_name, weight_name, shard_id) in stacked_params_mapping:\n                if weight_name not in name:\n                    continue\n                name = name.replace(weight_name, param_name)\n                param = params_dict[name]\n                load_weight(param, loaded_weight, shard_id=shard_id)\n                break\n            else:\n                param = params_dict[name]\n                load_weight(param, loaded_weight)\n"
  },
  {
    "path": "lmdeploy/pytorch/models/llama_eagle3.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nfrom typing import Any, Dict, Iterable, List, Optional, Tuple\n\nimport torch\nfrom torch import nn\nfrom transformers import PretrainedConfig\n\nfrom lmdeploy.pytorch.model_inputs import StepContext\nfrom lmdeploy.pytorch.nn import RMSNorm, build_rotary_embedding_from_config\nfrom lmdeploy.pytorch.nn.linear import build_qkv_proj, build_rowwise_linear\nfrom lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight\n\nfrom .llama import LlamaDecoderLayer\nfrom .utils.cudagraph import CudaGraphMeta, CudaGraphMixin\n\n\nclass Eagle3LlamaDecoderLayer(LlamaDecoderLayer):\n    \"\"\"Llama decoder layer.\"\"\"\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 layer_idx: int,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__(config, layer_idx, dtype=dtype, device=device, is_tp=False)\n        self.layer_idx = layer_idx\n\n        quantization_config = getattr(config, 'quantization_config', None)\n        num_heads = config.num_attention_heads\n        num_key_value_heads = config.num_key_value_heads\n        hidden_size = config.hidden_size\n        head_dim = getattr(config, 'head_dim', hidden_size // num_heads)\n\n        # override attention qkv\n        self.self_attn.qkv_proj = build_qkv_proj(\n            2 * hidden_size,\n            num_q_heads=num_heads,\n            num_kv_heads=num_key_value_heads,\n            head_size=head_dim,\n            bias=False,\n            quant_config=quantization_config,\n            dtype=dtype,\n            device=device,\n            is_tp=False,\n        )\n\n        self.hidden_norm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)\n\n    def forward(\n        self,\n        embeds: torch.Tensor,\n        hidden_states: torch.Tensor,\n        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],\n        past_key_value: Optional[List[torch.FloatTensor]],\n        attn_metadata: Any = None,\n    ):\n\n        residual = hidden_states\n        embeds = self.input_layernorm(embeds)\n        hidden_states = self.hidden_norm(hidden_states)\n        hidden_states = torch.cat([embeds, hidden_states], dim=-1)\n\n        # Self Attention\n        hidden_states = self.self_attn(\n            hidden_states=hidden_states,\n            rotary_pos_emb=rotary_pos_emb,\n            past_key_value=past_key_value,\n            attn_metadata=attn_metadata,\n        )\n\n        # Fully Connected\n        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)\n        hidden_states = self.mlp(hidden_states)\n\n        outputs = (hidden_states, residual)\n        return outputs\n\n\nclass Eagle3LlamaModel(nn.Module):\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        self.config = config\n        self.dtype = dtype\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n        self.embed_tokens = nn.Embedding(config.vocab_size,\n                                         config.hidden_size,\n                                         self.padding_idx,\n                                         dtype=dtype,\n                                         device=device)\n\n        # build layer\n        self.midlayer = Eagle3LlamaDecoderLayer(config, layer_idx=0, dtype=dtype, device=device)\n        target_hidden_size = getattr(config, 'target_hidden_size', config.hidden_size)\n        self.fc = build_rowwise_linear(\n            target_hidden_size * 3,\n            config.hidden_size,\n            bias=False,\n            dtype=dtype,\n            device=device,\n        )\n\n        self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)\n        # build rotary embedding in LlamaModel\n        self.rotary_emb = build_rotary_embedding_from_config(config)\n\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        attn_metadata: Any = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        previous_hidden_states: Optional[torch.FloatTensor] = None,\n    ):\n        \"\"\"Rewrite of LlamaModel.forward.\"\"\"\n        # token embedding\n        if inputs_embeds is None:\n            assert input_ids is not None\n            inputs_embeds = self.embed_tokens(input_ids).to(self.dtype)\n        previous_hidden_states = previous_hidden_states.to(inputs_embeds)\n        if previous_hidden_states.shape[-1] != inputs_embeds.shape[-1]:\n            # previous_hidden_states if from target model\n            previous_hidden_states = self.fc(previous_hidden_states)\n        # rotary embedding\n        cos, sin = self.rotary_emb(previous_hidden_states, position_ids)\n        cos, sin = cos[0], sin[0]\n        rotary_pos_emb = (cos, sin)\n\n        # decoding\n        residual = None\n        past_key_value = past_key_values[0]\n        hidden_states, residual = self.midlayer(\n            inputs_embeds,\n            previous_hidden_states,\n            rotary_pos_emb=rotary_pos_emb,\n            past_key_value=past_key_value,\n            attn_metadata=attn_metadata,\n        )\n        hidden_states, hidden_states_prenorm = self.norm(hidden_states, residual)\n        outputs = dict(hidden_states=hidden_states, hidden_states_prenorm=hidden_states_prenorm)\n        return outputs\n\n    def get_input_embeddings(self):\n        \"\"\"Get input embeddings.\"\"\"\n        return self.embed_tokens\n\n\nclass Eagle3LlamaForCausalLM(nn.Module, CudaGraphMixin):\n\n    packed_modules_mapping = {\n        'qkv_proj': [\n            'q_proj',\n            'k_proj',\n            'v_proj',\n        ],\n        'gate_up_proj': [\n            'gate_proj',\n            'up_proj',\n        ],\n    }\n\n    def __init__(self, config, ctx_mgr, dtype=None, device=None):\n        nn.Module.__init__(self)\n        self.config = config\n        self.ctx_mgr = ctx_mgr\n        self.dtype = dtype\n\n        if config.num_hidden_layers != 1:\n            raise ValueError('eagle3 only supports 1 decode layer')\n\n        # build LLamaModel\n        self.model = Eagle3LlamaModel(config, dtype=dtype, device=device)\n        # build lm_head\n        self.lm_head = build_rowwise_linear(config.hidden_size,\n                                            config.draft_vocab_size,\n                                            bias=False,\n                                            dtype=dtype,\n                                            device=device)\n        self.draft_id_to_target_id = nn.Parameter(\n            torch.zeros(self.config.draft_vocab_size, dtype=torch.long, device=device),\n            requires_grad=False,\n        )\n        self.include_embed_tokens = False\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        past_key_values: List[List[torch.Tensor]],\n        attn_metadata: Any = None,\n        inputs_embeds: torch.Tensor = None,\n        target_hidden_states: torch.Tensor = None,\n        **kwargs,\n    ):\n        \"\"\"Model forward, return logits.\"\"\"\n        hidden_states = self.model(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            inputs_embeds=inputs_embeds,\n            previous_hidden_states=target_hidden_states,\n        )\n        return hidden_states\n\n    def prepare_inputs_for_generation(\n        self,\n        past_key_values: List[List[torch.Tensor]],\n        inputs_embeds: Optional[torch.Tensor] = None,\n        context: StepContext = None,\n    ):\n        \"\"\"Prepare input.\"\"\"\n        input_ids = context.input_ids\n        position_ids = context.position_ids\n        attn_metadata = context.attn_metadata\n        target_hidden_states = context.target_hidden_states\n        return dict(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            inputs_embeds=inputs_embeds,\n            target_hidden_states=target_hidden_states,\n        )\n\n    def get_logits(self, hidden_states: torch.Tensor):\n        \"\"\"Compute logits of the model output.\"\"\"\n        logits = self.lm_head(hidden_states)\n        return logits\n\n    def make_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs):\n        \"\"\"Make cudagraph buffers from forward inputs.\"\"\"\n        max_tokens = graph_meta.max_tokens\n        target_hidden_states = kwargs.get('target_hidden_states')\n        assert target_hidden_states is not None\n        target_hidden_size = target_hidden_states.size(-1)\n        input_buffers = super().make_buffers_cudagraph(graph_meta=graph_meta, **kwargs)\n        input_buffers['target_hidden_states'] = input_buffers['input_ids'].new_zeros(1,\n                                                                                     max_tokens,\n                                                                                     target_hidden_size,\n                                                                                     dtype=self.dtype)\n\n        return input_buffers\n\n    def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs):\n        \"\"\"Fill cudagraph buffers from forward inputs.\"\"\"\n\n        new_inputs = super().fill_buffers_cudagraph(graph_meta=graph_meta, **kwargs)\n\n        num_tokens = kwargs['input_ids'].size(-1)\n\n        input_buffers = graph_meta.input_buffers\n\n        target_hidden_states = kwargs.get('target_hidden_states')\n        assert target_hidden_states is not None\n        input_buffers['target_hidden_states'][:, :num_tokens] = target_hidden_states\n\n        new_inputs['target_hidden_states'] = input_buffers['target_hidden_states']\n\n        return new_inputs\n\n    def get_outputs_cudagraph(self, output_buffers: Dict[str, torch.Tensor], input_ids: torch.Tensor, **kwargs):\n        \"\"\"Get outputs from buffers.\"\"\"\n        num_tokens = input_ids.size(-1)\n        outputs = dict()\n        outputs['hidden_states'] = output_buffers['hidden_states'][:, :num_tokens]\n        outputs['hidden_states_prenorm'] = output_buffers['hidden_states_prenorm'][:, :num_tokens]\n        return outputs\n\n    def get_input_embeddings(self):\n        \"\"\"Get input embeddings.\"\"\"\n        return self.model.get_input_embeddings()\n\n    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):\n        \"\"\"Load weights.\"\"\"\n        stacked_params_mapping = [\n            # (param_name, shard_name, shard_id)\n            ('.qkv_proj', '.q_proj', 'q'),\n            ('.qkv_proj', '.k_proj', 'k'),\n            ('.qkv_proj', '.v_proj', 'v'),\n            ('.gate_up_proj', '.gate_proj', 0),\n            ('.gate_up_proj', '.up_proj', 1),\n        ]\n\n        params_dict = dict(self.named_parameters())\n        for name, loaded_weight in weights:\n            if 'd2t' in name:\n                name = 'draft_id_to_target_id'\n                base = torch.arange(self.config.draft_vocab_size,\n                                    device=loaded_weight.device,\n                                    dtype=loaded_weight.dtype)\n                loaded_weight += base\n            elif 'lm_head.weight' not in name:\n                name = 'model.' + name\n            if 'embed_tokens' in name:\n                self.include_embed_tokens = True\n            if 't2d' in name:\n                continue\n            if 'rotary_emb.inv_freq' in name:\n                continue\n            if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):\n                continue\n            if self.config.tie_word_embeddings and 'lm_head.weight' in name:\n                continue\n            for (param_name, weight_name, shard_id) in stacked_params_mapping:\n                if weight_name not in name:\n                    continue\n                name = name.replace(weight_name, param_name)\n                param = params_dict[name]\n                load_weight(param, loaded_weight, shard_id=shard_id)\n                break\n            else:\n                param = params_dict[name]\n                load_weight(param, loaded_weight)\n"
  },
  {
    "path": "lmdeploy/pytorch/models/llava.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nfrom typing import Any, Dict, Iterable, List, Optional, Tuple\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn\nfrom transformers.configuration_utils import PretrainedConfig\nfrom transformers.modeling_outputs import BaseModelOutputWithPooling\nfrom transformers.models.llava.configuration_llava import LlavaConfig\n\nfrom lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor, PreprocessInputResult\nfrom lmdeploy.pytorch.model_inputs import StepContext, StepContextManager\nfrom lmdeploy.pytorch.multimodal.data_type import MultiModalData\nfrom lmdeploy.pytorch.nn.linear import build_colwise_linear, build_qkv_proj, build_rowwise_linear\nfrom lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight\n\nfrom .patch import build_model_from_hf_config\nfrom .utils.cudagraph import CudaGraphMixin\nfrom .utils.model import DeployModelMixin, vlm_model\n\n\nclass LlavaMultiModalProjector(nn.Module):\n\n    def __init__(self, config: LlavaConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        from transformers.activations import ACT2FN\n\n        self.linear_1 = nn.Linear(config.vision_config.hidden_size,\n                                  config.text_config.hidden_size,\n                                  bias=True,\n                                  dtype=dtype,\n                                  device=device)\n        self.act = ACT2FN[config.projector_hidden_act]\n        self.linear_2 = nn.Linear(config.text_config.hidden_size,\n                                  config.text_config.hidden_size,\n                                  bias=True,\n                                  dtype=dtype,\n                                  device=device)\n\n    def forward(self, image_features):\n        hidden_states = self.linear_1(image_features)\n        hidden_states = self.act(hidden_states)\n        hidden_states = self.linear_2(hidden_states)\n        return hidden_states\n\n\nclass CLIPVisionEmbeddings(nn.Module):\n    \"\"\"Clip vision embedding.\"\"\"\n\n    def __init__(self, config, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        self.config = config\n        self.embed_dim = config.hidden_size\n        self.image_size = config.image_size\n        self.patch_size = config.patch_size\n\n        self.class_embedding = nn.Parameter(torch.empty(self.embed_dim, dtype=dtype, device=device))\n\n        self.patch_embedding = nn.Conv2d(\n            in_channels=config.num_channels,\n            out_channels=self.embed_dim,\n            kernel_size=self.patch_size,\n            stride=self.patch_size,\n            bias=False,\n            dtype=dtype,\n            device=device,\n        )\n\n        self.num_patches = (self.image_size // self.patch_size)**2\n        self.num_positions = self.num_patches + 1\n        self.position_embedding = nn.Embedding(\n            self.num_positions,\n            self.embed_dim,\n            dtype=dtype,\n            device=device,\n        )\n        self.register_buffer('position_ids',\n                             torch.arange(self.num_positions, device=device).expand((1, -1)),\n                             persistent=False)\n\n    def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:\n        \"\"\"This method allows to interpolate the pre-trained position\n        encodings, to be able to use the model on higher resolution images.\n\n        This method is also adapted to support torch.jit tracing.\n        \"\"\"\n\n        num_patches = embeddings.shape[1] - 1\n        position_embedding = self.position_embedding.weight.unsqueeze(0)\n        num_positions = position_embedding.shape[1] - 1\n\n        # always interpolate when tracing\n        # to ensure the exported model works for dynamic input shapes\n        if not torch.jit.is_tracing() and num_patches == num_positions and height == width:\n            return self.position_embedding(self.position_ids)\n\n        from transformers.utils import torch_int\n\n        class_pos_embed = position_embedding[:, :1]\n        patch_pos_embed = position_embedding[:, 1:]\n\n        dim = embeddings.shape[-1]\n\n        new_height = height // self.patch_size\n        new_width = width // self.patch_size\n\n        sqrt_num_positions = torch_int(num_positions**0.5)\n        patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)\n        patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)\n\n        patch_pos_embed = nn.functional.interpolate(\n            patch_pos_embed,\n            size=(new_height, new_width),\n            mode='bicubic',\n            align_corners=False,\n        )\n\n        patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)\n\n        return torch.cat((class_pos_embed, patch_pos_embed), dim=1)\n\n    def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor:\n        batch_size, _, height, width = pixel_values.shape\n        if not interpolate_pos_encoding and (height != self.image_size or width != self.image_size):\n            raise ValueError(f\"Input image size ({height}*{width}) doesn't match model\"\n                             f' ({self.image_size}*{self.image_size}).')\n        target_dtype = self.patch_embedding.weight.dtype\n        patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))  # shape = [*, width, grid, grid]\n        patch_embeds = patch_embeds.flatten(2).transpose(1, 2)\n\n        class_embeds = self.class_embedding.expand(batch_size, 1, -1)\n        embeddings = torch.cat([class_embeds, patch_embeds], dim=1)\n        if interpolate_pos_encoding:\n            embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)\n        else:\n            embeddings = embeddings + self.position_embedding(self.position_ids)\n        return embeddings\n\n\nclass CLIPAttention(nn.Module):\n    \"\"\"Clip attention.\"\"\"\n\n    def __init__(self, config, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        self.config = config\n        quantization_config = getattr(config, 'quantization_config', None)\n        self.embed_dim = config.hidden_size\n        self.num_heads = config.num_attention_heads\n        self.head_dim = self.embed_dim // self.num_heads\n\n        self.qkv_proj = build_qkv_proj(\n            self.embed_dim,\n            num_q_heads=self.num_heads,\n            num_kv_heads=self.num_heads,\n            head_size=self.head_dim,\n            bias=True,\n            quant_config=quantization_config,\n            dtype=dtype,\n            device=device,\n        )\n\n        self.scale = self.head_dim**-0.5\n\n        # o_proj\n        self.out_proj = build_rowwise_linear(self.embed_dim,\n                                             self.embed_dim,\n                                             bias=True,\n                                             quant_config=quantization_config,\n                                             dtype=dtype,\n                                             device=device,\n                                             is_tp=True)\n\n    def forward(\n        self,\n        hidden_states,\n        attention_mask: Optional[torch.Tensor] = None,\n        causal_attention_mask: Optional[torch.Tensor] = None,\n    ):\n        \"\"\"forward.\"\"\"\n        # qkv proj\n        qkv_states = self.qkv_proj(hidden_states)\n        q, k, v = self.qkv_proj.split_qkv(qkv_states)\n\n        q = q.transpose(1, 2)\n        k = k.transpose(1, 2)\n        v = v.transpose(1, 2)\n\n        if attention_mask is not None and causal_attention_mask is not None:\n            attn_mask = attention_mask + causal_attention_mask\n        elif causal_attention_mask is not None:\n            attn_mask = causal_attention_mask\n        else:\n            attn_mask = attention_mask\n\n        attn_output = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, scale=self.scale)\n\n        # o proj\n        attn_output = attn_output.transpose(1, 2)\n        attn_output = attn_output.flatten(-2, -1)\n        attn_output = self.out_proj(attn_output)\n        return attn_output\n\n\nclass CLIPMLP(nn.Module):\n    \"\"\"Clip mlp.\"\"\"\n\n    def __init__(self, config, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        self.config = config\n        quantization_config = getattr(config, 'quantization_config', None)\n        from transformers.activations import ACT2FN\n        self.config = config\n        self.activation_fn = ACT2FN[config.hidden_act]\n        self.fc1 = build_colwise_linear(\n            config.hidden_size,\n            config.intermediate_size,\n            bias=True,\n            dtype=dtype,\n            device=device,\n            quant_config=quantization_config,\n            is_tp=True,\n        )\n        self.fc2 = build_rowwise_linear(\n            config.intermediate_size,\n            config.hidden_size,\n            bias=True,\n            dtype=dtype,\n            device=device,\n            quant_config=quantization_config,\n            is_tp=True,\n        )\n\n    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        \"\"\"forward.\"\"\"\n        hidden_states = self.fc1(hidden_states)\n        hidden_states = self.activation_fn(hidden_states)\n        hidden_states = self.fc2(hidden_states)\n        return hidden_states\n\n\nclass CLIPEncoderLayer(nn.Module):\n    \"\"\"Clip encoder layer.\"\"\"\n\n    def __init__(self, config, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        self.embed_dim = config.hidden_size\n        self.self_attn = CLIPAttention(config, dtype=dtype, device=device)\n        self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps, dtype=dtype, device=device)\n        self.mlp = CLIPMLP(config, dtype=dtype, device=device)\n        self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps, dtype=dtype, device=device)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: torch.Tensor,\n        causal_attention_mask: torch.Tensor,\n    ):\n        \"\"\"forward.\"\"\"\n        residual = hidden_states\n\n        hidden_states = self.layer_norm1(hidden_states)\n        hidden_states = self.self_attn(\n            hidden_states=hidden_states,\n            attention_mask=attention_mask,\n            causal_attention_mask=causal_attention_mask,\n        )\n        hidden_states = residual + hidden_states\n\n        residual = hidden_states\n        hidden_states = self.layer_norm2(hidden_states)\n        hidden_states = self.mlp(hidden_states)\n        hidden_states = residual + hidden_states\n\n        return hidden_states\n\n\nclass CLIPEncoder(nn.Module):\n    \"\"\"Clip encoder.\"\"\"\n\n    def __init__(self, config, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        self.config = config\n        self.layers = nn.ModuleList(\n            [CLIPEncoderLayer(config, dtype=dtype, device=device) for _ in range(config.num_hidden_layers)])\n\n    def forward(\n        self,\n        inputs_embeds,\n        attention_mask: Optional[torch.Tensor] = None,\n        causal_attention_mask: Optional[torch.Tensor] = None,\n        vision_feature_layer: int = -1,\n    ):\n        \"\"\"forward.\"\"\"\n        hidden_states = inputs_embeds\n        num_vision_layers = len(self.layers) + vision_feature_layer + 1\n        for _, encoder_layer in enumerate(self.layers[:num_vision_layers]):\n            layer_outputs = encoder_layer(\n                hidden_states,\n                attention_mask,\n                causal_attention_mask=causal_attention_mask,\n            )\n\n            hidden_states = layer_outputs\n\n        return hidden_states\n\n\nclass CLIPVisionTransformer(nn.Module):\n    \"\"\"Clip vision transformer.\"\"\"\n\n    def __init__(self, config, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        self.config = config\n        embed_dim = config.hidden_size\n\n        self.embeddings = CLIPVisionEmbeddings(config, dtype=dtype, device=device)\n        self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps, dtype=dtype, device=device)\n        self.encoder = CLIPEncoder(config, dtype=dtype, device=device)\n        self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps, dtype=dtype, device=device)\n\n    def forward(\n        self,\n        pixel_values: torch.FloatTensor,\n        interpolate_pos_encoding: bool = False,\n        vision_feature_layer: int = -1,\n    ) -> BaseModelOutputWithPooling:\n        \"\"\"forward.\"\"\"\n        hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)\n        hidden_states = self.pre_layrnorm(hidden_states)\n\n        encoder_outputs = self.encoder(inputs_embeds=hidden_states, vision_feature_layer=vision_feature_layer)\n\n        last_hidden_state = encoder_outputs\n        pooled_output = last_hidden_state[:, 0, :]\n        pooled_output = self.post_layernorm(pooled_output)\n\n        return BaseModelOutputWithPooling(\n            last_hidden_state=last_hidden_state,\n            pooler_output=pooled_output,\n            hidden_states=None,\n            attentions=None,\n        )\n\n\n@vlm_model\nclass CLIPVisionModel(nn.Module):\n    \"\"\"Clip vision model.\"\"\"\n\n    def __init__(self, config, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        self.vision_model = CLIPVisionTransformer(config, dtype=dtype, device=device)\n\n    def forward(self,\n                pixel_values: torch.FloatTensor,\n                interpolate_pos_encoding: bool = False,\n                vision_feature_layer: int = -1,\n                **kwargs):\n        \"\"\"forward.\"\"\"\n        return self.vision_model(pixel_values,\n                                 interpolate_pos_encoding=interpolate_pos_encoding,\n                                 vision_feature_layer=vision_feature_layer)\n\n\ndef build_vision_model(vision_config, dtype: torch.dtype = None, device: torch.device = None):\n    \"\"\"Build vision model.\"\"\"\n    model_type = vision_config.model_type\n\n    if model_type == 'clip_vision_model':\n        return CLIPVisionModel(vision_config, dtype, device)\n    else:\n        raise NotImplementedError(f'<{model_type}> is not implemented.')\n\n\nclass LlavaForConditionalGeneration(nn.Module, CudaGraphMixin, DeployModelMixin):\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 ctx_mgr: StepContextManager,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__()\n        self.config = config\n        self.ctx_mgr = ctx_mgr\n        text_config = config.text_config\n\n        self.vision_tower = build_vision_model(config.vision_config, dtype=dtype, device=device)\n\n        self.language_model = build_model_from_hf_config(text_config, dtype=dtype, device=device)\n\n        self.multi_modal_projector = LlavaMultiModalProjector(config, dtype=dtype, device=device)\n\n        self.input_processor = LLavaInputProcessor(config, dtype)\n\n    def get_image_features(self,\n                           pixel_values,\n                           vision_feature_layer: int = -1,\n                           vision_feature_select_strategy: str = 'default'):\n        \"\"\"Get image features.\"\"\"\n        selected_image_feature = self.vision_tower(pixel_values, vision_feature_layer=vision_feature_layer)[0]\n        if vision_feature_select_strategy == 'default':\n            selected_image_feature = selected_image_feature[:, 1:]\n        elif vision_feature_select_strategy == 'full':\n            selected_image_feature = selected_image_feature\n        else:\n            raise ValueError(f'Unexpected select feature strategy: {vision_feature_select_strategy}'  # noqa: E501\n                             )\n        image_features = self.multi_modal_projector(selected_image_feature)\n        image_features = image_features.flatten(0, 1)[None]\n\n        return image_features\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        past_key_values: List[List[torch.Tensor]],\n        attn_metadata: Any = None,\n        pixel_values: torch.Tensor = None,\n        image_mask: torch.Tensor = None,\n        inputs_embeds: torch.Tensor = None,\n        **kwargs,\n    ):\n        if inputs_embeds is None:\n            image_features = None\n            if pixel_values is not None:\n                vision_feature_layer = self.config.vision_feature_layer\n                select_strategy = self.config.vision_feature_select_strategy\n                image_features = self.get_image_features(pixel_values,\n                                                         vision_feature_layer=vision_feature_layer,\n                                                         vision_feature_select_strategy=select_strategy)\n            inputs_embeds = self.language_model.get_input_embeddings()(input_ids)\n            if pixel_values is not None:\n                inputs_embeds.masked_scatter_(image_mask[..., None], image_features)\n\n        return self.language_model.forward(input_ids=input_ids,\n                                           inputs_embeds=inputs_embeds,\n                                           past_key_values=past_key_values,\n                                           position_ids=position_ids,\n                                           attn_metadata=attn_metadata)\n\n    def get_logits(self, hidden_states: torch.Tensor):\n        \"\"\"Compute logits of the model output.\"\"\"\n        return self.language_model.get_logits(hidden_states)\n\n    def get_input_embeddings(self):\n        \"\"\"Get input embeddings.\"\"\"\n        return self.language_model.get_input_embeddings()\n\n    def prepare_inputs_for_generation(\n        self,\n        past_key_values: List[List[torch.Tensor]],\n        inputs_embeds: torch.Tensor = None,\n        context: StepContext = None,\n    ):\n        \"\"\"Prepare input.\"\"\"\n        input_ids = context.input_ids\n        position_ids = context.position_ids\n        attn_metadata = context.attn_metadata\n\n        # vision inputs\n        pixel_values = None\n        image_mask = None\n        if context.input_multimodals is not None:\n            pixel_values = [input_mm.get('image', []) for input_mm in context.input_multimodals]\n            # flatten batch\n            pixel_values = [data for im_data in pixel_values for data in im_data]\n            if len(pixel_values) > 0:\n                image_token_id = pixel_values[0].meta['image_token_id']\n                image_mask = input_ids == image_token_id\n                pixel_values = torch.cat([data.data for data in pixel_values])\n            else:\n                pixel_values = None\n                image_mask = None\n\n        # get inputs from context\n        vision_embeddings = context.input_embeddings\n        vision_embedding_indexing = context.input_embedding_indexing\n\n        if vision_embeddings is not None and len(vision_embeddings) > 0:\n            if inputs_embeds is None:\n                inputs_embeds = self.get_input_embeddings()(input_ids)\n            inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds)\n\n        return dict(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            pixel_values=pixel_values,\n            image_mask=image_mask,\n            inputs_embeds=inputs_embeds,\n        )\n\n    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):\n        \"\"\"Load weights.\"\"\"\n\n        stacked_params_mapping = [\n            # (param_name, shard_name, shard_id)\n            ('.qkv_proj', '.q_proj', 'q'),\n            ('.qkv_proj', '.k_proj', 'k'),\n            ('.qkv_proj', '.v_proj', 'v'),\n        ]\n\n        # vis model\n        lang_prefix = 'language_model.'\n        prefix_length = len(lang_prefix)\n        new_weights = dict()\n        params_dict = dict(self.named_parameters())\n        for name, loaded_weight in weights:\n            if name.startswith(lang_prefix):\n                new_key = name[prefix_length:]\n                new_weights[new_key] = loaded_weight\n                continue\n\n            for (param_name, weight_name, shard_id) in stacked_params_mapping:\n                if weight_name not in name:\n                    continue\n                name = name.replace(weight_name, param_name)\n                param = params_dict[name]\n                load_weight(param, loaded_weight, shard_id=shard_id)\n                break\n            else:\n                param = params_dict[name]\n                load_weight(param, loaded_weight)\n\n        self.language_model.load_weights(new_weights.items())\n\n    def get_input_processor(self) -> BaseModelInputProcessor:\n        \"\"\"Get input processor.\"\"\"\n        return self.input_processor\n\n\nclass LLavaInputProcessor(BaseModelInputProcessor):\n    \"\"\"Llava input processor.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype) -> None:\n        self.config = config\n        self.dtype = dtype\n\n    def preprocess_input(self,\n                         input_ids: List[int],\n                         input_multimodals: List[Dict[str, Any]] = None,\n                         **kwargs) -> PreprocessInputResult:\n        \"\"\"Prepare multimodal input.\"\"\"\n        if input_multimodals is None or len(input_multimodals) == 0:\n            return input_ids, input_multimodals\n\n        input_imgs = []\n        for input_mm in input_multimodals:\n            pixel_values = input_mm['pixel_values'].to(self.dtype)\n            offset = input_mm['offset']\n            image_token_id = input_mm['image_token_id']\n            num_pad = input_mm['image_tokens']\n            if isinstance(num_pad, torch.Tensor):\n                num_pad = num_pad.item()\n\n            mm_data = MultiModalData(data=pixel_values,\n                                     start=offset,\n                                     end=offset + num_pad,\n                                     meta=dict(image_token_id=image_token_id))\n            input_imgs.append(mm_data)\n\n        result = PreprocessInputResult(\n            input_ids=input_ids,\n            input_multimodals=dict(image=input_imgs),\n        )\n        return result\n\n\ndef get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):\n\n    from transformers.image_processing_utils import select_best_resolution\n\n    if not isinstance(grid_pinpoints, list):\n        raise TypeError('grid_pinpoints should be a list of tuples or lists')\n\n    if not isinstance(image_size, (list, tuple)):\n        image_size = image_size.tolist()\n\n    height, width = select_best_resolution(image_size, grid_pinpoints)\n    return height // patch_size, width // patch_size\n\n\ndef unpad_image(tensor, original_size):\n    \"\"\"Unpads a PyTorch tensor of a padded and resized image.\"\"\"\n    if not isinstance(original_size, (list, tuple)):\n        original_size = original_size.tolist()\n    original_height, original_width = original_size\n    current_height, current_width = tensor.shape[1:]\n\n    original_aspect_ratio = original_width / original_height\n    current_aspect_ratio = current_width / current_height\n\n    if original_aspect_ratio > current_aspect_ratio:\n        scale_factor = current_width / original_width\n        new_height = int(round(original_height * scale_factor, 7))\n        padding = (current_height - new_height) // 2\n        unpadded_tensor = tensor[:, padding:current_height - padding, :]\n    else:\n        scale_factor = current_height / original_height\n        new_width = int(round(original_width * scale_factor, 7))\n        padding = (current_width - new_width) // 2\n        unpadded_tensor = tensor[:, :, padding:current_width - padding]\n\n    return unpadded_tensor\n\n\ndef image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int):\n    \"\"\"Calculate the number of patches after the preprocessing for images of\n    any resolution.\"\"\"\n    from transformers.image_processing_utils import select_best_resolution\n    if not isinstance(grid_pinpoints, list):\n        raise TypeError('grid_pinpoints should be a list of tuples or lists')\n\n    if not isinstance(image_size, (list, tuple)):\n        image_size = image_size.tolist()\n\n    best_resolution = select_best_resolution(image_size, grid_pinpoints)\n    height, width = best_resolution\n\n    num_patches = (height // patch_size) * (width // patch_size)\n    # add the base patch\n    num_patches += 1\n    return num_patches\n\n\nclass LlavaNextForConditionalGeneration(LlavaForConditionalGeneration):\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 ctx_mgr: StepContextManager,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__(config=config, ctx_mgr=ctx_mgr, dtype=dtype, device=device)\n        self.image_newline = nn.Parameter(torch.empty(config.text_config.hidden_size, dtype=dtype, device=device))\n        self.input_processor = LLavaNextInputProcessor(config, dtype)\n\n    def get_image_features(\n        self,\n        pixel_values: torch.FloatTensor,\n        image_sizes: torch.Tensor,\n        vision_feature_layer: int,\n        vision_feature_select_strategy: str,\n    ):\n        # ! infer image_num_patches from image_sizes\n        image_num_patches = [\n            image_size_to_num_patches(\n                image_size=imsize,\n                grid_pinpoints=self.config.image_grid_pinpoints,\n                patch_size=self.config.vision_config.image_size,\n            ) for imsize in image_sizes\n        ]\n        if pixel_values.dim() == 5:\n            # stacked if input is\n            # (batch_size, num_patches, num_channels, height, width)\n            _pixel_values_list = [pix_val[:num_patch] for pix_val, num_patch in zip(pixel_values, image_num_patches)]\n            pixel_values = torch.cat(_pixel_values_list, dim=0)\n        elif pixel_values.dim() != 4:\n            # otherwise has to be stacked from list of\n            # (num_patches, num_channels, height, width)\n            raise ValueError(f'pixel_values of shape {pixel_values.shape}, '\n                             'expect to be of 4 or 5 dimensions')\n\n        selected_image_feature = self.vision_tower(pixel_values, vision_feature_layer=vision_feature_layer)[0]\n        if vision_feature_select_strategy == 'default':\n            selected_image_feature = selected_image_feature[:, 1:]\n        elif vision_feature_select_strategy == 'full':\n            selected_image_feature = selected_image_feature\n        image_features = self.multi_modal_projector(selected_image_feature)\n        image_features = torch.split(image_features, image_num_patches, dim=0)\n        return image_features\n\n    def pack_image_features(self, image_features, image_sizes, vision_feature_select_strategy, image_newline=None):\n\n        new_image_features = []\n        feature_lens = []\n        for image_idx, image_feature in enumerate(image_features):\n            if image_feature.shape[0] > 1:\n                base_image_feature = image_feature[0]\n                image_feature = image_feature[1:]\n                height = width = (self.config.vision_config.image_size // self.config.vision_config.patch_size)\n\n                if vision_feature_select_strategy == 'default':\n                    expected_num_patches = height * width\n                elif vision_feature_select_strategy == 'full':\n                    expected_num_patches = height * width + 1\n                if expected_num_patches != base_image_feature.shape[0]:\n                    raise ValueError('The number of patches is '\n                                     'not consistent with the image size.')\n\n                (num_patch_height, num_patch_width) = get_anyres_image_grid_shape(\n                    image_sizes[image_idx],\n                    self.config.image_grid_pinpoints,\n                    self.config.vision_config.image_size,\n                )\n                image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)\n                image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()\n                image_feature = image_feature.flatten(1, 2).flatten(2, 3)\n                image_feature = unpad_image(image_feature, image_sizes[image_idx])\n                if image_newline is not None:\n                    image_feature = torch.cat(\n                        (\n                            image_feature,\n                            image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.dtype),\n                        ),\n                        dim=-1,\n                    )\n                image_feature = image_feature.flatten(1, 2).transpose(0, 1)\n                image_feature = torch.cat((base_image_feature, image_feature), dim=0)\n            else:\n                image_feature = image_feature[0]\n                if image_newline is not None:\n                    image_feature = torch.cat((image_feature, image_newline[None].to(image_feature)), dim=0)\n            new_image_features.append(image_feature)\n            feature_lens.append(image_feature.size(0))\n        image_features = torch.cat(new_image_features, dim=0)\n        return image_features\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        past_key_values: List[List[torch.Tensor]],\n        attn_metadata: Any = None,\n        pixel_values: torch.Tensor = None,\n        image_sizes: torch.Tensor = None,\n        image_mask: torch.Tensor = None,\n        inputs_embeds: torch.Tensor = None,\n        **kwargs,\n    ):\n        if inputs_embeds is None:\n            image_features = None\n            if pixel_values is not None:\n                vision_feature_layer = self.config.vision_feature_layer\n                select_strategy = self.config.vision_feature_select_strategy\n                image_sizes = image_sizes.tolist()\n                image_features = self.get_image_features(pixel_values,\n                                                         image_sizes,\n                                                         vision_feature_layer=vision_feature_layer,\n                                                         vision_feature_select_strategy=select_strategy)\n                image_features = self.pack_image_features(\n                    image_features,\n                    image_sizes,\n                    vision_feature_select_strategy=select_strategy,\n                    image_newline=self.image_newline,\n                )\n                image_features = image_features[None]\n            inputs_embeds = self.language_model.get_input_embeddings()(input_ids)\n            if pixel_values is not None:\n                inputs_embeds.masked_scatter_(image_mask[..., None], image_features)\n\n        return self.language_model.forward(input_ids=input_ids,\n                                           inputs_embeds=inputs_embeds,\n                                           past_key_values=past_key_values,\n                                           position_ids=position_ids,\n                                           attn_metadata=attn_metadata)\n\n    def get_input_processor(self) -> BaseModelInputProcessor:\n        \"\"\"Get input processor.\"\"\"\n        return self.input_processor\n\n    def prepare_inputs_for_generation(\n        self,\n        past_key_values: List[List[torch.Tensor]],\n        inputs_embeds: torch.Tensor = None,\n        context: StepContext = None,\n    ):\n        \"\"\"Prepare input.\"\"\"\n        input_ids = context.input_ids\n        position_ids = context.position_ids\n        attn_metadata = context.attn_metadata\n\n        # vision inputs\n        pixel_values = None\n        image_sizes = None\n        image_mask = None\n        if context.input_multimodals is not None:\n            img_mms = [input_mm.get('image', []) for input_mm in context.input_multimodals]\n            # flatten batch\n            img_mms = [data for im_data in img_mms for data in im_data]\n            if len(img_mms) > 0:\n                image_token_id = img_mms[0].meta['image_token_id']\n                image_mask = input_ids == image_token_id\n                pixel_values = torch.cat([data.data.flatten(0, 1) for data in img_mms])\n                image_sizes = torch.cat([data.meta['image_sizes'] for data in img_mms])\n            else:\n                pixel_values = None\n                image_sizes = None\n\n        # get inputs from context\n        vision_embeddings = context.input_embeddings\n        vision_embedding_indexing = context.input_embedding_indexing\n\n        if vision_embeddings is not None and len(vision_embeddings) > 0:\n            if inputs_embeds is None:\n                inputs_embeds = self.get_input_embeddings()(input_ids)\n            inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds)\n\n        return dict(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            pixel_values=pixel_values,\n            image_sizes=image_sizes,\n            image_mask=image_mask,\n            inputs_embeds=inputs_embeds,\n        )\n\n\nclass LLavaNextInputProcessor(BaseModelInputProcessor):\n    \"\"\"Llava input processor.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype) -> None:\n        self.config = config\n        self.dtype = dtype\n\n    def preprocess_input(self,\n                         input_ids: List[int],\n                         input_multimodals: List[Dict[str, Any]] = None,\n                         **kwargs) -> PreprocessInputResult:\n        \"\"\"Prepare multimodal input.\"\"\"\n        if input_multimodals is None or len(input_multimodals) == 0:\n            return input_ids, input_multimodals\n\n        input_imgs = []\n        for input_mm in input_multimodals:\n            pixel_values = input_mm['pixel_values'].to(self.dtype)\n            image_sizes = input_mm['image_sizes']\n            offset = input_mm['offset']\n            image_token_id = input_mm['image_token_id']\n            num_pad = input_mm['image_tokens']\n            if isinstance(num_pad, torch.Tensor):\n                num_pad = num_pad.item()\n\n            mm_data = MultiModalData(data=pixel_values,\n                                     start=offset,\n                                     end=offset + num_pad,\n                                     meta=dict(image_sizes=image_sizes, image_token_id=image_token_id))\n            input_imgs.append(mm_data)\n\n        result = PreprocessInputResult(\n            input_ids=input_ids,\n            input_multimodals=dict(image=input_imgs),\n        )\n        return result\n"
  },
  {
    "path": "lmdeploy/pytorch/models/minicpm3.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nimport math\nfrom typing import Any, Iterable, List, Optional, Tuple\n\nimport torch\nfrom torch import nn\nfrom transformers.configuration_utils import PretrainedConfig\n\nfrom lmdeploy.pytorch.distributed import get_tp_world_rank\nfrom lmdeploy.pytorch.model_inputs import StepContext, StepContextManager\nfrom lmdeploy.pytorch.nn import Attention, RMSNorm, RopeType, SiluAndMul, build_rotary_embedding\nfrom lmdeploy.pytorch.nn.linear import build_colwise_linear, build_merged_colwise_linear, build_rowwise_linear\nfrom lmdeploy.pytorch.nn.rotary_embedding import (ApplyRotaryEmb, LongRoPEScalingParameters, get_rope_parameters,\n                                                  get_rope_theta)\nfrom lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight\n\nfrom .utils.cudagraph import CudaGraphMixin\n\n\n# TODO use MLA of pytorch engine\nclass MiniCPMAttention(nn.Module):\n    \"\"\"Minicpm3 attention.\"\"\"\n\n    def __init__(self, config: Any, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        quantization_config = None\n        self.q_lora_rank = config.q_lora_rank\n        self.hidden_size = config.hidden_size\n        self.num_heads = config.num_attention_heads\n        self.qk_rope_head_dim = config.qk_rope_head_dim\n        self.kv_lora_rank = config.kv_lora_rank\n        self.v_head_dim = config.hidden_size // config.num_attention_heads\n        self.qk_nope_head_dim = config.qk_nope_head_dim\n        self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim\n\n        if self.q_lora_rank is None:\n            self.q_proj = build_colwise_linear(\n                self.hidden_size,\n                self.num_heads * self.q_head_dim,\n                bias=False,\n                dtype=dtype,\n                device=device,\n                is_tp=True,\n            )\n        else:\n            self.q_a_proj = build_colwise_linear(\n                self.hidden_size,\n                config.q_lora_rank,\n                bias=config.attention_bias,\n                dtype=dtype,\n                device=device,\n                is_tp=False,\n            )\n            self.q_a_layernorm = RMSNorm(config.q_lora_rank,\n                                         1e-6,\n                                         quant_config=quantization_config,\n                                         dtype=dtype,\n                                         device=device)\n            self.q_b_proj = build_colwise_linear(\n                config.q_lora_rank,\n                self.num_heads * self.q_head_dim,\n                bias=False,\n                dtype=dtype,\n                device=device,\n                is_tp=True,\n            )\n\n        self.kv_a_proj_with_mqa = build_colwise_linear(\n            self.hidden_size,\n            config.kv_lora_rank + config.qk_rope_head_dim,\n            bias=config.attention_bias,\n            dtype=dtype,\n            device=device,\n            is_tp=False,\n        )\n        self.kv_a_layernorm = RMSNorm(config.kv_lora_rank,\n                                      1e-6,\n                                      quant_config=quantization_config,\n                                      dtype=dtype,\n                                      device=device)\n        self.kv_b_proj = build_colwise_linear(\n            config.kv_lora_rank,\n            self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),\n            bias=False,\n            dtype=dtype,\n            device=device,\n            is_tp=False,\n        )\n\n        self.apply_rotary_pos_emb = ApplyRotaryEmb()\n        self.softmax_scale = self.q_head_dim**(-0.5)\n        self.attn_fwd = Attention(self.num_heads,\n                                  config.kv_lora_rank + self.qk_rope_head_dim,\n                                  scale=self.softmax_scale,\n                                  num_kv_heads=config.num_key_value_heads)\n\n        self.o_proj = build_rowwise_linear(\n            self.num_heads * self.v_head_dim,\n            self.hidden_size,\n            bias=config.attention_bias,\n            dtype=dtype,\n            device=device,\n            is_tp=True,\n        )\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],\n        past_key_value: Optional[Tuple[torch.Tensor]] = None,\n        attn_metadata: Any = None,\n    ):\n        \"\"\"Rewrite of LlamaAttention.forward.\"\"\"\n        world_size, _ = get_tp_world_rank()\n        num_heads = self.num_heads // world_size\n        bsz, q_len, _ = hidden_states.size()\n\n        # qkv_proj\n        bsz, q_len, _ = hidden_states.size()\n\n        q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))\n        q = q.view(bsz, q_len, num_heads, self.q_head_dim)\n        q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)\n\n        compressed_kv = self.kv_a_proj_with_mqa(hidden_states)\n        compressed_kv, k_pe = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)\n        k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim)\n        kv = (self.kv_b_proj(self.kv_a_layernorm(compressed_kv)).view(bsz, q_len, num_heads,\n                                                                      self.qk_nope_head_dim + self.v_head_dim))\n\n        k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)\n\n        # apply rotary embedding\n        cos, sin = rotary_pos_emb\n        q_pe, k_pe = self.apply_rotary_pos_emb(\n            q_pe,\n            k_pe,\n            cos,\n            sin,\n            inplace=True,\n        )\n\n        query_states = k_pe.new_empty(bsz, q_len, self.num_heads, self.q_head_dim)\n        query_states[:, :, :, :self.qk_nope_head_dim] = q_nope\n        query_states[:, :, :, self.qk_nope_head_dim:] = q_pe\n\n        key_states = k_pe.new_empty(bsz, q_len, self.num_heads, self.q_head_dim)\n        key_states[:, :, :, :self.qk_nope_head_dim] = k_nope\n        key_states[:, :, :, self.qk_nope_head_dim:] = k_pe\n\n        if self.q_head_dim != self.v_head_dim:\n            value_states = torch.nn.functional.pad(value_states, [0, self.q_head_dim - self.v_head_dim])\n\n        attn_output = self.attn_fwd(\n            query_states,\n            key_states,\n            value_states,\n            past_key_value[0],\n            past_key_value[1],\n            attn_metadata,\n            inplace=False,\n        )\n        if self.q_head_dim != self.v_head_dim:\n            attn_output = attn_output[:, :, :, :self.v_head_dim]\n\n        attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim).contiguous()\n        attn_output = self.o_proj(attn_output)\n\n        return attn_output\n\n\nclass MiniCPMMLP(nn.Module):\n    \"\"\"mlp.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        quantization_config = getattr(config, 'quantization_config', None)\n        # gate up\n        self.gate_up_proj = build_merged_colwise_linear(\n            config.hidden_size,\n            [config.intermediate_size, config.intermediate_size],\n            bias=False,\n            dtype=dtype,\n            device=device,\n            quant_config=quantization_config,\n            is_tp=True,\n        )\n\n        # silu and mul\n        self.act_fn = SiluAndMul(inplace=True)\n\n        # down\n        self.down_proj = build_rowwise_linear(config.intermediate_size,\n                                              config.hidden_size,\n                                              bias=False,\n                                              quant_config=quantization_config,\n                                              dtype=dtype,\n                                              device=device,\n                                              is_tp=True)\n\n    def forward(self, x):\n        \"\"\"forward.\"\"\"\n        gate_up = self.gate_up_proj(x)\n        act = self.act_fn(gate_up)\n        return self.down_proj(act)\n\n\nclass MiniCPMDecoderLayer(nn.Module):\n    \"\"\"Decoder layer.\"\"\"\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 layer_idx: int,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__()\n        self.layer_idx = layer_idx\n        quantization_config = getattr(config, 'quantization_config', None)\n\n        # build attention layer\n        self.self_attn = MiniCPMAttention(config, dtype=dtype, device=device)\n\n        # build MLP\n        self.mlp = MiniCPMMLP(config, dtype=dtype, device=device)\n\n        # build input layer norm\n        self.input_layernorm = RMSNorm(config.hidden_size,\n                                       config.rms_norm_eps,\n                                       quant_config=quantization_config,\n                                       dtype=dtype,\n                                       device=device)\n\n        # build attention layer norm\n        self.post_attention_layernorm = RMSNorm(config.hidden_size,\n                                                config.rms_norm_eps,\n                                                quant_config=quantization_config,\n                                                dtype=dtype,\n                                                device=device)\n        self.scale_depth = config.scale_depth\n        self.num_hidden_layers = config.num_hidden_layers\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],\n        past_key_value: Optional[List[torch.FloatTensor]],\n        attn_metadata: Any = None,\n    ):\n\n        residual = hidden_states\n        hidden_states = self.input_layernorm(hidden_states)\n\n        # Self Attention\n        hidden_states = self.self_attn(\n            hidden_states=hidden_states,\n            rotary_pos_emb=rotary_pos_emb,\n            past_key_value=past_key_value,\n            attn_metadata=attn_metadata,\n        )\n\n        hidden_states = residual + hidden_states * (self.scale_depth / math.sqrt(self.num_hidden_layers))\n\n        # Fully Connected\n        residual = hidden_states\n        hidden_states = self.post_attention_layernorm(hidden_states)\n\n        hidden_states = self.mlp(hidden_states)\n        hidden_states = residual + hidden_states * (self.scale_depth / math.sqrt(self.num_hidden_layers))\n\n        outputs = (hidden_states, residual)\n        return outputs\n\n\nclass MiniCPM3Model(nn.Module):\n    \"\"\"model.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n        self.scale_emb = config.scale_emb\n\n        self.embed_tokens = nn.Embedding(config.vocab_size,\n                                         config.hidden_size,\n                                         self.padding_idx,\n                                         dtype=dtype,\n                                         device=device)\n\n        # build all decode layers\n        self.layers = nn.ModuleList([\n            MiniCPMDecoderLayer(config, layer_idx, dtype=dtype, device=device)\n            for layer_idx in range(config.num_hidden_layers)\n        ])\n\n        # build norm\n        self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)\n        # build rotary embedding\n        emb_type = RopeType.LinearScaling\n        rope_dim = config.qk_rope_head_dim\n        rope_max_pos_emb = config.max_position_embeddings\n        rope_base = get_rope_theta(config)\n        rope_scaling = get_rope_parameters(config)\n        if rope_scaling is not None:\n            scaling_type = rope_scaling['type']\n            assert scaling_type in ['longrope', 'su']\n            emb_type = RopeType.LongRoPEScaling\n            ori_pos_emb = getattr(config, 'original_max_position_embeddings', rope_max_pos_emb)\n\n            longrope_params = LongRoPEScalingParameters(short_factor=rope_scaling['short_factor'],\n                                                        long_factor=rope_scaling['long_factor'],\n                                                        original_max_position_embeddings=ori_pos_emb)\n            self.rotary_emb = build_rotary_embedding(\n                rope_dim,\n                rope_max_pos_emb,\n                rope_base,\n                longrope_params=longrope_params,\n                emb_type=emb_type,\n            )\n        else:\n            self.rotary_emb = build_rotary_embedding(\n                rope_dim,\n                rope_max_pos_emb,\n                rope_base,\n                emb_type=emb_type,\n            )\n\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        attn_metadata: Any = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n    ):\n        \"\"\"Rewrite of LlamaModel.forward.\"\"\"\n\n        # token embedding\n        if inputs_embeds is None:\n            inputs_embeds = self.embed_tokens(input_ids) * self.scale_emb\n\n        hidden_states = inputs_embeds\n\n        # rotary embedding\n        cos, sin = self.rotary_emb(hidden_states, position_ids)\n        cos, sin = cos[0], sin[0]\n        rotary_pos_emb = (cos, sin)\n        # decoding\n        for idx, decoder_layer in enumerate(self.layers):\n            past_key_value = past_key_values[idx]\n            hidden_states, _ = decoder_layer(\n                hidden_states,\n                rotary_pos_emb=rotary_pos_emb,\n                past_key_value=past_key_value,\n                attn_metadata=attn_metadata,\n            )\n\n        # norm\n        hidden_states = self.norm(hidden_states)\n\n        return hidden_states\n\n    def get_input_embeddings(self):\n        \"\"\"Get input embeddings.\"\"\"\n        return self.embed_tokens\n\n\nclass MiniCPM3ForCausalLM(nn.Module, CudaGraphMixin):\n    \"\"\"Rewrote model of MiniCPM3ForCausalLM.\"\"\"\n\n    packed_modules_mapping = {\n        'gate_up_proj': [\n            'gate_proj',\n            'up_proj',\n        ],\n    }\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 ctx_mgr: StepContextManager,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__()\n        self.config = config\n        self.ctx_mgr = ctx_mgr\n        # build LLamaModel\n        self.model = MiniCPM3Model(config, dtype=dtype, device=device)\n        # build lm_head\n        self.lm_head = build_rowwise_linear(config.hidden_size,\n                                            config.vocab_size,\n                                            bias=False,\n                                            dtype=dtype,\n                                            device=device)\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        past_key_values: List[List[torch.Tensor]],\n        attn_metadata: Any = None,\n        inputs_embeds: torch.Tensor = None,\n        **kwargs,\n    ):\n        \"\"\"Model forward, return logits.\"\"\"\n        hidden_states = self.model(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            inputs_embeds=inputs_embeds,\n        )\n\n        logits = self.lm_head(hidden_states / (self.config.hidden_size / self.config.dim_model_base))\n        return logits\n\n    def update_weights(self):\n        \"\"\"Update weights.\"\"\"\n        if self.config.tie_word_embeddings:\n            self.lm_head.weight = self.model.embed_tokens.weight\n\n    def get_input_embeddings(self):\n        \"\"\"Get input embeddings.\"\"\"\n        return self.model.get_input_embeddings()\n\n    def prepare_inputs_for_generation(\n        self,\n        past_key_values: List[List[torch.Tensor]],\n        inputs_embeds: Optional[torch.Tensor] = None,\n        context: StepContext = None,\n    ):\n        \"\"\"Prepare input.\"\"\"\n        # get input_ids, position_ids and attention metadatas\n        input_ids = context.input_ids\n        position_ids = context.position_ids\n        attn_metadata = context.attn_metadata\n\n        # process vision embeddings\n        vision_embeddings = context.input_embeddings\n        vision_embedding_indexing = context.input_embedding_indexing\n        if vision_embeddings is not None and len(vision_embeddings) > 0:\n            if inputs_embeds is None:\n                inputs_embeds = self.get_input_embeddings()(input_ids)\n            inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds)\n\n        # inputs of forward\n        return dict(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            inputs_embeds=inputs_embeds,\n        )\n\n    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):\n        \"\"\"Load weights.\"\"\"\n        # modify from vllm\n        stacked_params_mapping = [\n            # (param_name, shard_name, shard_id)\n            # ('.qkv_proj', '.q_proj', 'q'),\n            # ('.qkv_proj', '.k_proj', 'k'),\n            # ('.qkv_proj', '.v_proj', 'v'),\n            ('.gate_up_proj', '.gate_proj', 0),\n            ('.gate_up_proj', '.up_proj', 1),\n        ]\n\n        params_dict = dict(self.named_parameters())\n        for name, loaded_weight in weights:\n            if 'rotary_emb.inv_freq' in name:\n                continue\n            if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):\n                continue\n            if self.config.tie_word_embeddings and 'lm_head.weight' in name:\n                continue\n            for (param_name, weight_name, shard_id) in stacked_params_mapping:\n                if weight_name not in name:\n                    continue\n                name = name.replace(weight_name, param_name)\n                param = params_dict[name]\n                load_weight(param, loaded_weight, shard_id=shard_id)\n                break\n            else:\n                param = params_dict[name]\n                load_weight(param, loaded_weight)\n"
  },
  {
    "path": "lmdeploy/pytorch/models/minicpmv26.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nfrom typing import Any, Iterable, List, Optional, Tuple\n\nimport torch\nfrom torch import nn\nfrom transformers.configuration_utils import PretrainedConfig\n\nfrom lmdeploy.pytorch.model_inputs import StepContext, StepContextManager\nfrom lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, RMSNorm, SiluAndMul, build_rotary_embedding_from_config\nfrom lmdeploy.pytorch.nn.linear import build_merged_colwise_linear, build_qkv_proj, build_rowwise_linear\nfrom lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight\n\nfrom .utils.cudagraph import CudaGraphMixin\n\n\nclass MiniCPMV26Attention(nn.Module):\n    \"\"\"Rewrite module of MiniCPMV26Attention.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        quantization_config = getattr(config, 'quantization_config', None)\n        num_heads = config.num_attention_heads\n        num_key_value_heads = config.num_key_value_heads\n        hidden_size = config.hidden_size\n        head_dim = getattr(config, 'head_dim', hidden_size // num_heads)\n        num_replicate_kv_heads = getattr(config, 'num_replicate_key_value_heads', 1)\n        # packed qkv\n        self.qkv_proj = build_qkv_proj(hidden_size,\n                                       num_q_heads=num_heads,\n                                       num_kv_heads=num_key_value_heads,\n                                       head_size=head_dim,\n                                       bias=True,\n                                       quant_config=quantization_config,\n                                       dtype=dtype,\n                                       device=device,\n                                       num_replicate_kv_heads=num_replicate_kv_heads)\n\n        # rotary embedding\n        self.apply_rotary_pos_emb = ApplyRotaryEmb()\n\n        # attention\n        self.attn_fwd = Attention(\n            num_heads,\n            head_dim,\n            num_kv_heads=num_key_value_heads,\n            v_head_size=head_dim,\n            sliding_window=config.sliding_window,\n        )\n\n        # o_proj\n        self.o_proj = build_rowwise_linear(num_heads * head_dim,\n                                           hidden_size,\n                                           bias=False,\n                                           quant_config=quantization_config,\n                                           dtype=dtype,\n                                           device=device,\n                                           is_tp=True)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],\n        past_key_value: Optional[Tuple[torch.Tensor]] = None,\n        attn_metadata: Any = None,\n    ):\n        \"\"\"Rewrite of LlamaAttention.forward.\"\"\"\n        # qkv proj\n        qkv_states = self.qkv_proj(hidden_states)\n        # (-1, heads, head_dim)\n        qkv_states = qkv_states.flatten(0, -2)\n        query_states, key_states, value_states = self.qkv_proj.split_qkv(qkv_states)\n\n        # apply rotary embedding\n        cos, sin = rotary_pos_emb\n        query_states, key_states = self.apply_rotary_pos_emb(\n            query_states,\n            key_states,\n            cos,\n            sin,\n            inplace=True,\n        )\n\n        # attention\n        attn_output = self.attn_fwd(\n            query_states,\n            key_states,\n            value_states,\n            past_key_value[0],\n            past_key_value[1],\n            attn_metadata,\n            k_scales_zeros=None if len(past_key_value) == 2 else past_key_value[2],\n            v_scales_zeros=None if len(past_key_value) == 2 else past_key_value[3],\n            inplace=True,\n        )\n        attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1)\n\n        # o proj\n        attn_output = self.o_proj(attn_output)\n        return attn_output\n\n\nclass MiniCPMV26MLP(nn.Module):\n    \"\"\"mlp.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        quantization_config = getattr(config, 'quantization_config', None)\n        # gate up\n        self.gate_up_proj = build_merged_colwise_linear(\n            config.hidden_size,\n            [config.intermediate_size, config.intermediate_size],\n            bias=False,\n            dtype=dtype,\n            device=device,\n            quant_config=quantization_config,\n            is_tp=True,\n        )\n\n        # silu and mul\n        self.act_fn = SiluAndMul(inplace=True)\n\n        # down\n        self.down_proj = build_rowwise_linear(config.intermediate_size,\n                                              config.hidden_size,\n                                              bias=False,\n                                              quant_config=quantization_config,\n                                              dtype=dtype,\n                                              device=device,\n                                              is_tp=True)\n\n    def forward(self, x):\n        \"\"\"forward.\"\"\"\n        gate_up = self.gate_up_proj(x)\n        act = self.act_fn(gate_up)\n        return self.down_proj(act)\n\n\nclass MiniCPMV26DecoderLayer(nn.Module):\n    \"\"\"Decoder layer.\"\"\"\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 layer_idx: int,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__()\n        self.layer_idx = layer_idx\n        quantization_config = getattr(config, 'quantization_config', None)\n\n        # build attention layer\n        self.self_attn = MiniCPMV26Attention(config, dtype=dtype, device=device)\n\n        # build MLP\n        self.mlp = MiniCPMV26MLP(config, dtype=dtype, device=device)\n\n        # build input layer norm\n        self.input_layernorm = RMSNorm(config.hidden_size,\n                                       config.rms_norm_eps,\n                                       quant_config=quantization_config,\n                                       dtype=dtype,\n                                       device=device)\n\n        # build attention layer norm\n        self.post_attention_layernorm = RMSNorm(config.hidden_size,\n                                                config.rms_norm_eps,\n                                                quant_config=quantization_config,\n                                                dtype=dtype,\n                                                device=device)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],\n        past_key_value: Optional[List[torch.FloatTensor]],\n        residual: Optional[torch.Tensor] = None,\n        attn_metadata: Any = None,\n    ):\n\n        if residual is None:\n            residual = hidden_states\n            hidden_states = self.input_layernorm(hidden_states)\n        else:\n            hidden_states, residual = self.input_layernorm(hidden_states, residual)\n\n        # Self Attention\n        hidden_states = self.self_attn(\n            hidden_states=hidden_states,\n            rotary_pos_emb=rotary_pos_emb,\n            past_key_value=past_key_value,\n            attn_metadata=attn_metadata,\n        )\n\n        # Fully Connected\n        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)\n        hidden_states = self.mlp(hidden_states)\n\n        outputs = (hidden_states, residual)\n        return outputs\n\n\nclass MiniCPMV26Model(nn.Module):\n    \"\"\"model.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n\n        self.embed_tokens = nn.Embedding(config.vocab_size,\n                                         config.hidden_size,\n                                         self.padding_idx,\n                                         dtype=dtype,\n                                         device=device)\n\n        # build all decode layers\n        self.layers = nn.ModuleList([\n            MiniCPMV26DecoderLayer(config, layer_idx, dtype=dtype, device=device)\n            for layer_idx in range(config.num_hidden_layers)\n        ])\n\n        # build norm\n        self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)\n\n        # build rotary embedding\n        self.rotary_emb = build_rotary_embedding_from_config(config)\n\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        attn_metadata: Any = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n    ):\n        \"\"\"Rewrite of LlamaModel.forward.\"\"\"\n\n        # token embedding\n        if inputs_embeds is None:\n            inputs_embeds = self.embed_tokens(input_ids)\n\n        hidden_states = inputs_embeds\n\n        # rotary embedding\n        cos, sin = self.rotary_emb(hidden_states, position_ids)\n        cos, sin = cos[0], sin[0]\n        rotary_pos_emb = (cos, sin)\n\n        # decoding\n        residual = None\n        for idx, decoder_layer in enumerate(self.layers):\n            past_key_value = past_key_values[idx]\n            hidden_states, residual = decoder_layer(\n                hidden_states,\n                rotary_pos_emb=rotary_pos_emb,\n                past_key_value=past_key_value,\n                residual=residual,\n                attn_metadata=attn_metadata,\n            )\n\n        # norm\n        hidden_states, _ = self.norm(hidden_states, residual)\n\n        return hidden_states\n\n    def get_input_embeddings(self):\n        \"\"\"Get input embeddings.\"\"\"\n        return self.embed_tokens\n\n\nclass MiniCPMVForCausalLM(nn.Module, CudaGraphMixin):\n    \"\"\"Rewrote model of MiniCPMVForCausalLM.\"\"\"\n\n    packed_modules_mapping = {\n        'gate_up_proj': [\n            'gate_proj',\n            'up_proj',\n        ],\n    }\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 ctx_mgr: StepContextManager,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__()\n        self.config = config\n        self.ctx_mgr = ctx_mgr\n        # build model\n        self.model = MiniCPMV26Model(config, dtype=dtype, device=device)\n        # build lm_head\n        self.lm_head = build_rowwise_linear(config.hidden_size,\n                                            config.vocab_size,\n                                            bias=False,\n                                            dtype=dtype,\n                                            device=device)\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        past_key_values: List[List[torch.Tensor]],\n        attn_metadata: Any = None,\n        inputs_embeds: torch.Tensor = None,\n        **kwargs,\n    ):\n        \"\"\"Model forward, return logits.\"\"\"\n        hidden_states = self.model(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            inputs_embeds=inputs_embeds,\n        )\n\n        return hidden_states\n\n    def get_logits(self, hidden_states: torch.Tensor):\n        \"\"\"Compute logits of the model output.\"\"\"\n        return self.lm_head(hidden_states)\n\n    def update_weights(self):\n        \"\"\"Update weights.\"\"\"\n        if self.config.tie_word_embeddings:\n            self.lm_head.weight = self.model.embed_tokens.weight\n\n    def get_input_embeddings(self):\n        \"\"\"Get input embeddings.\"\"\"\n        return self.model.get_input_embeddings()\n\n    def prepare_inputs_for_generation(\n        self,\n        past_key_values: List[List[torch.Tensor]],\n        inputs_embeds: Optional[torch.Tensor] = None,\n        context: StepContext = None,\n    ):\n        \"\"\"Prepare input.\"\"\"\n        # get input_ids, position_ids and attention metadatas\n        input_ids = context.input_ids\n        position_ids = context.position_ids\n        attn_metadata = context.attn_metadata\n\n        # process vision embeddings\n        vision_embeddings = context.input_embeddings\n        vision_embedding_indexing = context.input_embedding_indexing\n        if vision_embeddings is not None and len(vision_embeddings) > 0:\n            if inputs_embeds is None:\n                inputs_embeds = self.get_input_embeddings()(input_ids)\n            inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds)\n\n        # inputs of forward\n        return dict(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            inputs_embeds=inputs_embeds,\n        )\n\n    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):\n        \"\"\"Load weights.\"\"\"\n        # modify from vllm\n        stacked_params_mapping = [\n            ('.qkv_proj', '.q_proj', 'q'),\n            ('.qkv_proj', '.k_proj', 'k'),\n            ('.qkv_proj', '.v_proj', 'v'),\n            ('.gate_up_proj', '.gate_proj', 0),\n            ('.gate_up_proj', '.up_proj', 1),\n        ]\n\n        params_dict = dict(self.named_parameters(prefix='llm'))\n        for name, loaded_weight in weights:\n            if 'vpm' in name or 'resampler' in name:\n                continue\n            if 'rotary_emb.inv_freq' in name:\n                continue\n            if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):\n                continue\n            if self.config.tie_word_embeddings and 'lm_head.weight' in name:\n                continue\n            for (param_name, weight_name, shard_id) in stacked_params_mapping:\n                if weight_name not in name:\n                    continue\n                name = name.replace(weight_name, param_name)\n                param = params_dict[name]\n                load_weight(param, loaded_weight, shard_id=shard_id)\n                break\n            else:\n                param = params_dict[name]\n                load_weight(param, loaded_weight)\n"
  },
  {
    "path": "lmdeploy/pytorch/models/mistral.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nfrom typing import Any, Iterable, List, Optional, Tuple\n\nimport torch\nfrom torch import nn\nfrom transformers.configuration_utils import PretrainedConfig\n\nfrom lmdeploy.pytorch.model_inputs import StepContext, StepContextManager\nfrom lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, RMSNorm, SiluAndMul, build_rotary_embedding_from_config\nfrom lmdeploy.pytorch.nn.linear import (build_down_linear, build_gateup_linear, build_o_proj, build_qkv_proj,\n                                        build_rowwise_linear)\nfrom lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight\n\nfrom .utils.cudagraph import CudaGraphMixin\n\n\nclass MistralAttention(nn.Module):\n    \"\"\"Rewrite module of MistralAttention.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        quantization_config = getattr(config, 'quantization_config', None)\n        num_heads = config.num_attention_heads\n        num_key_value_heads = config.num_key_value_heads\n        hidden_size = config.hidden_size\n        head_dim = config.head_dim\n        if head_dim is None:\n            head_dim = hidden_size // num_heads\n\n        # packed qkv\n        self.qkv_proj = build_qkv_proj(\n            hidden_size,\n            num_q_heads=num_heads,\n            num_kv_heads=num_key_value_heads,\n            head_size=head_dim,\n            bias=False,\n            quant_config=quantization_config,\n            dtype=dtype,\n            device=device,\n        )\n\n        # rotary embedding\n        self.apply_rotary_pos_emb = ApplyRotaryEmb()\n\n        # attention\n        self.attn_fwd = Attention(\n            num_heads,\n            head_dim,\n            num_kv_heads=num_key_value_heads,\n            v_head_size=head_dim,\n            sliding_window=config.sliding_window,\n        )\n\n        # o_proj\n        self.o_proj = build_o_proj(num_heads * head_dim,\n                                   hidden_size,\n                                   bias=False,\n                                   quant_config=quantization_config,\n                                   dtype=dtype,\n                                   device=device,\n                                   is_tp=True)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],\n        past_key_value: Optional[Tuple[torch.Tensor]] = None,\n        attn_metadata: Any = None,\n    ):\n        \"\"\"Rewrite of LlamaAttention.forward.\"\"\"\n        # qkv proj\n        qkv_states = self.qkv_proj(hidden_states)\n        # (-1, heads, head_dim)\n        qkv_states = qkv_states.flatten(0, -2)\n        query_states, key_states, value_states = self.qkv_proj.split_qkv(qkv_states)\n\n        # apply rotary embedding\n        cos, sin = rotary_pos_emb\n        query_states, key_states = self.apply_rotary_pos_emb(\n            query_states,\n            key_states,\n            cos,\n            sin,\n            inplace=True,\n        )\n\n        # attention\n        attn_output = self.attn_fwd(\n            query_states,\n            key_states,\n            value_states,\n            past_key_value[0],\n            past_key_value[1],\n            attn_metadata,\n            k_scales_zeros=None if len(past_key_value) == 2 else past_key_value[2],\n            v_scales_zeros=None if len(past_key_value) == 2 else past_key_value[3],\n            inplace=True,\n        )\n        attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1)\n\n        # o proj\n        attn_output = self.o_proj(attn_output)\n        return attn_output\n\n\nclass MistralMLP(nn.Module):\n    \"\"\"mlp.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        quantization_config = getattr(config, 'quantization_config', None)\n        # gate up\n        self.gate_up_proj = build_gateup_linear(\n            config.hidden_size,\n            [config.intermediate_size, config.intermediate_size],\n            bias=False,\n            dtype=dtype,\n            device=device,\n            quant_config=quantization_config,\n            is_tp=True,\n        )\n\n        # silu and mul\n        self.act_fn = SiluAndMul(inplace=True)\n\n        # down\n        self.down_proj = build_down_linear(config.intermediate_size,\n                                           config.hidden_size,\n                                           bias=False,\n                                           quant_config=quantization_config,\n                                           dtype=dtype,\n                                           device=device,\n                                           is_tp=True)\n\n    def forward(self, x):\n        \"\"\"forward.\"\"\"\n        gate_up = self.gate_up_proj(x)\n        act = self.act_fn(gate_up)\n        return self.down_proj(act)\n\n\nclass MistralDecoderLayer(nn.Module):\n    \"\"\"Decoder layer.\"\"\"\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 layer_idx: int,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__()\n        self.layer_idx = layer_idx\n        quantization_config = getattr(config, 'quantization_config', None)\n\n        # build attention layer\n        self.self_attn = MistralAttention(config, dtype=dtype, device=device)\n\n        # build MLP\n        self.mlp = MistralMLP(config, dtype=dtype, device=device)\n\n        # build input layer norm\n        self.input_layernorm = RMSNorm(config.hidden_size,\n                                       config.rms_norm_eps,\n                                       quant_config=quantization_config,\n                                       dtype=dtype,\n                                       device=device)\n\n        # build attention layer norm\n        self.post_attention_layernorm = RMSNorm(config.hidden_size,\n                                                config.rms_norm_eps,\n                                                quant_config=quantization_config,\n                                                dtype=dtype,\n                                                device=device)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],\n        past_key_value: Optional[List[torch.FloatTensor]],\n        residual: Optional[torch.Tensor] = None,\n        attn_metadata: Any = None,\n    ):\n\n        if residual is None:\n            residual = hidden_states\n            hidden_states = self.input_layernorm(hidden_states)\n        else:\n            hidden_states, residual = self.input_layernorm(hidden_states, residual)\n\n        # Self Attention\n        hidden_states = self.self_attn(\n            hidden_states=hidden_states,\n            rotary_pos_emb=rotary_pos_emb,\n            past_key_value=past_key_value,\n            attn_metadata=attn_metadata,\n        )\n\n        # Fully Connected\n        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)\n        hidden_states = self.mlp(hidden_states)\n\n        outputs = (hidden_states, residual)\n        return outputs\n\n\nclass MistralModel(nn.Module):\n    \"\"\"model.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n\n        self.embed_tokens = nn.Embedding(config.vocab_size,\n                                         config.hidden_size,\n                                         self.padding_idx,\n                                         dtype=dtype,\n                                         device=device)\n\n        # build all decode layers\n        self.layers = nn.ModuleList([\n            MistralDecoderLayer(config, layer_idx, dtype=dtype, device=device)\n            for layer_idx in range(config.num_hidden_layers)\n        ])\n\n        # build norm\n        self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)\n\n        # build rotary embedding\n        self.rotary_emb = build_rotary_embedding_from_config(config)\n\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        attn_metadata: Any = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n    ):\n        \"\"\"Rewrite of LlamaModel.forward.\"\"\"\n\n        # token embedding\n        if inputs_embeds is None:\n            inputs_embeds = self.embed_tokens(input_ids)\n\n        hidden_states = inputs_embeds\n\n        # rotary embedding\n        cos, sin = self.rotary_emb(hidden_states, position_ids)\n        cos, sin = cos[0], sin[0]\n        rotary_pos_emb = (cos, sin)\n\n        # decoding\n        residual = None\n        for idx, decoder_layer in enumerate(self.layers):\n            past_key_value = past_key_values[idx]\n            hidden_states, residual = decoder_layer(\n                hidden_states,\n                rotary_pos_emb=rotary_pos_emb,\n                past_key_value=past_key_value,\n                residual=residual,\n                attn_metadata=attn_metadata,\n            )\n\n        # norm\n        hidden_states, _ = self.norm(hidden_states, residual)\n\n        return hidden_states\n\n    def get_input_embeddings(self):\n        \"\"\"Get input embeddings.\"\"\"\n        return self.embed_tokens\n\n\nclass MistralForCausalLM(nn.Module, CudaGraphMixin):\n    \"\"\"ModelForCausalLM.\"\"\"\n\n    packed_modules_mapping = {\n        'qkv_proj': [\n            'q_proj',\n            'k_proj',\n            'v_proj',\n        ],\n        'gate_up_proj': [\n            'gate_proj',\n            'up_proj',\n        ],\n    }\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 ctx_mgr: StepContextManager,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__()\n        self.config = config\n        self.ctx_mgr = ctx_mgr\n        # build model\n        self.model = MistralModel(config, dtype=dtype, device=device)\n        # build lm_head\n        self.lm_head = build_rowwise_linear(config.hidden_size,\n                                            config.vocab_size,\n                                            bias=False,\n                                            dtype=dtype,\n                                            device=device)\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        past_key_values: List[List[torch.Tensor]],\n        attn_metadata: Any = None,\n        inputs_embeds: torch.Tensor = None,\n        **kwargs,\n    ):\n        \"\"\"Model forward, return logits.\"\"\"\n        hidden_states = self.model(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            inputs_embeds=inputs_embeds,\n        )\n        return hidden_states\n\n    def get_logits(self, hidden_states: torch.Tensor):\n        \"\"\"Compute logits of the model output.\"\"\"\n        return self.lm_head(hidden_states)\n\n    def get_input_embeddings(self):\n        \"\"\"Get input embeddings.\"\"\"\n        return self.model.get_input_embeddings()\n\n    def prepare_inputs_for_generation(\n        self,\n        past_key_values: List[List[torch.Tensor]],\n        inputs_embeds: Optional[torch.Tensor] = None,\n        context: StepContext = None,\n    ):\n        \"\"\"Prepare input.\"\"\"\n        # get input_ids, position_ids and attention metadatas\n        input_ids = context.input_ids\n        position_ids = context.position_ids\n        attn_metadata = context.attn_metadata\n\n        # process vision embeddings\n        vision_embeddings = context.input_embeddings\n        vision_embedding_indexing = context.input_embedding_indexing\n        if vision_embeddings is not None and len(vision_embeddings) > 0:\n            if inputs_embeds is None:\n                inputs_embeds = self.get_input_embeddings()(input_ids)\n            inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds)\n\n        # inputs of forward\n        return dict(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            inputs_embeds=inputs_embeds,\n        )\n\n    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):\n        \"\"\"Load weights.\"\"\"\n        # modify from vllm\n        stacked_params_mapping = [\n            # (param_name, shard_name, shard_id)\n            ('.qkv_proj', '.q_proj', 'q'),\n            ('.qkv_proj', '.k_proj', 'k'),\n            ('.qkv_proj', '.v_proj', 'v'),\n            ('.gate_up_proj', '.gate_proj', 0),\n            ('.gate_up_proj', '.up_proj', 1),\n        ]\n\n        params_dict = dict(self.named_parameters())\n        for name, loaded_weight in weights:\n            if 'rotary_emb.inv_freq' in name:\n                continue\n            if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):\n                continue\n            if self.config.tie_word_embeddings and 'lm_head.weight' in name:\n                continue\n            for (param_name, weight_name, shard_id) in stacked_params_mapping:\n                if weight_name not in name:\n                    continue\n                name = name.replace(weight_name, param_name)\n                param = params_dict[name]\n                load_weight(param, loaded_weight, shard_id=shard_id)\n                break\n            else:\n                param = params_dict[name]\n                load_weight(param, loaded_weight)\n"
  },
  {
    "path": "lmdeploy/pytorch/models/mixtral.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nfrom typing import Any, Iterable, List, Optional, Tuple\n\nimport torch\nfrom torch import nn\n\nfrom lmdeploy.pytorch.model_inputs import StepContext, StepContextManager\nfrom lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, RMSNorm, build_rotary_embedding_from_config\nfrom lmdeploy.pytorch.nn.linear import build_qkv_proj, build_rowwise_linear\nfrom lmdeploy.pytorch.nn.moe import SoftmaxTopK, build_fused_moe\nfrom lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight\n\nfrom .utils.cudagraph import CudaGraphMixin\n\n\nclass MixtralAttention(nn.Module):\n    \"\"\"Mixtral attention.\"\"\"\n\n    def __init__(self, config: Any, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        quantization_config = getattr(config, 'quantization_config', None)\n\n        num_heads = config.num_attention_heads\n        num_key_value_heads = config.num_key_value_heads\n        hidden_size = config.hidden_size\n        head_dim = hidden_size // num_heads\n\n        # qkv\n        self.qkv_proj = build_qkv_proj(\n            hidden_size,\n            num_q_heads=num_heads,\n            num_kv_heads=num_key_value_heads,\n            head_size=head_dim,\n            bias=False,\n            quant_config=quantization_config,\n            dtype=dtype,\n            device=device,\n        )\n\n        self.apply_rotary_pos_emb = ApplyRotaryEmb()\n\n        # attention\n        self.window_size = config.sliding_window or -1\n        self.attn_fwd = Attention(\n            num_heads,\n            head_dim,\n            num_kv_heads=num_key_value_heads,\n            v_head_size=head_dim,\n            sliding_window=self.window_size,\n        )\n\n        # o_proj\n        self.o_proj = build_rowwise_linear(num_heads * head_dim,\n                                           hidden_size,\n                                           bias=False,\n                                           quant_config=quantization_config,\n                                           dtype=dtype,\n                                           device=device,\n                                           is_tp=True)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],\n        past_key_value: Optional[Tuple[torch.Tensor]] = None,\n        attn_metadata: Any = None,\n    ):\n        \"\"\"Rewrite of LlamaAttention.forward.\"\"\"\n        qkv_states = self.qkv_proj(hidden_states)\n        # (-1, heads, head_dim)\n        qkv_states = qkv_states.flatten(0, -2)\n        query_states, key_states, value_states = self.qkv_proj.split_qkv(qkv_states)\n\n        cos, sin = rotary_pos_emb\n        query_states, key_states = self.apply_rotary_pos_emb(\n            query_states,\n            key_states,\n            cos,\n            sin,\n            inplace=True,\n        )\n        attn_output = self.attn_fwd(\n            query_states,\n            key_states,\n            value_states,\n            past_key_value[0],\n            past_key_value[1],\n            attn_metadata,\n            k_scales_zeros=None if len(past_key_value) == 2 else past_key_value[2],\n            v_scales_zeros=None if len(past_key_value) == 2 else past_key_value[3],\n            inplace=True,\n        )\n        attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1)\n\n        attn_output = self.o_proj(attn_output)\n\n        return attn_output\n\n\nclass MixtralSparseMoeBlock(nn.Module):\n    \"\"\"Mixtral sparse moe block.\"\"\"\n\n    def __init__(self, config: Any, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        quantization_config = getattr(config, 'quantization_config', None)\n        self.hidden_dim = config.hidden_size\n        self.ffn_dim = config.intermediate_size\n        self.num_experts = config.num_local_experts\n        self.top_k = config.num_experts_per_tok\n\n        self.gate = build_rowwise_linear(\n            self.hidden_dim,\n            self.num_experts,\n            bias=False,\n            dtype=dtype,\n            device=device,\n            is_tp=False,\n            quant_config=None,\n        )\n\n        self.softmax_topk = SoftmaxTopK(\n            self.top_k,\n            n_groups=getattr(config, 'router_n_groups', -1),\n        )\n\n        self.experts = build_fused_moe(\n            self.hidden_dim,\n            self.ffn_dim,\n            self.num_experts,\n            top_k=self.top_k,\n            renormalize=True,\n            dtype=dtype,\n            device=device,\n            all_reduce=True,\n            quant_config=quantization_config,\n        )\n\n    def forward(self, hidden_states: torch.Tensor):\n        \"\"\"forward.\"\"\"\n        batch_size, sequence_length, hidden_dim = hidden_states.shape\n        hidden_states = hidden_states.view(-1, hidden_dim)\n        router_logits = self.gate(hidden_states)\n\n        topk_weights, topk_ids = self.softmax_topk(router_logits)\n        out_states = self.experts(\n            hidden_states,\n            topk_weights,\n            topk_ids,\n        )\n\n        out_states = out_states.reshape(batch_size, sequence_length, -1)\n        return out_states, router_logits\n\n\nclass MixtralDecoderLayer(nn.Module):\n    \"\"\"Mixtral decoder layer.\"\"\"\n\n    def __init__(self, config: Any, layer_idx: int, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        self.layer_idx = layer_idx\n        quantization_config = getattr(config, 'quantization_config', None)\n\n        # build attention layer\n        self.self_attn = MixtralAttention(config, dtype=dtype, device=device)\n        self.block_sparse_moe = MixtralSparseMoeBlock(config, dtype=dtype, device=device)\n\n        # build input layer norm\n        self.input_layernorm = RMSNorm(config.hidden_size,\n                                       config.rms_norm_eps,\n                                       quant_config=quantization_config,\n                                       dtype=dtype,\n                                       device=device)\n\n        # build attention layer norm\n        self.post_attention_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],\n        past_key_value: Optional[List[torch.FloatTensor]],\n        residual: Optional[torch.Tensor] = None,\n        attn_metadata: Any = None,\n    ):\n\n        if residual is None:\n            residual = hidden_states\n            hidden_states = self.input_layernorm(hidden_states)\n        else:\n            hidden_states, residual = self.input_layernorm(hidden_states, residual)\n\n        # Self Attention\n        hidden_states = self.self_attn(\n            hidden_states=hidden_states,\n            rotary_pos_emb=rotary_pos_emb,\n            past_key_value=past_key_value,\n            attn_metadata=attn_metadata,\n        )\n\n        # Fully Connected\n        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)\n        hidden_states, _ = self.block_sparse_moe(hidden_states)\n\n        outputs = (hidden_states, residual)\n        return outputs\n\n\nclass MixtralModel(nn.Module):\n    \"\"\"Mixtral model.\"\"\"\n\n    def __init__(self, config: Any, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n        self.embed_tokens = nn.Embedding(config.vocab_size,\n                                         config.hidden_size,\n                                         self.padding_idx,\n                                         dtype=dtype,\n                                         device=device)\n        self.layers = nn.ModuleList([\n            MixtralDecoderLayer(config, layer_idx, dtype=dtype, device=device)\n            for layer_idx in range(config.num_hidden_layers)\n        ])\n\n        # build norm\n        self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, quant_config=None, dtype=dtype, device=device)\n\n        self.rotary_emb = build_rotary_embedding_from_config(config)\n\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        attn_metadata: Any = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n    ):\n        \"\"\"forward.\"\"\"\n        if inputs_embeds is None:\n            inputs_embeds = self.embed_tokens(input_ids)\n\n        hidden_states = inputs_embeds\n        residual = None\n        cos, sin = self.rotary_emb(hidden_states, position_ids)\n        cos, sin = cos[0], sin[0]\n        rotary_pos_emb = (cos, sin)\n        for idx, decoder_layer in enumerate(self.layers):\n\n            past_key_value = past_key_values[idx]\n            hidden_states, residual = decoder_layer(\n                hidden_states,\n                rotary_pos_emb=rotary_pos_emb,\n                past_key_value=past_key_value,\n                residual=residual,\n                attn_metadata=attn_metadata,\n            )\n\n        hidden_states, _ = self.norm(hidden_states, residual)\n\n        return hidden_states\n\n    def get_input_embeddings(self):\n        \"\"\"Get input embeddings.\"\"\"\n        return self.embed_tokens\n\n\nclass MixtralForCausalLM(nn.Module, CudaGraphMixin):\n    \"\"\"Mixture model for causalLM.\"\"\"\n\n    def __init__(self,\n                 config: Any,\n                 ctx_mgr: StepContextManager,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__()\n        self.config = config\n        self.ctx_mgr = ctx_mgr\n        self.model = MixtralModel(config, dtype=dtype, device=device)\n        # build lm_head\n        self.lm_head = build_rowwise_linear(config.hidden_size,\n                                            config.vocab_size,\n                                            bias=False,\n                                            dtype=dtype,\n                                            device=device)\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        past_key_values: List[List[torch.Tensor]],\n        attn_metadata: Any = None,\n        inputs_embeds: torch.Tensor = None,\n        **kwargs,\n    ):\n        hidden_states = self.model(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            inputs_embeds=inputs_embeds,\n        )\n        return hidden_states\n\n    def get_logits(self, hidden_states: torch.Tensor):\n        \"\"\"Compute logits of the model output.\"\"\"\n        return self.lm_head(hidden_states)\n\n    def get_input_embeddings(self):\n        \"\"\"Get input embeddings.\"\"\"\n        return self.model.get_input_embeddings()\n\n    def prepare_inputs_for_generation(\n        self,\n        past_key_values: List[List[torch.Tensor]],\n        inputs_embeds: Optional[torch.Tensor] = None,\n        context: StepContext = None,\n    ):\n        \"\"\"Prepare input.\"\"\"\n        input_ids = context.input_ids\n        position_ids = context.position_ids\n        attn_metadata = context.attn_metadata\n\n        return dict(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            inputs_embeds=inputs_embeds,\n        )\n\n    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):\n        \"\"\"Load weights.\"\"\"\n        # modify from vllm\n        stacked_params_mapping = [\n            # (param_name, shard_name, shard_id)\n            ('.qkv_proj', '.q_proj', 'q'),\n            ('.qkv_proj', '.k_proj', 'k'),\n            ('.qkv_proj', '.v_proj', 'v'),\n        ]\n\n        num_experts = self.config.num_local_experts\n        expert_params_mapping = []\n        for exp_id in range(num_experts):\n            gate_param = ('.experts.gate_up', f'.experts.{exp_id}.w1', exp_id, 'gate')\n            up_param = ('.experts.gate_up', f'.experts.{exp_id}.w3', exp_id, 'up')\n            down_param = ('.experts.down', f'.experts.{exp_id}.w2', exp_id, 'down')\n            expert_params_mapping += [gate_param, up_param, down_param]\n\n        params_dict = dict(self.named_parameters())\n        for name, loaded_weight in weights:\n            if 'rotary_emb.inv_freq' in name:\n                continue\n            if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):\n                continue\n            if self.config.tie_word_embeddings and 'lm_head.weight' in name:\n                continue\n            for (param_name, weight_name, expert_id, shard_id) in expert_params_mapping:\n                if weight_name not in name:\n                    continue\n                name = name.replace(weight_name, param_name)\n                param = params_dict[name]\n                load_weight(param, loaded_weight, expert_id=expert_id, shard_id=shard_id)\n                break\n            else:\n                for (param_name, weight_name, shard_id) in stacked_params_mapping:\n                    if weight_name not in name:\n                        continue\n                    name = name.replace(weight_name, param_name)\n                    param = params_dict[name]\n                    load_weight(param, loaded_weight, shard_id=shard_id)\n                    break\n                else:\n                    param = params_dict[name]\n                    load_weight(param, loaded_weight)\n"
  },
  {
    "path": "lmdeploy/pytorch/models/module_map.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nLMDEPLOY_PYTORCH_MODEL_PATH = 'lmdeploy.pytorch.models'\n\n# ascend module\nMODULE_MAP = dict()\nASCEND_MODULE_MAP = dict()\nMACA_MODULE_MAP = dict()\nCAMB_MODULE_MAP = dict()\n\nDEVICE_SPECIAL_MODULE_MAP = dict(ascend=ASCEND_MODULE_MAP, maca=MACA_MODULE_MAP, camb=CAMB_MODULE_MAP)\n\n# llama\nMODULE_MAP.update({\n    'LlamaForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaForCausalLM',\n})\n\n# llama4\nMODULE_MAP.update({\n    'Llama4ForConditionalGeneration':\n    f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama4.Llama4ForConditionalGeneration',\n})\n\n# baichuan\nMODULE_MAP.update({\n    'BaichuanForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.baichuan.BaichuanForCausalLM',\n})\n\n# chatglm\nMODULE_MAP.update({\n    'ChatGLMForConditionalGeneration':\n    f'{LMDEPLOY_PYTORCH_MODEL_PATH}.chatglm2.ChatGLMForConditionalGeneration',  # noqa: E501\n})\n\n# glm4-0414\nMODULE_MAP.update({\n    'Glm4ForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.glm4.Glm4ForCausalLM',\n})\n\n# glm4.1-v\nMODULE_MAP.update({\n    'Glm4vForConditionalGeneration':\n    f'{LMDEPLOY_PYTORCH_MODEL_PATH}.glm4_1v.Glm4vForConditionalGeneration',\n})\n\n# glm4.5\nMODULE_MAP.update({\n    'Glm4MoeForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.glm4_moe.Glm4MoeForCausalLM',\n})\n\n# glm4.7\n\nMODULE_MAP.update({'Glm4MoeLiteForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.deepseek_v2.DeepseekV2ForCausalLM'})\n\n# glm4.7 mtp\nMODULE_MAP.update({\n    'Glm4MoeMTPModel': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.glm4moe_mtp.Glm4MoeMTPModel',\n})\n\n# glm5\nMODULE_MAP.update({'GlmMoeDsaForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.deepseek_v32.DeepseekV32ForCausalLM'})\n\n# internlm\nMODULE_MAP.update({\n    'InternLMForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.internlm.InternLMForCausalLM',\n})\n\n# internlm2\nMODULE_MAP.update({\n    'InternLM2ForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.internlm2.InternLM2ForCausalLM',\n})\n\n# mistral\nMODULE_MAP.update({\n    'MistralForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.mistral.MistralForCausalLM',\n})\n\n# mixtral\nMODULE_MAP.update({\n    'MixtralForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.mixtral.MixtralForCausalLM',\n})\n\n# gemma\nMODULE_MAP.update({\n    'GemmaForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.gemma.GemmaForCausalLM',\n})\n\n# gemma2\nMODULE_MAP.update({\n    'Gemma2ForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.gemma.GemmaForCausalLM',\n})\n\n# gemma3 text\nMODULE_MAP.update({\n    'Gemma3ForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.gemma.GemmaForCausalLM',\n})\n\n# gemma3 VL\nMODULE_MAP.update({\n    'Gemma3ForConditionalGeneration':\n    f'{LMDEPLOY_PYTORCH_MODEL_PATH}.gemma3_vl.Gemma3ForConditionalGeneration',\n})\n\n# deepseek\nMODULE_MAP.update({\n    'DeepseekForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.deepseek.DeepseekForCausalLM',\n})\n\n# deepseek-v2\nMODULE_MAP.update({'DeepseekV2ForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.deepseek_v2.DeepseekV2ForCausalLM'})\n\n# deepseek-v3\nMODULE_MAP.update({'DeepseekV3ForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.deepseek_v2.DeepseekV2ForCausalLM'})\n\n# deepseek-v32\nMODULE_MAP.update({'DeepseekV32ForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.deepseek_v32.DeepseekV32ForCausalLM'})\n\n# deepseek-vl2\nMODULE_MAP.update({'DeepseekVLV2ForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.deepseek_vl2.DeepseekVLV2ForCausalLM'})\n\n# llava\nMODULE_MAP.update({\n    'LlavaForConditionalGeneration': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llava.LlavaForConditionalGeneration',  # noqa: E501\n    'LlavaNextForConditionalGeneration':  # noqa: E501\n    f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llava.LlavaNextForConditionalGeneration'  # noqa: E501\n})\n\n# qwen\nMODULE_MAP.update({\n    'QWenLMHeadModel': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen.QWenLMHeadModel',\n})\n\n# qwen1.5\nMODULE_MAP.update({\n    'Qwen2ForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen2.Qwen2ForCausalLM',\n})\n\n# qwen2 moe\nMODULE_MAP.update({\n    'Qwen2MoeForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen2_moe.Qwen2MoeForCausalLM',\n})\n\n# qwen3\nMODULE_MAP.update({\n    'Qwen3ForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen3.Qwen3ForCausalLM',\n})\n\n# qwen3 moe\nMODULE_MAP.update({\n    'Qwen3MoeForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen3_moe.Qwen3MoeForCausalLM',\n})\n\n# qwen2_vl\nMODULE_MAP.update({\n    'Qwen2VLForConditionalGeneration':\n    f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen2_vl.Qwen2VLForConditionalGeneration',\n})\n\n# qwen2_5_vl\nMODULE_MAP.update({\n    'Qwen2_5_VLForConditionalGeneration':\n    f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen2_5_vl.Qwen2_5_VLForConditionalGeneration',\n})\n\n# qwen3_vl\nMODULE_MAP.update({\n    'Qwen3VLForConditionalGeneration':\n    f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen3_vl.Qwen3VLForConditionalGeneration',\n})\n\n# qwen3_vl_moe\nMODULE_MAP.update({\n    'Qwen3VLMoeForConditionalGeneration':\n    f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen3_vl_moe.Qwen3VLMoeForConditionalGeneration',\n})\n\n# qwen3.5\nMODULE_MAP.update({\n    'Qwen3_5ForConditionalGeneration':\n    f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen3_5.Qwen3_5ForConditionalGeneration',\n})\n\n# qwen3.5 moe\nMODULE_MAP.update({\n    'Qwen3_5MoeForConditionalGeneration':\n    f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen3_5_moe.Qwen3_5MoeForConditionalGeneration',\n})\n\n# starcoder2\nMODULE_MAP.update({\n    'Starcoder2ForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.starcoder2.Starcoder2ForCausalLM',\n})\n\n# phi-3\nMODULE_MAP.update({\n    'Phi3ForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.phi3.Phi3ForCausalLM',\n})\n\n# cogvlm\nMODULE_MAP.update({\n    'CogVLMForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.cogvlm.CogVLMForCausalLM',\n})\n\n# internvl\nMODULE_MAP.update({'InternVLChatModel': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.internvl.InternVLChatModel'})\n\n# internvl3-hf\nMODULE_MAP.update({\n    'InternVLForConditionalGeneration':\n    f'{LMDEPLOY_PYTORCH_MODEL_PATH}.internvl3_hf.InternVLForConditionalGeneration'\n})\n\n# interns1-hf\nMODULE_MAP.update({\n    'InternS1ForConditionalGeneration':\n    f'{LMDEPLOY_PYTORCH_MODEL_PATH}.internvl3_hf.InternVLForConditionalGeneration'\n})\n\n# interns1-pro\nMODULE_MAP.update({\n    'InternS1ProForConditionalGeneration':\n    f'{LMDEPLOY_PYTORCH_MODEL_PATH}.interns1_pro.InternS1ProForConditionalGeneration',\n})\nMODULE_MAP.update({\n    'InternS1_1_ForConditionalGeneration':\n    f'{LMDEPLOY_PYTORCH_MODEL_PATH}.interns1_pro.InternS1ProForConditionalGeneration',\n})\n\n# mono-internvl\nMODULE_MAP.update({\n    'InternLM2VEForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.internlm2_ve.InternLM2VEForCausalLM',\n})\n\n# phi3 vision\nMODULE_MAP.update({\n    'Phi3VForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.phi3_v.Phi3VForCausalLM',\n})\n\n# phi-3.5-moe\nMODULE_MAP.update({\n    'PhiMoEForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.phi3_moe.PhiMoEForCausalLM',\n})\n\n# minicpm3\nMODULE_MAP.update({\n    'MiniCPM3ForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.minicpm3.MiniCPM3ForCausalLM',\n})\n\n# minicpmv2_6\nMODULE_MAP.update({\n    'MiniCPMV': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.minicpmv26.MiniCPMVForCausalLM',\n})\n\n# internlm3\nMODULE_MAP.update({\n    'InternLM3ForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.internlm3.InternLM3ForCausalLM',\n})\n\n# internlm2 reward model\nMODULE_MAP.update(\n    {'InternLM2ForRewardModel': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.internlm2_reward.InternLM2ForRewardModel'})\n\n# qwen2 reward model\nMODULE_MAP.update({'Qwen2ForRewardModel': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen2_reward.Qwen2ForRewardModel'})\n\n# gpt-oss\nMODULE_MAP.update({\n    'GptOssForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.gpt_oss.GptOssForCausalLM',\n})\n\n# qwen3 next model\nMODULE_MAP.update({\n    'Qwen3NextForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen3_next.Qwen3NextForCausalLM',\n})\n\n# SDAR\nMODULE_MAP.update({\n    'SDARForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.sdar.SDARForCausalLM',\n    'SDARMoeForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.sdar_moe.SDARMoeForCausalLM',\n})\n\nCUSTOM_MODULE_MAP = dict()\n\n# spec models\n# eagle llama\nMODULE_MAP.update({'EagleLlamaForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama_eagle.EagleLlamaForCausalLM'})\n\n# eagle3 llama\nMODULE_MAP.update({'Eagle3LlamaForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama_eagle3.Eagle3LlamaForCausalLM'})\n\n# deepseek mtp\nMODULE_MAP.update({'DeepseekMTPModel': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.deepseek_mtp.DeepseekMTPModel'})\n"
  },
  {
    "path": "lmdeploy/pytorch/models/patch.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nimport contextlib\nimport importlib\nimport inspect\nimport os.path as osp\nimport re\nimport sys\nfrom typing import Any, Dict\n\nimport torch\nfrom transformers.configuration_utils import PretrainedConfig\n\nfrom lmdeploy.pytorch.model_inputs import BuildModelContext, StepContextManager\nfrom lmdeploy.utils import get_logger\n\nfrom ..config import ModelConfig\nfrom ..devices import get_device_manager\nfrom .module_map import CUSTOM_MODULE_MAP, DEVICE_SPECIAL_MODULE_MAP, MODULE_MAP\n\nlogger = get_logger('lmdeploy')\n\n\ndef _get_rewrite_qualname(origin_qualname: str, module_map: Dict[str, str]) -> str:\n    \"\"\"Get rewrite module from origin module name.\n\n    Args:\n        origin_qualname (str): The origin qualname of the module.\n\n    Returns:\n        str: The rewrite qualname.\n    \"\"\"\n    if origin_qualname in module_map:\n        return module_map[origin_qualname]\n    for key, value in module_map.items():\n        if re.search(key, origin_qualname):\n            return value\n    return None\n\n\ndef _class_from_qualname(qualname: str) -> Any:\n    \"\"\"Import class with qualname.\n\n    Args:\n        qualname (str): Qualname of the class\n\n    Returns:\n        Any: class or builder of the class\n    \"\"\"\n    last_dot = qualname.rfind('.')\n    modname = qualname[:last_dot]\n    clsname = qualname[last_dot + 1:]\n\n    # get class at runtime\n    mod = importlib.import_module(modname)\n    assert mod is not None, f'failed to import module: {modname}'\n    cls_type = getattr(mod, clsname)\n    return cls_type\n\n\ndef _find_rewrite_module_qualname(model, module_map: Dict[str, str]):\n    \"\"\"Find rewrite module.\"\"\"\n    module_name = inspect.getmodule(model).__name__\n    class_name = model.__class__.__name__\n\n    def _find_fullname():\n        origin_qualname = f'{module_name}.{class_name}'\n        rewrite_qualname = _get_rewrite_qualname(origin_qualname, module_map)\n        return rewrite_qualname\n\n    def _find_classname():\n        origin_qualname = class_name\n        rewrite_qualname = _get_rewrite_qualname(origin_qualname, module_map)\n        return rewrite_qualname\n\n    def _find_submodulename():\n        # name with first module\n        mod_name = module_name[module_name.rfind('.') + 1:]\n        origin_qualname = f'{mod_name}.{class_name}'\n        rewrite_qualname = _get_rewrite_qualname(origin_qualname, module_map)\n        return rewrite_qualname\n\n    rewrite_qualname = _find_fullname()\n    if rewrite_qualname is None:\n        rewrite_qualname = _find_classname()\n    if rewrite_qualname is None:\n        rewrite_qualname = _find_submodulename()\n\n    origin_qualname = f'{module_name}.{class_name}'\n    if rewrite_qualname is not None:\n        logger.debug('Find rewrite of module\\n'\n                     f'{origin_qualname} <=> {rewrite_qualname}')\n    return rewrite_qualname\n\n\ndef get_rewrite_cls(model: torch.nn.Module, module_map: Dict[str, str] = None):\n    \"\"\"Get rewrite cls.\"\"\"\n    if module_map is None:\n        module_map = _get_module_map()\n    rewrite_qualname = _find_rewrite_module_qualname(model, module_map=module_map)\n    if rewrite_qualname is None:\n        return None\n    return _class_from_qualname(rewrite_qualname)\n\n\ndef _get_module_map():\n    \"\"\"Get module map.\"\"\"\n    module_map = MODULE_MAP.copy()\n    device_type = get_device_manager().current_context().device_type\n    if device_type != 'cuda':\n        device_map = DEVICE_SPECIAL_MODULE_MAP.get(device_type, dict())\n        module_map.update(device_map)\n    # add custom module map\n    module_map.update(CUSTOM_MODULE_MAP)\n    return module_map\n\n\ndef update_custom_module_map(module_map_path: str):\n    \"\"\"Moad custom module map from file.\"\"\"\n    from importlib.machinery import SourceFileLoader\n\n    from lmdeploy.pytorch.models.module_map import LMDEPLOY_PYTORCH_MODEL_PATH\n    assert osp.exists(module_map_path), (f'custom module map path: \"{module_map_path}\" not exists.')\n\n    module_map_path = osp.abspath(module_map_path)\n    folder = osp.split(module_map_path)[0]\n    sys.path.append(folder)\n    custom_mod = SourceFileLoader('map_mod', module_map_path).load_module()\n    sys.modules[f'{LMDEPLOY_PYTORCH_MODEL_PATH}._custom_mod'] = custom_mod\n\n    new_mod_map = dict()\n    has_map = False\n    if hasattr(custom_mod, 'MODULE_MAP'):\n        has_map = True\n        mod_map = custom_mod.MODULE_MAP\n        assert isinstance(mod_map, Dict)\n        new_mod_map.update(mod_map)\n\n    if hasattr(custom_mod, 'CUSTOM_MODULE_MAP'):\n        has_map = True\n        mod_map = custom_mod.CUSTOM_MODULE_MAP\n        assert isinstance(mod_map, Dict)\n        new_mod_map.update(mod_map)\n\n    if not has_map:\n        raise RuntimeError(f'Found no map in \"{module_map_path}\".')\n\n    for k, v in new_mod_map.items():\n        if '.' not in v:\n            v = f'{LMDEPLOY_PYTORCH_MODEL_PATH}._custom_mod.{v}'\n            new_mod_map[k] = v\n\n    CUSTOM_MODULE_MAP.update(new_mod_map)\n\n\ndef _get_model_class(config, module_map):\n    \"\"\"Get model class.\"\"\"\n    auto_map = getattr(config, 'auto_map', dict())\n    if 'AutoModelForCausalLM' in auto_map:\n        mapname = auto_map['AutoModelForCausalLM']\n        if '.' in mapname:\n            mapname = mapname.split('.')[-1]\n        if mapname in module_map:\n            qualname = module_map[mapname]\n            module_cls = _class_from_qualname(qualname)\n            return module_cls\n        raise RuntimeError(f'Can not found rewrite for auto_map: {mapname}')\n\n    architectures = getattr(config, 'architectures', [])\n\n    if architectures is None:\n        # only for deepseek-vl2, which has different config formats\n        # https://huggingface.co/deepseek-ai/deepseek-vl2/blob/main/config.json\n        assert getattr(config.language_config, 'architectures', []) is not None\n        qualname = module_map['DeepseekVLV2ForCausalLM']\n        module_cls = _class_from_qualname(qualname)\n        return module_cls\n\n    for arch in architectures:\n        if arch in module_map:\n            qualname = module_map[arch]\n            module_cls = _class_from_qualname(qualname)\n            return module_cls\n\n    raise RuntimeError(f'Can not found rewrite for architectures: {architectures}')\n\n\ndef build_model_from_hf_config(model_config: PretrainedConfig,\n                               dtype: torch.dtype = None,\n                               device: torch.device = None,\n                               ctx_mgr: StepContextManager = None,\n                               build_model_ctx: 'BuildModelContext' = None):\n    \"\"\"Build model from hf config.\"\"\"\n    if ctx_mgr is None:\n        ctx_mgr = StepContextManager(build_model_ctx)\n    module_map = _get_module_map()\n    if device is None:\n        device = torch.device('cuda')\n    model_cls = _get_model_class(model_config, module_map)\n    # update quant config\n    if build_model_ctx is not None and hasattr(model_cls, 'update_quant_config'):\n        build_model_ctx.quant_config = model_cls.update_quant_config(build_model_ctx.quant_config)\n\n    with build_model_context(build_model_ctx):\n        model = model_cls(model_config, ctx_mgr, dtype=dtype, device=device)\n    return model.eval()\n\n\n@torch.inference_mode()\ndef build_patched_model(config: ModelConfig, device: torch.device = None, build_model_ctx: 'BuildModelContext' = None):\n    \"\"\"Build patched model.\"\"\"\n    model_config = config.hf_config\n    dtype = config.dtype\n    return build_model_from_hf_config(model_config, dtype=dtype, device=device, build_model_ctx=build_model_ctx)\n\n\n@torch.inference_mode()\ndef add_adapters(model: torch.nn.Module,\n                 adapters: Dict[str, str],\n                 dtype: torch.dtype = torch.float16,\n                 device: torch.device = None):\n    \"\"\"Add adapters.\"\"\"\n    from peft import PeftConfig\n    from peft.tuners.lora import LoraConfig\n    from transformers.modeling_utils import load_state_dict\n\n    from lmdeploy.pytorch.adapter.adapter import find_all_target, get_ranks_and_scalings, load_lora_weights\n    from lmdeploy.pytorch.nn.linear import LoRA\n    num_adapters = len(adapters)\n    if num_adapters == 0:\n        return\n\n    if device is None:\n        device = torch.device('cuda')\n\n    # model could be graph runner\n    if hasattr(model, 'get_model'):\n        model = model.get_model()\n    ctx_mgr = model.ctx_mgr\n\n    adapter_names = list(adapters.keys())\n    adapter_names = sorted(adapter_names)\n\n    adapter_cfgs = [PeftConfig.from_pretrained(adapters[name]) for name in adapter_names]\n\n    # insert one for no adapter\n    adapter_cfgs = [LoraConfig(r=0, target_modules=[])] + adapter_cfgs\n    adapter_names = [None] + adapter_names\n    adapter_id_map = dict(zip(adapter_names, range(len(adapter_names))))\n\n    # target layer name to add adapter\n    target_names = set()\n    for cfg in adapter_cfgs:\n        target_names = target_names.union(cfg.target_modules)\n    target_names = list(target_names)\n    target_names = sorted(target_names)\n\n    target_infos = dict()\n    for _, target_name in enumerate(target_names):\n        # get ranks and scalings\n        ranks, scalings = get_ranks_and_scalings(target_name, adapter_cfgs, device=device)\n        # split in case target_name has '.' like 'attention.wo'\n        # which cannot be used as name of a module\n        # and it's not aligned with key in model.packed_modules_mapping\n        target_name = target_name.split('.')[-1]\n        found_mods, pack_idx = find_all_target(model, target_name)\n        sum_rank = ranks.sum().item()\n\n        in_features = 0\n        out_features = 0\n        colwise = True\n        for _, mod in found_mods:\n            assert hasattr(mod, 'lora_adapters')\n            in_features = mod.in_features\n            colwise = mod.colwise\n            if pack_idx is None:\n                base_slice = slice(0, mod.out_features)\n                out_features = mod.out_features\n                lora_b_spliter = getattr(mod, 'weight_spliter_lora_b', None)\n            else:\n                prev_feats = sum(mod.all_out_features[:pack_idx])\n                out_features = mod.all_out_features[pack_idx]\n                base_slice = slice(prev_feats, prev_feats + out_features)\n                lora_b_spliter = None\n            lora_a = torch.empty((sum_rank, in_features), dtype=dtype, device=device)\n            lora_b = torch.empty((sum_rank, out_features), dtype=dtype, device=device)\n\n            lora = LoRA(\n                in_features,\n                out_features,\n                ranks=ranks,\n                scalings=scalings,\n                lora_a=lora_a,\n                lora_b=lora_b,\n                base_slice=base_slice,\n                ctx_mgr=ctx_mgr,\n                colwise=colwise,\n                is_tp=mod.is_tp,\n                lora_b_spliter=lora_b_spliter,\n            )\n            mod.lora_adapters[target_name] = lora\n\n    # fill adapter data\n    for name, path in adapters.items():\n        adapter_id = adapter_id_map[name]\n        checkpoint_path = f'{path}/adapter_model.bin'\n        if not osp.exists(checkpoint_path):\n            checkpoint_path = f'{path}/adapter_model.safetensors'\n        state_dict = load_state_dict(checkpoint_path, map_location=device)\n\n        if hasattr(model, 'load_lora_weights'):\n            model.load_lora_weights(state_dict.items(), adapter_id=adapter_id)\n        else:\n            load_lora_weights(model, state_dict.items(), adapter_id=adapter_id)\n\n    return target_infos\n\n\nBUILD_MODEL_CTX = BuildModelContext()\n\n\n@contextlib.contextmanager\ndef build_model_context(ctx: BuildModelContext):\n    \"\"\"Context manager for building model.\"\"\"\n    global BUILD_MODEL_CTX\n    old_ctx = BUILD_MODEL_CTX\n    ctx = ctx or old_ctx\n    BUILD_MODEL_CTX = ctx\n    yield\n    BUILD_MODEL_CTX = old_ctx\n\n\ndef get_build_model_context() -> BuildModelContext:\n    \"\"\"Get build model context.\"\"\"\n    global BUILD_MODEL_CTX\n    return BUILD_MODEL_CTX\n\n\ndef add_prefix(name: str, prefix: str) -> str:\n    \"\"\"Add prefix to module name.\"\"\"\n    return name if not prefix else f'{prefix}.{name}'\n"
  },
  {
    "path": "lmdeploy/pytorch/models/phi3.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nfrom typing import Any, Iterable, List, Optional, Tuple\n\nimport torch\nfrom torch import nn\nfrom transformers.configuration_utils import PretrainedConfig\n\nfrom lmdeploy.pytorch.model_inputs import StepContext, StepContextManager\nfrom lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, RMSNorm, SiluAndMul\nfrom lmdeploy.pytorch.nn.linear import build_down_linear, build_gateup_linear, build_o_proj, build_qkv_proj\nfrom lmdeploy.pytorch.nn.rotary_embedding import build_rotary_embedding_from_config\nfrom lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight\n\nfrom .utils.cudagraph import CudaGraphMixin\nfrom .utils.model import DeployModelMixinV1\n\n\nclass Phi3Attention(nn.Module):\n    \"\"\"Rewrite module of Phi3Attention.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        quantization_config = getattr(config, 'quantization_config', None)\n        num_heads = config.num_attention_heads\n        num_key_value_heads = config.num_key_value_heads\n        hidden_size = config.hidden_size\n        head_dim = getattr(config, 'head_dim', hidden_size // num_heads)\n\n        # packed qkv\n        self.qkv_proj = build_qkv_proj(\n            hidden_size,\n            num_q_heads=num_heads,\n            num_kv_heads=num_key_value_heads,\n            head_size=head_dim,\n            bias=False,\n            quant_config=quantization_config,\n            dtype=dtype,\n            device=device,\n        )\n\n        # rotary embedding\n        self.apply_rotary_pos_emb = ApplyRotaryEmb()\n\n        # attention\n        sliding_window = getattr(config, 'sliding_window', None)\n        self.attn_fwd = Attention(\n            num_heads,\n            head_dim,\n            num_kv_heads=num_key_value_heads,\n            v_head_size=head_dim,\n            sliding_window=sliding_window,\n        )\n\n        # o_proj\n        self.o_proj = build_o_proj(num_heads * head_dim,\n                                   hidden_size,\n                                   bias=False,\n                                   quant_config=quantization_config,\n                                   dtype=dtype,\n                                   device=device,\n                                   is_tp=True)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],\n        past_key_value: Optional[Tuple[torch.Tensor]] = None,\n        attn_metadata: Any = None,\n    ):\n        \"\"\"Rewrite of LlamaAttention.forward.\"\"\"\n        # qkv proj\n        qkv_states = self.qkv_proj(hidden_states)\n        # (-1, heads, head_dim)\n        qkv_states = qkv_states.flatten(0, -2)\n        query_states, key_states, value_states = self.qkv_proj.split_qkv(qkv_states)\n\n        # apply rotary embedding\n        cos, sin = rotary_pos_emb\n        query_states, key_states = self.apply_rotary_pos_emb(\n            query_states,\n            key_states,\n            cos,\n            sin,\n            inplace=True,\n        )\n\n        # attention\n        attn_output = self.attn_fwd(\n            query_states,\n            key_states,\n            value_states,\n            past_key_value[0],\n            past_key_value[1],\n            attn_metadata,\n            k_scales_zeros=None if len(past_key_value) == 2 else past_key_value[2],\n            v_scales_zeros=None if len(past_key_value) == 2 else past_key_value[3],\n            inplace=True,\n        )\n        attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1)\n\n        # o proj\n        attn_output = self.o_proj(attn_output)\n        return attn_output\n\n\nclass Phi3MLP(nn.Module):\n    \"\"\"mlp.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        quantization_config = getattr(config, 'quantization_config', None)\n        # gate up\n        self.gate_up_proj = build_gateup_linear(\n            config.hidden_size,\n            [config.intermediate_size, config.intermediate_size],\n            bias=False,\n            dtype=dtype,\n            device=device,\n            quant_config=quantization_config,\n            is_tp=True,\n        )\n\n        # silu and mul\n        self.act_fn = SiluAndMul(inplace=True)\n\n        # down\n        self.down_proj = build_down_linear(config.intermediate_size,\n                                           config.hidden_size,\n                                           bias=False,\n                                           quant_config=quantization_config,\n                                           dtype=dtype,\n                                           device=device,\n                                           is_tp=True)\n\n    def forward(self, x):\n        \"\"\"forward.\"\"\"\n        gate_up = self.gate_up_proj(x)\n        act = self.act_fn(gate_up)\n        return self.down_proj(act)\n\n\nclass Phi3DecoderLayer(nn.Module):\n    \"\"\"Decoder layer.\"\"\"\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 layer_idx: int,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__()\n        self.layer_idx = layer_idx\n        quantization_config = getattr(config, 'quantization_config', None)\n\n        # build attention layer\n        self.self_attn = Phi3Attention(config, dtype=dtype, device=device)\n\n        # build MLP\n        self.mlp = Phi3MLP(config, dtype=dtype, device=device)\n\n        # build input layer norm\n        self.input_layernorm = RMSNorm(config.hidden_size,\n                                       config.rms_norm_eps,\n                                       quant_config=quantization_config,\n                                       dtype=dtype,\n                                       device=device)\n\n        # build attention layer norm\n        self.post_attention_layernorm = RMSNorm(config.hidden_size,\n                                                config.rms_norm_eps,\n                                                quant_config=quantization_config,\n                                                dtype=dtype,\n                                                device=device)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],\n        past_key_value: Optional[List[torch.FloatTensor]],\n        residual: Optional[torch.Tensor] = None,\n        attn_metadata: Any = None,\n    ):\n\n        if residual is None:\n            residual = hidden_states\n            hidden_states = self.input_layernorm(hidden_states)\n        else:\n            hidden_states, residual = self.input_layernorm(hidden_states, residual)\n\n        # Self Attention\n        hidden_states = self.self_attn(\n            hidden_states=hidden_states,\n            rotary_pos_emb=rotary_pos_emb,\n            past_key_value=past_key_value,\n            attn_metadata=attn_metadata,\n        )\n\n        # Fully Connected\n        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)\n        hidden_states = self.mlp(hidden_states)\n\n        outputs = (hidden_states, residual)\n        return outputs\n\n\nclass Phi3Model(nn.Module):\n    \"\"\"model.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n\n        self.embed_tokens = nn.Embedding(config.vocab_size,\n                                         config.hidden_size,\n                                         self.padding_idx,\n                                         dtype=dtype,\n                                         device=device)\n\n        # build all decode layers\n        self.layers = nn.ModuleList([\n            Phi3DecoderLayer(config, layer_idx, dtype=dtype, device=device)\n            for layer_idx in range(config.num_hidden_layers)\n        ])\n\n        # build norm\n        self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)\n\n        # build rotary embedding\n        self.rotary_emb = build_rotary_embedding_from_config(config)\n\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        attn_metadata: Any = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n    ):\n        \"\"\"Rewrite of LlamaModel.forward.\"\"\"\n\n        # token embedding\n        if inputs_embeds is None:\n            inputs_embeds = self.embed_tokens(input_ids)\n\n        hidden_states = inputs_embeds\n\n        # rotary embedding\n        cos, sin = self.rotary_emb(hidden_states, position_ids)\n        cos, sin = cos[0], sin[0]\n        rotary_pos_emb = (cos, sin)\n\n        # decoding\n        residual = None\n        for idx, decoder_layer in enumerate(self.layers):\n            past_key_value = past_key_values[idx]\n            hidden_states, residual = decoder_layer(\n                hidden_states,\n                rotary_pos_emb=rotary_pos_emb,\n                past_key_value=past_key_value,\n                residual=residual,\n                attn_metadata=attn_metadata,\n            )\n\n        # norm\n        hidden_states, _ = self.norm(hidden_states, residual)\n\n        return hidden_states\n\n    def get_input_embeddings(self):\n        \"\"\"Get input embeddings.\"\"\"\n        return self.embed_tokens\n\n\nclass Phi3ForCausalLM(nn.Module, DeployModelMixinV1, CudaGraphMixin):\n    \"\"\"ModelForCausalLM.\"\"\"\n\n    packed_modules_mapping = {\n        'gate_up_proj': [\n            'gate_proj',\n            'up_proj',\n        ],\n    }\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 ctx_mgr: StepContextManager,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__()\n        self.config = config\n        self.ctx_mgr = ctx_mgr\n        # build model\n        self.model = Phi3Model(config, dtype=dtype, device=device)\n        # build lm_head\n        self.lm_head = self.build_lm_head(config.hidden_size, config.vocab_size, bias=False, dtype=dtype, device=device)\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        past_key_values: List[List[torch.Tensor]],\n        attn_metadata: Any = None,\n        inputs_embeds: torch.Tensor = None,\n        **kwargs,\n    ):\n        \"\"\"Model forward, return logits.\"\"\"\n        hidden_states = self.model(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            inputs_embeds=inputs_embeds,\n        )\n        return hidden_states\n\n    def get_input_embeddings(self):\n        \"\"\"Get input embeddings.\"\"\"\n        return self.model.get_input_embeddings()\n\n    def prepare_inputs_for_generation(\n        self,\n        past_key_values: List[List[torch.Tensor]],\n        inputs_embeds: Optional[torch.Tensor] = None,\n        context: StepContext = None,\n    ):\n        \"\"\"Prepare input.\"\"\"\n        # get input_ids, position_ids and attention metadatas\n        input_ids = context.input_ids\n        position_ids = context.position_ids\n        attn_metadata = context.attn_metadata\n\n        # process vision embeddings\n        vision_embeddings = context.input_embeddings\n        vision_embedding_indexing = context.input_embedding_indexing\n        if vision_embeddings is not None and len(vision_embeddings) > 0:\n            if inputs_embeds is None:\n                inputs_embeds = self.get_input_embeddings()(input_ids)\n            inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds)\n\n        # inputs of forward\n        return dict(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            inputs_embeds=inputs_embeds,\n        )\n\n    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):\n        \"\"\"Load weights.\"\"\"\n        # modify from vllm\n\n        params_dict = dict(self.named_parameters())\n        for name, loaded_weight in weights:\n            if 'rotary_emb.inv_freq' in name:\n                continue\n            if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):\n                continue\n            if self.config.tie_word_embeddings and 'lm_head.weight' in name:\n                continue\n            if 'vision_embed_tokens' in name:\n                continue\n            if '.qkv_proj' in name:\n                param = params_dict[name]\n                q, k, v = param.weight_spliter(loaded_weight)\n                load_weight(param, q, shard_id='q')\n                load_weight(param, k, shard_id='k')\n                load_weight(param, v, shard_id='v')\n            elif '.gate_up_proj' in name:\n                param = params_dict[name]\n                gate, up = param.weight_spliter(loaded_weight)\n                load_weight(param, gate, shard_id=0)\n                load_weight(param, up, shard_id=1)\n            else:\n                param = params_dict[name]\n                load_weight(param, loaded_weight)\n"
  },
  {
    "path": "lmdeploy/pytorch/models/phi3_moe.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nfrom typing import Any, Iterable, List, Optional, Tuple\n\nimport torch\nfrom torch import nn\n\nfrom lmdeploy.pytorch.model_inputs import StepContext, StepContextManager\nfrom lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, LayerNorm, RopeType\nfrom lmdeploy.pytorch.nn.linear import build_qkv_proj, build_rowwise_linear\nfrom lmdeploy.pytorch.nn.moe import build_fused_moe\nfrom lmdeploy.pytorch.nn.rotary_embedding import (LongRoPEScalingParameters, build_rotary_embedding,\n                                                  get_rope_parameters, get_rope_theta)\nfrom lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight\n\nfrom .utils.cudagraph import CudaGraphMixin\n\n\ndef sparsemixer(scores, top_k, jitter_eps):\n    assert top_k == 2\n\n    with torch.no_grad():\n        # compute mask for sparsity\n        mask_logits_threshold, max_ind = scores.max(dim=-1, keepdim=True)\n        factor = scores.abs().clamp(min=mask_logits_threshold)\n        mask_logits_threshold = ((mask_logits_threshold - scores) / factor) > (2 * jitter_eps)\n\n    # apply mask\n    masked_gates = scores.masked_fill(mask_logits_threshold, float('-inf'))\n    selected_experts = max_ind\n\n    # compute scores for gradients\n    masked_gates = torch.softmax(masked_gates, dim=-1)\n    multiplier_o = masked_gates.gather(dim=-1, index=selected_experts)\n    multiplier = multiplier_o\n\n    # masked out first expert\n    masked_scores = torch.scatter(\n        scores,\n        -1,\n        selected_experts,\n        float('-inf'),\n    )\n    with torch.no_grad():\n        # compute mask for sparsity\n        mask_logits_threshold, max_ind = masked_scores.max(dim=-1, keepdim=True)\n        factor = scores.abs().clamp(min=mask_logits_threshold)\n        mask_logits_threshold = ((mask_logits_threshold - scores) / factor) > (2 * jitter_eps)\n\n    # apply mask\n    masked_gates_top2 = masked_scores.masked_fill(mask_logits_threshold, float('-inf'))\n    selected_experts_top2 = max_ind\n    # compute scores for gradients\n    masked_gates_top2 = torch.softmax(masked_gates_top2, dim=-1)\n    multiplier_top2_o = masked_gates_top2.gather(dim=-1, index=selected_experts_top2)\n\n    multiplier_top2 = multiplier_top2_o\n\n    multiplier = torch.concat((multiplier, multiplier_top2), dim=-1)\n    selected_experts = torch.concat((selected_experts, selected_experts_top2), dim=-1)\n\n    return (\n        multiplier,\n        selected_experts,\n    )\n\n\nclass PhiMoEAttention(nn.Module):\n    \"\"\"PhiMoE attention.\"\"\"\n\n    def __init__(self, config: Any, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        quantization_config = None\n\n        num_heads = config.num_attention_heads\n        num_key_value_heads = config.num_key_value_heads\n        hidden_size = config.hidden_size\n        head_dim = hidden_size // num_heads\n\n        # qkv\n        self.qkv_proj = build_qkv_proj(\n            hidden_size,\n            num_q_heads=num_heads,\n            num_kv_heads=num_key_value_heads,\n            head_size=head_dim,\n            bias=config.attention_bias,\n            quant_config=quantization_config,\n            dtype=dtype,\n            device=device,\n        )\n\n        self.apply_rotary_pos_emb = ApplyRotaryEmb()\n\n        # attention\n        self.sliding_window = getattr(config, 'sliding_window', None)\n        self.attn_fwd = Attention(\n            num_heads,\n            head_dim,\n            num_kv_heads=num_key_value_heads,\n            sliding_window=self.sliding_window,\n        )\n\n        # o_proj\n        self.o_proj = build_rowwise_linear(num_heads * head_dim,\n                                           hidden_size,\n                                           bias=config.attention_bias,\n                                           quant_config=quantization_config,\n                                           dtype=dtype,\n                                           device=device,\n                                           is_tp=True)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],\n        past_key_value: Optional[Tuple[torch.Tensor]] = None,\n        attn_metadata: Any = None,\n    ):\n        \"\"\"Rewrite of LlamaAttention.forward.\"\"\"\n        qkv_states = self.qkv_proj(hidden_states)\n        # (-1, heads, head_dim)\n        qkv_states = qkv_states.flatten(0, -2)\n        query_states, key_states, value_states = self.qkv_proj.split_qkv(qkv_states)\n\n        cos, sin = rotary_pos_emb\n        query_states, key_states = self.apply_rotary_pos_emb(\n            query_states,\n            key_states,\n            cos,\n            sin,\n            inplace=True,\n        )\n        attn_output = self.attn_fwd(\n            query_states,\n            key_states,\n            value_states,\n            past_key_value[0],\n            past_key_value[1],\n            attn_metadata,\n            k_scales_zeros=None if len(past_key_value) == 2 else past_key_value[2],\n            v_scales_zeros=None if len(past_key_value) == 2 else past_key_value[3],\n            inplace=True,\n        )\n        attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1)\n\n        attn_output = self.o_proj(attn_output)\n        return attn_output\n\n\nclass PhiMoESparseMoeBlock(nn.Module):\n    \"\"\"PhiMoE sparse moe block.\"\"\"\n\n    def __init__(self, config: Any, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        self.hidden_dim = config.hidden_size\n        self.ffn_dim = config.intermediate_size\n        self.num_experts = config.num_local_experts\n        self.top_k = config.num_experts_per_tok\n\n        self.gate = build_rowwise_linear(\n            self.hidden_dim,\n            self.num_experts,\n            bias=False,\n            dtype=dtype,\n            device=device,\n            is_tp=False,\n        )\n\n        self.experts = build_fused_moe(\n            self.hidden_dim,\n            self.ffn_dim,\n            self.num_experts,\n            top_k=2,\n            renormalize=False,\n            dtype=dtype,\n            device=device,\n            all_reduce=True,\n        )\n\n        self.router_jitter_noise = config.router_jitter_noise\n        self.input_jitter_noise = config.input_jitter_noise\n\n    def forward(self, hidden_states: torch.Tensor):\n        \"\"\"forward.\"\"\"\n        batch_size, sequence_length, hidden_dim = hidden_states.shape\n        if self.input_jitter_noise > 0:\n            hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.input_jitter_noise,\n                                                                      1.0 + self.input_jitter_noise)\n        hidden_states = hidden_states.view(-1, hidden_dim)\n        router_logits = self.gate(hidden_states)\n\n        topk_weights, topk_ids = sparsemixer(router_logits, top_k=2, jitter_eps=self.router_jitter_noise)\n        out_states = self.experts(\n            hidden_states,\n            topk_weights,\n            topk_ids,\n        )\n\n        out_states = out_states.reshape(batch_size, sequence_length, -1)\n        return out_states, router_logits\n\n\nclass PhiMoEDecoderLayer(nn.Module):\n    \"\"\"PhiMoE decoder layer.\"\"\"\n\n    def __init__(self, config: Any, layer_idx: int, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        self.layer_idx = layer_idx\n\n        # build attention layer\n        self.self_attn = PhiMoEAttention(config, dtype=dtype, device=device)\n        self.block_sparse_moe = PhiMoESparseMoeBlock(config, dtype=dtype, device=device)\n\n        # build input layer norm\n        self.input_layernorm = LayerNorm(config.hidden_size, eps=config.rms_norm_eps, dtype=dtype, device=device)\n\n        # build attention layer norm\n        self.post_attention_layernorm = LayerNorm(config.hidden_size,\n                                                  eps=config.rms_norm_eps,\n                                                  dtype=dtype,\n                                                  device=device)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],\n        past_key_value: Optional[List[torch.FloatTensor]],\n        residual: Optional[torch.Tensor] = None,\n        attn_metadata: Any = None,\n    ):\n        if residual is None:\n            residual = hidden_states\n            hidden_states = self.input_layernorm(hidden_states)\n        else:\n            hidden_states, residual = self.input_layernorm(hidden_states, residual)\n\n        # Self Attention\n        hidden_states = self.self_attn(\n            hidden_states=hidden_states,\n            rotary_pos_emb=rotary_pos_emb,\n            past_key_value=past_key_value,\n            attn_metadata=attn_metadata,\n        )\n\n        # Fully Connected\n        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)\n        hidden_states, _ = self.block_sparse_moe(hidden_states)\n\n        outputs = (hidden_states, residual)\n        return outputs\n\n\nclass PhiMoEModel(nn.Module):\n    \"\"\"PhiMoE model.\"\"\"\n\n    def __init__(self, config: Any, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n        self.embed_tokens = nn.Embedding(config.vocab_size,\n                                         config.hidden_size,\n                                         self.padding_idx,\n                                         dtype=dtype,\n                                         device=device)\n        self.layers = nn.ModuleList([\n            PhiMoEDecoderLayer(config, layer_idx, dtype=dtype, device=device)\n            for layer_idx in range(config.num_hidden_layers)\n        ])\n\n        # build norm\n        self.norm = LayerNorm(config.hidden_size, eps=config.rms_norm_eps, dtype=dtype, device=device)\n\n        # build rotary embedding\n        emb_type = RopeType.LinearScaling\n        rope_dim = config.hidden_size // config.num_attention_heads\n        rope_max_pos_emb = config.max_position_embeddings\n        rope_base = get_rope_theta(config)\n        rope_scaling = get_rope_parameters(config)\n        if rope_scaling is not None:\n            scaling_type = rope_scaling['type']\n            assert scaling_type in ['longrope', 'su']\n            emb_type = RopeType.LongRoPEScaling\n            ori_pos_emb = getattr(config, 'original_max_position_embeddings', rope_max_pos_emb)\n            longrope_params = LongRoPEScalingParameters(short_factor=rope_scaling['short_factor'],\n                                                        long_factor=rope_scaling['long_factor'],\n                                                        original_max_position_embeddings=ori_pos_emb,\n                                                        short_mscale=rope_scaling['short_mscale'],\n                                                        long_mscale=rope_scaling['long_mscale'])\n            self.rotary_emb = build_rotary_embedding(\n                rope_dim,\n                rope_max_pos_emb,\n                rope_base,\n                longrope_params=longrope_params,\n                emb_type=emb_type,\n            )\n        else:\n            self.rotary_emb = build_rotary_embedding(\n                rope_dim,\n                rope_max_pos_emb,\n                rope_base,\n                emb_type=emb_type,\n            )\n\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        attn_metadata: Any = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n    ):\n        \"\"\"forward.\"\"\"\n        if inputs_embeds is None:\n            inputs_embeds = self.embed_tokens(input_ids)\n\n        hidden_states = inputs_embeds\n        cos, sin = self.rotary_emb(hidden_states, position_ids)\n        cos, sin = cos[0], sin[0]\n        rotary_pos_emb = (cos, sin)\n        residual = None\n        for idx, decoder_layer in enumerate(self.layers):\n            past_key_value = past_key_values[idx]\n            hidden_states, residual = decoder_layer(\n                hidden_states,\n                rotary_pos_emb=rotary_pos_emb,\n                past_key_value=past_key_value,\n                residual=residual,\n                attn_metadata=attn_metadata,\n            )\n\n        hidden_states, _ = self.norm(hidden_states, residual)\n\n        return hidden_states\n\n    def get_input_embeddings(self):\n        \"\"\"Get input embeddings.\"\"\"\n        return self.embed_tokens\n\n\nclass PhiMoEForCausalLM(nn.Module, CudaGraphMixin):\n    \"\"\"Mixture model for causalLM.\"\"\"\n\n    def __init__(self,\n                 config: Any,\n                 ctx_mgr: StepContextManager,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__()\n        self.config = config\n        self.ctx_mgr = ctx_mgr\n        self.model = PhiMoEModel(config, dtype=dtype, device=device)\n        # build lm_head\n        self.lm_head = build_rowwise_linear(config.hidden_size,\n                                            config.vocab_size,\n                                            bias=config.lm_head_bias,\n                                            dtype=dtype,\n                                            device=device)\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        past_key_values: List[List[torch.Tensor]],\n        attn_metadata: Any = None,\n        inputs_embeds: torch.Tensor = None,\n        **kwargs,\n    ):\n        hidden_states = self.model(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            inputs_embeds=inputs_embeds,\n        )\n        return hidden_states\n\n    def get_logits(self, hidden_states: torch.Tensor):\n        \"\"\"Compute logits of the model output.\"\"\"\n        return self.lm_head(hidden_states)\n\n    def get_input_embeddings(self):\n        \"\"\"Get input embeddings.\"\"\"\n        return self.model.get_input_embeddings()\n\n    def prepare_inputs_for_generation(\n        self,\n        past_key_values: List[List[torch.Tensor]],\n        inputs_embeds: Optional[torch.Tensor] = None,\n        context: StepContext = None,\n    ):\n        \"\"\"Prepare input.\"\"\"\n        input_ids = context.input_ids\n        position_ids = context.position_ids\n        attn_metadata = context.attn_metadata\n\n        return dict(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            inputs_embeds=inputs_embeds,\n        )\n\n    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):\n        \"\"\"Load weights.\"\"\"\n        # modify from vllm\n        stacked_params_mapping = [\n            # (param_name, shard_name, shard_id)\n            ('.qkv_proj', '.q_proj', 'q'),\n            ('.qkv_proj', '.k_proj', 'k'),\n            ('.qkv_proj', '.v_proj', 'v'),\n        ]\n\n        num_experts = self.config.num_local_experts\n        expert_params_mapping = []\n        for exp_id in range(num_experts):\n            gate_param = ('.experts.gate_up', f'.experts.{exp_id}.w1', exp_id, 'gate')\n            up_param = ('.experts.gate_up', f'.experts.{exp_id}.w3', exp_id, 'up')\n            down_param = ('.experts.down', f'.experts.{exp_id}.w2', exp_id, 'down')\n            expert_params_mapping += [gate_param, up_param, down_param]\n\n        params_dict = dict(self.named_parameters())\n        for name, loaded_weight in weights:\n            if 'rotary_emb.inv_freq' in name:\n                continue\n            if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):\n                continue\n            if self.config.tie_word_embeddings and 'lm_head.weight' in name:\n                continue\n            for (param_name, weight_name, expert_id, shard_id) in expert_params_mapping:\n                if weight_name not in name:\n                    continue\n                name = name.replace(weight_name, param_name)\n                param = params_dict[name]\n                load_weight(param, loaded_weight, expert_id=expert_id, shard_id=shard_id)\n                break\n            else:\n                for (param_name, weight_name, shard_id) in stacked_params_mapping:\n                    if weight_name not in name:\n                        continue\n                    name = name.replace(weight_name, param_name)\n                    param = params_dict[name]\n                    load_weight(param, loaded_weight, shard_id=shard_id)\n                    break\n                else:\n                    param = params_dict[name]\n                    load_weight(param, loaded_weight)\n"
  },
  {
    "path": "lmdeploy/pytorch/models/phi3_v.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nfrom typing import Any, Dict, Iterable, List, Optional, Tuple\n\nimport torch\nfrom torch import nn\nfrom transformers import CLIPVisionConfig, CLIPVisionModel, PretrainedConfig\n\nfrom lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor, PreprocessInputResult\nfrom lmdeploy.pytorch.model_inputs import StepContext, StepContextManager\nfrom lmdeploy.pytorch.multimodal.data_type import MultiModalData\nfrom lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight\n\nfrom .phi3 import Phi3ForCausalLM, Phi3Model\nfrom .utils.model import vlm_model\n\nCLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(attention_dropout=0.0,\n                                                     dropout=0.0,\n                                                     hidden_act='quick_gelu',\n                                                     hidden_size=1024,\n                                                     image_size=336,\n                                                     initializer_factor=1.0,\n                                                     initializer_range=0.02,\n                                                     intermediate_size=4096,\n                                                     layer_norm_eps=1e-05,\n                                                     num_attention_heads=16,\n                                                     num_channels=3,\n                                                     num_hidden_layers=24,\n                                                     patch_size=14,\n                                                     projection_dim=768)\n\n\n@vlm_model\nclass Phi3ImageEmbedding(nn.Module):\n    \"\"\"Image embedding.\"\"\"\n\n    # from https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/c45209e90a4c4f7d16b2e9d48503c7f3e83623ed/image_embedding_phi3_v.py#L83 # noqa: E501\n    def __init__(self,\n                 config: PretrainedConfig,\n                 wte=None,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None,\n                 **kwargs):\n        super().__init__()\n        self.config = config\n        hidden_size = config.n_embd if hasattr(config, 'n_embd') else config.hidden_size\n\n        self.wte = wte\n\n        if (isinstance(config.img_processor, dict) and config.img_processor.get('name', None) == 'clip_vision_model'):\n            assert 'model_name' in config.img_processor, ('model_name must be provided for CLIPVisionModel')\n            assert 'image_dim_out' in config.img_processor, ('image_dim_out must be provided for CLIPVisionModel')\n            assert 'num_img_tokens' in config.img_processor, ('num_img_tokens must be provided for CLIPVisionModel')\n            assert config.img_processor['model_name'] == 'openai/clip-vit-large-patch14-336'\n            clip_config = CLIP_VIT_LARGE_PATCH14_336_CONFIG\n            self.img_processor = CLIPVisionModel(clip_config).to(device).to(dtype)\n            image_dim_out = config.img_processor['image_dim_out']\n            self.num_img_tokens = config.img_processor['num_img_tokens']\n        else:\n            raise NotImplementedError(f'img_processor = {config.img_processor}, not implemented')\n\n        self.image_dim_out = image_dim_out\n        self.img_sizes = None\n\n        self.use_hd_transform = kwargs.get('use_hd_transform', False)\n        self.with_learnable_separator = kwargs.get('with_learnable_separator', False)\n        self.hd_transform_order = kwargs.get('hd_transform_order', 'glb_sub')\n        # with_hd_transform and with_learnable_separator should have same value\n        assert (self.use_hd_transform == self.with_learnable_separator), (\n            'use_hd_transform and with_learnable_separator '\n            'should have same value')\n        if self.with_learnable_separator:\n            assert self.use_hd_transform, ('learnable separator is only for hd transform')\n            # 1024 * 4, merge spatial to channel dimension\n            self.glb_GN = nn.Parameter(torch.empty([1, 1, self.image_dim_out * 4], dtype=dtype, device=device))\n            self.sub_GN = nn.Parameter(torch.empty([1, 1, 1, self.image_dim_out * 4], dtype=dtype, device=device))\n\n        projection_cls = kwargs.get('projection_cls', 'linear')\n        if projection_cls == 'linear':\n            self.img_projection = nn.Linear(image_dim_out, hidden_size, dtype=dtype, device=device)\n        elif projection_cls == 'mlp' and self.use_hd_transform:\n            dim_projection = hidden_size\n            depth = 2\n            layers = [nn.Linear(image_dim_out * 4, dim_projection, dtype=dtype, device=device)]\n            for _ in range(1, depth):\n                layers.extend([nn.GELU(), nn.Linear(dim_projection, dim_projection, dtype=dtype, device=device)])\n            self.img_projection = nn.Sequential(*layers)\n        elif projection_cls == 'mlp':\n            dim_projection = hidden_size\n            depth = 2\n            layers = [nn.Linear(image_dim_out, dim_projection, dtype=dtype, device=device)]\n            for _ in range(1, depth):\n                layers.extend([nn.GELU(), nn.Linear(dim_projection, dim_projection, dtype=dtype, device=device)])\n            self.img_projection = nn.Sequential(*layers)\n        else:\n            raise NotImplementedError(f'projection_cls = {projection_cls}, not implemented')\n\n        self.vocab_size = config.vocab_size\n        self.img_features = None\n\n        if isinstance(config.img_processor, dict):\n            self.layer_idx = config.img_processor.get('layer_idx', -2)\n            self.type_feature = config.img_processor.get('type_feature', 'patch')\n        else:\n            self.layer_idx = -2\n            self.type_feature = 'patch'\n\n    def get_img_features(self, img_embeds: torch.FloatTensor) -> torch.FloatTensor:\n        LAYER_IDX = self.layer_idx\n        TYPE_FEATURE = self.type_feature\n\n        img_processor_output = self.img_processor(img_embeds, output_hidden_states=True)\n        img_feature = img_processor_output.hidden_states[LAYER_IDX]\n\n        if TYPE_FEATURE == 'patch':\n            patch_feature = img_feature[:, 1:]\n            return patch_feature\n\n        if TYPE_FEATURE == 'cls_patch':\n            return img_feature\n\n        raise NotImplementedError\n\n    def forward(\n        self,\n        input_ids: torch.LongTensor,\n        pixel_values: torch.FloatTensor,\n        image_sizes=None,\n        image_mask: torch.Tensor = None,\n    ) -> torch.FloatTensor:\n        \"\"\"forward.\"\"\"\n        inputs_embeds = self.wte(input_ids)\n        assert self.use_hd_transform\n        num_images, num_crops, c, h, w = pixel_values.shape\n        assert c == 3 and h == w == 336\n        img_features = self.get_img_features(pixel_values.flatten(0, 1)).reshape(num_images, num_crops, -1,\n                                                                                 self.image_dim_out)\n        image_features_proj = self.hd_feature_transform(img_features, image_sizes)\n        # update image feature to inputs_embeds\n        inputs_embeds.masked_scatter_(image_mask[..., None], image_features_proj)\n        return inputs_embeds\n\n    def hd_feature_transform(self, image_features, image_sizes):\n        \"\"\"\n        image_features: (num_images, num_crops+1, 24*24, 1024)\n        \"\"\"\n        assert (self.hd_transform_order == 'sub_glb'), f'hd_transform_order `{self.hd_transform_order}` not implemented'\n        if isinstance(self.img_projection, nn.Sequential):\n            target_device = self.img_projection[0].bias.device\n            target_dtype = self.img_projection[0].bias.dtype\n        else:  # It's a single nn.Linear layer\n            target_device = self.img_projection.bias.device\n            target_dtype = self.img_projection.bias.dtype\n\n        global_image_features = image_features[:, 0]  # (num_images, 24*24, 1024)\n        # global feature can be viewed as a special HD case with num_crops 1x1\n        global_image_features_hd = self.reshape_hd_patches_2x2merge(global_image_features, 1, 1)\n        global_image_features_hd_newline = self.add_image_newline(global_image_features_hd)\n\n        all_image_embeddings = []\n        # need a for loop to process each image because of different image sizes\n        # (patch arrangement is different for each image)\n        for i, img_size in enumerate(image_sizes):\n            h, w = img_size\n            h_crop = h // 336\n            w_crop = w // 336\n            num_crops = h_crop * w_crop\n\n            # NOTE: real num_crops is padded\n            # (num_crops, 24*24, 1024)\n            sub_image_features = image_features[i, 1:1 + num_crops]\n            sub_image_features_hd = self.reshape_hd_patches_2x2merge(sub_image_features, h_crop, w_crop)\n            sub_image_features_hd_newline = self.add_image_newline(sub_image_features_hd)\n\n            # [sub features, separator, global features]\n            all_image_embeddings.extend([\n                sub_image_features_hd_newline.squeeze(0),  # (h_crop*12*(w_crop*12+1), 4096)\n                self.glb_GN.squeeze(0),\n                global_image_features_hd_newline[i],\n            ])\n\n        image_features_proj = self.img_projection(\n            torch.cat(all_image_embeddings, dim=0).to(target_device).to(target_dtype))\n\n        return image_features_proj\n\n    def reshape_hd_patches_2x2merge(self, image_features, h_crop, w_crop):\n        \"\"\"\n        image_features: (num_images*num_crops, 24*24, 1024)\n        output: (num_images, h_crop*12, w_crop*12, 4096), h_crop*w_crop == num_crops\n        \"\"\"\n        N, L, C = image_features.shape\n        assert L == 24 * 24 and C == 1024 and N % (h_crop * w_crop) == 0\n        num_images = N // (h_crop * w_crop)\n        H = int(L**0.5)\n        image_features_hd = (\n            image_features.reshape(N, H, H, C)  # N, 24, 24, 1024\n            .reshape(N, H // 2, 2, H // 2, 2, C)  # N, 12, 2, 12, 2, 1024\n            .permute(0, 1, 3, 2, 4, 5)  # N, 12, 12, 2, 2, 1024\n            .reshape(N, -1, 4 * C)  # N, 144, 4096\n            .reshape(num_images, h_crop, w_crop, H // 2, H // 2, -1)  # n_img, h_crop, w_crop, 12, 12, 4096\n            .permute(0, 1, 3, 2, 4, 5)  # n_img, h_crop, 12, w_crop, 12, 4096\n            .reshape(num_images, h_crop * H // 2, w_crop * H // 2, 4 * C)  # n_img, h_crop*12, w_crop*12, 4096\n        )\n        return image_features_hd\n\n    def add_image_newline(self, image_features_hd):\n        \"\"\"\n        image_features_hd: (num_images, h_crop*12, w_crop*12, 4096)\n        output: (num_images, (h_crop*12) * (w_crop*12+1), 4096)\n        \"\"\"\n        num_images, h, w, hid_dim = image_features_hd.shape\n        # add the newline token to the HD image feature patches\n        newline_embeddings = self.sub_GN.expand(num_images, h, -1, -1)  # (n_img, h, 1, hid_dim)\n        image_features_hd_newline = torch.cat([image_features_hd, newline_embeddings],\n                                              dim=2).reshape(num_images, -1, hid_dim)\n        return image_features_hd_newline\n\n\nclass Phi3VModel(Phi3Model):\n    \"\"\"Phi3v model.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__(config=config, dtype=dtype, device=device)\n\n        self.vision_embed_tokens = None\n        if isinstance(config.embd_layer, dict):\n            # vision embedding layer\n            embedding_config = {'embedding_cls': config.embd_layer['embedding_cls'], **config.embd_layer}\n            self.vision_embed_tokens = Phi3ImageEmbedding(config,\n                                                          wte=self.embed_tokens,\n                                                          dtype=dtype,\n                                                          device=device,\n                                                          **embedding_config)\n\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        attn_metadata: Any = None,\n        pixel_values: Optional[torch.FloatTensor] = None,\n        image_sizes: Optional[torch.LongTensor] = None,\n        image_mask: torch.Tensor = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n    ):\n        \"\"\"Rewrite of LlamaModel.forward.\"\"\"\n\n        if inputs_embeds is None and pixel_values is not None:\n            inputs_embeds = self.vision_embed_tokens(\n                input_ids,\n                pixel_values,\n                image_sizes,\n                image_mask,\n            )\n\n        return super().forward(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            inputs_embeds=inputs_embeds,\n        )\n\n\nclass Phi3VForCausalLM(Phi3ForCausalLM):\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 ctx_mgr: StepContextManager,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__(config, ctx_mgr, dtype=dtype, device=device)\n        self.config = config\n        self.ctx_mgr = ctx_mgr\n        # build model\n        self.model = Phi3VModel(config, dtype=dtype, device=device)\n        # build lm_head\n        self.lm_head = self.build_lm_head(config.hidden_size, config.vocab_size, bias=False, dtype=dtype, device=device)\n\n        self.input_processor = Phi3VInputProcessor(config, dtype)\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        past_key_values: List[List[torch.Tensor]],\n        attn_metadata: Any = None,\n        pixel_values: torch.Tensor = None,\n        image_sizes: torch.Tensor = None,\n        image_mask: torch.Tensor = None,\n        inputs_embeds: torch.Tensor = None,\n        **kwargs,\n    ):\n        \"\"\"forward.\"\"\"\n        hidden_states = self.model(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            pixel_values=pixel_values,\n            image_sizes=image_sizes,\n            image_mask=image_mask,\n            inputs_embeds=inputs_embeds,\n        )\n        return hidden_states\n\n    def prepare_inputs_for_generation(\n        self,\n        past_key_values: List[List[torch.Tensor]],\n        inputs_embeds: torch.Tensor = None,\n        context: StepContext = None,\n    ):\n        \"\"\"Prepare input.\"\"\"\n        output = super().prepare_inputs_for_generation(past_key_values=past_key_values,\n                                                       inputs_embeds=inputs_embeds,\n                                                       context=context)\n\n        # vision inputs\n        pixel_values = None\n        if context.input_multimodals is not None:\n            input_mms = [input_mm.get('image', []) for input_mm in context.input_multimodals]\n            # flatten batch\n            input_mms = [data for im_data in input_mms for data in im_data]\n            if len(input_mms) > 0:\n                pixel_values = torch.cat([data.data for data in input_mms])\n                image_sizes = torch.cat([data.meta['image_sizes'] for data in input_mms])\n                image_token_id = input_mms[0].meta['image_token_id']\n                image_mask = output['input_ids'] == image_token_id\n                output['pixel_values'] = pixel_values\n                output['image_sizes'] = image_sizes\n                output['image_mask'] = image_mask\n\n        return output\n\n    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):\n        \"\"\"Load weights.\"\"\"\n        import itertools\n\n        vis_prefix = 'vision_embed_tokens.'\n        # create two ierators from weights for llm and vlm\n        llm_weights, vlm_weights = itertools.tee(weights, 2)\n        llm_weights = ((name, tensor) for name, tensor in llm_weights if vis_prefix not in name)\n        vlm_weights = ((name, tensor) for name, tensor in vlm_weights if vis_prefix in name)\n        super().load_weights(llm_weights)\n\n        params_dict = dict(self.named_parameters())\n        for name, loaded_weight in vlm_weights:\n            param = params_dict[name]\n            load_weight(param, loaded_weight)\n\n    def get_input_processor(self) -> BaseModelInputProcessor:\n        \"\"\"Get input processor.\"\"\"\n        return self.input_processor\n\n\nclass Phi3VInputProcessor(BaseModelInputProcessor):\n    \"\"\"Phi3V input processor.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype) -> None:\n        self.config = config\n        self.dtype = dtype\n\n    def preprocess_input(self,\n                         input_ids: List[int],\n                         input_multimodals: List[Dict[str, Any]] = None,\n                         **kwargs) -> PreprocessInputResult:\n        \"\"\"Prepare multimodal input.\"\"\"\n        if input_multimodals is None or len(input_multimodals) == 0:\n            return input_ids, input_multimodals\n\n        input_imgs = []\n        for input_mm in input_multimodals:\n            pixel_values = input_mm['pixel_values']\n            image_sizes = input_mm['image_sizes']\n            offset = input_mm['offset']\n            image_token_id = input_mm['image_token_id']\n            num_pad = input_mm['image_tokens']\n            if isinstance(num_pad, torch.Tensor):\n                num_pad = num_pad.item()\n\n            mm_data = MultiModalData(data=pixel_values,\n                                     start=offset,\n                                     end=offset + num_pad,\n                                     meta=dict(image_sizes=image_sizes, image_token_id=image_token_id))\n            input_imgs.append(mm_data)\n\n        result = PreprocessInputResult(\n            input_ids=input_ids,\n            input_multimodals=dict(image=input_imgs),\n        )\n        return result\n"
  },
  {
    "path": "lmdeploy/pytorch/models/q_modules.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nfrom dataclasses import dataclass, fields\n\nimport torch\nimport torch.nn as nn\n\nfrom ..kernels.w8a8_triton_kernels import (matmul_kernel_dynamic_quant, per_channel_quant, per_token_quant_int8,\n                                           rms_norm_dynamic_quant)\n\n\n@dataclass\nclass QTensor:\n    \"\"\"A data class representing a Quantized Tensor.\n\n    This class wraps around a regular Pytorch tensor and adds quantization- specific parameters.\n    \"\"\"\n    tensor: torch.Tensor\n    scale: torch.Tensor\n    zero_point: torch.Tensor = None\n\n    def __post_init__(self):\n        self.fields = [field.name for field in fields(self)]\n\n    def __getattr__(self, name: str):\n        \"\"\"Allows attribute access to be forwarded to the wrapped tensor when\n        the attribute doesn't exist in QTensor.\"\"\"\n        if name in self.fields:\n            return super().__getattr__(name)\n        return getattr(self.tensor, name)\n\n\nclass QRMSNorm(nn.Module):\n    \"\"\"It performs traditional RMS normalization and then quantizes the output\n    to 8-bit integers.\"\"\"\n\n    def __init__(self, hidden_size, eps=1e-6, quant_dtype=torch.int8):\n        super().__init__()\n        self.weight = nn.Parameter(torch.ones(hidden_size))\n        self.variance_epsilon = eps\n        self.quant_dtype = quant_dtype\n\n    @classmethod\n    def from_float(cls, mod: nn.Module, initialization: bool = True, quant_dtype=torch.int8):\n        \"\"\"Class method to create a QRMSNorm instance from a floating-point\n        module.\n\n        `initialization = True` for real init. `initialization = False` for dummy init.\n        \"\"\"\n        hidden_size = mod.weight.shape[0]\n        eps = mod.variance_epsilon\n        q_mod = cls(hidden_size, eps, quant_dtype=quant_dtype)\n        if initialization:\n            q_mod.weight = nn.Parameter(mod.weight.detach())\n        return q_mod\n\n    def forward(self, hidden_states):\n        \"\"\"Defines the computation performed at every call.\n\n        Performs RMS normalization followed by dynamic quantization on hidden_states. Returns a QTensor which wraps the\n        quantized tensor along with its scale factor.\n        \"\"\"\n        hidden_states_quant, rms_scale = rms_norm_dynamic_quant(hidden_states,\n                                                                self.weight,\n                                                                self.variance_epsilon,\n                                                                quant_dtype=self.quant_dtype)\n        return QTensor(hidden_states_quant, rms_scale)\n\n\nclass QLinear(nn.Module):\n    \"\"\"A Linear layer that operates on quantized inputs and weights.\n\n    It performs matrix multiplication in 8-bit precision and dequantize the results back to float.\n    \"\"\"\n\n    __constants__ = ['in_features', 'out_features']\n    in_features: int\n    out_features: int\n    weight: torch.Tensor\n\n    def __init__(self,\n                 in_features: int,\n                 out_features: int,\n                 bias: bool = True,\n                 device=None,\n                 dtype=None,\n                 quant_dtype=torch.int8) -> None:\n        factory_kwargs = {'device': device, 'dtype': dtype}\n        super().__init__()\n        self.in_features = in_features\n        self.out_features = out_features\n        self.quant_dtype = quant_dtype\n        self.register_buffer('weight', torch.empty((out_features, in_features), device=device, dtype=quant_dtype))\n        self.register_buffer('scale', torch.empty((out_features, 1), device=device, dtype=torch.float32))\n        if bias:\n            self.register_buffer('bias', torch.empty(out_features, **factory_kwargs))\n        else:\n            self.register_parameter('bias', None)\n\n    @classmethod\n    def from_float(cls, mod: nn.Module, initialization: bool = True, quant_dtype=torch.int8):\n        \"\"\"Class method to create a QLinear instance from a floating-point\n        module.\n\n        `initialization = True` for real init. `initialization = False` for dummy init.\n        \"\"\"\n        q_mod = cls(mod.in_features,\n                    mod.out_features,\n                    mod.bias is not None,\n                    device=mod.weight.device,\n                    dtype=mod.weight.dtype,\n                    quant_dtype=quant_dtype)\n\n        if initialization:\n            weight_quant, scale = per_channel_quant(mod.weight.detach(), quant_dtype)\n            q_mod.weight.data = weight_quant\n            q_mod.scale = scale\n\n        if mod.bias is not None:\n            q_mod.bias.data = mod.bias.detach()\n        return q_mod\n\n    def forward(self, input):\n        \"\"\"Defines the computation performed at every call.\n\n        Performs quantization if the input is a tensor, otherwise it assumes the input is already quantized (instance of\n        QTensor). Then, it performs linear transformation using dynamic quantization method, resulting in an 8-bit\n        integer output. Finally, it dequantizes the result back to a floating point tensor.\n        \"\"\"\n\n        if isinstance(input, torch.Tensor):\n            input_quant, input_scale = per_token_quant_int8(input, 1e-7, quant_dtype=self.quant_dtype)\n        else:\n            assert isinstance(input, QTensor)\n            input_quant, input_scale = input.tensor, input.scale\n\n        out = matmul_kernel_dynamic_quant(input_quant,\n                                          self.weight,\n                                          input_scale,\n                                          self.scale,\n                                          output_dtype=torch.float16,\n                                          bias=self.bias)\n        return out\n\n    def extra_repr(self) -> str:\n        return 'in_features={}, out_features={}, bias={}'.format(self.in_features, self.out_features, self.bias\n                                                                 is not None)\n"
  },
  {
    "path": "lmdeploy/pytorch/models/qwen.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nfrom typing import Any, Iterable, List, Optional, Tuple\n\nimport torch\nfrom torch import nn\nfrom transformers.configuration_utils import PretrainedConfig\n\nfrom lmdeploy.pytorch.model_inputs import StepContext, StepContextManager\nfrom lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, RMSNorm, RopeType, SiluAndMul, build_rotary_embedding\nfrom lmdeploy.pytorch.nn.linear import (build_down_linear, build_gateup_linear, build_o_proj, build_qkv_proj,\n                                        build_rowwise_linear)\nfrom lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight\n\nfrom .utils.cudagraph import CudaGraphMixin\n\n\nclass QWenAttention(torch.nn.Module):\n    \"\"\"Parallel self-attention layer abstract class.\n\n    Self-attention layer takes input with size [s, b, h] and returns output of the same size.\n    \"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        quantization_config = getattr(config, 'quantization_config', None)\n\n        self.hidden_size = config.hidden_size\n        self.split_size = config.hidden_size\n        self.num_heads = config.num_attention_heads\n        self.projection_size = config.kv_channels * config.num_attention_heads\n        self.num_attention_heads = config.num_attention_heads\n        self.num_kv_heads = self.num_attention_heads\n        self.head_dim = self.hidden_size // self.num_heads\n        self.c_attn = build_qkv_proj(\n            config.hidden_size,\n            num_q_heads=self.num_attention_heads,\n            num_kv_heads=self.num_kv_heads,\n            head_size=self.head_dim,\n            bias=True,\n            quant_config=quantization_config,\n            dtype=dtype,\n            device=device,\n        )\n\n        # apply rotary\n        self.apply_rotary_pos_emb = ApplyRotaryEmb()\n\n        # attention\n        self.attn_fwd = Attention(\n            self.num_attention_heads,\n            self.head_dim,\n            num_kv_heads=self.num_kv_heads,\n        )\n\n        # o_proj\n        self.c_proj = build_o_proj(self.projection_size,\n                                   config.hidden_size,\n                                   bias=not config.no_bias,\n                                   quant_config=quantization_config,\n                                   dtype=dtype,\n                                   device=device,\n                                   is_tp=True)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],\n        past_key_value: Optional[Tuple[torch.Tensor]] = None,\n        attn_metadata: Any = None,\n    ):\n        \"\"\"Rewrite of LlamaAttention.forward.\"\"\"\n        # qkv proj\n        qkv_states = self.c_attn(hidden_states)\n        # (-1, heads, head_dim)\n        qkv_states = qkv_states.flatten(0, -2)\n        (query_states, key_states, value_states) = self.c_attn.split_qkv(qkv_states)\n\n        # apply rotary embedding\n        cos, sin = rotary_pos_emb\n        query_states, key_states = self.apply_rotary_pos_emb(\n            query_states,\n            key_states,\n            cos,\n            sin,\n            inplace=True,\n        )\n\n        # attention\n        attn_output = self.attn_fwd(\n            query_states,\n            key_states,\n            value_states,\n            past_key_value[0],\n            past_key_value[1],\n            attn_metadata,\n            k_scales_zeros=None if len(past_key_value) == 2 else past_key_value[2],\n            v_scales_zeros=None if len(past_key_value) == 2 else past_key_value[3],\n            inplace=True,\n        )\n        attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1)\n\n        # o proj\n        attn_output = self.c_proj(attn_output)\n        return attn_output\n\n\nclass QWenMLP(nn.Module):\n    \"\"\"mlp.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        quantization_config = getattr(config, 'quantization_config', None)\n        ff_dim_in = config.intermediate_size // 2\n        # gate up\n        self.gate_up_proj = build_gateup_linear(\n            config.hidden_size,\n            [ff_dim_in, ff_dim_in],\n            bias=not config.no_bias,\n            dtype=dtype,\n            device=device,\n            quant_config=quantization_config,\n            is_tp=True,\n        )\n\n        # silu and mul\n        self.act_fn = SiluAndMul(inplace=True)\n\n        # down\n        self.c_proj = build_down_linear(ff_dim_in,\n                                        config.hidden_size,\n                                        bias=not config.no_bias,\n                                        quant_config=quantization_config,\n                                        dtype=dtype,\n                                        device=device,\n                                        is_tp=True)\n\n    def forward(self, x):\n        \"\"\"forward.\"\"\"\n        gate_up = self.gate_up_proj(x)\n        act = self.act_fn(gate_up)\n        return self.c_proj(act)\n\n\nclass QWenBlock(torch.nn.Module):\n    \"\"\"A single transformer layer.\n\n    Transformer layer takes input with size [s, b, h] and returns an output of the same size.\n    \"\"\"\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 layer_number: int,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__()\n        self.layer_number = layer_number\n        hidden_size = config.hidden_size\n        self.bf16 = config.bf16\n\n        quantization_config = getattr(config, 'quantization_config', None)\n\n        # build attention layer\n        self.attn = QWenAttention(config, dtype=dtype, device=device)\n\n        # build MLP\n        self.mlp = QWenMLP(config, dtype=dtype, device=device)\n\n        # build input layer norm\n        self.ln_1 = RMSNorm(hidden_size,\n                            config.layer_norm_epsilon,\n                            quant_config=quantization_config,\n                            dtype=dtype,\n                            device=device)\n\n        # build attention layer norm\n        self.ln_2 = RMSNorm(hidden_size,\n                            config.layer_norm_epsilon,\n                            quant_config=quantization_config,\n                            dtype=dtype,\n                            device=device)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],\n        past_key_value: Optional[List[torch.FloatTensor]],\n        residual: Optional[torch.Tensor] = None,\n        attn_metadata: Any = None,\n    ):\n\n        if residual is None:\n            residual = hidden_states\n            layernorm_output = self.ln_1(hidden_states)\n        else:\n            layernorm_output, residual = self.ln_1(hidden_states, residual)\n\n        # Self Attention\n        layernorm_input = self.attn(\n            hidden_states=layernorm_output,\n            rotary_pos_emb=rotary_pos_emb,\n            past_key_value=past_key_value,\n            attn_metadata=attn_metadata,\n        )\n\n        # Fully Connected\n        layernorm_output, residual = self.ln_2(layernorm_input, residual)\n        mlp_output = self.mlp(layernorm_output)\n\n        outputs = (mlp_output, residual)\n        return outputs\n\n\nclass QWenModel(nn.Module):\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        self.vocab_size = config.vocab_size\n        self.embed_dim = config.hidden_size\n        self.wte = nn.Embedding(self.vocab_size, self.embed_dim, dtype=dtype, device=device)\n\n        # build all decode layers\n        self.h = nn.ModuleList(\n            [QWenBlock(config, layer_idx, dtype=dtype, device=device) for layer_idx in range(config.num_hidden_layers)])\n\n        # build rotary embedding\n        emb_type = RopeType.LinearScaling\n        if config.rotary_pct == 1.0:\n            self.rotary_ndims = None\n        else:\n            assert config.rotary_pct < 1\n            self.rotary_ndims = int(config.kv_channels * config.rotary_pct)\n        rope_dim = (self.rotary_ndims if self.rotary_ndims is not None else config.kv_channels)\n        rope_max_pos_emb = getattr(config, 'max_position_embeddings', 4096)\n        rope_base = config.rotary_emb_base\n        self.rotary_emb = build_rotary_embedding(\n            rope_dim,\n            rope_max_pos_emb,\n            rope_base,\n            emb_type=emb_type,\n        )\n\n        self.ln_f = RMSNorm(self.embed_dim, eps=config.layer_norm_epsilon, dtype=dtype, device=device)\n\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        attn_metadata: Any = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n    ):\n        \"\"\"forward.\"\"\"\n\n        # token embedding\n        if inputs_embeds is None:\n            inputs_embeds = self.wte(input_ids)\n\n        hidden_states = inputs_embeds\n\n        # rotary embedding\n        cos, sin = self.rotary_emb(hidden_states, position_ids)\n        cos, sin = cos[0], sin[0]\n        rotary_pos_emb = (cos, sin)\n\n        # decoding\n        residual = None\n        for idx, decoder_layer in enumerate(self.h):\n            past_key_value = past_key_values[idx]\n            hidden_states, residual = decoder_layer(\n                hidden_states,\n                rotary_pos_emb=rotary_pos_emb,\n                past_key_value=past_key_value,\n                residual=residual,\n                attn_metadata=attn_metadata,\n            )\n\n        # norm\n        hidden_states, residual = self.ln_f(hidden_states, residual)\n\n        return hidden_states\n\n    def get_input_embeddings(self):\n        \"\"\"Get input embeddings.\"\"\"\n        return self.wte\n\n\nclass QWenLMHeadModel(nn.Module, CudaGraphMixin):\n    \"\"\"Rewrote model.\"\"\"\n\n    packed_modules_mapping = {\n        'gate_up_proj': [\n            'w2',\n            'w1',\n        ],\n    }\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 ctx_mgr: StepContextManager,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__()\n        self.config = config\n        self.ctx_mgr = ctx_mgr\n        # build Model\n        self.transformer = QWenModel(config, dtype=dtype, device=device)\n\n        # output_layers\n        self.lm_head = build_rowwise_linear(config.hidden_size,\n                                            config.vocab_size,\n                                            bias=False,\n                                            dtype=dtype,\n                                            device=device)\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        past_key_values: List[List[torch.Tensor]],\n        attn_metadata: Any = None,\n        inputs_embeds: torch.Tensor = None,\n        **kwargs,\n    ):\n        \"\"\"Model forward, return logits.\"\"\"\n        hidden_states = self.transformer(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            inputs_embeds=inputs_embeds,\n        )\n        return hidden_states\n\n    def get_logits(self, hidden_states: torch.Tensor):\n        \"\"\"Compute logits of the model output.\"\"\"\n        return self.lm_head(hidden_states)\n\n    def get_input_embeddings(self):\n        \"\"\"Get input embeddings.\"\"\"\n        return self.transformer.get_input_embeddings()\n\n    def prepare_inputs_for_generation(\n        self,\n        past_key_values: List[List[torch.Tensor]],\n        inputs_embeds: Optional[torch.Tensor] = None,\n        context: StepContext = None,\n    ):\n        \"\"\"Prepare input.\"\"\"\n        # get input_ids, position_ids and attention metadatas\n        input_ids = context.input_ids\n        position_ids = context.position_ids\n        attn_metadata = context.attn_metadata\n\n        # process vision embeddings\n        vision_embeddings = context.input_embeddings\n        vision_embedding_indexing = context.input_embedding_indexing\n        if vision_embeddings is not None and len(vision_embeddings) > 0:\n            if inputs_embeds is None:\n                inputs_embeds = self.get_input_embeddings()(input_ids)\n            inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds)\n\n        # inputs of forward\n        return dict(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            inputs_embeds=inputs_embeds,\n        )\n\n    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):\n        \"\"\"Load weights.\"\"\"\n        # modify from vllm\n        stacked_params_mapping = [\n            # (param_name, shard_name, shard_id)\n            ('.gate_up_proj', '.w2', 0),\n            ('.gate_up_proj', '.w1', 1),\n        ]\n\n        params_dict = dict(self.named_parameters())\n        for name, loaded_weight in weights:\n            if 'visual' in name:\n                continue\n            if 'rotary_pos_emb.inv_freq' in name:\n                continue\n            if ('rotary_pos_emb.cos_cached' in name or 'rotary_pos_emb.sin_cached' in name):\n                continue\n            if (self.config.tie_word_embeddings and 'lm_head.weight' in name):\n                continue\n            for (param_name, weight_name, shard_id) in stacked_params_mapping:\n                if weight_name not in name:\n                    continue\n                name = name.replace(weight_name, param_name)\n                param = params_dict[name]\n                load_weight(param, loaded_weight, shard_id=shard_id)\n                break\n            else:\n                if '.c_attn' in name:\n                    param = params_dict[name]\n                    q, k, v = param.weight_spliter(loaded_weight)\n                    load_weight(param, q, shard_id='q')\n                    load_weight(param, k, shard_id='k')\n                    load_weight(param, v, shard_id='v')\n                else:\n                    param = params_dict[name]\n                    load_weight(param, loaded_weight)\n"
  },
  {
    "path": "lmdeploy/pytorch/models/qwen2.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nfrom typing import Any, Iterable, List, Optional, Tuple\n\nimport torch\nfrom torch import nn\nfrom transformers.configuration_utils import PretrainedConfig\n\nfrom lmdeploy.pytorch.model_inputs import StepContext, StepContextManager\nfrom lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, RMSNorm, SiluAndMul, build_rotary_embedding_from_config\nfrom lmdeploy.pytorch.nn.linear import build_down_linear, build_gateup_linear, build_o_proj, build_qkv_proj\nfrom lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight\n\nfrom .utils.cudagraph import CudaGraphMixin\nfrom .utils.model import DeployModelMixinV1, build_embedding\n\n\nclass Qwen2Attention(nn.Module):\n    \"\"\"Rewrite module of Qwen2Attention.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        quantization_config = getattr(config, 'quantization_config', None)\n        num_heads = config.num_attention_heads\n        num_key_value_heads = config.num_key_value_heads\n        hidden_size = config.hidden_size\n        head_dim = getattr(config, 'head_dim', hidden_size // num_heads)\n        num_replicate_kv_heads = getattr(config, 'num_replicate_key_value_heads', 1)\n        # packed qkv\n        self.qkv_proj = build_qkv_proj(hidden_size,\n                                       num_q_heads=num_heads,\n                                       num_kv_heads=num_key_value_heads,\n                                       head_size=head_dim,\n                                       bias=True,\n                                       quant_config=quantization_config,\n                                       dtype=dtype,\n                                       device=device,\n                                       num_replicate_kv_heads=num_replicate_kv_heads)\n\n        # rotary embedding\n        self.apply_rotary_pos_emb = ApplyRotaryEmb()\n\n        # attention\n        self.attn_fwd = Attention(\n            num_heads,\n            head_dim,\n            num_kv_heads=num_key_value_heads,\n            v_head_size=head_dim,\n            sliding_window=config.sliding_window if config.use_sliding_window else None,\n        )\n\n        # o_proj\n        self.o_proj = build_o_proj(num_heads * head_dim,\n                                   hidden_size,\n                                   bias=False,\n                                   quant_config=quantization_config,\n                                   dtype=dtype,\n                                   device=device,\n                                   is_tp=True)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],\n        past_key_value: Optional[Tuple[torch.Tensor]] = None,\n        attn_metadata: Any = None,\n    ):\n        \"\"\"Rewrite of LlamaAttention.forward.\"\"\"\n        # qkv proj\n        qkv_states = self.qkv_proj(hidden_states)\n        # (-1, heads, head_dim)\n        qkv_states = qkv_states.flatten(0, -2)\n        query_states, key_states, value_states = self.qkv_proj.split_qkv(qkv_states)\n\n        # apply rotary embedding\n        cos, sin = rotary_pos_emb\n        query_states, key_states = self.apply_rotary_pos_emb(\n            query_states,\n            key_states,\n            cos,\n            sin,\n            inplace=True,\n        )\n\n        # attention\n        attn_output = self.attn_fwd(\n            query_states,\n            key_states,\n            value_states,\n            past_key_value[0],\n            past_key_value[1],\n            attn_metadata,\n            k_scales_zeros=None if len(past_key_value) == 2 else past_key_value[2],\n            v_scales_zeros=None if len(past_key_value) == 2 else past_key_value[3],\n            inplace=True,\n        )\n        attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1)\n\n        # o proj\n        attn_output = self.o_proj(attn_output)\n        return attn_output\n\n\nclass Qwen2MLP(nn.Module):\n    \"\"\"mlp.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        quantization_config = getattr(config, 'quantization_config', None)\n        # gate up\n        self.gate_up_proj = build_gateup_linear(\n            config.hidden_size,\n            [config.intermediate_size, config.intermediate_size],\n            bias=False,\n            dtype=dtype,\n            device=device,\n            quant_config=quantization_config,\n            is_tp=True,\n        )\n\n        # silu and mul\n        self.act_fn = SiluAndMul(inplace=True)\n\n        # down\n        self.down_proj = build_down_linear(config.intermediate_size,\n                                           config.hidden_size,\n                                           bias=False,\n                                           quant_config=quantization_config,\n                                           dtype=dtype,\n                                           device=device,\n                                           is_tp=True)\n\n    def forward(self, x):\n        \"\"\"forward.\"\"\"\n        gate_up = self.gate_up_proj(x)\n        act = self.act_fn(gate_up)\n        return self.down_proj(act)\n\n\nclass Qwen2DecoderLayer(nn.Module):\n    \"\"\"Decoder layer.\"\"\"\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 layer_idx: int,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__()\n        self.layer_idx = layer_idx\n        quantization_config = getattr(config, 'quantization_config', None)\n\n        # build attention layer\n        self.self_attn = Qwen2Attention(config, dtype=dtype, device=device)\n\n        # build MLP\n        self.mlp = Qwen2MLP(config, dtype=dtype, device=device)\n\n        # build input layer norm\n        self.input_layernorm = RMSNorm(config.hidden_size,\n                                       config.rms_norm_eps,\n                                       quant_config=quantization_config,\n                                       dtype=dtype,\n                                       device=device)\n\n        # build attention layer norm\n        self.post_attention_layernorm = RMSNorm(config.hidden_size,\n                                                config.rms_norm_eps,\n                                                quant_config=quantization_config,\n                                                dtype=dtype,\n                                                device=device)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],\n        past_key_value: Optional[List[torch.FloatTensor]],\n        residual: Optional[torch.Tensor] = None,\n        attn_metadata: Any = None,\n    ):\n\n        if residual is None:\n            residual = hidden_states\n            hidden_states = self.input_layernorm(hidden_states)\n        else:\n            hidden_states, residual = self.input_layernorm(hidden_states, residual)\n\n        # Self Attention\n        hidden_states = self.self_attn(\n            hidden_states=hidden_states,\n            rotary_pos_emb=rotary_pos_emb,\n            past_key_value=past_key_value,\n            attn_metadata=attn_metadata,\n        )\n\n        # Fully Connected\n        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)\n        hidden_states = self.mlp(hidden_states)\n\n        outputs = (hidden_states, residual)\n        return outputs\n\n\nclass Qwen2Model(nn.Module):\n    \"\"\"model.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n\n        self.embed_tokens = build_embedding(\n            config.vocab_size,\n            config.hidden_size,\n            self.padding_idx,\n            dtype=dtype,\n            device=device,\n        )\n\n        # build all decode layers\n        self.layers = nn.ModuleList([\n            Qwen2DecoderLayer(config, layer_idx, dtype=dtype, device=device)\n            for layer_idx in range(config.num_hidden_layers)\n        ])\n\n        # build norm\n        self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)\n\n        # build rotary embedding\n        self.rotary_emb = build_rotary_embedding_from_config(config)\n\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        attn_metadata: Any = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n    ):\n        \"\"\"Rewrite of LlamaModel.forward.\"\"\"\n\n        # token embedding\n        if inputs_embeds is None:\n            inputs_embeds = self.embed_tokens(input_ids)\n\n        hidden_states = inputs_embeds\n\n        # rotary embedding\n        cos, sin = self.rotary_emb(hidden_states, position_ids)\n        cos, sin = cos[0], sin[0]\n        rotary_pos_emb = (cos, sin)\n\n        # decoding\n        residual = None\n        for idx, decoder_layer in enumerate(self.layers):\n            past_key_value = past_key_values[idx]\n            hidden_states, residual = decoder_layer(\n                hidden_states,\n                rotary_pos_emb=rotary_pos_emb,\n                past_key_value=past_key_value,\n                residual=residual,\n                attn_metadata=attn_metadata,\n            )\n\n        # norm\n        hidden_states, _ = self.norm(hidden_states, residual)\n\n        return hidden_states\n\n    def get_input_embeddings(self):\n        \"\"\"Get input embeddings.\"\"\"\n        return self.embed_tokens\n\n\nclass Qwen2ForCausalLM(nn.Module, DeployModelMixinV1, CudaGraphMixin):\n    \"\"\"ModelForCausalLM.\"\"\"\n\n    packed_modules_mapping = {\n        'qkv_proj': [\n            'q_proj',\n            'k_proj',\n            'v_proj',\n        ],\n        'gate_up_proj': [\n            'gate_proj',\n            'up_proj',\n        ],\n    }\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 ctx_mgr: StepContextManager,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__()\n        self.config = config\n        self.ctx_mgr = ctx_mgr\n        # build model\n        self.model = Qwen2Model(config, dtype=dtype, device=device)\n        # build lm_head\n        self.lm_head = self.build_lm_head(config.hidden_size, config.vocab_size, bias=False, dtype=dtype, device=device)\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        past_key_values: List[List[torch.Tensor]],\n        attn_metadata: Any = None,\n        inputs_embeds: torch.Tensor = None,\n        **kwargs,\n    ):\n        \"\"\"Model forward, return logits.\"\"\"\n        hidden_states = self.model(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            inputs_embeds=inputs_embeds,\n        )\n        return hidden_states\n\n    def get_input_embeddings(self):\n        \"\"\"Get input embeddings.\"\"\"\n        return self.model.get_input_embeddings()\n\n    def prepare_inputs_for_generation(\n        self,\n        past_key_values: List[List[torch.Tensor]],\n        inputs_embeds: Optional[torch.Tensor] = None,\n        context: StepContext = None,\n    ):\n        \"\"\"Prepare input.\"\"\"\n        # get input_ids, position_ids and attention metadatas\n        input_ids = context.input_ids\n        position_ids = context.position_ids\n        attn_metadata = context.attn_metadata\n\n        # process vision embeddings\n        vision_embeddings = context.input_embeddings\n        vision_embedding_indexing = context.input_embedding_indexing\n        if vision_embeddings is not None and len(vision_embeddings) > 0:\n            if inputs_embeds is None:\n                inputs_embeds = self.get_input_embeddings()(input_ids)\n            inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds)\n\n        # inputs of forward\n        return dict(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            inputs_embeds=inputs_embeds,\n        )\n\n    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):\n        \"\"\"Load weights.\"\"\"\n        # modify from vllm\n        stacked_params_mapping = [\n            # (param_name, shard_name, shard_id)\n            ('.qkv_proj', '.q_proj', 'q'),\n            ('.qkv_proj', '.k_proj', 'k'),\n            ('.qkv_proj', '.v_proj', 'v'),\n            ('.gate_up_proj', '.gate_proj', 0),\n            ('.gate_up_proj', '.up_proj', 1),\n        ]\n\n        params_dict = dict(self.named_parameters())\n        for name, loaded_weight in weights:\n            if 'rotary_emb.inv_freq' in name:\n                continue\n            if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):\n                continue\n            if self.config.tie_word_embeddings and 'lm_head.weight' in name:\n                continue\n            for (param_name, weight_name, shard_id) in stacked_params_mapping:\n                if weight_name not in name:\n                    continue\n                name = name.replace(weight_name, param_name)\n                param = params_dict[name]\n                load_weight(param, loaded_weight, shard_id=shard_id)\n                break\n            else:\n                param = params_dict[name]\n                load_weight(param, loaded_weight)\n"
  },
  {
    "path": "lmdeploy/pytorch/models/qwen2_5_vl.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n# adapted from:\n# https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py\n\nfrom typing import Any, Dict, Iterable, List, Optional, Tuple\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn\nfrom transformers.configuration_utils import PretrainedConfig\n\nfrom lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor, PreprocessInputResult\nfrom lmdeploy.pytorch.model_inputs import StepContext, StepContextManager\nfrom lmdeploy.pytorch.models.qwen2_vl import Qwen2Model\nfrom lmdeploy.pytorch.multimodal.data_type import MultiModalData\nfrom lmdeploy.pytorch.nn import ApplyRotaryEmb, FlashAttention, RMSNorm, SiluAndMul\nfrom lmdeploy.pytorch.nn.linear import build_merged_colwise_linear, build_qkv_proj, build_rowwise_linear\nfrom lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight\n\nfrom .patch import add_prefix\nfrom .utils.cudagraph import CudaGraphMeta, CudaGraphMixin\nfrom .utils.model import DeployModelMixinV1, vlm_model\n\n\nclass Qwen2_5_PatchEmbed(nn.Module):\n    \"\"\"Patch Embed.\"\"\"\n\n    def __init__(self,\n                 patch_size: int = 14,\n                 temporal_patch_size: int = 2,\n                 in_channels: int = 3,\n                 embed_dim: int = 1152,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None) -> None:\n        super().__init__()\n        self.patch_size = patch_size\n        self.temporal_patch_size = temporal_patch_size\n        self.in_channels = in_channels\n        self.embed_dim = embed_dim\n\n        kernel_size = [temporal_patch_size, patch_size, patch_size]\n        self.proj = nn.Conv3d(in_channels,\n                              embed_dim,\n                              kernel_size=kernel_size,\n                              stride=kernel_size,\n                              bias=False,\n                              dtype=dtype,\n                              device=device)\n\n    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        target_dtype = self.proj.weight.dtype\n        hidden_states = hidden_states.view(-1, self.in_channels, self.temporal_patch_size, self.patch_size,\n                                           self.patch_size)\n        hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim)\n        return hidden_states\n\n\nclass Qwen2_5_VisionRotaryEmbedding(nn.Module):\n    \"\"\"Vision rotary embedding.\"\"\"\n\n    def __init__(self, dim: int, theta: float = 10000.0, device: torch.device = None) -> None:\n        super().__init__()\n        inv_freq = 1.0 / (theta**(torch.arange(0, dim, 2, dtype=torch.float, device=device) / dim))\n        self.register_buffer('inv_freq', inv_freq, persistent=False)\n\n    def forward(self, seqlen: int) -> torch.Tensor:\n        seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)\n        freqs = torch.outer(seq, self.inv_freq)\n        return freqs\n\n\nclass Qwen2_5_VLVisionAttention(nn.Module):\n    \"\"\"Vision attention.\"\"\"\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None,\n                 prefix: str = ''):\n        super().__init__()\n        quantization_config = getattr(config, 'quantization_config', None)\n        dim = config.hidden_size\n        num_heads = config.num_heads\n        head_dim = dim // num_heads\n        self.head_dim = head_dim\n\n        # packed qkv\n        self.qkv = build_qkv_proj(dim,\n                                  num_q_heads=num_heads,\n                                  num_kv_heads=num_heads,\n                                  head_size=head_dim,\n                                  bias=True,\n                                  quant_config=quantization_config,\n                                  dtype=dtype,\n                                  device=device,\n                                  prefix=add_prefix('qkv', prefix))\n\n        # rotary embedding\n        self.apply_rotary_pos_emb = ApplyRotaryEmb()\n\n        # attention\n        self.attention = FlashAttention(\n            num_heads,\n            head_dim,\n            causal=False,\n        )\n\n        # o_proj\n        self.proj = build_rowwise_linear(\n            dim,\n            dim,\n            bias=True,\n            quant_config=quantization_config,\n            dtype=dtype,\n            device=device,\n            is_tp=True,\n            prefix=add_prefix('proj', prefix),\n        )\n\n    def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor,\n                rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor]) -> torch.Tensor:\n        seq_length = hidden_states.shape[0]\n        # qkv proj\n        qkv_states = self.qkv(hidden_states)\n        # (-1, heads, head_dim)\n        qkv_states = qkv_states.flatten(0, -2)\n        q, k, v = self.qkv.split_qkv(qkv_states)\n\n        cos, sin = rotary_pos_emb\n        q, k = self.apply_rotary_pos_emb(q, k, cos, sin)\n\n        attn_output = self.attention(\n            q,\n            k,\n            v,\n            q_start_loc=cu_seqlens[:-1],\n            q_seqlens=cu_seqlens[1:] - cu_seqlens[:-1],\n        )\n\n        attn_output = attn_output.reshape(seq_length, -1)\n\n        # o proj\n        attn_output = self.proj(attn_output)\n        return attn_output\n\n\nclass Qwen2_5_VLMLP(nn.Module):\n    \"\"\"Vision mlp.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        quantization_config = getattr(config, 'quantization_config', None)\n        # gate up\n        self.gate_up_proj = build_merged_colwise_linear(\n            in_features=config.hidden_size,\n            all_out_features=[config.intermediate_size, config.intermediate_size],\n            bias=True,\n            dtype=dtype,\n            device=device,\n            quant_config=quantization_config,\n            is_tp=True,\n        )\n\n        # silu and mul\n        self.act_fn = SiluAndMul(inplace=True)\n\n        # down\n        self.down_proj = build_rowwise_linear(in_features=config.intermediate_size,\n                                              out_features=config.hidden_size,\n                                              bias=True,\n                                              quant_config=quantization_config,\n                                              dtype=dtype,\n                                              device=device,\n                                              is_tp=True)\n\n    def forward(self, x):\n        \"\"\"forward.\"\"\"\n        return self.down_proj(self.act_fn(self.gate_up_proj(x)))\n\n\nclass Qwen2_5_VLVisionBlock(nn.Module):\n    \"\"\"Vision block.\"\"\"\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 layer_idx: int,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__()\n        self.layer_idx = layer_idx\n        self.norm1 = RMSNorm(config.hidden_size, eps=1e-6, dtype=dtype, device=device)\n        self.norm2 = RMSNorm(config.hidden_size, eps=1e-6, dtype=dtype, device=device)\n\n        self.attn = Qwen2_5_VLVisionAttention(config, dtype=dtype, device=device)\n\n        self.mlp = Qwen2_5_VLMLP(config, dtype=dtype, device=device)\n\n    def forward(self,\n                hidden_states: torch.Tensor,\n                cu_seqlens: torch.Tensor,\n                rotary_pos_emb: Optional[torch.Tensor] = None) -> torch.Tensor:\n        hidden_states = hidden_states + self.attn(\n            self.norm1(hidden_states),\n            cu_seqlens=cu_seqlens,\n            rotary_pos_emb=rotary_pos_emb,\n        )\n        hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))\n        return hidden_states\n\n\nclass Qwen2_5_VLPatchMerger(nn.Module):\n    \"\"\"Qwen2_5_VLPatchMerger.\"\"\"\n\n    def __init__(self,\n                 dim: int,\n                 context_dim: int,\n                 spatial_merge_size: int = 2,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None) -> None:\n        super().__init__()\n        self.hidden_size = context_dim * (spatial_merge_size**2)\n        self.ln_q = RMSNorm(context_dim, eps=1e-6, dtype=dtype, device=device)\n\n        self.mlp = nn.Sequential(\n            nn.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device),\n            nn.GELU(),\n            nn.Linear(self.hidden_size, dim, dtype=dtype, device=device),\n        )\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        x = self.mlp(self.ln_q(x).view(-1, self.hidden_size))\n        return x\n\n\n@vlm_model\nclass Qwen2_5_VisionTransformerPretrainedModel(nn.Module):\n    \"\"\"Vision transformer.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        self.config = config\n        self.spatial_merge_size = config.spatial_merge_size\n        self.patch_size = config.patch_size\n        self.fullatt_block_indexes = config.fullatt_block_indexes\n        self.window_size = config.window_size\n        self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size\n\n        self.patch_embed = Qwen2_5_PatchEmbed(\n            patch_size=config.patch_size,\n            temporal_patch_size=config.temporal_patch_size,\n            in_channels=config.in_channels,\n            embed_dim=config.hidden_size,\n            dtype=dtype,\n            device=device,\n        )\n\n        head_dim = config.hidden_size // config.num_heads\n        self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2, device=device)\n\n        self.blocks = nn.ModuleList(\n            [Qwen2_5_VLVisionBlock(config, layer_idx, dtype=dtype, device=device) for layer_idx in range(config.depth)])\n        self.merger = Qwen2_5_VLPatchMerger(dim=config.out_hidden_size,\n                                            context_dim=config.hidden_size,\n                                            spatial_merge_size=config.spatial_merge_size,\n                                            dtype=dtype,\n                                            device=device)\n\n    def rot_pos_emb(self, grid_thw):\n        \"\"\"Rotary position embedding.\"\"\"\n        pos_ids = []\n        for t, h, w in grid_thw:\n            hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)\n            hpos_ids = hpos_ids.reshape(\n                h // self.spatial_merge_size,\n                self.spatial_merge_size,\n                w // self.spatial_merge_size,\n                self.spatial_merge_size,\n            )\n            hpos_ids = hpos_ids.permute(0, 2, 1, 3)\n            hpos_ids = hpos_ids.flatten()\n\n            wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)\n            wpos_ids = wpos_ids.reshape(\n                h // self.spatial_merge_size,\n                self.spatial_merge_size,\n                w // self.spatial_merge_size,\n                self.spatial_merge_size,\n            )\n            wpos_ids = wpos_ids.permute(0, 2, 1, 3)\n            wpos_ids = wpos_ids.flatten()\n            pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))\n        pos_ids = torch.cat(pos_ids, dim=0)\n        max_grid_size = grid_thw[:, 1:].max()\n        rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)\n        rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)\n        return rotary_pos_emb\n\n    def get_window_index(self, grid_thw):\n        window_index: list = []\n        cu_window_seqlens: list = [0]\n        window_index_id = 0\n        vit_merger_window_size = self.window_size // self.spatial_merge_size // self.patch_size\n\n        for grid_t, grid_h, grid_w in grid_thw:\n            llm_grid_h, llm_grid_w = (\n                grid_h // self.spatial_merge_size,\n                grid_w // self.spatial_merge_size,\n            )\n            index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w)\n            pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size\n            pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size\n            num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size\n            num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size\n            index_padded = F.pad(index, (0, pad_w, 0, pad_h), 'constant', -100)\n            index_padded = index_padded.reshape(\n                grid_t,\n                num_windows_h,\n                vit_merger_window_size,\n                num_windows_w,\n                vit_merger_window_size,\n            )\n            index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(\n                grid_t,\n                num_windows_h * num_windows_w,\n                vit_merger_window_size,\n                vit_merger_window_size,\n            )\n            seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)\n            index_padded = index_padded.reshape(-1)\n            index_new = index_padded[index_padded != -100]\n            window_index.append(index_new + window_index_id)\n            cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1]\n            cu_window_seqlens.extend(cu_seqlens_tmp.tolist())\n            window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()\n        window_index = torch.cat(window_index, dim=0)\n\n        return window_index, cu_window_seqlens\n\n    def forward(self,\n                hidden_states: torch.Tensor,\n                cu_seqlens: torch.Tensor,\n                rotary_pos_emb: torch.Tensor,\n                window_index: torch.Tensor = None,\n                cu_window_seqlens: List = None) -> torch.Tensor:\n        \"\"\"forward.\"\"\"\n        hidden_states = self.patch_embed(hidden_states)\n\n        # for window-based attention\n        seq_len, _ = hidden_states.size()\n        hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)\n        hidden_states = hidden_states[window_index, :, :]\n        hidden_states = hidden_states.reshape(seq_len, -1)\n        rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)\n        rotary_pos_emb = rotary_pos_emb[window_index, :, :]\n        rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)\n        rotary_pos_emb = rotary_pos_emb.repeat(1, 2)\n        rotary_pos_emb = (rotary_pos_emb.cos(), rotary_pos_emb.sin())\n\n        cu_seqlens = torch.nn.functional.pad(cu_seqlens, (1, 0), value=0)\n\n        for layer_num, blk in enumerate(self.blocks):\n            if layer_num in self.fullatt_block_indexes:\n                cu_seqlens_now = cu_seqlens\n            else:\n                cu_seqlens_now = cu_window_seqlens\n\n            hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens_now, rotary_pos_emb=rotary_pos_emb)\n\n        hidden_states = self.merger(hidden_states)\n        reverse_indices = torch.argsort(window_index)\n        hidden_states = hidden_states[reverse_indices, :]\n\n        return hidden_states\n\n\nclass Qwen2_5_VLForConditionalGeneration(nn.Module, DeployModelMixinV1, CudaGraphMixin):\n    \"\"\"ModelForCausalLM.\"\"\"\n\n    packed_modules_mapping = {\n        'qkv_proj': [\n            'q_proj',\n            'k_proj',\n            'v_proj',\n        ],\n        'gate_up_proj': [\n            'gate_proj',\n            'up_proj',\n        ],\n    }\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 ctx_mgr: StepContextManager,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__()\n        self.config = config\n        self.ctx_mgr = ctx_mgr\n\n        # preprocessor\n        self.input_processor = Qwen2_5_VLInputProcessor(self.config)\n\n        # build vision model\n        self.visual = Qwen2_5_VisionTransformerPretrainedModel(\n            config.vision_config,\n            dtype=dtype,\n            device=device,\n        )\n        # get text_config\n        text_config = getattr(config, 'text_config', config)\n        # build model\n        self.model = Qwen2Model(text_config, dtype=dtype, device=device)\n        # build lm_head\n        self.lm_head = self.build_lm_head(config.hidden_size, config.vocab_size, bias=False, dtype=dtype, device=device)\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        past_key_values: List[List[torch.Tensor]],\n        attn_metadata: Any = None,\n        inputs_embeds: torch.Tensor = None,\n        mrope_position_ids: torch.Tensor = None,\n        pixel_values: torch.Tensor = None,\n        vis_cu_seqlens: torch.Tensor = None,\n        vis_pos_emb: torch.Tensor = None,\n        window_index: torch.Tensor = None,\n        cu_window_seqlens: List = None,\n        image_mask: torch.Tensor = None,\n        **kwargs,\n    ):\n        \"\"\"Model forward, return logits.\"\"\"\n        if inputs_embeds is None:\n            inputs_embeds = self.get_input_embeddings()(input_ids)\n            if pixel_values is not None:\n                dtype = inputs_embeds.dtype\n                pixel_values = pixel_values.to(dtype)\n                image_embeds = self.visual(pixel_values,\n                                           cu_seqlens=vis_cu_seqlens,\n                                           rotary_pos_emb=vis_pos_emb.to(dtype),\n                                           window_index=window_index,\n                                           cu_window_seqlens=cu_window_seqlens)\n                inputs_embeds = inputs_embeds.masked_scatter(image_mask[..., None], image_embeds)\n\n        hidden_states = self.model(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            inputs_embeds=inputs_embeds,\n            mrope_position_ids=mrope_position_ids,\n        )\n        return hidden_states\n\n    def get_input_embeddings(self):\n        \"\"\"Get input embeddings.\"\"\"\n        return self.model.get_input_embeddings()\n\n    def prepare_inputs_for_generation(\n        self,\n        past_key_values: List[List[torch.Tensor]],\n        inputs_embeds: Optional[torch.Tensor] = None,\n        context: StepContext = None,\n    ):\n        \"\"\"Prepare input.\"\"\"\n\n        # get input_ids, position_ids and attention metadatas\n        input_ids = context.input_ids\n        position_ids = context.position_ids\n        attn_metadata = context.attn_metadata\n\n        pixel_values = None\n        vis_cu_seqlens = None\n        vis_pos_emb = None\n        image_mask = None\n        window_index = None\n        cu_window_seqlens = None\n        if context.input_multimodals is not None:\n            image_data = [input_mm.get('image', []) for input_mm in context.input_multimodals]\n\n            if len(image_data) > 0:\n                # flatten batch\n                image_data = [data for im_data in image_data for data in im_data]\n                pixel_values = torch.cat([data.data for data in image_data])\n                image_token_id = image_data[0].meta['image_token_id']\n                image_mask = input_ids == image_token_id\n                grid_thw = torch.cat([data.meta['grid_thw'] for data in image_data]).cpu()\n                vis_pos_emb = self.visual.rot_pos_emb(grid_thw)\n\n                # calculation for window-based attention\n                window_index, cu_window_seqlens = self.visual.get_window_index(grid_thw)\n                cu_window_seqlens = torch.tensor(\n                    cu_window_seqlens,\n                    device=pixel_values.device,\n                    dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,\n                )\n                cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)\n\n                vis_cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2],\n                                                         grid_thw[:, 0]).to(pixel_values.device)\n                vis_cu_seqlens = vis_cu_seqlens.cumsum(dim=0, dtype=torch.int32)\n\n        mrope_position_ids = getattr(context, 'mrope_position_ids', None)\n\n        # process vision embeddings\n        vision_embeddings = context.input_embeddings\n        vision_embedding_indexing = context.input_embedding_indexing\n        if vision_embeddings is not None and len(vision_embeddings) > 0:\n            if inputs_embeds is None:\n                inputs_embeds = self.get_input_embeddings()(input_ids)\n            inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds)\n\n        # inputs of forward\n        return dict(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            inputs_embeds=inputs_embeds,\n            mrope_position_ids=mrope_position_ids,\n            pixel_values=pixel_values,\n            vis_cu_seqlens=vis_cu_seqlens,\n            vis_pos_emb=vis_pos_emb,\n            window_index=window_index,\n            cu_window_seqlens=cu_window_seqlens,\n            image_mask=image_mask,\n        )\n\n    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):\n        \"\"\"Load weights.\"\"\"\n        # modify from vllm\n        stacked_params_mapping = [\n            # (param_name, shard_name, shard_id)\n            ('.qkv_proj', '.q_proj', 'q'),\n            ('.qkv_proj', '.k_proj', 'k'),\n            ('.qkv_proj', '.v_proj', 'v'),\n            ('.gate_up_proj', '.gate_proj', 0),\n            ('.gate_up_proj', '.up_proj', 1),\n        ]\n\n        params_dict = dict(self.named_parameters())\n        for name, loaded_weight in weights:\n            if 'rotary_emb.inv_freq' in name:\n                continue\n            if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):\n                continue\n            if self.config.tie_word_embeddings and 'lm_head.weight' in name:\n                continue\n            for (param_name, weight_name, shard_id) in stacked_params_mapping:\n                if weight_name not in name:\n                    continue\n                name = name.replace(weight_name, param_name)\n                param = params_dict[name]\n                load_weight(param, loaded_weight, shard_id=shard_id)\n                break\n            else:\n                if '.qkv.' in name:\n                    param = params_dict[name]\n                    q, k, v = param.weight_spliter(loaded_weight)\n                    load_weight(param, q, shard_id='q')\n                    load_weight(param, k, shard_id='k')\n                    load_weight(param, v, shard_id='v')\n                else:\n                    param = params_dict[name]\n                    load_weight(param, loaded_weight)\n\n    def make_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs):\n        \"\"\"Make cudagraph buffers from forward inputs.\"\"\"\n        max_tokens = graph_meta.max_tokens\n\n        input_buffers = super().make_buffers_cudagraph(graph_meta=graph_meta, **kwargs)\n        mrope_position_ids = kwargs.get('mrope_position_ids', None)\n        if mrope_position_ids is not None:\n            input_buffers['mrope_position_ids'] = mrope_position_ids.new_zeros(3, max_tokens)\n\n        return input_buffers\n\n    def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs):\n        \"\"\"Fill cudagraph buffers from forward inputs.\"\"\"\n\n        new_inputs = super().fill_buffers_cudagraph(graph_meta=graph_meta, **kwargs)\n\n        input_ids = kwargs.get('input_ids')\n        num_tokens = input_ids.size(-1)\n        new_batch_size = graph_meta.max_batchs\n\n        is_decoding = graph_meta.is_decoding\n        input_buffers = graph_meta.input_buffers\n        mrope_position_ids = kwargs.get('mrope_position_ids', None)\n        if mrope_position_ids is not None:\n            input_buffers['mrope_position_ids'][:, :num_tokens] = mrope_position_ids\n            if is_decoding:\n                new_inputs['mrope_position_ids'] = input_buffers['mrope_position_ids'][:, :new_batch_size]\n            else:\n                new_inputs['mrope_position_ids'] = input_buffers['mrope_position_ids']\n\n        return new_inputs\n\n    def _get_model_metas(self, context: StepContext):\n        \"\"\"Get model metas.\"\"\"\n        model_metas = context.model_metas\n        if model_metas is None:\n            batch_size = context.q_seqlens.numel()\n            return [dict(mrope_delta=0)] * batch_size\n        return [dict(mrope_delta=0) if meta is None else meta for meta in model_metas]\n\n    def _update_model_meta_decoding(self, context: StepContext):\n        \"\"\"Update model meta for decoding.\"\"\"\n        model_metas = self._get_model_metas(context)\n        position_ids = context.position_ids\n\n        mrope_deltas = [meta['mrope_delta'] for meta in model_metas]\n        mrope_deltas = position_ids.new_tensor(mrope_deltas)\n        mrope_position_ids = position_ids + mrope_deltas[None]\n        mrope_position_ids = mrope_position_ids.expand(3, -1)\n\n        context.mrope_position_ids = mrope_position_ids\n        return model_metas\n\n    def _get_multimodal_pos_ids(self, grid_thw: list, device: torch.device):\n        \"\"\"Get mrope ids.\"\"\"\n        t, h, w = grid_thw\n        h //= 2\n        w //= 2\n        stride = torch.tensor([h * w, w, 1], device=device)[:, None]\n        size = torch.tensor([t, h, w], device=device)[:, None]\n        pos_ids = torch.arange(t * h * w, device=device)[None].expand(3, -1)\n        pos_ids = pos_ids // stride % size\n        return pos_ids\n\n    def _update_model_meta_prefilling(self, context: StepContext):\n        \"\"\"Update model meta for prefilling.\"\"\"\n        model_metas = self._get_model_metas(context)\n        input_multimodals = context.input_multimodals\n        if input_multimodals is None:\n            input_multimodals = [None] * len(model_metas)\n        position_ids = context.position_ids\n        batched_pos_ids = position_ids[0].split(context.q_seqlens.tolist())\n        mrope_position_ids = []\n        new_model_metas = []\n        for pos_ids, model_meta, input_mm in zip(batched_pos_ids, model_metas, input_multimodals):\n            images = []\n            if input_mm is not None:\n                images = input_mm.get('image', [])\n            if model_meta is None or 'mrope_delta' not in model_meta:\n                mrope_delta = 0\n            else:\n                mrope_delta = model_meta['mrope_delta']\n\n            pos_start = pos_ids[0].item()\n            mrope_pos_ids = pos_ids + mrope_delta\n            mrope_pos_ids = mrope_pos_ids[None].expand(3, -1).clone()\n            for img in images:\n                grid_thw = img.meta['grid_thw'][0].tolist()\n                _, h, w = grid_thw\n                h //= 2\n                w //= 2\n                num_pad = img.end - img.start - max(h, w)\n                mrope_delta -= num_pad\n                fill_start = img.start - pos_start\n                fill_end = img.end - pos_start\n                img_pos_ids = self._get_multimodal_pos_ids(grid_thw, pos_ids.device)\n                img_pos_ids += mrope_pos_ids[:, fill_start:fill_start + 1]\n                mrope_pos_ids[:, fill_end:] -= num_pad\n                mrope_pos_ids[:, fill_start:fill_end] = img_pos_ids\n\n            mrope_position_ids.append(mrope_pos_ids)\n            new_model_metas.append(dict(mrope_delta=mrope_delta))\n\n        mrope_position_ids = torch.cat(mrope_position_ids, dim=1)\n        context.mrope_position_ids = mrope_position_ids\n\n        return new_model_metas\n\n    def update_model_metas(self,\n                           past_key_values: List[List[torch.Tensor]],\n                           inputs_embeds: Optional[torch.Tensor] = None,\n                           context: StepContext = None):\n        \"\"\"Update model meta.\"\"\"\n        if context.is_decoding:\n            return self._update_model_meta_decoding(context)\n        else:\n            return self._update_model_meta_prefilling(context)\n\n    def get_input_processor(self) -> BaseModelInputProcessor:\n        \"\"\"Get input processor.\"\"\"\n        return self.input_processor\n\n\nclass Qwen2_5_VLInputProcessor(BaseModelInputProcessor):\n    \"\"\"Qwen2 input processor.\"\"\"\n\n    def __init__(self, config: PretrainedConfig) -> None:\n        self.config = config\n\n    def preprocess_input(self,\n                         input_ids: List[int],\n                         input_multimodals: List[Dict[str, Any]] = None,\n                         **kwargs) -> PreprocessInputResult:\n        \"\"\"Prepare multimodal input.\"\"\"\n        if input_multimodals is None or len(input_multimodals) == 0:\n            return input_ids, input_multimodals\n\n        input_imgs = []\n        for input_mm in input_multimodals:\n            pixel_values = input_mm['pixel_values']\n            image_grid_thw = input_mm['image_grid_thw']\n            offset = input_mm['offset']\n            start = offset\n            image_token_id = input_mm['image_token_id']\n            num_pad = input_mm['image_tokens']\n            if isinstance(num_pad, torch.Tensor):\n                num_pad = num_pad.item()\n\n            mm_data = MultiModalData(data=pixel_values,\n                                     start=start,\n                                     end=start + num_pad,\n                                     meta=dict(grid_thw=image_grid_thw, image_token_id=image_token_id))\n            input_imgs.append(mm_data)\n\n        result = PreprocessInputResult(\n            input_ids=input_ids,\n            input_multimodals=dict(image=input_imgs),\n        )\n        return result\n"
  },
  {
    "path": "lmdeploy/pytorch/models/qwen2_moe.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nfrom typing import Any, Dict, Iterable, List, Optional, Tuple\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn\nfrom transformers.configuration_utils import PretrainedConfig\n\nimport lmdeploy.pytorch.distributed as dist\nfrom lmdeploy.pytorch.distributed import get_tp_world_rank\nfrom lmdeploy.pytorch.model_inputs import StepContext, StepContextManager\nfrom lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, RMSNorm, SiluAndMul, build_rotary_embedding_from_config\nfrom lmdeploy.pytorch.nn.linear import build_merged_colwise_linear, build_qkv_proj, build_rowwise_linear\nfrom lmdeploy.pytorch.nn.moe import SoftmaxTopK, build_fused_moe\nfrom lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight\n\nfrom .utils.cudagraph import CudaGraphMixin\nfrom .utils.model import DeployModelMixinV1, build_embedding\n\n\nclass Qwen2MoeAttention(nn.Module):\n    \"\"\"Rewrite module of Qwen2MoeAttention.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        quantization_config = getattr(config, 'quantization_config', None)\n        num_heads = config.num_attention_heads\n        num_key_value_heads = config.num_key_value_heads\n        hidden_size = config.hidden_size\n        head_dim = getattr(config, 'head_dim', hidden_size // num_heads)\n\n        # packed qkv\n        self.qkv_proj = build_qkv_proj(\n            hidden_size,\n            num_q_heads=num_heads,\n            num_kv_heads=num_key_value_heads,\n            head_size=head_dim,\n            bias=True,\n            quant_config=quantization_config,\n            dtype=dtype,\n            device=device,\n        )\n\n        # rotary embedding\n        self.apply_rotary_pos_emb = ApplyRotaryEmb()\n\n        # attention\n        self.attn_fwd = Attention(\n            num_heads,\n            head_dim,\n            num_kv_heads=num_key_value_heads,\n            v_head_size=head_dim,\n            sliding_window=config.sliding_window,\n        )\n\n        # o_proj\n        self.o_proj = build_rowwise_linear(num_heads * head_dim,\n                                           hidden_size,\n                                           bias=False,\n                                           quant_config=quantization_config,\n                                           dtype=dtype,\n                                           device=device,\n                                           is_tp=True)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],\n        past_key_value: Optional[Tuple[torch.Tensor]] = None,\n        attn_metadata: Any = None,\n    ):\n        \"\"\"Rewrite of LlamaAttention.forward.\"\"\"\n        # qkv proj\n        qkv_states = self.qkv_proj(hidden_states)\n        # (-1, heads, head_dim)\n        qkv_states = qkv_states.flatten(0, -2)\n        query_states, key_states, value_states = self.qkv_proj.split_qkv(qkv_states)\n\n        # apply rotary embedding\n        cos, sin = rotary_pos_emb\n        query_states, key_states = self.apply_rotary_pos_emb(\n            query_states,\n            key_states,\n            cos,\n            sin,\n            inplace=True,\n        )\n\n        # attention\n        attn_output = self.attn_fwd(\n            query_states,\n            key_states,\n            value_states,\n            past_key_value[0],\n            past_key_value[1],\n            attn_metadata,\n            k_scales_zeros=None if len(past_key_value) == 2 else past_key_value[2],\n            v_scales_zeros=None if len(past_key_value) == 2 else past_key_value[3],\n            inplace=True,\n        )\n        attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1)\n\n        # o proj\n        attn_output = self.o_proj(attn_output)\n        return attn_output\n\n\nclass Qwen2MoeMLP(nn.Module):\n    \"\"\"mlp.\"\"\"\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 intermediate_size: int = None,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None,\n                 is_tp: bool = True,\n                 all_reduce: bool = True):\n        super().__init__()\n        quantization_config = getattr(config, 'quantization_config', None)\n        if intermediate_size is None:\n            intermediate_size = config.intermediate_size\n        # gate up\n        self.gate_up_proj = build_merged_colwise_linear(\n            config.hidden_size,\n            [intermediate_size, intermediate_size],\n            bias=False,\n            dtype=dtype,\n            device=device,\n            quant_config=quantization_config,\n            is_tp=is_tp,\n        )\n\n        # silu and mul\n        self.act_fn = SiluAndMul(inplace=True)\n\n        # down\n        self.down_proj = build_rowwise_linear(intermediate_size,\n                                              config.hidden_size,\n                                              bias=False,\n                                              quant_config=quantization_config,\n                                              dtype=dtype,\n                                              device=device,\n                                              is_tp=is_tp,\n                                              all_reduce=all_reduce)\n\n    def forward(self, x):\n        \"\"\"forward.\"\"\"\n        gate_up = self.gate_up_proj(x)\n        act = self.act_fn(gate_up)\n        return self.down_proj(act)\n\n\nclass Qwen2MoeSparseMoeBlock(nn.Module):\n    \"\"\"Moe block.\"\"\"\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 layer_idx: int,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__()\n        self.layer_idx = layer_idx\n        self.hidden_dim = config.hidden_size\n        self.ffn_dim = config.moe_intermediate_size\n        self.num_experts = config.num_experts\n        self.top_k = config.num_experts_per_tok\n        self.norm_topk_prob = config.norm_topk_prob\n        self.renormalize = self.norm_topk_prob\n\n        self.gate = build_rowwise_linear(\n            self.hidden_dim,\n            self.num_experts,\n            bias=False,\n            dtype=dtype,\n            device=device,\n            is_tp=False,\n        )\n\n        self.softmax_topk = SoftmaxTopK(\n            self.top_k,\n            n_groups=getattr(config, 'router_n_groups', -1),\n        )\n\n        self.experts = build_fused_moe(\n            self.hidden_dim,\n            self.ffn_dim,\n            self.num_experts,\n            top_k=self.top_k,\n            renormalize=self.renormalize,\n            dtype=dtype,\n            device=device,\n            all_reduce=False,\n        )\n\n        intermediate_size = config.shared_expert_intermediate_size\n        self.shared_expert = Qwen2MoeMLP(\n            config=config,\n            intermediate_size=intermediate_size,\n            dtype=dtype,\n            device=device,\n            is_tp=True,\n            all_reduce=False,\n        )\n        self.shared_expert_gate = build_rowwise_linear(config.hidden_size,\n                                                       1,\n                                                       bias=False,\n                                                       dtype=dtype,\n                                                       device=device,\n                                                       all_reduce=False)\n        world_size, _ = get_tp_world_rank()\n        if world_size > 1:\n            self._all_reduce = True\n        else:\n            self._all_reduce = False\n\n    def forward(self, hidden_states: torch.Tensor):\n        \"\"\"forward.\"\"\"\n        batch_size, sequence_length, hidden_dim = hidden_states.shape\n        hidden_states = hidden_states.view(-1, hidden_dim)\n        router_logits = self.gate(hidden_states)\n\n        topk_weights, topk_ids = self.softmax_topk(router_logits)\n\n        out_states = self.experts(\n            hidden_states,\n            topk_weights,\n            topk_ids,\n        )\n\n        shared_states = self.shared_expert(hidden_states)\n        shared_states = F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_states\n        out_states += shared_states\n        out_states = out_states.reshape(batch_size, sequence_length, -1)\n\n        if self._all_reduce:\n            dist.all_reduce(out_states)\n\n        return out_states\n\n\nclass Qwen2MoeDecoderLayer(nn.Module):\n    \"\"\"Decoder layer.\"\"\"\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 layer_idx: int,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__()\n        self.layer_idx = layer_idx\n        quantization_config = getattr(config, 'quantization_config', None)\n\n        # build attention layer\n        self.self_attn = Qwen2MoeAttention(config, dtype=dtype, device=device)\n\n        # build MLP\n        if (layer_idx not in config.mlp_only_layers) and (config.num_experts\n                                                          > 0) and ((layer_idx + 1) % config.decoder_sparse_step == 0):\n            self.mlp = Qwen2MoeSparseMoeBlock(config, layer_idx=layer_idx, dtype=dtype, device=device)\n        else:\n            self.mlp = Qwen2MoeMLP(config, intermediate_size=config.intermediate_size, dtype=dtype, device=device)\n\n        # build input layer norm\n        self.input_layernorm = RMSNorm(config.hidden_size,\n                                       config.rms_norm_eps,\n                                       quant_config=quantization_config,\n                                       dtype=dtype,\n                                       device=device)\n\n        # build attention layer norm\n        self.post_attention_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],\n        past_key_value: Optional[List[torch.FloatTensor]],\n        residual: Optional[torch.Tensor] = None,\n        attn_metadata: Any = None,\n    ):\n\n        if residual is None:\n            residual = hidden_states\n            hidden_states = self.input_layernorm(hidden_states)\n        else:\n            hidden_states, residual = self.input_layernorm(hidden_states, residual)\n\n        # Self Attention\n        hidden_states = self.self_attn(\n            hidden_states=hidden_states,\n            rotary_pos_emb=rotary_pos_emb,\n            past_key_value=past_key_value,\n            attn_metadata=attn_metadata,\n        )\n\n        # Fully Connected\n        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)\n        hidden_states = self.mlp(hidden_states)\n\n        outputs = (hidden_states, residual)\n        return outputs\n\n\nclass Qwen2MoeModel(nn.Module):\n    \"\"\"model.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n\n        self.embed_tokens = build_embedding(\n            config.vocab_size,\n            config.hidden_size,\n            self.padding_idx,\n            dtype=dtype,\n            device=device,\n        )\n\n        # build all decode layers\n        self.layers = nn.ModuleList([\n            Qwen2MoeDecoderLayer(config, layer_idx, dtype=dtype, device=device)\n            for layer_idx in range(config.num_hidden_layers)\n        ])\n\n        # build norm\n        self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)\n\n        # build rotary embedding\n        self.rotary_emb = build_rotary_embedding_from_config(config)\n\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        attn_metadata: Any = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n    ):\n        \"\"\"Rewrite of LlamaModel.forward.\"\"\"\n\n        # token embedding\n        if inputs_embeds is None:\n            inputs_embeds = self.embed_tokens(input_ids)\n\n        hidden_states = inputs_embeds\n\n        # rotary embedding\n        cos, sin = self.rotary_emb(hidden_states, position_ids)\n        cos, sin = cos[0], sin[0]\n        rotary_pos_emb = (cos, sin)\n\n        # decoding\n        residual = None\n        for idx, decoder_layer in enumerate(self.layers):\n            past_key_value = past_key_values[idx]\n            hidden_states, residual = decoder_layer(\n                hidden_states,\n                rotary_pos_emb=rotary_pos_emb,\n                past_key_value=past_key_value,\n                residual=residual,\n                attn_metadata=attn_metadata,\n            )\n\n        # norm\n        hidden_states, _ = self.norm(hidden_states, residual)\n\n        return hidden_states\n\n    def get_input_embeddings(self):\n        \"\"\"Get input embeddings.\"\"\"\n        return self.embed_tokens\n\n\nclass Qwen2MoeForCausalLM(nn.Module, DeployModelMixinV1, CudaGraphMixin):\n    \"\"\"ModelForCausalLM.\"\"\"\n\n    packed_modules_mapping = {\n        'qkv_proj': [\n            'q_proj',\n            'k_proj',\n            'v_proj',\n        ],\n        'gate_up_proj': [\n            'gate_proj',\n            'up_proj',\n        ],\n    }\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 ctx_mgr: StepContextManager,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__()\n        self.config = config\n        self.ctx_mgr = ctx_mgr\n        # build model\n        self.model = Qwen2MoeModel(config, dtype=dtype, device=device)\n        # build lm_head\n        self.lm_head = self.build_lm_head(config.hidden_size, config.vocab_size, bias=False, dtype=dtype, device=device)\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        past_key_values: List[List[torch.Tensor]],\n        attn_metadata: Any = None,\n        inputs_embeds: torch.Tensor = None,\n        **kwargs,\n    ):\n        \"\"\"Model forward, return logits.\"\"\"\n        hidden_states = self.model(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            inputs_embeds=inputs_embeds,\n        )\n        return hidden_states\n\n    def get_input_embeddings(self):\n        \"\"\"Get input embeddings.\"\"\"\n        return self.model.get_input_embeddings()\n\n    def prepare_inputs_for_generation(\n        self,\n        past_key_values: List[List[torch.Tensor]],\n        inputs_embeds: Optional[torch.Tensor] = None,\n        context: StepContext = None,\n    ):\n        \"\"\"Prepare input.\"\"\"\n        # get input_ids, position_ids and attention metadatas\n        input_ids = context.input_ids\n        position_ids = context.position_ids\n        attn_metadata = context.attn_metadata\n\n        # process vision embeddings\n        vision_embeddings = context.input_embeddings\n        vision_embedding_indexing = context.input_embedding_indexing\n        if vision_embeddings is not None and len(vision_embeddings) > 0:\n            if inputs_embeds is None:\n                inputs_embeds = self.get_input_embeddings()(input_ids)\n            inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds)\n\n        # inputs of forward\n        return dict(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            inputs_embeds=inputs_embeds,\n        )\n\n    def _load_weight_experts(self, name: str, loaded_weight: torch.Tensor, params_dict: Dict[str, nn.Parameter],\n                             expert_params_mapping: List):\n        \"\"\"Load weight experts.\"\"\"\n        for (param_name, weight_name, expert_id, shard_id) in expert_params_mapping:\n            if weight_name not in name:\n                continue\n            name = name.replace(weight_name, param_name)\n            param = params_dict[name]\n            load_weight(param, loaded_weight, expert_id=expert_id, shard_id=shard_id)\n            break\n        else:\n            param = params_dict[name]\n            load_weight(param, loaded_weight)\n\n    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):\n        \"\"\"Load weights.\"\"\"\n        # modify from vllm\n        stacked_params_mapping = [\n            # (param_name, shard_name, shard_id)\n            ('.qkv_proj', '.q_proj', 'q'),\n            ('.qkv_proj', '.k_proj', 'k'),\n            ('.qkv_proj', '.v_proj', 'v'),\n            ('.gate_up_proj', '.gate_proj', 0),\n            ('.gate_up_proj', '.up_proj', 1),\n        ]\n\n        # expert map\n        num_experts = self.config.num_experts\n        expert_params_mapping = []\n        for exp_id in range(num_experts):\n            gate_param = ('.experts.gate_up', f'.experts.{exp_id}.gate_proj', exp_id, 'gate')\n            up_param = ('.experts.gate_up', f'.experts.{exp_id}.up_proj', exp_id, 'up')\n            down_param = ('.experts.down', f'.experts.{exp_id}.down_proj', exp_id, 'down')\n            expert_params_mapping += [gate_param, up_param, down_param]\n\n        params_dict = dict(self.named_parameters())\n        for name, loaded_weight in weights:\n            if 'rotary_emb.inv_freq' in name:\n                continue\n            if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):\n                continue\n            if self.config.tie_word_embeddings and 'lm_head.weight' in name:\n                continue\n\n            if '.experts' in name:\n                self._load_weight_experts(name, loaded_weight, params_dict, expert_params_mapping=expert_params_mapping)\n            else:\n                for (param_name, weight_name, shard_id) in stacked_params_mapping:\n                    if weight_name not in name:\n                        continue\n                    name = name.replace(weight_name, param_name)\n                    param = params_dict[name]\n                    load_weight(param, loaded_weight, shard_id=shard_id)\n                    break\n                else:\n                    param = params_dict[name]\n                    load_weight(param, loaded_weight)\n"
  },
  {
    "path": "lmdeploy/pytorch/models/qwen2_reward.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom typing import Any, Iterable, List, Optional, Tuple\n\nimport torch\nfrom torch import nn\nfrom transformers.configuration_utils import PretrainedConfig\n\nfrom lmdeploy.pytorch.model_inputs import StepContext, StepContextManager\nfrom lmdeploy.pytorch.nn.linear import build_rowwise_linear\nfrom lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight\n\nfrom .qwen2 import Qwen2Model\nfrom .utils.cudagraph import CudaGraphMixin\n\n\nclass Qwen2ForRewardModel(nn.Module, CudaGraphMixin):\n    \"\"\"ModelForCausalLM.\"\"\"\n\n    packed_modules_mapping = {\n        'qkv_proj': [\n            'q_proj',\n            'k_proj',\n            'v_proj',\n        ],\n        'gate_up_proj': [\n            'gate_proj',\n            'up_proj',\n        ],\n    }\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 ctx_mgr: StepContextManager,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__()\n        self.config = config\n        self.ctx_mgr = ctx_mgr\n        # build model\n        self.model = Qwen2Model(config, dtype=dtype, device=device)\n\n        self.lm_head = build_rowwise_linear(config.hidden_size,\n                                            config.vocab_size,\n                                            bias=False,\n                                            dtype=dtype,\n                                            device=device)\n\n        self.num_labels = 1\n        self.score = nn.Sequential(\n            build_rowwise_linear(config.hidden_size, config.hidden_size, bias=True, dtype=dtype, device=device),\n            nn.ReLU(), build_rowwise_linear(config.hidden_size, self.num_labels, bias=True, dtype=dtype, device=device))\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        past_key_values: List[List[torch.Tensor]],\n        attn_metadata: Any = None,\n        inputs_embeds: torch.Tensor = None,\n        **kwargs,\n    ):\n        \"\"\"Model forward, return logits.\"\"\"\n        hidden_states = self.model(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            inputs_embeds=inputs_embeds,\n        )\n        return hidden_states\n\n    def get_logits(self, hidden_states: torch.Tensor):\n        \"\"\"Compute logits of the model output.\"\"\"\n        logits = self.score(hidden_states)\n        return logits\n\n    def update_weights(self):\n        \"\"\"Update weights.\"\"\"\n        pass\n\n    def get_input_embeddings(self):\n        \"\"\"Get input embeddings.\"\"\"\n        return self.model.get_input_embeddings()\n\n    def prepare_inputs_for_generation(\n        self,\n        past_key_values: List[List[torch.Tensor]],\n        inputs_embeds: Optional[torch.Tensor] = None,\n        context: StepContext = None,\n    ):\n        \"\"\"Prepare input.\"\"\"\n        # get input_ids, position_ids and attention metadatas\n        input_ids = context.input_ids\n        position_ids = context.position_ids\n        attn_metadata = context.attn_metadata\n\n        # inputs of forward\n        return dict(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            # inputs_embeds=inputs_embeds,\n        )\n\n    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):\n        \"\"\"Load weights.\"\"\"\n        # modify from vllm\n        stacked_params_mapping = [\n            # (param_name, shard_name, shard_id)\n            ('.qkv_proj', '.q_proj', 'q'),\n            ('.qkv_proj', '.k_proj', 'k'),\n            ('.qkv_proj', '.v_proj', 'v'),\n            ('.gate_up_proj', '.gate_proj', 0),\n            ('.gate_up_proj', '.up_proj', 1),\n        ]\n\n        params_dict = dict(self.named_parameters())\n        for name, loaded_weight in weights:\n            if 'rotary_emb.inv_freq' in name:\n                continue\n            if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):\n                continue\n            if self.config.tie_word_embeddings and 'lm_head.weight' in name:\n                continue\n            for (param_name, weight_name, shard_id) in stacked_params_mapping:\n                if weight_name not in name:\n                    continue\n                name = name.replace(weight_name, param_name)\n                param = params_dict[name]\n                load_weight(param, loaded_weight, shard_id=shard_id)\n                break\n            else:\n                param = params_dict[name]\n                load_weight(param, loaded_weight)\n"
  },
  {
    "path": "lmdeploy/pytorch/models/qwen2_vl.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nfrom typing import Any, Callable, Dict, Iterable, List, Optional, Tuple\n\nimport torch\nfrom torch import nn\nfrom transformers.configuration_utils import PretrainedConfig\n\nfrom lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor, PreprocessInputResult\nfrom lmdeploy.pytorch.model_inputs import StepContext, StepContextManager\nfrom lmdeploy.pytorch.multimodal.data_type import MultiModalData\nfrom lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, FlashAttention, LayerNorm, RMSNorm, SiluAndMul,\n                                 build_rotary_embedding_from_config)\nfrom lmdeploy.pytorch.nn.linear import (build_colwise_linear, build_merged_colwise_linear, build_qkv_proj,\n                                        build_rowwise_linear)\nfrom lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight\n\nfrom .utils.cudagraph import CudaGraphMeta, CudaGraphMixin\nfrom .utils.model import DeployModelMixinV1, build_embedding, vlm_model\n\n\ndef _apply_mrope_selection(hidden_states: torch.Tensor, mrope_position_ids: torch.Tensor, mrope_section: List[int],\n                           position_ids: torch.Tensor, rotary_emb_func: Callable):\n    _mrope_position_ids = torch.zeros(3, position_ids.shape[-1], dtype=position_ids.dtype, device=position_ids.device)\n    _mrope_position_ids[:, :mrope_position_ids.shape[-1]] = mrope_position_ids\n    cos, sin = rotary_emb_func(hidden_states, _mrope_position_ids)\n    _cos = torch.zeros(cos.shape[1], cos.shape[-1], dtype=cos.dtype, device=cos.device)\n    _sin = torch.zeros_like(_cos)\n    mrope_section = mrope_section * 2\n\n    def _apply_split(src, dst):\n        start = 0\n        for i, m in enumerate(src.split(mrope_section, dim=-1)):\n            dst[:, start:start + mrope_section[i]] = m[i % 3]\n            start += mrope_section[i]\n\n    _apply_split(cos, _cos)\n    _apply_split(sin, _sin)\n\n    return _cos, _sin\n\n\nclass Qwen2Attention(nn.Module):\n    \"\"\"Rewrite module of Qwen2Attention.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        quantization_config = getattr(config, 'quantization_config', None)\n        num_heads = config.num_attention_heads\n        num_key_value_heads = config.num_key_value_heads\n        hidden_size = config.hidden_size\n        head_dim = getattr(config, 'head_dim', hidden_size // num_heads)\n\n        # packed qkv\n        self.qkv_proj = build_qkv_proj(\n            hidden_size,\n            num_q_heads=num_heads,\n            num_kv_heads=num_key_value_heads,\n            head_size=head_dim,\n            bias=True,\n            quant_config=quantization_config,\n            dtype=dtype,\n            device=device,\n        )\n\n        # rotary embedding\n        self.apply_rotary_pos_emb = ApplyRotaryEmb()\n\n        # attention\n        self.attn_fwd = Attention(\n            num_heads,\n            head_dim,\n            num_kv_heads=num_key_value_heads,\n            v_head_size=head_dim,\n            sliding_window=config.sliding_window,\n        )\n\n        # o_proj\n        self.o_proj = build_rowwise_linear(num_heads * head_dim,\n                                           hidden_size,\n                                           bias=False,\n                                           quant_config=quantization_config,\n                                           dtype=dtype,\n                                           device=device,\n                                           is_tp=True)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],\n        past_key_value: Optional[Tuple[torch.Tensor]] = None,\n        attn_metadata: Any = None,\n    ):\n        \"\"\"Rewrite of LlamaAttention.forward.\"\"\"\n        # qkv proj\n        qkv_states = self.qkv_proj(hidden_states)\n        # (-1, heads, head_dim)\n        qkv_states = qkv_states.flatten(0, -2)\n        query_states, key_states, value_states = self.qkv_proj.split_qkv(qkv_states)\n\n        # apply rotary embedding\n        cos, sin = rotary_pos_emb\n        query_states, key_states = self.apply_rotary_pos_emb(\n            query_states,\n            key_states,\n            cos,\n            sin,\n            inplace=True,\n        )\n\n        # attention\n        attn_output = self.attn_fwd(\n            query_states,\n            key_states,\n            value_states,\n            past_key_value[0],\n            past_key_value[1],\n            attn_metadata,\n            k_scales_zeros=None if len(past_key_value) == 2 else past_key_value[2],\n            v_scales_zeros=None if len(past_key_value) == 2 else past_key_value[3],\n            inplace=True,\n        )\n        attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1)\n\n        # o proj\n        attn_output = self.o_proj(attn_output)\n        return attn_output\n\n\nclass Qwen2MLP(nn.Module):\n    \"\"\"mlp.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        quantization_config = getattr(config, 'quantization_config', None)\n        # gate up\n        self.gate_up_proj = build_merged_colwise_linear(\n            config.hidden_size,\n            [config.intermediate_size, config.intermediate_size],\n            bias=False,\n            dtype=dtype,\n            device=device,\n            quant_config=quantization_config,\n            is_tp=True,\n        )\n\n        # silu and mul\n        self.act_fn = SiluAndMul(inplace=True)\n\n        # down\n        self.down_proj = build_rowwise_linear(config.intermediate_size,\n                                              config.hidden_size,\n                                              bias=False,\n                                              quant_config=quantization_config,\n                                              dtype=dtype,\n                                              device=device,\n                                              is_tp=True)\n\n    def forward(self, x):\n        \"\"\"forward.\"\"\"\n        gate_up = self.gate_up_proj(x)\n        act = self.act_fn(gate_up)\n        return self.down_proj(act)\n\n\nclass Qwen2DecoderLayer(nn.Module):\n    \"\"\"Decoder layer.\"\"\"\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 layer_idx: int,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__()\n        self.layer_idx = layer_idx\n        quantization_config = getattr(config, 'quantization_config', None)\n\n        # build attention layer\n        self.self_attn = Qwen2Attention(config, dtype=dtype, device=device)\n\n        # build MLP\n        self.mlp = Qwen2MLP(config, dtype=dtype, device=device)\n\n        # build input layer norm\n        self.input_layernorm = RMSNorm(config.hidden_size,\n                                       config.rms_norm_eps,\n                                       quant_config=quantization_config,\n                                       dtype=dtype,\n                                       device=device)\n\n        # build attention layer norm\n        self.post_attention_layernorm = RMSNorm(config.hidden_size,\n                                                config.rms_norm_eps,\n                                                quant_config=quantization_config,\n                                                dtype=dtype,\n                                                device=device)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],\n        past_key_value: Optional[List[torch.FloatTensor]],\n        residual: Optional[torch.Tensor] = None,\n        attn_metadata: Any = None,\n    ):\n\n        if residual is None:\n            residual = hidden_states\n            hidden_states = self.input_layernorm(hidden_states)\n        else:\n            hidden_states, residual = self.input_layernorm(hidden_states, residual)\n\n        # Self Attention\n        hidden_states = self.self_attn(\n            hidden_states=hidden_states,\n            rotary_pos_emb=rotary_pos_emb,\n            past_key_value=past_key_value,\n            attn_metadata=attn_metadata,\n        )\n\n        # Fully Connected\n        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)\n        hidden_states = self.mlp(hidden_states)\n\n        outputs = (hidden_states, residual)\n        return outputs\n\n\nclass Qwen2Model(nn.Module):\n    \"\"\"model.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n        self.mrope_section = config.rope_scaling['mrope_section']\n\n        self.embed_tokens = build_embedding(\n            config.vocab_size,\n            config.hidden_size,\n            self.padding_idx,\n            dtype=dtype,\n            device=device,\n        )\n\n        # build all decode layers\n        self.layers = nn.ModuleList([\n            Qwen2DecoderLayer(config, layer_idx, dtype=dtype, device=device)\n            for layer_idx in range(config.num_hidden_layers)\n        ])\n\n        # build norm\n        self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)\n\n        # build rotary embedding\n        self.rotary_emb = build_rotary_embedding_from_config(config)\n\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        attn_metadata: Any = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        mrope_position_ids: torch.LongTensor = None,\n    ):\n        \"\"\"Rewrite of LlamaModel.forward.\"\"\"\n\n        # token embedding\n        if inputs_embeds is None:\n            inputs_embeds = self.embed_tokens(input_ids)\n\n        hidden_states = inputs_embeds\n\n        # rotary embedding\n        if mrope_position_ids is None:\n            cos, sin = self.rotary_emb(hidden_states, position_ids)\n            cos, sin = cos[0], sin[0]\n        else:\n            cos, sin = _apply_mrope_selection(hidden_states, mrope_position_ids, self.mrope_section, position_ids,\n                                              self.rotary_emb)\n        rotary_pos_emb = (cos, sin)\n\n        # decoding\n        residual = None\n        for idx, decoder_layer in enumerate(self.layers):\n            past_key_value = past_key_values[idx]\n            hidden_states, residual = decoder_layer(\n                hidden_states,\n                rotary_pos_emb=rotary_pos_emb,\n                past_key_value=past_key_value,\n                residual=residual,\n                attn_metadata=attn_metadata,\n            )\n\n        # norm\n        hidden_states, _ = self.norm(hidden_states, residual)\n\n        return hidden_states\n\n    def get_input_embeddings(self):\n        \"\"\"Get input embeddings.\"\"\"\n        return self.embed_tokens\n\n\nclass PatchEmbed(nn.Module):\n    \"\"\"Patch Embed.\"\"\"\n\n    def __init__(self,\n                 patch_size: int = 14,\n                 temporal_patch_size: int = 2,\n                 in_channels: int = 3,\n                 embed_dim: int = 1152,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None) -> None:\n        super().__init__()\n        self.patch_size = patch_size\n        self.temporal_patch_size = temporal_patch_size\n        self.in_channels = in_channels\n        self.embed_dim = embed_dim\n\n        kernel_size = [temporal_patch_size, patch_size, patch_size]\n        self.proj = nn.Conv3d(in_channels,\n                              embed_dim,\n                              kernel_size=kernel_size,\n                              stride=kernel_size,\n                              bias=False,\n                              dtype=dtype,\n                              device=device)\n\n    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        target_dtype = self.proj.weight.dtype\n        hidden_states = hidden_states.view(-1, self.in_channels, self.temporal_patch_size, self.patch_size,\n                                           self.patch_size)\n        hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim)\n        return hidden_states\n\n\nclass VisionRotaryEmbedding(nn.Module):\n    \"\"\"Vision rotary embedding.\"\"\"\n\n    def __init__(self, dim: int, theta: float = 10000.0, device: torch.device = None) -> None:\n        super().__init__()\n        inv_freq = 1.0 / (theta**(torch.arange(0, dim, 2, dtype=torch.float, device=device) / dim))\n        self.register_buffer('inv_freq', inv_freq, persistent=False)\n\n    def forward(self, seqlen: int) -> torch.Tensor:\n        seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)\n        freqs = torch.outer(seq, self.inv_freq)\n        return freqs\n\n\nclass VisionAttention(nn.Module):\n    \"\"\"Vision attention.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        quantization_config = getattr(config, 'quantization_config', None)\n        dim = config.embed_dim\n        num_heads = config.num_heads\n        head_dim = dim // num_heads\n        self.head_dim = head_dim\n\n        # packed qkv\n        self.qkv = build_qkv_proj(\n            dim,\n            num_q_heads=num_heads,\n            num_kv_heads=num_heads,\n            head_size=head_dim,\n            bias=True,\n            quant_config=quantization_config,\n            dtype=dtype,\n            device=device,\n        )\n\n        # rotary embedding\n        self.apply_rotary_pos_emb = ApplyRotaryEmb()\n\n        # attention\n        self.attention = FlashAttention(\n            num_heads,\n            head_dim,\n            causal=False,\n        )\n\n        # o_proj\n        self.proj = build_rowwise_linear(dim,\n                                         dim,\n                                         bias=True,\n                                         quant_config=quantization_config,\n                                         dtype=dtype,\n                                         device=device,\n                                         is_tp=True)\n\n    def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor,\n                rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor]) -> torch.Tensor:\n        seq_length = hidden_states.shape[0]\n        # qkv proj\n        qkv_states = self.qkv(hidden_states)\n        # (-1, heads, head_dim)\n        qkv_states = qkv_states.flatten(0, -2)\n        q, k, v = self.qkv.split_qkv(qkv_states)\n\n        cos, sin = rotary_pos_emb\n        q, k = self.apply_rotary_pos_emb(q, k, cos, sin)\n\n        attn_output = self.attention(\n            q,\n            k,\n            v,\n            q_start_loc=cu_seqlens[:-1],\n            q_seqlens=cu_seqlens[1:] - cu_seqlens[:-1],\n        )\n\n        attn_output = attn_output.reshape(seq_length, -1)\n\n        # o proj\n        attn_output = self.proj(attn_output)\n        return attn_output\n\n\nclass VisionMlp(nn.Module):\n    \"\"\"Vision mlp.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        from transformers.activations import ACT2FN\n        dim = config.embed_dim\n        hidden_dim = int(config.embed_dim * config.mlp_ratio)\n        quantization_config = getattr(config, 'quantization_config', None)\n        # gate up\n        self.fc1 = build_colwise_linear(\n            dim,\n            hidden_dim,\n            bias=True,\n            dtype=dtype,\n            device=device,\n            quant_config=quantization_config,\n            is_tp=True,\n        )\n\n        # silu and mul\n        if config.hidden_act in ['gelu', 'gelu_fast', 'quick_gelu', 'gelu_python']:\n            self.act = nn.GELU()\n        else:\n            self.act = ACT2FN[config.hidden_act]\n\n        # down\n        self.fc2 = build_rowwise_linear(hidden_dim,\n                                        dim,\n                                        bias=True,\n                                        quant_config=quantization_config,\n                                        dtype=dtype,\n                                        device=device,\n                                        is_tp=True)\n\n    def forward(self, x):\n        \"\"\"forward.\"\"\"\n        return self.fc2(self.act(self.fc1(x)))\n\n\nclass Qwen2VLVisionBlock(nn.Module):\n    \"\"\"Vision block.\"\"\"\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 layer_idx: int,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__()\n        self.layer_idx = layer_idx\n        self.norm1 = LayerNorm(config.embed_dim, eps=1e-6, dtype=dtype, device=device)\n        self.norm2 = LayerNorm(config.embed_dim, eps=1e-6, dtype=dtype, device=device)\n\n        self.attn = VisionAttention(config, dtype=dtype, device=device)\n\n        self.mlp = VisionMlp(config, dtype=dtype, device=device)\n\n    def forward(self,\n                hidden_states,\n                cu_seqlens,\n                rotary_pos_emb,\n                residual: Optional[torch.Tensor] = None) -> torch.Tensor:\n        if residual is None:\n            residual = hidden_states\n            hidden_states = self.norm1(hidden_states)\n        else:\n            hidden_states, residual = self.norm1(hidden_states, residual)\n\n        hidden_states = self.attn(hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb)\n\n        hidden_states, residual = self.norm2(hidden_states, residual)\n        hidden_states = self.mlp(hidden_states)\n        return hidden_states, residual\n\n\nclass PatchMerger(nn.Module):\n    \"\"\"PatchMerger.\"\"\"\n\n    def __init__(self,\n                 dim: int,\n                 context_dim: int,\n                 spatial_merge_size: int = 2,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None) -> None:\n        super().__init__()\n        self.hidden_size = context_dim * (spatial_merge_size**2)\n        self.ln_q = nn.LayerNorm(context_dim, eps=1e-6, dtype=dtype, device=device)\n        self.mlp = nn.Sequential(\n            nn.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device),\n            nn.GELU(),\n            nn.Linear(self.hidden_size, dim, dtype=dtype, device=device),\n        )\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        x = self.mlp(self.ln_q(x).view(-1, self.hidden_size))\n        return x\n\n\n@vlm_model\nclass Qwen2VisionTransformerPretrainedModel(nn.Module):\n    \"\"\"Vision transformer.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        self.config = config\n        self.spatial_merge_size = config.spatial_merge_size\n\n        self.patch_embed = PatchEmbed(\n            patch_size=config.patch_size,\n            temporal_patch_size=config.temporal_patch_size,\n            in_channels=config.in_channels,\n            embed_dim=config.embed_dim,\n            dtype=dtype,\n            device=device,\n        )\n\n        head_dim = config.embed_dim // config.num_heads\n        self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2, device=device)\n\n        self.blocks = nn.ModuleList(\n            [Qwen2VLVisionBlock(config, layer_idx, dtype=dtype, device=device) for layer_idx in range(config.depth)])\n        self.merger = PatchMerger(dim=config.hidden_size,\n                                  context_dim=config.embed_dim,\n                                  spatial_merge_size=config.spatial_merge_size,\n                                  dtype=dtype,\n                                  device=device)\n\n    def rot_pos_emb(self, grid_thw):\n        \"\"\"Rotary position embedding.\"\"\"\n        pos_ids = []\n        for t, h, w in grid_thw:\n            hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)\n            hpos_ids = hpos_ids.reshape(\n                h // self.spatial_merge_size,\n                self.spatial_merge_size,\n                w // self.spatial_merge_size,\n                self.spatial_merge_size,\n            )\n            hpos_ids = hpos_ids.permute(0, 2, 1, 3)\n            hpos_ids = hpos_ids.flatten()\n\n            wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)\n            wpos_ids = wpos_ids.reshape(\n                h // self.spatial_merge_size,\n                self.spatial_merge_size,\n                w // self.spatial_merge_size,\n                self.spatial_merge_size,\n            )\n            wpos_ids = wpos_ids.permute(0, 2, 1, 3)\n            wpos_ids = wpos_ids.flatten()\n            pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))\n        pos_ids = torch.cat(pos_ids, dim=0)\n        max_grid_size = grid_thw[:, 1:].max()\n        rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)\n        rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)\n        return rotary_pos_emb\n\n    def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor,\n                rotary_pos_emb: torch.Tensor) -> torch.Tensor:\n        \"\"\"forward.\"\"\"\n        hidden_states = self.patch_embed(hidden_states)\n        cu_seqlens = torch.nn.functional.pad(cu_seqlens, (1, 0), value=0)\n\n        residual = None\n        for blk in self.blocks:\n            hidden_states, residual = blk(hidden_states,\n                                          cu_seqlens=cu_seqlens,\n                                          rotary_pos_emb=rotary_pos_emb,\n                                          residual=residual)\n\n        hidden_states = hidden_states + residual\n\n        return self.merger(hidden_states)\n\n\nclass Qwen2VLForConditionalGeneration(nn.Module, DeployModelMixinV1, CudaGraphMixin):\n    \"\"\"ModelForCausalLM.\"\"\"\n\n    packed_modules_mapping = {\n        'qkv_proj': [\n            'q_proj',\n            'k_proj',\n            'v_proj',\n        ],\n        'gate_up_proj': [\n            'gate_proj',\n            'up_proj',\n        ],\n    }\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 ctx_mgr: StepContextManager,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__()\n        self.config = config\n        self.ctx_mgr = ctx_mgr\n\n        # preprocessor\n        self.input_processor = Qwen2VLInputProcessor(self.config)\n\n        # build vision model\n        self.visual = Qwen2VisionTransformerPretrainedModel(\n            config.vision_config,\n            dtype=dtype,\n            device=device,\n        )\n        # get text_config\n        text_config = getattr(config, 'text_config', config)\n        # build model\n        self.model = Qwen2Model(text_config, dtype=dtype, device=device)\n        # build lm_head\n        self.lm_head = self.build_lm_head(config.hidden_size, config.vocab_size, bias=False, dtype=dtype, device=device)\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        past_key_values: List[List[torch.Tensor]],\n        attn_metadata: Any = None,\n        inputs_embeds: torch.Tensor = None,\n        mrope_position_ids: torch.Tensor = None,\n        pixel_values: torch.Tensor = None,\n        vis_cu_seqlens: torch.Tensor = None,\n        vis_pos_emb: torch.Tensor = None,\n        image_mask: torch.Tensor = None,\n        **kwargs,\n    ):\n        \"\"\"Model forward, return logits.\"\"\"\n        if inputs_embeds is None:\n            inputs_embeds = self.get_input_embeddings()(input_ids)\n            if pixel_values is not None:\n                dtype = inputs_embeds.dtype\n                pixel_values = pixel_values.to(dtype)\n                vis_pos_emb = (vis_pos_emb[0].to(dtype), vis_pos_emb[1].to(dtype))\n                image_embeds = self.visual(pixel_values, cu_seqlens=vis_cu_seqlens, rotary_pos_emb=vis_pos_emb)\n                inputs_embeds = inputs_embeds.masked_scatter(image_mask[..., None], image_embeds)\n\n        hidden_states = self.model(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            inputs_embeds=inputs_embeds,\n            mrope_position_ids=mrope_position_ids,\n        )\n        return hidden_states\n\n    def get_input_embeddings(self):\n        \"\"\"Get input embeddings.\"\"\"\n        return self.model.get_input_embeddings()\n\n    def prepare_inputs_for_generation(\n        self,\n        past_key_values: List[List[torch.Tensor]],\n        inputs_embeds: Optional[torch.Tensor] = None,\n        context: StepContext = None,\n    ):\n        \"\"\"Prepare input.\"\"\"\n\n        # get input_ids, position_ids and attention metadatas\n        input_ids = context.input_ids\n        position_ids = context.position_ids\n        attn_metadata = context.attn_metadata\n\n        pixel_values = None\n        vis_cu_seqlens = None\n        vis_pos_emb = None\n        image_mask = None\n        if context.input_multimodals is not None:\n            image_data = [input_mm.get('image', []) for input_mm in context.input_multimodals]\n            if len(image_data) > 0:\n                # flatten batch\n                image_data = [data for im_data in image_data for data in im_data]\n                pixel_values = torch.cat([data.data for data in image_data])\n                image_token_id = image_data[0].meta['image_token_id']\n                image_mask = input_ids == image_token_id\n                grid_thw = torch.cat([data.meta['grid_thw'] for data in image_data]).cpu()\n                vis_pos_emb = self.visual.rot_pos_emb(grid_thw)\n                vis_cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2],\n                                                         grid_thw[:, 0]).to(pixel_values.device)\n                vis_cu_seqlens = vis_cu_seqlens.cumsum(dim=0, dtype=torch.int32)\n                vis_pos_emb = vis_pos_emb.repeat(1, 2)\n                vis_pos_emb = (vis_pos_emb.cos(), vis_pos_emb.sin())\n\n        mrope_position_ids = getattr(context, 'mrope_position_ids', None)\n\n        # process vision embeddings\n        vision_embeddings = context.input_embeddings\n        vision_embedding_indexing = context.input_embedding_indexing\n        if vision_embeddings is not None and len(vision_embeddings) > 0:\n            if inputs_embeds is None:\n                inputs_embeds = self.get_input_embeddings()(input_ids)\n            inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds)\n\n        # inputs of forward\n        return dict(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            inputs_embeds=inputs_embeds,\n            mrope_position_ids=mrope_position_ids,\n            pixel_values=pixel_values,\n            vis_cu_seqlens=vis_cu_seqlens,\n            vis_pos_emb=vis_pos_emb,\n            image_mask=image_mask,\n        )\n\n    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):\n        \"\"\"Load weights.\"\"\"\n        # modify from vllm\n        stacked_params_mapping = [\n            # (param_name, shard_name, shard_id)\n            ('.qkv_proj', '.q_proj', 'q'),\n            ('.qkv_proj', '.k_proj', 'k'),\n            ('.qkv_proj', '.v_proj', 'v'),\n            ('.gate_up_proj', '.gate_proj', 0),\n            ('.gate_up_proj', '.up_proj', 1),\n        ]\n\n        params_dict = dict(self.named_parameters())\n        for name, loaded_weight in weights:\n            if 'rotary_emb.inv_freq' in name:\n                continue\n            if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):\n                continue\n            if self.config.tie_word_embeddings and 'lm_head.weight' in name:\n                continue\n            for (param_name, weight_name, shard_id) in stacked_params_mapping:\n                if weight_name not in name:\n                    continue\n                name = name.replace(weight_name, param_name)\n                param = params_dict[name]\n                load_weight(param, loaded_weight, shard_id=shard_id)\n                break\n            else:\n                if '.qkv.' in name:\n                    param = params_dict[name]\n                    q, k, v = param.weight_spliter(loaded_weight)\n                    load_weight(param, q, shard_id='q')\n                    load_weight(param, k, shard_id='k')\n                    load_weight(param, v, shard_id='v')\n                else:\n                    param = params_dict[name]\n                    load_weight(param, loaded_weight)\n\n    def make_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs):\n        \"\"\"Make cudagraph buffers from forward inputs.\"\"\"\n        max_tokens = graph_meta.max_tokens\n\n        input_buffers = super().make_buffers_cudagraph(graph_meta=graph_meta, **kwargs)\n        mrope_position_ids = kwargs.get('mrope_position_ids', None)\n        if mrope_position_ids is not None:\n            input_buffers['mrope_position_ids'] = mrope_position_ids.new_zeros(3, max_tokens)\n\n        return input_buffers\n\n    def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs):\n        \"\"\"Fill cudagraph buffers from forward inputs.\"\"\"\n\n        new_inputs = super().fill_buffers_cudagraph(graph_meta=graph_meta, **kwargs)\n\n        input_ids = kwargs.get('input_ids')\n        num_tokens = input_ids.size(-1)\n        new_batch_size = graph_meta.max_batchs\n\n        is_decoding = graph_meta.is_decoding\n        input_buffers = graph_meta.input_buffers\n        mrope_position_ids = kwargs.get('mrope_position_ids', None)\n        if mrope_position_ids is not None:\n            input_buffers['mrope_position_ids'][:, :num_tokens] = mrope_position_ids\n            if is_decoding:\n                new_inputs['mrope_position_ids'] = input_buffers['mrope_position_ids'][:, :new_batch_size]\n            else:\n                new_inputs['mrope_position_ids'] = input_buffers['mrope_position_ids']\n\n        return new_inputs\n\n    def _get_model_metas(self, context: StepContext):\n        \"\"\"Get model metas.\"\"\"\n        model_metas = context.model_metas\n        if model_metas is None:\n            batch_size = context.q_seqlens.numel()\n            return [dict(mrope_delta=0)] * batch_size\n        return [dict(mrope_delta=0) if meta is None else meta for meta in model_metas]\n\n    def _update_model_meta_decoding(self, context: StepContext):\n        \"\"\"Update model meta for decoding.\"\"\"\n        model_metas = self._get_model_metas(context)\n        position_ids = context.position_ids\n\n        mrope_deltas = [meta['mrope_delta'] for meta in model_metas]\n        mrope_deltas = position_ids.new_tensor(mrope_deltas)\n        mrope_position_ids = position_ids + mrope_deltas[None]\n        mrope_position_ids = mrope_position_ids.expand(3, -1)\n\n        context.mrope_position_ids = mrope_position_ids\n        return model_metas\n\n    def _get_multimodal_pos_ids(self, grid_thw: list, device: torch.device):\n        \"\"\"Get mrope ids.\"\"\"\n        t, h, w = grid_thw\n        h //= 2\n        w //= 2\n        stride = torch.tensor([h * w, w, 1], device=device)[:, None]\n        size = torch.tensor([t, h, w], device=device)[:, None]\n        pos_ids = torch.arange(t * h * w, device=device)[None].expand(3, -1)\n        pos_ids = pos_ids // stride % size\n        return pos_ids\n\n    def _update_model_meta_prefilling(self, context: StepContext):\n        \"\"\"Update model meta for prefilling.\"\"\"\n        model_metas = self._get_model_metas(context)\n        input_multimodals = context.input_multimodals\n        if input_multimodals is None:\n            input_multimodals = [None] * len(model_metas)\n        position_ids = context.position_ids\n        batched_pos_ids = position_ids[0].split(context.q_seqlens.tolist())\n        mrope_position_ids = []\n        new_model_metas = []\n        for pos_ids, model_meta, input_mm in zip(batched_pos_ids, model_metas, input_multimodals):\n            images = []\n            if input_mm is not None:\n                images = input_mm.get('image', [])\n            if model_meta is None or 'mrope_delta' not in model_meta:\n                mrope_delta = 0\n            else:\n                mrope_delta = model_meta['mrope_delta']\n\n            pos_start = pos_ids[0].item()\n            mrope_pos_ids = pos_ids + mrope_delta\n            mrope_pos_ids = mrope_pos_ids[None].expand(3, -1).clone()\n            for img in images:\n                grid_thw = img.meta['grid_thw'][0].tolist()\n                _, h, w = grid_thw\n                h //= 2\n                w //= 2\n                num_pad = img.end - img.start - max(h, w)\n                mrope_delta -= num_pad\n                fill_start = img.start - pos_start\n                fill_end = img.end - pos_start\n                img_pos_ids = self._get_multimodal_pos_ids(grid_thw, pos_ids.device)\n                img_pos_ids += mrope_pos_ids[:, fill_start:fill_start + 1]\n                mrope_pos_ids[:, fill_end:] -= num_pad\n                mrope_pos_ids[:, fill_start:fill_end] = img_pos_ids\n\n            mrope_position_ids.append(mrope_pos_ids)\n            new_model_metas.append(dict(mrope_delta=mrope_delta))\n\n        mrope_position_ids = torch.cat(mrope_position_ids, dim=1)\n        context.mrope_position_ids = mrope_position_ids\n\n        return new_model_metas\n\n    def update_model_metas(self,\n                           past_key_values: List[List[torch.Tensor]],\n                           inputs_embeds: Optional[torch.Tensor] = None,\n                           context: StepContext = None):\n        \"\"\"Update model meta.\"\"\"\n        if context.is_decoding:\n            return self._update_model_meta_decoding(context)\n        else:\n            return self._update_model_meta_prefilling(context)\n\n    def get_input_processor(self) -> BaseModelInputProcessor:\n        \"\"\"Get input processor.\"\"\"\n        return self.input_processor\n\n\nclass Qwen2VLInputProcessor(BaseModelInputProcessor):\n    \"\"\"Qwen2 input processor.\"\"\"\n\n    def __init__(self, config: PretrainedConfig) -> None:\n        self.config = config\n\n    def preprocess_input(self,\n                         input_ids: List[int],\n                         input_multimodals: List[Dict[str, Any]] = None,\n                         **kwargs) -> PreprocessInputResult:\n        \"\"\"Prepare multimodal input.\"\"\"\n        if input_multimodals is None or len(input_multimodals) == 0:\n            return input_ids, input_multimodals\n\n        input_imgs = []\n        for input_mm in input_multimodals:\n            pixel_values = input_mm['pixel_values']\n            image_grid_thw = input_mm['image_grid_thw']\n            offset = input_mm['offset']\n            start = offset\n            image_token_id = input_mm['image_token_id']\n            num_pad = input_mm['image_tokens']\n            if isinstance(num_pad, torch.Tensor):\n                num_pad = num_pad.item()\n\n            mm_data = MultiModalData(data=pixel_values,\n                                     start=start,\n                                     end=start + num_pad,\n                                     meta=dict(grid_thw=image_grid_thw, image_token_id=image_token_id))\n            input_imgs.append(mm_data)\n\n        result = PreprocessInputResult(\n            input_ids=input_ids,\n            input_multimodals=dict(image=input_imgs),\n        )\n        return result\n"
  },
  {
    "path": "lmdeploy/pytorch/models/qwen3.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nfrom typing import Any, Iterable, List, Optional, Tuple\n\nimport torch\nfrom torch import nn\nfrom transformers.configuration_utils import PretrainedConfig\n\nfrom lmdeploy.pytorch.model_inputs import StepContext, StepContextManager\nfrom lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, RMSNorm, SiluAndMul, build_rotary_embedding_from_config\nfrom lmdeploy.pytorch.nn.linear import build_down_linear, build_gateup_linear, build_o_proj, build_qkv_proj\nfrom lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight\n\nfrom .patch import add_prefix\nfrom .utils.cudagraph import CudaGraphMixin\nfrom .utils.model import DeployModelMixinV1, build_embedding\n\n\nclass Qwen3Attention(nn.Module):\n    \"\"\"Rewrite module of Qwen3Attention.\"\"\"\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None,\n                 prefix: str = ''):\n        super().__init__()\n        quantization_config = getattr(config, 'quantization_config', None)\n        num_heads = config.num_attention_heads\n        num_key_value_heads = config.num_key_value_heads\n        hidden_size = config.hidden_size\n        head_dim = getattr(config, 'head_dim', hidden_size // num_heads)\n        num_replicate_kv_heads = getattr(config, 'num_replicate_key_value_heads', 1)\n        # packed qkv\n        # Qwen3 uses 'config.attention_bias = False' for q/k/o projections\n        self.qkv_proj = build_qkv_proj(\n            hidden_size,\n            num_q_heads=num_heads,\n            num_kv_heads=num_key_value_heads,\n            head_size=head_dim,\n            bias=config.attention_bias,\n            quant_config=quantization_config,\n            dtype=dtype,\n            device=device,\n            num_replicate_kv_heads=num_replicate_kv_heads,\n            prefix=add_prefix('qkv_proj', prefix),\n        )\n\n        # rotary embedding\n        self.apply_rotary_pos_emb = ApplyRotaryEmb()\n\n        # attention\n        self.attn_fwd = Attention(\n            num_heads,\n            head_dim,\n            num_kv_heads=num_key_value_heads,\n            v_head_size=head_dim,\n            sliding_window=getattr(config, 'sliding_window', None),\n        )\n\n        # o_proj\n        self.o_proj = build_o_proj(\n            num_heads * head_dim,\n            hidden_size,\n            bias=config.attention_bias,\n            quant_config=quantization_config,\n            dtype=dtype,\n            device=device,\n            is_tp=True,\n            prefix=add_prefix('o_proj', prefix),\n        )\n\n        # q, k norm\n        self.q_norm = RMSNorm(head_dim, config.rms_norm_eps, dtype=dtype, device=device)\n        self.k_norm = RMSNorm(head_dim, config.rms_norm_eps, dtype=dtype, device=device)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],\n        past_key_value: Optional[Tuple[torch.Tensor]] = None,\n        attn_metadata: Any = None,\n    ):\n        \"\"\"Rewrite of LlamaAttention.forward.\"\"\"\n        # qkv proj\n        qkv_states = self.qkv_proj(hidden_states)\n        # (-1, heads, head_dim)\n        qkv_states = qkv_states.flatten(0, -2)\n        query_states, key_states, value_states = self.qkv_proj.split_qkv(qkv_states)\n\n        # apply q, k norm\n        query_states = self.q_norm(query_states)\n        key_states = self.k_norm(key_states)\n\n        # apply rotary embedding\n        cos, sin = rotary_pos_emb\n        query_states, key_states = self.apply_rotary_pos_emb(\n            query_states,\n            key_states,\n            cos,\n            sin,\n            inplace=True,\n        )\n\n        # attention\n        attn_output = self.attn_fwd(\n            query_states,\n            key_states,\n            value_states,\n            past_key_value[0],\n            past_key_value[1],\n            attn_metadata,\n            k_scales_zeros=None if len(past_key_value) == 2 else past_key_value[2],\n            v_scales_zeros=None if len(past_key_value) == 2 else past_key_value[3],\n            inplace=True,\n        )\n        attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1)\n\n        # o proj\n        attn_output = self.o_proj(attn_output)\n        return attn_output\n\n\nclass Qwen3MLP(nn.Module):\n    \"\"\"mlp.\"\"\"\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None,\n                 prefix: str = ''):\n        super().__init__()\n        quantization_config = getattr(config, 'quantization_config', None)\n        # gate up\n        self.gate_up_proj = build_gateup_linear(\n            config.hidden_size,\n            [config.intermediate_size, config.intermediate_size],\n            bias=False,\n            dtype=dtype,\n            device=device,\n            quant_config=quantization_config,\n            is_tp=True,\n            prefix=add_prefix('gate_up_proj', prefix),\n        )\n\n        # silu and mul\n        self.act_fn = SiluAndMul(inplace=True)\n\n        # down\n        self.down_proj = build_down_linear(\n            config.intermediate_size,\n            config.hidden_size,\n            bias=False,\n            quant_config=quantization_config,\n            dtype=dtype,\n            device=device,\n            is_tp=True,\n            prefix=add_prefix('down_proj', prefix),\n        )\n\n    def forward(self, x):\n        \"\"\"forward.\"\"\"\n        gate_up = self.gate_up_proj(x)\n        act = self.act_fn(gate_up)\n        return self.down_proj(act)\n\n\nclass Qwen3DecoderLayer(nn.Module):\n    \"\"\"Decoder layer.\"\"\"\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 layer_idx: int,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None,\n                 prefix: str = ''):\n        super().__init__()\n        self.layer_idx = layer_idx\n        quantization_config = getattr(config, 'quantization_config', None)\n\n        # build attention layer\n        self.self_attn = Qwen3Attention(config, dtype=dtype, device=device, prefix=add_prefix('self_attn', prefix))\n\n        # build MLP\n        self.mlp = Qwen3MLP(config, dtype=dtype, device=device, prefix=add_prefix('mlp', prefix))\n\n        # build input layer norm\n        self.input_layernorm = RMSNorm(\n            config.hidden_size,\n            config.rms_norm_eps,\n            quant_config=quantization_config,\n            dtype=dtype,\n            device=device,\n            prefix=add_prefix('input_layernorm', prefix),\n        )\n\n        # build attention layer norm\n        self.post_attention_layernorm = RMSNorm(\n            config.hidden_size,\n            config.rms_norm_eps,\n            quant_config=quantization_config,\n            dtype=dtype,\n            device=device,\n            prefix=add_prefix('post_attention_layernorm', prefix),\n        )\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],\n        past_key_value: Optional[List[torch.FloatTensor]],\n        residual: Optional[torch.Tensor] = None,\n        attn_metadata: Any = None,\n    ):\n\n        if residual is None:\n            residual = hidden_states\n            hidden_states = self.input_layernorm(hidden_states)\n        else:\n            hidden_states, residual = self.input_layernorm(hidden_states, residual)\n\n        # Self Attention\n        hidden_states = self.self_attn(\n            hidden_states=hidden_states,\n            rotary_pos_emb=rotary_pos_emb,\n            past_key_value=past_key_value,\n            attn_metadata=attn_metadata,\n        )\n\n        # Fully Connected\n        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)\n        hidden_states = self.mlp(hidden_states)\n\n        outputs = (hidden_states, residual)\n        return outputs\n\n\nclass Qwen3model(nn.Module):\n    \"\"\"model.\"\"\"\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None,\n                 prefix: str = ''):\n        super().__init__()\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n\n        self.embed_tokens = build_embedding(\n            config.vocab_size,\n            config.hidden_size,\n            self.padding_idx,\n            dtype=dtype,\n            device=device,\n        )\n\n        # build all decode layers\n        self.layers = nn.ModuleList([\n            Qwen3DecoderLayer(config,\n                              layer_idx,\n                              dtype=dtype,\n                              device=device,\n                              prefix=add_prefix(f'layers.{layer_idx}', prefix))\n            for layer_idx in range(config.num_hidden_layers)\n        ])\n\n        # build norm\n        self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)\n\n        # build rotary embedding\n        self.rotary_emb = build_rotary_embedding_from_config(config, device=device)\n\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        attn_metadata: Any = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n    ):\n        \"\"\"Rewrite of LlamaModel.forward.\"\"\"\n\n        # token embedding\n        if inputs_embeds is None:\n            inputs_embeds = self.embed_tokens(input_ids)\n\n        hidden_states = inputs_embeds\n\n        # rotary embedding\n        cos, sin = self.rotary_emb(hidden_states, position_ids)\n        cos, sin = cos[0], sin[0]\n        rotary_pos_emb = (cos, sin)\n\n        # decoding\n        residual = None\n        for idx, decoder_layer in enumerate(self.layers):\n            past_key_value = past_key_values[idx]\n            hidden_states, residual = decoder_layer(\n                hidden_states,\n                rotary_pos_emb=rotary_pos_emb,\n                past_key_value=past_key_value,\n                residual=residual,\n                attn_metadata=attn_metadata,\n            )\n\n        # norm\n        hidden_states, _ = self.norm(hidden_states, residual)\n\n        return hidden_states\n\n    def get_input_embeddings(self):\n        \"\"\"Get input embeddings.\"\"\"\n        return self.embed_tokens\n\n\nclass Qwen3ForCausalLM(nn.Module, DeployModelMixinV1, CudaGraphMixin):\n    \"\"\"ModelForCausalLM.\"\"\"\n\n    packed_modules_mapping = {\n        'qkv_proj': [\n            'q_proj',\n            'k_proj',\n            'v_proj',\n        ],\n        'gate_up_proj': [\n            'gate_proj',\n            'up_proj',\n        ],\n    }\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 ctx_mgr: StepContextManager,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None,\n                 prefix: str = ''):\n        super().__init__()\n        self.config = config\n        self.ctx_mgr = ctx_mgr\n        # build model\n        self.model = Qwen3model(config, dtype=dtype, device=device, prefix=add_prefix('model', prefix))\n        # build lm_head\n        self.lm_head = self.build_lm_head(config.hidden_size, config.vocab_size, bias=False, dtype=dtype, device=device)\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        past_key_values: List[List[torch.Tensor]],\n        attn_metadata: Any = None,\n        inputs_embeds: torch.Tensor = None,\n        **kwargs,\n    ):\n        \"\"\"Model forward, return logits.\"\"\"\n        hidden_states = self.model(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            inputs_embeds=inputs_embeds,\n        )\n        return hidden_states\n\n    def get_input_embeddings(self):\n        \"\"\"Get input embeddings.\"\"\"\n        return self.model.get_input_embeddings()\n\n    def prepare_inputs_for_generation(\n        self,\n        past_key_values: List[List[torch.Tensor]],\n        inputs_embeds: Optional[torch.Tensor] = None,\n        context: StepContext = None,\n    ):\n        \"\"\"Prepare input.\"\"\"\n        # get input_ids, position_ids and attention metadatas\n        input_ids = context.input_ids\n        position_ids = context.position_ids\n        attn_metadata = context.attn_metadata\n\n        # process vision embeddings\n        vision_embeddings = context.input_embeddings\n        vision_embedding_indexing = context.input_embedding_indexing\n        if vision_embeddings is not None and len(vision_embeddings) > 0:\n            if inputs_embeds is None:\n                inputs_embeds = self.get_input_embeddings()(input_ids)\n            inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds)\n\n        # inputs of forward\n        return dict(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            inputs_embeds=inputs_embeds,\n        )\n\n    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):\n        \"\"\"Load weights.\"\"\"\n        # modify from vllm\n        stacked_params_mapping = [\n            # (param_name, shard_name, shard_id)\n            ('.qkv_proj', '.q_proj', 'q'),\n            ('.qkv_proj', '.k_proj', 'k'),\n            ('.qkv_proj', '.v_proj', 'v'),\n            ('.gate_up_proj', '.gate_proj', 0),\n            ('.gate_up_proj', '.up_proj', 1),\n        ]\n\n        params_dict = dict(self.named_parameters())\n        for name, loaded_weight in weights:\n            if 'rotary_emb.inv_freq' in name:\n                continue\n            if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):\n                continue\n            if self.config.tie_word_embeddings and 'lm_head.weight' in name:\n                continue\n\n            for (param_name, weight_name, shard_id) in stacked_params_mapping:\n                if weight_name not in name:\n                    continue\n                name = name.replace(weight_name, param_name)\n                param = params_dict[name]\n                load_weight(param, loaded_weight, shard_id=shard_id)\n                break\n            else:\n                param = params_dict[name]\n                load_weight(param, loaded_weight)\n"
  },
  {
    "path": "lmdeploy/pytorch/models/qwen3_5.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nfrom functools import lru_cache\nfrom typing import Any, Iterable, List, Tuple\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn\nfrom transformers.configuration_utils import PretrainedConfig\nfrom transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update\n\nimport lmdeploy.pytorch.nn.gated_delta as gated_delta_util\nfrom lmdeploy.pytorch.distributed import get_tp_world_rank\nfrom lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor\nfrom lmdeploy.pytorch.model_inputs import StepContext, StepContextManager\nfrom lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, LayerNorm, RMSNorm, SiluAndMul\nfrom lmdeploy.pytorch.nn.gated_delta import CausalConv1d, GatedDelta, GatedDeltaMeta, build_rmsnorm_gated\nfrom lmdeploy.pytorch.nn.linear import (build_colwise_linear, build_merged_colwise_linear, build_o_proj, build_qkv_proj,\n                                        build_rowwise_linear)\nfrom lmdeploy.pytorch.nn.rotary_embedding import get_rope_parameters\nfrom lmdeploy.pytorch.weight_loader.model_weight_loader import default_weight_loader, load_weight\nfrom lmdeploy.vl.constants import Modality\n\nfrom .patch import add_prefix\nfrom .qwen2_5_vl import Qwen2_5_VisionRotaryEmbedding as Qwen3_5VisionRotaryEmbedding\nfrom .qwen2_5_vl import Qwen2_5_VLVisionAttention as Qwen3_5VisionAttention\nfrom .qwen3_vl import Qwen3VLInputProcessor as Qwen3_5InputProcessor\nfrom .utils.cudagraph import CudaGraphMeta, CudaGraphMixin\nfrom .utils.model import DeployModelMixinV1, vlm_model\n\n\nclass Qwen3_5VisionPatchEmbed(nn.Module):\n\n    def __init__(self, config, dtype: torch.dtype | None = None, device: torch.device | None = None) -> None:\n        super().__init__()\n        self.patch_size = config.patch_size\n        self.temporal_patch_size = config.temporal_patch_size\n        self.in_channels = config.in_channels\n        self.embed_dim = config.hidden_size\n\n        kernel_size = (self.temporal_patch_size, self.patch_size, self.patch_size)\n        self.proj = nn.Conv3d(self.in_channels,\n                              self.embed_dim,\n                              kernel_size=kernel_size,\n                              stride=kernel_size,\n                              bias=True,\n                              dtype=dtype,\n                              device=device)\n\n    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        target_dtype = self.proj.weight.dtype\n        hidden_states = hidden_states.view(-1, self.in_channels, self.temporal_patch_size, self.patch_size,\n                                           self.patch_size)\n        hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim)\n        return hidden_states\n\n\nclass Qwen3_5VisionMLP(nn.Module):\n    \"\"\"Vision mlp.\"\"\"\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 dtype: torch.dtype | None = None,\n                 device: torch.device | None = None,\n                 prefix: str = ''):\n        super().__init__()\n        from transformers.activations import ACT2FN\n        hidden_dim = config.hidden_size\n        intermediate_size = config.intermediate_size\n        quantization_config = getattr(config, 'quantization_config', None)\n        # gate up\n        self.linear_fc1 = build_colwise_linear(\n            hidden_dim,\n            intermediate_size,\n            bias=True,\n            dtype=dtype,\n            device=device,\n            quant_config=quantization_config,\n            is_tp=True,\n            prefix=add_prefix('linear_fc1', prefix),\n        )\n\n        # gelu_pytorch_tanh\n        self.act = ACT2FN[config.hidden_act]\n\n        # down\n        self.linear_fc2 = build_rowwise_linear(\n            intermediate_size,\n            hidden_dim,\n            bias=True,\n            dtype=dtype,\n            device=device,\n            quant_config=quantization_config,\n            is_tp=True,\n            prefix=add_prefix('linear_fc2', prefix),\n        )\n\n    def forward(self, x):\n        \"\"\"forward.\"\"\"\n        return self.linear_fc2(self.act(self.linear_fc1(x)))\n\n\nclass Qwen3_5VisionBlock(nn.Module):\n    \"\"\"Vision block.\"\"\"\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 layer_idx: int,\n                 dtype: torch.dtype | None = None,\n                 device: torch.device | None = None,\n                 prefix: str = ''):\n        super().__init__()\n        self.layer_idx = layer_idx\n        self.norm1 = LayerNorm(config.hidden_size, eps=1e-6, dtype=dtype, device=device)\n        self.norm2 = LayerNorm(config.hidden_size, eps=1e-6, dtype=dtype, device=device)\n\n        self.attn = Qwen3_5VisionAttention(config, dtype=dtype, device=device, prefix=add_prefix('attn', prefix))\n\n        self.mlp = Qwen3_5VisionMLP(config, dtype=dtype, device=device, prefix=add_prefix('mlp', prefix))\n\n    def forward(self,\n                hidden_states: torch.Tensor,\n                cu_seqlens: torch.Tensor,\n                rotary_pos_emb: torch.Tensor | None = None) -> torch.Tensor:\n        hidden_states = hidden_states + self.attn(\n            self.norm1(hidden_states),\n            cu_seqlens=cu_seqlens,\n            rotary_pos_emb=rotary_pos_emb,\n        )\n        hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))\n        return hidden_states\n\n\nclass Qwen3_5VisionPatchMerger(nn.Module):\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 use_postshuffle_norm=False,\n                 dtype: torch.dtype | None = None,\n                 device: torch.device | None = None) -> None:\n        super().__init__()\n        self.hidden_size = config.hidden_size * (config.spatial_merge_size**2)\n        self.use_postshuffle_norm = use_postshuffle_norm\n        self.norm = LayerNorm(self.hidden_size if use_postshuffle_norm else config.hidden_size,\n                              eps=1e-6,\n                              dtype=dtype,\n                              device=device)\n        self.linear_fc1 = build_colwise_linear(\n            self.hidden_size,\n            self.hidden_size,\n            bias=True,\n            dtype=dtype,\n            device=device,\n            is_tp=True,\n        )\n        self.act_fn = nn.GELU()\n        self.linear_fc2 = build_rowwise_linear(\n            self.hidden_size,\n            config.out_hidden_size,\n            bias=True,\n            dtype=dtype,\n            device=device,\n            is_tp=True,\n        )\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        x = self.norm(x.view(-1, self.hidden_size) if self.use_postshuffle_norm else x).view(-1, self.hidden_size)\n        x = self.linear_fc2(self.act_fn(self.linear_fc1(x)))\n        return x\n\n\n@vlm_model\nclass Qwen3_5VisionModel(nn.Module):\n    \"\"\"qwen3.5 vision model.\"\"\"\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 dtype: torch.dtype | None = None,\n                 device: torch.device | None = None,\n                 prefix: str = ''):\n        super().__init__()\n        self.config = config\n        self.spatial_merge_size = config.spatial_merge_size\n\n        self.patch_embed = Qwen3_5VisionPatchEmbed(config=config, dtype=dtype, device=device)\n\n        self.pos_embed = nn.Embedding(config.num_position_embeddings, config.hidden_size, dtype=dtype, device=device)\n        self.num_grid_per_side = int(config.num_position_embeddings**0.5)\n\n        head_dim = config.hidden_size // config.num_heads\n        self.rotary_pos_emb = Qwen3_5VisionRotaryEmbedding(head_dim // 2, device=device)\n\n        self.blocks = nn.ModuleList([\n            Qwen3_5VisionBlock(config,\n                               layer_idx,\n                               dtype=dtype,\n                               device=device,\n                               prefix=add_prefix(f'blocks.{layer_idx}', prefix)) for layer_idx in range(config.depth)\n        ])\n        self.merger = Qwen3_5VisionPatchMerger(config=config, use_postshuffle_norm=False, dtype=dtype, device=device)\n\n    @staticmethod\n    @lru_cache(maxsize=1024)\n    def rot_pos_ids(h: int, w: int, spatial_merge_size: int) -> torch.Tensor:\n        h_div = h // spatial_merge_size\n        w_div = w // spatial_merge_size\n\n        hpos_ids = np.broadcast_to(np.arange(h).reshape(h, 1), (h, w))\n        hpos_ids = hpos_ids.reshape(\n            h_div,\n            spatial_merge_size,\n            w_div,\n            spatial_merge_size,\n        )\n        hpos_ids = hpos_ids.transpose(0, 2, 1, 3)\n        hpos_ids = hpos_ids.flatten()\n\n        wpos_ids = np.broadcast_to(np.arange(w).reshape(1, w), (h, w))\n        wpos_ids = wpos_ids.reshape(\n            h_div,\n            spatial_merge_size,\n            w_div,\n            spatial_merge_size,\n        )\n        wpos_ids = wpos_ids.transpose(0, 2, 1, 3)\n        wpos_ids = wpos_ids.flatten()\n\n        return torch.from_numpy(np.stack([hpos_ids, wpos_ids], axis=-1))\n\n    def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:\n        \"\"\"Rotary position embedding.\"\"\"\n        pos_ids = []\n\n        for t, h, w in grid_thw:\n            base = self.rot_pos_ids(int(h), int(w), self.spatial_merge_size)\n            pos_ids.append(base if t == 1 else base.repeat(t, 1))\n\n        pos_ids = torch.cat(pos_ids, dim=0)\n        max_grid_size = grid_thw[:, 1:].max()\n        rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)\n        rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)\n\n        return rotary_pos_emb\n\n    # copy from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/qwen3_vl.py#L474\n    def fast_pos_embed_interpolate(self, grid_thw: List[List[int]]) -> torch.Tensor:\n        num_grid_per_side = self.num_grid_per_side\n        m_size = self.spatial_merge_size\n        hidden_dim = self.pos_embed.embedding_dim\n        device = self.pos_embed.weight.device\n\n        outputs = []\n        for t, h, w in grid_thw:\n            h_idxs = torch.linspace(0, num_grid_per_side - 1, h, dtype=torch.float32, device=device)\n            w_idxs = torch.linspace(0, num_grid_per_side - 1, w, dtype=torch.float32, device=device)\n\n            h_floor = h_idxs.to(torch.long)\n            w_floor = w_idxs.to(torch.long)\n            h_ceil = torch.clamp(h_floor + 1, max=num_grid_per_side - 1)\n            w_ceil = torch.clamp(w_floor + 1, max=num_grid_per_side - 1)\n\n            dh = h_idxs - h_floor\n            dw = w_idxs - w_floor\n\n            # Create meshgrid view for all h, w vars\n            dh_grid, dw_grid = torch.meshgrid(dh, dw, indexing='ij')\n            h_floor_grid, w_floor_grid = torch.meshgrid(h_floor, w_floor, indexing='ij')\n            h_ceil_grid, w_ceil_grid = torch.meshgrid(h_ceil, w_ceil, indexing='ij')\n\n            # original computation of weights\n            # w00 = (1 - dh_grid) * (1 - dw_grid)\n            # w01 = (1 - dh_grid) * dw_grid\n            # w10 = dh_grid * (1 - dw_grid)\n            # w11 = dh_grid * dw_grid\n            # we reuse w11 here to avoid duplicate\n            # dh_grid * dw_grid computation\n            w11 = dh_grid * dw_grid\n            w10 = dh_grid - w11\n            w01 = dw_grid - w11\n            w00 = 1 - dh_grid - w01\n\n            h_grid = torch.stack([h_floor_grid, h_floor_grid, h_ceil_grid, h_ceil_grid])\n            w_grid = torch.stack([w_floor_grid, w_ceil_grid, w_floor_grid, w_ceil_grid])\n            h_grid_idx = h_grid * num_grid_per_side\n\n            indices = (h_grid_idx + w_grid).reshape(4, -1)\n            weights = torch.stack([w00, w01, w10, w11], dim=0).reshape(4, -1, 1)\n            weights = weights.to(dtype=self.pos_embed.weight.dtype, device=device)\n\n            embeds = self.pos_embed(indices)\n            embeds *= weights\n            combined = embeds.sum(dim=0)\n\n            combined = combined.reshape(h // m_size, m_size, w // m_size, m_size, hidden_dim)\n            combined = combined.permute(0, 2, 1, 3, 4).reshape(1, -1, hidden_dim)\n            repeated = combined.expand(t, -1, -1).reshape(-1, hidden_dim)\n            outputs.append(repeated)\n\n        return torch.cat(outputs, dim=0)\n\n    def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor,\n                pos_embeds: torch.Tensor) -> torch.Tensor:\n        \"\"\"forward.\"\"\"\n        hidden_states = self.patch_embed(hidden_states)\n        hidden_states = hidden_states + pos_embeds\n        cu_seqlens = torch.nn.functional.pad(cu_seqlens, (1, 0), value=0)\n\n        for _, blk in enumerate(self.blocks):\n            hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb)\n\n        hidden_states = self.merger(hidden_states)\n\n        return hidden_states\n\n\nclass Qwen3_5MLP(nn.Module):\n    \"\"\"mlp.\"\"\"\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 intermediate_size: int | None = None,\n                 dtype: torch.dtype | None = None,\n                 device: torch.device | None = None,\n                 is_tp: bool = True,\n                 all_reduce: bool = True,\n                 prefix: str = ''):\n        super().__init__()\n        quantization_config = getattr(config, 'quantization_config', None)\n        if intermediate_size is None:\n            intermediate_size = config.intermediate_size\n        # gate up\n        self.gate_up_proj = build_merged_colwise_linear(\n            config.hidden_size,\n            [intermediate_size, intermediate_size],\n            bias=False,\n            dtype=dtype,\n            device=device,\n            quant_config=quantization_config,\n            is_tp=is_tp,\n            prefix=add_prefix('gate_up_proj', prefix),\n        )\n\n        # silu and mul\n        self.act_fn = SiluAndMul(inplace=True)\n\n        # down\n        self.down_proj = build_rowwise_linear(\n            intermediate_size,\n            config.hidden_size,\n            bias=False,\n            quant_config=quantization_config,\n            dtype=dtype,\n            device=device,\n            is_tp=is_tp,\n            all_reduce=all_reduce,\n            prefix=add_prefix('down_proj', prefix),\n        )\n\n    def forward(self, x, all_routed_experts: torch.Tensor | None = None):\n        \"\"\"forward.\"\"\"\n        gate_up = self.gate_up_proj(x)\n        act = self.act_fn(gate_up)\n        return self.down_proj(act)\n\n\nclass Qwen3_5GatedDeltaNet(nn.Module):\n    \"\"\"Gated deltanet.\"\"\"\n\n    def __init__(\n        self,\n        config: PretrainedConfig,\n        layer_idx: int,\n        dtype: torch.dtype | None = None,\n        device: torch.device | None = None,\n        prefix: str = '',\n    ):\n        super().__init__()\n        self.hidden_size = config.hidden_size\n        self.num_v_heads = config.linear_num_value_heads\n        self.num_k_heads = config.linear_num_key_heads\n        self.head_k_dim = config.linear_key_head_dim\n        self.head_v_dim = config.linear_value_head_dim\n        self.key_dim = self.head_k_dim * self.num_k_heads\n        self.value_dim = self.head_v_dim * self.num_v_heads\n        self.kv_ratio = self.num_v_heads // self.num_k_heads\n\n        self.conv_kernel_size = config.linear_conv_kernel_dim\n        self.layer_idx = layer_idx\n        self.activation = config.hidden_act\n        self.layer_norm_epsilon = config.rms_norm_eps\n\n        # QKV\n        self.conv_dim = self.key_dim * 2 + self.value_dim\n        self.conv1d = CausalConv1d(\n            in_channels=self.conv_dim,\n            out_channels=self.conv_dim,\n            kernel_size=self.conv_kernel_size,\n            split=[self.key_dim, self.key_dim, self.value_dim],\n            bias=False,\n            groups=self.conv_dim,\n            dtype=dtype,\n            device=device,\n        )\n\n        # projection of the input hidden states\n        projection_size_qkv = self.key_dim * 2 + self.value_dim\n        self.in_proj_qkv = build_colwise_linear(self.hidden_size,\n                                                projection_size_qkv,\n                                                bias=False,\n                                                dtype=dtype,\n                                                device=device,\n                                                is_tp=True)\n        self.in_proj_qkv.weight.weight_loader = self.weight_loader_qkv\n        self.in_proj_zba = build_merged_colwise_linear(\n            self.hidden_size,\n            [self.value_dim, self.num_v_heads, self.num_v_heads],\n            bias=False,\n            dtype=dtype,\n            device=device,\n            is_tp=True,\n            out_names=['z', 'b', 'a'],\n        )\n\n        # time step projection (discretization)\n        # instantiate once and copy inv_dt in init_weights of PretrainedModel\n        self.make_params(self.num_v_heads, device=device)\n        self.A_log_exp = None\n\n        self.norm = build_rmsnorm_gated(self.head_v_dim,\n                                        eps=self.layer_norm_epsilon,\n                                        activation=self.activation,\n                                        dtype=dtype,\n                                        device=device)\n        self.out_proj = build_o_proj(self.value_dim,\n                                     self.hidden_size,\n                                     bias=False,\n                                     dtype=dtype,\n                                     device=device,\n                                     is_tp=True)\n\n        self.gated_delta = GatedDelta()\n\n    def get_A_log_exp(self):\n        if self.A_log_exp is None:\n            self.A_log_exp = -self.A_log.float().exp()\n\n        return self.A_log_exp\n\n    def make_params(self, num_v_heads: int, device: torch.device | None):\n        tp, _ = get_tp_world_rank()\n        num_v_heads = num_v_heads // tp\n        A = torch.empty(num_v_heads, device=device)\n        dt_bias = torch.empty(num_v_heads, device=device)\n\n        self.register_parameter('A_log', nn.Parameter(torch.log(A)))\n        self.register_parameter('dt_bias', nn.Parameter(dt_bias))\n        self.A_log.weight_loader = self.weight_loader_a_dt\n        self.dt_bias.weight_loader = self.weight_loader_a_dt\n\n    def weight_loader_qkv(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor):\n        \"\"\"Weight loader for qkv projection.\"\"\"\n        tp, rank = get_tp_world_rank()\n        q, k, v = loaded_weight.split([self.key_dim, self.key_dim, self.value_dim], dim=0)\n        q = q.chunk(tp, dim=0)[rank]\n        k = k.chunk(tp, dim=0)[rank]\n        v = v.chunk(tp, dim=0)[rank]\n        loaded_weight = torch.cat([q, k, v], dim=0)\n        default_weight_loader(param, loaded_weight)\n\n    def weight_loader_a_dt(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor):\n        \"\"\"Weight loader.\"\"\"\n        tp, rank = get_tp_world_rank()\n        loaded_weight = loaded_weight.chunk(tp, dim=0)[rank]\n        default_weight_loader(param, loaded_weight)\n\n    def fix_zba_ordering(self, mixed_zba: torch.Tensor):\n        \"\"\"Derives `query`, `key` and `value` tensors from `mixed_qkv` and\n        `mixed_zba`.\"\"\"\n\n        # zba\n        split_arg_list_zba = [self.head_v_dim * self.kv_ratio, self.kv_ratio, self.kv_ratio]\n        num_heads = mixed_zba.size(-1) // sum(split_arg_list_zba)\n        split_arg_list_zba = [num_heads * x for x in split_arg_list_zba]\n        z, b, a = torch.split(mixed_zba, split_arg_list_zba, dim=-1)\n        # [..., ng, np/ng * hn] -> [..., np, hn]\n        z = z.unflatten(-1, (-1, self.head_v_dim))\n        return z, b, a\n\n    def _load_state(self, past_key_value: Tuple[torch.Tensor, torch.Tensor], gated_delta_meta: GatedDeltaMeta):\n        \"\"\"Load states from cache.\"\"\"\n        return gated_delta_util.load_state(past_key_value=past_key_value, gated_delta_meta=gated_delta_meta)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        past_key_value: Tuple[torch.Tensor, torch.Tensor],\n        gated_delta_meta: GatedDeltaMeta,\n    ):\n        \"\"\"forward.\"\"\"\n\n        # load states\n        conv_state, recurrent_state = self._load_state(past_key_value, gated_delta_meta)\n\n        # inputs proj\n        projected_states_qkv = self.in_proj_qkv(hidden_states)\n        projected_states_zba = self.in_proj_zba(hidden_states)\n        z, b, a = self.fix_zba_ordering(projected_states_zba)\n\n        mixed_qkv = projected_states_qkv\n        mixed_qkv, conv_state = self.conv1d(mixed_qkv, conv_state, gated_delta_meta=gated_delta_meta)\n\n        tp = (self.key_dim * 2 + self.value_dim) // mixed_qkv.size(-1)\n        query, key, value = torch.split(\n            mixed_qkv,\n            [\n                self.key_dim // tp,\n                self.key_dim // tp,\n                self.value_dim // tp,\n            ],\n            dim=-1,\n        )\n        query = query.unflatten(-1, (-1, self.head_k_dim))\n        key = key.unflatten(-1, (-1, self.head_k_dim))\n        value = value.unflatten(-1, (-1, self.head_v_dim))\n\n        beta = b.sigmoid()\n        # If the model is loaded in fp16, without the .float() here, A might be -inf\n        g = self.get_A_log_exp() * F.softplus(a.float() + self.dt_bias)\n        if self.kv_ratio > 1:\n            query = query.repeat_interleave(self.kv_ratio, dim=-2)\n            key = key.repeat_interleave(self.kv_ratio, dim=-2)\n\n        core_attn_out, recurrent_state = self.gated_delta(\n            query,\n            key,\n            value,\n            g=g,\n            beta=beta,\n            recurrent_state=recurrent_state,\n            gated_delta_meta=gated_delta_meta,\n        )\n\n        z_shape_og = z.shape\n        core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1])\n        z = z.reshape(-1, z.shape[-1])\n        core_attn_out = self.norm(core_attn_out, z)\n        core_attn_out = core_attn_out.reshape(z_shape_og)\n        core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1)\n\n        output = self.out_proj(core_attn_out)\n        return output\n\n\nclass Qwen3_5Attention(nn.Module):\n    \"\"\"Rewrite module of Qwen3MoeAttention.\"\"\"\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 layer_idx: int,\n                 dtype: torch.dtype | None = None,\n                 device: torch.device | None = None,\n                 prefix: str = ''):\n        super().__init__()\n        quantization_config = getattr(config, 'quantization_config', None)\n        num_heads = config.num_attention_heads\n        num_key_value_heads = config.num_key_value_heads\n        hidden_size = config.hidden_size\n        head_dim = getattr(config, 'head_dim', hidden_size // num_heads)\n        self.head_dim = head_dim\n        self.layer_idx = layer_idx\n        num_replicate_kv_heads = getattr(config, 'num_replicate_key_value_heads', 1)\n\n        # packed qkv\n        # Qwen3 uses 'config.attention_bias = False' for q/k/o projections\n        self.qkv_proj = build_qkv_proj(\n            hidden_size,\n            num_q_heads=num_heads * 2,\n            num_kv_heads=num_key_value_heads,\n            head_size=head_dim,\n            bias=config.attention_bias,\n            quant_config=quantization_config,\n            num_replicate_kv_heads=num_replicate_kv_heads,\n            dtype=dtype,\n            device=device,\n            prefix=add_prefix('qkv_proj', prefix),\n        )\n\n        # rotary embedding\n        self.apply_rotary_pos_emb = ApplyRotaryEmb()\n\n        # attention\n        self.attn_fwd = Attention(\n            num_heads,\n            head_dim,\n            num_kv_heads=num_key_value_heads,\n            v_head_size=head_dim,\n        )\n\n        # o_proj\n        self.o_proj = build_o_proj(\n            num_heads * head_dim,\n            hidden_size,\n            bias=config.attention_bias,\n            quant_config=quantization_config,\n            dtype=dtype,\n            device=device,\n            is_tp=True,\n            prefix=add_prefix('o_proj', prefix),\n        )\n\n        # q, k norm\n        self.q_norm = RMSNorm(\n            head_dim,\n            config.rms_norm_eps,\n            quant_config=quantization_config,\n            dtype=dtype,\n            device=device,\n            prefix=add_prefix('q_norm', prefix),\n        )\n        self.k_norm = RMSNorm(\n            head_dim,\n            config.rms_norm_eps,\n            quant_config=quantization_config,\n            dtype=dtype,\n            device=device,\n            prefix=add_prefix('k_norm', prefix),\n        )\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],\n        past_key_value: Tuple[torch.Tensor, torch.Tensor],\n        attn_metadata: Any,\n    ):\n        \"\"\"Rewrite of LlamaAttention.forward.\"\"\"\n        # qkv proj\n        qkv_states = self.qkv_proj(hidden_states)\n        # (-1, heads, head_dim)\n        qkv_states = qkv_states.flatten(0, -2)\n        query_states, key_states, value_states = self.qkv_proj.split_qkv(qkv_states)\n        query_states, gate = query_states.view(*query_states.shape[:-2], -1, 2 * self.head_dim).chunk(2, dim=-1)\n\n        # apply q, k norm\n        query_states = self.q_norm(query_states)\n        key_states = self.k_norm(key_states)\n\n        # apply rotary embedding\n        cos, sin = rotary_pos_emb\n        query_states, key_states = self.apply_rotary_pos_emb(\n            query_states,\n            key_states,\n            cos,\n            sin,\n            inplace=True,\n        )\n\n        # attention\n        attn_output = self.attn_fwd(\n            query_states,\n            key_states,\n            value_states,\n            past_key_value[0],\n            past_key_value[1],\n            attn_metadata,\n            k_scales_zeros=None if len(past_key_value) == 2 else past_key_value[2],\n            v_scales_zeros=None if len(past_key_value) == 2 else past_key_value[3],\n            inplace=True,\n        )\n        attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1)\n        gate = gate.reshape(*hidden_states.shape[:-1], -1)\n        attn_output = attn_output * gate.sigmoid()\n\n        # o proj\n        attn_output = self.o_proj(attn_output)\n        return attn_output\n\n\nclass Qwen3_5DecoderLayer(nn.Module):\n    \"\"\"Decoder layer.\"\"\"\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 layer_idx: int,\n                 dtype: torch.dtype | None = None,\n                 device: torch.device | None = None,\n                 prefix: str = ''):\n        super().__init__()\n        self.layer_idx = layer_idx\n        quantization_config = getattr(config, 'quantization_config', None)\n\n        # build attention layer\n        self.layer_type = config.layer_types[layer_idx]\n        if self.layer_type == 'linear_attention':\n            self.linear_attn = Qwen3_5GatedDeltaNet(config,\n                                                    layer_idx,\n                                                    dtype=dtype,\n                                                    device=device,\n                                                    prefix=add_prefix('linear_attn', prefix))\n        elif self.layer_type == 'full_attention':\n            self.self_attn = Qwen3_5Attention(config,\n                                              layer_idx,\n                                              dtype=dtype,\n                                              device=device,\n                                              prefix=add_prefix('self_attn', prefix))\n\n        # build MLP\n        self.mlp = Qwen3_5MLP(config,\n                              intermediate_size=config.intermediate_size,\n                              dtype=dtype,\n                              device=device,\n                              prefix=add_prefix('mlp', prefix))\n\n        # build input layer norm\n        self.input_layernorm = RMSNorm(\n            config.hidden_size,\n            config.rms_norm_eps,\n            quant_config=quantization_config,\n            dtype=dtype,\n            device=device,\n            prefix=add_prefix('input_layernorm', prefix),\n        )\n\n        # build attention layer norm\n        self.post_attention_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],\n        past_key_value: List[torch.FloatTensor],\n        residual: torch.Tensor | None,\n        attn_metadata: Any,\n        gated_delta_meta: GatedDeltaMeta,\n        all_routed_experts: torch.Tensor | None = None,\n    ):\n\n        if residual is None:\n            residual = hidden_states\n            hidden_states = self.input_layernorm(hidden_states)\n        else:\n            hidden_states, residual = self.input_layernorm(hidden_states, residual)\n\n        # Self Attention\n        if self.layer_type == 'linear_attention':\n            hidden_states = self.linear_attn(\n                hidden_states=hidden_states,\n                past_key_value=past_key_value,\n                gated_delta_meta=gated_delta_meta,\n            )\n        elif self.layer_type == 'full_attention':\n            hidden_states = self.self_attn(\n                hidden_states=hidden_states,\n                rotary_pos_emb=rotary_pos_emb,\n                past_key_value=past_key_value,\n                attn_metadata=attn_metadata,\n            )\n\n        # Fully Connected\n        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)\n        hidden_states = self.mlp(hidden_states, all_routed_experts=all_routed_experts)\n\n        outputs = (hidden_states, residual)\n        return outputs\n\n\nclass Qwen3_5TextRotaryEmbedding(nn.Module):\n    inv_freq: torch.Tensor  # fix linting for `register_buffer`\n\n    def __init__(self, config: PretrainedConfig, device=None):\n        super().__init__()\n        rope_scaling = get_rope_parameters(config)\n        assert rope_scaling is not None, 'RoPE scaling parameters must be provided in the config for Qwen3.5 models.'\n        self.rope_type = rope_scaling.get('rope_type', 'default')\n\n        self.max_seq_len_cached = config.max_position_embeddings\n        self.original_max_seq_len = config.max_position_embeddings\n\n        self.config = config\n        if self.rope_type != 'default':\n            self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]\n        else:\n            self.rope_init_fn = self.compute_default_rope_parameters\n\n        inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)\n        self.register_buffer('inv_freq', inv_freq, persistent=False)\n        self.original_inv_freq = self.inv_freq\n\n        self.mrope_section = rope_scaling.get('mrope_section', [11, 11, 10])\n\n    @staticmethod\n    def compute_default_rope_parameters(\n        config: PretrainedConfig | None = None,\n        device: torch.device | None = None,\n        seq_len: int | None = None,\n    ) -> tuple['torch.Tensor', float]:\n        \"\"\"\n        Computes the inverse frequencies according to the original RoPE implementation\n        Args:\n            config ([`~transformers.PreTrainedConfig`]):\n                The model configuration.\n            device (`torch.device`):\n                The device to use for initialization of the inverse frequencies.\n            seq_len (`int`, *optional*):\n                The current sequence length. Unused for this type of RoPE.\n        Returns:\n            Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the\n            post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).\n        \"\"\"\n        rope_parameters = get_rope_parameters(config)\n        base = rope_parameters['rope_theta']\n        partial_rotary_factor = rope_parameters.get('partial_rotary_factor', 1.0)\n        head_dim = getattr(config, 'head_dim', None) or config.hidden_size // config.num_attention_heads\n        dim = int(head_dim * partial_rotary_factor)\n\n        attention_factor = 1.0  # Unused in this type of RoPE\n\n        # Compute the inverse frequencies\n        inv_freq = 1.0 / (base**(torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim))\n        return inv_freq, attention_factor\n\n    def apply_interleaved_mrope(self, freqs, mrope_section):\n        \"\"\"Apply interleaved MRoPE to 3D rotary embeddings.\n\n        Reorganizes frequency layout from chunked [TTT...HHH...WWW] to\n        interleaved [THTHWHTHW...TT], preserving frequency continuity.\n        args:\n            x: (3, bs, seq_len, head_dim // 2)\n            mrope_section: (3,)\n        returns:\n            x_t: (bs, seq_len, head_dim // 2)\n        \"\"\"\n        freqs_t = freqs[0]  # just overwrite the first dimension T\n        for dim, offset in enumerate((1, 2), start=1):  # H, W\n            length = mrope_section[dim] * 3\n            idx = slice(offset, length, 3)\n            freqs_t[..., idx] = freqs[dim, ..., idx]\n        return freqs_t\n\n    @torch.no_grad()\n    @dynamic_rope_update  # power user: used with advanced RoPE types (e.g. dynamic rope)\n    def forward(self, x, position_ids):\n        # In contrast to other models, Qwen3VL has different position ids for the grids\n        # So we expand the inv_freq to shape (3, ...)\n        if position_ids.ndim == 2:\n            position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)\n        inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)\n        position_ids_expanded = position_ids[:, :, None, :].float()  # shape (3, bs, 1, positions)\n\n        freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)\n        freqs = self.apply_interleaved_mrope(freqs, self.mrope_section)\n        emb = torch.cat((freqs, freqs), dim=-1)\n        cos = emb.cos() * self.attention_scaling\n        sin = emb.sin() * self.attention_scaling\n\n        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)\n\n\nclass Qwen3_5TextModel(nn.Module):\n    \"\"\"qwen3.5 text model.\"\"\"\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 dtype: torch.dtype | None = None,\n                 device: torch.device | None = None,\n                 prefix: str = ''):\n        super().__init__()\n        self.config = config\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n\n        self.embed_tokens = nn.Embedding(config.vocab_size,\n                                         config.hidden_size,\n                                         self.padding_idx,\n                                         dtype=dtype,\n                                         device=device)\n\n        # build all decode layers\n        # TODO: use full config.num_hidden_layers\n        self.layers = nn.ModuleList([\n            Qwen3_5DecoderLayer(config,\n                                layer_idx,\n                                dtype=dtype,\n                                device=device,\n                                prefix=add_prefix(f'layers.{layer_idx}', prefix))\n            for layer_idx in range(self.config.num_hidden_layers)\n        ])\n\n        # build norm\n        self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)\n\n        # build rotary embedding\n        self.rotary_emb = Qwen3_5TextRotaryEmbedding(config, device=device)\n\n    def forward(\n        self,\n        input_ids: torch.LongTensor,\n        position_ids: torch.LongTensor,\n        past_key_values: List[List[torch.Tensor]],\n        attn_metadata: Any,\n        state_ids: torch.Tensor,\n        inputs_embeds: torch.Tensor | None = None,\n        mrope_position_ids: torch.Tensor | None = None,\n        all_routed_experts: torch.Tensor | None = None,\n    ):\n        \"\"\"Rewrite of LlamaModel.forward.\"\"\"\n\n        # token embedding\n        if inputs_embeds is None:\n            inputs_embeds = self.embed_tokens(input_ids)\n\n        hidden_states = inputs_embeds\n\n        # rotary embedding\n        if mrope_position_ids is None:\n            cos, sin = self.rotary_emb(hidden_states, position_ids)\n        else:\n            mrope_position_ids = mrope_position_ids.unsqueeze(1)\n            cos, sin = self.rotary_emb(hidden_states, mrope_position_ids)\n\n        cos, sin = cos[0], sin[0]\n        rotary_pos_emb = (cos, sin)\n\n        # make seq_idx\n        gated_delta_meta = GatedDeltaMeta(hidden_states.size(1), self.config.linear_conv_kernel_dim, state_ids,\n                                          attn_metadata)\n\n        # decoding\n        residual = None\n        for idx, decoder_layer in enumerate(self.layers):\n            hidden_states, residual = decoder_layer(\n                hidden_states,\n                rotary_pos_emb=rotary_pos_emb,\n                past_key_value=past_key_values[idx],\n                residual=residual,\n                attn_metadata=attn_metadata,\n                gated_delta_meta=gated_delta_meta,\n                all_routed_experts=all_routed_experts,\n            )\n\n        # norm\n        hidden_states, _ = self.norm(hidden_states, residual)\n\n        return hidden_states\n\n    def get_input_embeddings(self):\n        \"\"\"Get input embeddings.\"\"\"\n        return self.embed_tokens\n\n\nclass Qwen3_5Model(nn.Module):\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 dtype: torch.dtype | None = None,\n                 device: torch.device | None = None,\n                 prefix: str = ''):\n        super().__init__()\n\n        self.visual = Qwen3_5VisionModel(config.vision_config,\n                                         dtype=dtype,\n                                         device=device,\n                                         prefix=add_prefix('visual', prefix))\n        self.language_model = Qwen3_5TextModel(config.text_config,\n                                               dtype=dtype,\n                                               device=device,\n                                               prefix=add_prefix('language_model', prefix))\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        past_key_values: List[List[torch.Tensor]],\n        attn_metadata: Any,\n        state_ids: torch.Tensor,\n        inputs_embeds: torch.Tensor | None = None,\n        mrope_position_ids: torch.Tensor | None = None,\n        pixel_values: torch.Tensor | None = None,\n        vis_cu_seqlens: torch.Tensor | None = None,\n        vis_pos_emb: torch.Tensor | None = None,\n        image_mask: torch.Tensor | None = None,\n        pos_embeds: torch.Tensor | None = None,\n        grid_thw: torch.Tensor | None = None,\n        all_routed_experts: torch.Tensor | None = None,\n    ):\n        \"\"\"Model forward, return logits.\"\"\"\n\n        if inputs_embeds is None:\n            inputs_embeds = self.get_input_embeddings()(input_ids)\n\n            if pixel_values is not None:\n                dtype = inputs_embeds.dtype\n                pixel_values = pixel_values.to(dtype)\n                vis_pos_emb = (vis_pos_emb[0].to(dtype), vis_pos_emb[1].to(dtype))\n\n                # get image embeds and deepstack visual embeds\n                image_embeds = self.visual(pixel_values,\n                                           cu_seqlens=vis_cu_seqlens,\n                                           rotary_pos_emb=vis_pos_emb,\n                                           pos_embeds=pos_embeds)\n\n                # split image embeds per sample\n                split_sizes = (grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist()\n                image_embeds = torch.split(image_embeds, split_sizes)\n                image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, dtype)\n\n                # mask and scatter to create final input embeddings\n                expanded_image_mask = image_mask.unsqueeze(-1).expand_as(inputs_embeds)\n                inputs_embeds = inputs_embeds.masked_scatter(expanded_image_mask, image_embeds)\n\n        hidden_states = self.language_model(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            state_ids=state_ids,\n            inputs_embeds=inputs_embeds,\n            mrope_position_ids=mrope_position_ids,\n            all_routed_experts=all_routed_experts,\n        )\n        return hidden_states\n\n    def get_input_embeddings(self):\n        \"\"\"Get input embeddings.\"\"\"\n        return self.language_model.get_input_embeddings()\n\n\nclass Qwen3_5ForConditionalGeneration(nn.Module, DeployModelMixinV1, CudaGraphMixin):\n    \"\"\"ModelForCausalLM.\"\"\"\n\n    packed_modules_mapping = {\n        'qkv_proj': [\n            'q_proj',\n            'k_proj',\n            'v_proj',\n        ],\n        'gate_up_proj': [\n            'gate_proj',\n            'up_proj',\n        ],\n    }\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 ctx_mgr: StepContextManager,\n                 dtype: torch.dtype | None = None,\n                 device: torch.device | None = None,\n                 prefix: str = ''):\n        super().__init__()\n        self.config = config\n        self.ctx_mgr = ctx_mgr\n\n        # build preprocessor\n        self.input_processor = Qwen3_5InputProcessor(self.config)\n\n        # build model\n        self.model = Qwen3_5Model(config, dtype=dtype, device=device, prefix=add_prefix('model', prefix))\n        # build lm_head\n        self.lm_head = self.build_lm_head(config.text_config.hidden_size,\n                                          config.text_config.vocab_size,\n                                          bias=False,\n                                          dtype=dtype,\n                                          device=device)\n        # dense model\n        self.enable_return_routed_experts = False\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        past_key_values: List[List[torch.Tensor]],\n        attn_metadata: Any,\n        state_ids: torch.Tensor,\n        inputs_embeds: torch.Tensor | None = None,\n        mrope_position_ids: torch.Tensor | None = None,\n        pixel_values: torch.Tensor | None = None,\n        vis_cu_seqlens: torch.Tensor | None = None,\n        vis_pos_emb: torch.Tensor | None = None,\n        image_mask: torch.Tensor | None = None,\n        pos_embeds: torch.Tensor | None = None,\n        grid_thw: torch.Tensor | None = None,\n        **kwargs,\n    ):\n        \"\"\"Model forward, return logits.\"\"\"\n        all_routed_experts = None\n        if self.enable_return_routed_experts:\n            config = self.config.text_config\n            num_tokens = input_ids.size(1)\n            all_routed_experts = position_ids.new_empty(\n                (num_tokens, config.num_hidden_layers, config.num_experts_per_tok), dtype=torch.uint16)\n\n        hidden_states = self.model(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            state_ids=state_ids,\n            inputs_embeds=inputs_embeds,\n            mrope_position_ids=mrope_position_ids,\n            pixel_values=pixel_values,\n            vis_cu_seqlens=vis_cu_seqlens,\n            vis_pos_emb=vis_pos_emb,\n            image_mask=image_mask,\n            pos_embeds=pos_embeds,\n            grid_thw=grid_thw,\n            all_routed_experts=all_routed_experts,\n        )\n        if all_routed_experts is None:\n            return hidden_states\n        return dict(hidden_states=hidden_states, all_routed_experts=all_routed_experts)\n\n    def get_input_embeddings(self):\n        \"\"\"Get input embeddings.\"\"\"\n        return self.model.get_input_embeddings()\n\n    def prepare_inputs_for_generation(\n        self,\n        past_key_values: List[List[torch.Tensor]],\n        inputs_embeds: torch.Tensor | None = None,\n        context: StepContext | None = None,\n    ):\n        \"\"\"Prepare input.\"\"\"\n        # get input_ids, position_ids and attention metadatas\n        input_ids = context.input_ids\n        position_ids = context.position_ids\n        attn_metadata = context.attn_metadata\n\n        # make past_key_values\n        state_caches = list(cache.transpose(0, 1) for cache in context.state_caches)\n        state_caches = list(zip(state_caches[0], state_caches[1]))\n        past_key_values = list(past_key_values)\n        new_past_key_values = []\n        for layer_type in self.config.text_config.layer_types:\n            if layer_type == 'linear_attention':\n                new_past_key_values.append(state_caches.pop(0))\n            elif layer_type == 'full_attention':\n                new_past_key_values.append(past_key_values.pop(0))\n\n        # vlm inputs\n        pixel_values = None\n        vis_cu_seqlens = None\n        vis_pos_emb = None\n        image_mask = None\n        grid_thw = None\n        pos_embeds = None\n        if context.input_multimodals is not None:\n            mm_inputs = [input_mm.get('mm_data', []) for input_mm in context.input_multimodals]\n            # flatten batch\n            mm_inputs = [item for sublist in mm_inputs for item in sublist]\n\n            if len(mm_inputs) > 0:\n                modality = mm_inputs[0].modality\n                pixel_values = torch.cat([inp.data for inp in mm_inputs])\n\n                image_token_id = mm_inputs[0].meta.get('image_token_id')\n                video_token_id = mm_inputs[0].meta.get('video_token_id')\n                mm_token_id = image_token_id if modality == Modality.IMAGE else video_token_id\n                image_mask = (input_ids == mm_token_id)\n\n                grid_thw = torch.cat([data.meta['grid_thw'] for data in mm_inputs]).cpu()\n                vis_pos_emb = self.model.visual.rot_pos_emb(grid_thw)\n                pos_embeds = self.model.visual.fast_pos_embed_interpolate(grid_thw)\n                vis_cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2],\n                                                         grid_thw[:, 0]).to(pixel_values.device)\n                vis_cu_seqlens = vis_cu_seqlens.cumsum(dim=0, dtype=torch.int32)\n                vis_pos_emb = vis_pos_emb.repeat(1, 2)\n                vis_pos_emb = (vis_pos_emb.cos(), vis_pos_emb.sin())\n\n        mrope_position_ids = getattr(context, 'mrope_position_ids', None)\n\n        # process vision embeddings\n        vision_embeddings = context.input_embeddings\n        vision_embedding_indexing = context.input_embedding_indexing\n        if vision_embeddings is not None and len(vision_embeddings) > 0:\n            if inputs_embeds is None:\n                inputs_embeds = self.get_input_embeddings()(input_ids)\n            inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds)\n\n        # inputs of forward\n        return dict(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=new_past_key_values,\n            attn_metadata=attn_metadata,\n            inputs_embeds=inputs_embeds,\n            state_ids=context.state_offsets,\n            # vl inputs\n            mrope_position_ids=mrope_position_ids,\n            pixel_values=pixel_values,\n            vis_cu_seqlens=vis_cu_seqlens,\n            vis_pos_emb=vis_pos_emb,\n            image_mask=image_mask,\n            grid_thw=grid_thw,\n            pos_embeds=pos_embeds,\n        )\n\n    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):\n        \"\"\"Load weights.\"\"\"\n\n        def __skip_layers(name):\n            \"\"\"We might change the number of layers so we can debug the model\n            with less gpus.\"\"\"\n            import re\n            if '.layers.' not in name:\n                return False\n            matches = re.findall(r'\\.layers\\.(\\d+)\\.', name)\n            layer_id = int(matches[0])\n            return layer_id >= self.config.text_config.num_hidden_layers\n\n        # modify from vllm\n        stacked_params_mapping = [\n            # (param_name, shard_name, shard_id)\n            ('.qkv_proj', '.q_proj', 'q'),\n            ('.qkv_proj', '.k_proj', 'k'),\n            ('.qkv_proj', '.v_proj', 'v'),\n            ('.gate_up_proj', '.gate_proj', 0),\n            ('.gate_up_proj', '.up_proj', 1),\n            ('.in_proj_zba', '.in_proj_z', 'z'),\n            ('.in_proj_zba', '.in_proj_b', 'b'),\n            ('.in_proj_zba', '.in_proj_a', 'a'),\n        ]\n\n        rms_norm_keys = ['model.norm', '.input_layernorm', '.post_attention_layernorm', '.q_norm', '.k_norm']\n\n        params_dict = dict(self.named_parameters())\n        for name, loaded_weight in weights:\n\n            if __skip_layers(name):\n                continue\n\n            if 'mtp.' in name:\n                continue\n            if 'rotary_emb.inv_freq' in name:\n                continue\n            if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):\n                continue\n            if self.config.tie_word_embeddings and 'lm_head.weight' in name:\n                continue\n\n            for (param_name, weight_name, shard_id) in stacked_params_mapping:\n                if weight_name not in name:\n                    continue\n                name = name.replace(weight_name, param_name)\n                param = params_dict[name]\n                load_weight(param, loaded_weight, shard_id=shard_id)\n                break\n            else:\n                if '.qkv.' in name:\n                    # vl attention\n                    param = params_dict[name]\n                    q, k, v = param.weight_spliter(loaded_weight)\n                    load_weight(param, q, shard_id='q')\n                    load_weight(param, k, shard_id='k')\n                    load_weight(param, v, shard_id='v')\n                else:\n                    for rms_norm_key in rms_norm_keys:\n                        if rms_norm_key in name and 'weight' in name:\n                            loaded_weight = loaded_weight + 1\n                            break\n                    param = params_dict[name]\n                    load_weight(param, loaded_weight)\n\n    def make_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs):\n        \"\"\"Make cudagraph buffers from forward inputs.\"\"\"\n\n        max_batchs = graph_meta.max_batchs\n        device = graph_meta.device\n        max_tokens = graph_meta.max_tokens\n\n        input_buffers = super().make_buffers_cudagraph(graph_meta=graph_meta, **kwargs)\n        mrope_position_ids = kwargs.get('mrope_position_ids', None)\n        if mrope_position_ids is not None:\n            input_buffers['mrope_position_ids'] = mrope_position_ids.new_zeros(3, max_tokens)\n\n        state_ids = torch.full((max_batchs, ), -1, dtype=torch.long, device=device)\n        input_buffers['state_ids'] = state_ids\n        return input_buffers\n\n    def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, *args, **kwargs):\n        \"\"\"Fill cudagraph buffers from forward inputs.\"\"\"\n        input_buffers = graph_meta.input_buffers\n        new_inputs = super().fill_buffers_cudagraph(graph_meta, *args, **kwargs)\n        state_ids = kwargs['state_ids']\n        input_buffers['state_ids'].fill_(-1)\n        input_buffers['state_ids'][:state_ids.size(0)].copy_(state_ids)\n        new_inputs['state_ids'] = input_buffers['state_ids']\n\n        input_ids = kwargs.get('input_ids')\n        num_tokens = input_ids.size(-1)\n        new_batch_size = graph_meta.max_batchs\n\n        is_decoding = graph_meta.is_decoding\n        mrope_position_ids = kwargs.get('mrope_position_ids', None)\n        if mrope_position_ids is not None:\n            input_buffers['mrope_position_ids'][:, :num_tokens] = mrope_position_ids\n            if is_decoding:\n                new_inputs['mrope_position_ids'] = input_buffers['mrope_position_ids'][:, :new_batch_size]\n            else:\n                new_inputs['mrope_position_ids'] = input_buffers['mrope_position_ids']\n\n        return new_inputs\n\n    def _get_model_metas(self, context: StepContext):\n        \"\"\"Get model metas.\"\"\"\n        model_metas = context.model_metas\n        if model_metas is None:\n            batch_size = context.q_seqlens.numel()\n            return [dict(mrope_delta=0)] * batch_size\n        return [dict(mrope_delta=0) if meta is None else meta for meta in model_metas]\n\n    def _update_model_meta_decoding(self, context: StepContext):\n        \"\"\"Update model meta for decoding.\"\"\"\n        model_metas = self._get_model_metas(context)\n        position_ids = context.position_ids\n\n        mrope_deltas = [meta['mrope_delta'] for meta in model_metas]\n        mrope_deltas_cpu = torch.tensor(mrope_deltas, device='cpu')\n        if (mrope_deltas_cpu == mrope_deltas_cpu[0]).all():\n            mrope_deltas = position_ids.new_full((len(mrope_deltas), ), mrope_deltas[0])\n        else:\n            mrope_deltas = position_ids.new_tensor(mrope_deltas)\n        mrope_position_ids = position_ids + mrope_deltas[None]\n        mrope_position_ids = mrope_position_ids.expand(3, -1)\n\n        context.mrope_position_ids = mrope_position_ids\n        return model_metas\n\n    def _get_multimodal_pos_ids(self, grid_thw: list, device: torch.device):\n        \"\"\"Get mrope ids.\"\"\"\n        t, h, w = grid_thw\n        h //= 2\n        w //= 2\n        stride = torch.tensor([h * w, w, 1], device=device)[:, None]\n        size = torch.tensor([t, h, w], device=device)[:, None]\n        pos_ids = torch.arange(t * h * w, device=device)[None].expand(3, -1)\n        pos_ids = pos_ids // stride % size\n        return pos_ids\n\n    def _update_model_meta_prefilling(self, context: StepContext):\n        \"\"\"Update model meta for prefilling.\"\"\"\n        model_metas = self._get_model_metas(context)\n        input_multimodals = context.input_multimodals\n        if input_multimodals is None:\n            input_multimodals = [None] * len(model_metas)\n        position_ids = context.position_ids\n        batched_pos_ids = position_ids[0].split(context.q_seqlens.tolist())\n        mrope_position_ids = []\n        new_model_metas = []\n        for pos_ids, model_meta, input_mm in zip(batched_pos_ids, model_metas, input_multimodals):\n            mm_data_list = []\n            if input_mm is not None:\n                mm_data_list.extend(input_mm.get('mm_data', []))\n\n            if model_meta is None or 'mrope_delta' not in model_meta:\n                mrope_delta = 0\n            else:\n                mrope_delta = model_meta['mrope_delta']\n\n            pos_start = pos_ids[0].item()\n            mrope_pos_ids = pos_ids + mrope_delta\n            mrope_pos_ids = mrope_pos_ids[None].expand(3, -1).clone()\n\n            for mm_data in mm_data_list:\n                if mm_data.modality == Modality.IMAGE:\n                    grid_thw = mm_data.meta['grid_thw'][0].tolist()\n                    _, h, w = grid_thw\n                    h //= 2\n                    w //= 2\n                    num_pad = mm_data.end - mm_data.start - max(h, w)\n                    mrope_delta -= num_pad\n                    fill_start = mm_data.start - pos_start\n                    fill_end = mm_data.end - pos_start\n                    img_pos_ids = self._get_multimodal_pos_ids(grid_thw, pos_ids.device)\n                    img_pos_ids += mrope_pos_ids[:, fill_start:fill_start + 1]\n                    mrope_pos_ids[:, fill_end:] -= num_pad\n                    mrope_pos_ids[:, fill_start:fill_end] = img_pos_ids\n                elif mm_data.modality == Modality.VIDEO:\n                    video_token_id = self.config.video_token_id\n                    grid_thw = mm_data.meta['grid_thw']\n\n                    grid_thw = torch.repeat_interleave(grid_thw, grid_thw[:, 0], dim=0)\n                    grid_thw[:, 0] = 1\n\n                    position_ids_list = []\n                    input_tokens = context.input_ids.tolist()[0]\n\n                    st = 0\n                    # treat each frame separately as a single image\n                    for video_idx in range(grid_thw.shape[0]):\n                        # text before video. e.g. <0.3 seconds><|vision_start|> ...\n                        ed_video = input_tokens.index(video_token_id, st)\n                        ed = ed_video\n                        text_len = ed - st\n                        st_idx = position_ids_list[-1].max() + 1 if len(position_ids_list) > 0 else 0\n                        text_pos_ids = torch.arange(text_len, device=pos_ids.device).view(1, -1).expand(3, -1) + st_idx\n                        position_ids_list.append(text_pos_ids)\n\n                        # video frame. <video_pad> ... <|video_end|>\n                        t, h, w = (\n                            grid_thw[video_idx][0],\n                            grid_thw[video_idx][1] // 2,\n                            grid_thw[video_idx][2] // 2,\n                        )\n                        video_pos_ids = self._get_multimodal_pos_ids(grid_thw[video_idx], pos_ids.device)\n                        position_ids_list.append(video_pos_ids + text_len + st_idx)\n\n                        st = ed + t * h * w\n\n                    # text after video, <|vision_end|> ...\n                    if st < len(input_tokens):\n                        st_idx = position_ids_list[-1].max() + 1 if len(position_ids_list) > 0 else 0\n                        text_len = len(input_tokens) - st\n                        text_pos_ids = torch.arange(text_len, device=pos_ids.device).view(1, -1).expand(3, -1) + st_idx\n                        position_ids_list.append(text_pos_ids)\n\n                    mrope_pos_ids = torch.cat(position_ids_list, dim=1).reshape(3, -1)\n                    mrope_delta = mrope_pos_ids.max() + 1 - pos_ids.size(0)\n                    mrope_pos_ids += pos_start  # add back the original position offset\n\n            mrope_position_ids.append(mrope_pos_ids)\n            new_model_metas.append(dict(mrope_delta=mrope_delta))\n\n        mrope_position_ids = torch.cat(mrope_position_ids, dim=1)\n        context.mrope_position_ids = mrope_position_ids\n\n        return new_model_metas\n\n    def update_model_metas(self, past_key_values: List[List[torch.Tensor]], inputs_embeds: torch.Tensor | None,\n                           context: StepContext):\n        \"\"\"Update model meta.\"\"\"\n        if context.is_decoding:\n            return self._update_model_meta_decoding(context)\n        else:\n            return self._update_model_meta_prefilling(context)\n\n    def get_input_processor(self) -> BaseModelInputProcessor:\n        \"\"\"Get input processor.\"\"\"\n        return self.input_processor\n"
  },
  {
    "path": "lmdeploy/pytorch/models/qwen3_5_moe.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nfrom typing import Dict, Iterable, List, Tuple\n\nimport torch\nimport torch.distributed as dist\nimport torch.nn.functional as F\nfrom torch import nn\nfrom transformers.configuration_utils import PretrainedConfig\n\nfrom lmdeploy.pytorch.distributed import get_dist_manager\nfrom lmdeploy.pytorch.model_inputs import StepContextManager\nfrom lmdeploy.pytorch.nn import RMSNorm\nfrom lmdeploy.pytorch.nn.moe import build_fused_moe\nfrom lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight\n\nfrom .patch import add_prefix, get_build_model_context\nfrom .qwen3_5 import (Qwen3_5Attention, Qwen3_5DecoderLayer, Qwen3_5ForConditionalGeneration, Qwen3_5GatedDeltaNet,\n                      Qwen3_5MLP, Qwen3_5Model, Qwen3_5TextModel, Qwen3_5TextRotaryEmbedding)\nfrom .qwen3_5 import Qwen3_5VisionModel as Qwen3_5MoeVisionModel\nfrom .qwen3_vl import Qwen3VLInputProcessor as Qwen3_5MoeInputProcessor\n\n\nclass Qwen3_5MoeTopKRouter(nn.Module):\n\n    def __init__(self, config, dtype: torch.dtype | None = None, device: torch.device | None = None):\n        super().__init__()\n        self.top_k = config.num_experts_per_tok\n        self.num_experts = config.num_experts\n        self.hidden_dim = config.hidden_size\n        self.weight = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim, dtype=dtype, device=device))\n\n    def forward(self, hidden_states):\n        hidden_states = hidden_states.reshape(-1, self.hidden_dim)\n        router_logits = F.linear(hidden_states, self.weight)  # (seq_len, num_experts)\n        router_logits = torch.nn.functional.softmax(router_logits, dtype=torch.float, dim=-1)\n        router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1)  # (seq_len, top_k)\n        router_top_value /= router_top_value.sum(dim=-1, keepdim=True)\n        router_top_value = router_top_value.to(router_logits.dtype)\n        router_scores = router_top_value\n        return router_logits, router_scores, router_indices\n\n\nclass Qwen3_5MoeSparseMoeBlock(nn.Module):\n    \"\"\"Sparse MoE block.\"\"\"\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 layer_idx: int,\n                 dtype: torch.dtype | None = None,\n                 device: torch.device | None = None,\n                 prefix: str = ''):\n        super().__init__()\n        quantization_config = getattr(config, 'quantization_config', None)\n        self.layer_idx = layer_idx\n        self.hidden_dim = config.hidden_size\n        self.ffn_dim = config.moe_intermediate_size\n        self.num_experts = config.num_experts\n        self.top_k = config.num_experts_per_tok\n\n        self.gate = Qwen3_5MoeTopKRouter(config, dtype=dtype, device=device)\n\n        self.experts = build_fused_moe(\n            self.hidden_dim,\n            self.ffn_dim,\n            self.num_experts,\n            top_k=self.top_k,\n            renormalize=False,\n            dtype=dtype,\n            device=device,\n            quant_config=quantization_config,\n            all_reduce=False,\n            layer_idx=layer_idx,\n            prefix=add_prefix('experts', prefix),\n        )\n\n        self.shared_expert = Qwen3_5MLP(\n            config=config,\n            intermediate_size=config.shared_expert_intermediate_size,\n            dtype=dtype,\n            device=device,\n            is_tp=True,\n            all_reduce=False,\n            prefix=add_prefix('shared_expert', prefix),\n        )\n        self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False, device=device, dtype=dtype)\n\n        # get all reduce\n        dist_ctx = get_dist_manager().current_context()\n        dp = dist_ctx.dist_config.dp\n        world_size = dist_ctx.dist_config.moe_tp\n        if dp == 1 and world_size > 1:\n            self._all_reduce = True\n        else:\n            self._all_reduce = False\n\n    def forward(self, hidden_states: torch.Tensor, all_routed_experts: torch.Tensor | None = None):\n        \"\"\"forward.\"\"\"\n        batch_size, sequence_length, hidden_dim = hidden_states.shape\n        hidden_states = hidden_states.reshape(-1, hidden_dim)\n        router_logits, topk_weights, topk_ids = self.gate(hidden_states)\n        if all_routed_experts is not None:\n            all_routed_experts[:, self.layer_idx, :] = topk_ids\n        out_states = self.experts(\n            hidden_states,\n            topk_weights,\n            topk_ids,\n        )\n\n        shared_states = self.shared_expert(hidden_states)\n        shared_states = self.shared_expert_gate(hidden_states).sigmoid() * shared_states\n\n        out_states += shared_states\n        out_states = out_states.reshape(batch_size, sequence_length, -1)\n\n        if self._all_reduce:\n            dist.all_reduce(out_states)\n        return out_states\n\n\nclass Qwen3_5MoeDecoderLayer(Qwen3_5DecoderLayer):\n    \"\"\"Decoder layer.\"\"\"\n\n    def __init__(\n        self,\n        config: PretrainedConfig,\n        layer_idx: int,\n        dtype: torch.dtype | None = None,\n        device: torch.device | None = None,\n        prefix: str = '',\n    ):\n        nn.Module.__init__(self)\n        self.layer_idx = layer_idx\n        quantization_config = getattr(config, 'quantization_config', None)\n\n        # build attention layer\n        self.layer_type = config.layer_types[layer_idx]\n        if self.layer_type == 'linear_attention':\n            self.linear_attn = Qwen3_5GatedDeltaNet(config,\n                                                    layer_idx,\n                                                    dtype=dtype,\n                                                    device=device,\n                                                    prefix=add_prefix('linear_attn', prefix))\n        elif self.layer_type == 'full_attention':\n            self.self_attn = Qwen3_5Attention(config,\n                                              layer_idx,\n                                              dtype=dtype,\n                                              device=device,\n                                              prefix=add_prefix('self_attn', prefix))\n\n        # build MLP\n        self.mlp = Qwen3_5MoeSparseMoeBlock(config,\n                                            layer_idx,\n                                            dtype=dtype,\n                                            device=device,\n                                            prefix=add_prefix('mlp', prefix))\n\n        # build input layer norm\n        self.input_layernorm = RMSNorm(\n            config.hidden_size,\n            config.rms_norm_eps,\n            quant_config=quantization_config,\n            dtype=dtype,\n            device=device,\n            prefix=add_prefix('input_layernorm', prefix),\n        )\n\n        # build attention layer norm\n        self.post_attention_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)\n\n\nclass Qwen3_5MoeTextModel(Qwen3_5TextModel):\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 dtype: torch.dtype | None = None,\n                 device: torch.device | None = None,\n                 prefix: str = ''):\n        nn.Module.__init__(self)\n        self.config = config\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n\n        self.embed_tokens = nn.Embedding(config.vocab_size,\n                                         config.hidden_size,\n                                         self.padding_idx,\n                                         dtype=dtype,\n                                         device=device)\n\n        # build all decode layers\n        # TODO: use full config.num_hidden_layers\n        self.layers = nn.ModuleList([\n            Qwen3_5MoeDecoderLayer(config,\n                                   layer_idx,\n                                   dtype=dtype,\n                                   device=device,\n                                   prefix=add_prefix(f'layers.{layer_idx}', prefix))\n            for layer_idx in range(self.config.num_hidden_layers)\n        ])\n\n        # build norm\n        self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)\n\n        # build rotary embedding\n        self.rotary_emb = Qwen3_5TextRotaryEmbedding(config, device=device)\n\n\nclass Qwen3_5MoeModel(Qwen3_5Model):\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 dtype: torch.dtype | None = None,\n                 device: torch.device | None = None,\n                 prefix: str = ''):\n        nn.Module.__init__(self)\n\n        self.visual = Qwen3_5MoeVisionModel(config.vision_config,\n                                            dtype=dtype,\n                                            device=device,\n                                            prefix=add_prefix('visual', prefix))\n        self.language_model = Qwen3_5MoeTextModel(config.text_config,\n                                                  dtype=dtype,\n                                                  device=device,\n                                                  prefix=add_prefix('language_model', prefix))\n\n\nclass Qwen3_5MoeForConditionalGeneration(Qwen3_5ForConditionalGeneration):\n    \"\"\"ModelForCausalLM.\"\"\"\n\n    packed_modules_mapping = {\n        'qkv_proj': [\n            'q_proj',\n            'k_proj',\n            'v_proj',\n        ],\n        'gate_up_proj': [\n            'gate_proj',\n            'up_proj',\n        ],\n    }\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 ctx_mgr: StepContextManager,\n                 dtype: torch.dtype | None = None,\n                 device: torch.device | None = None,\n                 prefix: str = ''):\n        nn.Module.__init__(self)\n        self.config = config\n        self.ctx_mgr = ctx_mgr\n\n        # build preprocessor\n        self.input_processor = Qwen3_5MoeInputProcessor(self.config)\n\n        # build model\n        self.model = Qwen3_5MoeModel(config, dtype=dtype, device=device, prefix=add_prefix('model', prefix))\n        # build lm_head\n        self.lm_head = self.build_lm_head(config.text_config.hidden_size,\n                                          config.text_config.vocab_size,\n                                          bias=False,\n                                          dtype=dtype,\n                                          device=device)\n        # for router replay\n        bm_ctx = get_build_model_context()\n        self.enable_return_routed_experts = bm_ctx.enable_return_routed_experts\n\n    def _load_weight_experts(self, name: str, loaded_weight: torch.Tensor, params_dict: Dict[str, nn.Parameter],\n                             expert_params_mapping: List):\n        \"\"\"Load weight experts.\"\"\"\n        # this func is not used, but it has same layout with tranformers implementation\n        # so I will keep it for now.\n        # load fused weights\n        for (param_name, weight_name, expert_id, shard_id) in expert_params_mapping:\n            if weight_name not in name:\n                continue\n            name = name.replace(weight_name, param_name)\n            param = params_dict[name]\n            load_weight(param, loaded_weight, expert_id=expert_id, shard_id=shard_id)\n            break\n        else:\n            param = params_dict[name]\n            load_weight(param, loaded_weight)\n\n    def _load_weight_fused_experts(self, name: str, loaded_weight: torch.Tensor, params_dict: Dict[str, nn.Parameter]):\n        \"\"\"Load weight of fused expert weights.\"\"\"\n        num_experts = self.config.text_config.num_experts\n        fused_gateup_name = 'gate_up_proj'\n        fused_down_name = 'down_proj'\n        if fused_gateup_name in name:\n\n            for expert_id in range(num_experts):\n                param_name = name.replace(f'experts.{fused_gateup_name}', 'experts.gate_up.weight')\n                param = params_dict[param_name]\n                weight = loaded_weight[expert_id]\n                w1, w3 = weight.chunk(2, 0)\n                load_weight(param, w1, expert_id=expert_id, shard_id='gate')\n                load_weight(param, w3, expert_id=expert_id, shard_id='up')\n\n        elif fused_down_name in name:\n\n            for expert_id in range(num_experts):\n                param_name = name.replace(f'experts.{fused_down_name}', 'experts.down.weight')\n                param = params_dict[param_name]\n                w2 = loaded_weight[expert_id]\n                load_weight(param, w2, expert_id=expert_id, shard_id='down')\n\n    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):\n        \"\"\"Load weights.\"\"\"\n\n        def __skip_layers(name):\n            \"\"\"We might change the number of layers so we can debug the model\n            with less gpus.\"\"\"\n            import re\n            if '.layers.' not in name:\n                return False\n            matches = re.findall(r'\\.layers\\.(\\d+)\\.', name)\n            layer_id = int(matches[0])\n            return layer_id >= self.config.text_config.num_hidden_layers\n\n        # modify from vllm\n        stacked_params_mapping = [\n            # (param_name, shard_name, shard_id)\n            ('.qkv_proj', '.q_proj', 'q'),\n            ('.qkv_proj', '.k_proj', 'k'),\n            ('.qkv_proj', '.v_proj', 'v'),\n            ('.gate_up_proj', '.gate_proj', 0),\n            ('.gate_up_proj', '.up_proj', 1),\n            ('.in_proj_zba', '.in_proj_z', 'z'),\n            ('.in_proj_zba', '.in_proj_b', 'b'),\n            ('.in_proj_zba', '.in_proj_a', 'a'),\n        ]\n\n        # expert map\n        num_experts = self.config.text_config.num_experts\n        expert_params_mapping = []\n        for exp_id in range(num_experts):\n            gate_param = ('.experts.gate_up', f'.experts.{exp_id}.gate_proj', exp_id, 'gate')\n            up_param = ('.experts.gate_up', f'.experts.{exp_id}.up_proj', exp_id, 'up')\n            down_param = ('.experts.down', f'.experts.{exp_id}.down_proj', exp_id, 'down')\n            expert_params_mapping += [gate_param, up_param, down_param]\n\n        rms_norm_keys = ['model.norm', '.input_layernorm', '.post_attention_layernorm', '.q_norm', '.k_norm']\n\n        params_dict = dict(self.named_parameters())\n        for name, loaded_weight in weights:\n\n            if __skip_layers(name):\n                continue\n\n            if 'mtp.' in name:\n                continue\n            if 'rotary_emb.inv_freq' in name:\n                continue\n            if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):\n                continue\n            if self.config.tie_word_embeddings and 'lm_head.weight' in name:\n                continue\n\n            if '.experts' in name and '.shared_expert' not in name:\n                self._load_weight_fused_experts(name, loaded_weight, params_dict)\n            else:\n                for (param_name, weight_name, shard_id) in stacked_params_mapping:\n                    if weight_name not in name:\n                        continue\n                    name = name.replace(weight_name, param_name)\n                    param = params_dict[name]\n                    load_weight(param, loaded_weight, shard_id=shard_id)\n                    break\n                else:\n                    if '.qkv.' in name:\n                        # vl attention\n                        param = params_dict[name]\n                        q, k, v = param.weight_spliter(loaded_weight)\n                        load_weight(param, q, shard_id='q')\n                        load_weight(param, k, shard_id='k')\n                        load_weight(param, v, shard_id='v')\n                    else:\n                        for rms_norm_key in rms_norm_keys:\n                            if rms_norm_key in name and 'weight' in name:\n                                loaded_weight = loaded_weight + 1\n                                break\n                        param = params_dict[name]\n                        load_weight(param, loaded_weight)\n"
  },
  {
    "path": "lmdeploy/pytorch/models/qwen3_moe.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nfrom typing import Any, Dict, Iterable, List, Tuple\n\nimport torch\nfrom torch import nn\nfrom transformers.configuration_utils import PretrainedConfig\n\nfrom lmdeploy.pytorch.distributed import get_dist_manager, get_ep_world_rank\nfrom lmdeploy.pytorch.model_inputs import StepContext, StepContextManager\nfrom lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, RMSNorm, SiluAndMul, build_rotary_embedding_from_config\nfrom lmdeploy.pytorch.nn.eplb import EPLBManager\nfrom lmdeploy.pytorch.nn.linear import build_merged_colwise_linear, build_qkv_proj, build_rowwise_linear\nfrom lmdeploy.pytorch.nn.moe import SoftmaxTopK, build_fused_moe\nfrom lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight\n\nfrom .patch import add_prefix, get_build_model_context\nfrom .utils.cudagraph import CudaGraphMixin\nfrom .utils.model import DeployModelMixinV1, build_embedding\n\n\nclass Qwen3MoeAttention(nn.Module):\n    \"\"\"Rewrite module of Qwen3MoeAttention.\"\"\"\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None,\n                 prefix: str = ''):\n        super().__init__()\n        quantization_config = getattr(config, 'quantization_config', None)\n        num_heads = config.num_attention_heads\n        num_key_value_heads = config.num_key_value_heads\n        hidden_size = config.hidden_size\n        head_dim = getattr(config, 'head_dim', hidden_size // num_heads)\n        num_replicate_kv_heads = getattr(config, 'num_replicate_key_value_heads', 1)\n\n        # packed qkv\n        # Qwen3 uses 'config.attention_bias = False' for q/k/o projections\n        self.qkv_proj = build_qkv_proj(\n            hidden_size,\n            num_q_heads=num_heads,\n            num_kv_heads=num_key_value_heads,\n            head_size=head_dim,\n            bias=config.attention_bias,\n            quant_config=quantization_config,\n            num_replicate_kv_heads=num_replicate_kv_heads,\n            dtype=dtype,\n            device=device,\n            prefix=add_prefix('qkv_proj', prefix),\n        )\n        # rotary embedding\n        self.apply_rotary_pos_emb = ApplyRotaryEmb()\n\n        # attention\n        self.attn_fwd = Attention(\n            num_heads,\n            head_dim,\n            num_kv_heads=num_key_value_heads,\n            v_head_size=head_dim,\n            sliding_window=getattr(config, 'sliding_window', None),\n        )\n\n        # o_proj\n        self.o_proj = build_rowwise_linear(\n            num_heads * head_dim,\n            hidden_size,\n            bias=config.attention_bias,\n            quant_config=quantization_config,\n            dtype=dtype,\n            device=device,\n            is_tp=True,\n            prefix=add_prefix('o_proj', prefix),\n        )\n\n        # q, k norm\n        self.q_norm = RMSNorm(\n            head_dim,\n            config.rms_norm_eps,\n            quant_config=quantization_config,\n            dtype=dtype,\n            device=device,\n            prefix=add_prefix('q_norm', prefix),\n        )\n        self.k_norm = RMSNorm(\n            head_dim,\n            config.rms_norm_eps,\n            quant_config=quantization_config,\n            dtype=dtype,\n            device=device,\n            prefix=add_prefix('k_norm', prefix),\n        )\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],\n        past_key_value: Tuple[torch.Tensor] | None = None,\n        attn_metadata: Any = None,\n    ):\n        \"\"\"Rewrite of LlamaAttention.forward.\"\"\"\n        # qkv proj\n        qkv_states = self.qkv_proj(hidden_states)\n        # (-1, heads, head_dim)\n        qkv_states = qkv_states.flatten(0, -2)\n        query_states, key_states, value_states = self.qkv_proj.split_qkv(qkv_states)\n\n        # apply q, k norm\n        query_states = self.q_norm(query_states)\n        key_states = self.k_norm(key_states)\n\n        # apply rotary embedding\n        cos, sin = rotary_pos_emb\n        query_states, key_states = self.apply_rotary_pos_emb(\n            query_states,\n            key_states,\n            cos,\n            sin,\n            inplace=True,\n        )\n\n        # attention\n        attn_output = self.attn_fwd(\n            query_states,\n            key_states,\n            value_states,\n            past_key_value[0],\n            past_key_value[1],\n            attn_metadata,\n            k_scales_zeros=None if len(past_key_value) == 2 else past_key_value[2],\n            v_scales_zeros=None if len(past_key_value) == 2 else past_key_value[3],\n            inplace=True,\n        )\n        attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1)\n\n        # o proj\n        attn_output = self.o_proj(attn_output)\n        return attn_output\n\n\nclass Qwen3MoeMLP(nn.Module):\n    \"\"\"mlp.\"\"\"\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 intermediate_size: int = None,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None,\n                 is_tp: bool = True,\n                 all_reduce: bool = True,\n                 prefix: str = ''):\n        super().__init__()\n        quantization_config = getattr(config, 'quantization_config', None)\n        if intermediate_size is None:\n            intermediate_size = config.intermediate_size\n        # gate up\n        self.gate_up_proj = build_merged_colwise_linear(\n            config.hidden_size,\n            [intermediate_size, intermediate_size],\n            bias=False,\n            dtype=dtype,\n            device=device,\n            quant_config=quantization_config,\n            is_tp=is_tp,\n            prefix=add_prefix('gate_up_proj', prefix),\n        )\n\n        # silu and mul\n        self.act_fn = SiluAndMul(inplace=True)\n\n        # down\n        self.down_proj = build_rowwise_linear(\n            intermediate_size,\n            config.hidden_size,\n            bias=False,\n            quant_config=quantization_config,\n            dtype=dtype,\n            device=device,\n            is_tp=is_tp,\n            all_reduce=all_reduce,\n            prefix=add_prefix('down_proj', prefix),\n        )\n\n    def forward(self, x):\n        \"\"\"forward.\"\"\"\n        gate_up = self.gate_up_proj(x)\n        act = self.act_fn(gate_up)\n        return self.down_proj(act)\n\n\nclass Qwen3MoeSparseMoeBlock(nn.Module):\n    \"\"\"Moe block.\"\"\"\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 layer_idx: int,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None,\n                 prefix: str = ''):\n        super().__init__()\n        quantization_config = getattr(config, 'quantization_config', None)\n        self.layer_idx = layer_idx\n        self.hidden_dim = config.hidden_size\n        self.ffn_dim = config.moe_intermediate_size\n        self.num_experts = config.num_experts\n        self.top_k = config.num_experts_per_tok\n        self.norm_topk_prob = config.norm_topk_prob\n        self.renormalize = self.norm_topk_prob\n\n        self.gate = build_rowwise_linear(\n            self.hidden_dim,\n            self.num_experts,\n            bias=False,\n            dtype=dtype,\n            device=device,\n            is_tp=False,\n        )\n        self.softmax_topk = SoftmaxTopK(\n            self.top_k,\n            n_groups=getattr(config, 'router_n_groups', -1),\n        )\n\n        if get_dist_manager().current_context().dist_config.enable_eplb:\n            dist_ctx = get_dist_manager().current_context()\n            self.eplb_dispatch_info = EPLBManager.get_dispatch_info(\n                ep_rank=dist_ctx.ep_rank,\n                layer_idx=layer_idx,\n            )\n            self.num_experts = EPLBManager.num_physical_experts()\n        self.experts = build_fused_moe(\n            self.hidden_dim,\n            self.ffn_dim,\n            self.num_experts,\n            top_k=self.top_k,\n            renormalize=self.renormalize,\n            dtype=dtype,\n            device=device,\n            quant_config=quantization_config,\n            all_reduce=True,\n            layer_idx=layer_idx,\n            prefix=add_prefix('experts', prefix),\n        )\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        all_routed_experts: torch.Tensor = None,\n    ):\n        \"\"\"forward.\"\"\"\n        batch_size, sequence_length, hidden_dim = hidden_states.shape\n        hidden_states = hidden_states.view(-1, hidden_dim)\n        router_logits = self.gate(hidden_states)\n        topk_weights, topk_ids = self.softmax_topk(router_logits)\n        if all_routed_experts is not None:\n            all_routed_experts[:, self.layer_idx, :] = topk_ids\n        if get_dist_manager().current_context().dist_config.enable_eplb:\n            topk_ids = EPLBManager.topk_ids_logical_to_physical(topk_ids, self.eplb_dispatch_info)\n        out_states = self.experts(\n            hidden_states,\n            topk_weights,\n            topk_ids,\n        )\n\n        out_states = out_states.reshape(batch_size, sequence_length, -1)\n        return out_states\n\n\nclass Qwen3MoeDecoderLayer(nn.Module):\n    \"\"\"Decoder layer.\"\"\"\n\n    def __init__(\n        self,\n        config: PretrainedConfig,\n        layer_idx: int,\n        dtype: torch.dtype = None,\n        device: torch.device = None,\n        prefix: str = '',\n    ):\n        super().__init__()\n        self.layer_idx = layer_idx\n        quantization_config = getattr(config, 'quantization_config', None)\n\n        # build attention layer\n        self.self_attn = Qwen3MoeAttention(config, dtype=dtype, device=device, prefix=add_prefix('self_attn', prefix))\n\n        # build MLP\n        if (layer_idx not in config.mlp_only_layers) and (config.num_experts\n                                                          > 0) and ((layer_idx + 1) % config.decoder_sparse_step == 0):\n            self.mlp = Qwen3MoeSparseMoeBlock(config,\n                                              layer_idx=layer_idx,\n                                              dtype=dtype,\n                                              device=device,\n                                              prefix=add_prefix('mlp', prefix))\n        else:\n            self.mlp = Qwen3MoeMLP(config,\n                                   intermediate_size=config.intermediate_size,\n                                   dtype=dtype,\n                                   device=device,\n                                   prefix=add_prefix('mlp', prefix))\n\n        # build input layer norm\n        self.input_layernorm = RMSNorm(config.hidden_size,\n                                       config.rms_norm_eps,\n                                       quant_config=quantization_config,\n                                       dtype=dtype,\n                                       device=device,\n                                       prefix=add_prefix('input_layernorm', prefix))\n\n        # build attention layer norm\n        self.post_attention_layernorm = RMSNorm(config.hidden_size,\n                                                config.rms_norm_eps,\n                                                dtype=dtype,\n                                                device=device,\n                                                prefix=add_prefix('post_attention_layernorm', prefix))\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],\n        past_key_value: List[torch.FloatTensor] | None,\n        residual: torch.Tensor | None = None,\n        attn_metadata: Any = None,\n        all_routed_experts: torch.Tensor = None,\n    ):\n\n        if residual is None:\n            residual = hidden_states\n            hidden_states = self.input_layernorm(hidden_states)\n        else:\n            hidden_states, residual = self.input_layernorm(hidden_states, residual)\n\n        # Self Attention\n        hidden_states = self.self_attn(\n            hidden_states=hidden_states,\n            rotary_pos_emb=rotary_pos_emb,\n            past_key_value=past_key_value,\n            attn_metadata=attn_metadata,\n        )\n\n        # Fully Connected\n        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)\n        hidden_states = self.mlp(hidden_states, all_routed_experts=all_routed_experts)\n\n        outputs = (hidden_states, residual)\n        return outputs\n\n\nclass Qwen3MoeModel(nn.Module):\n    \"\"\"model.\"\"\"\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None,\n                 prefix: str = ''):\n        super().__init__()\n        self.padding_idx = getattr(config, 'pad_token_id', None)\n        self.vocab_size = config.vocab_size\n        self.embed_tokens = build_embedding(\n            config.vocab_size,\n            config.hidden_size,\n            self.padding_idx,\n            dtype=dtype,\n            device=device,\n        )\n\n        if get_dist_manager().current_context().dist_config.enable_eplb:\n            ep_size, _ = get_ep_world_rank()\n            EPLBManager.init_global_eplb_metadata(\n                ep_size=ep_size,\n                num_routed_experts=config.num_experts,\n                num_hidden_layers=config.num_hidden_layers,\n            )\n\n        # build all decode layers\n        self.layers = nn.ModuleList([\n            Qwen3MoeDecoderLayer(config,\n                                 layer_idx,\n                                 dtype=dtype,\n                                 device=device,\n                                 prefix=add_prefix(f'layers.{layer_idx}', prefix))\n            for layer_idx in range(config.num_hidden_layers)\n        ])\n\n        # build norm\n        self.norm = RMSNorm(config.hidden_size,\n                            config.rms_norm_eps,\n                            dtype=dtype,\n                            device=device,\n                            prefix=add_prefix('norm', prefix))\n\n        # build rotary embedding\n        self.rotary_emb = build_rotary_embedding_from_config(config, device=device)\n\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        position_ids: torch.LongTensor | None = None,\n        past_key_values: List[torch.FloatTensor] | None = None,\n        attn_metadata: Any = None,\n        inputs_embeds: torch.FloatTensor | None = None,\n        all_routed_experts: torch.Tensor = None,\n    ):\n        \"\"\"Rewrite of LlamaModel.forward.\"\"\"\n\n        # token embedding\n        if inputs_embeds is None:\n            inputs_embeds = self.embed_tokens(input_ids)\n\n        hidden_states = inputs_embeds\n\n        # rotary embedding\n        cos, sin = self.rotary_emb(hidden_states, position_ids)\n        cos, sin = cos[0], sin[0]\n        rotary_pos_emb = (cos, sin)\n\n        # decoding\n        residual = None\n        for idx, decoder_layer in enumerate(self.layers):\n            past_key_value = past_key_values[idx]\n            hidden_states, residual = decoder_layer(\n                hidden_states,\n                rotary_pos_emb=rotary_pos_emb,\n                past_key_value=past_key_value,\n                residual=residual,\n                attn_metadata=attn_metadata,\n                all_routed_experts=all_routed_experts,\n            )\n\n        # norm\n        hidden_states, _ = self.norm(hidden_states, residual)\n\n        return hidden_states\n\n    def get_input_embeddings(self):\n        \"\"\"Get input embeddings.\"\"\"\n        return self.embed_tokens\n\n\nclass Qwen3MoeForCausalLM(nn.Module, DeployModelMixinV1, CudaGraphMixin):\n    \"\"\"ModelForCausalLM.\"\"\"\n\n    packed_modules_mapping = {\n        'qkv_proj': [\n            'q_proj',\n            'k_proj',\n            'v_proj',\n        ],\n        'gate_up_proj': [\n            'gate_proj',\n            'up_proj',\n        ],\n    }\n\n    def __init__(\n        self,\n        config: PretrainedConfig,\n        ctx_mgr: StepContextManager,\n        dtype: torch.dtype = None,\n        device: torch.device = None,\n        prefix: str = '',\n    ):\n        super().__init__()\n        self.config = config\n        self.ctx_mgr = ctx_mgr\n\n        # build model\n        self.model = Qwen3MoeModel(config, dtype=dtype, device=device, prefix=add_prefix('model', prefix))\n        # build lm_head\n        self.lm_head = self.build_lm_head(\n            config.hidden_size,\n            config.vocab_size,\n            bias=False,\n            dtype=dtype,\n            device=device,\n        )\n        # for router replay\n        bm_ctx = get_build_model_context()\n        self.enable_return_routed_experts = bm_ctx.enable_return_routed_experts\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        past_key_values: List[List[torch.Tensor]],\n        attn_metadata: Any = None,\n        inputs_embeds: torch.Tensor = None,\n        **kwargs,\n    ):\n        \"\"\"Model forward, return logits.\"\"\"\n\n        # router replay\n        all_routed_experts = None\n        if self.enable_return_routed_experts:\n            if inputs_embeds is not None:\n                num_tokens = inputs_embeds.size(1)\n            else:\n                num_tokens = input_ids.size(1)\n            all_routed_experts = position_ids.new_empty(\n                (num_tokens, self.config.num_hidden_layers, self.config.num_experts_per_tok), dtype=torch.uint16)\n\n        hidden_states = self.model(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            inputs_embeds=inputs_embeds,\n            all_routed_experts=all_routed_experts,\n        )\n        if all_routed_experts is None:\n            return hidden_states\n        return dict(hidden_states=hidden_states, all_routed_experts=all_routed_experts)\n\n    def get_input_embeddings(self):\n        \"\"\"Get input embeddings.\"\"\"\n        return self.model.get_input_embeddings()\n\n    def prepare_inputs_for_generation(\n        self,\n        past_key_values: List[List[torch.Tensor]],\n        inputs_embeds: torch.Tensor | None = None,\n        context: StepContext = None,\n    ):\n        \"\"\"Prepare input.\"\"\"\n        # get input_ids, position_ids and attention metadatas\n        input_ids = context.input_ids\n        position_ids = context.position_ids\n        attn_metadata = context.attn_metadata\n\n        # process vision embeddings\n        vision_embeddings = context.input_embeddings\n        vision_embedding_indexing = context.input_embedding_indexing\n        if vision_embeddings is not None and len(vision_embeddings) > 0:\n            if inputs_embeds is None:\n                inputs_embeds = self.get_input_embeddings()(input_ids)\n            inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds)\n\n        # inputs of forward\n        return dict(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            inputs_embeds=inputs_embeds,\n        )\n\n    def _load_weight_experts(self, name: str, loaded_weight: torch.Tensor, params_dict: Dict[str, nn.Parameter],\n                             expert_params_mapping: List):\n        \"\"\"Load weight experts.\"\"\"\n        # load fused weights\n        if any([k in name for k in ['fused_w1w3', 'fused_w2']]):\n            return self._load_weight_fused_experts(name, loaded_weight, params_dict)\n\n        for (param_name, weight_name, expert_id, shard_id) in expert_params_mapping:\n            if weight_name not in name:\n                continue\n            name = name.replace(weight_name, param_name)\n            param = params_dict[name]\n            load_weight(param, loaded_weight, expert_id=expert_id, shard_id=shard_id)\n            break\n        else:\n            param = params_dict[name]\n            load_weight(param, loaded_weight)\n\n    def _load_weight_fused_experts(self, name: str, loaded_weight: torch.Tensor, params_dict: Dict[str, nn.Parameter]):\n        \"\"\"Load weight of fused expert weights.\"\"\"\n        num_experts = self.config.num_experts\n        fused_gateup_name = 'fused_w1w3'\n        fused_down_name = 'fused_w2'\n        if fused_gateup_name in name:\n            chunk_size = loaded_weight.shape[0] // num_experts\n\n            for expert_id in range(num_experts):\n                param_name = name.replace(f'experts.{fused_gateup_name}', 'experts.gate_up')\n                param = params_dict[param_name]\n                w1 = loaded_weight.narrow(dim=0, start=chunk_size * expert_id, length=chunk_size // 2)\n                w3 = loaded_weight.narrow(dim=0, start=chunk_size * expert_id + chunk_size // 2, length=chunk_size // 2)\n                load_weight(param, w1, expert_id=expert_id, shard_id='gate')\n                load_weight(param, w3, expert_id=expert_id, shard_id='up')\n\n        elif fused_down_name in name:\n            chunk_size = loaded_weight.shape[0] // num_experts\n\n            for expert_id in range(num_experts):\n                param_name = name.replace(f'experts.{fused_down_name}', 'experts.down')\n                param = params_dict[param_name]\n                w2 = loaded_weight.narrow(dim=0, start=chunk_size * expert_id, length=chunk_size)\n                load_weight(param, w2, expert_id=expert_id, shard_id='down')\n\n    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):\n        \"\"\"Load weights.\"\"\"\n        # modify from vllm\n        stacked_params_mapping = [\n            # (param_name, shard_name, shard_id)\n            ('.qkv_proj', '.q_proj', 'q'),\n            ('.qkv_proj', '.k_proj', 'k'),\n            ('.qkv_proj', '.v_proj', 'v'),\n            ('.gate_up_proj', '.gate_proj', 0),\n            ('.gate_up_proj', '.up_proj', 1),\n        ]\n\n        # expert map\n        num_experts = self.config.num_experts\n        expert_params_mapping = []\n        for exp_id in range(num_experts):\n            gate_param = ('.experts.gate_up', f'.experts.{exp_id}.gate_proj', exp_id, 'gate')\n            up_param = ('.experts.gate_up', f'.experts.{exp_id}.up_proj', exp_id, 'up')\n            down_param = ('.experts.down', f'.experts.{exp_id}.down_proj', exp_id, 'down')\n            expert_params_mapping += [gate_param, up_param, down_param]\n\n        params_dict = dict(self.named_parameters())\n        for name, loaded_weight in weights:\n            if 'rotary_emb.inv_freq' in name:\n                continue\n            if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):\n                continue\n            if self.config.tie_word_embeddings and 'lm_head.weight' in name:\n                continue\n            name = name.replace('.block_sparse_moe.', '.mlp.')\n            if '.experts' in name:\n                self._load_weight_experts(name, loaded_weight, params_dict, expert_params_mapping=expert_params_mapping)\n            else:\n                for (param_name, weight_name, shard_id) in stacked_params_mapping:\n                    if weight_name not in name:\n                        continue\n                    name = name.replace(weight_name, param_name)\n                    param = params_dict[name]\n                    load_weight(param, loaded_weight, shard_id=shard_id)\n                    break\n                else:\n                    param = params_dict[name]\n                    load_weight(param, loaded_weight)\n"
  },
  {
    "path": "lmdeploy/pytorch/models/qwen3_next.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nfrom typing import Any, Dict, Iterable, List, Optional, Tuple\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn\nfrom transformers.configuration_utils import PretrainedConfig\n\nimport lmdeploy.pytorch.distributed as dist\nimport lmdeploy.pytorch.nn.gated_delta as gated_delta_util\nfrom lmdeploy.pytorch.distributed import get_dist_manager, get_tp_world_rank\nfrom lmdeploy.pytorch.model_inputs import StepContext, StepContextManager\nfrom lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, RMSNorm, SiluAndMul, build_rotary_embedding_from_config\nfrom lmdeploy.pytorch.nn.gated_delta import CausalConv1d, GatedDelta, GatedDeltaMeta, build_rmsnorm_gated\nfrom lmdeploy.pytorch.nn.linear import (build_colwise_linear, build_merged_colwise_linear, build_o_proj, build_qkv_proj,\n                                        build_rowwise_linear)\nfrom lmdeploy.pytorch.nn.moe import SoftmaxTopK, build_fused_moe\nfrom lmdeploy.pytorch.weight_loader.model_weight_loader import default_weight_loader, load_weight\n\nfrom .utils.cudagraph import CudaGraphMeta, CudaGraphMixin\nfrom .utils.model import DeployModelMixinV1, build_embedding\n\n\nclass Qwen3NextGatedDeltaNet(nn.Module):\n    \"\"\"Gated deltanet.\"\"\"\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 layer_idx: int,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__()\n        self.hidden_size = config.hidden_size\n        self.num_v_heads = config.linear_num_value_heads\n        self.num_k_heads = config.linear_num_key_heads\n        self.head_k_dim = config.linear_key_head_dim\n        self.head_v_dim = config.linear_value_head_dim\n        self.key_dim = self.head_k_dim * self.num_k_heads\n        self.value_dim = self.head_v_dim * self.num_v_heads\n        self.kv_ratio = self.num_v_heads // self.num_k_heads\n\n        self.conv_kernel_size = config.linear_conv_kernel_dim\n        self.layer_idx = layer_idx\n        self.activation = config.hidden_act\n        self.layer_norm_epsilon = config.rms_norm_eps\n\n        # QKV\n        self.conv_dim = self.key_dim * 2 + self.value_dim\n        self.conv1d = CausalConv1d(\n            in_channels=self.conv_dim,\n            out_channels=self.conv_dim,\n            kernel_size=self.conv_kernel_size,\n            split=[self.key_dim, self.key_dim, self.value_dim],\n            bias=False,\n            groups=self.conv_dim,\n            dtype=dtype,\n            device=device,\n        )\n\n        # projection of the input hidden states\n        projection_size_qkvz = self.key_dim * 2 + self.value_dim * 2\n        projection_size_ba = self.num_v_heads * 2\n        self.in_proj_qkvz = build_colwise_linear(self.hidden_size,\n                                                 projection_size_qkvz,\n                                                 bias=False,\n                                                 dtype=dtype,\n                                                 device=device,\n                                                 is_tp=True)\n        self.in_proj_ba = build_colwise_linear(self.hidden_size,\n                                               projection_size_ba,\n                                               bias=False,\n                                               dtype=dtype,\n                                               device=device,\n                                               is_tp=True)\n\n        # time step projection (discretization)\n        # instantiate once and copy inv_dt in init_weights of PretrainedModel\n        self.make_params(self.num_v_heads, device=device)\n        self.A_log_exp = None\n\n        self.norm = build_rmsnorm_gated(self.head_v_dim,\n                                        eps=self.layer_norm_epsilon,\n                                        activation=self.activation,\n                                        dtype=dtype,\n                                        device=device)\n        self.out_proj = build_o_proj(self.value_dim,\n                                     self.hidden_size,\n                                     bias=False,\n                                     dtype=dtype,\n                                     device=device,\n                                     is_tp=True)\n\n        self.gated_delta = GatedDelta()\n\n    def get_A_log_exp(self):\n        if self.A_log_exp is None:\n            self.A_log_exp = -self.A_log.float().exp()\n\n        return self.A_log_exp\n\n    def make_params(self, num_v_heads: int, device: torch.device | None):\n        tp, _ = get_tp_world_rank()\n        num_v_heads = num_v_heads // tp\n        A = torch.empty(num_v_heads, device=device)\n        dt_bias = torch.empty(num_v_heads, device=device)\n\n        self.register_parameter('A_log', nn.Parameter(torch.log(A)))\n        self.register_parameter('dt_bias', nn.Parameter(dt_bias))\n        self.A_log.weight_loader = self.weight_loader_a_dt\n        self.dt_bias.weight_loader = self.weight_loader_a_dt\n\n    def weight_loader_a_dt(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor):\n        \"\"\"Weight loader.\"\"\"\n        tp, rank = get_tp_world_rank()\n        loaded_weight = loaded_weight.chunk(tp, dim=0)[rank]\n        default_weight_loader(param, loaded_weight)\n\n    def fix_query_key_value_ordering(self, mixed_qkvz: torch.Tensor, mixed_ba: torch.Tensor):\n        \"\"\"Derives `query`, `key` and `value` tensors from `mixed_qkvz` and\n        `mixed_ba`.\"\"\"\n        # qkvz\n        split_arg_list_qkvz = [\n            self.head_k_dim * 2,\n            (self.kv_ratio * self.head_v_dim),\n            (self.kv_ratio * self.head_v_dim),\n        ]\n        mixed_qkvz = mixed_qkvz.unflatten(-1, (-1, sum(split_arg_list_qkvz)))\n        qk, value, z = torch.split(mixed_qkvz, split_arg_list_qkvz, dim=-1)\n        qk = qk.unflatten(-1, (2, self.head_k_dim))\n        qk = qk.transpose(-3, -2).flatten(-3, -1)\n        value = value.flatten(-2, -1)\n        mixed_qkv = torch.cat((qk, value), dim=-1)\n        # [..., ng, np/ng * hn] -> [..., np, hn]\n        z = z.reshape(*z.shape[:-2], -1, self.head_v_dim)\n\n        # chunk_ba\n        mixed_ba = mixed_ba.unflatten(-1, (-1, 2 * self.kv_ratio))\n        b, a = mixed_ba.chunk(2, -1)\n        # do sigmoid and float here to prevent contiguous kernel\n        b = b.sigmoid().flatten(-2, -1)\n        a = a.float().flatten(-2, -1)\n        return mixed_qkv, z, b, a\n\n    def _load_state(self, past_key_value: Tuple[torch.Tensor, torch.Tensor], gated_delta_meta: GatedDeltaMeta):\n        \"\"\"Load states from cache.\"\"\"\n        return gated_delta_util.load_state(past_key_value=past_key_value, gated_delta_meta=gated_delta_meta)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        past_key_value: Tuple[torch.Tensor, torch.Tensor],\n        gated_delta_meta: GatedDeltaMeta,\n    ):\n        \"\"\"forward.\"\"\"\n\n        # load states\n        conv_state, recurrent_state = self._load_state(past_key_value, gated_delta_meta)\n\n        # inputs proj\n        projected_states_qkvz = self.in_proj_qkvz(hidden_states)\n        projected_states_ba = self.in_proj_ba(hidden_states)\n        mixed_qkv, z, b, a = self.fix_query_key_value_ordering(projected_states_qkvz, projected_states_ba)\n\n        mixed_qkv, conv_state = self.conv1d(mixed_qkv, conv_state, gated_delta_meta=gated_delta_meta)\n\n        tp = (self.key_dim * 2 + self.value_dim) // mixed_qkv.size(-1)\n        query, key, value = torch.split(\n            mixed_qkv,\n            [\n                self.key_dim // tp,\n                self.key_dim // tp,\n                self.value_dim // tp,\n            ],\n            dim=-1,\n        )\n        query = query.unflatten(-1, (-1, self.head_k_dim))\n        key = key.unflatten(-1, (-1, self.head_k_dim))\n        value = value.unflatten(-1, (-1, self.head_v_dim))\n\n        beta = b\n        # If the model is loaded in fp16, without the .float() here, A might be -inf\n        g = self.get_A_log_exp() * F.softplus(a + self.dt_bias)\n        if self.kv_ratio > 1:\n            query = query.repeat_interleave(self.kv_ratio, dim=-2)\n            key = key.repeat_interleave(self.kv_ratio, dim=-2)\n\n        core_attn_out, recurrent_state = self.gated_delta(\n            query,\n            key,\n            value,\n            g=g,\n            beta=beta,\n            recurrent_state=recurrent_state,\n            gated_delta_meta=gated_delta_meta,\n        )\n\n        z_shape_og = z.shape\n        core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1])\n        z = z.reshape(-1, z.shape[-1])\n        core_attn_out = self.norm(core_attn_out, z)\n        core_attn_out = core_attn_out.reshape(z_shape_og)\n        core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1)\n\n        output = self.out_proj(core_attn_out)\n        return output\n\n\nclass Qwen3NextAttention(nn.Module):\n    \"\"\"Rewrite module of Qwen3MoeAttention.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        quantization_config = getattr(config, 'quantization_config', None)\n        num_heads = config.num_attention_heads\n        num_key_value_heads = config.num_key_value_heads\n        hidden_size = config.hidden_size\n        head_dim = getattr(config, 'head_dim', hidden_size // num_heads)\n        self.head_dim = head_dim\n        num_replicate_kv_heads = getattr(config, 'num_replicate_key_value_heads', 1)\n\n        # packed qkv\n        # Qwen3 uses 'config.attention_bias = False' for q/k/o projections\n        self.qkv_proj = build_qkv_proj(\n            hidden_size,\n            num_q_heads=num_heads * 2,\n            num_kv_heads=num_key_value_heads,\n            head_size=head_dim,\n            bias=config.attention_bias,\n            quant_config=quantization_config,\n            num_replicate_kv_heads=num_replicate_kv_heads,\n            dtype=dtype,\n            device=device,\n        )\n\n        # rotary embedding\n        self.apply_rotary_pos_emb = ApplyRotaryEmb()\n\n        # attention\n        self.attn_fwd = Attention(\n            num_heads,\n            head_dim,\n            num_kv_heads=num_key_value_heads,\n            v_head_size=head_dim,\n        )\n\n        # o_proj\n        self.o_proj = build_o_proj(num_heads * head_dim,\n                                   hidden_size,\n                                   bias=config.attention_bias,\n                                   quant_config=quantization_config,\n                                   dtype=dtype,\n                                   device=device,\n                                   is_tp=True)\n\n        # q, k norm\n        self.q_norm = RMSNorm(head_dim,\n                              config.rms_norm_eps,\n                              quant_config=quantization_config,\n                              dtype=dtype,\n                              device=device)\n        self.k_norm = RMSNorm(head_dim,\n                              config.rms_norm_eps,\n                              quant_config=quantization_config,\n                              dtype=dtype,\n                              device=device)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],\n        past_key_value: Optional[Tuple[torch.Tensor]] = None,\n        attn_metadata: Any = None,\n    ):\n        \"\"\"Rewrite of LlamaAttention.forward.\"\"\"\n        # qkv proj\n        qkv_states = self.qkv_proj(hidden_states)\n        # (-1, heads, head_dim)\n        qkv_states = qkv_states.flatten(0, -2)\n        query_states, key_states, value_states = self.qkv_proj.split_qkv(qkv_states)\n        query_states, gate = query_states.view(*query_states.shape[:-2], -1, 2 * self.head_dim).chunk(2, dim=-1)\n\n        # apply q, k norm\n        query_states = self.q_norm(query_states)\n        key_states = self.k_norm(key_states)\n\n        # apply rotary embedding\n        cos, sin = rotary_pos_emb\n        query_states, key_states = self.apply_rotary_pos_emb(\n            query_states,\n            key_states,\n            cos,\n            sin,\n            inplace=True,\n        )\n\n        # attention\n        attn_output = self.attn_fwd(\n            query_states,\n            key_states,\n            value_states,\n            past_key_value[0],\n            past_key_value[1],\n            attn_metadata,\n            k_scales_zeros=None if len(past_key_value) == 2 else past_key_value[2],\n            v_scales_zeros=None if len(past_key_value) == 2 else past_key_value[3],\n            inplace=True,\n        )\n        attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1)\n        gate = gate.reshape(*hidden_states.shape[:-1], -1)\n        attn_output = attn_output * gate.sigmoid()\n\n        # o proj\n        attn_output = self.o_proj(attn_output)\n        return attn_output\n\n\nclass Qwen3NextMLP(nn.Module):\n    \"\"\"mlp.\"\"\"\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 intermediate_size: int = None,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None,\n                 is_tp: bool = True,\n                 all_reduce: bool = True):\n        super().__init__()\n        quantization_config = getattr(config, 'quantization_config', None)\n        if intermediate_size is None:\n            intermediate_size = config.intermediate_size\n        # gate up\n        self.gate_up_proj = build_merged_colwise_linear(\n            config.hidden_size,\n            [intermediate_size, intermediate_size],\n            bias=False,\n            dtype=dtype,\n            device=device,\n            quant_config=quantization_config,\n            is_tp=is_tp,\n        )\n\n        # silu and mul\n        self.act_fn = SiluAndMul(inplace=True)\n\n        # down\n        self.down_proj = build_rowwise_linear(intermediate_size,\n                                              config.hidden_size,\n                                              bias=False,\n                                              quant_config=quantization_config,\n                                              dtype=dtype,\n                                              device=device,\n                                              is_tp=is_tp,\n                                              all_reduce=all_reduce)\n\n    def forward(self, x):\n        \"\"\"forward.\"\"\"\n        gate_up = self.gate_up_proj(x)\n        act = self.act_fn(gate_up)\n        return self.down_proj(act)\n\n\nclass Qwen3NextSparseMoeBlock(nn.Module):\n    \"\"\"Moe block.\"\"\"\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 layer_idx: int,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__()\n        quantization_config = getattr(config, 'quantization_config', None)\n        self.layer_idx = layer_idx\n        self.hidden_dim = config.hidden_size\n        self.ffn_dim = config.moe_intermediate_size\n        self.num_experts = config.num_experts\n        self.top_k = config.num_experts_per_tok\n        self.norm_topk_prob = config.norm_topk_prob\n        self.renormalize = self.norm_topk_prob\n\n        self.gate = build_rowwise_linear(\n            self.hidden_dim,\n            self.num_experts,\n            bias=False,\n            dtype=dtype,\n            device=device,\n            is_tp=False,\n        )\n\n        self.softmax_topk = SoftmaxTopK(\n            self.top_k,\n            n_groups=getattr(config, 'router_n_groups', -1),\n        )\n\n        self.experts = build_fused_moe(\n            self.hidden_dim,\n            self.ffn_dim,\n            self.num_experts,\n            top_k=self.top_k,\n            renormalize=self.renormalize,\n            dtype=dtype,\n            device=device,\n            quant_config=quantization_config,\n            all_reduce=False,\n            layer_idx=layer_idx,\n        )\n\n        self.shared_expert = Qwen3NextMLP(\n            config=config,\n            intermediate_size=config.shared_expert_intermediate_size,\n            dtype=dtype,\n            device=device,\n            is_tp=True,\n            all_reduce=False,\n        )\n        self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False, device=device, dtype=dtype)\n\n        # get all reduce\n        dist_ctx = get_dist_manager().current_context()\n        dp = dist_ctx.dist_config.dp\n        world_size = dist_ctx.dist_config.moe_tp\n        if dp == 1 and world_size > 1:\n            self._all_reduce = True\n        else:\n            self._all_reduce = False\n\n    def forward(self, hidden_states: torch.Tensor):\n        \"\"\"forward.\"\"\"\n        batch_size, sequence_length, hidden_dim = hidden_states.shape\n        hidden_states = hidden_states.view(-1, hidden_dim)\n        router_logits = self.gate(hidden_states)\n        topk_weights, topk_ids = self.softmax_topk(router_logits)\n        out_states = self.experts(\n            hidden_states,\n            topk_weights,\n            topk_ids,\n        )\n\n        shared_states = self.shared_expert(hidden_states)\n        shared_states = self.shared_expert_gate(hidden_states).sigmoid() * shared_states\n\n        out_states += shared_states\n        out_states = out_states.reshape(batch_size, sequence_length, -1)\n\n        if self._all_reduce:\n            dist.all_reduce(out_states)\n        return out_states\n\n\nclass Qwen3NextDecoderLayer(nn.Module):\n    \"\"\"Decoder layer.\"\"\"\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 layer_idx: int,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__()\n        self.layer_idx = layer_idx\n        quantization_config = getattr(config, 'quantization_config', None)\n\n        # build attention layer\n        self.layer_type = config.layer_types[layer_idx]\n        if self.layer_type == 'linear_attention':\n            self.linear_attn = Qwen3NextGatedDeltaNet(config, layer_idx, dtype=dtype, device=device)\n        elif self.layer_type == 'full_attention':\n            self.self_attn = Qwen3NextAttention(config, dtype=dtype, device=device)\n\n        # build MLP\n        if (layer_idx not in config.mlp_only_layers) and (config.num_experts\n                                                          > 0) and ((layer_idx + 1) % config.decoder_sparse_step == 0):\n            self.mlp = Qwen3NextSparseMoeBlock(config, layer_idx=layer_idx, dtype=dtype, device=device)\n        else:\n            self.mlp = Qwen3NextMLP(config, intermediate_size=config.intermediate_size, dtype=dtype, device=device)\n\n        # build input layer norm\n        self.input_layernorm = RMSNorm(config.hidden_size,\n                                       config.rms_norm_eps,\n                                       quant_config=quantization_config,\n                                       dtype=dtype,\n                                       device=device)\n\n        # build attention layer norm\n        self.post_attention_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],\n        past_key_value: Optional[List[torch.FloatTensor]],\n        residual: Optional[torch.Tensor],\n        attn_metadata: Any,\n        gated_delta_meta: GatedDeltaMeta,\n    ):\n\n        if residual is None:\n            residual = hidden_states\n            hidden_states = self.input_layernorm(hidden_states)\n        else:\n            hidden_states, residual = self.input_layernorm(hidden_states, residual)\n\n        # Self Attention\n        if self.layer_type == 'linear_attention':\n            hidden_states = self.linear_attn(\n                hidden_states=hidden_states,\n                past_key_value=past_key_value,\n                gated_delta_meta=gated_delta_meta,\n            )\n        elif self.layer_type == 'full_attention':\n            hidden_states = self.self_attn(\n                hidden_states=hidden_states,\n                rotary_pos_emb=rotary_pos_emb,\n                past_key_value=past_key_value,\n                attn_metadata=attn_metadata,\n            )\n\n        # Fully Connected\n        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)\n        hidden_states = self.mlp(hidden_states)\n\n        outputs = (hidden_states, residual)\n        return outputs\n\n\nclass Qwen3NextModel(nn.Module):\n    \"\"\"Qwen3 next model.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        self.config = config\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n        self.embed_tokens = build_embedding(\n            config.vocab_size,\n            config.hidden_size,\n            self.padding_idx,\n            dtype=dtype,\n            device=device,\n        )\n\n        # build all decode layers\n        # TODO: use full config.num_hidden_layers\n        self.layers = nn.ModuleList([\n            Qwen3NextDecoderLayer(config, layer_idx, dtype=dtype, device=device)\n            for layer_idx in range(self.config.num_hidden_layers)\n        ])\n\n        # build norm\n        self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)\n\n        # build rotary embedding\n        self.rotary_emb = build_rotary_embedding_from_config(config)\n\n    def forward(\n        self,\n        input_ids: torch.LongTensor,\n        position_ids: torch.LongTensor,\n        past_key_values: List[torch.FloatTensor],\n        attn_metadata: Any,\n        state_ids: torch.Tensor,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n    ):\n        \"\"\"Rewrite of LlamaModel.forward.\"\"\"\n\n        # token embedding\n        if inputs_embeds is None:\n            inputs_embeds = self.embed_tokens(input_ids)\n\n        hidden_states = inputs_embeds\n\n        # rotary embedding\n        cos, sin = self.rotary_emb(hidden_states, position_ids)\n        cos, sin = cos[0], sin[0]\n        rotary_pos_emb = (cos, sin)\n\n        # make seq_idx\n        gated_delta_meta = GatedDeltaMeta(hidden_states.size(1), self.config.linear_conv_kernel_dim, state_ids,\n                                          attn_metadata)\n\n        # decoding\n        residual = None\n        for idx, decoder_layer in enumerate(self.layers):\n            hidden_states, residual = decoder_layer(\n                hidden_states,\n                rotary_pos_emb=rotary_pos_emb,\n                past_key_value=past_key_values[idx],\n                residual=residual,\n                attn_metadata=attn_metadata,\n                gated_delta_meta=gated_delta_meta,\n            )\n\n        # norm\n        hidden_states, _ = self.norm(hidden_states, residual)\n\n        return hidden_states\n\n    def get_input_embeddings(self):\n        \"\"\"Get input embeddings.\"\"\"\n        return self.embed_tokens\n\n\nclass Qwen3NextForCausalLM(nn.Module, DeployModelMixinV1, CudaGraphMixin):\n    \"\"\"ModelForCausalLM.\"\"\"\n\n    packed_modules_mapping = {\n        'qkv_proj': [\n            'q_proj',\n            'k_proj',\n            'v_proj',\n        ],\n        'gate_up_proj': [\n            'gate_proj',\n            'up_proj',\n        ],\n    }\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 ctx_mgr: StepContextManager,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__()\n        self.config = config\n        self.ctx_mgr = ctx_mgr\n        # build model\n        self.model = Qwen3NextModel(config, dtype=dtype, device=device)\n        # build lm_head\n        self.lm_head = self.build_lm_head(config.hidden_size, config.vocab_size, bias=False, dtype=dtype, device=device)\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        past_key_values: List[List[torch.Tensor]],\n        attn_metadata: Any = None,\n        inputs_embeds: torch.Tensor = None,\n        state_ids: torch.Tensor = None,\n        **kwargs,\n    ):\n        \"\"\"Model forward, return logits.\"\"\"\n        hidden_states = self.model(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            inputs_embeds=inputs_embeds,\n            state_ids=state_ids,\n        )\n        return hidden_states\n\n    def get_input_embeddings(self):\n        \"\"\"Get input embeddings.\"\"\"\n        return self.model.get_input_embeddings()\n\n    def prepare_inputs_for_generation(\n        self,\n        past_key_values: List[List[torch.Tensor]],\n        inputs_embeds: Optional[torch.Tensor] = None,\n        context: StepContext = None,\n    ):\n        \"\"\"Prepare input.\"\"\"\n        # get input_ids, position_ids and attention metadatas\n        input_ids = context.input_ids\n        position_ids = context.position_ids\n        attn_metadata = context.attn_metadata\n\n        # make past_key_values\n        state_caches = list(cache.transpose(0, 1) for cache in context.state_caches)\n        state_caches = list(zip(state_caches[0], state_caches[1]))\n        past_key_values = list(past_key_values)\n        new_past_key_values = []\n        for layer_type in self.config.layer_types:\n            if layer_type == 'linear_attention':\n                new_past_key_values.append(state_caches.pop(0))\n            elif layer_type == 'full_attention':\n                new_past_key_values.append(past_key_values.pop(0))\n\n        # process vision embeddings\n        vision_embeddings = context.input_embeddings\n        vision_embedding_indexing = context.input_embedding_indexing\n        if vision_embeddings is not None and len(vision_embeddings) > 0:\n            if inputs_embeds is None:\n                inputs_embeds = self.get_input_embeddings()(input_ids)\n            inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds)\n\n        # inputs of forward\n        return dict(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=new_past_key_values,\n            attn_metadata=attn_metadata,\n            inputs_embeds=inputs_embeds,\n            state_ids=context.state_offsets,\n        )\n\n    def make_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs):\n        \"\"\"Make cudagraph buffers from forward inputs.\"\"\"\n        max_batchs = graph_meta.max_batchs\n        device = graph_meta.device\n\n        input_buffers = super().make_buffers_cudagraph(graph_meta=graph_meta, **kwargs)\n        state_ids = torch.full((max_batchs, ), -1, dtype=torch.long, device=device)\n        input_buffers['state_ids'] = state_ids\n\n        return input_buffers\n\n    def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs):\n        \"\"\"Fill cudagraph buffers from forward inputs.\"\"\"\n        input_buffers = graph_meta.input_buffers\n\n        new_inputs = super().fill_buffers_cudagraph(graph_meta=graph_meta, **kwargs)\n        state_ids = kwargs['state_ids']\n        input_buffers['state_ids'].fill_(-1)\n        input_buffers['state_ids'][:state_ids.size(0)].copy_(state_ids)\n        new_inputs['state_ids'] = input_buffers['state_ids']\n\n        return new_inputs\n\n    def _load_weight_experts(self, name: str, loaded_weight: torch.Tensor, params_dict: Dict[str, nn.Parameter],\n                             expert_params_mapping: List):\n        \"\"\"Load weight experts.\"\"\"\n        # load fused weights\n        for (param_name, weight_name, expert_id, shard_id) in expert_params_mapping:\n            if weight_name not in name:\n                continue\n            name = name.replace(weight_name, param_name)\n            param = params_dict[name]\n            load_weight(param, loaded_weight, expert_id=expert_id, shard_id=shard_id)\n            break\n        else:\n            param = params_dict[name]\n            load_weight(param, loaded_weight)\n\n    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):\n        \"\"\"Load weights.\"\"\"\n\n        def __skip_layers(name):\n            \"\"\"We might change the number of layers so we can debug the model\n            with less gpus.\"\"\"\n            import re\n            if '.layers.' not in name:\n                return False\n            matches = re.findall(r'\\.layers\\.(\\d+)\\.', name)\n            layer_id = int(matches[0])\n            return layer_id >= self.config.num_hidden_layers\n\n        # modify from vllm\n        stacked_params_mapping = [\n            # (param_name, shard_name, shard_id)\n            ('.qkv_proj', '.q_proj', 'q'),\n            ('.qkv_proj', '.k_proj', 'k'),\n            ('.qkv_proj', '.v_proj', 'v'),\n            ('.gate_up_proj', '.gate_proj', 0),\n            ('.gate_up_proj', '.up_proj', 1),\n        ]\n\n        # expert map\n        num_experts = self.config.num_experts\n        expert_params_mapping = []\n        for exp_id in range(num_experts):\n            gate_param = ('.experts.gate_up', f'.experts.{exp_id}.gate_proj', exp_id, 'gate')\n            up_param = ('.experts.gate_up', f'.experts.{exp_id}.up_proj', exp_id, 'up')\n            down_param = ('.experts.down', f'.experts.{exp_id}.down_proj', exp_id, 'down')\n            expert_params_mapping += [gate_param, up_param, down_param]\n\n        rms_norm_keys = ['model.norm', '.input_layernorm', '.post_attention_layernorm', '.q_norm', '.k_norm']\n\n        params_dict = dict(self.named_parameters())\n        for name, loaded_weight in weights:\n\n            if __skip_layers(name):\n                continue\n\n            if 'mtp.' in name:\n                continue\n            if 'rotary_emb.inv_freq' in name:\n                continue\n            if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):\n                continue\n            if self.config.tie_word_embeddings and 'lm_head.weight' in name:\n                continue\n\n            name = name.replace('.block_sparse_moe.', '.mlp.')\n            if '.experts' in name and '.shared_expert' not in name:\n                self._load_weight_experts(name, loaded_weight, params_dict, expert_params_mapping=expert_params_mapping)\n            else:\n                for (param_name, weight_name, shard_id) in stacked_params_mapping:\n                    if weight_name not in name:\n                        continue\n                    name = name.replace(weight_name, param_name)\n                    param = params_dict[name]\n                    load_weight(param, loaded_weight, shard_id=shard_id)\n                    break\n                else:\n                    for rms_norm_key in rms_norm_keys:\n                        if rms_norm_key in name and 'weight' in name:\n                            loaded_weight = loaded_weight + 1\n                            break\n                    param = params_dict[name]\n                    load_weight(param, loaded_weight)\n"
  },
  {
    "path": "lmdeploy/pytorch/models/qwen3_vl.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nfrom functools import lru_cache\nfrom typing import Any, Dict, Iterable, List, Tuple\n\nimport numpy as np\nimport torch\nfrom torch import nn\nfrom transformers.configuration_utils import PretrainedConfig\nfrom transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update\n\nfrom lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor, PreprocessInputResult\nfrom lmdeploy.pytorch.model_inputs import StepContext, StepContextManager\nfrom lmdeploy.pytorch.multimodal.data_type import MultiModalData\nfrom lmdeploy.pytorch.nn import LayerNorm\nfrom lmdeploy.pytorch.nn.linear import build_colwise_linear, build_rowwise_linear\nfrom lmdeploy.pytorch.nn.rotary_embedding import get_rope_parameters\nfrom lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight\nfrom lmdeploy.vl.constants import Modality\n\nfrom .patch import add_prefix\nfrom .qwen2_5_vl import Qwen2_5_VisionRotaryEmbedding as Qwen3VLVisionRotaryEmbedding\nfrom .qwen2_5_vl import Qwen2_5_VLVisionAttention as Qwen3VLVisionAttention\nfrom .qwen3 import Qwen3model\nfrom .utils.cudagraph import CudaGraphMeta, CudaGraphMixin\nfrom .utils.model import DeployModelMixinV1, vlm_model\n\n\nclass Qwen3VLTextRotaryEmbedding(nn.Module):\n    inv_freq: torch.Tensor  # fix linting for `register_buffer`\n\n    def __init__(self, config: PretrainedConfig, device=None):\n        super().__init__()\n        if hasattr(config, 'rope_scaling') and config.rope_scaling is not None:\n            self.rope_type = config.rope_scaling.get('rope_type', 'default')\n        else:\n            self.rope_type = 'default'\n\n        self._pack_for_trans5(config)\n\n        self.max_seq_len_cached = config.max_position_embeddings\n        self.original_max_seq_len = config.max_position_embeddings\n\n        self.config = config\n        self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]\n\n        inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)\n        self.register_buffer('inv_freq', inv_freq, persistent=False)\n        self.original_inv_freq = self.inv_freq\n\n        self.mrope_section = config.rope_scaling.get('mrope_section', [24, 20, 20])\n\n    def _pack_for_trans5(self, config):\n        if self.rope_type == 'default' and 'default' not in ROPE_INIT_FUNCTIONS:\n            # transformers 5 has removed default in ROPE_INIT_FUNCTIONS\n            self.rope_type = 'linear'\n            rope_parameters = get_rope_parameters(config)\n            if 'factor' not in rope_parameters:\n                rope_parameters['factor'] = 1.0\n\n    def apply_interleaved_mrope(self, freqs, mrope_section):\n        \"\"\"Apply interleaved MRoPE to 3D rotary embeddings.\n\n        Reorganizes frequency layout from chunked [TTT...HHH...WWW] to\n        interleaved [THTHWHTHW...TT], preserving frequency continuity.\n        args:\n            x: (3, bs, seq_len, head_dim // 2)\n            mrope_section: (3,)\n        returns:\n            x_t: (bs, seq_len, head_dim // 2)\n        \"\"\"\n        freqs_t = freqs[0]  # just overwrite the first dimension T\n        for dim, offset in enumerate((1, 2), start=1):  # H, W\n            length = mrope_section[dim] * 3\n            idx = slice(offset, length, 3)\n            freqs_t[..., idx] = freqs[dim, ..., idx]\n        return freqs_t\n\n    @torch.no_grad()\n    @dynamic_rope_update  # power user: used with advanced RoPE types (e.g. dynamic rope)\n    def forward(self, x, position_ids):\n        # In contrast to other models, Qwen3VL has different position ids for the grids\n        # So we expand the inv_freq to shape (3, ...)\n        if position_ids.ndim == 2:\n            position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)\n        inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)\n        position_ids_expanded = position_ids[:, :, None, :].float()  # shape (3, bs, 1, positions)\n\n        device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != 'mps' else 'cpu'\n        with torch.autocast(device_type=device_type, enabled=False):  # Force float32\n            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)\n            freqs = self.apply_interleaved_mrope(freqs, self.mrope_section)\n            emb = torch.cat((freqs, freqs), dim=-1)\n            cos = emb.cos() * self.attention_scaling\n            sin = emb.sin() * self.attention_scaling\n\n        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)\n\n\nclass Qwen3VLTextModel(Qwen3model):\n    \"\"\"Text part of Qwen3VL.\n\n    not a pure text-only model, as DeepStack integrates visual features into the early hidden states.\n    \"\"\"\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None,\n                 prefix: str = ''):\n        super().__init__(config=config, dtype=dtype, device=device, prefix=prefix)\n\n        # build rotary embedding\n        # TODO: zhouxinyu, add triton kernel for interleaved mrope\n        self.rotary_emb = Qwen3VLTextRotaryEmbedding(config, device=device)\n\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        position_ids: torch.LongTensor | None = None,\n        past_key_values: List[torch.FloatTensor] | None = None,\n        attn_metadata: Any = None,\n        inputs_embeds: torch.FloatTensor | None = None,\n        mrope_position_ids: torch.LongTensor = None,\n        # args for deepstack\n        visual_pos_masks: torch.Tensor | None = None,\n        deepstack_visual_embeds: List[torch.Tensor] | None = None,\n    ):\n        \"\"\"Rewrite of LlamaModel.forward.\"\"\"\n\n        # token embedding\n        if inputs_embeds is None:\n            inputs_embeds = self.embed_tokens(input_ids)\n\n        hidden_states = inputs_embeds\n\n        # rotary embedding\n        if mrope_position_ids is None:\n            cos, sin = self.rotary_emb(hidden_states, position_ids)\n        else:\n            mrope_position_ids = mrope_position_ids.unsqueeze(1)\n            cos, sin = self.rotary_emb(hidden_states, mrope_position_ids)\n\n        cos, sin = cos[0], sin[0]\n        rotary_pos_emb = (cos, sin)\n\n        # decoding\n        residual = None\n        for idx, decoder_layer in enumerate(self.layers):\n            past_key_value = past_key_values[idx]\n            hidden_states, residual = decoder_layer(\n                hidden_states,\n                rotary_pos_emb=rotary_pos_emb,\n                past_key_value=past_key_value,\n                residual=residual,\n                attn_metadata=attn_metadata,\n            )\n\n            # add visual features to the hidden states of first several layers\n            if deepstack_visual_embeds is not None and idx in range(len(deepstack_visual_embeds)):\n                hidden_states = hidden_states + residual\n                hidden_states = self._deepstack_process(\n                    hidden_states,\n                    visual_pos_masks,\n                    deepstack_visual_embeds[idx],\n                )\n                residual = None\n\n        # norm\n        hidden_states, _ = self.norm(hidden_states, residual)\n\n        return hidden_states\n\n    def _deepstack_process(self, hidden_states: torch.Tensor, visual_pos_masks: torch.Tensor,\n                           visual_embeds: torch.Tensor):\n        visual_pos_masks = visual_pos_masks.to(hidden_states.device)\n        visual_embeds = visual_embeds.to(hidden_states.device, hidden_states.dtype)\n        local = torch.zeros_like(hidden_states)\n        local.masked_scatter_(visual_pos_masks, visual_embeds)\n        hidden_states += local\n        return hidden_states\n\n\nclass Qwen3VLVisionPatchEmbed(nn.Module):\n\n    def __init__(self, config, dtype: torch.dtype = None, device: torch.device = None) -> None:\n        super().__init__()\n        self.patch_size = config.patch_size\n        self.temporal_patch_size = config.temporal_patch_size\n        self.in_channels = config.in_channels\n        self.embed_dim = config.hidden_size\n\n        kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size]\n        self.proj = nn.Conv3d(self.in_channels,\n                              self.embed_dim,\n                              kernel_size=kernel_size,\n                              stride=kernel_size,\n                              bias=True,\n                              dtype=dtype,\n                              device=device)\n\n    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        target_dtype = self.proj.weight.dtype\n        hidden_states = hidden_states.view(-1, self.in_channels, self.temporal_patch_size, self.patch_size,\n                                           self.patch_size)\n        hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim)\n        return hidden_states\n\n\nclass Qwen3VLVisionMLP(nn.Module):\n    \"\"\"Vision mlp.\"\"\"\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None,\n                 prefix: str = ''):\n        super().__init__()\n        from transformers.activations import ACT2FN\n        hidden_dim = config.hidden_size\n        intermediate_size = config.intermediate_size\n        quantization_config = getattr(config, 'quantization_config', None)\n        # gate up\n        self.linear_fc1 = build_colwise_linear(hidden_dim,\n                                               intermediate_size,\n                                               bias=True,\n                                               dtype=dtype,\n                                               device=device,\n                                               quant_config=quantization_config,\n                                               is_tp=True,\n                                               prefix=add_prefix('linear_fc1', prefix))\n\n        # gelu_pytorch_tanh\n        self.act = ACT2FN[config.hidden_act]\n\n        # down\n        self.linear_fc2 = build_rowwise_linear(\n            intermediate_size,\n            hidden_dim,\n            bias=True,\n            dtype=dtype,\n            device=device,\n            quant_config=quantization_config,\n            is_tp=True,\n            prefix=add_prefix('linear_fc2', prefix),\n        )\n\n    def forward(self, x):\n        \"\"\"forward.\"\"\"\n        return self.linear_fc2(self.act(self.linear_fc1(x)))\n\n\nclass Qwen3VLVisionBlock(nn.Module):\n    \"\"\"Vision block.\"\"\"\n\n    def __init__(\n        self,\n        config: PretrainedConfig,\n        layer_idx: int,\n        dtype: torch.dtype = None,\n        device: torch.device = None,\n        prefix: str = '',\n    ):\n        super().__init__()\n        self.layer_idx = layer_idx\n        self.norm1 = LayerNorm(config.hidden_size, eps=1e-6, dtype=dtype, device=device)\n        self.norm2 = LayerNorm(config.hidden_size, eps=1e-6, dtype=dtype, device=device)\n\n        self.attn = Qwen3VLVisionAttention(config, dtype=dtype, device=device, prefix=add_prefix('attn', prefix))\n\n        self.mlp = Qwen3VLVisionMLP(config, dtype=dtype, device=device, prefix=add_prefix('mlp', prefix))\n\n    def forward(self,\n                hidden_states: torch.Tensor,\n                cu_seqlens: torch.Tensor,\n                rotary_pos_emb: torch.Tensor | None = None) -> torch.Tensor:\n        hidden_states = hidden_states + self.attn(\n            self.norm1(hidden_states),\n            cu_seqlens=cu_seqlens,\n            rotary_pos_emb=rotary_pos_emb,\n        )\n        hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))\n        return hidden_states\n\n\nclass Qwen3VLVisionPatchMerger(nn.Module):\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 use_postshuffle_norm=False,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None,\n                 prefix: str = '') -> None:\n        super().__init__()\n        self.hidden_size = config.hidden_size * (config.spatial_merge_size**2)\n        self.use_postshuffle_norm = use_postshuffle_norm\n        self.norm = LayerNorm(self.hidden_size if use_postshuffle_norm else config.hidden_size,\n                              eps=1e-6,\n                              dtype=dtype,\n                              device=device)\n        self.linear_fc1 = build_colwise_linear(\n            self.hidden_size,\n            self.hidden_size,\n            bias=True,\n            dtype=dtype,\n            device=device,\n            is_tp=True,\n            prefix=add_prefix('linear_fc1', prefix),\n        )\n        self.act_fn = nn.GELU()\n        self.linear_fc2 = build_rowwise_linear(\n            self.hidden_size,\n            config.out_hidden_size,\n            bias=True,\n            dtype=dtype,\n            device=device,\n            is_tp=True,\n            prefix=add_prefix('linear_fc2', prefix),\n        )\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        x = self.norm(x.view(-1, self.hidden_size) if self.use_postshuffle_norm else x).view(-1, self.hidden_size)\n        x = self.linear_fc2(self.act_fn(self.linear_fc1(x)))\n        return x\n\n\n@vlm_model\nclass Qwen3VLVisionModel(nn.Module):\n    \"\"\"Vision transformer.\"\"\"\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None,\n                 prefix: str = ''):\n        super().__init__()\n        self.config = config\n        self.spatial_merge_size = config.spatial_merge_size\n\n        self.patch_embed = Qwen3VLVisionPatchEmbed(config=config, dtype=dtype, device=device)\n\n        self.pos_embed = nn.Embedding(config.num_position_embeddings, config.hidden_size, dtype=dtype, device=device)\n        self.num_grid_per_side = int(config.num_position_embeddings**0.5)\n\n        head_dim = config.hidden_size // config.num_heads\n        self.rotary_pos_emb = Qwen3VLVisionRotaryEmbedding(head_dim // 2, device=device)\n\n        self.blocks = nn.ModuleList([\n            Qwen3VLVisionBlock(config,\n                               layer_idx,\n                               dtype=dtype,\n                               device=device,\n                               prefix=add_prefix(f'blocks.{layer_idx}', prefix)) for layer_idx in range(config.depth)\n        ])\n        self.merger = Qwen3VLVisionPatchMerger(config=config,\n                                               use_postshuffle_norm=False,\n                                               dtype=dtype,\n                                               device=device,\n                                               prefix=add_prefix('merger', prefix))\n\n        if hasattr(config, 'deepstack_visual_indexes'):\n            self.deepstack_visual_indexes = config.deepstack_visual_indexes\n            self.deepstack_merger_list = nn.ModuleList([\n                Qwen3VLVisionPatchMerger(config=config,\n                                         use_postshuffle_norm=True,\n                                         dtype=dtype,\n                                         device=device,\n                                         prefix=add_prefix(f'deepstack_merger_list.{dvi}', prefix))\n                for dvi in range(len(config.deepstack_visual_indexes))\n            ])\n\n    @staticmethod\n    @lru_cache(maxsize=1024)\n    def rot_pos_ids(h: int, w: int, spatial_merge_size: int) -> torch.Tensor:\n        h_div = h // spatial_merge_size\n        w_div = w // spatial_merge_size\n\n        hpos_ids = np.broadcast_to(np.arange(h).reshape(h, 1), (h, w))\n        hpos_ids = hpos_ids.reshape(\n            h_div,\n            spatial_merge_size,\n            w_div,\n            spatial_merge_size,\n        )\n        hpos_ids = hpos_ids.transpose(0, 2, 1, 3)\n        hpos_ids = hpos_ids.flatten()\n\n        wpos_ids = np.broadcast_to(np.arange(w).reshape(1, w), (h, w))\n        wpos_ids = wpos_ids.reshape(\n            h_div,\n            spatial_merge_size,\n            w_div,\n            spatial_merge_size,\n        )\n        wpos_ids = wpos_ids.transpose(0, 2, 1, 3)\n        wpos_ids = wpos_ids.flatten()\n\n        return torch.from_numpy(np.stack([hpos_ids, wpos_ids], axis=-1))\n\n    def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:\n        \"\"\"Rotary position embedding.\"\"\"\n        pos_ids = []\n\n        for t, h, w in grid_thw:\n            base = self.rot_pos_ids(int(h), int(w), self.spatial_merge_size)\n            pos_ids.append(base if t == 1 else base.repeat(t, 1))\n\n        pos_ids = torch.cat(pos_ids, dim=0)\n        max_grid_size = grid_thw[:, 1:].max()\n        rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)\n        rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)\n\n        return rotary_pos_emb\n\n    # copy from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/qwen3_vl.py#L474\n    def fast_pos_embed_interpolate(self, grid_thw: List[List[int]]) -> torch.Tensor:\n        num_grid_per_side = self.num_grid_per_side\n        m_size = self.spatial_merge_size\n        hidden_dim = self.pos_embed.embedding_dim\n        device = self.pos_embed.weight.device\n\n        outputs = []\n        for t, h, w in grid_thw:\n            h_idxs = torch.linspace(0, num_grid_per_side - 1, h, dtype=torch.float32, device=device)\n            w_idxs = torch.linspace(0, num_grid_per_side - 1, w, dtype=torch.float32, device=device)\n\n            h_floor = h_idxs.to(torch.long)\n            w_floor = w_idxs.to(torch.long)\n            h_ceil = torch.clamp(h_floor + 1, max=num_grid_per_side - 1)\n            w_ceil = torch.clamp(w_floor + 1, max=num_grid_per_side - 1)\n\n            dh = h_idxs - h_floor\n            dw = w_idxs - w_floor\n\n            # Create meshgrid view for all h, w vars\n            dh_grid, dw_grid = torch.meshgrid(dh, dw, indexing='ij')\n            h_floor_grid, w_floor_grid = torch.meshgrid(h_floor, w_floor, indexing='ij')\n            h_ceil_grid, w_ceil_grid = torch.meshgrid(h_ceil, w_ceil, indexing='ij')\n\n            # original computation of weights\n            # w00 = (1 - dh_grid) * (1 - dw_grid)\n            # w01 = (1 - dh_grid) * dw_grid\n            # w10 = dh_grid * (1 - dw_grid)\n            # w11 = dh_grid * dw_grid\n            # we reuse w11 here to avoid duplicate\n            # dh_grid * dw_grid computation\n            w11 = dh_grid * dw_grid\n            w10 = dh_grid - w11\n            w01 = dw_grid - w11\n            w00 = 1 - dh_grid - w01\n\n            h_grid = torch.stack([h_floor_grid, h_floor_grid, h_ceil_grid, h_ceil_grid])\n            w_grid = torch.stack([w_floor_grid, w_ceil_grid, w_floor_grid, w_ceil_grid])\n            h_grid_idx = h_grid * num_grid_per_side\n\n            indices = (h_grid_idx + w_grid).reshape(4, -1)\n            weights = torch.stack([w00, w01, w10, w11], dim=0).reshape(4, -1, 1)\n            weights = weights.to(dtype=self.pos_embed.weight.dtype, device=device)\n\n            embeds = self.pos_embed(indices)\n            embeds *= weights\n            combined = embeds.sum(dim=0)\n\n            combined = combined.reshape(h // m_size, m_size, w // m_size, m_size, hidden_dim)\n            combined = combined.permute(0, 2, 1, 3, 4).reshape(1, -1, hidden_dim)\n            repeated = combined.expand(t, -1, -1).reshape(-1, hidden_dim)\n            outputs.append(repeated)\n\n        return torch.cat(outputs, dim=0)\n\n    def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor,\n                pos_embeds: torch.Tensor) -> torch.Tensor:\n        \"\"\"forward.\"\"\"\n        hidden_states = self.patch_embed(hidden_states)\n        hidden_states = hidden_states + pos_embeds\n        cu_seqlens = torch.nn.functional.pad(cu_seqlens, (1, 0), value=0)\n\n        deepstack_feature_lists = []\n        for layer_num, blk in enumerate(self.blocks):\n            hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb)\n            if hasattr(self, 'deepstack_visual_indexes') and layer_num in self.deepstack_visual_indexes:\n                deepstack_merge_idx = self.deepstack_visual_indexes.index(layer_num)\n                deepstack_feature = self.deepstack_merger_list[deepstack_merge_idx](hidden_states)\n                deepstack_feature_lists.append(deepstack_feature)\n\n        hidden_states = self.merger(hidden_states)\n\n        return hidden_states, deepstack_feature_lists\n\n\nclass Qwen3VLForConditionalGeneration(nn.Module, DeployModelMixinV1, CudaGraphMixin):\n    \"\"\"ModelForCausalLM.\"\"\"\n\n    packed_modules_mapping = {\n        'qkv_proj': [\n            'q_proj',\n            'k_proj',\n            'v_proj',\n        ],\n        'gate_up_proj': [\n            'gate_proj',\n            'up_proj',\n        ],\n    }\n\n    def __init__(\n        self,\n        config: PretrainedConfig,\n        ctx_mgr: StepContextManager,\n        dtype: torch.dtype = None,\n        device: torch.device = None,\n        prefix: str = '',\n    ):\n        super().__init__()\n        self.config = config\n        self.ctx_mgr = ctx_mgr\n\n        # build preprocessor\n        self.input_processor = Qwen3VLInputProcessor(self.config)\n\n        # build vision model\n        self.visual = Qwen3VLVisionModel(\n            config.vision_config,\n            dtype=dtype,\n            device=device,\n            prefix=add_prefix('visual', prefix),\n        )\n\n        # build text model\n        self.language_model = Qwen3VLTextModel(config.text_config,\n                                               dtype=dtype,\n                                               device=device,\n                                               prefix=add_prefix('language_model', prefix))\n\n        # build lm_head\n        self.lm_head = self.build_lm_head(config.text_config.hidden_size,\n                                          config.text_config.vocab_size,\n                                          bias=False,\n                                          dtype=dtype,\n                                          device=device)\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        past_key_values: List[List[torch.Tensor]],\n        attn_metadata: Any = None,\n        inputs_embeds: torch.Tensor = None,\n        mrope_position_ids: torch.Tensor = None,\n        pixel_values: torch.Tensor = None,\n        vis_cu_seqlens: torch.Tensor = None,\n        vis_pos_emb: torch.Tensor = None,\n        image_mask: torch.Tensor = None,\n        pos_embeds: torch.Tensor = None,\n        grid_thw: torch.Tensor = None,\n        **kwargs,\n    ):\n        \"\"\"Model forward, return logits.\"\"\"\n\n        visual_pos_masks = None\n        deepstack_visual_embeds = None\n        if inputs_embeds is None:\n            inputs_embeds = self.get_input_embeddings()(input_ids)\n\n            if pixel_values is not None:\n                dtype = inputs_embeds.dtype\n                pixel_values = pixel_values.to(dtype)\n                vis_pos_emb = (vis_pos_emb[0].to(dtype), vis_pos_emb[1].to(dtype))\n\n                # get image embeds and deepstack visual embeds\n                image_embeds, deepstack_visual_embeds = self.visual(pixel_values,\n                                                                    cu_seqlens=vis_cu_seqlens,\n                                                                    rotary_pos_emb=vis_pos_emb,\n                                                                    pos_embeds=pos_embeds)\n\n                # split image embeds per sample\n                split_sizes = (grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist()\n                image_embeds = torch.split(image_embeds, split_sizes)\n                image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, dtype)\n\n                # mask and scatter to create final input embeddings\n                expanded_image_mask = image_mask.unsqueeze(-1).expand_as(inputs_embeds)\n                inputs_embeds = inputs_embeds.masked_scatter(expanded_image_mask, image_embeds)\n\n                visual_pos_masks = expanded_image_mask\n\n        hidden_states = self.language_model(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            inputs_embeds=inputs_embeds,\n            mrope_position_ids=mrope_position_ids,\n            # args for deepstack\n            visual_pos_masks=visual_pos_masks,\n            deepstack_visual_embeds=deepstack_visual_embeds,\n        )\n        return hidden_states\n\n    def get_input_embeddings(self):\n        \"\"\"Get input embeddings.\"\"\"\n        return self.language_model.get_input_embeddings()\n\n    def prepare_inputs_for_generation(\n        self,\n        past_key_values: List[List[torch.Tensor]],\n        inputs_embeds: torch.Tensor | None = None,\n        context: StepContext = None,\n    ):\n        \"\"\"Prepare input.\"\"\"\n\n        # get input_ids, position_ids and attention metadatas\n        input_ids = context.input_ids\n        position_ids = context.position_ids\n        attn_metadata = context.attn_metadata\n\n        pixel_values = None\n        vis_cu_seqlens = None\n        vis_pos_emb = None\n        image_mask = None\n        grid_thw = None\n        pos_embeds = None\n        if context.input_multimodals is not None:\n            mm_inputs = [input_mm.get('mm_data', []) for input_mm in context.input_multimodals]\n            # flatten batch\n            mm_inputs = [item for sublist in mm_inputs for item in sublist]\n\n            if len(mm_inputs) > 0:\n                modality = mm_inputs[0].modality\n                pixel_values = torch.cat([inp.data for inp in mm_inputs])\n\n                image_token_id = mm_inputs[0].meta.get('image_token_id')\n                video_token_id = mm_inputs[0].meta.get('video_token_id')\n                mm_token_id = image_token_id if modality == Modality.IMAGE else video_token_id\n                image_mask = (input_ids == mm_token_id)\n\n                grid_thw = torch.cat([data.meta['grid_thw'] for data in mm_inputs]).cpu()\n                vis_pos_emb = self.visual.rot_pos_emb(grid_thw)\n                pos_embeds = self.visual.fast_pos_embed_interpolate(grid_thw)\n                vis_cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2],\n                                                         grid_thw[:, 0]).to(pixel_values.device)\n                vis_cu_seqlens = vis_cu_seqlens.cumsum(dim=0, dtype=torch.int32)\n                vis_pos_emb = vis_pos_emb.repeat(1, 2)\n                vis_pos_emb = (vis_pos_emb.cos(), vis_pos_emb.sin())\n\n        mrope_position_ids = getattr(context, 'mrope_position_ids', None)\n\n        # process vision embeddings\n        vision_embeddings = context.input_embeddings\n        vision_embedding_indexing = context.input_embedding_indexing\n        if vision_embeddings is not None and len(vision_embeddings) > 0:\n            if inputs_embeds is None:\n                inputs_embeds = self.get_input_embeddings()(input_ids)\n            inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds)\n\n        # inputs of forward\n        return dict(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            inputs_embeds=inputs_embeds,\n            mrope_position_ids=mrope_position_ids,\n            pixel_values=pixel_values,\n            vis_cu_seqlens=vis_cu_seqlens,\n            vis_pos_emb=vis_pos_emb,\n            image_mask=image_mask,\n            grid_thw=grid_thw,\n            pos_embeds=pos_embeds,\n        )\n\n    @classmethod\n    def rename_weight(cls, name: str) -> str:\n        \"\"\"Rename weight.\"\"\"\n        if name.startswith('model.language_model.'):\n            return 'language_model.' + name[len('model.language_model.'):]\n        elif name.startswith('model.visual.'):\n            return 'visual.' + name[len('model.visual.'):]\n        elif name.startswith('model.'):\n            return name[len('model.'):]\n        return name\n\n    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):\n        \"\"\"Load weights.\"\"\"\n        # modify from vllm\n        stacked_params_mapping = [\n            # (param_name, shard_name, shard_id)\n            ('.qkv_proj', '.q_proj', 'q'),\n            ('.qkv_proj', '.k_proj', 'k'),\n            ('.qkv_proj', '.v_proj', 'v'),\n            ('.gate_up_proj', '.gate_proj', 0),\n            ('.gate_up_proj', '.up_proj', 1),\n        ]\n\n        params_dict = dict(self.named_parameters())\n        for name, loaded_weight in weights:\n            if 'rotary_emb.inv_freq' in name:\n                continue\n            if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):\n                continue\n            if self.config.tie_word_embeddings and 'lm_head.weight' in name:\n                continue\n            for (param_name, weight_name, shard_id) in stacked_params_mapping:\n                if weight_name not in name:\n                    continue\n                name = name.replace(weight_name, param_name)\n                param = params_dict[name]\n                load_weight(param, loaded_weight, shard_id=shard_id)\n                break\n            else:\n                if '.qkv.' in name:\n                    param = params_dict[name]\n                    q, k, v = param.weight_spliter(loaded_weight)\n                    load_weight(param, q, shard_id='q')\n                    load_weight(param, k, shard_id='k')\n                    load_weight(param, v, shard_id='v')\n                else:\n                    param = params_dict[name]\n                    load_weight(param, loaded_weight)\n\n    def make_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs):\n        \"\"\"Make cudagraph buffers from forward inputs.\"\"\"\n        max_tokens = graph_meta.max_tokens\n\n        input_buffers = super().make_buffers_cudagraph(graph_meta=graph_meta, **kwargs)\n        mrope_position_ids = kwargs.get('mrope_position_ids', None)\n        if mrope_position_ids is not None:\n            input_buffers['mrope_position_ids'] = mrope_position_ids.new_zeros(3, max_tokens)\n\n        return input_buffers\n\n    def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs):\n        \"\"\"Fill cudagraph buffers from forward inputs.\"\"\"\n\n        new_inputs = super().fill_buffers_cudagraph(graph_meta=graph_meta, **kwargs)\n\n        input_ids = kwargs.get('input_ids')\n        num_tokens = input_ids.size(-1)\n        new_batch_size = graph_meta.max_batchs\n\n        is_decoding = graph_meta.is_decoding\n        input_buffers = graph_meta.input_buffers\n        mrope_position_ids = kwargs.get('mrope_position_ids', None)\n        if mrope_position_ids is not None:\n            input_buffers['mrope_position_ids'][:, :num_tokens] = mrope_position_ids\n            if is_decoding:\n                new_inputs['mrope_position_ids'] = input_buffers['mrope_position_ids'][:, :new_batch_size]\n            else:\n                new_inputs['mrope_position_ids'] = input_buffers['mrope_position_ids']\n\n        return new_inputs\n\n    def _get_model_metas(self, context: StepContext):\n        \"\"\"Get model metas.\"\"\"\n        model_metas = context.model_metas\n        if model_metas is None:\n            batch_size = context.q_seqlens.numel()\n            return [dict(mrope_delta=0)] * batch_size\n        return [dict(mrope_delta=0) if meta is None else meta for meta in model_metas]\n\n    def _update_model_meta_decoding(self, context: StepContext):\n        \"\"\"Update model meta for decoding.\"\"\"\n        model_metas = self._get_model_metas(context)\n        position_ids = context.position_ids\n\n        mrope_deltas = [meta['mrope_delta'] for meta in model_metas]\n        mrope_deltas = position_ids.new_tensor(mrope_deltas)\n        mrope_position_ids = position_ids + mrope_deltas[None]\n        mrope_position_ids = mrope_position_ids.expand(3, -1)\n\n        context.mrope_position_ids = mrope_position_ids\n        return model_metas\n\n    def _get_multimodal_pos_ids(self, grid_thw: list, device: torch.device):\n        \"\"\"Get mrope ids.\"\"\"\n        t, h, w = grid_thw\n        h //= 2\n        w //= 2\n        stride = torch.tensor([h * w, w, 1], device=device)[:, None]\n        size = torch.tensor([t, h, w], device=device)[:, None]\n        pos_ids = torch.arange(t * h * w, device=device)[None].expand(3, -1)\n        pos_ids = pos_ids // stride % size\n        return pos_ids\n\n    def _update_model_meta_prefilling(self, context: StepContext):\n        \"\"\"Update model meta for prefilling.\"\"\"\n        model_metas = self._get_model_metas(context)\n        input_multimodals = context.input_multimodals\n        if input_multimodals is None:\n            input_multimodals = [None] * len(model_metas)\n        position_ids = context.position_ids\n        batched_pos_ids = position_ids[0].split(context.q_seqlens.tolist())\n        mrope_position_ids = []\n        new_model_metas = []\n        for pos_ids, model_meta, input_mm in zip(batched_pos_ids, model_metas, input_multimodals):\n            mm_data_list = []\n            if input_mm is not None:\n                mm_data_list.extend(input_mm.get('mm_data', []))\n\n            if model_meta is None or 'mrope_delta' not in model_meta:\n                mrope_delta = 0\n            else:\n                mrope_delta = model_meta['mrope_delta']\n\n            pos_start = pos_ids[0].item()\n            mrope_pos_ids = pos_ids + mrope_delta\n            mrope_pos_ids = mrope_pos_ids[None].expand(3, -1).clone()\n\n            for mm_data in mm_data_list:\n                if mm_data.modality == Modality.IMAGE:\n                    grid_thw = mm_data.meta['grid_thw'][0].tolist()\n                    _, h, w = grid_thw\n                    h //= 2\n                    w //= 2\n                    num_pad = mm_data.end - mm_data.start - max(h, w)\n                    mrope_delta -= num_pad\n                    fill_start = mm_data.start - pos_start\n                    fill_end = mm_data.end - pos_start\n                    img_pos_ids = self._get_multimodal_pos_ids(grid_thw, pos_ids.device)\n                    img_pos_ids += mrope_pos_ids[:, fill_start:fill_start + 1]\n                    mrope_pos_ids[:, fill_end:] -= num_pad\n                    mrope_pos_ids[:, fill_start:fill_end] = img_pos_ids\n                elif mm_data.modality == Modality.VIDEO:\n                    video_token_id = self.config.video_token_id\n                    grid_thw = mm_data.meta['grid_thw']\n\n                    grid_thw = torch.repeat_interleave(grid_thw, grid_thw[:, 0], dim=0)\n                    grid_thw[:, 0] = 1\n\n                    position_ids_list = []\n                    input_tokens = context.input_ids.tolist()[0]\n\n                    st = 0\n                    # treat each frame separately as a single image\n                    for video_idx in range(grid_thw.shape[0]):\n                        # text before video. e.g. <0.3 seconds><|vision_start|> ...\n                        ed_video = input_tokens.index(video_token_id, st)\n                        ed = ed_video\n                        text_len = ed - st\n                        st_idx = position_ids_list[-1].max() + 1 if len(position_ids_list) > 0 else 0\n                        text_pos_ids = torch.arange(text_len, device=pos_ids.device).view(1, -1).expand(3, -1) + st_idx\n                        position_ids_list.append(text_pos_ids)\n\n                        # video frame. <video_pad> ... <|video_end|>\n                        t, h, w = (\n                            grid_thw[video_idx][0],\n                            grid_thw[video_idx][1] // 2,\n                            grid_thw[video_idx][2] // 2,\n                        )\n                        video_pos_ids = self._get_multimodal_pos_ids(grid_thw[video_idx], pos_ids.device)\n                        position_ids_list.append(video_pos_ids + text_len + st_idx)\n\n                        st = ed + t * h * w\n\n                    # text after video, <|vision_end|> ...\n                    if st < len(input_tokens):\n                        st_idx = position_ids_list[-1].max() + 1 if len(position_ids_list) > 0 else 0\n                        text_len = len(input_tokens) - st\n                        text_pos_ids = torch.arange(text_len, device=pos_ids.device).view(1, -1).expand(3, -1) + st_idx\n                        position_ids_list.append(text_pos_ids)\n\n                    mrope_pos_ids = torch.cat(position_ids_list, dim=1).reshape(3, -1)\n                    mrope_delta = mrope_pos_ids.max() + 1 - pos_ids.size(0)\n                    mrope_pos_ids += pos_start  # add back the original position offset\n\n            mrope_position_ids.append(mrope_pos_ids)\n            new_model_metas.append(dict(mrope_delta=mrope_delta))\n\n        mrope_position_ids = torch.cat(mrope_position_ids, dim=1)\n        context.mrope_position_ids = mrope_position_ids\n\n        return new_model_metas\n\n    def update_model_metas(self,\n                           past_key_values: List[List[torch.Tensor]],\n                           inputs_embeds: torch.Tensor | None = None,\n                           context: StepContext = None):\n        \"\"\"Update model meta.\"\"\"\n        if context.is_decoding:\n            return self._update_model_meta_decoding(context)\n        else:\n            return self._update_model_meta_prefilling(context)\n\n    def get_input_processor(self) -> BaseModelInputProcessor:\n        \"\"\"Get input processor.\"\"\"\n        return self.input_processor\n\n\nclass Qwen3VLInputProcessor(BaseModelInputProcessor):\n    \"\"\"Qwen3 input processor.\"\"\"\n\n    def __init__(self, config: PretrainedConfig) -> None:\n        self.config = config\n\n    def _make_image_mm_data(self, input_mm: Dict[str, Any]) -> MultiModalData:\n        \"\"\"Make image MultiModalData.\"\"\"\n        pixel_values = input_mm['pixel_values']\n        image_grid_thw = input_mm['image_grid_thw']\n        offset = input_mm['offset']\n        start = offset\n        image_token_id = input_mm['image_token_id']\n        num_pad = input_mm['image_tokens']\n        if isinstance(num_pad, torch.Tensor):\n            num_pad = num_pad.item()\n\n        mm_data = MultiModalData(modality=Modality.IMAGE,\n                                 data=pixel_values,\n                                 start=start,\n                                 end=start + num_pad,\n                                 meta=dict(grid_thw=image_grid_thw, image_token_id=image_token_id))\n        return mm_data\n\n    def _make_video_mm_data(self, input_mm: Dict[str, Any]) -> MultiModalData:\n        \"\"\"Make video MultiModalData.\"\"\"\n        pixel_values_videos = input_mm['pixel_values_videos']\n        video_grid_thw = input_mm['video_grid_thw']\n        offset = input_mm['offset']\n        start = offset\n        video_token_id = input_mm['video_token_id']\n        num_pad = input_mm['video_tokens']\n        if isinstance(num_pad, torch.Tensor):\n            num_pad = num_pad.item()\n\n        mm_data = MultiModalData(modality=Modality.VIDEO,\n                                 data=pixel_values_videos,\n                                 start=start,\n                                 end=start + num_pad,\n                                 meta=dict(\n                                     grid_thw=video_grid_thw,\n                                     video_token_id=video_token_id,\n                                 ))\n        return mm_data\n\n    def preprocess_input(self,\n                         input_ids: List[int],\n                         input_multimodals: List[Dict[str, Any]] = None,\n                         **kwargs) -> PreprocessInputResult:\n        \"\"\"Prepare multimodal input.\"\"\"\n        if input_multimodals is None or len(input_multimodals) == 0:\n            return input_ids, input_multimodals\n\n        input_mm_data = []\n        for input_mm in input_multimodals:\n            modality = input_mm.get('modality')\n            if modality == Modality.IMAGE:\n                mm_data = self._make_image_mm_data(input_mm)\n            elif modality == Modality.VIDEO:\n                mm_data = self._make_video_mm_data(input_mm)\n            input_mm_data.append(mm_data)\n\n        result = PreprocessInputResult(input_ids=input_ids, input_multimodals=dict(mm_data=input_mm_data))\n\n        return result\n"
  },
  {
    "path": "lmdeploy/pytorch/models/qwen3_vl_moe.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nfrom typing import Any, Dict, Iterable, List, Optional, Tuple\n\nimport torch\nfrom torch import nn\nfrom transformers.configuration_utils import PretrainedConfig\n\nfrom lmdeploy.pytorch.model_inputs import StepContextManager\nfrom lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight\n\nfrom .patch import add_prefix\nfrom .qwen3_moe import Qwen3MoeModel\nfrom .qwen3_vl import Qwen3VLForConditionalGeneration\nfrom .qwen3_vl import Qwen3VLTextRotaryEmbedding as Qwen3VLMoeTextRotaryEmbedding\n\n\nclass Qwen3VLMoeTextModel(Qwen3MoeModel):\n    \"\"\"Text part of Qwen3VL.\n\n    not a pure text-only model, as DeepStack integrates visual features into the early hidden states.\n    \"\"\"\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None,\n                 prefix: str = ''):\n        super().__init__(config=config, dtype=dtype, device=device, prefix=prefix)\n\n        # build rotary embedding\n        # TODO: zhouxinyu, add triton kernel for interleaved mrope\n        self.rotary_emb = Qwen3VLMoeTextRotaryEmbedding(config, device=device)\n\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        attn_metadata: Any = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        mrope_position_ids: torch.LongTensor = None,\n        # args for deepstack\n        visual_pos_masks: Optional[torch.Tensor] = None,\n        deepstack_visual_embeds: Optional[list[torch.Tensor]] = None,\n    ):\n        \"\"\"Rewrite of LlamaModel.forward.\"\"\"\n\n        # token embedding\n        if inputs_embeds is None:\n            inputs_embeds = self.embed_tokens(input_ids)\n\n        hidden_states = inputs_embeds\n\n        # rotary embedding\n        if mrope_position_ids is None:\n            cos, sin = self.rotary_emb(hidden_states, position_ids)\n        else:\n            mrope_position_ids = mrope_position_ids.unsqueeze(1)\n            cos, sin = self.rotary_emb(hidden_states, mrope_position_ids)\n\n        cos, sin = cos[0], sin[0]\n        rotary_pos_emb = (cos, sin)\n\n        # decoding\n        residual = None\n        for idx, decoder_layer in enumerate(self.layers):\n            past_key_value = past_key_values[idx]\n            hidden_states, residual = decoder_layer(\n                hidden_states,\n                rotary_pos_emb=rotary_pos_emb,\n                past_key_value=past_key_value,\n                residual=residual,\n                attn_metadata=attn_metadata,\n            )\n\n            # add visual features to the hidden states of first several layers\n            if deepstack_visual_embeds is not None and idx in range(len(deepstack_visual_embeds)):\n                hidden_states = hidden_states + residual\n                hidden_states = self._deepstack_process(\n                    hidden_states,\n                    visual_pos_masks,\n                    deepstack_visual_embeds[idx],\n                )\n                residual = None\n\n        # norm\n        hidden_states, _ = self.norm(hidden_states, residual)\n\n        return hidden_states\n\n    def _deepstack_process(self, hidden_states: torch.Tensor, visual_pos_masks: torch.Tensor,\n                           visual_embeds: torch.Tensor):\n        visual_pos_masks = visual_pos_masks.to(hidden_states.device)\n        visual_embeds = visual_embeds.to(hidden_states.device, hidden_states.dtype)\n        local = torch.zeros_like(hidden_states)\n        local.masked_scatter_(visual_pos_masks, visual_embeds)\n        hidden_states += local\n        return hidden_states\n\n\nclass Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):\n    \"\"\"ModelForCausalLM.\"\"\"\n\n    packed_modules_mapping = {\n        'qkv_proj': [\n            'q_proj',\n            'k_proj',\n            'v_proj',\n        ],\n        'gate_up_proj': [\n            'gate_proj',\n            'up_proj',\n        ],\n    }\n\n    def __init__(\n        self,\n        config: PretrainedConfig,\n        ctx_mgr: StepContextManager,\n        dtype: torch.dtype = None,\n        device: torch.device = None,\n        prefix: str = '',\n    ):\n        super().__init__(config=config, ctx_mgr=ctx_mgr, dtype=dtype, device=device, prefix=prefix)\n\n        self.language_model = Qwen3VLMoeTextModel(config.text_config,\n                                                  dtype=dtype,\n                                                  device=device,\n                                                  prefix=add_prefix('language_model', prefix))\n\n    def _load_weight_experts(self, name: str, loaded_weight: torch.Tensor, params_dict: Dict[str, nn.Parameter],\n                             expert_params_mapping: List):\n        \"\"\"Load weight experts.\"\"\"\n\n        for (param_name, weight_name, expert_id, shard_id) in expert_params_mapping:\n            if weight_name not in name:\n                continue\n            name = name.replace(weight_name, param_name)\n            param = params_dict[name]\n            load_weight(param, loaded_weight, expert_id=expert_id, shard_id=shard_id)\n            break\n        else:\n            param = params_dict[name]\n            load_weight(param, loaded_weight)\n\n    # modify from vllm qwen3vlmoe fused expert loading\n    def _load_weight_fused_experts(self, name: str, loaded_weight: torch.Tensor, params_dict: Dict[str, nn.Parameter],\n                                   fused_expert_params_mapping: List):\n        \"\"\"Load weight of fused expert weights.\"\"\"\n        num_experts = self.config.text_config.num_experts\n\n        for (param_name, weight_name) in fused_expert_params_mapping:\n            if weight_name not in name:\n                continue\n            name = name.replace(weight_name, param_name)\n            param = params_dict[name]\n\n            loaded_weight = loaded_weight.transpose(-1, -2)  # no bias\n            if 'gate_up' in name:\n                loaded_weight = loaded_weight.chunk(2, dim=-2)\n                w1 = loaded_weight[0]\n                w3 = loaded_weight[1]\n                for expert_id in range(num_experts):\n                    load_weight(param, w1[expert_id], expert_id=expert_id, shard_id='gate')\n                    load_weight(param, w3[expert_id], expert_id=expert_id, shard_id='up')\n            elif 'down' in name:\n                w2 = loaded_weight\n                for expert_id in range(num_experts):\n                    load_weight(param, w2[expert_id], expert_id=expert_id, shard_id='down')\n\n    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):\n        \"\"\"Load weights.\"\"\"\n        # modify from vllm\n        stacked_params_mapping = [\n            # (param_name, shard_name, shard_id)\n            ('.qkv_proj', '.q_proj', 'q'),\n            ('.qkv_proj', '.k_proj', 'k'),\n            ('.qkv_proj', '.v_proj', 'v'),\n            ('.gate_up_proj', '.gate_proj', 0),\n            ('.gate_up_proj', '.up_proj', 1),\n        ]\n\n        # expert mapping\n        num_experts = self.config.text_config.num_experts\n        expert_params_mapping = []\n        for exp_id in range(num_experts):\n            # (param_name, weight_name, expert_id, shard_id)\n            gate_param = ('.experts.gate_up', f'.experts.{exp_id}.gate_proj', exp_id, 'gate')\n            up_param = ('.experts.gate_up', f'.experts.{exp_id}.up_proj', exp_id, 'up')\n            down_param = ('.experts.down', f'.experts.{exp_id}.down_proj', exp_id, 'down')\n            expert_params_mapping += [gate_param, up_param, down_param]\n\n        # fused expert mapping\n        fused_expert_params_mapping = [\n            # (param_name, weight_name)\n            ('.experts.gate_up.weight', '.experts.gate_up_proj'),\n            ('.experts.down.weight', '.experts.down_proj'),\n        ]\n\n        params_dict = dict(self.named_parameters())\n        for name, loaded_weight in weights:\n            if 'rotary_emb.inv_freq' in name:\n                continue\n            if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):\n                continue\n            if self.config.tie_word_embeddings and 'lm_head.weight' in name:\n                continue\n            name = name.replace('.block_sparse_moe.', '.mlp.')\n            if '.experts' in name:\n                is_fused_expert = ('experts.gate_up_proj' in name or 'experts.down_proj' in name)\n                if is_fused_expert:\n                    self._load_weight_fused_experts(name,\n                                                    loaded_weight,\n                                                    params_dict,\n                                                    fused_expert_params_mapping=fused_expert_params_mapping)\n                else:\n                    self._load_weight_experts(name,\n                                              loaded_weight,\n                                              params_dict,\n                                              expert_params_mapping=expert_params_mapping)\n            else:\n                for (param_name, weight_name, shard_id) in stacked_params_mapping:\n                    if weight_name not in name:\n                        continue\n                    name = name.replace(weight_name, param_name)\n                    param = params_dict[name]\n                    load_weight(param, loaded_weight, shard_id=shard_id)\n                    break\n                else:\n                    if '.qkv.' in name:\n                        param = params_dict[name]\n                        q, k, v = param.weight_spliter(loaded_weight)\n                        load_weight(param, q, shard_id='q')\n                        load_weight(param, k, shard_id='k')\n                        load_weight(param, v, shard_id='v')\n                    else:\n                        param = params_dict[name]\n                        load_weight(param, loaded_weight)\n"
  },
  {
    "path": "lmdeploy/pytorch/models/sdar.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nfrom typing import Any, Iterable, List, Optional, Tuple\n\nimport torch\nfrom torch import nn\nfrom transformers.configuration_utils import PretrainedConfig\n\nfrom lmdeploy.pytorch.model_inputs import StepContext, StepContextManager\nfrom lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, RMSNorm, SiluAndMul, build_rotary_embedding_from_config\nfrom lmdeploy.pytorch.nn.linear import (build_down_linear, build_gateup_linear, build_o_proj, build_qkv_proj,\n                                        build_rowwise_linear)\nfrom lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight\n\nfrom .utils.cudagraph import CudaGraphMixin\n\n\nclass SDARAttention(nn.Module):\n    \"\"\"attention.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        quantization_config = getattr(config, 'quantization_config', None)\n        num_heads = config.num_attention_heads\n        num_key_value_heads = config.num_key_value_heads\n        hidden_size = config.hidden_size\n        head_dim = getattr(config, 'head_dim', hidden_size // num_heads)\n        num_replicate_kv_heads = getattr(config, 'num_replicate_key_value_heads', 1)\n        # packed qkv\n        # Qwen3 uses 'config.attention_bias = False' for q/k/o projections\n        self.qkv_proj = build_qkv_proj(hidden_size,\n                                       num_q_heads=num_heads,\n                                       num_kv_heads=num_key_value_heads,\n                                       head_size=head_dim,\n                                       bias=config.attention_bias,\n                                       quant_config=quantization_config,\n                                       dtype=dtype,\n                                       device=device,\n                                       num_replicate_kv_heads=num_replicate_kv_heads)\n\n        # rotary embedding\n        self.apply_rotary_pos_emb = ApplyRotaryEmb()\n        dllm_block_length = config.dllm_block_length\n\n        # attention\n        self.attn_fwd = Attention(\n            num_heads,\n            head_dim,\n            num_kv_heads=num_key_value_heads,\n            v_head_size=head_dim,\n            sliding_window=config.sliding_window,\n            block_sparse_size=dllm_block_length,\n        )\n\n        # o_proj\n        self.o_proj = build_o_proj(num_heads * head_dim,\n                                   hidden_size,\n                                   bias=config.attention_bias,\n                                   quant_config=quantization_config,\n                                   dtype=dtype,\n                                   device=device,\n                                   is_tp=True)\n\n        # q, k norm\n        self.q_norm = RMSNorm(head_dim, config.rms_norm_eps, dtype=dtype, device=device)\n        self.k_norm = RMSNorm(head_dim, config.rms_norm_eps, dtype=dtype, device=device)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],\n        past_key_value: Optional[Tuple[torch.Tensor]] = None,\n        attn_metadata: Any = None,\n    ):\n        \"\"\"Rewrite of LlamaAttention.forward.\"\"\"\n        # qkv proj\n        qkv_states = self.qkv_proj(hidden_states)\n        # (-1, heads, head_dim)\n        qkv_states = qkv_states.flatten(0, -2)\n        query_states, key_states, value_states = self.qkv_proj.split_qkv(qkv_states)\n\n        # apply q, k norm\n        query_states = self.q_norm(query_states)\n        key_states = self.k_norm(key_states)\n\n        # apply rotary embedding\n        cos, sin = rotary_pos_emb\n        query_states, key_states = self.apply_rotary_pos_emb(\n            query_states,\n            key_states,\n            cos,\n            sin,\n        )\n        # attention\n        attn_output = self.attn_fwd(\n            query_states,\n            key_states,\n            value_states,\n            past_key_value[0],\n            past_key_value[1],\n            attn_metadata,\n            k_scales_zeros=None if len(past_key_value) == 2 else past_key_value[2],\n            v_scales_zeros=None if len(past_key_value) == 2 else past_key_value[3],\n            inplace=True,\n        )\n        attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1)\n\n        # o proj\n        attn_output = self.o_proj(attn_output)\n        return attn_output\n\n\nclass SDARMLP(nn.Module):\n    \"\"\"mlp.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        quantization_config = getattr(config, 'quantization_config', None)\n        # gate up\n        self.gate_up_proj = build_gateup_linear(\n            config.hidden_size,\n            [config.intermediate_size, config.intermediate_size],\n            bias=False,\n            dtype=dtype,\n            device=device,\n            quant_config=quantization_config,\n            is_tp=True,\n        )\n\n        # silu and mul\n        self.act_fn = SiluAndMul(inplace=True)\n\n        # down\n        self.down_proj = build_down_linear(config.intermediate_size,\n                                           config.hidden_size,\n                                           bias=False,\n                                           quant_config=quantization_config,\n                                           dtype=dtype,\n                                           device=device,\n                                           is_tp=True)\n\n    def forward(self, x):\n        \"\"\"forward.\"\"\"\n        gate_up = self.gate_up_proj(x)\n        act = self.act_fn(gate_up)\n        return self.down_proj(act)\n\n\nclass SDARDecoderLayer(nn.Module):\n    \"\"\"Decode layer.\"\"\"\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 layer_idx: int,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__()\n        self.layer_idx = layer_idx\n        quantization_config = getattr(config, 'quantization_config', None)\n\n        # build attention layer\n        self.self_attn = SDARAttention(config, dtype=dtype, device=device)\n\n        # build MLP\n        self.mlp = SDARMLP(config, dtype=dtype, device=device)\n\n        # build input layer norm\n        self.input_layernorm = RMSNorm(config.hidden_size,\n                                       config.rms_norm_eps,\n                                       quant_config=quantization_config,\n                                       dtype=dtype,\n                                       device=device)\n\n        # build attention layer norm\n        self.post_attention_layernorm = RMSNorm(config.hidden_size,\n                                                config.rms_norm_eps,\n                                                quant_config=quantization_config,\n                                                dtype=dtype,\n                                                device=device)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],\n        past_key_value: Optional[List[torch.FloatTensor]],\n        residual: Optional[torch.Tensor] = None,\n        attn_metadata: Any = None,\n    ):\n\n        if residual is None:\n            residual = hidden_states\n            hidden_states = self.input_layernorm(hidden_states)\n        else:\n            hidden_states, residual = self.input_layernorm(hidden_states, residual)\n\n        # Self Attention\n        hidden_states = self.self_attn(\n            hidden_states=hidden_states,\n            rotary_pos_emb=rotary_pos_emb,\n            past_key_value=past_key_value,\n            attn_metadata=attn_metadata,\n        )\n\n        # Fully Connected\n        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)\n        hidden_states = self.mlp(hidden_states)\n\n        outputs = (hidden_states, residual)\n        return outputs\n\n\nclass SDARModel(nn.Module):\n    \"\"\"model.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n\n        self.embed_tokens = nn.Embedding(config.vocab_size,\n                                         config.hidden_size,\n                                         self.padding_idx,\n                                         dtype=dtype,\n                                         device=device)\n\n        # build all decode layers\n        self.layers = nn.ModuleList([\n            SDARDecoderLayer(config, layer_idx, dtype=dtype, device=device)\n            for layer_idx in range(config.num_hidden_layers)\n        ])\n\n        # build norm\n        self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)\n\n        # build rotary embedding\n        self.rotary_emb = build_rotary_embedding_from_config(config)\n\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        attn_metadata: Any = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n    ):\n        \"\"\"Rewrite of forward.\"\"\"\n\n        # token embedding\n        if inputs_embeds is None:\n            inputs_embeds = self.embed_tokens(input_ids)\n\n        hidden_states = inputs_embeds\n\n        # rotary embedding\n        cos, sin = self.rotary_emb(hidden_states, position_ids)\n        cos, sin = cos[0], sin[0]\n        rotary_pos_emb = (cos, sin)\n\n        # decoding\n        residual = None\n        for idx, decoder_layer in enumerate(self.layers):\n            past_key_value = past_key_values[idx]\n            hidden_states, residual = decoder_layer(\n                hidden_states,\n                rotary_pos_emb=rotary_pos_emb,\n                past_key_value=past_key_value,\n                residual=residual,\n                attn_metadata=attn_metadata,\n            )\n\n        # norm\n        hidden_states, _ = self.norm(hidden_states, residual)\n\n        return hidden_states\n\n    def get_input_embeddings(self):\n        \"\"\"Get input embeddings.\"\"\"\n        return self.embed_tokens\n\n\nclass SDARForCausalLM(nn.Module, CudaGraphMixin):\n    \"\"\"ModelForCausalLM.\"\"\"\n\n    packed_modules_mapping = {\n        'qkv_proj': [\n            'q_proj',\n            'k_proj',\n            'v_proj',\n        ],\n        'gate_up_proj': [\n            'gate_proj',\n            'up_proj',\n        ],\n    }\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 ctx_mgr: StepContextManager,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__()\n        self.config = config\n        self.ctx_mgr = ctx_mgr\n        config.dllm_block_length = ctx_mgr.build_ctx.dllm_config.block_length\n        # build model\n        self.model = SDARModel(config, dtype=dtype, device=device)\n        # build lm_head\n        self.lm_head = build_rowwise_linear(config.hidden_size,\n                                            config.vocab_size,\n                                            bias=False,\n                                            dtype=dtype,\n                                            device=device)\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        past_key_values: List[List[torch.Tensor]],\n        attn_metadata: Any = None,\n        inputs_embeds: torch.Tensor = None,\n        **kwargs,\n    ):\n        \"\"\"Model forward, return logits.\"\"\"\n        hidden_states = self.model(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            inputs_embeds=inputs_embeds,\n        )\n        return hidden_states\n\n    def get_logits(self, hidden_states: torch.Tensor):\n        \"\"\"Compute logits of the model output.\"\"\"\n        return self.lm_head(hidden_states)\n\n    def update_weights(self):\n        \"\"\"Update weights.\"\"\"\n        if self.config.tie_word_embeddings:\n            self.lm_head.weight = self.model.embed_tokens.weight\n\n    def get_input_embeddings(self):\n        \"\"\"Get input embeddings.\"\"\"\n        return self.model.get_input_embeddings()\n\n    def prepare_inputs_for_generation(\n        self,\n        past_key_values: List[List[torch.Tensor]],\n        inputs_embeds: Optional[torch.Tensor] = None,\n        context: StepContext = None,\n    ):\n        \"\"\"Prepare input.\"\"\"\n        # get input_ids, position_ids and attention metadatas\n        input_ids = context.input_ids\n        position_ids = context.position_ids\n        attn_metadata = context.attn_metadata\n\n        # process vision embeddings\n        vision_embeddings = context.input_embeddings\n        vision_embedding_indexing = context.input_embedding_indexing\n        if vision_embeddings is not None and len(vision_embeddings) > 0:\n            if inputs_embeds is None:\n                inputs_embeds = self.get_input_embeddings()(input_ids)\n            inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds)\n\n        # inputs of forward\n        return dict(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            inputs_embeds=inputs_embeds,\n        )\n\n    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):\n        \"\"\"Load weights.\"\"\"\n        # modify from vllm\n        stacked_params_mapping = [\n            # (param_name, shard_name, shard_id)\n            ('.qkv_proj', '.q_proj', 'q'),\n            ('.qkv_proj', '.k_proj', 'k'),\n            ('.qkv_proj', '.v_proj', 'v'),\n            ('.gate_up_proj', '.gate_proj', 0),\n            ('.gate_up_proj', '.up_proj', 1),\n        ]\n\n        params_dict = dict(self.named_parameters())\n        for name, loaded_weight in weights:\n            if 'rotary_emb.inv_freq' in name:\n                continue\n            if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):\n                continue\n            if self.config.tie_word_embeddings and 'lm_head.weight' in name:\n                continue\n\n            for (param_name, weight_name, shard_id) in stacked_params_mapping:\n                if weight_name not in name:\n                    continue\n                name = name.replace(weight_name, param_name)\n                param = params_dict[name]\n                load_weight(param, loaded_weight, shard_id=shard_id)\n                break\n            else:\n                param = params_dict[name]\n                load_weight(param, loaded_weight)\n"
  },
  {
    "path": "lmdeploy/pytorch/models/sdar_moe.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nfrom typing import Any, Dict, Iterable, List, Optional, Tuple\n\nimport torch\nfrom torch import nn\nfrom transformers.configuration_utils import PretrainedConfig\n\nfrom lmdeploy.pytorch.model_inputs import StepContext, StepContextManager\nfrom lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, RMSNorm, SiluAndMul, build_rotary_embedding_from_config\nfrom lmdeploy.pytorch.nn.linear import (build_down_linear, build_gateup_linear, build_o_proj, build_qkv_proj,\n                                        build_rowwise_linear)\nfrom lmdeploy.pytorch.nn.moe import SoftmaxTopK, build_fused_moe\nfrom lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight\n\nfrom .utils.cudagraph import CudaGraphMixin\n\n\nclass SDARMoeAttention(nn.Module):\n    \"\"\"attention.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        quantization_config = getattr(config, 'quantization_config', None)\n        num_heads = config.num_attention_heads\n        num_key_value_heads = config.num_key_value_heads\n        hidden_size = config.hidden_size\n        head_dim = getattr(config, 'head_dim', hidden_size // num_heads)\n        num_replicate_kv_heads = getattr(config, 'num_replicate_key_value_heads', 1)\n        # packed qkv\n        # Qwen3 uses 'config.attention_bias = False' for q/k/o projections\n        self.qkv_proj = build_qkv_proj(hidden_size,\n                                       num_q_heads=num_heads,\n                                       num_kv_heads=num_key_value_heads,\n                                       head_size=head_dim,\n                                       bias=config.attention_bias,\n                                       quant_config=quantization_config,\n                                       dtype=dtype,\n                                       device=device,\n                                       num_replicate_kv_heads=num_replicate_kv_heads)\n\n        # rotary embedding\n        self.apply_rotary_pos_emb = ApplyRotaryEmb()\n        dllm_block_length = config.dllm_block_length\n\n        # attention\n        self.attn_fwd = Attention(\n            num_heads,\n            head_dim,\n            num_kv_heads=num_key_value_heads,\n            v_head_size=head_dim,\n            sliding_window=config.sliding_window,\n            block_sparse_size=dllm_block_length,\n        )\n\n        # o_proj\n        self.o_proj = build_o_proj(num_heads * head_dim,\n                                   hidden_size,\n                                   bias=config.attention_bias,\n                                   quant_config=quantization_config,\n                                   dtype=dtype,\n                                   device=device,\n                                   is_tp=True)\n\n        # q, k norm\n        self.q_norm = RMSNorm(head_dim, config.rms_norm_eps, dtype=dtype, device=device)\n        self.k_norm = RMSNorm(head_dim, config.rms_norm_eps, dtype=dtype, device=device)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],\n        past_key_value: Optional[Tuple[torch.Tensor]] = None,\n        attn_metadata: Any = None,\n    ):\n        \"\"\"Rewrite of LlamaAttention.forward.\"\"\"\n        # qkv proj\n        qkv_states = self.qkv_proj(hidden_states)\n        # (-1, heads, head_dim)\n        qkv_states = qkv_states.flatten(0, -2)\n        query_states, key_states, value_states = self.qkv_proj.split_qkv(qkv_states)\n\n        # apply q, k norm\n        query_states = self.q_norm(query_states)\n        key_states = self.k_norm(key_states)\n\n        # apply rotary embedding\n        cos, sin = rotary_pos_emb\n        query_states, key_states = self.apply_rotary_pos_emb(\n            query_states,\n            key_states,\n            cos,\n            sin,\n        )\n        # attention\n        attn_output = self.attn_fwd(\n            query_states,\n            key_states,\n            value_states,\n            past_key_value[0],\n            past_key_value[1],\n            attn_metadata,\n            k_scales_zeros=None if len(past_key_value) == 2 else past_key_value[2],\n            v_scales_zeros=None if len(past_key_value) == 2 else past_key_value[3],\n            inplace=True,\n        )\n        attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1)\n\n        # o proj\n        attn_output = self.o_proj(attn_output)\n        return attn_output\n\n\nclass SDARMoeMLP(nn.Module):\n    \"\"\"mlp.\"\"\"\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 intermediate_size: int,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__()\n        quantization_config = getattr(config, 'quantization_config', None)\n        # gate up\n        self.gate_up_proj = build_gateup_linear(\n            config.hidden_size,\n            [intermediate_size, intermediate_size],\n            bias=False,\n            dtype=dtype,\n            device=device,\n            quant_config=quantization_config,\n            is_tp=True,\n        )\n\n        # silu and mul\n        self.act_fn = SiluAndMul(inplace=True)\n\n        # down\n        self.down_proj = build_down_linear(intermediate_size,\n                                           config.hidden_size,\n                                           bias=False,\n                                           quant_config=quantization_config,\n                                           dtype=dtype,\n                                           device=device,\n                                           is_tp=True)\n\n    def forward(self, x):\n        \"\"\"forward.\"\"\"\n        gate_up = self.gate_up_proj(x)\n        act = self.act_fn(gate_up)\n        return self.down_proj(act)\n\n\nclass SDARMoeSparseMoeBlock(nn.Module):\n    \"\"\"SDARMoeSparseMoeBlock.\"\"\"\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 layer_idx: int,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__()\n        self.config = config\n        quantization_config = getattr(config, 'quantization_config', None)\n        self.hidden_dim = config.hidden_size\n        self.ffn_dim = config.moe_intermediate_size\n        self.num_experts = config.num_experts\n        self.top_k = config.num_experts_per_tok\n        self.renormalize = config.norm_topk_prob\n\n        self.gate = build_rowwise_linear(\n            self.hidden_dim,\n            self.num_experts,\n            bias=False,\n            dtype=dtype,\n            device=device,\n            is_tp=False,\n        )\n\n        self.softmax_topk = SoftmaxTopK(\n            self.top_k,\n            n_groups=getattr(config, 'router_n_groups', -1),\n        )\n\n        self.experts = build_fused_moe(\n            self.hidden_dim,\n            self.ffn_dim,\n            self.num_experts,\n            top_k=self.top_k,\n            renormalize=self.renormalize,\n            dtype=dtype,\n            device=device,\n            quant_config=quantization_config,\n            all_reduce=True,\n            layer_idx=layer_idx,\n        )\n\n    def forward(self, hidden_states: torch.Tensor):\n        \"\"\"forward.\"\"\"\n        batch_size, sequence_length, hidden_dim = hidden_states.shape\n        hidden_states = hidden_states.view(-1, hidden_dim)\n        router_logits = self.gate(hidden_states)\n        topk_weights, topk_ids = self.softmax_topk(router_logits)\n        out_states = self.experts(\n            hidden_states,\n            topk_weights,\n            topk_ids,\n        )\n\n        out_states = out_states.reshape(batch_size, sequence_length, -1)\n        return out_states\n\n\nclass SDARMoeDecoderLayer(nn.Module):\n    \"\"\"Decode layer.\"\"\"\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 layer_idx: int,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__()\n        self.layer_idx = layer_idx\n        quantization_config = getattr(config, 'quantization_config', None)\n\n        # build attention layer\n        self.self_attn = SDARMoeAttention(config, dtype=dtype, device=device)\n\n        # build MLP\n        if (layer_idx not in config.mlp_only_layers) and (config.num_experts > 0 and\n                                                          (layer_idx + 1) % config.decoder_sparse_step == 0):\n            self.mlp = SDARMoeSparseMoeBlock(config, layer_idx, dtype=dtype, device=device)\n        else:\n            self.mlp = SDARMoeMLP(config, intermediate_size=config.intermediate_size, dtype=dtype, device=device)\n\n        # build input layer norm\n        self.input_layernorm = RMSNorm(config.hidden_size,\n                                       config.rms_norm_eps,\n                                       quant_config=quantization_config,\n                                       dtype=dtype,\n                                       device=device)\n\n        # build attention layer norm\n        self.post_attention_layernorm = RMSNorm(config.hidden_size,\n                                                config.rms_norm_eps,\n                                                quant_config=quantization_config,\n                                                dtype=dtype,\n                                                device=device)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],\n        past_key_value: Optional[List[torch.FloatTensor]],\n        residual: Optional[torch.Tensor] = None,\n        attn_metadata: Any = None,\n    ):\n\n        if residual is None:\n            residual = hidden_states\n            hidden_states = self.input_layernorm(hidden_states)\n        else:\n            hidden_states, residual = self.input_layernorm(hidden_states, residual)\n\n        # Self Attention\n        hidden_states = self.self_attn(\n            hidden_states=hidden_states,\n            rotary_pos_emb=rotary_pos_emb,\n            past_key_value=past_key_value,\n            attn_metadata=attn_metadata,\n        )\n\n        # Fully Connected\n        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)\n        hidden_states = self.mlp(hidden_states)\n\n        outputs = (hidden_states, residual)\n        return outputs\n\n\nclass SDARMoeModel(nn.Module):\n    \"\"\"SDAR model.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n\n        self.embed_tokens = nn.Embedding(config.vocab_size,\n                                         config.hidden_size,\n                                         self.padding_idx,\n                                         dtype=dtype,\n                                         device=device)\n\n        # build all decode layers\n        self.layers = nn.ModuleList([\n            SDARMoeDecoderLayer(config, layer_idx, dtype=dtype, device=device)\n            for layer_idx in range(config.num_hidden_layers)\n        ])\n\n        # build norm\n        self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)\n\n        # build rotary embedding\n        self.rotary_emb = build_rotary_embedding_from_config(config)\n\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        attn_metadata: Any = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n    ):\n        \"\"\"Rewrite of forward.\"\"\"\n\n        # token embedding\n        if inputs_embeds is None:\n            inputs_embeds = self.embed_tokens(input_ids)\n\n        hidden_states = inputs_embeds\n\n        # rotary embedding\n        cos, sin = self.rotary_emb(hidden_states, position_ids)\n        cos, sin = cos[0], sin[0]\n        rotary_pos_emb = (cos, sin)\n\n        # decoding\n        residual = None\n        for idx, decoder_layer in enumerate(self.layers):\n            past_key_value = past_key_values[idx]\n            hidden_states, residual = decoder_layer(\n                hidden_states,\n                rotary_pos_emb=rotary_pos_emb,\n                past_key_value=past_key_value,\n                residual=residual,\n                attn_metadata=attn_metadata,\n            )\n\n        # norm\n        hidden_states, _ = self.norm(hidden_states, residual)\n\n        return hidden_states\n\n    def get_input_embeddings(self):\n        \"\"\"Get input embeddings.\"\"\"\n        return self.embed_tokens\n\n\nclass SDARMoeForCausalLM(nn.Module, CudaGraphMixin):\n    \"\"\"ModelForCausalLM.\"\"\"\n\n    packed_modules_mapping = {\n        'qkv_proj': [\n            'q_proj',\n            'k_proj',\n            'v_proj',\n        ],\n        'gate_up_proj': [\n            'gate_proj',\n            'up_proj',\n        ],\n    }\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 ctx_mgr: StepContextManager,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__()\n        self.config = config\n        self.ctx_mgr = ctx_mgr\n        config.dllm_block_length = ctx_mgr.build_ctx.dllm_config.block_length\n        # build model\n        self.model = SDARMoeModel(config, dtype=dtype, device=device)\n        # build lm_head\n        self.lm_head = build_rowwise_linear(config.hidden_size,\n                                            config.vocab_size,\n                                            bias=False,\n                                            dtype=dtype,\n                                            device=device)\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        past_key_values: List[List[torch.Tensor]],\n        attn_metadata: Any = None,\n        inputs_embeds: torch.Tensor = None,\n        **kwargs,\n    ):\n        \"\"\"Model forward, return logits.\"\"\"\n        hidden_states = self.model(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            inputs_embeds=inputs_embeds,\n        )\n        return hidden_states\n\n    def get_logits(self, hidden_states: torch.Tensor):\n        \"\"\"Compute logits of the model output.\"\"\"\n        return self.lm_head(hidden_states)\n\n    def update_weights(self):\n        \"\"\"Update weights.\"\"\"\n        if self.config.tie_word_embeddings:\n            self.lm_head.weight = self.model.embed_tokens.weight\n\n    def get_input_embeddings(self):\n        \"\"\"Get input embeddings.\"\"\"\n        return self.model.get_input_embeddings()\n\n    def prepare_inputs_for_generation(\n        self,\n        past_key_values: List[List[torch.Tensor]],\n        inputs_embeds: Optional[torch.Tensor] = None,\n        context: StepContext = None,\n    ):\n        \"\"\"Prepare input.\"\"\"\n        # get input_ids, position_ids and attention metadatas\n        input_ids = context.input_ids\n        position_ids = context.position_ids\n        attn_metadata = context.attn_metadata\n\n        # process vision embeddings\n        vision_embeddings = context.input_embeddings\n        vision_embedding_indexing = context.input_embedding_indexing\n        if vision_embeddings is not None and len(vision_embeddings) > 0:\n            if inputs_embeds is None:\n                inputs_embeds = self.get_input_embeddings()(input_ids)\n            inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds)\n\n        # inputs of forward\n        return dict(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            inputs_embeds=inputs_embeds,\n        )\n\n    def _load_weight_experts(self, name: str, loaded_weight: torch.Tensor, params_dict: Dict[str, nn.Parameter],\n                             expert_params_mapping: List):\n        \"\"\"Load weight experts.\"\"\"\n        # load fused weights\n        for (param_name, weight_name, expert_id, shard_id) in expert_params_mapping:\n            if weight_name not in name:\n                continue\n            name = name.replace(weight_name, param_name)\n            param = params_dict[name]\n            load_weight(param, loaded_weight, expert_id=expert_id, shard_id=shard_id)\n            break\n        else:\n            param = params_dict[name]\n            load_weight(param, loaded_weight)\n\n    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):\n        \"\"\"Load weights.\"\"\"\n        # modify from vllm\n        stacked_params_mapping = [\n            # (param_name, shard_name, shard_id)\n            ('.qkv_proj', '.q_proj', 'q'),\n            ('.qkv_proj', '.k_proj', 'k'),\n            ('.qkv_proj', '.v_proj', 'v'),\n            ('.gate_up_proj', '.gate_proj', 0),\n            ('.gate_up_proj', '.up_proj', 1),\n        ]\n\n        # expert map\n        num_experts = self.config.num_experts\n        expert_params_mapping = []\n        for exp_id in range(num_experts):\n            gate_param = ('.experts.gate_up', f'.experts.{exp_id}.gate_proj', exp_id, 'gate')\n            up_param = ('.experts.gate_up', f'.experts.{exp_id}.up_proj', exp_id, 'up')\n            down_param = ('.experts.down', f'.experts.{exp_id}.down_proj', exp_id, 'down')\n            expert_params_mapping += [gate_param, up_param, down_param]\n\n        params_dict = dict(self.named_parameters())\n        for name, loaded_weight in weights:\n            if 'rotary_emb.inv_freq' in name:\n                continue\n            if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):\n                continue\n            if self.config.tie_word_embeddings and 'lm_head.weight' in name:\n                continue\n\n            if '.experts' in name:\n                self._load_weight_experts(name, loaded_weight, params_dict, expert_params_mapping=expert_params_mapping)\n            else:\n                for (param_name, weight_name, shard_id) in stacked_params_mapping:\n                    if weight_name not in name:\n                        continue\n                    name = name.replace(weight_name, param_name)\n                    param = params_dict[name]\n                    load_weight(param, loaded_weight, shard_id=shard_id)\n                    break\n                else:\n                    param = params_dict[name]\n                    load_weight(param, loaded_weight)\n"
  },
  {
    "path": "lmdeploy/pytorch/models/siglip.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nimport math\nfrom typing import Iterable, Set, Tuple, Union\n\nimport torch\nfrom torch import nn\nfrom transformers import SiglipVisionConfig\n\nfrom lmdeploy.pytorch.model_inputs import StepContextManager\nfrom lmdeploy.pytorch.nn.linear import build_colwise_linear, build_qkv_proj, build_rowwise_linear\nfrom lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight\n\n\nclass SiglipVisionEmbeddings(nn.Module):\n\n    def __init__(self,\n                 config: SiglipVisionConfig,\n                 ctx_mgr: StepContextManager,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None,\n                 **kwargs):\n        super().__init__()\n        self.config = config\n        self.embed_dim = config.hidden_size\n        self.image_size = config.image_size\n        self.patch_size = config.patch_size\n\n        self.patch_embedding = nn.Conv2d(in_channels=config.num_channels,\n                                         out_channels=self.embed_dim,\n                                         kernel_size=self.patch_size,\n                                         stride=self.patch_size,\n                                         padding='valid',\n                                         dtype=dtype,\n                                         device=device)\n\n        self.num_patches = (self.image_size // self.patch_size)**2\n        self.num_positions = self.num_patches\n        self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim, dtype=dtype, device=device)\n        self.register_buffer('position_ids', torch.arange(self.num_positions).expand((1, -1)), persistent=False)\n\n    def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:\n        \"\"\"This method allows to interpolate the pre-trained position\n        encodings, to be able to use the model on higher resolution images.\n        This method is also adapted to support torch.jit tracing and no class\n        embeddings.\n\n        Adapted from:\n        - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and\n        - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211\n        \"\"\"  # noqa\n\n        num_patches = embeddings.shape[1]\n        num_positions = self.position_embedding.weight.shape[0]\n\n        # always interpolate when tracing to ensure the exported model works for dynamic input shapes\n        if not torch.jit.is_tracing() and num_patches == num_positions and height == width:\n            return self.position_embedding(self.position_ids)\n\n        patch_pos_embed = self.position_embedding.weight.unsqueeze(0)\n\n        dim = embeddings.shape[-1]\n\n        new_height = height // self.patch_size\n        new_width = width // self.patch_size\n\n        sqrt_num_positions = int(math.sqrt(num_positions))\n        patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)\n        patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)\n\n        patch_pos_embed = nn.functional.interpolate(\n            patch_pos_embed,\n            size=(new_height, new_width),\n            mode='bicubic',\n            align_corners=False,\n        )\n\n        patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)\n        return patch_pos_embed\n\n    def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor:\n        _, _, height, width = pixel_values.shape\n        target_dtype = self.patch_embedding.weight.dtype\n        patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))  # shape = [*, width, grid, grid]\n        embeddings = patch_embeds.flatten(2).transpose(1, 2)\n\n        if interpolate_pos_encoding:\n            embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)\n        else:\n            embeddings = embeddings + self.position_embedding(self.position_ids)\n        return embeddings\n\n\nclass SiglipAttention(nn.Module):\n\n    def __init__(self,\n                 config: SiglipVisionConfig,\n                 ctx_mgr: StepContextManager,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None,\n                 **kwargs) -> None:\n        super().__init__()\n        self.config = config\n        self.ctx_mgr = ctx_mgr\n        quantization_config = getattr(config, 'quantization_config', None)\n        self.embed_dim = config.hidden_size\n        self.num_heads = config.num_attention_heads\n        self.head_dim = self.embed_dim // self.num_heads\n        if self.head_dim * self.num_heads != self.embed_dim:\n            raise ValueError(\n                f'embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:'\n                f' {self.num_heads}).')\n\n        self.scale = self.head_dim**-0.5\n        self.dropout = config.attention_dropout\n        self.qkv_proj = build_qkv_proj(self.embed_dim,\n                                       num_q_heads=self.num_heads,\n                                       num_kv_heads=self.num_heads,\n                                       head_size=self.head_dim,\n                                       quant_config=quantization_config,\n                                       dtype=dtype,\n                                       bias=True,\n                                       device=device)\n\n        self.out_proj = build_rowwise_linear(self.embed_dim,\n                                             self.embed_dim,\n                                             bias=True,\n                                             quant_config=quantization_config,\n                                             dtype=dtype,\n                                             device=device,\n                                             is_tp=True)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n    ) -> torch.Tensor:\n        \"\"\"Input shape: Batch x Time x Channel.\"\"\"\n        batch_size, q_len, _ = hidden_states.size()\n        qkv_states = self.qkv_proj(hidden_states)\n        query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)\n        query_states = query_states.view(batch_size, q_len, -1, self.head_dim).transpose(1, 2)\n        key_states = key_states.view(batch_size, q_len, -1, self.head_dim).transpose(1, 2)\n        value_states = value_states.view(batch_size, q_len, -1, self.head_dim).transpose(1, 2)\n\n        out = nn.functional.scaled_dot_product_attention(query_states, key_states, value_states, scale=self.scale)\n        out = out.transpose(1, 2).contiguous().view(batch_size, q_len, -1)\n        attn_output = self.out_proj(out)\n\n        return attn_output, None\n\n\nclass SiglipMLP(nn.Module):\n\n    def __init__(self,\n                 config: SiglipVisionConfig,\n                 ctx_mgr: StepContextManager,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None,\n                 **kwargs) -> None:\n        super().__init__()\n        from transformers.activations import ACT2FN\n        self.config = config\n        self.ctx_mgr = ctx_mgr\n        self.activation_fn = ACT2FN[config.hidden_act]\n        quantization_config = getattr(config, 'quantization_config', None)\n        self.fc1 = build_colwise_linear(config.hidden_size,\n                                        config.intermediate_size,\n                                        bias=True,\n                                        dtype=dtype,\n                                        device=device,\n                                        is_tp=True,\n                                        quant_config=quantization_config)\n        self.fc2 = build_rowwise_linear(config.intermediate_size,\n                                        config.hidden_size,\n                                        bias=True,\n                                        quant_config=quantization_config,\n                                        dtype=dtype,\n                                        device=device,\n                                        is_tp=True)\n\n    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        \"\"\"forward.\"\"\"\n        hidden_states = self.fc1(hidden_states)\n        hidden_states = self.activation_fn(hidden_states)\n        hidden_states = self.fc2(hidden_states)\n        return hidden_states\n\n\nclass SiglipEncoderLayer(nn.Module):\n\n    def __init__(self,\n                 config: SiglipVisionConfig,\n                 ctx_mgr: StepContextManager,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None,\n                 **kwargs) -> None:\n        super().__init__()\n\n        self.embed_dim = config.hidden_size\n\n        self.self_attn = SiglipAttention(config, ctx_mgr=ctx_mgr, dtype=dtype, device=device)\n        self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps, dtype=dtype, device=device)\n        self.mlp = SiglipMLP(config, ctx_mgr=ctx_mgr, dtype=dtype, device=device)\n        self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps, dtype=dtype, device=device)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n    ) -> Tuple[torch.Tensor, None]:\n        residual = hidden_states\n        hidden_states = self.layer_norm1(hidden_states)\n        hidden_states, _ = self.self_attn(hidden_states=hidden_states)\n        hidden_states = residual + hidden_states\n\n        residual = hidden_states\n        hidden_states = self.layer_norm2(hidden_states)\n        hidden_states = self.mlp(hidden_states)\n        hidden_states = residual + hidden_states\n\n        return hidden_states, None\n\n\nclass SiglipEncoder(nn.Module):\n\n    def __init__(self,\n                 config: SiglipVisionConfig,\n                 ctx_mgr: StepContextManager,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None,\n                 **kwargs) -> None:\n        super().__init__()\n\n        self.config = config\n        num_hidden_layers = config.num_hidden_layers\n\n        self.layers = nn.ModuleList([\n            SiglipEncoderLayer(config, ctx_mgr=ctx_mgr, dtype=dtype, device=device)\n            for layer_idx in range(num_hidden_layers)\n        ])\n\n    def forward(\n        self,\n        inputs_embeds: torch.Tensor,\n        **kwargs,\n    ) -> Union[torch.Tensor, list[torch.Tensor]]:\n        hidden_states = inputs_embeds\n\n        for encoder_layer in self.layers:\n            hidden_states, _ = encoder_layer(hidden_states)\n        return hidden_states\n\n\nclass SiglipMultiheadAttentionPoolingHead(nn.Module):\n    \"\"\"Multihead Attention Pooling.\"\"\"\n\n    def __init__(\n        self,\n        config: SiglipVisionConfig,\n        ctx_mgr: StepContextManager,\n        dtype: torch.dtype = None,\n        device: torch.device = None,\n        **kwargs,\n    ) -> None:\n        super().__init__()\n\n        self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))\n        self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True)\n        self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, dtype=dtype, device=device)\n        self.mlp = SiglipMLP(config=config, ctx_mgr=ctx_mgr, dtype=dtype, device=device)\n\n    def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:\n        batch_size = hidden_state.shape[0]\n        probe = self.probe.repeat(batch_size, 1, 1)\n\n        hidden_state = self.attention(probe, hidden_state, hidden_state)[0]\n\n        residual = hidden_state\n        hidden_state = self.layernorm(hidden_state)\n        hidden_state = residual + self.mlp(hidden_state)\n\n        return hidden_state[:, 0]\n\n\nclass SiglipVisionTransformer(nn.Module):\n\n    def __init__(\n        self,\n        config: SiglipVisionConfig,\n        ctx_mgr: StepContextManager,\n        dtype: torch.dtype = None,\n        device: torch.device = None,\n        **kwargs,\n    ) -> None:\n        super().__init__()\n\n        self.config = config\n        embed_dim = config.hidden_size\n\n        self.embeddings = SiglipVisionEmbeddings(config, ctx_mgr=ctx_mgr, device=device, dtype=dtype)\n\n        self.encoder = SiglipEncoder(config, ctx_mgr=ctx_mgr, dtype=dtype, device=device)\n\n        num_hidden_layers = config.num_hidden_layers\n        if len(self.encoder.layers) > config.num_hidden_layers:\n            raise ValueError(f'The original encoder only has {num_hidden_layers} '\n                             f'layers, but you requested {len(self.encoder.layers)} layers.')\n\n        self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps, dtype=dtype, device=device)\n\n        self.use_head = (True if not hasattr(config, 'vision_use_head') else config.vision_use_head)\n        if self.use_head:\n            self.head = SiglipMultiheadAttentionPoolingHead(config=config, ctx_mgr=ctx_mgr, dtype=dtype, device=device)\n\n    def forward(\n        self,\n        pixel_values: torch.Tensor,\n        interpolate_pos_encoding: bool = True,\n    ) -> torch.Tensor:\n\n        hidden_states = self.embeddings(\n            pixel_values,\n            interpolate_pos_encoding=interpolate_pos_encoding,\n        )\n        last_hidden_state = self.encoder(inputs_embeds=hidden_states)\n\n        last_hidden_state = self.post_layernorm(last_hidden_state)\n\n        return last_hidden_state\n\n\nclass SiglipVisionModel(nn.Module):\n    config_class = SiglipVisionConfig\n    main_input_name = 'pixel_values'\n\n    def __init__(\n        self,\n        config: SiglipVisionConfig,\n        ctx_mgr: StepContextManager,\n        dtype: torch.dtype = None,\n        device: torch.device = None,\n        **kwargs,\n    ) -> None:\n        super().__init__()\n\n        self.vision_model = SiglipVisionTransformer(config, ctx_mgr=ctx_mgr, dtype=dtype, device=device)\n\n    def get_input_embeddings(self) -> nn.Module:\n        return self.vision_model.embeddings.patch_embedding\n\n    def forward(\n        self,\n        pixel_values: torch.Tensor,\n        interpolate_pos_encoding: bool = False,\n    ) -> torch.Tensor:\n        return self.vision_model(\n            pixel_values=pixel_values,\n            interpolate_pos_encoding=interpolate_pos_encoding,\n        )\n\n    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]:\n        stacked_params_mapping = [\n            # (param_name, shard_name, shard_id)\n            ('qkv_proj', 'q_proj', 'q'),\n            ('qkv_proj', 'k_proj', 'k'),\n            ('qkv_proj', 'v_proj', 'v'),\n        ]\n        params_dict = dict(self.named_parameters())\n        loaded_params: Set[str] = set()\n        layer_count = len(self.vision_model.encoder.layers)\n\n        for name, loaded_weight in weights:\n            # post_layernorm is optional in SiglipVisionModel\n            if (name.startswith('vision_model.post_layernorm') and self.vision_model.post_layernorm is None):\n                continue\n\n            # omit layers when num_hidden_layers_override is set\n            if name.startswith('vision_model.encoder.layers'):\n                layer_idx = int(name.split('.')[3])\n                if layer_idx >= layer_count:\n                    continue\n\n            for (param_name, weight_name, shard_id) in stacked_params_mapping:\n                if weight_name not in name:\n                    continue\n                name = name.replace(weight_name, param_name)\n\n                param = params_dict[name]\n                weight_loader = param.weight_loader\n                weight_loader(param, loaded_weight, shard_id)\n                break\n            else:\n                param = params_dict[name]\n                load_weight(param, loaded_weight)\n            loaded_params.add(name)\n        return loaded_params\n"
  },
  {
    "path": "lmdeploy/pytorch/models/starcoder2.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nfrom typing import Any, Iterable, List, Optional, Tuple\n\nimport torch\nfrom torch import nn\nfrom transformers.configuration_utils import PretrainedConfig\n\nfrom lmdeploy.pytorch.model_inputs import StepContext, StepContextManager\nfrom lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, LayerNorm, build_rotary_embedding_from_config\nfrom lmdeploy.pytorch.nn.linear import build_colwise_linear, build_qkv_proj, build_rowwise_linear\nfrom lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight\n\nfrom .utils.cudagraph import CudaGraphMixin\n\n\nclass Starcoder2Attention(nn.Module):\n    \"\"\"Rewrite module of Starcoder2Attention.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        quantization_config = getattr(config, 'quantization_config', None)\n        num_heads = config.num_attention_heads\n        num_key_value_heads = config.num_key_value_heads\n        hidden_size = config.hidden_size\n        head_dim = getattr(config, 'head_dim', hidden_size // num_heads)\n\n        # packed qkv\n        self.qkv_proj = build_qkv_proj(\n            hidden_size,\n            num_q_heads=num_heads,\n            num_kv_heads=num_key_value_heads,\n            head_size=head_dim,\n            bias=config.use_bias,\n            quant_config=quantization_config,\n            dtype=dtype,\n            device=device,\n        )\n\n        # rotary embedding\n        self.apply_rotary_pos_emb = ApplyRotaryEmb()\n\n        # attention\n        sliding_window = getattr(config, 'sliding_window', None)\n        self.attn_fwd = Attention(\n            num_heads,\n            head_dim,\n            num_kv_heads=num_key_value_heads,\n            v_head_size=head_dim,\n            sliding_window=sliding_window,\n        )\n\n        # o_proj\n        self.o_proj = build_rowwise_linear(num_heads * head_dim,\n                                           hidden_size,\n                                           bias=config.use_bias,\n                                           quant_config=quantization_config,\n                                           dtype=dtype,\n                                           device=device,\n                                           is_tp=True)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],\n        past_key_value: Optional[Tuple[torch.Tensor]] = None,\n        attn_metadata: Any = None,\n    ):\n        \"\"\"Rewrite of LlamaAttention.forward.\"\"\"\n        # qkv proj\n        qkv_states = self.qkv_proj(hidden_states)\n        # (-1, heads, head_dim)\n        qkv_states = qkv_states.flatten(0, -2)\n        query_states, key_states, value_states = self.qkv_proj.split_qkv(qkv_states)\n\n        # apply rotary embedding\n        cos, sin = rotary_pos_emb\n        query_states, key_states = self.apply_rotary_pos_emb(\n            query_states,\n            key_states,\n            cos,\n            sin,\n            inplace=True,\n        )\n\n        # attention\n        attn_output = self.attn_fwd(\n            query_states,\n            key_states,\n            value_states,\n            past_key_value[0],\n            past_key_value[1],\n            attn_metadata,\n            k_scales_zeros=None if len(past_key_value) == 2 else past_key_value[2],\n            v_scales_zeros=None if len(past_key_value) == 2 else past_key_value[3],\n            inplace=True,\n        )\n        attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1)\n\n        # o proj\n        attn_output = self.o_proj(attn_output)\n        return attn_output\n\n\nclass Starcoder2MLP(nn.Module):\n    \"\"\"mlp.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        quantization_config = getattr(config, 'quantization_config', None)\n        # gate up\n        self.c_fc = build_colwise_linear(\n            config.hidden_size,\n            config.intermediate_size,\n            bias=config.use_bias,\n            dtype=dtype,\n            device=device,\n            quant_config=quantization_config,\n            is_tp=True,\n        )\n\n        # silu and mul\n        hidden_act = config.hidden_act\n        if hidden_act is None:\n            hidden_act = 'gelu_pytorch_tanh'\n            assert hidden_act == 'gelu_pytorch_tanh'\n        self.act_fn = nn.GELU(approximate='tanh')\n\n        # down\n        self.c_proj = build_rowwise_linear(config.intermediate_size,\n                                           config.hidden_size,\n                                           bias=config.use_bias,\n                                           quant_config=quantization_config,\n                                           dtype=dtype,\n                                           device=device,\n                                           is_tp=True)\n\n    def forward(self, x):\n        \"\"\"forward.\"\"\"\n        gate_up = self.c_fc(x)\n        act = self.act_fn(gate_up)\n        return self.c_proj(act)\n\n\nclass Starcoder2DecoderLayer(nn.Module):\n    \"\"\"Decoder layer.\"\"\"\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 layer_idx: int,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__()\n        self.layer_idx = layer_idx\n\n        # build attention layer\n        self.self_attn = Starcoder2Attention(config, dtype=dtype, device=device)\n\n        # build MLP\n        self.mlp = Starcoder2MLP(config, dtype=dtype, device=device)\n\n        # build input layer norm\n        self.input_layernorm = LayerNorm(config.hidden_size, eps=config.norm_epsilon, dtype=dtype, device=device)\n\n        # build attention layer norm\n        self.post_attention_layernorm = LayerNorm(config.hidden_size,\n                                                  eps=config.norm_epsilon,\n                                                  dtype=dtype,\n                                                  device=device)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],\n        past_key_value: Optional[List[torch.FloatTensor]],\n        residual: Optional[torch.Tensor] = None,\n        attn_metadata: Any = None,\n    ):\n        if residual is None:\n            residual = hidden_states\n            hidden_states = self.input_layernorm(hidden_states)\n        else:\n            hidden_states, residual = self.input_layernorm(hidden_states, residual)\n\n        # Self Attention\n        hidden_states = self.self_attn(\n            hidden_states=hidden_states,\n            rotary_pos_emb=rotary_pos_emb,\n            past_key_value=past_key_value,\n            attn_metadata=attn_metadata,\n        )\n\n        # Fully Connected\n        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)\n        hidden_states = self.mlp(hidden_states)\n\n        outputs = (hidden_states, residual)\n        return outputs\n\n\nclass Starcoder2Model(nn.Module):\n    \"\"\"model.\"\"\"\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):\n        super().__init__()\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n\n        self.embed_tokens = nn.Embedding(config.vocab_size,\n                                         config.hidden_size,\n                                         self.padding_idx,\n                                         dtype=dtype,\n                                         device=device)\n\n        # build all decode layers\n        self.layers = nn.ModuleList([\n            Starcoder2DecoderLayer(config, layer_idx, dtype=dtype, device=device)\n            for layer_idx in range(config.num_hidden_layers)\n        ])\n\n        # build norm\n        self.norm = LayerNorm(config.hidden_size, eps=config.norm_epsilon, dtype=dtype, device=device)\n\n        # build rotary embedding\n        self.rotary_emb = build_rotary_embedding_from_config(config)\n\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        attn_metadata: Any = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n    ):\n        \"\"\"Rewrite of LlamaModel.forward.\"\"\"\n\n        # token embedding\n        if inputs_embeds is None:\n            inputs_embeds = self.embed_tokens(input_ids)\n\n        hidden_states = inputs_embeds\n\n        # rotary embedding\n        cos, sin = self.rotary_emb(hidden_states, position_ids)\n        cos, sin = cos[0], sin[0]\n        rotary_pos_emb = (cos, sin)\n\n        # decoding\n        residual = None\n        for idx, decoder_layer in enumerate(self.layers):\n            past_key_value = past_key_values[idx]\n            hidden_states, residual = decoder_layer(\n                hidden_states,\n                rotary_pos_emb=rotary_pos_emb,\n                past_key_value=past_key_value,\n                residual=residual,\n                attn_metadata=attn_metadata,\n            )\n\n        # norm\n        hidden_states, _ = self.norm(hidden_states, residual)\n\n        return hidden_states\n\n    def get_input_embeddings(self):\n        \"\"\"Get input embeddings.\"\"\"\n        return self.embed_tokens\n\n\nclass Starcoder2ForCausalLM(nn.Module, CudaGraphMixin):\n    \"\"\"ModelForCausalLM.\"\"\"\n\n    packed_modules_mapping = {\n        'qkv_proj': [\n            'q_proj',\n            'k_proj',\n            'v_proj',\n        ],\n    }\n\n    def __init__(self,\n                 config: PretrainedConfig,\n                 ctx_mgr: StepContextManager,\n                 dtype: torch.dtype = None,\n                 device: torch.device = None):\n        super().__init__()\n        self.config = config\n        self.ctx_mgr = ctx_mgr\n        # build model\n        self.model = Starcoder2Model(config, dtype=dtype, device=device)\n        # build lm_head\n        self.lm_head = build_rowwise_linear(config.hidden_size,\n                                            config.vocab_size,\n                                            bias=False,\n                                            dtype=dtype,\n                                            device=device)\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        past_key_values: List[List[torch.Tensor]],\n        attn_metadata: Any = None,\n        inputs_embeds: torch.Tensor = None,\n        **kwargs,\n    ):\n        \"\"\"Model forward, return logits.\"\"\"\n        hidden_states = self.model(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            inputs_embeds=inputs_embeds,\n        )\n        return hidden_states\n\n    def get_logits(self, hidden_states: torch.Tensor):\n        \"\"\"Compute logits of the model output.\"\"\"\n        return self.lm_head(hidden_states)\n\n    def update_weights(self):\n        \"\"\"Update weights.\"\"\"\n        self.lm_head.weight = self.model.embed_tokens.weight\n\n    def get_input_embeddings(self):\n        \"\"\"Get input embeddings.\"\"\"\n        return self.model.get_input_embeddings()\n\n    def prepare_inputs_for_generation(\n        self,\n        past_key_values: List[List[torch.Tensor]],\n        inputs_embeds: Optional[torch.Tensor] = None,\n        context: StepContext = None,\n    ):\n        \"\"\"Prepare input.\"\"\"\n        # get input_ids, position_ids and attention metadatas\n        input_ids = context.input_ids\n        position_ids = context.position_ids\n        attn_metadata = context.attn_metadata\n\n        # process vision embeddings\n        vision_embeddings = context.input_embeddings\n        vision_embedding_indexing = context.input_embedding_indexing\n        if vision_embeddings is not None and len(vision_embeddings) > 0:\n            if inputs_embeds is None:\n                inputs_embeds = self.get_input_embeddings()(input_ids)\n            inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds)\n\n        # inputs of forward\n        return dict(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n            inputs_embeds=inputs_embeds,\n        )\n\n    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):\n        \"\"\"Load weights.\"\"\"\n        # modify from vllm\n        stacked_params_mapping = [\n            # (param_name, shard_name, shard_id)\n            ('.qkv_proj', '.q_proj', 'q'),\n            ('.qkv_proj', '.k_proj', 'k'),\n            ('.qkv_proj', '.v_proj', 'v'),\n        ]\n\n        params_dict = dict(self.named_parameters())\n        for name, loaded_weight in weights:\n            if 'rotary_emb.inv_freq' in name:\n                continue\n            if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):\n                continue\n            if self.config.tie_word_embeddings and 'lm_head.weight' in name:\n                continue\n            for (param_name, weight_name, shard_id) in stacked_params_mapping:\n                if weight_name not in name:\n                    continue\n                name = name.replace(weight_name, param_name)\n                param = params_dict[name]\n                load_weight(param, loaded_weight, shard_id=shard_id)\n                break\n            else:\n                param = params_dict[name]\n                load_weight(param, loaded_weight)\n"
  },
  {
    "path": "lmdeploy/pytorch/models/utils/__init__.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n"
  },
  {
    "path": "lmdeploy/pytorch/models/utils/cudagraph.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom dataclasses import dataclass\nfrom typing import Any, Dict, List, Optional\n\nimport torch\nfrom torch import Tensor\nfrom torch.profiler import record_function\n\nfrom lmdeploy.pytorch.model_inputs import StepContext, get_step_ctx_manager\n\nBuffType = Dict[str, Tensor]\n\n\ndef _get_meta_flashattn(\n        batch_size: int,\n        max_seqlen_q: int,\n        max_seqlen_k: int,\n        num_heads_q: int,\n        num_heads_kv: int,\n        headdim: int,\n        cache_seqlens: torch.Tensor,\n        qkv_dtype=torch.bfloat16,\n        headdim_v=None,\n        cu_seqlens_q: Optional[torch.Tensor] = None,\n        cu_seqlens_k_new: Optional[torch.Tensor] = None,\n        page_size: Optional[int] = None,\n        causal=True,\n        window_size=(-1, -1),  # -1 means infinite context window\n        num_splits=0,\n):\n    \"\"\"Get scheduler metadata for flash attn.\"\"\"\n    from flash_attn_interface import get_scheduler_metadata\n\n    metadata = get_scheduler_metadata(\n        batch_size,\n        max_seqlen_q,\n        max_seqlen_k,\n        num_heads_q,\n        num_heads_kv,\n        headdim,\n        cache_seqlens,\n        qkv_dtype=qkv_dtype,\n        headdim_v=headdim_v,\n        cu_seqlens_q=cu_seqlens_q,\n        cu_seqlens_k_new=cu_seqlens_k_new,\n        page_size=page_size,\n        causal=causal,\n        window_size=window_size,\n        num_splits=num_splits,\n    )\n    return metadata\n\n\ndef next_power_of_2(n: int):\n    \"\"\"Return the smallest power of 2 greater than or equal to n.\"\"\"\n    n -= 1\n    n |= n >> 1\n    n |= n >> 2\n    n |= n >> 4\n    n |= n >> 8\n    n |= n >> 16\n    n |= n >> 32\n    n += 1\n    return n\n\n\n@dataclass\nclass CudaGraphMeta:\n    \"\"\"Meta info of cudagraph.\"\"\"\n    max_batchs: int\n    max_tokens: int\n    num_blocks: int\n    is_decoding: int\n    device: torch.device\n    input_buffers: BuffType = None\n    output_buffers: BuffType = None\n    vocab_size: int = 1\n    use_mla_fp8_cache: bool = False\n    use_flash_mla: bool = False\n    mla_index_topk: Optional[int] = None\n    decode_query_len: int = 1\n    use_fa3_decoding: bool = False\n\n\nclass CudaGraphMixin:\n    \"\"\"Mixin class to support cudagraph.\"\"\"\n\n    def support_cuda_graph(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        past_key_values: List[List[torch.Tensor]],\n        attn_metadata: Any = None,\n        inputs_embeds: torch.Tensor = None,\n        **kwargs,\n    ):\n        \"\"\"Return True is model support cudagraph.\"\"\"\n        return attn_metadata.is_decoding\n\n    def make_output_buffers(self, output):\n        \"\"\"Make output buffers.\"\"\"\n        if isinstance(output, torch.Tensor):\n            output_buffers = dict(hidden_states=output)\n        else:\n            assert isinstance(output, Dict)\n            output_buffers = output\n        return output_buffers\n\n    def update_meta_flashattn(self, graph_meta: CudaGraphMeta, block_size: int, max_seqlen_k: int,\n                              cache_seqlens: torch.Tensor):\n        \"\"\"Update meta flashattn.\"\"\"\n        ctx_mgr = get_step_ctx_manager()\n        step_ctx = ctx_mgr.current_context()\n        model_config = step_ctx.model_config\n        batch_size = graph_meta.max_batchs\n        max_seqlen_q = graph_meta.decode_query_len\n        sliding_window = model_config.sliding_window\n        num_attention_heads = model_config.num_attention_heads\n        num_key_value_heads = model_config.num_key_value_heads\n        headdim = model_config.head_dim\n        torch_dtype = model_config.dtype\n        if sliding_window is None:\n            window_size = (-1, -1)\n        elif isinstance(sliding_window, int):\n            window_size = (sliding_window, sliding_window)\n        cache_seqlens = cache_seqlens.to(torch.int32)\n        scheduler_metadata = _get_meta_flashattn(\n            batch_size=batch_size,\n            max_seqlen_q=max_seqlen_q,\n            max_seqlen_k=max_seqlen_k,\n            num_heads_q=num_attention_heads,\n            num_heads_kv=num_key_value_heads,\n            headdim=headdim,\n            cache_seqlens=cache_seqlens,\n            qkv_dtype=torch_dtype,\n            page_size=block_size,\n            window_size=window_size,\n        )\n        return scheduler_metadata\n\n    def make_buffers_cudagraph(self, graph_meta: CudaGraphMeta, *args, past_key_values: List, **kwargs) -> BuffType:\n        \"\"\"Make cudagraph buffers from forward inputs.\"\"\"\n        max_batches = graph_meta.max_batchs\n        max_tokens = graph_meta.max_tokens\n        num_blocks = graph_meta.num_blocks\n        device = graph_meta.device\n        decode_query_len = graph_meta.decode_query_len\n\n        input_buffers: BuffType = dict()\n        input_buffers['input_ids'] = torch.randint(0,\n                                                   graph_meta.vocab_size, (1, max_tokens),\n                                                   dtype=torch.int64,\n                                                   device=device)\n        input_buffers['position_ids'] = torch.zeros((1, max_tokens), dtype=torch.int64, device=device)\n\n        # flash_mla requires block_offsets and kv_lens int32\n        input_buffers['block_offsets'] = torch.zeros((max_batches, num_blocks), dtype=torch.int32, device=device)\n        input_buffers['qkv_lens'] = torch.zeros(3, max_batches, dtype=torch.int32, device=device)\n\n        input_buffers['q_start_loc'] = input_buffers['qkv_lens'][0]\n        input_buffers['q_seqlens'] = input_buffers['qkv_lens'][1]\n        input_buffers['kv_seqlens'] = input_buffers['qkv_lens'][2]\n        input_buffers['qkv_seqlens'] = input_buffers['qkv_lens'][1:]\n        input_buffers['local_adapter_ids'] = torch.zeros(max_batches, dtype=torch.int64, device=device)\n\n        input_buffers['cu_seqlens'] = torch.zeros(2, max_batches + 1, dtype=torch.int32, device=device)\n        input_buffers['cu_seqlens_q'] = input_buffers['cu_seqlens'][0]\n        input_buffers['cu_seqlens_k'] = input_buffers['cu_seqlens'][1]\n\n        if graph_meta.use_flash_mla is True:\n            import flash_mla\n\n            # create buffers for flash mla\n            num_attention_heads = self.config.num_attention_heads\n            index_topk = graph_meta.mla_index_topk\n            num_heads_q = None if index_topk is None else num_attention_heads\n            input_buffers['tile_scheduler_metadata'], input_buffers['num_splits'] = flash_mla.get_mla_metadata(\n                torch.ones(max_batches, dtype=torch.int32, device=device),\n                num_attention_heads * decode_query_len,\n                num_heads_k=1,\n                num_heads_q=num_heads_q,\n                is_fp8_kvcache=graph_meta.use_mla_fp8_cache,\n                topk=index_topk)\n\n        # use fa3 decode kernel for spec decode\n        elif graph_meta.use_fa3_decoding is True:\n            block_size = past_key_values[0][0].size(1)\n            input_buffers['scheduler_metadata'] = self.update_meta_flashattn(graph_meta,\n                                                                             block_size=block_size,\n                                                                             max_seqlen_k=decode_query_len,\n                                                                             cache_seqlens=input_buffers['kv_seqlens'])\n\n        return input_buffers\n\n    @record_function('fill_buffers_cudagraph')\n    def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, input_ids: Tensor, position_ids: Tensor,\n                               past_key_values: List, attn_metadata: Any, inputs_embeds: Tensor,\n                               **kwargs) -> Dict[str, Tensor]:\n        \"\"\"Fill cudagraph buffers from forward inputs.\"\"\"\n\n        block_offsets: Tensor = attn_metadata.block_offsets\n        q_start_loc: Tensor = attn_metadata.q_start_loc\n        q_seqlens: Tensor = attn_metadata.q_seqlens\n        kv_seqlens: Tensor = attn_metadata.kv_seqlens\n        input_buffers: BuffType = graph_meta.input_buffers\n\n        batch_size, num_blocks = block_offsets.size()\n        num_tokens = input_ids.size(-1)\n        decode_query_len = graph_meta.decode_query_len\n        # fill buffer\n        input_buffers['input_ids'].random_(0, graph_meta.vocab_size)\n        input_buffers['input_ids'][:, :num_tokens] = input_ids\n        input_buffers['position_ids'][:, :num_tokens] = position_ids\n        input_buffers['block_offsets'][:batch_size, :num_blocks] = block_offsets\n\n        qkv = torch.stack((q_start_loc, q_seqlens, kv_seqlens))\n        input_buffers['qkv_lens'].zero_()\n        input_buffers['q_seqlens'].fill_(graph_meta.max_tokens // graph_meta.max_batchs)\n        input_buffers['qkv_lens'][:, :batch_size] = qkv\n        input_buffers['cu_seqlens'][:, 1:] = input_buffers['qkv_seqlens'].cumsum(1)\n        if inputs_embeds is not None:\n            emb_size = inputs_embeds.size(-1)\n            if 'inputs_embeds' not in input_buffers:\n                max_num_tokens = input_buffers['input_ids'].size(-1)\n                input_buffers['inputs_embeds'] = inputs_embeds.new_zeros(1, max_num_tokens, emb_size)\n            input_buffers['inputs_embeds'][:, :num_tokens] = inputs_embeds\n\n        # create inputs\n        new_batch_size = input_buffers['block_offsets'].size(0)\n        attn_metadata.block_offsets = input_buffers['block_offsets']\n        attn_metadata.q_start_loc = input_buffers['q_start_loc']\n        attn_metadata.q_seqlens = input_buffers['q_seqlens']\n        attn_metadata.kv_seqlens = input_buffers['kv_seqlens']\n        attn_metadata.cu_seqlens_q = input_buffers['cu_seqlens_q']\n        attn_metadata.cu_seqlens_k = input_buffers['cu_seqlens_k']\n\n        if graph_meta.use_flash_mla is True:\n            import flash_mla\n            num_attention_heads = self.config.num_attention_heads\n            index_topk = graph_meta.mla_index_topk\n            num_heads_q = None if index_topk is None else num_attention_heads\n            tile_scheduler_metadata, num_splits = flash_mla.get_mla_metadata(\n                attn_metadata.kv_seqlens.to(torch.int32),\n                num_attention_heads * decode_query_len,\n                num_heads_k=1,\n                num_heads_q=num_heads_q,\n                is_fp8_kvcache=graph_meta.use_mla_fp8_cache,\n                topk=index_topk)\n            # here we use copy_ instead of = to avoid using new allocated mem for cuda graph\n            input_buffers['tile_scheduler_metadata'].copy_(tile_scheduler_metadata)\n            input_buffers['num_splits'][:new_batch_size + 1].copy_(num_splits[:new_batch_size + 1])\n            attn_metadata.tile_scheduler_metadata = input_buffers['tile_scheduler_metadata']\n            attn_metadata.num_splits = input_buffers['num_splits']\n\n        # use fa3 decode kernel for spec decode\n        elif graph_meta.use_fa3_decoding is True:\n            block_size = past_key_values[0][0].size(1)\n            scheduler_metadata = self.update_meta_flashattn(\n                graph_meta,\n                block_size=block_size,\n                max_seqlen_k=attn_metadata.max_kv_seqlen,\n                cache_seqlens=input_buffers['kv_seqlens'],\n            )\n            assert scheduler_metadata.shape == input_buffers['scheduler_metadata'].shape\n            input_buffers['scheduler_metadata'].copy_(scheduler_metadata)\n            attn_metadata.scheduler_metadata = input_buffers['scheduler_metadata']\n\n        new_inputs = dict(\n            past_key_values=past_key_values,\n            attn_metadata=attn_metadata,\n        )\n\n        new_inputs['input_ids'] = input_buffers['input_ids']\n        new_inputs['position_ids'] = input_buffers['position_ids']\n\n        if inputs_embeds is not None:\n            new_inputs['inputs_embeds'] = input_buffers['inputs_embeds']\n\n        new_inputs.update(kwargs)\n        return new_inputs\n\n    def update_context_cudagraph(self, graph_meta: CudaGraphMeta, context: StepContext):\n        \"\"\"Update step context with input buffers.\"\"\"\n        input_buffers = graph_meta.input_buffers\n        local_adapter_ids = context.local_adapter_ids\n        if local_adapter_ids is not None:\n            if input_buffers['local_adapter_ids'].data_ptr() != local_adapter_ids.data_ptr():\n                input_buffers['local_adapter_ids'].fill_(0)\n            batch_size = local_adapter_ids.size(0)\n            input_buffers['local_adapter_ids'][:batch_size] = local_adapter_ids\n            context.local_adapter_ids = input_buffers['local_adapter_ids']\n        context.q_seqlens = input_buffers['q_seqlens']\n        context.kv_seqlens = input_buffers['kv_seqlens']\n        context.q_start_loc = input_buffers['q_start_loc']\n\n    def get_outputs_cudagraph(self, output_buffers: Dict[str, torch.Tensor], input_ids: Tensor, **kwargs):\n        \"\"\"Get outputs from buffers.\"\"\"\n        num_tokens = input_ids.size(-1)\n        outputs = dict()\n        outputs['hidden_states'] = output_buffers['hidden_states'][:, :num_tokens]\n        if output_buffers.get('all_routed_experts', None) is not None:\n            outputs['all_routed_experts'] = output_buffers['all_routed_experts'][:num_tokens, ...].clone()\n        return outputs\n"
  },
  {
    "path": "lmdeploy/pytorch/models/utils/micro_batch.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport functools\n\nimport torch\n\n\ndef enable_micro_batch(param_name, index=-1):\n    \"\"\"Decorator factory to enable micro-batch computation.\"\"\"\n\n    def decorator(func):\n\n        @functools.wraps(func)\n        def wrapper(self, *args, **kwargs):\n            if index != -1 and len(args) > index:\n                inputs = args[index]\n            else:\n                inputs = kwargs.get(param_name, None)\n\n            if isinstance(inputs, list):\n                # Apply forward computation to each micro-batch\n                results = []\n                for input in inputs:\n                    if index != -1 and len(args) > index:\n                        args = args[0:index] + (input, ) + args[index + 1:]\n                    else:\n                        kwargs[param_name] = input\n                    result = func(self, *args, **kwargs)\n                    results.append(result)\n                return results\n            else:\n                # If not a list, directly apply the forward computation\n                return func(self, *args, **kwargs)\n\n        return wrapper\n\n    return decorator\n\n\ndef split_batch(func, param_name, index=-1, num_splits=2):\n    \"\"\"Decorator to split along the 0th dimension into a specified number of\n    chunks.\"\"\"\n\n    def wrapper(*args, **kwargs):\n        if index != -1 and len(args) > index:\n            inputs = args[index]\n        else:\n            inputs = kwargs.get(param_name, None)\n\n        if inputs is not None:\n            split_inputs = list(torch.chunk(inputs, num_splits, dim=0))\n            if index != -1 and len(args) > index:\n                args = args[0:index] + (split_inputs, ) + args[index + 1:]\n            else:\n                kwargs[param_name] = split_inputs\n\n            results = func(*args, **kwargs)\n            return torch.cat(results, dim=0)\n        else:\n            return func(*args, **kwargs)\n\n    return wrapper\n"
  },
  {
    "path": "lmdeploy/pytorch/models/utils/model.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport functools\nfrom typing import Iterable, List, Optional, Tuple\n\nimport torch\n\nfrom lmdeploy.pytorch.config import QuantizationConfig\nfrom lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor\nfrom lmdeploy.pytorch.model_inputs import ModelInputs, ModelInputsDelta, StepContext\nfrom lmdeploy.pytorch.models.patch import get_build_model_context\nfrom lmdeploy.pytorch.nn.embedding import ParallelEmbedding\nfrom lmdeploy.pytorch.nn.linear import build_rowwise_linear\n\n\nclass BaseModelMetaProcessor:\n    \"\"\"Model meta processor base class.\"\"\"\n\n    def update_inputs(self, inputs: ModelInputs, device: torch.device) -> ModelInputs:\n        \"\"\"Update model inputs.\"\"\"\n        return inputs\n\n    def update_delta(self, inputs: ModelInputs, delta: ModelInputsDelta) -> ModelInputs:\n        \"\"\"Update model inputs for delta.\"\"\"\n        return inputs\n\n    def merge(self, inputs: ModelInputs, other: ModelInputs) -> ModelInputs:\n        \"\"\"Merge model inputs with deltas.\"\"\"\n        return inputs\n\n\nclass DeployModelMixin:\n\n    def forward(self, *args, **kwargs):\n        \"\"\"Forward of model.\"\"\"\n        raise NotImplementedError('Not Implemented')\n\n    def prepare_inputs_for_generation(\n        self,\n        past_key_values: List[List[torch.Tensor]],\n        inputs_embeds: Optional[torch.Tensor] = None,\n        context: StepContext = None,\n    ):\n        \"\"\"Prepare input.\"\"\"\n        raise NotImplementedError('Not Implemented')\n\n    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):\n        \"\"\"Load weights.\"\"\"\n        raise NotImplementedError('Not Implemented')\n\n    def get_logits(self, hidden_states: torch.Tensor):\n        \"\"\"Compute logits of the model output.\"\"\"\n        return hidden_states\n\n    @classmethod\n    def rename_weight(cls, name: str) -> str:\n        \"\"\"Rename weight.\"\"\"\n        return name\n\n    def update_weights(self):\n        \"\"\"Update weights.\"\"\"\n        pass\n\n    def update_model_metas(self,\n                           past_key_values: List[List[torch.Tensor]],\n                           inputs_embeds: Optional[torch.Tensor] = None,\n                           context: StepContext = None):\n        \"\"\"Update model meta.\"\"\"\n        return None\n\n    def get_input_processor(self) -> BaseModelInputProcessor:\n        \"\"\"Get input processor.\"\"\"\n        return None\n\n    def get_modelmeta_processor(self) -> BaseModelMetaProcessor:\n        \"\"\"Get model meta preprocessor.\"\"\"\n        return BaseModelMetaProcessor()\n\n    @classmethod\n    def update_quant_config(cls, quant_config: QuantizationConfig):\n        \"\"\"Update quant config.\"\"\"\n        if quant_config is None:\n            return\n        ignored_layers = [cls.rename_weight(name) for name in quant_config.ignored_layers]\n\n        added_ignore_layers = set()\n\n        for layer_name in ignored_layers:\n            if '.q_proj' in layer_name:\n                added_ignore_layers.add(layer_name.replace(\n                    '.q_proj',\n                    '.qkv_proj',\n                ))\n            elif '.gate_proj' in layer_name:\n                if '.experts' in layer_name:\n                    added_ignore_layers.add(layer_name.split('.experts', 1)[0] + '.experts')\n                else:\n                    added_ignore_layers.add(layer_name.replace('.gate_proj', '.gate_up_proj'))\n            elif '.down_proj' in layer_name:\n                if '.experts' in layer_name:\n                    added_ignore_layers.add(layer_name.split('.experts', 1)[0] + '.experts')\n                else:\n                    added_ignore_layers.add(layer_name)\n\n        added_ignore_layers = list(added_ignore_layers)\n\n        ignored_layers.extend(added_ignore_layers)\n        quant_config.ignored_layers = ignored_layers\n\n        return quant_config\n\n\nclass DeployModelMixinV1(DeployModelMixin):\n\n    def get_logits(self, hidden_states: torch.Tensor):\n        \"\"\"Compute logits of the model output.\"\"\"\n        head_dtype = self.lm_head.weight.dtype\n        if hidden_states.dtype != head_dtype:\n            hidden_states = hidden_states.to(dtype=head_dtype)\n        hidden_states = self.lm_head(hidden_states)\n        return hidden_states\n\n    def get_input_embeddings(self):\n        \"\"\"Get embeds.\"\"\"\n        raise NotImplementedError('Not Implemented')\n\n    def update_weights(self):\n        \"\"\"Update weights.\"\"\"\n        if getattr(self.config, 'tie_word_embeddings', False):\n            self.lm_head.weight = self.get_input_embeddings().weight\n\n    def build_lm_head(self,\n                      hidden_size: int,\n                      vocab_size: int,\n                      bias: bool = False,\n                      dtype: Optional[torch.dtype] = None,\n                      device: Optional[torch.device] = None,\n                      **kwargs):\n        \"\"\"Build LM Head.\"\"\"\n        bm_ctx = get_build_model_context()\n        head_dtype = torch.float32 if bm_ctx.fp32_lm_head else dtype\n        lm_head = build_rowwise_linear(\n            hidden_size,\n            vocab_size,\n            bias,\n            dtype=head_dtype,\n            device=device,\n            **kwargs,\n        )\n        return lm_head\n\n\ndef vlm_model(vlm_cls):\n    if not issubclass(vlm_cls, torch.nn.Module):\n        raise ValueError('Only subclasses of nn.Module can be decorated with @vlm_model.')\n\n    @functools.wraps(vlm_cls)\n    def wrapper(*args, **kwargs):\n        bm_ctx = get_build_model_context()\n        disable_vision_encoder = bm_ctx.disable_vision_encoder\n        if disable_vision_encoder:\n            mod = torch.nn.Identity()\n            mod._is_dummy_mod = True\n            return mod\n        else:\n            return vlm_cls(*args, **kwargs)\n\n    return wrapper\n\n\ndef build_embedding(vocab_size: int,\n                    hidden_size: int,\n                    padding_idx: int,\n                    dtype: torch.dtype = None,\n                    device: torch.device = None,\n                    is_tp: bool = False,\n                    **kwargs):\n    \"\"\"Build embedding.\"\"\"\n    bm_ctx = get_build_model_context()\n\n    # run with fp32 only when share weights with lm_head\n    force_dtype = None\n    if bm_ctx.fp32_lm_head and bm_ctx.tie_word_embeddings:\n        force_dtype = torch.float32\n\n    return ParallelEmbedding(\n        vocab_size,\n        hidden_size,\n        padding_idx,\n        dtype=dtype,\n        device=device,\n        is_tp=is_tp,\n        force_dtype=force_dtype,\n        **kwargs,\n    )\n"
  },
  {
    "path": "lmdeploy/pytorch/models/whisper.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n# adpated from https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/modeling_whisper.py\n\nimport torch\nfrom torch import nn\nfrom transformers.activations import ACT2FN\nfrom transformers.configuration_utils import PretrainedConfig\n\nfrom lmdeploy.pytorch.nn import LayerNorm\nfrom lmdeploy.pytorch.nn.linear import build_colwise_linear, build_qkv_proj, build_rowwise_linear\n\n\nclass WhisperAttention(nn.Module):\n    \"\"\"Multi-headed attention from 'Attention Is All You Need' paper.\"\"\"\n\n    def __init__(\n        self,\n        embed_dim: int,\n        num_heads: int,\n        bias: bool = True,\n        config: PretrainedConfig = None,\n        dtype: torch.dtype = None,\n        device: torch.device = None,\n    ) -> None:\n        super().__init__()\n        self.config = config\n        quantization_config = getattr(config, 'quantization_config', None)\n\n        self.embed_dim = embed_dim\n        self.num_heads = num_heads\n        self.head_dim = embed_dim // num_heads\n\n        if (self.head_dim * num_heads) != self.embed_dim:\n            raise ValueError(f'embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}'\n                             f' and `num_heads`: {num_heads}).')\n        self.scaling = self.head_dim**-0.5\n\n        # packed qkv\n        # TODO, zhouxinyu, hf whisper hard-code k_proj bias = False, may double check\n        self.qkv_proj = build_qkv_proj(self.embed_dim,\n                                       num_q_heads=self.num_heads,\n                                       num_kv_heads=self.num_heads,\n                                       head_size=self.head_dim,\n                                       bias=bias,\n                                       quant_config=quantization_config,\n                                       dtype=dtype,\n                                       device=device)\n\n        # o_proj\n        self.out_proj = build_rowwise_linear(self.embed_dim,\n                                             self.embed_dim,\n                                             bias=bias,\n                                             quant_config=quantization_config,\n                                             dtype=dtype,\n                                             device=device,\n                                             is_tp=True)\n\n    def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:\n        \"\"\"forward.\"\"\"\n        # qkv proj\n        qkv_states = self.qkv_proj(hidden_states)\n        q, k, v = self.qkv_proj.split_qkv(qkv_states)\n\n        q = q.transpose(1, 2)\n        k = k.transpose(1, 2)\n        v = v.transpose(1, 2)\n        q = q * self.scaling\n\n        # attention\n        attn_output = nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attention_mask, scale=1.0)\n\n        # o proj\n        attn_output = attn_output.transpose(1, 2)\n        attn_output = attn_output.flatten(-2, -1)\n        attn_output = self.out_proj(attn_output)\n        return attn_output\n\n\nclass WhisperEncoderLayer(nn.Module):\n\n    def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None) -> None:\n        super().__init__()\n        self.config = config\n        quantization_config = getattr(config, 'quantization_config', None)\n\n        self.act = ACT2FN[config.activation_function]\n        self.embed_dim = config.d_model\n\n        self.self_attn = WhisperAttention(\n            embed_dim=self.embed_dim,\n            num_heads=config.encoder_attention_heads,\n            config=config,\n            dtype=dtype,\n            device=device,\n        )\n        self.self_attn_layer_norm = LayerNorm(self.embed_dim, dtype=dtype, device=device)\n        self.fc1 = build_colwise_linear(\n            self.embed_dim,\n            config.encoder_ffn_dim,\n            bias=True,\n            quant_config=quantization_config,\n            dtype=dtype,\n            device=device,\n        )\n        self.fc2 = build_rowwise_linear(\n            config.encoder_ffn_dim,\n            self.embed_dim,\n            bias=True,\n            quant_config=quantization_config,\n            dtype=dtype,\n            device=device,\n        )\n        self.final_layer_norm = LayerNorm(self.embed_dim, dtype=dtype, device=device)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: torch.Tensor,\n    ) -> torch.Tensor:\n        \"\"\"\n        Args:\n            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`\n            attention_mask (`torch.FloatTensor`): attention mask of size\n                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.\n        \"\"\"\n        residual = hidden_states\n        hidden_states = self.self_attn_layer_norm(hidden_states)\n        hidden_states = self.self_attn(\n            hidden_states=hidden_states,\n            attention_mask=attention_mask,\n        )\n        hidden_states = residual + hidden_states\n\n        residual = hidden_states\n        hidden_states = self.final_layer_norm(hidden_states)\n        hidden_states = self.act(self.fc1(hidden_states))\n        hidden_states = self.fc2(hidden_states)\n        hidden_states = residual + hidden_states\n\n        return hidden_states\n"
  },
  {
    "path": "lmdeploy/pytorch/multimodal/__init__.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom .data_type import MultiModalData\n\n__all__ = ['MultiModalData']\n"
  },
  {
    "path": "lmdeploy/pytorch/multimodal/data_type.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom dataclasses import dataclass, fields\nfrom typing import Any, Dict, List, Union\n\nfrom torch import Tensor\n\nfrom lmdeploy.vl.constants import Modality\n\nNestedTensor = Union[Tensor, List[Tensor]]\n\n\n@dataclass\nclass MultiModalData:\n    data: NestedTensor\n    start: int\n    end: int = None\n    meta: Dict[str, Any] = None\n\n    modality: Modality = Modality.IMAGE\n\n    def __post_init__(self):\n        if self.end is None:\n            self.end = self.start\n\n    def to_device(self, device: str, non_blocking: bool = False):\n        \"\"\"To device.\"\"\"\n        out_dict = dict()\n        for f in fields(self):\n            k = f.name\n            if k in ('data', 'meta'):\n                continue\n            v = getattr(self, k)\n            out_dict[k] = v\n\n        if isinstance(self.data, Tensor):\n            data = self.data.to(device=device, non_blocking=non_blocking)\n        else:\n            data = [d.to(device=device, non_blocking=non_blocking) for d in self.data]\n        out_dict['data'] = data\n\n        new_meta = None\n        if self.meta is not None:\n            new_meta = dict()\n            for k, v in self.meta.items():\n                if isinstance(v, Tensor):\n                    v = v.to(device=device, non_blocking=non_blocking)\n                elif hasattr(v, 'to_device'):\n                    v = v.to_device(device=device, non_blocking=non_blocking)\n                new_meta[k] = v\n\n        out_dict['meta'] = new_meta\n        return MultiModalData(**out_dict)\n\n\nMultiModalInputs = Dict[str, List[MultiModalData]]\n"
  },
  {
    "path": "lmdeploy/pytorch/nn/__init__.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n# attention module is modified from:\n# https://github.com/vllm-project/vllm/blob/main/vllm/attention/\nfrom .activation import GeluAndMul, SiluAndMul  # noqa: F401\nfrom .attention import Attention, FlashAttention  # noqa: F401\nfrom .embedding import ParallelEmbedding  # noqa: F401\nfrom .norm import LayerNorm, RMSNorm  # noqa: F401\nfrom .rotary_embedding import ApplyRotaryEmb  # noqa: F401\nfrom .rotary_embedding import RopeType  # noqa: F401\nfrom .rotary_embedding import YarnParameters  # noqa: F401\nfrom .rotary_embedding import build_rotary_embedding  # noqa: F401\nfrom .rotary_embedding import build_rotary_embedding_from_config  # noqa: F401\nfrom .rotary_embedding import build_rotary_params  # noqa: F401\n"
  },
  {
    "path": "lmdeploy/pytorch/nn/activation.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom torch import Tensor, nn\n\nfrom ..backends import OpType, get_backend\n\n\nclass SiluAndMul(nn.Module):\n    \"\"\"Silu and elementwise multiple.\"\"\"\n\n    def __init__(self, inplace: bool = True):\n        super().__init__()\n        backend = get_backend()\n        builder = backend.get_layer_impl_builder(OpType.SiluAndMul)\n        self.impl = builder.build(inplace)\n\n    def forward(self, x: Tensor):\n        \"\"\"forward.\"\"\"\n        return self.impl.forward(x)\n\n\nclass GeluAndMul(nn.Module):\n    \"\"\"Gelu and elementwise multiple.\"\"\"\n\n    def __init__(self, approximate: str = 'none'):\n        super().__init__()\n        backend = get_backend()\n        builder = backend.get_layer_impl_builder(OpType.GeluAndMul)\n        self.impl = builder.build(approximate)\n\n    def forward(self, x: Tensor):\n        \"\"\"forward.\"\"\"\n        return self.impl.forward(x)\n"
  },
  {
    "path": "lmdeploy/pytorch/nn/attention.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport torch\nfrom torch import nn\n\nfrom lmdeploy.pytorch.distributed import get_tp_world_rank\n\nfrom ..backends import OpType, get_backend\nfrom ..backends.attention import AttentionMetadata\nfrom .utils import get_distribute_size\n\n\ndef _update_num_heads(num_heads: int, num_kv_heads: int):\n    \"\"\"Update heads.\"\"\"\n    world_size, rank = get_tp_world_rank('attn')\n    num_heads = get_distribute_size(num_heads, world_size, rank)\n    num_kv_heads = get_distribute_size(num_kv_heads, world_size, rank)\n    return num_heads, num_kv_heads\n\n\nclass Attention(nn.Module):\n    \"\"\"Attention layer.\"\"\"\n\n    def __init__(\n        self,\n        num_heads: int,\n        head_size: int,\n        scale: float = None,\n        num_kv_heads: int = None,\n        v_head_size: int = None,\n        alibi: bool = False,\n        sliding_window: int = None,\n        logit_softcapping: float = 0.0,\n        causal: bool = True,\n        use_flash_mla: bool = False,\n        learnable_sink: bool = False,\n        block_sparse_size: int = 1,\n        **kwargs,\n    ):\n        super().__init__()\n        if num_kv_heads is None:\n            num_kv_heads = num_heads\n        if v_head_size is None:\n            v_head_size = head_size\n        self.origin_num_heads = num_heads\n        num_heads, num_kv_heads = _update_num_heads(num_heads, num_kv_heads)\n        self.num_heads = num_heads\n\n        layer_backend = get_backend()\n        impl_builder = layer_backend.get_layer_impl_builder(OpType.PagedAttention)\n\n        self.impl = impl_builder.build(\n            num_heads=num_heads,\n            head_size=head_size,\n            scale=scale,\n            num_kv_heads=num_kv_heads,\n            v_head_size=v_head_size,\n            alibi=alibi,\n            sliding_window=sliding_window,\n            logit_softcapping=logit_softcapping,\n            causal=causal,\n            use_flash_mla=use_flash_mla,\n            learnable_sink=learnable_sink,\n            block_sparse_size=block_sparse_size,\n            **kwargs,\n        )\n\n        if alibi:\n            self.alibi_ready = False\n        else:\n            self.alibi_ready = True\n\n    def _lazy_init(self, device):\n        \"\"\"Lazy init.\"\"\"\n        if not self.alibi_ready:\n            _, rank = get_tp_world_rank('attn')\n            start = self.num_heads * rank\n            end = start + self.num_heads\n            alibi_slopes = self.impl.make_alibi_slopes(start,\n                                                       end,\n                                                       self.origin_num_heads,\n                                                       alibi_scale=1,\n                                                       dtype=torch.float32,\n                                                       device=device)\n            self.impl.set_alibi_slopes(alibi_slopes)\n            self.alibi_ready = True\n\n    def forward(\n        self,\n        query: torch.Tensor,\n        key: torch.Tensor,\n        value: torch.Tensor,\n        k_cache: torch.Tensor,\n        v_cache: torch.Tensor,\n        attn_metadata: AttentionMetadata,\n        k_scales_zeros: torch.Tensor = None,\n        v_scales_zeros: torch.Tensor = None,\n        s_aux: torch.Tensor = None,\n        nsa_indices: torch.Tensor = None,\n        inplace: bool = True,\n    ) -> torch.Tensor:\n        \"\"\"forward.\"\"\"\n        self._lazy_init(query.device)\n\n        kwargs = dict()\n        if nsa_indices is not None:\n            kwargs['nsa_indices'] = nsa_indices\n        if s_aux is not None:\n            kwargs['learnable_sink'] = s_aux\n        return self.impl.forward(\n            query,\n            key,\n            value,\n            k_cache,\n            v_cache,\n            attn_metadata=attn_metadata,\n            k_scales_zeros=k_scales_zeros,\n            v_scales_zeros=v_scales_zeros,\n            inplace=inplace,\n            **kwargs,\n        )\n\n    @staticmethod\n    def update_meta_flashmla(attn_metadata: AttentionMetadata, num_attention_heads):\n        get_backend().update_meta_flashmla(attn_metadata, num_attention_heads)\n\n\nclass FlashAttention(nn.Module):\n    \"\"\"Flash attention w/o paging.\"\"\"\n\n    def __init__(\n        self,\n        num_heads: int,\n        head_dim: int,\n        scale: float = None,\n        num_kv_heads: int = None,\n        v_head_dim: int = None,\n        causal: bool = True,\n        sliding_window: int = None,\n        logit_softcapping: float = 0.0,\n        **kwargs,\n    ):\n        super().__init__()\n        if num_kv_heads is None:\n            num_kv_heads = num_heads\n        if v_head_dim is None:\n            v_head_dim = head_dim\n        num_heads, num_kv_heads = _update_num_heads(num_heads, num_kv_heads)\n\n        layer_backend = get_backend()\n\n        impl_builder = layer_backend.get_layer_impl_builder(OpType.FlashAttention)\n\n        self.impl = impl_builder.build(\n            num_heads=num_heads,\n            head_dim=head_dim,\n            scale=scale,\n            num_kv_heads=num_kv_heads,\n            v_head_dim=v_head_dim,\n            causal=causal,\n            sliding_window=sliding_window,\n            logit_softcapping=logit_softcapping,\n            **kwargs,\n        )\n\n    def forward(self,\n                query: torch.Tensor,\n                key: torch.Tensor,\n                value: torch.Tensor,\n                q_start_loc: torch.Tensor,\n                q_seqlens: torch.Tensor,\n                kv_start_loc: torch.Tensor = None,\n                kv_seqlens: torch.Tensor = None,\n                max_q_seqlen: int = None) -> torch.Tensor:\n        \"\"\"forward.\"\"\"\n\n        if max_q_seqlen is None:\n            max_q_seqlen = query.numel() // (query.size(-1) * query.size(-2))\n\n        if kv_start_loc is None and kv_seqlens is None:\n            kv_start_loc = q_start_loc\n            kv_seqlens = q_seqlens\n\n        assert kv_start_loc is not None\n        assert kv_seqlens is not None\n\n        return self.impl.forward(\n            query,\n            key,\n            value,\n            q_start_loc=q_start_loc,\n            q_seqlens=q_seqlens,\n            kv_start_loc=kv_start_loc,\n            kv_seqlens=kv_seqlens,\n            max_q_seqlen=max_q_seqlen,\n        )\n"
  },
  {
    "path": "lmdeploy/pytorch/nn/embedding.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport torch\nfrom torch import nn\n\nfrom lmdeploy.pytorch.backends import OpType, get_backend\nfrom lmdeploy.pytorch.distributed import get_dist_group, get_dist_manager, get_tp_world_rank\nfrom lmdeploy.pytorch.weight_loader.model_weight_loader import default_weight_loader\n\nDEFAULT_VOCAB_PADDING_SIZE = 64\n\n\ndef pad_vocab_size(vocab_size: int, pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int:\n    \"\"\"Pad the vocab size to the given value.\"\"\"\n    return ((vocab_size + pad_to - 1) // pad_to) * pad_to\n\n\nclass ParallelEmbedding(nn.Module):\n\n    def __init__(\n        self,\n        vocab_size: int,\n        hidden_size: int,\n        padding_idx: int,\n        dtype: torch.dtype = None,\n        device: torch.device = None,\n        is_tp: bool = False,\n        padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,\n        layer_type: str = 'attn',\n        force_dtype: torch.dtype = None,\n    ):\n        self.dist_ctx = get_dist_manager().current_context()\n        super().__init__()\n\n        self.is_tp = is_tp\n        self.vocab_size = vocab_size\n        self.padding_size = padding_size\n        if padding_idx is not None:\n            if padding_idx < 0:\n                padding_idx = vocab_size + padding_idx\n            assert padding_idx >= 0 and padding_idx < vocab_size\n        self.padding_idx = padding_idx\n\n        dist_cfg = get_dist_manager().current_config()\n        _, self.rank = get_tp_world_rank(layer_type)\n        self.tp, _ = dist_cfg.get_tp_by_layer(layer_type)\n\n        dist_group = get_dist_group(layer_type=layer_type)\n        self.tp_group = dist_group.gpu_group\n\n        if is_tp and self.tp > 1:\n            self.vocab_size_padded = pad_vocab_size(self.vocab_size, self.padding_size)\n            assert self.vocab_size_padded % self.tp == 0, \\\n                f'vocab_size_padded({self.vocab_size_padded}) must be divisible by tp({self.tp})'\n            self.vocab_size_padded = self.vocab_size_padded // self.tp\n        else:\n            self.vocab_size_padded = self.vocab_size\n\n        self.out_dtype = dtype\n        self.start_index = self.rank * self.vocab_size_padded\n        self.end_index = (self.rank + 1) * self.vocab_size_padded\n        weight_dtype = force_dtype or dtype\n        self.register_parameter('weight', self.create_weight(self.vocab_size_padded, hidden_size, weight_dtype, device))\n        self.weight.weight_loader = self.weight_loader\n\n        backend = get_backend()\n        builder = backend.get_layer_impl_builder(OpType.Embedding)\n        self.impl = builder.build(self.start_index, self.end_index)\n\n        self.all_reduce = self.is_tp and self.tp > 1\n\n    @staticmethod\n    def create_weight(vocab_size: int, hidden_size: int, dtype: torch.dtype = None, device: torch.device = None):\n        \"\"\"Create weight.\"\"\"\n        if dtype is None:\n            dtype = torch.float16\n        if device is None:\n            device = 'cuda'\n        weight = torch.nn.Parameter(torch.zeros((vocab_size, hidden_size), dtype=dtype, device=device),\n                                    requires_grad=False)\n        return weight\n\n    def _weight_loader_tp_rowwise(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor):\n        \"\"\"Weight loader for rowwise embedding.\"\"\"\n        loaded_weight = loaded_weight.to(param.device)\n\n        shard_size = self.vocab_size_padded\n        if self.end_index > loaded_weight.shape[0]:\n            shard_size = loaded_weight.shape[0] - self.start_index\n\n        loaded_weight = loaded_weight.narrow(0, self.start_index, shard_size)\n        param[:loaded_weight.shape[0]].data.copy_(loaded_weight)\n\n    def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor):\n        \"\"\"Weight loader.\"\"\"\n        if not self.all_reduce:\n            default_weight_loader(param, loaded_weight)\n            if self.padding_idx is not None:\n                self.weight[self.padding_idx] = 0\n        else:\n            self._weight_loader_tp_rowwise(param, loaded_weight)\n            if (self.padding_idx is not None and self.padding_idx >= self.start_index\n                    and self.padding_idx < self.end_index):\n                self.weight[self.padding_idx - self.start_index] = 0\n\n    def forward(self, x: torch.Tensor):\n        embeddings = self.impl.forward(x, self.weight, all_reduce=self.all_reduce, group=self.tp_group)\n        if self.out_dtype is not None and embeddings.dtype != self.out_dtype:\n            embeddings = embeddings.to(dtype=self.out_dtype)\n        return embeddings\n"
  },
  {
    "path": "lmdeploy/pytorch/nn/eplb.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport torch\n\n\nclass EPLBDispatchInfo:\n\n    def __init__(self, info) -> None:\n        self.info = info\n\n\nclass EPLBManager:\n    eplb = None\n\n    @classmethod\n    def init_global_eplb_metadata(cls, ep_size: int, num_routed_experts: int, num_hidden_layers: int):\n        assert ep_size > 1, 'eplb requires ep_size > 1'\n        from dlblas.layers.moe import eplb\n        EPLBManager.eplb = eplb\n        eplb.init_global_eplb_metadata(ep_size=ep_size,\n                                       num_routed_experts=num_routed_experts,\n                                       num_hidden_layers=num_hidden_layers)\n\n    @classmethod\n    def num_physical_experts(cls) -> int:\n        return EPLBManager.eplb.get_global_eplb_metadata().num_physical_experts()\n\n    @classmethod\n    def topk_ids_logical_to_physical(cls, topk_ids: torch.Tensor, eplb_dispatch_info: EPLBDispatchInfo):\n        return EPLBManager.eplb.topk_ids_logical_to_physical(topk_ids=topk_ids, info=eplb_dispatch_info.info)\n\n    @classmethod\n    def get_dispatch_info(cls, ep_rank, layer_idx) -> EPLBDispatchInfo:\n        info = EPLBManager.eplb.EPLBDispatchInfo.init_new(ep_rank=ep_rank, layer_idx=layer_idx)\n        return EPLBDispatchInfo(info)\n"
  },
  {
    "path": "lmdeploy/pytorch/nn/gated_delta.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nfrom typing import Any, Sequence, Tuple\n\nimport torch\nfrom torch import nn\nfrom torch.profiler import record_function\n\nfrom lmdeploy.pytorch.backends import OpType, get_backend\nfrom lmdeploy.pytorch.distributed import get_tp_world_rank\nfrom lmdeploy.pytorch.weight_loader.model_weight_loader import default_weight_loader\nfrom lmdeploy.utils import get_logger\n\nlogger = get_logger('lmdeploy')\n\n\ndef build_rmsnorm_gated(hidden_size: int, eps=1e-6, **kwargs):\n    # TODO: used custom kernel\n    from fla.modules import FusedRMSNormGated\n    try:\n        # avoid unwanted specialize\n        from fla.modules.fused_norm_gate import layer_norm_gated_fwd_kernel\n        keys = layer_norm_gated_fwd_kernel.fn.keys\n        if 'NB' in keys:\n            keys.remove('NB')\n    except Exception:\n        logger.debug('patch layer_norm_gated_fwd_kernel autotuning failed.')\n    return FusedRMSNormGated(hidden_size, eps=eps, **kwargs)\n\n\nclass GatedDeltaMeta:\n\n    def __init__(self, num_tokens: int, conv_kernel_size: int, state_ids: torch.Tensor, attn_metadata: Any):\n        self.num_tokens = num_tokens\n        self.is_decoding = attn_metadata.is_decoding\n        self.cu_seqlens = attn_metadata.cu_seqlens_q\n        device = self.cu_seqlens.device\n\n        # get seq_idx (1, num_tokens)\n        seqlens = attn_metadata.q_seqlens\n        batch_size = seqlens.numel()\n        batch_idx = torch.arange(0, batch_size, dtype=torch.int32, device=device)\n        self.seq_idx = torch.repeat_interleave(batch_idx, seqlens, output_size=num_tokens)[None]\n\n        # conv_idx\n        range_idx = torch.arange(-conv_kernel_size, 0, device=device)\n        self.conv_idx = self.cu_seqlens[1:, None] + range_idx[None]\n        self.conv_idx = self.conv_idx.clamp_min(0)\n\n        self.conv_state_indices = state_ids.to(torch.int32)\n        # we assume 0 is dummy state, shared by all invalid states.\n        self.valid_state = state_ids >= 0\n        self.state_ids = state_ids.clamp(0)\n\n\nclass CausalConv1dFunc:\n\n    def __init__(self, activation: str = 'silu'):\n        backend = get_backend()\n        builder = backend.get_layer_impl_builder(OpType.CausalConv1d)\n        impl = builder.build()\n        self.causal_conv1d_fn = impl.conv1d_fn\n        self.causal_conv1d_update = impl.update_fn\n        self.activation = activation\n\n    def conv1d_func(self, x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, conv_state: torch.Tensor,\n                    gated_delta_meta: GatedDeltaMeta):\n        \"\"\"\n        x: (b, seqlen, dim)\n        seqlen: (b)\n        out: (b, seqlen, dim)\n        conv_state: (b, dim, kernel_size)\n        \"\"\"\n        seq_idx = gated_delta_meta.seq_idx\n        conv_idx = gated_delta_meta.conv_idx\n        state_ids = gated_delta_meta.state_ids\n\n        assert x.dim() == 3\n        x = x.transpose(-2, -1)\n        if weight.dim() == 3:\n            assert weight.size(1) == 1\n            weight = weight[:, 0]\n\n        # fill conv state\n        # TODO: find efficient way to fill conv state without gather + scatter\n        final_state = conv_state.index_select(0, state_ids)\n        batch_size = conv_state.size(0)\n        conv_idx = conv_idx[:, None].expand(-1, x.size(1), -1)\n        torch.gather(x.expand(batch_size, -1, -1), -1, conv_idx, out=final_state)\n        conv_state = conv_state.index_copy_(0, state_ids, final_state)\n\n        out = self.causal_conv1d_fn(\n            x,\n            weight,\n            bias,\n            seq_idx,\n            return_final_states=False,\n            activation=self.activation,\n        )\n\n        out = out.transpose(-2, -1)\n\n        # store conv_state\n        return out, conv_state\n\n    def conv1d_update(\n        self,\n        x: torch.Tensor,\n        weight: torch.Tensor,\n        bias: torch.Tensor,\n        conv_state: torch.Tensor,\n        conv_state_indices: torch.Tensor,\n    ):\n        if weight.dim() == 3:\n            assert weight.size(1) == 1\n            weight = weight[:, 0]\n        out = self.causal_conv1d_update(x[0],\n                                        conv_state,\n                                        weight,\n                                        bias,\n                                        activation=self.activation,\n                                        conv_state_indices=conv_state_indices)\n        return out[None], conv_state\n\n    @record_function('causal_conv1d')\n    def __call__(\n        self,\n        x: torch.Tensor,\n        weight: torch.Tensor,\n        bias: torch.Tensor,\n        conv_state: torch.Tensor,\n        gated_delta_meta: GatedDeltaMeta,\n    ):\n        if gated_delta_meta.is_decoding:\n            conv_state_indices = gated_delta_meta.conv_state_indices\n            return self.conv1d_update(x, weight, bias, conv_state, conv_state_indices)\n        return self.conv1d_func(x, weight, bias, conv_state, gated_delta_meta=gated_delta_meta)\n\n\nclass GatedDelta:\n\n    def __init__(self, use_qk_l2norm_in_kernel: bool = True):\n        backend = get_backend()\n        builder = backend.get_layer_impl_builder(OpType.GatedDeltaRule)\n        self.impl = builder.build()\n        self.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel\n\n    def __call__(\n        self,\n        query: torch.Tensor,\n        key: torch.Tensor,\n        value: torch.Tensor,\n        g: torch.Tensor,\n        beta: torch.Tensor,\n        recurrent_state: torch.Tensor,\n        gated_delta_meta: GatedDeltaMeta,\n    ):\n        \"\"\"call.\"\"\"\n        is_decoding = gated_delta_meta.is_decoding\n        cu_seqlens = gated_delta_meta.cu_seqlens\n        state_ids = gated_delta_meta.state_ids\n\n        if not is_decoding:\n            core_attn_out, last_recurrent_state = self.impl.chunk_gated_delta_rule(\n                query,\n                key,\n                value,\n                g=g,\n                beta=beta,\n                initial_state=recurrent_state,\n                state_indices=state_ids,\n                output_final_state=True,\n                use_qk_l2norm_in_kernel=self.use_qk_l2norm_in_kernel,\n                cu_seqlens=cu_seqlens,\n            )\n        else:\n            # qkvgb (1, seqlen, ...) -> (seqlen, 1, ...)\n            core_attn_out, last_recurrent_state = self.impl.fused_recurrent_gated_delta_rule(\n                query[0, :, None],\n                key[0, :, None],\n                value[0, :, None],\n                g=g[0, :, None],\n                beta=beta[0, :, None],\n                initial_state=recurrent_state,\n                output_final_state=True,\n                use_qk_l2norm_in_kernel=self.use_qk_l2norm_in_kernel,\n                state_indices=state_ids,\n            )\n            # out (seqlen, 1, ...) -> (1, seqlen, ...)\n            core_attn_out = core_attn_out[None, :, 0]\n        return core_attn_out, last_recurrent_state\n\n\nclass CausalConv1d(nn.Module):\n    \"\"\"Causal conv1d wrapper.\"\"\"\n\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        kernel_size: int | Tuple[int],\n        split: Sequence[int],\n        groups: int = 1,\n        bias: bool = True,\n        device: str | torch.device | None = None,\n        dtype: torch.dtype | None = None,\n    ):\n        super().__init__()\n        tp, rank = get_tp_world_rank()\n        self.tp = tp\n        self.rank = rank\n        in_channels = in_channels // tp\n        out_channels = out_channels // tp\n        groups = groups // tp\n        assert len(split) == 3\n        self.split = split\n\n        weight, w_bias = self.make_weight(\n            in_channels,\n            out_channels,\n            kernel_size=kernel_size,\n            groups=groups,\n            bias=bias,\n            device=device,\n            dtype=dtype,\n        )\n\n        self.register_weight(weight, w_bias)\n        self.causal_conv1d_func = CausalConv1dFunc(activation='silu')\n\n    @staticmethod\n    def make_weight(\n        in_channels: int,\n        out_channels: int,\n        kernel_size: int | Tuple[int],\n        groups: int = 1,\n        bias: bool = True,\n        device: str | torch.device | None = None,\n        dtype: torch.dtype | None = None,\n    ):\n        weight_shape = (out_channels, in_channels // groups,\n                        kernel_size if isinstance(kernel_size, int) else kernel_size[0])\n        bias_shape = (out_channels, ) if bias else None\n\n        weight = torch.empty(weight_shape, device=device, dtype=dtype)\n        if bias_shape is not None:\n            w_bias = torch.empty(bias_shape, device=device, dtype=dtype)\n        else:\n            w_bias = None\n        return weight, w_bias\n\n    def register_weight(self, weight: torch.Tensor, w_bias: torch.Tensor | None = None):\n        self.register_parameter('weight', nn.Parameter(weight))\n        self.weight.weight_loader = self.weight_loader\n        if w_bias is not None:\n            self.register_parameter('bias', nn.Parameter(w_bias))\n            self.bias.weight_loader = self.weight_loader\n        else:\n            self.register_parameter('bias', None)\n\n    def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor):\n        \"\"\"Weight loader.\"\"\"\n        q, k, v = loaded_weight.split(self.split, dim=0)\n        q = q.chunk(self.tp, dim=0)[self.rank]\n        k = k.chunk(self.tp, dim=0)[self.rank]\n        v = v.chunk(self.tp, dim=0)[self.rank]\n        loaded_weight = torch.cat([q, k, v], dim=0)\n        default_weight_loader(param, loaded_weight)\n\n    def forward(self, x: torch.Tensor, conv_state: torch.Tensor, gated_delta_meta: GatedDeltaMeta):\n        \"\"\"forward.\"\"\"\n        return self.causal_conv1d_func(x, self.weight, self.bias, conv_state, gated_delta_meta=gated_delta_meta)\n\n\n@record_function('gated_delta_load_state')\ndef load_state(past_key_value: Tuple[torch.Tensor, torch.Tensor], gated_delta_meta: GatedDeltaMeta):\n    \"\"\"Load states from cache.\"\"\"\n    return past_key_value[:2]\n"
  },
  {
    "path": "lmdeploy/pytorch/nn/linear/__init__.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom typing import Any, Dict, List, Optional\n\nimport torch\nfrom torch import nn\n\nfrom lmdeploy.pytorch.config import TPMode\nfrom lmdeploy.pytorch.distributed import get_dist_manager, get_tp_world_rank\nfrom lmdeploy.pytorch.models.patch import get_build_model_context\n\nfrom .awq import AwqLinear, MergedAwqLinear, QKVAwqLinear\nfrom .blocked_fp8 import BlockedF8Linear, MergedBlockedF8Linear, QKVBlockedF8Linear\nfrom .default import BaseLinear, MergedBaseLinear, QKVBaseLinear\nfrom .lora import LoRA  # noqa: F401\nfrom .w8a8 import MergedW8A8Linear, QKVW8A8Linear, W8A8Linear\n\n\ndef build_linear(\n    in_features: int,\n    out_features: int,\n    bias: bool,\n    dtype: Optional[torch.dtype] = None,\n    device: Optional[torch.device] = None,\n    colwise: bool = True,\n    is_tp: bool = False,\n    quant_config: Dict = None,\n    all_reduce: bool = True,\n    tp_align_size: int = 1,\n    dp_gather: bool = False,\n    layer_type: str = 'attn',\n    prefix: str = '',\n) -> nn.Module:\n    \"\"\"Build linear.\"\"\"\n    if layer_type is None:\n        layer_type = 'attn'\n    all_reduce = all_reduce if is_tp else False\n    quant_method = None\n    if quant_config is not None:\n        quant_config = get_build_model_context().quant_config\n        quant_method = quant_config.get_quant_method(prefix)\n\n    if dp_gather and quant_method is not None:\n        assert quant_method in ['fp8'], (f'Do not support dp_gather with quant_method={quant_method}')\n\n    if quant_method is None:\n        return BaseLinear(\n            in_features,\n            out_features,\n            bias=bias,\n            dtype=dtype,\n            device=device,\n            colwise=colwise,\n            is_tp=is_tp,\n            all_reduce=all_reduce,\n            tp_align_size=tp_align_size,\n            dp_gather=dp_gather,\n            layer_type=layer_type,\n        )\n\n    if quant_method == 'awq':\n        return AwqLinear(\n            in_features,\n            out_features,\n            w_bit=quant_config.bits,\n            group_size=quant_config.group_size,\n            bias=bias,\n            device=device,\n            colwise=colwise,\n            is_tp=is_tp,\n            all_reduce=all_reduce,\n            layer_type=layer_type,\n        )\n    if quant_method == 'smooth_quant':\n        return W8A8Linear(in_features,\n                          out_features,\n                          bias=bias,\n                          dtype=dtype,\n                          device=device,\n                          colwise=colwise,\n                          is_tp=is_tp,\n                          all_reduce=all_reduce,\n                          quant_dtype=quant_config.quant_dtype,\n                          layer_type=layer_type)\n    elif quant_method == 'fp8':\n        return BlockedF8Linear(\n            in_features,\n            out_features,\n            bias=bias,\n            fp8_dtype=quant_config.quant_dtype,\n            scale_fmt=quant_config.scale_fmt,\n            dtype=dtype,\n            device=device,\n            colwise=colwise,\n            is_tp=is_tp,\n            all_reduce=all_reduce,\n            dp_gather=dp_gather,\n            layer_type=layer_type,\n        )\n    else:\n        raise RuntimeError(f'Unsupported quant method: {quant_method}')\n\n\ndef build_colwise_linear(\n    in_features: int,\n    out_features: int,\n    bias: bool,\n    dtype: Optional[torch.dtype] = None,\n    device: Optional[torch.device] = None,\n    is_tp: bool = False,\n    tp_align_size: int = 1,\n    quant_config: Dict = None,\n    dp_disable_tp: bool = False,\n    dp_gather: bool = False,\n    check_dist: bool = True,\n    layer_type: str = 'attn',\n    prefix: str = '',\n) -> nn.Module:\n    \"\"\"Build columnwise parallel linear layer.\"\"\"\n    if check_dist:\n        dist_config = get_dist_manager().current_config()\n        tp, tp_mode = dist_config.get_tp_by_layer(layer_type)\n\n        # check is_tp\n        is_tp = is_tp if tp > 1 else False\n        is_tp = False if (dp_disable_tp and dist_config.dp > 1) else is_tp\n\n        # check dp_gather\n        dp_gather = dp_gather if is_tp and tp_mode == TPMode.DP_TP else False\n\n    return build_linear(\n        in_features=in_features,\n        out_features=out_features,\n        bias=bias,\n        dtype=dtype,\n        device=device,\n        colwise=True,\n        is_tp=is_tp,\n        quant_config=quant_config,\n        all_reduce=False,\n        tp_align_size=tp_align_size,\n        dp_gather=dp_gather,\n        layer_type=layer_type,\n        prefix=prefix,\n    )\n\n\ndef build_rowwise_linear(\n    in_features: int,\n    out_features: int,\n    bias: bool,\n    dtype: Optional[torch.dtype] = None,\n    device: Optional[torch.device] = None,\n    is_tp: bool = False,\n    tp_align_size: int = 1,\n    quant_config: Dict = None,\n    all_reduce: bool = True,\n    dp_disable_tp: bool = False,\n    check_dist: bool = True,\n    layer_type: str = 'attn',\n    prefix: str = '',\n) -> nn.Module:\n    \"\"\"Build rowwise parallel linear layer.\"\"\"\n    if check_dist:\n        dist_config = get_dist_manager().current_config()\n        tp, _ = dist_config.get_tp_by_layer(layer_type)\n        is_tp = is_tp if tp > 1 else False\n        is_tp = False if (dp_disable_tp and dist_config.dp > 1) else is_tp\n    return build_linear(\n        in_features=in_features,\n        out_features=out_features,\n        bias=bias,\n        dtype=dtype,\n        device=device,\n        colwise=False,\n        is_tp=is_tp,\n        quant_config=quant_config,\n        all_reduce=all_reduce,\n        tp_align_size=tp_align_size,\n        layer_type=layer_type,\n        prefix=prefix,\n    )\n\n\ndef build_merged_colwise_linear(\n    in_features: int,\n    all_out_features: List[int],\n    bias: bool,\n    dtype: Optional[torch.dtype] = None,\n    device: Optional[torch.device] = None,\n    quant_config: Dict = None,\n    is_tp: bool = True,\n    out_names: List[Any] = None,\n    dp_gather: bool = False,\n    check_dist: bool = True,\n    layer_type: str = 'attn',\n    prefix: str = '',\n):\n    \"\"\"Merge linear.\"\"\"\n    if check_dist and is_tp:\n        is_tp = get_tp_world_rank(layer_type)[0] > 1\n    quant_method = None\n    if quant_config is not None:\n        quant_config = get_build_model_context().quant_config\n        quant_method = quant_config.get_quant_method(prefix)\n    if dp_gather and quant_method is not None:\n        assert quant_method in ['fp8'], (f'Do not support dp_gather with quant_method={quant_method}')\n\n    if quant_method is None:\n        return MergedBaseLinear(in_features=in_features,\n                                all_out_features=all_out_features,\n                                bias=bias,\n                                dtype=dtype,\n                                device=device,\n                                is_tp=is_tp,\n                                out_names=out_names,\n                                dp_gather=dp_gather,\n                                layer_type=layer_type)\n\n    if quant_method == 'awq':\n        return MergedAwqLinear(\n            in_features,\n            all_out_features=all_out_features,\n            w_bit=quant_config.bits,\n            group_size=quant_config.group_size,\n            bias=bias,\n            device=device,\n            is_tp=is_tp,\n            layer_type=layer_type,\n        )\n    if quant_method == 'smooth_quant':\n        return MergedW8A8Linear(in_features=in_features,\n                                all_out_features=all_out_features,\n                                bias=bias,\n                                dtype=dtype,\n                                device=device,\n                                is_tp=is_tp,\n                                out_names=out_names,\n                                quant_dtype=quant_config.quant_dtype,\n                                layer_type=layer_type)\n    elif quant_method == 'fp8':\n        return MergedBlockedF8Linear(\n            in_features=in_features,\n            all_out_features=all_out_features,\n            bias=bias,\n            fp8_dtype=quant_config.quant_dtype,\n            scale_fmt=quant_config.scale_fmt,\n            dtype=dtype,\n            device=device,\n            is_tp=is_tp,\n            out_names=out_names,\n            dp_gather=dp_gather,\n            layer_type=layer_type,\n        )\n    else:\n        raise RuntimeError(f'Unsupported quant method: {quant_method}')\n\n\ndef build_qkv_proj(in_features: int,\n                   num_q_heads: int,\n                   num_kv_heads: int,\n                   head_size: int,\n                   head_size_v: int = None,\n                   bias: bool = False,\n                   quant_config: Dict = None,\n                   dtype: Optional[torch.dtype] = None,\n                   device: Optional[torch.device] = None,\n                   is_tp: bool = True,\n                   num_replicate_kv_heads: int = 1,\n                   prefix: str = ''):\n    \"\"\"Build qkv proj.\"\"\"\n    dist_config = get_dist_manager().current_config()\n    is_tp = is_tp if dist_config.attn_tp > 1 else False\n    quant_method = None\n    if quant_config is not None:\n        quant_config = get_build_model_context().quant_config\n        quant_method = quant_config.get_quant_method(prefix)\n    if head_size_v is None:\n        head_size_v = head_size\n\n    if quant_method is None:\n        return QKVBaseLinear(in_features=in_features,\n                             num_q_heads=num_q_heads,\n                             num_kv_heads=num_kv_heads,\n                             head_size=head_size,\n                             head_size_v=head_size_v,\n                             bias=bias,\n                             dtype=dtype,\n                             device=device,\n                             is_tp=is_tp,\n                             num_replicate_kv_heads=num_replicate_kv_heads)\n\n    if quant_method == 'awq':\n        return QKVAwqLinear(in_features=in_features,\n                            num_q_heads=num_q_heads,\n                            num_kv_heads=num_kv_heads,\n                            head_size=head_size,\n                            head_size_v=head_size_v,\n                            w_bit=quant_config.bits,\n                            group_size=quant_config.group_size,\n                            bias=bias,\n                            device=device,\n                            is_tp=is_tp,\n                            num_replicate_kv_heads=num_replicate_kv_heads)\n    if quant_method == 'smooth_quant':\n        return QKVW8A8Linear(in_features=in_features,\n                             num_q_heads=num_q_heads,\n                             num_kv_heads=num_kv_heads,\n                             head_size=head_size,\n                             head_size_v=head_size_v,\n                             bias=bias,\n                             dtype=dtype,\n                             device=device,\n                             is_tp=is_tp,\n                             num_replicate_kv_heads=num_replicate_kv_heads,\n                             quant_dtype=quant_config.quant_dtype)\n    if quant_method == 'fp8':\n        return QKVBlockedF8Linear(in_features=in_features,\n                                  num_q_heads=num_q_heads,\n                                  num_kv_heads=num_kv_heads,\n                                  head_size=head_size,\n                                  head_size_v=head_size_v,\n                                  bias=bias,\n                                  fp8_dtype=quant_config.quant_dtype,\n                                  scale_fmt=quant_config.scale_fmt,\n                                  dtype=dtype,\n                                  device=device,\n                                  is_tp=is_tp,\n                                  dp_gather=False,\n                                  num_replicate_kv_heads=num_replicate_kv_heads)\n    else:\n        raise RuntimeError(f'Unsupported quant method: {quant_method}')\n\n\ndef build_o_proj(\n    in_features: int,\n    out_features: int,\n    bias: bool,\n    dtype: Optional[torch.dtype] = None,\n    device: Optional[torch.device] = None,\n    is_tp: bool = False,\n    tp_align_size: int = 1,\n    quant_config: Dict = None,\n    all_reduce: bool = True,\n    prefix: str = '',\n) -> nn.Module:\n    \"\"\"Build down linear.\"\"\"\n    dist_config = get_dist_manager().current_config()\n    is_tp = is_tp if dist_config.attn_tp > 1 else False\n\n    return build_rowwise_linear(\n        in_features=in_features,\n        out_features=out_features,\n        bias=bias,\n        dtype=dtype,\n        device=device,\n        is_tp=is_tp,\n        tp_align_size=tp_align_size,\n        quant_config=quant_config,\n        all_reduce=all_reduce,\n        check_dist=False,\n        layer_type='attn',\n        prefix=prefix,\n    )\n\n\ndef build_gateup_linear(\n    in_features: int,\n    all_out_features: List[int],\n    bias: bool,\n    dtype: Optional[torch.dtype] = None,\n    device: Optional[torch.device] = None,\n    quant_config: Dict = None,\n    is_tp: bool = True,\n    out_names: List[Any] = None,\n    dp_gather: bool = True,\n    prefix: str = '',\n):\n    \"\"\"Build gate up linear.\"\"\"\n    dist_config = get_dist_manager().current_config()\n    tp, tp_mode = dist_config.get_tp_by_layer('mlp')\n    is_tp = is_tp if tp > 1 else False\n    dp_gather = dp_gather if is_tp and tp_mode == TPMode.DP_TP else False\n\n    return build_merged_colwise_linear(\n        in_features=in_features,\n        all_out_features=all_out_features,\n        bias=bias,\n        dtype=dtype,\n        device=device,\n        quant_config=quant_config,\n        is_tp=is_tp,\n        out_names=out_names,\n        dp_gather=dp_gather,\n        check_dist=False,\n        layer_type='mlp',\n        prefix=prefix,\n    )\n\n\ndef build_down_linear(\n    in_features: int,\n    out_features: int,\n    bias: bool,\n    dtype: Optional[torch.dtype] = None,\n    device: Optional[torch.device] = None,\n    is_tp: bool = False,\n    tp_align_size: int = 1,\n    quant_config: Dict = None,\n    all_reduce: bool = True,\n    prefix: str = '',\n) -> nn.Module:\n    \"\"\"Build down linear.\"\"\"\n    dist_config = get_dist_manager().current_config()\n    is_tp = is_tp if dist_config.mlp_tp > 1 else False\n\n    return build_rowwise_linear(\n        in_features=in_features,\n        out_features=out_features,\n        bias=bias,\n        dtype=dtype,\n        device=device,\n        is_tp=is_tp,\n        tp_align_size=tp_align_size,\n        quant_config=quant_config,\n        all_reduce=all_reduce,\n        check_dist=False,\n        layer_type='mlp',\n        prefix=prefix,\n    )\n"
  },
  {
    "path": "lmdeploy/pytorch/nn/linear/awq.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom typing import Any, List, Optional\n\nimport torch\n\nfrom lmdeploy.pytorch.backends import OpType, get_backend\nfrom lmdeploy.pytorch.weight_loader.model_weight_loader import default_weight_loader\n\nfrom ..utils import chunk_aligned, get_distribute_size\nfrom .base import LinearBase\nfrom .utils import QKVMixin, check_qkv_split_layout\n\n\nclass AwqLinear(LinearBase):\n    \"\"\"W4a16 linear.\"\"\"\n\n    def __init__(\n        self,\n        in_features: int,\n        out_features: int,\n        w_bit: int,\n        group_size: int,\n        bias: bool,\n        device: Optional[torch.device] = None,\n        colwise: bool = True,\n        is_tp: bool = False,\n        all_reduce: bool = True,\n        layer_type: str = 'attn',\n    ):\n        super().__init__(dtype=torch.float16,\n                         device=device,\n                         colwise=colwise,\n                         is_tp=is_tp,\n                         all_reduce=all_reduce,\n                         layer_type=layer_type)\n        if self.is_tp:\n            in_features, out_features = self._get_io_features(in_features, out_features, w_bit, group_size, colwise)\n        qweight, scales, qzeros, bias = self.create_weights(in_features, out_features, w_bit, group_size, bias,\n                                                            self.dtype, self.device)\n        impl_builder = get_backend().get_layer_impl_builder(OpType.LinearW4A16)\n        self.impl = impl_builder.build(in_features,\n                                       out_features,\n                                       w_bit,\n                                       group_size,\n                                       bias is not None,\n                                       dtype=scales.dtype)\n        self.register_all_parameters(qweight, scales, qzeros, bias)\n\n        self.in_features = in_features\n        self.out_features = out_features\n        self.w_bit = w_bit\n        self.group_size = group_size\n        self.elem_per_int = 32 // w_bit\n\n    def setup_loaders(self):\n        \"\"\"Setup weight loaders.\"\"\"\n        self.qweight.weight_loader = self.weight_loader\n        self.qweight._weight_type = 'qweight'\n        self.scales.weight_loader = self.weight_loader\n        self.scales._weight_type = 'scales'\n        self.qzeros.weight_loader = self.weight_loader\n        self.qzeros._weight_type = 'qzeros'\n        if self.bias is not None:\n            self.bias.weight_loader = self.weight_loader\n            self.bias._weight_type = 'bias'\n\n    def register_all_parameters(self,\n                                qweight: torch.Tensor,\n                                scales: torch.Tensor,\n                                qzeros: torch.Tensor,\n                                bias: Optional[torch.Tensor] = None):\n        \"\"\"Register all parameters.\"\"\"\n        qweight = torch.nn.Parameter(qweight, requires_grad=False)\n        scales = torch.nn.Parameter(scales, requires_grad=False)\n        qzeros = torch.nn.Parameter(qzeros, requires_grad=False)\n        if bias is not None:\n            bias = torch.nn.Parameter(bias, requires_grad=False)\n        self.register_parameter('qweight', qweight)\n        self.register_parameter('scales', scales)\n        self.register_parameter('qzeros', qzeros)\n        self.register_parameter('bias', bias)\n        self.setup_loaders()\n\n    def _get_io_features(self, in_features: int, out_features: int, w_bit: int, group_size: int, colwise: bool):\n        \"\"\"Get io features.\"\"\"\n        align = max(32 // w_bit, group_size)\n        world_size, rank = self.get_tp_world_rank()\n        if colwise:\n            out_features = get_distribute_size(out_features, world_size, rank, align=align)\n        else:\n            in_features = get_distribute_size(in_features, world_size, rank, align=align)\n        return in_features, out_features\n\n    def _weight_loader_tp_colwise(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, rank: int,\n                                  world_size: int):\n        \"\"\"Weight loader for colwise linear.\"\"\"\n        if loaded_weight.dim() == 1:\n            # bias\n            align = max(self.elem_per_int, self.group_size)\n            weight = chunk_aligned(loaded_weight, world_size, 0, align)[rank]\n            return default_weight_loader(param, weight)\n\n        if loaded_weight.size(1) == self.out_features:\n            # scaling\n            align = max(self.elem_per_int, self.group_size)\n            weight = chunk_aligned(loaded_weight, world_size, 1, align)[rank]\n            return default_weight_loader(param, weight)\n\n        align = max(self.elem_per_int, self.group_size) // self.elem_per_int\n        weight = chunk_aligned(loaded_weight, world_size, 1, align)[rank]\n        return default_weight_loader(param, weight)\n\n    def _weight_loader_tp_rowwise(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, rank: int,\n                                  world_size: int):\n        \"\"\"Weight loader for rowwise linear.\"\"\"\n        if loaded_weight.dim() == 1:\n            # bias\n            if rank == 0:\n                loaded_weight = torch.zeros_like(loaded_weight)\n            return default_weight_loader(param, loaded_weight)\n\n        if loaded_weight.size(0) == self.in_features:\n            # qweight\n            align = max(self.elem_per_int, self.group_size)\n            weight = chunk_aligned(loaded_weight, world_size, 0, align)[rank]\n            return default_weight_loader(param, weight)\n\n        align = max(self.elem_per_int, self.group_size) // self.group_size\n        weight = chunk_aligned(loaded_weight, world_size, 0, align)[rank]\n        return default_weight_loader(param, weight)\n\n    def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor):\n        \"\"\"Weight loader.\"\"\"\n        if not self.is_tp:\n            return default_weight_loader(param, loaded_weight)\n\n        world_size, rank = self.get_tp_world_rank()\n        if self.colwise:\n            return self._weight_loader_tp_colwise(param, loaded_weight, rank, world_size)\n        else:\n            return self._weight_loader_tp_rowwise(param, loaded_weight, rank, world_size)\n\n    def create_weights(self, in_features: int, out_features: int, w_bit: int, group_size: int, bias: bool,\n                       dtype: torch.dtype, device: torch.device):\n        \"\"\"Create weights.\"\"\"\n        assert in_features % group_size == 0\n        elem_per_int = 32 // w_bit\n        assert out_features % elem_per_int == 0\n\n        grouped_in_feats = in_features // group_size\n        quant_out_feats = out_features // elem_per_int\n        qweight = torch.empty((in_features, quant_out_feats), dtype=torch.int32, device=device)\n        scales = torch.empty((grouped_in_feats, out_features), dtype=dtype, device=device)\n        qzeros = torch.empty((grouped_in_feats, quant_out_feats), dtype=torch.int32, device=device)\n        if bias:\n            bias = torch.empty((out_features, ), dtype=dtype, device=device)\n        else:\n            bias = None\n        return qweight, scales, qzeros, bias\n\n    def update_weights(self):\n        \"\"\"Update weights.\"\"\"\n        qweight, scales, qzeros, bias = self.impl.update_weights(self.qweight, self.scales, self.qzeros, self.bias)\n        self.register_all_parameters(qweight, scales, qzeros, bias)\n\n    def _forward_default(self, x, all_reduce, tp_sizes):\n        \"\"\"Default forward implement.\"\"\"\n        return self.impl.forward(x, self.qweight, self.scales, self.qzeros, self.bias, all_reduce, group=self.tp_group)\n\n\nclass MergedAwqLinear(AwqLinear):\n    \"\"\"Merged awq linear.\"\"\"\n\n    def __init__(self,\n                 in_features: int,\n                 all_out_features: List[int],\n                 w_bit: int,\n                 group_size: int,\n                 bias: bool,\n                 device: Optional[torch.device] = None,\n                 is_tp: bool = True,\n                 out_names: Optional[List[int]] = None,\n                 layer_type: str = 'attn'):\n        self.init_tp_args(is_tp, all_reduce=False, colwise=True, layer_type=layer_type)\n\n        self.split_section_s = all_out_features\n        elem_per_int = 32 // w_bit\n        self.split_section_wz = [size // elem_per_int for size in all_out_features]\n\n        all_out_features = self._update_all_out_features(all_out_features, w_bit, group_size)\n        self.all_out_features = all_out_features\n        if out_names is None:\n            out_names = torch.arange(len(self.all_out_features)).tolist()\n        assert len(out_names) == len(self.all_out_features)\n        self.out_names_map = dict((name, idx) for idx, name in enumerate(out_names))\n        out_features = sum(all_out_features)\n        super().__init__(in_features,\n                         out_features,\n                         w_bit,\n                         group_size,\n                         bias,\n                         device,\n                         colwise=True,\n                         is_tp=is_tp,\n                         layer_type=layer_type)\n        self.setup_loaders()\n\n    def setup_loaders(self):\n        \"\"\"Setup weight loaders.\"\"\"\n        self.qweight.weight_loader = self.weight_loader\n        self.qweight.weight_spliter = self.weight_spliter_wz\n        self.qweight._weight_type = 'qweight'\n        self.scales.weight_loader = self.weight_loader\n        self.scales.weight_spliter = self.weight_spliter_s\n        self.scales._weight_type = 'scales'\n        self.qzeros.weight_loader = self.weight_loader\n        self.qzeros.weight_spliter = self.weight_spliter_wz\n        self.qzeros._weight_type = 'qzeros'\n        if self.bias is not None:\n            self.bias.weight_loader = self.weight_loader\n            self.bias.weight_spliter = self.weight_spliter_s\n            self.bias._weight_type = 'bias'\n\n    def _get_io_features(self, in_features: int, out_features: int, w_bit: int, group_size: int, colwise: bool):\n        \"\"\"Get io features.\"\"\"\n        return in_features, out_features\n\n    def _update_all_out_features(self, all_out_features: List[int], w_bit: int, group_size: int):\n        \"\"\"Update all out features.\"\"\"\n        world_size, rank = self.get_tp_world_rank()\n        new_all_out_features = []\n        align = max(32 // w_bit, group_size)\n        for out_feat in all_out_features:\n            new_out_feat = get_distribute_size(out_feat, world_size, rank, align)\n            new_all_out_features.append(new_out_feat)\n        return new_all_out_features\n\n    def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, shard_id: Any):\n        \"\"\"Weight loader.\"\"\"\n        world_size, rank = self.get_tp_world_rank()\n        shard_idx = self.out_names_map[shard_id]\n        if loaded_weight.dim() == 1:\n            # bias\n            align = max(self.elem_per_int, self.group_size)\n            param_w = param.data.split(self.all_out_features, 0)[shard_idx]\n            weight = chunk_aligned(loaded_weight, world_size, 0, align)[rank]\n            param_w.copy_(weight)\n\n        if param._weight_type in ['scales', 'bias']:\n            # scales\n            align = max(self.elem_per_int, self.group_size)\n            param_w = param.data.split(self.all_out_features, -1)[shard_idx]\n        else:\n            # qweight or qzeros\n            align = max(self.elem_per_int, self.group_size) // self.elem_per_int\n            quanted_out_feats = [feat // self.elem_per_int for feat in self.all_out_features]\n            param_w = param.data.split(quanted_out_feats, 1)[shard_idx]\n\n        weight = chunk_aligned(loaded_weight, world_size, -1, align)[rank]\n        param_w.copy_(weight)\n\n    def weight_spliter_wz(self, loaded_weight: torch.Tensor):\n        \"\"\"Weight spliter.\"\"\"\n        return loaded_weight.split(self.split_section_wz, dim=1)\n\n    def weight_spliter_s(self, loaded_weight: torch.Tensor):\n        \"\"\"Weight spliter.\"\"\"\n        return loaded_weight.split(self.split_section_s, dim=-1)\n\n\nclass QKVAwqLinear(MergedAwqLinear, QKVMixin):\n    \"\"\"Qkv awq linear.\"\"\"\n\n    def __init__(self,\n                 in_features: int,\n                 num_q_heads: int,\n                 num_kv_heads: int,\n                 head_size: int,\n                 head_size_v: int,\n                 w_bit: int,\n                 group_size: int,\n                 bias: bool = False,\n                 device: Optional[torch.device] = None,\n                 is_tp: bool = True,\n                 num_replicate_kv_heads: int = 1):\n        self.init_tp_args(is_tp, all_reduce=False, colwise=True, layer_type='attn')\n        QKVMixin.__init__(self,\n                          num_q_heads=num_q_heads,\n                          num_kv_heads=num_kv_heads,\n                          head_size=head_size,\n                          head_size_v=head_size_v,\n                          num_replicate_kv_heads=num_replicate_kv_heads,\n                          is_tp=is_tp,\n                          tp=self.tp,\n                          tp_rank=self.tp_rank)\n\n        elem_per_int = 32 // w_bit\n        self.qkv_split_section_s = self.qkv_split_section\n        self.qkv_split_section_wz = [size // elem_per_int for size in self.qkv_split_section_s]\n        all_out_features = self.get_qkv_out_feautures()\n        out_names = ('q', 'k', 'v')\n        super().__init__(in_features,\n                         all_out_features,\n                         w_bit=w_bit,\n                         group_size=group_size,\n                         bias=bias,\n                         device=device,\n                         is_tp=is_tp,\n                         out_names=out_names,\n                         layer_type='attn')\n\n    def _update_all_out_features(self, all_out_features: List[int], w_bit: int, group_size: int):\n        \"\"\"Update all out features.\"\"\"\n        return all_out_features\n\n    def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, shard_id: Any):\n        \"\"\"Weight loader.\"\"\"\n        world_size, rank = self.get_tp_world_rank()\n        chunk_size, chunk_idx = world_size, rank\n        shard_idx = self.out_names_map[shard_id]\n\n        if self.num_replicate_kv_heads > 1 and shard_id in ['k', 'v']:\n            # update to duplicate k/v for tp_size > num_kv_heads\n            chunk_size = world_size // self.num_replicate_kv_heads\n            chunk_idx = rank // self.num_replicate_kv_heads\n\n        if loaded_weight.dim() == 1:\n            # bias\n            align = max(self.elem_per_int, self.group_size)\n            param_w = param.data.split(self.all_out_features, 0)[shard_idx]\n            weight = chunk_aligned(loaded_weight, chunk_size, 0, align)[chunk_idx]\n            param_w.copy_(weight)\n            return\n\n        if param._weight_type in ['scales', 'bias']:\n            # scales\n            align = max(self.elem_per_int, self.group_size)\n            param_w = param.data.split(self.all_out_features, -1)[shard_idx]\n        else:\n            # qweight or qzeros\n            align = max(self.elem_per_int, self.group_size) // self.elem_per_int\n            quanted_out_feats = [feat // self.elem_per_int for feat in self.all_out_features]\n            param_w = param.data.split(quanted_out_feats, 1)[shard_idx]\n\n        weight = chunk_aligned(loaded_weight, chunk_size, -1, align)[chunk_idx]\n        param_w.copy_(weight)\n\n    def weight_spliter_wz(self, loaded_weight: torch.Tensor, layout: str = 'default'):\n        \"\"\"Weight spliter.\"\"\"\n        check_qkv_split_layout(layout)\n        if layout == 'default':\n            return loaded_weight.split(self.qkv_split_section_wz, dim=1)\n        elif layout == 'hgd':\n            assert self.head_size == self.head_size_v\n            heads = [sec // self.head_size for sec in self.qkv_split_section_s]\n            kv_heads = heads[-1]\n            loaded_weight = loaded_weight.unflatten(1, (kv_heads, -1, self.head_size // self.elem_per_int))\n            q = loaded_weight[:, :, :-2].flatten(1, 3)\n            k = loaded_weight[:, :, -2].flatten(1, 2)\n            v = loaded_weight[:, :, -1].flatten(1, 2)\n            return q, k, v\n        else:\n            raise RuntimeError(f'Unsupported layout: {layout}')\n\n    def weight_spliter_s(self, loaded_weight: torch.Tensor, layout: str = 'default'):\n        \"\"\"Weight spliter.\"\"\"\n        check_qkv_split_layout(layout)\n        if layout == 'default':\n            return loaded_weight.split(self.qkv_split_section_s, dim=-1)\n        elif layout == 'hgd':\n            assert self.head_size == self.head_size_v\n            heads = [sec // self.head_size for sec in self.qkv_split_section_s]\n            kv_heads = heads[-1]\n            loaded_weight = loaded_weight.unflatten(1, (kv_heads, -1, self.head_size))\n            q = loaded_weight[:, :, :-2].flatten(1, 3)\n            k = loaded_weight[:, :, -2].flatten(1, 2)\n            v = loaded_weight[:, :, -1].flatten(1, 2)\n            return q, k, v\n        else:\n            raise RuntimeError(f'Unsupported layout: {layout}')\n\n    def weight_spliter_lora_b(self, loaded_weight: torch.Tensor):\n        return loaded_weight.split(self.qkv_split_section_s, dim=0)\n"
  },
  {
    "path": "lmdeploy/pytorch/nn/linear/base.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom typing import Callable, List, Optional\n\nimport torch\nimport torch.distributed as dist\nfrom torch import nn\n\nfrom lmdeploy.pytorch.config import TPMode\nfrom lmdeploy.pytorch.distributed import (gather_by_tp_sizes, get_dist_group, get_dist_manager, get_tp_world_rank,\n                                          reduce_scatter_by_tp_sizes)\nfrom lmdeploy.pytorch.model_inputs import get_step_ctx_manager\n\nfrom .utils import update_tp_args\n\n\nclass LinearForwardDPTP:\n\n    def __init__(self, gemm_func: Callable, max_tokens_per_round: int = 8192):\n        \"\"\"Linear forward dp tp.\"\"\"\n        self.gemm_func = gemm_func\n        self.dist_ctx = get_dist_manager().current_context()\n        self.dist_config = self.dist_ctx.dist_config\n        self.tp = self.dist_config.mlp_tp\n        self.attn_tp = self.dist_config.attn_tp\n\n        tp_group = self.dist_ctx.mlp_tp_group\n        self.rank = tp_group.rank\n        self.gather_rank = self.rank // self.attn_tp\n        self.gather_group = tp_group.gpu_gather_group\n        self.tp_group = tp_group.gpu_group\n        self.max_tokens_per_round = max_tokens_per_round * self.attn_tp // self.tp // 2\n\n    def all_gather(self, hidden_states: torch.Tensor, tp_sizes: List[int]):\n        \"\"\"All gather.\"\"\"\n        hidden_states, handle = dist.gather_by_tp_sizes(hidden_states, tp_sizes, group=self.gather_group, async_op=True)\n        return hidden_states, handle\n\n    def reduce_scatter(self, hidden_states: torch.Tensor, out_states: torch.Tensor, tp_sizes: List[int]):\n        \"\"\"Reduce scatter.\"\"\"\n        hidden_states_list = list(hidden_states.split(tp_sizes, -2))\n        cur_out_states = hidden_states_list[self.gather_rank]\n        out_states.copy_(cur_out_states)\n        hidden_states_list = [item for item in hidden_states_list for _ in range(self.attn_tp)]\n        hidden_states_list[self.rank] = out_states\n        handle = dist.reduce_scatter(out_states, hidden_states_list, group=self.tp_group, async_op=True)\n        return out_states, handle\n\n    def _gemm_and_reduce_scatter(self, hidden_states: torch.Tensor, output_states: torch.Tensor, tp_sizes: List[int],\n                                 handle: dist.Work):\n        \"\"\"Gemm and reduce scatter.\"\"\"\n        handle.wait()\n        cur_out = self.gemm_func(hidden_states)\n        return self.reduce_scatter(cur_out, output_states, tp_sizes)\n\n    def forward(self, hidden_states: torch.Tensor):\n        \"\"\"forward.\"\"\"\n\n        def __slice_tensor(tensor: torch.Tensor, slice_size: int):\n            \"\"\"Slice tensor.\"\"\"\n            cur_tensor = tensor[:slice_size]\n            tensor = tensor[slice_size:]\n            return cur_tensor, tensor\n\n        def __slice_and_gather():\n            \"\"\"Slice and gather.\"\"\"\n            nonlocal hidden_states, tp_sizes, output_states\n            cur_tp_sizes = tp_sizes.minimum(max_tokens_per_round)\n            tp_sizes -= cur_tp_sizes\n            cur_tp_sizes = cur_tp_sizes.tolist()\n\n            slice_size = cur_tp_sizes[self.gather_rank]\n            cur_hidden_states, hidden_states = __slice_tensor(hidden_states, slice_size)\n            cur_output, output_states = __slice_tensor(output_states, slice_size)\n\n            # all gather\n            cur_hidden_states, handle = self.all_gather(cur_hidden_states, cur_tp_sizes)\n            return dict(hidden_states=cur_hidden_states, output_states=cur_output, handle=handle, tp_sizes=cur_tp_sizes)\n\n        step_ctx = get_step_ctx_manager().current_context()\n        tp_sizes = step_ctx.dp_meta.moe_tp_sizes\n        tp_sizes = torch.tensor(tp_sizes)\n        max_tokens_per_round = tp_sizes.new_tensor(self.max_tokens_per_round)\n\n        output_states = torch.empty_like(hidden_states)\n        return_states = output_states\n\n        # pre\n        cur_inputs = __slice_and_gather()\n        handles = []\n\n        # main loop\n        while tp_sizes.sum() > 0:\n            next_inputs = __slice_and_gather()\n            _, handle = self._gemm_and_reduce_scatter(**cur_inputs)\n            handles.append(handle)\n            cur_inputs = next_inputs\n\n        # post\n        _, handle = self._gemm_and_reduce_scatter(**cur_inputs)\n        handles.append(handle)\n        for handle in handles:\n            handle.wait()\n        return return_states\n\n\nclass LinearBase(nn.Module):\n    \"\"\"Base class for linear layers.\"\"\"\n\n    def __init__(\n        self,\n        dtype: Optional[torch.dtype] = None,\n        device: Optional[torch.device] = None,\n        colwise: bool = True,\n        is_tp: bool = False,\n        all_reduce: bool = True,\n        tp_align_size: int = 1,\n        dp_gather: bool = False,\n        layer_type: str = 'attn',\n    ):\n        super().__init__()\n        self.init_tp_args(is_tp, all_reduce, colwise, layer_type)\n        self.colwise = colwise\n        self.tp_align_size = tp_align_size\n        self.dp_gather = dp_gather\n        if device is None:\n            device = torch.device('cpu')\n        if dtype is None:\n            dtype = torch.float16\n        self.device = device\n        self.dtype = dtype\n        self.layer_type = layer_type\n\n        self.lora_adapters = nn.ModuleDict()\n\n    def init_tp_args(self, is_tp: bool, all_reduce: bool, colwise: bool, layer_type: str):\n        if getattr(self, '_tp_args_initialized', False):\n            return\n        is_tp, all_reduce = update_tp_args(is_tp, all_reduce, colwise, layer_type=layer_type)\n        self.is_tp = is_tp\n        self.all_reduce = all_reduce\n        if is_tp:\n            dist_cfg = get_dist_manager().current_config()\n            _, rank = get_tp_world_rank(layer_type)\n            tp, tp_mode = dist_cfg.get_tp_by_layer(layer_type)\n            self.tp_rank = rank\n            self.tp = tp\n            self.tp_mode = tp_mode\n            dist_group = get_dist_group(layer_type=layer_type)\n            self.tp_group = dist_group.gpu_group\n            self.gather_group = dist_group.gpu_gather_group\n        else:\n            self.tp_rank = 0\n            self.tp = 1\n            self.tp_mode = TPMode.DEFAULT\n            self.tp_group = None\n            self.gather_group = None\n\n        if self.tp > 1 and self.tp_mode == TPMode.DP_TP:\n\n            def _gemm_func(self, x):\n                out = self._forward_default(x, False, None)\n\n                for lora_adapter in self.lora_adapters.values():\n                    out = lora_adapter(x, out)\n                return out\n\n            self.linear_dptp_forward = LinearForwardDPTP(_gemm_func)\n\n        self._tp_args_initialized = True\n\n    def get_tp_world_rank(self):\n        \"\"\"Get tp world rank.\"\"\"\n        assert hasattr(self, 'tp') and hasattr(self, 'tp_rank'), 'Please run init_tp_args first.'\n        return self.tp, self.tp_rank\n\n    def update_weights(self):\n        \"\"\"Update weights.\"\"\"\n        raise NotImplementedError('This method should be implemented in subclasses.')\n\n    def _forward_default(self, x, all_reduce: bool, tp_sizes: List[int]):\n        \"\"\"Default forward implement.\"\"\"\n        raise NotImplementedError('This method should be implemented in subclasses.')\n\n    def _forward_lora(self, x, tp_sizes: List[int] = None):\n        \"\"\"Forward with LoRA.\"\"\"\n        out = self._forward_default(x, False, tp_sizes)\n\n        for lora_adapter in self.lora_adapters.values():\n            out = lora_adapter(x, out)\n        if self.all_reduce:\n            if self.tp_mode == TPMode.DP_TP:\n                out = reduce_scatter_by_tp_sizes(out, self.tp_rank, tp_sizes, group=self.tp_group)\n            else:\n                dist.all_reduce(out, group=self.tp_group)\n        return out\n\n    def _forward_dp_tp(self, x):\n        \"\"\"Forward dp_tp.\"\"\"\n        if self.dp_gather and self.all_reduce:\n            return self.linear_dptp_forward.forward(x)\n\n        step_ctx = get_step_ctx_manager().current_context()\n        dp_meta = step_ctx.dp_meta\n        tp_sizes = dp_meta.tp_sizes\n\n        if self.dp_gather:\n            x = gather_by_tp_sizes(x, tp_sizes, group=self.gather_group)\n\n        if len(self.lora_adapters) == 0:\n            return self._forward_default(x, self.all_reduce, tp_sizes)\n        else:\n            return self._forward_lora(x, tp_sizes)\n\n    def forward(self, x):\n        \"\"\"Forward of linear layer.\"\"\"\n        if self.tp > 1 and self.tp_mode == TPMode.DP_TP:\n            return self._forward_dp_tp(x)\n\n        if len(self.lora_adapters) == 0:\n            return self._forward_default(x, self.all_reduce, None)\n        else:\n            return self._forward_lora(x)\n"
  },
  {
    "path": "lmdeploy/pytorch/nn/linear/blocked_fp8.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom typing import Any, List, Optional\n\nimport torch\n\nfrom lmdeploy.pytorch.backends import OpType, get_backend\nfrom lmdeploy.pytorch.config import TPMode\nfrom lmdeploy.pytorch.weight_loader.model_weight_loader import default_weight_loader\n\nfrom ..quant_utils import quant_blocked_fp8\nfrom ..utils import div_up, get_distribute_size\nfrom .base import LinearBase\nfrom .utils import QKVMixin, check_qkv_split_layout\n\n\nclass BlockedF8Linear(LinearBase):\n    \"\"\"Blocked f8 linear.\"\"\"\n\n    def __init__(\n        self,\n        in_features: int,\n        out_features: int,\n        bias: bool,\n        dtype: Optional[torch.dtype] = None,\n        device: Optional[torch.device] = None,\n        fp8_dtype: torch.dtype = torch.float8_e4m3fn,\n        scale_fmt: Optional[str] = None,\n        colwise: bool = True,\n        is_tp: bool = False,\n        all_reduce: bool = True,\n        dp_gather: bool = False,\n        layer_type: str = 'attn',\n    ):\n        super().__init__(dtype=dtype,\n                         device=device,\n                         colwise=colwise,\n                         is_tp=is_tp,\n                         all_reduce=all_reduce,\n                         dp_gather=dp_gather,\n                         layer_type=layer_type)\n        self.block_size = 128\n        self.fp8_dtype = fp8_dtype\n        self.scale_fmt = scale_fmt\n        if self.is_tp:\n            in_features, out_features = self._get_io_features(in_features, out_features, colwise)\n        impl_builder = get_backend().get_layer_impl_builder(OpType.LinearBlockedF8)\n        self.impl = impl_builder.build(in_features,\n                                       out_features,\n                                       block_size=128,\n                                       bias=bias is not None,\n                                       dtype=self.dtype)\n        self.impl.set_scale_fmt(scale_fmt)\n        weight, weight_scale_inv, bias = self.create_weights(in_features, out_features, bias, self.dtype, self.device)\n        self.register_all_parameters(weight, weight_scale_inv, bias)\n\n        self.in_features = in_features\n        self.out_features = out_features\n\n    def setup_loaders(self):\n        \"\"\"Setup weight loaders.\"\"\"\n        self.weight.weight_loader = self.weight_loader_with_quant\n        self.weight_scale_inv.weight_loader = self.weight_loader\n        if self.bias is not None:\n            self.bias.weight_loader = self.weight_loader\n\n    def register_all_parameters(self,\n                                weight: torch.Tensor,\n                                weight_scale_inv: torch.Tensor,\n                                bias: Optional[torch.Tensor] = None):\n        \"\"\"Register all parameters.\"\"\"\n        weight = torch.nn.Parameter(weight, requires_grad=False)\n        weight_scale_inv = torch.nn.Parameter(weight_scale_inv, requires_grad=False)\n        if bias is not None:\n            bias = torch.nn.Parameter(bias, requires_grad=False)\n        self.register_parameter('weight', weight)\n        self.register_parameter('weight_scale_inv', weight_scale_inv)\n        self.register_parameter('bias', bias)\n        self.setup_loaders()\n\n    def _get_io_features(self, in_features: int, out_features: int, colwise: bool):\n        \"\"\"Get io features.\"\"\"\n        world_size, rank = self.get_tp_world_rank()\n        if colwise:\n            out_features = get_distribute_size(out_features, world_size, rank)\n        else:\n            in_features = get_distribute_size(in_features, world_size, rank)\n        return in_features, out_features\n\n    def _weight_loader_tp_colwise(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, rank: int,\n                                  world_size: int):\n        \"\"\"Weight loader for colwise linear.\"\"\"\n        weight = loaded_weight.chunk(world_size, 0)[rank]\n        return default_weight_loader(param, weight)\n\n    def _weight_loader_tp_rowwise(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, rank: int,\n                                  world_size: int):\n        \"\"\"Weight loader for rowwise linear.\"\"\"\n        if loaded_weight.dim() == 2:\n            loaded_weight = loaded_weight.to(param.device)\n            weight = loaded_weight.chunk(world_size, 1)[rank]\n            return default_weight_loader(param, weight)\n        else:\n            # bias\n            if rank != 0:\n                loaded_weight = torch.zeros_like(loaded_weight)\n            return default_weight_loader(param, loaded_weight)\n\n    def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor):\n        \"\"\"Weight loader.\"\"\"\n        if not self.is_tp:\n            return default_weight_loader(param, loaded_weight)\n\n        world_size, rank = self.get_tp_world_rank()\n        if self.colwise:\n            return self._weight_loader_tp_colwise(param, loaded_weight, rank, world_size)\n        else:\n            return self._weight_loader_tp_rowwise(param, loaded_weight, rank, world_size)\n\n    def weight_loader_with_quant(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor):\n        \"\"\"Weight loader with weight quant.\"\"\"\n        if loaded_weight.dtype != param.dtype:\n            # quant loaded weight\n            quanted_weight, scaling = quant_blocked_fp8(loaded_weight.to(param.device),\n                                                        param.dtype,\n                                                        self.block_size,\n                                                        scale_fmt=self.scale_fmt)\n            self.weight_loader(self.weight, quanted_weight)\n            self.weight_loader(self.weight_scale_inv, scaling)\n        else:\n            return self.weight_loader(param, loaded_weight)\n\n    def create_weights(self, in_features: int, out_features: int, bias: bool, dtype: torch.dtype, device: torch.device):\n        \"\"\"Create weights.\"\"\"\n        weight = torch.empty((out_features, in_features), dtype=self.fp8_dtype, device=device)\n        weight_scale_inv = torch.empty((div_up(out_features, self.block_size), div_up(in_features, self.block_size)),\n                                       dtype=torch.float32,\n                                       device=device)\n        if bias:\n            bias = torch.empty((out_features, ), dtype=dtype, device=device)\n        else:\n            bias = None\n        return weight, weight_scale_inv, bias\n\n    def update_weights(self):\n        \"\"\"Update weights.\"\"\"\n        weight, weight_scale_inv, bias = self.impl.update_weights(self.weight, self.weight_scale_inv, self.bias)\n        self.register_all_parameters(weight, weight_scale_inv, bias)\n\n    def _forward_default(self, x, all_reduce, tp_sizes):\n        \"\"\"Default forward implement.\"\"\"\n        if self.tp_mode == TPMode.DP_TP:\n            rank = self.tp_rank\n            return self.impl.forward(x,\n                                     self.weight,\n                                     self.weight_scale_inv,\n                                     self.bias,\n                                     all_reduce,\n                                     group=self.tp_group,\n                                     rank=rank,\n                                     scatter_size=tp_sizes)\n        else:\n            return self.impl.forward(x, self.weight, self.weight_scale_inv, self.bias, all_reduce, group=self.tp_group)\n\n\nclass MergedBlockedF8Linear(BlockedF8Linear):\n    \"\"\"Merged blocked fp8 linear.\"\"\"\n\n    def __init__(self,\n                 in_features: int,\n                 all_out_features: List[int],\n                 bias: bool,\n                 fp8_dtype: torch.dtype = torch.float8_e4m3fn,\n                 scale_fmt: Optional[str] = None,\n                 replicate: Optional[List[bool]] = None,\n                 dtype: Optional[torch.dtype] = None,\n                 device: Optional[torch.device] = None,\n                 is_tp: bool = True,\n                 out_names: Optional[List[int]] = None,\n                 dp_gather: bool = False,\n                 layer_type: str = 'attn'):\n        self.init_tp_args(is_tp, all_reduce=False, colwise=True, layer_type=layer_type)\n        if replicate is None:\n            replicate = tuple(False for _ in all_out_features)\n        self.block_size = 128\n        self.split_section = all_out_features\n        self.scale_split_section = [section // self.block_size for section in self.split_section]\n        all_out_features = self._update_all_out_features(all_out_features, replicate)\n        self.all_out_features = all_out_features\n        self.replicate = replicate\n        if out_names is None:\n            out_names = torch.arange(len(self.all_out_features)).tolist()\n        assert len(out_names) == len(self.all_out_features)\n        self.out_names_map = dict((name, idx) for idx, name in enumerate(out_names))\n        out_features = sum(all_out_features)\n        super().__init__(in_features,\n                         out_features,\n                         bias,\n                         dtype,\n                         device,\n                         fp8_dtype=fp8_dtype,\n                         scale_fmt=scale_fmt,\n                         colwise=True,\n                         is_tp=is_tp,\n                         dp_gather=dp_gather,\n                         layer_type=layer_type)\n        self.setup_loaders()\n\n    def setup_loaders(self):\n        \"\"\"Setup weight loaders.\"\"\"\n        self.weight.weight_loader = self.weight_loader_with_quant\n        self.weight.weight_spliter = self.weight_spliter\n        self.weight._weight_type = 'qweight'\n        self.weight_scale_inv.weight_loader = self.weight_loader\n        self.weight_scale_inv.weight_spliter = self.weight_spliter\n        self.weight_scale_inv._weight_type = 'scales'\n        if self.bias is not None:\n            self.bias.weight_loader = self.weight_loader\n            self.bias.weight_spliter = self.weight_spliter\n            self.bias._weight_type = 'bias'\n\n    def _get_io_features(self, in_features: int, out_features: int, colwise: bool):\n        \"\"\"Get io features.\"\"\"\n        return in_features, out_features\n\n    def _update_all_out_features(self, all_out_features: List[int], replicate: Optional[List[bool]]):\n        \"\"\"Update all out features.\"\"\"\n        world_size, rank = self.get_tp_world_rank()\n        new_all_out_features = []\n        for out_feat, rep in zip(all_out_features, replicate):\n            if rep:\n                new_all_out_features.append(out_feat)\n            new_out_feat = get_distribute_size(out_feat, world_size, rank)\n            new_all_out_features.append(new_out_feat)\n        return new_all_out_features\n\n    def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, shard_id: Any):\n        \"\"\"Weight loader.\"\"\"\n        world_size, rank = self.get_tp_world_rank()\n        shard_idx = self.out_names_map[shard_id]\n        if loaded_weight.dim() == 2 and loaded_weight.dtype != self.fp8_dtype:\n            loaded_weight = loaded_weight.to(torch.float32)\n            all_out_features = [feats // self.block_size for feats in self.all_out_features]\n            param_w = param.data.split(all_out_features, 0)[shard_idx]\n        else:\n            param_w = param.data.split(self.all_out_features, 0)[shard_idx]\n        if not self.replicate[shard_idx]:\n            loaded_weight = loaded_weight.chunk(world_size, 0)[rank]\n        param_w.copy_(loaded_weight)\n\n    def weight_loader_with_quant(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, shard_id: Any):\n        \"\"\"Weight loader with weight quant.\"\"\"\n        if loaded_weight.dtype != param.dtype:\n            # quant loaded weight\n            quanted_weight, scaling = quant_blocked_fp8(loaded_weight.to(param.device),\n                                                        param.dtype,\n                                                        self.block_size,\n                                                        scale_fmt=self.scale_fmt)\n            self.weight_loader(self.weight, quanted_weight, shard_id)\n            self.weight_loader(self.weight_scale_inv, scaling, shard_id)\n        else:\n            return self.weight_loader(param, loaded_weight, shard_id)\n\n    def weight_spliter(self, loaded_weight: torch.Tensor):\n        \"\"\"Weight spliter.\"\"\"\n        if loaded_weight.dim() == 2 and loaded_weight.dtype != self.fp8_dtype:\n            return loaded_weight.split(self.scale_split_section, dim=0)\n        return loaded_weight.split(self.split_section, dim=0)\n\n    def weight_spliter_lora_b(self, loaded_weight: torch.Tensor):\n        return loaded_weight.split(self.split_section, dim=0)\n\n\nclass QKVBlockedF8Linear(MergedBlockedF8Linear, QKVMixin):\n    \"\"\"Qkv blockedf8 linear.\"\"\"\n\n    def __init__(self,\n                 in_features: int,\n                 num_q_heads: int,\n                 num_kv_heads: int,\n                 head_size: int,\n                 head_size_v: int,\n                 bias: bool = False,\n                 fp8_dtype: torch.dtype = torch.float8_e4m3fn,\n                 scale_fmt: Optional[str] = None,\n                 dtype: Optional[torch.dtype] = None,\n                 device: Optional[torch.device] = None,\n                 is_tp: bool = True,\n                 dp_gather: bool = False,\n                 num_replicate_kv_heads: int = 1):\n        self.block_size = 128\n        self.init_tp_args(is_tp, all_reduce=False, colwise=True, layer_type='attn')\n        QKVMixin.__init__(self,\n                          num_q_heads=num_q_heads,\n                          num_kv_heads=num_kv_heads,\n                          head_size=head_size,\n                          head_size_v=head_size_v,\n                          num_replicate_kv_heads=num_replicate_kv_heads,\n                          is_tp=is_tp,\n                          tp=self.tp,\n                          tp_rank=self.tp_rank)\n\n        all_out_features = self.get_qkv_out_feautures()\n        out_names = ('q', 'k', 'v')\n        super().__init__(in_features,\n                         all_out_features,\n                         dtype=dtype,\n                         fp8_dtype=fp8_dtype,\n                         scale_fmt=scale_fmt,\n                         bias=bias,\n                         device=device,\n                         is_tp=is_tp,\n                         out_names=out_names,\n                         dp_gather=dp_gather,\n                         layer_type='attn')\n\n    def _update_all_out_features(self, all_out_features: List[int], replicate: Optional[List[bool]]):\n        \"\"\"Update all out features.\"\"\"\n        return all_out_features\n\n    def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, shard_id: Any):\n        \"\"\"Weight loader.\"\"\"\n        _, rank = self.get_tp_world_rank()\n        shard_idx = self.out_names_map[shard_id]\n\n        num_head = self.num_q_heads if shard_id == 'q' \\\n            else self.num_kv_heads\n        head_dim = self.head_size if shard_id in ['q', 'k'] \\\n            else self.head_size_v\n        # update to duplicate k/v for tp_size > num_kv_heads\n        rank_idx = rank if shard_id == 'q' \\\n            else rank // self.num_replicate_kv_heads\n        sec_len = num_head * head_dim\n        all_out_features = self.all_out_features\n        if param._weight_type == 'scales':\n            loaded_weight = loaded_weight.to(torch.float32)\n            all_out_features = [sec // self.block_size for sec in all_out_features]\n            sec_len = sec_len // self.block_size\n\n        sec_start = rank_idx * sec_len\n\n        loaded_weight = loaded_weight.narrow(dim=0, start=sec_start, length=sec_len)\n        param_w = param.data.split(all_out_features, 0)[shard_idx]\n        param_w.copy_(loaded_weight)\n\n    def weight_loader_with_quant(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, shard_id: Any):\n        \"\"\"Weight loader with weight quant.\"\"\"\n        if loaded_weight.dtype != param.dtype:\n            # quant loaded weight\n            quanted_weight, scaling = quant_blocked_fp8(loaded_weight.to(param.device),\n                                                        param.dtype,\n                                                        self.block_size,\n                                                        scale_fmt=self.scale_fmt)\n            self.weight_loader(self.weight, quanted_weight, shard_id)\n            self.weight_loader(self.weight_scale_inv, scaling, shard_id)\n        else:\n            return self.weight_loader(param, loaded_weight, shard_id)\n\n    def weight_spliter(self, loaded_weight: torch.Tensor, layout: str = 'default'):\n        \"\"\"Weight spliter.\"\"\"\n        check_qkv_split_layout(layout)\n        assert layout == 'default'\n        qkv_split_section = self.qkv_split_section\n        if loaded_weight.dim() == 2 and loaded_weight.dtype != self.fp8_dtype:\n            qkv_split_section = [sec // self.block_size for sec in qkv_split_section]\n        return loaded_weight.split(qkv_split_section, dim=0)\n"
  },
  {
    "path": "lmdeploy/pytorch/nn/linear/default.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom typing import Any, List, Optional\n\nimport torch\n\nfrom lmdeploy.pytorch.backends import OpType, get_backend\nfrom lmdeploy.pytorch.config import TPMode\nfrom lmdeploy.pytorch.weight_loader.model_weight_loader import default_weight_loader\n\nfrom ..utils import chunk_aligned, get_distribute_size\nfrom .base import LinearBase\nfrom .utils import QKVMixin, check_qkv_split_layout\n\n\nclass BaseLinear(LinearBase):\n    \"\"\"Linear layer.\"\"\"\n\n    def __init__(\n        self,\n        in_features: int,\n        out_features: int,\n        bias: bool,\n        dtype: Optional[torch.dtype] = None,\n        device: Optional[torch.device] = None,\n        colwise: bool = True,\n        is_tp: bool = False,\n        all_reduce: bool = True,\n        tp_align_size: int = 1,\n        dp_gather: bool = False,\n        layer_type: str = 'attn',\n    ):\n        super().__init__(dtype=dtype,\n                         device=device,\n                         colwise=colwise,\n                         is_tp=is_tp,\n                         all_reduce=all_reduce,\n                         tp_align_size=tp_align_size,\n                         dp_gather=dp_gather,\n                         layer_type=layer_type)\n        if self.is_tp:\n            in_features, out_features = self._get_io_features(in_features, out_features, colwise)\n        impl_builder = get_backend().get_layer_impl_builder(OpType.Linear)\n        self.impl = impl_builder.build(in_features, out_features, bias is not None, dtype=self.dtype)\n        weight, bias = self.create_weights(in_features, out_features, bias, self.dtype, self.device)\n        self.register_all_parameters(weight, bias)\n\n        self.in_features = in_features\n        self.out_features = out_features\n\n    def setup_loaders(self):\n        \"\"\"Setup loaders.\"\"\"\n        self.weight.weight_loader = self.weight_loader\n        if self.bias is not None:\n            self.bias.weight_loader = self.weight_loader\n\n    def register_all_parameters(self, weight: torch.Tensor, bias: Optional[torch.Tensor] = None):\n        \"\"\"Register all parameters.\"\"\"\n        weight = torch.nn.Parameter(weight, requires_grad=False)\n        if bias is not None:\n            bias = torch.nn.Parameter(bias, requires_grad=False)\n        self.register_parameter('weight', weight)\n        self.register_parameter('bias', bias)\n        self.setup_loaders()\n\n    def _get_io_features(self, in_features: int, out_features: int, colwise: bool):\n        \"\"\"Get io features.\"\"\"\n        world_size, rank = self.get_tp_world_rank()\n        if colwise:\n            out_features = get_distribute_size(out_features, world_size, rank, align=self.tp_align_size)\n        else:\n            in_features = get_distribute_size(in_features, world_size, rank, align=self.tp_align_size)\n        return in_features, out_features\n\n    def _weight_loader_tp_colwise(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, rank: int,\n                                  world_size: int):\n        \"\"\"Weight loader for colwise linear.\"\"\"\n        weight = chunk_aligned(loaded_weight, world_size, 0, self.tp_align_size)[rank]\n        return default_weight_loader(param, weight)\n\n    def _weight_loader_tp_rowwise(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, rank: int,\n                                  world_size: int):\n        \"\"\"Weight loader for rowwise linear.\"\"\"\n        if loaded_weight.dim() == 2:\n            loaded_weight = loaded_weight.to(param.device)\n            weight = chunk_aligned(loaded_weight, world_size, 1, self.tp_align_size)[rank]\n            return default_weight_loader(param, weight)\n        else:\n            # bias\n            if rank != 0:\n                loaded_weight = torch.zeros_like(loaded_weight)\n            return default_weight_loader(param, loaded_weight)\n\n    def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor):\n        \"\"\"Weight loader.\"\"\"\n        if not self.is_tp:\n            return default_weight_loader(param, loaded_weight)\n\n        world_size, rank = self.get_tp_world_rank()\n        if self.colwise:\n            return self._weight_loader_tp_colwise(param, loaded_weight, rank, world_size)\n        else:\n            return self._weight_loader_tp_rowwise(param, loaded_weight, rank, world_size)\n\n    def create_weights(self, in_features: int, out_features: int, bias: bool, dtype: torch.dtype, device: torch.device):\n        \"\"\"Create weights.\"\"\"\n        weight = torch.empty((out_features, in_features), dtype=dtype, device=device)\n        if bias:\n            bias = torch.empty((out_features, ), dtype=dtype, device=device)\n        else:\n            bias = None\n        return weight, bias\n\n    def update_weights(self):\n        \"\"\"Update weights.\"\"\"\n        weight, bias = self.impl.update_weights(self.weight, self.bias)\n        self.register_all_parameters(weight, bias)\n\n    def _forward_default(self, x, all_reduce, tp_sizes):\n        \"\"\"Default forward implement.\"\"\"\n        if self.tp_mode == TPMode.DP_TP:\n            rank = self.tp_rank\n            return self.impl.forward(x,\n                                     self.weight,\n                                     self.bias,\n                                     all_reduce,\n                                     group=self.tp_group,\n                                     rank=rank,\n                                     scatter_size=tp_sizes)\n        else:\n            return self.impl.forward(x, self.weight, self.bias, all_reduce, group=self.tp_group)\n\n\nclass MergedBaseLinear(BaseLinear):\n    \"\"\"Merged base linear.\"\"\"\n\n    def __init__(self,\n                 in_features: int,\n                 all_out_features: List[int],\n                 bias: bool,\n                 dtype: Optional[torch.dtype] = None,\n                 device: Optional[torch.device] = None,\n                 is_tp: bool = True,\n                 out_names: Optional[List[int]] = None,\n                 dp_gather: bool = False,\n                 layer_type: str = 'attn'):\n        self.init_tp_args(is_tp, all_reduce=False, colwise=True, layer_type=layer_type)\n        self.split_section = all_out_features\n        all_out_features = self._update_all_out_features(all_out_features)\n        self.all_out_features = all_out_features\n        if out_names is None:\n            out_names = torch.arange(len(self.all_out_features)).tolist()\n        assert len(out_names) == len(self.all_out_features)\n        self.out_names_map = dict((name, idx) for idx, name in enumerate(out_names))\n        out_features = sum(all_out_features)\n        super().__init__(in_features,\n                         out_features,\n                         bias,\n                         dtype,\n                         device,\n                         colwise=True,\n                         is_tp=is_tp,\n                         dp_gather=dp_gather,\n                         layer_type=layer_type)\n        self.setup_loaders()\n\n    def setup_loaders(self):\n        \"\"\"Update loaders.\"\"\"\n        self.weight.weight_loader = self.weight_loader\n        self.weight.weight_spliter = self.weight_spliter\n        if self.bias is not None:\n            self.bias.weight_loader = self.weight_loader\n            self.bias.weight_spliter = self.weight_spliter\n\n    def _get_io_features(self, in_features: int, out_features: int, colwise: bool):\n        \"\"\"Get io features.\"\"\"\n        return in_features, out_features\n\n    def _update_all_out_features(self, all_out_features: List[int]):\n        \"\"\"Update all out features.\"\"\"\n        world_size, rank = self.get_tp_world_rank()\n        new_all_out_features = []\n        for out_feat in all_out_features:\n            new_out_feat = get_distribute_size(out_feat, world_size, rank)\n            new_all_out_features.append(new_out_feat)\n        return new_all_out_features\n\n    def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, shard_id: Any):\n        \"\"\"Weight loader.\"\"\"\n        world_size, rank = self.get_tp_world_rank()\n        shard_idx = self.out_names_map[shard_id]\n        param_w = param.data.split(self.all_out_features, 0)[shard_idx]\n        loaded_weight = loaded_weight.chunk(world_size, 0)[rank]\n        param_w.copy_(loaded_weight)\n\n    def weight_spliter(self, loaded_weight: torch.Tensor):\n        \"\"\"Weight spliter.\"\"\"\n        return loaded_weight.split(self.split_section, dim=0)\n\n    def weight_spliter_lora_b(self, loaded_weight: torch.Tensor):\n        return loaded_weight.split(self.split_section, dim=0)\n\n\nclass QKVBaseLinear(MergedBaseLinear, QKVMixin):\n    \"\"\"Qkv base linear.\"\"\"\n\n    def __init__(self,\n                 in_features: int,\n                 num_q_heads: int,\n                 num_kv_heads: int,\n                 head_size: int,\n                 head_size_v: int,\n                 bias: bool = False,\n                 dtype: Optional[torch.dtype] = None,\n                 device: Optional[torch.device] = None,\n                 is_tp: bool = True,\n                 num_replicate_kv_heads: int = 1):\n        self.init_tp_args(is_tp, all_reduce=False, colwise=True, layer_type='attn')\n        QKVMixin.__init__(self,\n                          num_q_heads=num_q_heads,\n                          num_kv_heads=num_kv_heads,\n                          head_size=head_size,\n                          head_size_v=head_size_v,\n                          num_replicate_kv_heads=num_replicate_kv_heads,\n                          is_tp=is_tp,\n                          tp=self.tp,\n                          tp_rank=self.tp_rank)\n\n        all_out_features = self.get_qkv_out_feautures()\n        out_names = ('q', 'k', 'v')\n        super().__init__(in_features,\n                         all_out_features,\n                         bias=bias,\n                         dtype=dtype,\n                         device=device,\n                         is_tp=is_tp,\n                         out_names=out_names,\n                         layer_type='attn')\n\n    def _update_all_out_features(self, all_out_features: List[int]):\n        \"\"\"Update all out features.\"\"\"\n        return all_out_features\n\n    def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, shard_id: Any):\n        \"\"\"Weight loader.\"\"\"\n        world_size, rank = self.get_tp_world_rank()\n        chunk_size, chunk_idx = world_size, rank\n        shard_idx = self.out_names_map[shard_id]\n        param_w = param.data.split(self.all_out_features, 0)[shard_idx]\n\n        if self.num_replicate_kv_heads > 1 and shard_id in ['k', 'v']:\n            # update to duplicate k/v for tp_size > num_kv_heads\n            chunk_size = world_size // self.num_replicate_kv_heads\n            chunk_idx = rank // self.num_replicate_kv_heads\n        if shard_idx in [0, 1]:\n            loaded_weight = chunk_aligned(loaded_weight, chunk_size, 0, self.head_size)[chunk_idx]\n        elif shard_idx == 2:\n            loaded_weight = chunk_aligned(loaded_weight, chunk_size, 0, self.head_size_v)[chunk_idx]\n        param_w.copy_(loaded_weight)\n\n    def weight_spliter(self, loaded_weight: torch.Tensor, layout: str = 'default'):\n        \"\"\"Weight spliter.\"\"\"\n        check_qkv_split_layout(layout)\n        if layout == 'default':\n            return loaded_weight.split(self.qkv_split_section, dim=0)\n        elif layout == 'hgd':\n            assert self.head_size == self.head_size_v\n            heads = [sec // self.head_size for sec in self.qkv_split_section]\n            kv_heads = heads[-1]\n            loaded_weight = loaded_weight.unflatten(0, (kv_heads, -1, self.head_size))\n            q = loaded_weight[:, :-2].flatten(0, 2)\n            k = loaded_weight[:, -2].flatten(0, 1)\n            v = loaded_weight[:, -1].flatten(0, 1)\n            return q, k, v\n        else:\n            raise RuntimeError(f'Unsupported layout: {layout}')\n\n    def weight_spliter_lora_b(self, loaded_weight: torch.Tensor):\n        return loaded_weight.split(self.qkv_split_section, dim=0)\n"
  },
  {
    "path": "lmdeploy/pytorch/nn/linear/lora.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom typing import Any\n\nimport torch\nfrom torch import nn\n\nfrom lmdeploy.pytorch.backends import OpType, get_backend\nfrom lmdeploy.pytorch.backends.lora import AdapterInfo\nfrom lmdeploy.pytorch.distributed import get_tp_world_rank\n\n\nclass LoRA(nn.Module):\n    \"\"\"LoRA layer.\"\"\"\n\n    def __init__(self,\n                 in_features: int,\n                 out_features: int,\n                 ranks: torch.Tensor,\n                 scalings: torch.Tensor,\n                 lora_a: torch.Tensor,\n                 lora_b: torch.Tensor,\n                 base_slice: slice,\n                 ctx_mgr: Any = None,\n                 colwise: bool = True,\n                 is_tp: bool = True,\n                 lora_b_spliter: Any = None):\n        super().__init__()\n        self.adapter_info = AdapterInfo(\n            in_features=in_features,\n            out_features=out_features,\n            ranks=ranks,\n            scalings=scalings,\n            base_slice=base_slice,\n        )\n        impl_builder = get_backend().get_layer_impl_builder(OpType.LoRA)\n        self.impl = impl_builder.build()\n\n        lora_A = nn.Parameter(lora_a, requires_grad=False)\n        lora_B = nn.Parameter(lora_b, requires_grad=False)\n        self.register_parameter('lora_A', lora_A)\n        self.register_parameter('lora_B', lora_B)\n        lora_A.weight_loader = self.weight_loader_A\n        lora_B.weight_loader = self.weight_loader_B\n        self.is_tp = is_tp\n        self.ctx_mgr = ctx_mgr\n        self.colwise = colwise\n        self.lora_b_spliter = lora_b_spliter\n\n    def forward(self, x, base_output=None):\n        \"\"\"Forward of loraA@loraB.\"\"\"\n        return self.impl.forward(x,\n                                 self.lora_A,\n                                 self.lora_B,\n                                 base_output,\n                                 self.adapter_info,\n                                 ctx_mgr=self.ctx_mgr,\n                                 colwise=self.colwise,\n                                 is_tp=self.is_tp)\n\n    def weight_loader_A(self, param: nn.Parameter, loaded_weight: torch.Tensor, adapter_id: int):\n        \"\"\"Weight loader.\"\"\"\n        rank = self.adapter_info.ranks[adapter_id].item()\n        r_start = self.adapter_info.rank_offsets[adapter_id].item()\n        r_end = r_start + rank\n        param_r = param.data[r_start:r_end]\n\n        if self.is_tp and not self.colwise:\n            world_size, rank = get_tp_world_rank()\n            loaded_weight = loaded_weight.to(param_r.device)\n            loaded_weight = loaded_weight.chunk(world_size, dim=1)[rank]\n\n        param_r.copy_(loaded_weight)\n\n    def weight_loader_B(self, param: nn.Parameter, loaded_weight: torch.Tensor, adapter_id: int):\n        \"\"\"Weight loader.\"\"\"\n        rank = self.adapter_info.ranks[adapter_id].item()\n        r_start = self.adapter_info.rank_offsets[adapter_id].item()\n        r_end = r_start + rank\n        param_r = param.data[r_start:r_end]\n\n        if self.is_tp and self.colwise:\n            world_size, rank = get_tp_world_rank()\n            if self.lora_b_spliter is not None:\n                loaded_weights = self.lora_b_spliter(loaded_weight)\n                new_weights = []\n                for w in loaded_weights:\n                    w = w.chunk(world_size, dim=0)[rank]\n                    new_weights.append(w)\n                loaded_weight = torch.cat(new_weights, dim=0)\n            else:\n                loaded_weight = loaded_weight.chunk(world_size, dim=0)[rank]\n\n        param_r.copy_(loaded_weight.t())\n"
  },
  {
    "path": "lmdeploy/pytorch/nn/linear/utils.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport torch\n\nfrom lmdeploy.pytorch.distributed import get_tp_world_rank\nfrom lmdeploy.utils import get_logger\n\nfrom ..utils import get_distribute_size\n\nlogger = get_logger('lmdeploy')\n\nQKV_SPLIT_LAYOUTS = ['default', 'hgd']\n\n\ndef check_qkv_split_layout(layout: str):\n    if layout not in QKV_SPLIT_LAYOUTS:\n        raise RuntimeError(f'Expect qkv split layout in {QKV_SPLIT_LAYOUTS}, '\n                           f'but get: {layout}')\n\n\ndef update_tp_args(is_tp: bool, all_reduce: bool, colwise: bool, layer_type: str = 'attn'):\n    \"\"\"Update tp args according to the environment.\"\"\"\n    if is_tp:\n        world, _ = get_tp_world_rank(layer_type)\n        is_tp = world > 1\n\n    if not is_tp or colwise:\n        all_reduce = False\n\n    return is_tp, all_reduce\n\n\nclass QKVMixin:\n    \"\"\"Qkv mixin.\"\"\"\n\n    def __init__(self,\n                 num_q_heads: int,\n                 num_kv_heads: int,\n                 head_size: int,\n                 head_size_v: int,\n                 num_replicate_kv_heads: int = 1,\n                 is_tp: bool = False,\n                 tp: int = 1,\n                 tp_rank: int = 0):\n        qkv_split_section = self._get_qkv_out_features(num_q_heads, num_kv_heads, head_size, head_size_v,\n                                                       num_replicate_kv_heads)\n        num_q_heads, num_kv_heads = self._update_num_heads(is_tp, tp, tp_rank, num_q_heads, num_kv_heads)\n        self.num_q_heads = num_q_heads\n        self.num_kv_heads = num_kv_heads\n        self.head_size = head_size\n        self.head_size_v = head_size_v\n        self.num_replicate_kv_heads = num_replicate_kv_heads\n        self.qkv_split_section = qkv_split_section\n\n    def get_qkv_out_feautures(self):\n        \"\"\"Get qkv out features.\"\"\"\n        return self._get_qkv_out_features(self.num_q_heads, self.num_kv_heads, self.head_size, self.head_size_v)\n\n    def _get_qkv_out_features(self,\n                              num_q_heads: int,\n                              num_kv_heads: int,\n                              head_size: int,\n                              head_size_v: int,\n                              num_replicate_kv_heads: int = 1):\n        \"\"\"Get io features.\"\"\"\n        num_kv_heads_real = num_kv_heads // num_replicate_kv_heads\n        all_out_features = (num_q_heads * head_size, num_kv_heads_real * head_size, num_kv_heads_real * head_size_v)\n        return all_out_features\n\n    def _update_num_heads(self, is_tp: bool, tp: int, tp_rank: int, num_q_heads: int, num_kv_heads: int):\n        \"\"\"Update num heads.\"\"\"\n        if not is_tp:\n            return num_q_heads, num_kv_heads\n        world_size, rank = tp, tp_rank\n        num_q_heads = get_distribute_size(num_q_heads, world_size, rank)\n        num_kv_heads = get_distribute_size(num_kv_heads, world_size, rank)\n\n        return num_q_heads, num_kv_heads\n\n    def split_qkv(self, x: torch.Tensor):\n        \"\"\"Split query, key and value.\"\"\"\n        num_q_heads = self.num_q_heads\n        num_kv_heads = self.num_kv_heads\n        head_size = self.head_size\n        head_size_v = self.head_size_v\n\n        sections = self.all_out_features\n        q, k, v = x.split(sections, dim=-1)\n        q = q.unflatten(-1, (num_q_heads, head_size))\n        k = k.unflatten(-1, (num_kv_heads, head_size))\n        v = v.unflatten(-1, (num_kv_heads, head_size_v))\n        return q, k, v\n"
  },
  {
    "path": "lmdeploy/pytorch/nn/linear/w8a8.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom typing import Any, List, Optional\n\nimport torch\n\nfrom lmdeploy.pytorch.backends import OpType, get_backend\nfrom lmdeploy.pytorch.weight_loader.model_weight_loader import default_weight_loader\n\nfrom ..utils import get_distribute_size\nfrom .base import LinearBase\nfrom .utils import QKVMixin, check_qkv_split_layout\n\n\nclass W8A8Linear(LinearBase):\n    \"\"\"W8a8 linear.\"\"\"\n\n    def __init__(self,\n                 in_features: int,\n                 out_features: int,\n                 bias: bool,\n                 dtype: Optional[torch.dtype] = None,\n                 device: Optional[torch.device] = None,\n                 colwise: bool = True,\n                 is_tp: bool = False,\n                 all_reduce: bool = True,\n                 quant_dtype: Optional[torch.dtype] = torch.int8,\n                 layer_type: str = 'attn'):\n        super().__init__(dtype=torch.float16,\n                         device=device,\n                         colwise=colwise,\n                         is_tp=is_tp,\n                         all_reduce=all_reduce,\n                         layer_type=layer_type)\n        if self.is_tp:\n            in_features, out_features = self._get_io_features(in_features, out_features, colwise)\n        impl_builder = get_backend().get_layer_impl_builder(OpType.LinearW8A8)\n        self.quant_dtype = quant_dtype\n        self.impl = impl_builder.build(in_features,\n                                       out_features,\n                                       bias is not None,\n                                       dtype=self.dtype,\n                                       quant_dtype=quant_dtype)\n        weight, scale, bias = self.create_weights(in_features, out_features, bias, self.dtype, self.device)\n        self.register_all_parameters(weight, scale, bias)\n\n        self.in_features = in_features\n        self.out_features = out_features\n\n    def setup_loaders(self):\n        \"\"\"Setup weight loaders.\"\"\"\n        self.weight.weight_loader = self.weight_loader\n        self.scale.weight_loader = self.weight_loader\n        if self.bias is not None:\n            self.bias.weight_loader = self.weight_loader\n\n    def register_all_parameters(self, weight: torch.Tensor, scale: torch.Tensor, bias: Optional[torch.Tensor] = None):\n        \"\"\"Register all parameters.\"\"\"\n        weight = torch.nn.Parameter(weight, requires_grad=False)\n        scale = torch.nn.Parameter(scale, requires_grad=False)\n        if bias is not None:\n            bias = torch.nn.Parameter(bias, requires_grad=False)\n        self.register_parameter('weight', weight)\n        self.register_parameter('scale', scale)\n        self.register_parameter('bias', bias)\n        self.setup_loaders()\n\n    def _get_io_features(self, in_features: int, out_features: int, colwise: bool):\n        \"\"\"Get io features.\"\"\"\n        world_size, rank = self.get_tp_world_rank()\n        if colwise:\n            out_features = get_distribute_size(out_features, world_size, rank)\n        else:\n            in_features = get_distribute_size(in_features, world_size, rank)\n        return in_features, out_features\n\n    def _weight_loader_tp_colwise(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, rank: int,\n                                  world_size: int):\n        \"\"\"Weight loader for colwise linear.\"\"\"\n        weight = loaded_weight.chunk(world_size, 0)[rank]\n        return default_weight_loader(param, weight)\n\n    def _weight_loader_tp_rowwise(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, rank: int,\n                                  world_size: int):\n        \"\"\"Weight loader for rowwise linear.\"\"\"\n        if loaded_weight.dim() == 2 and param.dtype in (torch.int8, torch.float8_e4m3fn, torch.float8_e5m2):\n            loaded_weight = loaded_weight.to(param.device)\n            weight = loaded_weight.chunk(world_size, 1)[rank]\n            return default_weight_loader(param, weight)\n        elif loaded_weight.dim() == 2 and loaded_weight.size(1) == 1:\n            # scaling\n            return default_weight_loader(param, loaded_weight)\n        else:\n            # bias\n            if rank != 0:\n                loaded_weight = torch.zeros_like(loaded_weight)\n            return default_weight_loader(param, loaded_weight)\n\n    def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor):\n        \"\"\"Weight loader.\"\"\"\n        if not self.is_tp:\n            return default_weight_loader(param, loaded_weight)\n\n        world_size, rank = self.get_tp_world_rank()\n        if self.colwise:\n            return self._weight_loader_tp_colwise(param, loaded_weight, rank, world_size)\n        else:\n            return self._weight_loader_tp_rowwise(param, loaded_weight, rank, world_size)\n\n    def create_weights(self, in_features: int, out_features: int, bias: bool, dtype: torch.dtype, device: torch.device):\n        \"\"\"Create weights.\"\"\"\n        weight = torch.empty((out_features, in_features), dtype=self.quant_dtype, device=device)\n        scale = torch.empty((out_features, 1), dtype=torch.float32, device=device)\n        if bias:\n            bias = torch.empty((out_features, ), dtype=dtype, device=device)\n        else:\n            bias = None\n        return weight, scale, bias\n\n    def update_weights(self):\n        \"\"\"Update weights.\"\"\"\n        weight, scale, bias = self.impl.update_weights(self.weight, self.scale, self.bias)\n        self.register_all_parameters(weight, scale, bias)\n\n    def _forward_default(self, x, all_reduce, tp_sizes):\n        \"\"\"Default forward implement.\"\"\"\n        return self.impl.forward(x, self.weight, self.scale, self.bias, all_reduce, group=self.tp_group)\n\n\nclass MergedW8A8Linear(W8A8Linear):\n    \"\"\"Merged w8a8 linear.\"\"\"\n\n    def __init__(self,\n                 in_features: int,\n                 all_out_features: List[int],\n                 bias: bool,\n                 dtype: Optional[torch.dtype] = None,\n                 device: Optional[torch.device] = None,\n                 is_tp: bool = True,\n                 out_names: Optional[List[int]] = None,\n                 quant_dtype: torch.dtype = torch.int8,\n                 layer_type: str = 'attn'):\n        self.init_tp_args(is_tp, all_reduce=False, colwise=True, layer_type=layer_type)\n        self.split_section = all_out_features\n        all_out_features = self._update_all_out_features(all_out_features)\n        self.all_out_features = all_out_features\n        if out_names is None:\n            out_names = torch.arange(len(self.all_out_features)).tolist()\n        assert len(out_names) == len(self.all_out_features)\n        self.out_names_map = dict((name, idx) for idx, name in enumerate(out_names))\n        out_features = sum(all_out_features)\n        super().__init__(in_features,\n                         out_features,\n                         bias,\n                         dtype,\n                         device,\n                         colwise=True,\n                         is_tp=is_tp,\n                         quant_dtype=quant_dtype,\n                         layer_type=layer_type)\n        self.setup_loaders()\n\n    def setup_loaders(self):\n        \"\"\"Setup weight loaders.\"\"\"\n        self.weight.weight_loader = self.weight_loader\n        self.scale.weight_loader = self.weight_loader\n        self.weight.weight_spliter = self.weight_spliter\n        self.scale.weight_spliter = self.weight_spliter\n        if self.bias is not None:\n            self.bias.weight_loader = self.weight_loader\n            self.bias.weight_spliter = self.weight_spliter\n\n    def _get_io_features(self, in_features: int, out_features: int, colwise: bool):\n        \"\"\"Get io features.\"\"\"\n        return in_features, out_features\n\n    def _update_all_out_features(self, all_out_features: List[int]):\n        \"\"\"Update all out features.\"\"\"\n        world_size, rank = self.get_tp_world_rank()\n        new_all_out_features = []\n        for out_feat in all_out_features:\n            new_out_feat = get_distribute_size(out_feat, world_size, rank)\n            new_all_out_features.append(new_out_feat)\n        return new_all_out_features\n\n    def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, shard_id: Any):\n        \"\"\"Weight loader.\"\"\"\n        world_size, rank = self.get_tp_world_rank()\n        shard_idx = self.out_names_map[shard_id]\n        param_w = param.data.split(self.all_out_features, 0)[shard_idx]\n        loaded_weight = loaded_weight.chunk(world_size, 0)[rank]\n        param_w.copy_(loaded_weight)\n\n    def weight_spliter(self, loaded_weight: torch.Tensor):\n        \"\"\"Weight spliter.\"\"\"\n        return loaded_weight.split(self.split_section, dim=0)\n\n    def weight_spliter_lora_b(self, loaded_weight: torch.Tensor):\n        return loaded_weight.split(self.split_section, dim=0)\n\n\nclass QKVW8A8Linear(MergedW8A8Linear, QKVMixin):\n    \"\"\"Qkv w8a8 linear.\"\"\"\n\n    def __init__(self,\n                 in_features: int,\n                 num_q_heads: int,\n                 num_kv_heads: int,\n                 head_size: int,\n                 head_size_v: int,\n                 bias: bool = False,\n                 dtype: Optional[torch.dtype] = None,\n                 device: Optional[torch.device] = None,\n                 is_tp: bool = True,\n                 num_replicate_kv_heads: int = 1,\n                 quant_dtype: torch.dtype = torch.int8):\n        self.init_tp_args(is_tp, all_reduce=False, colwise=True, layer_type='attn')\n        QKVMixin.__init__(self,\n                          num_q_heads=num_q_heads,\n                          num_kv_heads=num_kv_heads,\n                          head_size=head_size,\n                          head_size_v=head_size_v,\n                          num_replicate_kv_heads=num_replicate_kv_heads,\n                          is_tp=is_tp,\n                          tp=self.tp,\n                          tp_rank=self.tp_rank)\n\n        all_out_features = self.get_qkv_out_feautures()\n        out_names = ('q', 'k', 'v')\n        super().__init__(in_features,\n                         all_out_features,\n                         bias=bias,\n                         dtype=dtype,\n                         device=device,\n                         is_tp=is_tp,\n                         out_names=out_names,\n                         quant_dtype=quant_dtype,\n                         layer_type='attn')\n\n    def _update_all_out_features(self, all_out_features: List[int]):\n        \"\"\"Update all out features.\"\"\"\n        return all_out_features\n\n    def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, shard_id: Any):\n        \"\"\"Weight loader.\"\"\"\n        _, rank = self.get_tp_world_rank()\n        shard_idx = self.out_names_map[shard_id]\n        param_w = param.data.split(self.all_out_features, 0)[shard_idx]\n        num_head = self.num_q_heads if shard_id == 'q' \\\n            else self.num_kv_heads\n        head_dim = self.head_size if shard_id in ['q', 'k'] \\\n            else self.head_size_v\n        # update to duplicate k/v for tp_size > num_kv_heads\n        rank_idx = rank if shard_id == 'q' \\\n            else rank // self.num_replicate_kv_heads\n        sec_start = rank_idx * num_head * head_dim\n        sec_len = num_head * head_dim\n        loaded_weight = loaded_weight.narrow(dim=0, start=sec_start, length=sec_len)\n        param_w.copy_(loaded_weight)\n\n    def weight_spliter(self, loaded_weight: torch.Tensor, layout: str = 'default'):\n        \"\"\"Weight spliter.\"\"\"\n        check_qkv_split_layout(layout)\n        if layout == 'default':\n            return loaded_weight.split(self.qkv_split_section, dim=0)\n        elif layout == 'hgd':\n            assert self.head_size == self.head_size_v\n            heads = [sec // self.head_size for sec in self.qkv_split_section]\n            kv_heads = heads[-1]\n            loaded_weight = loaded_weight.unflatten(0, (kv_heads, -1, self.head_size))\n            q = loaded_weight[:, :-2].flatten(0, 2)\n            k = loaded_weight[:, -2].flatten(0, 1)\n            v = loaded_weight[:, -1].flatten(0, 1)\n            return q, k, v\n        else:\n            raise RuntimeError(f'Unsupported layout: {layout}')\n\n    def weight_spliter_lora_b(self, loaded_weight: torch.Tensor):\n        return loaded_weight.split(self.qkv_split_section, dim=0)\n"
  },
  {
    "path": "lmdeploy/pytorch/nn/moe/__init__.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom typing import Callable, Dict, Optional\n\nimport torch\n\nfrom lmdeploy.pytorch.models.patch import get_build_model_context\n\nfrom .base import MoeType, SoftmaxTopK  # noqa: F401\n\n\ndef build_fused_moe(\n    hidden_dim: int,\n    ffn_dim: int,\n    num_experts: int,\n    top_k: int,\n    bias: bool = False,\n    renormalize: bool = False,\n    dtype: Optional[torch.dtype] = None,\n    device: Optional[torch.device] = None,\n    all_reduce: bool = True,\n    enable_ep: bool = False,\n    quant_config: Dict = None,\n    layer_idx: int = 0,\n    act_func: Callable = None,\n    prefix: str = '',\n):\n    \"\"\"Fused moe builder.\"\"\"\n    quant_method = None\n    if quant_config is not None:\n        quant_config = get_build_model_context().quant_config\n        quant_method = quant_config.get_quant_method(prefix)\n\n    if quant_method is None:\n        from .default import FusedMoE\n        return FusedMoE(\n            hidden_dim=hidden_dim,\n            ffn_dim=ffn_dim,\n            num_experts=num_experts,\n            top_k=top_k,\n            bias=bias,\n            renormalize=renormalize,\n            dtype=dtype,\n            device=device,\n            all_reduce=all_reduce,\n            layer_idx=layer_idx,\n            act_func=act_func,\n        )\n\n    if quant_method == 'smooth_quant':\n        assert not bias, 'Quant model does not support bias for now.'\n        assert act_func is None, ('Quant model does not support activation function for now.')\n        from .w8a8 import FusedMoEW8A8\n        return FusedMoEW8A8(\n            hidden_dim=hidden_dim,\n            ffn_dim=ffn_dim,\n            num_experts=num_experts,\n            top_k=top_k,\n            renormalize=renormalize,\n            dtype=dtype,\n            quant_dtype=quant_config.quant_dtype,\n            device=device,\n            all_reduce=all_reduce,\n        )\n    elif quant_method == 'fp8':\n        from .blocked_fp8 import FusedMoEBlockedF8\n        return FusedMoEBlockedF8(\n            hidden_dim=hidden_dim,\n            ffn_dim=ffn_dim,\n            num_experts=num_experts,\n            top_k=top_k,\n            bias=bias,\n            renormalize=renormalize,\n            fp8_dtype=quant_config.quant_dtype,\n            scale_fmt=quant_config.scale_fmt,\n            dtype=dtype,\n            device=device,\n            all_reduce=all_reduce,\n            layer_idx=layer_idx,\n            act_func=act_func,\n        )\n    else:\n        raise RuntimeError(f'Unsupported quant method: {quant_method}')\n"
  },
  {
    "path": "lmdeploy/pytorch/nn/moe/base.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom dataclasses import dataclass\nfrom enum import Enum, auto\nfrom typing import Callable, Dict, List, Optional\n\nimport torch\nimport torch.nn as nn\n\nimport lmdeploy.pytorch.distributed as dist\nfrom lmdeploy.pytorch.backends import OpType, get_backend\nfrom lmdeploy.pytorch.config import TPMode\nfrom lmdeploy.pytorch.distributed import get_dist_manager, get_tp_world_rank\nfrom lmdeploy.pytorch.model_inputs import get_step_ctx_manager\n\n\nclass MoeType(Enum):\n    \"\"\"Batch ecex type.\"\"\"\n    Default = auto()\n    DSAsyncDecode = auto()\n    DSAsyncPrefill = auto()\n\n\nclass SoftmaxTopK(nn.Module):\n    \"\"\"Softmax topk.\"\"\"\n\n    def __init__(self, top_k: int, dim: int = -1, n_groups: int = -1):\n        super().__init__()\n        self.top_k = top_k\n        impl_builder = get_backend().get_layer_impl_builder(OpType.SoftmaxTopK)\n        self.impl = impl_builder.build(top_k, dim, n_groups=n_groups)\n\n    def forward(self, x: torch.Tensor):\n        \"\"\"forward.\"\"\"\n        return self.impl.forward(x)\n\n\ndef update_dims(hidden_dim: int, ffn_dim: int):\n    \"\"\"Update dims.\"\"\"\n    world_size, _ = get_tp_world_rank('moe')\n    assert ffn_dim % world_size == 0\n    ffn_dim = ffn_dim // world_size\n    return hidden_dim, ffn_dim\n\n\ndef split_size(size: int, world_size: int, align: int):\n    size = size // align\n    base = size // world_size\n    remain = size % world_size\n    split_size = [base + 1] * remain + [base] * (world_size - remain)\n    split_size = [s * align for s in split_size]\n    return split_size\n\n\ndef moe_gather_inputs(hidden_states, topk_weights, topk_ids, group: Optional[dist.ProcessGroup] = None):\n    dist_config = get_dist_manager().current_config()\n    tp = dist_config.moe_tp\n    if tp == 1:\n        return hidden_states, topk_weights, topk_ids\n\n    tp_mode = dist_config.moe_tp_mode\n    if tp_mode == TPMode.DEFAULT:\n        return hidden_states, topk_weights, topk_ids\n    elif tp_mode == TPMode.DP_TP:\n        step_ctx = get_step_ctx_manager().current_context()\n        dp_meta = step_ctx.dp_meta\n        tp_sizes = dp_meta.moe_tp_sizes\n        hidden_states = dist.gather_by_tp_sizes(hidden_states, tp_sizes, group=group)\n        topk_weights = dist.gather_by_tp_sizes(topk_weights, tp_sizes, group=group)\n        topk_ids = dist.gather_by_tp_sizes(topk_ids, tp_sizes, group=group)\n    else:\n        raise RuntimeError('Not supported.')\n\n    return hidden_states, topk_weights, topk_ids\n\n\ndef moe_reduce(ret, rank: int, tp_mode: TPMode, group: Optional[dist.ProcessGroup] = None):\n    dist_config = get_dist_manager().current_config()\n    if dist_config.moe_tp == 1:\n        return ret\n\n    if tp_mode == TPMode.DEFAULT:\n        dist.all_reduce(ret, group=group)\n        return ret\n    elif tp_mode == TPMode.DP_TP:\n        step_ctx = get_step_ctx_manager().current_context()\n        dp_meta = step_ctx.dp_meta\n        tp_size = dp_meta.moe_tp_sizes\n        ret = dist.reduce_scatter_by_tp_sizes(ret, rank, tp_size, group=group)\n        return ret\n    else:\n        raise RuntimeError('Not supported.')\n\n\nclass MoEForwardDPTP:\n\n    def __init__(self, gemm_func: Callable, max_tokens_per_round: int = 8192):\n        \"\"\"MoE forward dp tp.\"\"\"\n        self.gemm_func = gemm_func\n        self.dist_ctx = get_dist_manager().current_context()\n        self.dist_config = self.dist_ctx.dist_config\n        self.tp = self.dist_config.moe_tp\n        self.attn_tp = self.dist_config.attn_tp\n\n        tp_group = self.dist_ctx.moe_tp_group\n        self.rank = tp_group.rank\n        self.gather_rank = self.rank // self.attn_tp\n        self.gather_group = tp_group.gpu_gather_group\n        self.tp_group = tp_group.gpu_group\n        self.max_tokens_per_round = max_tokens_per_round * self.attn_tp // self.tp\n\n    def all_gather(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor,\n                   tp_sizes: List[int]):\n        \"\"\"All gather.\"\"\"\n        hidden_states, h0 = dist.gather_by_tp_sizes(hidden_states, tp_sizes, group=self.gather_group, async_op=True)\n        topk_weights, h1 = dist.gather_by_tp_sizes(topk_weights, tp_sizes, group=self.gather_group, async_op=True)\n        topk_ids, h2 = dist.gather_by_tp_sizes(topk_ids, tp_sizes, group=self.gather_group, async_op=True)\n        return hidden_states, topk_weights, topk_ids, (h0, h1, h2)\n\n    def reduce_scatter(self, hidden_states: torch.Tensor, out_states: torch.Tensor, tp_sizes: List[int]):\n        \"\"\"Reduce scatter.\"\"\"\n        hidden_states_list = list(hidden_states.split(tp_sizes, -2))\n        cur_out_states = hidden_states_list[self.gather_rank]\n        out_states.copy_(cur_out_states)\n        hidden_states_list = [item for item in hidden_states_list for _ in range(self.attn_tp)]\n        hidden_states_list[self.rank] = out_states\n        handle = dist.reduce_scatter(out_states, hidden_states_list, group=self.tp_group, async_op=True)\n        return out_states, handle\n\n    def _gemm_and_reduce_scatter(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor,\n                                 output_states: torch.Tensor, tp_sizes: List[int], handles: List[dist.Work]):\n        \"\"\"Gemm and reduce scatter.\"\"\"\n        for handle in handles:\n            handle.wait()\n        cur_out = self.gemm_func(hidden_states, topk_weights, topk_ids)\n        return self.reduce_scatter(cur_out, output_states, tp_sizes)\n\n    def forward(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor):\n        \"\"\"forward.\"\"\"\n\n        def __slice_tensor(tensor: torch.Tensor, slice_size: int):\n            \"\"\"Slice tensor.\"\"\"\n            cur_tensor = tensor[:slice_size]\n            tensor = tensor[slice_size:]\n            return cur_tensor, tensor\n\n        def __slice_and_gather():\n            \"\"\"Slice and gather.\"\"\"\n            nonlocal hidden_states, topk_weights, topk_ids, tp_sizes, output_states\n            cur_tp_sizes = tp_sizes.minimum(max_tokens_per_round)\n            tp_sizes -= cur_tp_sizes\n            cur_tp_sizes = cur_tp_sizes.tolist()\n\n            slice_size = cur_tp_sizes[self.gather_rank]\n            cur_hidden_states, hidden_states = __slice_tensor(hidden_states, slice_size)\n            cur_topk_weights, topk_weights = __slice_tensor(topk_weights, slice_size)\n            cur_topk_ids, topk_ids = __slice_tensor(topk_ids, slice_size)\n            cur_output, output_states = __slice_tensor(output_states, slice_size)\n\n            # all gather\n            cur_hidden_states, cur_topk_weights, cur_topk_ids, handles = self.all_gather(\n                cur_hidden_states, cur_topk_weights, cur_topk_ids, cur_tp_sizes)\n            return dict(hidden_states=cur_hidden_states,\n                        topk_weights=cur_topk_weights,\n                        topk_ids=cur_topk_ids,\n                        output_states=cur_output,\n                        handles=handles,\n                        tp_sizes=cur_tp_sizes)\n\n        step_ctx = get_step_ctx_manager().current_context()\n        tp_sizes = step_ctx.dp_meta.moe_tp_sizes\n        tp_sizes = torch.tensor(tp_sizes)\n        max_tokens_per_round = tp_sizes.new_tensor(self.max_tokens_per_round)\n\n        output_states = torch.empty_like(hidden_states)\n        return_states = output_states\n\n        # pre\n        cur_inputs = __slice_and_gather()\n\n        out_handles = []\n        # main loop\n        while tp_sizes.sum() > 0:\n            next_inputs = __slice_and_gather()\n            _, handle = self._gemm_and_reduce_scatter(**cur_inputs)\n            out_handles.append(handle)\n            cur_inputs = next_inputs\n\n        # post\n        _, handle = self._gemm_and_reduce_scatter(**cur_inputs)\n        out_handles.append(handle)\n        for handle in out_handles:\n            handle.wait()\n        return return_states\n\n\ndef _renormalize(topk_weights: torch.Tensor, renormalize: bool):\n    if renormalize:\n        topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)\n    if not topk_weights.is_contiguous():\n        topk_weights = topk_weights.contiguous()\n    return topk_weights\n\n\n@dataclass\nclass DispatchInputs:\n    \"\"\"Dispatch inputs.\"\"\"\n    hidden_states: torch.Tensor\n    topk_weights: torch.Tensor\n    topk_idx: torch.LongTensor\n    moe_type: MoeType = MoeType.Default\n\n    @classmethod\n    def from_dict(cls, input: Dict):\n        \"\"\"From dict.\"\"\"\n        assert ['hidden_states', 'topk_weights', 'topk_idx'] in input\n        moe_type = input.get('moe_type', MoeType.Default)\n        return cls(\n            hidden_states=input['hidden_states'],\n            topk_weights=input['topk_weights'],\n            topk_idx=input['topk_idx'],\n            moe_type=moe_type,\n        )\n\n    def to_dict(self) -> Dict:\n        \"\"\"To dict.\"\"\"\n        return {\n            'hidden_states': self.hidden_states,\n            'topk_weights': self.topk_weights,\n            'topk_idx': self.topk_idx,\n            'moe_type': self.moe_type,\n        }\n\n\nclass FusedMoEBase(nn.Module):\n    \"\"\"Fused MoE base.\"\"\"\n\n    def __init__(self, tp: int, tp_mode: TPMode, do_renormalize: bool):\n        super().__init__()\n        self.tp = tp\n        self.tp_mode = tp_mode\n        self.do_renormalize = do_renormalize\n\n    def init_dist_args(self, all_reduce: bool):\n        \"\"\"Init tp args.\"\"\"\n        dist_ctx = get_dist_manager().current_context()\n        dist_cfg = dist_ctx.dist_config\n        _, tp_mode = dist_cfg.get_tp_by_layer('moe')\n        tp, tp_rank = get_tp_world_rank('moe')\n        all_reduce = all_reduce if tp > 1 else False\n\n        self.ep = dist_cfg.ep\n        self.tp = tp\n        self.tp_rank = tp_rank\n        self.tp_mode = tp_mode\n        self.all_reduce = all_reduce\n        self.tp_group = dist_ctx.moe_tp_group.gpu_group\n        self.gather_group = dist_ctx.moe_tp_group.gpu_gather_group\n\n        if self.tp > 1 and self.tp_mode == TPMode.DP_TP:\n\n            def __gemm_func(hidden_states, topk_weights, topk_ids):\n                return self.gemm(\n                    dict(\n                        hidden_states=hidden_states,\n                        topk_weights=topk_weights,\n                        topk_idx=topk_ids,\n                        moe_type=MoeType.Default,\n                    ))['hidden_states']\n\n            self._forward_dptp = MoEForwardDPTP(__gemm_func)\n        else:\n            self._forward_dptp = None\n\n    def before_dispatch(self, state: DispatchInputs):\n        \"\"\"Before dispatch.\"\"\"\n        raise NotImplementedError\n\n    def dispatch(self, state: Dict):\n        \"\"\"dispatch.\"\"\"\n        raise NotImplementedError\n\n    def gemm(self, state: Dict):\n        \"\"\"gemm.\"\"\"\n        raise NotImplementedError\n\n    def combine(self, state: Dict):\n        \"\"\"combine.\"\"\"\n        raise NotImplementedError\n\n    def wait(self, state: Dict):\n        \"\"\"wait.\"\"\"\n        raise NotImplementedError\n\n    @property\n    def forward_dptp(self) -> MoEForwardDPTP:\n        \"\"\"Forward dptp.\"\"\"\n        return self._forward_dptp\n\n    def forward_default(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_idx: torch.LongTensor):\n        \"\"\"Default forward.\"\"\"\n        state = {\n            'hidden_states': hidden_states,\n            'topk_idx': topk_idx,\n            'topk_weights': topk_weights,\n            'moe_type': MoeType.Default,\n        }\n        recv_state = self.dispatch(state)\n        gemm_state = self.gemm(recv_state)\n        out_state = self.combine(gemm_state)\n        return out_state['hidden_states']\n\n    def forward(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_idx: torch.LongTensor):\n        \"\"\"forward.\"\"\"\n        if self.tp > 1 and self.tp_mode == TPMode.DP_TP:\n            return self.forward_dptp.forward(hidden_states, topk_weights, topk_idx)\n        else:\n            return self.forward_default(hidden_states, topk_weights, topk_idx)\n\n    def renormalize(self, topk_weights):\n        \"\"\"renormalize.\"\"\"\n        return _renormalize(topk_weights, self.do_renormalize)\n"
  },
  {
    "path": "lmdeploy/pytorch/nn/moe/blocked_fp8.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom typing import Callable, Dict, List, Optional\n\nimport torch\n\nfrom lmdeploy.pytorch.backends import OpType, get_backend\nfrom lmdeploy.pytorch.distributed import get_dist_manager, get_ep_world_rank, get_tp_world_rank\n\nfrom ..quant_utils import quant_blocked_fp8\nfrom ..utils import div_up\nfrom .base import DispatchInputs, FusedMoEBase, MoeType, moe_gather_inputs, moe_reduce\nfrom .base import split_size as _split_size\nfrom .default import LinearWeights\n\n\nclass LinearWeightsBlockedF8(LinearWeights):\n    \"\"\"Fused moe linear blocked fp8 weights.\"\"\"\n\n    def __init__(self,\n                 num_experts: int,\n                 in_features: int,\n                 out_features: int,\n                 weight_type: str,\n                 block_size: int,\n                 dtype: torch.dtype,\n                 device: torch.device,\n                 bias: bool = False,\n                 expert_list: List[int] = None,\n                 scale_fmt: Optional[str] = None):\n        super().__init__(num_experts=num_experts,\n                         in_features=in_features,\n                         out_features=out_features,\n                         weight_type=weight_type,\n                         dtype=dtype,\n                         device=device,\n                         bias=bias,\n                         expert_list=expert_list)\n        self.scale_fmt = scale_fmt\n        self.block_size = block_size\n        weight_scale_inv = torch.empty((num_experts, div_up(out_features, block_size), div_up(in_features, block_size)),\n                                       dtype=torch.float32,\n                                       device=device)\n        weight_scale_inv = torch.nn.Parameter(weight_scale_inv, requires_grad=False)\n        self.register_parameter('weight_scale_inv', weight_scale_inv)\n\n        if self.ep:\n            self.weight._base_weight_loader = self.weight.weight_loader\n            self.weight_scale_inv.weight_loader = self.weight_loader_scale_ep\n        else:\n            self.weight._base_weight_loader = self.weight_loader_tp_blocked_fp8\n            self.weight_scale_inv.weight_loader = self.weight_loader_scale_tp\n        self.weight.weight_loader = self.weight_loader_with_quant\n\n    def update_weight(self, weight: torch.Tensor, weight_scale_inv: torch.Tensor):\n        \"\"\"Update weight.\"\"\"\n        super().update_weight(weight=weight)\n        weight_loader = self.weight_scale_inv.weight_loader\n        weight_scale_inv = torch.nn.Parameter(weight_scale_inv, requires_grad=False)\n        weight_scale_inv.weight_loader = weight_loader\n        self.register_parameter('weight_scale_inv', weight_scale_inv)\n\n    def weight_loader_scale_ep(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, expert_id: int,\n                               shard_id: str):\n        expert_list = self.expert_list\n        if expert_id not in expert_list:\n            return\n        expert_ids = self.expert_map[expert_id]\n        for expert_id in expert_ids:\n            self.weight_loader_scale_tp(param, loaded_weight, expert_id, shard_id)\n\n    def _chunk_weight_tp(self, weight: torch.Tensor, dim: int, world_size: int, rank: int, align: int):\n        \"\"\"Chunk with align.\"\"\"\n        split_size = _split_size(weight.size(dim), world_size, align)\n        return weight.split(split_size, dim=dim)[rank]\n\n    def weight_loader_tp_blocked_fp8(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, expert_id: int,\n                                     shard_id: str):\n        \"\"\"Weight loader.\"\"\"\n        world_size, rank = get_tp_world_rank('moe')\n        if shard_id == 'gate':\n            param_data = param.data[expert_id, :self.half_out]\n            weight = self._chunk_weight_tp(loaded_weight,\n                                           dim=0,\n                                           world_size=world_size,\n                                           rank=rank,\n                                           align=self.block_size)\n        elif shard_id == 'up':\n            param_data = param.data[expert_id, self.half_out:]\n            weight = self._chunk_weight_tp(loaded_weight,\n                                           dim=0,\n                                           world_size=world_size,\n                                           rank=rank,\n                                           align=self.block_size)\n        elif shard_id == 'down':\n            param_data = param.data[expert_id]\n            # weight is not contiguous, chunk and copy in cpu is slow\n            weight = loaded_weight.to(param_data.device)\n            if weight.dim() > 1:\n                weight = self._chunk_weight_tp(weight, dim=1, world_size=world_size, rank=rank, align=self.block_size)\n            elif weight.dim() == 1 and rank != 0:\n                # bias with rank>0 should be 0\n                weight = torch.zeros_like(weight)\n        else:\n            raise RuntimeError(f'Unknown shard_id: {shard_id}')\n        param_data.copy_(weight)\n\n    def weight_loader_scale_tp(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, expert_id: int,\n                               shard_id: str):\n        \"\"\"Weight loader scale tp.\"\"\"\n        world_size, rank = get_tp_world_rank('moe')\n        block_size = self.block_size\n        half_out = self.half_out // block_size\n        if shard_id == 'gate':\n            param_data = param.data[expert_id, :half_out]\n            weight = self._chunk_weight_tp(loaded_weight, dim=0, world_size=world_size, rank=rank, align=1)\n        elif shard_id == 'up':\n            param_data = param.data[expert_id, half_out:]\n            weight = self._chunk_weight_tp(loaded_weight, dim=0, world_size=world_size, rank=rank, align=1)\n        elif shard_id == 'down':\n            param_data = param.data[expert_id]\n            loaded_weight = loaded_weight.to(param_data.device)\n            weight = self._chunk_weight_tp(loaded_weight, dim=1, world_size=world_size, rank=rank, align=1)\n        else:\n            raise RuntimeError(f'Unknown shard_id: {shard_id}')\n        param_data.copy_(weight)\n\n    def weight_loader_with_quant(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, expert_id: int,\n                                 shard_id: str):\n        \"\"\"Weight load with quant.\"\"\"\n        if loaded_weight.dtype != param.dtype:\n            # quant loaded weight\n            quanted_weight, scaling = quant_blocked_fp8(loaded_weight.to(param.device),\n                                                        param.dtype,\n                                                        self.block_size,\n                                                        scale_fmt=self.scale_fmt)\n            self.weight._base_weight_loader(self.weight, quanted_weight, expert_id, shard_id)\n            self.weight_scale_inv.weight_loader(self.weight_scale_inv, scaling, expert_id, shard_id)\n        else:\n            return self.weight._base_weight_loader(param, loaded_weight, expert_id, shard_id)\n\n\nclass FusedMoEBlockedF8(FusedMoEBase):\n    \"\"\"Fused moe blocked f8.\"\"\"\n\n    def __init__(self,\n                 hidden_dim: int,\n                 ffn_dim: int,\n                 num_experts: int,\n                 top_k: int,\n                 bias: bool = False,\n                 renormalize: bool = False,\n                 fp8_dtype: torch.dtype = torch.float8_e4m3fn,\n                 scale_fmt: Optional[str] = None,\n                 dtype: Optional[torch.dtype] = None,\n                 device: Optional[torch.device] = None,\n                 all_reduce: bool = True,\n                 layer_idx: int = 0,\n                 act_func: Callable = None):\n\n        device = device or torch.device('cpu')\n        dtype = dtype or torch.float16\n        # init distributed tp arguments\n        self.block_size = 128\n        self.init_dist_args(all_reduce)\n        self.scale_fmt = scale_fmt\n\n        super().__init__(\n            tp=self.tp,\n            tp_mode=self.tp_mode,\n            do_renormalize=renormalize,\n        )\n\n        dist_ctx = get_dist_manager().current_context()\n        self.ep_size, rank = get_ep_world_rank()\n        impl_builder = get_backend().get_layer_impl_builder(OpType.FusedMoEBlockedF8)\n        self.impl = impl_builder.build(top_k,\n                                       num_experts,\n                                       hidden_dim,\n                                       renormalize,\n                                       block_size=self.block_size,\n                                       ep_size=self.ep_size,\n                                       ep_group=dist_ctx.ep_gpu_group,\n                                       out_dtype=dtype,\n                                       layer_idx=layer_idx,\n                                       custom_gateup_act=act_func is not None)\n        self.impl.set_scale_fmt(scale_fmt)\n\n        if self.ep_size > 1:\n            expert_list = self.impl.ep_expert_list(self.ep_size, rank)\n            num_experts = len(expert_list)\n        else:\n            hidden_dim, ffn_dim = self._update_args(hidden_dim, ffn_dim, align=self.block_size)\n            expert_list = None\n        self.expert_list = expert_list\n\n        # create weights\n        self.gate_up = LinearWeightsBlockedF8(num_experts,\n                                              hidden_dim,\n                                              ffn_dim * 2,\n                                              weight_type='gate_up',\n                                              block_size=self.block_size,\n                                              dtype=fp8_dtype,\n                                              device=device,\n                                              bias=bias,\n                                              expert_list=expert_list,\n                                              scale_fmt=scale_fmt)\n        self.down = LinearWeightsBlockedF8(num_experts,\n                                           ffn_dim,\n                                           hidden_dim,\n                                           weight_type='down',\n                                           block_size=self.block_size,\n                                           dtype=fp8_dtype,\n                                           device=device,\n                                           bias=bias,\n                                           expert_list=expert_list,\n                                           scale_fmt=scale_fmt)\n\n        self.hidden_dim = hidden_dim\n        self.ffn_dim = ffn_dim\n        self.num_experts = num_experts\n        self.dtype = dtype\n        self.device = device\n        self.act_func = act_func\n\n    @staticmethod\n    def _update_args(hidden_dim: int, ffn_dim: int, align: int):\n        world_size, rank = get_tp_world_rank('moe')\n        split_size = _split_size(ffn_dim, world_size, align)\n        ffn_dim = split_size[rank]\n        return hidden_dim, ffn_dim\n\n    def update_weights(self):\n        \"\"\"Update weights.\"\"\"\n        (gate_up_weights, down_weights, gate_up_scale,\n         down_scale) = self.impl.update_weights(self.gate_up.weight, self.down.weight, self.gate_up.weight_scale_inv,\n                                                self.down.weight_scale_inv)\n        self.gate_up.update_weight(gate_up_weights, gate_up_scale)\n        self.down.update_weight(down_weights, down_scale)\n\n    def before_dispatch(self, state: DispatchInputs):\n        \"\"\"Before dispatch.\"\"\"\n        if not isinstance(state, Dict):\n            state = state.to_dict()\n\n        moe_type = state['moe_type']\n        if moe_type == MoeType.DSAsyncPrefill:\n            fusedmoe = self.fusedmoe_build(low_latency_mode=False)\n            state['fusedmoe'] = fusedmoe\n            if hasattr(fusedmoe, 'per_token_group_quant_fp8'):\n                state['hidden_states'] = fusedmoe.per_token_group_quant_fp8(state['hidden_states'])\n            previous_event = fusedmoe.capture()\n            state['previous_event'] = previous_event\n        return state\n\n    def dispatch(self, state: Dict):\n        moe_type = state['moe_type']\n        if moe_type == MoeType.DSAsyncPrefill:\n            fusedmoe = state['fusedmoe']\n            previous_event = state['previous_event']\n            (\n                recv_hidden_states,\n                recv_topk_idx,\n                recv_topk_weights,\n                recv_tokens_per_expert,\n                handle,\n                event,\n            ) = fusedmoe.dispatch_async(state['hidden_states'],\n                                        state['topk_idx'],\n                                        state['topk_weights'],\n                                        previous_event=previous_event,\n                                        async_finish=True)\n            recv_state = {\n                'fusedmoe': fusedmoe,\n                'recv_hidden_states': recv_hidden_states,\n                'recv_topk_idx': recv_topk_idx,\n                'recv_topk_weights': recv_topk_weights,\n                'recv_tokens_per_expert': recv_tokens_per_expert,\n                'handle': handle,\n                'event': event,\n                'num_experts': self.num_experts,\n                'moe_type': state['moe_type']\n            }\n        elif moe_type == MoeType.DSAsyncDecode:\n            fusedmoe = self.fusedmoe_build(low_latency_mode=True)\n            use_event = False\n            (recv_hidden_states, recv_expert_count, handle, event,\n             hook) = fusedmoe.dispatch_async(state['hidden_states'],\n                                             state['topk_idx'],\n                                             use_fp8=True,\n                                             async_finish=use_event)\n            recv_state = {\n                'fusedmoe': fusedmoe,\n                'recv_hidden_states': recv_hidden_states,\n                'recv_expert_count': recv_expert_count,\n                'topk_idx': state['topk_idx'],\n                'topk_weights': state['topk_weights'],\n                'raw_hidden_shape': state['raw_hidden_shape'],\n                'handle': handle,\n                'moe_type': state['moe_type']\n            }\n            if use_event:\n                recv_state['event'] = event\n            else:\n                recv_state['hook'] = hook\n        else:  # MoeType.Default\n            hidden_states, topk_weights, topk_idx = moe_gather_inputs(state['hidden_states'],\n                                                                      state['topk_weights'],\n                                                                      state['topk_idx'],\n                                                                      group=self.gather_group)\n            recv_state = {\n                'hidden_states': hidden_states,\n                'topk_idx': topk_idx,\n                'topk_weights': topk_weights,\n                'moe_type': state['moe_type']\n            }\n        return recv_state\n\n    def gemm(self, state: Dict):\n        moe_type = state['moe_type']\n        if moe_type == MoeType.DSAsyncPrefill:\n            if (state['recv_hidden_states'][0]\n                    if isinstance(state['recv_hidden_states'], tuple) else state['recv_hidden_states']).shape[0] > 0:\n                state['recv_hidden_states'] = state['fusedmoe'].fusedmoe_forward(state, self.gate_up.weight,\n                                                                                 self.gate_up.weight_scale_inv,\n                                                                                 self.down.weight,\n                                                                                 self.down.weight_scale_inv)\n            gemm_state = {\n                'fusedmoe': state['fusedmoe'],\n                'hidden_states': state['recv_hidden_states'],\n                'handle': state['handle'],\n                'moe_type': state['moe_type']\n            }\n        elif moe_type == MoeType.DSAsyncDecode:\n            state['recv_hidden_states'] = state['fusedmoe'].fusedmoe_forward(state, self.gate_up.weight,\n                                                                             self.gate_up.weight_scale_inv,\n                                                                             self.down.weight,\n                                                                             self.down.weight_scale_inv)\n            gemm_state = {\n                'fusedmoe': state['fusedmoe'],\n                'hidden_states': state['recv_hidden_states'],\n                'topk_idx': state['topk_idx'],\n                'topk_weights': state['topk_weights'],\n                'handle': state['handle'],\n                'moe_type': state['moe_type']\n            }\n        else:  # MoeType.Default\n            if self.gate_up.weight.numel() == 0:\n                # current rank get no expert chunk\n                # create a zero tensor with the same shape as hidden_states\n                gemm_state = {'hidden_states': torch.zeros_like(state['hidden_states']), 'moe_type': state['moe_type']}\n            else:\n                # default fused moe\n                hidden_states = self.impl.forward(state['hidden_states'],\n                                                  state['topk_weights'],\n                                                  state['topk_idx'],\n                                                  self.gate_up.weight,\n                                                  self.gate_up.weight_scale_inv,\n                                                  self.down.weight,\n                                                  self.down.weight_scale_inv,\n                                                  gate_up_bias=self.gate_up.bias,\n                                                  down_bias=self.down.bias,\n                                                  expert_list=self.expert_list,\n                                                  act_func=self.act_func)\n                gemm_state = {'hidden_states': hidden_states, 'moe_type': state['moe_type']}\n        return gemm_state\n\n    def combine(self, state: Dict):\n        moe_type = state['moe_type']\n        if moe_type == MoeType.DSAsyncPrefill:\n            fusedmoe = state['fusedmoe']\n            previous_event = fusedmoe.capture()\n            out_hidden_states, event = fusedmoe.combine_async(state['hidden_states'],\n                                                              state['handle'],\n                                                              previous_event=previous_event,\n                                                              async_finish=True)\n            out_state = {\n                'fusedmoe': state['fusedmoe'],\n                'hidden_states': out_hidden_states,\n                'event': event,\n                'moe_type': state['moe_type']\n            }\n        elif moe_type == MoeType.DSAsyncDecode:\n            fusedmoe = state['fusedmoe']\n            use_event = False\n            out_hidden_states, event, hook = fusedmoe.combine_async(state['hidden_states'],\n                                                                    state['topk_idx'],\n                                                                    state['topk_weights'],\n                                                                    state['handle'],\n                                                                    async_finish=use_event)\n            out_state = {\n                'fusedmoe': state['fusedmoe'],\n                'hidden_states': out_hidden_states,\n                'moe_type': state['moe_type']\n            }\n            if use_event:\n                out_state['event'] = event\n            else:\n                out_state['hook'] = hook\n        else:  # MoeType.Default\n            if self.all_reduce:\n                state['hidden_states'] = moe_reduce(state['hidden_states'],\n                                                    rank=self.tp_rank,\n                                                    tp_mode=self.tp_mode,\n                                                    group=self.tp_group)\n            out_state = {'hidden_states': state['hidden_states'], 'moe_type': state['moe_type']}\n        return out_state\n\n    def wait(self, state):\n        if state.get('event', None) is not None:\n            state['fusedmoe'].wait(state['event'])\n            return True\n        elif state.get('hook', None) is not None:\n            state['hook']()\n            return True\n        else:\n            return False\n\n    def fusedmoe_build(self, low_latency_mode: bool = False):\n        return self.impl.fusedmoe_build(low_latency_mode)\n"
  },
  {
    "path": "lmdeploy/pytorch/nn/moe/default.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom collections import defaultdict\nfrom typing import Callable, Dict, List, Optional\n\nimport torch\nfrom torch import nn\n\nfrom lmdeploy.pytorch.backends import OpType, get_backend\nfrom lmdeploy.pytorch.distributed import get_dist_manager, get_ep_world_rank, get_tp_world_rank\n\nfrom .base import DispatchInputs, FusedMoEBase, MoeType, moe_gather_inputs, moe_reduce, update_dims\n\n\nclass LinearWeights(nn.Module):\n    \"\"\"Fused moe linear weights.\"\"\"\n\n    def __init__(self,\n                 num_experts: int,\n                 in_features: int,\n                 out_features: int,\n                 weight_type: str,\n                 dtype: torch.dtype,\n                 device: torch.device,\n                 bias: bool = False,\n                 expert_list: Optional[List[int]] = None):\n        super().__init__()\n        weight = torch.empty((num_experts, out_features, in_features), dtype=dtype, device=device)\n        weight = torch.nn.Parameter(weight, requires_grad=False)\n        self.register_parameter('weight', weight)\n\n        if bias:\n            bias = torch.empty((num_experts, out_features), dtype=dtype, device=device)\n            bias = torch.nn.Parameter(bias, requires_grad=False)\n            self.register_parameter('bias', bias)\n        else:\n            self.bias = None\n\n        self.ep = expert_list is not None\n        self.expert_list = expert_list\n        self.weight_type = weight_type\n        self.half_out = out_features // 2\n\n        self.setup_weight_loader()\n\n    def setup_weight_loader(self):\n        \"\"\"Setup weight loader.\"\"\"\n        if self.expert_list is not None:\n            self.expert_map = defaultdict(list)\n            for idx, eid in enumerate(self.expert_list):\n                self.expert_map[eid].append(idx)\n            self.weight.weight_loader = self.weight_loader_ep\n            if self.bias is not None:\n                self.bias.weight_loader = self.weight_loader_ep\n        else:\n            self.weight.weight_loader = self.weight_loader_tp\n            if self.bias is not None:\n                self.bias.weight_loader = self.weight_loader_tp\n\n    def update_weight(self, weight: torch.Tensor):\n        \"\"\"Update weight.\"\"\"\n        weight_loader = self.weight.weight_loader\n        weight = torch.nn.Parameter(weight, requires_grad=False)\n        weight.weight_loader = weight_loader\n        self.register_parameter('weight', weight)\n\n    def weight_loader_tp(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, expert_id: int, shard_id: str):\n        \"\"\"Weight loader.\"\"\"\n        world_size, rank = get_tp_world_rank('moe')\n        if shard_id == 'gate':\n            param_data = param.data[expert_id, :self.half_out]\n            weight = loaded_weight.chunk(world_size, dim=0)[rank]\n        elif shard_id == 'up':\n            param_data = param.data[expert_id, self.half_out:]\n            weight = loaded_weight.chunk(world_size, dim=0)[rank]\n        elif shard_id == 'down':\n            param_data = param.data[expert_id]\n            # weight is not contiguous, chunk and copy in cpu is slow\n            weight = loaded_weight.to(param_data.device)\n            if weight.dim() > 1:\n                weight = weight.chunk(world_size, dim=1)[rank]\n            elif weight.dim() == 1 and rank != 0:\n                # bias with rank>0 should be 0\n                weight = torch.zeros_like(weight)\n        else:\n            raise RuntimeError(f'Unknown shard_id: {shard_id}')\n        param_data.copy_(weight)\n\n    def weight_loader_ep(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, expert_id: int, shard_id: str):\n        \"\"\"Weight loader.\"\"\"\n        expert_list = self.expert_list\n        if expert_id not in expert_list:\n            return\n\n        expert_map = self.expert_map\n        param_ids = expert_map[expert_id]\n        for param_id in param_ids:\n            if shard_id == 'gate':\n                param_data = param.data[param_id, :self.half_out]\n            elif shard_id == 'up':\n                param_data = param.data[param_id, self.half_out:]\n            elif shard_id == 'down':\n                param_data = param.data[param_id]\n            else:\n                raise RuntimeError(f'Unknown shard_id: {shard_id}')\n            param_data.copy_(loaded_weight)\n\n\nclass FusedMoE(FusedMoEBase):\n    \"\"\"Fused MoE.\"\"\"\n\n    def __init__(self,\n                 hidden_dim: int,\n                 ffn_dim: int,\n                 num_experts: int,\n                 top_k: int,\n                 bias: bool = False,\n                 renormalize: bool = False,\n                 dtype: Optional[torch.dtype] = None,\n                 device: Optional[torch.device] = None,\n                 all_reduce: bool = True,\n                 layer_idx: int = 0,\n                 act_func: Callable = None):\n\n        device = device or torch.device('cpu')\n        dtype = dtype or torch.float16\n        # init distributed tp arguments\n        self.init_dist_args(all_reduce)\n\n        super().__init__(\n            tp=self.tp,\n            tp_mode=self.tp_mode,\n            do_renormalize=renormalize,\n        )\n\n        # create implementation\n        dist_ctx = get_dist_manager().current_context()\n        self.ep_size, rank = get_ep_world_rank()\n        impl_builder = get_backend().get_layer_impl_builder(OpType.FusedMoE)\n        self.impl = impl_builder.build(\n            top_k,\n            num_experts,\n            renormalize,\n            hidden_dim=hidden_dim,\n            ep_size=self.ep_size,\n            ep_group=dist_ctx.ep_gpu_group,\n            layer_idx=layer_idx,\n        )\n\n        # create weights\n        if self.ep_size > 1:\n            expert_list = self.impl.ep_expert_list(self.ep_size, rank)\n            num_experts = len(expert_list)\n        else:\n            hidden_dim, ffn_dim = update_dims(hidden_dim, ffn_dim)\n            expert_list = None\n        self.expert_list = expert_list\n        self.gate_up = LinearWeights(num_experts,\n                                     hidden_dim,\n                                     ffn_dim * 2,\n                                     weight_type='gate_up',\n                                     dtype=dtype,\n                                     device=device,\n                                     bias=bias,\n                                     expert_list=expert_list)\n        self.down = LinearWeights(\n            num_experts,\n            ffn_dim,\n            hidden_dim,\n            weight_type='down',\n            dtype=dtype,\n            device=device,\n            bias=bias,\n            expert_list=expert_list,\n        )\n\n        self.hidden_dim = hidden_dim\n        self.ffn_dim = ffn_dim\n        self.num_experts = num_experts\n        self.dtype = dtype\n        self.device = device\n        self.act_func = act_func\n\n    def update_weights(self):\n        \"\"\"Update weights.\"\"\"\n        gate_up_weights, down_weights = self.impl.update_weights(self.gate_up.weight, self.down.weight)\n        self.gate_up.update_weight(gate_up_weights)\n        self.down.update_weight(down_weights)\n\n    def before_dispatch(self, state: DispatchInputs):\n        \"\"\"Before dispatch.\"\"\"\n        if not isinstance(state, Dict):\n            state = state.to_dict()\n\n        moe_type = state['moe_type']\n        if moe_type == MoeType.DSAsyncPrefill:\n            fusedmoe = self.fusedmoe_build(low_latency_mode=False)\n            state['fusedmoe'] = fusedmoe\n            previous_event = fusedmoe.capture()\n            state['previous_event'] = previous_event\n        return state\n\n    def dispatch(self, state: Dict):\n        \"\"\"dispatch.\"\"\"\n        moe_type = state['moe_type']\n        if moe_type == MoeType.DSAsyncPrefill:\n            fusedmoe = state['fusedmoe']\n            previous_event = state['previous_event']\n            (\n                recv_hidden_states,\n                recv_topk_idx,\n                recv_topk_weights,\n                recv_tokens_per_expert,\n                handle,\n                event,\n            ) = fusedmoe.dispatch_async(state['hidden_states'],\n                                        state['topk_idx'],\n                                        state['topk_weights'],\n                                        previous_event=previous_event,\n                                        async_finish=True)\n            recv_state = {\n                'fusedmoe': fusedmoe,\n                'recv_hidden_states': recv_hidden_states,\n                'recv_topk_idx': recv_topk_idx,\n                'recv_topk_weights': recv_topk_weights,\n                'recv_tokens_per_expert': recv_tokens_per_expert,\n                'handle': handle,\n                'event': event,\n                'num_experts': self.num_experts,\n                'moe_type': state['moe_type']\n            }\n        elif moe_type == MoeType.DSAsyncDecode:\n            fusedmoe = self.fusedmoe_build(low_latency_mode=True)\n            use_event = False\n            (recv_hidden_states, recv_expert_count, handle, event,\n             hook) = fusedmoe.dispatch_async(state['hidden_states'],\n                                             state['topk_idx'],\n                                             use_fp8=False,\n                                             async_finish=use_event)\n            recv_state = {\n                'fusedmoe': fusedmoe,\n                'recv_hidden_states': recv_hidden_states,\n                'recv_expert_count': recv_expert_count,\n                'topk_idx': state['topk_idx'],\n                'topk_weights': state['topk_weights'],\n                'raw_hidden_shape': state['raw_hidden_shape'],\n                'handle': handle,\n                'moe_type': state['moe_type']\n            }\n            if use_event:\n                recv_state['event'] = event\n            else:\n                recv_state['hook'] = hook\n        elif moe_type == MoeType.Default:\n            hidden_states, topk_weights, topk_idx = moe_gather_inputs(state['hidden_states'],\n                                                                      state['topk_weights'],\n                                                                      state['topk_idx'],\n                                                                      group=self.gather_group)\n            recv_state = {\n                'hidden_states': hidden_states,\n                'topk_idx': topk_idx,\n                'topk_weights': topk_weights,\n                'moe_type': moe_type\n            }\n        else:\n            raise NotImplementedError(f'Not supported moe type: {moe_type}')\n        return recv_state\n\n    def gemm(self, state: Dict):\n        \"\"\"gemm.\"\"\"\n        moe_type = state['moe_type']\n        if moe_type == MoeType.DSAsyncPrefill:\n            if (state['recv_hidden_states'][0]\n                    if isinstance(state['recv_hidden_states'], tuple) else state['recv_hidden_states']).shape[0] > 0:\n                state['recv_hidden_states'] = state['fusedmoe'].fusedmoe_forward(state, self.gate_up.weight,\n                                                                                 self.gate_up.weight_scale_inv,\n                                                                                 self.down.weight,\n                                                                                 self.down.weight_scale_inv)\n            gemm_state = {\n                'fusedmoe': state['fusedmoe'],\n                'hidden_states': state['recv_hidden_states'],\n                'handle': state['handle'],\n                'moe_type': state['moe_type']\n            }\n        elif moe_type == MoeType.DSAsyncDecode:\n            state['recv_hidden_states'] = state['fusedmoe'].fusedmoe_forward(state, self.gate_up.weight,\n                                                                             self.gate_up.weight_scale_inv,\n                                                                             self.down.weight,\n                                                                             self.down.weight_scale_inv)\n            gemm_state = {\n                'fusedmoe': state['fusedmoe'],\n                'hidden_states': state['recv_hidden_states'],\n                'topk_idx': state['topk_idx'],\n                'topk_weights': state['topk_weights'],\n                'handle': state['handle'],\n                'moe_type': state['moe_type']\n            }\n        else:\n            hidden_states = state['hidden_states']\n            topk_weights = state['topk_weights']\n            topk_ids = state['topk_idx']\n\n            hidden_states = self.impl.forward(hidden_states,\n                                              topk_weights,\n                                              topk_ids,\n                                              self.gate_up.weight,\n                                              self.down.weight,\n                                              self.gate_up.bias,\n                                              self.down.bias,\n                                              self.expert_list,\n                                              act_func=self.act_func)\n            gemm_state = {'hidden_states': hidden_states, 'moe_type': state['moe_type']}\n        return gemm_state\n\n    def combine(self, state: Dict):\n        \"\"\"combine.\"\"\"\n        moe_type = state['moe_type']\n        if moe_type == MoeType.DSAsyncPrefill:\n            fusedmoe = state['fusedmoe']\n            previous_event = fusedmoe.capture()\n            out_hidden_states, event = fusedmoe.combine_async(state['hidden_states'],\n                                                              state['handle'],\n                                                              previous_event=previous_event,\n                                                              async_finish=True)\n            out_state = {\n                'fusedmoe': state['fusedmoe'],\n                'hidden_states': out_hidden_states,\n                'event': event,\n                'moe_type': state['moe_type']\n            }\n        elif moe_type == MoeType.DSAsyncDecode:\n            fusedmoe = state['fusedmoe']\n            use_event = False\n            out_hidden_states, event, hook = fusedmoe.combine_async(state['hidden_states'],\n                                                                    state['topk_idx'],\n                                                                    state['topk_weights'],\n                                                                    state['handle'],\n                                                                    async_finish=use_event)\n            out_state = {\n                'fusedmoe': state['fusedmoe'],\n                'hidden_states': out_hidden_states,\n                'moe_type': state['moe_type']\n            }\n            if use_event:\n                out_state['event'] = event\n            else:\n                out_state['hook'] = hook\n        elif moe_type == MoeType.Default:\n            if self.all_reduce:\n                state['hidden_states'] = moe_reduce(state['hidden_states'],\n                                                    rank=self.tp_rank,\n                                                    tp_mode=self.tp_mode,\n                                                    group=self.tp_group)\n            out_state = {'hidden_states': state['hidden_states'], 'moe_type': moe_type}\n        else:\n            raise NotImplementedError(f'Not supported moe type: {moe_type}')\n        return out_state\n\n    def wait(self, state: Dict):\n        \"\"\"wait.\"\"\"\n        if state.get('event', None) is not None:\n            state['fusedmoe'].wait(state['event'])\n            return True\n        elif state.get('hook', None) is not None:\n            state['hook']()\n            return True\n        else:\n            return False\n\n    def fusedmoe_build(self, low_latency_mode: bool = False):\n        return self.impl.fusedmoe_build(low_latency_mode)\n"
  },
  {
    "path": "lmdeploy/pytorch/nn/moe/route.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom typing import Tuple\n\nimport torch\n\nfrom lmdeploy.pytorch.backends import OpType, get_backend\n\n\nclass NoauxTCRouter(torch.nn.Module):\n\n    def __init__(\n        self,\n        scoring_func: str,\n        top_k: int,\n        n_group: int,\n        topk_group: int,\n        n_routed_experts: int,\n        routed_scaling_factor: float,\n        renormalize: bool = True,\n        router_n_groups: int = -1,\n    ):\n        super().__init__()\n\n        impl_builder = get_backend().get_layer_impl_builder(OpType.RouterNoauxTC)\n        self.impl = impl_builder.build(\n            scoring_func=scoring_func,\n            top_k=top_k,\n            n_group=n_group,\n            topk_group=topk_group,\n            n_routed_experts=n_routed_experts,\n            routed_scaling_factor=routed_scaling_factor,\n            renormalize=renormalize,\n            router_n_groups=router_n_groups,\n        )\n\n    def forward(self, router_logits: torch.Tensor,\n                e_score_correction_bias: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"Router forward.\"\"\"\n        return self.impl.forward(router_logits, e_score_correction_bias)\n"
  },
  {
    "path": "lmdeploy/pytorch/nn/moe/w8a8.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom typing import Dict, List, Optional\n\nimport torch\n\nfrom lmdeploy.pytorch.backends import OpType, get_backend\nfrom lmdeploy.pytorch.distributed import get_tp_world_rank\n\nfrom .base import FusedMoEBase, MoeType, moe_gather_inputs, moe_reduce, update_dims\nfrom .default import LinearWeights\n\n\nclass LinearWeightsW8A8(LinearWeights):\n    \"\"\"Fused moe linear w8a8 weights.\"\"\"\n\n    def __init__(self,\n                 num_experts: int,\n                 in_features: int,\n                 out_features: int,\n                 weight_type: str,\n                 device: torch.device,\n                 expert_list: List[int] = None,\n                 quant_dtype: torch.dtype = torch.int8):\n        super().__init__(\n            num_experts=num_experts,\n            in_features=in_features,\n            out_features=out_features,\n            weight_type=weight_type,\n            dtype=quant_dtype,\n            device=device,\n            expert_list=expert_list,\n        )\n        scale = torch.empty((num_experts, out_features, 1), dtype=torch.float32, device=device)\n        scale = torch.nn.Parameter(scale, requires_grad=False)\n        self.register_parameter('scale', scale)\n\n        if self.ep:\n            self.scale.weight_loader = self.weight_loader_ep\n        else:\n            self.scale.weight_loader = self.weight_loader_scale_tp\n\n    def update_weight(self, weight: torch.Tensor, scale: torch.Tensor):\n        \"\"\"Update weight.\"\"\"\n        super().update_weight(weight=weight)\n        weight_loader = self.scale.weight_loader\n        scale = torch.nn.Parameter(scale, requires_grad=False)\n        scale.weight_loader = weight_loader\n        self.register_parameter('scale', scale)\n\n    def weight_loader_scale_tp(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, expert_id: int,\n                               shard_id: str):\n        \"\"\"Weight loader scale tp.\"\"\"\n        world_size, rank = get_tp_world_rank('moe')\n        if shard_id == 'gate':\n            param_data = param.data[expert_id, :self.half_out]\n            weight = loaded_weight.chunk(world_size, dim=0)[rank]\n        elif shard_id == 'up':\n            param_data = param.data[expert_id, self.half_out:]\n            weight = loaded_weight.chunk(world_size, dim=0)[rank]\n        elif shard_id == 'down':\n            param_data = param.data[expert_id]\n            weight = loaded_weight\n        else:\n            raise RuntimeError(f'Unknown shard_id: {shard_id}')\n        weight = weight.to(param.dtype)\n        param_data.copy_(weight)\n\n\nclass FusedMoEW8A8(FusedMoEBase):\n    \"\"\"Fused moe w8a8.\"\"\"\n\n    def __init__(self,\n                 hidden_dim: int,\n                 ffn_dim: int,\n                 num_experts: int,\n                 top_k: int,\n                 renormalize: bool = False,\n                 dtype: Optional[torch.dtype] = None,\n                 quant_dtype: Optional[torch.dtype] = torch.int8,\n                 device: Optional[torch.device] = None,\n                 all_reduce: bool = True):\n\n        device = device or torch.device('cpu')\n        dtype = dtype or torch.float16\n        # init distributed tp arguments\n        self.init_dist_args(all_reduce)\n\n        # check ep\n        if self.ep > 1:\n            raise RuntimeError('FusedMoEW8A8 does not support EP mode now.')\n\n        super().__init__(\n            tp=self.tp,\n            tp_mode=self.tp_mode,\n            do_renormalize=renormalize,\n        )\n\n        # create implementation\n        impl_builder = get_backend().get_layer_impl_builder(OpType.FusedMoEW8A8)\n        self.impl = impl_builder.build(top_k, num_experts, renormalize, dtype, quant_dtype=quant_dtype)\n\n        # create weights\n        hidden_dim, ffn_dim = update_dims(hidden_dim, ffn_dim)\n        expert_list = None\n        self.expert_list = expert_list\n        self.gate_up = LinearWeightsW8A8(num_experts,\n                                         hidden_dim,\n                                         ffn_dim * 2,\n                                         weight_type='gate_up',\n                                         device=device,\n                                         expert_list=expert_list,\n                                         quant_dtype=quant_dtype)\n        self.down = LinearWeightsW8A8(num_experts,\n                                      ffn_dim,\n                                      hidden_dim,\n                                      weight_type='down',\n                                      device=device,\n                                      expert_list=expert_list,\n                                      quant_dtype=quant_dtype)\n\n        self.hidden_dim = hidden_dim\n        self.ffn_dim = ffn_dim\n        self.num_experts = num_experts\n        self.dtype = dtype\n        self.device = device\n        self.all_reduce = all_reduce\n\n    def update_weights(self):\n        \"\"\"Update weights.\"\"\"\n        (gate_up_weights, down_weights, gate_up_scale,\n         down_scale) = self.impl.update_weights(self.gate_up.weight, self.down.weight, self.gate_up.scale,\n                                                self.down.scale)\n        self.gate_up.update_weight(gate_up_weights, gate_up_scale)\n        self.down.update_weight(down_weights, down_scale)\n\n    def dispatch(self, state: Dict):\n        \"\"\"dispatch.\"\"\"\n        moe_type = state['moe_type']\n        if moe_type == MoeType.Default:\n            hidden_states, topk_weights, topk_idx = moe_gather_inputs(state['hidden_states'],\n                                                                      state['topk_weights'],\n                                                                      state['topk_idx'],\n                                                                      group=self.gather_group)\n            recv_state = {\n                'hidden_states': hidden_states,\n                'topk_idx': topk_idx,\n                'topk_weights': topk_weights,\n                'moe_type': moe_type\n            }\n        else:\n            raise NotImplementedError(f'Not supported moe type: {moe_type}')\n        return recv_state\n\n    def gemm(self, state: Dict):\n        \"\"\"gemm.\"\"\"\n        hidden_states = state['hidden_states']\n        topk_weights = state['topk_weights']\n        topk_ids = state['topk_idx']\n\n        ret = self.impl.forward(hidden_states, topk_weights, topk_ids, self.gate_up.weight, self.gate_up.scale,\n                                self.down.weight, self.down.scale, self.expert_list)\n        return dict(hidden_states=ret, moe_type=state['moe_type'])\n\n    def combine(self, state: Dict):\n        \"\"\"combine.\"\"\"\n        moe_type = state['moe_type']\n        if moe_type == MoeType.Default:\n            if self.all_reduce:\n                state['hidden_states'] = moe_reduce(state['hidden_states'],\n                                                    rank=self.tp_rank,\n                                                    tp_mode=self.tp_mode,\n                                                    group=self.tp_group)\n            out_state = {'hidden_states': state['hidden_states'], 'moe_type': moe_type}\n        else:\n            raise NotImplementedError(f'Not supported moe type: {moe_type}')\n        return out_state\n\n    def wait(self, state: Dict):\n        \"\"\"wait.\"\"\"\n        raise NotImplementedError\n"
  },
  {
    "path": "lmdeploy/pytorch/nn/multinomial_sampling.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport torch\n\nfrom ..backends import OpType, get_backend\n\n\ndef multinomial_sampling(scores: torch.Tensor,\n                         seeds: torch.LongTensor,\n                         offsets: torch.LongTensor,\n                         indices: torch.Tensor = None):\n    \"\"\"Multinomial sampling op.\"\"\"\n    impl_builder = get_backend().get_layer_impl_builder(OpType.MultinomialSampling)\n    return impl_builder.build().forward(scores, seeds, offsets, indices)\n"
  },
  {
    "path": "lmdeploy/pytorch/nn/norm.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom typing import Dict\n\nimport torch\nfrom torch import nn\n\nfrom lmdeploy.pytorch.distributed import get_tp_world_rank\nfrom lmdeploy.pytorch.models.patch import get_build_model_context\n\nfrom ..backends import OpType, get_backend\nfrom .utils import chunk_aligned, get_distribute_size\n\n\nclass RMSNorm(nn.Module):\n    \"\"\"RMS Norm with add residual.\"\"\"\n\n    def __init__(\n        self,\n        hidden_size: int,\n        eps: float = 1e-6,\n        dtype: torch.dtype | None = None,\n        device: torch.device | None = None,\n        quant_config: Dict | None = None,\n        tp: bool = False,\n        align: int = 1,\n        prefix: str = '',\n    ):\n        super().__init__()\n        backend = get_backend()\n\n        quant_method = None\n        if quant_config is not None:\n            quant_config = get_build_model_context().quant_config\n            quant_method = quant_config.get_quant_method(prefix)\n\n        w8a8_flag = quant_method == 'smooth_quant'\n\n        if w8a8_flag:\n            builder = backend.get_layer_impl_builder(OpType.RMSNormW8A8)\n        else:\n            builder = backend.get_layer_impl_builder(OpType.RMSNorm)\n\n        if tp:\n            world_size, rank = get_tp_world_rank('attn')\n            hidden_size = get_distribute_size(hidden_size, world_size, rank, align=align)\n\n        self.register_parameter('weight', self.create_weight(hidden_size, dtype, device))\n        if w8a8_flag:\n            self.impl = builder.build(hidden_size, eps, quant_dtype=quant_config.quant_dtype)\n        else:\n            self.impl = builder.build(hidden_size, eps)\n\n        if tp:\n            self.weight.weight_loader = self.weight_loader\n        self.align = align\n\n    def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):\n        \"\"\"Weight loader.\"\"\"\n        world_size, rank = get_tp_world_rank('attn')\n        loaded_weight = chunk_aligned(loaded_weight, world_size, 0, self.align)[rank]\n        param.copy_(loaded_weight)\n\n    @staticmethod\n    def create_weight(hidden_size: int, dtype: torch.dtype | None = None, device: torch.device | None = None):\n        \"\"\"Create weight.\"\"\"\n        if dtype is None:\n            dtype = torch.float16\n        if device is None:\n            device = 'cuda'\n        weight = torch.nn.Parameter(torch.ones(hidden_size, dtype=dtype, device=device), requires_grad=False)\n        return weight\n\n    def forward(self, x: torch.Tensor, residual: torch.Tensor = None):\n        \"\"\"forward.\"\"\"\n        return self.impl.forward(x, self.weight, residual)\n\n\nclass LayerNorm(nn.Module):\n    \"\"\"Layer Norm with add residual.\"\"\"\n\n    def __init__(self,\n                 hidden_size: int,\n                 eps: float = 1e-6,\n                 bias: bool = True,\n                 dtype: torch.dtype | None = None,\n                 device: torch.device | None = None):\n        super().__init__()\n        backend = get_backend()\n        builder = backend.get_layer_impl_builder(OpType.LayerNorm)\n        weight, bias = self.create_weight(hidden_size, bias, dtype, device)\n        self.register_parameter('weight', weight)\n        self.register_parameter('bias', bias)\n        self.impl = builder.build(hidden_size, eps)\n\n    @staticmethod\n    def create_weight(hidden_size: int,\n                      bias: bool = True,\n                      dtype: torch.dtype | None = None,\n                      device: torch.device | None = None):\n        \"\"\"Create weight.\"\"\"\n        if dtype is None:\n            dtype = torch.float16\n        if device is None:\n            device = 'cuda'\n        weight = torch.nn.Parameter(torch.ones(hidden_size, dtype=dtype, device=device), requires_grad=False)\n        if bias:\n            bias = torch.nn.Parameter(torch.ones(hidden_size, dtype=dtype, device=device), requires_grad=False)\n        else:\n            bias = None\n\n        return weight, bias\n\n    def forward(self, x: torch.Tensor, residual: torch.Tensor | None = None):\n        \"\"\"forward.\"\"\"\n        return self.impl.forward(x, self.weight, self.bias, residual)\n"
  },
  {
    "path": "lmdeploy/pytorch/nn/nsa.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom torch import Tensor, nn\n\nfrom lmdeploy.pytorch.backends import OpType, get_backend\nfrom lmdeploy.pytorch.backends.attention import AttentionMetadata\nfrom lmdeploy.pytorch.backends.nsa import NSAIndexMeta\nfrom lmdeploy.pytorch.model_inputs import get_step_ctx_manager\n\n\nclass IndexerTopKFP8(nn.Module):\n\n    def __init__(self, topk: int, softmax_scale: float, block_size: int = 128, fill: int = -1):\n        super().__init__()\n        backend = get_backend()\n        index_builder = backend.get_layer_impl_builder(OpType.NSAIndexFP8)\n        self.index_impl = index_builder.build(topk, softmax_scale, block_size, fill)\n\n    def forward(\n        self,\n        q: Tensor,\n        k: Tensor,\n        weights: Tensor,\n        k_cache: Tensor,\n        k_s_cache: Tensor,\n        attn_metadata: AttentionMetadata = None,\n    ):\n        \"\"\"forward.\"\"\"\n        step_ctx = get_step_ctx_manager().current_context()\n        cache_config = step_ctx.cache_config\n        max_tokens = cache_config.block_size * cache_config.num_gpu_blocks\n        is_decoding = attn_metadata.is_decoding\n        if q.size(0) == attn_metadata.kv_seqlens.size(0):\n            is_decoding = True\n        max_q_seqlen = 1 if is_decoding else q.size(0)\n        # we need to make max_kv_seqlen=max_allocated_cache_len to enable cudagraph\n        max_kv_seqlen = max_tokens if is_decoding else attn_metadata.kv_flatten_size\n        meta = NSAIndexMeta(cu_seqlen_q=attn_metadata.cu_seqlens_q,\n                            q_seqlens=attn_metadata.q_seqlens,\n                            k_seqlens=attn_metadata.kv_seqlens,\n                            block_offset=attn_metadata.block_offsets,\n                            max_q_seqlen=max_q_seqlen,\n                            max_kv_seqlen=max_kv_seqlen)\n        ret = self.index_impl.forward(q, k, weights, k_cache, k_s_cache, meta=meta)\n        return ret\n"
  },
  {
    "path": "lmdeploy/pytorch/nn/quant_utils.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom lmdeploy.lite.quantization.weight.quant_utils import quant_blocked_fp8  # noqa: F401\n"
  },
  {
    "path": "lmdeploy/pytorch/nn/rotary_embedding.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nimport math\n\nimport torch\nfrom torch import Tensor, nn\nfrom transformers import PretrainedConfig\n\nfrom ..backends import OpType, get_backend\nfrom ..backends.rotary_embedding import (FopeParameters, Llama3Parameters, LongRoPEScalingParameters, RopeType,\n                                         YarnParameters)\n\n\ndef get_rope_parameters(config: PretrainedConfig):\n    \"\"\"Try get rope parameters from config.\"\"\"\n    if hasattr(config, 'rope_parameters'):\n        # for transformers v5\n        return config.rope_parameters\n    else:\n        return getattr(config, 'rope_scaling', None)\n\n\ndef _get_default_rope_parameters(config: PretrainedConfig):\n    \"\"\"Get default rope parameters.\"\"\"\n    return dict(emb_type=RopeType.Default, scaling_factor=1.0)\n\n\ndef _get_linear_scaling_rope_parameters(config: PretrainedConfig):\n    \"\"\"Get linear rope parameters.\"\"\"\n    rope_scaling = get_rope_parameters(config=config)\n    scaling_factor = rope_scaling['factor']\n    return dict(emb_type=RopeType.LinearScaling, scaling_factor=scaling_factor)\n\n\ndef _get_dynamic_ntk_parameters(config: PretrainedConfig):\n    \"\"\"Get dynamic ntk parameters.\"\"\"\n    rope_scaling = get_rope_parameters(config=config)\n    scaling_factor = rope_scaling['factor']\n    return dict(emb_type=RopeType.DynamicNTKScaling, scaling_factor=scaling_factor)\n\n\ndef _get_yarn_parameters(config: PretrainedConfig):\n    \"\"\"Get yarn parameters.\"\"\"\n\n    def get_mscale(scale, mscale=1):\n        if scale <= 1:\n            return 1.0\n        return 0.1 * mscale * math.log(scale) + 1.0\n\n    rope_scaling = get_rope_parameters(config=config)\n    factor = rope_scaling['factor']\n    params = YarnParameters()\n    params.beta_fast = rope_scaling.get('beta_fast', params.beta_fast)\n    params.beta_slow = rope_scaling.get('beta_slow', params.beta_slow)\n    mscale = rope_scaling.get('mscale', params.mscale)\n    mscale_all_dim = rope_scaling.get('mscale_all_dim', params.mscale_all_dim)\n    truncate = rope_scaling.get('truncate', params.truncate)\n\n    if 'attention_factor' in rope_scaling:\n        attention_factor = rope_scaling.get('attention_factor')\n    else:\n        if mscale_all_dim and mscale:\n            attention_factor = float(get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dim))\n        else:\n            attention_factor = get_mscale(factor)\n\n    params.attention_factor = attention_factor\n    params.mscale = mscale\n    params.mscale_all_dim = mscale_all_dim\n    params.truncate = truncate\n\n    ret = dict(emb_type=RopeType.Yarn, scaling_factor=factor, yarn_params=params)\n    if 'original_max_position_embeddings' in rope_scaling:\n        ret['max_position_embeddings'] = rope_scaling['original_max_position_embeddings']\n    return ret\n\n\ndef _get_longrope_parameters(config: PretrainedConfig):\n    \"\"\"Get longrope parameters.\"\"\"\n    rope_scaling = get_rope_parameters(config=config)\n    scaling_factor = rope_scaling.get('factor', 1.0)\n    long_factor = rope_scaling['long_factor']\n    short_factor = rope_scaling['short_factor']\n    original_max_position_embeddings = getattr(config, 'original_max_position_embeddings',\n                                               config.max_position_embeddings)\n    original_max_position_embeddings = rope_scaling.get('original_max_position_embeddings',\n                                                        original_max_position_embeddings)\n    params = LongRoPEScalingParameters(\n        long_factor=long_factor,\n        short_factor=short_factor,\n        original_max_position_embeddings=original_max_position_embeddings,\n    )\n    return dict(emb_type=RopeType.LongRoPEScaling, scaling_factor=scaling_factor, longrope_params=params)\n\n\ndef _get_llama3_parameters(config: PretrainedConfig):\n    \"\"\"Get llama rope parameters.\"\"\"\n    rope_scaling = get_rope_parameters(config=config)\n    params = Llama3Parameters()\n    scaling_factor = rope_scaling['factor']\n    params.low_freq_factor = rope_scaling['low_freq_factor']\n    params.high_freq_factor = rope_scaling['high_freq_factor']\n    params.original_max_position_embeddings = rope_scaling.get('original_max_position_embeddings',\n                                                               params.original_max_position_embeddings)\n    return dict(emb_type=RopeType.Llama3, scaling_factor=scaling_factor, llama3_params=params)\n\n\ndef _get_fope_parameters(config: PretrainedConfig):\n    \"\"\"Get fope parameters.\"\"\"\n    # check if fope is used\n    rope_scaling = getattr(config, 'rope_scaling', dict())\n    fope_keys = ['fope_sep_head', 'fope_num_inv_freq']\n    is_fope = any(key in rope_scaling for key in fope_keys)\n    if not is_fope:\n        return dict()\n\n    params = FopeParameters()\n    rope_scaling = get_rope_parameters(config=config)\n    params.num_inv_freq = rope_scaling.get('fope_num_inv_freq', rope_scaling.get('num_inv_freq', params.num_inv_freq))\n    params.num_key_value_heads = config.num_key_value_heads\n    params.fope_sep_head = rope_scaling['fope_sep_head']\n    return dict(fope_params=params)\n\n\ndef build_rotary_params(config: PretrainedConfig):\n    \"\"\"Get scaling_factor rotary params, and emb_type.\"\"\"\n    params = dict(emb_type=RopeType.Default)\n    # cannot access config.rope_scaling when the model is \"Qwen/Qwen2-Math-RM-72B\"\n    rope_scaling = get_rope_parameters(config=config)\n    if rope_scaling is not None:\n        # BC: \"rope_type\" was originally \"type\"\n        rope_type_str = rope_scaling.get('rope_type', rope_scaling.get('type', 'default'))\n        if rope_type_str == 'fope':\n            rope_type_str = 'default'\n        build_funcs = dict(default=_get_default_rope_parameters,\n                           linear=_get_linear_scaling_rope_parameters,\n                           dynamic=_get_dynamic_ntk_parameters,\n                           yarn=_get_yarn_parameters,\n                           longrope=_get_longrope_parameters,\n                           su=_get_longrope_parameters,\n                           llama3=_get_llama3_parameters)\n        params.update(build_funcs[rope_type_str](config))\n        params.update(_get_fope_parameters(config))\n\n    # update partial_rotary_factor\n    partial_rotary_factor = config.partial_rotary_factor if hasattr(config, 'partial_rotary_factor') else None\n    if partial_rotary_factor is not None:\n        params['partial_rotary_factor'] = partial_rotary_factor\n\n    return params\n\n\ndef build_rotary_embedding(dim: int,\n                           max_position_embeddings: int = 2048,\n                           base: int = 10000,\n                           scaling_factor: float = 1.0,\n                           yarn_params: YarnParameters = None,\n                           longrope_params: LongRoPEScalingParameters = None,\n                           llama3_params: Llama3Parameters = None,\n                           fope_params: FopeParameters = None,\n                           emb_type: RopeType = RopeType.Default,\n                           partial_rotary_factor: float = None,\n                           device: torch.device = None) -> nn.Module:\n    \"\"\"Build rotary embedding op.\"\"\"\n    backend = get_backend()\n\n    builder = backend.get_layer_impl_builder(OpType.RotaryEmbedding)\n\n    # update rope_dim\n    if partial_rotary_factor is not None:\n        dim = int(dim * partial_rotary_factor)\n    impl = builder.build(dim,\n                         max_position_embeddings,\n                         base,\n                         scaling_factor,\n                         yarn_params=yarn_params,\n                         longrope_params=longrope_params,\n                         llama3_params=llama3_params,\n                         emb_type=emb_type)\n\n    if fope_params is not None:\n        inv_freq = impl.inv_freq\n        fope_params.inv_freq = inv_freq\n        fope = FopeRotaryEmbedding(dim, max_position_embeddings, scaling_factor, fope_params, device)\n        return fope\n\n    return impl\n\n\ndef get_rope_theta(config: PretrainedConfig, default: int = 10000) -> int:\n    \"\"\"Get rope theta from config.\"\"\"\n    if hasattr(config, 'rope_parameters'):\n        # for transformers v5\n        rope_base = config.rope_parameters.get('rope_theta', default)\n    else:\n        rope_base = getattr(config, 'rope_theta', default)\n    return rope_base\n\n\ndef build_rotary_embedding_from_config(config: PretrainedConfig, device: torch.device = None) -> nn.Module:\n    \"\"\"Build rotary embedding op from config.\"\"\"\n    emb_type = RopeType.LinearScaling\n    rope_dim = getattr(config, 'head_dim', None)\n    if rope_dim is None:\n        rope_dim = config.hidden_size // config.num_attention_heads\n    rope_max_pos_emb = config.max_position_embeddings\n\n    rope_base = get_rope_theta(config, default=10000)\n    rope_params = dict(emb_type=emb_type, dim=rope_dim, max_position_embeddings=rope_max_pos_emb, base=rope_base)\n    update_params = build_rotary_params(config)\n    rope_params.update(update_params)\n    return build_rotary_embedding(**rope_params, device=device)\n\n\nclass ApplyRotaryEmb(nn.Module):\n    \"\"\"Apply rotary embedding.\"\"\"\n\n    def __init__(self):\n        super().__init__()\n        backend = get_backend()\n        builder = backend.get_layer_impl_builder(OpType.ApplyRotaryEmb)\n        self.impl = builder.build()\n\n    def forward(self, query: Tensor, key: Tensor, cos: Tensor, sin: Tensor, inplace: bool = True):\n        \"\"\"forward.\"\"\"\n\n        assert cos.dim() <= 3 and sin.dim() <= 3\n\n        need_reshape = False\n        if cos.dim() == 3:\n            # for fope\n            assert query.dim() == key.dim() == 3, 'Expected query key (seq_len, heads, head_dim)'\n            need_reshape = True\n            query_shape = query.shape\n            key_shape = key.shape\n            cos = cos.flatten(0, 1)\n            sin = sin.flatten(0, 1)\n            seq_len = cos.size(0)\n            query = query.view(seq_len, -1, query.size(-1))\n            key = key.view(seq_len, -1, key.size(-1))\n\n        query, key = self.impl.forward(query, key, cos, sin, inplace)\n\n        if need_reshape:\n            query = query.view(query_shape)\n            key = key.view(key_shape)\n        return query, key\n\n\nclass FopeRotaryEmbedding(nn.Module):\n    \"\"\"Fope rotary embedding.\"\"\"\n\n    def __init__(self,\n                 dim: int,\n                 max_position_embeddings: int,\n                 attention_scaling: float,\n                 params: FopeParameters,\n                 device: torch.device = None):\n        super().__init__()\n\n        num_key_value_heads, tp = self.update_num_kv_heads(params.num_key_value_heads)\n        self.tp = tp\n        params.num_key_value_heads = num_key_value_heads\n\n        # build impl\n        backend = get_backend()\n        builder = backend.get_layer_impl_builder(OpType.RotaryEmbedding)\n        self.impl = builder.build(dim,\n                                  max_position_embeddings=max_position_embeddings,\n                                  scaling_factor=attention_scaling,\n                                  fope_params=params,\n                                  emb_type=RopeType.Fope)\n\n        # setup params\n        inv_freq = self.impl.inv_freq\n        self.input_dim = inv_freq.shape[-1]\n        self.output_dim = inv_freq.shape[-1]\n        self.cos_coef = nn.Parameter(torch.empty(num_key_value_heads, self.input_dim, self.output_dim, device=device),\n                                     requires_grad=False)\n        self.sin_coef = nn.Parameter(torch.empty(num_key_value_heads, self.input_dim, self.output_dim, device=device),\n                                     requires_grad=False)\n        if self.tp:\n            self.cos_coef.weight_loader = self.weight_loader\n            self.sin_coef.weight_loader = self.weight_loader\n\n    @staticmethod\n    def update_num_kv_heads(num_key_value_heads: int):\n        \"\"\"Update num_key_value_heads.\"\"\"\n        from lmdeploy.pytorch.distributed import get_dist_manager\n        dist_mgr = get_dist_manager()\n        dist_ctx = dist_mgr.current_context()\n        tp = dist_ctx.dist_config.attn_tp\n        # tp = dist_ctx.dist_config.attn_config.tp\n        if tp > 1:\n            num_key_value_heads = max(1, num_key_value_heads // tp)\n        return num_key_value_heads, tp\n\n    def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):\n        \"\"\"Weight loader.\"\"\"\n        from lmdeploy.pytorch.distributed import get_tp_world_rank\n        world_size, rank = get_tp_world_rank()\n        num_key_value_heads = loaded_weight.size(0)\n\n        if num_key_value_heads < world_size:\n            n_replicate = world_size // num_key_value_heads\n            world_size = num_key_value_heads\n            rank = rank // n_replicate\n\n        loaded_weight = loaded_weight.chunk(world_size, dim=0)[rank]\n        param.copy_(loaded_weight)\n\n    def forward(self, x: Tensor, position_ids: Tensor):\n        \"\"\"forward.\"\"\"\n        return self.impl.forward(x, position_ids, sin_coef=self.sin_coef, cos_coef=self.cos_coef)\n"
  },
  {
    "path": "lmdeploy/pytorch/nn/utils.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport torch\n\n\ndef div_up(a: int, b: int):\n    \"\"\"Div up.\"\"\"\n    return (a + b - 1) // b\n\n\ndef get_distribute_size(feature_size: int, world_size: int, rank: int, align: int = 1):\n    \"\"\"Update feature size.\"\"\"\n    assert feature_size % align == 0\n    aligned_size = feature_size // align\n    # try to make every rank has same amount of feats\n    updated_aligned_size = aligned_size // world_size\n    # if there are still some remain, given them to\n    # each rank\n    if rank < aligned_size % world_size:\n        updated_aligned_size += 1\n    return updated_aligned_size * align\n\n\ndef chunk_aligned(weight: torch.Tensor, chunks: int, dim: int, align: int):\n    \"\"\"Chunk aligned.\"\"\"\n    if align == 1:\n        return weight.chunk(chunks, dim=dim)\n    size = weight.size(dim)\n    assert size % align == 0\n    aligned_size = size // align\n\n    # try best to evenly split chunks\n    align_per_chunk = aligned_size // chunks\n    remain = aligned_size % chunks\n    sections = [align_per_chunk + int(c < remain) for c in range(chunks)]\n    sections = [sec * align for sec in sections]\n    return weight.split(sections, dim=dim)\n"
  },
  {
    "path": "lmdeploy/pytorch/paging/__init__.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom .scheduler import Scheduler\n\n__all__ = ['Scheduler']\n"
  },
  {
    "path": "lmdeploy/pytorch/paging/block_manager/__init__.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom ...config import CacheConfig\nfrom .base_block_manager import BaseBlockManager\nfrom .default_block_manager import DefaultBlockManager\nfrom .window_block_manager import WindowBlockManager\n\n\ndef build_block_manager(cache_config: CacheConfig) -> BaseBlockManager:\n    \"\"\"Build block manager.\n\n    Args:\n        cache_config (CacheConfig):  cache_config.\n    \"\"\"\n\n    num_cpu_blocks = cache_config.num_cpu_blocks\n    num_gpu_blocks = cache_config.num_gpu_blocks\n    window_size = cache_config.window_size\n    num_gpu_reserved = cache_config.num_reserved_gpu_blocks\n\n    if window_size < 0:\n        return DefaultBlockManager(num_gpu_blocks, num_cpu_blocks, num_gpu_reserved=num_gpu_reserved)\n    else:\n        return WindowBlockManager(num_gpu_blocks,\n                                  num_cpu_blocks,\n                                  window_size=window_size,\n                                  num_gpu_reserved=num_gpu_reserved)\n"
  },
  {
    "path": "lmdeploy/pytorch/paging/block_manager/base_block_manager.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport time\nfrom typing import Dict\n\nimport numpy as np\n\nfrom ...messages import SchedulerSequence\n\n\nclass LogicalMemory:\n    \"\"\"Logical memory blocks.\"\"\"\n\n    def __init__(self, num_blocks: int) -> None:\n        self._num_blocks = num_blocks\n\n        self.phy_map: np.ndarray = np.zeros(self._num_blocks, dtype=np.int64)\n        self.ref_count: np.ndarray = np.zeros((self._num_blocks, ), dtype=np.int64)\n        self.access_time: np.ndarray = np.zeros((self._num_blocks, ), dtype=np.int64)\n\n    def get_physical_blocks(self, logical_address: np.ndarray):\n        \"\"\"Get physical address.\"\"\"\n        if isinstance(logical_address, np.ndarray) and len(logical_address) == 0:\n            return np.empty((0, ), dtype=np.int64)\n        return self.phy_map[logical_address]\n\n    def num_blocks(self):\n        \"\"\"Get num blocks.\"\"\"\n        return self._num_blocks\n\n\nclass PhysicalAllocator:\n    \"\"\"The physical block allocator.\n\n    The allocator won't allocate real memory. It is used to support block manager.\n    \"\"\"\n\n    def __init__(self, num_blocks: int, offset: int = 0):\n        self._num_blocks = num_blocks\n        self._offset = offset\n\n        self._free_blocks = np.arange(num_blocks, dtype=np.int64) + offset\n        self._free_count = num_blocks\n\n    def allocate(self, num_blocks: int):\n        \"\"\"Allocate block from block pool.\"\"\"\n        if self.get_num_free_blocks() >= num_blocks:\n            num_used = self._num_blocks - self._free_count\n            blocks = self._free_blocks[num_used:num_used + num_blocks]\n            self._free_count -= num_blocks\n            return blocks\n        else:\n            raise MemoryError('No enough free memory blocks.')\n\n    def free(self, blocks: np.ndarray):\n        \"\"\"Free block to block pool.\"\"\"\n        freed_blocks = blocks\n        num_freed_blocks = len(freed_blocks)\n        if num_freed_blocks > 0:\n            num_used = self._num_blocks - self._free_count\n            self._free_blocks[num_used - num_freed_blocks:num_used] = freed_blocks\n            self._free_count += num_freed_blocks\n        return freed_blocks\n\n    def get_num_free_blocks(self):\n        \"\"\"Get numbers of free blocks.\"\"\"\n        return self._free_count\n\n\nclass LogicalAllocator:\n    \"\"\"The logical block allocator.\"\"\"\n\n    def __init__(self, num_cpu_blocks: int, num_gpu_blocks: int, num_gpu_reserved: int = 0) -> None:\n        self._log_mem = LogicalMemory(num_cpu_blocks + num_gpu_blocks)\n\n        self._cpu_mem_offset = num_gpu_blocks\n        num_gpu_blocks -= num_gpu_reserved\n        self._gpu_allocator = PhysicalAllocator(num_gpu_blocks, num_gpu_reserved)\n        self._cpu_allocator = PhysicalAllocator(num_cpu_blocks, self._cpu_mem_offset)\n\n        num_blocks = self._log_mem.num_blocks()\n        self._num_blocks = num_blocks\n        self._free_blocks = np.arange(num_blocks)\n        self._free_count = num_blocks\n\n    def get_phy_allocator(self, device: str):\n        \"\"\"Get allocator.\"\"\"\n        if device == 'gpu':\n            return self._gpu_allocator\n        elif device == 'cpu':\n            return self._cpu_allocator\n        else:\n            raise ValueError(f'Unsupported device: {device}')\n\n    def allocate(self, num_blocks: int, device: str = 'gpu'):\n        \"\"\"Allocate logical blocks.\"\"\"\n        if num_blocks == 0:\n            return np.empty((0, ), dtype=np.int64)\n        phy_allocator = self.get_phy_allocator(device)\n        logical_enable = self.get_num_free_blocks() >= num_blocks\n        physical_enable = phy_allocator.get_num_free_blocks() >= num_blocks\n        if logical_enable and physical_enable:\n            num_used = self._num_blocks - self._free_count\n            blocks = self._free_blocks[num_used:num_used + num_blocks]\n            phy_blocks = phy_allocator.allocate(num_blocks)\n            self._log_mem.phy_map.put(blocks, phy_blocks)\n            self._log_mem.ref_count.put(blocks, 1)\n            self.update_access_time(blocks)\n            self._free_count -= num_blocks\n            return blocks.copy()\n        else:\n            raise MemoryError('No enough free memory blocks.')\n\n    def free(self, blocks: np.ndarray):\n        \"\"\"Free logical block.\"\"\"\n\n        self.add_ref_count(blocks, -1)\n        self.update_access_time(blocks)\n        ref_count = self.get_ref_count(blocks)\n        freed_blocks = blocks[ref_count == 0]\n        num_freed_blocks = len(freed_blocks)\n        if num_freed_blocks <= 0:\n            return\n\n        # free logical\n        num_used = self._num_blocks - self._free_count\n        self._free_blocks[num_used - num_freed_blocks:num_used] = freed_blocks\n        self._free_count += num_freed_blocks\n\n        # free physical\n        phy_blocks = self.get_physical_blocks(freed_blocks)\n\n        cpu_blocks = phy_blocks[phy_blocks >= self._cpu_mem_offset]\n        gpu_blocks = phy_blocks[phy_blocks < self._cpu_mem_offset]\n        if len(cpu_blocks) > 0:\n            self._cpu_allocator.free(cpu_blocks)\n        if len(gpu_blocks) > 0:\n            self._gpu_allocator.free(gpu_blocks)\n\n    def get_num_free_blocks(self):\n        \"\"\"Get numbers of free blocks.\"\"\"\n        return self._free_count\n\n    def get_physical_blocks(self, blocks: np.ndarray):\n        \"\"\"Get physical address.\"\"\"\n        return self._log_mem.get_physical_blocks(blocks)\n\n    def get_ref_count(self, blocks: np.ndarray):\n        \"\"\"Get ref count.\"\"\"\n        return self._log_mem.ref_count[blocks]\n\n    def add_ref_count(self, blocks: np.ndarray, value: np.ndarray):\n        \"\"\"Update ref count.\"\"\"\n        np.add.at(self._log_mem.ref_count, blocks, value)\n\n    def get_access_time(self, blocks: np.ndarray):\n        \"\"\"Get access time.\"\"\"\n        return self._log_mem.access_time[blocks]\n\n    def update_access_time(self, blocks: np.ndarray):\n        \"\"\"Update access time.\"\"\"\n        now = time.perf_counter()\n        self._log_mem.access_time[blocks] = now\n\n    def cpu_mem_offset(self):\n        \"\"\"Get cpu mem offset in unified physical memory.\"\"\"\n        return self._cpu_mem_offset\n\n    def count_cpu_blocks(self, blocks: np.ndarray):\n        \"\"\"Count cpu blocks.\"\"\"\n        phy_blocks = self.get_physical_blocks(blocks)\n        return np.count_nonzero(phy_blocks >= self.cpu_mem_offset())\n\n    def count_gpu_blocks(self, blocks: np.ndarray):\n        \"\"\"Count gpu blocks.\"\"\"\n        phy_blocks = self.get_physical_blocks(blocks)\n        return np.count_nonzero(phy_blocks < self.cpu_mem_offset())\n\n    def update_phy_map(self, log_blocks: np.ndarray, phy_blocks: np.ndarray):\n        \"\"\"Update physical map.\"\"\"\n        assert len(phy_blocks) == len(log_blocks)\n        self._log_mem.phy_map.put(log_blocks, phy_blocks)\n\n    def on_device(self, blocks: np.ndarray, device: str):\n        \"\"\"Blocks on given device.\"\"\"\n        if len(blocks) == 0:\n            return False\n\n        # TODO: check all blocks\n        cpu_mem_offset = self.cpu_mem_offset()\n\n        phy_blocks = self.get_physical_blocks(blocks[:1])\n        if phy_blocks[0] < cpu_mem_offset:\n            phy_device = 'gpu'\n        else:\n            phy_device = 'cpu'\n        return device == phy_device\n\n\nBlockTable = np.ndarray\n\n\nclass BaseBlockManager:\n    \"\"\"ABC of block manager.\n\n    Args:\n        num_gpu_blocks (int): number of gpu blocks.\n        num_cpu_blocks (int): number of cpu blocks.\n    \"\"\"\n\n    def __init__(self, num_gpu_blocks: int, num_cpu_blocks: int, num_gpu_reserved: int = 0) -> None:\n        self.num_gpu_blocks = num_gpu_blocks\n        self.num_cpu_blocks = num_cpu_blocks\n\n        self.allocator = LogicalAllocator(num_cpu_blocks, num_gpu_blocks, num_gpu_reserved)\n\n        self.block_tables: Dict[int, BlockTable] = {}\n\n    @classmethod\n    def num_required_blocks(cls, obj: SchedulerSequence, prealloc_size: int = 0):\n        \"\"\"Get num required blocks.\"\"\"\n        raise NotImplementedError('Not implemented.')\n\n    def can_allocate(self, msg: SchedulerSequence, prealloc_size: int = 0):\n        \"\"\"Return if physical block can be allocated for given message.\"\"\"\n        raise NotImplementedError('Not implemented.')\n\n    def allocate_msg(self, msg: SchedulerSequence, prealloc_size: int = 0):\n        \"\"\"Allocate physical blocks for given message according to logical\n        blocks.\"\"\"\n        raise NotImplementedError('Not implemented.')\n\n    def free(self, msg: SchedulerSequence):\n        \"\"\"Free all physical blocks allocated for the session.\"\"\"\n        raise NotImplementedError('Not implemented.')\n\n    def try_swap_out(self, msg: SchedulerSequence):\n        \"\"\"Try swap msg out.\"\"\"\n        raise NotImplementedError('Not implemented.')\n\n    def try_swap_in(self, msg: SchedulerSequence):\n        \"\"\"Try swap msg in.\"\"\"\n        raise NotImplementedError('Not implemented.')\n\n    def get_block_table(self, msg: SchedulerSequence):\n        \"\"\"Get the block table of given msg.\n\n        Args:\n            msg (SchedulerSequence): The msg to get block table.\n        \"\"\"\n        logical_blocks = msg.logical_blocks\n        return self.allocator.get_physical_blocks(logical_blocks.get_real_blocks())\n\n    def allocate(self, data: SchedulerSequence, prealloc_size: int = 0):\n        \"\"\"Allocate stuff.\"\"\"\n        return self.allocate_msg(data, prealloc_size)\n\n    def get_num_free_gpu_blocks(self) -> int:\n        \"\"\"Get number of free gpu blocks.\"\"\"\n        return self.allocator.get_phy_allocator('gpu').get_num_free_blocks()\n\n    def get_num_free_cpu_blocks(self) -> int:\n        \"\"\"Get number of free cpu blocks.\"\"\"\n        return self.allocator.get_phy_allocator('cpu').get_num_free_blocks()\n\n    def on_device(self, msg: SchedulerSequence, device: str):\n        allocator = self.allocator\n        logical_blocks = msg.logical_blocks\n        return allocator.on_device(logical_blocks.get_real_blocks(), device)\n"
  },
  {
    "path": "lmdeploy/pytorch/paging/block_manager/default_block_manager.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n# modify from: https://github.com/vllm-project/vllm\nimport numpy as np\n\nfrom ...messages import SchedulerSequence\nfrom .base_block_manager import BaseBlockManager\n\n\ndef _div_up(x, n):\n    \"\"\"Perform div up.\"\"\"\n    return (x + n - 1) // n\n\n\nBlockTable = np.ndarray\n\n\nclass DefaultBlockManager(BaseBlockManager):\n    \"\"\"Manage the usage of blocks, generate block tables.\n\n    Args:\n        num_gpu_blocks (int): number of gpu blocks.\n        num_cpu_blocks (int): number of cpu blocks.\n    \"\"\"\n\n    @classmethod\n    def num_required_blocks(cls, obj: SchedulerSequence, prealloc_size: int = 0):\n        \"\"\"Get num required blocks.\"\"\"\n        num_tokens = obj.num_all_ids + prealloc_size\n\n        num_all_blocks = _div_up(num_tokens, obj.block_size)\n        return max(0, num_all_blocks - len(obj.logical_blocks))\n\n    def can_allocate(self, msg: SchedulerSequence, prealloc_size: int = 0):\n        \"\"\"Return if physical block can be allocated for given message.\"\"\"\n        num_required_blocks = self.num_required_blocks(msg, prealloc_size)\n        num_free_phy = self.get_num_free_gpu_blocks()\n        return num_required_blocks <= num_free_phy\n\n    def allocate_msg(self, msg: SchedulerSequence, prealloc_size: int = 0):\n        \"\"\"Allocate physical blocks for given message according to logical\n        blocks.\"\"\"\n        logical_blocks = msg.logical_blocks\n        num_required_blocks = self.num_required_blocks(msg, prealloc_size)\n        if num_required_blocks > 0:\n            blocks = self.allocator.allocate(num_required_blocks, 'gpu')\n            logical_blocks.append(blocks)\n\n    def free(self, msg: SchedulerSequence):\n        \"\"\"Free all physical blocks allocated for the session.\"\"\"\n        self.allocator.free(msg.logical_blocks.get_real_blocks())\n        msg.logical_blocks.reset()\n\n    def try_swap_out(self, msg: SchedulerSequence):\n        \"\"\"Try swap msg out.\"\"\"\n        swap_map = dict()\n        logical_blocks = msg.logical_blocks\n        cpu_mem_offset = self.allocator.cpu_mem_offset()\n        phy_blocks = self.allocator.get_physical_blocks(logical_blocks)\n        cpu_allocator = self.allocator.get_phy_allocator('cpu')\n        gpu_allocator = self.allocator.get_phy_allocator('gpu')\n\n        def _can_swap():\n            \"\"\"Check swap.\"\"\"\n            if len(logical_blocks) == 0:\n                return False\n\n            # we only support all blocks of a sequence on same device\n            if phy_blocks[0] >= cpu_mem_offset:\n                return False\n\n            # no free blocks\n            num_free = self.get_num_free_cpu_blocks()\n            if num_free < len(phy_blocks):\n                return False\n\n            # don't swap sequence with multiple reference\n            ref_count = self.allocator.get_ref_count(logical_blocks)\n            if np.count_nonzero(ref_count != 1) > 0:\n                return False\n\n            return True\n\n        def _do_swap():\n            \"\"\"Perform swap.\"\"\"\n            new_blocks = cpu_allocator.allocate(len(logical_blocks))\n\n            old_blocks = phy_blocks\n            swap_map = dict(zip(old_blocks, new_blocks - self.num_gpu_blocks))\n\n            gpu_allocator.free(old_blocks)\n            self.allocator.update_phy_map(logical_blocks.get_real_blocks(), new_blocks)\n            return True, swap_map\n\n        if not _can_swap():\n            return False, swap_map\n        else:\n            return _do_swap()\n\n    def try_swap_in(self, msg: SchedulerSequence):\n        \"\"\"Try swap msg in.\"\"\"\n        swap_map = dict()\n        logical_blocks = msg.logical_blocks\n        cpu_mem_offset = self.allocator.cpu_mem_offset()\n        phy_blocks = self.allocator.get_physical_blocks(logical_blocks)\n        cpu_allocator = self.allocator.get_phy_allocator('cpu')\n        gpu_allocator = self.allocator.get_phy_allocator('gpu')\n\n        def _can_swap():\n            \"\"\"Check swap.\"\"\"\n            if len(logical_blocks) == 0:\n                return False\n\n            # we only support all blocks of a sequence on same device\n            if phy_blocks[0] < cpu_mem_offset:\n                return False\n\n            # no free blocks\n            num_free = self.get_num_free_gpu_blocks()\n            if num_free < len(phy_blocks):\n                return False\n\n            # don't swap sequence with multiple reference\n            ref_count = self.allocator.get_ref_count(logical_blocks)\n            if np.count_nonzero(ref_count != 1) > 0:\n                return False\n\n            return True\n\n        def _do_swap():\n            \"\"\"Perform swap.\"\"\"\n            new_blocks = gpu_allocator.allocate(len(logical_blocks))\n\n            old_blocks = phy_blocks\n            swap_map = dict(zip(old_blocks - self.num_gpu_blocks, new_blocks))\n\n            cpu_allocator.free(old_blocks)\n            self.allocator.update_phy_map(logical_blocks.get_real_blocks(), new_blocks)\n            return True, swap_map\n\n        if not _can_swap():\n            return False, swap_map\n        else:\n            return _do_swap()\n"
  },
  {
    "path": "lmdeploy/pytorch/paging/block_manager/window_block_manager.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport numpy as np\n\nfrom ...block import LogicalTokenBlocks\nfrom ...messages import SchedulerSequence\nfrom .default_block_manager import DefaultBlockManager\n\nBlockTable = np.ndarray\n\n\ndef _num_blocks_to_drop(seq: SchedulerSequence, window_size: int):\n    \"\"\"Num blocks to free.\"\"\"\n    history_len = seq.num_history_ids\n    if seq.num_history_ids <= window_size:\n        return 0\n    block_size = seq.block_size\n    num_blocks = len(seq.logical_blocks)\n    win_start_block_id = (history_len - window_size) // block_size\n    win_end_block_id = (history_len - 1) // block_size\n    num_win_blocks = win_end_block_id - win_start_block_id + 1\n    return max(0, num_blocks - num_win_blocks)\n\n\nclass WindowBlockManager(DefaultBlockManager):\n    \"\"\"Manage the usage of blocks, generate block tables.\n\n    Args:\n        num_gpu_blocks (int): number of gpu blocks.\n        num_cpu_blocks (int): number of cpu blocks.\n    \"\"\"\n\n    def __init__(self, num_gpu_blocks: int, num_cpu_blocks: int, window_size: int, num_gpu_reserved: int = 0):\n        super().__init__(num_gpu_blocks, num_cpu_blocks, num_gpu_reserved)\n        assert window_size > 0, ('expect window size > 0, '\n                                 f'but get window_size = {window_size}')\n        self.window_size = window_size\n\n    def num_required_blocks(self, obj: SchedulerSequence, prealloc_size: int = 0):\n        \"\"\"Get num required blocks.\"\"\"\n\n        # blocks is not enough\n        if obj.num_history_ids <= self.window_size:\n            return super().num_required_blocks(obj, prealloc_size)\n\n        return super().num_required_blocks(obj, prealloc_size) - obj.num_ignored_history // obj.block_size\n\n    def can_allocate(self, msg: SchedulerSequence, prealloc_size: int = 0):\n        \"\"\"Return if physical block can be allocated for given message.\"\"\"\n        num_drop_blocks = _num_blocks_to_drop(msg, self.window_size)\n        num_required_blocks = self.num_required_blocks(msg, prealloc_size)\n        num_free_phy = self.get_num_free_gpu_blocks()\n        return num_required_blocks <= num_free_phy + num_drop_blocks\n\n    def allocate_msg(self, msg: SchedulerSequence, prealloc_size: int = 0):\n        \"\"\"Allocate physical blocks for given message according to logical\n        blocks.\"\"\"\n        logical_blocks = msg.logical_blocks\n\n        def __get_droped_blocks(num_drop_blocks):\n            \"\"\"Get dropped blocks.\"\"\"\n            nonlocal logical_blocks\n            droped_blocks = None\n            if num_drop_blocks > 0:\n                remain_blocks = logical_blocks[num_drop_blocks:]\n                droped_blocks = logical_blocks[:num_drop_blocks]\n                logical_blocks = LogicalTokenBlocks(remain_blocks)\n                msg.logical_blocks = logical_blocks\n            return droped_blocks\n\n        def __reuse_droped_blocks(num_required_blocks, num_drop_blocks, droped_blocks):\n            \"\"\"Reuse dropped blocks.\"\"\"\n            num_used_blocks = min(num_drop_blocks - num_required_blocks, num_required_blocks)\n            if num_used_blocks > 0:\n                reused_blocks = droped_blocks[:num_used_blocks]\n            else:\n                reused_blocks = droped_blocks\n            logical_blocks.append(reused_blocks)\n\n            if num_used_blocks > 0:\n                droped_blocks = droped_blocks[num_used_blocks:]\n            else:\n                num_used_blocks = num_drop_blocks\n                droped_blocks = None\n            num_required_blocks = num_required_blocks - num_used_blocks\n            return num_required_blocks, droped_blocks\n\n        num_drop_blocks = _num_blocks_to_drop(msg, self.window_size)\n        num_required_blocks = self.num_required_blocks(msg, prealloc_size)\n        msg.num_ignored_history += num_drop_blocks * msg.block_size\n\n        droped_blocks = __get_droped_blocks(num_drop_blocks)\n\n        if num_required_blocks > 0:\n            if num_drop_blocks > 0:\n                num_required_blocks, droped_blocks = __reuse_droped_blocks(num_required_blocks, num_drop_blocks,\n                                                                           droped_blocks)\n            if num_required_blocks > 0:\n                blocks = self.allocator.allocate(num_required_blocks, 'gpu')\n                logical_blocks.append(blocks)\n\n        # drop unused blocks\n        if droped_blocks is not None:\n            self.allocator.free(droped_blocks)\n"
  },
  {
    "path": "lmdeploy/pytorch/paging/block_trie.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport heapq\nfrom dataclasses import dataclass\nfrom typing import Dict, Set\n\nimport numpy as np\n\nfrom lmdeploy.pytorch.messages import SchedulerSequence\n\nfrom ..config import CacheConfig\nfrom .block_manager import BaseBlockManager\n\n\n@dataclass\nclass PrefixCacheStats:\n    \"\"\"Prefix caching stats.\"\"\"\n    num_query_tokens: int = 0\n    num_hit_tokens: int = 0\n\n    def reset(self):\n        self.num_query_tokens = 0\n        self.num_hit_tokens = 0\n\n    def hit_rate(self):\n        return 0.0 if self.num_query_tokens <= 0 else float(self.num_hit_tokens) / self.num_query_tokens\n\n\nclass Node:\n    \"\"\"Node of block trie.\"\"\"\n\n    def __init__(self, hash_key: int, block: int, tokens: np.ndarray, num_matched: int = 0):\n        self.hash_key = hash_key\n        self.block = block\n        self.tokens = tokens\n        self.num_matched = num_matched\n        self.children: Dict[int, 'Node'] = dict()\n        self._parent: 'Node' = None\n\n    @property\n    def parent(self):\n        return self._parent\n\n    @parent.setter\n    def parent(self, val: 'Node'):\n        old_parent = self._parent\n        if old_parent is not None:\n            old_parent.children.pop(self.hash_key)\n        if val is not None:\n            val.children[self.hash_key] = self\n        self._parent = val\n\n    def __lt__(self, other):\n        return True\n\n    def __le__(self, other):\n        return True\n\n\nclass BlockTrie:\n    \"\"\"Block trie for prefix caching.\"\"\"\n\n    def __init__(self, cache_config: CacheConfig, block_manager: BaseBlockManager):\n        self.block_manager = block_manager\n        self.cache_config = cache_config\n        self.allocator = self.block_manager.allocator\n        self.block_size = cache_config.block_size\n        self.enable = self.cache_config.enable_prefix_caching\n\n        # caches with different adapter should not be shared.\n        self._roots: Dict[str, Node] = dict()\n        self.leaves: Set[Node] = set()\n        self.stats = PrefixCacheStats()\n\n    def hit_rate(self):\n        \"\"\"Get hit rate.\"\"\"\n        return self.stats.hit_rate()\n\n    def get_root(self, adapter_name: str):\n        \"\"\"Get root by adapter name.\"\"\"\n        if adapter_name not in self._roots:\n            self._roots[adapter_name] = Node(-1, -1, None)\n        return self._roots[adapter_name]\n\n    def match(self, seq: SchedulerSequence):\n        \"\"\"Match sequence and cache.\"\"\"\n        if not self.enable:\n            return\n\n        block_size = self.block_size\n        matched_blocks = []\n\n        logical_blocks = seq.logical_blocks\n        curr: Node = getattr(logical_blocks, 'last_shared_node', None)\n        if curr is None:\n            curr = self.get_root(seq.adapter_name)\n        init_num_matched = curr.num_matched\n        num_matched = curr.num_matched\n\n        def __match_success(node: Node):\n            nonlocal curr, num_matched\n            matched_blocks.append(node.block)\n            curr = node\n            num_matched += block_size\n\n        while num_matched + block_size < seq.num_valid_ids:\n            curr_tokens = seq.history_cache[num_matched:num_matched + block_size]\n\n            key = hash(('random', tuple(curr_tokens)))\n            if key not in curr.children:\n                break\n\n            child = curr.children[key]\n            if not np.array_equal(curr_tokens, child.tokens):\n                break\n\n            __match_success(child)\n\n        if len(matched_blocks) > 0:\n            matched_blocks = np.array(matched_blocks)\n            self.allocator.update_access_time(matched_blocks)\n            self.allocator.add_ref_count(matched_blocks, 1)\n            seq.logical_blocks.append(matched_blocks)\n            seq.set_step(num_matched)\n\n        # record prefix hit\n        self.stats.num_query_tokens += seq.num_all_ids - init_num_matched\n        self.stats.num_hit_tokens += num_matched - init_num_matched\n\n        seq.logical_blocks.last_shared_node = curr\n\n    def allocate(self, seq: SchedulerSequence):\n        \"\"\"allocate.\"\"\"\n        if not self.enable:\n            return\n\n        block_size = self.block_size\n        logical_blocks = seq.logical_blocks\n        node: Node = getattr(logical_blocks, 'last_shared_node', None)\n        if node is None:\n            node = self.get_root(seq.adapter_name)\n            logical_blocks.last_shared_node = node\n\n        num_matched = node.num_matched\n        num_valid_ids = seq.num_valid_ids\n\n        if num_matched + block_size > num_valid_ids:\n            return\n\n        if len(node.children) == 0 and node.parent is not None:\n            self.leaves.remove(node)\n\n        block_id = num_matched // block_size\n        blocks = []\n        free_blocks = []\n        while num_matched + block_size <= num_valid_ids:\n            curr_tokens = seq.history_cache[num_matched:num_matched + block_size]\n\n            block = logical_blocks[block_id]\n\n            hash_key = hash(('random', tuple(curr_tokens)))\n            parent = node\n            if hash_key in parent.children:\n                child = parent.children[hash_key]\n                if not np.array_equal(curr_tokens, child.tokens):\n                    break\n                node = child\n                free_blocks.append(block)\n                logical_blocks[block_id] = node.block\n            else:\n                node = Node(hash_key=hash_key, block=block, tokens=curr_tokens, num_matched=num_matched + block_size)\n                node.parent = parent\n            blocks.append(node.block)\n            num_matched += block_size\n            block_id += 1\n\n        logical_blocks.last_shared_node = node\n        if node.parent is not None and len(node.children) == 0:\n            # ignore root\n            self.leaves.add(node)\n        if len(blocks) > 0:\n            self.allocator.add_ref_count(np.array(blocks), 1)\n        if len(free_blocks) > 0:\n            self.allocator.free(np.array(free_blocks))\n\n    def evict(self, max_num_blocks: int):\n        \"\"\"evict.\"\"\"\n        if not self.enable:\n            return 0\n\n        def __remove_leaf(leaves, evicted_blocks):\n            _, leaf = heapq.heappop(leaves)\n            evicted_blocks.append(leaf.block)\n            parent = leaf.parent\n            leaf.parent = None\n            self.leaves.remove(leaf)\n            return parent\n\n        def __add_leaf(leaves, parent):\n            self.leaves.add(parent)\n            if self.allocator.get_ref_count(parent.block) == 1:\n                access_time = self.allocator.get_access_time(parent.block)\n                heapq.heappush(leaves, (access_time, parent))\n\n        if len(self.leaves) == 0:\n            return 0\n\n        evicted_blocks = []\n        leaves = list(self.leaves)\n\n        # filter ref-cnt == 1 (trie own one block ref)\n        leave_blocks = np.array(list(leaf.block for leaf in leaves))\n        ref_cnt = self.allocator.get_ref_count(leave_blocks)\n        indices = (ref_cnt == 1).nonzero()[0]\n        if len(indices) == 0:\n            return 0\n\n        # make heap\n        leaves = list(leaves[i] for i in indices)\n        access_times = self.allocator.get_access_time(leave_blocks)\n        access_times = list(access_times[i] for i in indices)\n        leaves = list(zip(access_times, leaves))\n        heapq.heapify(leaves)\n\n        while len(leaves) > 0 and len(evicted_blocks) < max_num_blocks:\n            parent = __remove_leaf(leaves, evicted_blocks)\n            if parent.parent is None:\n                # ignore root\n                continue\n            if len(parent.children) == 0:\n                __add_leaf(leaves, parent)\n\n        self.allocator.free(np.array(evicted_blocks))\n\n        return len(evicted_blocks)\n"
  },
  {
    "path": "lmdeploy/pytorch/paging/eviction_helper/__init__.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom lmdeploy.utils import get_logger\n\nlogger = get_logger('lmdeploy')\n\n\ndef build_eviction_helper(scheduler, eviction_type: str):\n    \"\"\"Build eviction helper.\"\"\"\n    if eviction_type == 'copy':\n        logger.warning('`copy` eviction has been deprecated, '\n                       'use `recompute` instead.')\n        eviction_type = 'recompute'\n    if eviction_type == 'recompute':\n        from .recompute_eviction_helper import RecomputeEvictionHelper\n        return RecomputeEvictionHelper(scheduler)\n    else:\n        raise TypeError(f'Unknown eviction type: {eviction_type}')\n"
  },
  {
    "path": "lmdeploy/pytorch/paging/eviction_helper/base_eviction_helper.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom typing import List\n\nfrom ...messages import SchedulerSequence\nfrom ..scheduler import Scheduler\n\nSeqList = List[SchedulerSequence]\n\n\nclass BaseEvictionHelper:\n    \"\"\"Base eviction helper.\"\"\"\n\n    def __init__(self, scheduler: Scheduler):\n        self.scheduler = scheduler\n        self.block_manager = scheduler.block_manager\n        self.block_trie = scheduler.block_trie\n        self.state_manager = scheduler.state_manager\n        self.cache_config = scheduler.cache_config\n\n    def need_swap_in(self, seq: SchedulerSequence):\n        \"\"\"Sequence need swap in.\"\"\"\n        raise NotImplementedError('Not implemented.')\n\n    def evict_for_seq(self, seq: SchedulerSequence, evictable_seqs: List[SchedulerSequence], prealloc_size: int):\n        \"\"\"Evict seqs.\"\"\"\n        raise NotImplementedError('Not implemented.')\n"
  },
  {
    "path": "lmdeploy/pytorch/paging/eviction_helper/recompute_eviction_helper.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom typing import List\n\nfrom ...messages import SchedulerSequence\nfrom ..scheduler import Scheduler\nfrom .base_eviction_helper import BaseEvictionHelper\n\n\nclass RecomputeEvictionHelper(BaseEvictionHelper):\n    \"\"\"Recompute eviction.\"\"\"\n\n    def __init__(self, scheduler: Scheduler):\n        super().__init__(scheduler)\n\n        if len(self.cache_config.states_shapes) == 0:\n            self.evict_for_seq = self._evict_for_seq_default\n        else:\n            self.evict_for_seq = self._evict_for_ssm\n\n    def _evict_for_seq_default(self, seq: SchedulerSequence, evictable_seqs: List[SchedulerSequence],\n                               prealloc_size: int):\n        \"\"\"Evict seqs.\"\"\"\n        block_manager = self.block_manager\n        block_trie = self.block_trie\n        num_required_blocks = block_manager.num_required_blocks(seq, prealloc_size)\n\n        if block_manager.get_num_free_gpu_blocks() >= num_required_blocks:\n            return True\n\n        success = False\n        while len(evictable_seqs) > 0:\n            evict_seq = evictable_seqs.pop(0)\n\n            # skip sequence with no blocks\n            if evict_seq.num_blocks == 0:\n                continue\n\n            evict_seq.state.free()\n            num_req = (num_required_blocks - block_manager.get_num_free_gpu_blocks())\n            if num_req <= 0:\n                success = True\n                break\n\n            block_trie.evict(num_req)\n            num_req = (num_required_blocks - block_manager.get_num_free_gpu_blocks())\n            if num_req <= 0:\n                success = True\n                break\n\n        # for empty evictable_seqs case\n        num_req = num_required_blocks - block_manager.get_num_free_gpu_blocks()\n        if num_req > 0:\n            block_trie.evict(num_req)\n            if num_required_blocks <= block_manager.get_num_free_gpu_blocks():\n                success = True\n\n        return success\n\n    def _evict_for_ssm(self, seq: SchedulerSequence, evictable_seqs: List[SchedulerSequence], prealloc_size: int):\n        \"\"\"Evict seqs.\"\"\"\n        block_manager = self.block_manager\n        state_manager = self.state_manager\n        block_trie = self.block_trie\n        num_required_blocks = block_manager.num_required_blocks(seq, prealloc_size)\n        has_free_state = state_manager.get_num_free() > 0\n\n        if has_free_state and block_manager.get_num_free_gpu_blocks() >= num_required_blocks:\n            return True\n\n        success = False\n        while len(evictable_seqs) > 0:\n            evict_seq = evictable_seqs.pop(0)\n\n            # skip sequence with no blocks\n            if evict_seq.num_blocks == 0 and evict_seq.logical_state < 0:\n                continue\n\n            # free sequence\n            evict_seq.state.free()\n            has_free_state = True\n            num_req = (num_required_blocks - block_manager.get_num_free_gpu_blocks())\n            if num_req <= 0:\n                success = True\n                break\n\n            # clear cached prefix\n            block_trie.evict(num_req)\n            num_req = (num_required_blocks - block_manager.get_num_free_gpu_blocks())\n            if num_req <= 0:\n                success = True\n                break\n\n        if not has_free_state:\n            return False\n\n        # for empty evictable_seqs case\n        num_req = num_required_blocks - block_manager.get_num_free_gpu_blocks()\n        if num_req > 0:\n            block_trie.evict(num_req)\n            if num_required_blocks <= block_manager.get_num_free_gpu_blocks():\n                success = True\n\n        return success\n"
  },
  {
    "path": "lmdeploy/pytorch/paging/scheduler.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n# modify from: https://github.com/vllm-project/vllm\n\nfrom collections import OrderedDict\nfrom contextlib import contextmanager\nfrom dataclasses import dataclass\nfrom typing import Dict, List\n\nfrom torch.profiler import record_function\n\nfrom lmdeploy.messages import EventType, ScheduleMetrics\nfrom lmdeploy.utils import get_logger\n\nfrom ..config import CacheConfig, SchedulerConfig\nfrom ..messages import MessageStatus, SchedulerSequence, SchedulerSession, SequenceManager, SequenceMeta\nfrom .block_manager import build_block_manager\nfrom .block_trie import BlockTrie\nfrom .eviction_helper import build_eviction_helper\nfrom .state_manager import build_state_manager\n\nlogger = get_logger('lmdeploy')\n\nMapType = Dict[int, int]\nSeqList = List[SchedulerSequence]\n\n\n@dataclass\nclass SchedulerOutput:\n    \"\"\"Output of schedule.\"\"\"\n\n    running: SeqList\n    swap_in_map: MapType\n    swap_out_map: MapType\n    copy_map: MapType\n\n\nclass Scheduler:\n    \"\"\"Tools to schedule next step.\n\n    Args:\n        scheduler_config (SchedulerConfig): The config of scheduler.\n        cache_config (CacheConfig): The config of cache info.\n    \"\"\"\n\n    def __init__(\n        self,\n        scheduler_config: SchedulerConfig,\n        cache_config: CacheConfig,\n        seq_meta: SequenceMeta = None,\n    ) -> None:\n        self.scheduler_config = scheduler_config\n        self.cache_config = cache_config\n        self.sessions: Dict[int, SchedulerSession] = OrderedDict()\n\n        # For Disaggregation\n        self.locked_sessions: Dict[int, SchedulerSession] = OrderedDict()\n\n        self.block_manager = build_block_manager(cache_config)\n        self.block_trie = BlockTrie(self.cache_config, self.block_manager)\n        self.state_manager = build_state_manager(self.cache_config)\n        self.is_ssm = len(self.cache_config.states_shapes) > 0\n\n        self.eviction_helper = build_eviction_helper(self, self.scheduler_config.eviction_type)\n\n        seq_meta = seq_meta or SequenceMeta(self.cache_config.block_size)\n        self.seq_meta = seq_meta\n        self.seq_manager = SequenceManager(seq_meta)\n\n    @staticmethod\n    def create_status_list_property(status: MessageStatus):\n        \"\"\"Create status list property.\"\"\"\n\n        def _get_status_list(self):\n            seq_map = self.seq_manager.get_sequences(status)\n            return list(seq_map.values())\n\n        return property(_get_status_list)\n\n    @staticmethod\n    def create_num_status_method(status: MessageStatus):\n        \"\"\"Create num status method.\"\"\"\n\n        def _num_status(self):\n            return self.seq_manager.num_sequences(status)\n\n        return _num_status\n\n    @staticmethod\n    def create_has_status_method(status: MessageStatus):\n        \"\"\"Create has status method.\"\"\"\n\n        def _has_status(self):\n            return self.seq_manager.num_sequences(status) > 0\n\n        return _has_status\n\n    # status list properties\n    waiting = create_status_list_property(MessageStatus.WAITING)\n    ready = create_status_list_property(MessageStatus.READY)\n    hanging = create_status_list_property(MessageStatus.STOPPED)\n    running = create_status_list_property(MessageStatus.RUNNING)\n    migration_waiting = create_status_list_property(MessageStatus.MIGRATION_WAITING)\n    migration_done = create_status_list_property(MessageStatus.MIGRATION_DONE)\n\n    # num status methods\n    num_waiting = create_num_status_method(MessageStatus.WAITING)\n    num_ready = create_num_status_method(MessageStatus.READY)\n    num_running = create_num_status_method(MessageStatus.RUNNING)\n    num_migration_waiting = create_num_status_method(MessageStatus.MIGRATION_WAITING)\n    num_migration_done = create_num_status_method(MessageStatus.MIGRATION_DONE)\n\n    # has status methods\n    has_waiting = create_has_status_method(MessageStatus.WAITING)\n    has_ready = create_has_status_method(MessageStatus.READY)\n    has_migration_waiting = create_has_status_method(MessageStatus.MIGRATION_WAITING)\n    has_migration_done = create_has_status_method(MessageStatus.MIGRATION_DONE)\n\n    def add_session(self, session_id: int):\n        \"\"\"Add new session.\n\n        Args:\n            session_id (int): New session id.\n        \"\"\"\n        assert session_id not in self.sessions\n        session = SchedulerSession(session_id, seq_manager=self.seq_manager, scheduler=self)\n        self.sessions[session_id] = session\n        return session\n\n    def _schedule_migration(self):\n        migration_ready: SeqList = []\n        migrating_token_count = 0\n\n        def _to_running(seq: SchedulerSequence):\n            \"\"\"To running.\"\"\"\n            seq.state.activate()\n            migration_ready.append(seq)\n            nonlocal migrating_token_count\n            migrating_token_count += seq.num_token_ids\n\n        def __evict_for_seq(seq: SchedulerSequence, waiting):\n            \"\"\"Evict until can append.\"\"\"\n            from itertools import chain\n\n            hanging = reversed(self.hanging)\n            waiting = reversed(waiting)\n            evictable = list(chain(hanging, waiting))\n            return self.eviction_helper.evict_for_seq(seq, evictable, 0)\n\n        def _reorder_migrating():\n            \"\"\"Reorder waiting.\"\"\"\n            return sorted(self.migration_waiting, key=lambda seq: seq.arrive_time)\n\n        migration_waiting = _reorder_migrating()\n\n        max_batches = self.scheduler_config.max_batches - self.num_ready() - self.num_running()\n        while len(migration_waiting) > 0 and len(migration_ready) < max_batches:\n            seq = migration_waiting.pop(0)\n            self.block_trie.match(migration_waiting)\n            if not __evict_for_seq(seq, migration_waiting):\n                break\n\n            # allocate session memory\n            self.block_manager.allocate(seq)\n            _to_running(seq)\n\n        return migration_ready\n\n    @record_function('schedule_prefill')\n    def _schedule_prefill(self, prealloc_size: int = 0):\n        \"\"\"Schedule for prefilling.\"\"\"\n\n        max_batches = self.scheduler_config.max_batches - self.num_ready() - self.num_running()\n        eviction_helper = self.eviction_helper\n        swap_out_map: MapType = dict()\n        swap_in_map: MapType = dict()\n        copy_map: MapType = dict()\n        running: SeqList = []\n        token_count = 0\n\n        def _to_running(seq: SchedulerSequence):\n            \"\"\"To running.\"\"\"\n            seq.state.activate()\n            running.append(seq)\n            nonlocal token_count\n            token_count += seq.num_token_ids\n\n        def __evict_for_seq(seq: SchedulerSequence, waiting):\n            \"\"\"Evict until can append.\"\"\"\n            from itertools import chain\n            hanging = reversed(self.hanging)\n            waiting = reversed(waiting)\n            evictable = list(chain(hanging, waiting))\n            return eviction_helper.evict_for_seq(seq, evictable, prealloc_size)\n\n        def _reorder_waiting():\n            \"\"\"Reorder waiting.\"\"\"\n            return sorted(self.waiting, key=lambda seq: seq.arrive_time)\n\n        num_waiting = self.seq_manager.num_sequences(MessageStatus.WAITING)\n        if (len(running) >= max_batches or num_waiting == 0):\n            return running, swap_in_map, swap_out_map, copy_map\n\n        waiting = _reorder_waiting()\n        while len(waiting) > 0 and len(running) < max_batches:\n            seq = waiting.pop(0)\n\n            if (len(running) > 0 and token_count + seq.num_token_ids > self.cache_config.max_prefill_token_num):\n                break\n\n            self.block_trie.match(seq)\n\n            if not __evict_for_seq(seq, waiting):\n                break\n\n            # allocate session memory\n            self.block_manager.allocate(seq, prealloc_size)\n            self.block_trie.allocate(seq)\n            if self.is_ssm:\n                self.state_manager.allocate(seq)\n            _to_running(seq)\n\n            seq.record_event(EventType.SCHEDULED)\n\n        return running, swap_in_map, swap_out_map, copy_map\n\n    @record_function('schedule_decoding')\n    def _schedule_decoding(self, prealloc_size: int = 0):\n        \"\"\"Schedule decoding.\"\"\"\n\n        def _reorder_running():\n            \"\"\"Reorder running.\"\"\"\n            return sorted(self.ready, key=lambda seq: seq.arrive_time)\n\n        running = _reorder_running()\n        assert len(running) != 0\n\n        eviction_helper = self.eviction_helper\n        swap_out_map: MapType = dict()\n        swap_in_map: MapType = dict()\n        copy_map: MapType = dict()\n\n        def __evict_for_seq(seq: SchedulerSequence, num_required_blocks: int):\n            \"\"\"Evict until can append.\"\"\"\n            if num_required_blocks == 0:\n                # No need to evict, just return True.\n                return True\n            elif num_required_blocks < self.block_manager.get_num_free_gpu_blocks():\n                # Enough free blocks, just return True.\n                return True\n\n            from itertools import chain\n            hanging = reversed(self.hanging)\n            waiting = reversed(self.waiting)\n            evictable = list(chain(hanging, waiting))\n            return eviction_helper.evict_for_seq(seq, evictable, prealloc_size)\n\n        # 1. running\n        while len(running) > 0:\n            # token + n\n            seq = running.pop(0)\n            num_required_blocks = self.block_manager.num_required_blocks(seq, prealloc_size)\n            assert seq.num_blocks + num_required_blocks <= self.block_manager.num_gpu_blocks, (\n                'Sequence requires more blocks than total gpu blocks.')\n\n            while not __evict_for_seq(seq, num_required_blocks):\n                if len(running) == 0:\n                    break\n                seq_preempted = running.pop(-1)\n                seq_preempted.state.evict()\n\n            if self.block_manager.get_num_free_gpu_blocks() < num_required_blocks:\n                seq.state.evict()\n                continue\n\n            self.block_manager.allocate(seq, prealloc_size)\n            self.block_trie.allocate(seq)\n\n        return self.ready[:self.scheduler_config.max_batches], swap_in_map, swap_out_map, copy_map\n\n    def schedule(self, is_prefill: bool, prealloc_size: int = 0):\n        \"\"\"Schedule inputs for next steps.\"\"\"\n        if is_prefill:\n            output = self._schedule_prefill(prealloc_size)\n        else:\n            output = self._schedule_decoding(prealloc_size)\n        running, swap_in_map, swap_out_map, copy_map = output\n\n        return SchedulerOutput(running=running, swap_in_map=swap_in_map, swap_out_map=swap_out_map, copy_map=copy_map)\n\n    @record_function('schedule_running')\n    def schedule_running(self, running: SeqList, num_decode_tokens: int = 1, prealloc_size: int = 1):\n        \"\"\"Schedule running sequences.\n\n        This function is used to add blocks for running sequences request would be marked as invalid if not enough\n        blocks can be allocated.\n        \"\"\"\n        assert len(running) > 0\n        eviction_helper = self.eviction_helper\n\n        valid_mask = [True for _ in running]\n\n        # loop over reverse running\n        rev_running = reversed(running)\n        for idx, seq in enumerate(rev_running):\n            if not seq.status == MessageStatus.RUNNING:\n                valid_mask[idx] = False\n                continue\n\n            num_required_blocks = self.block_manager.num_required_blocks(seq, num_decode_tokens)\n\n            if num_required_blocks == 0:\n                continue\n\n            if eviction_helper.evict_for_seq(seq, self.hanging + self.waiting, prealloc_size):\n                self.block_manager.allocate(seq, prealloc_size)\n                self.block_trie.allocate(seq)\n                continue\n\n            # running to ready\n            seq.state.deactivate()\n            # ready to waiting\n            seq.state.evict()\n            valid_mask[idx] = False\n        valid_mask = list(reversed(valid_mask))\n        return valid_mask\n\n    def stop_session(self, session_id: int):\n        \"\"\"Stop session.\n\n        Args:\n            session_id (int): The session id.\n        \"\"\"\n        assert session_id in self.sessions\n        session = self.sessions[session_id]\n        for seq in session.sequences.values():\n            seq.state.stop()\n\n    def end_session(self, session_id: int):\n        \"\"\"End session.\n\n        Args:\n            session_id (int): The session id.\n        \"\"\"\n        if self.seq_meta.sampling_strategy is not None:\n            self.seq_meta.sampling_strategy.on_session_end(session_id)\n        session = self.sessions[session_id]\n        seqs = list(session.sequences.values())\n        for seq in seqs:\n            # stop session so it won't get scheduled again\n            seq.state.stop()\n            session.remove_sequence(seq)\n        self.sessions.pop(session_id)\n\n    def has_unfinished(self):\n        \"\"\"Check if there are any unfinished message.\"\"\"\n        return self.has_ready() or self.has_waiting() or self.has_migration_done()\n\n    def get_block_tables(self, seqs: SeqList):\n        \"\"\"Get block table of the sequences.\"\"\"\n        return [self.block_manager.get_block_table(seq) for seq in seqs]\n\n    def evict_seqs(self, running: SeqList):\n        \"\"\"Evict running sequences.\"\"\"\n        for seq in running:\n            seq.state.evict()\n\n    def activate_seqs(self, running: SeqList, filter_status: MessageStatus = MessageStatus.READY):\n        \"\"\"Lock running sequence.\"\"\"\n        for seq in running:\n            if seq.status == filter_status:\n                seq.state.activate()\n\n    def deactivate_seqs(self, running: SeqList, filter_status: MessageStatus = MessageStatus.RUNNING):\n        for seq in running:\n            if seq.status == filter_status:\n                seq.state.deactivate()\n\n    @contextmanager\n    def seqs_activation(self, running: SeqList):\n        \"\"\"Context manager to activate and deactivate sequences.\"\"\"\n        self.activate_seqs(running, MessageStatus.READY)\n        try:\n            yield running\n        finally:\n            self.deactivate_seqs(running, MessageStatus.RUNNING)\n\n    def activate_migration_seqs(self, running: SeqList):\n        \"\"\"Lock running sequence.\"\"\"\n        return self.activate_seqs(running, filter_status=MessageStatus.MIGRATION_READY)\n\n    def deactivate_migration_seqs(self, running: SeqList):\n        \"\"\"Unlock running migration.\"\"\"\n        return self.deactivate_seqs(running, filter_status=MessageStatus.MIGRATION_RUNNING)\n\n    @contextmanager\n    def seqs_migration_activation(self, running: SeqList):\n        \"\"\"Context manager to activate and deactivate sequences.\"\"\"\n        self.activate_migration_seqs(running)\n        try:\n            yield running\n        finally:\n            self.deactivate_migration_seqs(running)\n\n    def collect_migration_done(self):\n        for seq in self.migration_done:\n            seq.state.activate()\n\n    @property\n    def schedule_metrics(self):\n        return ScheduleMetrics(\n            active_seqs=self.num_running(),\n            waiting_seqs=self.num_waiting() + self.num_ready(),\n            total_blocks=self.block_manager.num_gpu_blocks,\n            free_blocks=self.block_manager.get_num_free_gpu_blocks(),\n            prefix_cache_hit_rate=self.block_trie.hit_rate(),\n        )\n"
  },
  {
    "path": "lmdeploy/pytorch/paging/seq_states/__init__.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom .states import StateBase, build_seq_state  # noqa: F401\n"
  },
  {
    "path": "lmdeploy/pytorch/paging/seq_states/states.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom typing import TYPE_CHECKING\n\nfrom lmdeploy.pytorch.messages import MessageStatus, SchedulerSequence\n\nif TYPE_CHECKING:\n    from lmdeploy.pytorch.paging import Scheduler\n\n\ndef _free_seq(seq: SchedulerSequence, scheduler: 'Scheduler'):\n    \"\"\"Free the sequence.\"\"\"\n    if seq.num_blocks > 0:\n        scheduler.block_manager.free(seq)\n    if seq.logical_state >= 0:\n        scheduler.state_manager.free(seq)\n    seq.set_step(0)\n\n\nclass StateBase:\n    status = None\n    _registry = dict()\n\n    def __init_subclass__(cls, **kargs) -> None:\n        super().__init_subclass__(**kargs)\n        if cls.status:\n            cls._registry[cls.status] = cls\n\n    @classmethod\n    def build(cls, scheduler: 'Scheduler', seq: 'SchedulerSequence', status: MessageStatus) -> 'StateBase':\n        \"\"\"Build sequence state.\"\"\"\n        if status not in cls._registry:\n            raise NotImplementedError(f'Unsupported status {status} for building seq state.')\n        return cls._registry[status](seq, scheduler)\n\n    def __init__(self, seq: SchedulerSequence, scheduler: 'Scheduler'):\n        self.seq = seq\n        self.scheduler = scheduler\n\n    def to_state(self, new_state):\n        \"\"\"Transition to a new state.\"\"\"\n        self.scheduler.seq_manager.update_sequence_status(self.seq, new_state.status)\n        self.seq.set_state(new_state(self.seq, self.scheduler))\n\n    def evict(self):\n        \"\"\"Evict the state.\"\"\"\n        raise NotImplementedError(f'evict not implemented for state {self.status}')\n\n    def activate(self):\n        \"\"\"Activate the state.\"\"\"\n        raise NotImplementedError(f'activate not implemented for state {self.status}')\n\n    def deactivate(self):\n        \"\"\"Deactivate the state.\"\"\"\n        raise NotImplementedError(f'deactivate not implemented for state {self.status}')\n\n    def finish(self):\n        \"\"\"Finish the state.\"\"\"\n        raise NotImplementedError(f'finish not implemented for state {self.status}')\n\n    def stop(self):\n        \"\"\"Stop the state.\"\"\"\n        self.to_state(StoppedState)\n\n    def free(self):\n        \"\"\"Free the state.\"\"\"\n        _free_seq(self.seq, self.scheduler)\n\n\nclass WaitingState(StateBase):\n    \"\"\"State for waiting sequences.\"\"\"\n    status = MessageStatus.WAITING\n\n    def activate(self):\n        \"\"\"From WAITING to READY.\"\"\"\n        num_req_blocks = self.scheduler.block_manager.num_required_blocks(self.seq)\n        assert self.seq.num_blocks >= num_req_blocks\n        if self.scheduler.is_ssm:\n            assert self.seq.logical_state >= 0\n        self.to_state(ReadyState)\n\n    def evict(self):\n        self.to_state(WaitingState)\n\n\nclass ReadyState(StateBase):\n    \"\"\"State for ready sequences.\"\"\"\n    status = MessageStatus.READY\n\n    def activate(self):\n        \"\"\"From READY to RUNNING.\"\"\"\n        self.to_state(RunningState)\n\n    def evict(self):\n        self.to_state(WaitingState)\n\n\nclass StoppedState(StateBase):\n    \"\"\"State for stopped sequences.\"\"\"\n    status = MessageStatus.STOPPED\n\n    def activate(self):\n        \"\"\"From STOPPED to WAITING.\"\"\"\n        assert self.seq.num_token_ids > 0\n        self.to_state(WaitingState)\n\n    def evict(self):\n        self.to_state(StoppedState)\n\n\nclass RunningState(StateBase):\n    \"\"\"State for running sequences.\"\"\"\n    status = MessageStatus.RUNNING\n\n    def deactivate(self):\n        self.to_state(ReadyState)\n\n    def finish(self):\n        if self.seq.preserve_cache:\n            self.to_state(ToBeMigratedState)\n        else:\n            self.to_state(StoppedState)\n\n\nclass ToBeMigratedState(StateBase):\n    \"\"\"State for to be migrated sequences.\"\"\"\n    status = MessageStatus.TO_BE_MIGRATED\n\n    def finish(self):\n        self.to_state(StoppedState)\n\n\nclass MigrationWaitingState(StateBase):\n    \"\"\"State for migration waiting sequences.\"\"\"\n    status = MessageStatus.MIGRATION_WAITING\n\n    def activate(self):\n        self.to_state(MigrationReadyState)\n\n    def evict(self):\n        self.to_state(MigrationWaitingState)\n\n\nclass MigrationReadyState(StateBase):\n    \"\"\"State for migration ready sequences.\"\"\"\n    status = MessageStatus.MIGRATION_READY\n\n    def activate(self):\n        self.to_state(MigrationRunningState)\n\n    def evict(self):\n        self.to_state(MigrationWaitingState)\n\n\nclass MigrationDoneState(StateBase):\n    \"\"\"State for migration done sequences.\"\"\"\n    status = MessageStatus.MIGRATION_DONE\n\n    def activate(self):\n        self.to_state(WaitingState)\n\n    def finish(self):\n        self.to_state(WaitingState)\n\n\nclass MigrationRunningState(StateBase):\n    \"\"\"State for migration running sequences.\"\"\"\n    status = MessageStatus.MIGRATION_RUNNING\n\n    def deactivate(self):\n        self.to_state(MigrationDoneState)\n\n    def finish(self):\n        self.to_state(MigrationDoneState)\n\n\ndef build_seq_state(scheduler: 'Scheduler', seq: 'SchedulerSequence', status: MessageStatus) -> StateBase:\n    \"\"\"Build sequence state.\"\"\"\n    return StateBase.build(scheduler, seq, status)\n"
  },
  {
    "path": "lmdeploy/pytorch/paging/state_manager.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport numpy as np\n\nfrom lmdeploy.pytorch.config import CacheConfig\nfrom lmdeploy.pytorch.messages import SchedulerSequence\n\n\nclass StateAllocator:\n    \"\"\"State allocator.\"\"\"\n\n    def __init__(self, num_states: int, offset: int = 0):\n        self.num_states = num_states\n        self._free_states = np.arange(offset, offset + num_states, dtype=np.int64)\n        self._free_count = num_states\n\n    def allocate(self):\n        \"\"\"allocate.\"\"\"\n        if self.get_num_free() == 0:\n            raise RuntimeError('No free states.')\n        alloc_id = self._free_states[-self._free_count]\n        self._free_count -= 1\n        return alloc_id\n\n    def free(self, state_id: int):\n        \"\"\"free.\"\"\"\n        if self._free_count >= self.num_states:\n            raise RuntimeError('All states are free.')\n        self._free_count += 1\n        self._free_states[-self._free_count] = state_id\n\n    def get_num_free(self):\n        return self._free_count\n\n\nclass StateManager:\n\n    def __init__(self, num_states: int, num_reserved: int = 0):\n        if num_states is None:\n            num_states = 1\n        self.allocator = StateAllocator(num_states, offset=num_reserved)\n\n    def is_allocated(self, seq: SchedulerSequence):\n        \"\"\"Check if a sequence is allocated.\"\"\"\n        return seq.logical_state >= 0\n\n    def allocate(self, seq: SchedulerSequence):\n        \"\"\"Allocate states for a sequence.\"\"\"\n        if self.is_allocated(seq):\n            return None\n        seq.logical_state = self.allocator.allocate()\n\n    def free(self, seq: SchedulerSequence):\n        \"\"\"Free states for a sequence.\"\"\"\n        if not self.is_allocated(seq):\n            return None\n        self.allocator.free(seq.logical_state)\n        seq.logical_state = -1\n\n    def get_num_free(self):\n        \"\"\"Get num free.\"\"\"\n        return self.allocator.get_num_free()\n\n\ndef build_state_manager(cache_config: CacheConfig) -> StateManager:\n    \"\"\"Build state manager.\"\"\"\n    num_states = cache_config.num_state_caches\n    # state is different from block, we always reserve one state for system use\n    num_reserved = 1\n    return StateManager(num_states, num_reserved)\n"
  },
  {
    "path": "lmdeploy/pytorch/ray.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport os\nimport time\nfrom typing import Dict, List\n\nimport ray\nfrom ray.util.placement_group import PlacementGroup\n\nfrom lmdeploy.pytorch.devices import get_device_manager\nfrom lmdeploy.utils import get_logger\n\nlogger = get_logger('lmdeploy')\nPG_WAIT_TIMEOUT = 1800\n\n\ndef get_device_str(device_type: str = None) -> str:\n    \"\"\"Get device str.\"\"\"\n    device_type = device_type or get_device_manager().current_context().device_type\n    if device_type in ['cuda', 'maca']:\n        device_type = 'GPU'\n    elif device_type == 'ascend':\n        device_type = 'NPU'\n    elif device_type == 'camb':\n        device_type = 'MLU'\n    else:\n        raise ValueError(f'Unsupported device type: {device_type}')\n\n    return device_type\n\n\ndef get_resource_kwargs(device_str: str, resource_used: float = 0.01) -> Dict[str, float]:\n    \"\"\"Get resource kwargs.\"\"\"\n    if device_str == 'GPU':\n        resource_kwargs = {'num_gpus': resource_used}\n    elif device_str == 'NPU':\n        resource_kwargs = {'resources': {device_str: resource_used}}\n    else:\n        raise ValueError(f'Unsupported device type: {device_str}')\n    return resource_kwargs\n\n\ndef _wait_until_pg_ready(current_placement_group: PlacementGroup):\n    \"\"\"Wait until a placement group is ready.\n\n    It prints the informative log messages if the placement group is not created within time.\n    \"\"\"\n    # copy from vLLM\n    # Wait until PG is ready - this will block until all\n    # requested resources are available, and will timeout\n    # if they cannot be provisioned.\n    placement_group_specs = current_placement_group.bundle_specs\n\n    s = time.time()\n    pg_ready_ref = current_placement_group.ready()\n    wait_interval = 10\n    while time.time() - s < PG_WAIT_TIMEOUT:\n        ready, _ = ray.wait([pg_ready_ref], timeout=wait_interval)\n        if len(ready) > 0:\n            break\n\n        # Exponential backoff for warning print.\n        wait_interval *= 2\n        logger.info(\n            'Waiting for creating a placement group of specs for '\n            '%d seconds. specs=%s. Check '\n            '`ray status` to see if you have enough resources,'\n            ' and make sure the IP addresses used by ray cluster'\n            ' are the same as VLLM_HOST_IP environment variable'\n            ' specified in each node if you are running on a multi-node.', int(time.time() - s), placement_group_specs)\n\n    try:\n        ray.get(pg_ready_ref, timeout=0)\n    except ray.exceptions.GetTimeoutError:\n        raise ValueError('Cannot provide a placement group of '\n                         f'{placement_group_specs=} within {PG_WAIT_TIMEOUT} seconds. See '\n                         '`ray status` to make sure the cluster has enough resources.') from None\n\n\ndef _get_obj_store_memory(dp: int = 1):\n    \"\"\"Get obj store memory.\"\"\"\n    import psutil\n    DEFAULT_OBJECT_STORE_MEMORY_PROPORTION = os.getenv('RAY_DEFAULT_OBJECT_STORE_MEMORY_PROPORTION', '0.3')\n    DEFAULT_OBJECT_STORE_MEMORY_PROPORTION = float(DEFAULT_OBJECT_STORE_MEMORY_PROPORTION)\n    DEFAULT_OBJECT_STORE_MAX_MEMORY_BYTES = os.getenv('RAY_DEFAULT_OBJECT_STORE_MAX_MEMORY_BYTES', None)\n    if DEFAULT_OBJECT_STORE_MAX_MEMORY_BYTES is None:\n        DEFAULT_OBJECT_STORE_MAX_MEMORY_BYTES = 80 * (10**9)\n    else:\n        DEFAULT_OBJECT_STORE_MAX_MEMORY_BYTES = int(DEFAULT_OBJECT_STORE_MAX_MEMORY_BYTES)\n    total_mem = psutil.virtual_memory().total\n    obj_store_mem = int(total_mem * DEFAULT_OBJECT_STORE_MEMORY_PROPORTION)\n    obj_store_mem = min(DEFAULT_OBJECT_STORE_MAX_MEMORY_BYTES, obj_store_mem)\n    if dp > 1:\n        obj_store_mem = obj_store_mem // min(8, dp)\n    return obj_store_mem\n\n\ndef init_ray_cluster(world_size: int, ray_address: str = None, dp: int = 1, device_type: str = 'cuda'):\n    \"\"\"Init ray cluster.\"\"\"\n    # modifier from vLLM\n    if not ray.is_initialized():\n        try:\n            num_cpus = world_size\n            object_store_memory = _get_obj_store_memory(dp=dp)\n            ray.init(address=ray_address,\n                     ignore_reinit_error=True,\n                     num_cpus=num_cpus,\n                     object_store_memory=object_store_memory)\n        except ValueError as e:\n            if e.args is not None and len(e.args) >= 1 and e.args[\n                    0] == 'When connecting to an existing cluster, num_cpus and num_gpus must not be provided.':\n                ray.init(address=ray_address, ignore_reinit_error=True)\n            else:\n                raise\n\n    device_str = get_device_str(device_type)\n\n    # Create placement group for worker processes\n    current_placement_group = ray.util.get_current_placement_group()\n    owned_pg = False\n    if not current_placement_group:\n        num_devices_in_cluster = ray.cluster_resources().get(device_str, 0)\n        if world_size > num_devices_in_cluster:\n            logger.warning(\n                'The number of required %ss exceeds the total '\n                'number of available %ss in the placement group.', device_str, device_str)\n        # Create a new placement group\n        placement_group_specs: List[Dict[str, float]] = ([{device_str: 1.0} for _ in range(world_size)])\n\n        # Pin at least one bundle to the local node.\n        # This helps multi-node DP keep each dp_rank process's workers co-located with\n        # the node where the process is launched.\n        current_ip = ray.util.get_node_ip_address()\n        placement_group_specs[0][f'node:{current_ip}'] = 0.001\n\n        # By default, Ray packs resources as much as possible.\n        current_placement_group = ray.util.placement_group(placement_group_specs, strategy='PACK')\n        _wait_until_pg_ready(current_placement_group)\n        owned_pg = True\n\n    assert current_placement_group is not None\n    # Set the placement group in the parallel config\n    placement_group = current_placement_group\n    return placement_group, owned_pg\n\n\nclass RayContext:\n    \"\"\"Context manager for Ray.\"\"\"\n\n    def __init__(self, world_size: int, ray_address: str = None, dp: int = 1, device_type: str = 'cuda'):\n        \"\"\"Initialize Ray context.\"\"\"\n        placement_group, owned_pg = init_ray_cluster(world_size=world_size,\n                                                     ray_address=ray_address,\n                                                     dp=dp,\n                                                     device_type=device_type)\n\n        self.placement_group = placement_group\n        self.owned_pg = owned_pg\n\n    def get_placement_group(self):\n        \"\"\"Get the placement group.\"\"\"\n        return self.placement_group\n\n    def shutdown(self):\n        \"\"\"Shutdown Ray.\"\"\"\n        if self.owned_pg:\n            ray.util.remove_placement_group(self.placement_group)\n            logger.debug('RayContext placement group removed.')\n\n        if ray.is_initialized():\n            try:\n                ray.shutdown()\n                logger.debug('Ray shutdown.')\n            except Exception:\n                logger.exception('Error during Ray shutdown.')\n        else:\n            logger.debug('Ray is not initialized, skipping shutdown.')\n"
  },
  {
    "path": "lmdeploy/pytorch/spec_decode/__init__.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nfrom ..config import BackendConfig, SpecDecodeConfig\nfrom ..distributed import DistContext\n\n\ndef build_spec_agent(specdecode_config: SpecDecodeConfig,\n                     backend_config: BackendConfig,\n                     dist_ctx: DistContext,\n                     inputs_strategy,\n                     agent_strategy,\n                     device: str = 'cuda'):\n    \"\"\"Build spec agent.\"\"\"\n    enable = dist_ctx.rank % dist_ctx.dist_config.attn_tp == 0 and specdecode_config is not None\n    if enable:\n        from .spec_agent import SpecModelAgent\n        return SpecModelAgent(specdecode_config, backend_config, inputs_strategy, agent_strategy, device=device)\n    else:\n        from .base import BaseSpecModelAgent\n        return BaseSpecModelAgent()\n\n\n__all__ = ['build_spec_agent']\n"
  },
  {
    "path": "lmdeploy/pytorch/spec_decode/base.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nfrom typing import Dict\n\nimport torch\n\nfrom ..config import CacheConfig, ModelConfig\nfrom ..engine.logits_process import SamplingInputs\nfrom ..model_inputs import ModelInputs\nfrom ..strategies.base.model_agent import ExtraInputs\n\n\nclass BaseSpecModelAgent:\n    \"\"\"Speculative model agent.\"\"\"\n\n    def __init__(self, enable: bool = False):\n        self._enabled = enable\n\n    def is_enabled(self):\n        return self._enabled\n\n    def set_cache_config(self, cache_config: CacheConfig):\n        \"\"\"Set all cache config.\"\"\"\n        pass\n\n    def set_model_config(self, model_config: ModelConfig):\n        \"\"\"Set model config.\"\"\"\n        pass\n\n    def build_model(self, empty_init: bool, target_model=None, build_model_ctx=None):\n        \"\"\"Build draft model.\"\"\"\n        pass\n\n    def build_graph_runner(self):\n        \"\"\"Build graph runner.\"\"\"\n        pass\n\n    def build_cache_engine(self, cache_stream: torch.cuda.Stream):\n        \"\"\"Build cache engine.\"\"\"\n        pass\n\n    async def async_model_forward(self, next_token_ids: torch.Tensor, model_inputs: ModelInputs,\n                                  extra_inputs: ExtraInputs, sampling_inputs: SamplingInputs):\n        \"\"\"Draft model forward.\"\"\"\n        return extra_inputs\n\n    def warmup(self, max_batches: int, target_model_config: ModelConfig):\n        \"\"\"warmup.\"\"\"\n        pass\n\n    def reset_graph_runner(self):\n        'reset graph runner'\n        pass\n\n    def update_main_model_outputs(self, output: Dict[str, torch.Tensor], model_inputs: ModelInputs):\n        \"\"\"Update outputs of main model.\"\"\"\n        if not self.is_enabled():\n            hidden_states = output.pop('hidden_states')\n            return hidden_states, output\n\n        hidden_states = output['hidden_states']\n        if not model_inputs.is_decoding:\n            logits_indices = model_inputs.seq_length.cumsum(0) - 1\n            hidden_states = hidden_states[:, logits_indices]\n        if 'aux_hidden_states' in output:\n            # replace with aux\n            output['hidden_states'] = output.pop('aux_hidden_states')\n        return hidden_states, output\n"
  },
  {
    "path": "lmdeploy/pytorch/spec_decode/proposers/__init__.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n# Copyright (c) OpenMMLab. All rights reserved.\n\nfrom .deepseek_mtp import DeepseekMTP  # noqa F401\nfrom .eagle import Eagle  # noqa F401\nfrom .eagle3 import Eagle3  # noqa F401\n"
  },
  {
    "path": "lmdeploy/pytorch/spec_decode/proposers/base.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom typing import Any, Dict, List, Optional\n\nimport torch\nfrom mmengine import Registry\nfrom torch.profiler import record_function\n\nfrom lmdeploy.utils import get_logger\n\nfrom ...config import ModelConfig, SpecDecodeConfig\nfrom ...engine.cache_engine import CacheEngine\nfrom ...model_inputs import ModelInputs, step_ctx_manager\nfrom ...models.patch import build_patched_model, update_custom_module_map\nfrom ...strategies.base.model_agent import ExtraInputs\nfrom ...weight_loader.model_weight_loader import load_model_weights\n\nSPEC_PROPOSERS = Registry('spec_proposers')\n\nlogger = get_logger('lmdeploy')\n\n\n@torch.inference_mode()\ndef draft_model_forward(\n    model: torch.nn.Module,\n    inputs: ModelInputs,\n    model_config: Optional[ModelConfig] = None,\n    cache_engine: Optional[CacheEngine] = None,\n):\n    \"\"\"Perform model forward.\"\"\"\n    stream = torch.cuda.current_stream()\n    with torch.cuda.stream(stream), step_ctx_manager(model.ctx_mgr):\n        # forward\n        ctx_mgr = model.ctx_mgr\n        kv_caches = None if cache_engine is None else cache_engine.gpu_cache\n        context = ctx_mgr.build_context(\n            inputs=inputs,\n            model_config=model_config,\n            cache_config=cache_engine.cache_config,\n            kv_caches=kv_caches,\n        )\n        with ctx_mgr.context(context):\n            model_metas = None\n            model_metas = model.update_model_metas(\n                past_key_values=kv_caches,\n                context=context,\n            )\n            input_dict = model.prepare_inputs_for_generation(\n                past_key_values=kv_caches,\n                context=context,\n            )\n            outputs = model(**input_dict)\n            if not isinstance(outputs, dict):\n                outputs = dict(hidden_states=outputs)\n            outputs.update(dict(model_metas=model_metas))\n    return outputs\n\n\nclass BaseSpecProposer:\n\n    def __init__(self, specdecode_config: SpecDecodeConfig, device: torch.device = None):\n        self.specdecode_config = specdecode_config\n        self.model = None\n        self.device = device\n        self.lm_head = None\n        self.num_speculative_tokens = specdecode_config.num_speculative_tokens\n        self.target_model = None\n\n    def build_model(self, empty_init: bool, target_model: torch.nn.Module = None, build_model_ctx=None):\n        if self.specdecode_config is None:\n            return\n        model_path = self.specdecode_config.model\n        model_config = self.specdecode_config.model_config\n        custom_module_map = model_config.custom_module_map\n        if custom_module_map is not None:\n            update_custom_module_map(custom_module_map)\n        logger.debug('build draft model')\n        patched_model = build_patched_model(\n            model_config,\n            device=self.device,\n            build_model_ctx=build_model_ctx,\n        )\n        logger.debug('loading weights for draft model.')\n        if not empty_init:\n            load_model_weights(patched_model, model_path, device=self.device)\n        self.model = patched_model\n        self.target_model = target_model\n\n    def get_outputs(self,\n                    model_outputs: Dict[str, torch.Tensor],\n                    model_inputs: ModelInputs,\n                    extra_inputs: ExtraInputs = None):\n        \"\"\"Get outputs.\"\"\"\n        raise NotImplementedError()\n\n    @record_function('draft_model_forward')\n    def _forward(self, model_inputs: ModelInputs, cache_engine: CacheEngine = None):\n        \"\"\"Forward.\"\"\"\n        return draft_model_forward(\n            self.model,\n            model_inputs,\n            model_config=self.specdecode_config.model_config,\n            cache_engine=cache_engine,\n        )\n\n    def update_inputs_decoding(self, model_inputs: ModelInputs, extra_inputs: ExtraInputs, next_input_ids: torch.Tensor,\n                               target_hidden_states: torch.Tensor, model_metas: List[Any]):\n        \"\"\"Update to decoding inputs.\"\"\"\n        model_inputs.is_decoding = True\n        batch_size = model_inputs.seq_length.size(0)\n        model_inputs.input_ids = next_input_ids\n        model_inputs.max_q_seqlen = 1\n        model_inputs.max_kv_seqlen += 1\n        model_inputs.sum_kv_seqlen += model_inputs.seq_length.numel()\n        model_inputs.history_lengths += model_inputs.seq_length\n        if extra_inputs.num_rejected_tokens is not None:\n            model_inputs.history_lengths -= extra_inputs.num_rejected_tokens\n        model_inputs.seq_length = model_inputs.seq_length.new_ones(batch_size)\n        model_inputs.target_position_ids = model_inputs.history_lengths.unsqueeze(0).clone()\n        model_inputs.model_metas = model_metas\n        model_inputs.target_hidden_states = target_hidden_states\n        return model_inputs\n\n    @record_function('draft_get_logits')\n    def get_logits(self, hidden_states: torch.Tensor):\n        \"\"\"Get logits of model output.\"\"\"\n        draft_model = self.model\n        if not isinstance(draft_model, torch.nn.Module):\n            draft_model = draft_model.model\n\n        if hasattr(draft_model, 'get_logits'):\n            logits = draft_model.get_logits(hidden_states)\n        else:\n            logits = self.target_model.get_logits(hidden_states)\n        return logits\n\n    def get_target_hidden_size(self, model_config: ModelConfig):\n        \"\"\"Get target hidden size.\"\"\"\n        return model_config.hidden_size\n\n\ndef build_specdecode_proposer(specdecode_config: SpecDecodeConfig, device: str = 'cuda'):\n    \"\"\"Build spec decoding proposer.\"\"\"\n    method = specdecode_config.method\n    if method in SPEC_PROPOSERS.module_dict:\n        spec_cls = SPEC_PROPOSERS.module_dict[method]\n        obj = spec_cls(specdecode_config, device=device)\n        return obj\n    raise ValueError(f'{method} not found in {SPEC_PROPOSERS.module_dict.keys()}')\n"
  },
  {
    "path": "lmdeploy/pytorch/spec_decode/proposers/deepseek_mtp.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom typing import Dict\n\nimport torch\n\nfrom lmdeploy.utils import get_logger\n\nfrom ...model_inputs import ModelInputs\nfrom ...strategies.ar_spec.model_agent import ARSpecExtraInputs\nfrom .base import SPEC_PROPOSERS, BaseSpecProposer\n\nlogger = get_logger('lmdeploy')\n\n\n@SPEC_PROPOSERS.register_module(name='deepseek_mtp')\nclass DeepseekMTP(BaseSpecProposer):\n\n    def get_outputs(self,\n                    model_outputs: Dict[str, torch.Tensor],\n                    model_inputs: ModelInputs,\n                    extra_inputs: ARSpecExtraInputs = None):\n        \"\"\"Get outputs.\"\"\"\n        hidden_states = model_outputs['hidden_states']\n        model_metas = model_outputs['model_metas']\n        if extra_inputs is not None and extra_inputs.last_token_indices is not None:\n            # for long input\n            if (not model_inputs.is_decoding) and model_inputs.seq_length.size(0) == 1:\n                hidden_states = hidden_states[:, -1:]\n            else:\n                last_token_loc = extra_inputs.last_token_indices\n                hidden_states = hidden_states[:, last_token_loc]\n\n        logits = self.get_logits(hidden_states)[0]\n        draft_token_ids = logits.argmax(dim=-1, keepdim=True)\n        return draft_token_ids, model_metas, hidden_states\n"
  },
  {
    "path": "lmdeploy/pytorch/spec_decode/proposers/eagle.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nfrom .base import SPEC_PROPOSERS\nfrom .deepseek_mtp import DeepseekMTP\n\n\n@SPEC_PROPOSERS.register_module(name='eagle')\nclass Eagle(DeepseekMTP):\n    \"\"\"Eagle.\"\"\"\n"
  },
  {
    "path": "lmdeploy/pytorch/spec_decode/proposers/eagle3.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom typing import Dict\n\nimport torch\n\nfrom lmdeploy.utils import get_logger\n\nfrom ...config import ModelConfig\nfrom ...model_inputs import ModelInputs\nfrom ...strategies.base.model_agent import ExtraInputs\nfrom .base import SPEC_PROPOSERS\nfrom .deepseek_mtp import DeepseekMTP\n\nlogger = get_logger('lmdeploy')\n\n\n@SPEC_PROPOSERS.register_module(name='eagle3')\nclass Eagle3(DeepseekMTP):\n\n    def build_model(self, empty_init: bool, target_model: torch.nn.Module = None, build_model_ctx=None):\n        super().build_model(empty_init, target_model=target_model, build_model_ctx=build_model_ctx)\n        self.draft_id_to_target_id = self.model.draft_id_to_target_id\n        if not self.model.include_embed_tokens:\n            logger.info('Using embed_tokens from target model.')\n            del self.model.model.embed_tokens\n            self.model.model.embed_tokens = target_model.get_input_embeddings()\n\n    def get_target_hidden_size(self, model_config: ModelConfig):\n        \"\"\"Get target hidden size.\"\"\"\n        hf_config = self.specdecode_config.model_config.hf_config\n        hidden_size = getattr(hf_config, 'target_hidden_size', hf_config.hidden_size)\n        return hidden_size * 3\n\n    def get_outputs(self,\n                    model_outputs: Dict[str, torch.Tensor],\n                    model_inputs: ModelInputs,\n                    extra_inputs: ExtraInputs = None):\n        \"\"\"Get outputs.\"\"\"\n        hidden_states = model_outputs['hidden_states']\n        hidden_states_prenorm = model_outputs['hidden_states_prenorm']\n        model_metas = model_outputs['model_metas']\n        if extra_inputs is not None and extra_inputs.last_token_indices is not None:\n            # for long input\n            if (not model_inputs.is_decoding) and model_inputs.seq_length.size(0) == 1:\n                hidden_states = hidden_states[:, -1:]\n                hidden_states_prenorm = hidden_states_prenorm[:, -1:]\n            else:\n                last_token_loc = extra_inputs.last_token_indices\n                hidden_states = hidden_states[:, last_token_loc]\n                hidden_states_prenorm = hidden_states_prenorm[:, last_token_loc]\n\n        logits = self.get_logits(hidden_states)[0]\n        draft_token_ids = logits.argmax(dim=-1, keepdim=True)\n        # token mapping\n        draft_token_ids = self.draft_id_to_target_id[draft_token_ids]\n        return draft_token_ids, model_metas, hidden_states_prenorm\n"
  },
  {
    "path": "lmdeploy/pytorch/spec_decode/reject_sampler.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport enum\nfrom typing import Optional\n\nimport torch\nfrom torch import LongTensor, Tensor, nn\nfrom torch.profiler import record_function\n\n\nclass SamplePolicy(enum.Enum):\n    \"\"\"Sample policy.\"\"\"\n\n    ALL_GREEDY = enum.auto()\n\n\nclass RejectionSampler(nn.Module):\n\n    def __init__(self, sample_policy: SamplePolicy = SamplePolicy.ALL_GREEDY):\n        super().__init__()\n        self.sample_policy = sample_policy\n\n    def forward(\n        self,\n        target_logits: Tensor,\n        draft_token_ids: LongTensor,\n        bonus_token_ids: LongTensor,\n        draft_probs: Optional[Tensor] = None,\n    ):\n        \"\"\"forward\n        Args:\n            target_logits (Tensor): The logits of target model in shape of [batch_size, num_spec_tokens, vocab_size].\n            draft_token_ids (LongTensor): The input draft tokens ishape of [batch_size, num_spec_tokens]\n            bonus_token_ids (LongTensor): The bonus token ids in shape of [batch_size, 1].\n            draft_probs (Tensor): The probability of draft model in shape of [batch_size, num_spec_tokens, vocab_size].\n                Default to ``None``.\n        \"\"\"\n        output_token_ids, num_rejected_tokens, last_token_ids = rejection_sample(\n            target_logits,\n            draft_token_ids,\n            bonus_token_ids,\n            draft_probs=draft_probs,\n        )\n        return output_token_ids, num_rejected_tokens, last_token_ids\n\n\n@record_function('rejection_sample')\ndef rejection_sample(\n    target_probs: Tensor,\n    draft_token_ids: LongTensor,\n    bonus_token_ids: LongTensor,\n    sample_policy: SamplePolicy = SamplePolicy.ALL_GREEDY,\n    draft_probs: Optional[Tensor] = None,\n):\n    \"\"\"rejection sample\n    Args:\n        target_probs (Tensor):\n\n    \"\"\"\n    assert draft_probs is None or draft_probs.is_contiguous()\n    assert sample_policy == SamplePolicy.ALL_GREEDY, 'only support all greedy sampling policy'\n\n    target_argmax_tokens = target_probs.argmax(dim=-1)\n    return greedy_reject_sampler(draft_token_ids, target_argmax_tokens, bonus_token_ids)\n\n\ndef greedy_reject_sampler(draft_token_ids, target_token_ids, bonus_token_ids):\n    \"\"\"Greedy reject sampler\n    1. keep targets tokens that are equal to draft tokens\n    2. keep first not equal target tokens\n    3. add bonus tokens if all equal\n    Args:\n        draft_token_ids: (batch_size, num_spec_tokens)\n        target_token_ids: (batch_size, num_spec_tokens)\n        bonus_token_ids: (batch_size, 1)\n    Returns:\n        output_token_ids: (batch_size, num_spec_tokens + 1)\n    \"\"\"\n    masks = draft_token_ids == target_token_ids\n    batch_size, num_spec_tokens = draft_token_ids.shape\n    # check rest draft tokens\n    range_data = torch.arange(num_spec_tokens, device=draft_token_ids.device)[None, :]\n    equals = (masks.cumsum(dim=1) - 1) == range_data\n    num_rejected_tokens = num_spec_tokens - equals.sum(dim=1)\n    first_diff_indices = torch.argmin(equals.int(), dim=1, keepdim=True)\n    keeps = range_data.repeat(batch_size, 1) <= first_diff_indices\n    keeps = keeps | equals\n    keep_token_ids = torch.where(keeps, target_token_ids, -1)\n    # add bonus tokens\n    keep_bonus_ids = torch.where(equals[:, -1:], bonus_token_ids, -1)\n    output_token_ids = torch.cat([keep_token_ids, keep_bonus_ids], dim=1)\n    # get last token ids\n    last_indices = (torch.cat([keeps, equals[:, -1:]], dim=1).cumsum(dim=1) - 1)[:, -1].flatten()\n    last_token_ids = output_token_ids[torch.arange(batch_size, device=draft_token_ids.device), last_indices]\n    return output_token_ids, num_rejected_tokens, last_token_ids\n"
  },
  {
    "path": "lmdeploy/pytorch/spec_decode/spec_agent.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nimport asyncio\n\nimport torch\n\nfrom lmdeploy.utils import get_logger\n\nfrom ..backends import get_backend\nfrom ..config import BackendConfig, CacheConfig, ModelConfig, SpecDecodeConfig\nfrom ..engine.cache_engine import CacheEngine\nfrom ..engine.logits_process import SamplingInputs\nfrom ..model_inputs import ModelInputs\nfrom ..strategies.ar_spec.model_agent import ARSpecExtraInputs\nfrom ..strategies.base.model_agent import ExtraInputs\nfrom .base import BaseSpecModelAgent\nfrom .proposers.base import build_specdecode_proposer\nfrom .reject_sampler import RejectionSampler\n\nlogger = get_logger('lmdeploy')\n\n\nclass SpecModelAgent(BaseSpecModelAgent):\n    \"\"\"Speculative model agent.\"\"\"\n\n    def __init__(\n        self,\n        specdecode_config: SpecDecodeConfig,\n        backend_config: BackendConfig,\n        inputs_strategy,\n        agent_strategy,\n        device: str = 'cuda',\n    ):\n        super().__init__(enable=True)\n\n        self.backend_config = backend_config\n        self.device = device\n        self.cache_engine = None\n        self.inputs_strategy = inputs_strategy\n        self.agent_strategy = agent_strategy\n        self.rejection_sampler = RejectionSampler()\n        self.proposer = build_specdecode_proposer(specdecode_config, device=device)\n        self.method = specdecode_config.method\n        self.model_config = specdecode_config.model_config\n        self.cache_config = specdecode_config.cache_config\n        self.num_spec_tokens = specdecode_config.num_speculative_tokens\n\n    def set_cache_config(self, cache_config: CacheConfig):\n        \"\"\"Set all cache config.\"\"\"\n        self.cache_config = cache_config\n\n    def set_model_config(self, model_config: ModelConfig):\n        \"\"\"Set model config.\"\"\"\n        self.model_config = model_config\n\n    def build_model(self, empty_init: bool, target_model=None, build_model_ctx=None):\n        \"\"\"Build draft model.\"\"\"\n        self.proposer.build_model(empty_init, target_model=target_model, build_model_ctx=build_model_ctx)\n\n    def build_graph_runner(self):\n        \"\"\"Build graph runner.\"\"\"\n        backend = get_backend()\n        self.proposer.model = backend.build_graph_runner(self.proposer.model,\n                                                         model_config=self.model_config,\n                                                         cache_config=self.cache_config,\n                                                         backend_config=self.backend_config,\n                                                         device=self.device)\n\n    def build_cache_engine(self, cache_stream: torch.cuda.Stream):\n        \"\"\"Build cache engine.\"\"\"\n        if self.cache_config is not None:\n            self.cache_engine = CacheEngine(self.cache_config,\n                                            self.model_config,\n                                            rank=0,\n                                            tp_rank=0,\n                                            world_size=1,\n                                            cache_stream=cache_stream)\n\n    def _rejection_sampling(self, next_token_ids, model_inputs: 'ModelInputs', extra_inputs: ARSpecExtraInputs):\n        \"\"\"Do rejection sampling.\"\"\"\n        num_rejected_tokens = torch.zeros_like(model_inputs.seq_length)\n        bonus_token_ids = output_token_ids = next_token_ids.unsqueeze(-1)\n        last_token_indices = model_inputs.seq_length.cumsum(0) - 1\n        if model_inputs.is_decoding:\n            # only do rejection sample for decoding with draft tokens\n            input_draft_token_ids = model_inputs.input_ids.squeeze(0).unflatten(0, (-1, self.num_spec_tokens + 1))[:,\n                                                                                                                   1:]\n            output_token_ids, num_rejected_tokens, next_token_ids = self.rejection_sampler(\n                extra_inputs.target_logits,\n                input_draft_token_ids,\n                bonus_token_ids,\n            )\n            # update last token indices\n            last_token_indices = last_token_indices - num_rejected_tokens\n\n        # create new inputs\n        input_ids = model_inputs.input_ids.clone()\n        seq_length = model_inputs.seq_length\n        # # offset by 1 token\n        input_ids[:, :-1] = model_inputs.input_ids[:, 1:]\n        # # update next tokens\n        input_ids[:, last_token_indices] = next_token_ids\n        # use new inputs\n        new_model_inputs = ModelInputs(\n            input_ids=input_ids,\n            seq_length=seq_length,\n            max_kv_seqlen=model_inputs.max_kv_seqlen,\n            max_q_seqlen=model_inputs.max_q_seqlen,\n            sum_kv_seqlen=model_inputs.sum_kv_seqlen,\n            history_lengths=model_inputs.history_lengths.clone(),\n            block_offsets=model_inputs.block_offsets,\n            num_ignored_history=model_inputs.num_ignored_history,\n            is_decoding=model_inputs.is_decoding,\n            target_hidden_states=extra_inputs.target_hidden_states,\n            target_position_ids=extra_inputs.target_position_ids,\n        )\n        new_extra_inputs = ARSpecExtraInputs(\n            next_token_ids=next_token_ids,\n            last_token_indices=last_token_indices,\n            num_rejected_tokens=num_rejected_tokens,\n            output_token_ids=output_token_ids,\n        )\n        return new_model_inputs, new_extra_inputs\n\n    def _forward_impl(self, inputs: ModelInputs):\n        \"\"\"Forward impl.\"\"\"\n        output = self.proposer._forward(inputs, cache_engine=self.cache_engine)\n        return output\n\n    async def _async_forward(self, inputs: ModelInputs):\n        \"\"\"Model forward.\n\n        Args:\n            inputs (Dict): The input data comes from _make_inputs.\n        \"\"\"\n        output = self._forward_impl(inputs)\n        await asyncio.sleep(0)\n        return output\n\n    async def _async_model_forward(self, inputs: ModelInputs, extra_inputs: ARSpecExtraInputs,\n                                   sampling_inputs: SamplingInputs):\n        \"\"\"Model forward.\n\n        Args:\n            inputs (Dict): The input data comes from _make_inputs.\n        \"\"\"\n        outputs = await self._async_forward(inputs)\n        if inputs.is_chunk:\n            return torch.zeros_like(inputs.input_ids)\n\n        loop_count = self.num_spec_tokens - 1\n        draft_token_ids, model_metas, target_hidden_states = self.proposer.get_outputs(outputs, inputs, extra_inputs)\n        draft_tokens_li = [draft_token_ids]\n        if loop_count > 0:\n            # set last_token_indices to None for decoding\n            extra_inputs.last_token_indices = None\n            inputs = self.proposer.update_inputs_decoding(inputs, extra_inputs, draft_token_ids.transpose(0, 1),\n                                                          target_hidden_states, model_metas)\n            for loop_idx in range(loop_count):\n                outputs = await self._async_forward(inputs)\n                draft_token_ids, model_metas, target_hidden_states = self.proposer.get_outputs(outputs, inputs)\n                draft_tokens_li.append(draft_token_ids)\n                if loop_idx < loop_count - 1:\n                    step_seqlens = inputs.seq_length.new_ones(inputs.seq_length.size(0))\n                    inputs.step(draft_token_ids.transpose(0, 1), step_seqlens)\n                    inputs.model_metas = model_metas\n                    inputs.target_hidden_states = target_hidden_states\n                    if inputs.target_position_ids is not None:\n                        inputs.target_position_ids += 1\n\n        output_draft_ids = torch.cat(draft_tokens_li, dim=-1)\n        return output_draft_ids\n\n    async def async_model_forward(\n        self,\n        next_token_ids: torch.Tensor,\n        model_inputs: ModelInputs,\n        extra_inputs: ExtraInputs,\n        sampling_inputs: SamplingInputs,\n    ):\n        \"\"\"Draft model forward.\"\"\"\n        draft_model_inputs, draft_extra_inputs = self._rejection_sampling(next_token_ids, model_inputs, extra_inputs)\n        next_draft_ids = await self._async_model_forward(draft_model_inputs, draft_extra_inputs, sampling_inputs)\n        draft_extra_inputs.output_draft_token_ids = next_draft_ids\n        return draft_extra_inputs\n\n    def warmup(self, max_batches: int, target_model_config: ModelConfig):\n        \"\"\"warmup.\"\"\"\n        target_hidden_size = self.proposer.get_target_hidden_size(target_model_config)\n\n        # warmup prefill\n        inputs = self.inputs_strategy.make_dummy(max_batches,\n                                                 is_decoding=False,\n                                                 device='cuda',\n                                                 vocab_size=self.model_config.vocab_size,\n                                                 target_hidden_size=target_hidden_size,\n                                                 target_dtype=self.model_config.dtype)\n\n        self._forward_impl(inputs)\n\n        capture_batch_sizes = self.proposer.model.get_capture_batch_sizes()\n        capture_batch_sizes = sorted(capture_batch_sizes, reverse=True)\n\n        for batch_size in capture_batch_sizes:\n            # decode with num_spec_tokens + 1 per seq\n            inputs = self.inputs_strategy.make_dummy(\n                batch_size,\n                is_decoding=True,\n                device='cuda',\n                vocab_size=self.model_config.vocab_size,\n                max_q_seqlen=self.num_spec_tokens + 1,\n                target_hidden_size=target_hidden_size,\n                target_dtype=self.model_config.dtype,\n            )\n            self._forward_impl(inputs)\n            # decode 1 tokens per sequence\n            inputs = self.inputs_strategy.make_dummy(\n                batch_size,\n                is_decoding=True,\n                device='cuda',\n                vocab_size=self.model_config.vocab_size,\n                max_q_seqlen=1,\n                target_hidden_size=self.model_config.hidden_size,\n                target_dtype=self.model_config.dtype,\n            )\n            self._forward_impl(inputs)\n\n    def reset_graph_runner(self):\n        'reset graph runner'\n        if self.proposer.model is not None and hasattr(self.proposer.model, 'reset'):\n            self.proposer.model.reset()\n"
  },
  {
    "path": "lmdeploy/pytorch/strategies/__init__.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom lmdeploy.pytorch.config import MiscConfig, ModelConfig, SpecDecodeConfig\n\n\ndef build_strategy_factory(model_config: ModelConfig,\n                           misc_config: MiscConfig,\n                           specdecode_config: SpecDecodeConfig = None):\n    \"\"\"Build strategy factory.\"\"\"\n    model_paradigm = model_config.model_paradigm\n\n    if model_paradigm == 'ar':\n        from .ar import ARStrategyFactory\n        return ARStrategyFactory(model_config=model_config)\n    elif model_paradigm == 'dllm':\n        from .dllm import DLLMStrategyFactory\n        return DLLMStrategyFactory(model_config=model_config, dllm_config=misc_config.dllm_config)\n    elif model_paradigm == 'ar_spec':\n        from .ar_spec import ARSpecStrategyFactory\n        assert specdecode_config is not None, 'specdecode_config must be provided for ar_spec model'\n        return ARSpecStrategyFactory(model_config=model_config, specdecode_config=specdecode_config)\n    else:\n        raise RuntimeError(f'Unsupported model paradigm: {model_paradigm}')\n"
  },
  {
    "path": "lmdeploy/pytorch/strategies/ar/__init__.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom typing import TYPE_CHECKING\n\nfrom lmdeploy.pytorch.config import ModelConfig\nfrom lmdeploy.pytorch.strategies.base.sequence import SequenceStrategy\n\nif TYPE_CHECKING:\n    from lmdeploy.pytorch.strategies.base.cudagraph import CudagraphStrategy\n    from lmdeploy.pytorch.strategies.base.model_inputs import ModelInputsStrategy\n    from lmdeploy.pytorch.strategies.base.sampling import SamplingStrategy\n    from lmdeploy.pytorch.strategies.base.model_agent import ModelAgentStrategy\n    from lmdeploy.pytorch.strategies.base.engine import EngineStrategy\n    from lmdeploy.pytorch.config import CacheConfig, SchedulerConfig\n\nfrom ..base import StrategyFactoryBase\n\n\nclass ARStrategyFactory(StrategyFactoryBase):\n\n    def __init__(self, model_config: ModelConfig):\n        \"\"\"config.\"\"\"\n        self.model_config = model_config\n\n    def build_cudagraph_strategy(self) -> 'CudagraphStrategy':\n        \"\"\"Build cudagraph strategy.\"\"\"\n        from .cudagraph import ARCudagraphStrategy\n        return ARCudagraphStrategy()\n\n    def build_sampling_strategy(self) -> 'SamplingStrategy':\n        \"\"\"Build sampling strategy.\"\"\"\n        from .sampling import ARSamplingStrategy\n        pad_token_id = self.model_config.bos_token_id\n        pad_token_id = 0 if pad_token_id is None else pad_token_id\n        return ARSamplingStrategy(pad_token_id)\n\n    def build_model_inputs_strategy(self) -> 'ModelInputsStrategy':\n        \"\"\"Build model inputs strategy.\"\"\"\n        from .model_inputs import ARModelInputsStrategy\n        return ARModelInputsStrategy()\n\n    def build_model_agent_strategy(self) -> 'ModelAgentStrategy':\n        \"\"\"Build model agent strategy.\"\"\"\n        from .model_agent import ARModelAgentStrategy\n        return ARModelAgentStrategy()\n\n    def build_engine_strategy(self, cache_config: 'CacheConfig',\n                              scheduler_config: 'SchedulerConfig') -> 'EngineStrategy':\n        \"\"\"Build engine strategy.\"\"\"\n        from .engine import AREngineStrategy\n        return AREngineStrategy(cache_config=cache_config, scheduler_config=scheduler_config)\n\n    def build_sequence_strategy(self) -> SequenceStrategy:\n        from .sequence import ARSequenceStrategy\n        return ARSequenceStrategy()\n"
  },
  {
    "path": "lmdeploy/pytorch/strategies/ar/cudagraph.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom ..base.cudagraph import CudagraphStrategy\n\n\nclass ARCudagraphStrategy(CudagraphStrategy):\n\n    def get_max_tokens(self, batch_size: int, origin_batch_size: int, num_tokens: int) -> int:\n        \"\"\"Get max tokens.\"\"\"\n        return batch_size\n"
  },
  {
    "path": "lmdeploy/pytorch/strategies/ar/engine.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom lmdeploy.pytorch.config import CacheConfig, SchedulerConfig\n\nfrom ..base.engine import EngineStrategy\n\n\nclass AREngineStrategy(EngineStrategy):\n    \"\"\"AR Engine Strategy.\"\"\"\n\n    def __init__(self, scheduler_config: SchedulerConfig, cache_config: CacheConfig) -> None:\n        self.scheduler_config = scheduler_config\n        self.cache_config = cache_config\n\n    def get_prealloc_size(self, is_decoding: bool):\n        \"\"\"Get prealloc_size.\"\"\"\n        return self.scheduler_config.prefill_interval if is_decoding else 0\n\n    def get_num_loops(self, is_decoding: bool) -> int:\n        \"\"\"Get num_loops.\"\"\"\n        return self.scheduler_config.prefill_interval if is_decoding else 1\n\n    def get_num_decode_tokens(self) -> int:\n        \"\"\"Get num_decode_tokens.\"\"\"\n        return 1\n"
  },
  {
    "path": "lmdeploy/pytorch/strategies/ar/model_agent.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom contextlib import contextmanager\nfrom dataclasses import dataclass\nfrom typing import Any, Dict, List, Optional, Tuple\n\nimport torch\nimport torch.distributed as dist\nfrom torch.profiler import record_function\n\nfrom lmdeploy.pytorch.distributed import DistContext\nfrom lmdeploy.pytorch.engine.logits_process import SamplingInputs\nfrom lmdeploy.pytorch.messages import SchedulerSequence\nfrom lmdeploy.pytorch.model_inputs import ModelInputs, ModelInputsDelta\n\nfrom ..base.model_agent import ExtraInputs, ExtraOutputs, ModelAgentStrategy, StoppingCriteria\n\nSeqList = List[SchedulerSequence]\n\n\ndef get_model_inputs_next_decoding(inputs: ModelInputs, input_ids: torch.Tensor, max_q_seqlen: int,\n                                   model_metas) -> ModelInputs:\n    \"\"\"Next decoding step.\"\"\"\n    if input_ids.dim() == 1:\n        input_ids = input_ids[None, :]\n    state_offsets = inputs.state_offsets\n    if state_offsets is not None:\n        state_offsets = state_offsets.clone()\n    return ModelInputs(\n        input_ids=input_ids,\n        seq_length=torch.full_like(inputs.seq_length, max_q_seqlen),\n        history_lengths=inputs.history_lengths + inputs.seq_length,\n        block_offsets=inputs.block_offsets,\n        is_decoding=True,\n        num_ignored_history=inputs.num_ignored_history.clone(),\n        max_q_seqlen=max_q_seqlen,\n        max_kv_seqlen=inputs.max_kv_seqlen + max_q_seqlen,\n        sum_kv_seqlen=inputs.sum_kv_seqlen + inputs.seq_length.numel() * inputs.max_q_seqlen,\n        local_adapter_ids=inputs.local_adapter_ids,\n        model_metas=model_metas,\n        state_offsets=state_offsets,\n    )\n\n\n@dataclass\nclass ARExtraInputs(ExtraInputs):\n    \"\"\"Ar extra inputs.\"\"\"\n\n\n@dataclass\nclass ARExtraOutputs(ExtraOutputs):\n    \"\"\"Ar extra outputs.\"\"\"\n\n\n@dataclass\nclass ARStoppingCriteria(StoppingCriteria):\n    num_appendable_ids: torch.Tensor\n\n    def clone(self):\n        \"\"\"clone.\"\"\"\n        return ARStoppingCriteria(num_appendable_ids=self.num_appendable_ids)\n\n    def merge(self, other: 'ARStoppingCriteria'):\n        \"\"\"Merge two stopping criteria.\"\"\"\n        new_num_appendable = torch.cat([self.num_appendable_ids, other.num_appendable_ids], dim=0)\n        return ARStoppingCriteria(num_appendable_ids=new_num_appendable)\n\n    def update(self, delta: ModelInputsDelta):\n        \"\"\"Update stopping criteria.\"\"\"\n        indices = delta.indices\n        new_num_appendable = self.num_appendable_ids[indices]\n        return ARStoppingCriteria(num_appendable_ids=new_num_appendable)\n\n    @record_function('stopping_criteria')\n    def step(self,\n             token_ids: torch.Tensor,\n             stop_words: torch.Tensor,\n             inputs: Optional[ModelInputs] = None,\n             extra_inputs: Optional[ARExtraInputs] = None):\n        \"\"\"Check whether to stop generation.\"\"\"\n        num_appendable_ids = self.num_appendable_ids - 1\n        stopped = num_appendable_ids <= 0\n        stop_pos = torch.zeros_like(num_appendable_ids)\n        if stop_words is not None:\n            sw_stopped = (token_ids[:, None] == stop_words).any(1)\n            stopped = stopped | sw_stopped\n            one_ids = torch.clamp_max(num_appendable_ids, 0)\n            num_appendable_ids = torch.where(sw_stopped, one_ids, num_appendable_ids)\n\n        # I don't know why assign inplace does not works...\n        new_stopping = ARStoppingCriteria(num_appendable_ids=num_appendable_ids)\n        return stopped, stop_pos, new_stopping\n\n\nclass ARModelAgentStrategy(ModelAgentStrategy):\n\n    def slice_outputs(self, inputs: torch.Tensor, seq_length: torch.LongTensor) -> torch.Tensor:\n        \"\"\"Slice outputs.\"\"\"\n        # batch size == 1\n        if len(seq_length) == 1:\n            return inputs[-1:]\n\n        if len(seq_length) == inputs.size(0):\n            return inputs\n        last_idx = seq_length.cumsum(-1) - 1\n        return inputs[last_idx]\n\n    def slice_extra_inputs(self, extra_inputs: ARExtraInputs, model_inputs: ModelInputs,\n                           model_outputs: Dict[str, torch.Tensor], **kwargs) -> ARExtraInputs:\n        \"\"\"Slice outputs.\"\"\"\n        return extra_inputs\n\n    @record_function('step_sampling_inputs')\n    def step_sampling_inputs(self, sampling_inputs: SamplingInputs, next_token_ids: torch.Tensor, **kwargs):\n        \"\"\"step.\"\"\"\n        sampling_inputs.num_ignore_eos = sampling_inputs.num_ignore_eos - 1\n        if sampling_inputs.random_offsets is not None:\n            # random offset is used to generate random numbers for multinomial sampling\n            # so we need to increase it by 1 at each step\n            sampling_inputs.random_offsets += 1\n\n        all_ids = sampling_inputs.all_ids\n        if all_ids is not None:\n            sampling_inputs.all_ids = torch.cat([all_ids, next_token_ids[:, None]], 1)\n\n        return sampling_inputs\n\n    def make_stopping_criteria(self, seqs: SeqList) -> ARStoppingCriteria:\n        \"\"\"Create stopping criteria.\"\"\"\n        num_appendable = [seq.sampling_param.max_new_tokens - seq.num_new_tokens for seq in seqs]\n        num_appendable = torch.tensor(num_appendable)\n        return ARStoppingCriteria(num_appendable_ids=num_appendable)\n\n    def make_extra_inputs(self, seqs: 'SeqList', model_inputs: 'ModelInputs') -> ExtraInputs:\n        \"\"\"Create extra inputs.\"\"\"\n        return ARExtraInputs()\n\n    def make_extra_outputs(self, extra_inputs: ARExtraInputs) -> ARExtraOutputs:\n        \"\"\"Create extra outputs.\"\"\"\n        return ARExtraOutputs()\n\n    def update_prefill_for_next_step(\n        self,\n        model_inputs: 'ModelInputs',\n        extra_inputs: ARExtraInputs,\n        next_token_ids: torch.Tensor,\n        model_metas: Any,\n        extra_outputs: ARExtraOutputs,\n    ) -> Tuple['ModelInputs', ARExtraInputs]:\n        \"\"\"Step next decoding.\"\"\"\n        inputs = get_model_inputs_next_decoding(model_inputs, next_token_ids, max_q_seqlen=1, model_metas=model_metas)\n        return inputs, extra_inputs\n\n    def update_decoding_for_next_step(self, model_inputs: 'ModelInputs', next_token_ids: torch.Tensor, model_metas: Any,\n                                      extra_inputs: ARExtraInputs, **kwargs):\n        \"\"\"Step next inputs.\"\"\"\n        model_inputs.model_metas = model_metas\n        step_seqlens = model_inputs.seq_length\n        model_inputs.step(next_token_ids, step_seqlens)\n        return model_inputs, extra_inputs\n\n    def post_sampling(self, inputs: 'ModelInputs', logits: torch.Tensor, next_token_ids: torch.LongTensor,\n                      extra_inputs: ARExtraInputs):\n        \"\"\"Post sampling.\"\"\"\n        return next_token_ids, extra_inputs\n\n    @contextmanager\n    def broadcast_next_token(self, next_token_ids: torch.Tensor, extra_inputs: ExtraInputs, dist_ctx: DistContext):\n        \"\"\"Broadcast next token ids and extra inputs.\"\"\"\n        tp_gpu_group = dist_ctx.attn_tp_group.gpu_group\n        rank = dist.get_global_rank(tp_gpu_group, 0)\n        handle = dist.broadcast(next_token_ids, src=rank, group=tp_gpu_group, async_op=True)\n        yield\n        handle.wait()\n"
  },
  {
    "path": "lmdeploy/pytorch/strategies/ar/model_inputs.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom typing import Optional\n\nimport numpy as np\nimport torch\nfrom torch.profiler import record_function\n\nfrom lmdeploy.pytorch.model_inputs import ModelInputs, ModelInputsDelta\n\nfrom ..base.model_inputs import ModelInputsStrategy, make_dummy_inputs\n\n\ndef merge_model_inputs(inputs: ModelInputs, other: ModelInputs) -> ModelInputs:\n    \"\"\"Merge model inputs.\"\"\"\n\n    def __try_pad_block_offsets(block_offsets: torch.Tensor, target_size: int):\n        \"\"\"Try pad block offsets to target size.\"\"\"\n        cur_size = block_offsets.size(1)\n        if cur_size < target_size:\n            pad_size = target_size - cur_size\n            pad_tensor = torch.zeros((block_offsets.size(0), pad_size),\n                                     dtype=block_offsets.dtype,\n                                     device=block_offsets.device)\n            block_offsets = torch.cat([block_offsets, pad_tensor], dim=1)\n        return block_offsets\n\n    assert inputs.is_decoding and other.is_decoding, 'Only support merge in decoding.'\n    input_ids = torch.cat([inputs.input_ids, other.input_ids], dim=-1)\n    seq_length = torch.cat([inputs.seq_length, other.seq_length], dim=0)\n    history_lengths = torch.cat([inputs.history_lengths, other.history_lengths], dim=0)\n\n    # block offsets\n    max_blocks = max(inputs.block_offsets.size(1), other.block_offsets.size(1))\n    block_offsets0 = __try_pad_block_offsets(inputs.block_offsets, max_blocks)\n    block_offsets1 = __try_pad_block_offsets(other.block_offsets, max_blocks)\n    block_offsets = torch.cat([block_offsets0, block_offsets1], dim=0)\n    num_ignored_history = torch.cat([inputs.num_ignored_history, other.num_ignored_history], dim=0)\n\n    # lora adapter ids\n    local_adapter_ids = inputs.local_adapter_ids\n    if local_adapter_ids is not None and other.local_adapter_ids is not None:\n        local_adapter_ids = torch.cat([local_adapter_ids, other.local_adapter_ids], dim=0)\n\n    # model metas for vl models\n    model_metas = None\n    if inputs.model_metas is not None and other.model_metas is not None:\n        model_metas = inputs.model_metas + other.model_metas\n\n    # ssm\n    state_offsets = None\n    if inputs.state_offsets is not None:\n        state_offsets = torch.cat([inputs.state_offsets, other.state_offsets], dim=0)\n\n    return ModelInputs(\n        input_ids=input_ids,\n        seq_length=seq_length,\n        history_lengths=history_lengths,\n        block_offsets=block_offsets,\n        is_decoding=inputs.is_decoding,\n        num_ignored_history=num_ignored_history,\n        max_q_seqlen=max(inputs.max_q_seqlen, other.max_q_seqlen),\n        max_kv_seqlen=max(inputs.max_kv_seqlen, other.max_kv_seqlen),\n        sum_kv_seqlen=inputs.sum_kv_seqlen + other.sum_kv_seqlen,\n        local_adapter_ids=local_adapter_ids,\n        model_metas=model_metas,\n        state_offsets=state_offsets,\n    )\n\n\nclass ARModelInputsStrategy(ModelInputsStrategy):\n\n    def make_dummy(self,\n                   batch_size: int,\n                   is_decoding: bool,\n                   device: str = 'cpu',\n                   dummy_block_id: int = 0,\n                   vocab_size: int = 1) -> ModelInputs:\n        \"\"\"Create dummy model inputs.\"\"\"\n        return make_dummy_inputs(batch_size,\n                                 max_q_seqlen=1,\n                                 is_decoding=is_decoding,\n                                 device=device,\n                                 dummy_block_id=dummy_block_id,\n                                 vocab_size=vocab_size)\n\n    @record_function('ModelInputs.merge')\n    def merge(self, inputs: ModelInputs, other: ModelInputs) -> ModelInputs:\n        \"\"\"Merge model inputs.\"\"\"\n        return merge_model_inputs(inputs, other)\n\n    @staticmethod\n    def index_select(inputs: ModelInputs,\n                     indices: torch.Tensor,\n                     indice_cpu: np.ndarray = None,\n                     block_offsets: torch.Tensor = None,\n                     max_q_seqlen: Optional[int] = None,\n                     max_kv_seqlen: Optional[int] = None,\n                     sum_kv_seqlen: Optional[int] = None,\n                     num_ignored_history: Optional[torch.Tensor] = None):\n        \"\"\"Index select.\"\"\"\n        assert inputs.is_decoding, 'Only support index_select in decoding.'\n\n        if len(indices) == len(inputs.seq_length):\n            # we will not change the order of indices\n            # so same length means no change\n            indices = Ellipsis\n\n        # required inputs\n        input_ids = inputs.input_ids[..., indices]\n        seq_length = inputs.seq_length[indices]\n        history_lengths = inputs.history_lengths[indices]\n        if block_offsets is None:\n            block_offsets = inputs.block_offsets[indices]\n        if num_ignored_history is None:\n            num_ignored_history = inputs.num_ignored_history[indices]\n        max_q_seqlen = max_q_seqlen or inputs.max_q_seqlen\n        max_kv_seqlen = max_kv_seqlen or inputs.max_kv_seqlen\n        sum_kv_seqlen = sum_kv_seqlen or inputs.sum_kv_seqlen\n\n        # lora adapter ids\n        local_adapter_ids = inputs.local_adapter_ids\n        if local_adapter_ids is not None:\n            local_adapter_ids = local_adapter_ids[indices]\n\n        # model metas for vl models\n        model_metas = inputs.model_metas\n        if model_metas is not None and indice_cpu is not None:\n            model_metas = [model_metas[i] for i in indice_cpu]\n\n        # for ssm\n        state_offsets = inputs.state_offsets\n        if state_offsets is not None:\n            state_offsets = state_offsets[indices]\n\n        # spec decoding\n        target_hidden_states = inputs.target_hidden_states\n        if target_hidden_states is not None:\n            target_hidden_states = target_hidden_states[indices]\n        target_position_ids = inputs.target_position_ids\n        if target_position_ids is not None:\n            target_position_ids = target_position_ids[indices]\n\n        # return new inputs\n        return ModelInputs(\n            input_ids=input_ids,\n            seq_length=seq_length,\n            history_lengths=history_lengths,\n            block_offsets=block_offsets,\n            is_decoding=inputs.is_decoding,\n            num_ignored_history=num_ignored_history,\n            max_q_seqlen=max_q_seqlen,\n            max_kv_seqlen=max_kv_seqlen,\n            sum_kv_seqlen=sum_kv_seqlen,\n            local_adapter_ids=local_adapter_ids,\n            model_metas=model_metas,\n            state_offsets=state_offsets,\n            target_hidden_states=target_hidden_states,\n            target_position_ids=target_position_ids,\n        )\n\n    @record_function('ModelInputs.update_inputs')\n    def update_inputs(self, inputs: ModelInputs, delta: 'ModelInputsDelta') -> ModelInputs:\n        \"\"\"Update model inputs with delta.\"\"\"\n        assert inputs.is_decoding, 'Only support update_delta in decoding.'\n        return self.index_select(\n            inputs=inputs,\n            indices=delta.indices,\n            indice_cpu=delta.indice_cpu,\n            block_offsets=delta.block_offsets,\n            max_q_seqlen=delta.max_q_seqlen,\n            max_kv_seqlen=delta.max_kv_seqlen,\n            sum_kv_seqlen=delta.sum_kv_seqlen,\n            num_ignored_history=delta.num_ignored_history,\n        )\n"
  },
  {
    "path": "lmdeploy/pytorch/strategies/ar/sampling.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nimport numpy as np\nimport torch\nfrom torch.profiler import record_function\n\nfrom lmdeploy.pytorch.engine.logits_process import SamplingInputs, SamplingInputsDelta\nfrom lmdeploy.pytorch.messages import SchedulerSequence\nfrom lmdeploy.pytorch.model_inputs import ModelInputsDelta\n\nfrom ..base.sampling import SamplingStrategy\n\nSeqList = list[SchedulerSequence]\n\n\ndef _gather_all_ids(pad_id: int, seqs: SeqList, sampling_inputs: SamplingInputs):\n    \"\"\"Gather history.\"\"\"\n    if not any(sampling_inputs.logits_processors):\n        return None\n    batch = len(seqs)\n    max_len = max(seq.num_valid_ids for seq in seqs)\n    output = torch.full((batch, max_len), pad_id, dtype=torch.int64)\n    for idx, seq in enumerate(seqs):\n        h_len = seq.num_valid_ids\n        if h_len == 0:\n            continue\n        h_ids = torch.from_numpy(seq.valid_ids)\n        output[idx, -h_len:] = h_ids\n    return output\n\n\ndef _gather_generated_ids(pad_id: int, seqs: SeqList, sampling_inputs: SamplingInputs) -> np.ndarray | None:\n    \"\"\"Gather history.\"\"\"\n    if sampling_inputs.repetition_penalty is None and sampling_inputs.max_repetition_ngram_size == 0:\n        return None\n    batch = len(seqs)\n    max_len = max(seq.num_new_tokens for seq in seqs)\n    output = np.full((batch, max_len), pad_id, dtype=np.int64)\n    for idx, seq in enumerate(seqs):\n        h_len = seq.num_new_tokens\n        if h_len == 0:\n            continue\n        h_ids = seq.generated_ids\n        output[idx, -h_len:] = h_ids\n    return output\n\n\ndef _get_num_ignore_eos(seqs: SeqList):\n    \"\"\"Get num ignore eos.\"\"\"\n    ret = [seq.sampling_param.min_new_tokens - seq.num_new_tokens for seq in seqs]\n    return torch.tensor(ret)\n\n\nclass ARSamplingStrategy(SamplingStrategy):\n    \"\"\"Sampling strategy for autoregressive models.\"\"\"\n\n    def __init__(self, pad_token_id: int) -> None:\n        pad_token_id = 0 if pad_token_id is None else pad_token_id\n        self.pad_token_id = pad_token_id\n        self.session_to_cleanup = []\n\n    @record_function('make_sampling_inputs')\n    def make_sampling_inputs(self, seqs: SeqList) -> SamplingInputs:\n        \"\"\"Create sampling inputs from the sequences.\"\"\"\n        batch_size = len(seqs)\n        temperature = [None] * batch_size\n        repetition_penalty = [None] * batch_size\n        top_k = [None] * batch_size\n        top_p = [None] * batch_size\n        min_p = [None] * batch_size\n        bad_words = [None] * batch_size\n        stop_words = [None] * batch_size\n        random_seeds = [np.random.randint(0xffffffff)] * batch_size\n        random_offsets = [None] * batch_size\n        response_formats = [None] * batch_size\n        logits_processors = [None] * batch_size\n        num_logprobs = [None] * batch_size\n        session_to_cleanup = self.session_to_cleanup\n        self.session_to_cleanup = []\n        repetition_ngram_sizes = [None] * batch_size\n        repetition_ngram_thresholds = [None] * batch_size\n\n        def __gather_params():\n            \"\"\"Gather params.\"\"\"\n            for idx, seq in enumerate(seqs):\n                param = seq.sampling_param\n                temperature[idx] = param.temperature\n                repetition_penalty[idx] = param.repetition_penalty\n                top_k[idx] = max(0, param.top_k)\n                top_p[idx] = param.top_p\n                min_p[idx] = param.min_p\n                random_offsets[idx] = seq.num_valid_ids\n                response_formats[idx] = param.response_format\n                if param.random_seed is not None:\n                    random_seeds[idx] = param.random_seed & 0xffffffff\n\n                bw = param.bad_words\n                sw = param.stop_words\n                if (not param.ignore_eos and seq.num_new_tokens < param.min_new_tokens):\n                    bw = bw + sw\n                bad_words[idx] = bw\n                stop_words[idx] = sw\n                logits_processors[idx] = param.logits_processors\n                num_logprobs[idx] = param.num_logprobs\n                repetition_ngram_sizes[idx] = param.repetition_ngram_size\n                repetition_ngram_thresholds[idx] = param.repetition_ngram_threshold\n\n        def __get_topp(top_p):\n            \"\"\"Get topp.\"\"\"\n            min_top_p = min(top_p)\n            if min_top_p == 1.0:\n                top_p = None\n            else:\n                top_p = torch.tensor(top_p)\n            return top_p, min_top_p\n\n        def __get_minp(min_p):\n            \"\"\"Get minp.\"\"\"\n            max_min_p = max(min_p)\n            if max_min_p == 0.0:\n                min_p = None\n            else:\n                min_p = torch.Tensor(min_p)\n            return min_p\n\n        def __get_bad_words(bad_words):\n            \"\"\"Get bad words.\"\"\"\n            max_bw_len = max(len(bw) for bw in bad_words)\n            if max_bw_len == 0:\n                return None, None\n            if all(len(bw) == max_bw_len for bw in bad_words):\n                ret = torch.tensor(bad_words)\n                mask = torch.ones_like(ret, dtype=bool)\n                return ret, mask\n            ret = torch.full((batch_size, max_bw_len), -1, dtype=torch.int64)\n            for idx, bw in enumerate(bad_words):\n                bw_len = len(bw)\n                if bw_len == 0:\n                    continue\n                bw = ret.new_tensor(bw)\n                ret[idx, :bw_len] = bw\n\n            mask = ret >= 0\n            return ret, mask\n\n        __gather_params()\n\n        if all(rp == 1.0 for rp in repetition_penalty):\n            repetition_penalty = None\n        else:\n            repetition_penalty = torch.tensor(repetition_penalty)\n\n        temperature = torch.tensor(temperature)\n        if (temperature == 1.0).all():\n            # skip temperature processing if all temperature are 1.0\n            temperature = None\n\n        bad_words, bad_mask = __get_bad_words(bad_words)\n        stop_words, stop_mask = __get_bad_words(stop_words)\n\n        max_top_k = max(top_k)\n        if min(top_k) <= 0:\n            max_top_k = 0\n        if max_top_k == 1:\n            top_k = None\n            top_p, min_top_p = None, 1.0\n            min_p = None\n            random_seeds = None\n        else:\n            top_k = torch.tensor(top_k)\n            if (top_k == max_top_k).all():\n                # we would perform max_top_k before top_k\n                # if all top_k are same, we do not need to filter topk again\n                top_k = None\n            top_p, min_top_p = __get_topp(top_p)\n            min_p = __get_minp(min_p)\n            random_seeds = torch.tensor(random_seeds)\n        random_offsets = torch.tensor(random_offsets)\n\n        max_num_logprobs = max(num_logprobs)\n\n        session_ctx = [{\n            'session_id': seq.session.session_id,\n            'seq_id': seq.seq_id,\n        } for seq in seqs]\n\n        # repetition ngram\n        max_repetition_ngram_size = max(repetition_ngram_sizes)\n        if max_repetition_ngram_size == 0:\n            repetition_ngram_sizes = None\n            repetition_ngram_thresholds = None\n        else:\n            repetition_ngram_sizes = torch.tensor(repetition_ngram_sizes)\n            repetition_ngram_thresholds = torch.tensor(repetition_ngram_thresholds)\n            repetition_ngram_same_n = (repetition_ngram_sizes == max_repetition_ngram_size).all().item()\n            if repetition_ngram_same_n:\n                repetition_ngram_sizes = None\n\n        sampling_input = SamplingInputs(\n            temperature=temperature,\n            bad_words=bad_words,\n            bad_mask=bad_mask,\n            stop_words=stop_words,\n            stop_mask=stop_mask,\n            repetition_penalty=repetition_penalty,\n            top_k=top_k,\n            top_p=top_p,\n            min_p=min_p,\n            random_seeds=random_seeds,\n            random_offsets=random_offsets,\n            response_formats=tuple(response_formats),\n            max_top_k=max_top_k,\n            min_top_p=min_top_p,\n            logits_processors=logits_processors,\n            max_num_logprobs=max_num_logprobs,\n            batch_size=batch_size,\n            session_ctx=session_ctx,\n            session_to_cleanup=session_to_cleanup,\n            repetition_ngram_size=repetition_ngram_sizes,\n            repetition_ngram_threshold=repetition_ngram_thresholds,\n            max_repetition_ngram_size=max_repetition_ngram_size,\n        )\n\n        pad_token_id = self.pad_token_id\n        sampling_input.all_ids = _gather_all_ids(pad_token_id, seqs, sampling_input)\n        sampling_input.generated_ids_cpu = _gather_generated_ids(pad_token_id, seqs, sampling_input)\n        sampling_input.num_ignore_eos = _get_num_ignore_eos(seqs)\n        return sampling_input\n\n    def on_session_end(self, session_id: int):\n        self.session_to_cleanup.append(session_id)\n\n    def merge_sampling_delta(\n        self,\n        sampling_delta: 'SamplingInputsDelta',\n        other: 'SamplingInputsDelta',\n    ) -> 'SamplingInputsDelta':\n        \"\"\"Merge two sampling deltas.\"\"\"\n        num_ignore_eos = torch.cat([sampling_delta.num_ignore_eos, other.num_ignore_eos], 0)\n        random_offsets = torch.cat([sampling_delta.random_offsets, other.random_offsets], 0)\n\n        batch_size = num_ignore_eos.size(0)\n        all_ids0 = sampling_delta.all_ids\n        all_ids1 = other.all_ids\n        if all_ids0 is None and all_ids1 is None:\n            all_ids = None\n        else:\n            max_len0 = 0 if all_ids0 is None else all_ids0.size(1)\n            max_len1 = 0 if all_ids1 is None else all_ids1.size(1)\n            max_len = max(max_len0, max_len1)\n            all_ids = torch.full((batch_size, max_len),\n                                 self.pad_token_id,\n                                 dtype=torch.int64,\n                                 device=num_ignore_eos.device)\n            if all_ids0 is not None:\n                bs0 = all_ids0.size(0)\n                all_ids[:bs0, :max_len0] = all_ids0\n            if all_ids1 is not None:\n                bs1 = all_ids1.size(0)\n                all_ids[-bs1:, :max_len1] = all_ids1\n\n        return SamplingInputsDelta(\n            num_ignore_eos=num_ignore_eos,\n            random_offsets=random_offsets,\n            all_ids=all_ids,\n        )\n\n    def step_sampling_delta(\n        self,\n        sampling_delta: 'SamplingInputsDelta',\n        next_token_ids: torch.Tensor,\n        **kwargs,\n    ) -> 'SamplingInputsDelta':\n        \"\"\"Step next delta.\"\"\"\n        sampling_delta.num_ignore_eos = sampling_delta.num_ignore_eos - 1\n        if sampling_delta.random_offsets is not None:\n            # random offset is used to generate random numbers for multinomial sampling\n            # so we need to increase it by 1 at each step\n            sampling_delta.random_offsets += 1\n\n        all_ids = sampling_delta.all_ids\n        if all_ids is not None:\n            sampling_delta.all_ids = torch.cat([all_ids, next_token_ids[:, None]], 1)\n\n        return sampling_delta\n\n    def update_sampling_delta(\n        self,\n        sampling_delta: 'SamplingInputsDelta',\n        delta: 'ModelInputsDelta',\n    ) -> 'SamplingInputsDelta':\n        \"\"\"Update sampling delta with model inputs delta.\"\"\"\n        indices = delta.indices\n        num_ignore_eos = sampling_delta.num_ignore_eos[indices]\n        if sampling_delta.random_offsets is not None:\n            random_offsets = sampling_delta.random_offsets[indices]\n        else:\n            random_offsets = None\n        all_ids = sampling_delta.all_ids\n        if all_ids is not None:\n            all_ids = all_ids[indices]\n        return SamplingInputsDelta(\n            num_ignore_eos=num_ignore_eos,\n            random_offsets=random_offsets,\n            all_ids=all_ids,\n        )\n"
  },
  {
    "path": "lmdeploy/pytorch/strategies/ar/sequence.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport time\nfrom dataclasses import dataclass\nfrom typing import Any, Dict, List, Optional\n\nimport numpy as np\nfrom torch import Tensor\n\nfrom lmdeploy.pytorch.disagg.conn.protocol import MigrationRequest\nfrom lmdeploy.pytorch.engine.model_agent import BatchedOutputs\nfrom lmdeploy.pytorch.messages import (InputEmbeddings, MessageStatus, MultiModalInputs, SamplingParam,\n                                       SchedulerSequence, SchedulerSession, UpdateTokenMode, _to_ndarray)\nfrom lmdeploy.pytorch.model_inputs import ModelInputs, ModelInputsDelta\n\nfrom ..base.sequence import SequenceStrategy\n\nSeqList = List[SchedulerSequence]\n\n\n@dataclass\nclass SchedulerSequenceDefault(SchedulerSequence):\n\n    def update_token_ids(self,\n                         token_ids: Tensor,\n                         multimodals: MultiModalInputs = None,\n                         embeddings: List[InputEmbeddings] = None,\n                         model_meta: Dict[str, Any] = None,\n                         mode: UpdateTokenMode = UpdateTokenMode.INPUTS,\n                         routed_experts: np.ndarray = None,\n                         **kwargs):\n        \"\"\"Update token ids, old token ids will be added to history.\"\"\"\n        # update history image nums\n        self._update_embeddings(embeddings)\n\n        # update multimodals\n        self._update_multimodals(multimodals)\n\n        token_ids = _to_ndarray(token_ids)\n\n        num_valid = len(token_ids)\n        # record cached expert ids\n        self.append_routed_experts(routed_experts)\n\n        if mode == UpdateTokenMode.INPUTS:\n            self.arrive_time = time.perf_counter()\n            self.output_start_pos = self.num_all_ids + len(token_ids)\n            self._num_token_ids += num_valid\n            self.num_new_tokens = 0\n        else:\n            self._num_history_ids += self._num_token_ids\n            num_token_ids = num_valid\n            self._num_token_ids = num_token_ids\n            self.num_new_tokens += num_token_ids\n\n        self.history_cache.append(token_ids)\n\n        if model_meta is not None:\n            self.model_meta = model_meta\n\n    def set_step(self, step: int):\n        \"\"\"Set step.\"\"\"\n        num_all_ids = self.num_all_ids\n        # update step for vlm\n        if len(self.history_embeddings) > 0:\n            new_step, self._num_history_images, self._num_images = \\\n                self.history_embeddings.get_step(step)\n            assert 0 <= new_step <= step\n            step = new_step\n        self._num_history_ids = step\n        self._num_token_ids = num_all_ids - step\n        self.num_ignored_history = min(step, self.num_ignored_history)\n\n        self.model_meta = None\n\n        if self.return_routed_experts:\n            # chunk long context might not have all routed experts\n            if len(self.all_routed_experts) > step:\n                self.all_routed_experts.resize(step)\n\n\nclass ARSequenceStrategy(SequenceStrategy):\n\n    def make_sequence(self,\n                      seq_id: int,\n                      session: 'SchedulerSession',\n                      sampling_param: 'SamplingParam' = None,\n                      adapter_name: str = None,\n                      migration_request: Optional[MigrationRequest] = None,\n                      resp_cache: bool = False,\n                      preserve_cache: bool = False) -> 'SchedulerSequence':\n        \"\"\"Make sequence.\"\"\"\n        return SchedulerSequenceDefault(\n            seq_id=seq_id,\n            session=session,\n            sampling_param=sampling_param,\n            adapter_name=adapter_name,\n            migration_request=migration_request,\n            resp_cache=resp_cache,\n            preserve_cache=preserve_cache,\n        )\n\n    def update_running(self, running: SeqList, batched_outputs: BatchedOutputs, model_inputs: 'ModelInputs',\n                       delta: 'ModelInputsDelta') -> None:\n        \"\"\"Update running sequences.\"\"\"\n        next_token_ids = batched_outputs.next_token_ids\n        stopped = batched_outputs.stopped\n        stopped = stopped.tolist()\n        model_metas = batched_outputs.model_metas\n        if model_metas is None:\n            model_metas = [None] * len(running)\n\n        next_token_ids = next_token_ids.numpy()\n        if model_inputs is None:\n            num_tokens = delta.seq_length.tolist()\n            is_decoding = delta.is_decoding\n        else:\n            num_tokens = model_inputs.seq_length.tolist()\n            is_decoding = model_inputs.is_decoding\n        all_routed_experts = [None] * len(num_tokens)\n        if batched_outputs.all_routed_experts is not None:\n            all_routed_experts = batched_outputs.all_routed_experts.split(num_tokens, dim=0)\n            all_routed_experts = [experts.numpy() for experts in all_routed_experts]\n        update_mode = UpdateTokenMode.DECODE if is_decoding else UpdateTokenMode.PREFILL\n        for token, msg, stop, model_meta, routed_experts in zip(next_token_ids, running, stopped, model_metas,\n                                                                all_routed_experts):\n            if msg.status != MessageStatus.RUNNING:\n                continue\n\n            # fill token\n            msg.update_token_ids(token, model_meta=model_meta, mode=update_mode, routed_experts=routed_experts)\n            if stop:\n                msg.state.finish()\n"
  },
  {
    "path": "lmdeploy/pytorch/strategies/ar_spec/__init__.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom typing import TYPE_CHECKING\n\nfrom lmdeploy.pytorch.config import ModelConfig, SpecDecodeConfig\nfrom lmdeploy.pytorch.strategies.base.sequence import SequenceStrategy\n\nif TYPE_CHECKING:\n    from lmdeploy.pytorch.strategies.base.cudagraph import CudagraphStrategy\n    from lmdeploy.pytorch.strategies.base.model_inputs import ModelInputsStrategy\n    from lmdeploy.pytorch.strategies.base.sampling import SamplingStrategy\n    from lmdeploy.pytorch.strategies.base.model_agent import ModelAgentStrategy\n    from lmdeploy.pytorch.strategies.base.engine import EngineStrategy\n    from lmdeploy.pytorch.config import CacheConfig, SchedulerConfig\n\nfrom ..base import StrategyFactoryBase\n\n\nclass ARSpecStrategyFactory(StrategyFactoryBase):\n\n    def __init__(self, model_config: ModelConfig, specdecode_config: SpecDecodeConfig):\n        \"\"\"config.\"\"\"\n        self.model_config = model_config\n        self.specdecode_config = specdecode_config\n        self.pad_token_id = model_config.bos_token_id or 0\n\n    def build_cudagraph_strategy(self) -> 'CudagraphStrategy':\n        \"\"\"Build cudagraph strategy.\"\"\"\n        from .cudagraph import ARSpecCudagraphStrategy\n        return ARSpecCudagraphStrategy(self.specdecode_config.num_speculative_tokens)\n\n    def build_sampling_strategy(self) -> 'SamplingStrategy':\n        \"\"\"Build sampling strategy.\"\"\"\n        from .sampling import ARSpecSamplingStrategy\n        pad_token_id = self.model_config.bos_token_id\n        pad_token_id = 0 if pad_token_id is None else pad_token_id\n        return ARSpecSamplingStrategy(pad_token_id)\n\n    def build_model_inputs_strategy(self) -> 'ModelInputsStrategy':\n        \"\"\"Build model inputs strategy.\"\"\"\n        from .model_inputs import ARSpecModelInputsStrategy\n        return ARSpecModelInputsStrategy(self.specdecode_config.num_speculative_tokens)\n\n    def build_model_agent_strategy(self) -> 'ModelAgentStrategy':\n        \"\"\"Build model agent strategy.\"\"\"\n        from .model_agent import ARSpecModelAgentStrategy\n        return ARSpecModelAgentStrategy(self.specdecode_config.num_speculative_tokens)\n\n    def build_engine_strategy(self, cache_config: 'CacheConfig',\n                              scheduler_config: 'SchedulerConfig') -> 'EngineStrategy':\n        \"\"\"Build engine strategy.\"\"\"\n        from .engine import ARSpecEngineStrategy\n        return ARSpecEngineStrategy(cache_config=cache_config,\n                                    scheduler_config=scheduler_config,\n                                    num_spec_tokens=self.specdecode_config.num_speculative_tokens)\n\n    def build_sequence_strategy(self) -> SequenceStrategy:\n        from .sequence import ARSpecSequenceStrategy\n        return ARSpecSequenceStrategy()\n"
  },
  {
    "path": "lmdeploy/pytorch/strategies/ar_spec/cudagraph.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom ..base.cudagraph import CudagraphStrategy\n\n\nclass ARSpecCudagraphStrategy(CudagraphStrategy):\n\n    def __init__(self, num_spec_tokens: int):\n        super().__init__()\n        self.num_spec_tokens = num_spec_tokens\n\n    def get_max_tokens(self, batch_size: int, origin_batch_size: int, num_tokens: int) -> int:\n        \"\"\"Get max tokens.\"\"\"\n        if num_tokens == origin_batch_size:\n            return batch_size\n\n        assert num_tokens % (self.num_spec_tokens + 1) == 0, 'The input_ids length must be divisible by batch_size.'\n        return batch_size * (self.num_spec_tokens + 1)\n"
  },
  {
    "path": "lmdeploy/pytorch/strategies/ar_spec/engine.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom lmdeploy.pytorch.config import CacheConfig, SchedulerConfig\n\nfrom ..base.engine import EngineStrategy\n\n\nclass ARSpecEngineStrategy(EngineStrategy):\n    \"\"\"AR Engine Strategy.\"\"\"\n\n    def __init__(self, scheduler_config: SchedulerConfig, cache_config: CacheConfig, num_spec_tokens: int) -> None:\n        self.scheduler_config = scheduler_config\n        self.cache_config = cache_config\n        self.num_spec_tokens = num_spec_tokens\n\n    def get_prealloc_size(self, is_decoding: bool):\n        \"\"\"Get prealloc_size.\"\"\"\n        return self.scheduler_config.prefill_interval * (1 +\n                                                         self.num_spec_tokens) if is_decoding else self.num_spec_tokens\n\n    def get_num_loops(self, is_decoding: bool) -> int:\n        \"\"\"Get num_loops.\"\"\"\n        return self.scheduler_config.prefill_interval if is_decoding else 1\n\n    def get_num_decode_tokens(self) -> int:\n        \"\"\"Get num_decode_tokens.\"\"\"\n        return self.num_spec_tokens + 1\n"
  },
  {
    "path": "lmdeploy/pytorch/strategies/ar_spec/model_agent.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom contextlib import contextmanager\nfrom dataclasses import dataclass\nfrom typing import Any, Dict, List, Optional, Tuple\n\nimport torch\nimport torch.distributed as dist\nfrom torch.profiler import record_function\n\nfrom lmdeploy.pytorch.distributed import DistContext\nfrom lmdeploy.pytorch.engine.logits_process import SamplingInputs\nfrom lmdeploy.pytorch.messages import SchedulerSequence\nfrom lmdeploy.pytorch.model_inputs import ModelInputs, ModelInputsDelta\n\nfrom ..ar.model_agent import ARStoppingCriteria, get_model_inputs_next_decoding\nfrom ..base.model_agent import ExtraInputs, ExtraOutputs, ModelAgentStrategy\n\nSeqList = List[SchedulerSequence]\n\n\n@dataclass\nclass ARSpecExtraInputs(ExtraInputs):\n    \"\"\"ARSpec extra inputs.\"\"\"\n    # draft model inputs\n    target_logits: torch.Tensor = None\n    target_hidden_states: torch.Tensor = None\n    target_position_ids: torch.Tensor = None\n    next_token_ids: torch.LongTensor = None\n    last_token_indices: torch.LongTensor = None\n\n    # draft model outputs\n    output_draft_token_ids: torch.Tensor = None\n    num_rejected_tokens: torch.Tensor = None\n    output_token_ids: torch.Tensor = None\n\n    def __repr__(self):\n        return (f'ARSpecExtraInputs(next_token_ids={self.next_token_ids}, '\n                f'output_draft_token_ids={self.output_draft_token_ids}, '\n                f'last_token_indices={self.last_token_indices}, '\n                f'num_rejected_tokens={self.num_rejected_tokens}, '\n                f'output_token_ids={self.output_token_ids})')\n\n    def broadcast(self, src: int, group, async_op=False):\n        dist.broadcast(self.output_draft_token_ids, src=src, group=group, async_op=async_op)\n        handle = dist.broadcast(self.num_rejected_tokens, src=src, group=group, async_op=async_op)\n        return handle\n\n    def merge(self, other: 'ARSpecExtraInputs'):\n        \"\"\"Merge extra inputs.\"\"\"\n        output_token_ids = torch.cat([self.output_token_ids, other.output_token_ids], dim=0)\n        return ARSpecExtraInputs(output_token_ids=output_token_ids)\n\n\n@dataclass\nclass ARSpecExtraOutputs(ExtraOutputs):\n    \"\"\"ARSpec extra outputs.\"\"\"\n    # output the draft tokens to seq only for last loop step\n    draft_token_ids: torch.Tensor = None\n\n    def __repr__(self):\n        return (f'ARSpecExtraOutputs(draft_token_ids={self.draft_token_ids})')\n\n\n@dataclass\nclass ARSpecStoppingCriteria(ARStoppingCriteria):\n    num_appendable_ids: torch.Tensor\n\n    def clone(self):\n        \"\"\"clone.\"\"\"\n        return ARSpecStoppingCriteria(num_appendable_ids=self.num_appendable_ids)\n\n    def merge(self, other: 'ARSpecStoppingCriteria'):\n        \"\"\"Merge two stopping criteria.\"\"\"\n        new_num_appendable = torch.cat([self.num_appendable_ids, other.num_appendable_ids], dim=0)\n        return ARSpecStoppingCriteria(num_appendable_ids=new_num_appendable)\n\n    def update(self, delta: ModelInputsDelta):\n        \"\"\"Update stopping criteria.\"\"\"\n        indices = delta.indices\n        new_num_appendable = self.num_appendable_ids[indices]\n        return ARSpecStoppingCriteria(num_appendable_ids=new_num_appendable)\n\n    @record_function('stopping_criteria')\n    def step(self,\n             next_token_ids: torch.Tensor,\n             stop_words: torch.Tensor,\n             inputs: Optional[ModelInputs] = None,\n             extra_inputs: Optional[ARSpecExtraInputs] = None):\n        \"\"\"Check whether to stop generation.\"\"\"\n        token_ids = extra_inputs.output_token_ids\n\n        if token_ids.ndim == 1:\n            token_ids = token_ids.unsqueeze(-1)\n        valid_tokens = token_ids > -1\n        mask = (self.num_appendable_ids.unsqueeze(-1) - valid_tokens.cumsum(dim=-1)) <= 0\n        if stop_words is not None:\n            token_ids_rsp = token_ids.unsqueeze(-1).repeat(1, 1, stop_words.numel())\n            stop_words_rsp = stop_words.reshape(1, 1, -1)\n            assert stop_words_rsp.ndim == token_ids_rsp.ndim == 3\n            stop_mask = (token_ids_rsp == stop_words_rsp).any(-1)\n            mask = mask ^ stop_mask\n        # find the index of first `1`,  if not found, would be 0\n        index = torch.argmax(mask.int(), dim=-1, keepdim=True)\n        # update index of 0 to -1 if not found\n        stop_pos = torch.where(index == 0, mask[:, 0:1].int() - 1, index).ravel()\n        stopped = stop_pos != -1\n        num_valid_tokens = valid_tokens.sum(dim=-1)\n        num_appendable_ids = self.num_appendable_ids - num_valid_tokens\n        one_ids = torch.clamp_max(num_appendable_ids, 0)\n        num_appendable_ids = torch.where(stopped, one_ids, num_appendable_ids)\n        return stopped, stop_pos, ARSpecStoppingCriteria(num_appendable_ids=num_appendable_ids)\n\n\nclass ARSpecModelAgentStrategy(ModelAgentStrategy):\n\n    def __init__(self, num_spec_tokens: int):\n        self.num_spec_tokens = num_spec_tokens\n\n    def slice_outputs(self, inputs: torch.Tensor, seq_length: torch.LongTensor) -> torch.Tensor:\n        \"\"\"Slice outputs.\"\"\"\n        # batch size == 1\n        if len(seq_length) == 1:\n            return inputs[-1:]\n\n        if len(seq_length) == inputs.size(0):\n            return inputs\n        last_idx = seq_length.cumsum(-1) - 1\n        return inputs[last_idx]\n\n    def slice_extra_inputs(self, extra_inputs: ARSpecExtraInputs, model_inputs: ModelInputs,\n                           model_outputs: Dict[str, torch.Tensor], **kwargs) -> ARSpecExtraInputs:\n        \"\"\"Slice outputs.\"\"\"\n        extra_inputs = ARSpecExtraInputs()\n        extra_inputs.target_hidden_states = model_outputs.get('hidden_states')\n        extra_inputs.target_position_ids = model_outputs.get('position_ids', None)\n        if model_inputs.is_decoding:\n            batch_size = model_inputs.seq_length.size(0)\n            logits = model_outputs['logits'][0]\n            extra_inputs.target_logits = logits.unflatten(0, (batch_size, -1))[:, :-1]\n        return extra_inputs\n\n    def step_sampling_inputs(self, sampling_inputs: SamplingInputs, next_token_ids: torch.Tensor, **kwargs):\n        \"\"\"step.\"\"\"\n        sampling_inputs.num_ignore_eos = sampling_inputs.num_ignore_eos - 1\n\n        all_ids = sampling_inputs.all_ids\n        if all_ids is not None:\n            sampling_inputs.all_ids = torch.cat([all_ids, next_token_ids[:, None]], 1)\n\n        return sampling_inputs\n\n    def make_stopping_criteria(self, seqs: SeqList) -> ARSpecStoppingCriteria:\n        \"\"\"Create stopping criteria.\"\"\"\n        num_appendable = [seq.sampling_param.max_new_tokens - seq.num_new_tokens for seq in seqs]\n        num_appendable = torch.tensor(num_appendable)\n        return ARSpecStoppingCriteria(num_appendable_ids=num_appendable)\n\n    def make_extra_inputs(self, seqs: 'SeqList', model_inputs: 'ModelInputs') -> ExtraInputs:\n        \"\"\"Create extra inputs.\"\"\"\n        return ARSpecExtraInputs()\n\n    def update_extra_inputs(self, extra_inputs: ARSpecExtraInputs, delta: 'ModelInputsDelta') -> ARSpecExtraInputs:\n        \"\"\"Update extra inputs with model inputs delta.\"\"\"\n        indices = delta.indices\n        output_token_ids = extra_inputs.output_token_ids[indices]\n        return ARSpecExtraInputs(output_token_ids=output_token_ids)\n\n    def make_extra_outputs(self, extra_inputs: ARSpecExtraInputs) -> ARSpecExtraOutputs:\n        \"\"\"Create extra outputs.\"\"\"\n        output = ARSpecExtraOutputs()\n        output.draft_token_ids = extra_inputs.output_draft_token_ids\n        return output\n\n    def update_prefill_for_next_step(\n        self,\n        model_inputs: 'ModelInputs',\n        extra_inputs: ARSpecExtraInputs,\n        next_token_ids: torch.Tensor,\n        model_metas: Any,\n        extra_outputs: ARSpecExtraOutputs,\n    ) -> Tuple['ModelInputs', ARSpecExtraInputs]:\n        \"\"\"Step next decoding.\"\"\"\n        next_token_ids = next_token_ids[:, None]\n        next_token_ids = torch.cat([next_token_ids, extra_outputs.draft_token_ids], dim=-1)\n        max_q_seqlen = next_token_ids.size(-1)\n        next_token_ids = next_token_ids.flatten()[None, :]\n        inputs = get_model_inputs_next_decoding(model_inputs,\n                                                next_token_ids,\n                                                max_q_seqlen=max_q_seqlen,\n                                                model_metas=model_metas)\n        extra_inputs = ARSpecExtraInputs(output_token_ids=extra_outputs.draft_token_ids)\n        return inputs, extra_inputs\n\n    def update_decoding_for_next_step(self, model_inputs: 'ModelInputs', next_token_ids: torch.Tensor, model_metas: Any,\n                                      extra_inputs: ARSpecExtraInputs, extra_outputs: ARSpecExtraOutputs):\n        \"\"\"Step next inputs.\"\"\"\n        model_inputs.model_metas = model_metas\n        step_seqlens = model_inputs.seq_length\n        batch_size = step_seqlens.size(0)\n\n        # update extra inputs\n        extra_inputs.output_token_ids = extra_outputs.draft_token_ids\n\n        # update inputs\n        step_seqlens = model_inputs.seq_length - extra_inputs.num_rejected_tokens\n        input_ids = next_token_ids.new_empty((batch_size, self.num_spec_tokens + 1))\n        input_ids[:, 0] = next_token_ids\n        input_ids[:, 1:] = extra_inputs.output_draft_token_ids\n        input_ids = input_ids.flatten()[None, :]\n        model_inputs.step(input_ids, step_seqlens)\n        return model_inputs, extra_inputs\n\n    def post_sampling(self, inputs: 'ModelInputs', logits: torch.Tensor, next_token_ids: torch.LongTensor,\n                      extra_inputs: ARSpecExtraInputs):\n        \"\"\"Post sampling.\"\"\"\n        return next_token_ids, extra_inputs\n\n    def make_dummy_next_token(self, inputs: 'ModelInputs', logits: torch.Tensor, extra_inputs: ExtraInputs):\n        \"\"\"Make dummy next token for broadcast.\"\"\"\n        with torch.inference_mode():\n            next_token_ids = inputs.input_ids.new_zeros(logits.size(0))\n            extra_inputs.output_draft_token_ids = inputs.input_ids.new_zeros((logits.size(0), self.num_spec_tokens))\n            extra_inputs.num_rejected_tokens = inputs.input_ids.new_zeros(logits.size(0))\n        return next_token_ids, extra_inputs\n\n    @contextmanager\n    def broadcast_next_token(self, next_token_ids: torch.Tensor, extra_inputs: ARSpecExtraInputs,\n                             dist_ctx: DistContext):\n        \"\"\"Broadcast next token ids and extra inputs.\"\"\"\n        tp_gpu_group = dist_ctx.attn_tp_group.gpu_group\n        rank = dist.get_global_rank(tp_gpu_group, 0)\n        dist.broadcast(next_token_ids, src=rank, group=tp_gpu_group, async_op=True)\n        handle = extra_inputs.broadcast(src=rank, group=tp_gpu_group, async_op=True)\n        yield\n        handle.wait()\n"
  },
  {
    "path": "lmdeploy/pytorch/strategies/ar_spec/model_inputs.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport torch\nfrom torch.profiler import record_function\n\nfrom lmdeploy.pytorch.model_inputs import ModelInputs, ModelInputsDelta\n\nfrom ..ar.model_inputs import merge_model_inputs\nfrom ..base.model_inputs import ModelInputsStrategy, make_dummy_inputs\n\n\nclass ARSpecModelInputsStrategy(ModelInputsStrategy):\n\n    def __init__(self, num_spec_tokens: int):\n        self.num_spec_tokens = num_spec_tokens\n\n    def make_dummy(\n        self,\n        batch_size: int,\n        is_decoding: bool,\n        device: str = 'cpu',\n        dummy_block_id: int = 0,\n        vocab_size: int = 1,\n        max_q_seqlen: int = 1,\n        target_hidden_size: int = None,\n        target_dtype: torch.dtype = torch.bfloat16,\n    ) -> ModelInputs:\n        \"\"\"Create dummy model inputs.\"\"\"\n        inputs = make_dummy_inputs(batch_size,\n                                   max_q_seqlen=max_q_seqlen,\n                                   is_decoding=is_decoding,\n                                   device=device,\n                                   dummy_block_id=dummy_block_id,\n                                   vocab_size=vocab_size)\n        if target_hidden_size is not None:\n            inputs.target_hidden_states = torch.randn((1, batch_size * max_q_seqlen, target_hidden_size),\n                                                      dtype=target_dtype,\n                                                      device=device)\n        return inputs\n\n    @record_function('ModelInputs.merge')\n    def merge(self, inputs: ModelInputs, other: ModelInputs) -> ModelInputs:\n        \"\"\"Merge model inputs.\"\"\"\n        return merge_model_inputs(inputs, other)\n\n    @record_function('ModelInputs.update_inputs')\n    def update_inputs(self, inputs: ModelInputs, delta: 'ModelInputsDelta') -> ModelInputs:\n        \"\"\"Update model inputs with delta.\"\"\"\n        assert inputs.is_decoding, 'Only support update_delta in decoding.'\n        indices = delta.indices\n        indice_cpu = delta.indice_cpu\n        block_offsets = delta.block_offsets\n        max_q_seqlen = delta.max_q_seqlen\n        max_kv_seqlen = delta.max_kv_seqlen\n        sum_kv_seqlen = delta.sum_kv_seqlen\n        num_ignored_history = delta.num_ignored_history\n\n        # required inputs\n        # input_ids = inputs.input_ids[..., indices]\n        inputs_ids = inputs.input_ids.reshape(1, -1, self.num_spec_tokens + 1)\n        input_ids = inputs_ids[:, indices].reshape(1, -1)\n        seq_length = inputs.seq_length[indices]\n        history_lengths = inputs.history_lengths[indices]\n        if block_offsets is None:\n            block_offsets = inputs.block_offsets[indices]\n        if num_ignored_history is None:\n            num_ignored_history = inputs.num_ignored_history[indices]\n        max_q_seqlen = max_q_seqlen or inputs.max_q_seqlen\n        max_kv_seqlen = max_kv_seqlen or inputs.max_kv_seqlen\n        sum_kv_seqlen = sum_kv_seqlen or inputs.sum_kv_seqlen\n\n        # lora adapter ids\n        local_adapter_ids = inputs.local_adapter_ids\n        if local_adapter_ids is not None:\n            local_adapter_ids = local_adapter_ids[indices]\n\n        # model metas for vl models\n        model_metas = inputs.model_metas\n        if model_metas is not None and indice_cpu is not None:\n            model_metas = [model_metas[i] for i in indice_cpu]\n\n        # for ssm\n        state_offsets = inputs.state_offsets\n        if state_offsets is not None:\n            state_offsets = state_offsets[indices]\n\n        # return new inputs\n        return ModelInputs(\n            input_ids=input_ids,\n            seq_length=seq_length,\n            history_lengths=history_lengths,\n            block_offsets=block_offsets,\n            is_decoding=inputs.is_decoding,\n            num_ignored_history=num_ignored_history,\n            max_q_seqlen=max_q_seqlen,\n            max_kv_seqlen=max_kv_seqlen,\n            sum_kv_seqlen=sum_kv_seqlen,\n            local_adapter_ids=local_adapter_ids,\n            model_metas=model_metas,\n            state_offsets=state_offsets,\n        )\n"
  },
  {
    "path": "lmdeploy/pytorch/strategies/ar_spec/sampling.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom ..ar.sampling import ARSamplingStrategy\n\n\nclass ARSpecSamplingStrategy(ARSamplingStrategy):\n    \"\"\"Sampling strategy for AR with spec models.\"\"\"\n"
  },
  {
    "path": "lmdeploy/pytorch/strategies/ar_spec/sequence.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport time\nfrom dataclasses import dataclass\nfrom typing import Any, Dict, List, Optional\n\nimport numpy as np\nfrom torch import Tensor\n\nfrom lmdeploy.pytorch.disagg.conn.protocol import MigrationRequest\nfrom lmdeploy.pytorch.engine.model_agent import BatchedOutputs\nfrom lmdeploy.pytorch.messages import (InputEmbeddings, MessageStatus, MultiModalInputs, SamplingParam,\n                                       SchedulerSession, UpdateTokenMode, _to_ndarray)\nfrom lmdeploy.pytorch.model_inputs import ModelInputs, ModelInputsDelta\n\nfrom ..ar.sequence import ARSequenceStrategy, SchedulerSequenceDefault\n\nSeqList = List['SchedulerSequenceARSpec']\n\n\n@dataclass\nclass SchedulerSequenceARSpec(SchedulerSequenceDefault):\n\n    def __post_init__(self):\n        \"\"\"Post init.\"\"\"\n        super().__post_init__()\n        self._num_spec_ids: int = 0\n        self._num_new_valid: int = 0\n        self._num_valid_ids: int = len(self.history_cache)\n        self._strategy: ARSpecSequenceStrategy = self._seq_meta.strategy\n\n    @property\n    def num_valid_ids(self):\n        return self._num_valid_ids\n\n    @property\n    def num_spec_ids(self):\n        return self._num_spec_ids\n\n    @property\n    def generated_ids(self) -> np.ndarray:\n        end = self.num_valid_ids\n        start = end - self.num_new_tokens\n        return self.history_cache[start:end]\n\n    def set_stop_pos(self, pos: int):\n        val = self._num_new_valid - pos - 1\n        self._num_valid_ids -= val\n        self.num_new_tokens -= val\n        self._num_token_ids = 1\n        self._num_history_ids -= val\n\n        self._num_spec_ids = 0\n        self._num_new_valid = 0\n        self.history_cache.resize(self.num_valid_ids)\n\n    def _update_token_ids_inputs(self, token_ids: np.ndarray):\n        \"\"\"Append tokens.\"\"\"\n        num_tokens = len(token_ids)\n        self.output_start_pos = self.num_valid_ids + num_tokens\n        self._num_valid_ids = self.num_history_ids + num_tokens\n        self._num_token_ids = num_tokens\n        self.num_new_tokens = 0\n        self._num_spec_ids = 0\n        self._num_new_valid = 0\n        self.history_cache.append(token_ids)\n\n    def _update_token_ids_prefill(self, token_ids: np.ndarray, draft_token_ids: np.ndarray):\n        \"\"\"Update token ids for prefill.\"\"\"\n        num_valid = len(token_ids)\n        self._num_spec_ids = len(draft_token_ids)\n        token_ids = np.concatenate([token_ids, draft_token_ids])\n        num_tokens = len(token_ids)\n        self._num_history_ids += self._num_token_ids\n        self._num_token_ids = num_tokens\n        self.num_new_tokens += num_valid\n        self._num_new_valid = num_valid\n        self._num_valid_ids = self.num_history_ids + num_valid\n        self.history_cache.append(token_ids)\n\n    def _update_token_ids_decode(self, token_ids: np.ndarray, draft_token_ids: np.ndarray = None):\n        \"\"\"Update token ids for decode.\"\"\"\n        valid_ids = token_ids[token_ids > -1]\n        num_valid = len(valid_ids)\n        self.num_new_tokens = self.num_new_tokens + num_valid\n\n        self._num_new_valid = num_valid\n        self._num_valid_ids += num_valid\n        self._num_history_ids = self.num_valid_ids - 1\n\n        # last step has spec ids\n        if self.num_spec_ids > 0:\n            token_ids = valid_ids[-1:]\n        else:\n            token_ids = valid_ids\n\n        num_tokens = len(token_ids)\n\n        if draft_token_ids is not None:\n            num_tokens = 1 + len(draft_token_ids)\n            token_ids = np.concatenate([token_ids, draft_token_ids])\n            self._num_spec_ids = len(draft_token_ids)\n        else:\n            self._num_spec_ids = 0\n\n        self._num_token_ids = num_tokens\n        if self.num_history_ids < len(self.history_cache):\n            self.history_cache.resize(self.num_history_ids)\n        self.history_cache.append(token_ids)\n\n    def update_token_ids(self,\n                         token_ids: Tensor,\n                         multimodals: MultiModalInputs = None,\n                         embeddings: List[InputEmbeddings] = None,\n                         model_meta: Dict[str, Any] = None,\n                         draft_token_ids: Tensor = None,\n                         mode: UpdateTokenMode = UpdateTokenMode.INPUTS,\n                         **kwargs):\n        \"\"\"Update token ids, old token ids will be added to history.\"\"\"\n        # update history image nums\n        self._update_embeddings(embeddings)\n\n        # update multimodals\n        self._update_multimodals(multimodals)\n\n        self.arrive_time = time.perf_counter()\n\n        token_ids: np.ndarray = _to_ndarray(token_ids)\n        if draft_token_ids is not None:\n            draft_token_ids = _to_ndarray(draft_token_ids)\n        if mode == UpdateTokenMode.INPUTS:\n            self._update_token_ids_inputs(token_ids)\n        elif mode == UpdateTokenMode.PREFILL:\n            self._update_token_ids_prefill(token_ids, draft_token_ids)\n        else:\n            self._update_token_ids_decode(token_ids, draft_token_ids)\n        if model_meta is not None:\n            self.model_meta = model_meta\n\n\nclass ARSpecSequenceStrategy(ARSequenceStrategy):\n\n    def make_sequence(self,\n                      seq_id: int,\n                      session: 'SchedulerSession',\n                      sampling_param: 'SamplingParam' = None,\n                      adapter_name: str = None,\n                      migration_request: Optional[MigrationRequest] = None,\n                      resp_cache: bool = False,\n                      preserve_cache: bool = False) -> 'SchedulerSequenceARSpec':\n        \"\"\"Make sequence.\"\"\"\n        return SchedulerSequenceARSpec(seq_id=seq_id,\n                                       session=session,\n                                       sampling_param=sampling_param,\n                                       adapter_name=adapter_name,\n                                       migration_request=migration_request,\n                                       resp_cache=resp_cache,\n                                       preserve_cache=preserve_cache)\n\n    def update_running(self, running: SeqList, batched_outputs: BatchedOutputs, model_inputs: 'ModelInputs',\n                       delta: 'ModelInputsDelta', **kwargs) -> None:\n        \"\"\"Update running sequences.\"\"\"\n        next_token_ids = batched_outputs.next_token_ids\n        extra_outputs = batched_outputs.extra_outputs\n        stopped = batched_outputs.stopped\n        stopped = stopped.tolist()\n        model_metas = batched_outputs.model_metas\n        if model_metas is None:\n            model_metas = [None] * len(running)\n        stop_pos = batched_outputs.stop_pos\n\n        if model_inputs is None:\n            is_decoding = delta.is_decoding\n        else:\n            is_decoding = model_inputs.is_decoding\n\n        batch_size = len(running)\n        next_token_ids = next_token_ids.view(batch_size, -1).numpy()\n        if extra_outputs is None or extra_outputs.draft_token_ids is None:\n            draft_token_ids = [None] * batch_size\n        else:\n            draft_token_ids = extra_outputs.draft_token_ids.numpy()\n        stop_pos = stop_pos.tolist()\n        update_mode = UpdateTokenMode.DECODE if is_decoding else UpdateTokenMode.PREFILL\n\n        for idx, token in enumerate(next_token_ids):\n            msg = running[idx]\n            stop = stopped[idx]\n            model_meta = model_metas[idx]\n            if msg.status != MessageStatus.RUNNING:\n                continue\n            cur_draft_tokens = draft_token_ids[idx]\n            # fill token\n            msg.update_token_ids(token, draft_token_ids=cur_draft_tokens, model_meta=model_meta, mode=update_mode)\n            if stop:\n                msg.set_stop_pos(stop_pos[idx])\n                msg.state.finish()\n"
  },
  {
    "path": "lmdeploy/pytorch/strategies/base/__init__.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom abc import ABC, abstractmethod\nfrom typing import TYPE_CHECKING\n\nif TYPE_CHECKING:\n    from lmdeploy.pytorch.config import CacheConfig, SchedulerConfig\n\n    from .cudagraph import CudagraphStrategy\n    from .engine import EngineStrategy\n    from .model_agent import ModelAgentStrategy\n    from .model_inputs import ModelInputsStrategy\n    from .sampling import SamplingStrategy\n    from .sequence import SequenceStrategy\n\n\nclass StrategyFactoryBase(ABC):\n\n    @abstractmethod\n    def build_cudagraph_strategy(self) -> 'CudagraphStrategy':\n        \"\"\"Build cudagraph strategy.\"\"\"\n        pass\n\n    @abstractmethod\n    def build_sampling_strategy(self) -> 'SamplingStrategy':\n        \"\"\"Build sampling strategy.\"\"\"\n        pass\n\n    @abstractmethod\n    def build_model_inputs_strategy(self) -> 'ModelInputsStrategy':\n        \"\"\"Build model inputs strategy.\"\"\"\n        pass\n\n    @abstractmethod\n    def build_model_agent_strategy(self) -> 'ModelAgentStrategy':\n        \"\"\"Build model agent strategy.\"\"\"\n        pass\n\n    @abstractmethod\n    def build_engine_strategy(self, cache_config: 'CacheConfig',\n                              scheduler_config: 'SchedulerConfig') -> 'EngineStrategy':\n        \"\"\"Build engine strategy.\"\"\"\n        pass\n\n    @abstractmethod\n    def build_sequence_strategy(self) -> 'SequenceStrategy':\n        \"\"\"Build sequence strategy.\"\"\"\n        pass\n"
  },
  {
    "path": "lmdeploy/pytorch/strategies/base/cudagraph.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom abc import ABC, abstractmethod\n\n\nclass CudagraphStrategy(ABC):\n\n    @abstractmethod\n    def get_max_tokens(self, batch_size: int, origin_batch_size: int, num_tokens: int) -> int:\n        \"\"\"Get max tokens.\"\"\"\n        pass\n"
  },
  {
    "path": "lmdeploy/pytorch/strategies/base/engine.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom abc import ABC, abstractmethod\n\n\nclass EngineStrategy(ABC):\n    \"\"\"Engine strategy.\"\"\"\n\n    @abstractmethod\n    def get_prealloc_size(self, is_decoding: bool) -> int:\n        \"\"\"Get prealloc_size.\"\"\"\n        pass\n\n    @abstractmethod\n    def get_num_loops(self, is_decoding: bool) -> int:\n        \"\"\"Get num_loops.\"\"\"\n        pass\n\n    @abstractmethod\n    def get_num_decode_tokens(self) -> int:\n        \"\"\"Get num_decode_tokens.\"\"\"\n        pass\n\n    def get_num_required_tokens(self) -> int:\n        \"\"\"Get num_require_tokens.\"\"\"\n        return self.get_num_decode_tokens()\n"
  },
  {
    "path": "lmdeploy/pytorch/strategies/base/model_agent.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom abc import ABC, abstractmethod\nfrom contextlib import contextmanager\nfrom dataclasses import dataclass, fields\nfrom typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple\n\nimport numpy as np\nimport torch\n\nif TYPE_CHECKING:\n    from lmdeploy.pytorch.distributed import DistContext\n    from lmdeploy.pytorch.engine.logits_process import SamplingInputs\n    from lmdeploy.pytorch.messages import SchedulerSequence\n    from lmdeploy.pytorch.model_inputs import ModelInputs, ModelInputsDelta\n    SeqList = List[SchedulerSequence]\n\n\ndef to_device(self, device: str, non_blocking: bool = False):\n    \"\"\"To device.\"\"\"\n    out_dict = dict()\n    for f in fields(self):\n        k = f.name\n        v = getattr(self, k)\n        if isinstance(v, torch.Tensor):\n            v = v.to(device, non_blocking=non_blocking)\n        out_dict[k] = v\n\n    return type(self)(**out_dict)\n\n\n@dataclass\nclass ExtraInputs(ABC):\n\n    def to_device(self, device: str, non_blocking: bool = False):\n        \"\"\"To device.\"\"\"\n        return to_device(self, device, non_blocking)\n\n    def broadcast(self, src: int, group, async_op=False):\n        \"\"\"Broadcast extra inputs.\"\"\"\n        pass\n\n    def merge(self, other: 'ExtraInputs'):\n        \"\"\"Merge extra inputs.\"\"\"\n        return self\n\n\n@dataclass\nclass ExtraOutputs(ABC):\n\n    def to_device(self, device: str, non_blocking: bool = False):\n        \"\"\"To device.\"\"\"\n        return to_device(self, device, non_blocking)\n\n    def to_cpu(self):\n        \"\"\"To cpu.\"\"\"\n        return self.to_device('cpu', non_blocking=False)\n\n    def to_numpy(self):\n        \"\"\"To numpy.\"\"\"\n        out = dict()\n        for f in fields(self):\n            k = f.name\n            v = getattr(self, k)\n            if isinstance(v, torch.Tensor) and v.dtype != torch.bfloat16:\n                v = v.detach().numpy()\n            elif hasattr(v, 'to_numpy'):\n                v = v.to_numpy()\n            out[k] = v\n        return type(self)(**out)\n\n    def to_tensor(self):\n        \"\"\"To tensor.\"\"\"\n        out = dict()\n        for f in fields(self):\n            k = f.name\n            v = getattr(self, k)\n            if isinstance(v, np.ndarray):\n                v = torch.from_numpy(v)\n            elif hasattr(v, 'to_tensor'):\n                v = v.to_tensor()\n            out[k] = v\n        return type(self)(**out)\n\n\n@dataclass\nclass StoppingCriteria(ABC):\n    \"\"\"Base class for stopping criteria.\"\"\"\n\n    @abstractmethod\n    def clone(self) -> 'StoppingCriteria':\n        \"\"\"clone.\"\"\"\n\n    @abstractmethod\n    def merge(self, other: 'StoppingCriteria') -> 'StoppingCriteria':\n        \"\"\"Merge two stopping criteria.\"\"\"\n\n    @abstractmethod\n    def update(self, delta: 'ModelInputsDelta') -> 'StoppingCriteria':\n        \"\"\"Update stopping criteria.\"\"\"\n\n    @abstractmethod\n    def step(self,\n             token_ids: torch.Tensor,\n             stop_words: torch.Tensor,\n             inputs: Optional['ModelInputs'] = None,\n             extra_inputs: Optional[ExtraInputs] = None):\n        \"\"\"Check whether to stop generation.\"\"\"\n        pass\n\n    def to_device(self, device: str, non_blocking: bool = False):\n        \"\"\"To device.\"\"\"\n        return to_device(self, device, non_blocking)\n\n\nclass ModelAgentStrategy(ABC):\n    \"\"\"Base class for model agent strategies.\"\"\"\n\n    @abstractmethod\n    def slice_outputs(self, inputs: torch.Tensor, seq_length: torch.LongTensor) -> torch.Tensor:\n        \"\"\"Slice outputs.\"\"\"\n        pass\n\n    @abstractmethod\n    def slice_extra_inputs(self, extra_inputs: ExtraInputs, model_inputs: 'ModelInputs',\n                           model_outputs: Dict[str, torch.Tensor], **kwargs) -> ExtraInputs:\n        \"\"\"Slice outputs.\"\"\"\n        pass\n\n    @abstractmethod\n    def make_stopping_criteria(self, seqs: 'SeqList') -> StoppingCriteria:\n        \"\"\"Create stopping criteria.\"\"\"\n        pass\n\n    @abstractmethod\n    def make_extra_inputs(self, seqs: 'SeqList', model_inputs: 'ModelInputs') -> ExtraInputs:\n        \"\"\"Create extra inputs.\"\"\"\n        pass\n\n    def update_extra_inputs(self, extra_inputs: ExtraInputs, delta: 'ModelInputsDelta') -> ExtraInputs:\n        \"\"\"Update extra inputs with model inputs delta.\"\"\"\n        return extra_inputs\n\n    @abstractmethod\n    def make_extra_outputs(self, extra_inputs: ExtraInputs) -> ExtraOutputs:\n        \"\"\"Create extra outputs.\"\"\"\n        pass\n\n    @abstractmethod\n    def step_sampling_inputs(\n        self,\n        sampling_inputs: 'SamplingInputs',\n        next_token_ids: torch.Tensor,\n        extra_inputs: ExtraInputs,\n    ):\n        \"\"\"step.\"\"\"\n        pass\n\n    @abstractmethod\n    def update_prefill_for_next_step(\n        self,\n        model_inputs: 'ModelInputs',\n        extra_inputs: ExtraInputs,\n        next_token_ids: torch.Tensor,\n        model_metas: Any,\n        extra_outputs: ExtraOutputs,\n    ) -> Tuple['ModelInputs', ExtraInputs]:\n        \"\"\"Step next decoding.\"\"\"\n        pass\n\n    @abstractmethod\n    def update_decoding_for_next_step(self, model_inputs: 'ModelInputs', next_token_ids: torch.Tensor, model_metas: Any,\n                                      extra_inputs: ExtraInputs,\n                                      extra_outputs: ExtraOutputs) -> Tuple['ModelInputs', ExtraInputs]:\n        \"\"\"Step next inputs.\"\"\"\n        pass\n\n    @abstractmethod\n    def post_sampling(self, inputs: 'ModelInputs', logits: torch.Tensor, next_token_ids: torch.LongTensor,\n                      extra_inputs: ExtraInputs):\n        \"\"\"Post sampling.\"\"\"\n        pass\n\n    def make_dummy_next_token(self, inputs: 'ModelInputs', logits: torch.Tensor, extra_inputs: ExtraInputs):\n        \"\"\"Make dummy next token for broadcast.\"\"\"\n        with torch.inference_mode():\n            next_token_ids = inputs.input_ids.new_zeros(logits.size(0))\n        return next_token_ids, extra_inputs\n\n    @abstractmethod\n    @contextmanager\n    def broadcast_next_token(self, next_token_ids: torch.Tensor, extra_inputs: ExtraInputs, dist_ctx: 'DistContext'):\n        \"\"\"Broadcast next token ids and extra inputs.\"\"\"\n"
  },
  {
    "path": "lmdeploy/pytorch/strategies/base/model_inputs.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom abc import ABC, abstractmethod\n\nimport torch\nfrom torch.profiler import record_function\n\nfrom lmdeploy.pytorch.model_inputs import ModelInputs, ModelInputsDelta\n\n\n@record_function('make_dummy_input')\ndef make_dummy_inputs(batch_size: int,\n                      max_q_seqlen: int,\n                      is_decoding: bool,\n                      device: str = 'cpu',\n                      dummy_block_id: int = 0,\n                      vocab_size: int = 1):\n    \"\"\"Make dummy inputs global implement.\"\"\"\n    num_tokens = batch_size * max_q_seqlen\n    max_kv_seqlen = max_q_seqlen\n    input_ids = torch.randint(0, vocab_size, (\n        1,\n        num_tokens,\n    ), dtype=torch.long, device=device)\n    seq_length = torch.full((batch_size, ), max_q_seqlen, dtype=torch.long, device=device)\n    history_lengths = torch.zeros((batch_size, ), dtype=torch.long, device=device)\n    block_offsets = torch.full((batch_size, 1), dummy_block_id, dtype=torch.long, device=device)\n    num_ignored_history = torch.zeros((batch_size, ), dtype=torch.long, device=device)\n    local_adapter_ids = torch.zeros((batch_size, ), dtype=torch.long, device=device)\n    state_offsets = torch.full((batch_size, ), -1, dtype=torch.long, device=device)\n\n    return ModelInputs(\n        input_ids=input_ids,\n        seq_length=seq_length,\n        history_lengths=history_lengths,\n        block_offsets=block_offsets,\n        is_decoding=is_decoding,\n        num_ignored_history=num_ignored_history,\n        max_q_seqlen=max_q_seqlen,\n        max_kv_seqlen=max_kv_seqlen,\n        sum_kv_seqlen=num_tokens,\n        local_adapter_ids=local_adapter_ids,\n        is_dummy=True,\n        state_offsets=state_offsets,\n    )\n\n\nclass ModelInputsStrategy(ABC):\n\n    @abstractmethod\n    def make_dummy(self,\n                   batch_size: int,\n                   is_decoding: bool,\n                   device: str = 'cpu',\n                   dummy_block_id: int = 0,\n                   vocab_size: int = 1) -> ModelInputs:\n        \"\"\"Create dummy model inputs.\"\"\"\n        pass\n\n    @abstractmethod\n    def merge(self, inputs: ModelInputs, other: ModelInputs) -> ModelInputs:\n        \"\"\"Merge model inputs.\"\"\"\n        pass\n\n    @abstractmethod\n    def update_inputs(self, inputs: ModelInputs, delta: 'ModelInputsDelta') -> ModelInputs:\n        \"\"\"Update model inputs with delta.\"\"\"\n        pass\n"
  },
  {
    "path": "lmdeploy/pytorch/strategies/base/sampling.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom abc import ABC, abstractmethod\nfrom typing import List\n\nimport torch\n\nfrom lmdeploy.pytorch.engine.logits_process import SamplingInputs, SamplingInputsDelta\nfrom lmdeploy.pytorch.messages import SchedulerSequence\nfrom lmdeploy.pytorch.model_inputs import ModelInputsDelta\n\nfrom .model_agent import ExtraInputs\n\nSeqList = List[SchedulerSequence]\n\n\nclass SamplingStrategy(ABC):\n    \"\"\"Base class for sampling strategies.\"\"\"\n\n    @abstractmethod\n    def make_sampling_inputs(self, seqs: SeqList) -> SamplingInputs:\n        \"\"\"Create sampling inputs from the sequences.\"\"\"\n        pass\n\n    @abstractmethod\n    def on_session_end(self, session_id: int) -> None:\n        \"\"\"Invoked on session ends.\"\"\"\n        pass\n\n    @abstractmethod\n    def merge_sampling_delta(\n        self,\n        sampling_delta: 'SamplingInputsDelta',\n        other: 'SamplingInputsDelta',\n    ) -> 'SamplingInputsDelta':\n        \"\"\"Merge two sampling deltas.\"\"\"\n\n    @abstractmethod\n    def step_sampling_delta(\n        self,\n        sampling_delta: 'SamplingInputsDelta',\n        next_token_ids: torch.Tensor,\n        extra_inputs: 'ExtraInputs',\n    ) -> 'SamplingInputsDelta':\n        \"\"\"Step next delta.\"\"\"\n        pass\n\n    @abstractmethod\n    def update_sampling_delta(\n        self,\n        sampling_delta: 'SamplingInputsDelta',\n        delta: 'ModelInputsDelta',\n    ) -> 'SamplingInputsDelta':\n        \"\"\"Update sampling delta with model inputs delta.\"\"\"\n        pass\n"
  },
  {
    "path": "lmdeploy/pytorch/strategies/base/sequence.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom abc import ABC, abstractmethod\nfrom typing import TYPE_CHECKING, List, Optional\n\nfrom lmdeploy.pytorch.disagg.conn.protocol import MigrationRequest\n\nif TYPE_CHECKING:\n    from lmdeploy.pytorch.engine.model_agent import BatchedOutputs\n    from lmdeploy.pytorch.messages import SamplingParam, SchedulerSequence, SchedulerSession\n    from lmdeploy.pytorch.model_inputs import ModelInputs, ModelInputsDelta\n    SeqList = List[SchedulerSequence]\n\n\nclass SequenceStrategy(ABC):\n\n    @abstractmethod\n    def make_sequence(self,\n                      seq_id: int,\n                      session: 'SchedulerSession',\n                      sampling_param: 'SamplingParam' = None,\n                      adapter_name: str = None,\n                      migration_request: Optional[MigrationRequest] = None,\n                      resp_cache: bool = False,\n                      preserve_cache: bool = False) -> 'SchedulerSequence':\n        \"\"\"Make sequence.\"\"\"\n        pass\n\n    @abstractmethod\n    def update_running(self, running: 'SeqList', batched_outputs: 'BatchedOutputs', model_inputs: 'ModelInputs',\n                       delta: 'ModelInputsDelta') -> None:\n        \"\"\"Update running sequences.\"\"\"\n        pass\n"
  },
  {
    "path": "lmdeploy/pytorch/strategies/dllm/__init__.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom typing import TYPE_CHECKING\n\nfrom lmdeploy.pytorch.config import DLLMConfig, ModelConfig\nfrom lmdeploy.pytorch.strategies.base.sequence import SequenceStrategy\nfrom lmdeploy.utils import get_logger\n\nif TYPE_CHECKING:\n    from lmdeploy.pytorch.strategies.base.cudagraph import CudagraphStrategy\n    from lmdeploy.pytorch.strategies.base.model_inputs import ModelInputsStrategy\n    from lmdeploy.pytorch.strategies.base.sampling import SamplingStrategy\n    from lmdeploy.pytorch.strategies.base.model_agent import ModelAgentStrategy\n    from lmdeploy.pytorch.strategies.base.engine import EngineStrategy\n    from lmdeploy.pytorch.config import CacheConfig, SchedulerConfig\n\nfrom ..base import StrategyFactoryBase\n\nlogger = get_logger('lmdeploy')\n\n\nclass DLLMStrategyFactory(StrategyFactoryBase):\n\n    def __init__(self, model_config: ModelConfig, dllm_config: DLLMConfig):\n        \"\"\"config.\"\"\"\n        self.model_config = model_config\n        self.dllm_config = dllm_config\n\n        # update dllm_block_length\n        self.dllm_block_length = self._update_dllm_block_length()\n\n    def _update_dllm_block_length(self):\n        \"\"\"Update dllm_block_length.\"\"\"\n        if self.dllm_config.block_length is None:\n            dllm_block_length = self.model_config.dllm_block_length\n            if dllm_block_length is None:\n                dllm_block_length = 4\n                logger.warning('Model does not provide dllm_block_length. '\n                               f'Set dllm_block_length={dllm_block_length} as default.')\n        else:\n            dllm_block_length = self.dllm_config.block_length\n\n        assert dllm_block_length is not None, 'dllm_block_length should be set in model_config or dllm_config'\n\n        self.dllm_config.block_length = dllm_block_length\n        self.model_config.dllm_block_length = dllm_block_length\n\n        if self.dllm_config.denoising_steps is None:\n            self.dllm_config.denoising_steps = dllm_block_length\n        return dllm_block_length\n\n    def build_cudagraph_strategy(self) -> 'CudagraphStrategy':\n        \"\"\"Build cudagraph strategy.\"\"\"\n        from .cudagraph import DLLMCudagraphStrategy\n        return DLLMCudagraphStrategy(block_size=self.dllm_block_length)\n\n    def build_sampling_strategy(self) -> 'SamplingStrategy':\n        \"\"\"Build sampling strategy.\"\"\"\n        from .sampling import DLLMSamplingStrategy\n        pad_token_id = self.model_config.bos_token_id\n        pad_token_id = 0 if pad_token_id is None else pad_token_id\n        return DLLMSamplingStrategy(pad_token_id, self.dllm_block_length)\n\n    def build_model_inputs_strategy(self) -> 'ModelInputsStrategy':\n        \"\"\"Build model inputs strategy.\"\"\"\n        from .model_inputs import DLLMModelInputsStrategy\n        return DLLMModelInputsStrategy(block_size=self.dllm_block_length)\n\n    def build_model_agent_strategy(self) -> 'ModelAgentStrategy':\n        \"\"\"Build model agent strategy.\"\"\"\n        from .model_agent import DLLMModelAgentStrategy\n        return DLLMModelAgentStrategy(dllm_config=self.dllm_config, dllm_mask_token=self.model_config.dllm_mask_token)\n\n    def build_engine_strategy(self, cache_config: 'CacheConfig',\n                              scheduler_config: 'SchedulerConfig') -> 'EngineStrategy':\n        \"\"\"Build engine strategy.\"\"\"\n        from .engine import DLLMEngineStrategy\n        return DLLMEngineStrategy(cache_config=cache_config,\n                                  scheduler_config=scheduler_config,\n                                  dllm_block_length=self.dllm_block_length)\n\n    def build_sequence_strategy(self) -> SequenceStrategy:\n        from .sequence import DLLMSequenceStrategy\n        return DLLMSequenceStrategy(block_size=self.dllm_block_length,\n                                    dllm_mask_token=self.model_config.dllm_mask_token)\n"
  },
  {
    "path": "lmdeploy/pytorch/strategies/dllm/cudagraph.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom ..base.cudagraph import CudagraphStrategy\n\n\nclass DLLMCudagraphStrategy(CudagraphStrategy):\n\n    def __init__(self, block_size: int) -> None:\n        super().__init__()\n        self.block_size = block_size\n\n    def get_max_tokens(self, batch_size: int, origin_batch_size: int, num_tokens: int) -> int:\n        \"\"\"Get max tokens.\"\"\"\n        return batch_size * self.block_size\n"
  },
  {
    "path": "lmdeploy/pytorch/strategies/dllm/engine.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom functools import lru_cache\n\nfrom lmdeploy.pytorch.config import CacheConfig, SchedulerConfig\nfrom lmdeploy.utils import get_logger\n\nfrom ..base.engine import EngineStrategy\n\nlogger = get_logger('lmdeploy')\n\n\nclass DLLMEngineStrategy(EngineStrategy):\n    \"\"\"DLLM Engine Strategy.\"\"\"\n\n    def __init__(self, scheduler_config: SchedulerConfig, cache_config: CacheConfig, dllm_block_length: int) -> None:\n        self.scheduler_config = scheduler_config\n        self.cache_config = cache_config\n        self.dllm_block_length = dllm_block_length\n\n        self._check()\n\n    def _check(self):\n        \"\"\"check.\"\"\"\n        max_prefill_token_num = self.cache_config.max_prefill_token_num\n        max_batches = self.cache_config.max_batches\n        if self.dllm_block_length * max_batches > max_prefill_token_num:\n            logger.warning(f'dllm_block_length({self.dllm_block_length}) * max_batch_size ({max_batches}) '\n                           f'> max_prefill_token_num ({max_prefill_token_num}). '\n                           'This may lead to OOM. Consider to reduce max_batch_size or dllm_block_length.')\n\n    @lru_cache(maxsize=2)\n    def get_prealloc_size(self, is_decoding: bool) -> int:\n        \"\"\"Get prealloc_size.\"\"\"\n        if not is_decoding:\n            return 0\n        block_size = self.cache_config.block_size\n        dllm_block_length = self.dllm_block_length\n        num_blocks = min(self.scheduler_config.prefill_interval // 2, block_size // dllm_block_length)\n        return num_blocks * dllm_block_length\n\n    @lru_cache(maxsize=2)\n    def get_num_loops(self, is_decoding: bool) -> int:\n        \"\"\"Get num_loops.\"\"\"\n        if not is_decoding:\n            return 1\n        block_size = self.cache_config.block_size\n        dllm_block_length = self.dllm_block_length\n        max_num_loops = block_size // dllm_block_length * 2\n        num_loops = min(self.scheduler_config.prefill_interval, max_num_loops)\n        return num_loops\n\n    def get_num_decode_tokens(self) -> int:\n        \"\"\"Get num_decode_tokens.\"\"\"\n        return self.dllm_block_length\n"
  },
  {
    "path": "lmdeploy/pytorch/strategies/dllm/model_agent.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom contextlib import contextmanager\nfrom dataclasses import dataclass\nfrom typing import Any, Dict, List, Optional, Tuple\n\nimport numpy as np\nimport torch\nimport torch.distributed as dist\nfrom torch.profiler import record_function\n\nfrom lmdeploy.pytorch import consts\nfrom lmdeploy.pytorch.config import DLLMConfig\nfrom lmdeploy.pytorch.distributed import DistContext\nfrom lmdeploy.pytorch.engine.logits_process import SamplingInputs\nfrom lmdeploy.pytorch.messages import SchedulerSequence\nfrom lmdeploy.pytorch.model_inputs import ModelInputs, ModelInputsDelta\n\nfrom ..base.model_agent import ExtraInputs, ExtraOutputs, ModelAgentStrategy, StoppingCriteria\nfrom .unmasking import UnmaskingProcessor\n\nSeqList = List[SchedulerSequence]\n\n\ndef get_model_inputs_next_decoding(inputs: ModelInputs, input_ids: torch.Tensor, max_q_seqlen,\n                                   step_seqlens: torch.Tensor, model_metas) -> ModelInputs:\n    \"\"\"Next decoding step.\"\"\"\n    if input_ids.dim() == 1:\n        input_ids = input_ids[None, :]\n    step_seqlens = torch.where(step_seqlens > 0, step_seqlens, inputs.seq_length - max_q_seqlen)\n    return ModelInputs(\n        input_ids=input_ids,\n        seq_length=torch.full_like(inputs.seq_length, max_q_seqlen),\n        history_lengths=inputs.history_lengths + step_seqlens,\n        block_offsets=inputs.block_offsets,\n        is_decoding=True,\n        num_ignored_history=inputs.num_ignored_history,\n        max_q_seqlen=max_q_seqlen,\n        max_kv_seqlen=inputs.max_kv_seqlen + max_q_seqlen,\n        sum_kv_seqlen=inputs.sum_kv_seqlen + inputs.seq_length.numel() * inputs.max_q_seqlen,\n        local_adapter_ids=inputs.local_adapter_ids,\n        model_metas=model_metas,\n        state_offsets=inputs.state_offsets,\n    )\n\n\n@dataclass\nclass DLLMExtraInputs(ExtraInputs):\n    \"\"\"DLLM extra inputs.\"\"\"\n    dllm_mask: torch.Tensor\n\n    def broadcast(self, src: int, group, async_op=False):\n        return dist.broadcast(self.dllm_mask, src=src, group=group, async_op=async_op)\n\n    def merge(self, other: 'DLLMExtraInputs'):\n        \"\"\"Merge extra inputs.\"\"\"\n        dllm_mask = torch.cat([self.dllm_mask, other.dllm_mask], dim=0)\n        return DLLMExtraInputs(dllm_mask=dllm_mask)\n\n\n@dataclass\nclass DLLMExtraOutputs(ExtraOutputs):\n    \"\"\"Ar extra outputs.\"\"\"\n    dllm_mask: torch.Tensor\n\n\ndef _check_stopwords_dllm(token_ids: torch.Tensor, stop_words: torch.Tensor, is_unmasked: torch.Tensor,\n                          stopped: torch.Tensor, stop_pos: torch.Tensor, num_appendable_ids: torch.Tensor,\n                          output_start_pos: torch.Tensor, inputs: ModelInputs):\n    num_tokens = token_ids.size(0)\n    batch_size = num_appendable_ids.size(0)\n    block_size = num_tokens // batch_size\n\n    # blocks might contain stop words in prev-round chat\n    # these stop words should be ignored\n    kv_seqlens = inputs.history_lengths + inputs.seq_length\n    ignore_pos = (output_start_pos - (kv_seqlens - block_size)).clamp_min(0)\n    ignore_range = torch.arange(0, block_size, dtype=ignore_pos.dtype, device=ignore_pos.device)\n    ignore_mask = (ignore_range[None, :] < ignore_pos[:, None]).flatten()\n    token_ids = token_ids.clone()\n    token_ids[ignore_mask] = -1\n\n    # find stop words\n    sw_stopped = (token_ids[:, None] == stop_words).any(1)\n    sw_stopped = sw_stopped.view(batch_size, block_size)\n    sw_stop_pos = sw_stopped.int().argmax(1)\n\n    stop_pos = torch.where(stopped, stop_pos, sw_stop_pos)\n    sw_stopped = sw_stopped.any(dim=1)\n    sw_stopped = sw_stopped & is_unmasked\n    stopped = stopped | sw_stopped\n\n    # update num_appendable_ids\n    one_ids = torch.clamp_max(num_appendable_ids, 0)\n    num_appendable_ids = torch.where(sw_stopped, one_ids, num_appendable_ids)\n\n    return stopped, stop_pos, num_appendable_ids\n\n\n@dataclass\nclass DLLMStoppingCriteria(StoppingCriteria):\n    num_appendable_ids: torch.Tensor\n    output_start_pos: torch.Tensor\n\n    def clone(self) -> 'DLLMStoppingCriteria':\n        \"\"\"clone.\"\"\"\n        return DLLMStoppingCriteria(num_appendable_ids=self.num_appendable_ids, output_start_pos=self.output_start_pos)\n\n    def merge(self, other: 'DLLMStoppingCriteria') -> 'DLLMStoppingCriteria':\n        \"\"\"Merge two stopping criteria.\"\"\"\n        return DLLMStoppingCriteria(num_appendable_ids=torch.cat([self.num_appendable_ids, other.num_appendable_ids],\n                                                                 dim=0),\n                                    output_start_pos=torch.cat([self.output_start_pos, other.output_start_pos], dim=0))\n\n    def update(self, delta: 'ModelInputsDelta') -> 'DLLMStoppingCriteria':\n        \"\"\"Update stopping criteria.\"\"\"\n        indices = delta.indices\n        return DLLMStoppingCriteria(num_appendable_ids=self.num_appendable_ids[indices],\n                                    output_start_pos=self.output_start_pos[indices])\n\n    @record_function('stopping_criteria')\n    def step(self,\n             token_ids: torch.Tensor,\n             stop_words: torch.Tensor,\n             inputs: Optional[ModelInputs] = None,\n             extra_inputs: Optional[DLLMExtraInputs] = None):\n        \"\"\"Check whether to stop generation.\"\"\"\n        num_appendable_ids = self.num_appendable_ids\n        output_start_pos = self.output_start_pos\n        num_tokens = token_ids.size(0)\n        batch_size = num_appendable_ids.size(0)\n        block_size = num_tokens // batch_size\n\n        dllm_mask = extra_inputs.dllm_mask\n        dllm_mask = dllm_mask.view(batch_size, block_size)\n        is_unmasked = (dllm_mask == consts.DLLM_UNMASKED).all(dim=1)\n\n        # check stop by num_new_tokens\n        num_appendable_ids -= is_unmasked * block_size\n        stopped = num_appendable_ids <= 0\n        stop_pos = block_size - 1 + num_appendable_ids\n\n        # check stop words\n        if stop_words is not None:\n            stopped, stop_pos, num_appendable_ids = _check_stopwords_dllm(token_ids,\n                                                                          stop_words,\n                                                                          is_unmasked,\n                                                                          stopped,\n                                                                          stop_pos,\n                                                                          num_appendable_ids,\n                                                                          output_start_pos=output_start_pos,\n                                                                          inputs=inputs)\n\n        new_stopping = DLLMStoppingCriteria(num_appendable_ids=num_appendable_ids, output_start_pos=output_start_pos)\n        return stopped, stop_pos, new_stopping\n\n\nclass DLLMModelAgentStrategy(ModelAgentStrategy):\n\n    def __init__(self, dllm_config: DLLMConfig, dllm_mask_token: int):\n        block_size = dllm_config.block_length\n        self.block_size = block_size\n        self.dllm_mask_token = dllm_mask_token\n\n        self.unmasking_processor = UnmaskingProcessor(dllm_config=dllm_config)\n\n    def _update_dllm(self, next_token_ids: torch.Tensor, dllm_mask: torch.Tensor, seqlens: torch.Tensor):\n        \"\"\"Update token_ids and dllm_mask.\"\"\"\n        dllm_mask_token = self.dllm_mask_token\n        dllm_block_length = self.block_size\n\n        # reshape to (batch, dllm_block_length)\n        next_token_ids = next_token_ids.view(-1, dllm_block_length).clone()\n        dllm_mask = dllm_mask.view(-1, dllm_block_length).clone()\n\n        # flags\n        is_cached = (dllm_mask == consts.DLLM_CACHED).all(dim=1)\n\n        is_masked = (dllm_mask == consts.DLLM_MASKED)\n        next_token_ids[is_cached[:, None] | is_masked] = dllm_mask_token\n        dllm_mask[is_cached] = consts.DLLM_MASKED\n        seqlens = torch.where(is_cached.view(-1), seqlens, seqlens.new_zeros((1, )))\n\n        return next_token_ids.flatten(), dllm_mask.flatten(), seqlens\n\n    def slice_outputs(self, inputs: torch.Tensor, seq_length: torch.LongTensor) -> torch.Tensor:\n        \"\"\"Slice outputs.\"\"\"\n        block_length = self.block_size\n        # batch size = 1\n        if len(seq_length) == 1:\n            return inputs[-block_length:]\n\n        if len(seq_length) * block_length == inputs.size(0):\n            return inputs\n        last_idx = seq_length.cumsum(0)\n        block_range = torch.arange(-block_length, 0, device=last_idx.device)\n        index = (last_idx[:, None] + block_range[None, :]).flatten()\n        inputs = inputs[index]\n        return inputs\n\n    def slice_extra_inputs(self, extra_inputs: DLLMExtraInputs, model_inputs: ModelInputs,\n                           model_outputs: Dict[str, torch.Tensor], **kwargs) -> DLLMExtraInputs:\n        \"\"\"Slice outputs.\"\"\"\n        dllm_mask = self.slice_outputs(extra_inputs.dllm_mask, model_inputs.seq_length)\n        return DLLMExtraInputs(dllm_mask=dllm_mask)\n\n    def step_sampling_inputs(self, sampling_inputs: SamplingInputs, next_token_ids: torch.Tensor,\n                             extra_inputs: DLLMExtraInputs, **kwargs):\n        \"\"\"Step sampling inputs.\"\"\"\n        from lmdeploy.pytorch import consts\n        dllm_mask = extra_inputs.dllm_mask\n        dllm_block_size = self.block_size\n        DLLM_UNMASKED = consts.DLLM_UNMASKED\n        is_unmasked = (dllm_mask == DLLM_UNMASKED).view(-1, dllm_block_size).all(dim=1, keepdim=True)\n        num_ignore_eos = sampling_inputs.num_ignore_eos.view(-1, dllm_block_size)\n        num_ignore_eos = torch.where(is_unmasked, num_ignore_eos - dllm_block_size, num_ignore_eos)\n        sampling_inputs.num_ignore_eos = num_ignore_eos.flatten()\n        if sampling_inputs.random_offsets is not None:\n            # random offset is used to generate random numbers for multinomial sampling\n            # so we need to increase it by 1 at each step\n            sampling_inputs.random_offsets += 1\n        return sampling_inputs\n\n    def make_stopping_criteria(self, seqs: SeqList) -> DLLMStoppingCriteria:\n        \"\"\"Create stopping criteria.\"\"\"\n        # num_appendable\n        num_appendable = [seq.sampling_param.max_new_tokens - seq.num_new_tokens for seq in seqs]\n        num_appendable = torch.tensor(num_appendable)\n        block_size = self.block_size\n        remain = [seq.num_valid_ids % block_size for seq in seqs]\n        num_appendable += torch.tensor(remain)\n\n        # output_start_pos\n        pos = [seq.output_start_pos for seq in seqs]\n        output_start_pos = torch.tensor(pos)\n\n        return DLLMStoppingCriteria(num_appendable_ids=num_appendable, output_start_pos=output_start_pos)\n\n    def make_extra_inputs(self, seqs: 'SeqList', model_inputs: 'ModelInputs') -> ExtraInputs:\n        \"\"\"Create extra inputs.\"\"\"\n        dllm_masks = [seq.dllm_mask for seq in seqs]\n\n        # chunked prefill only require part of the dllm masks\n        if model_inputs.is_chunk:\n            seqlens = model_inputs.seq_length.tolist()\n            dllm_masks = [mask[:length] for mask, length in zip(dllm_masks, seqlens)]\n\n        dllm_masks = torch.as_tensor(np.concatenate(dllm_masks))\n        return DLLMExtraInputs(dllm_mask=dllm_masks)\n\n    def update_extra_inputs(self, extra_inputs: DLLMExtraInputs, delta: 'ModelInputsDelta') -> DLLMExtraInputs:\n        \"\"\"Update extra inputs with model inputs delta.\"\"\"\n        dllm_mask = extra_inputs.dllm_mask\n        dllm_mask = dllm_mask.reshape(-1, self.block_size)\n\n        indices = delta.indices\n        dllm_mask = dllm_mask[indices].flatten()\n\n        return DLLMExtraInputs(dllm_mask=dllm_mask)\n\n    def make_extra_outputs(self, extra_inputs: DLLMExtraInputs) -> DLLMExtraOutputs:\n        \"\"\"Create extra outputs.\"\"\"\n        dllm_mask = extra_inputs.dllm_mask\n        return DLLMExtraOutputs(dllm_mask=dllm_mask)\n\n    def update_prefill_for_next_step(\n        self,\n        model_inputs: 'ModelInputs',\n        extra_inputs: DLLMExtraInputs,\n        next_token_ids: torch.Tensor,\n        model_metas: Any,\n        extra_outputs: DLLMExtraOutputs,\n    ) -> Tuple['ModelInputs', DLLMExtraInputs]:\n        \"\"\"Step next decoding.\"\"\"\n        dllm_mask = extra_outputs.dllm_mask\n        next_token_ids, dllm_mask, step_seqlens = self._update_dllm(next_token_ids, dllm_mask, model_inputs.seq_length)\n\n        inputs = get_model_inputs_next_decoding(model_inputs,\n                                                next_token_ids,\n                                                model_metas=model_metas,\n                                                max_q_seqlen=self.block_size,\n                                                step_seqlens=step_seqlens)\n        extra_inputs = DLLMExtraInputs(dllm_mask=dllm_mask)\n        return inputs, extra_inputs\n\n    def update_decoding_for_next_step(self, model_inputs: 'ModelInputs', next_token_ids: torch.Tensor, model_metas: Any,\n                                      extra_inputs: DLLMExtraInputs, **kwargs):\n        \"\"\"Step next inputs.\"\"\"\n        model_inputs.model_metas = model_metas\n        dllm_mask = extra_inputs.dllm_mask\n\n        next_token_ids, dllm_mask, step_seqlens = self._update_dllm(next_token_ids, dllm_mask, model_inputs.seq_length)\n        model_inputs.step(next_token_ids, step_seqlens)\n\n        extra_inputs = DLLMExtraInputs(dllm_mask=dllm_mask)\n        return model_inputs, extra_inputs\n\n    def post_sampling(self, inputs: 'ModelInputs', logits: torch.Tensor, next_token_ids: torch.LongTensor,\n                      extra_inputs: DLLMExtraInputs):\n        \"\"\"Post sampling.\"\"\"\n        dllm_mask = extra_inputs.dllm_mask\n        input_ids = inputs.input_ids\n        input_ids = self.slice_outputs(input_ids.flatten(), inputs.seq_length)\n\n        dllm_mask, next_token_ids = self.unmasking_processor(logits, input_ids, next_token_ids, dllm_mask)\n\n        extra_inputs.dllm_mask = dllm_mask\n        return next_token_ids, extra_inputs\n\n    def make_dummy_next_token(self, inputs: 'ModelInputs', logits: torch.Tensor, extra_inputs: DLLMExtraInputs):\n        \"\"\"Make dummy next token for broadcast.\"\"\"\n        with torch.inference_mode():\n            next_token_ids = inputs.input_ids.new_zeros(logits.size(0))\n        return next_token_ids, extra_inputs\n\n    @contextmanager\n    def broadcast_next_token(self, next_token_ids: torch.Tensor, extra_inputs: DLLMExtraInputs, dist_ctx: DistContext):\n        \"\"\"Broadcast next token ids and extra inputs.\"\"\"\n        tp_gpu_group = dist_ctx.attn_tp_group.gpu_group\n        rank = dist.get_global_rank(tp_gpu_group, 0)\n        dist.broadcast(next_token_ids, src=rank, group=tp_gpu_group, async_op=True)\n        handle = extra_inputs.broadcast(src=rank, group=tp_gpu_group, async_op=True)\n        yield\n        handle.wait()\n"
  },
  {
    "path": "lmdeploy/pytorch/strategies/dllm/model_inputs.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom lmdeploy.pytorch.model_inputs import ModelInputs, ModelInputsDelta\n\nfrom ..ar.model_inputs import merge_model_inputs\nfrom ..base.model_inputs import ModelInputsStrategy, make_dummy_inputs\n\n\nclass DLLMModelInputsStrategy(ModelInputsStrategy):\n\n    def __init__(self, block_size: int):\n        self.block_size = block_size\n\n    def make_dummy(self,\n                   batch_size: int,\n                   is_decoding: bool,\n                   device: str = 'cpu',\n                   dummy_block_id: int = 0,\n                   vocab_size: int = 1) -> ModelInputs:\n        \"\"\"Create dummy model inputs.\"\"\"\n        return make_dummy_inputs(batch_size,\n                                 max_q_seqlen=self.block_size,\n                                 is_decoding=is_decoding,\n                                 device=device,\n                                 dummy_block_id=dummy_block_id,\n                                 vocab_size=vocab_size)\n\n    def merge(self, inputs: ModelInputs, other: ModelInputs) -> ModelInputs:\n        \"\"\"Merge model inputs.\"\"\"\n        return merge_model_inputs(inputs, other)\n\n    def update_inputs(self, inputs: ModelInputs, delta: 'ModelInputsDelta') -> ModelInputs:\n        \"\"\"Update model inputs with delta.\"\"\"\n\n        assert inputs.is_decoding, 'Only support index_select in decoding.'\n        indices = delta.indices\n        indice_cpu = delta.indice_cpu\n        block_offsets = delta.block_offsets\n        max_q_seqlen = delta.max_q_seqlen\n        max_kv_seqlen = delta.max_kv_seqlen\n        sum_kv_seqlen = delta.sum_kv_seqlen\n        num_ignored_history = delta.num_ignored_history\n\n        # required inputs\n        # input_ids = inputs.input_ids[..., indices]\n        inputs_ids = inputs.input_ids.reshape(1, -1, self.block_size)\n        input_ids = inputs_ids[:, indices].reshape(1, -1)\n        seq_length = inputs.seq_length[indices]\n        history_lengths = inputs.history_lengths[indices]\n        if block_offsets is None:\n            block_offsets = inputs.block_offsets[indices]\n        if num_ignored_history is None:\n            num_ignored_history = inputs.num_ignored_history[indices]\n        max_q_seqlen = max_q_seqlen or inputs.max_q_seqlen\n        max_kv_seqlen = max_kv_seqlen or inputs.max_kv_seqlen\n        sum_kv_seqlen = sum_kv_seqlen or inputs.sum_kv_seqlen\n\n        # lora adapter ids\n        local_adapter_ids = inputs.local_adapter_ids\n        if local_adapter_ids is not None:\n            local_adapter_ids = local_adapter_ids[indices]\n\n        # model metas for vl models\n        model_metas = inputs.model_metas\n        if model_metas is not None and indice_cpu is not None:\n            model_metas = [model_metas[i] for i in indice_cpu]\n\n        # for ssm\n        state_offsets = inputs.state_offsets\n        if state_offsets is not None:\n            state_offsets = state_offsets[indices]\n\n        # return new inputs\n        return ModelInputs(\n            input_ids=input_ids,\n            seq_length=seq_length,\n            history_lengths=history_lengths,\n            block_offsets=block_offsets,\n            is_decoding=inputs.is_decoding,\n            num_ignored_history=num_ignored_history,\n            max_q_seqlen=max_q_seqlen,\n            max_kv_seqlen=max_kv_seqlen,\n            sum_kv_seqlen=sum_kv_seqlen,\n            local_adapter_ids=local_adapter_ids,\n            model_metas=model_metas,\n            state_offsets=state_offsets,\n        )\n"
  },
  {
    "path": "lmdeploy/pytorch/strategies/dllm/sampling.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom typing import List\n\nimport numpy as np\nimport torch\nfrom torch.profiler import record_function\n\nfrom lmdeploy.pytorch.engine.logits_process import SamplingInputs, SamplingInputsDelta\nfrom lmdeploy.pytorch.messages import SchedulerSequence\nfrom lmdeploy.pytorch.model_inputs import ModelInputsDelta\n\nfrom ..ar.sampling import ARSamplingStrategy\nfrom .model_agent import DLLMExtraInputs\n\nSeqList = List[SchedulerSequence]\n\n\nclass DLLMSamplingStrategy(ARSamplingStrategy):\n    \"\"\"Sampling strategy for autoregressive models.\"\"\"\n\n    def __init__(self, pad_token_id: int, dllm_block_length: int) -> None:\n        super().__init__(pad_token_id)\n        self.dllm_block_length = dllm_block_length\n\n    @record_function('make_sampling_inputs')\n    def make_sampling_inputs(self, seqs: SeqList) -> SamplingInputs:\n        \"\"\"Create sampling inputs from the sequences.\"\"\"\n        out = super().make_sampling_inputs(seqs)\n        dllm_block_length = self.dllm_block_length\n\n        # repeat tensor\n        update_attr_names = [\n            'temperature',\n            'bad_words',\n            'bad_mask',\n            'stop_words',\n            'stop_mask',\n            'repetition_penalty',\n            'top_k',\n            'top_p',\n            'min_p',\n            'random_seeds',\n            'random_offsets',\n            'all_ids',\n            'num_ignore_eos',\n            'ngram_size',\n            'ngram_threshold',\n        ]\n        for name in update_attr_names:\n            attr = getattr(out, name)\n            if attr is None:\n                continue\n            if attr.dim() == 1:\n                repeats = (dllm_block_length, 1)\n                attr = attr[None].repeat(*repeats).flatten(0, 1)\n            elif attr.dim() == 2:\n                repeats = (1, dllm_block_length, 1)\n                attr = attr[:, None].repeat(*repeats).flatten(0, 1)\n            else:\n                repeats = (dllm_block_length, ) + (1, ) * (attr.dim())\n                attr = attr[None].repeat(*repeats).flatten(0, 1)\n            setattr(out, name, attr)\n\n        # update generated_ids_cpu\n        if out.generated_ids_cpu is not None:\n            generated_ids_cpu = out.generated_ids_cpu\n            if generated_ids_cpu.shape[1] == 0:\n                out.generated_ids_cpu = np.repeat(generated_ids_cpu, dllm_block_length, axis=0)\n            else:\n                generated_ids_cpu = np.repeat(generated_ids_cpu[:, None], dllm_block_length, axis=1)\n                generated_ids_cpu = np.reshape(generated_ids_cpu, (-1, generated_ids_cpu.shape[-1]))\n                out.generated_ids_cpu = generated_ids_cpu\n\n        if len(out.response_formats) > 0:\n            new_resp_formats = []\n            for resp in out.response_formats:\n                new_resp_formats += [resp] * dllm_block_length\n            out.response_formats = tuple(new_resp_formats)\n\n        out.batch_size *= dllm_block_length\n\n        return out\n\n    def merge_sampling_delta(\n        self,\n        sampling_delta: 'SamplingInputsDelta',\n        other: 'SamplingInputsDelta',\n    ) -> 'SamplingInputsDelta':\n        \"\"\"Merge two sampling deltas.\"\"\"\n        num_ignore_eos = torch.cat([sampling_delta.num_ignore_eos, other.num_ignore_eos], 0)\n        random_offsets = torch.cat([sampling_delta.random_offsets, other.random_offsets], 0)\n\n        return SamplingInputsDelta(\n            num_ignore_eos=num_ignore_eos,\n            random_offsets=random_offsets,\n            all_ids=None,\n        )\n\n    def update_sampling_delta(\n        self,\n        sampling_delta: 'SamplingInputsDelta',\n        delta: 'ModelInputsDelta',\n    ) -> 'SamplingInputsDelta':\n        \"\"\"Update sampling delta with model inputs delta.\"\"\"\n        indices = delta.indices\n        num_ignore_eos = sampling_delta.num_ignore_eos.view(-1, self.dllm_block_length)\n        num_ignore_eos = num_ignore_eos[indices].flatten()\n        if sampling_delta.random_offsets is not None:\n            random_offsets = sampling_delta.random_offsets.view(-1, self.dllm_block_length)\n            random_offsets = random_offsets[indices].flatten()\n        else:\n            random_offsets = None\n        return SamplingInputsDelta(\n            num_ignore_eos=num_ignore_eos,\n            random_offsets=random_offsets,\n            all_ids=None,\n        )\n\n    def step_sampling_delta(\n        self,\n        sampling_delta: 'SamplingInputsDelta',\n        next_token_ids: torch.Tensor,\n        extra_inputs: 'DLLMExtraInputs',\n    ) -> 'SamplingInputsDelta':\n        \"\"\"Step next delta.\"\"\"\n        from lmdeploy.pytorch import consts\n        dllm_mask = extra_inputs.dllm_mask\n        dllm_block_size = self.dllm_block_length\n        DLLM_UNMASKED = consts.DLLM_UNMASKED\n        is_unmasked = (dllm_mask == DLLM_UNMASKED).view(-1, dllm_block_size).all(dim=1, keepdim=True)\n        num_ignore_eos = sampling_delta.num_ignore_eos.view(-1, dllm_block_size)\n        num_ignore_eos = torch.where(is_unmasked, num_ignore_eos - dllm_block_size, num_ignore_eos)\n        sampling_delta.num_ignore_eos = num_ignore_eos.flatten()\n        if sampling_delta.random_offsets is not None:\n            # random offset is used to generate random numbers for multinomial sampling\n            # so we need to increase it by 1 at each step\n            sampling_delta.random_offsets += 1\n        return sampling_delta\n"
  },
  {
    "path": "lmdeploy/pytorch/strategies/dllm/sequence.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport time\nfrom dataclasses import dataclass, field\nfrom typing import Any, Dict, List, Optional\n\nimport numpy as np\nfrom torch import Tensor\n\nfrom lmdeploy.pytorch import consts\nfrom lmdeploy.pytorch.disagg.conn.protocol import MigrationRequest\nfrom lmdeploy.pytorch.engine.model_agent import BatchedOutputs\nfrom lmdeploy.pytorch.messages import (HistoryTokenIds, InputEmbeddings, MessageStatus, MultiModalInputs, SamplingParam,\n                                       SchedulerSession, UpdateTokenMode, _to_ndarray)\nfrom lmdeploy.pytorch.model_inputs import ModelInputs, ModelInputsDelta\n\nfrom ..ar.sequence import SchedulerSequenceDefault\nfrom ..base.sequence import SequenceStrategy\n\nSeqList = List['SchedulerSequenceDLLM']\n\nDLLM_MASKED = consts.DLLM_MASKED\nDLLM_UNMASKED = consts.DLLM_UNMASKED\nDLLM_CACHED = consts.DLLM_CACHED\nDLLM_MASK_DTYPE = np.uint8\n\n\nclass HistoryDLLMMask(HistoryTokenIds):\n\n    def __init__(self, token_ids: np.ndarray = None, dtype: np.dtype = DLLM_MASK_DTYPE):\n        super().__init__(token_ids=token_ids, dtype=dtype)\n\n\n@dataclass\nclass SchedulerSequenceDLLM(SchedulerSequenceDefault):\n\n    # For dllm\n    history_dllm_mask: HistoryDLLMMask = field(default_factory=HistoryDLLMMask)\n\n    def __post_init__(self):\n        \"\"\"Post init.\"\"\"\n        super().__post_init__()\n        self._num_valid_ids: int = len(self.history_cache)\n        self._strategy: DLLMSequenceStrategy = self._seq_meta.strategy\n\n    @property\n    def dllm_mask(self):\n        start = self.num_history_ids\n        end = start + self._num_token_ids\n        return self.history_dllm_mask[start:end]\n\n    @property\n    def num_valid_ids(self):\n        return self._num_valid_ids\n\n    @property\n    def generated_ids(self) -> np.ndarray:\n        end = self.num_valid_ids\n        start = end - self.num_new_tokens\n        return self.history_cache[start:end]\n\n    @property\n    def all_dllm_mask(self):\n        return self.history_dllm_mask[:self.num_all_ids]\n\n    @property\n    def dllm_block_length(self):\n        return self._strategy.block_size\n\n    @property\n    def dllm_mask_token(self):\n        return self._strategy.dllm_mask_token\n\n    def set_stop_pos(self, pos: int):\n        dllm_block_length = self.dllm_block_length\n        val = dllm_block_length - pos - 1\n        self._num_valid_ids -= val\n        self.num_new_tokens -= val\n\n    def _update_token_ids_inputs(self, token_ids: np.ndarray, dllm_mask: np.ndarray):\n        \"\"\"Append tokens.\"\"\"\n        num_tokens = len(token_ids)\n        dllm_block_length = self.dllm_block_length\n        dllm_mask_token = self.dllm_mask_token\n        new_token_ids = [token_ids]\n        new_dllm_mask = [dllm_mask]\n\n        # add uncached tokens in token_ids\n        # for example, [cccc cccc uumm], the [uu] in last block is remain valid.\n        num_remain_valid = self.num_valid_ids - self.num_history_ids\n        if num_remain_valid != 0:\n            prev_token_ids = self.valid_ids[-num_remain_valid:]\n            prev_dllm_mask = np.full_like(prev_token_ids, DLLM_UNMASKED, dtype=DLLM_MASK_DTYPE)\n            new_token_ids = [prev_token_ids] + new_token_ids\n            new_dllm_mask = [prev_dllm_mask] + new_dllm_mask\n            self.history_cache.resize(self.num_history_ids)\n            self.history_dllm_mask.resize(self.num_history_ids)\n            num_tokens += num_remain_valid\n\n        # pad to align with dllm_block_length\n        num_pad = (-num_tokens) % dllm_block_length\n        if num_pad > 0:\n            pad_ids = np.full_like(token_ids, dllm_mask_token, shape=(num_pad, ))\n            pad_mask = np.full_like(dllm_mask, DLLM_MASKED, shape=(num_pad, ))\n            new_token_ids += [pad_ids]\n            new_dllm_mask += [pad_mask]\n\n        token_ids = np.concatenate(new_token_ids)\n        dllm_mask = np.concatenate(new_dllm_mask)\n\n        assert len(token_ids) % dllm_block_length == 0\n\n        self.history_cache.append(token_ids)\n        self.history_dllm_mask.append(dllm_mask)\n        self.output_start_pos = self._num_valid_ids + len(token_ids)\n        self._num_valid_ids = self.num_history_ids + num_tokens\n        self._num_token_ids = len(token_ids)\n        self.num_new_tokens = 0\n\n    def _update_token_ids_decode(self, token_ids: np.ndarray, dllm_mask: np.ndarray):\n        \"\"\"Update token ids for decode.\"\"\"\n        num_tokens = len(token_ids)\n        dllm_block_length = self.dllm_block_length\n        dllm_mask_token = self.dllm_mask_token\n        assert num_tokens % dllm_block_length == 0\n        num_history_ids = self.num_history_ids\n\n        token_ids[dllm_mask == DLLM_MASKED] = dllm_mask_token\n        self.history_cache[num_history_ids:] = token_ids\n        self.history_dllm_mask[num_history_ids:] = dllm_mask\n\n        # check if all blocks are cached\n        last_mask = dllm_mask[-dllm_block_length:]\n        is_unmasked = np.all(last_mask == DLLM_UNMASKED)\n        is_cached = np.all(last_mask == DLLM_CACHED)\n\n        if is_unmasked:\n            num_new = dllm_block_length - self._num_valid_ids % dllm_block_length\n            self._num_valid_ids += num_new\n            self.num_new_tokens += num_new\n\n        if is_cached:\n            # add new block\n            new_token_ids = np.full_like(token_ids, dllm_mask_token, shape=(dllm_block_length, ))\n            new_dllm_mask = np.full_like(dllm_mask, DLLM_MASKED, shape=(dllm_block_length, ))\n            self.history_cache.append(new_token_ids)\n            self.history_dllm_mask.append(new_dllm_mask)\n            self._num_history_ids += self._num_token_ids\n            self._num_token_ids = dllm_block_length\n\n    def _update_token_ids_prefill(self, token_ids: np.ndarray, dllm_mask: np.ndarray):\n        \"\"\"Update token ids for prefill.\"\"\"\n        dllm_block_length = self.dllm_block_length\n        num_history_ids = self.num_history_ids\n\n        # fill input cache\n        if self.num_token_ids > dllm_block_length:\n            end = self.num_token_ids - dllm_block_length\n            self.history_dllm_mask[num_history_ids:end] = DLLM_CACHED\n            self._num_history_ids += end\n            self._num_token_ids -= end\n\n        # decoding update\n        self._update_token_ids_decode(token_ids, dllm_mask)\n\n    def update_token_ids(self,\n                         token_ids: Tensor,\n                         multimodals: MultiModalInputs = None,\n                         embeddings: List[InputEmbeddings] = None,\n                         model_meta: Dict[str, Any] = None,\n                         dllm_mask: Tensor = None,\n                         mode: UpdateTokenMode = UpdateTokenMode.INPUTS,\n                         **kwargs):\n        \"\"\"Update token ids, old token ids will be added to history.\"\"\"\n        # update history image nums\n        self._update_embeddings(embeddings)\n\n        # update multimodals\n        self._update_multimodals(multimodals)\n\n        self.arrive_time = time.perf_counter()\n\n        token_ids: np.ndarray = _to_ndarray(token_ids)\n        if dllm_mask is None:\n            dllm_mask = np.full_like(token_ids, DLLM_UNMASKED, dtype=DLLM_MASK_DTYPE)\n        dllm_mask: np.ndarray = _to_ndarray(dllm_mask)\n\n        if mode == UpdateTokenMode.INPUTS:\n            self._update_token_ids_inputs(token_ids, dllm_mask)\n        elif mode == UpdateTokenMode.PREFILL:\n            self._update_token_ids_prefill(token_ids, dllm_mask)\n        else:\n            self._update_token_ids_decode(token_ids, dllm_mask)\n\n        if model_meta is not None:\n            self.model_meta = model_meta\n\n    def set_step(self, step: int):\n        \"\"\"Set step.\"\"\"\n        # reset dllm mask\n        start = min(step, self.num_history_ids)\n        end = self.num_history_ids\n        if end > start:\n            to_change_mask = self.history_dllm_mask[start:]\n            to_change_mask[to_change_mask == DLLM_CACHED] = DLLM_UNMASKED\n        super().set_step(step)\n\n\nclass DLLMSequenceStrategy(SequenceStrategy):\n\n    def __init__(self, block_size: int, dllm_mask_token: int) -> None:\n        self.block_size = block_size\n        self.dllm_mask_token = dllm_mask_token\n\n    def make_sequence(self,\n                      seq_id: int,\n                      session: 'SchedulerSession',\n                      sampling_param: 'SamplingParam' = None,\n                      adapter_name: str = None,\n                      migration_request: Optional[MigrationRequest] = None,\n                      resp_cache: bool = False,\n                      preserve_cache: bool = False) -> 'SchedulerSequenceDLLM':\n        \"\"\"Make sequence.\"\"\"\n        return SchedulerSequenceDLLM(seq_id=seq_id,\n                                     session=session,\n                                     sampling_param=sampling_param,\n                                     adapter_name=adapter_name,\n                                     migration_request=migration_request,\n                                     resp_cache=resp_cache,\n                                     preserve_cache=preserve_cache)\n\n    def update_running(self, running: SeqList, batched_outputs: BatchedOutputs, model_inputs: 'ModelInputs',\n                       delta: 'ModelInputsDelta', **kwargs) -> None:\n        \"\"\"Update running sequences.\"\"\"\n        next_token_ids = batched_outputs.next_token_ids\n        stopped = batched_outputs.stopped\n        stopped = stopped.tolist()\n        model_metas = batched_outputs.model_metas\n        if model_metas is None:\n            model_metas = [None] * len(running)\n        dllm_mask = batched_outputs.extra_outputs.dllm_mask\n        stop_pos = batched_outputs.stop_pos\n\n        if model_inputs is None:\n            is_decoding = delta.is_decoding\n        else:\n            is_decoding = model_inputs.is_decoding\n\n        batch_size = len(running)\n        next_token_ids = next_token_ids.view(batch_size, -1).numpy()\n        dllm_mask = dllm_mask.view(batch_size, -1).numpy()\n        stop_pos = stop_pos.tolist()\n        update_mode = UpdateTokenMode.DECODE if is_decoding else UpdateTokenMode.PREFILL\n        for idx, token in enumerate(next_token_ids):\n            msg = running[idx]\n            stop = stopped[idx]\n            model_meta = model_metas[idx]\n            mask = dllm_mask[idx]\n            if msg.status != MessageStatus.RUNNING:\n                continue\n\n            # fill token\n            msg.update_token_ids(token, dllm_mask=mask, model_meta=model_meta, mode=update_mode)\n            if stop:\n                msg.set_stop_pos(stop_pos[idx])\n                msg.state.finish()\n"
  },
  {
    "path": "lmdeploy/pytorch/strategies/dllm/unmasking.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport torch\nfrom torch.profiler import record_function\n\nfrom lmdeploy.pytorch import consts\nfrom lmdeploy.pytorch.config import DLLMConfig, UnmaskingStrategy\n\nDLLM_MASKED = consts.DLLM_MASKED\nDLLM_UNMASKED = consts.DLLM_UNMASKED\nDLLM_CACHED = consts.DLLM_CACHED\n\n\nclass UnmaskingProcessor:\n\n    def __init__(self, dllm_config: DLLMConfig):\n        self.dllm_config = dllm_config\n\n    def _get_scores(self, logits: torch.Tensor, token_ids: torch.Tensor):\n        \"\"\"Get scores.\"\"\"\n        scores = logits.softmax(dim=-1)\n        scores = scores.gather(-1, token_ids.unsqueeze(-1)).flatten()\n        return scores\n\n    def _get_denoise_num(self):\n        \"\"\"Get denoise num.\"\"\"\n        block_size = self.dllm_config.block_length\n        denoising_steps = self.dllm_config.denoising_steps\n        if denoising_steps is None:\n            denoising_steps = block_size\n        num = block_size // self.dllm_config.denoising_steps\n        num = max(1, min(num, block_size))\n        return num\n\n    def low_confidence_static(self, logits: torch.Tensor, token_ids: torch.Tensor, dllm_mask: torch.Tensor):\n        \"\"\"static.\"\"\"\n        block_size = self.dllm_config.block_length\n        topk = self._get_denoise_num()\n        scores = self._get_scores(logits, token_ids)\n        is_masked = dllm_mask == DLLM_MASKED\n        scores = torch.where(is_masked, scores, scores.new_zeros((1, )))\n\n        scores = scores.view(-1, block_size)\n        dllm_mask = dllm_mask.view(-1, block_size)\n        _, indices = scores.topk(topk, dim=-1)\n        dllm_unmasked = dllm_mask.scatter(-1, indices, DLLM_UNMASKED)\n\n        is_masked = is_masked.view_as(dllm_mask)\n        dllm_mask = torch.where(is_masked, dllm_unmasked, dllm_mask)\n        return dllm_mask.flatten()\n\n    def low_confidence_dynamic(self, logits: torch.Tensor, token_ids: torch.Tensor, dllm_mask: torch.Tensor):\n        \"\"\"dynamic.\"\"\"\n        block_size = self.dllm_config.block_length\n        threshold = self.dllm_config.confidence_threshold\n        scores = self._get_scores(logits, token_ids)\n        is_masked = dllm_mask == DLLM_MASKED\n        scores = torch.where(is_masked, scores, scores.new_zeros((1, )))\n\n        scores = scores.view(-1, block_size)\n        dllm_mask = dllm_mask.view(-1, block_size)\n        _, indices = scores.topk(1, dim=-1)\n        scores = scores.scatter(-1, indices, threshold)\n\n        is_masked = is_masked.view_as(dllm_mask)\n        is_masked &= scores >= threshold\n        dllm_mask[is_masked] = DLLM_UNMASKED\n        return dllm_mask.flatten()\n\n    def sequential(self, dllm_mask: torch.Tensor):\n        \"\"\"sequential.\"\"\"\n        block_size = self.dllm_config.block_length\n        denoise_num = self._get_denoise_num()\n        dllm_mask = dllm_mask.view(-1, block_size)\n        is_masked = dllm_mask == DLLM_MASKED\n\n        # get indices\n        indices = is_masked.int().argmax(dim=1)\n        ranges = torch.arange(0, denoise_num, device=indices.device, dtype=indices.dtype)\n        indices = indices[:, None] + ranges[None, :]\n        indices = indices % block_size\n\n        dllm_unmasked = dllm_mask.clone()\n        dllm_unmasked = dllm_unmasked.scatter(-1, indices, DLLM_UNMASKED)\n        dllm_mask = torch.where(is_masked, dllm_unmasked, dllm_mask)\n\n        return dllm_mask.flatten()\n\n    @record_function('unmasking')\n    def __call__(self, logits: torch.Tensor, input_ids: torch.Tensor, token_ids: torch.Tensor, dllm_mask: torch.Tensor):\n        \"\"\"call.\"\"\"\n        strategy = self.dllm_config.unmasking_strategy\n        if strategy is None:\n            return dllm_mask\n\n        # reshape to [num_blocks, block_size]\n        block_size = self.dllm_config.block_length\n        dllm_mask = dllm_mask.unflatten(0, (-1, block_size))\n\n        is_same = (dllm_mask == dllm_mask[:, :1]).all(dim=1)\n        first_mask = dllm_mask[:, 0]\n\n        # unmasked to cache\n        is_block_unmasked = is_same & (first_mask == DLLM_UNMASKED)\n        dllm_mask[is_block_unmasked] = DLLM_CACHED\n\n        dllm_mask = dllm_mask.flatten()\n        token_ids = torch.where(dllm_mask != DLLM_MASKED, input_ids, token_ids)\n        if strategy == UnmaskingStrategy.LOW_CONFIDENCE_STATIC:\n            dllm_mask = self.low_confidence_static(logits, token_ids, dllm_mask)\n        elif strategy == UnmaskingStrategy.LOW_CONFIDENCE_DYNAMIC:\n            dllm_mask = self.low_confidence_dynamic(logits, token_ids, dllm_mask)\n        elif strategy == UnmaskingStrategy.SEQUENTIAL:\n            dllm_mask = self.sequential(dllm_mask)\n        else:\n            raise RuntimeError(f'strategy {strategy} not supported.')\n\n        return dllm_mask, token_ids\n"
  },
  {
    "path": "lmdeploy/pytorch/third_party/__init__.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n"
  },
  {
    "path": "lmdeploy/pytorch/third_party/deep_gemm/__init__.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom contextlib import contextmanager\n\nfrom lmdeploy.utils import get_logger\n\nlogger = get_logger('lmdeploy')\n\ntry:\n    import deep_gemm  # noqa: F401\nexcept ImportError:\n    logger.exception('DeepGemm is not installed. Please install https://github.com/deepseek-ai/DeepGEMM.')\n\nfrom deep_gemm import ceil_div, get_m_alignment_for_contiguous_layout  # noqa: F401, E402\n\ntry:\n    from deep_gemm import fp8_gemm_nt\nexcept Exception:\n    from deep_gemm.jit_kernels.gemm import gemm_fp8_fp8_bf16_nt\n\n    @contextmanager\n    def _log_jit_build(M: int, N: int, K: int):\n        from deep_gemm.jit.runtime import RuntimeCache\n\n        if hasattr(RuntimeCache, 'get'):\n            func_name = 'get'\n        else:\n            func_name = '__getitem__'\n        origin_func = getattr(RuntimeCache, func_name)\n\n        def __patched_func(self, *args, **kwargs):\n            ret = origin_func(self, *args, **kwargs)\n            if ret is None:\n                logger.warning(f'DeepGemm build <gemm_fp8_fp8_bf16_nt>: M={M}, N={N}, K={K}. Please waiting.')\n            return ret\n\n        setattr(RuntimeCache, func_name, __patched_func)\n        yield\n        setattr(RuntimeCache, func_name, origin_func)\n\n    def fp8_gemm_nt(a, b, d, c, recipe=None, compiled_dim='nk', disable_ue8m0_cast=False):\n        M, K = a[0].shape\n        N, _ = b[0].shape\n        with _log_jit_build(M, N, K):\n            gemm_fp8_fp8_bf16_nt(a, b, d)\n\n\ntry:\n    from deep_gemm import m_grouped_fp8_gemm_nt_contiguous\nexcept Exception:\n    from deep_gemm import m_grouped_gemm_fp8_fp8_bf16_nt_contiguous\n\n    def m_grouped_fp8_gemm_nt_contiguous(a, b, d, m_indices, recipe=None, compiled_dims='nk', disable_ue8m0_cast=False):\n        assert recipe is None\n        assert compiled_dims == 'nk'\n        assert disable_ue8m0_cast is False\n        return m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(a, b, d, m_indices)\n\n\ntry:\n    from deep_gemm import m_grouped_fp8_gemm_nt_masked\nexcept Exception:\n    from deep_gemm import m_grouped_gemm_fp8_fp8_bf16_nt_masked\n\n    def m_grouped_fp8_gemm_nt_masked(a,\n                                     b,\n                                     d,\n                                     masked_m,\n                                     expected_m,\n                                     recipe=None,\n                                     compiled_dims='nk',\n                                     disable_ue8m0_cast=False):\n        assert recipe is None\n        assert compiled_dims == 'nk'\n        assert disable_ue8m0_cast is False\n        return m_grouped_gemm_fp8_fp8_bf16_nt_masked(a, b, d, masked_m, expected_m)\n\n\ntry:\n    from deep_gemm import get_mn_major_tma_aligned_tensor\nexcept Exception:\n    from deep_gemm import get_col_major_tma_aligned_tensor\n\n    def get_mn_major_tma_aligned_tensor(x):\n        return get_col_major_tma_aligned_tensor(x)\n\n\ntry:\n    from deep_gemm import m_grouped_bf16_gemm_nt_contiguous, m_grouped_bf16_gemm_nt_masked  # noqa: F401\nexcept Exception:\n    logger.warning('DeepGemm bf16 grouped gemm kernels are not found. '\n                   'Please upgrade DeepGemm to the latest version.')\n"
  },
  {
    "path": "lmdeploy/pytorch/third_party/flash_attn_interface.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport functools\n\nfrom flash_attn_interface import flash_attn_varlen_func as _flash_attn_varlen_func\nfrom flash_attn_interface import flash_attn_with_kvcache as _flash_attn_with_kvcache\n\n\n@functools.wraps(_flash_attn_varlen_func)\ndef flash_attn_varlen_func(*args, **kwargs):\n    output = _flash_attn_varlen_func(*args, **kwargs)\n    if isinstance(output, tuple):\n        # for old api\n        return output[0]\n    return output\n\n\n@functools.wraps(_flash_attn_with_kvcache)\ndef flash_attn_with_kvcache(*args, **kwargs):\n    output = _flash_attn_with_kvcache(*args, **kwargs)\n    return output\n"
  },
  {
    "path": "lmdeploy/pytorch/tools/__init__.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom .utils import Timer  # noqa: F401\n"
  },
  {
    "path": "lmdeploy/pytorch/tools/utils.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom contextlib import contextmanager\nfrom typing import List\n\n\nclass Timer:\n    \"\"\"Debug timer.\"\"\"\n\n    def __init__(self):\n        self.duration = None\n        self.timer_type = None\n\n    def tic_cpu(self):\n        self.timer_type = 'cpu'\n        import time\n        self._start = time.perf_counter()\n\n    def toc_cpu(self):\n        assert self.timer_type == 'cpu'\n        import time\n        self._end = time.perf_counter()\n        self.duration = (self._end - self._start) * 1000\n        return self\n\n    def tic_cuda(self):\n        self.timer_type = 'cuda'\n        import torch\n        self._start = torch.cuda.Event(enable_timing=True)\n        self._end = torch.cuda.Event(enable_timing=True)\n        self._start.record()\n\n    def toc_cuda(self):\n        assert self.timer_type == 'cuda'\n        import torch\n        self._end.record()\n        torch.cuda.synchronize()\n        self.duration = self._start.elapsed_time(self._end)\n        return self\n\n    @classmethod\n    def tic(cls, is_cuda: bool = False) -> 'Timer':\n        timer = Timer()\n        if is_cuda:\n            timer.tic_cuda()\n        else:\n            timer.tic_cpu()\n        return timer\n\n    def toc(self):\n        if self.timer_type == 'cpu':\n            return self.toc_cpu()\n        elif self.timer_type == 'cuda':\n            return self.toc_cuda()\n        else:\n            raise RuntimeError(f'Unknown timer_type: {self.timer_type}')\n\n    @classmethod\n    @contextmanager\n    def timing(cls, is_cuda: bool = False) -> 'Timer':\n        timer = cls.tic(is_cuda=is_cuda)\n        yield timer\n        timer.toc()\n\n    @staticmethod\n    def format_duration(duration: float, acc: int = 3):\n        \"\"\"Format duration.\"\"\"\n        unit = 'ms'\n        if duration < 1:\n            duration *= 1000\n            unit = 'μs'\n        elif duration > 1000:\n            duration /= 1000\n            unit = 's'\n\n        return f'{duration:.{acc}f} {unit}'\n\n    @staticmethod\n    def format_flops(flops: float, acc: int = 3):\n        \"\"\"Compute flops.\"\"\"\n        unit = ''\n        if flops > (1 << 40):\n            flops /= (1 << 40)\n            unit = 'T'\n        elif flops > (1 << 30):\n            flops /= (1 << 30)\n            unit = 'G'\n        elif flops > (1 << 20):\n            flops /= (1 << 20)\n            unit = 'M'\n        elif flops > (1 << 10):\n            flops /= (1 << 10)\n            unit = 'K'\n        return f'{flops:.{acc}f} {unit}Flop/s'\n\n    @staticmethod\n    def formatted_print(out_info: dict, title: str = None):\n        \"\"\"Formatted print.\"\"\"\n        max_key_len = max(len(k) for k in out_info.keys())\n        max_key_len = min(10, max_key_len)\n        max_val_len = max(len(k) for k in out_info.values())\n        max_val_len = min(10, max_val_len)\n\n        if title is not None:\n            print(title)\n        for k, v in out_info.items():\n            print(f'{k:>{max_key_len}} : {v:>{max_val_len}}')\n\n    def print(self, flop: int = None, title: str = None):\n        \"\"\"print.\"\"\"\n        if self.duration is None:\n            print('Please run Timer.tic() first.')\n            return\n\n        out_info = dict()\n\n        formated_dur = self.format_duration(self.duration)\n        out_info['Duration'] = f'{formated_dur}'\n\n        if flop is not None:\n            flops = flop / self.duration * 1000\n            formated_flops = self.format_flops(flops)\n            out_info['Flops'] = f'{formated_flops}'\n\n        self.formatted_print(out_info, title)\n\n    def toc_print(self, flop: int = None, title: str = None):\n        return self.toc().print(flop=flop, title=title)\n\n\ndef visualize_pipe_out(outputs, enable_meta: bool = True):\n    import os\n\n    from lmdeploy.messages import Response\n\n    try:\n        from termcolor import colored\n    except ImportError:\n\n        def colored(text, color=None, on_color=None, attrs=None):\n            return text\n\n    if isinstance(outputs, Response):\n        outputs = [outputs]\n    elif outputs is None:\n        outputs = [outputs]\n    try:\n        term_size = os.get_terminal_size().columns\n    except Exception:\n        term_size = 100\n\n    border_color = 'cyan'\n    meta_color = 'light_grey'\n    number_color = 'green'\n\n    def _print_title(title: str, color: str = border_color):\n        title_text = f' {title} '\n        print(colored(f'【{title_text}】', color, attrs=['bold']))\n\n    def _print_section(title: str, content: str, color: str = border_color):\n        \"\"\"Simple title and content printing.\"\"\"\n        _print_title(title, color)\n        print(content)\n\n    def _print_meta(out: Response):\n        \"\"\"Enhanced meta information display.\"\"\"\n        # Create a clean table-like format\n        finish_color = 'yellow' if out.finish_reason == 'stop' else 'red'\n        meta_content = [\n            f\"{colored('• Input Tokens:', meta_color)}     {colored(out.input_token_len, number_color)}\",\n            f\"{colored('• Generated Tokens:', meta_color)} {colored(out.generate_token_len, number_color)}\",\n            f\"{colored('• Finish Reason:', meta_color)}    {colored(out.finish_reason, finish_color)}\"\n        ]\n        if out.routed_experts is not None:\n            shape = tuple(out.routed_experts.shape)\n            meta_content.append(f\"{colored('• Routed Experts:', meta_color)}  {colored(shape, number_color)}\")\n        if out.logits is not None:\n            shape = tuple(out.logits.shape)\n            meta_content.append(f\"{colored('• Logits Shape:', meta_color)}     {colored(shape, number_color)}\")\n        if out.logprobs is not None:\n            size = len(out.logprobs)\n            meta_content.append(f\"{colored('• Logprobs:', meta_color)}      {colored(size, number_color)}\")\n\n        lines = '\\n'.join(meta_content)\n        lines += '\\n'\n        _print_section('METADATA', lines, border_color)\n\n    # Main loop\n    print(colored('━' * term_size, border_color))\n\n    outputs: List[Response] = outputs\n    for idx, out in enumerate(outputs):\n        header = f'OUTPUT [{idx + 1}/{len(outputs)}]'\n        header_formatted = colored(f'✦ {header}', 'light_magenta', attrs=['bold'])\n        print(header_formatted)\n        print()\n\n        if out is not None:\n            if enable_meta:\n                _print_meta(out)\n\n            _print_section('TEXT', out.text, border_color)\n\n        if idx < len(outputs) - 1:  # Add separator when it's not the last output\n            print(colored('─' * (term_size), border_color, attrs=['dark']))\n        else:\n            print(colored('━' * term_size, border_color))\n\n\ndef visualize_chat_completions(outputs, enable_meta: bool = True):\n    \"\"\"Visualize chat completions.\"\"\"\n    from openai.types.chat import ChatCompletion\n\n    from lmdeploy.messages import Response\n    if isinstance(outputs, ChatCompletion):\n        outputs = [outputs]\n\n    resps = []\n    for out in outputs:\n        assert isinstance(out, ChatCompletion)\n        choice = out.choices[0]\n        resp = Response(text=choice.message.content,\n                        input_token_len=out.usage.prompt_tokens,\n                        generate_token_len=out.usage.completion_tokens,\n                        finish_reason=choice.finish_reason)\n        resps.append(resp)\n\n    return visualize_pipe_out(resps, enable_meta=enable_meta)\n\n\nsources = None\n\n\ndef dump_tilelang_source(kernel, path: str = 'sources/tvm_kernels.cu'):\n    global sources\n    if sources is not None:\n        return\n    sources = kernel.get_kernel_source()\n    with open(path, 'w') as f:\n        f.write(sources)\n"
  },
  {
    "path": "lmdeploy/pytorch/transformers/__init__.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom functools import lru_cache\n\nfrom transformers import AutoConfig\n\nfrom lmdeploy.utils import get_logger\n\n\n@lru_cache()\ndef register_config(model_type: str):\n    if model_type == 'deepseek_v32':\n        from lmdeploy.pytorch.transformers.configuration_deepseek_v32 import DeepseekV32Config\n        AutoConfig.register(DeepseekV32Config.model_type, DeepseekV32Config)\n    else:\n        logger.debug(f'Can not register config for model_type: {model_type}')\n\n\nlogger = get_logger('lmdeploy')\n\n\ndef config_from_pretrained(pretrained_model_name_or_path: str, **kwargs):\n    try:\n        return AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)\n    except ValueError as e:\n        logger.debug(f'AutoConfig.from_pretrained failed: {e}, try register config manually.')\n        # some models (dsv32) does not provide auto map for config\n        from transformers import PretrainedConfig\n        trust_remote_code = kwargs.pop('trust_remote_code', None)\n        config_dict, _ = PretrainedConfig.get_config_dict(pretrained_model_name_or_path, **kwargs)\n        model_type = config_dict.get('model_type', None)\n        if trust_remote_code is not None:\n            kwargs['trust_remote_code'] = trust_remote_code\n        register_config(model_type)\n\n    return AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)\n"
  },
  {
    "path": "lmdeploy/pytorch/transformers/configuration_deepseek_v32.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nfrom transformers.models.deepseek_v3.configuration_deepseek_v3 import DeepseekV3Config\n\n\nclass DeepseekV32Config(DeepseekV3Config):\n    model_type = 'deepseek_v32'\n\n    def __init__(self, index_head_dim=128, index_n_heads=64, index_topk=2048, **kwargs):\n        super().__init__(**kwargs)\n        self.index_head_dim = index_head_dim\n        self.index_n_heads = index_n_heads\n        self.index_topk = index_topk\n"
  },
  {
    "path": "lmdeploy/pytorch/utils.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n# modify from: https://github.com/vllm-project/vllm\nimport asyncio\nimport inspect\nfrom contextlib import contextmanager\nfrom inspect import Parameter, Signature\nfrom typing import Dict, Generic, Optional, Sequence, TypeVar\n\nimport psutil\n\nfrom lmdeploy.utils import get_logger\n\nlogger = get_logger('lmdeploy')\n\n\ndef get_gpu_memory(device_id: int = None) -> int:\n    \"\"\"Returns the free and total physical memory of the GPU in bytes.\"\"\"\n    import torch\n    if device_id is None:\n        device_id = torch.cuda.current_device()\n    return torch.cuda.mem_get_info(device_id)\n\n\ndef get_cpu_memory() -> int:\n    \"\"\"Returns the total CPU memory of the node in bytes.\"\"\"\n    return psutil.virtual_memory().total\n\n\ndef bind_sigature(input_names: str, args: Sequence, kwargs: Dict):\n    \"\"\"Bind args and kwargs to given input names.\"\"\"\n    kind = inspect._ParameterKind.POSITIONAL_OR_KEYWORD\n\n    sig = Signature([Parameter(name, kind) for name in input_names])\n    bind = sig.bind(*args, **kwargs)\n    return bind.arguments\n\n\ndef singleton(cls):\n    \"\"\"Singleton decorator.\"\"\"\n    import multiprocessing as mp\n\n    from lmdeploy.utils import get_logger\n    logger = get_logger('lmdeploy')\n    instances = {}\n\n    def get_instance(*args, **kwargs):\n        if cls not in instances:\n            pid = mp.current_process().pid\n            logger.debug(f'pid:{pid} - Creating instance of singleton class {cls.__name__}')\n            instances[cls] = cls(*args, **kwargs)\n        return instances[cls]\n\n    return get_instance\n\n\nT = TypeVar('T')\n\n\nclass CtxMgrBase(Generic[T]):\n    \"\"\"Context manager base class.\"\"\"\n\n    def __init__(self, default: Optional[T] = None):\n        self._context = default\n\n    def current_context(self) -> Optional[T]:\n        \"\"\"Get current context.\"\"\"\n        return self._context\n\n    def set_context(self, context: Optional[T]):\n        \"\"\"Set current context.\"\"\"\n        self._context = context\n\n    @contextmanager\n    def context(self, context: T):\n        \"\"\"Context manager.\"\"\"\n        origin_context = self.current_context()\n        self.set_context(context)\n        try:\n            yield self\n        finally:\n            self.set_context(origin_context)\n\n\n# from vllm\ndef maybe_register_config_serialize_by_value(trust_remote_code: bool) -> None:\n    \"\"\"Try to register HF model configuration class to serialize by value With\n    trust_remote_code, the config class is typically an instance of a custom\n    class imported from the HF modules cache.\n\n    The class will not be\n    importable in spawned workers by default (and won't exist at all on\n    other nodes), which breaks serialization of the config.\n    In this function we tell the cloudpickle serialization library to pass\n    instances of these generated classes by value instead of by reference,\n    i.e. the class definition is serialized along with its data so that the\n    class module does not need to be importable on the receiving end. This\n    registration only works if the modules cache has already been\n    initialized.\n    See: https://github.com/cloudpipe/cloudpickle?tab=readme-ov-file#overriding-pickles-serialization-mechanism-for-importable-constructs\n    \"\"\"  # noqa: E501\n    if not trust_remote_code:\n        return\n\n    try:\n        import transformers_modules\n    except ImportError:\n        logger.debug('Could not import transformers_modules used for remote'\n                     ' code. If remote code is not needed remove'\n                     ' `--trust-remote-code`.')\n        return\n\n    try:\n        import cloudpickle\n        cloudpickle.register_pickle_by_value(transformers_modules)\n\n        # ray vendors its own version of cloudpickle\n        try:\n            import ray\n        except ImportError:\n            return\n\n        ray.cloudpickle.register_pickle_by_value(transformers_modules)\n\n        # multiprocessing uses pickle to serialize arguments when using spawn\n        # Here we get pickle to use cloudpickle to serialize ModelConfig objects\n        # that contain instances of the custom config class to avoid\n        # serialization problems if the generated module (and model) has a `.`\n        # in its name\n        import multiprocessing\n        import pickle\n\n        from lmdeploy.pytorch.config import ModelConfig\n\n        def _reduce_modelconfig(mc: ModelConfig):\n            return (pickle.loads, (cloudpickle.dumps(mc), ))\n\n        multiprocessing.reducer.register(ModelConfig, _reduce_modelconfig)\n\n    except Exception as e:\n        logger.warning(\n            'Unable to register remote classes used by'\n            ' trust_remote_code with by-value serialization. This may'\n            ' lead to a later error. If remote code is not needed'\n            ' remove `--trust-remote-code`',\n            exc_info=e)\n\n\ndef monkey_patch_hf_modules_cache():\n    \"\"\"Monkey patch HF_MODULES_CACHE to a temporary directory per process. This\n    is necessary to avoid conflicts when multiple processes try to read/write\n    to the same HF_MODULES_CACHE directory, especially in multi-GPU setups.\n\n    modified from: https://github.com/InternLM/xtuner/blob/main/xtuner/v1/utils/misc.py\n    \"\"\"\n    import os\n\n    import transformers\n    from huggingface_hub import constants\n\n    # When using `remote_code` in HF components like tokenizer or config\n    # (e.g., `AutoConfig.from_pretrained(hf_model_path, trust_remote_code=True)`),\n    # the hf_model_path is copied to HF_MODULES_CACHE.\n    # On multi-GPU machines (e.g., 8 GPUs), simultaneous read/write operations\n    # by multiple processes on this shared directory can cause conflicts.\n    # Therefore, we set HF_MODULES_CACHE to a temporary directory per process.\n\n    HF_PATCH_MODULES_CACHE_PREFIX = 'modules_pid_'\n    modules_cache = os.path.join(constants.HF_HOME, f'{HF_PATCH_MODULES_CACHE_PREFIX}{os.getpid()}')\n    os.environ['HF_MODULES_CACHE'] = modules_cache\n\n    transformers.utils.hub.HF_MODULES_CACHE = modules_cache\n\n    # During import, Python creates a new name HF_MODULES_CACHE in the namespace\n    # of the dynamic_module_utils module, binding it to the object referenced by\n    # transformers.utils.HF_MODULES_CACHE at that moment.\n    # Hence, we also need to set transformers.dynamic_module_utils.HF_MODULES_CACHE\n    # to the new modules_cache.\n\n    transformers.dynamic_module_utils.HF_MODULES_CACHE = modules_cache\n    transformers.utils.HF_MODULES_CACHE = modules_cache\n\n    logger.info(f'Set HF_MODULES_CACHE to {modules_cache} for current process {os.getpid()}')\n\n\nasync def wait_for_async_tasks(tasks: Sequence[asyncio.Task],\n                               cancel_pending: bool = True,\n                               ignore_cancellederror: bool = True):\n    \"\"\"Wait for async tasks.\"\"\"\n    if len(tasks) == 0:\n        return [], []\n\n    for task in tasks:\n        if not isinstance(task, asyncio.Task):\n            raise ValueError('All inputs must be asyncio.Task instances.')\n\n    try:\n        done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_EXCEPTION)\n\n        if cancel_pending:\n            # cancel all pending tasks\n            for task in pending:\n                task.cancel()\n\n        # raise exception if any\n        for task in done:\n            if task.cancelled():\n                continue\n            if exc := task.exception():\n                if isinstance(exc, asyncio.CancelledError) and ignore_cancellederror:\n                    logger.debug(f'Task <{task.get_name()}> cancelled.')\n                    continue\n                raise exc from None\n    except asyncio.CancelledError:\n        for task in tasks:\n            if not task.done():\n                task.cancel()\n        raise\n\n    return done, pending\n\n\nasync def cancel_async_tasks(tasks: Sequence[asyncio.Task]):\n    \"\"\"Cancel async tasks.\"\"\"\n    if isinstance(tasks, asyncio.Task):\n        tasks = [tasks]\n\n    tasks = list(task for task in tasks if not task.done())\n    for task in tasks:\n        task.cancel()\n    return await asyncio.gather(*tasks, return_exceptions=True)\n"
  },
  {
    "path": "lmdeploy/pytorch/weight_loader/__init__.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n"
  },
  {
    "path": "lmdeploy/pytorch/weight_loader/model_weight_loader.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nimport json\nimport os.path as osp\n\nimport numpy as np\nimport torch\nfrom safetensors.torch import safe_open\nfrom tqdm.auto import tqdm\nfrom transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME\n\nfrom lmdeploy.pytorch import envs as _envs\nfrom lmdeploy.pytorch.distributed import get_world_rank\nfrom lmdeploy.utils import get_logger\n\nlogger = get_logger('lmdeploy')\n\n\ndef load_weight(param: torch.nn.Parameter, loaded_weight: torch.Tensor, **kwargs):\n    \"\"\"Load weight.\"\"\"\n    if hasattr(param, 'weight_loader'):\n        param.weight_loader(param, loaded_weight, **kwargs)\n    else:\n        assert len(kwargs) == 0\n        default_weight_loader(param, loaded_weight)\n\n\ndef default_weight_loader(param: torch.nn.Parameter, loaded_weight: torch.Tensor):\n    \"\"\"Default weight loader.\"\"\"\n    if param.numel() == 1 and loaded_weight.numel() == 1:\n        param.data.fill_(loaded_weight.item())\n    else:\n        assert param.size() == loaded_weight.size(), (f'Attempted to load weight ({loaded_weight.size()}) '\n                                                      f'into parameter ({param.size()})')\n        param.data.copy_(loaded_weight)\n\n\ndef _get_weight_type(model_path: str, use_safetensors: bool = None):\n    \"\"\"Get weight type.\"\"\"\n    weight_type = None\n    is_sharded = False\n    if use_safetensors is not False and osp.isfile(osp.join(model_path, SAFE_WEIGHTS_NAME)):\n        # Load from a safetensors checkpoint\n        weight_type = 'safetensors'\n    elif use_safetensors is not False and osp.isfile(osp.join(model_path, SAFE_WEIGHTS_INDEX_NAME)):\n        # Load from a sharded safetensors checkpoint\n        weight_type = 'safetensors'\n        is_sharded = True\n    elif osp.isfile(osp.join(model_path, WEIGHTS_NAME)):\n        # Load from a PyTorch checkpoint\n        weight_type = 'pytorch'\n    elif osp.isfile(osp.join(model_path, WEIGHTS_INDEX_NAME)):\n        # Load from a sharded PyTorch checkpoint\n        weight_type = 'pytorch'\n        is_sharded = True\n    else:\n        raise RuntimeError('Unknown weight type.')\n\n    return (weight_type, is_sharded)\n\n\ndef _get_weight_map(model_path: str, weight_type: str):\n    \"\"\"Get weight index.\"\"\"\n    if weight_type == 'safetensors':\n        load_index = osp.join(model_path, SAFE_WEIGHTS_INDEX_NAME)\n    elif weight_type == 'pytorch':\n        load_index = osp.join(model_path, WEIGHTS_INDEX_NAME)\n    else:\n        raise RuntimeError(f'Unsupported weight type: {weight_type}.')\n\n    with open(load_index, mode='r', encoding='utf-8') as f:\n        index = json.load(f)\n\n    weight_map = index['weight_map']\n    return weight_map\n\n\ndef _get_weight_path(model_path: str, weight_type: str):\n    \"\"\"Get weight path.\"\"\"\n    if weight_type == 'safetensors':\n        weight_name = SAFE_WEIGHTS_NAME\n    elif weight_type == 'pytorch':\n        weight_name = WEIGHTS_NAME\n    else:\n        raise RuntimeError('Unknown weight type.')\n\n    weight_path = osp.join(model_path, weight_name)\n    return weight_path, weight_name\n\n\ndef _get_safetensors_weights_iterator(file: str, prefix: str):\n    \"\"\"Get safeternsors weights iterator.\"\"\"\n    with safe_open(file, framework='pt') as f:\n        for name in f.keys():\n            param = f.get_tensor(name)\n            if prefix is not None:\n                name = f'{prefix}{name}'\n            yield name, param\n\n\ndef _get_pt_weights_iterator(file: str, prefix: str):\n    \"\"\"Get pt weights iterator.\"\"\"\n    state = torch.load(file, weights_only=True, map_location='cpu')\n    try:\n        if prefix is None:\n            yield from state.items()\n        else:\n            for k, v in state.items():\n                yield f'{prefix}{k}', v\n    finally:\n        del state\n        torch.cuda.empty_cache()\n\n\nclass ModelWeightLoader:\n    \"\"\"Model weight loader for sharded weights.\"\"\"\n\n    def __init__(self, model_path: str, prefix: str = None):\n        self.model_path = model_path\n        weight_type, is_sharded = _get_weight_type(model_path)\n\n        self._weight_type = weight_type\n        self._is_sharded = is_sharded\n        self._prefix = prefix\n        self._shard_paths = self._get_shard_paths(model_path, is_sharded, weight_type)\n\n    @staticmethod\n    def _get_shard_paths(model_path: str, is_sharded: bool, weight_type: str):\n        \"\"\"Get shard paths.\"\"\"\n        if is_sharded:\n            weight_map = _get_weight_map(model_path, weight_type)\n            paths = set(weight_map.values())\n            paths = tuple(f'{model_path}/{path}' for path in paths)\n            return paths\n        else:\n            path, _ = _get_weight_path(model_path, weight_type)\n            return (path, )\n\n    def _get_weights_iterator(self, path: str):\n        \"\"\"Get weights iterator.\"\"\"\n        if self._weight_type == 'safetensors':\n            weights_iterator = _get_safetensors_weights_iterator(path, self._prefix)\n        else:\n            weights_iterator = _get_pt_weights_iterator(path, self._prefix)\n        return weights_iterator\n\n    @staticmethod\n    def _skip_dummy_iterator(iterator, dummy_prefix: list):\n        \"\"\"Wrap iterator to skip dummy weights.\"\"\"\n        for name, param in iterator:\n            if not any(name.startswith(prefix) for prefix in dummy_prefix):\n                yield name, param\n\n    @staticmethod\n    def _rename_weights_iterator(iterator, model: torch.nn.Module):\n        \"\"\"Wrap iterator to rename weights.\"\"\"\n        rename_func = getattr(model, 'rename_weight', lambda x: x)\n        for name, param in iterator:\n            new_name = rename_func(name)\n            yield new_name, param\n\n    def load_model_weights(\n        self,\n        model: torch.nn.Module,\n        device: torch.device = None,\n    ):\n        \"\"\"Load model weights implementation.\"\"\"\n        assert hasattr(model, 'load_weights')\n        paths = self._shard_paths\n        _, rank = get_world_rank()\n        disable_tqdm = rank != 0\n\n        # get dummy prefix\n        dummy_prefix = []\n        for name, mod in model.named_modules():\n            if getattr(mod, '_is_dummy_mod', False):\n                dummy_prefix.append(f'{name}.')\n\n        paths = sorted(paths)\n        if _envs.random_load_weight:\n            np.random.shuffle(paths)\n        for path in tqdm(paths, desc='Loading weights from safetensors', disable=disable_tqdm):\n            weights_iterator = self._get_weights_iterator(path)\n            weights_iterator = self._rename_weights_iterator(weights_iterator, model)\n            if len(dummy_prefix) > 0:\n                weights_iterator = self._skip_dummy_iterator(weights_iterator, dummy_prefix)\n            model.load_weights(weights_iterator)\n        if device is not None:\n            model.to(device)\n\n\n@torch.inference_mode()\ndef load_model_weights(model: torch.nn.Module, checkpoint_path: str, prefix: str = None, device: torch.device = None):\n    \"\"\"Loading model weights.\"\"\"\n    loader = ModelWeightLoader(checkpoint_path, prefix=prefix)\n    loader.load_model_weights(model, device=device)\n    model.eval()\n    for _, mod in model.named_modules():\n        if not hasattr(mod, 'update_weights'):\n            continue\n        mod.update_weights()\n"
  },
  {
    "path": "lmdeploy/serve/__init__.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom .core import AsyncEngine, VLAsyncEngine\nfrom .managers import Session, SessionManager\nfrom .processors import MultimodalProcessor\n\n__all__ = [\n    'AsyncEngine',\n    'VLAsyncEngine',\n    'SessionManager',\n    'Session',\n    'MultimodalProcessor',\n]\n"
  },
  {
    "path": "lmdeploy/serve/core/__init__.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom .async_engine import AsyncEngine\nfrom .vl_async_engine import VLAsyncEngine\n\n__all__ = ['AsyncEngine', 'VLAsyncEngine']\n"
  },
  {
    "path": "lmdeploy/serve/core/async_engine.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nimport asyncio\nimport concurrent.futures\nimport dataclasses\nimport random\nfrom contextlib import asynccontextmanager\nfrom copy import deepcopy\nfrom typing import Any, Dict, List, Literal\n\nimport torch\n\nfrom lmdeploy.archs import get_model_arch\nfrom lmdeploy.logger import RequestLogger\nfrom lmdeploy.messages import (EngineOutput, GenerationConfig, PytorchEngineConfig, Response, ResponseType,\n                               SpeculativeConfig, TurbomindEngineConfig)\nfrom lmdeploy.metrics.metrics_processor import metrics_processor\nfrom lmdeploy.metrics.stats import IterationStats, RequestStats, SpeculativeDecodingStats\nfrom lmdeploy.model import ChatTemplateConfig, get_chat_template\nfrom lmdeploy.pytorch.disagg.conn.protocol import (DistServeConnectionRequest, DistServeDropConnectionRequest,\n                                                   DistServeInitRequest)\nfrom lmdeploy.serve.managers import Session, SessionManager\nfrom lmdeploy.serve.processors import MultimodalProcessor\nfrom lmdeploy.tokenizer import DetokenizeState, Tokenizer\nfrom lmdeploy.utils import _get_and_verify_max_len, _stop_words, get_hf_gen_cfg, get_logger\n\nfrom .exceptions import SafeRunException\n\nlogger = get_logger('lmdeploy')\n\n\n@dataclasses.dataclass\nclass GenOut:\n    \"\"\"Pack all response information together.\"\"\"\n    response: str\n    history_token_len: int\n    input_token_len: int\n    generate_token_len: int\n    finish_reason: Literal['stop', 'length', 'error'] | None = None\n    token_ids: List[int] | None = None\n    logprobs: List[Dict[int, float]] | None = None\n    logits: Any = None\n    last_hidden_state: Any = None\n    cache_block_ids: List[int] | None = None  # for disaggregation\n    routed_experts: Any = None  # for RL router replay\n\n    def to_response(self, index: int = 0) -> Response:\n        \"\"\"Convert GenOut to Response object.\n\n        Args:\n            index: The index position in the batch. Default to 0.\n        \"\"\"\n        return Response(text=self.response,\n                        generate_token_len=self.generate_token_len,\n                        input_token_len=self.input_token_len,\n                        finish_reason=self.finish_reason,\n                        token_ids=self.token_ids or [],\n                        logprobs=self.logprobs,\n                        last_hidden_state=self.last_hidden_state,\n                        logits=self.logits,\n                        routed_experts=self.routed_experts,\n                        index=index)\n\n\n# class AsyncEngine(LogitsMixin):\nclass AsyncEngine:\n    \"\"\"Async inference engine. Maintaining a bunch of tm_model instances.\n\n    Args:\n        model_path (str): the path of a model.\n            It could be one of the following options:\n                - i) A local directory path of a turbomind model which is\n                    converted by `lmdeploy convert` command or download from\n                    ii) and iii).\n                - ii) The model_id of a lmdeploy-quantized model hosted\n                    inside a model repo on huggingface.co, such as\n                    \"InternLM/internlm-chat-20b-4bit\",\n                    \"lmdeploy/llama2-chat-70b-4bit\", etc.\n                - iii) The model_id of a model hosted inside a model repo\n                    on huggingface.co, such as \"internlm/internlm-chat-7b\",\n                    \"Qwen/Qwen-7B-Chat \", \"baichuan-inc/Baichuan2-7B-Chat\"\n                    and so on.\n        model_name (str): needed when model_path is a pytorch model on\n            huggingface.co, such as \"internlm/internlm-chat-7b\",\n            \"Qwen/Qwen-7B-Chat \", \"baichuan-inc/Baichuan2-7B-Chat\" and so on.\n        backend (str): either `turbomind` or `pytorch` backend. Default to\n            `turbomind` backend.\n        backend_config (TurbomindEngineConfig | PytorchEngineConfig): beckend\n            config instance. Default to none.\n        chat_template_config (ChatTemplateConfig): chat template configuration.\n            Default to None.\n        max_log_len (int): Max number of prompt characters or prompt tokens\n            being printed in log. Default: Unlimited\n    \"\"\"\n\n    def __init__(self,\n                 model_path: str,\n                 model_name: str | None = None,\n                 backend: Literal['turbomind', 'pytorch'] = 'turbomind',\n                 backend_config: TurbomindEngineConfig | PytorchEngineConfig | None = None,\n                 chat_template_config: ChatTemplateConfig | None = None,\n                 max_log_len: int | None = None,\n                 speculative_config: SpeculativeConfig | None = None,\n                 **kwargs) -> None:\n        logger.info(f'input backend={backend}, backend_config={backend_config}')\n        logger.info(f'speculative_config={speculative_config}')\n        backend_config = backend_config or (TurbomindEngineConfig()\n                                            if backend == 'turbomind' else PytorchEngineConfig())\n        self.model_name = model_name if model_name else model_path\n        self.chat_template = get_chat_template(model_path, chat_template_config)\n        self.tokenizer = Tokenizer(model_path)\n        self.prompt_processor = MultimodalProcessor(self.tokenizer, self.chat_template)\n        self.hf_gen_cfg = get_hf_gen_cfg(model_path)\n        self.arch, self.hf_cfg = get_model_arch(model_path)\n        self.session_len = (_get_and_verify_max_len(self.hf_cfg, None)\n                            if backend_config.session_len is None else backend_config.session_len)\n        backend_config.session_len = self.session_len\n        if speculative_config is not None and backend == 'turbomind':\n            logger.warning('speculative decoding is not supported by turbomind ')\n        # build backend engine\n        if backend == 'turbomind':\n            self.engine = self._build_turbomind(model_path=model_path, backend_config=backend_config, **kwargs)\n        elif backend == 'pytorch':\n            self.engine = self._build_pytorch(model_path=model_path,\n                                              backend_config=backend_config,\n                                              speculative_config=speculative_config,\n                                              **kwargs)\n        else:\n            raise ValueError(f'unsupported backend {backend}')\n        self.backend_config = self.engine.engine_config\n        self.is_sleeping = backend_config.empty_init\n        self.sleeping_tags: set[str] = set() if not backend_config.empty_init else {'weights', 'kv_cache'}\n        logger.info(f'updated backend_config={self.backend_config}')\n\n        # parameters for member functions\n        self.stop_words = _stop_words(self.chat_template.stop_words, self.tokenizer)\n        if self.stop_words is not None:\n            self.stop_words = self.stop_words[0][0].tolist()\n        self.backend = backend\n        self.request_logger = RequestLogger(max_log_len)\n\n        self.num_spec_token = 0 if backend == 'turbomind' or speculative_config is None \\\n            else speculative_config.num_speculative_tokens\n\n        self.session_mgr = SessionManager()\n        self.session_mgr.build_request_handle_pool(self.engine, self.backend_config.max_batch_size)\n\n        # build stat loggers\n        self._build_stat_loggers()\n        self.epoch = 0\n\n    def close(self):\n        self.session_mgr.clear()\n        self.engine.close()\n\n    def __enter__(self):\n        return self\n\n    def __exit__(self, exc_type, exc_value, traceback):\n        self.close()\n\n    def _build_turbomind(self, model_path: str, backend_config: TurbomindEngineConfig | None = None, **kwargs):\n        \"\"\"Inner build method for turbomind backend.\"\"\"\n        from lmdeploy import turbomind as tm\n        return tm.TurboMind.from_pretrained(model_path, engine_config=backend_config, **kwargs)\n\n    def _build_pytorch(self,\n                       model_path: str,\n                       backend_config: PytorchEngineConfig | None = None,\n                       speculative_config: SpeculativeConfig | None = None,\n                       **kwargs):\n        \"\"\"Inner build method for pytorch backend.\"\"\"\n        from lmdeploy.pytorch.engine import Engine\n        return Engine.from_pretrained(model_path, engine_config=backend_config, speculative_config=speculative_config)\n\n    def _build_stat_loggers(self):\n        self.stat_loggers = []\n\n        if getattr(self.backend_config, 'enable_metrics', False):\n            from lmdeploy.metrics.loggers import LoggingStatLogger, PrometheusStatLogger\n\n            # currently, metrics in TM engine doesn't support dp\n            dp_rank = self.backend_config.dp_rank if self.backend == 'pytorch' else 0\n\n            logger.info(f'enable metrics, with dp: {self.backend_config.dp} dp_rank: {dp_rank}')\n            self.stat_loggers = [\n                LoggingStatLogger(dp_rank=dp_rank),\n                PrometheusStatLogger(model_name=self.model_name, max_model_len=self.session_len, dp_rank=dp_rank)\n            ]\n\n            # set stats loggers of metrics processor\n            metrics_processor.stat_loggers = self.stat_loggers\n\n    def get_schedule_metrics(self):\n        return self.engine.get_schedule_metrics()\n\n    async def do_log_stats(self):\n        \"\"\"Loop through CLI logger and Prometheus logger and output the\n        metrics.\"\"\"\n        for stat_logger in self.stat_loggers:\n            stat_logger.log()\n\n    async def stop_all_session(self):\n        \"\"\"Stop all running sessions.\"\"\"\n        logger.info('stop all sessions')\n        self.epoch += 1\n        await self.session_mgr.async_abort_all()\n\n    def sleep(self, level: int = 1):\n        \"\"\"Sleep the model.\n\n        Args:\n            level (int): The sleep level. Level 1 sleep will offload the model\n                weights and discard the kv cache. Level 2 sleep will\n                discard both the model weights and the kv cache.\n        \"\"\"\n        self.engine.sleep(level)\n        self.sleeping_tags = {'weights', 'kv_cache'}\n        self.is_sleeping = True\n\n    def wakeup(self, tags: List[str] | None = None):\n        \"\"\"Wake up the model.\n\n        Args:\n            tags: An optional list of tags to reallocate the engine memory\n                for specific memory allocations. Values must be in\n                `(\"weights\", \"kv_cache\")`. If None, all memory is reallocated.\n                wake_up should be called with all tags (or None) before the\n                engine is used again.\n        \"\"\"\n        tags = tags or list(self.sleeping_tags)\n        if any(tag not in self.sleeping_tags for tag in tags):\n            logger.warning(f'some tag in {tags} not in sleeping tags {self.sleeping_tags}')\n            return\n        self.engine.wakeup(tags)\n        # for TM backend, sleep/wakeup will reset gateway, therefore we need to rebuild instances\n        if self.backend == 'turbomind' and 'kv_cache' in tags:\n            self.session_mgr.build_request_handle_pool(self.engine, self.backend_config.max_batch_size)\n        self.sleeping_tags = self.sleeping_tags - set(tags)\n        self.is_sleeping = bool(self.sleeping_tags)\n\n    def _determine_gen_config(self, session, input_ids, gen_config: GenerationConfig | None = None) -> GenerationConfig:\n        \"\"\"Determine the generation configuration.\"\"\"\n        gen_config = deepcopy(gen_config) or GenerationConfig()\n        gen_config.convert_stop_bad_words_to_ids(self.tokenizer)\n        gen_config.stop_token_ids = gen_config.stop_token_ids or self.stop_words\n        gen_config.update_from_hf_gen_cfg(self.hf_gen_cfg, self.tokenizer.eos_token_id)\n        if not gen_config.do_sample:\n            # greedy decode\n            gen_config.top_k = 1\n            # avoid unnecessary process\n            gen_config.temperature = 1.0\n            gen_config.repetition_penalty = 1.0\n        # set random if it is not set and sequence_start is True\n        elif gen_config.random_seed is None and session.step == 0:\n            gen_config.random_seed = random.getrandbits(64)\n        if gen_config.n > 1:\n            logger.warning(f'n({gen_config.n}) > 1 hasn\\'t been supported yet. Fallback to 1')\n            gen_config.n = 1\n        if gen_config.max_new_tokens is None:\n            gen_config.max_new_tokens = max(0, self.session_len - session.step - len(input_ids))\n        return gen_config\n\n    @asynccontextmanager\n    async def safe_run(self, handle, session, **kwargs):\n        generator = handle.async_stream_infer(session.session_id, **kwargs)\n        try:\n            metrics_processor.increase_api_routed_requests()\n            yield generator\n        except (Exception, asyncio.CancelledError, GeneratorExit) as e:  # noqa\n            logger.exception(f'[safe_run] session {session.session_id} exception caught: {e}')\n            await session.async_abort()\n            if self.backend == 'pytorch':\n                await handle.async_end(session.session_id)\n            raise SafeRunException(f'Safe run exception for session {session.session_id}') from e\n        finally:\n            await generator.aclose()\n            metrics_processor.decrease_api_routed_requests()\n\n    async def generate(\n            self,\n            messages,\n            session_id: int | Session,\n            gen_config: GenerationConfig | None = None,\n            tools: List[object] | None = None,\n            reasoning_effort: Literal['low', 'medium', 'high'] | None = None,\n            stream_response: bool = True,\n            sequence_start: bool = True,\n            sequence_end: bool = True,  # no interactive mode by default\n            step: int = 0,\n            do_preprocess: bool = True,\n            adapter_name: str | None = None,\n            rewind_stop_tokens: bool = False,\n            input_ids: List | None = None,\n            enable_thinking: bool | None = None,\n            chat_template_kwargs: Dict | None = None,\n            media_io_kwargs: Dict[str, Any] | None = None,\n            mm_processor_kwargs: Dict[str, Any] | None = None,\n            **kwargs):\n        \"\"\"Generate responses.\n\n        Args:\n            messages (str | List): chat history or prompt\n            session_id (int | Session): the session id or instance of Session\n            gen_config (GenerationConfig | None): a instance of\n                GenerationConfig. Default to None.\n            stream_response (bool): whether return responses streamingly\n            sequence_start (bool): indicator for starting a sequence\n            sequence_end (bool): indicator for ending a sequence\n            step (int): the offset of the k/v cache\n            do_preprocess (bool): whether pre-process the messages. Default to\n                True, which means chat_template will be applied.\n        \"\"\"\n        epoch = self.epoch\n        if (messages is not None) ^ (input_ids is None):\n            raise ValueError('You must specify exactly one of messages or input_ids')\n        if isinstance(session_id, Session):\n            session = session_id\n        elif isinstance(session_id, int):\n            session = self.session_mgr.get(session_id, step=step)\n        else:\n            raise ValueError(f'Invalid session_id: {session_id}. It should be an instance of Session or an integer.')\n        session_id = session.session_id\n        chat_template_kwargs = chat_template_kwargs or {}\n        if enable_thinking is not None:\n            logger.warning('enable_thinking is deprecated, use chat_template_kwargs[\"enable_thinking\"] instead')\n            if chat_template_kwargs.get('enable_thinking') is None:\n                chat_template_kwargs['enable_thinking'] = enable_thinking\n            else:\n                logger.warning('chat_template_kwargs[\"enable_thinking\"] is already set, '\n                               'the value will not be overwritten by enable_thinking')\n        if messages:\n            prompt = messages\n            self.request_logger.log_prompt(session, prompt=prompt)\n            prompt_input = await self.prompt_processor.get_prompt_input(prompt=prompt,\n                                                                        do_preprocess=do_preprocess,\n                                                                        sequence_start=sequence_start,\n                                                                        adapter_name=adapter_name,\n                                                                        tools=tools,\n                                                                        reasoning_effort=reasoning_effort,\n                                                                        chat_template_kwargs=chat_template_kwargs,\n                                                                        media_io_kwargs=media_io_kwargs,\n                                                                        mm_processor_kwargs=mm_processor_kwargs,\n                                                                        **kwargs)\n            prompt = prompt_input['prompt']\n            input_ids = prompt_input['input_ids']\n            self.request_logger.log_inputs(session,\n                                           prompt=prompt,\n                                           prompt_token_ids=input_ids,\n                                           gen_config=gen_config,\n                                           adapter_name=adapter_name)\n        else:\n            # TODO(lvhan) VLM doesn't support input_ids as an argument.\n            # Figure out a graceful way to handle the invalid input\n            prompt_input = dict(input_ids=input_ids)\n\n        gen_config = self._determine_gen_config(session, input_ids, gen_config=gen_config)\n\n        if gen_config.max_new_tokens == 0:\n            logger.info(f'run out of tokens. session={session_id}.')\n            yield GenOut(response='',\n                         history_token_len=session.step,\n                         input_token_len=len(input_ids),\n                         generate_token_len=0,\n                         finish_reason='length',\n                         token_ids=[])\n            if sequence_end is True and sequence_start is False:\n                await session.async_close()\n            return\n\n        if self.backend_config.enable_prefix_caching and (gen_config.output_last_hidden_state == 'all'\n                                                          or gen_config.output_logits == 'all'):\n            errmsg = ('lmdeploy does not support outputting all token\\'s logits or last_hidden_state '\n                      'when prefix caching is ON')\n            yield GenOut(response=errmsg,\n                         history_token_len=session.step,\n                         input_token_len=len(input_ids),\n                         generate_token_len=0,\n                         finish_reason='error',\n                         token_ids=[])\n            return\n        logger.info(f'session={session_id}, '\n                    f'history_tokens={session.step}, '\n                    f'input_tokens={len(input_ids)}, '\n                    f'max_new_tokens={gen_config.max_new_tokens}, '\n                    f'seq_start={sequence_start}, seq_end={sequence_end}, '\n                    f'step={step}, prep={do_preprocess}')\n\n        def is_error(status):\n            return status not in [ResponseType.SUCCESS, ResponseType.FINISH, ResponseType.CANCEL]\n\n        stop_ids = []\n        if not gen_config.ignore_eos:\n            stop_ids = gen_config.stop_token_ids or []\n\n        metrics_processor.increase_total_requests()\n        async with session.request_handle() as handle:\n            if epoch != self.epoch:\n                logger.info(f'[generate] session {session_id} got aborted before starting inference')\n                # TODO(lvhan): metrics_processor.increase_failed_requests('abort')\n                metrics_processor.increase_completed_requests()\n                yield GenOut(response='',\n                             history_token_len=0,\n                             input_token_len=len(input_ids),\n                             generate_token_len=0,\n                             finish_reason='abort',\n                             token_ids=[])\n                return\n            token_ids = input_ids.copy()\n            history_len = session.step\n            input_len = len(input_ids)\n            output_len, gen_len = 0, 0\n            state = DetokenizeState(input_len)\n            response = ''\n            finish_reason = None\n            async with self.safe_run(handle,\n                                     session=session,\n                                     **prompt_input,\n                                     gen_config=gen_config,\n                                     adapter_name=adapter_name,\n                                     stream_output=stream_response,\n                                     sequence_start=sequence_start,\n                                     sequence_end=sequence_end,\n                                     step=history_len) as gen:\n                logger.debug(f'[generate] session {session_id} started')\n                hit_stop_token = 0\n                req_stats = RequestStats(prompt_tokens=input_len)  # per-request stats\n\n                # We use this as default outputs in case the async_stream_infer of the Engine yields empty generator.\n                outputs = EngineOutput(ResponseType.INTERNAL_ENGINE_ERROR, [])\n\n                async for outputs in gen:\n                    iteration_stats = IterationStats()  # per-iteration stats\n                    specdecode_stats = SpeculativeDecodingStats(\n                        self.num_spec_token) if self.num_spec_token > 0 else None\n                    metrics_processor.queue_update((outputs, req_stats, iteration_stats, specdecode_stats))\n                    # decode res\n                    if is_error(outputs.status):\n                        break\n\n                    output_len = len(outputs.token_ids)\n                    if hit_stop_token or output_len == 0:\n                        continue\n\n                    # This assumes the engine will stop when stop token is hit\n                    if output_len and outputs.token_ids[-1] in stop_ids:\n                        hit_stop_token = 1\n\n                    token_ids += outputs.token_ids[:output_len - hit_stop_token]\n                    gen_len = len(token_ids) - input_len\n\n                    ids_offset = state.ids_offset\n                    response, state = self.tokenizer.detokenize_incrementally(\n                        token_ids,\n                        state,\n                        skip_special_tokens=gen_config.skip_special_tokens,\n                        spaces_between_special_tokens=gen_config.spaces_between_special_tokens)\n                    res = token_ids[ids_offset:]\n\n                    out = GenOut(response,\n                                 history_len,\n                                 input_len,\n                                 gen_len,\n                                 finish_reason,\n                                 token_ids=res,\n                                 routed_experts=outputs.routed_experts,\n                                 cache_block_ids=outputs.cache_block_ids)\n                    if outputs.logprobs is not None:\n                        out.logprobs = (outputs.logprobs[:-hit_stop_token] if hit_stop_token else outputs.logprobs)\n                    if outputs.last_hidden_state is not None:\n                        out.last_hidden_state = (outputs.last_hidden_state[:-hit_stop_token]\n                                                 if hit_stop_token else outputs.last_hidden_state)\n                    if outputs.logits is not None:\n                        out.logits = (outputs.logits[:-hit_stop_token] if hit_stop_token else outputs.logits)\n                    yield out\n                # end of generator loop\n                metrics_processor.increase_completed_requests()\n\n                if not is_error(outputs.status):\n                    if outputs.status == ResponseType.CANCEL:\n                        finish_reason = 'abort'\n                    else:\n                        finish_reason = 'stop' if outputs.token_ids[-1] in stop_ids else 'length'\n\n                    # utf-8 char at the end means it's a potential unfinished byte sequence\n                    if not response.endswith('�'):\n                        # avoid returning the last response twice\n                        response = ''\n                    token_ids, logits, last_hidden_state, logprobs = [], None, None, None\n                    if gen_config.include_stop_str_in_output and finish_reason == 'stop':\n                        # return the eos token id (MUST be in a list), eos string, eos token's logits and so on\n                        token_ids = outputs.token_ids[-1:]\n                        response = self.tokenizer.decode(token_ids, skip_special_tokens=False)\n                        logits = outputs.logits[-1:] if outputs.logits is not None else None\n                        last_hidden_state = outputs.last_hidden_state[-1:] if outputs.last_hidden_state else None\n                        logprobs = outputs.logprobs[-1:] if outputs.logprobs else None\n                        gen_len += 1\n\n                    # router replay\n                    routed_experts = outputs.routed_experts\n                    if routed_experts is not None and not isinstance(routed_experts, str) and (\n                            not gen_config.include_stop_str_in_output) and finish_reason == 'stop':\n                        routed_experts = routed_experts[:-1]\n\n                    logger.info(f'session {session_id} finished, reason '\n                                f'\"{finish_reason}\", input_tokens '\n                                f'{len(input_ids)}, output_tokens {gen_len}')\n                    yield GenOut(response,\n                                 session.step,\n                                 len(input_ids),\n                                 gen_len,\n                                 finish_reason,\n                                 token_ids=token_ids,\n                                 logprobs=logprobs,\n                                 logits=logits,\n                                 last_hidden_state=last_hidden_state,\n                                 routed_experts=routed_experts,\n                                 cache_block_ids=outputs.cache_block_ids)\n                    # Note: We remove the session step update here. Let the caller(e.g., pipeline.chat) take care of it.\n                else:\n                    logger.error(f'session {session_id} finished, {outputs.status}, '\n                                 'reason \"error\"')\n                    yield GenOut(response=f'internal error happened, status code {outputs.status}',\n                                 history_token_len=session.step,\n                                 input_token_len=len(input_ids),\n                                 generate_token_len=0,\n                                 finish_reason='error',\n                                 token_ids=[])\n            # update step\n            if sequence_end:\n                if self.backend == 'pytorch':\n                    # manually end pytorch session\n                    # note: Using session.async_abort() here results in deadlock\n                    # because it waits for session's _active event to be set, but the event won't be set\n                    # until the session is finished, i.e., session.request_handle() context exits.\n                    await handle.async_end(session.session_id)\n                self.session_mgr.remove(session)\n        # if sequence_end:\n        #     if self.backend == 'pytorch':\n        #         # manually end pytorch session. session cannot be ended until session.request_handle()\n        #         # context exits\n        #         await session.async_close()\n        #     self.session_mgr.remove(session)\n\n    def start_loop(self, loop, use_async_api=False):\n        \"\"\"Start engine loop.\n\n        When using pytorch backend with dp > 1, all dp_rank should receive at least one request before it can start\n        processing (warmup). Since pytorch engine will bound to event loop, the pipeline can only choose either the\n        synchronous apis(__call__, stream_infer, etc.) or the asynchronous api (generate) during its lifetime.\n\n        The purpose of this function is to allow users to choose whether to use the synchronous interface or the\n        asynchronous interface for the pipeline.\n        \"\"\"\n        self.session_mgr.attach_event_loop(loop)\n        if hasattr(self.engine, 'start_loop'):\n            if use_async_api:\n                return self.engine.start_loop()\n            else:\n                fut = concurrent.futures.Future()\n\n                def _start_loop(fut):\n                    res = self.engine.start_loop()\n                    fut.set_result(res)\n\n                loop.call_soon_threadsafe(_start_loop, fut)\n                return fut.result()\n        else:\n            return True\n\n    \"\"\" DistServe Async Engine API Begin \"\"\"\n\n    def free_cache(self, session_id: int):\n        if self.engine.end_session(session_id):\n            logger.debug(f'successfully free session {session_id}')\n        else:\n            logger.warning(f'Invalid Free session {session_id}.')\n\n    def p2p_initialize(self, init_request: DistServeInitRequest):\n        return self.engine.p2p_initialize(init_request)\n\n    def p2p_connect(self, conn_request: List[DistServeConnectionRequest]):\n        return self.engine.p2p_connect(conn_request)\n\n    def p2p_drop_connect(self, drop_conn_request: List[DistServeDropConnectionRequest]):\n        return self.engine.p2p_drop_connect(drop_conn_request)\n\n    \"\"\" DistServe Async Engine API End \"\"\"\n\n    async def async_get_reward_score(self, input_ids: List) -> List[float]:\n        \"\"\"Async version of get_reward_score.\"\"\"\n        supported_reward_models = ['InternLM2ForRewardModel', 'Qwen2ForRewardModel']\n        if self.arch not in supported_reward_models:\n            raise ValueError(f'{self.arch} is not in reward model list: {supported_reward_models}')\n        assert isinstance(input_ids, List)\n        assert all(isinstance(x, int) for x in input_ids) or all(isinstance(x, List) for x in input_ids)\n        # Make input_ids a list of token_id list\n        input_ids = [input_ids] if isinstance(input_ids[0], int) else input_ids\n\n        logits = await self.async_get_logits(input_ids=input_ids)\n\n        logits = [x.squeeze() for x in logits]\n        scores = [x[-1].cpu().item() for x in logits]\n        return scores\n\n    async def async_get_logits(self,\n                               input_ids,\n                               sessions: List['Session'] | None = None,\n                               sequence_start: bool = True,\n                               sequence_end: bool = True) -> List[torch.Tensor]:\n        assert input_ids and all(isinstance(_, List) for _ in input_ids)\n        assert sessions is None or (len(sessions) == len(input_ids))\n\n        logits = [None] * len(input_ids)\n\n        async def _proc(session, i):\n            async with session.request_handle() as handle:\n                input_len = len(input_ids[i])\n                # TODO(lvhan): Fix the ugly code later on\n                max_new_tokens = 1 if self.backend == 'turbomind' else 0\n                # The reason to set `top_k=1` is that pt engine crashes at top_k sampling stage\n                # when perform inference on a reward model.\n                gen_config = GenerationConfig(max_new_tokens=max_new_tokens, output_logits='all', top_k=1)\n                async with self.safe_run(handle,\n                                         session=session,\n                                         input_ids=input_ids[i],\n                                         gen_config=gen_config,\n                                         stream_output=False,\n                                         sequence_start=sequence_start,\n                                         sequence_end=sequence_end,\n                                         step=session.step) as gen:\n                    async for outputs in gen:\n                        pass\n                    logits[i] = outputs.logits[:input_len, :]\n\n        create_sessions = False\n        if sessions is None:\n            create_sessions = True\n            sessions = [self.session_mgr.get() for _ in range(len(input_ids))]\n        tasks = [_proc(session, i) for i, session in enumerate(sessions)]\n        await asyncio.gather(*tasks)\n        if sequence_end and self.backend == 'pytorch':\n            for session in sessions:\n                await session.async_close()\n        if sequence_end and create_sessions:\n            for session in sessions:\n                self.session_mgr.remove(session)\n        return logits\n"
  },
  {
    "path": "lmdeploy/serve/core/exceptions.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\"\"\"Exceptions for the serve module.\"\"\"\n\n\nclass SafeRunException(Exception):\n    \"\"\"Exception raised by safe_run to avoid upper layer handling the original\n    exception again.\n\n    This exception wraps the original exception that occurred during safe_run execution.\n    \"\"\"\n"
  },
  {
    "path": "lmdeploy/serve/core/vl_async_engine.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom typing import Literal\n\nfrom lmdeploy.messages import PytorchEngineConfig, TurbomindEngineConfig, VisionConfig\nfrom lmdeploy.utils import get_logger\n\nfrom .async_engine import AsyncEngine\n\nlogger = get_logger('lmdeploy')\n\n\nclass VLAsyncEngine(AsyncEngine):\n    \"\"\"Visual Language Async inference engine.\"\"\"\n\n    def __init__(self,\n                 model_path: str,\n                 backend: Literal['turbomind', 'pytorch'] = 'turbomind',\n                 backend_config: TurbomindEngineConfig | PytorchEngineConfig | None = None,\n                 vision_config: VisionConfig | None = None,\n                 **kwargs) -> None:\n        from lmdeploy.serve.processors import MultimodalProcessor\n        from lmdeploy.utils import try_import_deeplink\n        from lmdeploy.vl.engine import ImageEncoder\n\n        if backend == 'pytorch':\n            try_import_deeplink(backend_config.device_type)\n        if backend_config and backend_config.enable_prefix_caching:\n            backend_config.enable_prefix_caching = False\n            logger.warning('Prefix caching is disabled since LMDeploy hasn\\'t support in on VL models yet')\n        self.vl_encoder = ImageEncoder(model_path, backend, vision_config, backend_config=backend_config)\n        super().__init__(model_path, backend=backend, backend_config=backend_config, **kwargs)\n        # Update prompt_processor to support multimodal processing\n        self.prompt_processor = MultimodalProcessor(self.tokenizer,\n                                                    self.chat_template,\n                                                    vl_encoder=self.vl_encoder,\n                                                    backend=backend)\n        if self.model_name == 'base':\n            raise RuntimeError(\n                'please specify chat template as guided in https://lmdeploy.readthedocs.io/en/latest/inference/vl_pipeline.html#set-chat-template'  # noqa: E501\n            )\n\n    def close(self):\n        if hasattr(self, 'vl_encoder'):\n            del self.vl_encoder\n            super().close()\n"
  },
  {
    "path": "lmdeploy/serve/managers/__init__.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom .session_manager import Session, SessionManager\n\n__all__ = ['Session', 'SessionManager']\n"
  },
  {
    "path": "lmdeploy/serve/managers/session_manager.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom __future__ import annotations\n\nimport asyncio\nimport itertools\nimport weakref\nfrom contextlib import asynccontextmanager\nfrom typing import Any, List, Tuple\n\nfrom lmdeploy.messages import GenerationConfig, Response\nfrom lmdeploy.serve.core.exceptions import SafeRunException\nfrom lmdeploy.utils import get_logger\n\nlogger = get_logger('lmdeploy')\n\n\nclass Session:\n    \"\"\"Session for the engine.\"\"\"\n\n    def __init__(self, session_id: int, session_mgr: SessionManager, **kwargs):\n        self.session_id = session_id\n        self.prompt: Any = None\n        self.response: Response | None = None\n        self.history: List[Tuple[Any, str]] = []\n        self.gen_config: GenerationConfig | None = None\n        self.step: int = 0\n        # event to wait for the session to be active\n        self._active: asyncio.Event | None = None\n        self._handle = None  # inference instance\n        self._session_mgr: SessionManager = weakref.ref(session_mgr)\n        self.update(**kwargs)\n\n    def update(self, **kwargs):\n        \"\"\"Update the session.\"\"\"\n        self.prompt = kwargs.get('prompt', self.prompt)\n        self.gen_config = kwargs.get('gen_config', self.gen_config)\n        self.step = kwargs.get('step', self.step)\n\n    def __repr__(self) -> str:\n        \"\"\"Return a string representation of the Session object.\"\"\"\n        return (f'Session(session_id={self.session_id}, '\n                f'step={self.step}, history_len={len(self.history)}, '\n                f'has_response={self.response is not None}, '\n                f'has_gen_config={self.gen_config is not None})')\n\n    def __str__(self) -> str:\n        \"\"\"Return a human-readable string representation of the Session.\"\"\"\n        res = f'Session(id={self.session_id}, step={self.step})'\n        if self.history:\n            res += '\\nHistory:\\n'\n            for user, assistant in self.history:\n                if isinstance(user, list):\n                    user = str(user)\n                res += f'USER: \\n{user}\\nASSISTANT: \\n{assistant}\\n'\n        return res\n\n    def reset(self):\n        \"\"\"Reset the session to initial state.\n\n        This method resets all session data (prompt, response, history, etc.) but keeps the session_id.\n        \"\"\"\n        self.prompt = None\n        self.response = None\n        self.history = []\n        self.gen_config = None\n        self.step = 0\n        self._active = None\n        self._handle = None\n        self._session_mgr = None\n        logger.debug(f'Session {self.session_id} has been reset.')\n\n    @asynccontextmanager\n    async def request_handle(self):\n        if self._handle is not None:\n            raise RuntimeError(f'Session {self.session_id} already has an inference instance.')\n        logger.debug(f'[request_handle] session {self.session_id} acquiring an instance')\n\n        hnd_pool = self._session_mgr().request_handle_pool\n        self._handle = await hnd_pool.get()\n        self._active = asyncio.Event()\n        logger.debug(f'[request_handle] session {self.session_id} acquired an instance')\n        try:\n            yield self._handle\n        except SafeRunException:\n            pass\n        except (asyncio.CancelledError, GeneratorExit) as e:\n            logger.exception(f'[request_handle] session {self.session_id} exception caught: {e}')\n            await self._handle.async_cancel(self.session_id)\n        except Exception as e:\n            logger.exception(f'[request_handle] session {self.session_id} exception caught: {e}')\n            raise\n        finally:\n            logger.debug(f'[request_handle] session {self.session_id} releasing the instance')\n            # Return inference instance if it was acquired\n            if self._handle is not None:\n                hnd_pool.put(self._handle)\n                self._handle = None\n            # MUST set the signal after releasing the instance to avoid race condition\n            # refer to async_end method\n            self._active.set()\n\n    async def async_abort(self):\n        \"\"\"Abort the session.\"\"\"\n        logger.info(f'[session] Aborting session {self.session_id}')\n        if self._handle is not None:\n            await self._handle.async_cancel(self.session_id)\n\n    async def async_close(self):\n        \"\"\"End the session.\"\"\"\n        logger.info(f'[session] Ending session {self.session_id}')\n        if self._handle is None and self.step == 0:\n            return\n        if self._handle is not None:\n            await self._active.wait()\n        async with self.request_handle() as handle:\n            try:\n                await handle.async_end(self.session_id)\n            except (Exception, asyncio.CancelledError, GeneratorExit) as e:\n                logger.exception(f'[async_close] exception caught: {e}')\n        self.reset()\n\n    def abort(self):\n        \"\"\"Abort the session in sync mode.\"\"\"\n        if self._session_mgr is not None:\n            self._run(self.async_abort()).result()\n\n    def close(self):\n        \"\"\"End the session in sync mode.\"\"\"\n        if self._session_mgr is not None:\n            self._run(self.async_close()).result()\n\n    def _run(self, coro):\n        assert self._session_mgr is not None, 'Session manager is not initialized'\n        return asyncio.run_coroutine_threadsafe(coro, self._session_mgr().loop)\n\n\nclass RequestHandlePool:\n    \"\"\"Manages a pool of request handles for concurrent request processing.\n\n    This class maintains a fixed-size pool of request handles that can be reused\n    across multiple inference requests. It implements a lazy-initialized queue-based\n    pool pattern to efficiently manage handle lifecycle and enable concurrent\n    request handling.\n\n    Each session or request should acquire a handle from the pool before inference and\n    return it after completion. The manager supports:\n    - Pool-based handle allocation and deallocation\n    - Lazy initialization of the async queue (required for asyncio.Queue)\n    - Handle rebuilding after engine wakeup (e.g., turbomind backend)\n    - Complete pool cleanup\n\n    Args:\n        engine (AsyncEngine): The async inference engine that creates handles.\n        size (int): The size of the handle pool, typically set to max_batch_size.\n\n    Note:\n        The pool queue is lazily initialized on first access via `get()` method,\n        as `asyncio.Queue` must be created within an async context.\n    \"\"\"\n\n    def __init__(self, engine, size: int):\n        self.size = size\n        self.handles = [engine.create_instance() for _ in range(size)]\n        # `asyncio.Queue` must be created in an async context, refer to `get` method\n        self.pool: asyncio.Queue = None\n\n    async def get(self):\n        \"\"\"Get a handle from pool.\"\"\"\n        # Lazy initialization: create pool on first use\n        if self.pool is None:\n            self.pool = asyncio.Queue()\n            for inst in self.handles:\n                self.pool.put_nowait(inst)\n\n        return await self.pool.get()\n\n    def put(self, handle):\n        \"\"\"Put a handle back to the pool.\"\"\"\n        if handle is not None and self.pool is not None:\n            self.pool.put_nowait(handle)\n\n    def clear(self):\n        \"\"\"Clear all handles.\"\"\"\n        self.handles = []\n        self.pool = None\n\n\nclass SessionManager:\n    \"\"\"Session manager.\"\"\"\n\n    def __init__(self):\n        \"\"\"Initialize the session manager.\"\"\"\n\n        self.sessions = {}\n        self.session_id_generator = itertools.count(1)\n        self.request_handle_pool = None\n        self.loop = None\n\n    def get(self, session_id: int | None = None, **kwargs) -> Session:\n        \"\"\"Create a new session.\"\"\"\n        session_id = session_id or next(self.session_id_generator)\n        if session_id in self.sessions:\n            logger.debug(f'[SessionManager] session {session_id} already exists. Updating...')\n            session = self.sessions[session_id]\n            session.update(**kwargs)\n            return session\n        else:\n            logger.info(f'[SessionManager] session {session_id} not found. Creating...')\n            session = Session(session_id, self, **kwargs)\n            self.sessions[session_id] = session\n            return session\n\n    async def async_abort_all(self):\n        \"\"\"Abort all sessions.\"\"\"\n        tasks = []\n        for session in list(self.sessions.values()):\n            tasks.append(session.async_abort())\n        await asyncio.gather(*tasks, return_exceptions=True)\n        # \"abort all\" is designed for async RL. The aborted sessions will be no longer used,\n        # so we clear the sessions here.\n        self.sessions.clear()\n\n    def has(self, session_id):\n        return session_id in self.sessions\n\n    def remove(self, session: Session):\n        self.sessions.pop(session.session_id, None)\n\n    def clear(self):\n        self.sessions.clear()\n        # reset the session id generator\n        self.session_id_generator = itertools.count(1)\n\n    def attach_event_loop(self, loop):\n        self.loop = loop\n\n    def build_request_handle_pool(self, engine, size):\n        \"\"\"Build the request handle's pool.\"\"\"\n        self.request_handle_pool = RequestHandlePool(engine, size)\n"
  },
  {
    "path": "lmdeploy/serve/openai/__init__.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n"
  },
  {
    "path": "lmdeploy/serve/openai/api_client.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport json\nfrom typing import Any, Dict, List, Optional, Union\n\nimport requests\n\nfrom lmdeploy.utils import get_logger\n\n\ndef get_model_list(api_url: str, headers: dict = None):\n    \"\"\"Get model list from api server.\"\"\"\n    response = requests.get(api_url, headers=headers)\n    logger = get_logger('lmdeploy')\n    if not response.ok:\n        logger.error(f'Failed to get the model list: {api_url}'\n                     f' returns {response.status_code}')\n        return None\n    elif not hasattr(response, 'text'):\n        logger.warning('Failed to get the model list.')\n        return None\n    else:\n        model_list = response.json()\n        model_list = model_list.pop('data', [])\n        return [item['id'] for item in model_list]\n\n\ndef json_loads(content):\n    \"\"\"Loads content to json format.\"\"\"\n    try:\n        content = json.loads(content)\n        return content\n    except:  # noqa\n        logger = get_logger('lmdeploy')\n        logger.warning(f'weird json content {content}')\n        return ''\n\n\nclass APIClient:\n    \"\"\"Chatbot for LLaMA series models with turbomind as inference engine.\n\n    Args:\n        api_server_url (str): communicating address 'http://<ip>:<port>' of\n            api_server\n        api_key (str | None): api key. Default to None, which means no\n            api key will be used.\n    \"\"\"\n\n    def __init__(self, api_server_url: str, api_key: Optional[str] = None, **kwargs):\n        self.api_server_url = api_server_url\n        self.chat_completions_v1_url = f'{api_server_url}/v1/chat/completions'\n        self.completions_v1_url = f'{api_server_url}/v1/completions'\n        self.models_v1_url = f'{api_server_url}/v1/models'\n        self.encode_v1_url = f'{api_server_url}/v1/encode'\n        self._available_models = None\n        self.api_key = api_key\n        self.headers = {'content-type': 'application/json'}\n        if api_key is not None:\n            self.headers['Authorization'] = f'Bearer {api_key}'\n\n    @property\n    def available_models(self):\n        \"\"\"Show available models.\"\"\"\n        if self._available_models is not None:\n            return self._available_models\n        self._available_models = get_model_list(self.models_v1_url, headers=self.headers)\n        return self._available_models\n\n    def encode(self,\n               input: Union[str, List[str]],\n               do_preprocess: Optional[bool] = False,\n               add_bos: Optional[bool] = True):\n        \"\"\"Encode prompts.\n\n        Args:\n            input: the prompt to be encoded. In str or List[str] format.\n            do_preprocess: whether do preprocess or not. Default to False.\n            add_bos: True when it is the beginning of a conversation. False\n                when it is not. Default to True.\n        Return: (input_ids, length)\n        \"\"\"\n        response = requests.post(self.encode_v1_url,\n                                 headers=self.headers,\n                                 json=dict(input=input, do_preprocess=do_preprocess, add_bos=add_bos),\n                                 stream=False)\n        if hasattr(response, 'text'):\n            output = json_loads(response.text)\n            return output['input_ids'], output['length']\n        return None, None\n\n    def chat_completions_v1(\n        self,\n        model: str,\n        messages: Union[str, List[Dict[str, str]]],\n        temperature: Optional[float] = 0.7,\n        top_p: Optional[float] = 1.0,\n        logprobs: Optional[bool] = False,\n        top_logprobs: Optional[int] = 0,\n        n: Optional[int] = 1,\n        max_completion_tokens: Optional[int] = None,\n        max_tokens: Optional[int] = None,\n        stop: Optional[Union[str, List[str]]] = None,\n        stream: Optional[bool] = False,\n        presence_penalty: Optional[float] = 0.0,\n        frequency_penalty: Optional[float] = 0.0,\n        user: Optional[str] = None,\n        repetition_penalty: Optional[float] = 1.0,\n        ignore_eos: Optional[bool] = False,\n        skip_special_tokens: Optional[bool] = True,\n        spaces_between_special_tokens: Optional[bool] = True,\n        top_k: int = 40,\n        min_new_tokens: Optional[int] = None,\n        min_p: float = 0.0,\n        logit_bias: Optional[Dict[str, float]] = None,\n        stream_options: Optional[Dict] = None,\n        **kwargs,\n    ):\n        \"\"\"Chat completion v1.\n\n        Args:\n            model: model name. Available from self.available_models.\n            messages: string prompt or chat history in OpenAI format. Chat\n                history example: `[{\"role\": \"user\", \"content\": \"hi\"}]`.\n            temperature (float): to modulate the next token probability\n            top_p (float): If set to float < 1, only the smallest set of most\n                probable tokens with probabilities that add up to top_p or\n                higher are kept for generation.\n            n (int): How many chat completion choices to generate for each\n                input message. Only support one here.\n            stream: whether to stream the results or not. Default to false.\n            max_completion_tokens (int | None): output token nums. Default to None.\n            max_tokens (int | None): output token nums. Default to None.\n                Deprecated: Use max_completion_tokens instead.\n            stop (str | List[str] | None): To stop generating further\n              tokens. Only accept stop words that's encoded to one token idex.\n            repetition_penalty (float): The parameter for repetition penalty.\n                1.0 means no penalty\n            ignore_eos (bool): indicator for ignoring eos\n            skip_special_tokens (bool): Whether or not to remove special tokens\n                in the decoding. Default to be True.\n            spaces_between_special_tokens (bool): Whether or not to add spaces\n                around special tokens. The behavior of Fast tokenizers is to have\n                this to False. This is setup to True in slow tokenizers.\n            top_k (int): The number of the highest probability vocabulary\n                tokens to keep for top-k-filtering\n            min_new_tokens (int): To generate at least numbers of tokens.\n            min_p (float): Minimum token probability, which will be scaled by the\n                probability of the most likely token. It must be a value between\n                0 and 1. Typical values are in the 0.01-0.2 range, comparably\n                selective as setting `top_p` in the 0.99-0.8 range (use the\n                opposite of normal `top_p` values)\n            logit_bias (Dict): Bias to logits. Only supported in pytorch engine.\n            stream_options: Options for streaming response. Only set this when you\n                set stream: true.\n\n        Yields:\n            json objects in openai formats\n        \"\"\"\n        pload = {k: v for k, v in locals().copy().items() if k[:2] != '__' and k not in ['self']}\n        response = requests.post(self.chat_completions_v1_url, headers=self.headers, json=pload, stream=stream)\n        for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b'\\n'):\n            if chunk:\n                if stream:\n                    decoded = chunk.decode('utf-8')\n                    if decoded == 'data: [DONE]':\n                        continue\n                    if decoded[:6] == 'data: ':\n                        decoded = decoded[6:]\n                    output = json_loads(decoded)\n                    yield output\n                else:\n                    decoded = chunk.decode('utf-8')\n                    output = json_loads(decoded)\n                    yield output\n\n    def completions_v1(\n        self,\n        model: str,\n        prompt: Union[str, List[Any]],\n        suffix: Optional[str] = None,\n        temperature: Optional[float] = 0.7,\n        n: Optional[int] = 1,\n        max_completion_tokens: Optional[int] = 16,\n        max_tokens: Optional[int] = 16,\n        stream: Optional[bool] = False,\n        stop: Optional[Union[str, List[str]]] = None,\n        top_p: Optional[float] = 1.0,\n        top_k: Optional[int] = 40,\n        user: Optional[str] = None,\n        # additional argument of lmdeploy\n        repetition_penalty: Optional[float] = 1.0,\n        ignore_eos: Optional[bool] = False,\n        skip_special_tokens: Optional[bool] = True,\n        spaces_between_special_tokens: Optional[bool] = True,\n        stream_options: Optional[Dict] = None,\n        **kwargs,\n    ):\n        \"\"\"Chat completion v1.\n\n        Args:\n            model (str): model name. Available from /v1/models.\n            prompt (str): the input prompt.\n            suffix (str): The suffix that comes after a completion of inserted\n                text.\n            max_completion_tokens (int | None): output token nums. Default to 16.\n            max_tokens (int): output token nums\n                Deprecated: Use max_completion_tokens instead.\n            temperature (float): to modulate the next token probability\n            top_p (float): If set to float < 1, only the smallest set of most\n                probable tokens with probabilities that add up to top_p or\n                higher are kept for generation.\n            top_k (int): The number of the highest probability vocabulary\n                tokens to keep for top-k-filtering\n            n (int): How many chat completion choices to generate for each\n                input message. Only support one here.\n            stream: whether to stream the results or not. Default to false.\n            stop (str | List[str] | None): To stop generating further\n              tokens. Only accept stop words that's encoded to one token idex.\n            repetition_penalty (float): The parameter for repetition penalty.\n                1.0 means no penalty\n            user (str): A unique identifier representing your end-user.\n            ignore_eos (bool): indicator for ignoring eos\n            skip_special_tokens (bool): Whether or not to remove special tokens\n                in the decoding. Default to be True.\n            spaces_between_special_tokens (bool): Whether or not to add spaces\n                around special tokens. The behavior of Fast tokenizers is to have\n                this to False. This is setup to True in slow tokenizers.\n            stream_options: Options for streaming response. Only set this when you\n                set stream: true.\n\n        Yields:\n            json objects in openai formats\n        \"\"\"\n        pload = {k: v for k, v in locals().copy().items() if k[:2] != '__' and k not in ['self']}\n        response = requests.post(self.completions_v1_url, headers=self.headers, json=pload, stream=stream)\n        for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b'\\n'):\n            if chunk:\n                if stream:\n                    decoded = chunk.decode('utf-8')\n                    if decoded == 'data: [DONE]':\n                        continue\n                    if decoded[:6] == 'data: ':\n                        decoded = decoded[6:]\n                    output = json_loads(decoded)\n                    yield output\n                else:\n                    decoded = chunk.decode('utf-8')\n                    output = json_loads(decoded)\n                    yield output\n"
  },
  {
    "path": "lmdeploy/serve/openai/api_server.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n# yapf: disable\nimport asyncio\nimport copy\nimport json\nimport os\nimport re\nimport time\nfrom contextlib import asynccontextmanager\nfrom functools import partial\nfrom http import HTTPStatus\nfrom typing import AsyncGenerator, Literal\n\nimport uvicorn\nfrom fastapi import APIRouter, Depends, FastAPI, HTTPException, Request, status\nfrom fastapi.encoders import jsonable_encoder\nfrom fastapi.exceptions import RequestValidationError\nfrom fastapi.middleware.cors import CORSMiddleware\nfrom fastapi.responses import JSONResponse, Response, StreamingResponse\nfrom starlette.middleware.base import BaseHTTPMiddleware\nfrom starlette.routing import Mount\n\nfrom lmdeploy.archs import get_task\nfrom lmdeploy.messages import (GenerationConfig, LogitsProcessor, PytorchEngineConfig, SpeculativeConfig,\n                               TurbomindEngineConfig)\nfrom lmdeploy.metrics.metrics_processor import metrics_processor\nfrom lmdeploy.model import ChatTemplateConfig\nfrom lmdeploy.pytorch.disagg.config import DistServeEngineConfig\nfrom lmdeploy.pytorch.disagg.conn.protocol import (DistServeCacheFreeRequest, DistServeConnectionRequest,\n                                                   DistServeDropConnectionRequest, DistServeInitRequest,\n                                                   MigrationRequest)\nfrom lmdeploy.serve.core import AsyncEngine\nfrom lmdeploy.serve.openai.harmony_utils import GptOssChatParser\nfrom lmdeploy.serve.openai.protocol import ChatCompletionResponse  # noqa: E501\nfrom lmdeploy.serve.openai.protocol import (AbortRequest, ChatCompletionRequest, ChatCompletionResponseChoice,\n                                            ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse,\n                                            ChatCompletionTokenLogprob, ChatMessage, ChoiceLogprobs, CompletionRequest,\n                                            CompletionResponse, CompletionResponseChoice,\n                                            CompletionResponseStreamChoice, CompletionStreamResponse, DeltaMessage,\n                                            EmbeddingsRequest, EncodeRequest, EncodeResponse, ErrorResponse,\n                                            GenerateReqInput, GenerateReqMetaOutput, GenerateReqOutput, LogProbs,\n                                            ModelCard, ModelList, ModelPermission, PoolingRequest, PoolingResponse,\n                                            TopLogprob, UpdateParamsRequest, UsageInfo)\nfrom lmdeploy.serve.openai.reasoning_parser.reasoning_parser import ReasoningParser, ReasoningParserManager\nfrom lmdeploy.serve.openai.tool_parser.tool_parser import ToolParser, ToolParserManager\nfrom lmdeploy.serve.utils.server_utils import validate_json_request\nfrom lmdeploy.tokenizer import DetokenizeState, Tokenizer\nfrom lmdeploy.utils import get_logger\n\n# yapf: enable\n\nlogger = get_logger('lmdeploy')\n\n\nclass VariableInterface:\n    \"\"\"A IO interface maintaining variables.\"\"\"\n    async_engine: AsyncEngine = None\n    request_hosts = []\n    # following are for registering to proxy server\n    proxy_url: str | None = None\n    api_server_url: str | None = None\n    # following are for reasoning parsers\n    reasoning_parser: ReasoningParser | None = None\n    # following is for tool parsers\n    tool_parser: ToolParser | None = None\n    allow_terminate_by_client: bool = False\n    enable_abort_handling: bool = False\n\n    @staticmethod\n    def get_session(session_id: int) -> int:\n        session_mgr = VariableInterface.get_session_manager()\n        if session_id == -1:\n            return session_mgr.get()\n        else:\n            return session_mgr.get(session_id)\n\n    @staticmethod\n    def get_session_manager():\n        return VariableInterface.async_engine.session_mgr\n\n    @staticmethod\n    def get_engine_config():\n        return VariableInterface.async_engine.backend_config\n\n\nrouter = APIRouter()\nserver_context = VariableInterface()\n\n\ndef get_model_list():\n    \"\"\"Available models.\n\n    If it is a slora serving. The model list would be [model_name, adapter_name1, adapter_name2, ...]\n    \"\"\"\n    model_names = [VariableInterface.async_engine.model_name]\n    cfg = VariableInterface.async_engine.backend_config\n    model_names += getattr(cfg, 'adapters', None) or []\n    return model_names\n\n\n@router.get('/v1/models')\ndef available_models():\n    \"\"\"Show available models.\"\"\"\n    model_cards = []\n    for model_name in get_model_list():\n        model_cards.append(ModelCard(id=model_name, root=model_name, permission=[ModelPermission()]))\n    return ModelList(data=model_cards)\n\n\ndef create_error_response(status: HTTPStatus, message: str, error_type='invalid_request_error'):\n    \"\"\"Create error response according to http status and message.\n\n    Args:\n        status (HTTPStatus): HTTP status codes and reason phrases\n        message (str): error message\n        error_type (str): error type\n    \"\"\"\n    return JSONResponse(ErrorResponse(message=message, type=error_type, code=status.value).model_dump(),\n                        status_code=status.value)\n\n\ndef check_request(request) -> JSONResponse | None:\n    \"\"\"Check if a request is valid.\"\"\"\n    if hasattr(request, 'model') and request.model not in get_model_list():\n        return create_error_response(HTTPStatus.NOT_FOUND, f'The model {request.model!r} does not exist.')\n\n    # Import the appropriate check function based on request type\n    if isinstance(request, ChatCompletionRequest):\n        from .serving_chat_completion import check_request\n        check_func = check_request\n    elif isinstance(request, CompletionRequest):\n        from .serving_completion import check_request\n        check_func = check_request\n    elif isinstance(request, GenerateReqInput):\n        from .serving_generate import check_request\n        check_func = check_request\n    else:\n        # Define an async function that always returns success\n        def always_success(req, server_context):\n            return ''\n\n        check_func = always_success\n\n    error_msg = check_func(request, server_context)\n    if error_msg:\n        return create_error_response(HTTPStatus.BAD_REQUEST, error_msg)\n    return None\n\n\ndef _create_completion_logprobs(tokenizer: Tokenizer,\n                                token_ids: list[int] | None = None,\n                                logprobs: list[dict[int, float]] | None = None,\n                                skip_special_tokens: bool = True,\n                                offset: int = 0,\n                                all_token_ids: list[int] | None = None,\n                                state: DetokenizeState = None,\n                                spaces_between_special_tokens: bool = True):\n    \"\"\"Create openai LogProbs for completion.\n\n    Args:\n        tokenizer (Tokenizer): tokenizer.\n        token_ids (List[int]): output token ids.\n        logprobs (List[Dict[int, float]]): the top logprobs for each output\n            position.\n        skip_special_tokens (bool): Whether or not to remove special tokens\n            in the decoding. Default to be True.\n        offset (int): text offset.\n        all_token_ids (int): the history output token ids.\n        state (DetokenizeState): tokenizer decode state.\n        spaces_between_special_tokens (bool): Whether or not to add spaces\n            around special tokens. The behavior of Fast tokenizers is to have\n            this to False. This is setup to True in slow tokenizers.\n    \"\"\"\n    if logprobs is None or len(logprobs) == 0:\n        return None, None, None, None\n\n    if all_token_ids is None:\n        all_token_ids = []\n    if state is None:\n        state = DetokenizeState()\n\n    out_logprobs = LogProbs()\n    out_logprobs.top_logprobs = []\n    for token_id, tops in zip(token_ids, logprobs):\n        out_logprobs.text_offset.append(offset)\n        out_logprobs.token_logprobs.append(tops[token_id])\n\n        res = {}\n        out_state = None\n        for top_id, prob in tops.items():\n            response, _state = tokenizer.detokenize_incrementally(\n                all_token_ids + [top_id],\n                copy.deepcopy(state),\n                skip_special_tokens=skip_special_tokens,\n                spaces_between_special_tokens=spaces_between_special_tokens)\n            res[response] = prob\n            if top_id == token_id:\n                out_state = _state\n                offset += len(response)\n                out_logprobs.tokens.append(response)\n\n        out_logprobs.top_logprobs.append(res)\n        state = out_state\n        all_token_ids.append(token_id)\n\n    return out_logprobs, offset, all_token_ids, state\n\n\ndef _create_chat_completion_logprobs(tokenizer: Tokenizer,\n                                     token_ids: list[int] | None = None,\n                                     logprobs: list[dict[int, float]] | None = None):\n    \"\"\"Create openai LogProbs for chat.completion.\n\n    Args:\n        tokenizer (Tokenizer): tokenizer.\n        token_ids (List[int]): output token ids.\n        logprobs (List[Dict[int, float]]): the top logprobs for each output\n            position.\n    Returns:\n        ChoiceLogprobs: logprob result.\n    \"\"\"\n    if token_ids is None or logprobs is None:\n        return None\n\n    content: list[ChatCompletionTokenLogprob] = []\n    for token_id, tops in zip(token_ids, logprobs):\n        item = ChatCompletionTokenLogprob(token='', bytes=[], logprob=0.0, top_logprobs=[])\n        for top_id, prob in tops.items():\n            token = tokenizer.model.model.convert_ids_to_tokens(top_id)\n            if isinstance(token, bytes):\n                _bytes = list(token)\n                token = token.decode('utf-8', errors='backslashreplace')\n            else:\n                _bytes = list(token.encode())  # token is str\n            if top_id == token_id:\n                item.token = token\n                item.bytes = _bytes\n                item.logprob = prob\n            else:\n                item.top_logprobs.append(TopLogprob(token=token, bytes=_bytes, logprob=prob))\n        content.append(item)\n    return ChoiceLogprobs(content=content)\n\n\n@router.get('/health')\nasync def health() -> Response:\n    \"\"\"Health check.\"\"\"\n    return Response(status_code=200)\n\n\n@router.get('/terminate')\nasync def terminate():\n    \"\"\"Terminate server.\"\"\"\n    import signal\n\n    if not VariableInterface.allow_terminate_by_client:\n        return create_error_response(\n            HTTPStatus.BAD_REQUEST,\n            'The server can not be terminated. Please add --allow-terminate-by-client when start the server.')\n    os.kill(os.getpid(), signal.SIGTERM)\n    return Response(status_code=200)\n\n\n# modified from https://github.com/vllm-project/vllm/blob/v0.5.4/vllm/entrypoints/openai/logits_processors.py#L51  # noqa\ndef logit_bias_logits_processor(logit_bias: dict[int, float] | dict[str, float], tokenizer) -> LogitsProcessor:\n    try:\n        # Convert token_id to integer\n        # Clamp the bias between -100 and 100 per OpenAI API spec\n        clamped_logit_bias: dict[int, float] = {\n            int(token_id): min(100.0, max(-100.0, bias))\n            for token_id, bias in logit_bias.items()\n        }\n    except ValueError as exc:\n        raise ValueError('Found token_id in logit_bias that is not '\n                         'an integer or string representing an integer') from exc\n\n    # Check if token_id is within the vocab size\n    for token_id, bias in clamped_logit_bias.items():\n        if token_id < 0 or token_id >= tokenizer.vocab_size:\n            raise ValueError(f'token_id {token_id} in logit_bias contains '\n                             'out-of-vocab token id')\n\n    def _logit_bias_processor(\n        logit_bias,\n        token_ids,\n        logits,\n    ):\n        for token_id, bias in logit_bias.items():\n            logits[token_id] = logits[token_id] + bias\n        return logits\n\n    return partial(_logit_bias_processor, clamped_logit_bias)\n\n\n@router.post('/v1/chat/completions', dependencies=[Depends(validate_json_request)])\nasync def chat_completions_v1(request: ChatCompletionRequest, raw_request: Request = None):\n    \"\"\"Completion API similar to OpenAI's API.\n\n    Refer to https://platform.openai.com/docs/api-reference/chat/create\n    for the API specification.\n\n    The request should be a JSON object with the following fields:\n\n    - **model**: model name. Available from /v1/models.\n    - **messages**: string prompt or chat history in OpenAI format. Chat history example:\n      ``[{\"role\": \"user\", \"content\": \"hi\"}]``.\n    - **temperature** (float): to modulate the next token probability\n    - **top_p** (float): If set to float < 1, only the smallest set of most\n      probable tokens with probabilities that add up to top_p or higher\n      are kept for generation.\n    - **n** (int): How many chat completion choices to generate for each input\n      message. **Only support one here**.\n    - **stream**: whether to stream the results or not. Default to false.\n    - **stream_options**: Options for streaming response. Only set this when you\n      set stream: true.\n    - **max_completion_tokens** (int | None): output token nums. Default to None.\n    - **max_tokens** (int | None): output token nums. Default to None.\n      Deprecated: Use max_completion_tokens instead.\n    - **repetition_penalty** (float): The parameter for repetition penalty.\n      1.0 means no penalty\n    - **stop** (str | List[str] | None): To stop generating further\n      tokens. Only accept stop words that's encoded to one token idex.\n    - **response_format** (dict | None): To generate response according to given\n      schema. Examples:\n\n      .. code-block:: json\n\n        {\n          \"type\": \"json_schema\",\n          \"json_schema\":{\n            \"name\": \"test\",\n            \"schema\":{\n              \"properties\":{\n                \"name\":{\"type\":\"string\"}\n              },\n              \"required\":[\"name\"],\n              \"type\":\"object\"\n            }\n          }\n        }\n\n      or ``{\"type\": \"regex_schema\", \"regex_schema\": \"call me [A-Za-z]{1,10}\"}``\n    - **logit_bias** (dict): Bias to logits. Only supported in pytorch engine.\n    - **tools** (list): A list of tools the model may call. Currently, only\n      internlm2 functions are supported as a tool. Use this to specify a\n      list of functions for which the model can generate JSON inputs.\n    - **tool_choice** (str | object): Controls which (if any) tool is called by\n      the model. `none` means the model will not call any tool and instead\n      generates a message. Specifying a particular tool via\n      ``{\"type\": \"function\", \"function\": {\"name\": \"my_function\"}}``\n      forces the model to call that tool. `auto` or `required` will put all\n      the tools informationto the model.\n\n    Additional arguments supported by LMDeploy:\n\n    - **top_k** (int): The number of the highest probability vocabulary\n      tokens to keep for top-k-filtering\n    - **ignore_eos** (bool): indicator for ignoring eos\n    - **skip_special_tokens** (bool): Whether or not to remove special tokens\n      in the decoding. Default to be True.\n    - **spaces_between_special_tokens** (bool): Whether or not to add spaces\n      around special tokens. The behavior of Fast tokenizers is to have\n      this to False. This is setup to True in slow tokenizers.\n    - **min_new_tokens** (int): To generate at least numbers of tokens.\n    - **min_p** (float): Minimum token probability, which will be scaled by the\n      probability of the most likely token. It must be a value between\n      0 and 1. Typical values are in the 0.01-0.2 range, comparably\n      selective as setting `top_p` in the 0.99-0.8 range (use the\n      opposite of normal `top_p` values)\n\n    Currently we do not support the following features:\n\n    - **presence_penalty** (replaced with repetition_penalty)\n    - **frequency_penalty** (replaced with repetition_penalty)\n    \"\"\"\n    error_check_ret = check_request(request)\n    if error_check_ret is not None:\n        return error_check_ret\n    session = VariableInterface.get_session(request.session_id)\n\n    json_request = await raw_request.json()\n    migration_request = json_request.pop('migration_request', None)\n    with_cache = json_request.pop('with_cache', False)\n    preserve_cache = json_request.pop('preserve_cache', False)\n    if migration_request:\n        migration_request = MigrationRequest.model_validate(migration_request)\n\n    model_name = request.model\n    adapter_name = None\n    if model_name != VariableInterface.async_engine.model_name:\n        adapter_name = model_name  # got a adapter name\n    request_id = str(session.session_id)\n    created_time = int(time.time())\n    gpt_oss_parser = None\n    if VariableInterface.async_engine.arch == 'GptOssForCausalLM':\n        gpt_oss_parser = GptOssChatParser()\n\n    if isinstance(request.stop, str):\n        request.stop = [request.stop]\n\n    gen_logprobs, logits_processors = None, None\n    if request.logprobs and request.top_logprobs:\n        gen_logprobs = request.top_logprobs\n    response_format = None\n    if request.response_format and request.response_format.type != 'text':\n        response_format = request.response_format.model_dump()\n\n    if request.logit_bias is not None:\n        try:\n            logits_processors = [\n                logit_bias_logits_processor(request.logit_bias, VariableInterface.async_engine.tokenizer.model)\n            ]\n        except Exception as e:\n            return create_error_response(HTTPStatus.BAD_REQUEST, str(e))\n\n    random_seed = request.seed if request.seed else None\n    max_new_tokens = (request.max_completion_tokens if request.max_completion_tokens else request.max_tokens)\n\n    gen_config = GenerationConfig(\n        max_new_tokens=max_new_tokens,\n        do_sample=True,\n        logprobs=gen_logprobs,\n        top_k=request.top_k,\n        top_p=request.top_p,\n        temperature=request.temperature,\n        repetition_penalty=request.repetition_penalty,\n        ignore_eos=request.ignore_eos,\n        stop_words=request.stop,\n        include_stop_str_in_output=request.include_stop_str_in_output,\n        skip_special_tokens=request.skip_special_tokens,\n        response_format=response_format,\n        logits_processors=logits_processors,\n        min_new_tokens=request.min_new_tokens,\n        min_p=request.min_p,\n        random_seed=random_seed,\n        spaces_between_special_tokens=request.spaces_between_special_tokens,\n        migration_request=migration_request,\n        with_cache=with_cache,\n        preserve_cache=preserve_cache,\n    )\n\n    tools = None\n    if request.tools and request.tool_choice != 'none':\n        gen_config.skip_special_tokens = False\n        # internlm2 only uses contents inside function regardless of 'type'\n        if not isinstance(request.tool_choice, str):\n            if gpt_oss_parser:\n                tools = [\n                    item.model_dump() for item in request.tools\n                    if item.function.name == request.tool_choice.function.name\n                ]\n            else:\n                tools = [\n                    item.function.model_dump() for item in request.tools\n                    if item.function.name == request.tool_choice.function.name\n                ]\n        else:\n            if gpt_oss_parser:\n                tools = [item.model_dump() for item in request.tools]\n            else:\n                tools = [item.function.model_dump() for item in request.tools]\n    # text completion for string input\n    do_preprocess = False if isinstance(request.messages, str) else request.do_preprocess\n    chat_template_kwargs = request.chat_template_kwargs or {}\n    if request.enable_thinking is not None:\n        logger.warning('`enable_thinking` will be deprecated in the future, '\n                       'please use `chat_template_kwargs` instead.')\n        if chat_template_kwargs.get('enable_thinking') is None:\n            chat_template_kwargs['enable_thinking'] = request.enable_thinking\n        else:\n            logger.warning('`enable_thinking` in `chat_template_kwargs` will override the value in request.')\n    enable_thinking = chat_template_kwargs.get('enable_thinking', None)\n    result_generator = VariableInterface.async_engine.generate(\n        request.messages,\n        session,\n        gen_config=gen_config,\n        tools=tools,\n        reasoning_effort=request.reasoning_effort,\n        stream_response=True,  # always use stream to enable batching\n        sequence_start=True,\n        sequence_end=True,\n        do_preprocess=do_preprocess,\n        adapter_name=adapter_name,\n        chat_template_kwargs=chat_template_kwargs or None,\n        media_io_kwargs=request.media_io_kwargs,\n        mm_processor_kwargs=request.mm_processor_kwargs)\n\n    def create_stream_response_json(index: int,\n                                    delta_message: DeltaMessage,\n                                    finish_reason: str | None = None,\n                                    logprobs: LogProbs | None = None,\n                                    usage: UsageInfo | None = None) -> str:\n        choice_data = ChatCompletionResponseStreamChoice(index=index,\n                                                         delta=delta_message,\n                                                         finish_reason=finish_reason,\n                                                         logprobs=logprobs)\n        response = ChatCompletionStreamResponse(\n            id=request_id,\n            created=created_time,\n            model=model_name,\n            choices=[choice_data],\n            usage=usage,\n        )\n        response_json = response.model_dump_json()\n\n        return response_json\n\n    async def completion_stream_generator() -> AsyncGenerator[str, None]:\n        previous_text = ''\n        current_text = ''\n        previous_token_ids = []\n        current_token_ids = []\n        delta_token_ids = []\n        has_parser = VariableInterface.tool_parser is not None or VariableInterface.reasoning_parser is not None\n        streaming_tools = False\n        async for res in result_generator:\n            logprobs, usage = None, None\n            if gen_logprobs and res.logprobs:\n                logprobs = _create_chat_completion_logprobs(VariableInterface.async_engine.tokenizer, res.token_ids,\n                                                            res.logprobs)\n            # Only stream chunk `usage` in the final chunk according to OpenAI API spec\n            if (res.finish_reason and request.stream_options and request.stream_options.include_usage):\n                total_tokens = sum([res.input_token_len, res.generate_token_len])\n                usage = UsageInfo(\n                    prompt_tokens=res.input_token_len,\n                    completion_tokens=res.generate_token_len,\n                    total_tokens=total_tokens,\n                )\n\n            delta_token_ids = res.token_ids if res.token_ids is not None else []\n            if gpt_oss_parser:\n                delta_message = gpt_oss_parser.parse_streaming(res.token_ids)\n                if res.finish_reason == 'stop' and len(delta_message.tool_calls) > 0:\n                    res.finish_reason = 'tool_calls'\n            else:\n                delta_message = DeltaMessage(role='assistant', content=res.response)\n                if has_parser:\n                    current_text = current_text + res.response\n                    current_token_ids = current_token_ids + delta_token_ids\n                if request.tool_choice != 'none' and VariableInterface.tool_parser is not None:\n                    if res.finish_reason == 'stop' and streaming_tools is True:\n                        res.finish_reason = 'tool_calls'\n                    tool_delta = VariableInterface.tool_parser.extract_tool_calls_streaming(\n                        previous_text=previous_text,\n                        current_text=current_text,\n                        delta_text=delta_message.content,\n                        previous_token_ids=previous_token_ids,\n                        current_token_ids=current_token_ids,\n                        delta_token_ids=delta_token_ids,\n                        request=request)\n                    if tool_delta is not None:\n                        delta_message.tool_calls = tool_delta.tool_calls\n                        delta_message.content = tool_delta.content\n                        if isinstance(tool_delta.tool_calls, list) and len(tool_delta.tool_calls):\n                            streaming_tools = True\n                elif (request.tool_choice != 'none' and request.tools is not None\n                      and VariableInterface.tool_parser is None):\n                    logger.error('Please launch the api_server with --tool-call-parser if you want to use tool.')\n                if VariableInterface.reasoning_parser is not None and enable_thinking is not False:\n                    reasoning_delta = VariableInterface.reasoning_parser.extract_reasoning_content_streaming(\n                        previous_text=previous_text,\n                        current_text=current_text,\n                        delta_text=delta_message.content or '',\n                        previous_token_ids=previous_token_ids,\n                        current_token_ids=current_token_ids,\n                        delta_token_ids=delta_token_ids)\n                    if reasoning_delta is not None:\n                        delta_message.reasoning_content = reasoning_delta.reasoning_content\n                        delta_message.content = reasoning_delta.content\n                if has_parser:\n                    previous_text = current_text\n                    previous_token_ids = current_token_ids\n            if request.return_token_ids:\n                delta_message.gen_tokens = delta_token_ids\n            response_json = create_stream_response_json(index=0,\n                                                        delta_message=delta_message,\n                                                        finish_reason=res.finish_reason,\n                                                        logprobs=logprobs,\n                                                        usage=usage)\n            if res.cache_block_ids is not None:\n                response_json['cache_block_ids'] = res.cache_block_ids\n                response_json['remote_token_ids'] = res.token_ids\n            yield f'data: {response_json}\\n\\n'\n        yield 'data: [DONE]\\n\\n'\n\n    # Streaming response\n    if request.stream:\n        return StreamingResponse(completion_stream_generator(), media_type='text/event-stream')\n\n    # Non-streaming response\n    final_logprobs = []\n    final_token_ids = []\n    final_res = None\n    text = ''\n    cache_block_ids = []\n    remote_token_ids = []\n    async for res in result_generator:\n        if await raw_request.is_disconnected():\n            # Abort the request if the client disconnects.\n            await session.async_abort()\n            return create_error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected')\n        final_res = res\n        text += res.response\n        if res.token_ids:\n            final_token_ids.extend(res.token_ids)\n        if res.logprobs:\n            final_logprobs.extend(res.logprobs)\n        cache_block_ids.append(res.cache_block_ids)\n        remote_token_ids.append(res.token_ids)\n\n    if gpt_oss_parser:\n        message = gpt_oss_parser.parse_full(final_token_ids)\n        if final_res.finish_reason == 'stop' and len(message.tool_calls) > 0:\n            final_res.finish_reason = 'tool_calls'\n    else:\n        tool_calls = None\n        reasoning_content = None\n        if request.tool_choice != 'none' and VariableInterface.tool_parser is not None:\n            try:\n                tool_call_info = VariableInterface.tool_parser.extract_tool_calls(text, request=request)\n                text, tool_calls = tool_call_info.content, tool_call_info.tool_calls\n                if isinstance(tool_calls, list) and len(tool_calls):\n                    if final_res.finish_reason == 'stop':\n                        final_res.finish_reason = 'tool_calls'\n\n            except Exception as e:\n                logger.error(f'Failed to parse {text}. Exception: {e}.')\n                return create_error_response(HTTPStatus.BAD_REQUEST, 'Failed to parse fc related info to json format!')\n        elif request.tool_choice != 'none' and request.tools is not None and VariableInterface.tool_parser is None:\n            logger.error('Please launch the api_server with --tool-call-parser if you want to use tool.')\n\n        if VariableInterface.reasoning_parser is not None and enable_thinking is not False:\n            reasoning_content, text = VariableInterface.reasoning_parser.extract_reasoning_content(text, request)\n\n        message = ChatMessage(role='assistant',\n                              content=text,\n                              tool_calls=tool_calls,\n                              reasoning_content=reasoning_content)\n\n    logprobs = None\n    if gen_logprobs and len(final_logprobs):\n        logprobs = _create_chat_completion_logprobs(VariableInterface.async_engine.tokenizer, final_token_ids,\n                                                    final_logprobs)\n\n    assert final_res is not None\n    choices = []\n    if request.return_token_ids:\n        message.gen_tokens = final_token_ids\n    choice_data = ChatCompletionResponseChoice(\n        index=0,\n        message=message,\n        logprobs=logprobs,\n        finish_reason=final_res.finish_reason,\n    )\n    choices.append(choice_data)\n\n    if with_cache:\n        cache_block_ids = cache_block_ids[0]\n        remote_token_ids = [remote_token_ids[0][-1]]\n\n    total_tokens = sum([final_res.input_token_len, final_res.generate_token_len])\n    usage = UsageInfo(\n        prompt_tokens=final_res.input_token_len,\n        completion_tokens=final_res.generate_token_len,\n        total_tokens=total_tokens,\n    )\n    response = ChatCompletionResponse(\n        id=request_id,\n        created=created_time,\n        model=model_name,\n        choices=choices,\n        usage=usage,\n    ).model_dump()\n\n    if with_cache:\n        response['cache_block_ids'] = cache_block_ids\n        response['remote_token_ids'] = remote_token_ids\n\n    return response\n\n\n@router.post('/v1/completions', dependencies=[Depends(validate_json_request)])\nasync def completions_v1(request: CompletionRequest, raw_request: Request = None):\n    \"\"\"Completion API similar to OpenAI's API.\n\n    Go to https://platform.openai.com/docs/api-reference/completions/create\n    for the API specification.\n\n    The request should be a JSON object with the following fields:\n\n    - **model** (str): model name. Available from /v1/models.\n    - **prompt** (str): the input prompt.\n    - **suffix** (str): The suffix that comes after a completion of inserted text.\n    - **max_completion_tokens** (int | None): output token nums. Default to None.\n    - **max_tokens** (int | None): output token nums. Default to 16.\n      Deprecated: Use max_completion_tokens instead.\n    - **temperature** (float): to modulate the next token probability\n    - **top_p** (float): If set to float < 1, only the smallest set of most\n      probable tokens with probabilities that add up to top_p or higher\n      are kept for generation.\n    - **n** (int): How many chat completion choices to generate for each input\n      message. **Only support one here**.\n    - **stream**: whether to stream the results or not. Default to false.\n    - **stream_options**: Options for streaming response. Only set this when you\n      set stream: true.\n    - **repetition_penalty** (float): The parameter for repetition penalty.\n      1.0 means no penalty\n    - **user** (str): A unique identifier representing your end-user.\n    - **stop** (str | list[str] | None): To stop generating further\n      tokens. Only accept stop words that's encoded to one token idex.\n\n    Additional arguments supported by LMDeploy:\n\n    - **ignore_eos** (bool): indicator for ignoring eos\n    - **skip_special_tokens** (bool): Whether or not to remove special tokens\n      in the decoding. Default to be True.\n    - **spaces_between_special_tokens** (bool): Whether or not to add spaces\n      around special tokens. The behavior of Fast tokenizers is to have\n      this to False. This is setup to True in slow tokenizers.\n    - **top_k** (int): The number of the highest probability vocabulary\n      tokens to keep for top-k-filtering\n    - **min_p** (float): Minimum token probability, which will be scaled by the\n      probability of the most likely token. It must be a value between\n      0 and 1. Typical values are in the 0.01-0.2 range, comparably\n      selective as setting `top_p` in the 0.99-0.8 range (use the\n      opposite of normal `top_p` values)\n\n    Currently we do not support the following features:\n\n    - **logprobs** (not supported yet)\n    - **presence_penalty** (replaced with repetition_penalty)\n    - **frequency_penalty** (replaced with repetition_penalty)\n    \"\"\"\n    error_check_ret = check_request(request)\n    if error_check_ret is not None:\n        return error_check_ret\n\n    json_request = await raw_request.json()\n    migration_request = json_request.pop('migration_request', None)\n    with_cache = json_request.pop('with_cache', False)\n    preserve_cache = json_request.pop('preserve_cache', False)\n    if migration_request:\n        migration_request = MigrationRequest.model_validate(migration_request)\n\n    model_name = request.model\n    adapter_name = None\n    if model_name != VariableInterface.async_engine.model_name:\n        adapter_name = model_name  # got a adapter name\n    request_id = str(request.session_id)\n    created_time = int(time.time())\n    sessions = []\n    if isinstance(request.prompt, str):\n        request.prompt = [request.prompt]\n        sessions.append(VariableInterface.get_session(request.session_id))\n    elif isinstance(request.prompt, list):\n        for i in range(len(request.prompt)):\n            sessions.append(VariableInterface.get_session(i + 1))\n    if isinstance(request.stop, str):\n        request.stop = [request.stop]\n    random_seed = request.seed if request.seed else None\n    max_new_tokens = (request.max_completion_tokens if request.max_completion_tokens else request.max_tokens)\n\n    gen_config = GenerationConfig(\n        max_new_tokens=max_new_tokens,\n        do_sample=True,\n        logprobs=request.logprobs,\n        top_k=request.top_k,\n        top_p=request.top_p,\n        temperature=request.temperature,\n        repetition_penalty=request.repetition_penalty,\n        ignore_eos=request.ignore_eos,\n        stop_words=request.stop,\n        skip_special_tokens=request.skip_special_tokens,\n        min_p=request.min_p,\n        random_seed=random_seed,\n        spaces_between_special_tokens=request.spaces_between_special_tokens,\n        migration_request=migration_request,\n        with_cache=with_cache,\n        preserve_cache=preserve_cache,\n    )\n    generators = []\n    for prompt, session in zip(request.prompt, sessions):\n        result_generator = VariableInterface.async_engine.generate(\n            prompt,\n            session,\n            gen_config=gen_config,\n            stream_response=True,  # always use stream to enable batching\n            sequence_start=True,\n            sequence_end=True,\n            do_preprocess=False,\n            adapter_name=adapter_name)\n        generators.append(result_generator)\n\n    def create_stream_response_json(index: int,\n                                    text: str,\n                                    finish_reason: str | None = None,\n                                    logprobs: LogProbs | None = None,\n                                    gen_tokens: list[int] | None = None,\n                                    usage: UsageInfo | None = None) -> str:\n        choice_data = CompletionResponseStreamChoice(index=index,\n                                                     text=text,\n                                                     gen_tokens=gen_tokens,\n                                                     finish_reason=finish_reason,\n                                                     logprobs=logprobs)\n        response = CompletionStreamResponse(\n            id=request_id,\n            created=created_time,\n            model=model_name,\n            choices=[choice_data],\n            usage=usage,\n        )\n        response_json = response.model_dump()\n        return response_json\n\n    async def completion_stream_generator() -> AsyncGenerator[str, None]:\n        # First chunk with role\n        for generator in generators:\n            offset = 0\n            all_token_ids = []\n            state = DetokenizeState()\n            async for res in generator:\n                logprobs = None\n                usage = None\n                if request.logprobs and res.logprobs:\n                    logprobs, offset, all_token_ids, state = _create_completion_logprobs(  # noqa E501\n                        VariableInterface.async_engine.tokenizer, res.token_ids, res.logprobs,\n                        gen_config.skip_special_tokens, offset, all_token_ids, state,\n                        gen_config.spaces_between_special_tokens)\n                # Only stream chunk `usage` in the final chunk according to OpenAI API spec\n                if (res.finish_reason and request.stream_options and request.stream_options.include_usage):\n                    final_res = res\n                    total_tokens = sum([final_res.input_token_len, final_res.generate_token_len])\n                    usage = UsageInfo(\n                        prompt_tokens=final_res.input_token_len,\n                        completion_tokens=final_res.generate_token_len,\n                        total_tokens=total_tokens,\n                    )\n                gen_tokens = None\n                if request.return_token_ids:\n                    gen_tokens = res.token_ids or []\n                response_json = create_stream_response_json(index=0,\n                                                            text=res.response,\n                                                            gen_tokens=gen_tokens,\n                                                            finish_reason=res.finish_reason,\n                                                            logprobs=logprobs,\n                                                            usage=usage)\n                if res.cache_block_ids is not None:\n                    response_json['cache_block_ids'] = res.cache_block_ids\n                    response_json['remote_token_ids'] = res.token_ids\n                yield f'data: {json.dumps(response_json)}\\n\\n'\n        yield 'data: [DONE]\\n\\n'\n\n    # Streaming response\n    if request.stream:\n        return StreamingResponse(completion_stream_generator(), media_type='text/event-stream')\n\n    # Non-streaming response\n    usage = UsageInfo()\n    choices = [None] * len(generators)\n    cache_block_ids = []\n    remote_token_ids = []\n\n    async def _inner_call(i, generator):\n        nonlocal cache_block_ids, remote_token_ids\n        final_logprobs = []\n        final_token_ids = []\n        final_res = None\n        text = ''\n        async for res in generator:\n            if await raw_request.is_disconnected():\n                # Abort the request if the client disconnects.\n                await VariableInterface.async_engine.stop_session(request.session_id)\n                return create_error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected')\n            final_res = res\n            text += res.response\n            cache_block_ids.append(res.cache_block_ids)\n            remote_token_ids.append(res.token_ids)\n            if res.token_ids:\n                final_token_ids.extend(res.token_ids)\n            if res.logprobs:\n                final_logprobs.extend(res.logprobs)\n\n        logprobs = None\n        if request.logprobs and len(final_logprobs):\n            logprobs, _, _, _ = _create_completion_logprobs(\n                VariableInterface.async_engine.tokenizer,\n                final_token_ids,\n                final_logprobs,\n                gen_config.skip_special_tokens,\n                spaces_between_special_tokens=gen_config.spaces_between_special_tokens)\n\n        assert final_res is not None\n        choice_data = CompletionResponseChoice(index=i,\n                                               text=text,\n                                               finish_reason=final_res.finish_reason,\n                                               logprobs=logprobs,\n                                               gen_tokens=final_token_ids if request.return_token_ids else None)\n        choices[i] = choice_data\n\n        if with_cache:\n            cache_block_ids = cache_block_ids[0]\n            remote_token_ids = [remote_token_ids[0][-1]]\n\n        total_tokens = sum([final_res.input_token_len, final_res.generate_token_len])\n        usage.prompt_tokens += final_res.input_token_len\n        usage.completion_tokens += final_res.generate_token_len\n        usage.total_tokens += total_tokens\n\n    await asyncio.gather(*[_inner_call(i, generators[i]) for i in range(len(generators))])\n\n    response = CompletionResponse(\n        id=request_id,\n        created=created_time,\n        model=model_name,\n        choices=choices,\n        usage=usage,\n    ).model_dump()\n\n    if with_cache:\n        response['cache_block_ids'] = cache_block_ids\n        response['remote_token_ids'] = remote_token_ids\n\n    return response\n\n\n@router.post('/generate', dependencies=[Depends(validate_json_request)])\nasync def generate(request: GenerateReqInput, raw_request: Request = None):\n    error_check_ret = check_request(request)\n    if error_check_ret is not None:\n        return error_check_ret\n    session = VariableInterface.get_session(request.session_id)\n\n    prompt = request.prompt\n    input_ids = request.input_ids\n    image_data = request.image_data\n    if image_data is not None:\n        # convert to openai format\n        image_input = []\n        if not isinstance(image_data, list):\n            image_data = [image_data]\n        for img in image_data:\n            if isinstance(img, str):\n                image_input.append(dict(type='image_url', image_url=dict(url=img)))\n            else:\n                image_input.append(dict(type='image_url', image_url=img))\n        text_input = dict(type='text', text=prompt if prompt else input_ids)\n        prompt = [dict(role='user', content=[text_input] + image_input)]\n        input_ids = None\n\n    gen_config = GenerationConfig(\n        max_new_tokens=request.max_tokens,\n        do_sample=True,\n        logprobs=1 if request.return_logprob else None,\n        top_k=request.top_k,\n        top_p=request.top_p,\n        min_p=request.min_p,\n        temperature=request.temperature,\n        repetition_penalty=request.repetition_penalty,\n        ignore_eos=request.ignore_eos,\n        stop_words=request.stop,\n        stop_token_ids=request.stop_token_ids,\n        skip_special_tokens=request.skip_special_tokens,\n        spaces_between_special_tokens=request.spaces_between_special_tokens,\n        include_stop_str_in_output=request.include_stop_str_in_output,\n        return_routed_experts=request.return_routed_experts,\n        repetition_ngram_size=request.repetition_ngram_size,\n        repetition_ngram_threshold=request.repetition_ngram_threshold,\n    )\n\n    result_generator = VariableInterface.async_engine.generate(\n        messages=prompt,\n        session_id=session,\n        input_ids=input_ids,\n        gen_config=gen_config,\n        stream_response=True,  # always use stream to enable batching\n        sequence_start=True,\n        sequence_end=True,\n        do_preprocess=False,\n        media_io_kwargs=request.media_io_kwargs,\n        mm_processor_kwargs=request.mm_processor_kwargs)\n\n    def create_generate_response_json(res, text, output_ids, logprobs, finish_reason, routed_experts=None):\n        # only output router experts in last chunk\n        routed_experts = None if finish_reason is None else routed_experts\n        meta = GenerateReqMetaOutput(finish_reason=dict(type=finish_reason) if finish_reason else None,\n                                     output_token_logprobs=logprobs or None,\n                                     prompt_tokens=res.input_token_len,\n                                     routed_experts=routed_experts,\n                                     completion_tokens=res.generate_token_len)\n\n        response = GenerateReqOutput(text=text, output_ids=output_ids, meta_info=meta, routed_experts=routed_experts)\n        return response.model_dump_json()\n\n    async def generate_stream_generator():\n        async for res in result_generator:\n            text = res.response or ''\n            output_ids = res.token_ids\n            routed_experts = res.routed_experts\n            logprobs = []\n            if res.logprobs:\n                for tok, tok_logprobs in zip(res.token_ids, res.logprobs):\n                    logprobs.append((tok_logprobs[tok], tok))\n            response_json = create_generate_response_json(res,\n                                                          text,\n                                                          output_ids,\n                                                          logprobs,\n                                                          res.finish_reason,\n                                                          routed_experts=routed_experts)\n            yield f'data: {response_json}\\n\\n'\n        yield 'data: [DONE]\\n\\n'\n\n    if request.stream:\n        return StreamingResponse(generate_stream_generator(), media_type='text/event-stream')\n\n    response = None\n\n    async def _inner_call():\n        text = ''\n        output_ids = []\n        logprobs = []\n        async for res in result_generator:\n            if await raw_request.is_disconnected():\n                # Abort the request if the client disconnects.\n                await session.async_abort()\n                return create_error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected')\n            text += res.response or ''\n            output_ids.extend(res.token_ids or [])\n            if res.logprobs:\n                for tok, tok_logprobs in zip(res.token_ids, res.logprobs):\n                    logprobs.append((tok_logprobs[tok], tok))\n        nonlocal response\n        meta = GenerateReqMetaOutput(finish_reason=dict(type=res.finish_reason) if res.finish_reason else None,\n                                     output_token_logprobs=logprobs or None,\n                                     prompt_tokens=res.input_token_len,\n                                     routed_experts=res.routed_experts,\n                                     completion_tokens=res.generate_token_len)\n        response = GenerateReqOutput(text=text, output_ids=output_ids, meta_info=meta)\n\n    await _inner_call()\n    return response\n\n\n@router.post('/v1/embeddings', tags=['unsupported'])\nasync def create_embeddings(request: EmbeddingsRequest, raw_request: Request = None):\n    \"\"\"Creates embeddings for the text.\"\"\"\n    return create_error_response(HTTPStatus.BAD_REQUEST, 'Unsupported by turbomind.')\n\n\n@router.post('/v1/encode', dependencies=[Depends(validate_json_request)])\nasync def encode(request: EncodeRequest, raw_request: Request = None):\n    \"\"\"Encode prompts.\n\n    The request should be a JSON object with the following fields:\n\n    - **input**: the prompt to be encoded. In str or list[str] format.\n    - **do_preprocess**: whether do preprocess or not. Default to False.\n    - **add_bos**: True when it is the beginning of a conversation. False when it\n      is not. Default to True.\n    \"\"\"\n\n    def encode(prompt: str, do_preprocess: bool, add_bos: bool):\n        if do_preprocess:\n            prompt = VariableInterface.async_engine.chat_template.get_prompt(prompt, sequence_start=add_bos)\n        input_ids = VariableInterface.async_engine.tokenizer.encode(prompt, add_bos=add_bos)\n        return input_ids\n\n    if isinstance(request.input, str):\n        encoded = encode(request.input, request.do_preprocess, request.add_bos)\n        return EncodeResponse(input_ids=encoded, length=len(encoded))\n    else:\n        encoded, length = [], []\n        for prompt in request.input:\n            ids = encode(prompt, request.do_preprocess, request.add_bos)\n            encoded.append(ids)\n            length.append(len(ids))\n        return EncodeResponse(input_ids=encoded, length=length)\n\n\n@router.post('/pooling', dependencies=[Depends(validate_json_request)])\nasync def pooling(request: PoolingRequest, raw_request: Request = None):\n    \"\"\"Pooling prompts for reward model.\n\n    In vLLM documentation, https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#pooling-api_1,\n    the input format of Pooling API is the same as Embeddings API.\n\n    Go to https://platform.openai.com/docs/api-reference/embeddings/create\n    for the Embeddings API specification.\n\n    The request should be a JSON object with the following fields:\n\n    - **model** (str): model name. Available from /v1/models.\n    - **input** (list[int] | list[list[int]] | str | list[str]): input text to be embed\n    \"\"\"\n\n    async_engine = VariableInterface.async_engine\n\n    request_input = request.input\n    model_name = request.model or async_engine.model_name\n\n    # Normalize all inputs to be a batch (List[List[int]])\n    if isinstance(request_input, str):\n        input_ids = [async_engine.tokenizer.encode(request_input)]\n    elif isinstance(request_input, list):\n        if not request_input:\n            return create_error_response(HTTPStatus.BAD_REQUEST, 'Input list cannot be empty.')\n        if isinstance(request_input[0], str):  # list[str]\n            input_ids = [async_engine.tokenizer.encode(p) for p in request_input]\n        elif isinstance(request_input[0], int):  # list[int]\n            input_ids = [request_input]\n        elif isinstance(request_input[0], list):  # list[list[int]]\n            input_ids = request_input\n        else:\n            return create_error_response(HTTPStatus.BAD_REQUEST, 'Input list contains an invalid type.')\n    else:\n        return create_error_response(HTTPStatus.BAD_REQUEST, 'Input must be a string or a list.')\n\n    batch_scores = await async_engine.async_get_reward_score(input_ids)\n    prompt_tokens = sum(len(ids) for ids in input_ids)\n    usage = UsageInfo(prompt_tokens=prompt_tokens, completion_tokens=0, total_tokens=prompt_tokens)\n\n    data = []\n    for i, score in enumerate(batch_scores):\n        data.append({\n            'index': i,\n            'object': 'pooling',\n            'data': score,\n        })\n\n    response = PoolingResponse(model=model_name, data=data, usage=usage)\n    return response.model_dump()\n\n\n@router.post('/update_weights', dependencies=[Depends(validate_json_request)])\ndef update_params(request: UpdateParamsRequest, raw_request: Request = None):\n    \"\"\"Update weights for the model.\"\"\"\n    VariableInterface.async_engine.engine.update_params(request)\n    return JSONResponse(content=None)\n\n\n@router.post('/sleep', dependencies=[Depends(validate_json_request)])\nasync def sleep(raw_request: Request = None):\n    level = raw_request.query_params.get('level', '1')\n    VariableInterface.async_engine.sleep(int(level))\n    return Response(status_code=200)\n\n\n@router.post('/wakeup', dependencies=[Depends(validate_json_request)])\nasync def wakeup(raw_request: Request = None):\n    tags = raw_request.query_params.getlist('tags')\n    tags = tags or None\n    VariableInterface.async_engine.wakeup(tags)\n    return Response(status_code=200)\n\n\n@router.get('/is_sleeping')\nasync def is_sleeping():\n    is_sleeping = VariableInterface.async_engine.is_sleeping\n    return JSONResponse(content={'is_sleeping': is_sleeping})\n\n\n\"\"\" PD Disaggregation API Begin \"\"\"\n\n\n@router.get('/distserve/engine_info')\nasync def engine_info():\n    engine_config = VariableInterface.async_engine.backend_config\n\n    response = DistServeEngineConfig(tp_size=engine_config.tp,\n                                     dp_size=engine_config.dp,\n                                     pp_size=None,\n                                     ep_size=engine_config.ep,\n                                     dp_rank=engine_config.dp_rank,\n                                     block_size=engine_config.block_size,\n                                     num_cpu_blocks=engine_config.num_cpu_blocks,\n                                     num_gpu_blocks=engine_config.num_gpu_blocks)\n\n    return response.model_dump_json()\n\n\n@router.post('/distserve/p2p_initialize')\nasync def p2p_initialize(init_request: DistServeInitRequest):\n    return VariableInterface.async_engine.p2p_initialize(init_request)\n\n\n@router.post('/distserve/p2p_connect')\nasync def p2p_connect(conn_request: DistServeConnectionRequest):\n    return VariableInterface.async_engine.p2p_connect(conn_request)\n\n\n@router.post('/distserve/p2p_drop_connect')\nasync def p2p_drop_connect(drop_conn_request: DistServeDropConnectionRequest):\n    return VariableInterface.async_engine.p2p_drop_connect(drop_conn_request)\n\n\n@router.post('/distserve/free_cache')\nasync def free_cache(cache_free_request: DistServeCacheFreeRequest) -> JSONResponse:\n    session_id = cache_free_request.remote_session_id\n    VariableInterface.async_engine.free_cache(session_id)\n    return {'status': 'SUCCESS'}\n\n\n\"\"\" PD Disaggregation API End \"\"\"\n\n\n@router.post('/abort_request')\nasync def abort_request(request: AbortRequest, raw_request: Request = None):\n    \"\"\"Abort an ongoing request.\"\"\"\n    if not VariableInterface.enable_abort_handling:\n        return Response(\n            status_code=501,\n            content='This server does not support abort requests. Enable with --enable-abort-handling flag.')\n\n    if request.abort_all:\n        await VariableInterface.async_engine.stop_all_session()\n    else:\n        session = VariableInterface.get_session(request.session_id)\n        await session.async_abort()\n    return Response(status_code=200)\n\n\n@router.post('/v1/chat/interactive', dependencies=[Depends(validate_json_request)], include_in_schema=False)\nasync def chat_interactive_v1(request, raw_request: Request = None):\n    return create_error_response(\n        HTTPStatus.BAD_REQUEST, 'v1/chat/interactive is deprecated, please launch server with --enable-prefix-cache '\n        'and use /v1/chat/completions instead.')\n\n\ndef handle_torchrun():\n    \"\"\"To disable mmengine logging logic when using torchrun.\"\"\"\n\n    def dummy_get_device_id():\n        return 0\n\n    if int(os.environ.get('LOCAL_RANK', -1)) > 0:\n        from lmdeploy.vl.model.utils import _set_func\n\n        # the replacement can't be recovered\n        _set_func('mmengine.logging.logger._get_device_id', dummy_get_device_id)\n\n\n@router.on_event('startup')\nasync def startup_event():\n    async_engine = VariableInterface.async_engine\n    async_engine.start_loop(asyncio.get_running_loop(), use_async_api=True)\n\n    if VariableInterface.proxy_url is None:\n        return\n    elif getattr(async_engine.engine, 'is_dummy', False):\n        logger.info('Dummy node started')\n        return\n    try:\n        import requests\n        engine_config = VariableInterface.async_engine.backend_config\n        engine_role = engine_config.role.value if hasattr(engine_config, 'role') else 1\n        url = f'{VariableInterface.proxy_url}/nodes/add'\n        data = {'url': VariableInterface.api_server_url, 'status': {'models': get_model_list(), 'role': engine_role}}\n        headers = {'accept': 'application/json', 'Content-Type': 'application/json'}\n        response = requests.post(url, headers=headers, json=data)\n\n        if response.status_code != 200:\n            raise HTTPException(status_code=response.status_code, detail=response.text)\n    except Exception as e:\n        logger.error(f'Service registration failed: {e}')\n\n\n@router.on_event('shutdown')\nasync def shutdown_event():\n    async_engine = VariableInterface.async_engine\n    if async_engine is not None:\n        async_engine.close()\n\n\nasync def validation_exception_handler(request: Request, exc: RequestValidationError):\n    \"\"\"Handler for RequestValidationError.\"\"\"\n    return JSONResponse(\n        status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,\n        content=jsonable_encoder({\n            'detail': exc.errors(),\n            'body': exc.body\n        }),\n    )\n\n\nclass ConcurrencyLimitMiddleware(BaseHTTPMiddleware):\n\n    def __init__(self, app: FastAPI, max_concurrent_requests: int):\n        super().__init__(app)\n        self.semaphore = asyncio.Semaphore(max_concurrent_requests)\n\n    async def dispatch(self, request: Request, call_next):\n        async with self.semaphore:\n            response = await call_next(request)\n            return response\n\n\ndef set_parsers(reasoning_parser: str | None = None, tool_parser: str | None = None):\n    \"\"\"Set tool parser and reasoning parsers.\"\"\"\n    # set reasoning parser\n    if reasoning_parser is not None:\n        if reasoning_parser in ReasoningParserManager.module_dict:\n            tokenizer = VariableInterface.async_engine.tokenizer\n            VariableInterface.reasoning_parser = ReasoningParserManager.get(reasoning_parser)(tokenizer)\n        else:\n            raise ValueError(\n                f'The reasoning parser {reasoning_parser} is not in the parser list: {ReasoningParserManager.module_dict.keys()}'  # noqa\n            )\n    # set tool parsers\n    if tool_parser is not None:\n        if tool_parser in ToolParserManager.module_dict:\n            tokenizer = VariableInterface.async_engine.tokenizer\n            VariableInterface.tool_parser = ToolParserManager.get(tool_parser)(tokenizer)\n        else:\n            raise ValueError(\n                f'The reasoning parser {tool_parser} is not in the parser list: {ToolParserManager.module_dict.keys()}'  # noqa\n            )\n\n\ndef mount_metrics(app: FastAPI, backend_config: PytorchEngineConfig | TurbomindEngineConfig):\n    if not getattr(backend_config, 'enable_metrics', False):\n        return\n\n    from prometheus_client import REGISTRY, make_asgi_app\n    registry = REGISTRY\n\n    # Add prometheus asgi middleware to route /metrics requests\n    metrics_route = Mount('/metrics', make_asgi_app(registry=registry))\n\n    # Workaround for 307 Redirect for /metrics\n    metrics_route.path_regex = re.compile('^/metrics(?P<path>.*)$')\n    app.routes.append(metrics_route)\n\n\ndef create_lifespan_handler(backend_config: PytorchEngineConfig | TurbomindEngineConfig, async_engine: AsyncEngine):\n    \"\"\"Factory function to create a lifespan handler.\"\"\"\n\n    @asynccontextmanager\n    async def lifespan_handler(app: FastAPI):\n        task = None\n        try:\n            if getattr(backend_config, 'enable_metrics', False):\n                metrics_processor.start_metrics_handler(enable_metrics=True)\n                log_interval = 10.\n\n                async def _force_log():\n                    while True:\n                        await asyncio.sleep(log_interval)\n\n                        # periodically update schedule metrics, as they change less frequently than iteration stats\n                        schedule_metrics = async_engine.get_schedule_metrics()\n                        await metrics_processor.update_schedule_stats(schedule_metrics)\n\n                        await async_engine.do_log_stats()\n\n                task = asyncio.create_task(_force_log())\n\n            yield\n        finally:\n            if task:\n                task.cancel()\n            await metrics_processor.stop_metrics_handler()\n\n    return lifespan_handler\n\n\ndef serve(model_path: str,\n          model_name: str | None = None,\n          backend: Literal['turbomind', 'pytorch'] = 'turbomind',\n          backend_config: PytorchEngineConfig | TurbomindEngineConfig | None = None,\n          chat_template_config: ChatTemplateConfig | None = None,\n          server_name: str = '0.0.0.0',\n          server_port: int = 23333,\n          allow_origins: list[str] = ['*'],\n          allow_credentials: bool = True,\n          allow_methods: list[str] = ['*'],\n          allow_headers: list[str] = ['*'],\n          log_level: str = 'ERROR',\n          api_keys: list[str] | str | None = None,\n          ssl: bool = False,\n          proxy_url: str | None = None,\n          max_log_len: int | None = None,\n          disable_fastapi_docs: bool = False,\n          max_concurrent_requests: int | None = None,\n          reasoning_parser: str | None = None,\n          tool_call_parser: str | None = None,\n          allow_terminate_by_client: bool = False,\n          enable_abort_handling: bool = False,\n          speculative_config: SpeculativeConfig | None = None,\n          **kwargs):\n    \"\"\"An example to perform model inference through the command line\n    interface.\n\n    Args:\n        model_path (str): the path of a model.\n            It could be one of the following options:\n                - i) A local directory path of a turbomind model which is\n                    converted by `lmdeploy convert` command or download from\n                    ii) and iii).\n                - ii) The model_id of a lmdeploy-quantized model hosted\n                    inside a model repo on huggingface.co, such as\n                    \"InternLM/internlm-chat-20b-4bit\",\n                    \"lmdeploy/llama2-chat-70b-4bit\", etc.\n                - iii) The model_id of a model hosted inside a model repo\n                    on huggingface.co, such as \"internlm/internlm-chat-7b\",\n                    \"Qwen/Qwen-7B-Chat \", \"baichuan-inc/Baichuan2-7B-Chat\"\n                    and so on.\n        model_name (str): the name of the served model. It can be accessed\n            by the RESTful API `/v1/models`. If it is not specified,\n            `model_path` will be adopted\n        backend (str): either `turbomind` or `pytorch` backend. Default to\n            `turbomind` backend.\n        backend_config (TurbomindEngineConfig | PytorchEngineConfig): beckend\n            config instance. Default to none.\n        chat_template_config (ChatTemplateConfig): chat template configuration\n            Default to None.\n        server_name (str): host ip for serving\n        server_port (int): server port\n        tp (int): tensor parallel\n        allow_origins (List[str]): a list of allowed origins for CORS\n        allow_credentials (bool): whether to allow credentials for CORS\n        allow_methods (List[str]): a list of allowed HTTP methods for CORS\n        allow_headers (List[str]): a list of allowed HTTP headers for CORS\n        log_level(str): set log level whose value among [CRITICAL, ERROR,\n            WARNING, INFO, DEBUG]\n        api_keys (List[str] | str | None): Optional list of API keys. Accepts\n            string type as a single api_key. Default to None, which means no\n            api key applied.\n        ssl (bool): Enable SSL. Requires OS Environment variables\n            'SSL_KEYFILE' and 'SSL_CERTFILE'.\n        proxy_url (str): The proxy url to register the api_server.\n        max_log_len (int): Max number of prompt characters or prompt tokens\n            being printed in log. Default: Unlimited\n        max_concurrent_requests: This refers to the number of concurrent\n            requests that the server can handle. The server is designed to\n            process the engine’s tasks once the maximum number of concurrent\n            requests is reached, regardless of any additional requests sent by\n            clients concurrently during that time. Default to None.\n        reasoning_parser (str): The reasoning parser name.\n        tool_call_parser (str): The tool call parser name.\n        allow_terminate_by_client (bool): Allow request from client to terminate server.\n    \"\"\"\n    if os.getenv('TM_LOG_LEVEL') is None:\n        os.environ['TM_LOG_LEVEL'] = log_level\n    logger.setLevel(log_level)\n\n    VariableInterface.allow_terminate_by_client = allow_terminate_by_client\n    VariableInterface.enable_abort_handling = enable_abort_handling\n\n    ssl_keyfile, ssl_certfile, http_or_https = None, None, 'http'\n    if ssl:\n        ssl_keyfile = os.environ['SSL_KEYFILE']\n        ssl_certfile = os.environ['SSL_CERTFILE']\n        http_or_https = 'https'\n\n    handle_torchrun()\n    _, pipeline_class = get_task(backend, model_path)\n    if isinstance(backend_config, PytorchEngineConfig):\n        backend_config.enable_mp_engine = True\n        # router replay\n        if backend_config.enable_return_routed_experts:\n            backend_config.enable_transfer_obj_ref = True\n    VariableInterface.async_engine = pipeline_class(model_path=model_path,\n                                                    model_name=model_name,\n                                                    backend=backend,\n                                                    backend_config=backend_config,\n                                                    chat_template_config=chat_template_config,\n                                                    max_log_len=max_log_len,\n                                                    speculative_config=speculative_config,\n                                                    **kwargs)\n    # set reasoning parser and tool parser\n    set_parsers(reasoning_parser, tool_call_parser)\n\n    # create FastAPI lifespan events\n    lifespan = create_lifespan_handler(backend_config, VariableInterface.async_engine)\n\n    if disable_fastapi_docs:\n        app = FastAPI(docs_url=None, redoc_url=None, openapi_url=None, lifespan=lifespan)\n    else:\n        app = FastAPI(docs_url='/', lifespan=lifespan)\n\n    app.include_router(router)\n    app.add_exception_handler(RequestValidationError, validation_exception_handler)\n    mount_metrics(app, backend_config)\n\n    if allow_origins:\n        app.add_middleware(\n            CORSMiddleware,\n            allow_origins=allow_origins,\n            allow_credentials=allow_credentials,\n            allow_methods=allow_methods,\n            allow_headers=allow_headers,\n        )\n\n    if api_keys is not None and (tokens := [key for key in api_keys if key]):\n        from lmdeploy.serve.utils.server_utils import AuthenticationMiddleware\n\n        app.add_middleware(AuthenticationMiddleware, tokens=tokens)\n\n    # set the maximum number of concurrent requests\n    if max_concurrent_requests is not None:\n        app.add_middleware(ConcurrencyLimitMiddleware, max_concurrent_requests=max_concurrent_requests)\n\n    if proxy_url is not None:\n        VariableInterface.proxy_url = proxy_url\n        VariableInterface.api_server_url = f'{http_or_https}://{server_name}:{server_port}'  # noqa\n    for i in range(3):\n        print(f'HINT:    Please open \\033[93m\\033[1m{http_or_https}://'\n              f'{server_name}:{server_port}\\033[0m in a browser for detailed api'\n              ' usage!!!')\n    uvicorn.run(app=app,\n                host=server_name,\n                port=server_port,\n                log_level=os.getenv('UVICORN_LOG_LEVEL', 'info').lower(),\n                ssl_keyfile=ssl_keyfile,\n                ssl_certfile=ssl_certfile,\n                timeout_keep_alive=int(os.environ.get('UVICORN_TIMEOUT_KEEP_ALIVE', 5)))\n\n\nif __name__ == '__main__':\n    import fire\n\n    fire.Fire(serve)\n"
  },
  {
    "path": "lmdeploy/serve/openai/harmony_utils.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n# Modified from https://github.com/vllm-project/vllm/blob/v0.10.2rc1/vllm/entrypoints/harmony_utils.py\nfrom typing import List\n\nimport shortuuid\nfrom openai_harmony import HarmonyEncodingName, Role, StreamableParser, load_harmony_encoding\n\nfrom lmdeploy.serve.openai.protocol import (ChatMessage, DeltaFunctionCall, DeltaMessage, DeltaToolCall, FunctionCall,\n                                            ToolCall)\n\n_harmony_encoding = None\n\n\ndef get_encoding():\n    global _harmony_encoding\n    if _harmony_encoding is None:\n        _harmony_encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)\n    return _harmony_encoding\n\n\ndef get_streamable_parser_for_assistant() -> 'StreamableParser':\n    return StreamableParser(get_encoding(), role=Role.ASSISTANT)\n\n\nclass GptOssChatParser:\n\n    def __init__(self):\n        self.parser = get_streamable_parser_for_assistant()\n\n    def parse_streaming(self, tokens: List[int]) -> DeltaMessage:\n        parser = self.parser\n        delta_message = DeltaMessage(role='assistant')\n        content = ''\n        reasoning_content = ''\n        tool_calls = []\n        delta_tool_call = None\n        for token in tokens:\n            prev_recipient = parser.current_recipient\n            parser.process(token)\n            cur_channel = parser.current_channel\n            cur_recipient = parser.current_recipient\n            delta_text = parser.last_content_delta or ''\n            if cur_channel == 'final':\n                content += delta_text\n            elif cur_channel == 'analysis':\n                reasoning_content += delta_text\n            elif cur_channel == 'commentary' and cur_recipient and cur_recipient.startswith('functions.'):\n                base_index = 0\n                for msg in parser.messages:\n                    if msg.channel == 'commentary' and msg.recipient and msg.recipient.startswith('functions.'):\n                        base_index += 1\n                if prev_recipient != cur_recipient:\n                    if delta_tool_call is not None:\n                        tool_calls.append(delta_tool_call)\n                    tool_name = cur_recipient.split('functions.', 1)[1]\n                    delta_tool_call = DeltaToolCall(id=f'chatcmpl-tool-{shortuuid.random()}',\n                                                    type='function',\n                                                    index=base_index,\n                                                    function=DeltaFunctionCall(name=tool_name, arguments=''))\n                elif delta_text:\n                    # Continuing the same tool call. Ensure we don't duplicate the\n                    # very first delta string in this chunk. Previously we initialized\n                    # with arguments=delta_text and then appended again, causing\n                    # duplicated content like \"locationlocation\".\n                    if delta_tool_call is None:\n                        # We are in the middle of a tool call carried over from the\n                        # previous chunk. Initialize an empty arguments buffer.\n                        delta_tool_call = DeltaToolCall(index=base_index, function=DeltaFunctionCall(arguments=''))\n                    delta_tool_call.function.arguments += delta_text\n\n        if delta_tool_call:\n            tool_calls.append(delta_tool_call)\n\n        delta_message.content = content if content else None\n        delta_message.reasoning_content = reasoning_content if reasoning_content else None\n        delta_message.tool_calls = tool_calls\n        return delta_message\n\n    def parse_full(self, tokens: List[int]) -> ChatMessage:\n        delta_message = self.parse_streaming(tokens)\n        tool_calls = []\n        for delta_tool_call in delta_message.tool_calls:\n            function = FunctionCall(**delta_tool_call.function.model_dump())\n            tool_calls.append(ToolCall(id=delta_tool_call.id, type=delta_tool_call.type, function=function))\n        chat_message = ChatMessage(role='assistant',\n                                   content=delta_message.content,\n                                   tool_calls=tool_calls,\n                                   reasoning_content=delta_message.reasoning_content)\n        return chat_message\n"
  },
  {
    "path": "lmdeploy/serve/openai/launch_server.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport copy\nimport multiprocessing as mp\nimport os\nimport random\nimport signal\nimport socket\nimport sys\nfrom typing import List, Union\n\nfrom lmdeploy.messages import PytorchEngineConfig, TurbomindEngineConfig\nfrom lmdeploy.utils import get_logger\n\nfrom .api_server import serve\n\nlogger = get_logger('lmdeploy')\n\n\ndef find_available_ports(num: int) -> List[int]:\n    \"\"\"Find available port.\"\"\"\n\n    def __is_port_ok(port: int):\n        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:\n            try:\n                s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)\n                s.bind(('127.0.0.1', port))\n                s.listen(1)\n                return True\n            except Exception:\n                return False\n\n    ports = []\n    test_port = 3000\n    while len(ports) < num:\n        test_port += random.randint(10, 500)\n        if __is_port_ok(test_port):\n            ports.append(test_port)\n\n    return ports\n\n\ndef get_host_ip():\n    \"\"\"Get host ip.\"\"\"\n    with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:\n        s.connect(('8.8.8.8', 0))\n        ip = s.getsockname()[0]\n        return ip\n\n\ndef _run_server(gpu_ids: List[int], model_path: str, **kwargs):\n    \"\"\"Launch a server process.\"\"\"\n    cuda_visible_devices = ','.join([str(_) for _ in gpu_ids])\n    os.setpgrp()\n    if len(gpu_ids) > 0:\n        os.environ['CUDA_VISIBLE_DEVICES'] = cuda_visible_devices\n    serve(model_path, **kwargs)\n\n\ndef cleanup_processes(processes: List[mp.Process]):\n    \"\"\"Clean up server process.\"\"\"\n    for process in processes:\n        logger.info(f'Terminating process group {process.pid}')\n        try:\n            os.killpg(process.pid, signal.SIGTERM)\n        except ProcessLookupError:\n            # Process group may already be terminated\n            pass\n\n    # Wait for processes to terminate\n    for process in processes:\n        process.join(timeout=15)\n        if process.is_alive():\n            logger.warning(f'Process {process.pid} did not terminate gracefully, forcing kill')\n            try:\n                os.killpg(process.pid, signal.SIGKILL)\n            except ProcessLookupError:\n                pass\n\n    logger.info('All processes terminated')\n    sys.exit(0)\n\n\ndef launch_server(num_nodes: int,\n                  node_rank: int,\n                  model_path: str,\n                  backend_config: Union[PytorchEngineConfig, TurbomindEngineConfig],\n                  proxy_url: str = None,\n                  **kwargs):\n    \"\"\"Run multiple server processes in dp mode.\"\"\"\n    assert proxy_url is not None, 'Please launch proxy server and pass proxy_url'\n    log_level = kwargs.get('log_level', 'ERROR')\n    logger.setLevel(log_level)\n\n    mp.set_start_method('spawn', force=True)\n    dp = backend_config.dp\n    tp = backend_config.tp\n    ep = backend_config.ep\n    assert dp > 1, f'only support dp > 1, but give dp={dp}'\n    assert tp > 1 or ep > 1, f'only support tp > 1 or ep > 1, but given tp={tp} ep={ep}'\n\n    num_devices = max(dp, tp, ep)\n    dp_per_node = dp // num_nodes\n    tp_per_dp = num_devices // dp\n    http_or_https = 'https' if kwargs.get('ssl', False) else 'http'\n    model_name = kwargs.get('model_name', None)\n    if model_name is None:\n        model_name = model_path\n    server_name = get_host_ip()\n    server_urls = []\n    processes = []\n\n    server_port_li = find_available_ports(dp_per_node)\n\n    for idx in range(dp_per_node):\n        backend_config_dp = copy.deepcopy(backend_config)\n        dp_rank = node_rank * dp_per_node + idx\n        gpu_ids_per_dp = [gid for gid in range(idx * tp_per_dp, (idx + 1) * tp_per_dp)]\n        backend_config_dp.dp_rank = dp_rank\n        server_port = server_port_li[idx]\n\n        cur_server_kwargs = dict()\n        cur_server_kwargs.update(kwargs)\n        cur_server_kwargs['server_name'] = server_name\n        cur_server_kwargs['server_port'] = server_port\n        cur_server_kwargs['backend_config'] = backend_config_dp\n        cur_server_kwargs['proxy_url'] = proxy_url\n        url = f'{http_or_https}://{server_name}:{server_port}'\n        server_urls.append(url)\n        logger.info(f'create server with url={url}')\n        logger.info(f'backend_config_dp={backend_config_dp} gpus={gpu_ids_per_dp}')\n        proc = mp.Process(target=_run_server, args=(gpu_ids_per_dp, model_path), kwargs=cur_server_kwargs)\n        proc.start()\n        processes.append(proc)\n\n    # bind signal\n    signal.signal(signal.SIGINT, lambda sig, frame: cleanup_processes(processes))\n    signal.signal(signal.SIGTERM, lambda sig, frame: cleanup_processes(processes))\n    signal.signal(signal.SIGQUIT, lambda sig, frame: cleanup_processes(processes))\n\n    for p in processes:\n        p.join()\n"
  },
  {
    "path": "lmdeploy/serve/openai/protocol.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n# Modified from\n# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py\nimport time\nfrom typing import Any, Dict, List, Literal, Optional, Union\n\nimport shortuuid\nfrom pydantic import BaseModel, ConfigDict, Field\n\n\nclass ErrorResponse(BaseModel):\n    \"\"\"Error responses.\"\"\"\n    message: str\n    type: str\n    code: int\n    param: Optional[str] = None\n    object: str = 'error'\n\n\nclass ModelPermission(BaseModel):\n    \"\"\"Model permissions.\"\"\"\n    id: str = Field(default_factory=lambda: f'modelperm-{shortuuid.random()}')\n    object: str = 'model_permission'\n    created: int = Field(default_factory=lambda: int(time.time()))\n    allow_create_engine: bool = False\n    allow_sampling: bool = True\n    allow_logprobs: bool = True\n    allow_search_indices: bool = True\n    allow_view: bool = True\n    allow_fine_tuning: bool = False\n    organization: str = '*'\n    group: Optional[str] = None\n    is_blocking: bool = False\n\n\nclass ModelCard(BaseModel):\n    \"\"\"Model cards.\"\"\"\n    id: str\n    object: str = 'model'\n    created: int = Field(default_factory=lambda: int(time.time()))\n    owned_by: str = 'lmdeploy'\n    root: Optional[str] = None\n    parent: Optional[str] = None\n    permission: List[ModelPermission] = []\n\n\nclass ModelList(BaseModel):\n    \"\"\"Model list consists of model cards.\"\"\"\n    object: str = 'list'\n    data: List[ModelCard] = []\n\n\nclass UsageInfo(BaseModel):\n    \"\"\"Usage information.\"\"\"\n    prompt_tokens: int = 0\n    total_tokens: int = 0\n    completion_tokens: Optional[int] = 0\n\n\nclass Function(BaseModel):\n    \"\"\"Function descriptions.\"\"\"\n    description: Optional[str] = Field(default=None, examples=[None])\n    name: str\n    parameters: Optional[Dict[str, Any]] = None\n\n\nclass Tool(BaseModel):\n    \"\"\"Function wrapper.\"\"\"\n    type: str = Field(default='function', examples=['function'])\n    function: Function\n\n\nclass ToolChoiceFuncName(BaseModel):\n    \"\"\"The name of tool choice function.\"\"\"\n    name: str\n\n\nclass ToolChoice(BaseModel):\n    \"\"\"The tool choice definition.\"\"\"\n    function: ToolChoiceFuncName\n    type: Literal['function'] = Field(default='function', examples=['function'])\n\n\nclass StreamOptions(BaseModel):\n    \"\"\"The stream options.\"\"\"\n    include_usage: Optional[bool] = False\n\n\nclass JsonSchema(BaseModel):\n    name: str\n    # description is not used since it depends on model\n    description: Optional[str] = None\n    # `schema` is a reserved field in Pydantic BaseModel\n    # use alias since pydantic does not support the OpenAI key `schema`\n    json_schema: Optional[Dict[str, Any]] = Field(default=None, alias='schema', examples=[None])\n    # strict is not used\n    strict: Optional[bool] = False\n    model_config = ConfigDict(serialize_by_alias=True)\n\n\nclass ResponseFormat(BaseModel):\n    # regex_schema is extended by lmdeploy to support regex output\n    type: Literal['text', 'json_object', 'json_schema', 'regex_schema']\n    json_schema: Optional[JsonSchema] = None\n    regex_schema: Optional[str] = None\n\n\nclass ChatCompletionRequest(BaseModel):\n    \"\"\"Chat completion request.\"\"\"\n    model: str\n\n    messages: Union[str, List[Dict[str, Any]]] = Field(examples=[[{'role': 'user', 'content': 'hi'}]])\n    temperature: Optional[float] = 0.7\n    top_p: Optional[float] = 1.0\n    tools: Optional[List[Tool]] = Field(default=None, examples=[None])\n    tool_choice: Union[ToolChoice, Literal['auto', 'required', 'none']] = Field(default='auto', examples=['none'])\n    logprobs: Optional[bool] = False\n    top_logprobs: Optional[int] = None\n    n: Optional[int] = 1\n    logit_bias: Optional[Dict[str, float]] = Field(default=None, examples=[None])\n    max_completion_tokens: Optional[int] = Field(\n        default=None,\n        examples=[None],\n        description=('An upper bound for the number of tokens that can be generated for a completion, '\n                     'including visible output tokens and reasoning tokens'),\n    )\n    max_tokens: Optional[int] = Field(\n        default=None,\n        examples=[None],\n        deprecated='max_tokens is deprecated in favor of the max_completion_tokens field',\n    )\n    stop: Optional[Union[str, List[str]]] = Field(default=None, examples=[None])\n\n    stream: Optional[bool] = False\n    stream_options: Optional[StreamOptions] = Field(default=None, examples=[None])\n    presence_penalty: Optional[float] = 0.0\n    frequency_penalty: Optional[float] = 0.0\n    user: Optional[str] = None\n    reasoning_effort: Optional[Literal['low', 'medium', 'high']] = None\n    response_format: Optional[ResponseFormat] = Field(default=None, examples=[None])\n    # additional argument of lmdeploy\n    do_preprocess: Optional[bool] = True\n    repetition_penalty: Optional[float] = 1.0\n    session_id: Optional[int] = -1\n    ignore_eos: Optional[bool] = False\n    skip_special_tokens: Optional[bool] = True\n    spaces_between_special_tokens: Optional[bool] = True\n    top_k: Optional[int] = 40\n    seed: Optional[int] = None\n    min_new_tokens: Optional[int] = Field(default=None, examples=[None])\n    min_p: float = 0.0\n    enable_thinking: Optional[bool] = None  # will be deprecated in the future\n    return_token_ids: Optional[bool] = False\n    include_stop_str_in_output: Optional[bool] = False\n    # kwargs for chat template renderer\n    chat_template_kwargs: dict[str, Any] | None = Field(\n        default=None,\n        description=('Additional keyword args to pass to the template renderer. '\n                     'Will be accessible by the chat template.'),\n    )\n    # kwargs for media IO\n    media_io_kwargs: Optional[dict[str, Any]] = Field(\n        default=None,\n        description=('Additional kwargs to pass to the media IO processing, keyed by modality.'),\n    )\n    # kwargs for hf processor\n    mm_processor_kwargs: Optional[dict[str, Any]] = Field(\n        default=None,\n        description=('Additional kwargs to pass to the HF processor'),\n    )\n\n\nclass FunctionCall(BaseModel):\n    \"\"\"Function response.\"\"\"\n    name: str\n    arguments: str\n\n\nclass ToolCall(BaseModel):\n    \"\"\"Tool call response.\"\"\"\n    id: str = Field(default_factory=lambda: f'chatcmpl-{shortuuid.random()}')\n    type: Literal['function'] = 'function'\n    function: FunctionCall\n\n\nclass ExtractedToolCallInformation(BaseModel):\n    # modified from https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/entrypoints/openai/protocol.py#L1199\n    # indicate if tools were called\n    tools_called: bool\n    # extracted tool calls\n    tool_calls: List[ToolCall]\n    # content - per OpenAI spec, content AND tool calls can be returned rarely\n    # But some models will do this intentionally\n    content: Optional[str] = None\n\n\nclass ChatMessage(BaseModel):\n    \"\"\"Chat messages.\"\"\"\n    role: str\n    content: Optional[str] = None\n    gen_tokens: Optional[List[int]] = None\n    reasoning_content: Optional[str] = Field(default=None, examples=[None])\n    tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])\n\n\nclass LogProbs(BaseModel):\n    text_offset: List[int] = Field(default_factory=list)\n    token_logprobs: List[Optional[float]] = Field(default_factory=list)\n    tokens: List[str] = Field(default_factory=list)\n    top_logprobs: Optional[List[Optional[Dict[str, float]]]] = None\n\n\nclass TopLogprob(BaseModel):\n    token: str\n    bytes: Optional[List[int]] = None\n    logprob: float\n\n\nclass ChatCompletionTokenLogprob(BaseModel):\n    token: str\n    bytes: Optional[List[int]] = None\n    logprob: float\n    top_logprobs: List[TopLogprob]\n\n\nclass ChoiceLogprobs(BaseModel):\n    content: Optional[List[ChatCompletionTokenLogprob]] = None\n\n\nclass ChatCompletionResponseChoice(BaseModel):\n    \"\"\"Chat completion response choices.\"\"\"\n    index: int\n    message: ChatMessage\n    logprobs: Optional[ChoiceLogprobs] = None\n    finish_reason: Optional[Literal['stop', 'length', 'tool_calls', 'error', 'abort']] = None\n\n\nclass ChatCompletionResponse(BaseModel):\n    \"\"\"Chat completion response.\"\"\"\n    id: str = Field(default_factory=lambda: f'chatcmpl-{shortuuid.random()}')\n    object: str = 'chat.completion'\n    created: int = Field(default_factory=lambda: int(time.time()))\n    model: str\n    choices: List[ChatCompletionResponseChoice]\n    usage: UsageInfo\n\n\nclass DeltaFunctionCall(BaseModel):\n    name: Optional[str] = None\n    arguments: Optional[str] = None\n\n\n# a tool call delta where everything is optional\nclass DeltaToolCall(BaseModel):\n    id: str = Field(default_factory=lambda: f'chatcmpl-tool-{shortuuid.random()}')\n    type: Literal['function'] = 'function'\n    index: int\n    function: Optional[DeltaFunctionCall] = None\n\n\nclass DeltaMessage(BaseModel):\n    \"\"\"Delta messages.\"\"\"\n    role: Optional[str] = None\n    content: Optional[str] = None\n    reasoning_content: Optional[str] = None\n    gen_tokens: Optional[List[int]] = None\n    tool_calls: List[DeltaToolCall] = Field(default_factory=list)\n\n\nclass ChatCompletionResponseStreamChoice(BaseModel):\n    \"\"\"Chat completion response stream choice.\"\"\"\n    index: int\n    delta: DeltaMessage\n    logprobs: Optional[ChoiceLogprobs] = None\n    finish_reason: Optional[Literal['stop', 'length', 'tool_calls', 'error', 'abort']] = None\n\n\nclass ChatCompletionStreamResponse(BaseModel):\n    \"\"\"Chat completion stream response.\"\"\"\n    id: str = Field(default_factory=lambda: f'chatcmpl-{shortuuid.random()}')\n    object: str = 'chat.completion.chunk'\n    created: int = Field(default_factory=lambda: int(time.time()))\n    model: str\n    choices: List[ChatCompletionResponseStreamChoice]\n    usage: Optional[UsageInfo] = None\n\n\nclass CompletionRequest(BaseModel):\n    \"\"\"Completion request.\"\"\"\n    model: str\n    prompt: Union[str, List[Any]]\n    suffix: Optional[str] = None\n    temperature: Optional[float] = 0.7\n    n: Optional[int] = 1\n    logprobs: Optional[int] = None\n    max_completion_tokens: Optional[int] = Field(\n        default=None,\n        examples=[None],\n        description=('An upper bound for the number of tokens that can be generated for a completion, '\n                     'including visible output tokens and reasoning tokens'),\n    )\n    max_tokens: Optional[int] = Field(\n        default=16,\n        examples=[16],\n        deprecated='max_tokens is deprecated in favor of the max_completion_tokens field',\n    )\n    stop: Optional[Union[str, List[str]]] = Field(default=None, examples=[None])\n    stream: Optional[bool] = False\n    stream_options: Optional[StreamOptions] = Field(default=None, examples=[None])\n    top_p: Optional[float] = 1.0\n    echo: Optional[bool] = False\n    presence_penalty: Optional[float] = 0.0\n    frequency_penalty: Optional[float] = 0.0\n    user: Optional[str] = None\n    # additional argument of lmdeploy\n    repetition_penalty: Optional[float] = 1.0\n    session_id: Optional[int] = -1\n    ignore_eos: Optional[bool] = False\n    skip_special_tokens: Optional[bool] = True\n    spaces_between_special_tokens: Optional[bool] = True\n    top_k: Optional[int] = 40  # for opencompass\n    seed: Optional[int] = None\n    min_p: float = 0.0\n    return_token_ids: Optional[bool] = False\n\n\nclass CompletionResponseChoice(BaseModel):\n    \"\"\"Completion response choices.\"\"\"\n    index: int\n    text: str\n    logprobs: Optional[LogProbs] = None\n    gen_tokens: Optional[List[int]] = None\n    finish_reason: Optional[Literal['stop', 'length', 'tool_calls', 'error', 'abort']] = None\n\n\nclass CompletionResponse(BaseModel):\n    \"\"\"Completion response.\"\"\"\n    id: str = Field(default_factory=lambda: f'cmpl-{shortuuid.random()}')\n    object: str = 'text_completion'\n    created: int = Field(default_factory=lambda: int(time.time()))\n    model: str\n    choices: List[CompletionResponseChoice]\n    usage: UsageInfo\n\n\nclass CompletionResponseStreamChoice(BaseModel):\n    \"\"\"Completion response stream choice.\"\"\"\n    index: int\n    text: str\n    logprobs: Optional[LogProbs] = None\n    gen_tokens: Optional[List[int]] = None\n    finish_reason: Optional[Literal['stop', 'length', 'tool_calls', 'error', 'abort']] = None\n\n\nclass CompletionStreamResponse(BaseModel):\n    \"\"\"Completion stream response.\"\"\"\n    id: str = Field(default_factory=lambda: f'cmpl-{shortuuid.random()}')\n    object: str = 'text_completion'\n    created: int = Field(default_factory=lambda: int(time.time()))\n    model: str\n    choices: List[CompletionResponseStreamChoice]\n    usage: Optional[UsageInfo] = None\n\n\nclass EmbeddingsRequest(BaseModel):\n    \"\"\"Embedding request.\"\"\"\n    model: str = None\n    input: Union[str, List[str]]\n    user: Optional[str] = None\n\n\nclass EmbeddingsResponse(BaseModel):\n    \"\"\"Embedding response.\"\"\"\n    object: str = 'list'\n    data: List[Dict[str, Any]]\n    model: str\n    usage: UsageInfo\n\n\nclass PoolingRequest(BaseModel):\n    \"\"\"Pooling request.\n\n    Currently we follow vLLM API protocol,\n    https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/openai/protocol.py#L1174\n\n    Notice that ideally we should reuse the input format of embedding API\n    https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/openai/protocol.py#L1174\n    https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/entrypoints/http_server.py#L383\n    \"\"\"\n    model: Optional[str] = None\n    input: Union[List[int], List[List[int]], str, List[str]]\n    encoding_format: Literal['float', 'base64'] = 'float'\n    dimensions: Optional[int] = None\n    user: Optional[str] = None\n\n\nclass PoolingResponse(BaseModel):\n    \"\"\"Pooling response.\"\"\"\n    id: str = Field(default_factory=lambda: f'pool-{shortuuid.random()}')\n    object: str = 'list'\n    created: int = Field(default_factory=lambda: int(time.time()))\n    model: str = None\n    data: List[Dict[str, Any]]\n    usage: UsageInfo\n\n\nclass EncodeRequest(BaseModel):\n    \"\"\"Encode request.\"\"\"\n    input: Union[str, List[str]]\n    do_preprocess: Optional[bool] = False\n    add_bos: Optional[bool] = True\n\n\nclass EncodeResponse(BaseModel):\n    \"\"\"Encode response.\"\"\"\n    input_ids: Union[List[int], List[List[int]]]\n    length: Union[int, List[int]]\n\n\nclass GenerateResponse(BaseModel):\n    \"\"\"Generate response.\"\"\"\n    text: str\n    tokens: int\n    input_tokens: int\n    history_tokens: int\n    finish_reason: Optional[Literal['stop', 'length', 'tool_calls', 'error', 'abort']] = None\n\n\nclass UpdateParamsRequest(BaseModel):\n    \"\"\"Update weights request.\"\"\"\n    serialized_named_tensors: Union[str, List[str], Dict]\n    load_format: Optional[str] = None  # 'flattened_bucket' or None\n    finished: bool = False\n\n\n# str for url/base64, base64 should be data:image/jpeg;base64, dict should be {'url': url/base64, 'options': ...}\nImageDataInputItem = Union[str, Dict]\nImageDataFormat = Union[ImageDataInputItem, List[ImageDataInputItem]]\n\n\n# /generate input\nclass GenerateReqInput(BaseModel):\n    session_id: Optional[int] = -1\n    prompt: Optional[str] = None\n    input_ids: Optional[List[int]] = None\n    image_data: Optional[ImageDataFormat] = None\n    return_logprob: Optional[bool] = None\n    max_tokens: int = 128\n    stop: Optional[Union[str, List[str]]] = None\n    stop_token_ids: Optional[List[int]] = None\n    stream: Optional[bool] = False\n    temperature: float = 1.0\n    repetition_penalty: Optional[float] = 1.0\n    ignore_eos: Optional[bool] = False\n    top_p: float = 1.0\n    top_k: int = 0\n    min_p: float = 0.0\n    skip_special_tokens: Optional[bool] = True\n    spaces_between_special_tokens: Optional[bool] = True\n    include_stop_str_in_output: Optional[bool] = False\n    return_routed_experts: Optional[bool] = False\n    repetition_ngram_size: int = 0\n    repetition_ngram_threshold: int = 0\n    # kwargs for media IO\n    media_io_kwargs: Optional[dict[str, Any]] = Field(\n        default=None,\n        description=('Additional kwargs to pass to the media IO processing, keyed by modality.'),\n    )\n    # kwargs for hf processor\n    mm_processor_kwargs: Optional[dict[str, Any]] = Field(\n        default=None,\n        description=('Additional kwargs to pass to the HF processor'),\n    )\n\n\nclass GenerateReqMetaOutput(BaseModel):\n    prompt_tokens: Optional[int] = None\n    completion_tokens: Optional[int] = None\n    finish_reason: Optional[Dict[str, Any]] = None\n    output_token_logprobs: Optional[List[tuple[float, int]]] = None  # (logprob, token_id)\n    routed_experts: Optional[Union[List[List[List[int]]], str]] = None  # (num_token, num_layer, topk_expert)\n\n\n# /generate output\nclass GenerateReqOutput(BaseModel):\n    text: str\n    output_ids: List[int]\n    meta_info: GenerateReqMetaOutput\n\n\nclass AbortRequest(BaseModel):\n    # Whether to abort all requests\n    abort_all: bool = False\n    # The finished reason data\n    finished_reason: Optional[Dict[str, Any]] = None\n    abort_message: Optional[str] = None\n    # The session ID to abort. If `abort_all` is True, this field is ignored.\n    session_id: Optional[int] = -1\n"
  },
  {
    "path": "lmdeploy/serve/openai/reasoning_parser/__init__.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom .deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser\nfrom .qwen_qwq_reasoning_parser import QwenQwQReasoningParser\nfrom .reasoning_parser import ReasoningParser, ReasoningParserManager\n\n__all__ = ['ReasoningParser', 'ReasoningParserManager', 'DeepSeekR1ReasoningParser', 'QwenQwQReasoningParser']\n"
  },
  {
    "path": "lmdeploy/serve/openai/reasoning_parser/deepseek_r1_reasoning_parser.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n# modified from https://github.com/vllm-project/vllm/tree/v0.7.3/vllm/entrypoints/openai/reasoning_parsers\nimport re\nfrom typing import Optional, Sequence, Tuple, Union\n\nfrom lmdeploy.serve.openai.protocol import ChatCompletionRequest, DeltaMessage\n\nfrom .reasoning_parser import ReasoningParser, ReasoningParserManager\n\n\n@ReasoningParserManager.register_module(name='deepseek-r1')\nclass DeepSeekR1ReasoningParser(ReasoningParser):\n    \"\"\"Reasoning parser for DeepSeek R1 model.\n\n    The DeepSeek R1 model uses <think>...</think> tokens to denote reasoning text. This parser extracts the reasoning\n    content from the model output.\n    \"\"\"\n\n    def __init__(self, tokenizer: object):\n        super().__init__(tokenizer)\n        self.think_start_token = '<think>'\n        self.think_end_token = '</think>'\n\n        self.reasoning_regex = re.compile(rf'{self.think_start_token}(.*?){self.think_end_token}', re.DOTALL)\n\n        if not self.model_tokenizer:\n            raise ValueError('The model tokenizer must be passed to the ReasoningParser '\n                             'constructor during construction.')\n\n        self.think_start_token_id = self.vocab.get(self.think_start_token)\n        self.think_end_token_id = self.vocab.get(self.think_end_token)\n        if (self.think_start_token_id is None or self.think_end_token_id is None):\n            raise RuntimeError('DeepSeek R1 reasoning parser could not locate think start/end '\n                               'tokens in the tokenizer!')\n\n    def extract_reasoning_content_streaming(\n        self,\n        previous_text: str,\n        current_text: str,\n        delta_text: str,\n        previous_token_ids: Sequence[int],\n        current_token_ids: Sequence[int],\n        delta_token_ids: Sequence[int],\n        **kwargs,\n    ) -> Union[DeltaMessage, None]:\n        \"\"\"Instance method that should be implemented for extracting reasoning\n        from an incomplete response; for use when handling reasoning calls and\n        streaming.\n\n        Has to be an instance method because  it requires state - the current tokens/diffs, but also the information\n        about what has previously been parsed and extracted (see constructor)\n        \"\"\"\n        # Skip single special tokens\n        if len(delta_token_ids) == 1:\n            if delta_token_ids[0] == self.think_end_token_id:\n                return DeltaMessage(content='')\n            elif delta_token_ids[0] == self.think_start_token_id:\n                return None\n\n        # Check if <think> is present in previous or delta.\n        # Keep compatibility with models that don't generate <think> tokens.\n        if self.think_start_token_id in previous_token_ids:\n            if self.think_end_token_id in delta_token_ids:\n                # <think> in previous, </think> in delta,\n                # extract reasoning content\n                end_index = delta_text.find(self.think_end_token)\n                reasoning_content = delta_text[:end_index]\n                content = delta_text[end_index + len(self.think_end_token):]\n                return DeltaMessage(reasoning_content=reasoning_content, content=content if content else None)\n            elif self.think_end_token_id in previous_token_ids:\n                # <think> in previous, </think> in previous,\n                return DeltaMessage(content=delta_text)\n            else:\n                # <think> in previous, no </think> in previous or delta,\n                # reasoning content continues\n                return DeltaMessage(reasoning_content=delta_text)\n        elif self.think_start_token_id in delta_token_ids:\n            if self.think_end_token_id in delta_token_ids:\n                # <think> in delta, </think> in delta, extract reasoning content\n                start_index = delta_text.find(self.think_start_token)\n                end_index = delta_text.find(self.think_end_token)\n                reasoning_content = delta_text[start_index + len(self.think_start_token):end_index]\n                content = delta_text[end_index + len(self.think_end_token):]\n                return DeltaMessage(reasoning_content=reasoning_content, content=content if content else None)\n            else:\n                # <think> in delta, no </think> in delta,\n                # reasoning content continues\n                return DeltaMessage(reasoning_content=delta_text)\n        else:\n            # No <think> in previous or delta, also need to check for </think>.\n            # Because the model may have generated </think> without <think>\n            # Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f\n            if self.think_end_token_id in delta_token_ids:\n                # </think> in delta with more tokens,\n                # extract reasoning content and content\n                end_index = delta_text.find(self.think_end_token)\n                reasoning_content = delta_text[:end_index]\n                content = delta_text[end_index + len(self.think_end_token):]\n                return DeltaMessage(reasoning_content=reasoning_content, content=content if content else None)\n            elif self.think_end_token_id in previous_token_ids:\n                # </think> in previous, thinking content ends\n                return DeltaMessage(content=delta_text)\n            else:\n                # no </think> in previous or delta, reasoning content continues\n                return DeltaMessage(reasoning_content=delta_text)\n\n    def extract_reasoning_content(self, model_output: str, request: ChatCompletionRequest,\n                                  **kwargs) -> Tuple[Optional[str], Optional[str]]:\n        \"\"\"Extract reasoning content from a complete model-generated string.\n\n        Used for non-streaming responses where we have the entire model response\n        available before sending to the client.\n\n        Args:\n            model_output (str): The model-generated string to extract reasoning content from.\n            request (ChatCompletionRequest): he request object that was used to generate the model_output.\n\n        Returns:\n            reasoning_content (str | None): The reasoning content.\n            final_output (str | None): The content.\n        \"\"\"\n        # DeepSeek R1 doesn't generate <think> now.\n        # Thus we assume the reasoning content is always at the start.\n        # Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f\n        if self.think_end_token not in model_output:\n            return model_output, None\n        else:\n            # Add a start token if it's missing to keep compatibility.\n            if self.think_start_token not in model_output:\n                model_output = f'{self.think_start_token}{model_output}'\n            # Use a regex to find the reasoning content\n            reasoning_content = self.reasoning_regex.findall(model_output)[0]\n\n            end_index = len(f'{self.think_start_token}{reasoning_content}{self.think_end_token}')\n            final_output = model_output[end_index:]\n\n            if len(final_output) == 0:\n                return reasoning_content, None\n\n            return reasoning_content, final_output\n"
  },
  {
    "path": "lmdeploy/serve/openai/reasoning_parser/qwen_qwq_reasoning_parser.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport re\nfrom typing import Optional, Sequence, Tuple, Union\n\nfrom lmdeploy.serve.openai.protocol import ChatCompletionRequest, DeltaMessage\n\nfrom .reasoning_parser import ReasoningParser, ReasoningParserManager\n\n\n@ReasoningParserManager.register_module(name=['qwen-qwq', 'intern-s1'])\nclass QwenQwQReasoningParser(ReasoningParser):\n    \"\"\"Reasoning parser for Qwen QwQ model.\n\n    The Qwen QwQ model uses <think>...</think> tokens to denote reasoning text. This parser extracts the reasoning\n    content from the model output.\n    \"\"\"\n\n    def __init__(self, tokenizer: object):\n        super().__init__(tokenizer)\n        self.think_start_token = '<think>'\n        self.think_end_token = '</think>'\n\n        self.reasoning_regex = re.compile(rf'{self.think_start_token}(.*?){self.think_end_token}', re.DOTALL)\n\n        if not self.model_tokenizer:\n            raise ValueError('The model tokenizer must be passed to the ReasoningParser '\n                             'constructor during construction.')\n\n    def extract_reasoning_content_streaming(\n        self,\n        previous_text: str,\n        current_text: str,\n        delta_text: str,\n        previous_token_ids: Sequence[int],\n        current_token_ids: Sequence[int],\n        delta_token_ids: Sequence[int],\n        **kwargs,\n    ) -> Union[DeltaMessage, None]:\n        \"\"\"Instance method that should be implemented for extracting reasoning\n        from an incomplete response; for use when handling reasoning calls and\n        streaming.\n\n        Has to be an instance method because  it requires state - the current tokens/diffs, but also the information\n        about what has previously been parsed and extracted (see constructor)\n        \"\"\"\n        # Skip single special tokens\n        if delta_text == self.think_end_token or delta_text == self.think_start_token:\n            return DeltaMessage(content='')\n\n        # Check if <think> is present in previous or delta.\n        # Keep compatibility with models that don't generate <think> tokens.\n        if self.think_start_token in previous_text:\n            if self.think_end_token in delta_text:\n                # <think> in previous, </think> in delta,\n                # extract reasoning content\n                end_index = delta_text.find(self.think_end_token)\n                reasoning_content = delta_text[:end_index]\n                content = delta_text[end_index + len(self.think_end_token):]\n                return DeltaMessage(reasoning_content=reasoning_content, content=content if content else None)\n            elif self.think_end_token in previous_text:\n                # <think> in previous, </think> in previous,\n                return DeltaMessage(content=delta_text)\n            else:\n                # <think> in previous, no </think> in previous or delta,\n                # reasoning content continues\n                return DeltaMessage(reasoning_content=delta_text)\n        elif self.think_start_token in delta_text:\n            if self.think_end_token in delta_text:\n                # <think> in delta, </think> in delta, extract reasoning content\n                start_index = delta_text.find(self.think_start_token)\n                end_index = delta_text.find(self.think_end_token)\n                reasoning_content = delta_text[start_index + len(self.think_start_token):end_index]\n                content = delta_text[end_index + len(self.think_end_token):]\n                return DeltaMessage(reasoning_content=reasoning_content, content=content if content else None)\n            else:\n                # <think> in delta, no </think> in delta,\n                # reasoning content continues\n                return DeltaMessage(reasoning_content=delta_text)\n        else:\n            # No <think> in previous or delta, also need to check for </think>.\n            # Because the model may have generated </think> without <think>\n            # Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f\n            if self.think_end_token in delta_text:\n                # </think> in delta with more tokens,\n                # extract reasoning content and content\n                end_index = delta_text.find(self.think_end_token)\n                reasoning_content = delta_text[:end_index]\n                content = delta_text[end_index + len(self.think_end_token):]\n                return DeltaMessage(reasoning_content=reasoning_content, content=content if content else None)\n            elif self.think_end_token in previous_text:\n                # </think> in previous, thinking content ends\n                return DeltaMessage(content=delta_text)\n            else:\n                # no </think> in previous or delta, reasoning content continues\n                return DeltaMessage(reasoning_content=delta_text)\n\n    def extract_reasoning_content(self, model_output: str, request: ChatCompletionRequest,\n                                  **kwargs) -> Tuple[Optional[str], Optional[str]]:\n        \"\"\"Extract reasoning content from a complete model-generated string.\n\n        Used for non-streaming responses where we have the entire model response\n        available before sending to the client.\n\n        Args:\n            model_output (str): The model-generated string to extract reasoning content from.\n            request (ChatCompletionRequest): he request object that was used to generate the model_output.\n\n        Returns:\n            reasoning_content (str | None): The reasoning content.\n            final_output (str | None): The content.\n        \"\"\"\n        # DeepSeek R1 doesn't generate <think> now.\n        # Thus we assume the reasoning content is always at the start.\n        # Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f\n        if self.think_end_token not in model_output:\n            # for qwen3 model, the reasoning content is wrapped by <think> </think> xml tags\n            return None, model_output\n        # Add a start token if it's missing to keep compatibility.\n        if self.think_start_token not in model_output:\n            model_output = f'{self.think_start_token}{model_output}'\n        # Use a regex to find the reasoning content\n        reasoning_content = self.reasoning_regex.findall(model_output)[0]\n\n        end_index = len(f'{self.think_start_token}{reasoning_content}{self.think_end_token}')\n        final_output = model_output[end_index:]\n        if reasoning_content.startswith('\\n'):\n            reasoning_content = reasoning_content[1:]\n        if reasoning_content.endswith('\\n'):\n            reasoning_content = reasoning_content[:-1]\n\n        if len(final_output) == 0:\n            return reasoning_content, None\n\n        return reasoning_content, final_output\n"
  },
  {
    "path": "lmdeploy/serve/openai/reasoning_parser/reasoning_parser.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n# modified from https://github.com/vllm-project/vllm/tree/v0.7.3/vllm/entrypoints/openai/reasoning_parsers\nfrom functools import cached_property\nfrom typing import Dict, Optional, Sequence, Tuple, Union\n\nfrom mmengine import Registry\n\nfrom lmdeploy.serve.openai.protocol import ChatCompletionRequest, DeltaMessage\n\nReasoningParserManager = Registry('reasoning_parser', locations=['lmdeploy.serve.openai.reasoning_parser'])\n\n\nclass ReasoningParser:\n\n    def __init__(self, tokenizer: object):\n        self.model_tokenizer = tokenizer\n\n    @cached_property\n    def vocab(self) -> Dict[str, int]:\n        # NOTE: Only PreTrainedTokenizerFast is guaranteed to have .vocab\n        # whereas all tokenizers have .get_vocab()\n        return self.model_tokenizer.get_vocab()\n\n    def extract_reasoning_content_streaming(\n        self,\n        previous_text: str,\n        current_text: str,\n        delta_text: str,\n        previous_token_ids: Sequence[int],\n        current_token_ids: Sequence[int],\n        delta_token_ids: Sequence[int],\n        **kwargs,\n    ) -> Union[DeltaMessage, None]:\n        \"\"\"Instance method that should be implemented for extracting reasoning\n        from an incomplete response; for use when handling reasoning calls and\n        streaming.\n\n        Has to be an instance method because  it requires state - the current tokens/diffs, but also the information\n        about what has previously been parsed and extracted (see constructor)\n        \"\"\"\n        raise NotImplementedError('ReasoningParser.extract_reasoning_content_streaming '\n                                  'has not been implemented!')\n\n    def extract_reasoning_content(self, model_output: str, request: ChatCompletionRequest,\n                                  **kwargs) -> Tuple[Optional[str], Optional[str]]:\n        \"\"\"Extract reasoning content from a complete model-generated string.\n\n        Used for non-streaming responses where we have the entire model response\n        available before sending to the client.\n\n        Args:\n            model_output (str): The model-generated string to extract reasoning content from.\n            request (ChatCompletionRequest): he request object that was used to generate the model_output.\n\n        Returns:\n            reasoning_content (str | None): The reasoning content.\n            final_output (str | None): The content.\n        \"\"\"\n        raise NotImplementedError('ReasoningParser.extract_reasoning_content '\n                                  'has not been implemented!')\n"
  },
  {
    "path": "lmdeploy/serve/openai/serving_chat_completion.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom typing import TYPE_CHECKING\n\nfrom .protocol import ChatCompletionRequest\n\nif TYPE_CHECKING:\n    from .api_server import VariableInterface\n\n\ndef check_request(request: ChatCompletionRequest, server_context: 'VariableInterface') -> str:\n    engine_config = server_context.get_engine_config()\n    session_manager = server_context.get_session_manager()\n    try:\n        # Check logprobs settings\n        logprobs_mode = engine_config.logprobs_mode\n        logprobs = request.logprobs\n        top_logprobs = request.top_logprobs or 0\n        if logprobs_mode is None and (logprobs or top_logprobs > 0):\n            return (f'Logprobs({logprobs})/top_logprobs({top_logprobs}) requested '\n                    'but not enabled logprobs_mode in engine configuration')\n        if logprobs_mode is not None and (top_logprobs < 0 or (not logprobs and top_logprobs > 0)):\n            return (f'Invalid logprobs({logprobs})/top_logprobs({top_logprobs}) requested '\n                    'when logprobs_mode is enabled in engine configuration.')\n    except AttributeError:\n        pass\n\n    if session_manager.has(request.session_id):\n        return f'The session_id {request.session_id!r} is occupied.'\n\n    # check sampling settings\n    if request.n <= 0:\n        return f'The n {request.n!r} must be a positive int.'\n    if not (0 < request.top_p <= 1):\n        return f'The top_p {request.top_p!r} must be in (0, 1].'\n    if request.top_k < 0:\n        return f'The top_k {request.top_k!r} cannot be a negative integer.'\n    if not (0 <= request.temperature <= 2):\n        return f'The temperature {request.temperature!r} must be in [0, 2]'\n\n    return ''\n"
  },
  {
    "path": "lmdeploy/serve/openai/serving_completion.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom typing import TYPE_CHECKING\n\nfrom .protocol import CompletionRequest\n\nif TYPE_CHECKING:\n    from .api_server import VariableInterface\n\n\ndef check_request(request: CompletionRequest, server_context: 'VariableInterface') -> str:\n    engine_config = server_context.get_engine_config()\n    session_manager = server_context.get_session_manager()\n    try:\n        # Check logprobs settings\n        logprobs_mode = engine_config.logprobs_mode\n        logprobs = request.logprobs or 0\n        if logprobs > 0 and logprobs_mode is None:\n            return f'logprobs({logprobs}) requested but not enabled logprobs_mode in engine configuration.'\n        if logprobs_mode is not None and logprobs < 0:\n            return 'logprobs must be non-negative when logprobs_mode is enabled in engine configuration.'\n    except AttributeError:\n        pass\n\n    if session_manager.has(request.session_id):\n        return f'The session_id {request.session_id!r} is occupied.'\n\n    # check sampling settings\n    if request.n <= 0:\n        return f'The n {request.n!r} must be a positive int.'\n    if not (0 < request.top_p <= 1):\n        return f'The top_p {request.top_p!r} must be in (0, 1].'\n    if request.top_k < 0:\n        return f'The top_k {request.top_k!r} cannot be a negative integer.'\n    if not (0 <= request.temperature <= 2):\n        return f'The temperature {request.temperature!r} must be in [0, 2]'\n\n    return ''\n"
  },
  {
    "path": "lmdeploy/serve/openai/serving_generate.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom typing import TYPE_CHECKING\n\nfrom .protocol import GenerateReqInput\n\nif TYPE_CHECKING:\n    from .api_server import VariableInterface\n\n\ndef check_request(request: GenerateReqInput, server_context: 'VariableInterface') -> str:\n    engine_config = server_context.get_engine_config()\n    session_manager = server_context.get_session_manager()\n    try:\n        # Check logprobs settings\n        logprobs_mode = engine_config.logprobs_mode\n        return_logprob = request.return_logprob\n        if logprobs_mode is None and return_logprob:\n            return f'return_logprob({return_logprob}) requested but not enabled logprobs_mode in engine configuration.'\n    except AttributeError:\n        pass\n\n    if (request.prompt is not None) ^ (request.input_ids is None):\n        return 'You must specify exactly one of prompt or input_ids'\n\n    if request.prompt is not None and request.prompt == '':\n        return 'The prompt must not be an empty string'\n\n    if request.input_ids is not None and len(request.input_ids) == 0:\n        return 'The input_ids must not be an empty list'\n\n    if request.max_tokens is not None and request.max_tokens <= 0:\n        return f'The max_tokens {request.max_tokens!r} must be a positive integer.'\n\n    if session_manager.has(request.session_id):\n        return f'The session_id {request.session_id!r} is occupied.'\n\n    # check sampling settings\n    if not (0 < request.top_p <= 1):\n        return f'The top_p {request.top_p!r} must be in (0, 1].'\n    if request.top_k < 0:\n        return f'The top_k {request.top_k!r} cannot be a negative integer.'\n    if not (0 <= request.temperature <= 2):\n        return f'The temperature {request.temperature!r} must be in [0, 2]'\n\n    return ''\n"
  },
  {
    "path": "lmdeploy/serve/openai/tool_parser/__init__.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom .internlm2_parser import Internlm2ToolParser\nfrom .llama3_parser import Llama3JsonToolParser\nfrom .qwen2d5_parser import Qwen2d5ToolParser\nfrom .qwen3_parser import Qwen3ToolParser\nfrom .qwen3coder_parser import Qwen3CoderToolParser\nfrom .tool_parser import ToolParser, ToolParserManager\n\n__all__ = [\n    'Internlm2ToolParser',\n    'Qwen2d5ToolParser',\n    'Qwen3ToolParser',\n    'Qwen3CoderToolParser',\n    'ToolParser',\n    'ToolParserManager',\n    'Llama3JsonToolParser',\n]\n"
  },
  {
    "path": "lmdeploy/serve/openai/tool_parser/internlm2_parser.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n# modified from https://github.com/vllm-project/vllm/tree/v0.7.3/vllm/entrypoints/openai/tool_parsers\nimport json\nfrom typing import Dict, Sequence, Union\n\nimport partial_json_parser\nimport shortuuid\nfrom partial_json_parser.core.options import Allow\n\nfrom lmdeploy.serve.openai.protocol import (ChatCompletionRequest, DeltaFunctionCall, DeltaMessage, DeltaToolCall,\n                                            ExtractedToolCallInformation, FunctionCall, ToolCall)\nfrom lmdeploy.utils import get_logger\n\nfrom .tool_parser import ToolParser, ToolParserManager\nfrom .utils import extract_intermediate_diff\n\nlogger = get_logger('lmdeploy')\n\n\n@ToolParserManager.register_module(['internlm', 'intern-s1'])\nclass Internlm2ToolParser(ToolParser):\n\n    def __init__(self, tokenizer: object):\n        super().__init__(tokenizer)\n        self.position = 0\n\n    def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest:\n        if request.tools and request.tool_choice != 'none':\n            # do not skip special tokens because internlm use the special\n            # tokens to indicated the start and end of the tool calls\n            # information.\n            request.skip_special_tokens = False\n        return request\n\n    def get_argments(self, obj):\n        if 'parameters' in obj:\n            return obj.get('parameters')\n        elif 'arguments' in obj:\n            return obj.get('arguments')\n        return None\n\n    def extract_tool_calls_streaming(\n        self,\n        previous_text: str,\n        current_text: str,\n        delta_text: str,\n        previous_token_ids: Sequence[int],\n        current_token_ids: Sequence[int],\n        delta_token_ids: Sequence[int],\n        request: ChatCompletionRequest,\n    ) -> Union[DeltaMessage, None]:\n        if '<|action_start|>' not in current_text:\n            self.position = len(current_text)\n            return DeltaMessage(content=delta_text)\n        # if the tool call is sended, return a empty delta message\n        # to make sure the finish_reason will be send correctly.\n        if self.current_tool_id > 0:\n            return DeltaMessage(content='')\n\n        last_pos = self.position\n        if '<|action_start|><|plugin|>\\n' not in current_text[last_pos:]:\n            return None\n\n        new_delta = current_text[last_pos:]\n        text, action = new_delta.split('<|action_start|><|plugin|>\\n')\n\n        if len(text) > 0:\n            self.position = self.position + len(text)\n            return DeltaMessage(content=text)\n\n        action = action.strip()\n        action = action.split('<|action_end|>'.strip())[0]\n\n        # bit mask flags for partial JSON parsing. If the name hasn't been\n        # sent yet, don't allow sending\n        # an incomplete string since OpenAI only ever (as far as I have\n        # seen) allows sending the entire tool/ function name at once.\n        flags = Allow.ALL if self.current_tool_name_sent \\\n            else Allow.ALL & ~Allow.STR\n\n        try:\n            parsable_arr = action\n\n            # tool calls are generated in an object in inernlm2\n            # it's not support parallel tool calls\n            try:\n                tool_call_arr: Dict = partial_json_parser.loads(parsable_arr, flags)\n            except partial_json_parser.core.exceptions.MalformedJSON:\n                logger.debug('not enough tokens to parse into JSON yet')\n                return None\n\n            # if the current tool name hasn't been sent, send if available\n            # - otherwise send nothing\n            if not self.current_tool_name_sent:\n                function_name = tool_call_arr.get('name')\n                if function_name:\n                    self.current_tool_id = self.current_tool_id + 1\n                    delta = DeltaMessage(tool_calls=[\n                        DeltaToolCall(index=self.current_tool_id,\n                                      type='function',\n                                      id=f'chatcmpl-tool-{shortuuid.random()}',\n                                      function=DeltaFunctionCall(name=function_name).model_dump(exclude_none=True))\n                    ])\n                    self.current_tool_name_sent = True\n                    self.streamed_args_for_tool.append('')\n                else:\n                    delta = None\n            # now we know we're on the same tool call and we're streaming\n            # arguments\n            else:\n                prev_arguments = self.get_argments(self.prev_tool_call_arr[self.current_tool_id])\n                cur_arguments = self.get_argments(tool_call_arr)\n\n                # not arguments generated\n                if not cur_arguments and not prev_arguments:\n                    delta = None\n                # will never happen\n                elif not cur_arguments and prev_arguments:\n                    logger.error('INVARIANT - impossible to have arguments reset '\n                                 'mid-arguments')\n                    delta = None\n                # first time to get parameters\n                elif cur_arguments and not prev_arguments:\n                    cur_arguments_json = json.dumps(cur_arguments, ensure_ascii=False)\n\n                    arguments_delta = cur_arguments_json[:cur_arguments_json.index(delta_text) + len(delta_text)]\n                    delta = DeltaMessage(tool_calls=[\n                        DeltaToolCall(index=self.current_tool_id,\n                                      function=DeltaFunctionCall(arguments=arguments_delta).model_dump(\n                                          exclude_none=True))\n                    ])\n                    self.streamed_args_for_tool[self.current_tool_id] += arguments_delta\n                # both prev and cur parameters, send the increase parameters\n                elif cur_arguments and prev_arguments:\n                    cur_args_json = json.dumps(cur_arguments, ensure_ascii=False)\n                    prev_args_json = json.dumps(prev_arguments, ensure_ascii=False)\n\n                    argument_diff = extract_intermediate_diff(cur_args_json, prev_args_json)\n\n                    delta = DeltaMessage(tool_calls=[\n                        DeltaToolCall(index=self.current_tool_id,\n                                      function=DeltaFunctionCall(arguments=argument_diff).model_dump(exclude_none=True))\n                    ])\n                    self.streamed_args_for_tool[self.current_tool_id] += argument_diff\n\n            # check to see if the name is defined and has been sent. if so,\n            # stream the name - otherwise keep waiting\n            # finish by setting old and returning None as base case\n            tool_call_arr['arguments'] = self.get_argments(tool_call_arr)\n            self.prev_tool_call_arr = [tool_call_arr]\n            return delta\n        except Exception:\n            logger.exception('Error trying to handle streaming tool call.')\n            logger.debug('Skipping chunk as a result of tool streaming extraction '\n                         'error')\n            return None\n\n    def extract_tool_calls(\n        self,\n        model_output: str,\n        request: ChatCompletionRequest,\n    ) -> ExtractedToolCallInformation:\n        text = model_output\n        tools = request.tools\n        if '<|action_start|><|plugin|>' in text:\n            text, action = text.split('<|action_start|><|plugin|>')\n            action = action.split('<|action_end|>'.strip())[0]\n            action = action[action.find('{'):]\n            action_dict = json.loads(action)\n            name, parameters = action_dict['name'], json.dumps(action_dict.get('parameters',\n                                                                               action_dict.get('arguments', {})),\n                                                               ensure_ascii=False)\n\n            if not tools or name not in [t.function.name for t in tools]:\n                ExtractedToolCallInformation(tools_called=False, tool_calls=[], content=text)\n\n            tool_calls = [ToolCall(function=FunctionCall(name=name, arguments=parameters))]\n            return ExtractedToolCallInformation(tools_called=True,\n                                                tool_calls=tool_calls,\n                                                content=text if len(text) > 0 else None)\n\n        return ExtractedToolCallInformation(tools_called=False, tool_calls=[], content=text)\n"
  },
  {
    "path": "lmdeploy/serve/openai/tool_parser/llama3_parser.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport json\nimport re\nfrom typing import Dict, List, Sequence, Union\n\nimport partial_json_parser\nimport shortuuid\nfrom partial_json_parser.core.options import Allow\n\nfrom lmdeploy.serve.openai.protocol import (ChatCompletionRequest, DeltaFunctionCall, DeltaMessage, DeltaToolCall,\n                                            ExtractedToolCallInformation, FunctionCall, ToolCall)\nfrom lmdeploy.utils import get_logger\n\nfrom .tool_parser import ToolParser, ToolParserManager\nfrom .utils import find_common_prefix, is_complete_json, partial_json_loads\n\nlogger = get_logger('lmdeploy')\n\n\n@ToolParserManager.register_module('llama3')\nclass Llama3JsonToolParser(ToolParser):\n    \"\"\"Tool call parser for Llama 3.1 models intended for use with the\n    examples/tool_chat_template_llama.jinja template.\n\n    Used when --tool-call-parser llama3 are all set\n    \"\"\"\n\n    def __init__(self, tokenizer: object):\n        super().__init__(tokenizer)\n\n        # initialize properties used for state when parsing tool calls in\n        # streaming mode\n        self.prev_tool_call_arr: List[Dict] = []\n        self.current_tool_id: int = -1\n        self.current_tool_name_sent: bool = False\n        self.streamed_args_for_tool: List[str] = []  # map what has been streamed for each tool so far to a list\n        self.bot_token = '<|python_tag|>'\n        self.bot_token_id = tokenizer.encode(self.bot_token, add_special_tokens=False)[0]\n        self.tool_call_regex = re.compile(r'\\[{.*?}\\]', re.DOTALL)\n\n    def extract_tool_calls(self, model_output: str, request: ChatCompletionRequest) -> ExtractedToolCallInformation:\n        \"\"\"Extract the tool calls from a complete model response.\"\"\"\n        try:\n            # load the JSON, and then use it to build the Function and\n            # Tool Call\n            action, _ = model_output.split('</function>')\n            parameters = action[action.find('{'):]\n            name = action.split('<function=')[1].split('>{')[0]\n            call_info_list = [(name, parameters)]\n\n            tool_calls: List[ToolCall] = [\n                ToolCall(type='function', function=FunctionCall(name=name, arguments=arguments))\n                for name, arguments in call_info_list\n            ]\n\n            # get any content before  the tool call\n            ret = ExtractedToolCallInformation(tools_called=True, tool_calls=tool_calls, content=None)\n            return ret\n\n        except Exception:\n            logger.exception('Error in extracting tool call from response.')\n            # return information to just treat the tool call as regular JSON\n            return ExtractedToolCallInformation(tools_called=False, tool_calls=[], content=model_output)\n\n    def extract_tool_calls_streaming(\n        self,\n        previous_text: str,\n        current_text: str,\n        delta_text: str,\n        previous_token_ids: Sequence[int],\n        current_token_ids: Sequence[int],\n        delta_token_ids: Sequence[int],\n        request: ChatCompletionRequest,\n    ) -> Union[DeltaMessage, None]:\n\n        if not (current_text.startswith(self.bot_token) or current_text.startswith('{')):\n            return DeltaMessage(content=delta_text)\n\n        # bit mask flags for partial JSON parsing. If the name hasn't been\n        # sent yet, don't allow sending\n        # an incomplete string since OpenAI only ever (as far as I have\n        # seen) allows sending the entire tool/ function name at once.\n        flags = Allow.ALL if self.current_tool_name_sent \\\n            else Allow.ALL & ~Allow.STR\n        try:\n            tool_call_arr = []\n            is_complete = []\n            try:\n                # depending on the prompt format the Llama model may or may not\n                # prefix the output with the <|python_tag|> token\n                start_idx = len(self.bot_token) if current_text.startswith(self.bot_token) else 0\n                while start_idx < len(current_text):\n                    (obj, end_idx) = partial_json_loads(current_text[start_idx:], flags)\n                    is_complete.append(is_complete_json(current_text[start_idx:start_idx + end_idx]))\n                    start_idx += end_idx + len('; ')\n                    # depending on the prompt Llama can use\n                    # either arguments or parameters\n                    if 'parameters' in obj:\n                        assert 'arguments' not in obj, \\\n                            'model generated both parameters and arguments'\n                        obj['arguments'] = obj['parameters']\n                    tool_call_arr.append(obj)\n            except partial_json_parser.core.exceptions.MalformedJSON:\n                logger.debug('not enough tokens to parse into JSON yet')\n                return None\n\n            # select as the current tool call the one we're on the state at\n            current_tool_call: Dict = tool_call_arr[self.current_tool_id] \\\n                if len(tool_call_arr) > 0 else {}\n\n            # case -- if no tokens have been streamed for the tool, e.g.\n            #   only the array brackets, stream nothing\n            if len(tool_call_arr) == 0:\n                return None\n\n            # case: we are starting a new tool in the array\n            #   -> array has > 0 length AND length has moved past cursor\n            elif (len(tool_call_arr) > 0 and len(tool_call_arr) > self.current_tool_id + 1):\n\n                # if we're moving on to a new call, first make sure we\n                # haven't missed anything in the previous one that was\n                # auto-generated due to JSON completions, but wasn't\n                # streamed to the client yet.\n                if self.current_tool_id >= 0:\n                    cur_arguments = current_tool_call.get('arguments')\n                    if cur_arguments:\n                        cur_args_json = json.dumps(cur_arguments, ensure_ascii=False)\n                        sent = len(self.streamed_args_for_tool[self.current_tool_id])\n                        argument_diff = cur_args_json[sent:]\n\n                        logger.debug('got arguments diff: %s', argument_diff)\n                        delta = DeltaMessage(tool_calls=[\n                            DeltaToolCall(index=self.current_tool_id,\n                                          function=DeltaFunctionCall(arguments=argument_diff).model_dump(\n                                              exclude_none=True))\n                        ])\n                        self.streamed_args_for_tool[self.current_tool_id] += argument_diff\n                    else:\n                        delta = None\n                else:\n                    delta = None\n                # re-set stuff pertaining to progress in the current tool\n                self.current_tool_id = len(tool_call_arr) - 1\n                self.current_tool_name_sent = False\n                self.streamed_args_for_tool.append('')\n                logger.debug('starting on new tool %d', self.current_tool_id)\n                return delta\n\n            # if the current tool name hasn't been sent, send if available\n            # - otherwise send nothing\n            elif not self.current_tool_name_sent:\n                function_name = current_tool_call.get('name')\n                if function_name:\n\n                    delta = DeltaMessage(tool_calls=[\n                        DeltaToolCall(index=self.current_tool_id,\n                                      type='function',\n                                      id=f'chatcmpl-tool-{shortuuid.random()}',\n                                      function=DeltaFunctionCall(name=function_name).model_dump(exclude_none=True))\n                    ])\n                    self.current_tool_name_sent = True\n                else:\n                    delta = None\n\n            # now we know we're on the same tool call and we're streaming\n            # arguments\n            else:\n                cur_arguments = current_tool_call.get('arguments')\n                delta = None\n\n                if cur_arguments:\n                    sent = len(self.streamed_args_for_tool[self.current_tool_id])\n                    cur_args_json = json.dumps(cur_arguments, ensure_ascii=False)\n                    prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get('arguments')\n\n                    argument_diff = None\n                    if is_complete[self.current_tool_id]:\n                        argument_diff = cur_args_json[sent:]\n                    elif prev_arguments:\n                        prev_args_json = json.dumps(prev_arguments, ensure_ascii=False)\n                        if cur_args_json != prev_args_json:\n\n                            prefix = find_common_prefix(prev_args_json, cur_args_json)\n                            argument_diff = prefix[sent:]\n\n                    if argument_diff is not None:\n                        delta = DeltaMessage(tool_calls=[\n                            DeltaToolCall(index=self.current_tool_id,\n                                          function=DeltaFunctionCall(arguments=argument_diff).model_dump(\n                                              exclude_none=True))\n                        ])\n                        self.streamed_args_for_tool[self.current_tool_id] += argument_diff\n\n            self.prev_tool_call_arr = tool_call_arr\n            return delta\n\n        except Exception:\n            logger.exception('Error trying to handle streaming tool call.')\n            logger.debug('Skipping chunk as a result of tool streaming extraction '\n                         'error')\n            return None\n"
  },
  {
    "path": "lmdeploy/serve/openai/tool_parser/qwen2d5_parser.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport json\nimport re\nfrom typing import Dict, Sequence, Union\n\nimport partial_json_parser\nimport shortuuid\nfrom partial_json_parser.core.options import Allow\n\nfrom lmdeploy.serve.openai.protocol import (ChatCompletionRequest, DeltaFunctionCall, DeltaMessage, DeltaToolCall,\n                                            ExtractedToolCallInformation, FunctionCall, ToolCall)\nfrom lmdeploy.utils import get_logger\n\nfrom .tool_parser import ToolParser, ToolParserManager\nfrom .utils import extract_intermediate_diff\n\nlogger = get_logger('lmdeploy')\n\n\n@ToolParserManager.register_module(['qwen2d5'])\nclass Qwen2d5ToolParser(ToolParser):\n\n    def __init__(self, tokenizer: object):\n        super().__init__(tokenizer)\n        self.position = 0\n        self.tool_start_token = '<tool_call>'\n        self.tool_end_token = '</tool_call>'\n        self.pattern = r'<tool_call>(.*?)</tool_call>'\n\n    def get_argments(self, obj):\n        if 'parameters' in obj:\n            return obj.get('parameters')\n        elif 'arguments' in obj:\n            return obj.get('arguments')\n        return None\n\n    def extract_tool_calls_streaming(\n        self,\n        previous_text: str,\n        current_text: str,\n        delta_text: str,\n        previous_token_ids: Sequence[int],\n        current_token_ids: Sequence[int],\n        delta_token_ids: Sequence[int],\n        request: ChatCompletionRequest,\n    ) -> Union[DeltaMessage, None]:\n        if self.tool_start_token not in current_text:\n            self.position = len(current_text)\n            return DeltaMessage(content=delta_text)\n        # if the tool call is sended, return a empty delta message\n        # to make sure the finish_reason will be send correctly.\n        if self.current_tool_id > 0:\n            return DeltaMessage(content='')\n\n        last_pos = self.position\n        if self.tool_start_token not in current_text[last_pos:]:\n            return None\n\n        new_delta = current_text[last_pos:]\n        text, action = new_delta.split(self.tool_start_token)\n\n        if len(text) > 0:\n            self.position = self.position + len(text)\n            return DeltaMessage(content=text)\n\n        action = action.strip()\n        action = action.split(self.tool_end_token.strip())[0]\n\n        # bit mask flags for partial JSON parsing. If the name hasn't been\n        # sent yet, don't allow sending\n        # an incomplete string since OpenAI only ever (as far as I have\n        # seen) allows sending the entire tool/ function name at once.\n        flags = Allow.ALL if self.current_tool_name_sent \\\n            else Allow.ALL & ~Allow.STR\n\n        try:\n            parsable_arr = action\n\n            # tool calls are generated in an object in inernlm2\n            # it's not support parallel tool calls\n            try:\n                tool_call_arr: Dict = partial_json_parser.loads(parsable_arr, flags)\n            except partial_json_parser.core.exceptions.MalformedJSON:\n                logger.debug('not enough tokens to parse into JSON yet')\n                return None\n\n            # if the current tool name hasn't been sent, send if available\n            # - otherwise send nothing\n            if not self.current_tool_name_sent:\n                function_name = tool_call_arr.get('name')\n                if function_name:\n                    self.current_tool_id = self.current_tool_id + 1\n                    delta = DeltaMessage(tool_calls=[\n                        DeltaToolCall(index=self.current_tool_id,\n                                      type='function',\n                                      id=f'chatcmpl-tool-{shortuuid.random()}',\n                                      function=DeltaFunctionCall(name=function_name).model_dump(exclude_none=True))\n                    ])\n                    self.current_tool_name_sent = True\n                    self.streamed_args_for_tool.append('')\n                else:\n                    delta = None\n            # now we know we're on the same tool call and we're streaming\n            # arguments\n            else:\n                prev_arguments = self.get_argments(self.prev_tool_call_arr[self.current_tool_id])\n                cur_arguments = self.get_argments(tool_call_arr)\n\n                # not arguments generated\n                if not cur_arguments and not prev_arguments:\n                    delta = None\n                # will never happen\n                elif not cur_arguments and prev_arguments:\n                    logger.error('INVARIANT - impossible to have arguments reset '\n                                 'mid-arguments')\n                    delta = None\n                # first time to get parameters\n                elif cur_arguments and not prev_arguments:\n                    cur_arguments_json = json.dumps(cur_arguments, ensure_ascii=False)\n\n                    arguments_delta = cur_arguments_json[:cur_arguments_json.index(delta_text) + len(delta_text)]\n                    delta = DeltaMessage(tool_calls=[\n                        DeltaToolCall(index=self.current_tool_id,\n                                      function=DeltaFunctionCall(arguments=arguments_delta).model_dump(\n                                          exclude_none=True))\n                    ])\n                    self.streamed_args_for_tool[self.current_tool_id] += arguments_delta\n                # both prev and cur parameters, send the increase parameters\n                elif cur_arguments and prev_arguments:\n                    cur_args_json = json.dumps(cur_arguments, ensure_ascii=False)\n                    prev_args_json = json.dumps(prev_arguments, ensure_ascii=False)\n\n                    argument_diff = extract_intermediate_diff(cur_args_json, prev_args_json)\n\n                    delta = DeltaMessage(tool_calls=[\n                        DeltaToolCall(index=self.current_tool_id,\n                                      function=DeltaFunctionCall(arguments=argument_diff).model_dump(exclude_none=True))\n                    ])\n                    self.streamed_args_for_tool[self.current_tool_id] += argument_diff\n\n            # check to see if the name is defined and has been sent. if so,\n            # stream the name - otherwise keep waiting\n            # finish by setting old and returning None as base case\n            tool_call_arr['arguments'] = self.get_argments(tool_call_arr)\n            self.prev_tool_call_arr = [tool_call_arr]\n            return delta\n        except Exception:\n            logger.exception('Error trying to handle streaming tool call.')\n            logger.debug('Skipping chunk as a result of tool streaming extraction '\n                         'error')\n            return None\n\n    def extract_tool_calls(\n        self,\n        model_output: str,\n        request: ChatCompletionRequest,\n    ) -> ExtractedToolCallInformation:\n        text = model_output\n        if self.tool_start_token in text:\n\n            # get tool_call in text\n            match_result_list = re.findall(self.pattern, text, re.DOTALL)\n            tool_calls = []\n            for match_result in match_result_list:\n                action = json.loads(match_result)\n                name, arguments = action['name'], json.dumps(action['arguments'], ensure_ascii=False)\n                tool_calls.append(ToolCall(function=FunctionCall(name=name, arguments=arguments)))\n\n            # get text outside of tags\n            if not text.startswith('<tool_call>'):\n                text = text[:text.find('<tool_call>')]\n            elif not text.endswith('</tool_call>'):\n                text = text[text.rfind('</tool_call>') + len('</tool_call>'):]\n            else:\n                text = ''\n            return ExtractedToolCallInformation(tools_called=True,\n                                                tool_calls=tool_calls,\n                                                content=text if len(text) > 0 else None)\n\n        return ExtractedToolCallInformation(tools_called=False, tool_calls=[], content=text)\n"
  },
  {
    "path": "lmdeploy/serve/openai/tool_parser/qwen3_parser.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport json\nimport re\nfrom dataclasses import dataclass\nfrom typing import Dict, Optional, Sequence, Union\n\nimport shortuuid\n\nfrom lmdeploy.serve.openai.protocol import (ChatCompletionRequest, DeltaFunctionCall, DeltaMessage, DeltaToolCall,\n                                            ExtractedToolCallInformation, FunctionCall, ToolCall)\nfrom lmdeploy.utils import get_logger\n\nfrom .tool_parser import ToolParser, ToolParserManager\n\nlogger = get_logger('lmdeploy')\n\n\n@dataclass\nclass ParserState(object):\n    \"\"\"Maintains the state of parsing during tool call extraction.\"\"\"\n    position: int = 0  # Current position in the text being parsed\n    current_index: int = -1  # Index of the current tool call\n    parsing_reasoning: bool = False  # Whether currently parsing reasoning content\n\n    id: str = ''  # ID of the current tool call\n\n    def reset_tool_call(self):\n        \"\"\"Called when `</tool_call>` finish tag occurred.\"\"\"\n        self.id = ''\n\n\n@ToolParserManager.register_module(['qwen', 'qwen3'])\nclass Qwen3ToolParser(ToolParser):\n    \"\"\"Parser for Qwen3 model's tool call format.\n\n    Handles the extraction of tool calls from Qwen3's output format, which uses XML-like tags for tool calls and\n    reasoning.\n    \"\"\"\n\n    def __init__(self, tokenizer: object):\n        super().__init__(tokenizer)\n        self.tool_start_token = '<tool_call>'\n        self.tool_end_token = '</tool_call>'\n        self.tool_call_pat = re.compile(r'\\n*<tool_call>(.*?)</tool_call>', re.DOTALL)\n\n    def get_argments(self, obj):\n        \"\"\"Extract arguments from tool call object, handling different formats.\n\n        Supports both 'parameters' and 'arguments' keys in the tool call object.\n        \"\"\"\n        if 'parameters' in obj:\n            return obj.get('parameters')\n        elif 'arguments' in obj:\n            return obj.get('arguments')\n        return None\n\n    def _split(self, parser_state: ParserState, parsing_content: str):\n        \"\"\"Split content into tuple: (text_content, tool_content, has_tool_end)\n\n        This method parses the model output and separates it into regular text,\n        and tool call content.\n        \"\"\"\n        # tool call\n        try:\n            start_idx = parsing_content.index(self.tool_start_token)\n            # move to the beginning of tool_start_token\n            parser_state.position += start_idx\n        except ValueError:\n            parser_state.position += len(parsing_content)\n            return parsing_content, '', False\n        try:\n            end_idx = parsing_content.index(self.tool_end_token)\n        except ValueError:\n            # position holds until tool_end_token is found\n            return parsing_content[:start_idx], '', False\n        # move position to the end of tool_end_token\n        parser_state.position += (end_idx - start_idx) + len(self.tool_end_token)\n        return parsing_content[:start_idx], parsing_content[start_idx + len(self.tool_start_token):end_idx], True\n\n    def _parse_delta_tool_call(self, parser_state: ParserState, tool_content: str) -> Optional[DeltaToolCall]:\n        \"\"\"Parse tool content into a DeltaToolCall object.\n\n        This method handles parsing tool calls only when it's a valid tool\n        \"\"\"\n        parsable_arr = tool_content.strip()\n        try:\n            tool_call_arr: Dict = json.loads(parsable_arr)\n        except json.JSONDecodeError:\n            logger.debug('cannot parse into JSON yet')\n            return\n\n        fcall = DeltaFunctionCall()\n        func_name = tool_call_arr.get('name')\n        if func_name:\n            fcall.name = func_name\n        args = self.get_argments(tool_call_arr)\n        if args and isinstance(args, dict):\n            fcall.arguments = json.dumps(args, ensure_ascii=False)\n        # Return None if no new information to send\n        if not fcall.name and not fcall.arguments:\n            return\n        if not parser_state.id:\n            # A new tool call parsed, allocate a new id & index\n            parser_state.id = f'chatcmpl-tool-{shortuuid.random()}'\n            parser_state.current_index += 1\n        # Create and return the DeltaToolCall object\n        return DeltaToolCall(\n            id=parser_state.id,\n            index=parser_state.current_index,\n            function=fcall.model_dump(exclude_none=True),\n        )\n\n    def extract_tool_calls_streaming(\n        self,\n        previous_text: str,\n        current_text: str,\n        delta_text: str,\n        previous_token_ids: Sequence[int],\n        current_token_ids: Sequence[int],\n        delta_token_ids: Sequence[int],\n        request: ChatCompletionRequest,\n    ) -> Union[DeltaMessage, None]:\n        \"\"\"Extract tool calls from streaming model output.\n\n        This method processes incremental model output to extract tool calls, reasoning content, and regular text\n        content in a streaming fashion. It maintains parser state between calls to handle partial outputs.\n        \"\"\"\n        parser_state = getattr(request, '_tool_parser_state', None)\n        if parser_state is None:\n            parser_state = ParserState()\n            setattr(request, '_tool_parser_state', parser_state)\n\n        # Split the new content into text and tool content\n        split_result = self._split(parser_state, current_text[parser_state.position:])\n        text_content, tool_content, has_tool_end = split_result\n        delta = DeltaMessage()\n\n        # Add each type of content to the delta message if present\n        if text_content:\n            delta.content = text_content\n        if tool_content:\n            # Parse tool content into a DeltaToolCall object\n            delta_tool_call = self._parse_delta_tool_call(parser_state, tool_content)\n            if delta_tool_call is not None:\n                delta.tool_calls = [delta_tool_call]\n            if has_tool_end:\n                parser_state.reset_tool_call()\n        return delta\n\n    def extract_tool_calls(\n        self,\n        model_output: str,\n        request: ChatCompletionRequest,\n    ) -> ExtractedToolCallInformation:\n        \"\"\"Extract tool calls from complete model output.\n\n        This method processes the full model output to extract tool calls, reasoning content, and regular text content.\n        Unlike the streaming version, this processes the entire output at once.\n        \"\"\"\n        text = model_output\n\n        # Extract tool calls (content inside <tool_call> tags)\n        buf = []\n        scan_pos = 0\n        tool_calls = []\n        for idx, match in enumerate(self.tool_call_pat.finditer(text)):\n            buf.append(text[scan_pos:match.start()])  # Add text before the <tool_call> tag\n            scan_pos = match.end()\n            action = json.loads(match.group(1))  # Parse the tool call JSON\n            name, arguments = action['name'], json.dumps(action['arguments'], ensure_ascii=False)\n            tool_calls.append(ToolCall(function=FunctionCall(name=name, arguments=arguments)))\n        if scan_pos < len(text):\n            buf.append(text[scan_pos:])  # Add remaining text\n        text = ''.join(buf)  # Reconstruct text without <tool_call> tags\n\n        return ExtractedToolCallInformation(\n            content=text,\n            tool_calls=tool_calls,\n            tools_called=bool(tool_calls),\n        )\n"
  },
  {
    "path": "lmdeploy/serve/openai/tool_parser/qwen3coder_parser.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport json\nimport re\nfrom dataclasses import dataclass\nfrom typing import Any, Dict, Optional, Sequence, Tuple, Union\n\nimport shortuuid\n\nfrom lmdeploy.serve.openai.protocol import (ChatCompletionRequest, DeltaFunctionCall, DeltaMessage, DeltaToolCall,\n                                            ExtractedToolCallInformation, FunctionCall, ToolCall)\nfrom lmdeploy.utils import get_logger\n\nfrom .tool_parser import ToolParser, ToolParserManager\n\nlogger = get_logger('lmdeploy')\n\n\n@dataclass\nclass ParserState(object):\n    \"\"\"Maintains the state of parsing during tool call extraction.\"\"\"\n    position: int = 0  # Current position in the text being parsed\n    current_index: int = -1  # Index of the current tool call\n\n    id: str = ''  # ID of the current tool call\n\n    def reset_tool_call(self):\n        \"\"\"Called when `</tool_call>` finish tag occurred.\"\"\"\n        self.id = ''\n\n\n@ToolParserManager.register_module(['qwen3coder'])\nclass Qwen3CoderToolParser(ToolParser):\n    \"\"\"Parser for Qwen3 Coder model's tool call format.\n\n    Handles the extraction of tool calls from Qwen3 Coder's output format, which uses purely XML tags for function names\n    and parameters, e.g., <tool_call> <function=func_name> <parameter=arg_name>arg_value</parameter> </function>\n    </tool_call>\n    \"\"\"\n\n    def __init__(self, tokenizer: object):\n        super().__init__(tokenizer)\n        self.tool_start_token = '<tool_call>'\n        self.tool_end_token = '</tool_call>'\n        self.func_prefix = '<function='\n        self.func_end_token = '</function>'\n        self.param_prefix = '<parameter='\n        self.param_end_token = '</parameter>'\n\n        self.tool_call_pat = re.compile(r'\\n*<tool_call>(.*?)</tool_call>', re.DOTALL)\n\n    def _split(self, parser_state: ParserState, parsing_content: str) -> Tuple[str, str, bool]:\n        \"\"\"Split content into tuple: (text_content, tool_content, has_tool_end)\"\"\"\n        try:\n            start_idx = parsing_content.index(self.tool_start_token)\n            parser_state.position += start_idx\n        except ValueError:\n            parser_state.position += len(parsing_content)\n            return parsing_content, '', False\n\n        try:\n            end_idx = parsing_content.index(self.tool_end_token)\n        except ValueError:\n            return parsing_content[:start_idx], parsing_content[start_idx:], False\n\n        rem = end_idx - start_idx\n        parser_state.position += rem + len(self.tool_end_token)\n        return parsing_content[:start_idx], parsing_content[start_idx:end_idx + len(self.tool_end_token)], True\n\n    def _extract_params(self, content: str) -> Tuple[Optional[str], Dict[str, Any], bool]:\n        \"\"\"Parse XML tool content into components.\"\"\"\n        content = content.replace(self.tool_start_token, '').replace(self.tool_end_token, '').strip()\n\n        func_name = None\n        func_start = content.find(self.func_prefix)\n        if func_start != -1:\n            name_start = func_start + len(self.func_prefix)\n            terminators = [idx for idx in (content.find('>', name_start), content.find('\\n', name_start)) if idx != -1]\n            if terminators:\n                func_name = content[name_start:min(terminators)].strip()\n\n        args_dict = {}\n        search_idx = 0\n        while True:\n            param_start = content.find(self.param_prefix, search_idx)\n            if param_start == -1:\n                break\n\n            name_start = param_start + len(self.param_prefix)\n            terminators = [idx for idx in (content.find('>', name_start), content.find('\\n', name_start)) if idx != -1]\n            if not terminators:\n                break\n\n            name_end = min(terminators)\n            param_name = content[name_start:name_end].strip()\n\n            val_start = name_end + 1\n            val_end = content.find(self.param_end_token, val_start)\n            if val_end == -1:\n                break\n\n            param_val_str = content[val_start:val_end].strip()\n\n            if param_val_str.lower() == 'null':\n                val = None\n            elif param_val_str.lower() == 'true':\n                val = True\n            elif param_val_str.lower() == 'false':\n                val = False\n            else:\n                try:\n                    val = json.loads(param_val_str)\n                except json.JSONDecodeError:\n                    val = param_val_str\n            args_dict[param_name] = val\n            search_idx = val_end + len(self.param_end_token)\n\n        is_func_closed = self.func_end_token in content\n        return func_name, args_dict, is_func_closed\n\n    def extract_tool_calls_streaming(\n        self,\n        previous_text: str,\n        current_text: str,\n        delta_text: str,\n        previous_token_ids: Sequence[int],\n        current_token_ids: Sequence[int],\n        delta_token_ids: Sequence[int],\n        request: ChatCompletionRequest,\n    ) -> Union[DeltaMessage, None]:\n\n        parser_state = getattr(request, '_tool_parser_state', None)\n        if parser_state is None:\n            parser_state = ParserState()\n            setattr(request, '_tool_parser_state', parser_state)\n\n        split_result = self._split(parser_state, current_text[parser_state.position:])\n        text_content, tool_content, has_tool_end = split_result\n\n        delta = DeltaMessage()\n        if text_content:\n            delta.content = text_content\n\n        if tool_content:\n            if not parser_state.id:\n                parser_state.id = f'chatcmpl-tool-{shortuuid.random()}'\n                parser_state.current_index += 1\n                parser_state.has_emitted_name = False\n                parser_state.has_emitted_json_start = False\n                parser_state.json_closed = False\n                parser_state.emitted_params = set()\n\n            func_name, args_dict, is_func_closed = self._extract_params(tool_content)\n\n            fcall_delta = DeltaFunctionCall()\n            has_updates = False\n\n            if func_name and not getattr(parser_state, 'has_emitted_name', False):\n                fcall_delta.name = func_name\n                parser_state.has_emitted_name = True\n                has_updates = True\n\n            json_fragments = []\n            if not getattr(parser_state, 'has_emitted_json_start', False):\n                if args_dict or is_func_closed:\n                    json_fragments.append('{')\n                    parser_state.has_emitted_json_start = True\n\n            for k, v in args_dict.items():\n                if k not in parser_state.emitted_params:\n                    prefix = ', ' if len(parser_state.emitted_params) > 0 else ''\n                    serialized = json.dumps(v, ensure_ascii=False)\n                    json_fragments.append(f'{prefix}\"{k}\": {serialized}')\n                    parser_state.emitted_params.add(k)\n\n            if is_func_closed and not getattr(parser_state, 'json_closed', False):\n                if getattr(parser_state, 'has_emitted_json_start', False):\n                    json_fragments.append('}')\n                    parser_state.json_closed = True\n\n            joined_fragments = ''.join(json_fragments)\n            if joined_fragments:\n                fcall_delta.arguments = joined_fragments\n                has_updates = True\n\n            if has_updates:\n                parsed_delta = DeltaToolCall(\n                    id=parser_state.id,\n                    index=parser_state.current_index,\n                    function=fcall_delta,\n                )\n                delta.tool_calls = [parsed_delta]\n\n        if has_tool_end:\n            parser_state.reset_tool_call()\n            # Prepare for the next tool call\n            if hasattr(parser_state, 'has_emitted_name'):\n                delattr(parser_state, 'has_emitted_name')\n                delattr(parser_state, 'has_emitted_json_start')\n                delattr(parser_state, 'json_closed')\n                delattr(parser_state, 'emitted_params')\n\n        return delta\n\n    def extract_tool_calls(\n        self,\n        model_output: str,\n        request: ChatCompletionRequest,\n    ) -> ExtractedToolCallInformation:\n        text = model_output\n        buf = []\n        scan_pos = 0\n        tool_calls = []\n\n        for idx, match in enumerate(self.tool_call_pat.finditer(text)):\n            buf.append(text[scan_pos:match.start()])\n            scan_pos = match.end()\n\n            tool_content = match.group(1)\n            func_name, args_dict, _ = self._extract_params(tool_content)\n\n            if func_name:\n                tool_calls.append(\n                    ToolCall(function=FunctionCall(\n                        name=func_name, arguments=json.dumps(args_dict, ensure_ascii=False) if args_dict else '{}')))\n\n        if scan_pos < len(text):\n            buf.append(text[scan_pos:])\n\n        text = ''.join(buf)\n\n        return ExtractedToolCallInformation(\n            content=text,\n            tool_calls=tool_calls,\n            tools_called=bool(tool_calls),\n        )\n"
  },
  {
    "path": "lmdeploy/serve/openai/tool_parser/tool_parser.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n# modified from https://github.com/vllm-project/vllm/tree/v0.7.3/vllm/entrypoints/openai/tool_parsers\nfrom functools import cached_property\nfrom typing import Dict, List, Sequence, Union\n\nfrom mmengine import Registry\n\nfrom lmdeploy.serve.openai.protocol import ChatCompletionRequest, DeltaMessage, ExtractedToolCallInformation\nfrom lmdeploy.utils import get_logger\n\nlogger = get_logger('lmdeploy')\nToolParserManager = Registry('tool_parser', locations=['lmdeploy.serve.openai.tool_parser'])\n\n\nclass ToolParser:\n    \"\"\"Abstract ToolParser class that should not be used directly.\n\n    Provided properties and methods should be used in derived classes.\n    \"\"\"\n\n    def __init__(self, tokenizer: object):\n        self.prev_tool_call_arr: List[Dict] = []\n        # the index of the tool call that is currently being parsed\n        self.current_tool_id: int = -1\n        self.current_tool_name_sent: bool = False\n        self.streamed_args_for_tool: List[str] = []\n\n        self.model_tokenizer = tokenizer\n\n    @cached_property\n    def vocab(self) -> Dict[str, int]:\n        # NOTE: Only PreTrainedTokenizerFast is guaranteed to have .vocab\n        # whereas all tokenizers have .get_vocab()\n        return self.model_tokenizer.get_vocab()\n\n    def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest:\n        \"\"\"Static method that used to adjust the request parameters.\"\"\"\n        return request\n\n    def extract_tool_calls(self, model_output: str, request: ChatCompletionRequest) -> ExtractedToolCallInformation:\n        \"\"\"Static method that should be implemented for extracting tool calls\n        from a complete model-generated string.\n\n        Used for non-streaming responses where we have the entire model response available before sending to the client.\n        Static because it's stateless.\n        \"\"\"\n        raise NotImplementedError('AbstractToolParser.extract_tool_calls has not been implemented!')\n\n    def extract_tool_calls_streaming(\n        self,\n        previous_text: str,\n        current_text: str,\n        delta_text: str,\n        previous_token_ids: Sequence[int],\n        current_token_ids: Sequence[int],\n        delta_token_ids: Sequence[int],\n        request: ChatCompletionRequest,\n    ) -> Union[DeltaMessage, None]:\n        \"\"\"Instance method that should be implemented for extracting tool calls\n        from an incomplete response; for use when handling tool calls and\n        streaming.\n\n        Has to be an instance method because  it requires state - the current tokens/diffs, but also the information\n        about what has previously been parsed and extracted (see constructor)\n        \"\"\"\n        raise NotImplementedError('AbstractToolParser.extract_tool_calls_streaming has not been '\n                                  'implemented!')\n"
  },
  {
    "path": "lmdeploy/serve/openai/tool_parser/utils.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n# Copied from https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/entrypoints/openai/tool_parsers/utils.py\n\nimport json\nfrom json import JSONDecodeError, JSONDecoder\nfrom typing import Any, List, Tuple\n\nimport partial_json_parser\nfrom partial_json_parser.core.options import Allow\n\n\ndef find_common_prefix(s1: str, s2: str) -> str:\n    \"\"\"Finds a common prefix that is shared between two strings, if there is\n    one. Order of arguments is NOT important.\n\n    This function is provided as a UTILITY for extracting information from JSON generated by partial_json_parser, to\n    help in ensuring that the right tokens are returned in streaming, so that close-quotes, close-brackets and close-\n    braces are not returned prematurely.\n\n    e.g. find_common_prefix('{\"fruit\": \"ap\"}', '{\"fruit\": \"apple\"}') -> '{\"fruit\": \"ap'\n    \"\"\"\n    prefix = ''\n    min_length = min(len(s1), len(s2))\n    for i in range(0, min_length):\n        if s1[i] == s2[i]:\n            prefix += s1[i]\n        else:\n            break\n    return prefix\n\n\ndef find_common_suffix(s1: str, s2: str) -> str:\n    \"\"\"Finds a common suffix shared between two strings, if there is one. Order\n    of arguments is NOT important. Stops when the suffix ends OR it hits an\n    alphanumeric character.\n\n    e.g. find_common_suffix('{\"fruit\": \"ap\"}', '{\"fruit\": \"apple\"}') -> '\"}'\n    \"\"\"\n    suffix = ''\n    min_length = min(len(s1), len(s2))\n    for i in range(1, min_length + 1):\n        if s1[-i] == s2[-i] and not s1[-i].isalnum():\n            suffix = s1[-i] + suffix\n        else:\n            break\n    return suffix\n\n\ndef extract_intermediate_diff(curr: str, old: str) -> str:\n    \"\"\"Given two strings, extract the difference in the middle between two\n    strings that are known to have a common prefix and/or suffix.\n\n    This function is provided as a UTILITY for extracting information from JSON\n    generated by partial_json_parser, to help in ensuring that the right tokens\n    are returned in streaming, so that close-quotes, close-brackets and\n    close-braces are not returned prematurely. The order of arguments IS\n    important - the new version of the partially-parsed JSON must be the first\n    argument, and the secnod argument must be from the previous generation.\n\n    What it returns, is tokens that should be streamed to the client.\n\n    e.g. extract_intermediate_diff('{\"fruit\": \"apple\"}', '{\"fruit\": \"ap\"}')\n        -> 'ple'\n    \"\"\"\n    suffix = find_common_suffix(curr, old)\n\n    old = old[::-1].replace(suffix[::-1], '', 1)[::-1]\n    prefix = find_common_prefix(curr, old)\n    diff = curr\n    if len(suffix):\n        diff = diff[::-1].replace(suffix[::-1], '', 1)[::-1]\n\n    if len(prefix):\n        # replace the prefix only once in case it's mirrored\n        diff = diff.replace(prefix, '', 1)\n\n    return diff\n\n\ndef find_all_indices(string: str, substring: str) -> List[int]:\n    \"\"\"Find all (starting) indices of a substring in a given string.\n\n    Useful for tool call extraction\n    \"\"\"\n    indices = []\n    index = -1\n    while True:\n        index = string.find(substring, index + 1)\n        if index == -1:\n            break\n        indices.append(index)\n    return indices\n\n\n# partial_json_parser doesn't support extra data and\n# JSONDecorder.raw_decode doesn't support partial JSON\ndef partial_json_loads(input_str: str, flags: Allow) -> Tuple[Any, int]:\n    try:\n        return (partial_json_parser.loads(input_str, flags), len(input_str))\n    except JSONDecodeError as e:\n        if 'Extra data' in e.msg:\n            dec = JSONDecoder()\n            return dec.raw_decode(input_str)\n        raise\n\n\ndef is_complete_json(input_str: str) -> bool:\n    try:\n        json.loads(input_str)\n        return True\n    except JSONDecodeError:\n        return False\n\n\ndef consume_space(i: int, s: str) -> int:\n    while i < len(s) and s[i].isspace():\n        i += 1\n    return i\n"
  },
  {
    "path": "lmdeploy/serve/processors/__init__.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom .multimodal import MultimodalProcessor\n\n__all__ = ['MultimodalProcessor']\n"
  },
  {
    "path": "lmdeploy/serve/processors/multimodal.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport asyncio\nfrom typing import Any, Dict, List, Literal, Tuple\n\nimport PIL\n\nfrom lmdeploy.model import MODELS, BaseChatTemplate\nfrom lmdeploy.tokenizer import Tokenizer\nfrom lmdeploy.utils import get_logger\nfrom lmdeploy.vl.constants import Modality\nfrom lmdeploy.vl.media.connection import load_from_url\nfrom lmdeploy.vl.media.image import ImageMediaIO\nfrom lmdeploy.vl.media.time_series import TimeSeriesMediaIO\nfrom lmdeploy.vl.media.video import VideoMediaIO\n\nlogger = get_logger('lmdeploy')\n\n\nclass MultimodalProcessor:\n    \"\"\"Processor for handling prompt preprocessing, message content merging,\n    and multimodal processing.\"\"\"\n\n    def __init__(self,\n                 tokenizer: Tokenizer,\n                 chat_template: BaseChatTemplate,\n                 vl_encoder=None,\n                 backend: str | None = None):\n        \"\"\"Initialize MultimodalProcessor.\n\n        Args:\n            tokenizer: Tokenizer instance for encoding prompts.\n            chat_template: Chat template instance for message processing.\n            vl_encoder: Optional ImageEncoder instance for multimodal processing.\n            backend: Optional backend name ('turbomind' or 'pytorch') for multimodal processing.\n        \"\"\"\n        self.tokenizer = tokenizer\n        self.chat_template = chat_template\n        self.vl_encoder = vl_encoder\n        self.backend = backend\n\n    @staticmethod\n    def merge_message_content(msg: Dict) -> Dict:\n        \"\"\"Merge multimodal content blocks and ensure content field exists.\n\n        This function normalizes message content to match vLLM's behavior:\n        1. Missing content field -> add content='' (empty string)\n        2. None content -> convert to content='' (empty string)\n        3. String content -> return as-is\n        4. List content (multimodal) -> merge all text blocks with newline separator\n\n        Args:\n            msg: A message dict with 'role' and optionally 'content' field\n\n        Returns:\n            A message dict with 'content' field guaranteed to exist\n\n        Note:\n            This implementation is based on vLLM's content processing logic.\n            vLLM uses \"\\n\".join() to merge multiple text blocks from multimodal content.\n\n        References:\n            - vLLM content normalization:\n              https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/chat_utils.py\n              See _parse_chat_message_content() and _parse_chat_message_content_parts()\n            - vLLM text merging logic:\n              text_prompt = \"\\n\".join(texts)\n        \"\"\"\n        # If content is missing or None, convert to empty string (matches vLLM behavior)\n        # This prevents Jinja2 template errors when rendering chat templates\n        if 'content' not in msg or msg['content'] is None:\n            result = dict(msg)\n            result['content'] = ''\n            return result\n\n        # If content is already a string, return as-is\n        if isinstance(msg['content'], str):\n            return msg\n\n        # If content is a list, merge all text blocks into a single string\n        # This matches vLLM's behavior: text_prompt = \"\\n\".join(texts)\n        content_parts = []\n        for block in msg['content']:\n            if isinstance(block, dict) and block.get('type') == 'text':\n                content_parts.append(block.get('text', ''))\n        merged_content = '\\n'.join(content_parts)\n\n        # Preserve all other fields in the message (e.g., tool_calls)\n        result = dict(msg)\n        result['content'] = merged_content\n        return result\n\n    @staticmethod\n    def _parse_multimodal_item(i: int, in_messages: List[Dict], out_messages: List[Dict], media_io_kwargs: Dict[str,\n                                                                                                                Any]):\n        \"\"\"Synchronous helper to parse a single multimodal message item.\"\"\"\n        role = in_messages[i]['role']\n        content = in_messages[i]['content']\n\n        if role != 'user' or isinstance(content, str):\n            out_messages[i] = in_messages[i]\n            return\n\n        assert isinstance(content, list)\n        out_message = dict(role=role, content=[])\n\n        for item in content:\n            item_type = item.get('type')\n            if item_type == 'text':\n                out_message['content'].append(item)\n                continue\n\n            item_params = item.get(item_type, {})\n            data_src = item_params.pop('url', None) or item_params.pop('data', None)\n\n            if item_type == 'image_data':\n                modality = Modality.IMAGE\n                data = data_src\n            elif item_type == 'image_url':\n                modality = Modality.IMAGE\n                img_io = ImageMediaIO(**media_io_kwargs.get('image', {}))\n                data = load_from_url(data_src, img_io)\n            elif item_type == 'video_url':\n                modality = Modality.VIDEO\n                vid_io = VideoMediaIO(image_io=ImageMediaIO(), **media_io_kwargs.get('video', {}))\n                data, metadata = load_from_url(data_src, vid_io)\n                item_params['video_metadata'] = metadata\n            elif item_type == 'time_series_url':\n                modality = Modality.TIME_SERIES\n                ts_io = TimeSeriesMediaIO(**media_io_kwargs.get('time_series', {}))\n                data = load_from_url(data_src, ts_io)\n            else:\n                raise NotImplementedError(f'unknown type: {item_type}')\n\n            out_message['content'].append({'type': modality, 'data': data, **item_params})\n\n        out_messages[i] = out_message\n\n    @staticmethod\n    async def async_parse_multimodal_item(messages: List[Dict],\n                                          media_io_kwargs: Dict[str, Any] | None = None) -> List[Dict]:\n        \"\"\"Convert user-input multimodal data into GPT4V message format.\"\"\"\n        if isinstance(messages, dict):\n            messages = [messages]\n        assert isinstance(messages, list)\n\n        out_messages = [None] * len(messages)\n        media_io_kwargs = media_io_kwargs or {}\n        loop = asyncio.get_event_loop()\n\n        await asyncio.gather(*[\n            loop.run_in_executor(None, MultimodalProcessor._parse_multimodal_item, i, messages, out_messages,\n                                 media_io_kwargs) for i in range(len(messages))\n        ])\n        return out_messages\n\n    async def get_prompt_input(self,\n                               prompt: str | List[Dict],\n                               do_preprocess: bool,\n                               sequence_start: bool,\n                               adapter_name: str,\n                               tools: List[object] | None = None,\n                               reasoning_effort: Literal['low', 'medium', 'high'] | None = None,\n                               chat_template_kwargs: Dict | None = None,\n                               media_io_kwargs: Dict[str, Any] | None = None,\n                               mm_processor_kwargs: Dict[str, Any] | None = None,\n                               **kwargs):\n        \"\"\"Process prompt and return prompt string and input_ids.\n\n        Handles both text-only and multimodal prompts. If multimodal input is detected\n        and vl_encoder is available, processes images accordingly.\n\n        Args:\n            prompt: Input prompt as string or list of message dicts.\n            do_preprocess: Whether to apply chat template preprocessing.\n            sequence_start: Indicator for starting a sequence.\n            adapter_name: Adapter name for selecting chat template.\n            tools: Optional list of tools.\n            reasoning_effort: Optional reasoning effort level.\n            chat_template_kwargs: Optional kwargs for chat template.\n            media_io_kwargs: Optional kwargs for media IO operations.\n            mm_processor_kwargs: Optional kwargs for multimodal processor.\n            **kwargs: Additional keyword arguments.\n\n        Returns:\n            Dict with 'prompt' (str) and 'input_ids' (List[int]) keys for text-only,\n            or dict with multimodal data for multimodal prompts.\n        \"\"\"\n        # Handle string input\n        if isinstance(prompt, str):\n            return await self._get_text_prompt_input(prompt=prompt,\n                                                     do_preprocess=do_preprocess,\n                                                     sequence_start=sequence_start,\n                                                     adapter_name=adapter_name,\n                                                     tools=tools,\n                                                     reasoning_effort=reasoning_effort,\n                                                     chat_template_kwargs=chat_template_kwargs,\n                                                     **kwargs)\n\n        # Handle list input\n        elif isinstance(prompt, list):\n            # Check if multimodal input exists\n            has_multimodal_input = self._has_multimodal_input(prompt)\n\n            # If no multimodal input or no vl_encoder, use text-only processing\n            if not has_multimodal_input or self.vl_encoder is None:\n                return await self._get_text_prompt_input(prompt=prompt,\n                                                         do_preprocess=do_preprocess,\n                                                         sequence_start=sequence_start,\n                                                         adapter_name=adapter_name,\n                                                         tools=tools,\n                                                         reasoning_effort=reasoning_effort,\n                                                         chat_template_kwargs=chat_template_kwargs,\n                                                         **kwargs)\n\n            # Process multimodal input\n            return await self._get_multimodal_prompt_input(messages=prompt,\n                                                           do_preprocess=do_preprocess,\n                                                           sequence_start=sequence_start,\n                                                           adapter_name=adapter_name,\n                                                           tools=tools,\n                                                           chat_template_kwargs=chat_template_kwargs,\n                                                           media_io_kwargs=media_io_kwargs,\n                                                           mm_processor_kwargs=mm_processor_kwargs,\n                                                           **kwargs)\n        else:\n            raise RuntimeError(f'unsupported prompt type: {type(prompt)}')\n\n    @staticmethod\n    def format_prompts(prompts: Any) -> List[Dict]:\n        \"\"\"Format prompts.\"\"\"\n        if not isinstance(prompts, list):\n            prompts = [prompts]\n        # str or batch of str\n        if all(isinstance(prompt, str) for prompt in prompts):\n            return prompts\n        if (MultimodalProcessor._is_openai_message(prompts)\n                or all(MultimodalProcessor._is_openai_message(prompt) for prompt in prompts)):\n            return prompts\n        if all(MultimodalProcessor._is_str_images_pair(prompt) for prompt in prompts):\n            # batch of (prompt, image or [images]) or (image or [images], prompt) ->\n            # [[openai_gpt4v_message], [openai_gpt4v_message], ...]\n            return [[MultimodalProcessor._re_format_prompt_images_pair(prompt)] for prompt in prompts]\n        raise ValueError(f'Unsupported prompts: {prompts}. Only support str, openai message format, '\n                         'or (prompt, image or [images]) or (image or [images], prompt) pair.')\n\n    @staticmethod\n    def _is_openai_message(message) -> bool:\n        \"\"\"Check if the message conforms to openai message format.\"\"\"\n        return isinstance(message, list) and all(isinstance(msg, dict) for msg in message)\n\n    @staticmethod\n    def _is_str_images_pair(message) -> bool:\n        \"\"\"Check if the message is a (prompt, image or [images]) or (image or\n        [images], prompt) pair.\"\"\"\n        if not (isinstance(message, tuple) and len(message) == 2):\n            return False\n        _1, _2 = message\n        if MultimodalProcessor._is_image(_1) or MultimodalProcessor._is_image_list(_1):\n            _1, _2 = _2, _1\n        return isinstance(_1, str) and (MultimodalProcessor._is_image(_2) or MultimodalProcessor._is_image_list(_2))\n\n    @staticmethod\n    def _is_image(obj) -> bool:\n        # image or image url or base64-encoded image data\n        return (isinstance(obj, PIL.Image.Image)\n                or isinstance(obj, str) and (obj.startswith('http') or obj.startswith('data:image')))\n\n    @staticmethod\n    def _is_image_list(obj) -> bool:\n        return isinstance(obj, list) and all(MultimodalProcessor._is_image(img) for img in obj)\n\n    @staticmethod\n    def _re_format_prompt_images_pair(prompt: Tuple) -> Dict:\n        \"\"\"Reformat the prompt to openai message format.\"\"\"\n        from lmdeploy.vl import load_image\n\n        messages = {'role': 'user', 'content': []}\n        prompt, images = prompt\n        prompt_first = True\n        if MultimodalProcessor._is_image(prompt) or MultimodalProcessor._is_image_list(prompt):\n            prompt, images = images, prompt\n            prompt_first = False\n        image_contents = []\n        images = images if isinstance(images, list) else [images]\n        for image in images:\n            # 'image_url': means url or local path to image.\n            # 'image_data': means PIL.Image.Image object.\n            if isinstance(image, str):\n                image = load_image(image)\n                item = {'type': 'image_data', 'image_data': {'data': image}}\n            elif isinstance(image, PIL.Image.Image):\n                item = {'type': 'image_data', 'image_data': {'data': image}}\n            else:\n                raise ValueError('image should be a str(url/path) or PIL.Image.Image')\n            image_contents.append(item)\n\n        if prompt_first:\n            messages['content'].append({'type': 'text', 'text': prompt})\n            messages['content'].extend(image_contents)\n        else:\n            messages['content'].extend(image_contents)\n            messages['content'].append({'type': 'text', 'text': prompt})\n        return messages\n\n    def _has_multimodal_input(self, messages: List[Dict]) -> bool:\n        \"\"\"Check if messages contain multimodal input (images).\"\"\"\n        multimodal_types = ['image_url', 'image_data', 'video_url', 'time_series_url']\n        return any(\n            isinstance(message.get('content'), list) and any(\n                item.get('type') in multimodal_types for item in message['content']) for message in messages)\n\n    async def _get_text_prompt_input(self,\n                                     prompt: str | List[Dict],\n                                     do_preprocess: bool,\n                                     sequence_start: bool,\n                                     adapter_name: str,\n                                     tools: List[object] | None = None,\n                                     reasoning_effort: Literal['low', 'medium', 'high'] | None = None,\n                                     chat_template_kwargs: Dict | None = None,\n                                     **kwargs):\n        \"\"\"Process text-only prompt and return prompt string and input_ids.\"\"\"\n        # Change multimodal data to openai text messages\n        if isinstance(prompt, list):\n            prompt = [self.merge_message_content(msg) for msg in prompt]\n        if do_preprocess:\n            # use adapter's chat template if possible\n            chat_template = self.chat_template\n            if adapter_name in MODELS.module_dict:\n                chat_template = MODELS.module_dict[adapter_name]()\n        else:\n            chat_template = BaseChatTemplate()\n        chat_template_kwargs = chat_template_kwargs or {}\n        prompt = chat_template.messages2prompt(prompt,\n                                               sequence_start,\n                                               tools=tools,\n                                               reasoning_effort=reasoning_effort,\n                                               **chat_template_kwargs)\n        if prompt is None:\n            raise ValueError(\n                f'You are using base template to handle chat task. Please specify a `--chat-template` name chosen from `lmdeploy list` if you want to use OpenAI messages input.'  # noqa\n            )\n        input_ids = self.tokenizer.encode(prompt, add_bos=sequence_start)\n        return {'prompt': prompt, 'input_ids': input_ids}\n\n    async def _get_multimodal_prompt_input(self,\n                                           messages: List[Dict],\n                                           do_preprocess: bool,\n                                           sequence_start: bool,\n                                           adapter_name: str,\n                                           tools: List[object] | None = None,\n                                           chat_template_kwargs: Dict | None = None,\n                                           media_io_kwargs: Dict[str, Any] | None = None,\n                                           mm_processor_kwargs: Dict[str, Any] | None = None,\n                                           **kwargs):\n        \"\"\"Process multimodal prompt and return processed data for inference\n        engines.\"\"\"\n        chat_template = self.chat_template if do_preprocess else BaseChatTemplate()\n        messages = await self.async_parse_multimodal_item(messages, media_io_kwargs)\n        results = await self.vl_encoder.preprocess(messages, mm_processor_kwargs)\n\n        if self.backend == 'turbomind':\n            # for tm engine, this module perform vision embedding after image\n            # preprocessing. It utilizes the hf model's vision embeddings\n            # functions and returns the input_ids, input_embeddings,\n            # embedding_ranges and so on. All the returned values are passed\n            # to tm engine for token generation\n            results = await self.vl_encoder.async_infer(results)\n            results = await self.vl_encoder.wrap_for_turbomind(messages=results,\n                                                               chat_template=chat_template,\n                                                               tokenizer=self.tokenizer,\n                                                               sequence_start=sequence_start,\n                                                               tools=tools,\n                                                               chat_template_kwargs=chat_template_kwargs)\n        elif self.backend == 'pytorch':\n            # for pt engine, this module only conduct the image preprocessing\n            # It leaves the vision embedding to the pt engine\n            results = await self.vl_encoder.wrap_for_pytorch(messages=results,\n                                                             chat_template=chat_template,\n                                                             tokenizer=self.tokenizer,\n                                                             sequence_start=sequence_start,\n                                                             tools=tools,\n                                                             chat_template_kwargs=chat_template_kwargs)\n        return results\n"
  },
  {
    "path": "lmdeploy/serve/proxy/__init__.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n"
  },
  {
    "path": "lmdeploy/serve/proxy/proxy.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nimport asyncio\nimport copy\nimport json\nimport os\nimport os.path as osp\nimport random\nimport threading\nimport time\nfrom collections import deque\nfrom http import HTTPStatus\nfrom typing import Deque, Literal\n\nimport aiohttp\nimport numpy as np\nimport requests\nimport uvicorn\nfrom fastapi import BackgroundTasks, Depends, FastAPI, Request\nfrom fastapi.middleware.cors import CORSMiddleware\nfrom fastapi.responses import JSONResponse, StreamingResponse\nfrom pydantic import BaseModel, Field\n\nfrom lmdeploy.pytorch.disagg.config import DistServeRDMAConfig, EngineRole, RDMALinkType, ServingStrategy\nfrom lmdeploy.pytorch.disagg.conn.protocol import MigrationProtocol, MigrationRequest\nfrom lmdeploy.pytorch.disagg.conn.proxy_conn import PDConnectionPool\nfrom lmdeploy.pytorch.disagg.messages import PDConnectionMessage\nfrom lmdeploy.serve.openai.api_server import create_error_response\nfrom lmdeploy.serve.openai.protocol import ModelCard  # noqa: E501\nfrom lmdeploy.serve.openai.protocol import ChatCompletionRequest, CompletionRequest, ModelList, ModelPermission\nfrom lmdeploy.serve.proxy.utils import AIOHTTP_TIMEOUT, LATENCY_DEQUE_LEN, ErrorCodes, RoutingStrategy, err_msg\nfrom lmdeploy.serve.utils.server_utils import validate_json_request\nfrom lmdeploy.utils import get_logger\n\nfrom .streaming_response import ProxyStreamingResponse\nfrom .utils import APIServerException\n\nlogger = get_logger('lmdeploy')\n\n\nclass Status(BaseModel):\n    \"\"\"Status protocol consists of models' information.\"\"\"\n    role: EngineRole = EngineRole.Hybrid\n    models: list[str] = Field(default=[], examples=[[]])\n    unfinished: int = 0\n    latency: Deque = Field(default=deque(maxlen=LATENCY_DEQUE_LEN), examples=[[]])\n    speed: int | None = Field(default=None, examples=[None])\n\n\nclass Node(BaseModel):\n    \"\"\"Node protocol consists of url and status.\"\"\"\n    url: str\n    status: Status | None = None\n\n\nCONTROLLER_HEART_BEAT_EXPIRATION = int(os.getenv('LMDEPLOY_CONTROLLER_HEART_BEAT_EXPIRATION', 90))\n\n\ndef heart_beat_controller(proxy_controller):\n    while True:\n        time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION)\n        logger.info('Start heart beat check')\n        proxy_controller.remove_stale_nodes_by_expiration()\n\n\nclass NodeManager:\n    \"\"\"Manage all the sub nodes.\n\n    Args:\n        config_path (str): the path of the config file.\n        strategy (str): the strategy to dispatch node to handle the requests.\n            - **random**: not fully radom, but decided by the speed of nodes.\n            - **min_expected_latency**: will compute the expected latency to\n                process the requests. The sooner of the node, the more requests\n                will be dispatched to it.\n            - **min_observed_latency**: Based on previous finished requests. The\n                sooner they get processed, the more requests will be dispatched\n                to.\n    \"\"\"\n\n    def __init__(self,\n                 config_path: str | None = None,\n                 serving_strategy: str = 'Hybrid',\n                 routing_strategy: str = 'min_expected_latency',\n                 migration_protocol: str = 'RDMA',\n                 link_type: str = 'RoCE',\n                 with_gdr: bool = True,\n                 cache_status: bool = True) -> None:\n        self.nodes = dict()\n        self.serving_strategy = ServingStrategy[serving_strategy]\n        self.routing_strategy = RoutingStrategy.from_str(routing_strategy)\n\n        self.cache_status = cache_status\n        self.latencies = dict()\n        self.config_path = osp.join(osp.dirname(osp.realpath(__file__)), 'proxy_config.json')\n        if config_path is not None:\n            self.config_path = config_path\n        if osp.exists(self.config_path) and self.cache_status:\n            with open(self.config_path, 'r') as config_file:\n                if os.path.getsize(self.config_path) > 0:\n                    logger.info(f'loading node configuration: {self.config_path}')\n                    config = json.load(config_file)\n                    self.nodes = {\n                        node_url: Status.model_validate_json(node_status)\n                        for node_url, node_status in config.items()\n                    }\n        self.heart_beat_thread = threading.Thread(target=heart_beat_controller, args=(self, ), daemon=True)\n        self.heart_beat_thread.start()\n        self.aiotimeout = aiohttp.ClientTimeout(total=AIOHTTP_TIMEOUT)\n\n        # For PD Disaggregation\n        self.migration_protocol = MigrationProtocol[migration_protocol]\n        self.rdma_config = DistServeRDMAConfig(with_gdr=with_gdr, link_type=RDMALinkType[link_type])\n        self.pd_connection_pool = PDConnectionPool()\n        self.dummy_prefill = False\n\n    def get_nodes(self, role: EngineRole) -> dict[str, Status]:\n        items = list(self.nodes.items())\n        return {node_url: node_status for (node_url, node_status) in items if node_status.role == role}\n\n    @property\n    def hybrid_nodes(self):\n        return self.get_nodes(EngineRole.Hybrid)\n\n    @property\n    def prefill_nodes(self):\n        return self.get_nodes(EngineRole.Prefill)\n\n    @property\n    def decode_nodes(self):\n        return self.get_nodes(EngineRole.Decode)\n\n    def update_config_file(self):\n        \"\"\"Update the config file.\"\"\"\n        nodes = copy.deepcopy(self.nodes)\n        for _, status in nodes.items():\n            status.latency = deque(list(status.latency)[-LATENCY_DEQUE_LEN:])\n        if self.cache_status:\n            with open(self.config_path, 'w') as config_file:  # update cfg yml\n                json.dump({\n                    node_url: node_status.model_dump_json()\n                    for node_url, node_status in nodes.items()\n                },\n                          config_file,\n                          indent=2)\n\n    def add(self, node_url: str, status: Status | None = None):\n        \"\"\"Add a node to the manager.\n\n        Args:\n            node_url (str): A http url. Can be the url generated by\n                `lmdeploy serve api_server`.\n            description (Dict): The description of the node. An example:\n                {'http://0.0.0.0:23333': {models: ['internlm-chat-7b]},\n                speed: -1}. The speed here can be RPM or other metric. All the\n                values of nodes should be the same metric.\n        \"\"\"\n        if status is None:\n            status = self.nodes.get(node_url, Status())\n        if status.models != []:  # force register directly\n            self.remove(node_url)\n            self.nodes[node_url] = status\n            self.update_config_file()\n            return\n        try:\n            from lmdeploy.serve.openai.api_client import APIClient\n            client = APIClient(api_server_url=node_url)\n            status.models = client.available_models\n            self.nodes[node_url] = status\n        except requests.exceptions.RequestException as e:  # noqa\n            logger.error(f'exception happened when adding node {node_url}, {e}')\n            return self.handle_api_timeout(node_url)\n        self.update_config_file()\n\n    def remove(self, node_url: str):\n        \"\"\"Remove a node.\"\"\"\n        if node_url in self.nodes.keys():\n            self.nodes.pop(node_url)\n            self.update_config_file()\n            self.pd_connection_pool.dereg_instance(node_url)\n\n    def terminate_node(self, node_url: str):\n        \"\"\"Terminate a node.\"\"\"\n        success = True\n        if node_url in self.nodes:\n            self.nodes.pop(node_url)\n            headers = {'accept': 'application/json'}\n            try:\n                response = requests.get(f'{node_url}/terminate', headers=headers)\n                if response.status_code != 200:\n                    success = False\n                    logger.error(f'Failed to terminate node {node_url}, '\n                                 f'error_code={response.status_code}, '\n                                 f'error_msg={response.text}')\n            except Exception as e:  # noqa\n                logger.error(f'exception happened when terminating node {node_url}, {e}')\n                success = False\n        else:\n            logger.error(f'terminating node {node_url} failed since it does not exist. '\n                         'May try /nodes/status to check the node list')\n            success = False\n        self.update_config_file()\n        return success\n\n    def terminate_all_nodes(self):\n        \"\"\"Terminate all nodes.\"\"\"\n        node_url_li = list(self.nodes.keys())\n        all_success = True\n        for node_url in node_url_li:\n            if not self.terminate_node(node_url):\n                all_success = False\n        return all_success\n\n    def remove_stale_nodes_by_expiration(self):\n        \"\"\"Remove stale nodes.\"\"\"\n        to_be_deleted = []\n        node_urls = list(self.nodes.keys())\n        for node_url in node_urls:\n            url = f'{node_url}/health'\n            headers = {'accept': 'application/json'}\n            try:\n                response = requests.get(url, headers=headers)\n                if response.status_code != 200:\n                    to_be_deleted.append(node_url)\n            except:  # noqa\n                to_be_deleted.append(node_url)\n        for node_url in to_be_deleted:\n            self.remove(node_url)\n            logger.info(f'Removed node_url: {node_url} '\n                        'due to heart beat expiration')\n\n    @property\n    def model_list(self):\n        \"\"\"Supported model list.\"\"\"\n        model_names = []\n        items = list(self.nodes.items())\n        for _, status in items:\n            model_names.extend(status.models)\n        return model_names\n\n    @property\n    def status(self):\n        \"\"\"Return the status.\"\"\"\n        return self.nodes\n\n    def get_node_url(self, model_name: str, role: EngineRole = EngineRole.Hybrid):\n        \"\"\"Add a node to the manager.\n\n        Args:\n            model_name (str): A http url. Can be the url generated by\n                `lmdeploy serve api_server`.\n        Return:\n            A node url or None.\n        \"\"\"\n\n        def get_matched_urls():\n            urls_with_speeds, speeds, urls_without_speeds = [], [], []\n            for node_url, status in self.get_nodes(role).items():\n                if model_name in status.models:\n                    if status.speed is not None:\n                        urls_with_speeds.append(node_url)\n                        speeds.append(status.speed)\n                    else:\n                        urls_without_speeds.append(node_url)\n            all_matched_urls = urls_with_speeds + urls_without_speeds\n            if len(all_matched_urls) == 0:\n                return None\n            # some nodes does not contain speed\n            # we can set them the average speed value\n            average_speed = sum(speeds) / len(speeds) if len(speeds) else 1\n            all_the_speeds = speeds + [average_speed] * len(urls_without_speeds)\n            return all_matched_urls, all_the_speeds\n\n        if self.routing_strategy == RoutingStrategy.RANDOM:\n            all_matched_urls, all_the_speeds = get_matched_urls()\n            if len(all_matched_urls) == 0:\n                return None\n            speed_sum = sum(all_the_speeds)\n            weights = [speed / speed_sum for speed in all_the_speeds]\n            index = random.choices(range(len(all_matched_urls)), weights=weights)[0]\n            url = all_matched_urls[index]\n            return url\n        elif self.routing_strategy == RoutingStrategy.MIN_EXPECTED_LATENCY:\n            all_matched_urls, all_the_speeds = get_matched_urls()\n            if len(all_matched_urls) == 0:\n                return None\n            min_latency = float('inf')\n            min_index = 0\n            # random traverse nodes for low concurrency situation\n            all_indexes = [i for i in range(len(all_the_speeds))]\n            random.shuffle(all_indexes)\n            for index in all_indexes:\n                latency = self.get_nodes(role)[all_matched_urls[index]].unfinished / all_the_speeds[index]\n                if min_latency > latency:\n                    min_latency = latency\n                    min_index = index\n            url = all_matched_urls[min_index]\n            return url\n        elif self.routing_strategy == RoutingStrategy.MIN_OBSERVED_LATENCY:\n            all_matched_urls, latencies = [], []\n            for node_url, node_status in self.get_nodes(role).items():\n                if model_name in node_status.models:\n                    if len(node_status.latency):\n                        latencies.append(np.mean(np.array(node_status.latency)))\n                    else:\n                        latencies.append(float('inf'))\n                    all_matched_urls.append(node_url)\n            if len(all_matched_urls) == 0:\n                return None\n            index = np.argmin(np.array(latencies))\n            return all_matched_urls[index]\n        else:\n            raise ValueError(f'Invalid strategy: {self.routing_strategy}')\n\n    async def check_request_model(self, model_name) -> JSONResponse | None:\n        \"\"\"Check if a request is valid.\"\"\"\n        if model_name in self.model_list:\n            return\n        ret = create_error_response(HTTPStatus.NOT_FOUND, f'The model {model_name!r} does not exist.')\n        return ret\n\n    def handle_unavailable_model(self, model_name):\n        \"\"\"Handle unavailable model.\n\n        Args:\n            model_name (str): the model in the request.\n        \"\"\"\n        logger.warning(f'no model name: {model_name}')\n        ret = {\n            'error_code': ErrorCodes.MODEL_NOT_FOUND,\n            'text': err_msg[ErrorCodes.MODEL_NOT_FOUND],\n        }\n        return json.dumps(ret).encode() + b'\\n'\n\n    def handle_api_timeout(self, node_url):\n        \"\"\"Handle the api time out.\"\"\"\n        logger.warning(f'api timeout: {node_url}')\n        ret = {\n            'error_code': ErrorCodes.API_TIMEOUT.value,\n            'text': err_msg[ErrorCodes.API_TIMEOUT],\n        }\n        return json.dumps(ret).encode() + b'\\n'\n\n    async def stream_generate(self, request: dict, node_url: str, endpoint: str):\n        \"\"\"Return a generator to handle the input request.\n\n        Args:\n            request (Dict): the input request.\n            node_url (str): the node url.\n            endpoint (str): the endpoint. Such as `/v1/chat/completions`.\n        \"\"\"\n        try:\n            async with aiohttp.ClientSession() as session:\n                async with session.post(node_url + endpoint, json=request, timeout=self.aiotimeout) as response:\n                    async for line in response.content:\n                        if line.strip():\n                            yield line + b'\\n\\n'\n        except (Exception, GeneratorExit, aiohttp.ClientError) as e:  # noqa\n            logger.error(f'caught an exception: {e}')\n            # exception happened, reduce unfinished num\n            yield self.handle_api_timeout(node_url)\n\n    async def generate(self, request: dict, node_url: str, endpoint: str):\n        \"\"\"Return a the response of the input request.\n\n        Args:\n            request (Dict): the input request.\n            node_url (str): the node url.\n            endpoint (str): the endpoint. Such as `/v1/chat/completions`.\n        \"\"\"\n        try:\n            async with aiohttp.ClientSession() as session:\n                async with session.post(node_url + endpoint, json=request, timeout=self.aiotimeout) as response:\n                    return await response.text()\n        except (Exception, GeneratorExit, aiohttp.ClientError, asyncio.CancelledError) as e:  # noqa  # yapf: disable\n            logger.error(f'caught an exception: {e}')\n            return self.handle_api_timeout(node_url)\n\n    async def forward_raw_request_stream_generate(self, raw_request: Request, node_url: str, endpoint: str):\n        try:\n            target_url = node_url.rstrip('/') + endpoint\n            headers = self._prepare_headers(raw_request)\n            body_bytes = await raw_request.body()\n            async with aiohttp.ClientSession() as session:\n                async with session.post(target_url, headers=headers, data=body_bytes,\n                                        timeout=self.aiotimeout) as response:\n                    if response.status != 200:\n                        error_body = await response.read()\n                        raise APIServerException(status_code=response.status, body=error_body)\n                    async for line in response.content:\n                        if line.strip():\n                            yield line + b'\\n\\n'\n        except APIServerException:\n            # raise APIServerException again to be caught by the outer layer\n            raise\n        except (Exception, GeneratorExit, aiohttp.ClientError) as e:  # noqa\n            logger.error(f'caught an exception: {e}')\n            # exception happened, reduce unfinished num\n            yield self.handle_api_timeout(node_url)\n\n    async def forward_raw_request_generate(self, raw_request: Request, node_url: str, endpoint: str):\n        try:\n            target_url = node_url.rstrip('/') + endpoint\n            headers = self._prepare_headers(raw_request)\n            body_bytes = await raw_request.body()\n            async with aiohttp.ClientSession() as session:\n                async with session.post(target_url, headers=headers, data=body_bytes,\n                                        timeout=self.aiotimeout) as response:\n                    return await response.text()\n        except (Exception, GeneratorExit, aiohttp.ClientError, asyncio.CancelledError) as e:  # noqa  # yapf: disable\n            logger.error(f'caught an exception: {e}')\n            return self.handle_api_timeout(node_url)\n\n    def pre_call(self, node_url):\n        \"\"\"Preprocess before the request get processed.\n\n        Args:\n            node_url (str): the node url.\n        \"\"\"\n        self.nodes[node_url].unfinished += 1\n        return time.time()\n\n    def post_call(self, node_url: str, start: int):\n        \"\"\"Post process after the response finished.\n\n        Args:\n            node_url (str): the node url.\n            start (int): the start time point. time.time()\n        \"\"\"\n        if node_url in self.nodes:\n            self.nodes[node_url].unfinished -= 1\n            self.nodes[node_url].latency.append(time.time() - start)\n\n    def create_background_tasks(self, url: str, start: int):\n        \"\"\"To create a background task.\n\n        Args:\n            node_url (str): the node url.\n            start (int): the start time point. time.time()\n        \"\"\"\n        background_tasks = BackgroundTasks()\n        background_tasks.add_task(self.post_call, url, start)\n        return background_tasks\n\n    def _prepare_headers(self, raw_request: Request) -> dict[str, str]:\n        headers = dict((name, value) for name, value in raw_request.headers.items() if name.lower() != 'host')\n\n        client_ip = raw_request.client.host if raw_request.client else 'unknown'\n        headers.update({\n            'X-Forwarded-For': client_ip,\n            'X-Forwarded-Host': raw_request.headers.get('host', ''),\n            'X-Forwarded-Proto': raw_request.url.scheme,\n        })\n\n        return headers\n\n\napp = FastAPI(docs_url='/')\napp.add_middleware(\n    CORSMiddleware,\n    allow_origins=['*'],\n    allow_credentials=True,\n    allow_methods=['*'],\n    allow_headers=['*'],\n)\nnode_manager = NodeManager()\n\n\n@app.get('/v1/models')\ndef available_models():\n    \"\"\"Show available models.\"\"\"\n    model_cards = []\n    for model_name in node_manager.model_list:\n        model_cards.append(ModelCard(id=model_name, root=model_name, permission=[ModelPermission()]))\n    return ModelList(data=model_cards)\n\n\n@app.get('/nodes/status')\ndef node_status():\n    \"\"\"Show nodes status.\"\"\"\n    try:\n        return node_manager.status\n    except:  # noqa\n        return False\n\n\n@app.post('/nodes/add', dependencies=[Depends(validate_json_request)])\ndef add_node(node: Node, raw_request: Request = None):\n    \"\"\"Add a node to the manager.\n\n    - **url** (str): A http url. Can be the url generated by\n      `lmdeploy serve api_server`.\n    - **status** (Dict): The description of the node. An example:\n      ``{models: ['internlm-chat-7b],  speed: 1}``. The speed here can be\n      RPM or other metric. All the values of nodes should be the same metric.\n    \"\"\"\n    try:\n        res = node_manager.add(node.url, node.status)\n        if res is not None:\n            logger.error(f'add node {node.url} failed, {res}')\n            return res\n        logger.info(f'add node {node.url} successfully')\n        return 'Added successfully'\n    except:  # noqa\n        return 'Failed to add, please check the input url.'\n\n\n@app.post('/nodes/remove', dependencies=[Depends(validate_json_request)])\ndef remove_node(node: Node):\n    \"\"\"Show available models.\"\"\"\n    try:\n        node_url = node.url\n        node_manager.remove(node_url)\n        logger.info(f'delete node {node_url} successfully')\n        return 'Deleted successfully'\n    except:  # noqa\n        logger.error(f'delete node {node.url} failed.')\n        return 'Failed to delete, please check the input url.'\n\n\n@app.post('/nodes/terminate', dependencies=[Depends(validate_json_request)])\ndef terminate_node(node: Node):\n    \"\"\"Terminate nodes.\"\"\"\n    try:\n        node_url = node.url\n        success = node_manager.terminate_node(node_url)\n        if not success:\n            return f'Failed to terminate node {node_url}'\n        return 'Terminated successfully'\n    except:  # noqa\n        logger.error(f'Terminate node {node_url} failed.')\n        return 'Failed to terminate node {node_url}, please check the input url.'\n\n\n@app.get('/nodes/terminate_all', dependencies=[Depends(validate_json_request)])\ndef terminate_node_all():\n    \"\"\"Terminate nodes.\"\"\"\n    try:\n        success = node_manager.terminate_all_nodes()\n        if not success:\n            return 'Failed to terminate all nodes'\n        return 'All nodes terminated successfully'\n    except:  # noqa\n        logger.error('Failed to terminate all nodes')\n        return 'Failed to terminate all nodes.'\n\n\n@app.post('/distserve/connection_warmup', dependencies=[Depends(validate_json_request)])\nasync def connection_warmup():\n    await asyncio.gather(*[\n        node_manager.pd_connection_pool.connect(\n            PDConnectionMessage(\n                p_url=p_url,\n                d_url=d_url,\n                protocol=node_manager.migration_protocol,\n                rdma_config=node_manager.rdma_config,\n            )) for p_url in node_manager.prefill_nodes for d_url in node_manager.decode_nodes\n    ])\n    return JSONResponse({'SUCCESS': True})\n\n\n@app.post('/distserve/gc', dependencies=[Depends(validate_json_request)])\nasync def cache_block_gc_to_be_migrated():\n    # TODO (JimyMa): add garbage collection of to be migrated request\n    raise NotImplementedError\n\n\n@app.post('/v1/chat/completions', dependencies=[Depends(validate_json_request)])\nasync def chat_completions_v1(request: ChatCompletionRequest, raw_request: Request = None):\n    \"\"\"Completion API similar to OpenAI's API.\n\n    Refer to https://platform.openai.com/docs/api-reference/chat/create\n    for the API specification.\n\n    The request should be a JSON object with the following fields:\n\n    - **model**: model name. Available from /v1/models.\n    - **messages**: string prompt or chat history in OpenAI format. Chat history\n      example: `[{\"role\": \"user\", \"content\": \"hi\"}]`.\n    - **temperature** (float): to modulate the next token probability\n    - **top_p** (float): If set to float < 1, only the smallest set of most\n      probable tokens with probabilities that add up to top_p or higher\n      are kept for generation.\n    - **n** (int): How many chat completion choices to generate for each input\n      message. **Only support one here**.\n    - **stream**: whether to stream the results or not. Default to false.\n    - **max_completion_tokens** (int | None): output token nums. Default to None.\n    - **max_tokens** (int | None): output token nums. Default to None.\n      Deprecated: Use max_completion_tokens instead.\n    - **repetition_penalty** (float): The parameter for repetition penalty.\n      1.0 means no penalty\n    - **stop** (str | List[str] | None): To stop generating further\n      tokens. Only accept stop words that's encoded to one token idex.\n    - **response_format** (Dict | None): To generate response according to given\n      schema. Examples:\n\n      .. code-block:: json\n\n        {\n          \"type\": \"json_schema\",\n          \"json_schema\":{\n            \"name\": \"test\",\n            \"schema\":{\n              \"properties\":{\n                \"name\":{\"type\":\"string\"}\n              },\n              \"required\":[\"name\"],\n              \"type\":\"object\"\n            }\n          }\n        }\n\n      or\n      ``{\"type\": \"regex_schema\", \"regex_schema\": \"call me [A-Za-z]{1,10}\"}``\n    - **logit_bias** (Dict): Bias to logits. Only supported in pytorch engine.\n    - **tools** (List): A list of tools the model may call. Currently, only\n      internlm2 functions are supported as a tool. Use this to specify a\n      list of functions for which the model can generate JSON inputs.\n    - **tool_choice** (str | object): Controls which (if any) tool is called by\n      the model. `none` means the model will not call any tool and instead\n      generates a message. Specifying a particular tool via\n      ``{\"type\": \"function\", \"function\": {\"name\": \"my_function\"}}``\n      forces the model to call that tool. `auto` or `required` will put all\n      the tools information to the model.\n\n    Additional arguments supported by LMDeploy:\n\n    - **top_k** (int): The number of the highest probability vocabulary\n      tokens to keep for top-k-filtering\n    - **ignore_eos** (bool): indicator for ignoring eos\n    - **skip_special_tokens** (bool): Whether or not to remove special tokens\n      in the decoding. Default to be True.\n    - **min_new_tokens** (int): To generate at least numbers of tokens.\n    - **min_p** (float): Minimum token probability, which will be scaled by the\n      probability of the most likely token. It must be a value between\n      0 and 1. Typical values are in the 0.01-0.2 range, comparably\n      selective as setting `top_p` in the 0.99-0.8 range (use the\n      opposite of normal `top_p` values)\n\n    Currently we do not support the following features:\n\n    - **presence_penalty** (replaced with repetition_penalty)\n    - **frequency_penalty** (replaced with repetition_penalty)\n    \"\"\"\n    check_response = await node_manager.check_request_model(request.model)\n    if check_response is not None:\n        return check_response\n\n    if node_manager.serving_strategy == ServingStrategy.Hybrid:\n        node_url = node_manager.get_node_url(request.model)\n        if not node_url:\n            return node_manager.handle_unavailable_model(request.model)\n\n        logger.info(f'A request is dispatched to {node_url}')\n        start = node_manager.pre_call(node_url)\n        if request.stream is True:\n            response = node_manager.forward_raw_request_stream_generate(raw_request, node_url, '/v1/chat/completions')\n            background_task = node_manager.create_background_tasks(node_url, start)\n            return ProxyStreamingResponse(response, background=background_task, media_type='text/event-stream')\n        else:\n            response = await node_manager.forward_raw_request_generate(raw_request, node_url, '/v1/chat/completions')\n            node_manager.post_call(node_url, start)\n            return JSONResponse(json.loads(response))\n    elif node_manager.serving_strategy == ServingStrategy.DistServe:\n        request_dict = request.model_dump()\n\n        # Prefill\n        prefill_request_dict = copy.deepcopy(request_dict)\n        prefill_request_dict['max_tokens'] = 1\n        prefill_request_dict['max_completion_tokens'] = 1\n        prefill_request_dict['stream'] = False\n        prefill_request_dict['with_cache'] = True\n        prefill_request_dict['preserve_cache'] = True\n\n        prefill_info = {}\n        p_url = 'dummy:dummy'\n        if not node_manager.dummy_prefill:\n            p_url = node_manager.get_node_url(request.model, EngineRole.Prefill)\n            if not p_url:\n                return node_manager.handle_unavailable_model(request.model)\n            logger.info(f'A Prefill request is dispatched to {p_url}')\n\n            start = node_manager.pre_call(p_url)\n            prefill_info = json.loads(await node_manager.generate(prefill_request_dict, p_url, '/v1/chat/completions'))\n            node_manager.post_call(p_url, start)\n\n        # # Decode\n        d_url = node_manager.get_node_url(request.model, EngineRole.Decode)\n        if not d_url:\n            return node_manager.handle_unavailable_model(request.model)\n        logger.info(f'A Decode request is dispatched to {d_url}')\n\n        if not node_manager.dummy_prefill:\n            if not node_manager.pd_connection_pool.is_connected(p_url, d_url):\n                await node_manager.pd_connection_pool.connect(\n                    PDConnectionMessage(\n                        p_url=p_url,\n                        d_url=d_url,\n                        protocol=node_manager.migration_protocol,\n                        rdma_config=node_manager.rdma_config,\n                    ))\n\n        remote_session_id = int(prefill_info.get('id')) if prefill_info.get('id') else 0\n        remote_block_ids = prefill_info.get('cache_block_ids') or []\n        remote_token_id = prefill_info.get('remote_token_ids')[-1] if prefill_info.get('remote_token_ids') else 0\n\n        request_dict['migration_request'] = MigrationRequest(\n            protocol=node_manager.migration_protocol,\n            remote_engine_id=p_url,\n            remote_session_id=remote_session_id,\n            remote_block_ids=remote_block_ids,\n            remote_token_id=remote_token_id,\n            is_dummy_prefill=node_manager.dummy_prefill).model_dump(mode='json')\n\n        start = node_manager.pre_call(d_url)\n        if not node_manager.dummy_prefill:\n            node_manager.pd_connection_pool.shelf_prefill_session((p_url, d_url), prefill_info['id'])\n        if request.stream is True:\n            response = node_manager.stream_generate(request_dict, d_url, '/v1/chat/completions')\n            background_task = node_manager.create_background_tasks(d_url, start)\n            resp = StreamingResponse(response, background=background_task, media_type='text/event-stream')\n        else:\n            response = await node_manager.generate(request_dict, d_url, '/v1/chat/completions')\n            node_manager.post_call(d_url, start)\n            resp = JSONResponse(json.loads(response))\n\n        if not node_manager.dummy_prefill:\n            node_manager.pd_connection_pool.unshelf_prefill_session((p_url, d_url), prefill_info['id'])\n\n        return resp\n\n    else:\n        raise ValueError(f'No serving strategy named {node_manager.serving_strategy}')\n\n\n@app.post('/v1/completions', dependencies=[Depends(validate_json_request)])\nasync def completions_v1(request: CompletionRequest, raw_request: Request = None):\n    \"\"\"Completion API similar to OpenAI's API.\n\n    Go to https://platform.openai.com/docs/api-reference/completions/create\n    for the API specification.\n\n    The request should be a JSON object with the following fields:\n\n    - **model** (str): model name. Available from /v1/models.\n    - **prompt** (str): the input prompt.\n    - **suffix** (str): The suffix that comes after a completion of inserted text.\n    - **max_completion_tokens** (int | None): output token nums. Default to None.\n    - **max_tokens** (int): output token nums. Default to 16.\n      Deprecated: Use max_completion_tokens instead.\n    - **temperature** (float): to modulate the next token probability\n    - **top_p** (float): If set to float < 1, only the smallest set of most\n      probable tokens with probabilities that add up to top_p or higher\n      are kept for generation.\n    - **n** (int): How many chat completion choices to generate for each input\n      message. **Only support one here**.\n    - **stream**: whether to stream the results or not. Default to false.\n    - **repetition_penalty** (float): The parameter for repetition penalty.\n      1.0 means no penalty\n    - **user** (str): A unique identifier representing your end-user.\n    - **stop** (str | List[str] | None): To stop generating further\n      tokens. Only accept stop words that's encoded to one token idex.\n\n    Additional arguments supported by LMDeploy:\n\n    - **ignore_eos** (bool): indicator for ignoring eos\n    - **skip_special_tokens** (bool): Whether or not to remove special tokens\n      in the decoding. Default to be True.\n    - **top_k** (int): The number of the highest probability vocabulary\n      tokens to keep for top-k-filtering\n\n    Currently we do not support the following features:\n\n    - **logprobs** (not supported yet)\n    - **presence_penalty** (replaced with repetition_penalty)\n    - **frequency_penalty** (replaced with repetition_penalty)\n    \"\"\"\n    check_response = await node_manager.check_request_model(request.model)\n    if check_response is not None:\n        return check_response\n    if node_manager.serving_strategy == ServingStrategy.Hybrid:\n        node_url = node_manager.get_node_url(request.model)\n        if not node_url:\n            return node_manager.handle_unavailable_model(request.model)\n\n        logger.info(f'A request is dispatched to {node_url}')\n        start = node_manager.pre_call(node_url)\n        if request.stream is True:\n            response = node_manager.forward_raw_request_stream_generate(raw_request, node_url, '/v1/completions')\n            background_task = node_manager.create_background_tasks(node_url, start)\n            return ProxyStreamingResponse(response, background=background_task, media_type='text/event-stream')\n        else:\n            response = await node_manager.forward_raw_request_generate(raw_request, node_url, '/v1/completions')\n            node_manager.post_call(node_url, start)\n            return JSONResponse(json.loads(response))\n    elif node_manager.serving_strategy == ServingStrategy.DistServe:\n        request_dict = request.model_dump()\n\n        # Prefill\n        prefill_request_dict = copy.deepcopy(request_dict)\n        prefill_request_dict['max_tokens'] = 1\n        prefill_request_dict['stream'] = False\n        prefill_request_dict['with_cache'] = True\n        prefill_request_dict['preserve_cache'] = True\n\n        if not node_manager.dummy_prefill:\n            try:\n                p_url = node_manager.get_node_url(request.model, EngineRole.Prefill)\n            except Exception as e:\n                logger.error(f'error Msg: {str(e)}')\n                return {'status': 'Instance sch error, cannot find available p_url'}\n\n            if not p_url:\n                return node_manager.handle_unavailable_model(request.model)\n            logger.info(f'A Prefill request is dispatched to {p_url}')\n\n            start = node_manager.pre_call(p_url)\n            prefill_info = json.loads(await node_manager.generate(prefill_request_dict, p_url, '/v1/completions'))\n            node_manager.post_call(p_url, start)\n        else:\n            p_url = 'dummy:dummy'\n            prefill_info = {}\n\n        # Decode\n        try:\n            d_url = node_manager.get_node_url(request.model, EngineRole.Decode)\n        except Exception as e:\n            logger.error(f'error Msg: {str(e)}')\n            return {'status': 'Instance sch error, cannot find available p_url'}\n\n        if not d_url:\n            return node_manager.handle_unavailable_model(request.model)\n        logger.info(f'A Decode request is dispatched to {d_url}')\n\n        if not node_manager.dummy_prefill:\n            if not node_manager.pd_connection_pool.is_connected(p_url, d_url):\n                try:\n                    await node_manager.pd_connection_pool.connect(\n                        PDConnectionMessage(\n                            p_url=p_url,\n                            d_url=d_url,\n                            protocol=node_manager.migration_protocol,\n                            rdma_config=node_manager.rdma_config,\n                        ))\n                except Exception as e:\n                    logger.error(f'error Msg: {str(e)}')\n                    return {'status': f'Connection error, cannot establish connection {(p_url, d_url)}'}\n            node_manager.pd_connection_pool.shelf_prefill_session((p_url, d_url), prefill_info['id'])\n\n        remote_session_id = int(prefill_info.get('id')) if prefill_info.get('id') else 0\n        remote_block_ids = prefill_info.get('cache_block_ids') or []\n        remote_token_id = prefill_info.get('remote_token_ids')[-1] if prefill_info.get('remote_token_ids') else 0\n        request_dict['migration_request'] = MigrationRequest(\n            protocol=node_manager.migration_protocol,\n            remote_engine_id=p_url,\n            remote_session_id=remote_session_id,\n            remote_block_ids=remote_block_ids,\n            remote_token_id=remote_token_id,\n            is_dummy_prefill=node_manager.dummy_prefill).model_dump(mode='json')\n\n        start = node_manager.pre_call(d_url)\n        if not node_manager.dummy_prefill:\n            node_manager.pd_connection_pool.shelf_prefill_session((p_url, d_url), prefill_info['id'])\n        if request.stream is True:\n            response = node_manager.stream_generate(request_dict, d_url, '/v1/completions')\n            background_task = node_manager.create_background_tasks(d_url, start)\n            resp = StreamingResponse(response, background=background_task, media_type='text/event-stream')\n        else:\n            response = await node_manager.generate(request_dict, d_url, '/v1/completions')\n            node_manager.post_call(d_url, start)\n            resp = JSONResponse(json.loads(response))\n        if not node_manager.dummy_prefill:\n            node_manager.pd_connection_pool.unshelf_prefill_session((p_url, d_url), prefill_info.get('id'))\n        return resp\n    else:\n        raise ValueError(f'No serving strategy named {node_manager.serving_strategy}')\n\n\ndef proxy(server_name: str = '0.0.0.0',\n          server_port: int = 8000,\n          serving_strategy: Literal['Hybrid', 'DistServe'] = 'Hybrid',\n          routing_strategy: Literal['random', 'min_expected_latency', 'min_observed_latency'] = 'min_expected_latency',\n          api_keys: list[str] | str | None = None,\n          ssl: bool = False,\n          log_level: str = 'INFO',\n          disable_cache_status: bool = False,\n          link_type: Literal['RoCE', 'IB'] = 'RoCE',\n          migration_protocol: Literal['RDMA'] = 'RDMA',\n          dummy_prefill: bool = False,\n          **kwargs):\n    \"\"\"To launch the proxy server.\n\n    Args:\n        server_name (str): the server name of the proxy. Default to '0.0.0.0'.\n        server_port (str): the server port. Default to 8000.\n        serving_strategy ('Hybrid' | 'DistServe'):  the strategy to serving. Hybrid default.\n            DistServe for PD Disaggregation.\n        route_strategy ('random' | 'min_expected_latency' | 'min_observed_latency'):\n            the strategy to dispatch requests to nodes. Default to\n            'min_expected_latency'\n        api_keys (List[str] | str | None): Optional list of API keys. Accepts string type as\n            a single api_key. Default to None, which means no api key applied.\n        ssl (bool): Enable SSL. Requires OS Environment variables 'SSL_KEYFILE' and 'SSL_CERTFILE'.\n        log_level (str): Set the log level. Default to INFO.\n        disable_cache_status (str): Whether to cache the proxy status to\n             proxy_config.yml.\n        migration_protocol: migration protocol when PD disaggregation. RDMA default.\n    \"\"\"  # noqa\n    node_manager.serving_strategy = ServingStrategy[serving_strategy]\n    node_manager.routing_strategy = RoutingStrategy.from_str(routing_strategy)\n    node_manager.migration_protocol = MigrationProtocol[migration_protocol]\n    node_manager.dummy_prefill = dummy_prefill\n\n    node_manager.rdma_config = DistServeRDMAConfig(\n        link_type=RDMALinkType[link_type],\n        with_gdr=True,\n    )\n    node_manager.cache_status = not disable_cache_status\n    if api_keys is not None and (tokens := [key for key in api_keys if key]):\n        from lmdeploy.serve.utils.server_utils import AuthenticationMiddleware\n\n        app.add_middleware(AuthenticationMiddleware, tokens=tokens)\n\n    ssl_keyfile, ssl_certfile = None, None\n    if ssl:\n        ssl_keyfile = os.environ['SSL_KEYFILE']\n        ssl_certfile = os.environ['SSL_CERTFILE']\n\n    logger.setLevel(log_level)\n    uvicorn_log_level = os.getenv('UVICORN_LOG_LEVEL', 'info').lower()\n    uvicorn.run(app=app,\n                host=server_name,\n                port=server_port,\n                log_level=uvicorn_log_level,\n                ssl_keyfile=ssl_keyfile,\n                ssl_certfile=ssl_certfile)\n\n\nif __name__ == '__main__':\n    import fire\n\n    fire.Fire(proxy)\n"
  },
  {
    "path": "lmdeploy/serve/proxy/streaming_response.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nimport json\n\nfrom fastapi.responses import StreamingResponse\n\nfrom .utils import APIServerException\n\n\nclass ProxyStreamingResponse(StreamingResponse):\n    \"\"\"StreamingResponse that can handle exceptions thrown by the generator.\"\"\"\n\n    def __init__(self, content, **kwargs):\n        super().__init__(content, **kwargs)\n\n    async def stream_response(self, send) -> None:\n        iterator = self.body_iterator.__aiter__()\n        try:\n            # get the first chunk(stream_generate's first yield)\n            first_chunk = await iterator.__anext__()\n\n        except APIServerException as e:\n            headers = self._convert_headers_to_asgi(e.headers) if e.headers else self.raw_headers\n            await send({'type': 'http.response.start', 'status': e.status_code, 'headers': headers})\n            await send({\n                'type': 'http.response.body',\n                'body': e.body,\n                'more_body': False,\n            })\n            return\n\n        # normal case, send the header first\n        await send({\n            'type': 'http.response.start',\n            'status': self.status_code,\n            'headers': self.raw_headers,\n        })\n\n        # send body with the first chunk\n        await send({\n            'type': 'http.response.body',\n            'body': first_chunk,\n            'more_body': True,\n        })\n\n        # continue streaming output\n        try:\n            async for chunk in iterator:\n                await send({\n                    'type': 'http.response.body',\n                    'body': chunk,\n                    'more_body': True,\n                })\n        except Exception:\n            error_data = {'error': True, 'status': 500, 'message': 'Internal streaming error'}\n            await send({\n                'type': 'http.response.body',\n                'body': json.dumps(error_data).encode('utf-8'),\n                'more_body': False,\n            })\n            return\n\n        await send({\n            'type': 'http.response.body',\n            'body': b'',\n            'more_body': False,\n        })\n\n    def _convert_headers_to_asgi(self, headers: dict) -> list[tuple[bytes, bytes]]:\n        \"\"\"Convert dict headers to ASGI raw header tuples.\"\"\"\n        return [(name.lower().encode('latin-1'), str(value).encode('latin-1')) for name, value in headers.items()]\n"
  },
  {
    "path": "lmdeploy/serve/proxy/utils.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nimport enum\nimport os\n\nfrom lmdeploy.utils import get_logger\n\nlogger = get_logger('lmdeploy')\n\nLATENCY_DEQUE_LEN = 15\nAIOHTTP_TIMEOUT = os.getenv('AIOHTTP_TIMEOUT', None)\nif AIOHTTP_TIMEOUT is not None:\n    AIOHTTP_TIMEOUT = int(AIOHTTP_TIMEOUT)\nlogger.info(f'AIOHTTP_TIMEOUT set to {AIOHTTP_TIMEOUT}. It can be modified before launching the proxy server '\n            'through env variable AIOHTTP_TIMEOUT')\n\n\nclass RoutingStrategy(enum.Enum):\n    \"\"\"Strategy to dispatch requests to nodes.\"\"\"\n    RANDOM = enum.auto()\n    MIN_EXPECTED_LATENCY = enum.auto()\n    MIN_OBSERVED_LATENCY = enum.auto()\n\n    @classmethod\n    def from_str(cls, name):\n        \"\"\"Get strategy from string.\"\"\"\n        if name == 'random':\n            return cls.RANDOM\n        elif name == 'min_expected_latency':\n            return cls.MIN_EXPECTED_LATENCY\n        elif name == 'min_observed_latency':\n            return cls.MIN_OBSERVED_LATENCY\n        else:\n            raise ValueError(f'Invalid strategy: {name}. Supported: random, '\n                             f'min_expected_latency, min_observed_latency.')\n\n\nclass ErrorCodes(enum.Enum):\n    \"\"\"Error codes.\"\"\"\n    MODEL_NOT_FOUND = 10400\n    SERVICE_UNAVAILABLE = 10401\n    API_TIMEOUT = 10402\n\n\nerr_msg = {\n    ErrorCodes.MODEL_NOT_FOUND: 'The request model name does not exist in the model list.',\n    ErrorCodes.SERVICE_UNAVAILABLE: 'The service is unavailable now. May retry later.',\n    ErrorCodes.API_TIMEOUT: 'Failed to get response after a period of time'\n}\n\n\nclass APIServerException(Exception):\n\n    def __init__(self, status_code: int, body: bytes, headers: dict | None = None):\n        self.status_code = status_code\n        self.body = body\n        self.headers = headers or {}\n        if 'content-type' not in self.headers:\n            self.headers['content-type'] = 'application/json'\n"
  },
  {
    "path": "lmdeploy/serve/utils/__init__.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n"
  },
  {
    "path": "lmdeploy/serve/utils/server_utils.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/openai/server_utils.py\nimport hashlib\nimport secrets\nfrom collections.abc import Awaitable\n\nfrom fastapi import Request\nfrom fastapi.exceptions import RequestValidationError\nfrom fastapi.responses import JSONResponse\nfrom starlette.datastructures import URL, Headers\nfrom starlette.types import ASGIApp, Receive, Scope, Send\n\n\ndef validate_json_request(raw_request: Request):\n    content_type = raw_request.headers.get('content-type', '').lower()\n    media_type = content_type.split(';', maxsplit=1)[0]\n    if media_type != 'application/json':\n        raise RequestValidationError(errors=[\"Unsupported Media Type: Only 'application/json' is allowed\"])\n\n\nclass AuthenticationMiddleware:\n    \"\"\"Pure ASGI middleware that authenticates each request by checking if the\n    Authorization Bearer token exists and equals anyof \"{api_key}\".\n\n    Notes\n    -----\n    There are two cases in which authentication is skipped:\n        1. The HTTP method is OPTIONS.\n        2. The request path doesn't start with /v1 (e.g. /health).\n    \"\"\"\n\n    def __init__(self, app: ASGIApp, tokens: list[str]) -> None:\n        self.app = app\n        self.api_tokens = [hashlib.sha256(t.encode('utf-8')).digest() for t in tokens]\n        # Path prefixes that bypass authentication\n        self.skip_prefixes = [\n            '/health',  # Health check endpoints\n            '/docs',  # Swagger UI documentation\n            '/redoc',  # ReDoc documentation\n            '/nodes',  # Endpoints about node operation between proxy and api_server\n        ]\n\n    def verify_token(self, headers: Headers) -> bool:\n        authorization_header_value = headers.get('Authorization')\n        if not authorization_header_value:\n            return False\n\n        scheme, _, param = authorization_header_value.partition(' ')\n        if scheme.lower() != 'bearer':\n            return False\n\n        param_hash = hashlib.sha256(param.encode('utf-8')).digest()\n\n        token_match = False\n        for token_hash in self.api_tokens:\n            token_match |= secrets.compare_digest(param_hash, token_hash)\n\n        return token_match\n\n    def __call__(self, scope: Scope, receive: Receive, send: Send) -> Awaitable[None]:\n        if scope['type'] not in ('http', 'websocket'):\n            # scope[\"type\"] can be \"lifespan\" or \"startup\" for example,\n            # in which case we don't need to do anything\n            return self.app(scope, receive, send)\n        if scope['type'] == 'http' and scope['method'] == 'OPTIONS':\n            return self.app(scope, receive, send)\n\n        root_path = scope.get('root_path', '')\n        url_path = URL(scope=scope).path.removeprefix(root_path)\n        headers = Headers(scope=scope)\n        if not any(url_path.startswith(path) for path in self.skip_prefixes) and not self.verify_token(headers):\n            response = JSONResponse(content={'error': 'Unauthorized'}, status_code=401)\n            return response(scope, receive, send)\n        return self.app(scope, receive, send)\n"
  },
  {
    "path": "lmdeploy/tokenizer.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport json\nimport os.path as osp\nfrom collections import deque\nfrom dataclasses import dataclass\nfrom functools import partial\nfrom typing import List, Optional, Sequence, Tuple, Union\n\nfrom lmdeploy.utils import get_logger\n\n# this file will be copied to triton server, make sure all\n# importing are starting from the package root lmdeploy\n\n\n@dataclass\nclass DetokenizeState:\n    \"\"\"A state collection of incrementally detekenization.\n\n    Args:\n        ids_offset (int): offset to all input ids. In LMDeploy, the output\n            ids length is not one by one. It could be random by random.\n        prev_tokens (List[str] | None): for incrementally decoding.\n            Default to None, which means the first round.\n        prefix_offset (int): the start index of tokens to be converted to\n            string (prev + new tokens). Default to 0 for the first round.\n        read_offset (int): the end index of tokens to be converted to\n            string (prev token). Default to 0 for the first round.\n    \"\"\"\n    ids_offset: int = 0\n    prev_tokens: Optional[List[str]] = None\n    prefix_offset: int = 0\n    read_offset: int = 0\n\n    def as_tuple(self) -> Tuple:\n        \"\"\"Return a tuple of states.\"\"\"\n        return (self.ids_offset, self.prev_tokens, self.prefix_offset, self.read_offset)\n\n\nclass HuggingFaceTokenizer:\n    \"\"\"A wrapper of transformers' AutoTokenizer.\n\n    Args:\n        model_dir (str): the directory of the tokenizer model\n    \"\"\"\n\n    def __init__(self, model_dir: str):\n        self._check_transformers_version(model_dir)\n        from transformers import AutoTokenizer\n        self.logger = get_logger('lmdeploy')\n        self.model = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)\n        self._prefix_space_tokens = None\n\n        if self.model.eos_token_id is None:\n            generation_config_file = osp.join(model_dir, 'generation_config.json')\n            if osp.exists(generation_config_file):\n                with open(generation_config_file, 'r') as f:\n                    cfg = json.load(f)\n                    self.model.eos_token_id = cfg['eos_token_id']\n            elif hasattr(self.model, 'eod_id'):  # Qwen remote\n                self.model.eos_token_id = self.model.eod_id\n\n        # for stop words\n        self._vocab_size_with_added: int = None\n        self._maybe_decode_bytes: bool = None\n        # TODO maybe lack a constant.py\n        self._indexes_tokens_deque = deque(maxlen=10)\n        self.max_indexes_num = 5\n        self.token2id = {}\n\n    def _check_transformers_version(self, model_dir: str):\n        import transformers\n        from packaging import version\n\n        from lmdeploy.archs import get_model_arch\n\n        logger = get_logger('lmdeploy')\n\n        current_transformers_version = version.parse(transformers.__version__)\n        cfg = get_model_arch(model_dir)[1]\n        cfg_ver = getattr(cfg, 'transformers_version', None)\n        if cfg_ver is None:\n            llm_config = getattr(cfg, 'llm_config', None)\n            if llm_config:\n                cfg_ver = getattr(llm_config, 'transformers_version', None)\n        if cfg_ver is None:\n            return\n        required_transformers_version = version.parse(cfg_ver)\n        if current_transformers_version < required_transformers_version:\n            logger.warning(\n                f'The current version of `transformers` is transformers=={current_transformers_version}, '  # noqa: E501\n                f'which is lower than the required version transformers=={required_transformers_version}. '  # noqa: E501\n                'Please upgrade to the required version.')\n\n    def get_vocab(self):\n        \"\"\"Get vocab.\"\"\"\n        return self.model.get_vocab()\n\n    @property\n    def vocab_size(self):\n        \"\"\"Vocabulary size.\"\"\"\n        return self.model.vocab_size\n\n    @property\n    def vocab_size_with_added(self):\n        \"\"\"Vocabulary size with added vocab.\"\"\"\n        if self._vocab_size_with_added is not None:\n            return self._vocab_size_with_added\n        self._vocab_size_with_added = len(self.model.get_vocab())\n        return self._vocab_size_with_added\n\n    @property\n    def bos_token_id(self):\n        \"\"\"Begin of the sentence token id.\"\"\"\n        return self.model.bos_token_id\n\n    @property\n    def eos_token_id(self):\n        \"\"\"End of the sentence token id.\"\"\"\n        return self.model.eos_token_id\n\n    @property\n    def prefix_space_tokens(self):\n        \"\"\"Tokens without prefix space.\"\"\"\n        if self._prefix_space_tokens is None:\n            vocab = self.model.convert_ids_to_tokens(list(range(self.vocab_size)))\n            self._prefix_space_tokens = {\n                i\n                for i, tok in enumerate(vocab) if tok.startswith('▁' if isinstance(tok, str) else b' ')\n            }\n        return self._prefix_space_tokens\n\n    def _maybe_add_prefix_space(self, tokens: List[int], decoded: str):\n        \"\"\"Maybe add prefix space for incremental decoding.\"\"\"\n        if len(tokens) and not decoded.startswith(' ') and\\\n                tokens[0] in self.prefix_space_tokens:\n            return ' ' + decoded\n        else:\n            return decoded\n\n    @property\n    def maybe_decode_bytes(self):\n        \"\"\"Check if self.model.convert_ids_to_tokens return not a str value.\"\"\"\n        if self._maybe_decode_bytes is None:\n            self._maybe_decode_bytes = False\n            vocab = self.model.convert_ids_to_tokens(list(range(self.vocab_size)))\n            for tok in vocab:\n                if not isinstance(tok, str):\n                    self._maybe_decode_bytes = True\n                    break\n        return self._maybe_decode_bytes\n\n    def indexes_containing_token(self, token: str):\n        \"\"\"Return all the possible indexes, whose decoding output may contain\n        the input token.\"\"\"\n        # traversing vocab is time consuming, can not be accelerated with\n        # multi threads (computation) or multi process (can't pickle tokenizer)\n        # so, we maintain latest 10 stop words and return directly if matched\n        for _token, _indexes in self._indexes_tokens_deque:\n            if token == _token:\n                return _indexes\n\n        if self.token2id == {}:\n            # decode is slower than convert_ids_to_tokens\n            if self.maybe_decode_bytes:\n                for i in range(self.vocab_size):\n                    try:\n                        self.token2id[self.model.decode(i)] = i\n                    except:  # noqa: E722\n                        # some tokens just can't be decoded by `decode`\n                        pass\n            else:\n                self.token2id = {self.model.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}\n        if token == ' ':  # ' ' is special\n            token = '▁'\n        indexes = [i for _token, i in self.token2id.items() if token in _token]\n        if len(indexes) > self.max_indexes_num:\n            # multiple id decode to same token\n            indexes = [i for i in indexes if self.decode([i]) == token]\n            indexes = indexes[:self.max_indexes_num]\n            self.logger.warning(f'There are too many(>{self.max_indexes_num}) possible '\n                                f'indexes may decoding {token}, we will use {indexes} only')\n        # there might be token id that exceeds self.vocab_size\n        if len(indexes) == 0:\n            indexes = self.encode(token, False)\n            if len(indexes) != 1:\n                self.logger.warning(f'The token {token}, its length of indexes {indexes} is '\n                                    'not 1. Currently, it can not be used as stop words')\n                indexes = []\n        self._indexes_tokens_deque.append((token, indexes))\n        return indexes\n\n    def encode(self, s: str, add_bos: bool = True, add_special_tokens: bool = True, **kwargs):\n        \"\"\"Tokenize a prompt.\n\n        Args:\n            s (str): a prompt\n            add_bos (bool): Whether to add `bos` token id when encoding\n                the prompt\n            add_special_tokens (bool): Whether or not to add special tokens\n                when encoding the prompt\n        Returns:\n            list[int]: token ids\n        \"\"\"\n        encoded = self.model.encode(s, add_special_tokens=add_special_tokens, **kwargs)\n        if not add_bos:\n            # in the middle of a session\n            if len(encoded) and encoded[0] == self.bos_token_id:\n                encoded = encoded[1:]\n        return encoded\n\n    def decode(self, t: Sequence[int], offset: Optional[int] = None, skip_special_tokens: bool = True):\n        \"\"\"De-tokenize.\n\n        Args:\n            t (List[int]): a list of token ids\n            offset (int): for incrementally decoding. Default to None, which\n                means not applied.\n            skip_special_tokens (bool): Whether or not to remove special\n                tokens in the decoding.\n        Returns:\n            str: text of decoding tokens\n        \"\"\"\n        t = t[offset:]\n        out_string = self.model.decode(t, skip_special_tokens=skip_special_tokens)\n        if offset:\n            logger = get_logger('lmdeploy')\n            logger.warning('For incrementally detokenization, please try '\n                           'detokenize_incrementally function instead.')\n            out_string = self._maybe_add_prefix_space(t, out_string)\n        return out_string\n\n    @staticmethod\n    def _convert_tokens_to_string_with_added_encoders(\n        tokenizer,\n        output_tokens: List[str],\n        skip_special_tokens: bool,\n        spaces_between_special_tokens: bool,\n    ) -> str:\n        if tokenizer.is_fast or not tokenizer.get_added_vocab():\n            return tokenizer.convert_tokens_to_string(output_tokens)\n        # Adapted from\n        # https://github.com/vllm-project/vllm/blob/v0.2.7/vllm/transformers_utils/tokenizer.py#L68-L99\n        sub_texts = []\n        current_sub_text = []\n        all_special_tokens = set(tokenizer.all_special_tokens)\n        for token in output_tokens:\n            if skip_special_tokens and token in all_special_tokens:\n                continue\n            if token in tokenizer.get_added_vocab():\n                if current_sub_text:\n                    sub_text = tokenizer.convert_tokens_to_string(current_sub_text)\n                    sub_texts.append(sub_text)\n                    current_sub_text = []\n                sub_texts.append(token)\n            else:\n                current_sub_text.append(token)\n        if current_sub_text:\n            sub_text = tokenizer.convert_tokens_to_string(current_sub_text)\n            sub_texts.append(sub_text)\n        if spaces_between_special_tokens:\n            return ' '.join(sub_texts)\n        else:\n            return ''.join(sub_texts)\n\n    # Based on\n    # https://github.com/vllm-project/vllm/blob/v0.2.7/vllm/transformers_utils/tokenizer.py#L105-L165\n    def detokenize_incrementally(self,\n                                 all_input_ids: Sequence[int],\n                                 state: DetokenizeState,\n                                 skip_special_tokens: bool = True,\n                                 spaces_between_special_tokens: bool = True):\n        \"\"\"Incrementally detokenize the input indexes.\n\n        Args:\n            all_input_ids (List[int]): a list of token ids. Expected to be\n                different sections of a long sequence.\n            state (DetokenizeState): an instance of DetokenizeState. Consists\n                of incrementally decoding states.\n            skip_special_tokens (bool): Whether or not to remove special tokens\n                in the decoding. Default to be True.\n            spaces_between_special_tokens (bool): Whether or not to add spaces\n                between special tokens. Default to be True.\n        Returns:\n            str: decoding output string of the current round.\n            state (DetokenizeState): an instance of DetokenizeState. Consists\n                of incrementally decoding states.\n        \"\"\"\n        tokenizer = self.model\n        ids_offset, prev_tokens, prefix_offset, read_offset = state.as_tuple()\n        # This is the first iteration for this sequence\n        new_tokens = tokenizer.convert_ids_to_tokens(all_input_ids[ids_offset:],\n                                                     skip_special_tokens=skip_special_tokens)\n        # `convert_ids_to_tokens` returns None for out-of-range token_id\n        new_tokens = new_tokens or []\n        new_tokens = [x for x in new_tokens if x is not None] if None in new_tokens else new_tokens\n        if prev_tokens is None:\n            # Please notice that in VLLM, indexes are detokenized one by one\n            # while in LMDeploy, every turn, the detokenized indexes length\n            # can be different.\n            prev_tokens = tokenizer.convert_ids_to_tokens(all_input_ids[:ids_offset],\n                                                          skip_special_tokens=skip_special_tokens)\n            # `convert_ids_to_tokens` returns None for out-of-range token_id\n            prev_tokens = prev_tokens or []\n            prev_tokens = [x for x in prev_tokens if x is not None] if None in prev_tokens else prev_tokens\n            read_offset = len(prev_tokens)\n            if skip_special_tokens and new_tokens and new_tokens[0] in tokenizer.all_special_ids:\n                read_offset = read_offset + 1  # skip special token\n\n        output_tokens = prev_tokens + new_tokens\n        prev_tokens += new_tokens\n        prefix_text = self._convert_tokens_to_string_with_added_encoders(\n            tokenizer,\n            output_tokens[prefix_offset:read_offset],\n            skip_special_tokens=skip_special_tokens,\n            spaces_between_special_tokens=spaces_between_special_tokens,\n        )\n        new_text = self._convert_tokens_to_string_with_added_encoders(\n            tokenizer,\n            output_tokens[prefix_offset:],\n            skip_special_tokens=skip_special_tokens,\n            spaces_between_special_tokens=spaces_between_special_tokens,\n        )\n\n        # update state and get final decoded output\n        if len(new_text) > len(prefix_text) and not new_text.endswith('�'):\n            # utf-8 char at the end means it's a potential unfinished byte\n            # sequence from byte fallback tokenization.\n            # If it's in the middle, it's probably a real invalid id generated\n            # by the model\n            prefix_offset = read_offset\n            read_offset = len(output_tokens)\n            new_text = new_text[len(prefix_text):]\n        else:\n            new_text = ''\n\n        return new_text, DetokenizeState(len(all_input_ids), prev_tokens, prefix_offset, read_offset)\n\n    def __call__(self, s: Union[str, Sequence[str]]):\n        \"\"\"Tokenize prompts.\n\n        Args:\n            s (str): prompts\n        Returns:\n            list[int]: token ids\n        \"\"\"\n        add_special_tokens = False\n        return self.model(s, add_special_tokens=add_special_tokens)\n\n\nclass ChatGLM4Tokenizer(HuggingFaceTokenizer):\n    \"\"\"Tokenizer of GLM4.\"\"\"\n\n    def __init__(self, model_path):\n        super(ChatGLM4Tokenizer, self).__init__(model_path)\n        original_pad = self.model._pad\n\n        def __pad(*args, **kwargs):\n            if 'padding_side' in kwargs:\n                kwargs.pop('padding_side')\n            return original_pad(*args, **kwargs)\n\n        # fix for transformers>4.45.0\n        self.model._pad = __pad\n\n    def encode(self, s: str, add_bos: bool = True, add_special_tokens: bool = True, **kwargs):\n        \"\"\"Tokenize a prompt.\"\"\"\n        # ChtGLM4Tokenizer hardcode `add_speical_tokens=False` when tokenizing\n        # a prompt. Refer to https://huggingface.co/THUDM/glm-4-9b-chat/blob/main/tokenization_chatglm.py#L227 # noqa E501\n        return super(ChatGLM4Tokenizer, self).encode(s, add_bos, add_special_tokens=False, **kwargs)\n\n\nclass ChatGLMTokenizer(HuggingFaceTokenizer):\n    \"\"\"Tokenizer of GLM2.\"\"\"\n\n    def __init__(self, model_path):\n        super(ChatGLMTokenizer, self).__init__(model_path)\n        original_pad = self.model._pad\n\n        def __pad(*args, **kwargs):\n            if 'padding_side' in kwargs:\n                kwargs.pop('padding_side')\n            return original_pad(*args, **kwargs)\n\n        # fix for transformers>4.45.0\n        self.model._pad = __pad\n\n\nclass GptOssTokenizer(HuggingFaceTokenizer):\n    \"\"\"Tokenizer of GPT-OSS.\"\"\"\n\n    def __init__(self, model_dir: str):\n        super(GptOssTokenizer, self).__init__(model_dir)\n        from openai_harmony import HarmonyEncodingName, Role, StreamableParser, load_harmony_encoding\n        encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)\n        self.role = Role.ASSISTANT\n        self.parser = partial(StreamableParser, encoding, role=Role.ASSISTANT)\n\n    def detokenize_incrementally(self,\n                                 all_input_ids: Sequence[int],\n                                 state: DetokenizeState,\n                                 skip_special_tokens: bool = True,\n                                 spaces_between_special_tokens: bool = True):\n        if not hasattr(state, 'stream'):\n            state.stream = self.parser()\n\n        response = ''\n        stream = state.stream\n        for token_id in all_input_ids[state.ids_offset:]:\n            stream.process(token_id)\n            if stream.current_channel in ['final', 'analysis'] and stream.current_role == self.role:\n                response += stream.last_content_delta or ''\n\n        state.ids_offset = len(all_input_ids)\n        return response, state\n\n\nclass Tokenizer:\n    \"\"\"Tokenize prompts or de-tokenize tokens into texts.\n\n    Args:\n        model_path (str): the path of the tokenizer model\n    \"\"\"\n\n    def __init__(self, model_path: str):\n        from transformers import AutoConfig, PretrainedConfig\n        try:\n            model_cfg = AutoConfig.from_pretrained(model_path, trust_remote_code=True)\n        except Exception as e:  # noqa\n            model_cfg = PretrainedConfig.from_pretrained(model_path, trust_remote_code=True)\n        is_gpt_oss = getattr(model_cfg, 'model_type', '') == 'gpt_oss'\n        from transformers.models.auto.tokenization_auto import get_tokenizer_config\n        tokenizer_config = get_tokenizer_config(model_path, trust_remote_code=True)\n        config_tokenizer_class = tokenizer_config.get('tokenizer_class')\n        if config_tokenizer_class == 'ChatGLM4Tokenizer':\n            self.model = ChatGLM4Tokenizer(model_path)\n        elif config_tokenizer_class == 'ChatGLMTokenizer':\n            self.model = ChatGLMTokenizer(model_path)\n        elif is_gpt_oss:\n            self.model = GptOssTokenizer(model_path)\n        else:\n            self.model = HuggingFaceTokenizer(model_path)\n        self.logger = get_logger('lmdeploy')\n\n    @property\n    def vocab_size(self):\n        \"\"\"Vocabulary size.\"\"\"\n        return self.model.vocab_size\n\n    @property\n    def bos_token_id(self):\n        \"\"\"Begin of the sentence token id.\"\"\"\n        return self.model.bos_token_id\n\n    @property\n    def eos_token_id(self):\n        \"\"\"End of the sentence token id.\"\"\"\n        return self.model.eos_token_id\n\n    def get_vocab(self):\n        \"\"\"Get vocab.\"\"\"\n        return self.model.get_vocab()\n\n    def encode(self, s: str, add_bos: bool = True, add_special_tokens: bool = True, **kwargs):\n        \"\"\"Tokenize a prompt.\n\n        Args:\n            s (str): a prompt\n            add_bos (bool): Whether to add `bos` token id when encoding\n                the prompt\n            add_special_tokens (bool): Whether or not to add special tokens\n                when encoding the prompt\n        Returns:\n            list[int]: token ids\n        \"\"\"\n        encoded = self.model.encode(s, add_bos, add_special_tokens, **kwargs)\n        if encoded[:2] == [self.bos_token_id] * 2:\n            self.logger.warning(f'Detected duplicate bos token {self.bos_token_id} in prompt, '\n                                'this will likely reduce response quality, one of them will be'\n                                'removed')\n            encoded = encoded[1:]\n        return encoded\n\n    def decode(\n        self,\n        t: Sequence[int],\n        offset: Optional[int] = None,\n        skip_special_tokens: bool = True,\n    ):\n        \"\"\"De-tokenize.\n\n        Args:\n            t (List[int]): a list of token ids\n            offset (int): for incrementally decoding. Default to None, which\n                means not applied.\n            skip_special_tokens (bool): Whether or not to remove special\n                tokens in the decoding.\n        Returns:\n            str: text of decoding tokens\n        \"\"\"\n        return self.model.decode(t, offset, skip_special_tokens)\n\n    def detokenize_incrementally(self,\n                                 all_input_ids: Sequence[int],\n                                 state: DetokenizeState,\n                                 skip_special_tokens: bool = True,\n                                 spaces_between_special_tokens: bool = True):\n        \"\"\"Incrementally detokenize the input indexes.\n\n        Args:\n            all_input_ids (List[int]): a list of token ids. Expected to be\n                different sections of a long sequence.\n            state (DetokenizeState): an instance of DetokenizeState. Consists\n                of incrementally decoding states.\n            skip_special_tokens (bool): Whether or not to remove special tokens\n                in the decoding. Default to be True.\n            spaces_between_special_tokens (bool): Whether or not to add spaces\n                between special tokens. Default to be True.\n        Returns:\n            str: decoding output string of the current round.\n            state (DetokenizeState): an instance of DetokenizeState. Consists\n                of incrementally decoding states.\n        \"\"\"\n        return self.model.detokenize_incrementally(all_input_ids,\n                                                   state=state,\n                                                   skip_special_tokens=skip_special_tokens,\n                                                   spaces_between_special_tokens=spaces_between_special_tokens)\n\n    def __call__(self, s: Union[str, Sequence[str]]):\n        \"\"\"Tokenize prompts.\n\n        Args:\n            s (str): prompts\n        Returns:\n            list[int]: token ids\n        \"\"\"\n        return self.model(s)\n\n    def indexes_containing_token(self, token):\n        \"\"\"Return all the possible indexes, whose decoding output may contain\n        the input token.\"\"\"\n        encoded = self.encode(token, add_bos=False)\n        if len(encoded) > 1:\n            self.logger.warning(f'The token {token}, its length of indexes {encoded} is over '\n                                'than 1. Currently, it can not be used as stop words')\n            return []\n        return self.model.indexes_containing_token(token)\n"
  },
  {
    "path": "lmdeploy/turbomind/__init__.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\n\ndef bootstrap():\n    import os\n    import sys\n\n    has_turbomind = False\n    pwd = os.path.dirname(__file__)\n    if os.path.exists(os.path.join(pwd, '..', 'lib')):\n        has_turbomind = True\n    if os.name == 'nt' and has_turbomind:\n        if sys.version_info[:2] >= (3, 8):\n            CUDA_PATH = os.getenv('CUDA_PATH')\n            assert CUDA_PATH is not None, 'Can not find $env:CUDA_PATH'\n            dll_path = os.path.join(CUDA_PATH, 'bin')\n            print(f'Add dll path {dll_path}, please note cuda version '\n                  'should >= 11.3 when compiled with cuda 11')\n            os.add_dll_directory(dll_path)\n\n\nbootstrap()\n\nfrom .turbomind import TurboMind, update_parallel_config  # noqa: E402\n\n__all__ = ['TurboMind', 'update_parallel_config']\n"
  },
  {
    "path": "lmdeploy/turbomind/deploy/__init__.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n"
  },
  {
    "path": "lmdeploy/turbomind/deploy/config.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport inspect\nimport json\nfrom dataclasses import asdict, field, fields\nfrom typing import List\n\n# use pydantic.dataclasses.dataclass to check data type\nfrom pydantic.dataclasses import dataclass\n\nfrom lmdeploy.messages import TurbomindEngineConfig\nfrom lmdeploy.utils import get_logger\n\nlogger = get_logger('lmdeploy')\n\n\ndef config_from_dict(cls, env):\n    \"\"\"Initiate an instance of a config class from a dict.\"\"\"\n    params = inspect.signature(cls).parameters\n    used = {k: v for k, v in env.items() if k in params and v is not None}\n\n    def _remove_none(d: dict):\n        for k, v in d.items():\n            if isinstance(v, dict):\n                d[k] = _remove_none(v)\n        return {k: v for k, v in d.items() if v is not None}\n\n    used = _remove_none(used)\n    return cls(**used)\n\n\ndef config_to_dict(config):\n    \"\"\"Export config to a dict.\"\"\"\n    if not config:\n        return dict()\n    assert isinstance(config, (ModelConfig, AttentionConfig, LoraConfig)), \\\n        f'A dataclass is expected, but got {type(config)}'\n\n    return asdict(config)\n\n\n@dataclass\nclass ModelConfig:\n    model_name: str = ''\n    chat_template: str = ''\n    model_arch: str = None\n    head_num: int = None\n    kv_head_num: int = None\n    hidden_units: int = None\n    vocab_size: int = None\n    # Turbomind used to assume token_embedding and lm_head has the same size\n    # at vocab dim, i.e. `vocab_size`\n    # But in molmo, embedding.shape is [vocab_size + 128, hidden_units]\n    # while lm_head shape is [hidden_units, vocab_size].\n    # Therefore, we add a new attr \"embedding_size\" to represent the vocab dim\n    # of token_embedding\n    embedding_size: int = 0\n    num_layer: int = None\n    inter_size: List[int] = None\n    norm_eps: float = None\n    attn_bias: int = 0\n    mlp_bias: bool = False\n    window_size: List[int] = field(default_factory=list)\n    attn_sink: bool = False\n    qk_norm: bool = False\n    size_per_head: int = 128\n    group_size: int = 32\n    data_type: str = None\n    weight_type: str = None\n    expert_weight_type: str = None\n    ffn_weight_type: str = None\n    session_len: int = None\n    attn_tp_size: int = 1\n    attn_cp_size: int = 1\n    mlp_tp_size: int = 1\n    model_format: str = 'hf'\n    expert_num: List[int] = ()\n    expert_router_bias: bool = False\n    expert_inter_size: int = 0\n    experts_per_token: int = 0\n    activation_type: str = ''\n    moe_shared_gate: bool = False\n    norm_topk_prob: bool = False\n    routed_scale: float = 1.0\n    topk_group: int = 1\n    topk_method: str = 'greedy'\n    moe_group_num: int = 1\n    scoring_func: str = 'softmax'\n    router_n_groups: int = -1\n    # MLA\n    q_lora_rank: int = 0\n    kv_lora_rank: int = 0\n    qk_rope_dim: int = 0\n    v_head_dim: int = 0\n    # Qwen 3.5\n    layer_types: List[str] = field(default_factory=list)\n    linear_key_head_dim: int = 0\n    linear_value_head_dim: int = 0\n    linear_conv_kernel_dim: int = 0\n    linear_num_key_heads: int = 0\n    linear_num_value_heads: int = 0\n    attn_output_gate: bool = False\n    # Per-layer expert weight type override: layer indices whose\n    # MoE experts are unquantized (fp16) despite expert_weight_type=int4.\n    # Populated from modules_to_not_convert patterns like 'model.layers.0.'.\n    unquantized_expert_layers: List[int] = field(default_factory=list)\n    # tuning\n    tune_layer_num: int = 1\n\n    def verify(self):\n        invalid = {}\n        for k, v in self.__dict__.items():\n            if v is None:\n                invalid[k] = v\n        assert not invalid, f'incomplete model config: {invalid}'\n\n\n@dataclass\nclass RopeParam:\n    type: str\n    base: float\n    dim: int\n    factor: float = 1.0\n    max_position_embeddings: int = None\n    attention_factor: float = 1.0\n    beta_fast: float = 32\n    beta_slow: float = 1\n    low_freq_factor: float = None\n    high_freq_factor: float = None\n    original_max_position_embeddings: int = None\n    mrope_section: List[int] = None\n\n\n@dataclass\nclass AttentionConfig:\n    softmax_scale: float = 0\n    cache_block_seq_len: int = 64\n    use_logn_attn: int = 0\n    max_position_embeddings: int = 0\n    rope_param: RopeParam = None\n\n\n@dataclass\nclass LoraConfig:\n    lora_policy: str = ''\n    lora_r: int = 0\n    lora_scale: float = 0.0\n    lora_max_wo_r: int = 0\n    lora_rank_pattern: str = ''\n    lora_scale_pattern: str = ''\n\n\n@dataclass\nclass TurbomindModelConfig:\n    \"\"\"Config for turbomind model.\"\"\"\n    model_config: ModelConfig = None\n    attention_config: AttentionConfig = None\n    lora_config: LoraConfig = None\n\n    def update_from_engine_config(self, config: TurbomindEngineConfig):\n        \"\"\"Update the attributes of this instance with the attributes from\n        TurbomindEngineConfig.\n\n        Args:\n            config (TurbomindEngineConfig): The turbomind engine config\n        \"\"\"\n        if config is None:\n            return\n        for key, value in asdict(config).items():\n            if not value:\n                continue\n\n            if hasattr(self.model_config, key):\n                setattr(self.model_config, key, value)\n            if hasattr(self.attention_config, key):\n                setattr(self.attention_config, key, value)\n\n        # update from hf_overrides\n        if hasattr(config, 'hf_overrides') and config.hf_overrides:\n            hf_overrides = config.hf_overrides\n\n            if hf_overrides.get('rope_scaling'):\n                override_params = hf_overrides.get('rope_scaling')\n\n                rope_param = self.attention_config.rope_param or RopeParam(type='', base=0, dim=0)\n                rope_param.type = override_params.get('rope_type', '')\n                if rope_param.type == 'yarn' and 'original_max_position_embeddings' in override_params:\n                    rope_param.factor = self.attention_config.max_position_embeddings / override_params[\n                        'original_max_position_embeddings']\n                    rope_param.max_position_embeddings = override_params['original_max_position_embeddings']\n                else:\n                    rope_param.factor = override_params.get('factor', 1.0)\n                    rope_param.max_position_embeddings = override_params.get('original_max_position_embeddings', None)\n\n                self.attention_config.rope_param = rope_param\n            logger.warning(f'Overriding HF config with {hf_overrides}')\n\n        # use dynamic ntk\n        if config.rope_scaling_factor:\n            # some ut will create empty RopeParam, will check base/dim in src code\n            rope_param = self.attention_config.rope_param or RopeParam(type='', base=0, dim=0)\n            rope_param.type = 'dynamic'\n            rope_param.factor = config.rope_scaling_factor\n            rope_param.max_position_embeddings = self.attention_config.max_position_embeddings\n\n            self.attention_config.rope_param = rope_param\n            logger.warning(\n                '`--rope-scaling-factor` will be removed in a future release. Please instead use `--hf-overrides`.')\n\n    @classmethod\n    def from_dict(cls, config: dict = {}):\n        \"\"\"Construct TurbomindModelConfig instance from config in a dict.\"\"\"\n        _cfg = {field.name: config.get(field.name, {}) for field in fields(TurbomindModelConfig)}\n\n        return TurbomindModelConfig(model_config=config_from_dict(ModelConfig, _cfg['model_config']),\n                                    attention_config=config_from_dict(AttentionConfig, _cfg['attention_config']),\n                                    lora_config=config_from_dict(LoraConfig, _cfg['lora_config']))\n\n    def to_dict(self):\n        \"\"\"Export to a dict.\"\"\"\n        return dict(model_config=config_to_dict(self.model_config),\n                    attention_config=config_to_dict(self.attention_config),\n                    lora_config=config_to_dict(self.lora_config))\n\n    @property\n    def session_len(self):\n        return self.model_config.session_len\n\n    @property\n    def weight_type(self):\n        return self.model_config.weight_type\n\n    @property\n    def group_size(self):\n        return self.model_config.group_size\n\n    @property\n    def vocab_size(self):\n        return self.model_config.vocab_size\n\n    def __str__(self):\n        return json.dumps(self.to_dict(), indent=2)\n"
  },
  {
    "path": "lmdeploy/turbomind/deploy/converter.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport torch\n\nfrom lmdeploy.archs import get_model_arch, search_nested_config\nfrom lmdeploy.messages import TurbomindEngineConfig\nfrom lmdeploy.utils import get_logger\n\nfrom ...utils import _get_and_verify_max_len, is_bf16_supported\nfrom ..supported_models import SUPPORTED_ARCHS\nfrom .config import TurbomindModelConfig\nfrom .module import Transformer\nfrom .policy import get_input_policy\nfrom .source_model.base import INPUT_MODELS\nfrom .target_model.base import OUTPUT_MODELS, BaseOutputModel\n\nSUPPORTED_FORMATS = ['hf', 'awq', 'gptq', 'fp8', None]\nlogger = get_logger('lmdeploy')\n\n\ndef get_input_model_registered_name(model_path: str, model_format: str):\n    \"\"\"Get the registered name of a model. The name will be used to access the\n    INPUT_MODELS registry.\n\n    Args:\n        model_path (str): the path of the input model\n        model_format (str): the format of the model, which can be one of\n            ['hf', 'awq', 'gptq']\n    \"\"\"\n    arch = get_model_arch(model_path)[0]\n    register_name = SUPPORTED_ARCHS[arch]\n    return register_name\n\n\ndef get_output_model_registered_name_and_config(model_path: str, model_format: str, dtype: str, group_size: int):\n    \"\"\"Get the registered name of the turbomind model and its configuration\n    according to the input model path, format and user-input config. The name\n    will be used to access the OUTPUT_MODELS registry.\n\n    Args:\n        model_path (str): the path of the input model\n        model_format (str): the format of the model, which can be one of\n            ['hf', 'awq', 'gptq']\n        dtype (str): the data type of the model's weights and activations\n        group_size (int): the size of group used by awq model\n    \"\"\"\n    register_name = 'tm'\n\n    has_bf16 = is_bf16_supported()\n\n    model_arch, model_config = get_model_arch(model_path)\n\n    # infer dtype from device and model config\n    if dtype == 'auto':\n        # pick dtype by device as default\n        dtype = 'bfloat16' if has_bf16 else 'float16'\n        # dtype from model (prefer `dtype` over deprecated `torch_dtype`)\n        torch_dtype = getattr(model_config, 'dtype', None)\n        if torch_dtype is None:\n            torch_dtype = getattr(model_config, 'torch_dtype', None)\n        if not torch_dtype:\n            if model_arch in ['QWenLMHeadModel', 'GptOssForCausalLM']:\n                torch_dtype = torch.bfloat16\n        TORCH_DTYPE_MAP = {torch.bfloat16: 'bfloat16', torch.float16: 'float16'}\n        dtype = TORCH_DTYPE_MAP.get(torch_dtype, dtype)\n\n    if dtype == 'bfloat16' and not has_bf16:\n        logger.warning('data type fallback to float16 since '\n                       'torch.cuda.is_bf16_supported is False')\n        dtype = 'float16'\n\n    weight_type = dtype\n\n    config = TurbomindModelConfig.from_dict()\n\n    session_len = _get_and_verify_max_len(model_config, None)\n\n    if model_format in ['awq', 'gptq', 'compressed-tensors']:\n        weight_type = 'int4'\n        dtype = 'float16'  # force float16 for int4 quantized weights\n        group_size = 128 if group_size == 0 else group_size\n        if model_format == 'compressed-tensors':\n            model_format = 'awq'\n    elif model_format == 'fp8':\n        weight_type = 'fp8'\n        group_size = 128\n    elif model_format == 'mxfp4':\n        weight_type = 'e2m1'\n        group_size = 32\n\n    expert_weight_type = weight_type\n\n    # ONLY experts are in mxfp4\n    if model_arch == 'GptOssForCausalLM':\n        weight_type = dtype\n\n    # Three weight types control allocation for mixed quantization:\n    #   weight_type        - attention weights\n    #   ffn_weight_type    - dense FFN / shared expert weights\n    #   expert_weight_type - MoE routed expert weights\n    #\n    # The assignment order matters:\n    #   1. expert_weight_type = original weight_type (before any overrides)\n    #   2. GptOss override:   weight_type -> dtype  (attn + shared experts are fp16)\n    #   3. ffn_weight_type  = weight_type           (captures post-GptOss value)\n    #   4. Mixed AWQ override: weight_type -> dtype  (only attn becomes fp16)\n    #\n    #                  weight_type   ffn_weight_type   expert_weight_type\n    #  Pure fp16       float16       float16           float16\n    #  Full AWQ        int4          int4              int4\n    #  Mixed AWQ       float16       int4              int4\n    #  GptOss mxfp4    bfloat16      bfloat16          e2m1\n    ffn_weight_type = weight_type\n\n    # When attention weights are not quantized (e.g. AWQ with self_attn in\n    # modules_to_not_convert), weight_type becomes fp16 for attention.\n    # ffn_weight_type and expert_weight_type retain int4.\n    if model_format in ['awq', 'gptq'] and weight_type != dtype:\n        quant_config = getattr(model_config, 'quantization_config', None)\n        if quant_config is None:\n            quant_config = {}\n        if isinstance(quant_config, dict):\n            modules_to_not_convert = quant_config.get('modules_to_not_convert') or []\n        else:\n            modules_to_not_convert = getattr(quant_config, 'modules_to_not_convert', None) or []\n        if any('self_attn' in m for m in modules_to_not_convert):\n            weight_type = dtype\n        if any('shared_expert' in m for m in modules_to_not_convert):\n            ffn_weight_type = dtype\n        # Detect per-layer exclusions like 'model.layers.0.' which mean\n        # ALL weights in that layer (including MoE experts) are fp16.\n        import re as _re\n        unquantized_expert_layers = []\n        for m in modules_to_not_convert:\n            _m = _re.match(r'model\\.layers\\.(\\d+)\\.?$', m)\n            if _m:\n                unquantized_expert_layers.append(int(_m.group(1)))\n        config.model_config.unquantized_expert_layers = unquantized_expert_layers\n\n    config.model_config.model_arch = model_arch\n    config.model_config.data_type = dtype\n    config.model_config.weight_type = weight_type\n    config.model_config.expert_weight_type = expert_weight_type\n    config.model_config.ffn_weight_type = ffn_weight_type\n    config.model_config.model_format = model_format\n    config.model_config.group_size = group_size\n    config.model_config.session_len = session_len\n\n    return register_name, config\n\n\ndef get_tm_model(model_path,\n                 model_name,\n                 chat_template_name,\n                 engine_config: TurbomindEngineConfig,\n                 group_size: int = None,\n                 out_dir: str = None) -> BaseOutputModel:\n    \"\"\"Create turbomind model.\n\n    Args:\n        model_path (str): the path of the input model, which is supposed\n            to be a local path, or huggingface hub repo_id, or modelscope\n            hub repo_id\n        model_name (str): user customized model name\n        chat_template_name (str): the name of the chat template of\n            the input model\n        engine_config(TurbomindEngineConfig): user input engine config\n        group_size(int): refers to the group_size if the input model\n            is a w4a16(awq or gptq) quantized model\n        out_dir(str): the output directory where to save to turbomind model.\n            If it is None, the turbomind model won't be saved\n    \"\"\"\n    _, cfg = get_model_arch(model_path)\n    quant_config = search_nested_config(cfg.to_dict(), 'quantization_config')\n    mixed_awq = False\n    if quant_config:\n        quant_method = quant_config.get('quant_method')\n        _group_size = int(quant_config.get('group_size', 0))\n        version = quant_config.get('version')\n        assert engine_config.model_format is None or engine_config.model_format == quant_method, (\n            f'mismatched quant method: user input \"{engine_config.model_format}\" '\n            f'vs model quant_config \"{quant_method}\"')\n        assert not group_size or group_size == _group_size, (f'mismatched quant group size: user input \"{group_size}\" '\n                                                             f'vs model quant_config \"{_group_size}\"')\n\n        if quant_method == 'awq':\n            assert version == 'gemm', f'unsupported quant config: {quant_config}'\n            modules_to_not_convert = quant_config.get('modules_to_not_convert') or []\n            if any('self_attn' in name for name in modules_to_not_convert):\n                mixed_awq = True\n        elif quant_method == 'gptq':\n            assert not quant_config.get('desc_act', False) and quant_config.get(\n                'sym', True), f'unsupported quant config: {quant_config}'\n        elif quant_method == 'fp8':\n            pass\n        elif quant_method == 'mxfp4':\n            _group_size = 32\n        elif quant_method == 'compressed-tensors':\n            _format = quant_config['config_groups']['group_0']['format']\n            assert _format == 'pack-quantized', ('compressed-tennsors only supports pack-quantized format, '\n                                                 f'but got {_format}')\n            _weights = quant_config['config_groups']['group_0']['weights']\n            _group_size = _weights['group_size']\n            _num_bits = _weights['num_bits']\n            _type = _weights['type']\n            assert _num_bits == 4 and _type == 'int', ('pack-quantized requires 4-bit int, '\n                                                       f'but got {_num_bits}-bit {_type}')\n        else:\n            assert 0, f'unsupported quant_config: {quant_config}'\n\n        engine_config.model_format = quant_method\n        group_size = _group_size\n\n    if engine_config.model_format in ['awq', 'gptq', 'compressed-tensors']:\n        # Compatible to awq models that are quantized by lmdeploy (<=v0.3.0)\n        if not group_size:\n            group_size = 128\n        assert group_size == 128, (f'model format is \"{engine_config.model_format}\" '\n                                   f'but group_size is {group_size}. Currently, only 128 '\n                                   'is supported')\n\n    input_model_name = get_input_model_registered_name(model_path, engine_config.model_format)\n\n    fp8_quant = (engine_config.model_format == 'fp8' and not quant_config)\n    input_policy = get_input_policy(engine_config.model_format)\n    input_model = INPUT_MODELS.get(input_model_name)(model_path=model_path,\n                                                     tokenizer_path=model_path,\n                                                     input_policy=input_policy,\n                                                     fp8_quant=fp8_quant)\n\n    output_model_name, tm_cfg = get_output_model_registered_name_and_config(model_path=model_path,\n                                                                            model_format=engine_config.model_format,\n                                                                            dtype=engine_config.dtype,\n                                                                            group_size=group_size)\n\n    if mixed_awq:\n        # Mixed-precision AWQ: attention weights are fp16 (not quantized),\n        # but expert weights remain as int4 AWQ for efficient inference.\n        tm_cfg.model_config.weight_type = tm_cfg.model_config.data_type\n        # expert_weight_type stays as 'int4' (set by get_output_model_registered_name_and_config)\n\n    tm_cfg.model_config.chat_template = chat_template_name\n    tm_cfg.model_config.model_name = model_name\n\n    if engine_config.attn_tp_size is not None:\n        tm_cfg.model_config.attn_tp_size = engine_config.attn_tp_size\n    if engine_config.attn_cp_size is not None:\n        tm_cfg.model_config.attn_cp_size = engine_config.attn_cp_size\n    if engine_config.mlp_tp_size is not None:\n        tm_cfg.model_config.mlp_tp_size = engine_config.mlp_tp_size\n\n    output_model = OUTPUT_MODELS.get(output_model_name)(input_model=input_model,\n                                                        cfg=tm_cfg,\n                                                        model_cls=Transformer,\n                                                        out_dir=out_dir)\n\n    return output_model\n"
  },
  {
    "path": "lmdeploy/turbomind/deploy/loader.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport json\nimport os.path as osp\nimport re\nfrom abc import ABC, abstractmethod\nfrom collections import defaultdict\nfrom functools import partial\nfrom glob import glob\nfrom queue import Queue\nfrom typing import Iterator, Tuple, Union\n\nimport torch\nfrom safetensors import safe_open\n\n# https://github.com/huggingface/transformers/blob/53fad641cfdb5105e2470bcf3ef17ea8e25cc300/src/transformers/modeling_utils.py#L372\nWEIGHT_INDEX_NAME = 'pytorch_model.bin.index.json'\nWEIGHT_PATTERN = 'pytorch_model*.bin'\nSAFE_WEIGHT_INDEX_NAME = 'model.safetensors.index.json'\nSAFE_WEIGHT_PATTERN = 'model*.safetensors'\nEXTRA_WEIGHT_PATTERNS = ['*.pt', '*.bin']\nEXTRA_SAFE_WEIGHT_PATTERN = '*.safetensors'\n\n\nclass BaseLoader(ABC):\n\n    def __init__(self, model_path: str, pattern, mappings: list):\n        self.model_path = model_path\n        self.pattern = pattern\n        self.item_count = defaultdict(int)\n        self.mappings = mappings\n\n    def get_index(self, index_name: str, file_pattern: str) -> Tuple[dict, list]:\n        \"\"\"Get shards and weight map (if possible) for the model.\"\"\"\n        get_path = partial(osp.join, self.model_path)\n        shards = []\n        if index_name:\n            with open(get_path(index_name), 'r') as f:\n                index = json.load(f)\n            index = index['weight_map']\n            shards = list(map(get_path, set(index.values())))\n        else:\n            index = {}\n            shards = glob(get_path(file_pattern))\n        if not shards:\n            raise RuntimeError(f'failed to locate weight files for {self.model_path}')\n        return sorted(shards), index\n\n    def map_key(self, key: str):\n        if self.mappings:\n            k = str(key)\n            for f in self.mappings:\n                k = f(k)\n            return k\n        else:\n            return key\n\n    @abstractmethod\n    def items(self) -> Iterator[Tuple[int, dict]]:\n        pass\n\n\nclass SafetensorsLoader(BaseLoader):\n\n    def __init__(self, model_path: str, pattern: str, mappings: list, index_name=None, file_pattern=None):\n        super().__init__(model_path, pattern, mappings)\n        self.shards, index = self.get_index(index_name, file_pattern)\n        if not index:\n            # there is no model.safetensors.index.json in the model_path,\n            # read tensor form the safetensor file and update the index\n            for shard in self.shards:\n                filename = osp.basename(shard)\n                with safe_open(shard, 'pt') as f:\n                    index.update({k: filename for k in f.keys()})\n        # self.index maps weight names to their corresponding safetensors file name\n        self.index = index\n        # count layer-wise parameters\n        for k in index.keys():\n            match = re.findall(self.pattern, k)\n            if match:\n                self.item_count[int(match[0])] += 1\n\n    def items(self):\n        params = defaultdict(dict)\n        for shard in self.shards:\n            with safe_open(shard, 'pt') as f:\n                misc = []\n                filename = osp.basename(shard)\n                for k in f.keys():\n                    # Filtering logic:\n                    # - Exclude weights not found in the mapping\n                    # - Exclude duplicated weights (present in multiple files)\n                    if k not in self.index or self.index[k] != filename:\n                        continue\n                    match = re.findall(self.pattern, k)\n                    if not match:\n                        misc.append(k)\n                    else:\n                        idx = int(match[0])\n                        param = params[idx]\n                        param[self.map_key(k)] = f.get_tensor(k)\n                        if len(param) == self.item_count[idx]:\n                            yield (idx, params.pop(idx))\n                if misc:\n                    yield (-1, {k: f.get_tensor(k) for k in misc})\n        assert not params\n\n\nclass PytorchLoader(BaseLoader):\n\n    def __init__(self, model_path: str, pattern: str, mappings: list, index_name=None, file_pattern=None):\n        super().__init__(model_path, pattern, mappings)\n        self.shards, index = self.get_index(index_name, file_pattern)\n        for k in index.keys():\n            match = re.findall(self.pattern, k)\n            if match:\n                self.item_count[int(match[0])] += 1\n\n    def items(self):\n        params = defaultdict(dict)\n        for shard in self.shards:\n            misc = {}\n            tmp = torch.load(shard, map_location='cpu', weights_only=True)\n            for k, v in tmp.items():\n                match = re.findall(self.pattern, k)\n                if not match:\n                    misc[k] = v\n                else:\n                    idx = int(match[0])\n                    params[idx][k] = v\n            del tmp\n            if misc:\n                yield (-1, misc)\n                misc.clear()\n            ready = []\n            if self.item_count:\n                for idx, param in params.items():\n                    if len(param) == self.item_count[idx]:\n                        ready.append(idx)\n            else:\n                ready = sorted(params.keys())[:-1]\n            for idx in ready:\n                yield (idx, params.pop(idx))\n        idxs = sorted(params.keys())\n        for idx in idxs:\n            yield (idx, params.pop(idx))\n\n\nclass StateDictLoader:\n    \"\"\"This loader is used for `update_params`.\n\n    Currently, the item in the queue should be full state dict of a decoder layer or the meta of the model (embedding,\n    lm_head, norm).\n    \"\"\"\n\n    def __init__(self, queue: Queue, pattern: str, mappings: list):\n        self.que = queue\n        self.pattern = pattern\n\n    def items(self):\n        for data in iter(self.que.get, None):\n            # If data is state dict of a decoder layer, any key will match the pattern.\n            # Otherwise, none of the keys will match the pattern.\n            for k in data.keys():\n                match = re.findall(self.pattern, k)\n                break\n\n            if not match:\n                yield (-1, data)\n            else:\n                idx = int(match[0])\n                yield (idx, data)\n\n            torch.cuda.empty_cache()\n            self.que.task_done()\n\n\ndef create_loader(model_path: Union[str, Queue], pattern: str, mappings: list) -> BaseLoader:\n    args = (model_path, pattern, mappings)\n\n    if isinstance(model_path, Queue):\n        # used for `update_params`\n        return StateDictLoader(*args)\n\n    if osp.exists(osp.join(model_path, SAFE_WEIGHT_INDEX_NAME)):\n        return SafetensorsLoader(*args, index_name=SAFE_WEIGHT_INDEX_NAME)\n\n    if glob(osp.join(model_path, SAFE_WEIGHT_PATTERN)):\n        return SafetensorsLoader(*args, file_pattern=SAFE_WEIGHT_PATTERN)\n\n    if osp.exists(osp.join(model_path, WEIGHT_INDEX_NAME)):\n        return PytorchLoader(*args, index_name=WEIGHT_INDEX_NAME)\n\n    if glob(osp.join(model_path, WEIGHT_PATTERN)):\n        return PytorchLoader(*args, file_pattern=WEIGHT_PATTERN)\n\n    # non-standard safetensors model (*.safetensors)\n    if glob(osp.join(model_path, EXTRA_SAFE_WEIGHT_PATTERN)):\n        return SafetensorsLoader(*args, file_pattern=EXTRA_SAFE_WEIGHT_PATTERN)\n\n    # non-standard pytorch model (*.bin, *.pt)\n    for p in EXTRA_WEIGHT_PATTERNS:\n        if glob(osp.join(model_path, p)):\n            return PytorchLoader(*args, file_pattern=p)\n\n    raise RuntimeError(f'Failed to find valid loader for {model_path}')\n"
  },
  {
    "path": "lmdeploy/turbomind/deploy/module.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom abc import ABC, abstractmethod\nfrom functools import partial\n\nimport torch\n\nfrom .parameter import get_params\nfrom .source_model.base import BaseReader\nfrom .target_model.base import BaseOutputModel\n\n\ndef permute_v2(x: torch.Tensor, size_per_head: int = 128):\n    \"\"\"\n        Contract: x.size(-1) is output dims\n    \"\"\"\n\n    assert x.size(-1) > 1\n\n    output_dims = x.size(-1)\n    head_num = output_dims // size_per_head\n\n    return x.view(-1, head_num, 2, size_per_head // 2).transpose(2, 3).reshape(x.shape)\n\n\ndef permute_v2_partial(x: torch.Tensor, size_per_head: int, rotary_dim: int):\n    \"\"\"Permute only the first rotary_dim elements of each head.\n\n    Used when partial_rotary_factor < 1.0: only the rotary portion needs interleaving for TurboMind's RoPE kernel\n    layout.\n    \"\"\"\n    assert x.size(-1) > 1\n    assert rotary_dim % 2 == 0, f'rotary_dim must be even, got {rotary_dim}'\n    assert rotary_dim <= size_per_head, f'rotary_dim ({rotary_dim}) must be <= size_per_head ({size_per_head})'\n    output_dims = x.size(-1)\n    assert output_dims % size_per_head == 0, (f'output_dims ({output_dims}) must be divisible by '\n                                              f'size_per_head ({size_per_head})')\n    head_num = output_dims // size_per_head\n    orig_shape = x.shape\n    if x.dim() == 1:\n        x = x.unsqueeze(0)\n    x = x.view(x.size(0), head_num, size_per_head)\n    rotary = x[:, :, :rotary_dim]\n    passthrough = x[:, :, rotary_dim:]\n    # Interleave rotary part: [2, rotary_dim//2] -> [rotary_dim//2, 2]\n    rotary = rotary.view(x.size(0), head_num, 2, rotary_dim // 2).transpose(2, 3).contiguous()\n    rotary = rotary.view(x.size(0), head_num, rotary_dim)\n    x = torch.cat([rotary, passthrough], dim=-1)\n    return x.reshape(orig_shape)\n\n\ndef merge_qkv_v2(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, tp: int):\n    \"\"\"\n        Contract: x.size(-1) is output dims\n    \"\"\"\n\n    def reshape(x):\n        return x.view(x.size(0), tp, -1) if q.dim() == 2 else x.view(tp, -1)\n\n    qkv = torch.cat(tuple(map(reshape, (q, k, v))), dim=-1)\n\n    qkv = qkv.view(-1, qkv.size(-1) * tp)\n    if q.dim() == 1:\n        qkv.squeeze_()\n\n    return qkv\n\n\ndef merge_qkvg_v2(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, gate: torch.Tensor, tp: int):\n    \"\"\"Merge Q, K, V, and Gate with gate appended after V.\n\n    Layout per tp-shard: [Q | K | V | Gate].\n    \"\"\"\n\n    def reshape(x):\n        return x.view(x.size(0), tp, -1) if q.dim() == 2 else x.view(tp, -1)\n\n    qkvg = torch.cat(tuple(map(reshape, (q, k, v, gate))), dim=-1)\n\n    qkvg = qkvg.view(-1, qkvg.size(-1) * tp)\n    if q.dim() == 1:\n        qkvg.squeeze_()\n\n    return qkvg\n\n\ndef transpose(x):\n    return x.t() if x is not None else x\n\n\ndef pad_out_dims(x: torch.Tensor, dims: int):\n    pad = dims - x.size(-1)\n    assert pad >= 0\n    return torch.nn.functional.pad(x, (0, pad), 'constant', 0)\n\n\ndef pad_in_dims(x: torch.Tensor, dims: int):\n    if x.dim() == 1:  # 1-dim object does not have input dim (e.g. bias)\n        return x\n    pad = dims - x.size(0)\n    assert x.dim() == 2\n    assert pad >= 0\n    return torch.nn.functional.pad(x, (0, 0, 0, pad), 'constant', 0)\n\n\n# split out dims -> copy A, split-out-dims B (qkv, w1, w3)\n# split  in dims -> split-in-dims A,  copy B (  o, w2)\ndef get_lora_flags(kind: str):\n    return ('lora_a' in kind, 'lora_b' in kind)\n\n\nclass Module(ABC):\n\n    def __init__(self, model: BaseOutputModel):\n        self.model = model\n\n    def __call__(self, *args, **kwargs):\n        return self.apply(*args, **kwargs)\n\n    @abstractmethod\n    def apply(self, idx: int, r: BaseReader):\n        pass\n\n\nclass LayerNorm(Module):\n\n    def apply(self, i: int, r: BaseReader):\n        attn_norm = r.attn_norm(i)\n        ffn_norm = r.ffn_norm(i)\n        self.model.save_split(attn_norm, f'layers.{i}.attention_norm.weight')\n        self.model.save_split(ffn_norm, f'layers.{i}.ffn_norm.weight')\n\n\nclass Ffn(Module):\n    \"\"\"\n    requires:\n        r.ffn(i, kind)\n    \"\"\"\n\n    _ffn = 'layers.{0}.feed_forward.{1}.{2}'\n\n    def __init__(self, model: BaseOutputModel):\n        self.model = model\n        self.tp = model.mlp_tp_size\n        # inter_sizes in config are padded and may be different from what's\n        # in the weights\n        self.inter_size = model.model_config.inter_size\n        self.group_size = max(1, model.model_config.group_size)\n\n    def _export(self, inter_size: int, fmt: str, idx: int, w123, kind: str, pack_fn, apply_gs=[], **kwargs):\n        is_lora_a, is_lora_b = get_lora_flags(kind)\n        w1, w2, w3 = map(transpose, w123)\n\n        gs1 = self.group_size if 'w1' in apply_gs else 1\n        w1 = pad_out_dims(w1, inter_size // gs1)\n\n        gs3 = self.group_size if 'w3' in apply_gs else 1\n        w3 = pad_out_dims(w3, inter_size // gs3)\n\n        gs2 = self.group_size if 'w2' in apply_gs else 1\n        w2 = pad_in_dims(w2, inter_size // gs2)\n\n        w1, w2, w3 = map(pack_fn, (w1, w2, w3))\n        self.model.save_split(w1, fmt.format(idx, 'w1', kind), split_dim=-1, split_num=self.tp, copy=is_lora_a)\n        self.model.save_split(w3, fmt.format(idx, 'w3', kind), split_dim=-1, split_num=self.tp, copy=is_lora_a)\n        self.model.save_split(w2, fmt.format(idx, 'w2', kind), split_dim=0, split_num=self.tp, copy=is_lora_b)\n\n    def apply(self, i: int, r: BaseReader):\n        if i >= len(self.inter_size) or not self.inter_size[i]:\n            return\n        keys = r.ffn(i, None)\n\n        for e in get_params(keys):\n            e(partial(self._export, self.inter_size[i], self._ffn), partial(r.ffn, i), i)\n\n\nclass MoeFfn(Ffn):\n    \"\"\"\n    requires:\n        r.moe_ffn_expert(e, i, kind)\n        r.moe_ffn_gate(i)\n        r.moe_ffn_shared_gate(i)\n    \"\"\"\n\n    _moe_ffn_expert = 'layers.{0}.moe_ffn.experts.E.{1}.{2}'\n    _moe_ffn_gate = 'layers.{0}.moe_ffn.gate.{1}'\n    _moe_ffn_shared_gate = 'layers.{0}.moe_ffn.shared_gate.weight'\n\n    def __init__(self, model: BaseOutputModel):\n        super().__init__(model)\n        self.expert_num = model.model_config.expert_num\n        self.inter_size = model.model_config.expert_inter_size\n        self.shared_gate = model.model_config.moe_shared_gate\n\n    def apply(self, i: int, r: BaseReader):\n        if i >= len(self.expert_num) or self.expert_num[i] == 0:\n            return\n\n        # Export expert weights with outer loop over experts (not params)\n        # to ensure each expert's full weight set is grouped together\n        for e in range(self.expert_num[i]):\n            for p in get_params(r.moe_ffn_expert(), 1):\n                fmt = self._moe_ffn_expert.replace('E', str(e))\n                p(partial(self._export, self.inter_size, fmt), partial(r.moe_ffn_expert, e, i), i)\n\n        # router\n        gate = transpose(r.moe_ffn_gate(i, 'weight'))\n        self.model.save_split(gate, self._moe_ffn_gate.format(i, 'weight'))\n        bias = r.moe_ffn_gate(i, 'bias')\n        if bias is not None:\n            self.model.save_split(bias, self._moe_ffn_gate.format(i, 'bias'))\n\n        # Export score_correction_bias for noaux_tc routing (GLM 4.7 Flash)\n        correction_bias = getattr(r, 'moe_ffn_gate_correction_bias', None)\n        if callable(correction_bias):\n            correction = correction_bias(i)\n            if correction is not None:\n                self.model.save_split(correction, self._moe_ffn_gate.format(i, 'score_correction_bias'))\n\n        if self.shared_gate:\n            shared_gate = transpose(r.moe_ffn_shared_gate(i))\n            self.model.save_split(shared_gate, self._moe_ffn_shared_gate.format(i))\n\n\nclass Attn(Module):\n    \"\"\"\n    requires:\n        r.attn(i, kind)\n    \"\"\"\n\n    _attn = 'layers.{0}.attention.{1}.{2}'\n\n    def __init__(self, model: BaseOutputModel):\n        self.model = model\n        self.tp = model.attn_tp_size\n        self.head_dim = model.model_config.size_per_head\n        self.attn_bias = model.model_config.attn_bias\n        self.qk_norm = model.model_config.qk_norm\n        self.attn_sink = model.model_config.attn_sink\n        self.group_size = max(1, model.model_config.group_size)\n        self.attn_output_gate = model.model_config.attn_output_gate\n        rope_param = model.attention_config.rope_param\n        self.rope_dim = rope_param.dim if rope_param else self.head_dim\n        self.head_num = model.model_config.head_num\n\n    def _split_q_gate(self, q):\n        \"\"\"Split interleaved Q+gate tensor into separate Q and gate.\n\n        HF layout: [Q_head0, Gate_head0, Q_head1, Gate_head1, ...]\n        Returns: (q_real, gate) each with shape [..., num_heads * head_dim]\n        \"\"\"\n        output_dims = q.size(-1)\n        head_num = output_dims // (self.head_dim * 2)\n        orig_shape = list(q.shape)\n        if q.dim() == 1:\n            q = q.unsqueeze(0)\n        q = q.view(q.size(0), head_num, 2, self.head_dim)\n        q_real = q[:, :, 0, :].contiguous()\n        gate = q[:, :, 1, :].contiguous()\n        new_last_dim = head_num * self.head_dim\n        q_real = q_real.reshape(-1, new_last_dim)\n        gate = gate.reshape(-1, new_last_dim)\n        if len(orig_shape) == 1:\n            q_real = q_real.squeeze(0)\n            gate = gate.squeeze(0)\n        return q_real, gate\n\n    def _reorder_and_merge(self, qkvo, gs: int):\n        q, k, v, o = qkvo\n        gate = None\n        # When attn_output_gate, Q is interleaved [Q0, G0, Q1, G1, ...]\n        # Split into separate Q and gate before permuting\n        if self.attn_output_gate and q is not None:\n            q, gate = self._split_q_gate(q)\n        # reorder output dim for tm's rotary embedding layout\n        if self.model.permute_qk:\n            if gs == 1:\n                if self.rope_dim < self.head_dim:\n                    q = permute_v2_partial(q, self.head_dim, self.rope_dim)\n                    k = permute_v2_partial(k, self.head_dim, self.rope_dim)\n                else:\n                    q = permute_v2(q, self.head_dim)\n                    k = permute_v2(k, self.head_dim)\n            else:\n                assert gs % self.head_dim == 0\n        # Merge QKV with gate appended at end if present\n        if gate is not None:\n            qkv = merge_qkvg_v2(q, k, v, gate, self.tp)\n        else:\n            qkv = merge_qkv_v2(q, k, v, self.tp)\n        # zero bias for `wo` when `w_qkv` has bias but `wo` doesn't\n        if o is None and q.dim() == 1:\n            o = torch.zeros_like(q)\n        return qkv, o\n\n    def _repeat_kv(self, qkvo, gs: int, kind: str):\n        \"\"\"Replicate kv.\"\"\"\n        q, k, v, o = qkvo\n        head_dim = self.model.model_config.size_per_head // gs\n        kv_head_num = self.model.model_config.kv_head_num // self.model.repeat_kv\n        hidden_dim = self.model.model_config.hidden_units\n\n        def _repeat(x):\n            n = self.model.repeat_kv\n\n            x = x.reshape(-1, kv_head_num, head_dim)\n            x = x.repeat(1, 1, n)\n            x = x.reshape(-1, kv_head_num * n * head_dim)\n\n            return x\n\n        k, v = map(_repeat, (k, v))\n\n        if kind == 'bias':\n            if o is None:\n                o = torch.zeros(hidden_dim, dtype=q.dtype, device=q.device)\n            q, k, v, o = map(torch.squeeze, (q, k, v, o))\n\n        return (q, k, v, o)\n\n    def _export(self, idx: int, qkvo, kind: str, pack_fn, apply_gs=[], **kwargs):\n        if all(x is None for x in qkvo):\n            return\n        is_lora_a, is_lora_b = get_lora_flags(kind)\n        assert not (is_lora_a or is_lora_b)\n\n        qkvo = tuple(map(transpose, qkvo))\n\n        gs = self.group_size if ('w1' in apply_gs) else 1\n\n        if self.model.repeat_kv:\n            qkvo = self._repeat_kv(qkvo, gs, kind)\n\n        qkv, o = self._reorder_and_merge(qkvo, gs)\n\n        self.model.save_split(pack_fn(qkv),\n                              self._attn.format(idx, 'w_qkv', kind),\n                              split_dim=-1,\n                              split_num=self.tp,\n                              copy=is_lora_a)\n        self.model.save_split(pack_fn(o),\n                              self._attn.format(idx, 'wo', kind),\n                              split_dim=0,\n                              split_num=self.tp,\n                              copy=is_lora_b)\n\n    def apply(self, i: int, r: BaseReader):\n        for e in get_params(r.attn(i, None), bias=self.attn_bias):\n            e(self._export, partial(r.attn, i), i)\n        if self.qk_norm:\n            q, k = r.qk_norm(i)\n            if q is not None and k is not None:\n                if self.model.permute_qk:\n                    if self.rope_dim < self.head_dim:\n                        q = permute_v2_partial(q, self.head_dim, self.rope_dim)\n                        k = permute_v2_partial(k, self.head_dim, self.rope_dim)\n                    else:\n                        q = permute_v2(q, self.head_dim)\n                        k = permute_v2(k, self.head_dim)\n                self.model.save_split(q, self._attn.format(i, 'q_norm', '')[:-1])\n                self.model.save_split(k, self._attn.format(i, 'k_norm', '')[:-1])\n        if self.attn_sink:\n            sinks = r.attn_sinks(i)\n            self.model.save_split(sinks, self._attn.format(i, 'sinks', '')[:-1], split_dim=-1, split_num=self.tp)\n\n\nclass MLA(Module):\n    \"\"\"\n    requires:\n        r.mla(i, kind)\n        r.mla_norm(i)\n    \"\"\"\n\n    _mla = 'layers.{0}.attention.{1}.{2}'\n\n    def __init__(self, model: BaseOutputModel):\n        self.model = model\n\n    def _export(self, idx: int, xs, kind: str, pack_fn, **kwargs):\n        if all(x is None for x in xs):\n            return\n        q_a, q_b, q, kv_a, kv_b, o = xs\n\n        cfg = self.model.model_config\n        head_num = cfg.head_num\n        kv_lora_rank = cfg.kv_lora_rank\n        qk_rope_dim = cfg.qk_rope_dim\n        size_per_head = cfg.size_per_head\n        v_head_dim = cfg.v_head_dim\n\n        # ========== MLA Weight Folding for Dimension Mismatch ==========\n        # When kv_lora_rank != qk_nope_dim (e.g., GLM 4.7 Flash: 512 != 512+64=576),\n        # fold the kc/vc compression/decompression BMMs into q_b_proj/o_proj weights\n        # at conversion time to avoid runtime overhead.\n        if kind == 'weight' and kv_lora_rank and q is None and q_b is not None and kv_b is not None and o is not None:\n            if not (torch.is_floating_point(q_b) and torch.is_floating_point(kv_b) and torch.is_floating_point(o)):\n                raise ValueError('MLA weight folding requires floating-point attention weights.')\n\n            orig_q_head_dim = q_b.size(0) // head_num\n            orig_qk_nope_dim = orig_q_head_dim - qk_rope_dim\n            orig_kv_dim_total = kv_b.size(0) // head_num\n            orig_v_head_dim = o.size(1) // head_num\n            actual_orig_qk_nope_dim = orig_kv_dim_total - orig_v_head_dim\n\n            if abs(orig_qk_nope_dim - actual_orig_qk_nope_dim) > 1:\n                raise ValueError(f'Dimension mismatch: inferred qk_nope from q_b ({orig_qk_nope_dim}) != '\n                                 f'inferred from kv_b ({actual_orig_qk_nope_dim})')\n\n            orig_qk_nope_dim = actual_orig_qk_nope_dim\n            target_nope_dim = size_per_head - qk_rope_dim\n            target_v_head_dim = v_head_dim\n\n            if orig_qk_nope_dim != target_nope_dim or orig_v_head_dim != target_v_head_dim:\n                if target_nope_dim != kv_lora_rank or target_v_head_dim != kv_lora_rank:\n                    raise ValueError(f'MLA folding expects v_head_dim and nope_dim to equal kv_lora_rank, '\n                                     f'got nope={target_nope_dim}, v_head={target_v_head_dim}, rank={kv_lora_rank}')\n\n                if kv_b.size(1) != kv_lora_rank:\n                    raise ValueError(f'kv_b_proj second dim must equal kv_lora_rank for MLA folding, '\n                                     f'got {kv_b.size(1)} != {kv_lora_rank}')\n\n                # Split kv_b into kc and vc\n                kv_b_per_head = kv_b.reshape(head_num, orig_qk_nope_dim + orig_v_head_dim, kv_lora_rank)\n                kc_w = kv_b_per_head[:, :orig_qk_nope_dim, :]\n                vc_w = kv_b_per_head[:, orig_qk_nope_dim:, :]\n\n                # Fold kc into q_b_proj\n                q_b_per_head = q_b.reshape(head_num, orig_q_head_dim, q_b.size(1))\n                q_nope_w = q_b_per_head[:, :orig_qk_nope_dim, :]\n                q_rope_w = q_b_per_head[:, orig_qk_nope_dim:, :]\n                q_nope_expanded = torch.bmm(kc_w.transpose(1, 2), q_nope_w)\n                q_b_folded = torch.cat([q_nope_expanded, q_rope_w], dim=1)\n                q_b = q_b_folded.reshape(head_num * size_per_head, q_b.size(1))\n\n                # Fold vc into o_proj\n                o_per_head = o.reshape(o.size(0), head_num, orig_v_head_dim)\n                o_folded = torch.bmm(o_per_head.permute(1, 0, 2), vc_w)\n                o = o_folded.permute(1, 0, 2).reshape(o.size(0), head_num * kv_lora_rank)\n\n                # Set kv_b to identity (kc/vc are now absorbed)\n                eye = torch.eye(kv_lora_rank, dtype=kv_b.dtype, device=kv_b.device)\n                kv_b = torch.cat([eye, eye], dim=0).repeat(head_num, 1)\n        # ========== End MLA Weight Folding ==========\n\n        # Transpose after folding\n        q_a, q_b, q, kv_a, kv_b, o = map(transpose, (q_a, q_b, q, kv_a, kv_b, o))\n\n        if q is not None:\n            q_b = q\n\n        # Pad o_proj to size_per_head if present\n        if o is not None:\n            o = o.reshape(head_num, v_head_dim, -1)\n            o = torch.nn.functional.pad(o, (0, 0, size_per_head - v_head_dim, 0, 0, 0))\n            o = o.view(head_num * size_per_head, cfg.hidden_units)\n\n        tp = self.model.attn_tp_size\n\n        # Export MLA weights (handle None for folded-away tensors)\n        if q_a is not None:\n            self.model.save_split(pack_fn(q_a), self._mla.format(idx, 'q_a_proj', kind))\n        q_b_name = 'q_proj' if q_a is None else 'q_b_proj'\n        if q_b is not None:\n            self.model.save_split(pack_fn(q_b), self._mla.format(idx, q_b_name, kind), split_dim=-1, split_num=tp)\n        if kv_a is not None:\n            self.model.save_split(pack_fn(kv_a), self._mla.format(idx, 'kv_a_proj', kind))\n        # if kv_b is not None:\n        #     self.model.save_split(pack_fn(kv_b), self._mla.format(idx, 'kv_b_proj', kind), split_dim=-1, split_num=tp)\n        if o is not None:\n            self.model.save_split(pack_fn(o), self._mla.format(idx, 'wo', kind), split_dim=0, split_num=tp)\n\n    _layernorm = 'layers.{0}.attention.{1}_a_layernorm'\n\n    def apply(self, i: int, r: BaseReader):\n\n        for f in get_params(r.attn(i, None), bias=False):\n            f(self._export, partial(r.mla, i), i)\n\n        q, k = r.mla_norm(i)\n        if q is not None:\n            self.model.save_split(q, self._layernorm.format(i, 'q'))\n        self.model.save_split(k, self._layernorm.format(i, 'kv'))\n\n\nclass LinearAttn(Module):\n    _linear_attn = 'layers.{0}.linear_attn.{1}.{2}'\n\n    def __init__(self, model: BaseOutputModel):\n        self.model = model\n        self.tp = model.attn_tp_size\n        cfg = model.model_config\n        self.key_dim = cfg.linear_num_key_heads * cfg.linear_key_head_dim\n        self.value_dim = cfg.linear_num_value_heads * cfg.linear_value_head_dim\n\n    def _tp_interleave_qkv(self, tensor, dim):\n        \"\"\"Split a concatenated [Q, K, V] tensor into components, reshape each\n        for TP interleaving, and re-concatenate.\n\n        in_proj_qkv layout along ``dim``: Q(key_dim) | K(key_dim) | V(value_dim).\n        A naive split doesn't respect component boundaries when key_dim and\n        value_dim differ.  This method splits Q/K/V, reshapes each to\n        ``(tp, -1)`` along ``dim``, concatenates per-TP-shard, then flattens\n        so that a subsequent ``save_split(split_dim=dim)`` gives each rank the\n        correct portion.\n        \"\"\"\n        if dim < 0:\n            dim = tensor.dim() + dim\n        q, k, v = torch.split(tensor, [self.key_dim, self.key_dim, self.value_dim], dim=dim)\n\n        def reshape(x):\n            # Move TP axis to a new dimension right after ``dim``\n            shape = list(x.shape)\n            d = shape[dim]\n            new_shape = shape[:dim] + [self.tp, d // self.tp] + shape[dim + 1:]\n            return x.view(new_shape)\n\n        parts = torch.cat([reshape(q), reshape(k), reshape(v)], dim=dim + 1)\n        # Collapse tp and per-shard dims back\n        shape = list(parts.shape)\n        final_shape = shape[:dim] + [shape[dim] * shape[dim + 1]] + shape[dim + 2:]\n        return parts.reshape(final_shape)\n\n    def apply(self, i: int, r: BaseReader):\n        layer_types = getattr(self.model.model_config, 'layer_types', [])\n        if i >= len(layer_types) or layer_types[i] != 'linear_attention':\n            return\n\n        for kind in ['weight', 'bias']:\n            weights = r.linear_attn(i, kind)\n            if not weights:\n                continue\n\n            names = ['conv1d', 'in_proj_qkv', 'in_proj_z', 'in_proj_b', 'in_proj_a', 'out_proj', 'A_log', 'dt_bias']\n            for name, tensor in zip(names, weights):\n                if tensor is None:\n                    continue\n                if name == 'conv1d':\n                    # conv1d shape: (conv_dim, 1, d_conv) where\n                    # conv_dim = key_dim*2 + value_dim.  Interleave Q/K/V\n                    # portions along dim 0 before splitting for TP.\n                    tensor = self._tp_interleave_qkv(tensor, dim=0)\n                    self.model.save_split(tensor,\n                                          self._linear_attn.format(i, name, kind),\n                                          split_dim=0,\n                                          split_num=self.tp)\n                elif name in ['A_log', 'dt_bias']:\n                    # Split per-head params across TP ranks (use -1 to\n                    # avoid the 1-D copy shortcut in save_split).\n                    self.model.save_split(tensor,\n                                          self._linear_attn.format(i, name, kind),\n                                          split_dim=-1,\n                                          split_num=self.tp)\n                elif name == 'out_proj':\n                    self.model.save_split(transpose(tensor),\n                                          self._linear_attn.format(i, name, kind),\n                                          split_dim=0,\n                                          split_num=self.tp)\n                elif name == 'in_proj_qkv':\n                    # in_proj_qkv: (conv_dim, hidden) where conv_dim =\n                    # key_dim*2 + value_dim.  After transpose the QKV\n                    # components are along dim -1.  Interleave for TP so\n                    # each shard gets the correct Q/K/V slice.\n                    t = transpose(tensor)\n                    t = self._tp_interleave_qkv(t, dim=-1)\n                    self.model.save_split(t, self._linear_attn.format(i, name, kind), split_dim=-1, split_num=self.tp)\n                else:\n                    self.model.save_split(transpose(tensor),\n                                          self._linear_attn.format(i, name, kind),\n                                          split_dim=-1,\n                                          split_num=self.tp)\n\n        norm = r.linear_norm(i, 'weight')\n        if norm is not None:\n            self.model.export_weight(norm, f'layers.{i}.linear_attn.norm.weight')\n\n\nclass Misc(Module):\n    \"\"\"\n    requires:\n        r.tok_embeddings()\n        r.norm_weight()\n        r.output_weight()\n    \"\"\"\n\n    def apply(self, i: int, r: BaseReader):\n        \"\"\"Export embedding, norm, output weight.\"\"\"\n        emb = r.tok_embeddings()\n        norm_weight = r.norm_weight()\n        output_weight = r.output_weight()\n\n        def pad_weight(tensor: torch.Tensor, tp: int):\n            pad_size = None\n            vocab_size = self.model.model_config.vocab_size\n            if vocab_size % tp != 0:\n                pad_size = (vocab_size + tp - 1) // tp * tp - vocab_size\n            if pad_size is None:\n                return tensor\n            return torch.nn.functional.pad(tensor, (0, 0, 0, pad_size), 'constant', 0)\n\n        tp = self.model.attn_tp_size * self.model.attn_cp_size\n        if emb is not None:\n            emb = pad_weight(emb, tp=tp)\n            self.model.save_split(emb, 'tok_embeddings.weight', split_dim=1, split_num=tp)\n        if norm_weight is not None:\n            self.model.export_weight(norm_weight, 'norm.weight')\n        if output_weight is not None:\n            output_weight = pad_weight(output_weight, tp=tp)\n            # transpose\n            self.model.save_split(output_weight.t(), 'output.weight', split_dim=1, split_num=tp)\n\n\nclass Transformer:\n\n    def __init__(self, model: BaseOutputModel):\n        self.model = model\n        modules = [LayerNorm]\n        if model.model_config.kv_lora_rank:\n            modules.append(MLA)\n        else:\n            modules.append(Attn)\n        if getattr(model.model_config, 'layer_types', []):\n            modules.append(LinearAttn)\n        if model.model_config.inter_size:\n            modules.append(Ffn)\n        if model.model_config.expert_num:\n            modules.append(MoeFfn)\n        self.modules = [c(model) for c in modules]\n        self.misc = Misc(model)\n\n    def __call__(self, i: int, r: BaseReader):\n        if i >= 0:\n            for m in self.modules:\n                m(i, r)\n            return 1\n        else:\n            self.misc(i, r)\n"
  },
  {
    "path": "lmdeploy/turbomind/deploy/parameter.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom abc import abstractmethod\n\nimport torch\n\n\ndef identity(x):\n    return x\n\n\ndef to_half(x: torch.Tensor):\n    return x.to(torch.half)\n\n\ndef to_float(x: torch.Tensor):\n    return x.to(torch.float)\n\n\ndef to_fp8(x: torch.Tensor):\n    assert x.dtype == torch.uint8\n    return x.view(dtype=torch.float8_e4m3fn)\n\n\ndef pack_u4_row(x: torch.Tensor) -> torch.Tensor:\n    assert x.dtype == torch.uint8, f'x.dtype: {x.dtype}'\n    xs = x.view(*x.shape[:-1], -1, 8).split(1, dim=-1)\n    a = torch.zeros(xs[0].shape, dtype=torch.int32, device=x.device)\n    for t in reversed(xs):\n        a = (a << 4) | t\n    return a.squeeze(dim=-1)\n\n\ndef generate_zero_point(g):\n    weight_shapes = g('weight_shape')\n    result = []\n    for weight_shape in weight_shapes:\n        row, col = weight_shape\n        tensor = torch.full((row, col // 128), 8, dtype=torch.uint8)\n        result.append(tensor)\n    return (*result, )\n\n\nclass Parameter:\n    KEY = ()\n\n    @classmethod\n    def take(cls, keys: list[str]):\n        if not any(k.endswith(cls.KEYS[0]) for k in keys):\n            return False\n        xs = []\n        for k in keys:\n            if any(k.endswith(p) for p in cls.KEYS):\n                xs.append(k)\n        for x in xs:\n            keys.remove(x)\n        return xs\n\n    @abstractmethod\n    def __call__(cls, f, g, i):\n        pass\n\n\nclass QuantWeightOnly(Parameter):\n    KEYS = '.qweight', '.scales', '.qzeros'\n\n    def __call__(self, f, g, i):\n        f(i, g('qweight'), 'qweight', pack_u4_row)\n        f(i, g('scales'), 'scales', to_half, apply_gs=['w2'])\n        f(i, g('qzeros'), 'zeros', to_half, apply_gs=['w2'])\n\n\nclass WeightScaleInv(Parameter):\n    KEYS = '.weight_scale_inv', '.weight'\n\n    # TODO: flag any operations crossing the quant blocks as illegal\n    def __call__(self, f, g, i):\n        f(i, g('weight_scale_inv'), 'scales', to_float, apply_gs=['w1', 'w3', 'w2'])\n        f(i, g('weight'), 'weight', identity)\n\n\nclass CompressedWeight(Parameter):\n    KEYS = '.weight_packed', '.weight_scale', '.weight_zero_point'\n\n    def __init__(self, xs):\n        self.has_zero_point = False\n        if any(key.endswith(self.KEYS[2]) for key in xs):\n            self.has_zero_point = True\n\n    def __call__(self, f, g, i):\n        f(i, g('weight_packed'), 'qweight', pack_u4_row)\n        f(i, g('weight_scale'), 'scales', to_half, apply_gs=['w2'])\n        if self.has_zero_point:\n            f(i, g('weight_zero_point'), 'zeros', to_half, apply_gs=['w2'])\n        else:\n            f(i, generate_zero_point(g), 'zeros', to_half, apply_gs=['w2'])\n\n\nclass Mxfp4Weight(Parameter):\n    KEYS = '.blocks', '.scales'\n\n    def __call__(self, f, g, i):\n        f(i, g('blocks'), 'weight', pack_u4_row)\n        f(i, g('scales'), 'scales', identity, apply_gs=['w2'])\n\n\nclass Weight(Parameter):\n    KEYS = '.weight',\n\n    def __call__(self, f, g, i):\n        f(i, g('weight'), 'weight', identity)\n\n\nclass Bias(Parameter):\n    KEYS = '.bias',\n\n    def __call__(self, f, g, i):\n        f(i, g('bias'), 'bias', identity)\n\n\nclass PLora(Parameter):\n    KEYS = '.Plora_A.weight', '.Plora_B.weight'\n\n    def __call__(self, f, g, i):\n        f(i, g('Plora_A.weight'), 'lora_a.weight', identity)\n        f(i, g('Plora_B.weight'), 'lora_b.weight', identity)\n\n\ndef get_params(keys: list[str], bias=0):\n    ps = []\n    if PLora.take(keys):\n        ps.append(PLora())\n    if QuantWeightOnly.take(keys):\n        ps.append(QuantWeightOnly())\n    if WeightScaleInv.take(keys):\n        ps.append(WeightScaleInv())\n    xs = CompressedWeight.take(keys)\n    if xs:\n        ps.append(CompressedWeight(xs))\n    if Mxfp4Weight.take(keys):\n        ps.append(Mxfp4Weight())\n    if Weight.take(keys):\n        ps.append(Weight())\n    if bias and Bias.take(keys):\n        ps.append(Bias())\n    return ps\n"
  },
  {
    "path": "lmdeploy/turbomind/deploy/policy.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom typing import List\n\nimport torch.cuda\n\n\ndef to_cuda(x: torch.Tensor, *args):\n    return x.cuda()\n\n\ndef get_u4_slices(x: torch.Tensor, dtype: torch.dtype) -> List[torch.Tensor]:\n    MAP = {torch.int32: 8, torch.uint8: 2}\n    xs = []\n    for _ in range(MAP[x.dtype]):\n        xs.append((x & 15).to(dtype))\n        x = x >> 4\n    return xs\n\n\ndef unpack_awq_gemm(x: torch.Tensor) -> torch.Tensor:\n    xs = get_u4_slices(x, torch.uint8)\n    order = [0, 4, 1, 5, 2, 6, 3, 7]\n    ys = [xs[i] for i in order]\n    return torch.stack(ys, dim=-1).view(*x.shape[:-1], -1)\n\n\ndef process_awq_gemm(x: torch.Tensor, kind: str):\n    x = x.cuda()\n    if x.dtype == torch.int32:\n        x = unpack_awq_gemm(x)\n    if kind in ['qweight', 'qzeros', 'scales']:\n        x = x.t()\n    return x\n\n\ndef process_gptq(x: torch.Tensor, kind: str):\n    x = x.cuda()\n    if x.dtype == torch.int32:\n        xs = get_u4_slices(x, torch.uint8)\n        if kind == 'qweight':  # (k/8,n)\n            x = torch.stack(xs, dim=1).view(-1, x.size(-1))\n        else:  # 'qzeros' (k/g,n/8)\n            x = torch.stack(xs, dim=-1).view(x.size(0), -1) + 1\n    if kind in ['qweight', 'qzeros', 'scales']:\n        x = x.t()\n    return x\n\n\ndef process_mxfp4(x: torch.Tensor, kind: str):\n    # print(x.shape, x.dtype, kind)\n    x = x.cuda()\n    if kind == 'blocks':\n        xs = get_u4_slices(torch.flatten(x, start_dim=-2), torch.uint8)\n        x = torch.flatten(torch.stack(xs, dim=-1), start_dim=-2)\n    if kind == 'scales':\n        pass\n    return x\n\n\ndef process_fp8(x: torch.Tensor, kind: str):\n    x = x.cuda()\n    if x.dtype == torch.float8_e4m3fn:\n        # some ops (e.g. torch.cat) for fp8 is not implemented in pytorch\n        return x.view(dtype=torch.uint8)\n    elif kind != 'weight_scale_inv' and x.dtype == torch.float:\n        return x.to(dtype=torch.bfloat16)\n    else:\n        return x.to(dtype=torch.bfloat16)\n\n\ndef process_compressed_tensor(x: torch.Tensor, kind: str):\n    x = x.cuda()\n    if x.dtype == torch.int32:\n        xs = get_u4_slices(x, torch.uint8)\n        if kind == 'weight_packed':  # (out_channels, in_channels // 8)\n            x = torch.stack(xs, dim=-1).view(*x.shape[:-1], -1)\n        elif kind == 'weight_zero_point':  # (out_channels // 8, in_channels // group_size)\n            x = torch.stack(xs, dim=1).view(-1, x.size(-1))\n    return x\n\n\ndef get_input_policy(model_format):\n    if model_format == 'awq':\n        return process_awq_gemm\n    elif model_format == 'gptq':\n        return process_gptq\n    elif model_format == 'mxfp4':\n        return process_mxfp4\n    elif model_format == 'fp8':\n        return process_fp8\n    elif model_format == 'compressed-tensors':\n        return process_compressed_tensor\n    else:\n        return to_cuda\n"
  },
  {
    "path": "lmdeploy/turbomind/deploy/source_model/__init__.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom .baichuan import Baichuan2Model, BaichuanModel  # noqa: F401\nfrom .deepseek2 import DeepSeek2Model  # noqa: F401\nfrom .deepseek_vl import DeepSeekVLModel  # noqa: F401\nfrom .glm4 import Glm4Model  # noqa: F401\nfrom .glm4_moe_lite import Glm4MoeLiteModel  # noqa: F401\nfrom .gpt_oss import GptOssModel  # noqa: F401\nfrom .internlm2 import InternLM2Model  # noqa: F401\nfrom .internvl import InternVLModel  # noqa: F401\nfrom .llama import LlamaModel  # noqa: F401\nfrom .llava import LlavaModel  # noqa: F401\nfrom .minicpmv import MiniCPMVModel  # noqa: F401\nfrom .mixtral import MixtralModel  # noqa: F401\nfrom .molmo import MolmoModel  # noqa: F401\nfrom .qwen import QwenModel  # noqa: F401\nfrom .xcomposer2 import Xcomposer2Model  # noqa: F401\n"
  },
  {
    "path": "lmdeploy/turbomind/deploy/source_model/baichuan.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nimport torch\n\nfrom .base import INPUT_MODELS\nfrom .llama import LlamaModel, LlamaReader\n\n\nclass BaichuanReader(LlamaReader):\n    \"\"\"BaichuanReader.\"\"\"\n\n    def _attn(self, i: int, kind: str):\n        \"\"\"Get q, k, v, o kind for layer i.\"\"\"\n        q, k, v, o = (None, ) * 4\n        pack_key = f'model.layers.{i}.self_attn.W_pack.{kind}'\n        qkv = self.transform(self.params.get(pack_key), kind)\n        if qkv is not None:\n            q, k, v = torch.split(qkv, qkv.shape[0] // 3, dim=0)\n        o = self.params.get(f'model.layers.{i}.self_attn.o_proj.{kind}')\n        o = self.transform(o, kind)\n        return q, k, v, o\n\n\n@INPUT_MODELS.register_module(name='baichuan')\nclass BaichuanModel(LlamaModel):\n    \"\"\"Llama model in baichuan format.\"\"\"\n\n    Reader = BaichuanReader\n\n\nclass Baichuan2Reader(BaichuanReader):\n    \"\"\"Baichuan2Reader.\"\"\"\n\n    def output_weight(self):\n        \"\"\"Get output.\"\"\"\n        # https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat/blob/main/modeling_baichuan.py#L507\n        tensor = self.params.get('lm_head.weight', None)\n        if tensor is not None:\n            tensor = tensor.cuda()\n            tensor = torch.nn.functional.normalize(tensor)\n        return tensor\n\n\n@INPUT_MODELS.register_module(name='baichuan2')\nclass Baichuan2Model(LlamaModel):\n    \"\"\"Llama model in baichuan format.\"\"\"\n\n    Reader = Baichuan2Reader\n"
  },
  {
    "path": "lmdeploy/turbomind/deploy/source_model/base.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom abc import ABC, abstractmethod\nfrom typing import Dict, Iterator, Union\n\nimport torch\nfrom mmengine import Registry\n\nINPUT_MODELS = Registry('source model', locations=['lmdeploy.turbomind.deploy.source_model.base'])\n\n\nclass BaseReader(ABC):\n    \"\"\"Mapping between TM modules and source modules.\"\"\"\n\n    def __init__(self):\n        pass\n\n    def transform(self, x: Union[torch.Tensor, None], kind: str) -> Union[torch.Tensor, None]:\n        return None if x is None else self._transform(x, kind)\n\n    @abstractmethod\n    def _transform(self, x: torch.Tensor, kind: str):\n        \"\"\"Transform x.\"\"\"\n        pass\n\n\nclass BaseInputModel(ABC):\n    \"\"\"Base class for input model.\"\"\"\n\n    def __init__(self, model_path: str, tokenizer_path: str, **kwargs):\n        \"\"\"Constructor for BaseInputModel.\n\n        Args:\n            model_path (str): the path of the model.\n            tokenizer_path (str): the path of the tokenizer model.\n        \"\"\"\n        self.model_path = model_path\n        self.tokenizer_path = tokenizer_path\n\n    @abstractmethod\n    def model_info(self) -> Dict:\n        \"\"\"Read model info.\"\"\"\n        pass\n\n    @abstractmethod\n    def readers(self) -> Iterator[BaseReader]:\n        pass\n"
  },
  {
    "path": "lmdeploy/turbomind/deploy/source_model/deepseek2.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport math\nimport os\n\nfrom ..config import RopeParam\nfrom .base import INPUT_MODELS\nfrom .llama import LlamaModel, LlamaReader\n\n\nclass DeepSeek2Reader(LlamaReader):\n\n    def moe_ffn_gate(self, i, kind):\n        return self.params.get(f'model.layers.{i}.mlp.gate.{kind}')\n\n    def moe_ffn_expert(self, e=None, i=None, kind=None):\n        if not kind:\n            return self.filter(r'experts', i)\n        result = []\n        for key in ['gate', 'down', 'up']:\n            name = f'model.layers.{i}.mlp.experts.{e}.{key}_proj.{kind}'\n            tensor = self.params.get(name)\n            tensor = self.transform(tensor, kind)\n            result.append(tensor)\n        return (*result, )\n\n    def _ffn(self, i: int, kind: str):\n        \"\"\"Get ffn kind for layer i.\"\"\"\n        if not kind:\n            # Filter by layer number to get only keys for this specific layer\n            if i == 0:\n                pattern = rf'model\\.layers\\.{i}\\.mlp\\.'\n            else:\n                pattern = rf'model\\.layers\\.{i}\\.mlp\\.shared_experts\\.'\n            return self.filter(pattern, None)\n        result = []\n        for key in ['gate', 'down', 'up']:\n            name = f'model.layers.{i}.mlp.shared_experts.{key}_proj.{kind}'\n            if i == 0:\n                name = name.replace('shared_experts.', '')\n            tensor = self.params.get(name)\n            tensor = self.transform(tensor, kind)\n            result.append(tensor)\n        return (*result, )\n\n    def ffn(self, i: int, kind: str):\n        return self._ffn(i, kind)\n\n    def mla(self, i: int, kind: str):\n        if not kind:\n            return self.filter(r'self_attn.*proj', i)\n        result = []\n        for key in ['q_a_proj', 'q_b_proj', 'q_proj', 'kv_a_proj_with_mqa', 'kv_b_proj', 'o_proj']:\n            tensor = self.params.get(f'{self.attn_layer_prefix}.{i}.self_attn.{key}.{kind}')\n            tensor = self.transform(tensor, kind)\n            result.append(tensor)\n        return (*result, )\n\n    def mla_norm(self, i: int):\n        result = []\n        for k in ['q', 'kv']:\n            name = f'{self.attn_layer_prefix}.{i}.self_attn.{k}_a_layernorm.weight'  # noqa: E501\n            result.append(self.params.get(name))\n        return (*result, )\n\n\ndef get_yarn_params(rope_scaling: dict):\n\n    scaling_factor = float(rope_scaling['factor'])\n    mscale = rope_scaling['mscale']\n    mscale_all_dim = rope_scaling['mscale_all_dim']\n\n    def yarn_get_mscale(scale=1, mscale=1):\n        if scale <= 1:\n            return 1.0\n        return 0.1 * mscale * math.log(scale) + 1.0\n\n    _mscale = float(yarn_get_mscale(scaling_factor, mscale) / yarn_get_mscale(scaling_factor, mscale_all_dim))\n\n    softmax_scale = 0\n    if mscale_all_dim:\n        scale = yarn_get_mscale(scaling_factor, mscale_all_dim)\n        softmax_scale = scale * scale\n\n    return _mscale, softmax_scale\n\n\n@INPUT_MODELS.register_module(name='deepseek2')\nclass DeepSeek2Model(LlamaModel):\n\n    Reader = DeepSeek2Reader\n\n    def model_info(self):\n        cfg = self.model_config\n        info = super().model_info()\n        qk_nope_dim = cfg['qk_nope_head_dim']\n        qk_rope_dim = cfg['qk_rope_head_dim']\n        kv_lora_rank = cfg['kv_lora_rank']\n        q_head_dim = qk_nope_dim + qk_rope_dim\n        num_layer = cfg['num_hidden_layers']\n        expert_num = cfg['n_routed_experts']\n        expert_num = [expert_num] * num_layer\n        expert_num[0] = 0\n        n_shared_experts = cfg['n_shared_experts']\n        expert_inter_size = cfg['moe_intermediate_size']\n        experts_per_token = cfg['num_experts_per_tok']\n        inter_size = [n_shared_experts * expert_inter_size] * num_layer\n        inter_size[0] = cfg['intermediate_size']\n        norm_topk_prob = cfg['norm_topk_prob']\n        size_per_head = qk_rope_dim + qk_nope_dim\n        v_head_dim = cfg['v_head_dim']\n        softmax_scale = 0.0\n        disable_mla_fold = os.getenv('LMDEPLOY_MLA_FOLD', '1').lower() in ('0', 'false', 'no')\n        if kv_lora_rank and kv_lora_rank != qk_nope_dim and not disable_mla_fold:\n            # MLA folding: remap to kv_lora_rank-based head dims and fold\n            # kc/vc BMMs into q_b_proj/o_proj at conversion time.\n            size_per_head = kv_lora_rank + qk_rope_dim\n            v_head_dim = kv_lora_rank\n            softmax_scale = q_head_dim**(-0.5)\n        elif kv_lora_rank and kv_lora_rank != qk_nope_dim:\n            softmax_scale = q_head_dim**(-0.5)\n\n        info.update(kv_lora_rank=kv_lora_rank,\n                    q_lora_rank=cfg['q_lora_rank'] or 0,\n                    qk_rope_dim=qk_rope_dim,\n                    v_head_dim=v_head_dim,\n                    size_per_head=size_per_head,\n                    kv_head_num=1,\n                    expert_num=expert_num,\n                    expert_inter_size=expert_inter_size,\n                    experts_per_token=experts_per_token,\n                    inter_size=inter_size,\n                    norm_topk_prob=norm_topk_prob,\n                    routed_scale=cfg['routed_scaling_factor'],\n                    topk_method=cfg['topk_method'],\n                    topk_group=cfg['topk_group'],\n                    moe_group_num=cfg['n_group'],\n                    scoring_func=cfg.get('scoring_func', 'softmax'),\n                    tune_layer_num=2)\n        if 'router_n_groups' in cfg and cfg['router_n_groups'] > 0:\n            info['router_n_groups'] = cfg['router_n_groups']\n        rope_param: RopeParam = info['rope_param']\n        rope_param.dim = qk_rope_dim\n        if 'rope_parameters' in cfg:\n            # transformers v5.0.0 aggregates all rope-related parameters into 'rope_parameters'\n            rope_scaling = cfg['rope_parameters']\n        else:\n            rope_scaling = cfg.get('rope_scaling')\n        if rope_scaling and rope_scaling.get('type') == 'yarn':\n            attention_factor, yarn_scale = get_yarn_params(rope_scaling)\n            yarn_scale *= q_head_dim**(-0.5)\n            rope_param.max_position_embeddings = rope_scaling['original_max_position_embeddings']\n            rope_param.attention_factor = attention_factor\n            info.update(rope_param=rope_param, softmax_scale=yarn_scale)\n        elif softmax_scale:\n            info.update(softmax_scale=softmax_scale)\n        return info\n"
  },
  {
    "path": "lmdeploy/turbomind/deploy/source_model/deepseek_vl.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport json\nimport os.path as osp\n\nfrom ..config import RopeParam\nfrom .base import INPUT_MODELS\nfrom .llama import LlamaModel, LlamaReader\n\n\nclass DeepSeekVLReader(LlamaReader):\n    \"\"\"DeepSeekVL model reader.\"\"\"\n\n    attn_layer_prefix = 'language_model.model.layers'\n    attn_layer_patten = r'language_model\\.model\\.layers\\.([0-9]+).'\n    tok_embeddings_key = 'language_model.model.embed_tokens.weight'\n    norm_weight_key = 'language_model.model.norm.weight'\n    output_weight_key = 'language_model.lm_head.weight'\n\n    def __init__(self, new_params: dict, unused_params: dict, last_bin: bool, model_cfg: dict, **kwargs):\n        model_cfg = model_cfg['language_config']\n        super().__init__(new_params, unused_params, last_bin, model_cfg, **kwargs)\n\n    def attn_norm(self, i: int):\n        \"\"\"Get attn norm for layer i.\"\"\"\n        return self.params[f'language_model.model.layers.{i}.input_layernorm.weight']\n\n    def ffn_norm(self, i: int):\n        \"\"\"Get ffn norm for layer i.\"\"\"\n        return self.params[f'language_model.model.layers.{i}.post_attention_layernorm.weight']\n\n\n@INPUT_MODELS.register_module(name='deepseekvl')\nclass DeepSeekVLModel(LlamaModel):\n    \"\"\"DeepSeekVL model in hf format.\"\"\"\n\n    Reader = DeepSeekVLReader\n\n    def model_info(self):\n        \"\"\"Read model info.\"\"\"\n        params_path = osp.join(self.model_path, 'config.json')\n        with open(params_path) as f:\n            model_arg = json.load(f)\n            if 'language_config' in model_arg and model_arg['language_config'].get('model_type', None) == 'llama':\n                model_arg = model_arg['language_config']  # depseek-vl\n            num_layer = model_arg['num_hidden_layers']\n            hidden_units = model_arg.get('hidden_size', 4096)\n            inter_size = model_arg.get('intermediate_size', 11008)\n            vocab_size = model_arg.get('vocab_size', 102400)\n            norm_eps = model_arg.get('rms_norm_eps', 1e-06)\n            attn_head_num = model_arg.get('num_attention_heads', 32)\n            if 'num_key_value_heads' in model_arg:\n                kv_head_num = model_arg['num_key_value_heads']\n            else:\n                kv_head_num = model_arg.get('num_attention_heads', 32)\n            rope_theta = float(model_arg.get('rope_theta', 10000.0))\n            max_position_embeddings = int(model_arg.get('max_position_embeddings', 0))\n            rope_scaling = model_arg.get('rope_scaling', None)\n            scaling_factor = 0.0\n            scaling_type = 'default'\n            if isinstance(rope_scaling, dict):\n                scaling_type = model_arg['rope_scaling'].get('type', 'default')\n                scaling_factor = model_arg['rope_scaling'].get('factor', '')\n            head_dim = model_arg.get('head_dim', hidden_units // attn_head_num)\n            rope_param = RopeParam(type=scaling_type,\n                                   base=rope_theta,\n                                   dim=head_dim,\n                                   max_position_embeddings=max_position_embeddings,\n                                   factor=scaling_factor)\n\n        return dict(num_layer=num_layer,\n                    norm_eps=norm_eps,\n                    head_num=attn_head_num,\n                    kv_head_num=kv_head_num,\n                    hidden_units=hidden_units,\n                    inter_size=inter_size,\n                    vocab_size=vocab_size,\n                    max_position_embeddings=max_position_embeddings,\n                    rope_param=rope_param)\n"
  },
  {
    "path": "lmdeploy/turbomind/deploy/source_model/glm4.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport json\nimport os.path as osp\n\nimport torch\n\nfrom ..config import RopeParam\nfrom .base import INPUT_MODELS\nfrom .llama import LlamaModel, LlamaReader\n\n\nclass Glm4Reader(LlamaReader):\n    \"\"\"Glm4Reader.\"\"\"\n\n    attn_layer_patten = r'transformer\\.encoder\\.layers\\.([0-9]+).'\n    tok_embeddings_key = 'transformer.embedding.word_embeddings.weight'\n    norm_weight_key = 'transformer.encoder.final_layernorm.weight'\n    output_weight_key = 'transformer.output_layer.weight'\n\n    attn_pattern = r'self_attention'\n\n    def _attn(self, i: int, kind: str):\n        \"\"\"Get q, k, v, o kind for layer i.\"\"\"\n        qkv = self.params[f'transformer.encoder.layers.{i}'\n                          f'.self_attention.query_key_value.{kind}']\n        qkv = self.transform(qkv, kind)\n        attn_head_num = self.model_cfg['num_attention_heads']\n        kv_head_num = attn_head_num\n        if self.model_cfg.get('multi_query_attention', False):\n            kv_head_num = self.model_cfg['multi_query_group_num']\n        HEAD_DIM = 128\n        q, k, v = torch.split(qkv, [attn_head_num * HEAD_DIM, kv_head_num * HEAD_DIM, kv_head_num * HEAD_DIM], dim=0)\n        o = self.params.get(f'transformer.encoder.layers.{i}.self_attention.dense.{kind}')\n        o = self.transform(o, kind)\n        if o is None:  # handle the case when qkv has bias but o doesn't\n            o = torch.zeros_like(q)\n        return q, k, v, o\n\n    def attn_norm(self, i: int):\n        \"\"\"Get attn norm for layer i.\"\"\"\n        return self.params[f'transformer.encoder.layers.{i}.input_layernorm.weight']\n\n    def _ffn(self, i: int, kind: str):\n        \"\"\"Get ffn kind for layer i.\"\"\"\n        up_and_gate = self.params[f'transformer.encoder.layers.{i}.mlp.dense_h_to_4h.{kind}']\n        up_and_gate = self.transform(up_and_gate, kind)\n        up, gate = up_and_gate.chunk(2, dim=0)\n        down = self.params[f'transformer.encoder.layers.{i}.mlp.dense_4h_to_h.{kind}']\n        down = self.transform(down, kind)\n        return (up, down, gate)\n\n    def ffn_norm(self, i: int):\n        \"\"\"Get ffn norm for layer i.\"\"\"\n        return self.params[f'transformer.encoder.layers.{i}.post_attention_layernorm.weight']\n\n\n@INPUT_MODELS.register_module(name='glm4')\nclass Glm4Model(LlamaModel):\n    \"\"\"Glm2/3/4 model in hf format.\"\"\"\n\n    Reader = Glm4Reader\n\n    def __init__(self, model_path: str, tokenizer_path: str, **kwargs):\n        super().__init__(model_path, tokenizer_path, **kwargs)\n        config_path = osp.join(self.model_path, 'config.json')\n        with open(config_path) as f:\n            self.config = json.load(f)\n\n    def model_info(self):\n        \"\"\"Read model info.\"\"\"\n        config = self.config\n        hidden_units = config.get('hidden_size', None)\n        num_layer = config.get('num_hidden_layers', None)\n        num_layer = config.get('num_layers', num_layer)\n        norm_eps = config['layernorm_epsilon']\n        rope_theta = float(config.get('rotary_emb_base', 10000.0))\n        rope_ratio = float(config.get('rope_ratio', 1.0))\n        rope_theta *= rope_ratio\n        attn_head_num = config['num_attention_heads']\n        kv_head_num = attn_head_num\n        inter_size = config['ffn_hidden_size']\n        vocab_size = config['padded_vocab_size']\n        attn_bias = config['add_qkv_bias']\n        if config['multi_query_attention']:\n            kv_head_num = config['multi_query_group_num']\n        seq_length = config['seq_length']\n        rope_param = RopeParam(type='default', base=rope_theta, dim=64)\n        return dict(num_layer=num_layer,\n                    norm_eps=norm_eps,\n                    head_num=attn_head_num,\n                    kv_head_num=kv_head_num,\n                    hidden_units=hidden_units,\n                    attn_bias=int(attn_bias),\n                    inter_size=inter_size,\n                    vocab_size=vocab_size,\n                    rope_param=rope_param,\n                    max_position_embeddings=seq_length,\n                    permute_qk=False)  # head layout is same as TM\n"
  },
  {
    "path": "lmdeploy/turbomind/deploy/source_model/glm4_moe_lite.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\"\"\"GLM-4 MoE Lite (e.g. GLM-4.7-Flash) source model for TurboMind.\n\nArchitecture: MLA (Multi-head Latent Attention) + MoE with dense first layer.\nWeight layout follows HuggingFace checkpoint with model.layers.* (same family as DeepSeek2).\n\"\"\"\n\nfrom .base import INPUT_MODELS\nfrom .deepseek2 import DeepSeek2Model, DeepSeek2Reader\n\n\nclass Glm4MoeLiteReader(DeepSeek2Reader):\n    \"\"\"Reader for Glm4MoeLiteForCausalLM (GLM-4.7-Flash).\n\n    Uses same key layout as DeepSeek2: model.layers.{i}.self_attn.*, model.layers.{i}.mlp.*\n    Supports noaux_tc via e_score_correction_bias.\n    \"\"\"\n\n    attn_layer_prefix = 'model.layers'\n    attn_layer_patten = r'model\\.layers\\.([0-9]+).'\n    tok_embeddings_key = 'model.embed_tokens.weight'\n    norm_weight_key = 'model.norm.weight'\n    output_weight_key = 'lm_head.weight'\n\n    def moe_ffn_gate_correction_bias(self, i: int):\n        \"\"\"Per-expert score correction bias for noaux_tc routing.\"\"\"\n        return self.params.get(f'{self.attn_layer_prefix}.{i}.mlp.gate.e_score_correction_bias')\n\n\n@INPUT_MODELS.register_module(name='glm4-moe-lite')\nclass Glm4MoeLiteModel(DeepSeek2Model):\n    \"\"\"GLM-4 MoE Lite (e.g. GLM-4.7-Flash) in HF format.\n\n    MLA + MoE with first_k_dense_replace; config mapping aligned to DeepSeek2.\n    \"\"\"\n\n    Reader = Glm4MoeLiteReader\n\n    def model_info(self):\n        cfg = self.model_config\n        # Set default MoE routing config for GLM-4 MoE Lite if not in HF config\n        if 'topk_method' not in cfg:\n            cfg['topk_method'] = 'noaux_tc'\n        if 'topk_group' not in cfg:\n            cfg['topk_group'] = 1\n        if 'n_group' not in cfg:\n            cfg['n_group'] = 1\n        if 'scoring_func' not in cfg:\n            cfg['scoring_func'] = 'sigmoid'\n\n        info = super().model_info()\n        # GLM4 MoE Lite uses noaux_tc routing with sigmoid scoring\n        info['topk_method'] = 'noaux_tc'\n        info['scoring_func'] = 'sigmoid'\n        if 'router_n_groups' in cfg and cfg['router_n_groups'] > 0:\n            info['router_n_groups'] = cfg['router_n_groups']\n\n        return info\n"
  },
  {
    "path": "lmdeploy/turbomind/deploy/source_model/gpt_oss.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nimport re\n\nfrom .base import INPUT_MODELS\nfrom .llama import LlamaModel, LlamaReader\n\n\ndef map_experts(str):\n    s = re.sub(r'(experts.*proj)$', r'\\1.weight', str)\n    s = re.sub(r'(experts.*proj)_bias$', r'\\1.bias', s)\n    s = re.sub(r'(experts.*proj)_blocks$', r'\\1.blocks', s)\n    s = re.sub(r'(experts.*proj)_scales$', r'\\1.scales', s)\n    return s\n\n\nclass GptOssReader(LlamaReader):\n\n    mappings = [map_experts]\n\n    def moe_ffn_expert(self, e=None, i=None, kind=None):\n        if not kind:\n            return self.filter(r'experts', i)\n        result = []\n        for key in ['gate_up', 'down']:\n            name = f'{self.attn_layer_prefix}.{i}.mlp.experts.{key}_proj.{kind}'\n            tensor = self.params.get(name)[e]\n            if kind == 'weight':  # experts in BF16 models are in M-major\n                tensor = tensor.cuda().t()\n            if key == 'gate_up':\n                gate, up = tensor[::2], tensor[1::2]\n                result.append(self.transform(gate, kind))\n                result.append(self.transform(up, kind))\n            else:\n                result.append(self.transform(tensor, kind))\n        return (result[0], result[2], result[1])\n\n    def moe_ffn_gate(self, i, kind):\n        return self.transform(self.params.get(f'{self.attn_layer_prefix}.{i}.mlp.router.{kind}'), kind)\n\n    def attn_sinks(self, i):\n        return self.params.get(f'{self.attn_layer_prefix}.{i}.self_attn.sinks')\n\n\n@INPUT_MODELS.register_module(name='gpt-oss')\nclass GptOssModel(LlamaModel):\n\n    Reader = GptOssReader\n\n    def model_info(self):\n        cfg = self.model_config\n        types = cfg['layer_types']\n        sliding_window = cfg['sliding_window']\n        info = super().model_info()\n        info.update(attn_bias=int(cfg['attention_bias']),\n                    mlp_bias=True,\n                    expert_router_bias=True,\n                    expert_num=cfg['num_local_experts'],\n                    expert_inter_size=cfg['intermediate_size'],\n                    experts_per_token=cfg['experts_per_token'],\n                    norm_topk_prob=True,\n                    inter_size=0,\n                    window_size=[sliding_window if x == 'sliding_attention' else 0 for x in types],\n                    attn_sink=True,\n                    activation_type='gpt-oss')\n        return info\n"
  },
  {
    "path": "lmdeploy/turbomind/deploy/source_model/internlm2.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nimport re\n\nimport torch\n\nfrom .base import INPUT_MODELS\nfrom .llama import LlamaModel, LlamaReader\n\n\nclass InternLM2Reader(LlamaReader):\n    \"\"\"InternLM2 model reader.\"\"\"\n\n    attn_layer_prefix = 'model.layers'\n    attn_layer_patten = r'model\\.layers\\.([0-9]+).'\n    tok_embeddings_key = 'model.tok_embeddings.weight'\n    norm_weight_key = 'model.norm.weight'\n    output_weight_key = 'output.weight'\n\n    attn_pattern = r'attention'\n    ffn_pattern = r'feed_forward'\n\n    proj_pattern = 'w'\n\n    def filter(self, pattern: str, i: int | None):\n        params = []\n        for k in self.params.keys():\n            if re.search(pattern, k):\n                params.append(k)\n\n        if self.fp8_quant and pattern == self.attn_pattern:\n            from lmdeploy.lite.quantization.weight.quant_utils import quant_blocked_fp8\n            q, k, v = (None, ) * 3\n            kv_head_num = self.model_cfg['num_key_value_heads']\n            gs = int(self.model_cfg['num_attention_heads'] / kv_head_num)\n            qkv = self.params.get(f'{self.attn_layer_prefix}.{i}.attention.wqkv.weight')\n\n            if qkv is not None:\n                qkv = qkv.view(kv_head_num, gs + 2, 128, -1)\n                hidden_dim = qkv.shape[-1]\n                q, k, v = torch.split(qkv, [gs, 1, 1], dim=1)\n\n                tensors = [q.reshape(-1, hidden_dim), k.reshape(-1, hidden_dim), v.reshape(-1, hidden_dim)]\n                split_sizes = [gs, 1, 1]\n                keys = ['q', 'k', 'v']\n                qkv_weight = []\n                for tensor, split_size, key in zip(tensors, split_sizes, keys):\n                    qweight, scale = quant_blocked_fp8(tensor, torch.float8_e4m3fn, block_size=128)\n                    qweight = qweight.reshape(kv_head_num, split_size, 128, -1)\n                    qkv_weight.append(qweight)\n\n                    self.params[f'{self.attn_layer_prefix}.{i}.{self.attn_pattern}.w{key}.weight_scale_inv'] = scale\n                    params.append(f'{self.attn_layer_prefix}.{i}.{self.attn_pattern}.w{key}.weight_scale_inv')\n\n                qkv_weight = torch.cat(qkv_weight, dim=1)\n                qkv_weight = qkv_weight.reshape(-1, hidden_dim)\n                self.params[f'{self.attn_layer_prefix}.{i}.{self.attn_pattern}.wqkv.weight'] = qkv_weight\n\n            return params\n        else:\n            return params\n\n    def _attn(self, i: int, kind: str):\n        \"\"\"Get q, k, v, o kind for layer i.\"\"\"\n        if self.fp8_quant and kind == 'weight_scale_inv':\n            result = []\n            for key in ['q', 'k', 'v', 'o']:\n                tensor = self.params.get(f'{self.attn_layer_prefix}.{i}.{self.attn_pattern}.w{key}.{kind}')\n                tensor = self.transform(tensor, kind)\n                result.append(tensor)\n            return (*result, )\n        q, k, v = (None, ) * 3\n        kv_head_num = self.model_cfg['num_key_value_heads']\n        gs = int(self.model_cfg['num_attention_heads'] / kv_head_num)\n        qkv = self.params.get(f'{self.attn_layer_prefix}.{i}.attention.wqkv.{kind}')\n        qkv = self.transform(qkv, kind)\n        if qkv is not None:\n            qkv = qkv.view(kv_head_num, gs + 2, 128, -1)\n            hidden_dim = qkv.shape[-1]\n            q, k, v = torch.split(qkv, [gs, 1, 1], dim=1)\n            q = q.reshape(-1, hidden_dim)\n            k = k.reshape(-1, hidden_dim)\n            v = v.reshape(-1, hidden_dim)\n        o = self.params.get(f'{self.attn_layer_prefix}.{i}.attention.wo.{kind}')\n        o = self.transform(o, kind)\n        return (q, k, v, o)\n\n    def attn_norm(self, i: int):\n        \"\"\"Get attn norm for layer i.\"\"\"\n        return self.params[f'{self.attn_layer_prefix}.{i}.attention_norm.weight']\n\n    def _ffn(self, i: int, kind: str):\n        \"\"\"Get ffn kind for layer i.\"\"\"\n        if not kind:\n            return self.filter(self.ffn_pattern, i)\n        result = []\n        for key in ['w1', 'w2', 'w3']:\n            tensor = self.params[f'{self.attn_layer_prefix}.{i}.feed_forward.{key}.{kind}']\n            tensor = self.transform(tensor, kind)\n            result.append(tensor)\n        return (*result, )\n\n    def ffn_norm(self, i: int):\n        \"\"\"Get ffn norm for layer i.\"\"\"\n        return self.params[f'{self.attn_layer_prefix}.{i}.ffn_norm.weight']\n\n\n@INPUT_MODELS.register_module(name='internlm2')\nclass InternLM2Model(LlamaModel):\n    \"\"\"InternLM2 model in hf format.\"\"\"\n\n    Reader = InternLM2Reader\n"
  },
  {
    "path": "lmdeploy/turbomind/deploy/source_model/internvl.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom .base import INPUT_MODELS\nfrom .gpt_oss import GptOssReader\nfrom .internlm2 import InternLM2Reader\nfrom .llama import LlamaModel, LlamaReader\nfrom .qwen import Qwen3MoeReader, Qwen3Reader\n\n\nclass InternVLReader(LlamaReader):\n    \"\"\"InternVLReader for llama model.\"\"\"\n\n    attn_layer_prefix = 'language_model.model.layers'\n    attn_layer_patten = r'language_model\\.model\\.layers\\.([0-9]+).'\n    tok_embeddings_key = 'language_model.model.embed_tokens.weight'\n    norm_weight_key = 'language_model.model.norm.weight'\n    output_weight_key = 'language_model.lm_head.weight'\n\n\n# Note the subtle difference in keys\nclass InternVL2Reader(InternLM2Reader):\n    \"\"\"InternVLReader for InternLM2 model.\"\"\"\n\n    attn_layer_prefix = 'language_model.model.layers'\n    attn_layer_patten = r'language_model\\.model\\.layers\\.([0-9]+).'\n    tok_embeddings_key = 'language_model.model.tok_embeddings.weight'\n    norm_weight_key = 'language_model.model.norm.weight'\n    output_weight_key = 'language_model.output.weight'\n\n\nclass InternVL3d5Reader(Qwen3Reader):\n    attn_layer_prefix = 'language_model.model.layers'\n    attn_layer_patten = r'language_model\\.model\\.layers\\.([0-9]+).'\n    tok_embeddings_key = 'language_model.model.embed_tokens.weight'\n    norm_weight_key = 'language_model.model.norm.weight'\n    output_weight_key = 'language_model.lm_head.weight'\n\n\nclass InternVL3d5Qwen3MoEReader(Qwen3MoeReader):\n    attn_layer_prefix = 'language_model.model.layers'\n    attn_layer_patten = r'language_model\\.model\\.layers\\.([0-9]+).'\n    tok_embeddings_key = 'language_model.model.embed_tokens.weight'\n    norm_weight_key = 'language_model.model.norm.weight'\n    output_weight_key = 'language_model.lm_head.weight'\n\n\nclass InternVL3d5GptOSSReader(GptOssReader):\n    attn_layer_prefix = 'language_model.model.layers'\n    attn_layer_patten = r'language_model\\.model\\.layers\\.([0-9]+).'\n    tok_embeddings_key = 'language_model.model.embed_tokens.weight'\n    norm_weight_key = 'language_model.model.norm.weight'\n    output_weight_key = 'language_model.lm_head.weight'\n\n\nclass InternS1Reader(Qwen3MoeReader):\n    \"\"\"InternS1Reader for internlm/InternS1 model.\"\"\"\n\n    attn_layer_prefix = 'model.language_model.layers'\n    attn_layer_patten = r'model\\.language_model\\.layers\\.([0-9]+).'\n    tok_embeddings_key = 'model.language_model.embed_tokens.weight'\n    norm_weight_key = 'model.language_model.norm.weight'\n    output_weight_key = 'lm_head.weight'\n\n\nclass InternS1MiniReader(Qwen3Reader):\n\n    attn_layer_prefix = 'model.language_model.layers'\n    attn_layer_patten = r'model\\.language_model\\.layers\\.([0-9]+).'\n    tok_embeddings_key = 'model.language_model.embed_tokens.weight'\n    norm_weight_key = 'model.language_model.norm.weight'\n    output_weight_key = 'lm_head.weight'\n\n\n@INPUT_MODELS.register_module(name='internvl')\nclass InternVLModel(LlamaModel):\n    \"\"\"InternVL model in hf format.\"\"\"\n\n    def __init__(self, model_path: str, tokenizer_path: str, **kwargs):\n        super().__init__(model_path, tokenizer_path, **kwargs)\n        from transformers import AutoConfig\n        config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)\n\n        arch = config.architectures[0]\n        if arch == 'InternVLChatModel' or arch == 'InternVLForConditionalGeneration':\n            relations = dict(InternLM2ForCausalLM=('internlm2', InternVL2Reader),\n                             LlamaForCausalLM=('llama', InternVLReader),\n                             Qwen2ForCausalLM=('qwen2', InternVLReader),\n                             Qwen3MoeForCausalLM=('qwen3-moe', InternVL3d5Qwen3MoEReader),\n                             Qwen3ForCausalLM=('qwen3', InternVL3d5Reader),\n                             GptOssForCausalLM=('gpt-oss', InternVL3d5GptOSSReader))\n        elif arch == 'InternS1ForConditionalGeneration':\n            relations = dict(Qwen3MoeForCausalLM=('qwen3-moe', InternS1Reader),\n                             Qwen3ForCausalLM=('qwen3', InternS1MiniReader))\n        else:\n            raise ValueError('unsupported model arch {arch}')\n        self.llm_config = getattr(config, 'llm_config', None) or getattr(config, 'text_config', None)\n        arch = self.llm_config.architectures[0]\n        llm_model, self.Reader = relations[arch]\n        self.llm_model = INPUT_MODELS.get(llm_model)(model_path=model_path, tokenizer_path=tokenizer_path, **kwargs)\n\n    def model_info(self):\n        \"\"\"Read model info.\"\"\"\n        self.llm_model.model_config = self.llm_config.to_dict()\n        return self.llm_model.model_info()\n"
  },
  {
    "path": "lmdeploy/turbomind/deploy/source_model/llama.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport math\nimport re\n\nimport torch\n\nfrom lmdeploy.archs import get_model_arch\n\nfrom ..config import RopeParam\nfrom ..loader import create_loader\nfrom .base import INPUT_MODELS, BaseInputModel, BaseReader\n\n\nclass LlamaReader(BaseReader):\n    \"\"\"LlamaReader.\"\"\"\n\n    attn_layer_prefix = 'model.layers'\n    attn_layer_patten = r'model\\.layers\\.([0-9]+).'\n    tok_embeddings_key = 'model.embed_tokens.weight'\n    norm_weight_key = 'model.norm.weight'\n    output_weight_key = 'lm_head.weight'\n\n    attn_pattern = r'self_attn'\n    ffn_pattern = r'mlp'\n\n    proj_pattern = 'proj'\n    scale_inv_suffix = '_scale_inv'\n\n    def __init__(self, new_params: dict, unused_params: dict, last_bin: bool, model_cfg: dict, policy, fp8_quant=False):\n        super().__init__()\n        self.params = unused_params\n        self.params.update(new_params)\n        self.last_bin = last_bin\n        self.model_cfg = model_cfg\n        tie_word_embeddings = self.model_cfg.get('tie_word_embeddings', False)\n        if tie_word_embeddings:\n            self.output_weight_key = self.tok_embeddings_key\n        self.processor = policy\n        self.fp8_quant = fp8_quant\n        if self.fp8_quant:\n            quant_params = self.quant_weight_fp8()\n            self.params.update(quant_params)\n\n    def quant_weight_fp8(self):\n        from lmdeploy.lite.quantization.weight.quant_utils import quant_blocked_fp8\n        pattern_str = fr'({self.attn_pattern}|{self.ffn_pattern}).*{self.proj_pattern}.*\\.weight'\n        target_pattern = re.compile(pattern_str)\n\n        if self.__class__.__name__ == 'InternLM2Reader':\n            skip_pattern = re.compile(r'wqkv.*\\.weight')\n        else:\n            skip_pattern = None\n\n        quant_params = {}\n        for name, weight in self.params.items():\n            if target_pattern.search(name) and name.endswith('.weight'):\n                if skip_pattern and skip_pattern.search(name):\n                    continue\n                q_weight, scale = quant_blocked_fp8(weight, torch.float8_e4m3fn, block_size=128)\n                quant_params[name] = q_weight\n                quant_params[f'{name}{self.scale_inv_suffix}'] = scale.to(weight.dtype)\n\n        return quant_params\n\n    def filter(self, pattern: str, i: int | None):\n        params = []\n        for k in self.params.keys():\n            if re.search(pattern, k):\n                params.append(k)\n        return params\n\n    def tok_embeddings(self):\n        \"\"\"Get embeddings.\"\"\"\n        return self.transform(self.params.get(self.tok_embeddings_key, None), 'weight')\n\n    def norm_weight(self):\n        \"\"\"Get norm.\"\"\"\n        return self.transform(self.params.get(self.norm_weight_key, None), 'weight')\n\n    def output_weight(self):\n        \"\"\"Get output.\"\"\"\n        return self.transform(self.params.get(self.output_weight_key, None), 'weight')\n\n    def _transform(self, x: torch.Tensor, kind: str):\n        return self.processor(x, kind)\n\n    def _attn(self, i: int, kind: str):\n        \"\"\"Get q, k, v, o kind for layer i.\"\"\"\n        result = []\n        for key in ['q', 'k', 'v', 'o']:\n            tensor = self.params.get(f'{self.attn_layer_prefix}.{i}.self_attn.{key}_proj.{kind}')\n            tensor = self.transform(tensor, kind)\n            result.append(tensor)\n        return (*result, )\n\n    def attn(self, i: int, kind: str):\n        if not kind:\n            return self.filter(self.attn_pattern, i)\n        return self._attn(i, kind)\n\n    def attn_norm(self, i: int):\n        \"\"\"Get attn norm for layer i.\"\"\"\n        return self.transform(self.params[f'{self.attn_layer_prefix}.{i}.input_layernorm.weight'], 'weight')\n\n    def _ffn(self, i: int, kind: str):\n        \"\"\"Get ffn kind for layer i.\"\"\"\n        if not kind:\n            return self.filter(self.ffn_pattern, i)\n        result = []\n        for key in ['gate', 'down', 'up']:\n            tensor = self.params[f'{self.attn_layer_prefix}.{i}.mlp.{key}_proj.{kind}']\n            tensor = self.transform(tensor, kind)\n            result.append(tensor)\n        return (*result, )\n\n    def ffn(self, i: int, kind: str):\n        if not kind:\n            return self.filter(self.ffn_pattern, i)\n        return self._ffn(i, kind)\n\n    def ffn_norm(self, i: int):\n        \"\"\"Get ffn norm for layer i.\"\"\"\n        return self.transform(self.params[f'{self.attn_layer_prefix}.{i}.post_attention_layernorm.weight'], 'weight')\n\n\n@INPUT_MODELS.register_module(name='llama')\nclass LlamaModel(BaseInputModel):\n    \"\"\"Llama model in hf format.\"\"\"\n\n    Reader = LlamaReader\n\n    def __init__(self, model_path: str, tokenizer_path: str, **kwargs: dict):\n        super().__init__(model_path, tokenizer_path)\n        self.policy = kwargs.get('input_policy')\n        _, model_config = get_model_arch(model_path)\n        if hasattr(model_config, 'text_config'):\n            model_config = model_config.text_config\n        elif hasattr(model_config, 'llm_config'):\n            model_config = model_config.llm_config\n        if hasattr(model_config, 'to_dict'):\n            self.model_config = model_config.to_dict()\n        else:\n            self.model_config = model_config\n        self.fp8_quant = kwargs.get('fp8_quant', False)\n\n    def readers(self):\n        mappings = getattr(self.Reader, 'mappings', [])\n        loader = create_loader(self.model_path, self.Reader.attn_layer_patten, mappings)\n        for i, param in loader.items():\n            reader = self.Reader(param, {}, False, self.model_config, policy=self.policy, fp8_quant=self.fp8_quant)\n            yield i, reader\n        torch.cuda.empty_cache()\n\n    def model_info(self):\n        \"\"\"Read model info.\"\"\"\n        model_arg = self.model_config\n        num_layer = model_arg['num_hidden_layers']\n        norm_eps = model_arg['rms_norm_eps']\n        attn_head_num = model_arg['num_attention_heads']\n        vocab_size = model_arg['vocab_size']\n        inter_size = model_arg.get('intermediate_size', 0)\n        if 'num_key_value_heads' in model_arg:\n            kv_head_num = model_arg['num_key_value_heads']\n        else:\n            kv_head_num = model_arg['num_attention_heads']\n        hidden_units = model_arg['hidden_size']\n        # head_dim could be none in config\n        head_dim = model_arg.get('head_dim', None)\n        head_dim = head_dim or hidden_units // attn_head_num\n        # compute rope param\n        if 'rope_parameters' in model_arg:\n            # transformers v5.0.0 aggregates rope settings into rope_parameters\n            rope_scaling = model_arg['rope_parameters']\n            rope_theta = float(rope_scaling.get('rope_theta', 10000.0))\n        else:\n            rope_theta = float(model_arg.get('rope_theta', 10000.0))\n            rope_scaling = model_arg.get('rope_scaling', None)\n        max_position_embeddings = int(model_arg.get('max_position_embeddings', 0))\n        rope_param = RopeParam(type='default', base=rope_theta, dim=head_dim)\n        if isinstance(rope_scaling, dict):\n            rope_type = rope_scaling.get('rope_type', '') or rope_scaling.get('type', '')\n            if rope_scaling.get('mrope_section') is not None:\n                # TODO: treat mrope as an option to the common rope functions\n                rope_type = 'mrope'\n            scaling_factor = rope_scaling.get('factor', 0.0)\n            if rope_type == 'default':\n                pass\n            elif rope_type == 'dynamic':\n                rope_param.type = 'dynamic'\n                rope_param.factor = scaling_factor\n                rope_param.max_position_embeddings = max_position_embeddings\n            elif rope_type == 'linear':\n                rope_param.type = 'linear'\n                rope_param.factor = scaling_factor\n            elif rope_type == 'llama3':\n                low_freq_factor = rope_scaling.get('low_freq_factor', 1.0)\n                high_freq_factor = rope_scaling.get('high_freq_factor', 1.0)\n                original_max_position_embeddings = rope_scaling.get('original_max_position_embeddings', 0)\n                rope_param.type = 'llama3'\n                rope_param.factor = scaling_factor\n                rope_param.low_freq_factor = low_freq_factor\n                rope_param.high_freq_factor = high_freq_factor\n                rope_param.original_max_position_embeddings = original_max_position_embeddings\n            elif rope_type == 'yarn':\n                attention_factor = rope_scaling.get('attention_factor', None)\n                if attention_factor is None:\n                    attention_factor = 0.1 * math.log(scaling_factor) + 1.0\n                beta_fast = rope_scaling.get('beta_fast', 32.0)\n                beta_slow = rope_scaling.get('beta_slow', 1.0)\n                rope_param.type = 'yarn'\n                if 'original_max_position_embeddings' in rope_scaling:\n                    original_max_position_embeddings = rope_scaling['original_max_position_embeddings']\n                    scaling_factor = max_position_embeddings / original_max_position_embeddings\n                else:\n                    original_max_position_embeddings = max_position_embeddings\n                rope_param.factor = scaling_factor\n                rope_param.max_position_embeddings = original_max_position_embeddings\n                rope_param.attention_factor = attention_factor\n                rope_param.beta_fast = beta_fast\n                rope_param.beta_slow = beta_slow\n            elif rope_type == 'mrope':\n                mrope_section = rope_scaling.get('mrope_section')\n                rope_param.type = 'mrope'\n                rope_param.mrope_section = mrope_section\n            else:\n                raise RuntimeError(f'Unsupported rope type: {rope_type}')\n\n        return dict(size_per_head=head_dim,\n                    num_layer=num_layer,\n                    norm_eps=norm_eps,\n                    head_num=attn_head_num,\n                    kv_head_num=kv_head_num,\n                    hidden_units=hidden_units,\n                    inter_size=inter_size,\n                    vocab_size=vocab_size,\n                    max_position_embeddings=max_position_embeddings,\n                    rope_param=rope_param)\n"
  },
  {
    "path": "lmdeploy/turbomind/deploy/source_model/llava.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport json\nimport os.path as osp\n\nfrom ..config import RopeParam\nfrom .base import INPUT_MODELS\nfrom .llama import LlamaModel, LlamaReader\n\n\nclass LlavaReader(LlamaReader):\n    \"\"\"LlavaReader for llama model.\"\"\"\n\n    attn_layer_prefix = 'language_model.model.layers'\n    attn_layer_patten = r'language_model\\.model\\.layers\\.([0-9]+).'\n    tok_embeddings_key = 'language_model.model.embed_tokens.weight'\n    norm_weight_key = 'language_model.model.norm.weight'\n    output_weight_key = 'language_model.lm_head.weight'\n\n    def __init__(self, new_params: dict, unused_params: dict, last_bin: bool, model_cfg: dict, policy):\n        model_cfg = model_cfg.get('text_config')\n        super().__init__(new_params, unused_params, last_bin, model_cfg, policy)\n\n\n@INPUT_MODELS.register_module(name='llava')\nclass LlavaModel(LlamaModel):\n    \"\"\"LlavaModel model in hf format.\"\"\"\n\n    def __init__(self, model_path: str, tokenizer_path: str, **kwargs):\n        super().__init__(model_path, tokenizer_path, **kwargs)\n        from transformers import AutoConfig\n        config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)\n        config = getattr(config, 'text_config', config)\n        arch = config.architectures[0]\n        _readers = dict(Qwen2ForCausalLM=LlavaReader, LlamaForCausalLM=LlavaReader)\n        self.Reader = _readers[arch]\n        self.arch = arch\n\n    def model_info(self):\n        \"\"\"Read model info for LlavaForConditionalGeneration.\n\n        https://huggingface.co/llava-hf/llava-interleave-qwen-7b-hf\n        \"\"\"\n        params_path = osp.join(self.model_path, 'config.json')\n        with open(params_path) as f:\n            model_arg = json.load(f)['text_config']\n            num_layer = model_arg.get('num_hidden_layers', 32)\n            norm_eps = model_arg.get('rms_norm_eps', 1e-6)\n            attn_head_num = model_arg.get('num_attention_heads', 32)\n            if 'num_key_value_heads' in model_arg:\n                kv_head_num = model_arg.get('num_key_value_heads', 32)\n            else:\n                kv_head_num = model_arg.get('num_attention_heads', 32)\n            rope_theta = float(model_arg.get('rope_theta', 10000.0))\n            max_position_embeddings = int(model_arg.get('max_position_embeddings', 0))\n            rope_scaling = model_arg.get('rope_scaling', None)\n            scaling_factor = 0.0\n            scaling_type = 'default'\n\n            # special for the model: llava-hf/llava-interleave-qwen-7b-hf\n            hidden_units = model_arg.get('hidden_size', 4096)\n            vocab_size = model_arg.get('vocab_size', 152000)\n            intermediate_size = model_arg.get('intermediate_size', 11008)\n            attn_bias = 1 if model_arg['architectures'][0] \\\n                == 'Qwen2ForCausalLM' else 0\n            attn_bias = int(model_arg.get('attn_bias', attn_bias))\n            use_logn_attn = int(model_arg.get('use_logn_attn', 0))\n\n            if isinstance(rope_scaling, dict):\n                scaling_type = model_arg['rope_scaling'].get('type', '')\n                scaling_factor = model_arg['rope_scaling'].get('factor', '')\n\n            rope_param = RopeParam(type=scaling_type,\n                                   base=rope_theta,\n                                   dim=hidden_units // attn_head_num,\n                                   max_position_embeddings=max_position_embeddings,\n                                   factor=scaling_factor)\n\n        return dict(num_layer=num_layer,\n                    norm_eps=norm_eps,\n                    head_num=attn_head_num,\n                    hidden_units=hidden_units,\n                    kv_head_num=kv_head_num,\n                    rope_param=rope_param,\n                    max_position_embeddings=max_position_embeddings,\n                    inter_size=intermediate_size,\n                    use_logn_attn=use_logn_attn,\n                    attn_bias=attn_bias,\n                    vocab_size=vocab_size)\n"
  },
  {
    "path": "lmdeploy/turbomind/deploy/source_model/minicpmv.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nimport json\nimport os.path as osp\n\nfrom .base import INPUT_MODELS\nfrom .llama import LlamaModel, LlamaReader\n\n\nclass MiniCPMVReader(LlamaReader):\n    \"\"\"MiniCPMVReader for llama model.\"\"\"\n\n    attn_layer_prefix = 'llm.model.layers'\n    attn_layer_patten = r'llm\\.model\\.layers\\.([0-9]+).'\n    tok_embeddings_key = 'llm.model.embed_tokens.weight'\n    norm_weight_key = 'llm.model.norm.weight'\n    output_weight_key = 'llm.lm_head.weight'\n\n\n@INPUT_MODELS.register_module(name='minicpmv')\nclass MiniCPMVModel(LlamaModel):\n    \"\"\"MiniCPMV model in hf format.\"\"\"\n    Reader = MiniCPMVReader\n\n    def model_info(self):\n        info = super().model_info()\n        with open(osp.join(self.model_path, 'config.json')) as f:\n            config = json.load(f)\n            if str(config.get('version')) == '2.6':\n                info['attn_bias'] = True\n        return info\n"
  },
  {
    "path": "lmdeploy/turbomind/deploy/source_model/mixtral.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nfrom .base import INPUT_MODELS\nfrom .llama import LlamaModel, LlamaReader\n\n\nclass MixtralReader(LlamaReader):\n\n    def moe_ffn_expert(self, e=None, i=None, kind=None):\n        if not kind:\n            return self.filter(r'experts', i)\n        result = []\n        for x in ['w1', 'w2', 'w3']:\n            name = f'model.layers.{i}.block_sparse_moe.experts.{e}.{x}.{kind}'\n            tensor = self.params.get(name)\n            tensor = self.transform(tensor, kind)\n            result.append(tensor)\n        return (*result, )\n\n    def moe_ffn_gate(self, i, kind):\n        return self.params.get(f'model.layers.{i}.block_sparse_moe.gate.{kind}')\n\n\n@INPUT_MODELS.register_module(name='mixtral')\nclass MixtralModel(LlamaModel):\n\n    Reader = MixtralReader\n\n    def model_info(self):\n        cfg = self.model_config\n        info = super().model_info()\n        info['expert_num'] = cfg['num_local_experts']\n        info['expert_inter_size'] = cfg['intermediate_size']\n        info['experts_per_token'] = cfg['num_experts_per_tok']\n        info['norm_topk_prob'] = True\n        info['inter_size'] = 0\n        return info\n"
  },
  {
    "path": "lmdeploy/turbomind/deploy/source_model/molmo.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport json\nimport os.path as osp\n\nimport torch\n\nfrom ..config import RopeParam\nfrom .base import INPUT_MODELS\nfrom .llama import LlamaModel, LlamaReader\n\n\nclass MolmoReader(LlamaReader):\n    attn_layer_prefix = 'model.transformer.blocks'\n    attn_layer_patten = r'model\\.transformer\\.blocks\\.([0-9]+).'\n    norm_weight_key = 'model.transformer.ln_f.weight'\n    output_weight_key = 'model.transformer.ff_out.weight'\n\n    # In molmo, names of attention parameters are \"att_proj.bias\",\n    # \"att_proj.weight\", \"attn_norm.weight\", \"attn_out.weight\", and names\n    # of ffn parameters are \"ff_norm\", \"ff_out\", \"ff_proj\", so we\n    # make the patterns are r'att' and r'ffn_', respectively.\n    attn_pattern = r'att'\n    ffn_pattern = r'ff_'\n\n    def tok_embeddings(self):\n        embed1 = self.params.get('model.transformer.wte.embedding', None)\n        embed2 = self.params.get('model.transformer.wte.new_embedding', None)\n        if embed1 is not None and embed2 is not None:\n            return torch.cat((embed1, embed2), dim=0)\n        else:\n            assert embed1 is None and embed2 is None\n            return None\n\n    def attn_norm(self, i: int):\n        \"\"\"Get attn norm for layer i.\"\"\"\n        return self.params[f'{self.attn_layer_prefix}.{i}.attn_norm.weight']\n\n    def _attn(self, i: int, kind: str):\n        \"\"\"Get q, k, v, o kind(weight, bias, qweight) for layer i.\n\n        Args:\n            i (int): layer id\n            kind (str): can be one of [\"weight\", \"bias\", \"qweight\"]\n        \"\"\"\n        q, k, v = (None, ) * 3\n        hidden_size = self.model_cfg['hidden_size']\n        head_num = self.model_cfg['num_attention_heads']\n        kv_head_num = self.model_cfg['num_key_value_heads']\n        head_dim = hidden_size // head_num\n        assert head_dim == 128\n        fused_dims = (hidden_size, kv_head_num * head_dim, kv_head_num * head_dim)\n        qkv = self.params.get(f'{self.attn_layer_prefix}.{i}.att_proj.{kind}')\n        qkv = self.transform(qkv, kind)\n        if qkv is not None:\n            q, k, v = qkv.split(fused_dims, dim=0)\n        o = self.params.get(f'{self.attn_layer_prefix}.{i}.attn_out.{kind}')\n        o = self.transform(o, kind)\n        if o is None:  # handle the case when qkv has bias but o doesn't\n            o = torch.zeros_like(q)\n        return (q, k, v, o)\n\n    def _ffn(self, i: int, kind: str):\n        \"\"\"Get ffn kind(weight, qweight) for layer i.\"\"\"\n        up_and_gate = self.params[f'{self.attn_layer_prefix}.{i}.ff_proj.{kind}']\n        up_and_gate = self.transform(up_and_gate, kind)\n        gate, up = up_and_gate.chunk(2, dim=0)\n        down = self.params[f'{self.attn_layer_prefix}.{i}.ff_out.{kind}']\n        down = self.transform(down, kind)\n        return (up, down, gate)\n\n    def ffn_norm(self, i: int):\n        \"\"\"Get ffn norm for layer i.\"\"\"\n        return self.params[f'{self.attn_layer_prefix}.{i}.ff_norm.weight']\n\n\n@INPUT_MODELS.register_module(name='molmo')\nclass MolmoModel(LlamaModel):\n\n    Reader = MolmoReader\n\n    def __init__(self, model_path: str, tokenizer_path: str, **kwargs):\n        super().__init__(model_path, tokenizer_path, **kwargs)\n        config_path = osp.join(self.model_path, 'config.json')\n        with open(config_path) as f:\n            self.config = json.load(f)\n\n    def model_info(self):\n        config = self.config\n        num_layer = config['num_hidden_layers']\n        norm_eps = config['layer_norm_eps']\n        attn_head_num = config['num_attention_heads']\n        kv_head_num = config['num_key_value_heads']\n        hidden_units = config['hidden_size']\n        rope_theta = config['rope_theta']\n        max_position_embeddings = config['max_position_embeddings']\n        vocab_size = config['vocab_size']\n        # https://huggingface.co/allenai/Molmo-7B-D-0924/blob/main/modeling_molmo.py#L2041\n        additional_vocab_size = 128\n        inter_size = config['intermediate_size'] // 2\n        attn_bias = config['qkv_bias']\n        rope_param = RopeParam(type='default', base=rope_theta, dim=hidden_units // attn_head_num)\n        return dict(\n            num_layer=num_layer,\n            norm_eps=norm_eps,\n            head_num=attn_head_num,\n            kv_head_num=kv_head_num,\n            hidden_units=hidden_units,\n            attn_bias=int(attn_bias),\n            inter_size=inter_size,\n            vocab_size=vocab_size,\n            # https://huggingface.co/allenai/Molmo-7B-D-0924/blob/main/modeling_molmo.py#L564\n            embedding_size=vocab_size + additional_vocab_size,\n            rope_param=rope_param,\n            max_position_embeddings=max_position_embeddings,\n        )\n"
  },
  {
    "path": "lmdeploy/turbomind/deploy/source_model/qwen.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport json\nimport os.path as osp\nimport re\n\nimport torch\n\nfrom ..config import RopeParam\nfrom ..loader import create_loader\nfrom .base import INPUT_MODELS\nfrom .llama import LlamaModel, LlamaReader\n\n\nclass QwenReader(LlamaReader):\n    \"\"\"QwenReader.\"\"\"\n\n    attn_layer_patten = r'transformer\\.h\\.([0-9]+).'\n    tok_embeddings_key = 'transformer.wte.weight'\n    norm_weight_key = 'transformer.ln_f.weight'\n    output_weight_key = 'lm_head.weight'\n\n    attn_pattern = r'attn'\n    ffn_pattern = r'mlp'\n\n    def _attn(self, i: int, kind: str):\n        \"\"\"Get q, k, v, o kind for layer i.\"\"\"\n        q, k, v, o = (None, ) * 4\n        qkv = self.params[f'transformer.h.{i}.attn.c_attn.{kind}']\n        qkv = self.transform(qkv, kind)\n        if qkv is not None:\n            q, k, v = torch.split(qkv, qkv.size(0) // 3, dim=0)\n        o = self.params.get(f'transformer.h.{i}.attn.c_proj.{kind}')\n        o = self.transform(o, kind)\n        if o is None:\n            o = torch.zeros_like(q)\n        return q, k, v, o\n\n    def attn_norm(self, i: int):\n        \"\"\"Get attn norm for layer i.\"\"\"\n        return self.params[f'transformer.h.{i}.ln_1.weight']\n\n    def _ffn(self, i: int, kind: str):\n        \"\"\"Get ffn kind for layer i.\"\"\"\n        result = []\n        for key in ['w2', 'c_proj', 'w1']:\n            tensor = self.params[f'transformer.h.{i}.mlp.{key}.{kind}']\n            tensor = self.transform(tensor, kind)\n            result.append(tensor)\n        return (*result, )\n\n    def ffn_norm(self, i: int):\n        \"\"\"Get ffn norm for layer i.\"\"\"\n        return self.params[f'transformer.h.{i}.ln_2.weight']\n\n\n@INPUT_MODELS.register_module(name='qwen')\nclass QwenModel(LlamaModel):\n    \"\"\"Qwen model in hf format.\"\"\"\n\n    Reader = QwenReader\n\n    def model_info(self):\n        \"\"\"Read model info.\"\"\"\n        params_path = osp.join(self.model_path, 'config.json')\n        with open(params_path) as f:\n            config = json.load(f)\n            hidden_units = config['hidden_size']\n            num_layer = config['num_hidden_layers']\n            norm_eps = config['layer_norm_epsilon']\n            kv_channels = config['kv_channels']\n            rope_theta = float(config.get('rotary_emb_base', 10000.0))\n            if 'num_key_value_heads' in config:\n                kv_head_num = config['num_key_value_heads']\n            else:\n                kv_head_num = config['num_attention_heads']\n            attn_head_num = config['num_attention_heads']\n            seq_length = config['seq_length']\n            use_dynamic_ntk = int(config['use_dynamic_ntk'])\n            use_logn_attn = int(config['use_logn_attn'])\n            vocab_size = config['vocab_size']\n            inter_size = config['intermediate_size']\n            scaling_type = 'dynamic' if use_dynamic_ntk else 'default'\n            # need setting rope_scaling_factor in TurbomindEngineConfig if scaling_type is dynamic\n            rope_param = RopeParam(type=scaling_type,\n                                   base=rope_theta,\n                                   dim=kv_channels,\n                                   max_position_embeddings=seq_length,\n                                   factor=0)\n\n        return dict(size_per_head=kv_channels,\n                    num_layer=num_layer,\n                    norm_eps=norm_eps,\n                    hidden_units=hidden_units,\n                    head_num=attn_head_num,\n                    kv_head_num=kv_head_num,\n                    vocab_size=vocab_size,\n                    inter_size=inter_size,\n                    attn_bias=1,\n                    rope_param=rope_param,\n                    max_position_embeddings=seq_length,\n                    use_dynamic_ntk=int(use_dynamic_ntk),\n                    use_logn_attn=use_logn_attn)\n\n\n@INPUT_MODELS.register_module(name='qwen2')\nclass Qwen2Model(LlamaModel):\n    \"\"\"Qwen model in hf format.\n\n    The weight of qwen2 model is similar to Llama, except its attention bias doesn't include o_proj bias.\n    \"\"\"\n\n    Reader = LlamaReader\n\n    def model_info(self):\n        cfg = super().model_info()\n        cfg['attn_bias'] = 1\n        return cfg\n\n\nclass Qwen2MoeReader(LlamaReader):\n\n    def moe_ffn_expert(self, e=None, i=None, kind=None):\n        if not kind:\n            return self.filter(r'experts', i)\n        result = []\n        for key in ['gate', 'down', 'up']:\n            name = f'{self.attn_layer_prefix}.{i}.mlp.experts.{e}.{key}_proj.{kind}'\n            tensor = self.params.get(name)\n            tensor = self.transform(tensor, kind)\n            result.append(tensor)\n        return (*result, )\n\n    def moe_ffn_gate(self, i, kind):\n        return self.transform(self.params.get(f'{self.attn_layer_prefix}.{i}.mlp.gate.{kind}'), kind)\n\n    def _ffn(self, i: int, kind: str):\n        \"\"\"Get ffn kind for layer i.\"\"\"\n        if not kind:\n            return self.filter(r'shared_expert\\.', i)\n        result = []\n        for key in ['gate', 'down', 'up']:\n            tensor = self.params[f'{self.attn_layer_prefix}.{i}.mlp.shared_expert.{key}_proj.{kind}']\n            tensor = self.transform(tensor, kind)\n            result.append(tensor)\n        return (*result, )\n\n    def ffn(self, i: int, kind: str):\n        if not kind:\n            return self.filter(r'shared_expert\\.', i)\n        return self._ffn(i, kind)\n\n    def moe_ffn_shared_gate(self, i):\n        return self.params.get(f'{self.attn_layer_prefix}.{i}.mlp.shared_expert_gate.weight')\n\n\n@INPUT_MODELS.register_module(name='qwen2-moe')\nclass Qwen2MoeModel(LlamaModel):\n\n    Reader = Qwen2MoeReader\n\n    def model_info(self):\n        cfg = self.model_config\n        info = super().model_info()\n        info['expert_num'] = cfg['num_experts']\n        info['expert_inter_size'] = cfg['moe_intermediate_size']\n        info['experts_per_token'] = cfg['num_experts_per_tok']\n        info['inter_size'] = cfg['shared_expert_intermediate_size']\n        info['moe_shared_gate'] = True\n        info['norm_topk_prob'] = cfg['norm_topk_prob']\n        info['attn_bias'] = cfg.get('qkv_bias', 1)\n        return info\n\n\nclass Qwen3Reader(LlamaReader):\n\n    def qk_norm(self, i: int):\n        result = []\n        for x in ['q', 'k']:\n            name = f'{self.attn_layer_prefix}.{i}.self_attn.{x}_norm.weight'\n            result.append(self.transform(self.params.get(name), 'weight'))\n        return (*result, )\n\n\n@INPUT_MODELS.register_module(name='qwen3')\nclass Qwen3Model(LlamaModel):\n    Reader = Qwen3Reader\n\n    def model_info(self):\n        cfg = self.model_config\n        info = super().model_info()\n        info.update(qk_norm=True, attn_bias=cfg.get('attention_bias', 0))\n        return info\n\n\nclass Qwen3MoeReader(Qwen2MoeReader):\n\n    def qk_norm(self, i: int):\n        result = []\n        for x in ['q', 'k']:\n            name = f'{self.attn_layer_prefix}.{i}.self_attn.{x}_norm.weight'\n            result.append(self.transform(self.params.get(name), 'weight'))\n        return (*result, )\n\n\n@INPUT_MODELS.register_module(name='qwen3-moe')\nclass Qwen3MoeModel(LlamaModel):\n    Reader = Qwen3MoeReader\n\n    def model_info(self):\n        cfg = self.model_config\n        info = super().model_info()\n        info.update(\n            qk_norm=True,\n            expert_num=cfg.get('num_experts', 128),\n            experts_per_token=cfg.get('num_experts_per_tok', 8),\n            expert_inter_size=cfg.get('moe_intermediate_size', 768),\n            attn_bias=cfg.get('attention_bias', 0),\n            inter_size=0,  # no shared expert\n            norm_topk_prob=cfg.get('norm_topk_prob', False))\n        return info\n\n\nclass Qwen3_5ReaderMixin:\n    \"\"\"Mixin providing linear attention weight reading for Qwen3.5 models.\n\n    Qwen3.5 uses a zero-centered RMSNorm: ``output = norm(x) * (1 + weight)``\n    where weight is initialized to zeros.  TurboMind's RMSNorm kernel computes\n    ``norm(x) * weight`` (standard LLaMA style), so we add 1 to every\n    RMSNorm weight during export.  The GDN-internal norm\n    (``Qwen3_5MoeRMSNormGated``) uses standard weight and is NOT affected.\n    \"\"\"\n\n    attn_layer_pattern = r'(?:model\\.language_model\\.|model\\.)layers\\.([0-9]+)\\.'\n\n    _LINEAR_ATTN_KEYS = ['conv1d', 'in_proj_qkv', 'in_proj_z', 'in_proj_b', 'in_proj_a', 'out_proj', 'A_log', 'dt_bias']\n\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        if any(k.startswith('model.language_model.') for k in self.params.keys()):\n            self.attn_layer_prefix = 'model.language_model.layers'\n            self.tok_embeddings_key = 'model.language_model.embed_tokens.weight'\n            self.norm_weight_key = 'model.language_model.norm.weight'\n        tie_word_embeddings = self.model_cfg.get('tie_word_embeddings', False)\n        if tie_word_embeddings:\n            self.output_weight_key = self.tok_embeddings_key\n\n    # ---- zero-centered RMSNorm: add 1 to weights during export ----\n    def attn_norm(self, i: int):\n        w = super().attn_norm(i)\n        if w is not None:\n            w = w.float() + 1.0\n        return w\n\n    def ffn_norm(self, i: int):\n        w = super().ffn_norm(i)\n        if w is not None:\n            w = w.float() + 1.0\n        return w\n\n    def norm_weight(self):\n        w = super().norm_weight()\n        if w is not None:\n            w = w.float() + 1.0\n        return w\n\n    def qk_norm(self, i: int):\n        result = super().qk_norm(i)\n        return tuple(w.float() + 1.0 if w is not None else w for w in result)\n\n    # ---- handle mixed QKV(fp16) + O(AWQ) attention layers -------\n\n    def _attn(self, i: int, kind: str):\n        \"\"\"Override to handle mixed QKV(fp16) + O(AWQ) attention layers.\n\n        Some AWQ-quantized Qwen3.5 models keep QKV in fp16 while quantizing only the O projection.  TurboMind requires\n        uniform weight types per layer, so we dequantize O to fp16 at export time.\n        \"\"\"\n        prefix = f'{self.attn_layer_prefix}.{i}.self_attn'\n        q_is_fp16 = f'{prefix}.q_proj.weight' in self.params\n        o_is_awq = f'{prefix}.o_proj.qweight' in self.params\n\n        if not (q_is_fp16 and o_is_awq):\n            # Not a mixed-format layer, use standard behaviour.\n            return super()._attn(i, kind)\n\n        # Mixed format detected: QKV are fp16 but O is AWQ.\n        if kind == 'weight':\n            # Get fp16 QKV the normal way, then dequantize O.\n            q, k, v, _ = super()._attn(i, kind)\n            o = self._awq_dequant(f'{prefix}.o_proj')\n            o = self.transform(o, kind)\n            return (q, k, v, o)\n\n        # For any quant kind (qweight/scales/qzeros), return all None\n        # so that the AWQ handler skips this layer entirely — the O\n        # weight is already handled via dequantization above.\n        return (None, None, None, None)\n\n    def _awq_dequant(self, prefix: str):\n        \"\"\"Dequantize an AWQ-quantized linear layer to fp16.\n\n        AWQ stores weights in transposed form relative to PyTorch's\n        convention ([in, out] vs [out, in]), so we transpose here to\n        match the fp16 ``.weight`` layout that downstream export\n        expects.\n        \"\"\"\n        from lmdeploy.pytorch.backends.default.awq_modules import dequantize_gemm\n        qweight = self.params[f'{prefix}.qweight']\n        scales = self.params[f'{prefix}.scales']\n        qzeros = self.params[f'{prefix}.qzeros']\n        group_size = qweight.shape[0] // scales.shape[0]\n        w = dequantize_gemm(qweight, qzeros, scales, 4, group_size)\n        return w.t()  # [in, out] → [out, in] (PyTorch convention)\n\n    def linear_attn(self, i: int, kind: str):\n        if not kind:\n            return self.filter(r'linear_attn\\.', i)\n        # Always return a fixed-length tuple with None placeholders to\n        # preserve positional alignment with the name list in module.py.\n        result = []\n        for key in self._LINEAR_ATTN_KEYS:\n            prefix = f'{self.attn_layer_prefix}.{i}.linear_attn.{key}'\n            tensor = self.params.get(f'{prefix}.{kind}')\n            # A_log and dt_bias are bare nn.Parameter (no .weight suffix)\n            if tensor is None:\n                tensor = self.params.get(prefix)\n            # If requesting weight but only AWQ qweight exists,\n            # dequantize on the fly so LinearAttn gets fp16 tensors.\n            if tensor is None and kind == 'weight':\n                if f'{prefix}.qweight' in self.params:\n                    tensor = self._awq_dequant(prefix)\n            if tensor is not None:\n                tensor = self.transform(tensor, kind)\n            result.append(tensor)  # keep None to preserve alignment\n        if all(t is None for t in result):\n            return tuple()\n        return tuple(result)\n\n    def linear_norm(self, i: int, kind: str = 'weight'):\n        tensor = self.params.get(f'{self.attn_layer_prefix}.{i}.linear_attn.norm.{kind}')\n        if tensor is not None:\n            return self.transform(tensor, kind)\n        return None\n\n\nclass Qwen3_5Reader(Qwen3_5ReaderMixin, Qwen3Reader):\n    pass\n\n\n@INPUT_MODELS.register_module(name='qwen3_5')\nclass Qwen3_5Model(Qwen3Model):\n    Reader = Qwen3_5Reader\n\n    def model_info(self):\n        if 'text_config' in self.model_config:\n            self.model_config = self.model_config['text_config']\n        cfg = self.model_config\n        info = super().model_info()\n        # MoE parameters (same as Qwen2MoeModel / Qwen3MoeModel)\n        info['expert_num'] = cfg.get('num_experts', 0)\n        info['expert_inter_size'] = cfg.get('moe_intermediate_size', 0)\n        info['experts_per_token'] = cfg.get('num_experts_per_tok', 0)\n        # For MoE models, inter_size is the shared expert intermediate size;\n        # for dense models, keep the value from super() (intermediate_size).\n        shared_expert_size = cfg.get('shared_expert_intermediate_size')\n        if shared_expert_size is not None:\n            info['inter_size'] = shared_expert_size\n        info['moe_shared_gate'] = True\n        # Qwen3.5 uses sigmoid MoE routing (not softmax)\n        info['scoring_func'] = 'softmax'\n        info['norm_topk_prob'] = True\n        # Fix RoPE dim for partial_rotary_factor\n        rope_params = cfg.get('rope_parameters', {})\n        partial_rotary_factor = rope_params.get('partial_rotary_factor', cfg.get('partial_rotary_factor', 1.0))\n        if partial_rotary_factor < 1.0:\n            info['rope_param'].dim = int(info['size_per_head'] * partial_rotary_factor)\n        # Linear attention parameters\n        info['layer_types'] = cfg.get('layer_types', [])\n        info['linear_key_head_dim'] = cfg.get('linear_key_head_dim', 0)\n        info['linear_value_head_dim'] = cfg.get('linear_value_head_dim', 0)\n        info['linear_conv_kernel_dim'] = cfg.get('linear_conv_kernel_dim', 0)\n        info['linear_num_key_heads'] = cfg.get('linear_num_key_heads', 0)\n        info['linear_num_value_heads'] = cfg.get('linear_num_value_heads', 0)\n        # attn_output_gate doubles Q projection for full-attention layers\n        info['attn_output_gate'] = cfg.get('attn_output_gate', False)\n        return info\n\n\nclass Qwen3_5MoeReader(Qwen3_5ReaderMixin, Qwen3MoeReader):\n\n    def _unpacked_moe_expert(self, e: int, i: int, kind: str):\n        prefix = f'{self.attn_layer_prefix}.{i}.mlp.experts'\n        gate_up = self.params.get(f'{prefix}.gate_up_proj.{kind}')\n        down = self.params.get(f'{prefix}.down_proj.{kind}')\n        if gate_up is None or down is None:\n            return None\n\n        # Packed Qwen3.5 MoE checkpoints store all experts in the first\n        # dimension. Slice one expert before transform so quantized policies\n        # still see a 2D tensor.\n        gate_up = self.transform(gate_up[e], kind)\n        down = self.transform(down[e], kind)\n        gate, up = gate_up.chunk(2, dim=0)\n        return (gate, down, up)\n\n    def moe_ffn_expert(self, e=None, i=None, kind=None):\n        if not kind:\n            return self.filter(r'experts', i)\n        unpacked = self._unpacked_moe_expert(e, i, kind)\n        if unpacked is not None:\n            return unpacked\n\n        return super().moe_ffn_expert(e, i, kind)\n\n\n@INPUT_MODELS.register_module(name='qwen3_5-moe')\nclass Qwen3_5MoeModel(Qwen3MoeModel):\n    Reader = Qwen3_5MoeReader\n\n    @staticmethod\n    def map_packed_qwen35_experts(name: str):\n        \"\"\"Map packed expert names to weight names, i.e.,\n        \"mlp.experts.gate_up_proj\" -> \"mlp.experts.gate_up_proj.weight\" so that\n        class Weight in parameter.py can classify them.\"\"\"\n        s = re.sub(r'(mlp\\.experts\\.(?:gate_up|down)_proj)$', r'\\1.weight', name)\n        return s\n\n    def readers(self):\n        pattern = getattr(self.Reader, 'attn_layer_pattern', self.Reader.attn_layer_patten)\n        loader = create_loader(self.model_path, pattern, [])\n\n        has_packed_gate_up = any('mlp.experts.gate_up_proj' in k for k in loader.index.keys())\n        has_packed_down = any('mlp.experts.down_proj' in k for k in loader.index.keys())\n        if has_packed_gate_up and has_packed_down:\n            loader.mappings = [self.map_packed_qwen35_experts]\n\n        for i, param in loader.items():\n            reader = self.Reader(param, {}, False, self.model_config, policy=self.policy, fp8_quant=self.fp8_quant)\n            yield i, reader\n        torch.cuda.empty_cache()\n\n    def model_info(self):\n        if 'text_config' in self.model_config:\n            self.model_config = self.model_config['text_config']\n        cfg = self.model_config\n        info = super().model_info()\n        # Shared expert params (missing from Qwen3MoeModel base)\n        info['inter_size'] = cfg.get('shared_expert_intermediate_size', 0)\n        info['moe_shared_gate'] = True\n        # Qwen3.5 uses sigmoid MoE routing (not softmax)\n        info['scoring_func'] = 'softmax'\n        info['norm_topk_prob'] = True\n        # Fix RoPE dim for partial_rotary_factor\n        rope_params = cfg.get('rope_parameters', {})\n        partial_rotary_factor = rope_params.get('partial_rotary_factor', cfg.get('partial_rotary_factor', 1.0))\n        if partial_rotary_factor < 1.0:\n            info['rope_param'].dim = int(info['size_per_head'] * partial_rotary_factor)\n        # Linear attention parameters\n        info['layer_types'] = cfg.get('layer_types', [])\n        info['linear_key_head_dim'] = cfg.get('linear_key_head_dim', 0)\n        info['linear_value_head_dim'] = cfg.get('linear_value_head_dim', 0)\n        info['linear_conv_kernel_dim'] = cfg.get('linear_conv_kernel_dim', 0)\n        info['linear_num_key_heads'] = cfg.get('linear_num_key_heads', 0)\n        info['linear_num_value_heads'] = cfg.get('linear_num_value_heads', 0)\n        # attn_output_gate doubles Q projection for full-attention layers\n        info['attn_output_gate'] = cfg.get('attn_output_gate', False)\n        return info\n"
  },
  {
    "path": "lmdeploy/turbomind/deploy/source_model/xcomposer2.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nfrom .base import INPUT_MODELS\nfrom .internlm2 import InternLM2Model, InternLM2Reader\n\n\nclass Xcomposer2Reader(InternLM2Reader):\n    \"\"\"Xcomposer2 model reader.\"\"\"\n\n    # include only Plora and ignore other lora weights\n    attn_pattern = r'attention.\\w+(.Plora_[AB])?.\\w+$'\n    ffn_pattern = r'feed_forward.\\w+(.Plora_[AB])?.\\w+$'\n\n    def _attn(self, i, kind):\n        if 'Plora_A' in kind:\n            qkv = self.params[f'model.layers.{i}.attention.wqkv.Plora_A.weight']\n            o = self.params[f'model.layers.{i}.attention.wo.Plora_A.weight']\n            return qkv, o\n        return super()._attn(i, kind)\n\n\n@INPUT_MODELS.register_module(name='xcomposer2')\nclass Xcomposer2Model(InternLM2Model):\n    \"\"\"Xcomposer2 model in hf format.\"\"\"\n\n    Reader = Xcomposer2Reader\n\n    def _lora_cfg_7b(self):\n        \"\"\"Lora config for internlm-xcomposer2-7b.\"\"\"\n        return dict(lora_r=256, lora_scale=1.0, lora_policy='plora', lora_max_wo_r=256)\n\n    def _lora_cfg_4khd_7b(self, model_info: dict):\n        \"\"\"Lora config for internlm-xcomposer2-4khd-7b.\"\"\"\n        rank_pattern = ['attention.w_qkv:8', 'attention.wo:256']\n        scale_pattern = ['attention.w_qkv:2.0', 'attention.wo:1.0']\n        rank_pattern = ','.join(rank_pattern)\n        scale_pattern = ','.join(scale_pattern)\n        return dict(lora_r=256,\n                    lora_scale=1.0,\n                    lora_max_wo_r=256,\n                    lora_policy='plora',\n                    lora_rank_pattern=rank_pattern,\n                    lora_scale_pattern=scale_pattern)\n\n    def model_info(self):\n        out = super().model_info()\n        from lmdeploy.vl.model.xcomposer2 import ModelType, get_xcomposer_type\n        model_type, _ = get_xcomposer_type(self.model_path)\n        if model_type == ModelType.XCOMPOSER2_4KHD:\n            out.update(self._lora_cfg_4khd_7b(out))\n        else:\n            out.update(self._lora_cfg_7b())\n        return out\n"
  },
  {
    "path": "lmdeploy/turbomind/deploy/target_model/__init__.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom .fp import TurbomindModel  # noqa: F401\n"
  },
  {
    "path": "lmdeploy/turbomind/deploy/target_model/base.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nimport os.path as osp\nfrom abc import ABC\nfrom collections.abc import Sequence\n\nimport torch\nimport tqdm\nimport yaml\nfrom mmengine import Registry\n\nfrom ..config import AttentionConfig, LoraConfig, ModelConfig, TurbomindModelConfig, config_from_dict, config_to_dict\nfrom ..source_model.base import BaseInputModel\n\nOUTPUT_MODELS = Registry('target model', locations=['lmdeploy.turbomind.deploy.target_model.base'])\n\n\ndef tprint(*args, **kwargs):\n    to_file = kwargs.pop('to_file', False)\n    if not to_file:\n        return\n    from io import StringIO\n    s = StringIO()\n    print(*args, **kwargs, file=s, end='')\n    tqdm.tqdm.write(s.getvalue())\n\n\ndef _weight_dtype_map(weight_type: str, default=None):\n    \"\"\"Map literal data type to torch dtype.\"\"\"\n\n    _WEIGHT_DTYPE_MAP = dict(int4=torch.float16, float16=torch.float16, float32=torch.float16, bfloat16=torch.bfloat16)\n\n    return _WEIGHT_DTYPE_MAP.get(weight_type, default)\n\n\ndef _pad_inter_size(inter_size: int, group_size: int, tp: int):\n    group_size = max(1, group_size)\n    group_num = (inter_size + group_size - 1) // group_size\n    groups_per_rank = (group_num + tp - 1) // tp\n    inter_size_padded = groups_per_rank * group_size * tp\n    return inter_size_padded\n\n\nclass BaseOutputModel(ABC):\n    \"\"\"Base output model.\"\"\"\n\n    def __init__(self, input_model: BaseInputModel, cfg: TurbomindModelConfig, model_cls, out_dir: str = ''):\n        super().__init__()\n        self.input_model = input_model\n        self.model_config = cfg.model_config\n        self.attention_config = cfg.attention_config\n        self.lora_config = cfg.lora_config\n        self.attn_tp_size = self.model_config.attn_tp_size\n        self.attn_cp_size = self.model_config.attn_cp_size\n        self.mlp_tp_size = self.model_config.mlp_tp_size\n        self.out_dir = out_dir\n        self.to_file = True if out_dir else False\n        self.tm_params = dict()\n\n        # get `model_info` at first, which will be updated to `self.model_config` and `self.attention_config`\n        self.input_model_info = self.input_model.model_info()\n        self.input_model_info = self.single_to_list(self.input_model_info, keys=['inter_size', 'expert_num'])\n        self.permute_qk = self.input_model_info.get('permute_qk', True)\n        self.update_model_config()\n        for i, v in enumerate(self.model_config.inter_size):\n            self.model_config.inter_size[i] = _pad_inter_size(v, self.model_config.group_size, self.mlp_tp_size)\n        if self.model_config.expert_num:\n            self.model_config.expert_inter_size = _pad_inter_size(self.model_config.expert_inter_size,\n                                                                  self.model_config.group_size, self.mlp_tp_size)\n\n        # head_num is divisble by tp but kv_head_num is not\n        # and tp is divisble by kv_head_num\n        assert self.model_config.head_num % self.attn_tp_size == 0\n        self.repeat_kv = 0\n        if (self.attn_tp_size > self.model_config.kv_head_num\n                and self.attn_tp_size % self.model_config.kv_head_num == 0):\n            self.repeat_kv = (self.attn_tp_size // self.model_config.kv_head_num)\n            self.model_config.kv_head_num = self.attn_tp_size\n\n        self.model_config.verify()\n        assert self.model_config.kv_head_num % self.attn_tp_size == 0\n\n        # print(self.model_config)\n\n        self.update_attention_config()\n        self.update_lora_config()\n        # ! Dependency on `self`\n        self.model = model_cls(self)\n\n    def single_to_list(self, config: dict, keys):\n        num_layer = int(config['num_layer'])\n        for k in keys:\n            v = config.get(k, None)\n            if v is not None and not isinstance(v, Sequence):\n                config[k] = [v] * num_layer\n        return config\n\n    def update_model_config(self):\n        \"\"\"Update `self.model_config` according to the input_model's\n        `model_info`\"\"\"\n        final_cfg = config_to_dict(self.model_config)\n        final_cfg.update(self.input_model_info)\n        if 'embedding_size' not in self.input_model_info.keys():\n            final_cfg.update(embedding_size=self.input_model_info['vocab_size'])\n\n        self.model_config = config_from_dict(ModelConfig, final_cfg)\n\n    def update_attention_config(self):\n        \"\"\"Update attention config according to input model's model info.\"\"\"\n        final_cfg = config_to_dict(self.attention_config)\n        final_cfg.update(self.input_model_info)\n        self.attention_config = config_from_dict(AttentionConfig, final_cfg)\n\n    def update_lora_config(self):\n        \"\"\"Update lora config according to input model's model info.\"\"\"\n        final_cfg = config_to_dict(self.lora_config)\n        final_cfg.update(self.input_model_info)\n        self.lora_config = config_from_dict(LoraConfig, final_cfg)\n\n    def export_config(self) -> None:\n        \"\"\"Export turbomind config.\"\"\"\n        if self.to_file:\n            config_path = osp.join(self.out_dir, 'config.yaml')\n            with open(config_path, 'w') as f:\n                yaml.safe_dump(self.tm_config.to_dict(), f)\n\n    def export_weight(self, param: torch.Tensor, name: str) -> None:\n        \"\"\"Export turbomind weight.\"\"\"\n\n        def _tofile(tensor, path):\n            \"\"\"To file.\"\"\"\n            if tensor.dtype == torch.bfloat16:\n                tensor = tensor.view(torch.half)\n            tensor.contiguous().cpu().numpy().tofile(path)\n\n        if self.to_file:\n            if torch.is_floating_point(param):\n                torch_type = _weight_dtype_map(self.model_config.weight_type, torch.float16)\n                param = param.to(torch_type)\n            tprint(name, param.shape)\n            _tofile(param, osp.join(self.out_dir, name))\n        elif len(self.tm_params) > 0:\n            tm_params = self.tm_params\n            weight_type = self.model_config.weight_type\n            data_type = self.model_config.data_type\n            assert weight_type in ['float16', 'bfloat16', 'int4', 'fp8']\n\n            # currently, the tensor type should in\n            # [torch.float, torch.half, torch.bfloat16, torch.int32]\n            torch_tensor = param if param.is_contiguous() else param.contiguous()\n            torch_tensor = torch_tensor.cuda()\n            assert torch_tensor.dtype in [torch.int32, torch.float, torch.half, torch.bfloat16, torch.uint8]\n            FLOAT_TYPES = [torch.float, torch.half, torch.bfloat16]\n            if weight_type == 'fp8':\n                # avoid casting float scales to half\n                if torch_tensor.dtype == torch.bfloat16 and data_type == 'float16':\n                    torch_tensor = torch_tensor.half()\n            elif torch_tensor.dtype in FLOAT_TYPES:\n                if weight_type in ['float16', 'int4']:\n                    torch_tensor = torch_tensor.half()\n                elif weight_type == 'bfloat16':\n                    torch_tensor = torch_tensor.bfloat16()\n                else:\n                    torch_tensor = torch_tensor.half()\n            if name in tm_params:\n                try:\n                    import _turbomind as _tm\n                except ImportError:\n                    _tm = None\n                for tm_tensor in tm_params[name]:\n                    # Match TurboMind tensor dtype to avoid byte_size mismatch (e.g. f32 256b vs f16 128b)\n                    if _tm is not None:\n                        if tm_tensor.type == _tm.DataType.TYPE_FP32 and torch_tensor.dtype in [\n                                torch.float16, torch.bfloat16\n                        ]:\n                            torch_tensor = torch_tensor.float()\n                        elif tm_tensor.type == _tm.DataType.TYPE_FP16 and torch_tensor.dtype == torch.float32:\n                            torch_tensor = torch_tensor.half()\n                    tm_tensor.copy_from(torch_tensor)\n                tm_params.pop(name)\n        else:\n            tprint('skip export', name, param.shape)\n\n    def save_split(self, tensor: torch.Tensor, name: str, split_dim=None, split_num=1, copy=False) -> None:\n        \"\"\"Save split.\n\n        - 2D input\n            shape must be (input_dims, output_dims)\n        - 1D input (bias)\n            shape must be (output_dims)\n            split is skipped when split_dim == 0\n        \"\"\"\n\n        if copy or (tensor.dim() == 1 and split_dim == 0):\n            split_dim = None\n            copy = True\n\n        if split_dim is not None:\n            tprint(f'*** splitting {name}, shape={tensor.shape}, '\n                   f'split_dim={split_dim}, split_num={split_num}',\n                   to_file=self.to_file)\n            if tensor.shape[split_dim] % split_num != 0:\n                raise RuntimeError(f'{name}: shape={list(tensor.shape)}, split_num={split_num}')\n            split_size = tensor.shape[split_dim] // split_num\n            splits = torch.split(tensor, split_size, dim=split_dim)\n            for i, split in enumerate(splits):\n                prefix, ext = osp.splitext(name)\n                self.export_weight(split, f'{prefix}.{i}{ext}')\n        elif copy:\n            tprint(f'### copying {name}, shape={tensor.shape}', to_file=self.to_file)\n            copies = [tensor] * split_num\n            for i, copy in enumerate(copies):\n                prefix, ext = osp.splitext(name)\n                self.export_weight(copy, f'{prefix}.{i}{ext}')\n        else:\n            self.export_weight(tensor, name)\n\n    def export(self) -> None:\n        \"\"\"Export to turbomind model format.\"\"\"\n        num_layer = self.model_config.num_layer\n        from tqdm import tqdm\n        pbar = tqdm(total=num_layer, desc='Convert to turbomind format', leave=self.to_file)\n        self.export_config()\n        for i, reader in self.input_model.readers():\n            if self.model(i, reader):\n                pbar.update(1)\n        pbar.close()\n\n    def export_iter(self):\n        self.export_config()\n        for i, reader in self.input_model.readers():\n            self.model(i, reader)\n            yield i\n\n    @property\n    def tm_config(self):\n        return TurbomindModelConfig(model_config=self.model_config,\n                                    attention_config=self.attention_config,\n                                    lora_config=self.lora_config)\n"
  },
  {
    "path": "lmdeploy/turbomind/deploy/target_model/fp.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nfrom .base import OUTPUT_MODELS, BaseOutputModel\n\n\n@OUTPUT_MODELS.register_module(name='tm')\nclass TurbomindModel(BaseOutputModel):\n    \"\"\"Export to turbomind fp16 format.\"\"\"\n    pass\n"
  },
  {
    "path": "lmdeploy/turbomind/supported_models.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom lmdeploy.archs import get_model_arch, search_nested_config\nfrom lmdeploy.utils import get_logger\n\nlogger = get_logger('lmdeploy')\n\nSUPPORTED_ARCHS = dict(\n    # baichuan-7b\n    BaiChuanForCausalLM='baichuan',\n    # baichuan2-7b, baichuan-13b, baichuan2-13b\n    BaichuanForCausalLM='baichuan2',\n    # gpt-oss\n    GptOssForCausalLM='gpt-oss',\n    # internlm\n    InternLMForCausalLM='llama',\n    # internlm2\n    InternLM2ForCausalLM='internlm2',\n    # internlm3\n    InternLM3ForCausalLM='llama',\n    # llama, llama2, alpaca, vicuna, codellama, ultracm, yi,\n    # deepseek-coder, deepseek-llm\n    LlamaForCausalLM='llama',\n    # Qwen 7B-72B, Qwen-VL-7B\n    QWenLMHeadModel='qwen',\n    # Qwen2\n    Qwen2ForCausalLM='qwen2',\n    Qwen2MoeForCausalLM='qwen2-moe',\n    # Qwen2-VL\n    Qwen2VLForConditionalGeneration='qwen2',\n    # Qwen2.5-VL\n    Qwen2_5_VLForConditionalGeneration='qwen2',\n    # Qwen3\n    Qwen3ForCausalLM='qwen3',\n    Qwen3MoeForCausalLM='qwen3-moe',\n    # Qwen 3.5\n    Qwen3_5ForConditionalGeneration='qwen3_5',\n    Qwen3_5MoeForConditionalGeneration='qwen3_5-moe',\n    # mistral\n    MistralForCausalLM='llama',\n    # llava\n    LlavaLlamaForCausalLM='llama',\n    LlavaMistralForCausalLM='llama',\n    LlavaForConditionalGeneration='llava',\n    # xcomposer2\n    InternLMXComposer2ForCausalLM='xcomposer2',\n    # internvl\n    InternVLChatModel='internvl',\n    # internvl3\n    InternVLForConditionalGeneration='internvl',\n    InternS1ForConditionalGeneration='internvl',\n    # deepseek-vl\n    MultiModalityCausalLM='deepseekvl',\n    DeepseekV2ForCausalLM='deepseek2',\n    # MiniCPMV\n    MiniCPMV='minicpmv',\n    # chatglm2/3, glm4\n    ChatGLMModel='glm4',\n    ChatGLMForConditionalGeneration='glm4',\n    # glm4-moe-lite (e.g. GLM-4.7-Flash)\n    Glm4MoeLiteForCausalLM='glm4-moe-lite',\n    # mixtral\n    MixtralForCausalLM='mixtral',\n    MolmoForCausalLM='molmo',\n)\n\n\ndef is_supported(model_path: str):\n    \"\"\"Check whether supported by turbomind engine.\n\n    Args:\n        model_path (str): the path of a model.\n            It could be one of the following options:\n                - i) A local directory path of a turbomind model which is\n                    converted by `lmdeploy convert` command or download from\n                    ii) and iii).\n                - ii) The model_id of a lmdeploy-quantized model hosted\n                    inside a model repo on huggingface.co, such as\n                    \"InternLM/internlm-chat-20b-4bit\",\n                    \"lmdeploy/llama2-chat-70b-4bit\", etc.\n                - iii) The model_id of a model hosted inside a model repo\n                    on huggingface.co, such as \"internlm/internlm-chat-7b\",\n                    \"Qwen/Qwen-7B-Chat \", \"baichuan-inc/Baichuan2-7B-Chat\"\n                    and so on.\n    Returns:\n        support_by_turbomind (bool): Whether input model is supported by turbomind engine\n    \"\"\"  # noqa: E501\n    import os\n\n    def _is_head_dim_supported(cfg):\n        head_dim = cfg.head_dim if hasattr(cfg, 'head_dim') else cfg.hidden_size // cfg.num_attention_heads\n        return head_dim in [128, 64]\n\n    support_by_turbomind = False\n    triton_model_path = os.path.join(model_path, 'triton_models')\n    if os.path.exists(triton_model_path):\n        support_by_turbomind = True\n    else:\n\n        arch, cfg = get_model_arch(model_path)\n        quant_method = search_nested_config(cfg.to_dict(), 'quant_method')\n        if quant_method and quant_method in ['smooth_quant']:\n            # tm hasn't support quantized models by applying smoothquant\n            return False\n\n        if arch in SUPPORTED_ARCHS.keys():\n            support_by_turbomind = True\n            # special cases\n            if arch == 'BaichuanForCausalLM':\n                num_attn_head = cfg.num_attention_heads\n                if num_attn_head == 40:\n                    # baichuan-13B, baichuan2-13B not supported by turbomind\n                    support_by_turbomind = False\n            elif arch in ['Qwen2ForCausalLM', 'LlamaForCausalLM']:\n                support_by_turbomind = _is_head_dim_supported(cfg)\n            elif arch in ('ChatGLMModel', 'ChatGLMForConditionalGeneration'):\n                # chatglm1/2/3 is not working yet\n                support_by_turbomind = cfg.num_layers == 40\n                if getattr(cfg, 'vision_config', None) is not None:\n                    # glm-4v-9b not supported\n                    support_by_turbomind = False\n            elif arch == 'InternVLChatModel':\n                llm_arch = cfg.llm_config.architectures[0]\n                support_by_turbomind = (llm_arch in SUPPORTED_ARCHS and _is_head_dim_supported(cfg.llm_config))\n            elif arch in ['LlavaForConditionalGeneration', 'InternVLForConditionalGeneration']:\n                llm_arch = cfg.text_config.architectures[0]\n                if llm_arch in ['Qwen2ForCausalLM', 'LlamaForCausalLM']:\n                    support_by_turbomind = _is_head_dim_supported(cfg.text_config)\n            elif arch == 'MolmoForCausalLM':\n                kv_heads = cfg.num_key_value_heads\n                # TM hasn't supported allenai/Molmo-7B-O-0924 yet\n                support_by_turbomind = kv_heads is not None\n            elif arch == 'DeepseekV2ForCausalLM':\n                if getattr(cfg, 'vision_config', None) is not None:\n                    support_by_turbomind = False\n            elif arch == 'Glm4MoeLiteForCausalLM':\n                if getattr(cfg, 'vision_config', None) is not None:\n                    support_by_turbomind = False\n\n    return support_by_turbomind\n"
  },
  {
    "path": "lmdeploy/turbomind/tokenizer_info.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n# Borrowed from xgrammar's TokenizerInfo\n\"\"\"This module provides the tokenizer info class to handle the tokenizer\ninformation.\"\"\"\n\nimport json\nimport logging\nfrom enum import Enum\nfrom typing import List, Optional, Union\n\nimport _xgrammar as _xgr  # noqa: E402\n\ntry:\n    import sentencepiece\nexcept ImportError:\n    sentencepiece = None\ntry:\n    import tiktoken\nexcept ImportError:\n    tiktoken = None\n\nfrom transformers import PreTrainedTokenizerBase, PreTrainedTokenizerFast\n\nlogger = logging.getLogger(__name__)\n\n\nclass VocabType(Enum):\n    \"\"\"The type of the vocabulary.\n\n    Used in TokenizerInfo. XGrammar supports three types of\n    vocabularies: RAW, BYTE_FALLBACK, BYTE_LEVEL.\n    \"\"\"\n\n    RAW = 0\n    \"\"\"The vocabulary is in the raw format.\n\n    The tokens in the vocabulary are kept in their original form without any processing. This kind of tokenizer includes\n    the tiktoken tokenizer, e.g. microsoft/Phi-3-small-8k-instruct, Qwen/Qwen-7B-Chat, etc.\n    \"\"\"\n\n    BYTE_FALLBACK = 1\n    r\"\"\"The vocabulary used in the byte fallback BPE tokenizer.\n\n    The tokens are encoded through the byte-fallback conversion. E.g. \"\\u001b\" -> \"<0x1B>\", \" apple\" -> \"▁apple\". This\n    kind of tokenizer includes meta-llama/Llama-2-7b-chat, microsoft/Phi-3.5-mini-instruct, etc.\n    \"\"\"\n\n    BYTE_LEVEL = 2\n    \"\"\"The vocabulary used in the byte level BPE tokenizer.\n\n    The tokens are encoded through the byte-to-unicode conversion, as in\n    https://github.com/huggingface/transformers/blob/87be06ca77166e6a6215eee5a990ab9f07238a18/src/transformers/models/gpt2/tokenization_gpt2.py#L38-L59\n\n    This kind of tokenizer includes meta-llama/Meta-Llama-3-8B-Instruct,\n    meta-llama/Meta-Llama-3.1-8B-Instruct, etc.\n    \"\"\"\n\n\nclass TokenizerInfo(_xgr.TokenizerInfo):\n    \"\"\"The tokenizer info contains the vocabulary, the type of the vocabulary,\n    and necessary information for the grammar-guided generation.\n\n    Note that although some tokenizers will encode the tokens in a special format, e.g. \"<0x1B>\" for \"\\u001b\" in the\n    ByteFallback tokenizer, and \"Ġ\" for \" \" in the Byte-Level BPE tokenizer, TokenizerInfo always decodes the vocabulary\n    to the original format (e.g. \"\\u001b\" and \" \").\n\n    Also note that some models (e.g. Phi-3 and Deepseek-V2) may pad the vocabulary to a multiple of 32. In this case,\n    the model's vocab_size is larger than the tokenizer's vocabulary size. Please pass the model's vocab_size to the\n    vocab_size parameter in the constructor, because this information is used to determine the size of the token mask.\n    \"\"\"\n\n    def __init__(\n        self,\n        encoded_vocab: Union[List[bytes], List[str]],\n        vocab_type: VocabType = VocabType.RAW,\n        *,\n        vocab_size: Optional[int] = None,\n        stop_token_ids: Optional[Union[List[int], int]] = None,\n        add_prefix_space: bool = False,\n    ) -> None:\n        \"\"\"Construct the tokenizer info.\n\n        Parameters\n        ----------\n        encoded_vocab : Union[List[bytes], List[str]]\n            The encoded vocabulary of the tokenizer.\n\n        vocab_type : VocabType, default: VocabType.RAW\n            The type of the vocabulary. See also VocabType.\n\n        vocab_size : Optional[int], default: None\n            The size of the vocabulary. If not provided, the vocabulary size will be len(encoded_vocab).\n\n        stop_token_ids : Optional[List[int]], default: None\n            The stop token ids. If not provided, the stop token ids will be auto detected (but may not\n            be correct).\n\n        add_prefix_space : bool, default: False\n            Whether the tokenizer will prepend a space before the text in the tokenization process.\n        \"\"\"\n        if isinstance(stop_token_ids, int):\n            stop_token_ids = [stop_token_ids]\n\n        super().__init__(encoded_vocab, vocab_type.value, vocab_size, stop_token_ids, add_prefix_space)\n\n    @staticmethod\n    def _is_tiktoken_tokenizer(tokenizer: PreTrainedTokenizerBase) -> bool:\n        if tiktoken is None:\n            return False\n\n        # helper to check if tokenizer is a tiktoken tokenizer\n        has_tiktoken_encoding = hasattr(tokenizer, 'tokenizer') and isinstance(tokenizer.tokenizer, tiktoken.Encoding)\n\n        filename_pattern = (hasattr(tokenizer, 'vocab_files_names') and 'vocab_file' in tokenizer.vocab_files_names\n                            and 'tiktoken' in tokenizer.vocab_files_names['vocab_file'])\n\n        return has_tiktoken_encoding or filename_pattern\n\n    @staticmethod\n    def _is_sentencepiece_tokenizer(tokenizer: PreTrainedTokenizerBase) -> bool:\n        if sentencepiece is None:\n            return False\n\n        # helper to check if tokenizer is a sentence piece tokenizer\n        has_sp_model_attr = hasattr(tokenizer, 'sp_model') and isinstance(tokenizer.sp_model,\n                                                                          sentencepiece.SentencePieceProcessor)\n\n        has_nested_sp_model_attr = (hasattr(tokenizer, 'tokenizer') and hasattr(tokenizer.tokenizer, 'sp_model')\n                                    and isinstance(tokenizer.tokenizer.sp_model, sentencepiece.SentencePieceProcessor))\n\n        return has_sp_model_attr or has_nested_sp_model_attr\n\n    @staticmethod\n    def from_huggingface(\n        tokenizer: PreTrainedTokenizerBase,\n        *,\n        vocab_size: Optional[int] = None,\n        stop_token_ids: Optional[Union[List[int], int]] = None,\n    ) -> 'TokenizerInfo':\n        \"\"\"Construct the tokenizer info from the huggingface tokenizer. This\n        constructor supports various tokenizer backends, including the\n        huggingface fast tokenizer and tiktoken tokenizer. Necessary\n        information is automatically detected from the tokenizer.\n\n        The vocab_size parameter is introduced to handle the misalignment between the model's\n        vocab_size and the tokenizer's vocabulary size. User should pass the model's vocab_size\n        (could be defined in the model config) here. See docs of vocab_size for more details.\n\n        The stop token ids is by default the eos_token_id of the tokenizer. If there are other\n        stop tokens, you can specify them manually.\n\n        Parameters\n        ----------\n        tokenizer : PreTrainedTokenizerBase\n            The huggingface tokenizer.\n\n        vocab_size : Optional[int], default: None\n            The vocabulary size **defined by the model** (**not the tokenizer**). This equals to the\n            vocab dimension of the model's lm_head. This is the size of the token mask.\n\n            It can be:\n\n            1. the same as the tokenizer's vocabulary size. This is the most common case.\n            2. larger than the tokenizer's vocabulary size. This happens when the model has padding\n               to lm_head, possibly due to aligning lm_head to the power of 2.\n               E.g. Phi-3 and Deepseek-V2.\n            3. smaller than the tokenizer's vocabulary size. This happens when the tokenizer has\n               some added tokens that will not supported by the model. E.g.\n               Llama-3.2 Vision and Molmo-72B-0924 has padded `<|image|>` tokens, but they will not\n               be considered in lm_head or generated by the model.\n\n            model_vocab_size need to be provided for case 2 and 3. If not provided, it will be\n            set to the tokenizer's vocabulary size.\n\n        stop_token_ids : Optional[List[int]], default: None\n            The stop token ids. If not provided, the eos_token_id of the tokenizer will be used.\n\n        Returns\n        -------\n        tokenizer_info : TokenizerInfo\n            The tokenizer info.\n        \"\"\"\n        if isinstance(stop_token_ids, int):\n            stop_token_ids = [stop_token_ids]\n        if isinstance(stop_token_ids, list) and len(stop_token_ids) == 0:\n            raise ValueError('stop_token_ids cannot be empty')\n\n        try:\n            vocab_dict = tokenizer.get_vocab()\n        except AttributeError as e:\n            msg = (f'Cannot get the vocabulary of the tokenizer {type(tokenizer)}. The tokenizer '\n                   'should have a get_vocab method.')\n            raise ValueError(msg) from e\n\n        # Some tokenizer don't have token id 0 or 1 or 2. So the max_id could be larger than the\n        # number of tokens.\n        max_id = max(vocab_dict.values())\n        tokenizer_vocab_size = max(len(vocab_dict), max_id + 1)\n\n        vocab_size = vocab_size or tokenizer_vocab_size\n\n        # maintain tokenizer's indexing\n        encoded_vocab = [''] * vocab_size\n        for token, idx in vocab_dict.items():\n            if idx < vocab_size:\n                encoded_vocab[idx] = token\n\n        if isinstance(tokenizer, PreTrainedTokenizerFast):\n            # huggingface fast tokenizer\n            # - the vocabulary is directly obtained from tokenizer.get_vocab()\n            #   (tokenizer.backend_tokenizer.to_str() may not contain the full vocab, special\n            #   tokens may be omitted)\n            # - the vocab size is obtained from len(tokenizer.get_vocab()) or provided by user\n            # - the vocab type and add_prefix_space are obtained from\n            #   tokenizer.backend_tokenizer.to_str()\n            # - stop token id is provided by user, or auto detected.\n            backend_str = tokenizer.backend_tokenizer.to_str()\n            if stop_token_ids is None:\n                if hasattr(tokenizer, 'eos_token_id') and tokenizer.eos_token_id is not None:\n                    stop_token_ids = [tokenizer.eos_token_id]\n                else:\n                    logger.warning('When constructing TokenizerInfo from a huggingface tokenizer, '\n                                   'stop_token_ids is neither provided by user nor found from the tokenizer. '\n                                   'It will be automatically detected.')\n            metadata = json.loads(TokenizerInfo._detect_metadata_from_hf(backend_str))\n            return TokenizerInfo(\n                encoded_vocab,\n                vocab_type=VocabType(metadata['vocab_type']),\n                vocab_size=vocab_size,\n                stop_token_ids=stop_token_ids,\n                add_prefix_space=metadata['add_prefix_space'],\n            )\n\n        elif TokenizerInfo._is_tiktoken_tokenizer(tokenizer):\n            # tiktoken tokenizer\n            # e.g. Phi-3-small-8k-instruct, Qwen-7B-Chat, stablelm-2-12b-chat (previously)\n            if stop_token_ids is None:\n                if hasattr(tokenizer, 'eos_token_id') and tokenizer.eos_token_id is not None:\n                    stop_token_ids = [tokenizer.eos_token_id]\n                else:\n                    logger.warning('When constructing TokenizerInfo from a huggingface tokenizer, '\n                                   'stop_token_ids is neither provided by user nor found from the tokenizer. '\n                                   'It will be automatically detected.')\n            return TokenizerInfo(\n                encoded_vocab,\n                VocabType.RAW,\n                vocab_size=vocab_size,\n                stop_token_ids=stop_token_ids,\n                add_prefix_space=False,\n            )\n\n        elif TokenizerInfo._is_sentencepiece_tokenizer(tokenizer):\n            # sentencepiece tokenizer\n            # e.g. Chatglm3-6b\n            if hasattr(tokenizer, 'sp_model'):\n                sp_model = tokenizer.sp_model\n            elif hasattr(tokenizer, 'tokenizer') and hasattr(tokenizer.tokenizer, 'sp_model'):\n                sp_model = tokenizer.tokenizer.sp_model\n\n            if stop_token_ids is None:\n                if hasattr(tokenizer, 'eos_token_id') and tokenizer.eos_token_id is not None:\n                    stop_token_ids = [tokenizer.eos_token_id]\n                else:\n                    eos_id = sp_model.eos_id()\n                    if eos_id != -1:\n                        stop_token_ids = [eos_id]\n                    else:\n                        logger.warning('When constructing TokenizerInfo from a huggingface tokenizer, '\n                                       'stop_token_ids is neither provided by user nor found from the tokenizer. '\n                                       'It will be automatically detected.')\n            # detect vocab_type of tokenizer\n            if '<0x0A>' in vocab_dict:\n                vocab_type = VocabType.BYTE_FALLBACK\n            else:\n                vocab_type = VocabType.RAW\n\n            return TokenizerInfo(\n                encoded_vocab,\n                vocab_type=vocab_type,\n                vocab_size=vocab_size,\n                stop_token_ids=stop_token_ids,\n                add_prefix_space=True,\n            )\n\n        else:\n            # TODO(yixin): unsupported tokenizer\n            raise ValueError(f'Unsupported tokenizer type: {type(tokenizer)}')\n"
  },
  {
    "path": "lmdeploy/turbomind/turbomind.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nimport asyncio\nimport copy\nimport json\nimport math\nimport os\nimport os.path as osp\nimport sys\nfrom collections.abc import Sequence\nfrom concurrent.futures import ThreadPoolExecutor\nfrom dataclasses import asdict\nfrom functools import partial\nfrom multiprocessing.reduction import ForkingPickler\nfrom queue import Queue\nfrom typing import Any, Dict, List, Optional\n\nimport pybase64\nimport torch\nimport yaml\n\nimport lmdeploy\nfrom lmdeploy.messages import EngineOutput, GenerationConfig, ResponseType, ScheduleMetrics, TurbomindEngineConfig\nfrom lmdeploy.serve.openai.protocol import UpdateParamsRequest\nfrom lmdeploy.tokenizer import Tokenizer\nfrom lmdeploy.utils import get_logger, get_max_batch_size, get_model\n\nfrom .deploy.config import TurbomindModelConfig\nfrom .supported_models import is_supported\n\n# TODO: find another way import _turbomind\nlmdeploy_dir = osp.split(lmdeploy.__file__)[0]\nsys.path.append(osp.join(lmdeploy_dir, 'lib'))\nimport _turbomind as _tm  # noqa: E402\nimport _xgrammar as _xgr  # noqa: E402\n\nfrom .tokenizer_info import TokenizerInfo  # noqa: E402\n\nlogger = get_logger('lmdeploy')\n\nMAX_LOGPROBS = 1024\n\n\ndef _construct_stop_or_bad_words(words: List[int] = None):\n    if words is None or len(words) == 0:\n        return None\n    offsets = list(range(1, len(words) + 1))\n    combined = [words, offsets]\n    return combined\n\n\ndef _np_dict_to_tm_dict(np_dict: dict):\n    \"\"\"Map numpy.ndarray to turbomind's tensor.\"\"\"\n    ret = _tm.TensorMap()\n    for k, v in np_dict.items():\n        ret[k] = _tm.from_dlpack(v)\n\n    return ret\n\n\ndef _tm_dict_to_torch_dict(tm_dict: _tm.TensorMap):\n    \"\"\"Map turbomind's tensor to torch's tensor.\"\"\"\n    ret = dict()\n    for k, v in tm_dict.items():\n        if v.type == _tm.DataType.TYPE_UINT32:\n            v = v.view(_tm.DataType.TYPE_INT32)\n        ret[k] = torch.from_dlpack(v)\n\n    return ret\n\n\ndef complete_parallel_config(cfg: TurbomindEngineConfig):\n    if any((cfg.attn_dp_size, cfg.attn_tp_size, cfg.mlp_dp_size, cfg.mlp_tp_size, cfg.outer_dp_size)):\n        cfg.attn_dp_size = cfg.attn_dp_size or 1\n        cfg.attn_tp_size = cfg.attn_tp_size or 1\n        cfg.mlp_dp_size = cfg.mlp_dp_size or 1\n        cfg.mlp_tp_size = cfg.mlp_tp_size or 1\n        cfg.outer_dp_size = cfg.outer_dp_size or 1\n        gcd = math.gcd(cfg.mlp_dp_size, cfg.attn_dp_size)\n        cfg.outer_dp_size *= gcd\n        cfg.mlp_dp_size //= gcd\n        cfg.attn_dp_size //= gcd\n        return True\n    return False\n\n\ndef update_parallel_config(cfg: TurbomindEngineConfig):\n    cfg.device_num = len(cfg.devices) * cfg.nnodes if cfg.devices else cfg.device_num\n    if not complete_parallel_config(cfg):\n        total = cfg.dp * cfg.tp\n        if not cfg.device_num:\n            count = torch.cuda.device_count() * cfg.nnodes\n            if total < count:\n                count = total\n            cfg.device_num = count\n        assert total % cfg.device_num == 0\n        overlap = total // cfg.device_num\n        attn_dp_size = overlap\n        mlp_tp_size = overlap\n        inner_tp_size = cfg.tp // mlp_tp_size\n        cfg.outer_dp_size = cfg.dp // attn_dp_size\n        cfg.attn_dp_size = attn_dp_size\n        cfg.attn_tp_size = inner_tp_size // cfg.cp\n        cfg.attn_cp_size = cfg.cp\n        cfg.mlp_dp_size = 1\n        cfg.mlp_tp_size = mlp_tp_size * inner_tp_size\n    assert cfg.attn_dp_size * cfg.attn_tp_size * cfg.attn_cp_size == cfg.mlp_dp_size * cfg.mlp_tp_size\n    assert cfg.attn_dp_size * cfg.attn_tp_size * cfg.attn_cp_size * cfg.outer_dp_size == cfg.device_num\n    # update devices\n    cfg.devices = cfg.devices or list(range(cfg.device_num // cfg.nnodes))\n    cfg.devices = cfg.devices[:cfg.device_num // cfg.nnodes]\n    assert len(cfg.devices) == cfg.device_num // cfg.nnodes\n\n\nclass TurboMind:\n    \"\"\"LMDeploy's inference engine.\n\n    Args:\n        model_path (str): the path of turbomind's model\n        mode_name (str): the name of the served model\n        chat_template_name (str): the name of the chat template, which is\n            supposed to be a builtin chat template defined in\n            `lmdeploy/model.py`\n        engine_config (TurbomindEngineConfig): the config of the inference\n            engine\n        model_source (int): the source of the model, which is either\n            turbomind model, or a transformers model\n    \"\"\"\n\n    def __init__(self,\n                 model_path: str,\n                 model_name: str = None,\n                 chat_template_name: str = None,\n                 engine_config: TurbomindEngineConfig = None,\n                 **kwargs):\n        self.model_name = model_name\n        self.chat_template_name = chat_template_name\n\n        _engine_config = copy.deepcopy(engine_config)\n        if _engine_config is None:\n            _engine_config = TurbomindEngineConfig()\n        if _engine_config.max_batch_size is None:\n            _engine_config.max_batch_size = get_max_batch_size('cuda')\n        assert _engine_config.max_batch_size > 0, 'max_batch_size should be' \\\n            f' greater than 0, but got {_engine_config.max_batch_size}'\n\n        update_parallel_config(_engine_config)\n        if _engine_config.nnodes > 1:\n            logger.info(f'dist_init_addr={_engine_config.dist_init_addr}')\n            assert _engine_config.dist_init_addr is not None\n            hostname, port = _engine_config.dist_init_addr.split(':')\n            os.environ['LMDEPLOY_DIST_INIT_ADDR'] = hostname\n            os.environ['LMDEPLOY_DIST_INIT_PORT'] = port\n            # this will block the process and ignore signals until all ranks done\n            from torch.distributed import TCPStore\n            self.store = TCPStore(host_name=hostname,\n                                  port=int(port),\n                                  world_size=_engine_config.nnodes,\n                                  is_master=_engine_config.node_rank == 0)\n\n        self.gpu_count = len(_engine_config.devices)\n        self.devices = _engine_config.devices\n        self._engine_created = False\n\n        if not osp.exists(model_path):\n            model_path = get_model(model_path, _engine_config.download_dir, _engine_config.revision)\n        self.model_comm = self._from_hf(model_path=model_path, engine_config=_engine_config)\n        self.is_dummy = self.model_comm.is_dummy_node()\n        self.tokenizer = Tokenizer(model_path)\n        if not _engine_config.empty_init:\n            self._load_weights()\n            self._process_weights()\n            self._create_engine()\n\n        self.session_len = self.config.session_len\n\n    def _check_unloaded_tm_params(self):\n        tm_params = self._tm_model.tm_params\n        if len(tm_params) > 0:\n            uninitialized = list(tm_params.keys())\n            logger.warning('the model may not be loaded successfully '\n                           f'with {len(tm_params)} uninitialized params:\\n{uninitialized}')\n\n    def _load_weights(self):\n        \"\"\"Load weights.\"\"\"\n        self._get_model_params()\n\n        with torch.cuda.device(self.devices[0]):\n            self._tm_model.export()\n\n        self._check_unloaded_tm_params()\n\n    def _process_weights(self):\n        \"\"\"Process weight.\"\"\"\n        with ThreadPoolExecutor(max_workers=self.gpu_count) as e:\n            for _ in e.map(self.model_comm.process_weight, range(self.gpu_count)):\n                pass\n\n    def _create_engine(self):\n        \"\"\"Create engine.\"\"\"\n        with ThreadPoolExecutor(max_workers=self.gpu_count) as e:\n            for _ in e.map(self.model_comm.create_engine, range(self.gpu_count)):\n                pass\n        self._engine_created = True\n\n    def _create_weight(self, model_comm):\n        \"\"\"Allocate weight buffer, load params if from_workspace.\"\"\"\n\n        # create weight\n        def _create_weight_func(device_id):\n            model_comm.create_weights(device_id)\n\n        with ThreadPoolExecutor(max_workers=self.gpu_count) as executor:\n            futures = []\n            for device_id in range(self.gpu_count):\n                futures.append(executor.submit(_create_weight_func, device_id))\n            for future in futures:\n                future.result()\n\n    def _get_model_params(self):\n        \"\"\"Get turbomind model params when loading from hf.\"\"\"\n\n        model_comm = self.model_comm\n        tm_params = self._tm_model.tm_params\n        tm_params.clear()\n\n        def _get_params(device_id, que):\n            out = model_comm.get_weights(device_id)\n            que.put(out)\n\n        que = Queue()\n        with ThreadPoolExecutor(max_workers=self.gpu_count) as executor:\n            futures = []\n            for device_id in range(self.gpu_count):\n                futures.append(executor.submit(_get_params, device_id, que))\n            for future in futures:\n                future.result()\n\n        for _ in range(self.gpu_count):\n            tensor_map = que.get()\n            for k, v in tensor_map.items():\n                if k not in tm_params:\n                    tm_params[k] = [v]\n                else:\n                    tm_params[k].append(v)\n        logger.warning(f'get {len(tm_params)} model params')\n\n    def _postprocess_config(self, tm_config: TurbomindModelConfig, engine_config: TurbomindEngineConfig):\n        \"\"\"Postprocess turbomind config by.\"\"\"\n        import copy\n        self.config = copy.deepcopy(tm_config)\n        # Update the attribute values in `self.config` with the valid values\n        # from the corresponding attributes in `engine_config`, such as\n        # `session_len`, `quant_policy`, `rope_scaling_factor`, etc.\n        self.config.update_from_engine_config(engine_config)\n\n        # update some attributes of `engine_config` which depends on\n        # `session_len`\n        self.engine_config = engine_config\n\n        # pack `self.config` and `self.engine_config` into a dict\n        self.config_dict = self.config.to_dict()\n        self.config_dict.update(dict(engine_config=asdict(self.engine_config)))\n        logger.info(f'turbomind model config:\\n\\n'\n                    f'{json.dumps(self.config_dict, indent=2)}')\n\n    def _from_hf(self, model_path: str, engine_config: TurbomindEngineConfig):\n        \"\"\"Load model which is in hf format.\"\"\"\n        assert is_supported(model_path), (f'turbomind does not support {model_path}. '\n                                          'Plz try pytorch engine instead.')\n\n        # convert transformers model into turbomind model\n        from .deploy.converter import get_tm_model\n        tm_model = get_tm_model(model_path, self.model_name, self.chat_template_name, engine_config)\n\n        self._postprocess_config(tm_model.tm_config, engine_config)\n\n        model_comm = _tm.TurboMind.create(model_dir='',\n                                          config=yaml.safe_dump(self.config_dict),\n                                          weight_type=self.config.model_config.weight_type)\n\n        # create empty weight\n        self._create_weight(model_comm)\n        # output model\n        self._tm_model = tm_model\n        return model_comm\n\n    def sleep(self, level: int = 1):\n        \"\"\"Sleep the model.\"\"\"\n        with ThreadPoolExecutor(max_workers=self.gpu_count) as e:\n            for _ in e.map(self.model_comm.sleep, range(self.gpu_count), [level] * self.gpu_count):\n                pass\n\n    def wakeup(self, tags: Optional[list[str]] = None):\n        \"\"\"Wakeup the model.\"\"\"\n        if tags is None:\n            tags = ['weights', 'kv_cache']\n        with ThreadPoolExecutor(max_workers=self.gpu_count) as e:\n            for _ in e.map(self.model_comm.wakeup, range(self.gpu_count), [tags] * self.gpu_count):\n                pass\n\n    def update_params(self, request: UpdateParamsRequest):\n        \"\"\"Update params.\n\n        When using the this function, you need to set empty_init=True when creating the engine.\n\n        For each request, the serialized_named_tensors should be the full weights of a decoder layer or the misc weights\n        (embedding, norm, lm_haed). You should set finished=True when you call this function for the last time.\n        \"\"\"\n\n        def _construct(item):\n            \"\"\" Deserialize torch.Tensor\n            Args:\n                item (Tuple[Callable, Tuple]): the return of reduce_tensor\n            \"\"\"\n            func, args = item\n            args = list(args)\n            args[6] = torch.cuda.current_device()  # device id.\n            return func(*args).clone()\n\n        if not hasattr(self, '_export_iter'):\n            self._get_model_params()\n            que = Queue()\n            tm_model = self._tm_model\n            tm_model.input_model.model_path = que\n            self._update_params_que = que\n            self._export_iter = tm_model.export_iter()\n\n        with torch.cuda.device(self.devices[0]):\n            if isinstance(request.serialized_named_tensors, str):\n                weights = ForkingPickler.loads(pybase64.b64decode(request.serialized_named_tensors))\n                weights = {k: _construct(v) for k, v in weights}\n            else:\n                weights = request.serialized_named_tensors\n            self._update_params_que.put(weights)\n            next(self._export_iter)\n\n        if request.finished:\n            self._check_unloaded_tm_params()\n            self._process_weights()\n            if self._engine_created is False:\n                self._create_engine()\n\n    @classmethod\n    def from_pretrained(cls,\n                        pretrained_model_name_or_path: str,\n                        model_name: str = None,\n                        chat_template_name: str = None,\n                        engine_config: TurbomindEngineConfig = None,\n                        **kwargs):\n        \"\"\"LMDeploy's turbomind inference engine.\n\n        Args:\n            pretrained_model_name_or_path (str):\n                It could be one of the following options:\n                    - i) A local directory path of a turbomind model which is\n                      converted by `lmdeploy convert` command or download from\n                      ii) and iii)\n                    - ii) The model_id of a lmdeploy-quantized model hosted\n                      inside a model repo on huggingface.co, such as\n                      \"InternLM/internlm-chat-20b-4bit\",\n                      \"lmdeploy/llama2-chat-70b-4bit\", etc.\n                    - iii) The model_id of a model hosted inside a model repo\n                      on huggingface.co, such as \"internlm/internlm-chat-7b\",\n                      \"Qwen/Qwen-7B-Chat \", \"baichuan-inc/Baichuan2-7B-Chat\"\n                      and so on.\n            kwargs (remaining dictionary of keyword arguments, *optional*):\n                Can be used to update configuration when initialize the engine.\n        \"\"\"\n        return cls(model_path=pretrained_model_name_or_path,\n                   model_name=model_name,\n                   chat_template_name=chat_template_name,\n                   engine_config=engine_config,\n                   **kwargs)\n\n    def close(self):\n        if hasattr(self, '_tm_model'):\n            # close immediately after init engine with empty_init=True\n            self._tm_model.tm_params.clear()\n        if hasattr(self, '_export_iter'):\n            del self._export_iter\n        if self.model_comm is not None:\n            self.model_comm = None\n        self._engine_created = False\n        if hasattr(self, 'store'):\n            del self.store\n\n    def create_instance(self, cuda_stream_id=0):\n        \"\"\"Create a turbomind instance.\n\n        Args:\n            cuda_stream_id(int): identity of a cuda stream\n        Returns:\n            TurboMindInstance: an instance of turbomind\n        \"\"\"\n        return TurboMindInstance(self, self.config, cuda_stream_id)\n\n    def get_schedule_metrics(self):\n        # TODO: support dp\n        tm_metrics = self.model_comm.get_schedule_metrics(0)\n        return ScheduleMetrics(active_seqs=tm_metrics.active_seqs,\n                               waiting_seqs=tm_metrics.waiting_seqs,\n                               total_blocks=tm_metrics.total_blocks,\n                               active_blocks=tm_metrics.active_blocks,\n                               free_blocks=tm_metrics.free_blocks)\n\n\ndef _get_logits(outputs, offset: int):\n    logits = outputs['logits']\n\n    def _func(out: EngineOutput, step: int, **kwargs):\n        out.logits = logits[:step - offset - 1, :]\n\n    return _func\n\n\ndef _get_last_hidden_state(outputs, offset: int):\n    last_hidden_state = outputs['last_hidden_state']\n\n    def _func(out: EngineOutput, step: int, **kwargs):\n        out.last_hidden_state = last_hidden_state[:step - offset - 1, :]\n\n    return _func\n\n\ndef _get_logprobs_impl(logprob_vals: torch.Tensor, logprob_idxs: torch.Tensor, logprob_nums: torch.Tensor,\n                       output_ids: List[int], logprobs: int, offset: int):\n    \"\"\"Get logprob of each generated token.\n\n    Args:\n        logprob_vals (torch.Tensor): shape (max_new_tokens, 1024),\n            1024 is the max_logprobs that turbomind engine can output\n        logprob_idxs (torch.Tensor): shape (max_new_tokens, 1024)\n        logprob_nums (torch.Tensor): shape (max_new_tokens,)\n        output_ids (List[int]): new generated token ids\n        logprobs (int): top n logprobs to return\n        offset (int): offset to index logprob_vals, logprob_idxs and logprob_nums.\n            It indicates where to start getting logprobs for the current generated tokens `output_ids`\n    \"\"\"\n    out_logprobs = []\n    # the total generated token number until now\n    length = len(output_ids) + offset\n    for (pos, idx, val, n) in zip(range(len(output_ids)), logprob_idxs[offset:length], logprob_vals[offset:length],\n                                  logprob_nums[offset:length]):\n        topn = min(n.item(), logprobs)\n        tok_res = {idx[i].item(): val[i].item() for i in range(topn)}\n        token_id = output_ids[pos]\n        if token_id not in tok_res:\n            valid_n = n.item()\n            tok_res[token_id] = \\\n                val[:valid_n][idx[:valid_n] == token_id].item()\n        ids = list(tok_res.keys())\n        for k in ids:\n            if tok_res[k] == float('-inf'):\n                tok_res.pop(k)\n        out_logprobs.append(tok_res)\n    return out_logprobs\n\n\ndef _get_logprobs(outputs, output_logprobs: int):\n    logprob_vals = outputs['logprob_vals']  # shape {max_new_tokens, 1024}\n    logprob_idxs = outputs['logprob_indexes']  # shape {max_new_tokens, 1024}\n    logprob_nums = outputs['logprob_nums']  # shape {max_new_tokens,}\n    offset = 0  # offset to index logprob_vals, logprob_idxs and logprob_nums\n\n    def _func(out: EngineOutput, step: int, **kwargs):\n        nonlocal offset\n        out.logprobs = _get_logprobs_impl(logprob_vals, logprob_idxs, logprob_nums, out.token_ids, output_logprobs,\n                                          offset)\n        offset += len(out.token_ids)\n\n    return _func\n\n\ndef _get_metrics(metrics):\n    import time\n\n    from lmdeploy.messages import EngineEvent, EventType, RequestMetrics\n\n    is_first = True\n\n    def _func(out: EngineOutput, step: int, **kwargs):\n        nonlocal is_first\n        if not is_first:\n            out.req_metrics = RequestMetrics(token_timestamp=time.time())\n        else:\n            events = [\n                EngineEvent(EventType.QUEUED, metrics.enqueue_time / 1000000),\n                EngineEvent(EventType.SCHEDULED, metrics.scheduled_time / 1000000),\n            ]\n            out.req_metrics = RequestMetrics(token_timestamp=time.time(), engine_events=events)\n            is_first = False\n\n    return _func\n\n\nclass StreamingSemaphore:\n\n    def __init__(self):\n        self.loop = asyncio.get_running_loop()\n        self.fut = None\n        self.val = 0\n\n    async def acquire(self):\n        if self.val:\n            self.val = 0\n            return\n        self.fut = self.loop.create_future()\n        await self.fut\n        self.fut = None\n        self.val = 0\n\n    def release(self):\n        if not self.val:\n            self.val = 1\n            if self.fut and not self.fut.done():\n                self.fut.set_result(None)\n\n\nclass TurboMindInstance:\n    \"\"\"Instance of TurboMind.\n\n    Args:\n        tm_model (str): turbomind's model path\n        cuda_stream_id(int): identity of a cuda stream\n    \"\"\"\n\n    def __init__(self, tm_model: TurboMind, config: TurbomindModelConfig, cuda_stream_id: int = 0):\n        self.tm_model = tm_model\n        self.cuda_stream_id = cuda_stream_id\n\n        # create model instances\n        lazy_init = self.tm_model.config_dict['engine_config'].get('empty_init', False)\n        self._model_inst = None if lazy_init else self._create_model_instance()\n\n        self.config = config\n        self.lock = None\n        # error code map from csrc (refer to `struct Request` in src/turbomind/engine/request.h)\n        # to lmdeploy.messages.ResponseType\n        self.errcode_map = {\n            0: ResponseType.SUCCESS,\n            1: ResponseType.SESSION_NOT_EXIST,\n            2: ResponseType.SESSION_REPEAT,\n            3: ResponseType.SESSION_REPEAT,\n            4: ResponseType.INTERNAL_ENGINE_ERROR,\n            5: ResponseType.INTERNAL_ENGINE_ERROR,\n            6: ResponseType.INPUT_LENGTH_ERROR,\n            7: ResponseType.FINISH,\n            8: ResponseType.CANCEL,\n            9: ResponseType.PREFIX_CACHE_CONFLICT_INTERACTIVE_MODE,\n            10: ResponseType.NO_QUEUE,\n            -1: ResponseType.INTERNAL_ENGINE_ERROR,\n        }\n\n    @property\n    def model_inst(self):\n        if self._model_inst is None:\n            self._model_inst = self._create_model_instance()\n        return self._model_inst\n\n    def _create_model_instance(self):\n        model_inst = self.tm_model.model_comm.create_request()\n        return model_inst\n\n    def _get_extra_output_processors(self, outputs: Dict[str, torch.Tensor], gen_config: GenerationConfig,\n                                     input_len: int, metrics: '_tm.RequestMetrics'):\n\n        def _get_offset(type):\n            return input_len - 1 if type == 'generation' else 0\n\n        fs = []\n        if gen_config.output_logits:\n            offset = _get_offset(gen_config.output_logits)\n            fs.append(_get_logits(outputs, offset))\n        if gen_config.output_last_hidden_state:\n            offset = _get_offset(gen_config.output_last_hidden_state)\n            fs.append(_get_last_hidden_state(outputs, offset))\n        if gen_config.logprobs:\n            fs.append(_get_logprobs(outputs, gen_config.logprobs))\n        if self.tm_model.engine_config.enable_metrics:\n            fs.append(_get_metrics(metrics))\n        return fs\n\n    def prepare_embeddings(self, input_embeddings=None, input_embedding_ranges=None):\n        \"\"\"Convert embeddings.\"\"\"\n        if not input_embeddings:\n            return None, None\n\n        assert isinstance(input_embeddings, List)\n        assert isinstance(input_embedding_ranges, List)\n        assert len(input_embeddings) == len(input_embedding_ranges)\n\n        length = sum([x.shape[0] for x in input_embeddings])\n\n        _MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16)\n        dtype = _MAP[self.tm_model.config.model_config.data_type]\n\n        values = torch.empty((length, input_embeddings[0].shape[-1]), dtype=dtype, device='cpu')\n        ranges = torch.tensor(input_embedding_ranges, dtype=torch.int32, device='cpu')\n\n        offset = 0\n        for embeds in input_embeddings:\n            values[offset:offset + embeds.shape[0]].copy_(embeds)\n            offset += embeds.shape[0]\n\n        return values, ranges\n\n    def prepare_mrope(self, input_meta: Dict[str, Any], input_len: int):\n        mrope_position_ids = input_meta['mrope_position_ids']\n        mrope_position_delta = input_meta['mrope_position_delta']\n        assert mrope_position_ids.size(-1) == input_len\n        mrope_position_ids = mrope_position_ids.t().contiguous()\n        return mrope_position_ids, mrope_position_delta\n\n    def prepare_inputs(self,\n                       input_ids,\n                       gen_config: GenerationConfig,\n                       input_embeddings=None,\n                       input_embedding_ranges=None,\n                       input_meta: Dict[str, Any] = None):\n        \"\"\"Convert inputs format.\"\"\"\n        assert isinstance(input_ids, Sequence)\n\n        input_ids = torch.IntTensor(input_ids)\n        input_len = len(input_ids)\n\n        inputs = dict(input_ids=input_ids, )\n\n        input_embeddings, input_embedding_ranges = self.prepare_embeddings(input_embeddings, input_embedding_ranges)\n        if input_embeddings is not None:\n            inputs['input_embeddings'] = input_embeddings.cpu()\n            inputs['input_embedding_ranges'] = input_embedding_ranges\n\n        if input_meta and 'mrope_position_ids' in input_meta:\n            mrope_position_ids, mrope_position_delta = self.prepare_mrope(input_meta, input_len)\n            inputs['mrope_position_ids'] = mrope_position_ids.type(torch.int32)\n            inputs['mrope_position_delta'] = mrope_position_delta.type(torch.int32)\n            inputs['mrope_length'] = torch.IntTensor([mrope_position_ids.shape[0]])\n\n        return inputs, input_len\n\n    async def async_cancel(self, session_id: int = None):\n        self.model_inst.cancel()\n\n    def async_end_cb(self, fut: asyncio.Future, status: int):\n        \"\"\"Executing on engine's signaling thread.\"\"\"\n        logger.info(f'[async_end_cb] session ended, status = {status}')\n        fut.get_loop().call_soon_threadsafe(fut.set_result, status)\n\n    async def async_end(self, session_id):\n        fut = asyncio.get_running_loop().create_future()\n        self.model_inst.end(partial(self.async_end_cb, fut), session_id)\n        await fut\n\n    def async_signal_cb(self, s: StreamingSemaphore):\n        \"\"\"Executing on engine's signaling thread.\"\"\"\n        s.loop.call_soon_threadsafe(s.release)\n\n    async def async_stream_infer(self,\n                                 session_id,\n                                 input_ids,\n                                 input_embeddings=None,\n                                 input_embedding_ranges=None,\n                                 input_meta: Dict[str, Any] = None,\n                                 sequence_start: bool = True,\n                                 sequence_end: bool = False,\n                                 step=0,\n                                 gen_config: GenerationConfig = None,\n                                 stream_output=False,\n                                 **kwargs):\n        \"\"\"Perform model inference.\n\n        Args:\n            session_id (int): the id of a session\n            input_ids (numpy.ndarray): the token ids of a prompt\n            input_embeddings (List[numpy.ndarray]): embeddings features\n            input_embedding_ranges (List[Tuple[int,int]]): the begin/end\n              offsets of input_embeddings to input_ids\n            sequence_start (bool): indicator for starting a sequence\n            sequence_end (bool): indicator for ending a sequence\n            step (int): the offset of the k/v cache\n            stop (bool): indicator for cancelling the session\n            gen_config (GenerationConfig): generation config\n            stream_output (bool): indicator for stream output\n            kwargs (dict): kwargs for backward compatibility\n        \"\"\"\n        logger.info(f'[async_stream_infer] session {session_id} start')\n        gen_cfg = self._get_generation_config(gen_config)\n\n        inputs, input_len = self.prepare_inputs(input_ids=input_ids,\n                                                input_embeddings=input_embeddings,\n                                                input_embedding_ranges=input_embedding_ranges,\n                                                input_meta=input_meta,\n                                                gen_config=gen_config)\n\n        if gen_config.response_format is not None:\n            tokenizer = self.tm_model.tokenizer\n            vocab_size = self.tm_model.config.model_config.vocab_size\n\n            try:\n                tokenizer_info = TokenizerInfo.from_huggingface(tokenizer.model.model, vocab_size=vocab_size)\n                decode_grammar_type = gen_config.response_format['type']\n                if decode_grammar_type == 'json_schema':\n                    decode_grammar = gen_config.response_format[decode_grammar_type]['schema']\n                elif decode_grammar_type == 'regex_schema':\n                    decode_grammar = gen_config.response_format[decode_grammar_type]\n                elif decode_grammar_type == 'json_object':\n                    decode_grammar = '{\"type\" : \"object\", \"additionalProperties\": true}'\n\n                compiler = _xgr.GrammarCompiler(tokenizer_info)\n\n                if decode_grammar_type == 'json_schema':\n                    decode_grammar = json.dumps(decode_grammar)\n                    grammar = compiler.compile_json_schema(decode_grammar)\n                elif decode_grammar_type == 'regex_schema':\n                    decode_grammar = str(decode_grammar)\n                    grammar = compiler.compile_regex(decode_grammar)\n                elif decode_grammar_type == 'json_object':\n                    decode_grammar = str(decode_grammar)\n                    grammar = compiler.compile_json_schema(decode_grammar)\n                else:\n                    assert False, f'Decode grammar type {decode_grammar_type} should be in ' \\\n                                   '[\"json_schema\", \"regex_schema\", \"json_object\"]'\n\n                self.model_inst.set_grammar(grammar)\n            except ValueError as e:\n                logger.warning(f'Failed to initialize guided decoding for tokenizer {tokenizer}, '\n                               f'disable guided decoding: {e}')\n                gen_config.response_format = None\n\n        session = _tm.SessionParam(id=session_id, step=step, start=sequence_start, end=sequence_end)\n\n        inputs = _np_dict_to_tm_dict(inputs)\n\n        sem = StreamingSemaphore()\n        signal_cb = partial(self.async_signal_cb, sem)\n\n        outputs, shared_state, metrics = self.model_inst.forward(inputs, session, gen_cfg, stream_output,\n                                                                 self.tm_model.engine_config.enable_metrics, signal_cb)\n\n        outputs = _tm_dict_to_torch_dict(outputs)\n\n        extra_fs = self._get_extra_output_processors(outputs, gen_config, input_len, metrics)\n\n        output_ids_buf = outputs['output_ids']\n\n        finish = False\n        state = None\n\n        output_ids = []\n        prev_len = step + input_len\n        try:\n            while True:\n                await sem.acquire()\n                state = shared_state.consume()\n\n                status, seq_len = state.status, state.seq_len\n                ret_status = ResponseType.SUCCESS\n\n                if status in [7, 8]:  # finish / canceled\n                    finish = True\n                    ret_status = ResponseType.FINISH if status == 7 else ResponseType.CANCEL\n                elif status:\n                    logger.error(f'internal error. status_code {status}')\n                    yield self._get_error_output(status)\n                    break\n\n                if seq_len == prev_len and not finish:\n                    continue\n\n                output_ids = output_ids_buf[prev_len:seq_len].tolist()\n                output = EngineOutput(ret_status, output_ids)\n\n                for f in extra_fs:\n                    f(output, seq_len)\n\n                prev_len = seq_len\n\n                yield output\n\n                if finish:\n                    break\n\n        except (GeneratorExit, asyncio.CancelledError) as e:\n            logger.info(f'[async_stream_infer] {type(e).__name__}')\n            self.model_inst.cancel()\n        except Exception as e:\n            logger.error(f'[async_stream_infer] {type(e).__name__} {e}')\n            self.model_inst.cancel()\n            yield self._get_error_output(-1)\n        finally:\n            # Contract: `cb` won't be called again if status is non-zero\n            # wait for status to be set as `finish` or `error`\n            while not state or state.status == 0:\n                await sem.acquire()\n                state = shared_state.consume()\n            logger.info(f'[async_stream_infer] session {session_id} done')\n\n    def _get_error_output(self, status):\n        return EngineOutput(status=self.errcode_map[status], token_ids=[])\n\n    def _get_generation_config(self, cfg: GenerationConfig):\n        c = _tm.GenerationConfig()\n        c.max_new_tokens = cfg.max_new_tokens\n        c.top_k = cfg.top_k\n        c.top_p = cfg.top_p\n        c.min_p = cfg.min_p\n        c.temperature = cfg.temperature\n        if cfg.stop_token_ids:\n            c.eos_ids = cfg.stop_token_ids\n        if cfg.bad_token_ids:\n            c.bad_ids = _construct_stop_or_bad_words(cfg.bad_token_ids)\n        if not cfg.ignore_eos and cfg.stop_token_ids:\n            c.stop_ids = _construct_stop_or_bad_words(cfg.stop_token_ids)\n        c.repetition_penalty = cfg.repetition_penalty\n        if cfg.min_new_tokens:\n            c.min_new_tokens = cfg.min_new_tokens\n        output_type = dict(all=1, generation=2)\n        if cfg.output_last_hidden_state:\n            c.output_last_hidden_state = output_type[cfg.output_last_hidden_state]\n        if cfg.output_logits:\n            c.output_logits = output_type[cfg.output_logits]\n        if cfg.logprobs:\n            if cfg.logprobs > MAX_LOGPROBS:\n                cfg.logprobs = MAX_LOGPROBS\n                logger.warning(f'logprobs shoudd be in range [1, {MAX_LOGPROBS}]'\n                               f'update logprobs={cfg.logprobs}')\n            c.output_logprobs = cfg.logprobs\n        if cfg.random_seed is not None:\n            c.random_seed = cfg.random_seed\n        # print (c)\n        return c\n"
  },
  {
    "path": "lmdeploy/utils.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport asyncio\nimport functools\nimport logging\nimport os\nimport sys\nimport time\nfrom contextlib import contextmanager\nfrom dataclasses import dataclass\nfrom logging import Logger, LogRecord\n\nimport torch\nfrom transformers import PretrainedConfig\n\nlogger_initialized = {}\n\n\nclass _ASNI_COLOR:\n    BRIGHT_RED = '\\033[91m'\n    RED = '\\033[31m'\n    YELLOW = '\\033[33m'\n    WHITE = '\\033[37m'\n    GREEN = '\\033[32m'\n\n\n# copy from: https://github.com/termcolor/termcolor\n@functools.cache\ndef can_colorize(*, no_color: bool | None = None, force_color: bool | None = None) -> bool:\n    \"\"\"Check env vars and for tty/dumb terminal.\"\"\"\n    import io\n    if no_color is not None and no_color:\n        return False\n    if force_color is not None and force_color:\n        return True\n\n    # Then check env vars:\n    if os.environ.get('ANSI_COLORS_DISABLED'):\n        return False\n    if os.environ.get('NO_COLOR'):\n        return False\n    if os.environ.get('FORCE_COLOR'):\n        return True\n\n    # Then check system:\n    if os.environ.get('TERM') == 'dumb':\n        return False\n    if not hasattr(sys.stdout, 'fileno'):\n        return False\n\n    try:\n        return os.isatty(sys.stdout.fileno())\n    except io.UnsupportedOperation:\n        return sys.stdout.isatty()\n\n\nclass ColorFormatter(logging.Formatter):\n\n    _LEVELNAME_COLOR_MAP = dict(CRITICAL=_ASNI_COLOR.BRIGHT_RED,\n                                ERROR=_ASNI_COLOR.RED,\n                                WARN=_ASNI_COLOR.YELLOW,\n                                WARNING=_ASNI_COLOR.YELLOW,\n                                INFO=_ASNI_COLOR.WHITE,\n                                DEBUG=_ASNI_COLOR.GREEN)\n\n    _RESET_COLOR = '\\033[0m'\n\n    def format(self, record: LogRecord):\n        \"\"\"format.\"\"\"\n        if not can_colorize():\n            # windows does not support ASNI color\n            return super().format(record)\n        levelname = record.levelname\n        level_color = self._LEVELNAME_COLOR_MAP.get(levelname, self._RESET_COLOR)\n        levelname = f'{level_color}{levelname}{self._RESET_COLOR}'\n        record.levelname = levelname\n        return super().format(record)\n\n\nclass FilterDuplicateWarning(logging.Filter):\n    \"\"\"Filter the repeated warning message.\n\n    Args:\n        name (str): name of the filter.\n    \"\"\"\n\n    def __init__(self, name: str = 'lmdeploy'):\n        super().__init__(name)\n        self.seen: set = set()\n\n    def filter(self, record: LogRecord) -> bool:\n        \"\"\"Filter the repeated warning message.\n\n        Args:\n            record (LogRecord): The log record.\n\n        Returns:\n            bool: Whether to output the log record.\n        \"\"\"\n        if record.levelno != logging.WARNING:\n            return True\n\n        if record.msg not in self.seen:\n            self.seen.add(record.msg)\n            return True\n        return False\n\n\n_FORMAT = '%(asctime)s - %(name)s - %(levelname)s - %(filename)s:%(lineno)d' \\\n          ' - %(message)s'\n\n\ndef get_logger(name: str | None = None,\n               log_file: str | None = None,\n               log_level: int = logging.INFO,\n               file_mode: str = 'a',\n               log_formatter: str = _FORMAT) -> Logger:\n    \"\"\"Initialize and get a logger by name.\n\n    If the logger has not been initialized, this method will initialize the\n    logger by adding one or two handlers, otherwise the initialized logger will\n    be directly returned. During initialization, a StreamHandler will always be\n    added. If `log_file` is specified, a FileHandler will also be added.\n    Args:\n        name (str): Logger name.\n        log_file (str | None): The log filename. If specified, a FileHandler\n            will be added to the logger.\n        log_level (int): The logger level.\n        file_mode (str): The file mode used in opening log file.\n            Defaults to 'a'.\n        log_formatter (str): The logger output format.\n    Returns:\n        logging.Logger: The expected logger.\n    \"\"\"\n    logger = logging.getLogger(name)\n    if name in logger_initialized:\n        return logger\n    # handle hierarchical names\n    # e.g., logger \"a\" is initialized, then logger \"a.b\" will skip the\n    # initialization since it is a child of \"a\".\n    for logger_name in logger_initialized:\n        if name.startswith(logger_name):\n            return logger\n\n    # handle duplicate logs to the console\n    for handler in logger.root.handlers:\n        if type(handler) is logging.StreamHandler:\n            handler.setLevel(logging.ERROR)\n\n    stream_handler = logging.StreamHandler(stream=sys.stdout)\n    handlers = [stream_handler]\n\n    # set log_file from env\n    log_file = log_file or os.getenv('LMDEPLOY_LOG_FILE')\n\n    if log_file is not None:\n        log_file = os.path.expanduser(log_file)\n        log_dir = os.path.dirname(log_file)\n        if log_dir:\n            os.makedirs(log_dir, exist_ok=True)\n        # Here, the default behaviour of the official logger is 'a'. Thus, we\n        # provide an interface to change the file mode to the default\n        # behaviour.\n        file_handler = logging.FileHandler(log_file, file_mode)\n        handlers.append(file_handler)\n\n    formatter = ColorFormatter(log_formatter)\n    for handler in handlers:\n        handler.setFormatter(formatter)\n        handler.setLevel(logging.DEBUG)\n        handler.addFilter(FilterDuplicateWarning(name))\n        logger.addHandler(handler)\n\n    logger.setLevel(log_level)\n    logger.propagate = False\n    logger_initialized[name] = True\n\n    return logger\n\n\ndef filter_suffix(response: str, suffixes: list[str] | None = None) -> str:\n    \"\"\"Filter response with suffixes.\n\n    Args:\n        response (str): generated response by LLMs.\n        suffixes (str): a list of suffixes to be deleted.\n\n    Return:\n        str: a clean response.\n    \"\"\"\n    if suffixes is None:\n        return response\n    for item in suffixes:\n        if response.endswith(item):\n            response = response[:len(response) - len(item)]\n    return response\n\n\n# TODO remove stop_word_offsets stuff and make it clean\ndef _stop_words(stop_words: list[int | str], tokenizer: object):\n    \"\"\"Return list of stop-words to numpy.ndarray.\"\"\"\n    import numpy as np\n    if stop_words is None:\n        return None\n    assert isinstance(stop_words, list) and \\\n        all(isinstance(elem, (str, int)) for elem in stop_words), \\\n        f'stop_words must be a list but got {type(stop_words)}'\n    stop_indexes = []\n    for stop_word in stop_words:\n        if isinstance(stop_word, str):\n            stop_indexes += tokenizer.indexes_containing_token(stop_word)\n        elif isinstance(stop_word, int):\n            stop_indexes.append(stop_word)\n    assert isinstance(stop_indexes, list) and all(isinstance(elem, int) for elem in stop_indexes), 'invalid stop_words'\n    # each id in stop_indexes represents a stop word\n    # refer to https://github.com/fauxpilot/fauxpilot/discussions/165 for\n    # detailed explanation about fastertransformer's stop_indexes\n    stop_word_offsets = range(1, len(stop_indexes) + 1)\n    stop_words = np.array([[stop_indexes, stop_word_offsets]]).astype(np.int32)\n    return stop_words\n\n\ndef get_hf_gen_cfg(path: str):\n    from transformers import GenerationConfig\n    try:\n        cfg = GenerationConfig.from_pretrained(path, trust_remote_code=True)\n        return cfg.to_dict()\n    except OSError:\n        return {}\n\n\ndef get_model(pretrained_model_name_or_path: str, download_dir: str = None, revision: str = None, token: str = None):\n    \"\"\"Get model from huggingface, modelscope or openmind_hub.\"\"\"\n    import os\n    if os.getenv('LMDEPLOY_USE_MODELSCOPE', 'False').lower() == 'true':\n        from modelscope import snapshot_download\n    elif os.getenv('LMDEPLOY_USE_OPENMIND_HUB', 'False').lower() == 'true':\n        from openmind_hub import snapshot_download\n    else:\n        from huggingface_hub import snapshot_download\n\n    download_kwargs = {}\n    if download_dir is not None:\n        download_kwargs['cache_dir'] = download_dir\n    if revision is not None:\n        download_kwargs['revision'] = revision\n    if token is not None:\n        download_kwargs['token'] = token\n\n    model_path = snapshot_download(pretrained_model_name_or_path, ignore_patterns=['*.pth'], **download_kwargs)\n    return model_path\n\n\ndef logging_timer(op_name: str, logger: Logger, level: int = logging.DEBUG):\n    \"\"\"Logging timer.\"\"\"\n\n    @contextmanager\n    def __timer():\n        \"\"\"timer.\"\"\"\n        start = time.perf_counter()\n        yield\n        end = time.perf_counter()\n        duration = (end - start) * 1000\n        logger.log(level, f'<{op_name}> take time: {duration:.2f} ms')\n\n    def __inner(func):\n        \"\"\"inner.\"\"\"\n\n        @functools.wraps(func)\n        def __func_warpper(*args, **kwargs):\n            \"\"\"Func warpper.\"\"\"\n            if logger.level > level:\n                return func(*args, **kwargs)\n            with __timer():\n                return func(*args, **kwargs)\n\n        @functools.wraps(func)\n        def __async_warpper(*args, **kwargs):\n            \"\"\"Async warpper.\"\"\"\n\n            async def __tmp():\n                if logger.level > level:\n                    return (await func(*args, **kwargs))\n                with __timer():\n                    return (await func(*args, **kwargs))\n\n            return __tmp()\n\n        if asyncio.iscoroutinefunction(func):\n            return __async_warpper\n        else:\n            return __func_warpper\n\n    return __inner\n\n\n# modified from https://github.com/vllm-project/vllm/blob/0650e5935b0f6af35fb2acf71769982c47b804d7/vllm/config.py#L1082-L1150  # noqa\ndef _get_and_verify_max_len(\n    hf_config: PretrainedConfig,\n    max_model_len: int | None,\n) -> int:\n    \"\"\"Get and verify the model's maximum length.\"\"\"\n\n    # vl configs hide session-len inside llm configs\n    llm_keys = ['language_config', 'llm_config', 'text_config']\n    for key in llm_keys:\n        hf_config = getattr(hf_config, key, hf_config)\n\n    logger = get_logger('lmdeploy')\n    derived_max_model_len = float('inf')\n    possible_keys = [\n        # OPT\n        'max_position_embeddings',\n        # GPT-2\n        'n_positions',\n        # MPT\n        'max_seq_len',\n        # ChatGLM2\n        'seq_length',\n        # Command-R\n        'model_max_length',\n        # Others\n        'max_sequence_length',\n        'max_seq_length',\n        'seq_len',\n    ]\n    max_len_key = None\n    for key in possible_keys:\n        max_len = None\n        if hasattr(hf_config, key):\n            max_len = getattr(hf_config, key)\n        elif key in hf_config:\n            max_len = hf_config[key]\n        if max_len is not None:\n            max_len_key = key if max_len < derived_max_model_len \\\n                else max_len_key\n            derived_max_model_len = min(derived_max_model_len, max_len)\n    if derived_max_model_len == float('inf'):\n        if max_model_len is not None:\n            # If max_model_len is specified, we use it.\n            return max_model_len\n\n        default_max_len = 2048\n        logger.warning(\"The model's config.json does not contain any of the following \"\n                       'keys to determine the original maximum length of the model: '\n                       f\"{possible_keys}. Assuming the model's maximum length is \"\n                       f'{default_max_len}.')\n        derived_max_model_len = default_max_len\n\n    if max_model_len is None:\n        max_model_len = int(derived_max_model_len)\n    elif max_model_len > derived_max_model_len:\n        # Some models might have a separate key for specifying model_max_length\n        # that will be bigger than derived_max_model_len. We compare user input\n        # with model_max_length and allow this override when it's smaller.\n        model_max_length = getattr(hf_config, 'model_max_length', None)\n        if model_max_length is not None and max_model_len <= model_max_length:\n            pass\n        else:\n            logger.warning(f'User-specified max_model_len ({max_model_len}) is greater '\n                           'than the derived max_model_len '\n                           f'({max_len_key}={derived_max_model_len} or model_max_length='\n                           f\"{model_max_length} in model's config.json).\")\n    return int(max_model_len)\n\n\ndef get_max_batch_size(device_type: str):\n    \"\"\"Get the max inference batch size for LLM models according to the device\n    type.\n\n    Args:\n        device_type (str): the type of device\n    \"\"\"\n    assert device_type in ['cuda', 'ascend', 'maca', 'camb']\n    if device_type == 'cuda':\n        max_batch_size_map = {'a100': 384, 'a800': 384, 'h100': 1024, 'h800': 1024, 'l20y': 1024, 'h200': 1024}\n        import torch\n        device_name = torch.cuda.get_device_name(0).lower()\n        for name, size in max_batch_size_map.items():\n            if name in device_name:\n                return size\n        # for devices that are not in `max_batch_size_map`, set\n        # the max_batch_size 128\n        return 128\n    elif device_type == 'ascend':\n        return 256\n    elif device_type == 'maca':\n        return 256\n    elif device_type == 'camb':\n        return 256\n\n\ndef is_bf16_supported(device_type: str = 'cuda'):\n    \"\"\"Check if device support bfloat16.\n\n    Args:\n        device_type (str): the type of device\n    \"\"\"\n\n    if device_type == 'cuda':\n        import torch\n        device = torch.cuda.current_device()\n\n        # Check for CUDA version and device compute capability.\n        # This is a fast way to check for it.\n        cuda_version = torch.version.cuda\n        if (cuda_version is not None and int(cuda_version.split('.')[0]) >= 11\n                and torch.cuda.get_device_properties(device).major >= 8):\n            return True\n        else:\n            return False\n    elif device_type == 'ascend':\n        # The following API doesn't work somehow in multi-npu devices. Due to\n        # the `ascend910` device's capability to support bfloat16, we are\n        # returning true as a workaround\n        return True\n        # import torch_npu\n        # device_name = torch_npu.npu.get_device_name(0)[:10]\n        # device_name = device_name.lower()\n        # if device_name.startwith('ascend910'):\n        #     return True\n        # else:\n        #     return False\n    elif device_type == 'maca':\n        return True\n    elif device_type == 'camb':\n        return True\n    elif device_type == 'rocm':\n        return True\n    else:\n        return False\n\n\ndef try_import_deeplink(device_type: str):\n    deeplink_device_type_list = [\n        'ascend',\n        'npu',\n        'maca',\n        'camb',\n    ]\n    if device_type in deeplink_device_type_list:\n        try:\n            import dlinfer.framework.lmdeploy_ext  # noqa: F401\n        except Exception as e:\n            logger = get_logger('lmdeploy')\n            logger.error(f'{type(e).__name__}: {e}')\n            exit(1)\n\n\ndef serialize_state_dict(state_dict: dict) -> str:\n    \"\"\"Serialize state dict to str.\n\n    The consumer should use it on same node. As the producer and consumer may\n    have different GPU visibility, we use reduce_tensor instead of ForkingPickler.dumps\n    to fix the device_id when loading the serialized tensor.\n\n    Args:\n        state_dict (dict[str, torch.Tensor]): state dict to serialize.\n    Returns:\n        str: serialized state dict.\n    \"\"\"\n    from io import BytesIO\n    from multiprocessing.reduction import ForkingPickler\n\n    import pybase64\n    from torch.multiprocessing.reductions import reduce_tensor\n\n    # flattened_tensor\n    if 'metadata' in state_dict and 'flattened_tensor' in state_dict:\n        data = state_dict\n        if isinstance(data['flattened_tensor'], torch.Tensor):\n            data['flattened_tensor'] = reduce_tensor(state_dict['flattened_tensor'])\n    else:\n        data = [(k, reduce_tensor(v)) for k, v in state_dict.items()]\n\n    buf = BytesIO()\n    ForkingPickler(buf).dump(data)\n    buf.seek(0)\n    return pybase64.b64encode(buf.read()).decode('utf-8')\n\n\ndef is_dlblas_installed():\n    is_dlblas_installed = True\n    try:\n        import dlblas  # noqa: F401\n    except Exception:\n        is_dlblas_installed = False\n    return is_dlblas_installed\n\n\n# from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/weight_sync/tensor_bucket.py\n\n\n@dataclass\nclass FlattenedTensorMetadata:\n    \"\"\"Metadata for flatten bucket tensor.\"\"\"\n    name: str\n    shape: torch.Size\n    dtype: torch.dtype\n    start_idx: int\n    end_idx: int\n    numel: int\n\n\nclass FlattenedTensorBucket:\n    \"\"\"Pack multiple flattened tensor into one to transfer efficiently.\"\"\"\n\n    def __init__(\n        self,\n        named_tensors: list[tuple[str, torch.Tensor]] | None = None,\n        flattened_tensor: torch.Tensor = None,\n        metadata: list[FlattenedTensorMetadata] | None = None,\n    ):\n        \"\"\"Initialize a tensor bucket from a list of named tensors or from pre-\n        flattened data.\n\n        Args:\n            named_tensors: List of (name, tensor) tuples (for creating new bucket)\n            flattened_tensor: Pre-flattened tensor (for reconstruction)\n            metadata: Pre-computed metadata (for reconstruction)\n        \"\"\"\n        if named_tensors is not None:\n            num_tensors = len(named_tensors)\n            self.metadata = [None] * num_tensors\n            self.flattened_tensor = [None] * num_tensors\n            if num_tensors > 0:\n                if num_tensors > 1:\n                    dtypes = [t.dtype for _, t in named_tensors]\n                    if not all([d == dtypes[0] for d in dtypes[1:]]):\n                        raise ValueError(f'All tensors should have same dtype, but given {dtypes}')\n\n                current_idx = 0\n                for idx, (name, tensor) in enumerate(named_tensors):\n                    self.flattened_tensor[idx] = tensor.flatten()\n                    numel = tensor.numel()\n                    self.metadata[idx] = FlattenedTensorMetadata(name=name,\n                                                                 shape=tensor.shape,\n                                                                 dtype=tensor.dtype,\n                                                                 start_idx=current_idx,\n                                                                 end_idx=current_idx + numel,\n                                                                 numel=numel)\n                    current_idx += numel\n\n                self.flattened_tensor = torch.cat(self.flattened_tensor, dim=0)\n        else:\n            if flattened_tensor is None or metadata is None:\n                raise ValueError('Must provide either named_tensors or both flattened_tensor and metadata')\n            self.metadata = metadata\n            self.flattened_tensor = flattened_tensor\n\n    def get_flattened_tensor(self) -> torch.Tensor:\n        \"\"\"Get the flattened tensor containing multiple tensors.\"\"\"\n        return self.flattened_tensor\n\n    def get_metadata(self) -> list[FlattenedTensorMetadata]:\n        \"\"\"Get all metadatas for all tensors in the bucket.\"\"\"\n        return self.metadata\n\n    def reconstruct_tensors(self) -> list[tuple[str, torch.Tensor]]:\n        \"\"\"Reconstruct original tensors.\"\"\"\n        # preallocate the result list\n        reconstructed = [None] * len(self.metadata)\n\n        for i, meta in enumerate(self.metadata):\n            tensor = self.flattened_tensor[meta.start_idx:meta.end_idx].reshape(meta.shape)\n\n            # batch dtype conversion (if needed)\n            if tensor.dtype != meta.dtype:\n                tensor = tensor.to(meta.dtype)\n\n            reconstructed[i] = (meta.name, tensor)\n\n        return reconstructed\n"
  },
  {
    "path": "lmdeploy/version.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom typing import Tuple\n\n__version__ = '0.12.2'\nshort_version = __version__\n\n\ndef parse_version_info(version_str: str) -> Tuple:\n    \"\"\"Parse version from a string.\n\n    Args:\n        version_str (str): A string represents a version info.\n\n    Returns:\n        tuple: A sequence of integer and string represents version.\n    \"\"\"\n    _version_info = []\n    for x in version_str.split('.'):\n        if x.isdigit():\n            _version_info.append(int(x))\n        elif x.find('rc') != -1:\n            patch_version = x.split('rc')\n            _version_info.append(int(patch_version[0]))\n            _version_info.append(f'rc{patch_version[1]}')\n    return tuple(_version_info)\n\n\nversion_info = parse_version_info(__version__)\n"
  },
  {
    "path": "lmdeploy/vl/__init__.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom .utils import (encode_image_base64, encode_time_series_base64, encode_video_base64, load_image, load_time_series,\n                    load_video)\n\n__all__ = [\n    'load_image',\n    'load_video',\n    'load_time_series',\n    'encode_image_base64',\n    'encode_video_base64',\n    'encode_time_series_base64',\n]\n"
  },
  {
    "path": "lmdeploy/vl/constants.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom enum import Enum\n\nIMAGE_TOKEN = '<IMAGE_TOKEN>'\n\n\nclass Modality(str, Enum):\n    IMAGE = 'image'\n    VIDEO = 'video'\n    AUDIO = 'audio'\n    TIME_SERIES = 'time_series'\n"
  },
  {
    "path": "lmdeploy/vl/engine.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nimport asyncio\nimport inspect\nfrom concurrent.futures import ThreadPoolExecutor\nfrom typing import Any, Dict, List, Optional, Union\n\nimport torch\n\nfrom lmdeploy.messages import PytorchEngineConfig, TurbomindEngineConfig, VisionConfig\nfrom lmdeploy.utils import get_logger\nfrom lmdeploy.vl.model.builder import load_vl_model\n\nlogger = get_logger('lmdeploy')\n\n\ndef _raise_exception_on_finish(task: asyncio.Task) -> None:\n    \"\"\"Raise exception on finish.\"\"\"\n    try:\n        task.result()\n    except asyncio.CancelledError:\n        return\n    except Exception as e:\n        raise e\n\n\ndef _accepts_arg(func, arg_name: str) -> bool:\n    \"\"\"Check if a function accepts a specific keyword argument.\"\"\"\n    return arg_name in inspect.signature(func).parameters\n\n\nclass ImageEncoder:\n    \"\"\"Image encoder.\"\"\"\n\n    def __init__(\n        self,\n        model_path: str,\n        backend: str,\n        vision_config: VisionConfig = None,\n        backend_config: Optional[Union[TurbomindEngineConfig, PytorchEngineConfig]] = None,\n    ):\n        self.model = load_vl_model(model_path, backend, backend_config=backend_config)\n        if vision_config is None:\n            vision_config = VisionConfig()\n        self.vision_config = vision_config\n        self.max_batch_size = vision_config.max_batch_size\n        self.executor = ThreadPoolExecutor(max_workers=1)\n        torch.cuda.empty_cache()\n\n    async def preprocess(self,\n                         messages: List[Dict],\n                         mm_processor_kwargs: Optional[Dict[str, Any]] = None) -> List[Dict]:\n        \"\"\"Preprocess multimodal data in the messages.\"\"\"\n        if _accepts_arg(self.model.preprocess, 'mm_processor_kwargs'):\n            future = asyncio.get_event_loop().run_in_executor(self.executor, self.model.preprocess, messages,\n                                                              mm_processor_kwargs)\n        else:\n            future = asyncio.get_event_loop().run_in_executor(self.executor, self.model.preprocess, messages)\n        future.add_done_callback(_raise_exception_on_finish)\n        outputs = await future\n        return outputs\n\n    async def async_infer(self, messages: List[Dict]) -> List[Dict]:\n        \"\"\"Get multimodal embedding.\n\n        Args:\n            messages (List[Dict]): a list of message, which is the output\n            of `preprocess()`\n        \"\"\"\n        future = asyncio.get_event_loop().run_in_executor(self.executor, self.model.forward, messages,\n                                                          self.max_batch_size)\n        future.add_done_callback(_raise_exception_on_finish)\n        outputs = await future\n        return outputs\n\n    async def wrap_for_pytorch(\n        self,\n        messages: List[Dict],\n        chat_template,\n        tokenizer,\n        sequence_start,\n        tools: Optional[List[object]] = None,\n        chat_template_kwargs: Optional[Dict] = None,\n    ) -> List[Dict]:\n        \"\"\"\n        Args:\n            messages (List[Dict]): a list of message, which is supposed to be\n                the output of `preprocess`\n        Returns:\n            a dict which will be passed to pytorch engine_instance's forward.\n            The dict is like the following:\n            Dict(\n                'prompt': 'the prompt after applying chat template'\n                'input_ids': [],\n                'multimodal': {\n                    'pixel_values': torch.Tensor,\n                    ...\n                ]\n            )\n        \"\"\"\n        has_input_ids = self.model.has_input_ids(messages)\n        if not has_input_ids:\n            result = self.model.to_pytorch(messages,\n                                           chat_template,\n                                           tokenizer,\n                                           sequence_start,\n                                           tools=tools,\n                                           chat_template_kwargs=chat_template_kwargs)\n        else:\n            result = self.model.to_pytorch_with_input_ids(messages)\n        # clear data\n        for i, message in enumerate(messages):\n            if isinstance(message['content'], List):\n                messages[i]['preprocess'] = None\n        return result\n\n    async def wrap_for_turbomind(\n        self,\n        messages: List[Dict],\n        chat_template,\n        tokenizer,\n        sequence_start,\n        tools: Optional[List[object]] = None,\n        chat_template_kwargs: Optional[Dict] = None,\n    ) -> Dict:\n        \"\"\"\n        Args:\n            messages (List[Dict]): a list of message, which is supposed to be\n                the output of `async_infer`\n        Returns:\n            a dict which will be passed to pytorch engine_instance's forward.\n            The dict is like the following:\n            Dict(\n                'prompt': 'the prompt after applying chat template'\n                'input_ids': [],\n                'input_embeddings': list[torch.Tensor],\n                'input_embedding_ranges': list[torch.Tensor],\n                ...\n        \"\"\"\n        result = self.model.to_turbomind(messages,\n                                         chat_template,\n                                         tokenizer,\n                                         sequence_start,\n                                         tools=tools,\n                                         chat_template_kwargs=chat_template_kwargs)\n        # clear data\n        for i, message in enumerate(messages):\n            if isinstance(message['content'], List):\n                messages[i]['preprocess'] = None\n                messages[i]['forward'] = None\n        return result\n"
  },
  {
    "path": "lmdeploy/vl/media/__init__.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n"
  },
  {
    "path": "lmdeploy/vl/media/base.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/multimodal/media/base.py\n\nfrom abc import ABC, abstractmethod\nfrom pathlib import Path\nfrom typing import Generic, TypeVar\n\n_T = TypeVar('_T')\n\n\nclass MediaIO(ABC, Generic[_T]):\n\n    @abstractmethod\n    def load_bytes(self, data: bytes) -> _T:\n        raise NotImplementedError\n\n    @abstractmethod\n    def load_base64(self, media_type: str, data: str) -> _T:\n        raise NotImplementedError\n\n    @abstractmethod\n    def load_file(self, filepath: Path) -> _T:\n        raise NotImplementedError\n"
  },
  {
    "path": "lmdeploy/vl/media/connection.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport os\nfrom pathlib import Path\nfrom typing import TypeVar\nfrom urllib.parse import ParseResult, urlparse\nfrom urllib.request import url2pathname\n\nimport requests\n\nfrom .base import MediaIO\nfrom .image import ImageMediaIO\nfrom .video import VideoMediaIO\n\n_M = TypeVar('_M')\n\nheaders = {\n    'User-Agent':\n    'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 '\n    '(KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3'\n}\n\n\ndef _load_http_url(url_spec: ParseResult, media_io: MediaIO[_M]) -> _M:\n    if url_spec.scheme not in ('http', 'https'):\n        raise ValueError(f'Unsupported URL scheme: {url_spec.scheme}')\n\n    fetch_timeout = 10\n    if isinstance(media_io, ImageMediaIO):\n        fetch_timeout = int(os.environ.get('LMDEPLOY_IMAGE_FETCH_TIMEOUT', 10))\n    elif isinstance(media_io, VideoMediaIO):\n        fetch_timeout = int(os.environ.get('LMDEPLOY_VIDEO_FETCH_TIMEOUT', 30))\n\n    client = requests.Session()\n    response = client.get(url_spec.geturl(), headers=headers, timeout=fetch_timeout)\n    response.raise_for_status()\n\n    return media_io.load_bytes(response.content)\n\n\ndef _load_data_url(url_spec: ParseResult, media_io: MediaIO[_M]) -> _M:\n    url_spec_path = url_spec.path or ''\n    data_spec, data = url_spec_path.split(',', 1)\n    media_type, data_type = data_spec.split(';', 1)\n    # media_type starts with a leading \"/\" (e.g., \"/video/jpeg\")\n    media_type = media_type.lstrip('/')\n\n    if data_type != 'base64':\n        msg = 'Only base64 data URLs are supported for now.'\n        raise NotImplementedError(msg)\n\n    return media_io.load_base64(media_type, data)\n\n\ndef _load_file_url(url_spec: ParseResult, media_io: MediaIO[_M]) -> _M:\n    url_spec_path = url_spec.path or ''\n    url_spec_netloc = url_spec.netloc or ''\n    filepath = Path(url2pathname(url_spec_netloc + url_spec_path))\n    return media_io.load_file(filepath)\n\n\ndef load_from_url(url: str, media_io: MediaIO[_M]) -> _M:\n    \"\"\"Load media from a HTTP, data or file url.\"\"\"\n    url_spec = urlparse(url)\n\n    if url_spec.scheme and url_spec.scheme.startswith('http'):\n        return _load_http_url(url_spec, media_io)\n\n    if url_spec.scheme == 'data':\n        return _load_data_url(url_spec, media_io)\n\n    # file url or raw file path (absolute or relative)\n    if url_spec.scheme == 'file' or os.path.exists(url) or os.path.exists(url_spec.path):\n        return _load_file_url(url_spec, media_io)\n\n    msg = 'The URL must be either a HTTP, data or file URL.'\n    raise ValueError(msg)\n"
  },
  {
    "path": "lmdeploy/vl/media/image.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/multimodal/media/image.py\n\nfrom io import BytesIO\nfrom pathlib import Path\n\nimport pybase64\nfrom PIL import Image, ImageFile\n\nfrom .base import MediaIO\n\nImageFile.LOAD_TRUNCATED_IMAGES = True\n\n\nclass ImageMediaIO(MediaIO[Image.Image]):\n\n    def __init__(self, image_mode: str = 'RGB', **kwargs) -> None:\n        super().__init__()\n        self.image_mode = image_mode\n\n        # for potential custom arguments from --media-io-kwargs\n        self.kwargs = kwargs\n\n    def load_bytes(self, data: bytes) -> Image.Image:\n        image = Image.open(BytesIO(data))\n        return image.convert(self.image_mode)\n\n    def load_base64(self, media_type: str, data: str) -> Image.Image:\n        return self.load_bytes(pybase64.b64decode(data))\n\n    def load_file(self, file_path: Path) -> Image.Image:\n        with open(file_path, 'rb') as f:\n            data = f.read()\n        image = Image.open(BytesIO(data))\n        return image.convert(self.image_mode)\n\n    def encode_base64(self, image: Image.Image, image_format: str = 'PNG') -> str:\n        with BytesIO() as buffer:\n            image = image.convert(self.image_mode)\n            image.save(buffer, image_format)\n            data = buffer.getvalue()\n\n        return pybase64.b64encode(data).decode('utf-8')\n"
  },
  {
    "path": "lmdeploy/vl/media/time_series.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom io import BytesIO\nfrom pathlib import Path\n\nimport numpy as np\nimport numpy.typing as npt\nimport pybase64\n\nfrom lmdeploy.utils import get_logger\n\nfrom .base import MediaIO\n\nlogger = get_logger('lmdeploy')\n\n\nclass TimeSeriesMediaIO(MediaIO[npt.NDArray]):\n\n    def __init__(self, **kwargs):\n        super().__init__()\n\n        # for potential custom arguments from --media-io-kwargs\n        self.kwargs = kwargs\n\n    def load_bytes(self, data: bytes) -> npt.NDArray:\n        ts_array = np.load(BytesIO(data), allow_pickle=False)\n        return ts_array\n\n    def load_base64(self, media_type: str, data: str) -> npt.NDArray:\n        return self.load_bytes(pybase64.b64decode(data))\n\n    def load_file(self, filepath: Path) -> npt.NDArray:\n        suffix = filepath.suffix.lower()\n\n        if suffix == '.npy':\n            return np.load(filepath, allow_pickle=False)\n        elif suffix == '.csv':\n            try:\n                ts_array = np.genfromtxt(filepath, delimiter=',', dtype=np.float32)\n                if ts_array.size == 0:\n                    raise ValueError(f'CSV file {filepath} yielded no data.')\n                return ts_array\n            except Exception as e:\n                logger.error(f'Failed to load CSV {filepath}: {e}')\n                raise\n        elif suffix in ['.wav', '.mp3', '.flac']:\n            try:\n                import soundfile as sf\n            except ImportError:\n                raise ImportError('Please install soundfile via `pip install soundfile`.')\n\n            ts_array, _ = sf.read(filepath)\n            return ts_array\n\n        raise ValueError(f'Unsupported file format: {suffix}')\n\n    def encode_base64(self, data: npt.NDArray) -> str:\n        \"\"\"Encode numpy array to base64 string using NPY format.\"\"\"\n        buffer = BytesIO()\n        np.save(buffer, data, allow_pickle=False)\n        return pybase64.b64encode(buffer.getvalue()).decode('utf-8')\n"
  },
  {
    "path": "lmdeploy/vl/media/video.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/multimodal/media/video.py\n\nimport base64\nfrom functools import partial\nfrom pathlib import Path\nfrom typing import Any\n\nimport numpy as np\nimport numpy.typing as npt\nfrom PIL import Image\n\nfrom lmdeploy.utils import get_logger\n\nfrom .base import MediaIO\nfrom .image import ImageMediaIO\nfrom .video_loader import (DecordVideoLoader, OpenCVVideoLoader, TorchCodecVideoLoader, TorchVisionVideoLoader,\n                           VideoLoader)\n\nlogger = get_logger('lmdeploy')\n\n\nclass VideoMediaIO(MediaIO[tuple[npt.NDArray, dict[str, Any]]]):\n\n    def __init__(\n        self,\n        image_io: ImageMediaIO,\n        num_frames: int = 32,\n        **kwargs,\n    ) -> None:\n        super().__init__()\n\n        self.image_io = image_io\n        self.num_frames = num_frames\n\n        # for potential custom arguments from --media-io-kwargs\n        self.kwargs = kwargs\n        self.video_loader = self._get_video_loader_backend()\n\n    def _get_video_loader_backend(self) -> VideoLoader:\n        \"\"\"Determines the best available video loader backend.\"\"\"\n        # vLLM:          OpenCV\n        # SGLang:        Decord\n        # qwen-vl-utils: TorchCodec -> Decord -> TorchVision (deprecated soon)\n        backends = [\n            ('cv2', OpenCVVideoLoader),\n            ('decord', DecordVideoLoader),\n            ('torchcodec', TorchCodecVideoLoader),\n            ('torchvision', TorchVisionVideoLoader),\n        ]\n\n        for module_name, loader_cls in backends:\n            try:\n                __import__(module_name)\n                return loader_cls()\n            except (ImportError, RuntimeError):\n                logger.warning(f\"Video backend '{module_name}' not found. Trying next backend...\")\n                continue\n\n        raise ImportError(\n            'No video backend found. Install either opencv-python-headless, decord, torchcodec, or torchvision.')\n\n    def load_bytes(self, data: bytes) -> tuple[npt.NDArray, dict[str, Any]]:\n        return self.video_loader.load_bytes(data, num_frames=self.num_frames, **self.kwargs)\n\n    def load_base64(self, media_type: str, data: str) -> tuple[npt.NDArray, dict[str, Any]]:\n        if media_type.lower() == 'video/jpeg':\n            load_frame = partial(\n                self.image_io.load_base64,\n                'image/jpeg',\n            )\n\n            # NOTE: known issue in https://github.com/QwenLM/Qwen3-VL/issues/1643\n            # when passing a video as a sequence of JPEG frames, we cannot obtain the video metadata\n            # therefore we construct a default metadata dictionary with common values.\n            frames = np.stack([np.asarray(load_frame(frame_data)) for frame_data in data.split(',')])\n\n            total_frames_num = int(frames.shape[0])\n            fps = float(self.kwargs.get('fps', 2))  # default to 2 fps if not specified\n            duration = (total_frames_num / fps) if fps > 0 else 0\n            frame_idx = list(range(total_frames_num))\n\n            metadata = {\n                'total_num_frames': total_frames_num,\n                'fps': fps,\n                'duration': duration,\n                'video_backend': 'jpeg_sequence',\n                'frames_indices': frame_idx,\n            }\n\n            logger.info('Loading video from base64-encoded JPEG frames misses video metadata.'\n                        f'Fall back to default metadata values:\\n{metadata}')\n            return frames, metadata\n\n        return self.load_bytes(base64.b64decode(data))\n\n    def load_file(self, filepath: Path) -> tuple[npt.NDArray, dict[str, Any]]:\n        return self.video_loader.load_file(filepath, num_frames=self.num_frames, **self.kwargs)\n\n    def encode_base64(\n        self,\n        media: npt.NDArray,\n        *,\n        video_format: str = 'JPEG',\n    ) -> str:\n        video = media\n\n        if video_format == 'JPEG':\n            encode_frame = partial(\n                self.image_io.encode_base64,\n                image_format=video_format,\n            )\n\n            return ','.join(encode_frame(Image.fromarray(frame)) for frame in video)\n\n        msg = 'Only JPEG format is supported for now.'\n        raise NotImplementedError(msg)\n"
  },
  {
    "path": "lmdeploy/vl/media/video_loader.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/multimodal/video.py\n# adapted from https://github.com/QwenLM/Qwen3-VL/blob/main/qwen-vl-utils/src/qwen_vl_utils/vision_process.py\n\nimport math\nimport os\nimport tempfile\nfrom abc import abstractmethod\nfrom io import BytesIO\nfrom pathlib import Path\nfrom typing import Any\n\nimport numpy as np\nimport numpy.typing as npt\n\nfrom lmdeploy.utils import get_logger\n\nlogger = get_logger('lmdeploy')\n\n\nclass VideoLoader:\n\n    @classmethod\n    @abstractmethod\n    def load_bytes(self, data: bytes, num_frames: int = -1, **kwargs) -> tuple[npt.NDArray, dict[str, Any]]:\n        raise NotImplementedError\n\n    @classmethod\n    def smart_nframes(self, total_frames_num: int, num_frames: int, fps: int, duration: int) -> tuple[int, list[int]]:\n        # resample video to target num_frames and fps\n        # - the minimum of the two will be used\n        num_frames_to_sample = total_frames_num\n        if num_frames > 0:\n            num_frames_to_sample = min(num_frames, total_frames_num)\n        if fps > 0:\n            num_frames_to_sample = min(num_frames_to_sample, math.floor(duration * fps))\n        num_frames_to_sample = max(1, num_frames_to_sample)  # at least one sample\n\n        if num_frames_to_sample == total_frames_num:\n            frame_idx = list(range(0, num_frames_to_sample))\n        else:\n            uniform_sampled_frames = np.linspace(0, total_frames_num - 1, num_frames_to_sample, dtype=int)\n            frame_idx = uniform_sampled_frames.tolist()\n        return num_frames_to_sample, frame_idx\n\n\nclass OpenCVVideoLoader(VideoLoader):\n\n    def get_cv2_video_api(self):\n        import cv2.videoio_registry as vr\n\n        api_pref = None\n        for backend in vr.getStreamBufferedBackends():\n            if not vr.hasBackend(backend):\n                continue\n            if not vr.isBackendBuiltIn(backend):\n                _, abi, api = vr.getStreamBufferedBackendPluginVersion(backend)\n                if abi < 1 or (abi == 1 and api < 2):\n                    continue\n            api_pref = backend\n            break\n        return api_pref\n\n    @staticmethod\n    def _read_frames(\n        cap,\n        frame_indices: set[int],\n        num_expected_frames: int,\n        max_frame_idx: int,\n    ) -> tuple[npt.NDArray, int, list[int]]:\n        import cv2\n\n        width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))\n        height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))\n        frames = np.empty((num_expected_frames, height, width, 3), dtype=np.uint8)  # THWC\n\n        i = 0\n        valid_frame_indices = []\n        for idx in range(max_frame_idx + 1):\n            ok = cap.grab()\n            if not ok:\n                # Frame is broken/unreadable, log warning\n                if idx in frame_indices:\n                    logger.warning(\n                        'Failed to grab frame %d during video loading. '\n                        'This frame will be skipped.',\n                        idx,\n                    )\n                continue\n            if idx in frame_indices:\n                ret, frame = cap.retrieve()\n                if ret:\n                    frames[i] = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)\n                    valid_frame_indices.append(idx)\n                    i += 1\n                else:\n                    # retrieve() failed even though grab() succeeded\n                    logger.warning(\n                        'Failed to retrieve frame %d during video loading. '\n                        'This frame will be skipped.',\n                        idx,\n                    )\n\n        valid_num_frames = len(valid_frame_indices)\n        if valid_num_frames < num_expected_frames:\n            logger.warning(\n                'Video loading completed with %d broken/unreadable frames. '\n                'Expected %d frames but only loaded %d frames.',\n                num_expected_frames - valid_num_frames,\n                num_expected_frames,\n                valid_num_frames,\n            )\n\n        return frames[:valid_num_frames], valid_num_frames, valid_frame_indices\n\n    @classmethod\n    def load_file(\n        self,\n        filepath: Path,\n        num_frames: int = -1,\n        fps: int = -1,\n        max_duration: int = 300,\n        **kwargs,\n    ) -> tuple[npt.NDArray, dict[str, Any]]:\n        with open(filepath, 'rb') as f:\n            data = f.read()\n        return self.load_bytes(data, num_frames=num_frames, fps=fps, max_duration=max_duration, **kwargs)\n\n    @classmethod\n    def load_bytes(\n        cls,\n        data: bytes,\n        num_frames: int = -1,\n        fps: int = -1,\n        max_duration: int = 300,\n        **kwargs,\n    ) -> tuple[npt.NDArray, dict[str, Any]]:\n        \"\"\"Load video frames from bytes.\n\n        Args:\n            data: Raw video bytes\n            num_frames: Target number of frames to sample (-1 for all)\n            fps: Target FPS for sampling (-1 for original)\n            max_duration: Maximum duration (unused in base backend)\n\n        Returns:\n            Tuple of (frames_array, metadata_dict)\n        \"\"\"\n        import cv2\n\n        backend = cls().get_cv2_video_api()\n        cap = cv2.VideoCapture(BytesIO(data), backend, [])\n        if not cap.isOpened():\n            raise ValueError('Could not open video stream')\n\n        total_frames_num = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))\n        original_fps = cap.get(cv2.CAP_PROP_FPS)\n        duration = total_frames_num / original_fps if original_fps > 0 else 0\n\n        num_frames_to_sample, frame_idx = cls.smart_nframes(total_frames_num, num_frames, fps, duration)\n\n        frame_idx_set = set(frame_idx)\n        frames, valid_num_frames, valid_frame_indices = cls._read_frames(cap, frame_idx_set, num_frames_to_sample,\n                                                                         max(frame_idx))\n\n        # Use transformers transformers.video_utils.VideoMetadata format\n        # For models like Qwen3-VL/GLM4.5V, this metadata\n        # can cause incorrect timestamp calculation without num_frames=-1.\n        # TODO: zhouxinyu, support per-request do_sample_frames\n        metadata = {\n            'total_num_frames': total_frames_num,\n            'fps': original_fps,\n            'duration': duration,\n            'video_backend': 'opencv',\n            'frames_indices': valid_frame_indices,\n            # extra field used to control hf processor's video\n            # sampling behavior\n            # \"do_sample_frames\": valid_num_frames == total_frames_num,\n        }\n        return frames, metadata\n\n\nclass DecordVideoLoader(VideoLoader):\n\n    @classmethod\n    def load_file(self,\n                  filepath: Path,\n                  num_frames: int = -1,\n                  fps: int = -1,\n                  max_duration: int = 300,\n                  **kwargs) -> tuple[npt.NDArray, dict[str, Any]]:\n        import decord\n        vr = decord.VideoReader(str(filepath))\n        total_frames_num = len(vr)\n        original_fps = vr.get_avg_fps()\n        duration = total_frames_num / original_fps if original_fps > 0 else 0\n\n        num_frames_to_sample, frame_idx = self.smart_nframes(total_frames_num, num_frames, fps, duration)\n\n        video = vr.get_batch(frame_idx).asnumpy()  # THWC\n        metadata = {\n            'total_num_frames': total_frames_num,\n            'fps': original_fps,\n            'duration': duration,\n            'video_backend': 'decord',\n            'frames_indices': frame_idx,\n        }\n        return video, metadata\n\n    @classmethod\n    def load_bytes(self,\n                   data: bytes,\n                   num_frames: int = -1,\n                   fps: int = -1,\n                   max_duration: int = 300,\n                   **kwargs) -> tuple[npt.NDArray, dict[str, Any]]:\n        tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4')\n        try:\n            tmp_file.write(data)\n            tmp_file.close()\n            return self.load_file(Path(tmp_file.name),\n                                  num_frames=num_frames,\n                                  fps=fps,\n                                  max_duration=max_duration,\n                                  **kwargs)\n        finally:\n            # always cleanup, even if load_file crashes\n            try:\n                os.unlink(tmp_file.name)\n            except OSError:\n                pass  # file might not exist if write failed\n\n\nclass TorchCodecVideoLoader(VideoLoader):\n\n    @classmethod\n    def load_file(self,\n                  filepath: Path,\n                  num_frames: int = -1,\n                  fps: int = -1,\n                  max_duration: int = 300,\n                  **kwargs) -> tuple[npt.NDArray, dict[str, Any]]:\n        # torchcodec requires matched ffmpeg, torchcodec, and torch versions\n        # ffmpeg 5.1.2, torch 2.8.0, torchcodec 0.7.0 are verified to work together\n        from torchcodec.decoders import VideoDecoder\n\n        torch_codec_num_threads = 8\n        decoder = VideoDecoder(str(filepath), num_ffmpeg_threads=torch_codec_num_threads)\n        total_frames_num = decoder.metadata.num_frames\n        original_fps = decoder.metadata.average_fps\n        duration = total_frames_num / original_fps if original_fps > 0 else 0\n\n        num_frames_to_sample, frame_idx = self.smart_nframes(total_frames_num, num_frames, fps, duration)\n\n        video = decoder.get_frames_at(frame_idx).data\n        metadata = {\n            'total_num_frames': total_frames_num,\n            'fps': original_fps,\n            'duration': duration,\n            'video_backend': 'torchcodec',\n            'frames_indices': frame_idx,\n        }\n        return video, metadata\n\n    @classmethod\n    def load_bytes(self,\n                   data: bytes,\n                   num_frames: int = -1,\n                   fps: int = -1,\n                   max_duration: int = 300,\n                   **kwargs) -> tuple[npt.NDArray, dict[str, Any]]:\n        tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4')\n        try:\n            tmp_file.write(data)\n            tmp_file.close()\n            return self.load_file(Path(tmp_file.name),\n                                  num_frames=num_frames,\n                                  fps=fps,\n                                  max_duration=max_duration,\n                                  **kwargs)\n        finally:\n            # always cleanup, even if load_file crashes\n            try:\n                os.unlink(tmp_file.name)\n            except OSError:\n                pass  # file might not exist if write failed\n\n\nclass TorchVisionVideoLoader(VideoLoader):\n\n    @classmethod\n    def load_file(self,\n                  filepath: Path,\n                  num_frames: int = -1,\n                  fps: int = -1,\n                  max_duration: int = 300,\n                  **kwargs) -> tuple[npt.NDArray, dict[str, Any]]:\n        import torchvision\n\n        video, audio, info = torchvision.io.read_video(\n            filepath,\n            pts_unit='sec',\n            output_format='THWC',\n        )\n        total_frames_num = video.size(0)\n        original_fps = info['video_fps']\n        duration = total_frames_num / original_fps if original_fps > 0 else 0\n\n        num_frames_to_sample, frame_idx = self.smart_nframes(total_frames_num, num_frames, fps, duration)\n\n        video = video[frame_idx]\n        metadata = {\n            'total_num_frames': total_frames_num,\n            'fps': original_fps,\n            'duration': duration,\n            'video_backend': 'torchvision',\n            'frames_indices': frame_idx,\n        }\n        return video, metadata\n\n    @classmethod\n    def load_bytes(self,\n                   data: bytes,\n                   num_frames: int = -1,\n                   fps: int = -1,\n                   max_duration: int = 300,\n                   **kwargs) -> tuple[npt.NDArray, dict[str, Any]]:\n        tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4')\n        try:\n            tmp_file.write(data)\n            tmp_file.close()\n            return self.load_file(Path(tmp_file.name),\n                                  num_frames=num_frames,\n                                  fps=fps,\n                                  max_duration=max_duration,\n                                  **kwargs)\n        finally:\n            # always cleanup, even if load_file crashes\n            try:\n                os.unlink(tmp_file.name)\n            except OSError:\n                pass  # file might not exist if write failed\n"
  },
  {
    "path": "lmdeploy/vl/model/__init__.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n"
  },
  {
    "path": "lmdeploy/vl/model/base.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom abc import ABC, abstractmethod\nfrom itertools import groupby\nfrom typing import Dict, List, Union\n\nimport numpy as np\nfrom mmengine import Registry\nfrom transformers import AutoConfig, AutoTokenizer\n\nfrom lmdeploy.archs import get_model_arch\n\nVISION_MODELS = Registry('vision_model')\n\n\nclass VisionModel(ABC):\n    \"\"\"Visual model which extract image feature.\"\"\"\n    _arch: Union[str, List[str]] = None\n\n    def __init__(self,\n                 model_path: str,\n                 with_llm: bool = False,\n                 max_memory: Dict[int, int] = None,\n                 hf_config: AutoConfig = None,\n                 backend: str = ''):\n        \"\"\"init.\"\"\"\n        self.model_path = model_path\n        self.with_llm = with_llm\n        self.max_memory = max_memory\n        self.backend = backend\n        if hf_config is None:\n            _, hf_config = get_model_arch(model_path)\n        self.hf_config = hf_config\n        self.image_token_id = self.get_pad_token_id(model_path, hf_config) or 0\n\n    def get_pad_token_id(self, model_path, hf_config):\n        \"\"\"Get pad_token_id from hf_config or tokenizer.\"\"\"\n        pad_token_id = getattr(hf_config, 'pad_token_id', None)\n        if pad_token_id is None:\n            try:\n                tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)\n                pad_token_id = getattr(tokenizer, 'pad_token_id', None)\n            except Exception as e:\n                print(e)\n                pass\n        return pad_token_id\n\n    @abstractmethod\n    def build_preprocessor(self, ):\n        \"\"\"Build the preprocessor.\n\n        NOTE: When the derived class implements this method, try not to\n        introduce the upper stream model repo as a thirdparty package\n        \"\"\"\n        raise NotImplementedError()\n\n    def build_model(self, ):\n        \"\"\"Build the vision part of a VLM model when backend is turbomind.\n\n        But when `with_llm=True`, load the whole VLM model\n        \"\"\"\n        if self.backend == 'turbomind' or self.with_llm:\n            raise NotImplementedError()\n\n    @abstractmethod\n    def preprocess(self, messages: List[Dict]) -> List[Dict]:\n        \"\"\"Preprocess multimodal data in the messages.\n\n        The derived class,\n        i.e., a specific vision model, takes the charge of image preprocessing\n        and the result management.\n        It can integrate the result into the messages list, or insert it to\n        the individual image item.\n        Args:\n            message(Dict): multimodal data in a dict, which is as follows:\n            [\n                {'role': 'user', 'content': 'user prompt'},\n                {'role': 'assisant', 'content': 'AI reponse'},\n                {\n                    'role': 'user',\n                    'content': [\n                        {\n                            'type': 'text',\n                            'text': 'string',\n                        },\n                        {\n                            'type': 'image',\n                            'image': pillow.Image,\n                            'key1': value1,\n                            ...\n                        },\n                        {\n                            'type': 'image',\n                            'image': pillow.Image,\n                            'key1': value1,\n                            ...\n                        },\n                        ...\n                    ]\n                }\n                {....}\n            ]\n        Returns:\n            the message list with preprocessing results included, which is\n            determined by the derived classes\n        \"\"\"  # noqa\n        raise NotImplementedError()\n\n    def has_input_ids(self, messages: List[Dict]) -> bool:\n        \"\"\"Check whether the messages contain input_ids directly.\n\n        Args:\n            messages (List[Dict]): a list of message, which is supposed to be\n                the output of `preprocess`\n        Returns:\n            bool: whether the messages contain input_ids directly\n        \"\"\"\n        users = [x['content'] for x in messages if x['role'] == 'user']\n        return len(users) == 1 and isinstance(users[0], List) and isinstance(users[0][0].get('text', ''), List)\n\n    def forward(self, messages: List[Dict], max_batch_size: int = 1) -> List[Dict]:\n        \"\"\"Extract image feature. ONLY implement it when the backend is\n        turbomind engine.\n\n        Args:\n            messages(List[Dict]): the outputs of `preprocess`\n            max_batch_size(int): the max batch size when forwarding vision\n                model\n        Return:\n            the message list with forwarding results included, which is\n            determined by the derived classes\n        \"\"\"\n        if self.backend == 'turbomind':\n            raise NotImplementedError()\n\n    def to_pytorch(self, messages, chat_template, tokenizer, sequence_start, chat_template_kwargs=None, **kwargs):\n        \"\"\"Pack the preprocessing results in a format compatible with what is\n        required by pytorch engine. ONLY implement it when the backend is\n        pytorch engine.\n\n        Args:\n            messages(List[Dict]): the output of `preprocess`\n            chat_template: the chat template defined in `lmdeploy/model.py`\n            tokenzer: the tokenizer model\n            sequence_start: starting flag of a sequence\n            chat_template_kwargs: additional arguments for chat template\n                processing, such as `add_vision_id` and `enable_thinking`\n        \"\"\"\n        if self.backend == 'pytorch':\n            raise NotImplementedError()\n\n    def to_turbomind(self, messages, chat_template, tokenizer, sequence_start, chat_template_kwargs=None, **kwargs):\n        \"\"\"Pack the forwarding results in a format compatible with what is\n        required by turbomind engine. ONLY implement it when the backend is\n        turbomind engine.\n\n        Args:\n            messages(List[Dict]): the output of `preprocess`\n            chat_template: the chat template defined in `lmdeploy/model.py`\n            tokenzer: the tokenizer model\n            sequence_start: starting flag of a sequence\n            chat_template_kwargs: additional arguments for chat template\n                processing, such as `add_vision_id` and `enable_thinking`\n        \"\"\"\n        if self.backend == 'turbomind':\n            raise NotImplementedError()\n\n    @staticmethod\n    def collect_multimodal_items(messages):\n        \"\"\"Gather all multimodal items along with their respective parameters\n        from the messages and compile them into a single list.\n\n        Args:\n            messages (List[Dict]): a list of message\n        Returns:\n            List[Tuple[Modality, Any, Dict]]: a list of (modality, data, params) for each multimodal item\n        \"\"\"\n        multimodal_items = []\n        for message in messages:\n            content = message['content']\n            if not isinstance(content, list):\n                continue\n\n            for x in content:\n                if not isinstance(x, dict):\n                    continue\n\n                modality = x.get('type')\n                if modality is None or modality == 'text':\n                    continue\n\n                data = x.get('data')\n                params = {k: v for k, v in x.items() if k not in ['type', 'data']}\n                multimodal_items.append((modality, data, params))\n\n        return multimodal_items\n\n    @staticmethod\n    def IMAGE_TOKEN_included(messages):\n        \"\"\"Check whether the IMAGE_TOKEN is included in the messages.\n\n        Args:\n            messages (List[Dict]): a list of message\n        Returns:\n            bool: whether the IMAGE_TOKEN is included in the messages\n        \"\"\"\n        for message in messages:\n            role, content = message['role'], message['content']\n            if role != 'user':\n                continue\n            if isinstance(content, str) and '<IMAGE_TOKEN>' in content:\n                return True\n            elif isinstance(content, List):\n                content = [x['text'] for x in content if x['type'] == 'text']\n                if any('<IMAGE_TOKEN>' in x for x in content):\n                    return True\n        return False\n\n    def to_pytorch_with_input_ids(self, messages):\n        \"\"\"Pack the preprocessing results in a format compatible with what is\n        required by pytorch engine when input_ids are provided directly.\n\n        Args:\n            messages(List[Dict]): the output of `preprocess`\n        \"\"\"\n        # collect all preprocessing result from messages\n        preps = [x['content'] for x in messages if x['role'] == 'preprocess']\n        assert len(preps) == 1\n        preps = preps[0]\n\n        _input_ids = messages[0]['content'][0]['text']\n        segs = []\n        for k, g in groupby(_input_ids, lambda x: x == self.image_token_id):\n            if not k:\n                segs.append(list(g))\n            else:\n                segs.extend([[]] * (len(list(g)) - 1))\n        if _input_ids[0] == self.image_token_id:\n            segs = [[]] + segs\n        if _input_ids[-1] == self.image_token_id:\n            segs = segs + [[]]\n\n        assert self.image_token_id == preps[0]['image_token_id']\n        assert len(segs) == len(preps) + 1, (f'the number of image token id {self.image_token_id} is not equal '\n                                             f'to input images, {len(segs) - 1} vs {len(preps)}')\n        input_ids = []\n        for i, seg in enumerate(segs):\n            if i > 0 and i <= len(preps):\n                preps[i - 1].update(offset=len(input_ids))\n                image_tokens = preps[i - 1]['image_tokens']\n                input_ids.extend([self.image_token_id] * image_tokens)\n            input_ids.extend(seg)\n\n        return dict(prompt=None, input_ids=input_ids, multimodal=preps)\n\n    def to_pytorch_aux(self, messages, prompt, IMAGE_TOKEN, tokenizer, sequence_start):\n        \"\"\"Auxiliary function to pack the preprocessing results in a format\n        compatible with what is required by pytorch engine.\n\n        Args:\n            messages(List[Dict]): the output of `preprocess`\n            prompt(str): the prompt after applying chat template\n            IMAGE_TOKEN(str): a placeholder where image tokens will be\n                inserted\n            tokenzer: the tokenizer model\n            sequence_start: starting flag of a sequence\n        \"\"\"\n        # collect all preprocessing result from messages\n        preps = [x['content'] for x in messages if x['role'] == 'preprocess']\n        assert len(preps) == 1\n        preps = preps[0]\n\n        # split prompt into segments and validate data\n        segs = prompt.split(IMAGE_TOKEN)\n        assert len(segs) == len(preps) + 1, (f'the number of {IMAGE_TOKEN} is not equal '\n                                             f'to input images, {len(segs) - 1} vs {len(preps)}')\n\n        # calculate the image token offset for each image\n        input_ids = []\n        for i, seg in enumerate(segs):\n            if i > 0 and i <= len(preps):\n                preps[i - 1].update(offset=len(input_ids))\n                image_tokens = preps[i - 1]['image_tokens']\n                assert self.image_token_id == preps[i - 1]['image_token_id']\n                input_ids.extend([self.image_token_id] * image_tokens)\n            token_ids = tokenizer.encode(seg, add_bos=((i == 0) and sequence_start))\n            input_ids.extend(token_ids)\n\n        return dict(prompt=prompt, input_ids=input_ids, multimodal=preps)\n\n    def to_turbomind_aux(self, messages, prompt, IMAGE_TOKEN, tokenizer, sequence_start):\n        \"\"\"Auxiliary function to pack the forwarding results in a format\n        compatible with what is required by turbomind engine.\n\n        Args:\n            messages(List[Dict]): the output of `preprocess`\n            prompt(str): the prompt after applying chat template\n            IMAGE_TOKEN(str): a placeholder where image tokens will be\n                inserted\n            tokenzer: the tokenizer model\n            sequence_start: starting flag of a sequence\n        \"\"\"\n        # collect image features from messages\n        features = [x['content'] for x in messages if x['role'] == 'forward']\n        features = features[0]\n        features = [x.cpu() for x in features]\n        # split prompt into segments and validate data\n        segs = prompt.split(IMAGE_TOKEN)\n        assert len(segs) == len(features) + 1, (f'the number of {IMAGE_TOKEN} is not equal '\n                                                f'to input images, {len(segs) - 1} vs {len(features)}')\n\n        # tokenizer prompt, and get input_embeddings and input_embedding_ranges\n        input_ids = []\n        begins = []\n        ends = []\n        for i, seg in enumerate(segs):\n            if i > 0 and i <= len(features):\n                image_dim = features[i - 1].shape[0]\n                begins.append(len(input_ids))\n                ends.append(begins[-1] + image_dim)\n                input_ids.extend([self.image_token_id] * image_dim)\n            seg_ids = tokenizer.encode(seg, add_bos=((i == 0) and sequence_start))\n            input_ids.extend(seg_ids)\n        ranges = np.stack([begins, ends], axis=1).tolist()\n        return dict(prompt=prompt, input_ids=input_ids, input_embeddings=features, input_embedding_ranges=ranges)\n\n    @classmethod\n    def match(cls, config: AutoConfig):\n        \"\"\"Check whether the config match the model.\"\"\"\n        arch = config.architectures[0] if config.architectures else None\n        if arch and (arch == cls._arch or arch in cls._arch):\n            return True\n        return False\n"
  },
  {
    "path": "lmdeploy/vl/model/builder.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport os\nfrom typing import Optional, Union\n\nimport torch\n\nfrom lmdeploy.archs import get_model_arch\nfrom lmdeploy.messages import PytorchEngineConfig, TurbomindEngineConfig\nfrom lmdeploy.utils import get_logger, get_model\nfrom lmdeploy.vl.model.base import VISION_MODELS\n\nfrom .cogvlm import CogVLMVisionModel  # noqa F401\nfrom .deepseek import DeepSeekVisionModel  # noqa F401\nfrom .deepseek_vl2 import DeepSeek2VisionModel  # noqa F401\nfrom .gemma3_vl import Gemma3VisionModel  # noqa F401\nfrom .glm4_1v import GLM4_1_VisionModel  # noqa F401\nfrom .glm4_v import GLM4VisionModel  # noqa F401\nfrom .interns1_pro import InternS1ProVisionModel  # noqa F401\nfrom .internvl import InternVLVisionModel  # noqa F401\nfrom .internvl3_hf import InternVL3VisionModel  # noqa F401\nfrom .internvl_llava import InternVLLlavaVisionModel  # noqa F401\nfrom .llama4 import LLama4VisionModel  # noqa F401\nfrom .llava import LlavaVisionModel  # noqa F401\nfrom .llava_hf import LlavaHfVisionModel  # noqa F401\nfrom .llava_next import LlavaNextVisionModel  # noqa F401\nfrom .minicpmv import MiniCPMVModel  # noqa F401\nfrom .mllama import MllamaVLModel  # noqa F401\nfrom .molmo import MolmoVisionModel  # noqa F401\nfrom .phi3_vision import Phi3VisionModel  # noqa F401\nfrom .qwen import QwenVisionModel  # noqa F401\nfrom .qwen2 import Qwen2VLModel  # noqa F401\nfrom .qwen3 import Qwen3VLModel  # noqa F401\nfrom .qwen3_5 import Qwen3_5Model  # noqa F401\nfrom .xcomposer2 import Xcomposer2VisionModel  # noqa F401\nfrom .yi import YiVisionModel  # noqa F401\n\nlogger = get_logger('lmdeploy')\n\n\ndef load_vl_model(model_path: str,\n                  backend: str,\n                  with_llm: bool = False,\n                  backend_config: Optional[Union[TurbomindEngineConfig, PytorchEngineConfig]] = None):\n    \"\"\"Load visual model.\n\n    Args:\n        model_path(str): the path or repo_id from model hub of the model\n        backend(str): the name of inference backend\n        with_llm(bool): load LLM model or not. Set it to False for VLM\n            inference scenarios and True for VLM quantization\n        backend_config: the config of the inference engine\n    \"\"\"\n    if not os.path.exists(model_path):\n        revision = getattr(backend_config, 'revision', None)\n        download_dir = getattr(backend_config, 'download_dir', None)\n        model_path = get_model(model_path, revision=revision, download_dir=download_dir)\n\n    max_memory = None\n    if not with_llm:\n        tp = getattr(backend_config, 'tp', 1)\n        max_memory = {i: torch.cuda.mem_get_info(i)[0] for i in range(tp)} if backend == 'turbomind' else None\n\n    _, hf_config = get_model_arch(model_path)\n    kwargs = dict(model_path=model_path, with_llm=with_llm, max_memory=max_memory, hf_config=hf_config, backend=backend)\n\n    for name, module in VISION_MODELS.module_dict.items():\n        try:\n            if module.match(hf_config):\n                logger.info(f'matching vision model: {name}')\n                model = module(**kwargs)\n                model.build_preprocessor()\n                # build the vision part of a VLM model when backend is\n                # turbomind, or load the whole VLM model when `with_llm==True`\n                if backend == 'turbomind' or with_llm:\n                    model.build_model()\n                return model\n        except Exception as e:\n            logger.error(f'build vision model {name} failed, {e}')\n            raise\n\n    raise ValueError(f'unsupported vl model with config {hf_config}')\n"
  },
  {
    "path": "lmdeploy/vl/model/cogvlm.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom typing import Dict, List\n\nfrom lmdeploy.utils import get_logger\nfrom lmdeploy.vl.model.base import VISION_MODELS, VisionModel\n\nlogger = get_logger('lmdeploy')\n\n\n@VISION_MODELS.register_module()\nclass CogVLMVisionModel(VisionModel):\n    \"\"\"CogVLM vision model.\"\"\"\n\n    _arch = 'CogVLMForCausalLM'\n\n    def build_preprocessor(self):\n        from torchvision import transforms\n        self.image_transform = transforms.Compose([\n            transforms.Resize((self.hf_config.vision_config['image_size'], ) * 2,\n                              interpolation=transforms.InterpolationMode.BICUBIC),\n            transforms.ToTensor(),\n            transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),\n        ])\n        image_size = self.hf_config.vision_config['image_size']\n        patch_size = self.hf_config.vision_config['patch_size']\n        if self.hf_config.vision_config['num_positions'] == 1226:\n            # cogvlm-chat-hf, https://huggingface.co/THUDM/cogvlm-chat-hf/blob/e29dc3ba206d524bf8efbfc60d80fc4556ab0e3c/modeling_cogvlm.py#L820 # noqa E501\n            self.n_token_per_image = 2 + (image_size // patch_size)**2\n        else:\n            # cogvlm2, https://huggingface.co/THUDM/cogvlm2-llama3-chinese-chat-19B/blob/2c2226281325649d49b8aa237a932367c7da4f26/modeling_cogvlm.py#L819 # noqa E501\n            self.n_token_per_image = 2 + (image_size // patch_size // 2)**2\n\n    def build_model(self):\n        if self.with_llm:\n            from transformers import AutoModelForCausalLM\n            self.vl_model = AutoModelForCausalLM.from_pretrained(self.model_path,\n                                                                 device_map='cpu',\n                                                                 trust_remote_code=True)\n        else:\n            raise NotImplementedError('turbomind has not supported cogvlm yet')\n\n    def preprocess(self, messages: List[Dict]) -> List[Dict]:\n        \"\"\"Refer to the spec of `super().preprocess`\"\"\"\n        images = self.collect_multimodal_items(messages)\n        outputs = []\n        for modality, image, _ in images:\n            image = image.convert('RGB')\n            pixel_values = self.image_transform(image)\n            outputs.append(\n                dict(pixel_values=pixel_values,\n                     image_size=image.size,\n                     image_tokens=self.n_token_per_image,\n                     image_token_id=self.image_token_id))\n        messages.append(dict(role='preprocess', content=outputs))\n        return messages\n\n    @staticmethod\n    def proc_messages(messages, chat_template, sequence_start):\n        \"\"\"Apply chat template to get the prompt.\"\"\"\n        prompt_messages = []\n        for message in messages:\n            if isinstance(message['content'], str):\n                prompt_messages.append(message)\n                continue\n            elif message['role'] in ['images', 'preprocess', 'forward']:\n                continue\n            content = [x.get('text', '') for x in message['content'] if x['type'] == 'text']\n            n_images = len([1 for x in message['content'] if x['type'] == 'image'])\n\n            prompt_messages.append(dict(role='user', content=content[0], num_images=n_images))\n\n        from lmdeploy.model import Vicuna\n        llm_chat_template = Vicuna(eoa='</s>', stop_words=chat_template.stop_words)\n        prompt = ''\n        IMAGE_TOKEN = '<IMAGE_TOKEN>'\n        for i, msg in enumerate(prompt_messages):\n            num_images = msg.pop('num_images', 0)\n            if num_images == 0:\n                role = msg['role']\n                msg = llm_chat_template.messages2prompt([msg], sequence_start and i == 0)\n                msg = dict(role=role, content=msg)\n            prompt_i = chat_template.messages2prompt([msg], sequence_start and i == 0)\n            if num_images > 0:\n                prompt_i = (IMAGE_TOKEN * num_images) + prompt_i\n            prompt += prompt_i\n        return prompt, IMAGE_TOKEN\n\n    def to_pytorch(self, messages, chat_template, tokenizer, sequence_start, **kwargs):\n        prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, sequence_start)\n        return self.to_pytorch_aux(messages, prompt, IMAGE_TOKEN, tokenizer, sequence_start)\n"
  },
  {
    "path": "lmdeploy/vl/model/deepseek.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport warnings\nfrom typing import Dict, List\n\nimport torch\nfrom transformers import AutoModelForCausalLM\n\nfrom lmdeploy.utils import get_logger\nfrom lmdeploy.vl.model.base import VISION_MODELS, VisionModel\nfrom lmdeploy.vl.model.utils import disable_logging\n\nlogger = get_logger('lmdeploy')\n\n\ndef check_deepseek_vl_install():\n    \"\"\"Check deepseek_vl install.\"\"\"\n    try:\n        import deepseek_vl  # noqa: F401\n    except ImportError:\n        raise ImportError('To use DeepSeekVLModel, please install deepseek_vl by '\n                          '`pip install git+https://github.com/deepseek-ai/DeepSeek-VL.git'\n                          ' --no-deps`')\n\n\n@VISION_MODELS.register_module()\nclass DeepSeekVisionModel(VisionModel):\n    \"\"\"Qwen vision model.\"\"\"\n\n    _arch = 'MultiModalityCausalLM'\n\n    def build_preprocessor(self):\n        check_deepseek_vl_install()\n        from deepseek_vl.models import VLChatProcessor\n        vl_chat_processor = VLChatProcessor.from_pretrained(self.model_path)\n        tokenizer = vl_chat_processor.tokenizer\n        self.image_token_id = tokenizer.vocab.get(vl_chat_processor.image_tag)\n        self.image_processor = vl_chat_processor.image_processor\n\n    def build_model(self):\n        \"\"\"Build the vision part of a VLM model when backend is turbomind, or\n        load the whole VLM model when `self.with_llm==True`\"\"\"\n        from accelerate import init_empty_weights\n        with init_empty_weights():\n            warnings.simplefilter('ignore')\n            model = AutoModelForCausalLM.from_pretrained(self.model_path)\n            self.vl_model = model\n            if not self.with_llm:\n                del model.language_model\n\n        from accelerate.utils import get_balanced_memory, infer_auto_device_map\n        max_memory = get_balanced_memory(model,\n                                         max_memory=self.max_memory,\n                                         dtype=torch.half,\n                                         no_split_module_classes=['Block'])\n        device_map = infer_auto_device_map(model,\n                                           no_split_module_classes=['Block'],\n                                           max_memory=max_memory,\n                                           dtype=torch.half)\n        same_device_keys = [('vision_model.vision_tower_high.vision_tower.pos_embed',\n                             'vision_model.vision_tower_high.vision_tower.patch_embed'),\n                            ('vision_model.vision_tower_low.vision_tower.pos_embed',\n                             'vision_model.vision_tower_low.vision_tower.patch_embed')]\n        for (a, b) in same_device_keys:\n            if a in device_map and b in device_map:\n                device_map[b] = device_map[a]\n        downsamples = []\n        ka = 'vision_model.vision_tower_high.vision_tower.downsamples'\n        kb = 'vision_model.vision_tower_high.vision_tower.hd_alpha_downsamples'  # noqa: E501\n        for k in device_map:\n            if k.startswith(ka):\n                downsamples.append(k)\n        if len(downsamples) == 1:\n            device_map[ka] = device_map[kb]\n        elif len(downsamples) > 1:\n            numbers = [int(x[len(ka) + 1:]) for x in downsamples]\n            device_map[f'{ka}.{numbers[-1]}'] = device_map[kb]\n\n        from accelerate import load_checkpoint_and_dispatch\n        with disable_logging():\n            load_checkpoint_and_dispatch(model=model,\n                                         checkpoint=self.model_path,\n                                         device_map=device_map if not self.with_llm else {'': 'cpu'},\n                                         dtype=torch.half)\n\n        self.model = model.eval()\n        self.vision_model = model.vision_model.eval()\n        self.aligner = model.aligner.eval()\n\n    def preprocess(self, messages: List[Dict]) -> List[Dict]:\n        \"\"\"Refers to the spec of `super.preprocess()\"\"\"\n        images = self.collect_multimodal_items(messages)\n        outputs = []\n        for modality, image, _ in images:\n            image = image.convert('RGB')\n            pixel_values = self.image_processor([image], return_tensors='pt').pixel_values\n            outputs.append(\n                dict(\n                    pixel_values=pixel_values,\n                    image_size=image.size,\n                    # refer to https://github.com/deepseek-ai/DeepSeek-VL/blob/main/deepseek_vl/models/processing_vlm.py  # noqa\n                    # which is hardcoded 576\n                    image_tokens=576,\n                    image_token_id=self.image_token_id))\n        messages.append(dict(role='preprocess', content=outputs))\n        return messages\n\n    @torch.no_grad()\n    def forward(self, messages: List[Dict], max_batch_size: int = 1) -> List[Dict]:\n        \"\"\"Extract image feature. ONLY implement it when the backend is\n        turbomind engine.\n\n        Args:\n            messages(List[Dict]): the outputs of `preprocess`\n            max_batch_size(int): the max batch size when forwarding vision\n                model\n        Return:\n            the message list with forwarding results included\n        \"\"\"\n        inputs = [x['content'] for x in messages if x['role'] == 'preprocess']\n        inputs = inputs[0]\n        outputs = []\n        for idx in range(0, len(inputs), max_batch_size):\n            pixel_values = [x['pixel_values'] for x in inputs[idx:idx + max_batch_size]]\n            pixel_values = torch.cat(pixel_values, dim=0)\n            pixel_values = pixel_values.to(device=next(self.vision_model.parameters()).device, dtype=torch.float16)\n            # [b x n_images, T2, D]\n            logger.info(f'vision forward shape: {pixel_values.shape}')\n            feats = self.aligner(self.vision_model(pixel_values))\n            feats = torch.split(feats, 1, dim=0)\n            outputs.extend([x.squeeze() for x in feats])\n        messages.append(dict(role='forward', content=outputs))\n        return messages\n\n    @staticmethod\n    def proc_messages(messages, chat_template, sequence_start):\n        # apply chat template to get the prompt\n        prompt_messages = []\n        IMAGE_TOKEN = '<IMAGE_TOKEN>'\n        for message in messages:\n            if isinstance(message['content'], str):\n                prompt_messages.append(message)\n                continue\n            elif message['role'] in ['images', 'preprocess', 'forward']:\n                continue\n            content = [x.get('text', '') for x in message['content'] if x['type'] == 'text']\n            content = content[0]\n            n_image = sum([1 for x in message['content'] if x['type'] == 'image'])\n            n_placeholder = content.count(IMAGE_TOKEN)\n            if n_placeholder == 0:\n                logger.warning(f\"\"\"for deepseek-vl model, the user should insert the {IMAGE_TOKEN}\n                    to user prompt manually, please read https://lmdeploy.readthedocs.io/en/latest/inference/vl_pipeline.html\n                    for more details.\"\"\")  # noqa\n            if n_placeholder != 0 and n_placeholder != n_image:\n                logger.error(f'unmatched placeholder and image: {n_placeholder} vs '\n                             f'{n_image}. Ignore the placeholder')\n                content = content.replace(IMAGE_TOKEN, '')\n                n_placeholder = 0\n            if n_placeholder == 0:\n                if n_image == 1:\n                    content = f'{IMAGE_TOKEN}{content}'\n                else:\n                    content = ''.join([f'{IMAGE_TOKEN} is Figure {str(i)}.\\n' for i in range(n_image)]) + content\n            prompt_messages.append(dict(role='user', content=content))\n        prompt = chat_template.messages2prompt(prompt_messages, sequence_start)\n        return prompt, IMAGE_TOKEN\n\n    def to_pytorch(self, messages, chat_template, tokenizer, sequence_start, **kwargs):\n        prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, sequence_start)\n        return self.to_pytorch_aux(messages, prompt, IMAGE_TOKEN, tokenizer, sequence_start)\n\n    def to_turbomind(self, messages, chat_template, tokenizer, sequence_start, **kwargs):\n        prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, sequence_start)\n        return self.to_turbomind_aux(messages, prompt, IMAGE_TOKEN, tokenizer, sequence_start)\n"
  },
  {
    "path": "lmdeploy/vl/model/deepseek_vl2.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport os\nfrom contextlib import redirect_stdout\nfrom typing import Dict, List\n\nimport torch\nfrom transformers import AutoConfig\n\nfrom lmdeploy.utils import get_logger\nfrom lmdeploy.vl.model.base import VISION_MODELS, VisionModel\n\nlogger = get_logger('lmdeploy')\n\n\ndef check_deepseek_vl2_install():\n    \"\"\"Check deepseek_vl2 install.\"\"\"\n    try:\n        import deepseek_vl2  # noqa: F401\n    except ImportError:\n        raise ImportError('To use DeepSeek-VL2, please install deepseek_vl2 by '\n                          '`pip install git+https://github.com/deepseek-ai/DeepSeek-VL2.git'\n                          ' --no-deps`')\n\n\ndef check_trans_version():\n    \"\"\"Check if the installed version of the 'transformers' library is smaller\n    than the specified version.\"\"\"\n    import transformers\n    from packaging import version\n\n    max_version = '4.48.0'\n    installed_version = transformers.__version__\n    assert version.parse(installed_version) < version.parse(\n        max_version\n    ), f'deepseek_vl2 requires transformers version < 4.48.0, but found version: {installed_version}. Please downgrade.'\n\n\n@VISION_MODELS.register_module()\nclass DeepSeek2VisionModel(VisionModel):\n    \"\"\"DeepSeek2 vision model.\"\"\"\n\n    _arch = 'DeepseekV2ForCausalLM'\n\n    @classmethod\n    def match(cls, config: AutoConfig):\n        \"\"\"Check whether the config match the model.\"\"\"\n        if hasattr(config, 'language_config') and hasattr(config, 'vision_config'):\n            arch = config.language_config.get('architectures', [None])[0]\n            return arch == cls._arch\n        return False\n\n    def build_preprocessor(self):\n        check_trans_version()\n        check_deepseek_vl2_install()\n        from deepseek_vl2.models.processing_deepseek_vl_v2 import DeepseekVLV2Processor\n\n        # suppress deepseek-vl2 processor initialization print logs\n        with open(os.devnull, 'w') as devnull:\n            with redirect_stdout(devnull):\n                self.image_processor = DeepseekVLV2Processor.from_pretrained(self.model_path,\n                                                                             image_token='<IMAGE_TOKEN>')\n                self.image_token_id = self.image_processor.image_token_id\n\n    def build_model(self):\n        \"\"\"Build the vision part of a VLM model when backend is turbomind, or\n        load the whole VLM model when `self.with_llm==True`\"\"\"\n        # TODO, implement for tubomind engine\n        raise NotImplementedError()\n\n    def preprocess(self, messages: List[Dict]) -> List[Dict]:\n        \"\"\"Refers to the spec of `super.preprocess()\"\"\"\n        images = self.collect_multimodal_items(messages)\n\n        # convert to upstream api formats\n        images = [item[1] for item in images]\n        formatted_messages = []\n        for message in messages:\n            text_content = DeepSeek2VisionModel.proc_single_message(message)\n            image_content = [x['image'] for x in message['content'] if x['type'] == 'image']\n            formatted_messages.append(dict(role=message['role'], content=text_content, images=image_content))\n\n        # NOTE: DeepseekVLV2Processor inputs\n        # conversations (List[Dict]): conversations with a list of messages;\n        # images (List[ImageType]): the list of images;\n        # force_batchify (bool): force batchify the inputs;\n        # inference_mode (bool): if True, then remove the last eos token;\n        prepare = self.image_processor(conversations=formatted_messages,\n                                       images=images,\n                                       force_batchify=False,\n                                       inference_mode=False)\n\n        messages.append(\n            dict(role='preprocess',\n                 content=[\n                     dict(\n                         pixel_values=prepare.images,\n                         image_tokens=prepare.num_image_tokens[0],\n                         image_token_id=self.image_processor.image_token_id,\n                         image_size=self.image_processor.image_size,\n                         images_spatial_crop=prepare.images_spatial_crop,\n                     )\n                 ]))\n        return messages\n\n    @torch.no_grad()\n    def forward(self, messages: List[Dict], max_batch_size: int = 1) -> List[Dict]:\n        \"\"\"Extract image feature. ONLY implement it when the backend is\n        turbomind engine.\n\n        Args:\n            messages(List[Dict]): the outputs of `preprocess`\n            max_batch_size(int): the max batch size when forwarding vision\n                model\n        Return:\n            the message list with forwarding results included\n        \"\"\"\n        # TODO, implement for turbomind engine\n        raise NotImplementedError()\n\n    @staticmethod\n    def proc_single_message(message):\n        IMAGE_TOKEN = '<IMAGE_TOKEN>'\n\n        if isinstance(message['content'], str):\n            return message\n        elif message['role'] in ['images', 'preprocess', 'forward']:\n            return None\n\n        content = [x.get('text', '') for x in message['content'] if x['type'] == 'text']\n        content = content[0]\n        n_image = sum([1 for x in message['content'] if x['type'] == 'image'])\n        n_placeholder = content.count(IMAGE_TOKEN)\n        if n_placeholder == 0:\n            logger.warning(f\"\"\"for deepseek-vl2 model, the user should insert the {IMAGE_TOKEN}\n                to user prompt manually, please read https://lmdeploy.readthedocs.io/en/latest/inference/vl_pipeline.html\n                for more details.\"\"\")  # noqa\n        if n_placeholder != 0 and n_placeholder != n_image:\n            logger.error(f'unmatched placeholder and image: {n_placeholder} vs '\n                         f'{n_image}. Ignore the placeholder')\n            content = content.replace(IMAGE_TOKEN, '')\n            n_placeholder = 0\n        if n_placeholder == 0:\n            if n_image == 1:\n                content = f'{IMAGE_TOKEN}{content}'\n            else:\n                content = ''.join([f'{IMAGE_TOKEN} is Figure {str(i)}.\\n' for i in range(n_image)]) + content\n        return content\n\n    @staticmethod\n    def proc_messages(messages, chat_template, sequence_start):\n        \"\"\"Apply chat template to get the prompt.\"\"\"\n        prompt_messages = []\n        IMAGE_TOKEN = '<IMAGE_TOKEN>'\n        for message in messages:\n            content = DeepSeek2VisionModel.proc_single_message(message)\n            if content is None:\n                continue\n            prompt_messages.append(dict(role='user', content=content))\n        prompt = chat_template.messages2prompt(prompt_messages, sequence_start)\n        return prompt, IMAGE_TOKEN\n\n    def to_pytorch(self, messages, chat_template, tokenizer, sequence_start, **kwargs):\n        prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, sequence_start)\n        return self.to_pytorch_aux(messages, prompt, IMAGE_TOKEN, tokenizer, sequence_start)\n\n    def to_turbomind(self, messages, chat_template, tokenizer, sequence_start, **kwargs):\n        prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, sequence_start)\n        return self.to_turbomind_aux(messages, prompt, IMAGE_TOKEN, tokenizer, sequence_start)\n"
  },
  {
    "path": "lmdeploy/vl/model/gemma3_vl.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom typing import Dict, List, Optional\n\nimport torch\nfrom transformers import AutoConfig, AutoProcessor\nfrom transformers.processing_utils import ImagesKwargs, ProcessingKwargs\n\nfrom lmdeploy.utils import get_logger\nfrom lmdeploy.vl.model.base import VISION_MODELS, VisionModel\n\nlogger = get_logger('lmdeploy')\n\n\nclass Gemma3ImagesKwargs(ImagesKwargs):\n    do_pan_and_scan: Optional[bool]\n    pan_and_scan_min_crop_size: Optional[int]\n    pan_and_scan_max_num_crops: Optional[int]\n    pan_and_scan_min_ratio_to_activate: Optional[float]\n    do_convert_rgb: Optional[bool]\n\n\nclass Gemma3ProcessorKwargs(ProcessingKwargs, total=False):\n    images_kwargs: Gemma3ImagesKwargs\n    _defaults = {\n        'text_kwargs': {\n            'padding': False,\n        },\n        'images_kwargs': {\n            'do_pan_and_scan': False,\n            'pan_and_scan_min_crop_size': 256,\n            'pan_and_scan_max_num_crops': 4,\n            'pan_and_scan_min_ratio_to_activate': 1.2,\n        },\n    }\n\n\n@VISION_MODELS.register_module()\nclass Gemma3VisionModel(VisionModel):\n    \"\"\"Gemma3 vision model.\"\"\"\n\n    _arch = 'Gemma3ForConditionalGeneration'\n\n    def __init__(self,\n                 model_path: str,\n                 with_llm: bool = False,\n                 max_memory: Dict[int, int] = None,\n                 hf_config: AutoConfig = None,\n                 backend: str = ''):\n        super().__init__(model_path, with_llm, max_memory, hf_config, backend)\n\n    def build_preprocessor(self):\n        self.processor = AutoProcessor.from_pretrained(self.model_path)\n        tokenizer = self.processor.tokenizer\n        self.image_token_id = tokenizer.encode(tokenizer.image_token)[-1]\n        self.image_tokens = self.processor.image_seq_length\n        self.tokenizer_init_kwargs = tokenizer.init_kwargs\n\n    def build_model(self):\n        \"\"\"Build the vision part of a VLM model when backend is turbomind, or\n        load the whole VLM model when `self.with_llm==True`\"\"\"\n        # TODO, implement for tubomind engine\n        raise NotImplementedError()\n\n    def preprocess(self, messages: List[Dict]) -> List[Dict]:\n        \"\"\"Refers to `super.preprocess() for spec.\"\"\"\n        from transformers.image_utils import make_nested_list_of_images\n        output_kwargs = self.processor._merge_kwargs(\n            Gemma3ProcessorKwargs,\n            tokenizer_init_kwargs=self.tokenizer_init_kwargs,\n            **{\n                'return_tensors': 'pt',\n                'add_special_tokens': False\n            },\n        )\n        images = self.collect_multimodal_items(messages)\n        images = [image.convert('RGB') for modality, image, _ in images]\n        num_image = len(images)\n        images = make_nested_list_of_images(images)\n        image_inputs = self.processor.image_processor(images, **output_kwargs['images_kwargs'])\n        outputs = []\n        for idx in range(num_image):\n            pixel_values = image_inputs['pixel_values'][idx:idx + 1, ...]\n            num_crops = image_inputs['num_crops'][:idx:idx + 1]\n            data = dict(pixel_values=pixel_values,\n                        num_crops=num_crops,\n                        image_tokens=self.image_tokens,\n                        image_token_id=self.image_token_id)\n            outputs.append(data)\n\n        messages.append(dict(role='preprocess', content=outputs))\n        return messages\n\n    @torch.no_grad()\n    def forward(self, messages: List[Dict], max_batch_size: int = 1) -> List[Dict]:\n        \"\"\"Extract image feature. ONLY implement it when the backend is\n        turbomind engine.\n\n        Args:\n            messages(List[Dict]): the outputs of `preprocess`\n            max_batch_size(int): the max batch size when forwarding vision\n                model\n        Return:\n            the message list with forwarding results included\n        \"\"\"\n        # TODO, implement for turbomind engine\n        raise NotImplementedError()\n\n    @staticmethod\n    def proc_messages(messages, chat_template, sequence_start):\n        \"\"\"Apply chat template to get the prompt.\"\"\"\n        prompt_messages = []\n        IMAGE_TOKEN = '<IMAGE_TOKEN>'\n        for message in messages:\n            if isinstance(message['content'], str):\n                prompt_messages.append(message)\n                continue\n            elif message['role'] in ['images', 'preprocess', 'forward']:\n                continue\n            n_images = len([1 for x in message['content'] if x['type'] == 'image'])\n            content = [item['text'] for item in message['content'] if item['type'] == 'text']\n            prompt = ('\\n\\n' + IMAGE_TOKEN + '\\n\\n') * n_images + content[0]\n            prompt_messages.append(dict(role='user', content=prompt))\n        prompt = chat_template.messages2prompt(prompt_messages, sequence_start)\n        return prompt, IMAGE_TOKEN\n\n    def to_pytorch(self, messages, chat_template, tokenizer, sequence_start, **kwargs):\n        prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, sequence_start)\n        return self.to_pytorch_aux(messages, prompt, IMAGE_TOKEN, tokenizer, sequence_start)\n\n    def to_turbomind(self, messages, chat_template, tokenizer, sequence_start, **kwargs):\n        prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, sequence_start)\n        return self.to_turbomind_aux(messages, prompt, IMAGE_TOKEN, tokenizer, sequence_start)\n"
  },
  {
    "path": "lmdeploy/vl/model/glm4_1v.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom typing import Dict, List\n\nfrom transformers import AutoConfig\n\nfrom lmdeploy.utils import get_logger\nfrom lmdeploy.vl.model.base import VISION_MODELS, VisionModel\n\nlogger = get_logger('lmdeploy')\n\n\n@VISION_MODELS.register_module()\nclass GLM4_1_VisionModel(VisionModel):\n    \"\"\"GLM-4.1V-9B-Thinking model.\"\"\"\n\n    _arch = ['Glm4vForConditionalGeneration']\n\n    @classmethod\n    def match(cls, config: AutoConfig):\n        \"\"\"Check whether the config match the model.\"\"\"\n        arch = config.architectures[0] if config.architectures else None\n        if arch in cls._arch and hasattr(config, 'vision_config'):\n            return True\n        return False\n\n    def build_preprocessor(self):\n        from transformers import AutoProcessor\n        self.processor = AutoProcessor.from_pretrained(self.model_path)\n        tokenizer = self.processor.tokenizer\n        image_token = self.processor.image_token\n        self.image_token_id = tokenizer.encode(image_token)[-1]\n\n    def build_model(self):\n        raise NotImplementedError('turbomind has not supported glm4v yet')\n\n    def preprocess(self, messages: List[Dict]) -> List[Dict]:\n        \"\"\"Refer to `super().preprocess()` for spec.\"\"\"\n        images = self.collect_multimodal_items(messages)\n        optional_keys = {'resized_height', 'resized_width', 'min_pixels', 'max_pixels'}\n        outputs = []\n        for modality, image, params in images:\n            image = image.convert('RGB')\n\n            item = dict(type='image', image=image)\n            item.update({key: params[key] for key in params.keys() if key in optional_keys})\n            result = self.processor.image_processor(images=image, videos=None, return_tensors='pt')\n            merge_length = self.processor.image_processor.merge_size**2\n            image_tokens = result['image_grid_thw'].prod(dim=1) // merge_length\n            result.update(dict(image_size=image.size, image_tokens=image_tokens, image_token_id=self.image_token_id))\n            outputs.append(result)\n        messages.append(dict(role='preprocess', content=outputs))\n        return messages\n\n    @staticmethod\n    def proc_messages(messages, chat_template, sequence_start):\n        \"\"\"Apply chat template to get the prompt.\"\"\"\n        prompt_messages = []\n        IMAGE_TOKEN = '<IMAGE_TOKEN>'\n        for message in messages:\n            if isinstance(message['content'], str):\n                prompt_messages.append(message)\n                continue\n            elif message['role'] in ['images', 'preprocess', 'forward']:\n                continue\n            n_images = len([1 for x in message['content'] if x['type'] == 'image'])\n            content = [item['text'] for item in message['content'] if item['type'] == 'text']\n            prompt = content[0]\n            if IMAGE_TOKEN in prompt and '<|begin_of_image|>' not in prompt:\n                prompt = prompt.replace(IMAGE_TOKEN, f'<|begin_of_image|>{IMAGE_TOKEN}<|end_of_image|>')\n            else:\n                prompt = f'<|begin_of_image|>{IMAGE_TOKEN}<|end_of_image|>' * \\\n                    n_images + prompt\n            prompt_messages.append(dict(role=message['role'], content=prompt))\n        prompt = chat_template.messages2prompt(prompt_messages, sequence_start)\n        return prompt, IMAGE_TOKEN\n\n    def to_pytorch(self, messages, chat_template, tokenizer, sequence_start, **kwargs):\n        prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, sequence_start)\n        return self.to_pytorch_aux(messages, prompt, IMAGE_TOKEN, tokenizer, sequence_start)\n"
  },
  {
    "path": "lmdeploy/vl/model/glm4_v.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom typing import Dict, List\n\nfrom transformers import AutoConfig\n\nfrom lmdeploy.utils import get_logger\nfrom lmdeploy.vl.model.base import VISION_MODELS, VisionModel\n\nlogger = get_logger('lmdeploy')\n\n\n@VISION_MODELS.register_module()\nclass GLM4VisionModel(VisionModel):\n    \"\"\"Glm-4v-9b vision model.\"\"\"\n\n    _arch = ['ChatGLMModel', 'ChatGLMForConditionalGeneration']\n\n    @classmethod\n    def match(cls, config: AutoConfig):\n        \"\"\"Check whether the config match the model.\"\"\"\n        arch = config.architectures[0] if config.architectures else None\n        if arch in cls._arch and hasattr(config, 'vision_config'):\n            return True\n        return False\n\n    def build_preprocessor(self):\n        from torchvision import transforms\n        self.image_transform = transforms.Compose([\n            transforms.Resize((self.hf_config.vision_config['image_size'], ) * 2,\n                              interpolation=transforms.InterpolationMode.BICUBIC),\n            transforms.ToTensor(),\n            transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),\n        ])\n        image_size = self.hf_config.vision_config['image_size']\n        patch_size = self.hf_config.vision_config['patch_size']\n        self.n_token_per_image = 2 + (image_size // patch_size // 2)**2\n\n    def build_model(self):\n        if self.with_llm:\n            from transformers import AutoModelForCausalLM\n            self.vl_model = AutoModelForCausalLM.from_pretrained(self.model_path,\n                                                                 device_map='cpu',\n                                                                 trust_remote_code=True)\n        else:\n            raise NotImplementedError('turbomind has not supported glm4v yet')\n\n    def preprocess(self, messages: List[Dict]) -> List[Dict]:\n        \"\"\"Refers to the spec of `super.preprocess()\"\"\"\n        outputs = []\n        for message in messages:\n            if not isinstance(message['content'], List):\n                continue\n            images = [x['image'] for x in message['content'] if x['type'] == 'image']\n            if len(images) > 1:\n                logger.warning(f'glm4v does not support the input of multiple images'\n                               f' in a single chat round, but got {len(images)} images.')\n            # we still pass all the images to the model and let the\n            # model decide what to do\n            images = [x.convert('RGB') for x in images]\n            pixel_values = [self.image_transform(x) for x in images]\n            outputs.extend([\n                dict(pixel_values=_2,\n                     image_size=_1.size,\n                     image_tokens=self.n_token_per_image,\n                     image_token_id=self.image_token_id) for _1, _2 in zip(images, pixel_values)\n            ])\n        messages.append(dict(role='preprocess', content=outputs))\n        return messages\n\n    @staticmethod\n    def proc_messages(messages, chat_template, sequence_start):\n        \"\"\"Apply chat template to get the prompt.\"\"\"\n        prompt_messages = []\n        IMAGE_TOKEN = '<IMAGE_TOKEN>'\n        for message in messages:\n            content = message['content']\n            if isinstance(content, str):\n                prompt_messages.append(message)\n                continue\n            elif message['role'] in ['preprocess', 'forward']:\n                continue\n            prompt = [x['text'] for x in content if x['type'] == 'text']\n            n_images = len([1 for x in content if x['type'] == 'image'])\n            prompt = ''.join([f'{IMAGE_TOKEN}\\n'] * n_images) + prompt[0]\n            prompt_messages.append(dict(role='user', content=prompt))\n        prompt = chat_template.messages2prompt(prompt_messages, sequence_start)\n        return prompt, IMAGE_TOKEN\n\n    def to_pytorch(self, messages, chat_template, tokenizer, sequence_start, **kwargs):\n        prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, sequence_start)\n        return self.to_pytorch_aux(messages, prompt, IMAGE_TOKEN, tokenizer, sequence_start)\n"
  },
  {
    "path": "lmdeploy/vl/model/interns1_pro.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom typing import Any, Dict, List, Optional\n\nimport numpy as np\nimport torch\nfrom transformers import AutoProcessor\n\nfrom lmdeploy.utils import get_logger\nfrom lmdeploy.vl.constants import Modality\nfrom lmdeploy.vl.model.base import VISION_MODELS, VisionModel\n\nlogger = get_logger('lmdeploy')\n\n\ndef check_transformers():\n    try:\n        from transformers import Qwen3VLForConditionalGeneration, Qwen3VLMoeForConditionalGeneration  # noqa: F401\n    except ImportError:\n        raise ImportError('please install latest transformers by '\n                          'pip install git+https://github.com/huggingface/transformers.git')\n\n\n@VISION_MODELS.register_module()\nclass InternS1ProVisionModel(VisionModel):\n    \"\"\"InternS1Pro model.\n\n    Basically the same preprocessing as Qwen3VL, but with Time Series support.\n    \"\"\"\n\n    _arch = ['InternS1ProForConditionalGeneration', 'InternS1_1_ForConditionalGeneration']\n\n    def build_preprocessor(self):\n        check_transformers()\n        self.processor = AutoProcessor.from_pretrained(self.model_path, trust_remote_code=True)\n\n        # image tokens\n        self.image_token = self.processor.image_token\n        self.image_token_id = self.processor.image_token_id\n\n        # video tokens\n        self.video_token = self.processor.video_token\n        self.video_token_id = self.processor.video_token_id\n\n        # time series tokens\n        self.ts_token = getattr(self.processor, 'ts_token', None)\n        self.ts_token_id = getattr(self.processor, 'ts_token_id', None)\n\n        # vision start and end tokens\n        self.vision_start_token = self.processor.vision_start_token\n        self.vision_end_token = self.processor.vision_end_token\n\n    def get_processor_args(self, mm_processor_kwargs: Optional[Dict[str, Any]] = None):\n        min_pixels = self.processor.image_processor.size['shortest_edge']\n        max_pixels = self.processor.image_processor.size['longest_edge']\n\n        if mm_processor_kwargs is None:\n            return min_pixels, max_pixels\n\n        input_min_pixels = mm_processor_kwargs.get('min_pixels', None)\n        input_max_pixels = mm_processor_kwargs.get('max_pixels', None)\n\n        # boundary check for min_pixels and max_pixels\n        if input_min_pixels is None:\n            if input_max_pixels is not None:\n                # only max_pixels is given in the input\n                if input_max_pixels < min_pixels:\n                    logger.warning(\n                        f'input max_pixels {input_max_pixels} < default min_pixels {min_pixels}, fall back to default.')\n                    return min_pixels, max_pixels\n                max_pixels = input_max_pixels\n        else:\n            if input_max_pixels is None:\n                # only min_pixels is given in the input\n                if input_min_pixels > max_pixels:\n                    logger.warning(\n                        f'input min_pixels {input_min_pixels} > default max_pixels {max_pixels}, fall back to default.')\n                    return min_pixels, max_pixels\n            else:\n                if input_min_pixels > input_max_pixels:\n                    logger.warning(\n                        f'input min_pixels {input_min_pixels} > max_pixels {input_max_pixels}, fall back to default.')\n                    return min_pixels, max_pixels\n                max_pixels = input_max_pixels\n            min_pixels = input_min_pixels\n\n        return min_pixels, max_pixels\n\n    def check_time_series_input(self, messages):\n        has_time_series_input = any(\n            isinstance(message['content'], list) and any(item['type'] == 'time_series' for item in message['content'])\n            for message in messages)\n        self.has_time_series_input = has_time_series_input\n\n    def _preprocess_image(self,\n                          data: List[Any],\n                          params: Dict[str, Any],\n                          mm_processor_kwargs: Dict[str, Any] | None = None) -> List[Dict]:\n\n        image = data.convert('RGB')\n        min_pixels, max_pixels = self.get_processor_args(mm_processor_kwargs)\n\n        result = self.processor.image_processor(images=image,\n                                                size={\n                                                    'shortest_edge': min_pixels,\n                                                    'longest_edge': max_pixels\n                                                },\n                                                return_tensors='pt')\n        merge_length = self.processor.image_processor.merge_size**2\n        image_tokens = result['image_grid_thw'].prod(dim=1) // merge_length\n        result.update(dict(image_size=image.size, image_tokens=image_tokens, image_token_id=self.image_token_id))\n        return result\n\n    def _preprocess_video(self,\n                          data: List[Any],\n                          params: Dict[str, Any],\n                          mm_processor_kwargs: Dict[str, Any] | None = None) -> List[Dict]:\n\n        # TODO: zhouxinyu, apply transformers smart_resize using per-request kwargs\n        metadata = params['video_metadata']\n        video_kwargs = dict(return_metadata=True,\n                            do_resize=True,\n                            do_sample_frames=False,\n                            video_metadata=metadata,\n                            return_tensors='pt')\n        result = self.processor.video_processor(videos=data, **video_kwargs)\n        video_grid_thw = result['video_grid_thw']\n\n        merge_length = self.processor.video_processor.merge_size**2\n        if metadata.get('fps') is None:\n            logger.warning_once('Qwen3VL: fps not found, defaulting to 24.')\n            metadata['fps'] = metadata['fps'] or 24\n\n        # if timestamps are not provided, calculate them\n        curr_timestamp = self.processor._calculate_timestamps(\n            metadata['frames_indices'],\n            metadata['fps'],\n            self.processor.video_processor.merge_size,\n        )\n\n        frame_seqlen = video_grid_thw[0][1:].prod() // merge_length\n        result.update(curr_timestamp=curr_timestamp, frame_seqlen=frame_seqlen, video_token_id=self.video_token_id)\n        return result\n\n    def _preprocess_time_series(self,\n                                data: List[Any],\n                                params: Dict[str, Any],\n                                mm_processor_kwargs: Dict[str, Any] | None = None) -> List[Dict]:\n\n        ts_input = data\n        sr = params.get('sampling_rate') if params is not None else None\n\n        if not isinstance(ts_input, np.ndarray):\n            ts_input = np.array(ts_input, dtype=np.float32)\n\n        mean = ts_input.mean(axis=0, keepdims=True)\n        std = ts_input.std(axis=0, keepdims=True)\n        ts_input = (ts_input - mean) / (std + 1e-8)\n\n        # truncate to 240k to avoid OOM\n        max_ts_len = 240000\n        if len(ts_input) > max_ts_len:\n            ts_input = ts_input[:max_ts_len]\n\n        if ts_input.ndim == 1:\n            ts_input = ts_input[:, None]  # [T,C]\n\n        ts_len = ts_input.shape[0]\n\n        # set the default value to ts_len / 4 if sr is not provided or invalid\n        if sr is None or sr <= 0:\n            sr = max(ts_len / 4, 1.0)\n\n        # compute num ts tokens\n        stride = np.floor(160 / ((1 + np.exp(-sr / 100))**6))\n        patch_size = stride * 2\n        embed_length = (np.ceil((ts_len - patch_size) / stride) + 1)\n        ts_tokens = int((embed_length // 2 + 1) // 2)\n\n        return dict(ts_values=[ts_input],\n                    ts_sr=[sr],\n                    ts_lens=[ts_len],\n                    ts_tokens=[ts_tokens],\n                    ts_token_id=self.ts_token_id)\n\n    def preprocess(self, messages: List[Dict], mm_processor_kwargs: Dict[str, Any] | None = None) -> List[Dict]:\n        \"\"\"Refer to `super().preprocess()` for spec.\"\"\"\n        outputs = []\n        self.contains_video_input = False\n        self.contains_ts_input = False\n\n        mm_items = self.collect_multimodal_items(messages)\n        for modality, data, params in mm_items:\n            result = {}\n            if modality == Modality.IMAGE:\n                result = self._preprocess_image(data, params, mm_processor_kwargs)\n            elif modality == Modality.VIDEO:\n                self.contains_video_input = True\n                result = self._preprocess_video(data, params, mm_processor_kwargs)\n            elif modality == Modality.TIME_SERIES:\n                self.contains_ts_input = True\n                result = self._preprocess_time_series(data, params, mm_processor_kwargs)\n\n            result.update(modality=modality)\n            outputs.append(result)\n\n        messages.append(dict(role='preprocess', content=outputs))\n        return messages\n\n    def proc_messages(self,\n                      messages,\n                      chat_template,\n                      sequence_start,\n                      tools: Optional[List[object]] = None,\n                      chat_template_kwargs=None):\n        \"\"\"Apply chat template to get the prompt.\"\"\"\n        chat_template_kwargs = chat_template_kwargs or {}\n        prompt_messages = []\n        IMAGE_TOKEN = '<IMAGE_TOKEN>'\n        messages = [x for x in messages if x['role'] not in ['preprocess', 'forward']]\n\n        if VisionModel.IMAGE_TOKEN_included(messages):\n            # backward compatibility\n            for message in messages:\n                role, content = message['role'], message['content']\n                if role != 'user' or isinstance(content, str):\n                    prompt_messages.append(message)\n                    continue\n                content = [x['text'] for x in content if x['type'] == 'text']\n                prompt = ''.join(content)\n                prompt = prompt.replace(IMAGE_TOKEN, f'<|vision_start|>{self.image_token}<|vision_end|>')\n                prompt_messages.append(dict(role='user', content=prompt))\n        else:\n            prompt_messages = messages\n\n        # time series input requires enabling_thinking = False\n        if self.contains_ts_input:\n            chat_template_kwargs['enable_thinking'] = False\n\n        prompt = chat_template.messages2prompt(prompt_messages, sequence_start, tools=tools, **chat_template_kwargs)\n        return prompt, None\n\n    def to_pytorch_aux_video(self, messages, prompt, VIDEO_TOKEN, tokenizer, sequence_start):\n        \"\"\"Pack the video input to the compatible format with pytorch\n        engine.\"\"\"\n\n        # collect all preprocessing result from messages\n        preps = [x['content'] for x in messages if x['role'] == 'preprocess']\n        assert len(preps) == 1\n        preps = preps[0]\n\n        # split prompt into segments and validate data\n        segs = prompt.split(self.vision_start_token + self.video_token + self.vision_end_token)\n        assert len(segs) == len(preps) + 1, (f'the number of {self.video_token} is not equal '\n                                             f'to input videos, {len(segs) - 1} vs {len(preps)}')\n\n        # calculate the video token offset for each video\n        input_ids = []\n        for i, seg in enumerate(segs):\n            if i > 0 and i <= len(preps):\n                preps[i - 1].update(offset=len(input_ids))\n                frame_seqlen = preps[i - 1]['frame_seqlen']\n                assert self.video_token_id == preps[i - 1]['video_token_id']\n\n                video_grid_thw = preps[i - 1]['video_grid_thw']\n                curr_timestamp = preps[i - 1]['curr_timestamp']\n\n                # update prompt with timestamp index tokens and video pad tokens\n                video_placeholder = ''\n                for frame_idx in range(video_grid_thw[0][0]):\n                    curr_time = curr_timestamp[frame_idx]\n                    video_placeholder += f'<{curr_time:.1f} seconds>'\n                    video_placeholder += (self.vision_start_token + '<|placeholder|>' * frame_seqlen +\n                                          self.vision_end_token)\n\n                video_placeholder = video_placeholder.replace('<|placeholder|>', self.video_token)\n                video_token_ids = tokenizer.encode(video_placeholder)\n                input_ids.extend(video_token_ids)\n\n                preps[i - 1].update(video_tokens=len(video_token_ids))\n\n            token_ids = tokenizer.encode(seg, add_bos=((i == 0) and sequence_start))\n            input_ids.extend(token_ids)\n\n        return dict(prompt=prompt, input_ids=input_ids, multimodal=preps)\n\n    def to_pytorch_aux_ts(self, messages, prompt, TS_TOKEN, tokenizer, sequence_start):\n        \"\"\"Pack the time series input to the compatible format with pytorch\n        engine.\"\"\"\n        # collect all preprocessing result from messages\n        preps = [x['content'] for x in messages if x['role'] == 'preprocess']\n        assert len(preps) == 1\n        preps = preps[0]\n\n        # split prompt into segments and validate data\n        segs = prompt.split(TS_TOKEN)\n        assert len(segs) == len(preps) + 1, (f'the number of {TS_TOKEN} is not equal '\n                                             f'to input time series data, {len(segs) - 1} vs {len(preps)}')\n\n        input_ids = []\n        for i, seg in enumerate(segs):\n            if i > 0 and i <= len(preps):\n                preps[i - 1].update(offset=len(input_ids))\n                ts_tokens = preps[i - 1]['ts_tokens']\n\n                ts_tokens = ts_tokens[0]\n                ts_array = np.array(preps[i - 1]['ts_values'])\n\n                preps[i - 1].update(ts_tokens=ts_tokens)\n                preps[i - 1].update(ts_values=torch.from_numpy(ts_array).to(dtype=torch.bfloat16))\n                preps[i - 1].update(ts_lens=torch.tensor(preps[i - 1]['ts_lens']))\n                preps[i - 1].update(ts_sr=torch.tensor(preps[i - 1]['ts_sr']))\n\n                assert self.ts_token_id == preps[i - 1]['ts_token_id']\n                input_ids.extend([self.ts_token_id] * ts_tokens)\n            token_ids = tokenizer.encode(seg, add_bos=((i == 0) and sequence_start))\n            input_ids.extend(token_ids)\n\n        return dict(prompt=prompt, input_ids=input_ids, multimodal=preps)\n\n    def to_pytorch(self,\n                   messages,\n                   chat_template,\n                   tokenizer,\n                   sequence_start,\n                   tools: Optional[List[object]] = None,\n                   chat_template_kwargs: Optional[Dict] = None,\n                   **kwargs):\n        \"\"\"Return to the information needed by pytorch engine.\"\"\"\n        prompt, _ = self.proc_messages(messages,\n                                       chat_template,\n                                       sequence_start,\n                                       tools=tools,\n                                       chat_template_kwargs=chat_template_kwargs)\n\n        if self.contains_video_input:\n            return self.to_pytorch_aux_video(messages, prompt, self.video_token, tokenizer, sequence_start)\n        elif self.contains_ts_input:\n            return self.to_pytorch_aux_ts(messages, prompt, self.ts_token, tokenizer, sequence_start)\n        else:\n            return self.to_pytorch_aux(messages, prompt, self.image_token, tokenizer, sequence_start)\n\n    def build_model(self):\n        # TODO: implement for turbomind\n        pass\n\n    @torch.no_grad()\n    def forward(self, messages: List[Dict], max_batch_size: int = 1) -> List[Dict]:\n        # TODO: implement for turbomind\n        pass\n\n    def to_turbomind(self,\n                     messages,\n                     chat_template,\n                     tokenizer,\n                     sequence_start,\n                     chat_template_kwargs: Optional[Dict] = None,\n                     **kwargs):\n        # TODO: implement for turbomind\n        pass\n"
  },
  {
    "path": "lmdeploy/vl/model/internvl.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom typing import Dict, List, Optional\n\nimport torch\nfrom transformers import AutoConfig, AutoModel, AutoTokenizer, CLIPImageProcessor\n\nfrom lmdeploy.utils import get_logger\nfrom lmdeploy.vl.model.base import VISION_MODELS, VisionModel\nfrom lmdeploy.vl.model.utils import disable_logging\n\nlogger = get_logger('lmdeploy')\n\n\ndef find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):\n    \"\"\"copy from https://huggingface.co/OpenGVLab/InternVL-Chat-V1-5.\"\"\"\n    best_ratio_diff = float('inf')\n    best_ratio = (1, 1)\n    area = width * height\n    for ratio in target_ratios:\n        target_aspect_ratio = ratio[0] / ratio[1]\n        ratio_diff = abs(aspect_ratio - target_aspect_ratio)\n        if ratio_diff < best_ratio_diff:\n            best_ratio_diff = ratio_diff\n            best_ratio = ratio\n        elif ratio_diff == best_ratio_diff:\n            if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:\n                best_ratio = ratio\n    return best_ratio\n\n\ndef dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False):\n    \"\"\"copy from https://huggingface.co/OpenGVLab/InternVL-Chat-V1-5.\"\"\"\n    orig_width, orig_height = image.size\n    aspect_ratio = orig_width / orig_height\n\n    # calculate the existing image aspect ratio\n    target_ratios = set((i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1)\n                        if i * j <= max_num and i * j >= min_num)\n    target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])\n\n    # find the closest aspect ratio to the target\n    target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio, target_ratios, orig_width, orig_height, image_size)\n\n    # calculate the target width and height\n    target_width = image_size * target_aspect_ratio[0]\n    target_height = image_size * target_aspect_ratio[1]\n    blocks = target_aspect_ratio[0] * target_aspect_ratio[1]\n\n    # resize the image\n    resized_img = image.resize((target_width, target_height))\n    processed_images = []\n    for i in range(blocks):\n        box = ((i % (target_width // image_size)) * image_size, (i // (target_width // image_size)) * image_size,\n               ((i % (target_width // image_size)) + 1) * image_size,\n               ((i // (target_width // image_size)) + 1) * image_size)\n        # split the image\n        split_img = resized_img.crop(box)\n        processed_images.append(split_img)\n    assert len(processed_images) == blocks\n    if use_thumbnail and len(processed_images) != 1:\n        thumbnail_img = image.resize((image_size, image_size))\n        processed_images.append(thumbnail_img)\n    return processed_images\n\n\n@VISION_MODELS.register_module()\nclass InternVLVisionModel(VisionModel):\n    \"\"\"InternVL vision model.\"\"\"\n\n    _arch = 'InternVLChatModel'\n\n    def __init__(self,\n                 model_path: str,\n                 with_llm: bool = False,\n                 max_memory: Dict[int, int] = None,\n                 hf_config: AutoConfig = None,\n                 backend: str = ''):\n        super().__init__(model_path, with_llm, max_memory, hf_config, backend)\n        self.image_token = '<IMG_CONTEXT>'\n        tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, use_fast=False)\n        self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)\n\n    def build_preprocessor(self):\n        self.config = self.hf_config\n        dynamic_image_size = getattr(self.config, 'dynamic_image_size', False)\n        image_processor = None\n        try:\n            image_processor = CLIPImageProcessor.from_pretrained(self.model_path)\n        except OSError:\n            pass\n\n        if dynamic_image_size or image_processor is None:\n            logger.info('using InternVL-Chat-V1-5 vision preprocess')\n            MEAN = (0.485, 0.456, 0.406)\n            STD = (0.229, 0.224, 0.225)\n            import torchvision.transforms as T\n            from torchvision.transforms.functional import InterpolationMode\n            input_size = self.config.vision_config.image_size\n            self.transform = T.Compose([\n                T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),\n                T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),\n                T.ToTensor(),\n                T.Normalize(mean=MEAN, std=STD)\n            ])\n            self.processor = self._preprocess_v1_5\n            self._forward_func = self._forward_v1_5\n        else:\n            self.processor = self._preprocess\n            self.image_processor = image_processor\n            self._forward_func = self._forward\n\n        force_image_size = self.hf_config.force_image_size\n        patch_size = self.hf_config.vision_config.patch_size\n        downsample_ratio = self.hf_config.downsample_ratio\n        self.image_tokens_per_patch = int((force_image_size // patch_size)**2 * (downsample_ratio**2))\n\n    def build_model(self):\n        \"\"\"Build the vision part of a VLM model when backend is turbomind, or\n        load the whole VLM model when `self.with_llm==True`\"\"\"\n        from accelerate import init_empty_weights\n        with init_empty_weights():\n            # transformers below 4.37.0 may raise error about flash_attn\n            self.config.llm_config.attn_implementation = 'eager'\n            model = AutoModel.from_config(self.config, trust_remote_code=True)\n            self.vl_model = model\n            if not self.with_llm:\n                del model.language_model\n\n        model.half()\n        from accelerate import load_checkpoint_and_dispatch\n        with disable_logging():\n            load_checkpoint_and_dispatch(model=model,\n                                         checkpoint=self.model_path,\n                                         device_map='auto' if not self.with_llm else {'': 'cpu'},\n                                         max_memory=self.max_memory,\n                                         no_split_module_classes=['InternVisionEncoderLayer'],\n                                         dtype=torch.half)\n\n        # We need eval mode to freeze the weights in model, thus,\n        # avoid randomness in inference.\n        self.model = model.eval()\n\n    def _preprocess_v1_5(self, image, params=None):\n        image_res = {'low': 6, 'medium': 12, 'high': 24}\n        max_num = params.get('max_dynamic_patch')\n        if max_num is None or not isinstance(max_num, int):\n            res_key = params.get('detail', 'default')\n            max_num = image_res.get(res_key, self.config.max_dynamic_patch)\n        out = dynamic_preprocess(image,\n                                 min_num=self.config.min_dynamic_patch,\n                                 max_num=max_num,\n                                 image_size=self.config.vision_config.image_size,\n                                 use_thumbnail=self.config.use_thumbnail)\n        pixel_values = [self.transform(x) for x in out]\n        # (patch) x c x h x w\n        pixel_values = torch.stack(pixel_values)\n        return pixel_values\n\n    def _forward_v1_5(self, inputs, max_batch_size):\n        \"\"\"Forward for internvl-chat-v1-5.\"\"\"\n        assert all(x.get('pixel_values') is not None for x in inputs)\n        outputs = []\n        for idx in range(0, len(inputs), max_batch_size):\n            pixel_values = [x['pixel_values'] for x in inputs[idx:idx + max_batch_size]]\n            split = [x.shape[0] for x in pixel_values]\n            pixel_values = torch.cat(pixel_values, dim=0)\n            pixel_values = pixel_values.to(self.model.device, dtype=torch.float16)\n            logger.info(f'vision forward shape: {pixel_values.shape}')\n            feats = self.model.extract_feature(pixel_values)\n            feats = torch.split(feats, split, dim=0)\n            outputs.extend([x.reshape(-1, x.shape[-1]) for x in feats])\n        return outputs\n\n    def _preprocess(self, image, params=None):\n        \"\"\"Forward for internvl-chat-v1-1, internvl-chat-v1-2.\"\"\"\n        pixel_values = self.image_processor(images=image, return_tensors='pt').pixel_values\n        return pixel_values\n\n    def _forward(self, inputs, max_batch_size):\n        \"\"\"Forward for internvl-chat-v1-1, internvl-chat-v1-2.\"\"\"\n        assert all(x.get('pixel_values') is not None for x in inputs)\n        outputs = []\n        for idx in range(0, len(inputs), max_batch_size):\n            pixel_values = [x['pixel_values'] for x in inputs[idx:idx + max_batch_size]]\n            pixel_values = torch.cat(pixel_values, dim=0)\n            pixel_values = pixel_values.to(self.model.device, dtype=torch.float16)\n            logger.info(f'vision forward shape: {pixel_values.shape}')\n            feats = self.model.extract_feature(pixel_values)\n            feats = torch.split(feats, 1, dim=0)\n            outputs.extend([x.squeeze() for x in feats])\n        return outputs\n\n    def preprocess(self, messages: List[Dict]) -> List[Dict]:\n        \"\"\"Refers to `super.preprocess() for spec.\"\"\"\n        images = self.collect_multimodal_items(messages)\n        outputs = []\n        for modality, image, params in images:\n            image = image.convert('RGB')\n            pixel_values = self.processor(image, params)\n            image_tokens = (pixel_values.shape[0] * self.image_tokens_per_patch)\n            outputs.append(\n                dict(pixel_values=pixel_values,\n                     image_tokens=image_tokens,\n                     image_token_id=self.image_token_id,\n                     image_size=image.size))\n        messages.append(dict(role='preprocess', content=outputs))\n        return messages\n\n    @torch.no_grad()\n    def forward(self, messages: List[Dict], max_batch_size: int = 1) -> List[Dict]:\n        \"\"\"Extract image feature. ONLY implement it when the backend is\n        turbomind engine.\n\n        Args:\n            messages(List[Dict]): the outputs of `preprocess`\n            max_batch_size(int): the max batch size when forwarding vision\n                model\n        Return:\n            the message list with forwarding results included\n        \"\"\"\n        inputs = [x['content'] for x in messages if x['role'] == 'preprocess']\n        inputs = inputs[0]\n        outputs = self._forward_func(inputs, max_batch_size)\n        messages.append(dict(role='forward', content=outputs))\n        return messages\n\n    def proc_messages(\n        self,\n        messages,\n        chat_template,\n        sequence_start,\n        tools: Optional[List[object]] = None,\n        chat_template_kwargs: Optional[Dict] = None,\n    ):\n        chat_template_kwargs = chat_template_kwargs or {}\n        \"\"\"Apply chat template to get the prompt.\"\"\"\n        prompt_messages = []\n        IMAGE_TOKEN = '<IMAGE_TOKEN>'\n        messages = [x for x in messages if x['role'] not in ['preprocess', 'forward']]\n        if VisionModel.IMAGE_TOKEN_included(messages):\n            # backward compatibility\n            for message in messages:\n                role, content = message['role'], message['content']\n                if role != 'user' or isinstance(content, str):\n                    prompt_messages.append(message)\n                    continue\n                content = [x['text'] for x in content if x['type'] == 'text']\n                prompt = ''.join(content)\n                prompt = prompt.replace(f'{IMAGE_TOKEN}', f'<img>{self.image_token}</img>')\n                prompt_messages.append(dict(role='user', content=prompt))\n        else:\n            for message in messages:\n                role, content = message['role'], message['content']\n                if role != 'user' or isinstance(content, str):\n                    prompt_messages.append(message)\n                    continue\n                _content = []\n                for item in content:\n                    item_type = item['type']\n                    if item_type == 'text':\n                        _content.append(item['text'])\n                    elif item_type in ['image', 'image_url']:\n                        _content.append(f'<img>{self.image_token}</img>\\n')\n                    else:\n                        raise ValueError(f'Unsupported message type: {item[\"type\"]}')\n                prompt_messages.append(dict(role='user', content=''.join(_content)))\n        prompt = chat_template.messages2prompt(prompt_messages, sequence_start, tools=tools, **chat_template_kwargs)\n        return prompt, self.image_token\n\n    def to_pytorch(self,\n                   messages,\n                   chat_template,\n                   tokenizer,\n                   sequence_start,\n                   tools: Optional[List[object]] = None,\n                   chat_template_kwargs: Optional[Dict] = None,\n                   **kwargs):\n        prompt, IMAGE_TOKEN = self.proc_messages(messages,\n                                                 chat_template,\n                                                 sequence_start,\n                                                 tools=tools,\n                                                 chat_template_kwargs=chat_template_kwargs)\n        return self.to_pytorch_aux(messages, prompt, IMAGE_TOKEN, tokenizer, sequence_start)\n\n    def to_turbomind(self,\n                     messages,\n                     chat_template,\n                     tokenizer,\n                     sequence_start,\n                     tools: Optional[List[object]] = None,\n                     chat_template_kwargs: Optional[Dict] = None,\n                     **kwargs):\n        prompt, IMAGE_TOKEN = self.proc_messages(messages,\n                                                 chat_template,\n                                                 sequence_start,\n                                                 tools=tools,\n                                                 chat_template_kwargs=chat_template_kwargs)\n        return self.to_turbomind_aux(messages, prompt, IMAGE_TOKEN, tokenizer, sequence_start)\n"
  },
  {
    "path": "lmdeploy/vl/model/internvl3_hf.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom typing import Dict, List, Optional\n\nimport torch\nfrom transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoProcessor\nfrom transformers.processing_utils import ImagesKwargs, ProcessingKwargs\n\nfrom lmdeploy.utils import get_logger\nfrom lmdeploy.vl.model.internvl import VISION_MODELS, InternVLVisionModel\nfrom lmdeploy.vl.model.utils import disable_logging\n\nlogger = get_logger('lmdeploy')\n\n\nclass InternVLImagesKwargs(ImagesKwargs, total=False):\n    crop_to_patches: Optional[bool]\n    min_patches: Optional[int]\n    max_patches: Optional[int]\n\n\nclass InternVLProcessorKwargs(ProcessingKwargs, total=False):\n    images_kwargs: InternVLImagesKwargs\n    _defaults = {\n        'text_kwargs': {\n            'padding': False,\n        },\n        'images_kwargs': {\n            'crop_to_patches': True,\n        },\n        'videos_kwargs': {},\n    }\n\n\n@VISION_MODELS.register_module()\nclass InternVL3VisionModel(InternVLVisionModel):\n    \"\"\"Internvl3 vision model.\"\"\"\n\n    _arch = ['InternVLForConditionalGeneration', 'InternS1ForConditionalGeneration']\n\n    def __init__(self,\n                 model_path: str,\n                 with_llm: bool = False,\n                 max_memory: Dict[int, int] = None,\n                 hf_config: AutoConfig = None,\n                 backend: str = ''):\n        super().__init__(model_path, with_llm, max_memory, hf_config, backend)\n        self.arch = self.hf_config.architectures[0]\n\n    def build_preprocessor(self):\n        self.processor = AutoProcessor.from_pretrained(self.model_path, trust_remote_code=True)\n        tokenizer = self.processor.tokenizer\n        self.image_token = self.processor.image_token\n        self.image_token_id = tokenizer.context_image_token_id\n        self.image_tokens_per_patch = self.processor.image_seq_length\n        self.tokenizer_init_kwargs = tokenizer.init_kwargs\n\n    def build_model(self):\n        \"\"\"Build the vision part of a VLM model when backend is turbomind, or\n        load the whole VLM model when `self.with_llm==True`\"\"\"\n        from accelerate import init_empty_weights\n        with init_empty_weights():\n            if self.arch == 'InternVLForConditionalGeneration':\n                model = AutoModel.from_config(self.hf_config, trust_remote_code=True)\n                if not self.with_llm:\n                    del model.language_model\n            elif self.arch == 'InternS1ForConditionalGeneration':\n                model = AutoModelForCausalLM.from_config(self.hf_config, trust_remote_code=True)\n                if not self.with_llm:\n                    del model.model.language_model\n            else:\n                raise ValueError(f'unsupported model arch {self.arch}')\n\n        model.half()\n        from accelerate import load_checkpoint_and_dispatch\n        with disable_logging():\n            load_checkpoint_and_dispatch(model=model,\n                                         checkpoint=self.model_path,\n                                         device_map='auto' if not self.with_llm else {'': 'cpu'},\n                                         max_memory=self.max_memory,\n                                         no_split_module_classes=['InternVLVisionLayer', 'InternS1VisionLayer'],\n                                         dtype=torch.half)\n        # We need eval mode to freeze the weights in model, thus,\n        # avoid randomness in inference.\n        self.model = model.eval()\n\n    def preprocess(self, messages: List[Dict]) -> List[Dict]:\n        \"\"\"Refers to `super.preprocess() for spec.\"\"\"\n        from transformers.image_utils import make_flat_list_of_images\n        output_kwargs = self.processor._merge_kwargs(\n            InternVLProcessorKwargs,\n            tokenizer_init_kwargs=self.tokenizer_init_kwargs,\n            **{\n                'return_tensors': 'pt',\n                'add_special_tokens': False\n            },\n        )\n        images = self.collect_multimodal_items(messages)\n        images = [image.convert('RGB') for modality, image, _ in images]\n        num_image = len(images)\n        images = make_flat_list_of_images(images)\n        image_inputs = self.processor.image_processor(images, **output_kwargs['images_kwargs'])\n        image_num_patches = image_inputs.pop('num_patches').cpu().numpy().tolist()\n        image_pixel_values = image_inputs.pop('pixel_values')\n        outputs = []\n        cum_num_patches = 0\n        for idx in range(num_image):\n            cur_num_patches = image_num_patches[idx]\n            pixel_values = image_pixel_values[cum_num_patches:cum_num_patches + cur_num_patches, ...]\n            cum_num_patches += cur_num_patches\n            data = dict(pixel_values=pixel_values,\n                        image_tokens=self.image_tokens_per_patch * cur_num_patches,\n                        image_token_id=self.image_token_id)\n            outputs.append(data)\n\n        messages.append(dict(role='preprocess', content=outputs))\n        return messages\n\n    @torch.no_grad()\n    def forward(self, messages: List[Dict], max_batch_size: int = 1) -> List[Dict]:\n        \"\"\"Extract image feature. ONLY implement it when the backend is\n        turbomind engine.\n\n        Args:\n            messages(List[Dict]): the outputs of `preprocess`\n            max_batch_size(int): the max batch size when forwarding vision\n                model\n        Return:\n            the message list with forwarding results included\n        \"\"\"\n        inputs = [x['content'] for x in messages if x['role'] == 'preprocess']\n        inputs = inputs[0]\n        assert all(x.get('pixel_values') is not None for x in inputs)\n        outputs = []\n        for idx in range(0, len(inputs), max_batch_size):\n            pixel_values = [x['pixel_values'] for x in inputs[idx:idx + max_batch_size]]\n            split = [x.shape[0] for x in pixel_values]\n            pixel_values = torch.cat(pixel_values, dim=0)\n            pixel_values = pixel_values.to(self.model.device, dtype=torch.float16)\n            logger.info(f'vision forward shape: {pixel_values.shape}')\n            feats = self.model.get_image_features(\n                pixel_values,\n                vision_feature_layer=self.hf_config.vision_feature_layer,\n                vision_feature_select_strategy=self.hf_config.vision_feature_select_strategy,\n            )\n            feats = torch.split(feats, split, dim=0)\n            outputs.extend([x.reshape(-1, x.shape[-1]) for x in feats])\n        messages.append(dict(role='forward', content=outputs))\n        return messages\n"
  },
  {
    "path": "lmdeploy/vl/model/internvl_llava.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nimport warnings\nfrom contextlib import contextmanager\nfrom typing import Dict, List\n\nimport torch\nfrom transformers import AutoConfig, AutoModelForCausalLM\n\nfrom lmdeploy.utils import get_logger\nfrom lmdeploy.vl.model.llava import VISION_MODELS, LlavaVisionModel\nfrom lmdeploy.vl.model.utils import rewrite_ctx\n\nfrom .utils import disable_logging, disable_transformers_logging\n\nlogger = get_logger('lmdeploy')\n\n\ndef check_llava_install():\n    try:\n        from llava.model.multimodal_encoder.clip_encoder import InternVisionModel  # noqa: F401\n    except ImportError:\n        raise ImportError(\n            'To use LlavaVLModel, please install llava by '\n            '`pip install git+https://github.com/OpenGVLab/InternVL#subdirectory=internvl_chat_llava --no-deps`')\n\n\ndef _intern_vision_model__from_pretrained(vision_tower_name: str):\n    logger.info(f'init empty InternVisionModel: {vision_tower_name}')\n    from llava.model.multimodal_encoder.intern_vit_6b.modeling_intern_vit import InternVisionConfig, InternVisionModel\n    config = InternVisionConfig.from_pretrained(vision_tower_name)\n    model = InternVisionModel._from_config(config)\n    model.requires_grad_(False)\n    return model\n\n\ndef _intern_vl_model__from_pretrained(vision_tower_name: str):\n    logger.info(f'init empty InternVLModel: {vision_tower_name}')\n\n    from llava.model.multimodal_encoder.internvl_14b.modeling_internvl import InternVLConfig, InternVLModel\n\n    config = InternVLConfig.from_pretrained(vision_tower_name)\n    model = InternVLModel._from_config(config)\n    model.requires_grad_(False)\n    return model\n\n\n@contextmanager\ndef init_empty_vit():\n    \"\"\"Skip download vision model if possible.\"\"\"\n    origin_func_path = [\n        'llava.model.multimodal_encoder.intern_vit_6b.modeling_intern_vit.InternVisionModel.from_pretrained',  # noqa: E501\n        'llava.model.multimodal_encoder.internvl_14b.modeling_internvl.InternVLModel.from_pretrained',  # noqa: E501\n    ]\n    rewrite_func = [_intern_vision_model__from_pretrained, _intern_vl_model__from_pretrained]\n    with rewrite_ctx(origin_func_path, rewrite_func):\n        yield\n\n\n@VISION_MODELS.register_module()\nclass InternVLLlavaVisionModel(LlavaVisionModel):\n    \"\"\"Llava visual model.\"\"\"\n\n    @classmethod\n    def match(cls, config: AutoConfig):\n        \"\"\"Check whether the config match the model.\"\"\"\n        arch = config.architectures[0] if config.architectures else None\n        if arch == 'LlavaLlamaForCausalLM':\n            mm_vision_tower = getattr(config, 'mm_vision_tower', '')\n            if 'OpenGVLab' in mm_vision_tower:\n                return True\n        return False\n\n    def build_preprocessor(self):\n        return super().build_preprocessor()\n\n    def build_model(self):\n        \"\"\"Build the vision part of a VLM model when backend is turbomind, or\n        load the whole VLM model when `self.with_llm==True`\"\"\"\n        check_llava_install()\n        # currently, only support llava llama\n        from llava.model.language_model.llava_llama import LlavaConfig, LlavaLlamaForCausalLM  # noqa\n        self.config = LlavaConfig.from_pretrained(self.model_path)\n        assert self.config.model_type in ['llava', 'llava_llama'], \\\n            'currently, only support llava llama'\n\n        # init empty model, skip layer initialization\n        from accelerate import init_empty_weights\n        with init_empty_weights(), warnings.catch_warnings(), \\\n                disable_transformers_logging():\n            warnings.simplefilter('ignore')\n            self.config.quantization_config = {}  # disable vision part quantization\n            model = AutoModelForCausalLM.from_config(self.config, trust_remote_code=True)\n            self.vl_model = model\n            if not self.with_llm:\n                del model.lm_head\n                del model.model.embed_tokens\n                del model.model.layers\n                del model.model.norm\n\n            with init_empty_vit():\n                vision_tower = model.get_vision_tower()\n                vision_tower.is_loaded = False\n                vision_tower.load_model()\n            crop_size = vision_tower.image_processor.crop_size['height']\n            image_size = vision_tower.config.image_size\n            patch_size = vision_tower.config.patch_size\n            if crop_size != image_size:\n                vision_tower.vision_tower.resize_pos_embeddings(image_size, crop_size, patch_size)\n                vision_tower.vision_tower.embeddings.image_size = crop_size\n                vision_tower.config.image_size = crop_size\n                vision_tower.image_processor.crop_size = dict(height=crop_size, width=crop_size)\n                vision_tower.image_processor.size = dict(shortest_edge=crop_size)\n\n        from accelerate import load_checkpoint_and_dispatch\n        with disable_logging():\n            load_checkpoint_and_dispatch(model=model,\n                                         max_memory=self.max_memory,\n                                         checkpoint=self.model_path,\n                                         device_map='auto' if not self.with_llm else {'': 'cpu'},\n                                         no_split_module_classes=['InternVisionEncoderLayer'],\n                                         dtype=torch.half)\n\n        self.model = model.model.eval()\n        self.vision_tower = model.model.vision_tower.eval()\n        self.mm_projector = model.model.mm_projector.eval()\n\n    def preprocess(self, messages: List[Dict]) -> List[Dict]:\n        \"\"\"Refer to `super().preprocess() for spec.\"\"\"\n        return super().preprocess(messages)\n\n    @torch.no_grad()\n    def forward(self, messages: List[Dict], max_batch_size: int = 1) -> List[Dict]:\n        \"\"\"Extract image feature. ONLY implement it when the backend is\n        turbomind engine.\n\n        Args:\n            messages(List[Dict]): the outputs of `preprocess`\n            max_batch_size(int): the max batch size when forwarding vision\n                model\n        Return:\n            the message list with forwarding results included\n        \"\"\"\n        inputs = [x['content'] for x in messages if x['role'] == 'preprocess']\n        inputs = inputs[0]\n        outputs = []\n        for idx in range(0, len(inputs), max_batch_size):\n            pixel_values = [x['pixel_values'] for x in inputs[idx:idx + max_batch_size]]\n            split_sizes = [x.shape[0] for x in pixel_values]\n            pixel_values = torch.cat(pixel_values, dim=0)\n            pixel_values = pixel_values.to(device=self.vision_tower.device, dtype=torch.float16)\n            logger.info(f'vision forward shape: {pixel_values.shape}')\n            if pixel_values.ndim == 5:\n                feats = self.encode_images(pixel_values)\n                feats = torch.split(feats, split_sizes, dim=0)\n                feats = [x.flatten(0, 1) for x in feats]\n            else:\n                feats = self.encode_images(pixel_values)\n                feats = [x for x in feats]\n            outputs.extend(feats)\n        messages.append(dict(role='forward', content=outputs))\n        return messages\n"
  },
  {
    "path": "lmdeploy/vl/model/llama4.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom typing import Dict, List\n\nimport torch\nfrom transformers import AutoConfig\n\nfrom lmdeploy.utils import get_logger\nfrom lmdeploy.vl.model.base import VISION_MODELS, VisionModel\n\nlogger = get_logger('lmdeploy')\n\n\ndef check_trans_version():\n    \"\"\"Check if the installed version of the 'transformers' library is smaller\n    than the specified version.\"\"\"\n    import transformers\n    from packaging import version\n\n    min_version = '4.51.0'\n    installed_version = transformers.__version__\n    assert version.parse(installed_version) >= version.parse(min_version), (\n        f'llama4 requires transformers version >= {min_version}, '\n        f'but found version: {installed_version}. Please upgrade.')\n\n\n@VISION_MODELS.register_module()\nclass LLama4VisionModel(VisionModel):\n    \"\"\"Llama4 vision model.\"\"\"\n\n    _arch = 'Llama4ForConditionalGeneration'\n\n    @classmethod\n    def match(cls, config: AutoConfig):\n        \"\"\"Check whether the config match the model.\"\"\"\n        arch = config.architectures[0]\n        return arch == cls._arch\n\n    def build_preprocessor(self):\n        check_trans_version()\n        from transformers.models.llama4 import Llama4Processor\n        from transformers.models.llama4.processing_llama4 import Llama4ProcessorKwargs\n        self.processor = Llama4Processor.from_pretrained(\n            self.model_path,\n            padding_side='left',\n        )\n        img_patch_token = self.processor.img_patch_token\n        self.image_token_id = self.processor.tokenizer.encode(img_patch_token, add_special_tokens=False)[0]\n        self.images_kwargs = self.processor._merge_kwargs(\n            Llama4ProcessorKwargs,\n            tokenizer_init_kwargs=self.processor.tokenizer.init_kwargs,\n            return_tensors='pt',\n            add_special_tokens=False,\n        )['images_kwargs']\n\n    def build_model(self):\n        \"\"\"Build the vision part of a VLM model when backend is turbomind, or\n        load the whole VLM model when `self.with_llm==True`\"\"\"\n        # TODO, implement for tubomind engine\n        raise NotImplementedError()\n\n    def preprocess(self, messages: List[Dict]) -> List[Dict]:\n        \"\"\"Refers to `super.preprocess() for spec.\"\"\"\n        images = self.collect_multimodal_items(messages)\n        outputs = []\n        processor = self.processor\n        patch_size = processor.patch_size\n        downsample_ratio = processor.downsample_ratio\n        images_kwargs = self.images_kwargs\n        for modality, image, params in images:\n            image_inputs = processor.image_processor(images=[image], **images_kwargs)\n            pixel_values = image_inputs['pixel_values']\n            image_height, image_width = image_inputs['pixel_values'][0].shape[-2:]\n            num_patches_per_chunk = int((image_height // patch_size) * (image_width // patch_size) // downsample_ratio)\n            aspect_ratios = image_inputs.pop('aspect_ratios')\n            image_prompts = processor._prompt_split_image(aspect_ratios[0], num_patches_per_chunk)\n            image_tokens = image_prompts.count('<|') - 2\n            outputs.append(\n                dict(pixel_values=pixel_values,\n                     image_tokens=image_tokens,\n                     image_token_id=self.image_token_id,\n                     image_size=image.size,\n                     image_prompts=image_prompts))\n        messages.append(dict(role='preprocess', content=outputs))\n        return messages\n\n    @torch.no_grad()\n    def forward(self, messages: List[Dict], max_batch_size: int = 1) -> List[Dict]:\n        \"\"\"Extract image feature. ONLY implement it when the backend is\n        turbomind engine.\n\n        Args:\n            messages(List[Dict]): the outputs of `preprocess`\n            max_batch_size(int): the max batch size when forwarding vision\n                model\n        Return:\n            the message list with forwarding results included\n        \"\"\"\n        # TODO, implement for turbomind engine\n        raise NotImplementedError()\n\n    @staticmethod\n    def proc_messages(messages, chat_template, sequence_start):\n        \"\"\"Apply chat template to get the prompt.\"\"\"\n        prompt_messages = []\n        IMAGE_TOKEN = '<IMAGE_TOKEN>'\n        for message in messages:\n            if isinstance(message['content'], str):\n                prompt_messages.append(message)\n                continue\n            elif message['role'] in ['preprocess', 'forward']:\n                continue\n            n_images = len([1 for x in message['content'] if x['type'] == 'image'])\n            content = [x.get('text', '') for x in message['content'] if x['type'] == 'text']\n            prompt = content[0]\n            if IMAGE_TOKEN not in prompt:\n                prompt = f'{IMAGE_TOKEN * n_images}' + prompt\n            prompt_messages.append(dict(role='user', content=prompt))\n        prompt = chat_template.messages2prompt(prompt_messages, sequence_start)\n        return prompt, IMAGE_TOKEN\n\n    def to_pytorch_aux(self, messages, prompt, IMAGE_TOKEN, tokenizer, sequence_start):\n        \"\"\"Auxiliary function to pack the preprocessing results in a format\n        compatible with what is required by pytorch engine.\n\n        Args:\n            messages(List[Dict]): the output of `preprocess`\n            prompt(str): the prompt after applying chat template\n            IMAGE_TOKEN(str): a placeholder where image tokens will be\n                inserted\n            tokenzer: the tokenizer model\n            sequence_start: starting flag of a sequence\n        \"\"\"\n        # collect all preprocessing result from messages\n        preps = [x['content'] for x in messages if x['role'] == 'preprocess']\n        assert len(preps) == 1\n        preps = preps[0]\n\n        # split prompt into segments and validate data\n        segs = prompt.split(IMAGE_TOKEN)\n        assert len(segs) == len(preps) + 1, (f'the number of {IMAGE_TOKEN} is not equal '\n                                             f'to input images, {len(segs) - 1} vs {len(preps)}')\n\n        # calculate the image token offset for each image\n        input_ids = []\n        for i, seg in enumerate(segs):\n            if i > 0 and i <= len(preps):\n                prep = preps[i - 1]\n                image_prompts = prep.pop('image_prompts', '')\n                prep.update(offset=len(input_ids) + 1)\n                assert self.image_token_id == prep['image_token_id']\n                seg = image_prompts + seg\n            token_ids = tokenizer.encode(seg, add_bos=((i == 0) and sequence_start))\n            input_ids.extend(token_ids)\n        return dict(prompt=prompt, input_ids=input_ids, multimodal=preps)\n\n    def to_pytorch(self, messages, chat_template, tokenizer, sequence_start, **kwargs):\n        prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, sequence_start)\n        return self.to_pytorch_aux(messages, prompt, IMAGE_TOKEN, tokenizer, sequence_start)\n\n    def to_turbomind(self, messages, chat_template, tokenizer, sequence_start, **kwargs):\n        prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, sequence_start)\n        return self.to_turbomind_aux(messages, prompt, IMAGE_TOKEN, tokenizer, sequence_start)\n"
  },
  {
    "path": "lmdeploy/vl/model/llava.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n# Modified from https://github.com/haotian-liu/LLaVA.git\n\nimport ast\nimport math\nimport warnings\nfrom contextlib import contextmanager\nfrom typing import Dict, List\n\nimport torch\nfrom PIL import Image\nfrom transformers import AutoConfig, AutoModelForCausalLM\n\nfrom lmdeploy.utils import get_logger\nfrom lmdeploy.vl.model.llava_hf import VISION_MODELS, LlavaHfVisionModel\nfrom lmdeploy.vl.model.utils import disable_logging, rewrite_ctx\n\nlogger = get_logger('lmdeploy')\n\n\ndef check_llava_install():\n    \"\"\"Check llava install.\"\"\"\n    try:\n        import llava  # noqa: F401\n    except ImportError:\n        raise ImportError('To use LlavaVLModel, please install llava by '\n                          '`pip install git+https://github.com/haotian-liu/LLaVA.git --no-deps`'  # noqa: E501\n                          )\n\n\ndef _clip_vision_tower_load_model(self, **kwargs):\n    logger.info(f'CLIPVisionTower.load_model: {self.vision_tower_name}')\n    from transformers import CLIPVisionConfig, CLIPVisionModel\n\n    config = CLIPVisionConfig.from_pretrained(self.vision_tower_name)\n    self.vision_tower = CLIPVisionModel._from_config(config=config)\n    self.vision_tower.requires_grad_(False)\n    self.is_loaded = True\n\n\n@contextmanager\ndef init_llava_vision_tower(config):\n    \"\"\"Skip download vision model if possible.\"\"\"\n    if getattr(config, 'unfreeze_mm_vision_tower', False):\n        origin_func_path = [\n            'llava.model.multimodal_encoder.clip_encoder.CLIPVisionTower.load_model'  # noqa: E501\n        ]\n        rewrite_func = [_clip_vision_tower_load_model]\n        with rewrite_ctx(origin_func_path, rewrite_func):\n            yield\n    else:\n        yield\n\n\ndef select_best_resolution(original_size, possible_resolutions):\n    \"\"\"Selects the best resolution from a list of possible resolutions based on\n    the original size.\n\n    Args:\n        original_size (tuple): The original size of the image in the format (width, height).\n        possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].\n\n    Returns:\n        tuple: The best fit resolution in the format (width, height).\n    \"\"\"  # noqa\n    original_width, original_height = original_size\n    best_fit = None\n    max_effective_resolution = 0\n    min_wasted_resolution = float('inf')\n\n    for width, height in possible_resolutions:\n        scale = min(width / original_width, height / original_height)\n        downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)\n        effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)\n        wasted_resolution = (width * height) - effective_resolution\n\n        if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution\n                                                               and wasted_resolution < min_wasted_resolution):\n            max_effective_resolution = effective_resolution\n            min_wasted_resolution = wasted_resolution\n            best_fit = (width, height)\n\n    return best_fit\n\n\ndef resize_and_pad_image(image, target_resolution):\n    \"\"\"Resize and pad an image to a target resolution while maintaining aspect\n    ratio.\n\n    Args:\n        image (PIL.Image.Image): The input image.\n        target_resolution (tuple): The target resolution (width, height) of the image.\n\n    Returns:\n        PIL.Image.Image: The resized and padded image.\n    \"\"\"  # noqa\n    original_width, original_height = image.size\n    target_width, target_height = target_resolution\n\n    scale_w = target_width / original_width\n    scale_h = target_height / original_height\n\n    if scale_w < scale_h:\n        new_width = target_width\n        new_height = min(math.ceil(original_height * scale_w), target_height)\n    else:\n        new_height = target_height\n        new_width = min(math.ceil(original_width * scale_h), target_width)\n\n    # Resize the image\n    resized_image = image.resize((new_width, new_height))\n\n    new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0))\n    paste_x = (target_width - new_width) // 2\n    paste_y = (target_height - new_height) // 2\n    new_image.paste(resized_image, (paste_x, paste_y))\n\n    return new_image\n\n\ndef divide_to_patches(image, patch_size):\n    \"\"\"Divides an image into patches of a specified size.\n\n    Args:\n        image (PIL.Image.Image): The input image.\n        patch_size (int): The size of each patch.\n\n    Returns:\n        list: A list of PIL.Image.Image objects representing the patches.\n    \"\"\"\n    patches = []\n    width, height = image.size\n    for i in range(0, height, patch_size):\n        for j in range(0, width, patch_size):\n            box = (j, i, j + patch_size, i + patch_size)\n            patch = image.crop(box)\n            patches.append(patch)\n\n    return patches\n\n\ndef process_anyres_image(image, processor, grid_pinpoints):\n    \"\"\"Process an image with variable resolutions.\n\n    Args:\n        image (PIL.Image.Image): The input image to be processed.\n        processor: The image processor object.\n        grid_pinpoints (str): A string representation of a list of possible resolutions.\n\n    Returns:\n        torch.Tensor: A tensor containing the processed image patches.\n    \"\"\"  # noqa\n    if type(grid_pinpoints) is list:\n        possible_resolutions = grid_pinpoints\n    else:\n        possible_resolutions = ast.literal_eval(grid_pinpoints)\n    best_resolution = select_best_resolution(image.size, possible_resolutions)\n    image_padded = resize_and_pad_image(image, best_resolution)\n\n    patches = divide_to_patches(image_padded, processor.crop_size['height'])\n\n    image_original_resize = image.resize((processor.size['shortest_edge'], processor.size['shortest_edge']))\n\n    image_patches = [image_original_resize] + patches\n    image_patches = [\n        processor.preprocess(image_patch, return_tensors='pt')['pixel_values'][0] for image_patch in image_patches\n    ]\n    return torch.stack(image_patches, dim=0)\n\n\ndef expand2square(pil_img, background_color):\n    width, height = pil_img.size\n    if width == height:\n        return pil_img\n    elif width > height:\n        result = Image.new(pil_img.mode, (width, width), background_color)\n        result.paste(pil_img, (0, (width - height) // 2))\n        return result\n    else:\n        result = Image.new(pil_img.mode, (height, height), background_color)\n        result.paste(pil_img, ((height - width) // 2, 0))\n        return result\n\n\ndef process_images(images, image_processor, model_cfg):\n    image_aspect_ratio = getattr(model_cfg, 'image_aspect_ratio', None)\n    new_images = []\n    if image_aspect_ratio == 'pad':\n        for image in images:\n            image = expand2square(image, tuple(int(x * 255) for x in image_processor.image_mean))\n            image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]\n            new_images.append(image)\n    elif image_aspect_ratio == 'anyres':\n        for image in images:\n            image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints)\n            new_images.append(image)\n    else:\n        return image_processor(images, return_tensors='pt')['pixel_values']\n    if all(x.shape == new_images[0].shape for x in new_images):\n        new_images = torch.stack(new_images, dim=0)\n    return new_images\n\n\n@VISION_MODELS.register_module()\nclass LlavaVisionModel(LlavaHfVisionModel):\n    \"\"\"Llava visual model.\"\"\"\n\n    @classmethod\n    def match(cls, config: AutoConfig):\n        \"\"\"Check whether the config match the model.\"\"\"\n        arch = config.architectures[0] if config.architectures else None\n        if arch in ['LlavaLlamaForCausalLM', 'LlavaMistralForCausalLM']:\n            # internvl-llava has vision_tower of OpenGVLab/xxx\n            mm_vision_tower = getattr(config, 'mm_vision_tower', '')\n            # yi-vl has projector type of xxx_Norm\n            projector_type = getattr(config, 'mm_projector_type', 'linear')\n            if '_Norm' in projector_type:\n                return False\n            if 'OpenGVLab' in mm_vision_tower:\n                return False\n            return True\n        return False\n\n    def build_preprocessor(self):\n        from transformers import CLIPImageProcessor\n        self.image_processor = CLIPImageProcessor.from_pretrained(self.hf_config.mm_vision_tower)\n        config = AutoConfig.from_pretrained(self.hf_config.mm_vision_tower)\n        image_size = config.vision_config.image_size\n        patch_size = config.vision_config.patch_size\n        self.n_token_per_image = (image_size // patch_size)**2\n        if self.hf_config.mm_vision_select_feature == 'cls_patch':\n            self.n_token_per_image += 1\n\n    def build_model(self):\n        \"\"\"Build the vision part of a VLM model when backend is turbomind, or\n        load the whole VLM model when `self.with_llm==True`\"\"\"\n        check_llava_install()\n\n        self.arch = self.hf_config.architectures[0]\n        model = None\n        if self.arch == 'LlavaLlamaForCausalLM':\n            from llava.model.language_model.llava_llama import LlavaConfig\n            self.config = LlavaConfig.from_pretrained(self.model_path)\n            assert self.config.model_type in ['llava', 'llava_llama'], \\\n                f'expect model_type llava and llava_llama '\\\n                f'but got {self.config.model_type}'\n        elif self.arch == 'LlavaMistralForCausalLM':\n            from llava.model.language_model.llava_mistral import LlavaMistralConfig\n            self.config = LlavaMistralConfig.from_pretrained(self.model_path)\n        else:\n            assert 0, f'unsupported arch {self.arch}'\n\n        from accelerate import init_empty_weights\n\n        # init empty model, skip layer initialization\n        with init_empty_weights(), warnings.catch_warnings(), \\\n                init_llava_vision_tower(self.config):\n            warnings.simplefilter('ignore')\n            self.config.quantization_config = {}  # disable vision part quantization\n            model = AutoModelForCausalLM.from_config(self.config, trust_remote_code=True)\n\n        self.vl_model = model\n        if not self.with_llm:\n            # remove the LLM part from llava model.\n            del model.lm_head\n            del model.model.embed_tokens\n            del model.model.layers\n            del model.model.norm\n\n        # init empty vision_tower, the embedding layer in CLIPVisionModel\n        # can't init right under init_empty_weights\n        with init_llava_vision_tower(self.config):\n            vision_tower = model.get_vision_tower()\n            vision_tower.is_loaded = False\n            vision_tower.load_model()\n            # for llava-v1.5, the vit is not in llm ckpt\n            vision_tower.to(dtype=torch.half)\n\n        from accelerate import load_checkpoint_and_dispatch\n        with disable_logging():\n            load_checkpoint_and_dispatch(model=model,\n                                         max_memory=self.max_memory,\n                                         checkpoint=self.model_path,\n                                         device_map='auto' if not self.with_llm else {'': 'cpu'},\n                                         no_split_module_classes=['CLIPEncoderLayer'],\n                                         dtype=torch.half)\n\n        self.model = model.model.eval()\n        self.vision_tower = model.model.vision_tower.half().eval()\n        self.mm_projector = model.model.mm_projector.half().eval()\n\n    def encode_images(self, images: torch.Tensor) -> torch.Tensor:\n        \"\"\"Encode images.\"\"\"\n        image_features = self.vision_tower(images)\n        image_features = self.mm_projector(image_features)\n        return image_features\n\n    def preprocess(self, messages: List[Dict]) -> List[Dict]:\n        \"\"\"Refer to `super().preprocess() for spec.\"\"\"\n        images = self.collect_multimodal_items(messages)\n        outputs = []\n        for modality, image, params in images:\n            image = image.convert('RGB')\n            pixel_values = process_images([image], self.image_processor, self.config)\n            outputs.append(\n                dict(pixel_values=pixel_values,\n                     image_size=image.size,\n                     image_tokens=self.n_token_per_image,\n                     image_token_id=self.image_token_id))\n        messages.append(dict(role='preprocess', content=outputs))\n        return messages\n\n    @torch.no_grad()\n    def forward(self, messages: List[Dict], max_batch_size: int = 1) -> List[Dict]:\n        \"\"\"Extract image feature. ONLY implement it when the backend is\n        turbomind engine.\n\n        Args:\n            messages(List[Dict]): the outputs of `preprocess`\n            max_batch_size(int): the max batch size when forwarding vision\n                model\n        Return:\n            the message list with forwarding results included\n        \"\"\"\n\n        from llava.model.llava_arch import get_anyres_image_grid_shape, unpad_image\n\n        inputs = [x['content'] for x in messages if x['role'] == 'preprocess']\n        inputs = inputs[0]\n        outputs = []\n        for idx in range(0, len(inputs), max_batch_size):\n            image_sizes = [x['image_size'] for x in inputs[idx:idx + max_batch_size]]\n            pixel_values = [x['pixel_values'] for x in inputs[idx:idx + max_batch_size]]\n            if pixel_values[0].ndim == 5:\n                split_sizes = [x.shape[1] for x in pixel_values]\n                pixel_values = torch.cat([x for x in pixel_values], dim=1)\n                logger.info(f'vision forward shape: {pixel_values.shape}')\n                pixel_values = pixel_values.squeeze(0)\n                pixel_values = pixel_values.to(device=self.vision_tower.device, dtype=torch.float16)\n                feats = self.encode_images(pixel_values)\n                feats = torch.split(feats, split_sizes, dim=0)\n                mm_patch_merge_type = getattr(self.config, 'mm_patch_merge_type', 'flat')\n                image_aspect_ratio = getattr(self.config, 'image_aspect_ratio', 'square')\n                if mm_patch_merge_type == 'flat':\n                    outputs.expand([x.flatten(0, 1) for x in feats])\n                elif mm_patch_merge_type.startswith('spatial'):\n                    for img_idx, feat in enumerate(feats):\n                        if feat.shape[0] > 1:\n                            base_feat = feat[0]\n                            feat = feat[1:]\n                            height = self.vision_tower.num_patches_per_side\n                            width = self.vision_tower.num_patches_per_side\n                            assert height * width == base_feat.shape[0]\n                            if image_aspect_ratio == 'anyres':\n                                num_patch_width, num_patch_height = \\\n                                    get_anyres_image_grid_shape(\n                                        image_sizes[img_idx],\n                                        self.config.image_grid_pinpoints,\n                                        self.vision_tower.config.image_size)\n                                feat = feat.view(num_patch_height, num_patch_width, height, width, -1)\n                            else:\n                                raise NotImplementedError\n                            if 'unpad' in mm_patch_merge_type:\n                                feat = feat.permute(4, 0, 2, 1, 3).contiguous()\n                                feat = feat.flatten(1, 2).flatten(2, 3)\n                                feat = unpad_image(feat, image_sizes[img_idx])\n                                feat = torch.cat((feat, self.model.image_newline[:, None, None].expand(\n                                    *feat.shape[:-1], 1).to(feat.device)),\n                                                 dim=-1)\n                                feat = feat.flatten(1, 2).transpose(0, 1)\n                            else:\n                                feat = feat.permute(0, 2, 1, 3, 4).contiguous()\n                                feat = feat.flatten(0, 3)\n                            feat = torch.cat((base_feat, feat), dim=0)\n                        else:\n                            feat = feat[0]\n                            if 'unpad' in mm_patch_merge_type:\n                                feat = torch.cat((feat, self.model.image_newline[None].to(feat.device)), dim=0)\n                        outputs.append(feat)\n                else:\n                    raise ValueError('Unexpected mm_patch_merge_type: '\n                                     f'{self.config.mm_patch_merge_type}')\n            else:\n                pixel_values = torch.cat(pixel_values, dim=0)\n                pixel_values = pixel_values.to(device=self.vision_tower.device, dtype=torch.float16)\n                logger.info(f'vision forward shape: {pixel_values.shape}')\n                feats = self.encode_images(pixel_values)\n                outputs.extend([x for x in feats])\n        messages.append(dict(role='forward', content=outputs))\n        return messages\n"
  },
  {
    "path": "lmdeploy/vl/model/llava_hf.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport warnings\nfrom typing import Dict, List\n\nimport torch\nfrom transformers import AutoProcessor\n\nfrom lmdeploy.utils import get_logger\nfrom lmdeploy.vl.model.base import VISION_MODELS, VisionModel\nfrom lmdeploy.vl.model.utils import disable_logging\n\nlogger = get_logger('lmdeploy')\n\n\n@VISION_MODELS.register_module()\nclass LlavaHfVisionModel(VisionModel):\n    \"\"\"Llava hf vision model.\"\"\"\n\n    _arch = 'LlavaForConditionalGeneration'\n\n    def build_preprocessor(self):\n        processor = AutoProcessor.from_pretrained(self.model_path, trust_remote_code=True)\n        if hasattr(processor, 'tokenizer'):\n            del processor.tokenizer\n            processor.prtokenizer = None\n        self.processor = processor.image_processor\n        image_size = self.hf_config.vision_config.image_size\n        patch_size = self.hf_config.vision_config.patch_size\n        self.n_token_per_image = (image_size // patch_size)**2\n        if self.hf_config.vision_feature_select_strategy == 'full':\n            self.n_token_per_image += 1\n\n    def build_model(self):\n        \"\"\"Build the vision part of a VLM model when backend is turbomind, or\n        load the whole VLM model when `self.with_llm==True`\"\"\"\n        from accelerate import init_empty_weights, load_checkpoint_and_dispatch\n\n        with init_empty_weights(), warnings.catch_warnings():\n            warnings.simplefilter('ignore')\n            from transformers import LlavaForConditionalGeneration\n            model = LlavaForConditionalGeneration._from_config(self.hf_config)\n            self.vl_model = model\n            if not self.with_llm:\n                del model.language_model\n\n        # fix for llava-hf/llava-interleave-qwen-7b-hf\n        setattr(model.config, 'tie_word_embeddings', False)\n        with disable_logging():\n            load_checkpoint_and_dispatch(model=model,\n                                         max_memory=self.max_memory,\n                                         checkpoint=self.model_path,\n                                         device_map='auto' if not self.with_llm else {'': 'cpu'},\n                                         no_split_module_classes=['CLIPEncoderLayer', 'SiglipEncoderLayer'],\n                                         dtype=torch.half)\n        model.eval()\n        self.model = model\n\n    def preprocess(self, messages: List[Dict]) -> List[Dict]:\n        \"\"\"Refers to `super.preprocess() for spec.\"\"\"\n        images = self.collect_multimodal_items(messages)\n        outputs = []\n        for modality, image, params in images:\n            image = image.convert('RGB')\n            pixel_values = self.processor(image, return_tensors='pt', input_data_format='channels_last').pixel_values\n            outputs.append(\n                dict(pixel_values=pixel_values,\n                     image_size=image.size,\n                     image_tokens=self.n_token_per_image,\n                     image_token_id=self.image_token_id))\n        messages.append(dict(role='preprocess', content=outputs))\n        return messages\n\n    @torch.no_grad()\n    def forward(self, messages: List[Dict], max_batch_size: int = 1) -> List[Dict]:\n        \"\"\"Extract image feature. ONLY implement it when the backend is\n        turbomind engine.\n\n        Args:\n            messages(List[Dict]): the outputs of `preprocess`\n            max_batch_size(int): the max batch size when forwarding vision\n                model\n        Return:\n            the message list with forwarding results included\n        \"\"\"\n        inputs = [x['content'] for x in messages if x['role'] == 'preprocess']\n        inputs = inputs[0]\n        outputs = []\n        for idx in range(0, len(inputs), max_batch_size):\n            pixel_values = [x['pixel_values'] for x in inputs[idx:idx + max_batch_size]]\n            pixel_values = torch.cat(pixel_values, dim=0)\n            pixel_values = pixel_values.to(device=self.model.device, dtype=self.model.dtype)\n            logger.info(f'vision forward shape: {pixel_values.shape}')\n            image_outputs = self.model.vision_tower.forward(pixel_values, output_hidden_states=True)\n            image_features = image_outputs.hidden_states[self.hf_config.vision_feature_layer]\n            if self.hf_config.vision_feature_select_strategy == 'default':\n                image_features = image_features[:, 1:]\n            elif self.hf_config.vision_feature_select_strategy == 'full':\n                image_features = image_features\n            else:\n                raise ValueError('Unexpected select feature strategy: '\n                                 f'{self.hf_config.vision_feature_select_strategy}')\n            image_features = self.model.multi_modal_projector(image_features)\n            image_features = torch.split(image_features, 1, dim=0)\n            outputs.extend([x.squeeze() for x in image_features])\n        messages.append(dict(role='forward', content=outputs))\n        return messages\n\n    @staticmethod\n    def proc_messages(messages, chat_template, sequence_start):\n        \"\"\"Apply chat template to get the prompt.\"\"\"\n        prompt_messages = []\n        IMAGE_TOKEN = '<IMAGE_TOKEN>'\n        for message in messages:\n            if isinstance(message['content'], str):\n                prompt_messages.append(message)\n                continue\n            elif message['role'] in ['images', 'preprocess', 'forward']:\n                continue\n            n_images = len([1 for x in message['content'] if x['type'] == 'image'])\n            content = [item['text'] for item in message['content'] if item['type'] == 'text']\n            prompt = (IMAGE_TOKEN + '\\n') * n_images + content[0]\n            prompt_messages.append(dict(role='user', content=prompt))\n        prompt = chat_template.messages2prompt(prompt_messages, sequence_start)\n        return prompt, IMAGE_TOKEN\n\n    def to_pytorch(self, messages, chat_template, tokenizer, sequence_start, **kwargs):\n        prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, sequence_start)\n        return self.to_pytorch_aux(messages, prompt, IMAGE_TOKEN, tokenizer, sequence_start)\n\n    def to_turbomind(self, messages, chat_template, tokenizer, sequence_start, **kwargs):\n        prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, sequence_start)\n        return self.to_turbomind_aux(messages, prompt, IMAGE_TOKEN, tokenizer, sequence_start)\n"
  },
  {
    "path": "lmdeploy/vl/model/llava_next.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport itertools\nimport warnings\nfrom typing import Dict, List\n\nimport torch\n\nfrom lmdeploy.utils import get_logger\nfrom lmdeploy.vl.model.llava_hf import VISION_MODELS, LlavaHfVisionModel\nfrom lmdeploy.vl.model.utils import disable_logging\n\nlogger = get_logger('lmdeploy')\n\n\n@VISION_MODELS.register_module()\nclass LlavaNextVisionModel(LlavaHfVisionModel):\n    \"\"\"Llava hf vision model.\"\"\"\n\n    _arch = 'LlavaNextForConditionalGeneration'\n\n    def build_preprocessor(self):\n        super().build_preprocessor()\n        # build the model with empty weights. The model will be used in\n        # `preprocess` to get the image token number\n        from accelerate import init_empty_weights\n        with init_empty_weights(), warnings.catch_warnings():\n            warnings.simplefilter('ignore')\n            from transformers import LlavaNextForConditionalGeneration\n            self.model = LlavaNextForConditionalGeneration._from_config(self.hf_config)\n            self.vl_model = self.model\n            if not self.with_llm:\n                del self.model.language_model\n\n    def build_model(self):\n        \"\"\"Build the vision part of a VLM model when backend is turbomind, or\n        load the whole VLM model when `self.with_llm==True`\"\"\"\n        from accelerate import load_checkpoint_and_dispatch\n        from accelerate.utils import get_balanced_memory, infer_auto_device_map\n\n        no_split_module_classes = ['CLIPEncoderLayer']\n        max_memory = get_balanced_memory(self.model,\n                                         max_memory=self.max_memory,\n                                         dtype=torch.half,\n                                         no_split_module_classes=no_split_module_classes)\n        device_map = infer_auto_device_map(self.model,\n                                           no_split_module_classes=no_split_module_classes,\n                                           max_memory=max_memory,\n                                           dtype=torch.half)\n\n        same_device_keys = [('multi_modal_projector', 'image_newline')]\n        for keys in same_device_keys:\n            keys = [k for k in keys if k in device_map]\n            if len(keys) <= 1:\n                continue\n            for k in keys[1:]:\n                device_map[k] = device_map[keys[0]]\n\n        with disable_logging():\n            load_checkpoint_and_dispatch(model=self.model,\n                                         checkpoint=self.model_path,\n                                         device_map=device_map if not self.with_llm else {'': 'cpu'},\n                                         no_split_module_classes=no_split_module_classes,\n                                         dtype=torch.half)\n        self.model.eval()\n\n    def preprocess(self, messages: List[Dict]) -> List[Dict]:\n        \"\"\"Refers to the spec of `super.preprocess()\"\"\"\n        from transformers.models.llava_next.modeling_llava_next import image_size_to_num_patches\n        images = self.collect_multimodal_items(messages)\n        outputs = []\n        for modality, image, params in images:\n            image = image.convert('RGB')\n            result = self.processor(image, return_tensors='pt', input_data_format='channels_last')\n            # ! infer image_num_patches from image_sizes\n            image_num_patches = [\n                image_size_to_num_patches(\n                    image_size=imsize,\n                    grid_pinpoints=self.hf_config.image_grid_pinpoints,\n                    patch_size=self.hf_config.vision_config.image_size,\n                ) for imsize in result['image_sizes']\n            ]\n\n            hidden_size = self.hf_config.text_config.hidden_size\n            fake_image_features = torch.zeros([image_num_patches[0], self.n_token_per_image, hidden_size])\n            image_sizes = result['image_sizes']\n            image_newline = torch.randn(self.hf_config.text_config.hidden_size)\n            strategy = self.hf_config.vision_feature_select_strategy\n            _, image_tokens = self.model.pack_image_features([fake_image_features],\n                                                             image_sizes,\n                                                             vision_feature_select_strategy=strategy,\n                                                             image_newline=image_newline)\n            result.update(\n                dict(image_size=image.size,\n                     image_patches=image_num_patches,\n                     image_tokens=image_tokens,\n                     image_token_id=self.image_token_id))\n            outputs.append(result)\n        messages.append(dict(role='preprocess', content=outputs))\n        return messages\n\n    @torch.no_grad()\n    def forward(self, messages: List[Dict], max_batch_size: int = 1) -> List[Dict]:\n        \"\"\"Extract image feature. ONLY implement it when the backend is\n        turbomind engine.\n\n        Args:\n            messages(List[Dict]): the outputs of `preprocess`\n            max_batch_size(int): the max batch size when forwarding vision\n                model\n        Return:\n            the message list with forwarding results included\n        \"\"\"\n        inputs = [x['content'] for x in messages if x['role'] == 'preprocess']\n        inputs = inputs[0]\n        outputs = []\n        for idx in range(0, len(inputs), max_batch_size):\n            pixel_values = [\n                x['pixel_values'].to(device=self.model.device, dtype=self.model.dtype)\n                for x in inputs[idx:idx + max_batch_size]\n            ]\n            pixel_values = torch.cat(pixel_values, dim=0)\n            image_sizes = [\n                x['image_sizes'].to(device=self.model.device, dtype=self.model.dtype)\n                for x in inputs[idx:idx + max_batch_size]\n            ]\n            image_sizes = torch.cat(image_sizes, dim=0)\n            image_num_patches = [x['num_patch'] for x in inputs[idx:idx + max_batch_size]]\n            image_num_patches = list(itertools.chain(*image_num_patches))\n            # figure out if pixel_values is concatenated or stacked\n            if pixel_values.dim() == 5:\n                # stacking when input is\n                # (batch_size, num_patches, num_channels, height, width)\n                _pixel_values_list = [\n                    pix_val[:num_patch] for pix_val, num_patch in zip(pixel_values, image_num_patches)\n                ]\n                pixel_values = torch.cat(_pixel_values_list, dim=0)\n            elif pixel_values.dim() != 4:\n                # otherwise has to be stacked from list of\n                # (num_patches, num_channels, height, width)\n                raise ValueError(f'pixel_values of shape {pixel_values.shape}, '\n                                 'expect to be of 4 or 5 dimensions')\n            logger.info(f'vision forward shape: {pixel_values.shape}')\n            image_outputs = self.model.vision_tower.forward(pixel_values, output_hidden_states=True)\n            image_features = image_outputs.hidden_states[self.hf_config.vision_feature_layer]\n            strategy = self.hf_config.vision_feature_select_strategy\n            if strategy == 'default':\n                image_features = image_features[:, 1:]\n            elif strategy == 'full':\n                image_features = image_features\n            else:\n                raise ValueError('Unexpected select feature strategy: '\n                                 f'{strategy}')\n            image_features = self.model.multi_modal_projector(image_features)\n            image_features = torch.split(image_features, image_num_patches, dim=0)\n            image_features, feature_lens = self.model.pack_image_features(\n                image_features,\n                image_sizes,\n                vision_feature_select_strategy=strategy,\n                image_newline=self.model.image_newline,\n            )\n            image_features = torch.split(image_features, feature_lens.cpu().numpy().tolist(), dim=0)\n            outputs.extend(image_features)\n        messages.append(dict(role='forward', content=outputs))\n        return messages\n"
  },
  {
    "path": "lmdeploy/vl/model/minicpmv.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport itertools\nimport warnings\nfrom typing import Dict, List\n\nimport torch\nfrom PIL.Image import Image\nfrom transformers import AutoConfig, AutoModelForCausalLM\n\nfrom lmdeploy.utils import get_logger\nfrom lmdeploy.vl.model.base import VISION_MODELS, VisionModel\nfrom lmdeploy.vl.model.utils import disable_logging\n\nlogger = get_logger('lmdeploy')\n\n\n@VISION_MODELS.register_module()\nclass MiniCPMVModel(VisionModel):\n    \"\"\"MiniCPMV vision model.\"\"\"\n\n    _arch = 'MiniCPMV'\n\n    def __init__(self,\n                 model_path: str,\n                 with_llm: bool = False,\n                 max_memory: Dict[int, int] = None,\n                 hf_config: AutoConfig = None,\n                 backend: str = ''):\n        super().__init__(model_path, with_llm, max_memory, hf_config, backend)\n        if not hasattr(self.hf_config, 'version'):\n            raise ValueError('Can not find `version` in config.json. '\n                             'Please checkout the latest model')\n        version = str(self.hf_config.version)\n        if version not in ['2.5', '2.6']:\n            raise ValueError(f'Only support v2.5 and v2.6, but got version {version}')\n        self.version = version\n\n    def build_preprocessor(self):\n        from transformers import AutoProcessor\n        self.processor = AutoProcessor.from_pretrained(self.model_path, trust_remote_code=True)\n        self.image_processor = self.processor.image_processor\n        self._preprocess_func = (self._preprocess_v2_5 if self.version == '2.5' else self._preprocess_v2_6)\n\n    def build_model(self):\n        \"\"\"Build the vision part of a VLM model when backend is turbomind, or\n        load the whole VLM model when `self.with_llm==True`\"\"\"\n        from accelerate import init_empty_weights\n        with init_empty_weights(), warnings.catch_warnings():\n            warnings.simplefilter('ignore')\n            config = self.hf_config\n            assert config.slice_mode is True, 'only support slice mode'\n            config.quantization_config = {}  # disable vision part quantization\n            model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)\n        self.vl_model = model\n        if not self.with_llm:\n            del model.llm\n\n        from accelerate import load_checkpoint_and_dispatch\n        with disable_logging():\n            load_checkpoint_and_dispatch(\n                model=model,\n                max_memory=self.max_memory,\n                checkpoint=self.model_path,\n                device_map='auto' if not self.with_llm else {'': 'cpu'},\n                no_split_module_classes=['Idefics2EncoderLayer', 'Resampler', 'SiglipEncoderLayer'],\n                dtype=torch.half)\n\n        model.resampler.pos_embed = model.resampler.pos_embed.to(device=model.resampler.proj.device)\n        self.config = config\n        self.model = model.eval()\n\n    def _get_slice_image(self, image: Image):\n        slice_images = []\n        source_image, patches, best_grid = self.image_processor.slice_image(image)\n        slice_images.append(source_image)\n        if len(patches) > 0:\n            for i in range(len(patches)):\n                for j in range(len(patches[0])):\n                    slice_images.append(patches[i][j])\n        return slice_images, best_grid\n\n    def _reshape_by_patch(self, slice_images):\n        tgt_sizes = []\n        patches = []\n        for slice_image in slice_images:\n            slice_image = self.model.transform(slice_image)\n            H, W = slice_image.shape[1:]\n            slice_image = slice_image.numpy()\n            slice_image = self.image_processor.reshape_by_patch(slice_image)\n            slice_image = torch.from_numpy(slice_image)\n            patches.append(slice_image)\n            H //= self.config.patch_size\n            W //= self.config.patch_size\n            tgt_sizes.append(torch.Tensor([H, W]).type(torch.int32))\n        return patches, tgt_sizes\n\n    def _preprocess_v2_5(self, image: Image, params: Dict = None) -> Dict:\n        \"\"\"Image preprocessing for MiniCPM-Llama3-V-2_5.\"\"\"\n        slice_images, best_grid = self._get_slice_image(image)\n        # pixel_values, tgt_sizes are list of torch tensors\n        pixel_values, tgt_sizes = self._reshape_by_patch(slice_images)\n        num_patches = len(pixel_values)\n        return dict(\n            pixel_values=pixel_values,  # a list\n            tgt_sizes=tgt_sizes,  # a list\n            best_grid=best_grid,\n            num_patches=num_patches,\n            image_tokens=1,\n            image_token_id=self.image_token_id)\n\n    def _preprocess_v2_6(self, image: Image, params: Dict = None) -> Dict:\n        \"\"\"Image preprocessing for MiniCPM-V-2_6.\"\"\"\n        max_slice_nums = self.image_processor.max_slice_nums\n        use_image_id = self.image_processor.use_image_id\n        max_slice_nums = params.get('max_slice_nums', max_slice_nums)\n        use_image_id = params.get('use_image_id', use_image_id)\n        outputs = self.image_processor(image, max_slice_nums=max_slice_nums)\n        pixel_values = outputs['pixel_values'][0]\n        num_patches = len(pixel_values)\n        pixel_values = [torch.as_tensor(x) for x in pixel_values]\n        tgt_sizes = outputs['tgt_sizes'][0]\n        tgt_sizes = [torch.as_tensor(x) for x in tgt_sizes]\n        grid = self.image_processor.get_sliced_grid(image_size=image.size, max_slice_nums=max_slice_nums)\n        return dict(\n            pixel_values=pixel_values,  # a list\n            tgt_sizes=tgt_sizes,  # a list\n            best_grid=grid,\n            num_patches=num_patches,\n            image_tokens=1,\n            image_token_id=self.image_token_id,\n            use_image_id=use_image_id)\n\n    def preprocess(self, messages: List[Dict]) -> List[Dict]:\n        \"\"\"Refer to `super().preprocess() for spec.\"\"\"\n        outputs = []\n        for i, message in enumerate(messages):\n            if message['role'] != 'user' or not isinstance(message['content'], List):\n                continue\n            for item in message['content']:\n                if item['type'] == 'image':\n                    image = item['image'].convert('RGB')\n                    params = {k: v for k, v in item.items() if k not in {'type', 'image'}}\n                    result = self._preprocess_func(image, params)\n                    outputs.append(result)\n            messages[i].update(dict(preprocess=outputs))\n        return messages\n\n    @torch.no_grad()\n    def forward(self, messages: List[Dict], max_batch_size: int = 1) -> List[Dict]:\n        \"\"\"Extract image feature. ONLY implement it when the backend is\n        turbomind engine.\n\n        Args:\n            messages(List[Dict]): the outputs of `preprocess`\n            max_batch_size(int): the max batch size when forwarding vision\n                model\n        Return:\n            the message list with forwarding results included\n        \"\"\"\n        # collect preprocess results into a list\n        inputs = []\n        inputs = [x['preprocess'] for x in messages if 'preprocess' in x.keys()]\n        # flatten the list\n        inputs = list(itertools.chain(*inputs))\n        outputs = []\n        for idx in range(0, len(inputs), max_batch_size):\n            tgt_sizes = [x['tgt_sizes'] for x in inputs[idx:idx + max_batch_size]]\n            pixel_values = [x['pixel_values'] for x in inputs[idx:idx + max_batch_size]]\n            num_patches = [x['num_patches'] for x in inputs[idx:idx + max_batch_size]]\n            # flatten the list\n            tgt_sizes = list(itertools.chain(*tgt_sizes))\n            pixel_values = list(itertools.chain(*pixel_values))\n            pixel_values = [x.to(dtype=torch.half, device=self.model.device) for x in pixel_values]\n            pixel_values = [x.flatten(end_dim=1).permute(1, 0) for x in pixel_values]\n            pixel_values = torch.nn.utils.rnn.pad_sequence(pixel_values, batch_first=True, padding_value=0.0)\n            B, L, _ = pixel_values.shape\n            pixel_values = pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L)\n            tgt_sizes = torch.vstack(tgt_sizes).type(torch.int32)\n            max_patches = torch.max(tgt_sizes[:, 0] * tgt_sizes[:, 1])\n            patch_attn_mask = torch.zeros((B, 1, max_patches), dtype=torch.bool, device=self.model.device)\n            logger.info(f'vision forward shape: {pixel_values.shape}')\n            if self.version == '2.5':\n                for j in range(B):\n                    patch_attn_mask[j, :tgt_sizes[j][0] * tgt_sizes[j][1]] = True\n                embeddings = self.model.vpm(pixel_values.type(torch.half),\n                                            patch_attention_mask=patch_attn_mask).last_hidden_state\n            else:\n                for j in range(B):\n                    patch_attn_mask[j, 0, :tgt_sizes[j][0] * tgt_sizes[j][1]] = True\n                embeddings = self.model.vpm(pixel_values.type(torch.half),\n                                            patch_attention_mask=patch_attn_mask,\n                                            tgt_sizes=tgt_sizes).last_hidden_state\n\n            embeddings = self.model.resampler(embeddings, tgt_sizes)\n            embeddings = torch.split(embeddings, num_patches, 0)\n            for embedding in embeddings:\n                embedding = embedding.split(1, dim=0)\n                outputs.extend([x.squeeze() for x in embedding])\n        messages.append(dict(role='forward', content=outputs))\n        return messages\n\n    def proc_messages(self, messages, chat_template, sequence_start):\n        \"\"\"Apply chat template to get the prompt.\"\"\"\n        prompt_messages = []\n        IMAGE_TOKEN = '<IMAGE_TOKEN>'\n        idx = 0\n        for message in messages:\n            if isinstance(message['content'], str):\n                prompt_messages.append(message)\n                continue\n            if 'preprocess' not in message.keys():\n                continue\n            prompts = []\n            for x in message['preprocess']:\n                prompt = f'<image>{IMAGE_TOKEN}</image>'\n                if x.get('use_image_id', False):\n                    prompt = f'<image_id>{idx}</image_id>' + prompt\n                    idx += 1\n                grid = x['best_grid']\n                if grid is not None:\n                    if self.version == '2.5':\n                        slice = '\\n'.join([f'<image>{IMAGE_TOKEN}</image>' * grid[0]] * grid[1])\n                        prompt = f'{prompt}<slice>{slice}</slice>\\n'\n                    elif self.version == '2.6':\n                        slice = '\\n'.join([f'<slice>{IMAGE_TOKEN}</slice>' * grid[0]] * grid[1])\n                        prompt = prompt + slice\n                        prompt += '\\n'\n                else:\n                    prompt = (prompt + '\\n' if self.version == '2.6' else prompt)\n                prompts.append(prompt)\n            content = [x.get('text', '') for x in message['content'] if x['type'] == 'text']\n            prompt = ''.join(prompts) + content[0]\n            prompt_messages.append(dict(role='user', content=prompt))\n        prompt = chat_template.messages2prompt(prompt_messages, sequence_start)\n        return prompt, IMAGE_TOKEN\n\n    def to_pytorch(self, messages, chat_template, tokenizer, sequence_start, **kwargs):\n        prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, sequence_start)\n        return self.to_pytorch_aux(messages, prompt, IMAGE_TOKEN, tokenizer, sequence_start)\n\n    def to_turbomind(self, messages, chat_template, tokenizer, sequence_start, **kwargs):\n        prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, sequence_start)\n        return self.to_turbomind_aux(messages, prompt, IMAGE_TOKEN, tokenizer, sequence_start)\n"
  },
  {
    "path": "lmdeploy/vl/model/mllama.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nfrom typing import Dict, List\n\nfrom lmdeploy.vl.model.base import VISION_MODELS, VisionModel\n\n\ndef check_transformers():\n    try:\n        from transformers import MllamaForConditionalGeneration  # noqa: F401\n    except ImportError:\n        raise ImportError('please install latest transformers by '\n                          'pip install git+https://github.com/huggingface/transformers.git')\n\n\n@VISION_MODELS.register_module()\nclass MllamaVLModel(VisionModel):\n    \"\"\"llama3.2 model.\"\"\"\n\n    _arch = 'MllamaForConditionalGeneration'\n\n    def build_preprocessor(self):\n        from transformers import AutoProcessor\n        self.processor = AutoProcessor.from_pretrained(self.model_path)\n        self.image_token_id = 128256\n\n    def preprocess(self, messages: List[Dict]) -> List[Dict]:\n        \"\"\"Refer to the spec of `super().preprocess`\"\"\"\n        images = self.collect_multimodal_items(messages)\n        outputs = []\n        for modality, image, params in images:\n            image = image.convert('RGB')\n            results = self.processor.image_processor(images=image, return_tensors='pt')\n            results.update(image_size=image.size, image_tokens=1, image_token_id=self.image_token_id)\n            outputs.append(results)\n        messages.append(dict(role='preprocess', content=outputs))\n        return messages\n\n    def build_model(self):\n        check_transformers()\n        if self.with_llm:\n            from transformers import MllamaForConditionalGeneration\n            model = MllamaForConditionalGeneration.from_pretrained(self.model_path, device_map='cpu')\n            self.vl_model = model\n        else:\n            raise NotImplementedError('turbomind has not supported mllama yet')\n\n    @staticmethod\n    def proc_messages(messages, chat_template, sequence_start):\n        \"\"\"Apply chat template to get the prompt.\"\"\"\n        prompt_messages = []\n        IMAGE_TOKEN = '<|image|>'\n        for message in messages:\n            if isinstance(message['content'], str):\n                prompt_messages.append(message)\n                continue\n            elif message['role'] in ['images', 'preprocess', 'forward']:\n                continue\n            n_images = len([1 for x in message['content'] if x['type'] == 'image'])\n            content = [item['text'] for item in message['content'] if item['type'] == 'text']\n            prompt = (IMAGE_TOKEN) * n_images + content[0]\n            prompt_messages.append(dict(role='user', content=prompt))\n        prompt = chat_template.messages2prompt(prompt_messages, sequence_start)\n        return prompt, IMAGE_TOKEN\n\n    def to_pytorch(self, messages, chat_template, tokenizer, sequence_start, **kwargs):\n        prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, sequence_start)\n        return self.to_pytorch_aux(messages, prompt, IMAGE_TOKEN, tokenizer, sequence_start)\n"
  },
  {
    "path": "lmdeploy/vl/model/molmo.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nfrom typing import Dict, List\n\nimport torch\nfrom transformers import AutoModelForCausalLM, AutoProcessor\n\nfrom lmdeploy.utils import get_logger\nfrom lmdeploy.vl.model.base import VISION_MODELS, VisionModel\nfrom lmdeploy.vl.model.utils import disable_logging\n\nlogger = get_logger('lmdeploy')\n\n\n@VISION_MODELS.register_module()\nclass MolmoVisionModel(VisionModel):\n    \"\"\"Molmo's vision model.\"\"\"\n\n    _arch = 'MolmoForCausalLM'\n\n    def build_preprocessor(self):\n        self.processor = AutoProcessor.from_pretrained(self.model_path,\n                                                       trust_remote_code=True,\n                                                       torch_dtype=torch.half,\n                                                       device_map='auto')\n\n    def build_model(self):\n        \"\"\"Build the vision part of a VLM model when backend is turbomind, or\n        load the whole VLM model when `self.with_llm==True`\"\"\"\n        from accelerate import init_empty_weights, load_checkpoint_and_dispatch\n        with init_empty_weights():\n            model = AutoModelForCausalLM.from_config(self.hf_config, trust_remote_code=True)\n\n            self.vl_model = model\n            if not self.with_llm:\n                # Remove nn modules other than embedding from the LLM model\n                for key in ['emb_drop', 'ln_f', 'blocks', 'ff_out']:\n                    del model.model.transformer[key]\n            self.token_embedding = model.model.transformer.wte\n\n        with disable_logging():\n            load_checkpoint_and_dispatch(model=model,\n                                         checkpoint=self.model_path,\n                                         device_map='auto' if not self.with_llm else {'': 'cpu'},\n                                         max_memory=self.max_memory,\n                                         no_split_module_classes=['ResidualAttentionBlock', 'Embedding'],\n                                         dtype=torch.half)\n\n        # We need eval mode to freeze the weights in model, thus,\n        # avoid randomness in inference.\n        self.model = model.eval()\n\n    def preprocess(self, messages: List[Dict]) -> List[Dict]:\n        \"\"\"Refer to the `super.preprocess() for spec.\"\"\"\n        for i, message in enumerate(messages):\n            if not isinstance(message['content'], List):\n                continue\n            images = [x['image'] for x in message['content'] if x['type'] == 'image']\n            content = [x.get('text', '') for x in message['content'] if x['type'] == 'text']\n            prompt = f' User: {content[0]}'\n            tokens = self.processor.tokenizer.encode(prompt, add_special_tokens=False)\n            # preprocess images. The output is a dict, which is\n            # {\n            #     'input_ids': torch.Tensor,\n            #     'images': torch.Tensor, # (n_patch, d_model)\n            #     'image_input_idx': torch.Tensor, # (n_patch, d_model)\n            #     'image_masks': torch.Tensor,  # (n_patch, d_model)\n            # }\n            result = self.processor.process(images=images, tokens=tokens)\n            # remove the bos from input_ids which is prepended by molmo's\n            # processor\n            input_ids = result['input_ids'][1:]\n            result.update(input_ids=input_ids)\n            messages[i].update(preprocess=result)\n        return messages\n\n    @torch.no_grad()\n    def forward(self, messages: List[Dict], max_batch_size: int = 1) -> List[Dict]:\n        \"\"\"Extract image feature. ONLY implement it when the backend is\n        turbomind engine.\n\n        Args:\n            messages(List[Dict]): the outputs of `preprocess`\n            max_batch_size(int): the max batch size when forwarding vision\n                model\n        Return:\n            the message list with forwarding results included\n        \"\"\"\n        for i, message in enumerate(messages):\n            if 'preprocess' not in message.keys():\n                continue\n            inputs = message['preprocess']\n            # get input_ids of embedding\n            inputs = {k: v.to(self.model.device).unsqueeze(0) for k, v in inputs.items()}\n            input_ids = inputs['input_ids']\n            # (batch_size, num_image, num_patch, d_model)\n            images = inputs['images']\n            # (batch_size, num_image, num_patch)\n            image_input_idx = inputs['image_input_idx']\n            image_masks = inputs['image_masks']\n            batch_size, seq_len = input_ids.size()\n            assert batch_size == 1\n            input_ids = input_ids * (input_ids != -1).to(input_ids.dtype)\n            embeddings = self.model.model.transformer.wte(input_ids)\n            images = images.to(self.model.dtype)\n            image_masks = image_masks.to(self.model.dtype)\n            logger.info(f'vision forward shape: {images.shape}')\n            image_features, _ = self.model.model.vision_backbone(images, image_masks)\n            num_image, num_patch = image_features.shape[1:3]\n            assert image_input_idx.shape == (batch_size, num_image, num_patch)\n\n            # insert the image feature into the embedding.\n            image_features = image_features.view(batch_size, num_image * num_patch, -1)\n            image_input_idx = image_input_idx.view(batch_size, num_image * num_patch)\n            valid = image_input_idx >= 0\n            batch_idx = torch.arange(batch_size, device=embeddings.device)\n            batch_idx = torch.tile(batch_idx[:, None], [1, image_features.shape[1]])\n            image_features = image_features.to(embeddings.device)\n            # Since we remove bos_id from input_ids during `preprocess`,\n            # the index `image_input_idx[valid]` should be shift to left\n            # by subtracting 1\n            index = image_input_idx[valid] - 1\n            embeddings[batch_idx[valid], index] += image_features[valid]\n            assert embeddings.shape[:2] == (batch_size, seq_len)\n            messages[i].update(dict(forward=dict(input_ids=input_ids.flatten(), embeddings=embeddings)))\n        return messages\n\n    @staticmethod\n    def proc_messages(messages):\n        prompt = []\n        IMAGE_TOKEN = '<IMAGE_TOKEN>'\n        for message in messages:\n            role, content = message['role'], message['content']\n            if isinstance(content, List):\n                n_images = len([1 for x in content if x['type'] == 'image'])\n                content = [x['text'] for x in content if x['type'] == 'text']\n                prompt.append(' User: ' + (IMAGE_TOKEN + '\\n') * n_images + content[0])\n            else:\n                if role == 'user':\n                    prompt.append(f' User: {content}')\n                elif role == 'assistant':\n                    prompt.append(f' Assistant:{content}')\n                else:\n                    assert 0, f'molmo does not support role {role}, message is {message}'  # noqa\n        prompt.append(' Assistant:')\n        return ''.join(prompt)\n\n    def to_pytorch(self, messages, chat_template, tokenizer, sequence_start, **kwargs):\n        assert 0, 'molmo is not supported by pytorch engine'\n\n    def to_turbomind(self, messages, chat_template, tokenizer, sequence_start, **kwargs):\n        # results is a list of tuple(input_ids, embeddings)\n        results = []\n        # Prepend BOS\n        # qwen2 and olmo do not have a BOS, and instead use EOS as a generic\n        # separator token.\n        bos = (self.processor.tokenizer.bos_token_id or self.processor.tokenizer.eos_token_id)\n        results.append(([bos], None))\n\n        for i, message in enumerate(messages):\n            prompt = ''\n            role, content = message['role'], message['content']\n            if isinstance(content, List):\n                forward_result = message.pop('forward')\n                input_ids = forward_result['input_ids']\n                embeddings = forward_result['embeddings']\n                results.append((input_ids.tolist(), embeddings))\n            else:\n                if role == 'user':\n                    prompt = f' User: {content}'\n                elif role == 'assistant':\n                    prompt = f' Assistant:{content}'\n                else:\n                    assert 0, f'molmo does not support role {role}, message is {message}'  # noqa\n            if i == len(messages) - 1:\n                # the last message\n                assert role == 'user', f'the role of last message is expected to be user, but got {role}'  # noqa\n                prompt += ' Assistant:'\n            if prompt:\n                input_ids = self.processor.tokenizer.encode(prompt, add_special_tokens=False)\n                results.append((input_ids, None))\n\n        # concat input_ids from results, calculate the range in the input_ids\n        # where embeddings will be copied to\n        input_ids = []\n        input_embeddings = []\n        input_embedding_ranges = []\n        start = 0\n        for _input_ids, _embeddings in results:\n            if _embeddings is not None:\n                input_embeddings.append(_embeddings.cpu())\n                end = start + len(_input_ids)\n                input_embedding_ranges.append((start, end))\n            input_ids += _input_ids\n            start += len(_input_ids)\n\n        prompt = self.proc_messages(messages)\n        return dict(prompt=prompt,\n                    input_ids=input_ids,\n                    input_embeddings=input_embeddings,\n                    input_embedding_ranges=input_embedding_ranges)\n"
  },
  {
    "path": "lmdeploy/vl/model/phi3_vision.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nfrom typing import Dict, List\n\nfrom transformers import AutoProcessor\n\nfrom lmdeploy.vl.model.llava_hf import VISION_MODELS, LlavaHfVisionModel\n\n\n@VISION_MODELS.register_module()\nclass Phi3VisionModel(LlavaHfVisionModel):\n    \"\"\"Phi3-vision model.\"\"\"\n\n    _arch = 'Phi3VForCausalLM'\n\n    def build_preprocessor(self):\n        processor = AutoProcessor.from_pretrained(self.model_path, trust_remote_code=True)\n        if hasattr(processor, 'tokenizer'):\n            del processor.tokenizer\n            processor.tokenizer = None\n        self.processor = processor\n\n    def build_model(self):\n        if self.with_llm:\n            from transformers import AutoModelForCausalLM\n            self.vl_model = AutoModelForCausalLM.from_pretrained(self.model_path,\n                                                                 device_map='cpu',\n                                                                 trust_remote_code=True)\n        else:\n            raise NotImplementedError('turbomind has not supported phi3v yet')\n\n    def preprocess(self, messages: List[Dict]) -> List[Dict]:\n        \"\"\"Refers to `super.preprocess() for spec.\"\"\"\n        images = self.collect_multimodal_items(messages)\n        outputs = []\n        for modality, image, params in images:\n            result = self.processor.image_processor([image], return_tensors='pt')\n            image_tokens = result['num_img_tokens']\n            result.update(dict(image_size=image.size, image_tokens=image_tokens, image_token_id=self.image_token_id))\n            outputs.append(result)\n        messages.append(dict(role='preprocess', content=outputs))\n        return messages\n"
  },
  {
    "path": "lmdeploy/vl/model/qwen.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nfrom typing import Dict, List\n\nimport torch\nfrom transformers import AutoModelForCausalLM\n\nfrom lmdeploy.utils import get_logger\nfrom lmdeploy.vl.model.base import VISION_MODELS, VisionModel\nfrom lmdeploy.vl.model.utils import disable_logging\n\nlogger = get_logger('lmdeploy')\n\n\n@VISION_MODELS.register_module()\nclass QwenVisionModel(VisionModel):\n    \"\"\"Qwen vision model.\"\"\"\n\n    _arch = 'QWenLMHeadModel'\n\n    def build_preprocessor(self):\n        from torchvision import transforms\n        from torchvision.transforms import InterpolationMode\n        mean = (0.48145466, 0.4578275, 0.40821073)\n        std = (0.26862954, 0.26130258, 0.27577711)\n        image_size = self.hf_config.visual['image_size']\n        self.image_transform = transforms.Compose([\n            transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),\n            transforms.ToTensor(),\n            transforms.Normalize(mean=mean, std=std),\n        ])\n\n    def build_model(self):\n        \"\"\"Build the vision part of a VLM model when backend is turbomind, or\n        load the whole VLM model when `self.with_llm==True`\"\"\"\n        from accelerate import init_empty_weights\n        with init_empty_weights():\n            config = self.hf_config\n            config.quantization_config = {}  # disable vision part quantization\n            model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)\n            self.vl_model = model\n            if not self.with_llm:\n                del model.lm_head\n                for key in ['wte', 'h', 'ln_f']:\n                    setattr(model.transformer, key, None)\n\n        from accelerate.utils import get_balanced_memory, infer_auto_device_map\n        max_memory = get_balanced_memory(model,\n                                         max_memory=self.max_memory,\n                                         dtype=torch.half,\n                                         no_split_module_classes=['VisualAttentionBlock', 'Resampler'])\n        device_map = infer_auto_device_map(model,\n                                           no_split_module_classes=['VisualAttentionBlock', 'Resampler'],\n                                           max_memory=max_memory,\n                                           dtype=torch.half)\n        same_device_keys = [('transformer.visual.conv1', 'transformer.visual.positional_embedding'),\n                            ('transformer.visual.ln_post', 'transformer.visual.proj')]\n        for (a, b) in same_device_keys:\n            if a in device_map and b in device_map:\n                device_map[b] = device_map[a]\n\n        from accelerate import load_checkpoint_and_dispatch\n        with disable_logging():\n            load_checkpoint_and_dispatch(model=model,\n                                         checkpoint=self.model_path,\n                                         device_map=device_map if not self.with_llm else {'': 'cpu'},\n                                         no_split_module_classes=['VisualAttentionBlock'],\n                                         dtype=torch.half)\n\n        self.model = model.transformer.visual.eval()\n\n    def preprocess(self, messages: List[Dict]) -> List[Dict]:\n        \"\"\"Refers to `super.preprocess() for spec.\"\"\"\n        images = self.collect_multimodal_items(messages)\n        outputs = []\n        for modality, image, params in images:\n            image = image.convert('RGB')\n            pixel_values = self.image_transform(image)\n            outputs.append(\n                dict(pixel_values=pixel_values,\n                     image_size=image.size,\n                     image_tokens=256,\n                     image_token_id=self.image_token_id))\n        messages.append(dict(role='preprocess', content=outputs))\n        return messages\n\n    @torch.no_grad()\n    def forward(self, messages: List[Dict], max_batch_size: int = 1) -> List[Dict]:\n        \"\"\"Extract image feature. ONLY implement it when the backend is\n        turbomind engine.\n\n        Args:\n            messages(List[Dict]): the outputs of `preprocess`\n            max_batch_size(int): the max batch size when forwarding vision\n                model\n        Return:\n            the message list with forwarding results included\n        \"\"\"\n        inputs = [x['content'] for x in messages if x['role'] == 'preprocess']\n        inputs = inputs[0]\n        outputs = []\n        for idx in range(0, len(inputs), max_batch_size):\n            pixel_values = [x['pixel_values'] for x in inputs[idx:idx + max_batch_size]]\n            pixel_values = torch.stack(pixel_values, dim=0)\n            logger.info(f'vision forward shape: {pixel_values.shape}')\n            feats = self.model(pixel_values)\n            feats = torch.split(feats, 1, dim=0)\n            outputs.extend([x.squeeze() for x in feats])\n        messages.append(dict(role='forward', content=outputs))\n        return messages\n\n    @staticmethod\n    def proc_messages(messages, chat_template, sequence_start):\n        \"\"\"Apply chat template to get the prompt.\"\"\"\n        prompt_messages = []\n        IMAGE_TOKEN = '<IMAGE_TOKEN>'\n        for message in messages:\n            if isinstance(message['content'], str):\n                prompt_messages.append(message)\n                continue\n            elif message['role'] in ['images', 'preprocess', 'forward']:\n                continue\n            n_images = len([1 for x in message['content'] if x['type'] == 'image'])\n            content = [x.get('text', '') for x in message['content'] if x['type'] == 'text']\n            prompt = content[0]\n            if IMAGE_TOKEN in prompt:\n                pass\n            else:\n                prompt = ''.join([f'Picture {str(i)}:{IMAGE_TOKEN}\\n' for i in range(n_images)]) + prompt\n            prompt_messages.append(dict(role='user', content=prompt))\n        prompt = chat_template.messages2prompt(prompt_messages, sequence_start)\n        return prompt, IMAGE_TOKEN\n\n    def to_pytorch(self, messages, chat_template, tokenizer, sequence_start, **kwargs):\n        prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, sequence_start)\n        return self.to_pytorch_aux(messages, prompt, IMAGE_TOKEN, tokenizer, sequence_start)\n\n    def to_turbomind(self, messages, chat_template, tokenizer, sequence_start, **kwargs):\n        prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, sequence_start)\n        return self.to_turbomind_aux(messages, prompt, IMAGE_TOKEN, tokenizer, sequence_start)\n"
  },
  {
    "path": "lmdeploy/vl/model/qwen2.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport torch\n\nfrom lmdeploy.vl.model.base import VISION_MODELS, VisionModel\nfrom lmdeploy.vl.model.utils import disable_logging\n\n\ndef check_qwen_vl_deps_install():\n    \"\"\"Check qwen_vl_utils.\"\"\"\n    try:\n        import qwen_vl_utils  # noqa: F401\n    except ImportError:\n        raise ImportError('please install qwen_vl_utils by `pip install qwen_vl_utils`'  # noqa: E501\n                          )\n    try:\n        from transformers import Qwen2VLForConditionalGeneration  # noqa: F401\n    except ImportError:\n        raise ImportError('please install latest transformers by '\n                          'pip install git+https://github.com/huggingface/transformers.git')\n\n\n@VISION_MODELS.register_module()\nclass Qwen2VLModel(VisionModel):\n    \"\"\"Qwen2VL model.\"\"\"\n\n    _arch = ['Qwen2VLForConditionalGeneration', 'Qwen2_5_VLForConditionalGeneration']\n\n    def build_preprocessor(self):\n        check_qwen_vl_deps_install()\n        from transformers import AutoProcessor\n        self.processor = AutoProcessor.from_pretrained(self.model_path)\n        tokenizer = self.processor.tokenizer\n        self.image_token = self.processor.image_token\n        self.image_token_id = tokenizer.encode(self.image_token)[-1]\n\n    def preprocess(self, messages: list[dict]) -> list[dict]:\n        \"\"\"Refer to `super().preprocess()` for spec.\"\"\"\n        from qwen_vl_utils import process_vision_info\n\n        images = self.collect_multimodal_items(messages)\n        optional_keys = {'resized_height', 'resized_width', 'min_pixels', 'max_pixels'}\n        outputs = []\n        for modality, image, params in images:\n            image = image.convert('RGB')\n\n            item = dict(type='image', image=image)\n            item.update({key: params[key] for key in params.keys() if key in optional_keys})\n            image_inputs, _ = process_vision_info([dict(content=[item])])\n            result = self.processor.image_processor(images=image_inputs, return_tensors='pt')\n            merge_length = self.processor.image_processor.merge_size**2\n            image_tokens = result['image_grid_thw'].prod(dim=1) // merge_length\n            result.update(dict(image_size=image.size, image_tokens=image_tokens, image_token_id=self.image_token_id))\n            outputs.append(result)\n        messages.append(dict(role='preprocess', content=outputs))\n        return messages\n\n    def build_model(self):\n        check_qwen_vl_deps_install()\n        arch = self.hf_config.architectures[0]\n        if arch == 'Qwen2VLForConditionalGeneration':\n            from transformers import Qwen2VLForConditionalGeneration as AutoModelCls\n        elif arch == 'Qwen2_5_VLForConditionalGeneration':\n            from transformers import Qwen2_5_VLForConditionalGeneration as AutoModelCls\n        else:\n            raise ValueError(f'Unsupported arch={arch}')\n\n        if self.with_llm:\n            self.vl_model = AutoModelCls.from_pretrained(self.model_path, device_map='cpu')\n        else:\n            from accelerate import init_empty_weights\n            with init_empty_weights():\n                config = self.hf_config\n                # disable accelerate check_tied_parameters_in_config for Qwen2-VL-2B-Instruct\n                config.tie_word_embeddings = False\n                if hasattr(config, 'text_config'):\n                    config.text_config.tie_word_embeddings = False\n                model = AutoModelCls._from_config(config)\n                model.visual = model.model.visual\n                del model.model\n                del model.lm_head\n                model.half()\n\n            from accelerate import load_checkpoint_and_dispatch\n            with disable_logging():\n                load_checkpoint_and_dispatch(model=model,\n                                             checkpoint=self.model_path,\n                                             device_map='auto' if not self.with_llm else {'': 'cpu'},\n                                             max_memory=self.max_memory,\n                                             no_split_module_classes=['Qwen2VLVisionBlock', 'Qwen2_5_VLVisionBlock'],\n                                             dtype=torch.half)\n            self.model = model.eval()\n\n    @torch.no_grad()\n    def forward(self, messages: list[dict], max_batch_size: int = 1) -> list[dict]:\n        \"\"\"Extract image feature. ONLY implement it when the backend is\n        turbomind engine.\n\n        Args:\n            messages(list[dict]): the outputs of `preprocess`\n            max_batch_size(int): the max batch size when forwarding vision\n                model\n        Return:\n            the message list with forwarding results included\n        \"\"\"\n        inputs = [x['content'] for x in messages if x['role'] == 'preprocess'][0]\n        dtype = torch.half\n        device = next(self.model.visual.parameters()).device\n        outputs = []\n        for idx in range(0, len(inputs), max_batch_size):\n            pixel_values = [x['pixel_values'].type(dtype) for x in inputs[idx:idx + max_batch_size]]\n            image_grid_thw = [x['image_grid_thw'] for x in inputs[idx:idx + max_batch_size]]\n            pixel_values = torch.cat(pixel_values, dim=0).to(device)\n            image_grid_thw = torch.cat(image_grid_thw, dim=0).to(device)\n            image_embeds = self.model.visual(pixel_values, grid_thw=image_grid_thw)\n            if hasattr(image_embeds, 'pooler_output'):\n                # transformers >= 5.0.0, the type if image_embeds is `BaseModelOutputWithPooling`\n                # rather than torch.Tensor\n                image_embeds = image_embeds.pooler_output\n            merge_length = self.processor.image_processor.merge_size**2\n            split_size = image_grid_thw.prod(dim=1) // merge_length\n            image_embeds = image_embeds.split(split_size.tolist())\n            outputs.extend(image_embeds)\n        messages.append(dict(role='forward', content=outputs))\n        return messages\n\n    def proc_messages(self, messages, chat_template, sequence_start, chat_template_kwargs=None):\n        \"\"\"Apply chat template to get the prompt.\"\"\"\n        chat_template_kwargs = chat_template_kwargs or {}\n        prompt_messages = []\n        IMAGE_TOKEN = '<IMAGE_TOKEN>'\n        messages = [x for x in messages if x['role'] not in ['preprocess', 'forward']]\n        if VisionModel.IMAGE_TOKEN_included(messages):\n            # backward compatibility\n            for message in messages:\n                role, content = message['role'], message['content']\n                if role != 'user' or isinstance(content, str):\n                    prompt_messages.append(message)\n                    continue\n                content = [x['text'] for x in content if x['type'] == 'text']\n                prompt = ''.join(content)\n                prompt = prompt.replace(IMAGE_TOKEN, f'<|vision_start|>{self.image_token}<|vision_end|>')\n                prompt_messages.append(dict(role='user', content=prompt))\n        else:\n            for message in messages:\n                role, content = message['role'], message['content']\n                if role != 'user' or isinstance(content, str):\n                    prompt_messages.append(message)\n                    continue\n                _content = []\n                for item in content:\n                    if item['type'] == 'text':\n                        _content.append(item['text'])\n                    elif item['type'] in ['image', 'image_url']:\n                        _content.append(f'<|vision_start|>{self.image_token}<|vision_end|>')\n                    else:\n                        raise ValueError(f'Unsupported message type: {item[\"type\"]}')\n                message = dict(role=role, content=''.join(_content))\n                prompt_messages.append(message)\n        prompt = chat_template.messages2prompt(prompt_messages, sequence_start)\n        return prompt, self.image_token\n\n    @staticmethod\n    def get_mrope_info(seq_len: int,\n                       grid_thws: list[tuple[int, int, int]] = None,\n                       ranges: list[tuple[int, int]] = None):\n        mrope_position_ids = [torch.arange(ranges[0][0]).expand(3, -1)]\n        st_idx = ranges[0][0]\n        for i, (grid_thw, embedding_range) in enumerate(zip(grid_thws, ranges)):\n            llm_grid_t, llm_grid_h, llm_grid_w = grid_thw\n            llm_grid_h //= 2\n            llm_grid_w //= 2\n            t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten()\n            h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()\n            w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()\n            mrope_position_ids.append(torch.stack([t_index, h_index, w_index]) + st_idx)\n            st_idx += max(llm_grid_h, llm_grid_w)\n            if i < len(ranges) - 1:\n                text_len = ranges[i + 1][0] - ranges[i][1]\n            else:\n                text_len = seq_len - embedding_range[1]\n            mrope_position_ids.append(torch.arange(text_len).expand(3, -1) + st_idx)\n            st_idx += text_len\n        mrope_position_ids = torch.cat(mrope_position_ids, dim=-1)\n        mrope_position_delta = torch.tensor([st_idx - seq_len], dtype=torch.long)\n        return mrope_position_ids, mrope_position_delta\n\n    def to_pytorch(self, messages, chat_template, tokenizer, sequence_start, chat_template_kwargs=None, **kwargs):\n        \"\"\"Return to the information needed by pytorch engine.\"\"\"\n        prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, sequence_start, chat_template_kwargs)\n        return self.to_pytorch_aux(messages, prompt, IMAGE_TOKEN, tokenizer, sequence_start)\n\n    def to_turbomind(self, messages, chat_template, tokenizer, sequence_start, chat_template_kwargs=None, **kwargs):\n        prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, sequence_start, chat_template_kwargs)\n        info = super().to_turbomind_aux(messages, prompt, IMAGE_TOKEN, tokenizer, sequence_start)\n        inputs = [x['content'] for x in messages if x['role'] == 'preprocess'][0]\n        grid_thws = [x['image_grid_thw'].tolist()[0] for x in inputs]\n        seq_len = len(info['input_ids'])\n        ranges = info['input_embedding_ranges']\n        mrope_position_ids, mrope_position_delta = self.get_mrope_info(seq_len, grid_thws, ranges)\n        meta = dict(mrope_position_ids=mrope_position_ids, mrope_position_delta=mrope_position_delta)\n        info.update(dict(input_meta=meta))\n        return info\n"
  },
  {
    "path": "lmdeploy/vl/model/qwen3.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom typing import Any, Dict, List\n\nimport torch\nfrom transformers import AutoProcessor\n\nfrom lmdeploy.utils import get_logger\nfrom lmdeploy.vl.constants import Modality\nfrom lmdeploy.vl.model.base import VISION_MODELS, VisionModel\n\nlogger = get_logger('lmdeploy')\n\n\ndef check_transformers():\n    try:\n        from transformers import Qwen3VLForConditionalGeneration, Qwen3VLMoeForConditionalGeneration  # noqa: F401\n    except ImportError:\n        raise ImportError('please install latest transformers by '\n                          'pip install git+https://github.com/huggingface/transformers.git')\n\n\n@VISION_MODELS.register_module()\nclass Qwen3VLModel(VisionModel):\n    \"\"\"Qwen3VL model.\"\"\"\n\n    _arch = ['Qwen3VLForConditionalGeneration', 'Qwen3VLMoeForConditionalGeneration']\n\n    def build_preprocessor(self):\n        check_transformers()\n        self.processor = AutoProcessor.from_pretrained(self.model_path)\n\n        # image tokens\n        self.image_token = self.processor.image_token\n        self.image_token_id = self.processor.image_token_id\n\n        # video tokens\n        self.video_token = self.processor.video_token\n        self.video_token_id = self.processor.video_token_id\n\n        # vision start and end tokens\n        self.vision_start_token = self.processor.vision_start_token\n        self.vision_end_token = self.processor.vision_end_token\n\n    def get_processor_args(self, mm_processor_kwargs: Dict[str, Any] | None = None):\n        min_pixels = self.processor.image_processor.size['shortest_edge']\n        max_pixels = self.processor.image_processor.size['longest_edge']\n\n        if mm_processor_kwargs is None:\n            return min_pixels, max_pixels\n\n        input_min_pixels = mm_processor_kwargs.get('min_pixels', None)\n        input_max_pixels = mm_processor_kwargs.get('max_pixels', None)\n\n        # boundary check for min_pixels and max_pixels\n        if input_min_pixels is None:\n            if input_max_pixels is not None:\n                # only max_pixels is given in the input\n                if input_max_pixels < min_pixels:\n                    logger.warning(\n                        f'input max_pixels {input_max_pixels} < default min_pixels {min_pixels}, fall back to default.')\n                    return min_pixels, max_pixels\n                max_pixels = input_max_pixels\n        else:\n            if input_max_pixels is None:\n                # only min_pixels is given in the input\n                if input_min_pixels > max_pixels:\n                    logger.warning(\n                        f'input min_pixels {input_min_pixels} > default max_pixels {max_pixels}, fall back to default.')\n                    return min_pixels, max_pixels\n            else:\n                if input_min_pixels > input_max_pixels:\n                    logger.warning(\n                        f'input min_pixels {input_min_pixels} > max_pixels {input_max_pixels}, fall back to default.')\n                    return min_pixels, max_pixels\n                max_pixels = input_max_pixels\n            min_pixels = input_min_pixels\n\n        return min_pixels, max_pixels\n\n    def _preprocess_image(self,\n                          data: List[Any],\n                          params: Dict[str, Any],\n                          mm_processor_kwargs: Dict[str, Any] | None = None) -> List[Dict]:\n\n        image = data.convert('RGB')\n        min_pixels, max_pixels = self.get_processor_args(mm_processor_kwargs)\n\n        result = self.processor.image_processor(images=image,\n                                                size={\n                                                    'shortest_edge': min_pixels,\n                                                    'longest_edge': max_pixels\n                                                },\n                                                return_tensors='pt')\n        merge_length = self.processor.image_processor.merge_size**2\n        image_tokens = result['image_grid_thw'].prod(dim=1) // merge_length\n        result.update(dict(image_size=image.size, image_tokens=image_tokens, image_token_id=self.image_token_id))\n        return result\n\n    def _preprocess_video(self,\n                          data: List[Any],\n                          params: Dict[str, Any],\n                          mm_processor_kwargs: Dict[str, Any] | None = None) -> List[Dict]:\n\n        # TODO: zhouxinyu, apply transformers smart_resize using per-request kwargs\n        metadata = params['video_metadata']\n        video_kwargs = dict(return_metadata=True,\n                            do_resize=True,\n                            do_sample_frames=False,\n                            video_metadata=metadata,\n                            return_tensors='pt')\n        result = self.processor.video_processor(videos=data, **video_kwargs)\n        video_grid_thw = result['video_grid_thw']\n\n        merge_length = self.processor.video_processor.merge_size**2\n        if metadata.get('fps') is None:\n            logger.warning_once('Qwen3VL: fps not found, defaulting to 24.')\n            metadata['fps'] = metadata['fps'] or 24\n\n        # if timestamps are not provided, calculate them\n        curr_timestamp = self.processor._calculate_timestamps(\n            metadata['frames_indices'],\n            metadata['fps'],\n            self.processor.video_processor.merge_size,\n        )\n\n        frame_seqlen = video_grid_thw[0][1:].prod() // merge_length\n        result.update(curr_timestamp=curr_timestamp, frame_seqlen=frame_seqlen, video_token_id=self.video_token_id)\n        return result\n\n    def preprocess(self, messages: List[Dict], mm_processor_kwargs: Dict[str, Any] | None = None) -> List[Dict]:\n        \"\"\"Refer to `super().preprocess()` for spec.\"\"\"\n        outputs = []\n        self.contains_video_input = False\n\n        mm_items = self.collect_multimodal_items(messages)\n        for modality, data, params in mm_items:\n            result = {}\n            if modality == Modality.IMAGE:\n                result = self._preprocess_image(data, params, mm_processor_kwargs)\n            elif modality == Modality.VIDEO:\n                self.contains_video_input = True\n                result = self._preprocess_video(data, params, mm_processor_kwargs)\n\n            result.update(modality=modality)\n            outputs.append(result)\n\n        messages.append(dict(role='preprocess', content=outputs))\n        return messages\n\n    def proc_messages(self, messages, chat_template, sequence_start, chat_template_kwargs=None):\n        \"\"\"Apply chat template to get the prompt.\"\"\"\n        chat_template_kwargs = chat_template_kwargs or {}\n        prompt_messages = []\n        IMAGE_TOKEN = '<IMAGE_TOKEN>'\n        messages = [x for x in messages if x['role'] not in ['preprocess', 'forward']]\n        if VisionModel.IMAGE_TOKEN_included(messages):\n            # backward compatibility\n            for message in messages:\n                role, content = message['role'], message['content']\n                if role != 'user' or isinstance(content, str):\n                    prompt_messages.append(message)\n                    continue\n                content = [x['text'] for x in content if x['type'] == 'text']\n                prompt = ''.join(content)\n                prompt = prompt.replace(IMAGE_TOKEN, f'<|vision_start|>{self.image_token}<|vision_end|>')\n                prompt_messages.append(dict(role='user', content=prompt))\n        else:\n            prompt_messages = messages\n        prompt = chat_template.messages2prompt(prompt_messages, sequence_start, **chat_template_kwargs)\n        return prompt, None\n\n    def to_pytorch_aux_video(self, messages, prompt, VIDEO_TOKEN, tokenizer, sequence_start):\n        \"\"\"Pack the video input to the compatible format with pytorch\n        engine.\"\"\"\n\n        # collect all preprocessing result from messages\n        preps = [x['content'] for x in messages if x['role'] == 'preprocess']\n        assert len(preps) == 1\n        preps = preps[0]\n\n        # split prompt into segments and validate data\n        segs = prompt.split(self.vision_start_token + self.video_token + self.vision_end_token)\n        assert len(segs) == len(preps) + 1, (f'the number of {self.video_token} is not equal '\n                                             f'to input videos, {len(segs) - 1} vs {len(preps)}')\n\n        # calculate the video token offset for each video\n        input_ids = []\n        for i, seg in enumerate(segs):\n            if i > 0 and i <= len(preps):\n                preps[i - 1].update(offset=len(input_ids))\n                frame_seqlen = preps[i - 1]['frame_seqlen']\n                assert self.video_token_id == preps[i - 1]['video_token_id']\n\n                video_grid_thw = preps[i - 1]['video_grid_thw']\n                curr_timestamp = preps[i - 1]['curr_timestamp']\n\n                # update prompt with timestamp index tokens and video pad tokens\n                video_placeholder = ''\n                for frame_idx in range(video_grid_thw[0][0]):\n                    curr_time = curr_timestamp[frame_idx]\n                    video_placeholder += f'<{curr_time:.1f} seconds>'\n                    video_placeholder += (self.vision_start_token + '<|placeholder|>' * frame_seqlen +\n                                          self.vision_end_token)\n\n                video_placeholder = video_placeholder.replace('<|placeholder|>', self.video_token)\n                video_token_ids = tokenizer.encode(video_placeholder)\n                input_ids.extend(video_token_ids)\n\n                preps[i - 1].update(video_tokens=len(video_token_ids))\n\n            token_ids = tokenizer.encode(seg, add_bos=((i == 0) and sequence_start))\n            input_ids.extend(token_ids)\n\n        return dict(prompt=prompt, input_ids=input_ids, multimodal=preps)\n\n    def to_pytorch(self,\n                   messages,\n                   chat_template,\n                   tokenizer,\n                   sequence_start,\n                   chat_template_kwargs: Dict | None = None,\n                   **kwargs):\n        \"\"\"Return to the information needed by pytorch engine.\"\"\"\n        prompt, _ = self.proc_messages(messages, chat_template, sequence_start, chat_template_kwargs)\n\n        if self.contains_video_input:\n            return self.to_pytorch_aux_video(messages, prompt, self.video_token, tokenizer, sequence_start)\n        else:\n            return self.to_pytorch_aux(messages, prompt, self.image_token, tokenizer, sequence_start)\n\n    def build_model(self):\n        # TODO: implement for turbomind\n        pass\n\n    @torch.no_grad()\n    def forward(self, messages: List[Dict], max_batch_size: int = 1) -> List[Dict]:\n        # TODO: implement for turbomind\n        pass\n\n    def to_turbomind(self,\n                     messages,\n                     chat_template,\n                     tokenizer,\n                     sequence_start,\n                     chat_template_kwargs: Dict | None = None,\n                     **kwargs):\n        # TODO: implement for turbomind\n        pass\n"
  },
  {
    "path": "lmdeploy/vl/model/qwen3_5.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom transformers import AutoProcessor\n\nfrom lmdeploy.utils import get_logger\nfrom lmdeploy.vl.model.base import VISION_MODELS\n\nfrom .qwen3 import Qwen3VLModel\n\nlogger = get_logger('lmdeploy')\n\n\ndef check_transformers():\n    try:\n        from transformers import Qwen3_5ForConditionalGeneration, Qwen3_5MoeForConditionalGeneration  # noqa: F401\n    except ImportError:\n        raise ImportError('please install latest transformers by '\n                          'pip install git+https://github.com/huggingface/transformers.git')\n\n\n@VISION_MODELS.register_module()\nclass Qwen3_5Model(Qwen3VLModel):\n    \"\"\"Qwen3_5 model.\"\"\"\n\n    _arch = ['Qwen3_5ForConditionalGeneration', 'Qwen3_5MoeForConditionalGeneration']\n\n    def build_preprocessor(self):\n        check_transformers()\n\n        self.processor = AutoProcessor.from_pretrained(self.model_path)\n\n        # image tokens\n        self.image_token = self.processor.image_token\n        self.image_token_id = self.processor.image_token_id\n\n        # video tokens\n        self.video_token = self.processor.video_token\n        self.video_token_id = self.processor.video_token_id\n\n        # vision start and end tokens\n        self.vision_start_token = self.processor.vision_start_token\n        self.vision_end_token = self.processor.vision_end_token\n"
  },
  {
    "path": "lmdeploy/vl/model/utils.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nimport inspect\nfrom contextlib import contextmanager\nfrom typing import Callable, MutableSequence\n\nimport torch\n\n\n@contextmanager\ndef disable_transformers_logging():\n    import transformers\n    from transformers.utils import logging\n    previous_level = logging.get_verbosity()\n    logging.set_verbosity(transformers.logging.ERROR)\n    yield\n    logging.set_verbosity(previous_level)\n\n\n@contextmanager\ndef disable_logging():\n    import logging\n    previous_level = logging.root.manager.disable\n    logging.disable(logging.ERROR)\n    yield\n    logging.disable(previous_level)\n\n\ndef _set_func(origin_func_path: str | None, rewrite_func: Callable, origin_func: Callable = None):\n    \"\"\"Replace old function with the new function.\n\n    Args:\n        origin_func_path (str): original function path\n        rewrite_func (Callable): function to replace with\n        origin_func (Callable): function to replace\n    \"\"\"\n    # import module\n    if isinstance(origin_func_path, str):\n        split_path = origin_func_path.split('.')\n        for i in range(len(split_path), 0, -1):\n            try:\n                exec('import {}'.format('.'.join(split_path[:i])))\n                break\n            except Exception:\n                continue\n\n        origin_func = eval(origin_func_path) \\\n            if origin_func is None else origin_func\n\n    method_class = inspect.ismethod(origin_func)\n\n    # replace method\n    if not method_class:\n        import gc\n        refs = gc.get_referrers(origin_func)\n        obj_id = id(origin_func)\n        for ref in refs:\n            if isinstance(ref, dict):\n                for x, y in ref.items():\n                    if id(y) == obj_id:\n                        ref[x] = rewrite_func\n            elif isinstance(ref, MutableSequence):\n                for i, v in enumerate(ref):\n                    if id(v) == obj_id:\n                        ref[i] = rewrite_func\n    if isinstance(origin_func_path, str):\n        exec(f'{origin_func_path} = rewrite_func')\n    elif method_class:\n        raise NotImplementedError\n\n    return origin_func\n\n\n@contextmanager\ndef rewrite_ctx(origin_func_path: list[str | Callable], rewrite_func: list[Callable]):\n    \"\"\"Rewrite context.\"\"\"\n    assert len(origin_func_path) == len(rewrite_func)\n    origin_func_list = []\n    for (func_path, dst_func) in zip(origin_func_path, rewrite_func):\n        if isinstance(func_path, Callable):\n            origin_func = _set_func(None, dst_func, func_path)\n        else:\n            origin_func = _set_func(func_path, dst_func)\n        origin_func_list.append(origin_func)\n    yield\n    for (func_path, dst_func, origin_func) in zip(origin_func_path, rewrite_func, origin_func_list):\n        if isinstance(func_path, Callable):\n            _set_func(None, origin_func, dst_func)\n        else:\n            _set_func(func_path, origin_func, dst_func)\n\n\ndef add_device_hook(module: torch.nn.Module, device: torch.device, fn: Callable = None):\n    \"\"\"Add device hook.\"\"\"\n    from accelerate.hooks import ModelHook, add_hook_to_module\n\n    class ToDevice(ModelHook):\n        \"\"\"ToDevice hook.\"\"\"\n\n        def __init__(self, device):\n            self.device = device\n\n        def post_forward(self, module, output):\n            if fn is not None:\n                output = fn(output)\n            else:\n                output = output.to(device=self.device)\n            return output\n\n    add_hook_to_module(module=module, hook=ToDevice(device=device), append=True)\n"
  },
  {
    "path": "lmdeploy/vl/model/xcomposer2.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nimport enum\nimport os\nimport sys\nimport warnings\nfrom contextlib import contextmanager\nfrom typing import Any, Dict, List, Tuple\n\nimport torch\nfrom PIL.Image import Image\nfrom transformers import AutoConfig, AutoModelForCausalLM\n\nfrom lmdeploy.utils import get_logger\nfrom lmdeploy.vl.model.base import VISION_MODELS, VisionModel\nfrom lmdeploy.vl.model.utils import add_device_hook, disable_logging, rewrite_ctx\n\nlogger = get_logger('lmdeploy')\n\n\ndef check_xcomposer_install():\n    try:\n        # WARNING! we have to do this otherwise the model_type is wrong for\n        # xcomposer2d5\n        import decord  # noqa: F401\n    except ImportError:\n        raise ImportError(\"No module named 'decord'. Please install decord by `pip install decord`\"  # noqa\n                          )\n\n\nclass ModelType(enum.Enum):\n    \"\"\"Request type.\"\"\"\n    XCOMPOSER2 = enum.auto()\n    XCOMPOSER2_4KHD = enum.auto()\n    XCOMPOSER2D5 = enum.auto()\n\n\ndef get_xcomposer_type(model_path: str) -> Tuple[ModelType, Any]:\n    \"\"\"Get xcomposer type.\"\"\"\n    from transformers.dynamic_module_utils import get_class_from_dynamic_module\n    match_modules = {\n        'ixc_utils.Image_transform': ModelType.XCOMPOSER2D5,\n        'ixc_utils.HD_transform': ModelType.XCOMPOSER2_4KHD\n    }\n    for key, value in match_modules.items():\n        try:\n            module = get_class_from_dynamic_module(key, model_path)\n            return value, module\n        except Exception:\n            pass\n    return ModelType.XCOMPOSER2, None\n\n\ndef _CLIPVisionModel_from_pretrained(vision_tower_name):\n    from transformers import CLIPVisionConfig, CLIPVisionModel\n    config = CLIPVisionConfig.from_pretrained(vision_tower_name)\n    model = CLIPVisionModel._from_config(config)\n    return model\n\n\n@contextmanager\ndef init_empty_vit(model_path):\n    \"\"\"Skip download vision model.\"\"\"\n    origin_func_path = [\n        'transformers.CLIPVisionModel.from_pretrained',\n    ]\n    rewrite_func = [\n        _CLIPVisionModel_from_pretrained,\n    ]\n\n    model_type, _ = get_xcomposer_type(model_path)\n    if model_type == ModelType.XCOMPOSER2D5:\n        from transformers.dynamic_module_utils import get_class_from_dynamic_module\n        from transformers.utils import TRANSFORMERS_DYNAMIC_MODULE_NAME\n        _ = get_class_from_dynamic_module('modeling_internlm_xcomposer2.get_font', model_path)\n        folder = model_path.rstrip(os.sep).split(os.sep)[-1]\n        module_path = '.'.join([TRANSFORMERS_DYNAMIC_MODULE_NAME, folder, 'modeling_internlm_xcomposer2'])\n        origin_get_font_func = getattr(sys.modules[module_path], 'get_font')\n        origin_func_path.append(origin_get_font_func)\n        rewrite_func.append(lambda: None)\n\n    with rewrite_ctx(origin_func_path, rewrite_func):\n        yield\n\n\n@VISION_MODELS.register_module()\nclass Xcomposer2VisionModel(VisionModel):\n    \"\"\"InternLM-Xcomposer2 vision model.\"\"\"\n\n    def __init__(self,\n                 model_path: str,\n                 with_llm: bool = False,\n                 max_memory: Dict[int, int] = None,\n                 hf_config: AutoConfig = None,\n                 backend: str = ''):\n        model_path = model_path.rstrip(os.sep)\n        super().__init__(model_path, with_llm, max_memory, hf_config, backend)\n        check_xcomposer_install()\n        self.model_type, self.module = get_xcomposer_type(self.model_path)\n        logger.info(f'matching type of {self.model_type}')\n\n    @classmethod\n    def match(cls, config: AutoConfig):\n        \"\"\"Check whether the config match the model.\"\"\"\n        arch = config.architectures[0] if config.architectures else None\n        target = 'InternLMXComposer2ForCausalLM'\n        if arch == target:\n            return True\n        for _, v in getattr(config, 'auto_map', {}).items():\n            if target in v:\n                return True\n        return False\n\n    def build_preprocessor(self):\n\n        import torchvision.transforms as transforms\n        from torchvision.transforms.functional import InterpolationMode\n\n        if self.model_type in [ModelType.XCOMPOSER2D5, ModelType.XCOMPOSER2_4KHD]:\n            self.HD_transform = self.module\n            self.vis_processor = transforms.Compose([\n                transforms.ToTensor(),\n                transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),\n            ])\n            self.preprocess_func = (self._preprocess_2d5\n                                    if self.model_type == ModelType.XCOMPOSER2D5 else self._preprocess_4khd_7b)\n        else:\n            self.vis_processor = transforms.Compose([\n                transforms.Resize((self.hf_config.img_size, self.hf_config.img_size),\n                                  interpolation=InterpolationMode.BICUBIC),\n                transforms.ToTensor(),\n                transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),\n            ])\n            self.preprocess_func = self._preprocess_7b\n\n    def build_model(self):\n        \"\"\"Build the vision part of a VLM model when backend is turbomind, or\n        load the whole VLM model when `self.with_llm==True`\"\"\"\n        from accelerate import init_empty_weights\n        with init_empty_weights(), warnings.catch_warnings(), \\\n                init_empty_vit(self.model_path):\n            warnings.simplefilter('ignore')\n            config = self.hf_config\n            model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)\n            model.vit.load_model()\n            model.vit.resize_pos()\n            if hasattr(self.hf_config, 'img_size'):\n                model.vit.vision_tower.vision_model.embeddings.image_size = \\\n                    self.hf_config.img_size\n            model.vit.vision_tower.vision_model.post_layernorm.to_empty(device='cpu').half()\n            self.vl_model = model\n            if not self.with_llm:\n                del model.model\n                del model.output\n\n        from accelerate.utils import get_balanced_memory, infer_auto_device_map\n        max_memory = get_balanced_memory(model,\n                                         max_memory=self.max_memory,\n                                         dtype=torch.half,\n                                         no_split_module_classes=['CLIPEncoderLayer'])\n        device_map = infer_auto_device_map(model,\n                                           no_split_module_classes=['CLIPEncoderLayer'],\n                                           max_memory=max_memory,\n                                           dtype=torch.half)\n        # make all tensor on same device for postprocess\n        if 'plora_glb_GN' in device_map:\n            device_map['plora_sub_GN'] = device_map['plora_glb_GN']\n\n        from accelerate import load_checkpoint_and_dispatch\n        with disable_logging():\n            load_checkpoint_and_dispatch(model=model,\n                                         checkpoint=self.model_path,\n                                         device_map=device_map if not self.with_llm else {'': 'cpu'},\n                                         no_split_module_classes=['CLIPEncoderLayer'],\n                                         dtype=torch.half)\n\n        if 'plora_glb_GN' in device_map:\n            add_device_hook(model.vit.vision_tower.vision_model.encoder.layers[-1], device_map['plora_glb_GN'],\n                            lambda x: (x[0].to(device=device_map['plora_glb_GN']), ))\n\n        self.model = model.eval()\n\n    def _preprocess_2d5(self, image: Image, params: Dict) -> Dict:\n        \"\"\"Image preprocessing for internlm-xcomposer2d5-7b.\"\"\"\n        hd_num = params.get('hd_num', 24)\n        image = self.HD_transform(image, hd_num=hd_num)\n        pixel_values = self.vis_processor(image).unsqueeze(0).half()\n        w, h = image.size\n        w, h = w // 560, h // 560\n        n_token_per_image = int((h * w + 1) * 400 + 1 + (h + 1) * 20)\n        return pixel_values, n_token_per_image\n\n    def _preprocess_7b(self, image: Image, params: Dict) -> Dict:\n        \"\"\"Image preprocessing for internlm-xcomposer2-7b.\"\"\"\n        pixel_values = self.vis_processor(image).unsqueeze(0).half()\n        return pixel_values, 256\n\n    def _preprocess_4khd_7b(self, image: Image, params: Dict) -> Dict:\n        \"\"\"Image preprocessing for internlm-xcomposer2-4khd-7b.\"\"\"\n        image = self.HD_transform(image, hd_num=25)\n        pixel_values = self.vis_processor(image).unsqueeze(0).half()\n        w, h = image.size\n        w, h = w // 336, h // 336\n        n_token_per_image = int((h * w + 1) * 144 + 1 + (h + 1) * 12)\n        return pixel_values, n_token_per_image\n\n    def preprocess(self, messages: List[Dict]) -> List[Dict]:\n        \"\"\"Refer to `super().preprocess() for spec.\"\"\"\n        images = self.collect_multimodal_items(messages)\n        outputs = []\n        for modality, image, params in images:\n            image = image.convert('RGB')\n            pixel_values, n_token = self.preprocess_func(image, params)\n            outputs.append(\n                dict(pixel_values=pixel_values,\n                     image_size=image.size,\n                     image_tokens=n_token,\n                     image_token_id=self.image_token_id))\n        messages.append(dict(role='preprocess', content=outputs))\n        return messages\n\n    @torch.no_grad()\n    def forward(self, messages: List[Dict], max_batch_size: int = 1) -> List[Dict]:\n        \"\"\"Extract image feature. ONLY implement it when the backend is\n        turbomind engine.\n\n        Args:\n            messages(List[Dict]): the outputs of `preprocess`\n            max_batch_size(int): the max batch size when forwarding vision\n                model\n        Return:\n            the message list with forwarding results included\n        \"\"\"\n        inputs = [x['content'] for x in messages if x['role'] == 'preprocess']\n        inputs = inputs[0]\n        outputs = []\n        for idx in range(0, len(inputs), max_batch_size):\n            if self.model_type in [ModelType.XCOMPOSER2D5, ModelType.XCOMPOSER2_4KHD]:\n                pixel_values = [x['pixel_values'] for x in inputs[idx:idx + max_batch_size]]\n                embeds, split = self.model.vit(pixel_values, self.model.plora_glb_GN, self.model.plora_sub_GN)\n                embeds = self.model.vision_proj(embeds)\n                embeds = torch.split(embeds, split, dim=1)\n                embeds = [x.squeeze() for x in embeds]\n            else:\n                pixel_values = [x['pixel_values'] for x in inputs[idx:idx + max_batch_size]]\n                pixel_values = torch.cat(pixel_values, dim=0)\n                logger.info(f'vision forward shape: {pixel_values.shape}')\n                embeds = self.model.vit(pixel_values)\n                embeds = self.model.vision_proj(embeds)\n                embeds = torch.split(embeds, 1, dim=0)\n                embeds = [x.squeeze() for x in embeds]\n            outputs.extend(embeds)\n        messages.append(dict(role='forward', content=outputs))\n        return messages\n\n    @staticmethod\n    def proc_messages(messages, chat_template, sequence_start, model_type):\n        \"\"\"Apply chat template to get the prompt.\"\"\"\n        prompt_messages = []\n        IMAGE_TOKEN = '<IMAGE_TOKEN>'\n        prefix_image_token = ''\n        for message in messages:\n            if isinstance(message['content'], str):\n                prompt_messages.append(message)\n                continue\n            elif message['role'] in ['images', 'preprocess', 'forward']:\n                continue\n            n_images = len([1 for x in message['content'] if x['type'] == 'image'])\n            content = [item['text'] for item in message['content'] if item['type'] == 'text']\n            if IMAGE_TOKEN not in content[0]:\n                if model_type == ModelType.XCOMPOSER2D5:\n                    if n_images == 1:\n                        prefix_image_token, prompt = IMAGE_TOKEN, content[0]\n                    else:\n                        prompt = ''.join([f'Image{i+1} {IMAGE_TOKEN}; ' for i in range(n_images)]) + content[0]\n                else:\n                    prompt = ''.join([IMAGE_TOKEN] * n_images) + content[0]\n            else:\n                prompt = content[0]\n            prompt_messages.append(dict(role='user', content=prompt))\n        prompt = prefix_image_token + chat_template.messages2prompt(prompt_messages, sequence_start)\n        return prompt, IMAGE_TOKEN\n\n    def to_pytorch(self, messages, chat_template, tokenizer, sequence_start, **kwargs):\n        prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, sequence_start, self.model_type)\n        return self.to_pytorch_aux(messages, prompt, IMAGE_TOKEN, tokenizer, sequence_start)\n\n    def to_turbomind(self, messages, chat_template, tokenizer, sequence_start, **kwargs):\n        prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, sequence_start, self.model_type)\n        return self.to_turbomind_aux(messages, prompt, IMAGE_TOKEN, tokenizer, sequence_start)\n"
  },
  {
    "path": "lmdeploy/vl/model/yi.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nimport os\nfrom contextlib import contextmanager\nfrom os import path as osp\nfrom typing import Dict, List\n\nimport torch.nn as nn\nfrom transformers import AutoConfig\n\nfrom lmdeploy.vl.model.base import VISION_MODELS\nfrom lmdeploy.vl.model.llava import LlavaVisionModel, check_llava_install, process_images\n\nfrom .utils import disable_transformers_logging, rewrite_ctx\n\n_model_path = None\n\n\ndef _build_vision_projector(config, delay_load=False, **kwargs):\n    \"\"\"Build yi projector.\"\"\"\n    # copy from https://github.com/01-ai/Yi/blob/main/VL/llava/model/multimodal_projector/builder.py # noqa: E501\n    projector_type = getattr(config, 'mm_projector_type', 'linear')\n\n    if projector_type == 'linear':\n        return nn.Linear(config.mm_hidden_size, config.hidden_size)\n\n    import re\n    use_norm = False\n    if '_Norm' in projector_type:\n        use_norm = True\n        projector_type = projector_type.replace('_Norm', '')\n    mlp_gelu_match = re.match(r'^mlp(\\d+)x_gelu$', projector_type)\n    if mlp_gelu_match:\n        mlp_depth = int(mlp_gelu_match.group(1))\n        if use_norm:\n            modules = [\n                nn.Linear(config.mm_hidden_size, config.hidden_size),\n                nn.LayerNorm(config.hidden_size),\n            ]\n        else:\n            modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]\n        for _ in range(1, mlp_depth):\n            modules.append(nn.GELU())\n            if use_norm:\n                modules.append(nn.Linear(config.hidden_size, config.hidden_size))\n                modules.append(nn.LayerNorm(config.hidden_size))\n            else:\n                modules.append(nn.Linear(config.hidden_size, config.hidden_size))\n        return nn.Sequential(*modules)\n\n    if projector_type == 'identity':\n        return nn.Identity()\n\n    raise ValueError(f'Unknown projector type: {projector_type}')\n\n\ndef _build_vision_tower(vision_tower_cfg, **kwargs):\n    \"\"\"Build yi vision tower.\"\"\"\n    cfg = vision_tower_cfg\n    vision_tower = getattr(cfg, 'mm_vision_tower', getattr(cfg, 'vision_tower', None))\n    if os.path.exists(os.path.join(_model_path, vision_tower)):\n        vision_tower = os.path.join(_model_path, vision_tower)\n\n    from llava.model.multimodal_encoder.clip_encoder import CLIPVisionTower\n    is_absolute_path_exists = os.path.exists(vision_tower)\n    if is_absolute_path_exists or vision_tower.startswith('openai') or vision_tower.startswith(\n            'laion') or 'ShareGPT4V' in vision_tower:\n        return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)\n\n    raise ValueError(f'Unknown vision tower: {vision_tower}')\n\n\n@contextmanager\ndef init_yi_model():\n    origin_func_path = [\n        'llava.model.multimodal_projector.builder.build_vision_projector',\n        'llava.model.multimodal_encoder.builder.build_vision_tower'\n    ]\n    rewrite_func = [_build_vision_projector, _build_vision_tower]\n    with rewrite_ctx(origin_func_path, rewrite_func):\n        yield\n\n\n@VISION_MODELS.register_module()\nclass YiVisionModel(LlavaVisionModel):\n    \"\"\"Yi visual model.\"\"\"\n\n    @classmethod\n    def match(cls, config: AutoConfig):\n        \"\"\"Check whether the config match the model.\"\"\"\n        arch = config.architectures[0] if config.architectures else None\n        if arch == 'LlavaLlamaForCausalLM':\n            projector_type = getattr(config, 'mm_projector_type', 'linear')\n            if '_Norm' in projector_type:\n                return True\n        return False\n\n    def build_preprocessor(self):\n        from transformers import CLIPImageProcessor\n        vision_tower_name = osp.join(self.model_path, self.hf_config.mm_vision_tower)\n        self.image_processor = CLIPImageProcessor.from_pretrained(vision_tower_name)\n        config = AutoConfig.from_pretrained(vision_tower_name)\n        image_size = config.image_size\n        patch_size = config.patch_size\n        self.n_token_per_image = (image_size // patch_size)**2\n        if self.hf_config.mm_vision_select_feature == 'cls_patch':\n            self.n_token_per_image += 1\n\n    def build_model(self):\n        \"\"\"Build the vision part of a VLM model when backend is turbomind, or\n        load the whole VLM model when `self.with_llm==True`\"\"\"\n        check_llava_install()\n\n        global _model_path\n        _model_path = self.model_path\n\n        with init_yi_model(), disable_transformers_logging():\n            super().build_model()\n\n    def preprocess(self, messages: List[Dict]) -> List[Dict]:\n        \"\"\"Refer to `super().preprocess() for spec.\"\"\"\n        images = self.collect_multimodal_items(messages)\n        outputs = []\n        for modality, image, params in images:\n            image = image.convert('RGB')\n            pixel_values = process_images([image], self.image_processor, self.config)\n            outputs.append(\n                dict(pixel_values=pixel_values,\n                     image_size=image.size,\n                     image_tokens=self.n_token_per_image,\n                     image_token_id=self.image_token_id))\n        messages.append(dict(role='preprocess', content=outputs))\n        return messages\n"
  },
  {
    "path": "lmdeploy/vl/tools/__init__.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n"
  },
  {
    "path": "lmdeploy/vl/tools/merge_xcomposer2d5_task.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nimport os\nimport shutil\n\nimport fire\nimport torch\nfrom tqdm import tqdm\nfrom transformers import AutoModelForCausalLM, AutoTokenizer\n\n\ndef main(src_path: str, dst_path: str, task: str):\n    \"\"\"Merge internlm-xcomposer2d5-7b LoRA model weights.\n\n    Args:\n        src_path (str): the source model path of internlm-xcomposer2d5-7b\n        dst_path (str): the target model path of merged model\n        task (str): the task of source model, should choose from\n            ['web', 'write']\n    \"\"\"\n    if os.path.exists(dst_path):\n        shutil.rmtree(dst_path)\n\n    to_merged = dict(web=['lora_web'], write=['lora_sft', 'lora_dpo'])\n    keys = to_merged[task]\n\n    # load model\n    model = AutoModelForCausalLM.from_pretrained(src_path, trust_remote_code=True)\n    tokenizer = AutoTokenizer.from_pretrained(src_path, trust_remote_code=True)\n\n    # merge lora weight to base model\n    @torch.inference_mode\n    def _merge(module: torch.nn.Module, lora_weights):\n        # merge lora weight first to reduce precision loss\n        mw = None\n        for wa, wb in lora_weights:\n            if mw is None:\n                mw = (wb.float() @ wa.float())\n            else:\n                mw += (wb.float() @ wa.float())\n        ow = module.weight\n        mw += ow.float()\n        module.weight.data = mw.half()\n\n    def _extract_lora(module: torch.nn.Module, keys: str):\n        lora_weights = []\n        for key in keys:\n            lora_a_key = f'{key}_A'\n            lora_b_key = f'{key}_B'\n            wa = getattr(module, lora_a_key).weight\n            wb = getattr(module, lora_b_key).weight\n            lora_weights.append((wa, wb))\n        return lora_weights\n\n    for _, module in tqdm(model.named_modules()):\n        if type(module).__name__ == 'PLoRA':\n            lora_weights = _extract_lora(module, keys)\n            _merge(module, lora_weights)\n\n    # save model\n    model.save_pretrained(dst_path, torch_dtype=torch.half)\n    tokenizer.save_pretrained(dst_path)\n\n\nif __name__ == '__main__':\n    fire.Fire(main)\n"
  },
  {
    "path": "lmdeploy/vl/utils.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\nfrom typing import Any, Dict, Tuple\n\nimport numpy.typing as npt\nfrom PIL import Image\n\nfrom .media.connection import load_from_url\nfrom .media.image import ImageMediaIO\nfrom .media.time_series import TimeSeriesMediaIO\nfrom .media.video import VideoMediaIO\n\n\ndef load_image(image_url: str, **kwargs) -> Image.Image:\n    \"\"\"Fetch and decode an image from a URL, path, or base64 string.\"\"\"\n    image_io = ImageMediaIO(**kwargs)\n    return load_from_url(image_url, image_io)\n\n\ndef load_video(video_url: str, **kwargs) -> Tuple[npt.NDArray, Dict[str, Any]]:\n    \"\"\"Fetch and decode video frames from a URL, path, or base64 string.\"\"\"\n    image_io = ImageMediaIO()\n    video_io = VideoMediaIO(image_io=image_io, **kwargs)\n    return load_from_url(video_url, video_io)\n\n\ndef load_time_series(ts_url: str, **kwargs) -> npt.NDArray:\n    \"\"\"Fetch and decode time-series from a URL or path or base64 string..\"\"\"\n    ts_io = TimeSeriesMediaIO(**kwargs)\n    return load_from_url(ts_url, ts_io)\n\n\ndef encode_image_base64(image: str | Image.Image, format: str = 'PNG', **kwargs) -> str:\n    \"\"\"Encode image (path or PIL image) to a base64 string.\"\"\"\n    if isinstance(image, str):\n        image = load_image(image, **kwargs)\n    image_io = ImageMediaIO(**kwargs)\n    return image_io.encode_base64(image, image_format=format)\n\n\ndef encode_video_base64(video: str | npt.NDArray, format: str = 'JPEG', **kwargs) -> str:\n    \"\"\"Encode video (path or frames) to a base64 string.\"\"\"\n    if isinstance(video, str):\n        video, _ = load_video(video, **kwargs)\n    image_io = ImageMediaIO()\n    video_io = VideoMediaIO(image_io=image_io, **kwargs)\n    return video_io.encode_base64(video, video_format=format)\n\n\ndef encode_time_series_base64(data: str | npt.NDArray, **kwargs) -> str:\n    \"\"\"Encode time-series (path or numpy array) to a base64 string.\"\"\"\n    if isinstance(data, str):\n        data = load_time_series(data, **kwargs)\n    ts_io = TimeSeriesMediaIO(**kwargs)\n    return ts_io.encode_base64(data)\n"
  },
  {
    "path": "pyproject.toml",
    "content": "[build-system]\nrequires = [\n    \"cmake_build_extension\",\n]\nbuild-backend = \"setuptools.build_meta\"\n"
  },
  {
    "path": "setup.py",
    "content": "import os\nimport re\nimport subprocess\nimport sys\nfrom pathlib import Path\n\nfrom setuptools import find_packages, setup\n\npwd = os.path.dirname(__file__)\nversion_file = 'lmdeploy/version.py'\n\n\ndef get_target_device():\n    return os.getenv('LMDEPLOY_TARGET_DEVICE', 'cuda')\n\n\ndef readme():\n    with open(os.path.join(pwd, 'README.md'), encoding='utf-8') as f:\n        content = f.read()\n    return content\n\n\ndef get_version():\n    file_path = os.path.join(pwd, version_file)\n    pattern = re.compile(r\"\\s*__version__\\s*=\\s*'([0-9A-Za-z.-]+)'\")\n    with open(file_path, 'r') as f:\n        for line in f:\n            m = pattern.match(line)\n            if m:\n                return m.group(1)\n        else:\n            assert False, f'No version found {file_path}'\n\n\ndef get_turbomind_deps():\n    if os.name == 'nt':\n        return []\n\n    CUDA_COMPILER = os.getenv('CUDACXX', os.getenv('CMAKE_CUDA_COMPILER', 'nvcc'))\n    nvcc_output = subprocess.check_output([CUDA_COMPILER, '--version'], stderr=subprocess.DEVNULL).decode()\n    CUDAVER, = re.search(r'release\\s+(\\d+).', nvcc_output).groups()\n    if int(CUDAVER) >= 13:\n        return [\n            f'nvidia-nccl-cu{CUDAVER}',\n            'nvidia-cuda-runtime',\n            'nvidia-cublas',\n            'nvidia-curand',\n        ]\n    else:\n        return [\n            f'nvidia-nccl-cu{CUDAVER}',\n            f'nvidia-cuda-runtime-cu{CUDAVER}',\n            f'nvidia-cublas-cu{CUDAVER}',\n            f'nvidia-curand-cu{CUDAVER}',\n        ]\n\n\ndef parse_requirements(fname='requirements.txt', with_version=True):\n    \"\"\"Parse the package dependencies listed in a file but strips specific\n    versioning information.\n\n    Args:\n        fname (str): path to the file\n        with_version (bool, default=False): if True include version specs\n\n    Returns:\n        List[str]: list of requirements items\n\n    CommandLine:\n        python -c \"import setup; print(setup.parse_requirements())\"\n    \"\"\"\n    require_fpath = fname\n\n    def parse_line(line):\n        \"\"\"Parse information from a line in a requirements text file.\"\"\"\n        if line.startswith('-r '):\n            # Allow specifying requirements in other files\n            target = line.split(' ')[1]\n            for info in parse_require_file(target):\n                yield info\n        else:\n            info = {'line': line}\n            if line.startswith('-e '):\n                info['package'] = line.split('#egg=')[1]\n            elif '@git+' in line:\n                info['package'] = line\n            else:\n                # Remove versioning from the package\n                pat = '(' + '|'.join(['>=', '==', '>']) + ')'\n                parts = re.split(pat, line, maxsplit=1)\n                parts = [p.strip() for p in parts]\n\n                info['package'] = parts[0]\n                if len(parts) > 1:\n                    op, rest = parts[1:]\n                    if ';' in rest:\n                        # Handle platform specific dependencies\n                        # http://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-platform-specific-dependencies\n                        version, platform_deps = map(str.strip, rest.split(';'))\n                        info['platform_deps'] = platform_deps\n                    else:\n                        version = rest  # NOQA\n                    info['version'] = (op, version)\n            yield info\n\n    def parse_require_file(fpath):\n        with open(fpath, 'r') as f:\n            for line in f.readlines():\n                line = line.strip()\n                if line and not line.startswith('#'):\n                    for info in parse_line(line):\n                        yield info\n\n    def gen_packages_items():\n        if os.path.exists(require_fpath):\n            for info in parse_require_file(require_fpath):\n                parts = [info['package']]\n                if with_version and 'version' in info:\n                    parts.extend(info['version'])\n                if not sys.version.startswith('3.4'):\n                    # apparently package_deps are broken in 3.4\n                    platform_deps = info.get('platform_deps')\n                    if platform_deps is not None:\n                        parts.append(';' + platform_deps)\n                item = ''.join(parts)\n                yield item\n\n    packages = list(gen_packages_items())\n\n    return packages\n\n\nif get_target_device() == 'cuda' and not os.getenv('DISABLE_TURBOMIND', '').lower() in ('yes', 'true', 'on', 't', '1'):\n    import cmake_build_extension\n\n    ext_modules = [\n        cmake_build_extension.CMakeExtension(\n            name='_turbomind',\n            install_prefix='lmdeploy/lib',\n            cmake_depends_on=['pybind11'],\n            source_dir=str(Path(__file__).parent.absolute()),\n            cmake_generator=None if os.name == 'nt' else 'Ninja',\n            cmake_build_type=os.getenv('CMAKE_BUILD_TYPE', 'RelWithDebInfo'),\n            cmake_configure_options=[\n                f'-DPython3_ROOT_DIR={Path(sys.prefix)}',\n                f'-DPYTHON_EXECUTABLE={Path(sys.executable)}',\n                '-DCALL_FROM_SETUP_PY:BOOL=ON',\n                '-DBUILD_SHARED_LIBS:BOOL=OFF',\n                # Select the bindings implementation\n                '-DBUILD_PY_FFI=ON',\n                '-DBUILD_MULTI_GPU=' + ('OFF' if os.name == 'nt' else 'ON'),\n                '-DUSE_NVTX=' + ('OFF' if os.name == 'nt' else 'ON'),\n            ],\n        ),\n    ]\n    extra_deps = get_turbomind_deps()\n    cmdclass = dict(build_ext=cmake_build_extension.BuildExtension, )\nelse:\n    ext_modules = []\n    cmdclass = {}\n    extra_deps = []\n\nif __name__ == '__main__':\n    setup(\n        name='lmdeploy',\n        version=get_version(),\n        description='A toolset for compressing, deploying and serving LLM',\n        long_description=readme(),\n        long_description_content_type='text/markdown',\n        author='OpenMMLab',\n        author_email='openmmlab@gmail.com',\n        packages=find_packages(exclude=()),\n        include_package_data=True,\n        setup_requires=parse_requirements('requirements/build.txt'),\n        tests_require=parse_requirements('requirements/test.txt'),\n        install_requires=parse_requirements(f'requirements/runtime_{get_target_device()}.txt') + extra_deps,\n        extras_require={\n            'all': parse_requirements(f'requirements_{get_target_device()}.txt'),\n            'lite': parse_requirements('requirements/lite.txt'),\n            'serve': parse_requirements('requirements/serve.txt'),\n        },\n        classifiers=[\n            'Programming Language :: Python :: 3.10',\n            'Programming Language :: Python :: 3.11',\n            'Programming Language :: Python :: 3.12',\n            'Programming Language :: Python :: 3.13',\n            'Intended Audience :: Developers',\n            'Intended Audience :: Education',\n            'Intended Audience :: Science/Research',\n        ],\n        entry_points={'console_scripts': ['lmdeploy = lmdeploy.cli:run']},\n        ext_modules=ext_modules,\n        cmdclass=cmdclass,\n    )\n"
  },
  {
    "path": "src/CMakeLists.txt",
    "content": "# Copyright (c) 2019-2023, NVIDIA CORPORATION.  All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nadd_subdirectory(turbomind)\n"
  },
  {
    "path": "src/turbomind/CMakeLists.txt",
    "content": "# Copyright (c) 2019-2023, NVIDIA CORPORATION.  All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nadd_subdirectory(utils)\nadd_subdirectory(core)\nadd_subdirectory(kernels)\nadd_subdirectory(comm)\nadd_subdirectory(generation)\nadd_subdirectory(models)\nadd_subdirectory(engine)\n\nif(BUILD_PY_FFI)\n    add_subdirectory(python)\nendif()\n\nadd_library(turbomind STATIC turbomind.cc)\nset_property(TARGET turbomind PROPERTY POSITION_INDEPENDENT_CODE ON)\ntarget_link_libraries(turbomind PUBLIC\n        engine\n        models\n        device_comm\n        host_comm\n        core\n        memory_utils\n        nvtx_utils\n        CUDA::cublasLt\n        CUDA::cudart\n        yaml-cpp::yaml-cpp)\n"
  },
  {
    "path": "src/turbomind/comm/CMakeLists.txt",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\ncmake_minimum_required(VERSION 3.11)\n\nfind_package(Threads)\n\nadd_library(host_comm STATIC host_comm.cc thread_comm.cc)\ntarget_link_libraries(host_comm PRIVATE core logger Threads::Threads)\nset_property(TARGET host_comm PROPERTY POSITION_INDEPENDENT_CODE ON)\n\nadd_library(device_comm STATIC device_comm.cc)\ntarget_link_libraries(device_comm PRIVATE core logger)\nset_property(TARGET device_comm PROPERTY POSITION_INDEPENDENT_CODE ON)\nset_property(TARGET device_comm PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)\n\nif (BUILD_MULTI_GPU)\n    add_subdirectory(cuda_ipc)\n    target_link_libraries(device_comm INTERFACE cuda_ipc_comm)\n\n    if (USE_NCCL)\n        add_subdirectory(nccl)\n        target_link_libraries(device_comm INTERFACE nccl_comm)\n    endif ()\n\n    add_subdirectory(gloo)\n    target_link_libraries(host_comm INTERFACE gloo_comm)\n\n    if (BUILD_TEST)\n        add_executable(test_comm test_comm.cu)\n        target_link_libraries(test_comm PRIVATE device_comm host_comm core pthread nvtx_utils)\n        target_compile_options(test_comm PRIVATE -march=native -mtune=native)\n\n        add_executable(test_host_comm test_host_comm.cc)\n        target_link_libraries(test_host_comm PRIVATE host_comm core Threads::Threads)\n    endif ()\nendif ()\n"
  },
  {
    "path": "src/turbomind/comm/barrier.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#if defined(_MSC_VER) && !defined(__clang__)\n\n#include <condition_variable>\n#include <cstdint>\n#include <mutex>\n\nnamespace turbomind::comm {\n\nclass Barrier {\npublic:\n    explicit Barrier(int count): threshold_{count}, count_{count} {}\n\n    void arrive_and_wait()\n    {\n        std::unique_lock lock{mutex_};\n        auto             phase = phase_;\n        if (--count_ == 0) {\n            ++phase_;\n            count_ = threshold_;\n            cv_.notify_all();\n        }\n        else {\n            cv_.wait(lock, [this, phase] { return phase_ != phase; });\n        }\n    }\n\nprivate:\n    std::mutex              mutex_;\n    std::condition_variable cv_;\n\n    int threshold_;\n    int count_;\n\n    uint32_t phase_{};\n};\n\n}  // namespace turbomind::comm\n\n#else\n\n#include <pthread.h>\n\nnamespace turbomind::comm {\n\nclass Barrier {\npublic:\n    explicit Barrier(int count): barrier_{}\n    {\n        pthread_barrier_init(&barrier_, {}, count);\n    }\n\n    ~Barrier()\n    {\n        pthread_barrier_destroy(&barrier_);\n    }\n\n    void arrive_and_wait()\n    {\n        pthread_barrier_wait(&barrier_);\n    }\n\nprivate:\n    pthread_barrier_t barrier_;\n};\n\n}  // namespace turbomind::comm\n\n#endif\n"
  },
  {
    "path": "src/turbomind/comm/cuda_ipc/CMakeLists.txt",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\ncmake_minimum_required(VERSION 3.11)\n\nadd_library(cuda_ipc_comm STATIC\n        cuda_ipc_comm.cu\n        allreduce.cu\n        allgather.cu\n        fused_allreduce.cu\n        fused_allreduce_ex.cu\n        broadcast.cu)\n\ntarget_link_libraries(cuda_ipc_comm PRIVATE\n        rms_norm\n        host_comm\n        core\n        cuda_utils\n        CUDA::cuda_driver\n        logger)\n\nset_property(TARGET cuda_ipc_comm PROPERTY POSITION_INDEPENDENT_CODE  ON)\nset_property(TARGET cuda_ipc_comm PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS  ON)\n"
  },
  {
    "path": "src/turbomind/comm/cuda_ipc/allgather.cu",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include <cstdint>\n\n#include \"src/turbomind/comm/cuda_ipc/common.h\"\n#include \"src/turbomind/comm/cuda_ipc/cuda_ipc_comm.h\"\n#include \"src/turbomind/comm/cuda_ipc/multimem.cuh\"\n#include \"src/turbomind/comm/cuda_ipc/semaphore.cuh\"\n\n#include \"src/turbomind/kernels/core/meta.h\"\n#include \"src/turbomind/utils/cuda_utils.h\"\n\nnamespace turbomind::comm {\n\n__global__ void Barrier_V2(SystemSemaphoreInfo* semaphores, int ranks)\n{\n    SystemSemaphore sem(semaphores, ranks, blockIdx.x, threadIdx.x);\n    sem.Signal(true);\n    sem.Wait(true);\n    sem.Update(semaphores, ranks, blockIdx.x, threadIdx.x);\n}\n\nvoid CudaIpcCommImpl::Barrier(int group, cudaStream_t stream)\n{\n    const int ranks = n_ranks(group);\n    Barrier_V2<<<1, ranks, 0, stream>>>(groups_.at(group).semaphore.handle(), ranks);\n}\n\ntemplate<class T, class Relaxed>\n__global__ void __launch_bounds__(1024, 1) Allgather_Simple_Pull(\n    Array<T*, kMaxRanks> uc, SystemSemaphoreInfo* semaphores, int rank, int ranks, int64_t slice, Relaxed relaxed)\n{\n    SystemSemaphore sem(semaphores, ranks, blockIdx.x, threadIdx.x);\n    sem.Signal(relaxed);\n    sem.Wait(relaxed);\n\n    __syncthreads();\n\n    auto local = uc[rank];\n\n    for (int i = 1; i < ranks; ++i) {\n        const int p  = rank + i < ranks ? rank + i : rank + i - ranks;\n        const T*  ch = cvta_generic_to_global(uc[p]);\n        for (int64_t idx = threadIdx.x + blockIdx.x * blockDim.x; idx < slice; idx += blockDim.x * gridDim.x) {\n            local[slice * p + idx] = ch[slice * p + idx];\n        }\n    }\n\n    __syncthreads();\n\n    sem.Signal(relaxed);\n    sem.Wait(relaxed);\n\n    sem.Update(semaphores, ranks, blockIdx.x, threadIdx.x);\n}\n\ntemplate<class T, class Relaxed>\n__global__ void __launch_bounds__(1024, 1) Allgather_NVLS_V2(\n    T* uc, T* mc, SystemSemaphoreInfo* semaphores, int rank, int ranks, int64_t slice, Relaxed relaxed)\n{\n#if TURBOMIND_ARCH_SM90\n    SystemSemaphore sem(semaphores, ranks, blockIdx.x, threadIdx.x);\n    sem.Signal(relaxed);\n    sem.Wait(relaxed);\n\n    __syncthreads();\n\n    for (int64_t idx = threadIdx.x + blockIdx.x * blockDim.x; idx < slice; idx += blockDim.x * gridDim.x) {\n        multimem_st(&mc[slice * rank + idx], uc[slice * rank + idx]);\n    }\n\n    __syncthreads();\n\n    sem.Signal(relaxed);\n    sem.Wait(relaxed);\n    sem.Update(semaphores, ranks, blockIdx.x, threadIdx.x);\n#endif\n}\n\nvoid CudaIpcCommImpl::AllGather(\n    const void* sendbuff, void* recvbuff, size_t sendcount, DataType type, int group, cudaStream_t stream)\n{\n    const size_t bytesize = turbomind::byte_size(type) * sendcount;\n\n    const int ranks = this->n_ranks(group);\n    const int rank  = this->rank(group);\n\n    auto semaphore = groups_.at(group).semaphore.handle();\n\n    auto invoke = [&](auto t) {\n        using T               = decltype(t);\n        const auto   symm_ptr = get_symmetric_v2((T*)recvbuff, group);\n        const size_t slice    = bytesize / sizeof(T);\n        const int    threads  = 1024;\n        if (symm_ptr.mc) {\n            const int blocks = std::min<int>(4, (slice + threads - 1) / threads);\n            Allgather_NVLS_V2<T><<<blocks, threads, 0, stream>>>(\n                symm_ptr.uc[rank], symm_ptr.mc, semaphore, rank, ranks, slice, std::false_type{});\n        }\n        else {\n            const int blocks = std::min<int>(max_ctas_.apply(32), (slice + threads - 1) / threads);\n            Allgather_Simple_Pull<T>\n                <<<blocks, threads, 0, stream>>>(symm_ptr.uc, semaphore, rank, ranks, slice, std::false_type{});\n        }\n    };\n\n    auto invoke_copy_engine = [&] {\n        auto symm_ptr = get_symmetric_v2((char*)recvbuff, group);\n\n        Barrier(group, stream);\n\n        for (int i = 1; i < ranks; ++i) {\n            const int p = (rank + i) % ranks;\n            check_cuda_error(cudaMemcpyAsync(symm_ptr.uc[p] + rank * bytesize,  //\n                                             (char*)recvbuff + rank * bytesize,\n                                             bytesize,\n                                             cudaMemcpyDefault,\n                                             stream));\n        }\n\n        Barrier(group, stream);\n    };\n\n    if (bytesize < copy_threshold_) {\n        if (bytesize % sizeof(uint4) == 0) {\n            invoke(uint4{});\n        }\n        else if (bytesize % sizeof(uint2) == 0) {\n            invoke(uint2{});\n        }\n        else if (bytesize % sizeof(uint) == 0) {\n            invoke(uint{});\n        }\n        else {\n            TM_CHECK(0) << \"not implemented\";\n        }\n    }\n    else {\n        invoke_copy_engine();\n    }\n}\n\ntemplate<class T, int log2_block_dim, class Relaxed>\n__global__ void __launch_bounds__(1024, 1) Allgather2D_Simple_Pull(T*                   local,\n                                                                   Array<T*, kMaxRanks> uc,\n                                                                   SystemSemaphoreInfo* semaphores,\n                                                                   int                  rank,\n                                                                   int                  ranks,\n                                                                   int64_t              pitch,\n                                                                   int64_t              stride,\n                                                                   int                  width,\n                                                                   int                  height,\n                                                                   int                  log2_groups,\n                                                                   constant<log2_block_dim>,\n                                                                   Relaxed relaxed)\n{\n    SystemSemaphore sem(semaphores, ranks, blockIdx.x, threadIdx.x);\n\n    sem.Signal(relaxed);\n\n    const int log2_threads = log2_block_dim - log2_groups;\n    const int threads      = 1 << log2_threads;\n    const int groups       = 1 << log2_groups;\n\n    const int gi = threadIdx.x >> log2_threads;\n    const int di = (threadIdx.x & (threads - 1));\n    const int bi = blockIdx.x * groups + gi;\n    const int bn = gridDim.x * groups;\n\n    sem.Wait(relaxed);\n\n    __syncthreads();\n\n    for (int i = 1; i < ranks; ++i) {\n        const int     p      = rank + i < ranks ? rank + i : rank + i - ranks;\n        const T*      ch     = cvta_generic_to_global(uc[p]);\n        const int64_t offset = stride * p;\n        for (int x = di; x < width; x += threads) {\n            for (int y = bi; y < height; y += bn) {\n                local[offset + y * pitch + x] = ch[offset + y * pitch + x];\n            }\n        }\n    }\n\n    __syncthreads();\n\n    sem.Signal(relaxed);\n    sem.Wait(relaxed);\n\n    sem.Update(semaphores, ranks, blockIdx.x, threadIdx.x);\n}\n\ntemplate<class T, int log2_block_dim, class Relaxed>\n__global__ void __launch_bounds__(1024, 1) Allgather2D_NVLS_V2(T*                   uc_buf,\n                                                               T*                   mc_buf,\n                                                               SystemSemaphoreInfo* semaphores,\n                                                               int                  rank,\n                                                               int                  ranks,\n                                                               int64_t              pitch,\n                                                               int64_t              stride,\n                                                               int                  width,\n                                                               int                  height,\n                                                               int                  log2_groups,\n                                                               constant<log2_block_dim>,\n                                                               Relaxed relaxed)\n{\n\n#if TURBOMIND_ARCH_SM90\n\n    SystemSemaphore sem(semaphores, ranks, blockIdx.x, threadIdx.x);\n\n    sem.Signal(relaxed);\n    sem.Wait(relaxed);\n\n    const int log2_threads = log2_block_dim - log2_groups;\n    const int threads      = 1 << log2_threads;\n    const int groups       = 1 << log2_groups;\n\n    const int gi = threadIdx.x >> log2_threads;\n    const int di = (threadIdx.x & (threads - 1));\n    const int bi = blockIdx.x * groups + gi;\n    const int bn = gridDim.x * groups;\n\n    __syncthreads();\n\n    const int64_t offset = stride * rank;\n    for (int y = bi; y < height; y += bn) {\n        for (int x = di; x < width; x += threads) {\n            const int64_t idx = offset + y * pitch + x;\n            multimem_st(&mc_buf[idx], uc_buf[idx]);\n        }\n    }\n\n    __syncthreads();\n\n    sem.Signal(relaxed);\n    sem.Wait(relaxed);\n\n    sem.Update(semaphores, ranks, blockIdx.x, threadIdx.x);\n#endif\n}\n\nvoid CudaIpcCommImpl::AllGather2D(const void*  sendbuff,\n                                  void*        recvbuff,\n                                  size_t       pitch,\n                                  size_t       stride,\n                                  int          width,\n                                  int          height,\n                                  DataType     type,\n                                  int2         flags,\n                                  int          group,\n                                  cudaStream_t stream)\n{\n    const size_t byte_width  = byte_size(type, width);\n    const size_t byte_pitch  = byte_size(type, pitch);\n    const size_t byte_stride = byte_size(type, stride);\n\n    const size_t nbytes = byte_width * height;\n\n    const int ranks = this->n_ranks(group);\n    const int rank  = this->rank(group);\n\n    TM_CHECK_EQ((char*)sendbuff, (char*)recvbuff + rank * byte_stride);\n\n    auto semaphore = groups_.at(group).semaphore.handle();\n\n    auto invoke = [&](auto t) {\n        using T = decltype(t);\n\n        const int threads     = 1024;\n        int       log2_groups = 0;\n        while ((threads * sizeof(T) >> log2_groups) > byte_width * 2) {\n            ++log2_groups;\n        }\n        const int groups = 1 << log2_groups;\n\n        auto symm_ptr = get_symmetric_v2((T*)recvbuff, group);\n\n        if (symm_ptr.mc) {\n            const int blocks = std::min<int>(4, (height + groups - 1) >> log2_groups);\n            Allgather2D_NVLS_V2<T><<<blocks, threads, 0, stream>>>((T*)recvbuff,\n                                                                   symm_ptr.mc,\n                                                                   semaphore,\n                                                                   rank,\n                                                                   this->n_ranks(group),\n                                                                   byte_pitch / sizeof(T),\n                                                                   byte_stride / sizeof(T),\n                                                                   byte_width / sizeof(T),\n                                                                   height,\n                                                                   log2_groups,\n                                                                   constant<10>{},\n                                                                   std::true_type{});\n        }\n        else {\n            const int blocks = std::min<int>(max_ctas_.apply(48), (height + groups - 1) >> log2_groups);\n            Allgather2D_Simple_Pull<T><<<blocks, threads, 0, stream>>>((T*)recvbuff,  //\n                                                                       symm_ptr.uc,\n                                                                       semaphore,\n                                                                       rank,\n                                                                       ranks,\n                                                                       byte_pitch / sizeof(T),\n                                                                       byte_stride / sizeof(T),\n                                                                       byte_width / sizeof(T),\n                                                                       height,\n                                                                       log2_groups,\n                                                                       constant<10>{},\n                                                                       std::true_type{});\n        }\n    };\n\n    auto invoke_copy_engine = [&] {\n        auto symm_ptr = get_symmetric_v2((char*)recvbuff, group);\n\n        Barrier(group, stream);\n\n        for (int i = 1; i < ranks; ++i) {\n            const int p = (rank + i) % ranks;\n            check_cuda_error(cudaMemcpy2DAsync(symm_ptr.uc[p] + rank * byte_stride,\n                                               byte_pitch,\n                                               (char*)recvbuff + rank * byte_stride,\n                                               byte_pitch,\n                                               byte_width,\n                                               height,\n                                               cudaMemcpyDefault,\n                                               stream));\n        }\n\n        Barrier(group, stream);\n    };\n\n    if (nbytes < copy_threshold_) {\n        if (byte_width % sizeof(uint4) == 0) {\n            invoke(uint4{});\n        }\n        else if (byte_width % sizeof(uint2) == 0) {\n            invoke(uint2{});\n        }\n        else if (byte_width % sizeof(uint) == 0) {\n            invoke(uint{});\n        }\n        else {\n            TM_CHECK(0) << \"not implemented\";\n        }\n    }\n    else {\n        invoke_copy_engine();\n    }\n}\n\n}  // namespace turbomind::comm\n"
  },
  {
    "path": "src/turbomind/comm/cuda_ipc/allreduce.cu",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include \"src/turbomind/comm/cuda_ipc/common.h\"\n#include \"src/turbomind/comm/cuda_ipc/cuda_ipc_comm.h\"\n#include \"src/turbomind/comm/cuda_ipc/mscclpp.h\"\n#include \"src/turbomind/comm/cuda_ipc/multimem.cuh\"\n#include \"src/turbomind/comm/cuda_ipc/semaphore.cuh\"\n\n#include \"src/turbomind/core/data_type.h\"\n\n#include \"src/turbomind/kernels/core/array_ops.h\"\n#include \"src/turbomind/kernels/core/math.h\"\n#include \"src/turbomind/kernels/core/meta.h\"\n\n#include \"src/turbomind/utils/cuda_utils.h\"\n\nnamespace turbomind::comm {\n\nusing mscclpp::LLPacket;\n\n// reduce-scatter + allgather using LL16Packet\ntemplate<class T, class CtasPerPeer>\n__global__ void __launch_bounds__(1024, 1) Allreduce_LL16_V2(T*                          dst,\n                                                             const T*                    src,\n                                                             LLPacket*                   incoming,\n                                                             Array<LLPacket*, kMaxRanks> outgoing,\n                                                             int                         rank,\n                                                             int                         ranks,\n                                                             int                         slice,  // padded slice\n                                                             int                         count,  // actual count\n                                                             uint32_t                    flag,\n                                                             CtasPerPeer                 ctas_per_peer)\n{\n\n    constexpr int vec_size = sizeof(uint2) / sizeof(T);\n\n    using Vec = Array<T, vec_size>;\n\n    const int bi = blockIdx.x % ctas_per_peer;\n    const int p  = [&, i = blockIdx.x / ctas_per_peer + 1] { return rank + i < ranks ? rank + i : rank + i - ranks; }();\n    const int n  = min(count, p * slice + slice) - p * slice;\n\n    {  // send slice of `src` to peers  (src -> packet0)\n        auto chn = outgoing[p] + rank * slice;\n        for (int idx = threadIdx.x + bi * blockDim.x; idx < n; idx += ctas_per_peer * blockDim.x) {\n            chn[idx].write(*((const uint2*)src + p * slice + idx), flag);\n        }\n    }\n\n    // device-wide barrier not required as what we are sending is not what we are going to modify\n\n    {  // recv data | reduce | send results (src -> packet0 -> packet1)\n        using namespace ops;\n        const int n = min(count, rank * slice + slice) - rank * slice;\n        for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < n; idx += blockDim.x * gridDim.x) {\n            Vec vec;\n            Load(vec, src + (rank * slice + idx) * vec_size);\n            for (int i = 1; i < ranks; ++i) {\n                const int p    = rank + i < ranks ? rank + i : rank + i - ranks;\n                uint2     data = incoming[p * slice + idx].read(flag);\n                vec            = vec + (Vec&)data;\n            }\n            Store(dst + (rank * slice + idx) * vec_size, vec);\n            for (int i = 1; i < ranks; ++i) {\n                const int p = rank + i < ranks ? rank + i : rank + i - ranks;\n                outgoing[p][(ranks + rank) * slice + idx].write((uint2&)vec, flag);\n            }\n        }\n    }\n\n    {  // recv results (packet1 -> dst)\n        incoming += (ranks + p) * slice;\n        dst += p * slice * vec_size;\n        // ! note that `dst` MUST have same partition as we are sending `src`\n        for (int idx = threadIdx.x + bi * blockDim.x; idx < n; idx += ctas_per_peer * blockDim.x) {\n            uint2 data = incoming[idx].read(flag);\n            Store(dst + idx * vec_size, (Vec&)data);\n        }\n    }\n}\n\n// Modified from\n// https://github.com/microsoft/mscclpp/blob/591276f9d07d2df8e2a45a16738e27867e468ca3/test/mscclpp-test/allreduce_test.cu#L963\ntemplate<class T, int vec_size, class Relaxed>\n__global__ void Allreduce_Simple_Pull(T*                   buf,\n                                      Array<T*, kMaxRanks> chns,\n                                      SystemSemaphoreInfo* semaphores,\n                                      int                  rank,\n                                      int                  ranks,\n                                      int                  slice,\n                                      int                  count,\n                                      constant<vec_size>,\n                                      Relaxed relaxed)\n{\n    const int block_num  = gridDim.x;\n    const int thread_num = blockDim.x * block_num;\n    const int thread_idx = threadIdx.x + blockIdx.x * blockDim.x;\n\n    SystemSemaphore sem(semaphores, ranks, blockIdx.x, threadIdx.x);\n\n    sem.Signal(relaxed);\n    sem.Wait(relaxed);\n\n    __syncthreads();\n\n    using Vec = Array<T, vec_size>;\n\n    using namespace ops;\n\n    const int first = rank * slice;\n    const int last  = min(count, first + slice);\n\n    for (int i = 1; i < ranks; ++i) {\n        const int p   = rank + i < ranks ? rank + i : rank + i - ranks;\n        auto      chn = cvta_generic_to_global(chns[p]);\n        for (int idx = first + thread_idx; idx < last; idx += thread_num) {\n            Vec acc, tmp;\n            Load(tmp, chn + idx * vec_size);\n            Load(acc, buf + idx * vec_size);\n            acc = acc + tmp;\n            Store(buf + idx * vec_size, acc);\n        }\n    }\n\n    __syncthreads();\n\n    sem.Signal(relaxed);\n    sem.Wait(relaxed);\n\n    __syncthreads();\n\n    for (int i = 1; i < ranks; ++i) {\n        const int p     = rank + i < ranks ? rank + i : rank + i - ranks;\n        const int first = p * slice;\n        const int last  = min(count, first + slice);\n        auto      chn   = cvta_generic_to_global(chns[p]);\n        for (int idx = first + thread_idx; idx < last; idx += thread_num) {\n            Vec vec;\n            Load(vec, chn + idx * vec_size);\n            Store(buf + idx * vec_size, vec);\n        }\n    }\n\n    __syncthreads();\n\n    sem.Signal(true);\n    sem.Wait(true);\n\n    sem.Update(semaphores, ranks, blockIdx.x, threadIdx.x);\n}\n\ntemplate<class T, int vec_size, class Relaxed>\n__global__ void Allreduce_Simple_Push_v3(T*                   buf,\n                                         T*                   scratch,\n                                         Array<T*, kMaxRanks> symm_buf,\n                                         Array<T*, kMaxRanks> symm_scratch,\n                                         SystemSemaphoreInfo* semaphores,\n                                         int                  rank,\n                                         int                  ranks,\n                                         int                  slice,  // in vec\n                                         int                  count,  // in vec\n                                         constant<vec_size>,\n                                         Relaxed relaxed)\n{\n    const int thread_idx = threadIdx.x + blockIdx.x * blockDim.x;\n    const int thread_num = blockDim.x * gridDim.x;\n\n    using Vec = Array<T, vec_size>;\n\n    for (int i = 1; i < ranks; ++i) {\n        const int p = rank + i < ranks ? rank + i : rank + i - ranks;\n        const int n = min(count, p * slice + slice) - p * slice;\n        for (int idx = thread_idx; idx < n; idx += thread_num) {\n            Vec vec;\n            Load(vec, buf + (p * slice + idx) * vec_size);\n            Store(symm_scratch[p] + (rank * slice + idx) * vec_size, vec);\n        }\n    }\n\n    __syncthreads();\n\n    SystemSemaphore sem(semaphores, ranks, blockIdx.x, threadIdx.x);\n\n    sem.Signal(relaxed);\n    sem.Wait(relaxed);\n\n    __syncthreads();\n\n    using namespace ops;\n    const int n = min(count, rank * slice + slice) - rank * slice;\n    for (int idx = thread_idx; idx < n; idx += thread_num) {\n        Vec acc;\n        Load(acc, buf + (rank * slice + idx) * vec_size);\n        for (int i = 1; i < ranks; ++i) {\n            const int p = rank + i < ranks ? rank + i : rank + i - ranks;\n            Vec       tmp;\n            Load(tmp, scratch + (p * slice + idx) * vec_size);\n            acc = acc + tmp;\n        }\n        Store(buf + (rank * slice + idx) * vec_size, acc);\n        for (int i = 1; i < ranks; ++i) {\n            const int p = rank + i < ranks ? rank + i : rank + i - ranks;\n            Store(symm_buf[p] + (rank * slice + idx) * vec_size, acc);\n        }\n    }\n\n    __syncthreads();\n\n    sem.Signal(true);\n    sem.Wait(true);\n\n    sem.Update(semaphores, ranks, blockIdx.x, threadIdx.x);\n}\n\ntemplate<class T, int vec_size, class Relaxed>\n__global__ void Allreduce_NVLS_V2(\n    T* mc_buf, SystemSemaphoreInfo* semaphores, int ranks, int first, int last, constant<vec_size>, Relaxed relaxed)\n{\n#if TURBOMIND_ARCH_SM90\n    const int block_num  = gridDim.x;\n    const int thread_num = blockDim.x * block_num;\n    const int thread_idx = threadIdx.x + blockIdx.x * blockDim.x;\n\n    SystemSemaphore sem(semaphores, ranks, blockIdx.x, threadIdx.x);\n\n    sem.Signal(relaxed);\n    sem.Wait(relaxed);\n\n    __syncthreads();\n\n    using Vec = Array<T, vec_size>;\n\n    using namespace ops;\n\n    for (int idx = first + thread_idx; idx < last; idx += thread_num) {\n        Vec vsum = multimem_ld_reduce_sum((const Vec*)(mc_buf + idx * vec_size));\n        multimem_st(mc_buf + idx * vec_size, vsum);\n    }\n\n    __syncthreads();\n\n    sem.Signal(true);\n    sem.Wait(true);\n\n    sem.Update(semaphores, ranks, blockIdx.x, threadIdx.x);\n#endif\n}\n\nvoid CudaIpcCommImpl::AllReduceSum(\n    const void* sendbuff, void* recvbuff, size_t count, DataType type, int group, cudaStream_t stream)\n{\n    FT_CHECK(sendbuff == recvbuff);\n\n    void* data = recvbuff;\n\n    const int n_ranks = this->n_ranks(group);\n    const int rank    = this->rank(group);\n\n    auto semaphore = groups_.at(group).semaphore.handle();\n\n    auto invoke = [&](auto t) {\n        using T               = decltype(t);\n        const size_t bytesize = sizeof(T) * count;\n\n        auto symm_ptr = get_symmetric_v2((T*)data, group);\n\n        if (symm_ptr.mc) {\n            constexpr int vec_size = sizeof(uint4) / sizeof(T);\n            constexpr int threads  = 1024;\n            const int     slice    = (count / vec_size + n_ranks - 1) / n_ranks;\n            const int     first    = rank * slice;\n            const int     last     = std::min<int>(count / vec_size, first + slice);\n            const int     max_ctas = max_ctas_.apply(8);\n            const int     blocks   = std::min(max_ctas, (slice + threads - 1) / threads);\n            Allreduce_NVLS_V2<<<blocks, threads, 0, stream>>>(symm_ptr.mc,  //\n                                                              semaphore,\n                                                              n_ranks,\n                                                              first,\n                                                              last,\n                                                              constant<vec_size>{},\n                                                              std::false_type{});\n        }\n#if 1\n        else if (round_up(bytesize, 2 * n_ranks * sizeof(LLPacket)) <= std::min<size_t>(1 << 20, kPacketBuffSize)) {\n            constexpr int vec_size      = sizeof(uint2) / sizeof(T);\n            const int     slice         = (count / vec_size + n_ranks - 1) / n_ranks;\n            constexpr int ctas_per_peer = 4;\n            constexpr int threads       = 1024;\n            const int     blocks        = (n_ranks - 1) * ctas_per_peer;\n            auto          incoming      = (LLPacket*)packet_buff_;\n            auto          outgoing      = get_symmetric_v2(incoming, group).uc;\n            Allreduce_LL16_V2<<<blocks, threads, 0, stream>>>((T*)data,  //\n                                                              (T*)data,\n                                                              incoming,\n                                                              outgoing,\n                                                              rank,\n                                                              n_ranks,\n                                                              slice,\n                                                              count / vec_size,\n                                                              flag_++,\n                                                              constant<ctas_per_peer>{});\n        }\n#endif\n        else if (round_up(bytesize, n_ranks * sizeof(uint4)) <= std::min<size_t>(6 << 20, kScratchBuffSize)) {\n            constexpr int vec_size = sizeof(uint4) / sizeof(T);\n            constexpr int threads  = 1024;\n            const int     slice    = (count / vec_size + n_ranks - 1) / n_ranks;\n            const int     max_ctas = max_ctas_.apply(48);\n            const int     blocks   = std::min(max_ctas, (slice + threads - 1) / threads);\n            Allreduce_Simple_Push_v3<<<blocks, threads, 0, stream>>>((T*)data,\n                                                                     (T*)scratch_buff_,\n                                                                     symm_ptr.uc,\n                                                                     get_symmetric_v2((T*)scratch_buff_, group).uc,\n                                                                     semaphore,\n                                                                     rank,\n                                                                     n_ranks,\n                                                                     slice,\n                                                                     count / vec_size,\n                                                                     constant<vec_size>{},\n                                                                     std::false_type{});\n        }\n        else {\n            constexpr int vec_size = sizeof(uint4) / sizeof(T);\n            constexpr int threads  = 1024;\n            const int     slice    = (count / vec_size + n_ranks - 1) / n_ranks;\n            const int     max_ctas = max_ctas_.apply(48);\n            const int     blocks   = std::min(max_ctas, (slice + threads - 1) / threads);\n            Allreduce_Simple_Pull<<<blocks, threads, 0, stream>>>((T*)data,\n                                                                  symm_ptr.uc,\n                                                                  semaphore,\n                                                                  rank,\n                                                                  n_ranks,\n                                                                  slice,\n                                                                  count / vec_size,\n                                                                  constant<vec_size>{},\n                                                                  std::false_type{});\n        }\n    };\n\n    TM_DISPATCH_PRIMARY_DTYPES(type, invoke);\n}\n\n}  // namespace turbomind::comm\n"
  },
  {
    "path": "src/turbomind/comm/cuda_ipc/bootstrap.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include <mutex>\n#include <queue>\n#include <thread>\n\n#include \"src/turbomind/comm/barrier.h\"\n#include \"src/turbomind/comm/device_comm.h\"\n\nnamespace turbomind::comm {\n\n// Inspired by\n// https://github.com/microsoft/mscclpp/blob/591276f9d07d2df8e2a45a16738e27867e468ca3/include/mscclpp/core.hpp#L31\nclass LocalBootstrap {\npublic:\n    struct State {\n\n        explicit State(int n): num(n), barrier(n), ptrs(n), queues(n * n)\n        {\n            for (int i = 0; i < n; ++i) {\n                mutexes.emplace_back();\n            }\n        }\n\n        using Queue = std::queue<std::vector<uint8_t>>;\n\n        Queue& get_que(int from, int to)\n        {\n            return queues[from * num + to];\n        }\n\n        int num;\n\n        comm::Barrier barrier;\n\n        std::vector<void*>     ptrs;\n        std::deque<std::mutex> mutexes;\n        std::vector<Queue>     queues;\n    };\n\n    LocalBootstrap(int world_size, int rank, std::shared_ptr<State> state):\n        world_size_{world_size}, rank_{rank}, state_{state}\n    {\n    }\n\n    int getRank()\n    {\n        return rank_;\n    }\n\n    int getNranks()\n    {\n        return world_size_;\n    }\n\n    int getNranksPerNode()\n    {\n        return world_size_;\n    }\n\n    void send(void* data, int size, int peer, int tag)\n    {\n        // std::cerr << \"send \" << size << \" \" << rank_ << \" -> \" << peer << \" \" << tag << \"\\n\";\n        std::lock_guard lock{state_->mutexes[peer]};\n        auto&           que = state_->get_que(rank_, peer);\n        que.push(std::vector<uint8_t>((uint8_t*)data, (uint8_t*)data + size));\n    }\n\n    void recv(void* data, int size, int peer, int tag)\n    {\n        // std::cerr << \"recv \" << size << \" \" << rank_ << \" <- \" << peer << \" \" << tag << \"\\n\";\n        auto& que = state_->get_que(peer, rank_);\n        while (true) {\n            {\n                std::lock_guard lock{state_->mutexes[rank_]};\n                if (!que.empty()) {\n                    FT_CHECK(que.front().size() == (size_t)size);\n                    std::copy_n(que.front().begin(), size, (uint8_t*)data);\n                    que.pop();\n                    return;\n                }\n            }\n            std::this_thread::yield();\n        }\n    }\n\n    void allGather(void* allData, int size)\n    {\n        barrier();\n\n        state_->ptrs[rank_] = allData;\n\n        barrier();\n\n        for (int i = 0; i < world_size_; ++i) {\n            if (i == rank_) {\n                continue;\n            }\n            const auto offset = i * (size_t)size;\n            std::copy_n((uint8_t*)state_->ptrs[i] + offset, size, (uint8_t*)allData + offset);\n        }\n\n        barrier();\n    }\n\n    void barrier()\n    {\n        state_->barrier.arrive_and_wait();\n    }\n\nprivate:\n    int world_size_;\n    int rank_;\n\n    std::shared_ptr<State> state_;\n};\n\n}  // namespace turbomind::comm\n"
  },
  {
    "path": "src/turbomind/comm/cuda_ipc/broadcast.cu",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include <cstdint>\n\n#include \"src/turbomind/comm/cuda_ipc/common.h\"\n#include \"src/turbomind/comm/cuda_ipc/cuda_ipc_comm.h\"\n#include \"src/turbomind/comm/cuda_ipc/multimem.cuh\"\n#include \"src/turbomind/comm/cuda_ipc/semaphore.cuh\"\n\n#include \"src/turbomind/comm/cuda_ipc/semaphore.h\"\n#include \"src/turbomind/kernels/core/math.h\"\n#include \"src/turbomind/kernels/core/meta.h\"\n\n#include \"src/turbomind/utils/cuda_utils.h\"\n\nnamespace turbomind::comm {\n\ntemplate<class T, class Relaxed>\n__global__ void __launch_bounds__(1024, 1) Broadcast_NVLS_V2(const T*             uc,\n                                                             T*                   mc,\n                                                             SystemSemaphoreInfo* semaphores,\n                                                             int                  rank,\n                                                             int                  ranks,\n                                                             int                  root,\n                                                             int64_t              slice,\n                                                             int64_t              count,\n                                                             Relaxed              relaxed)\n{\n\n#if TURBOMIND_ARCH_SM90\n    SystemSemaphore sem(semaphores, ranks, blockIdx.x, threadIdx.x);\n    sem.Signal(relaxed);\n    sem.Wait(relaxed);\n\n    __syncthreads();\n\n    int64_t first = rank * slice;\n    int64_t last  = min(first + slice, count);\n\n    for (int64_t idx = first + threadIdx.x + blockIdx.x * blockDim.x; idx < last; idx += blockDim.x * gridDim.x) {\n        multimem_st(&mc[idx], uc[idx]);\n    }\n\n    __syncthreads();\n\n    sem.Signal(relaxed);\n    sem.Wait(relaxed);\n    sem.Update(semaphores, ranks, blockIdx.x, threadIdx.x);\n#endif\n}\n\ntemplate<class T, class Relaxed>\n__global__ void __launch_bounds__(1024, 1) Broadcast_Simple_Pull(Array<T*, kMaxRanks> uc,\n                                                                 SystemSemaphoreInfo* semaphores,\n                                                                 int                  rank,\n                                                                 int                  ranks,\n                                                                 int                  root,\n                                                                 int64_t              slice,\n                                                                 Relaxed              relaxed)\n{\n    SystemSemaphore sem(semaphores, ranks, blockIdx.x, threadIdx.x);\n    sem.Signal(relaxed);\n    sem.Wait(relaxed);\n\n    __syncthreads();\n\n    auto dst = uc[rank];\n    auto src = uc[root];\n\n    if (rank != root) {\n        for (int64_t idx = threadIdx.x + blockIdx.x * blockDim.x; idx < slice; idx += blockDim.x * gridDim.x) {\n            dst[idx] = src[idx];\n        }\n        __syncthreads();\n    }\n\n    sem.Signal(relaxed);\n    sem.Wait(relaxed);\n    sem.Update(semaphores, ranks, blockIdx.x, threadIdx.x);\n}\n\ntemplate<class T, class Relaxed>\n__global__ void __launch_bounds__(1024, 1) Broadcast_Simple_V2(Array<T*, kMaxRanks> uc,\n                                                               SystemSemaphoreInfo* semaphores,\n                                                               int                  index,\n                                                               int                  rank,\n                                                               int                  ranks,\n                                                               int                  root,\n                                                               int64_t              slice,\n                                                               int64_t              count,\n                                                               Relaxed              relaxed)\n{\n    SystemSemaphore sem(semaphores, ranks, blockIdx.x, threadIdx.x);\n    sem.Signal(relaxed);\n    sem.Wait(relaxed);\n\n    __syncthreads();\n\n    auto dst = uc[rank];\n    auto src = uc[root];\n\n    int64_t first = index * slice;\n    int64_t last  = min(first + slice, count);\n\n    if (rank != root) {\n        for (int64_t idx = first + threadIdx.x + blockIdx.x * blockDim.x; idx < last; idx += blockDim.x * gridDim.x) {\n            dst[idx] = src[idx];\n            for (int i = 0; i < ranks; ++i) {\n                int p = rank + i < ranks ? rank + i : rank + i - ranks;\n                if (p != root) {\n                    uc[p][idx] = dst[idx];\n                }\n            }\n        }\n    }\n\n    __syncthreads();\n\n    sem.Signal(relaxed);\n    sem.Wait(relaxed);\n    sem.Update(semaphores, ranks, blockIdx.x, threadIdx.x);\n}\n\nvoid CudaIpcCommImpl::Broadcast(const void*  sendbuff,  //\n                                void*        recvbuff,\n                                size_t       count,\n                                DataType     type,\n                                int          root,\n                                int          group,\n                                cudaStream_t stream)\n{\n\n    const int rank  = this->rank(group);\n    const int ranks = this->n_ranks(group);\n\n    const size_t bytesize = turbomind::byte_size(type, count);\n\n    auto semaphore = groups_.at(group).semaphore.handle();\n\n    const int algo = 5;\n\n    if (algo == 0) {\n        Barrier(group, stream);\n        if (rank != root) {\n            SymmetricPtr_V2<char> symm_ptr = get_symmetric_v2((char*)recvbuff, group);\n            check_cuda_error(cudaMemcpyAsync((char*)recvbuff, symm_ptr.uc[root], bytesize, cudaMemcpyDefault, stream));\n        }\n        Barrier(group, stream);\n    }\n    else if (algo == 1) {\n        const int    slices = 16;\n        const size_t slice  = bytesize / slices;\n        TM_CHECK(bytesize % slices == 0);\n        TM_CHECK_EQ(root, 0);\n        SymmetricPtr_V2<char> symm_ptr = get_symmetric_v2((char*)recvbuff, group);\n        for (int i = 1; i <= ranks + slices - 2; ++i) {\n            Barrier(group, stream);\n            int s = i - rank;\n            if (0 <= s && s < slices && rank != root) {\n                check_cuda_error(cudaMemcpyAsync(\n                    (char*)recvbuff + s * slice, symm_ptr.uc[rank - 1] + s * slice, slice, cudaMemcpyDefault, stream));\n            }\n        }\n        Barrier(group, stream);\n    }\n    else if (algo == 2) {\n        SymmetricPtr_V2<char> symm_ptr = get_symmetric_v2((char*)recvbuff, group);\n        TM_CHECK_EQ(ranks, 8);\n        TM_CHECK_EQ(root, 0);\n        Barrier(group, stream);\n        if (rank == 4) {\n            check_cuda_error(\n                cudaMemcpyAsync((char*)recvbuff, symm_ptr.uc[rank - 4], bytesize, cudaMemcpyDefault, stream));\n        }\n        Barrier(group, stream);\n        if (rank == 2 || rank == 6) {\n            check_cuda_error(\n                cudaMemcpyAsync((char*)recvbuff, symm_ptr.uc[rank - 2], bytesize, cudaMemcpyDefault, stream));\n        }\n        Barrier(group, stream);\n        if (rank & 1) {\n            check_cuda_error(\n                cudaMemcpyAsync((char*)recvbuff, symm_ptr.uc[rank - 1], bytesize, cudaMemcpyDefault, stream));\n        }\n        Barrier(group, stream);\n    }\n    else if (algo == 3) {\n        using T               = uint4;\n        const auto   symm_ptr = get_symmetric_v2((T*)recvbuff, group);\n        const size_t count    = bytesize / sizeof(T);\n        const size_t slice    = cdiv<size_t>(count, ranks);\n        const int    threads  = 1024;\n        const int    blocks   = std::min<int>(2, (slice + threads - 1) / threads);\n        Broadcast_NVLS_V2<T><<<blocks, threads, 0, stream>>>(\n            symm_ptr.uc[root], symm_ptr.mc, semaphore, rank, ranks, root, slice, count, std::true_type{});\n    }\n    else if (algo == 4) {\n        using T               = uint4;\n        const auto   symm_ptr = get_symmetric_v2((T*)recvbuff, group);\n        const size_t slice    = bytesize / sizeof(T);\n        const int    threads  = 1024;\n        const int    blocks   = std::min<int>(32, (slice + threads - 1) / threads);\n        Broadcast_Simple_Pull<T>\n            <<<blocks, threads, 0, stream>>>(symm_ptr.uc, semaphore, rank, ranks, root, slice, std::false_type{});\n    }\n    else if (algo == 5) {\n        using T               = uint4;\n        const auto   symm_ptr = get_symmetric_v2((T*)recvbuff, group);\n        const size_t count    = bytesize / sizeof(T);\n        const int    peers    = ranks - 1;\n        const size_t slice    = (count + peers - 1) / peers;\n        const int    threads  = 1024;\n        const int    blocks   = std::min<int>(32, (slice + threads - 1) / threads);\n        const int    index    = rank >= root ? rank - 1 : rank;\n        Broadcast_Simple_V2<T><<<blocks, threads, 0, stream>>>(\n            symm_ptr.uc, semaphore, index, rank, ranks, root, slice, count, std::false_type{});\n    }\n    else if (algo == 6) {\n        TM_CHECK_EQ(ranks, 8);\n        TM_CHECK_EQ(root, 0);\n        const auto   symm_ptr = get_symmetric_v2((char*)recvbuff, group);\n        const size_t count    = bytesize;\n        const size_t slice    = cdiv<size_t>(count, ranks);\n\n        // 0->4\n        // 0->2,       4->6\n        // 0->1, 2->3, 4->5, 6->7\n        Barrier(group, stream);\n        if (rank == 0) {\n            check_cuda_error(cudaMemcpyAsync(symm_ptr.uc[rank + 4] + slice * (rank + 4),\n                                             symm_ptr.uc[rank + 0] + slice * (rank + 4),\n                                             slice * 4,\n                                             cudaMemcpyDefault,\n                                             stream));\n        }\n        Barrier(group, stream);\n        if (rank == 0 || rank == 4) {\n            check_cuda_error(cudaMemcpyAsync(symm_ptr.uc[rank + 2] + slice * (rank + 2),\n                                             symm_ptr.uc[rank + 0] + slice * (rank + 2),\n                                             slice * 2,\n                                             cudaMemcpyDefault,\n                                             stream));\n        }\n        Barrier(group, stream);\n        if (rank % 2 == 0) {\n            check_cuda_error(cudaMemcpyAsync(symm_ptr.uc[rank + 1] + slice * (rank + 1),\n                                             symm_ptr.uc[rank + 0] + slice * (rank + 1),\n                                             slice * 1,\n                                             cudaMemcpyDefault,\n                                             stream));\n        }\n        Barrier(group, stream);\n        for (int i = 1; i < ranks; ++i) {\n            const int p = (rank + i) % ranks;\n            check_cuda_error(cudaMemcpyAsync(symm_ptr.uc[p] + rank * slice,  //\n                                             (char*)recvbuff + rank * slice,\n                                             slice,\n                                             cudaMemcpyDefault,\n                                             stream));\n        }\n        Barrier(group, stream);\n    }\n}\n\n}  // namespace turbomind::comm\n"
  },
  {
    "path": "src/turbomind/comm/cuda_ipc/common.h",
    "content": "#pragma once\n\n#include \"src/turbomind/kernels/core/array.h\"\n\nnamespace turbomind::comm {\n\ninline constexpr int kMaxRanks        = 8;\nstatic constexpr int kPacketBuffSize  = 8 << 20;  // 8 MB\nstatic constexpr int kScratchBuffSize = 8 << 20;  // 8 MB\nstatic constexpr int kMaxChannels     = 64;\n\ntemplate<class T>\nstruct SymmetricPtr_V2 {\n    Array<T*, kMaxRanks> uc;\n    T*                   mc;\n};\n\n}  // namespace turbomind::comm\n"
  },
  {
    "path": "src/turbomind/comm/cuda_ipc/cuda_ipc_comm.cu",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include <cstdlib>\n#include <memory>\n#include <numeric>\n#include <vector>\n\n#include <cuda.h>\n\n#include \"src/turbomind/comm/cuda_ipc/common.h\"\n#include \"src/turbomind/kernels/core/math.h\"\n\n#include \"src/turbomind/comm/cuda_ipc/cuda_ipc_comm.h\"\n#include \"src/turbomind/comm/device_comm.h\"\n#include \"src/turbomind/comm/env.h\"\n#include \"src/turbomind/comm/host_comm.h\"\n\n#include \"src/turbomind/comm/cuda_ipc/semaphore.h\"\n\n#include \"src/turbomind/utils/cuda_utils.h\"\n#include \"src/turbomind/utils/logger.h\"\n\nnamespace turbomind::comm {\n\nTM_ENV_VAR(COMM, MAX_CTAS, 0);\nTM_ENV_VAR(COMM, NVLS_ENABLE, 1);\n// per-rank send size threshold to use copy engine instead of p2p for all-gather colls\nTM_ENV_VAR(COMM, COPY_THRESHOLD, INT64_MAX);\n\nint CudaIpcCommImpl::Split(int color, int key, int group)\n{\n    FT_CHECK(color >= 0);\n    FT_CHECK(rank(group) >= 0);\n\n    auto& parent = groups_.at(group);\n\n    auto vec = comm::AllGather(h_comm_, std::make_tuple(color, key, parent.g2l[global_rank_]));\n\n    auto last = std::stable_partition(vec.begin(), vec.end(), [&](auto x) {  //\n        return std::get<0>(x) == color;\n    });\n    vec.erase(last, vec.end());\n    std::stable_sort(vec.begin(), vec.end(), [](auto& a, auto& b) {  //\n        return a < b;\n    });\n\n    std::vector<int> l2g;\n    std::vector<int> g2l(parent.g2l.size(), -1);\n\n    for (size_t local = 0; local < vec.size(); ++local) {\n        const auto r      = std::get<2>(vec[local]);\n        int        global = parent.l2g.at(r);\n        l2g.push_back(global);\n        g2l[global] = local;\n    }\n\n    int index = groups_.size();\n\n    auto& g = groups_.emplace_back(Group{l2g, g2l});\n\n    for (auto& a : allocation_) {\n        Register(a, index);\n    }\n\n    g.semaphore.Allocate(l2g.size(), g2l[global_rank_], [&](size_t size) {\n        auto buf = (uint64_t*)Allocate(size);\n        check_cuda_error(cudaMemsetAsync(buf, 0, size));\n        check_cuda_error(cudaStreamSynchronize(0));\n        Register(buf, size);\n        return get_symmetric_v2(buf, index);\n    });\n\n    return index;\n};\n\nCudaIpcCommImpl::CudaIpcCommImpl(HostComm h_comm):\n    h_comm_{h_comm}, global_n_ranks_{h_comm->n_ranks()}, global_rank_{h_comm->rank()}\n{\n    h_comm_ = h_comm;\n\n    const int n_ranks = global_n_ranks_;\n    const int rank    = global_rank_;\n\n    // Exchange device ordinals\n    ordinals_.resize(n_ranks);\n    check_cuda_error(cudaGetDevice(&ordinals_[rank]));\n    comm::AllGather(h_comm_, ordinals_.data(), 1);\n\n    max_ctas_ = {std::min(getSMCount(), kMaxChannels)};\n    if (auto v = GetEnv<COMM_MAX_CTAS>()) {\n        max_ctas_.set_value(std::min(v, max_ctas_.value()));\n    }\n    auto minval = comm::AllReduce(h_comm_, max_ctas_.value(), RedOp::kMin);\n    TM_CHECK_EQ(max_ctas_.value(), minval) << \"MAX_CTAS set to different values\";\n\n#if __CUDACC_VER_MAJOR__ >= 12\n    if (global_n_ranks_ >= 4 && GetEnv<COMM_NVLS_ENABLE>()) {  // solve 2n-2>n+1 -> n>3\n        CUDRVCHECK(\n            cuDeviceGetAttribute(&multicast_capability_, CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED, ordinals_[rank]));\n        multicast_capability_ = comm::AllReduce(h_comm_, multicast_capability_, RedOp::kMin);\n    }\n#endif\n\n    copy_threshold_ = GetEnv<COMM_COPY_THRESHOLD>();\n\n    // Prepare access descriptors\n    alloc_access_descs_.resize(n_ranks);\n    for (int r = 0; r < n_ranks; ++r) {\n        alloc_access_descs_[r].location.id   = ordinals_[r];\n        alloc_access_descs_[r].location.type = CU_MEM_LOCATION_TYPE_DEVICE;\n        alloc_access_descs_[r].flags         = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;\n    }\n\n    // Initialize group mapping\n    std::vector<int> idxs(n_ranks);\n    std::iota(idxs.begin(), idxs.end(), 0);\n    auto& g = groups_.emplace_back();\n    g.l2g = g.g2l = idxs;\n\n    // Prepare packet buffer\n    packet_buff_ = Allocate(kPacketBuffSize);\n    check_cuda_error(cudaMemsetAsync(packet_buff_, 0, kPacketBuffSize));\n\n    // Prepare scratch buffer\n    scratch_buff_ = Allocate(kScratchBuffSize);\n    check_cuda_error(cudaMemsetAsync(scratch_buff_, 0, kScratchBuffSize));\n\n    /// TODO: release\n    g.semaphore.Allocate(global_n_ranks_, global_rank_, [this](size_t size) {\n        auto buf = (uint64_t*)Allocate(size);\n        check_cuda_error(cudaMemsetAsync(buf, 0, size));\n        check_cuda_error(cudaStreamSynchronize(0));\n        Register(buf, size);\n        return get_symmetric_v2(buf, 0);\n    });\n\n    check_cuda_error(cudaStreamSynchronize(0));\n\n    Register(packet_buff_, kPacketBuffSize);\n    Register(scratch_buff_, kScratchBuffSize);\n}\n\nCudaIpcCommImpl::~CudaIpcCommImpl()\n{\n    Deregister(scratch_buff_);\n    Deregister(packet_buff_);\n\n    Free(scratch_buff_);\n    Free(packet_buff_);\n\n    for (auto i = (int)groups_.size() - 1; i >= 0; --i) {\n        groups_[i].semaphore.Free([this](void* ptr) {\n            Deregister(ptr);\n            Free(ptr);\n        });\n    }\n\n    for (const auto& a : allocation_) {\n        TM_LOG_WARNING(\"[COMM][%d] Allocation (%p, %lu) is not freed\", global_rank_, a.uc_beg, a.size);\n    }\n\n    cudaStreamSynchronize(0);\n}\n\nvoid* CudaIpcCommImpl::Allocate(size_t size)\n{\n    size_t              granularity{};\n    CUmemAllocationProp prop{};\n\n    prop.type          = CU_MEM_ALLOCATION_TYPE_PINNED;\n    prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;\n    prop.location.id   = ordinals_[global_rank_];\n\n    if (multicast_capability_) {\n#if __CUDACC_VER_MAJOR__ >= 12\n        CUmulticastObjectProp prop{};\n        prop.numDevices = alloc_access_descs_.size();\n        prop.size       = size;\n        CUDRVCHECK(cuMulticastGetGranularity(&granularity, &prop, CU_MULTICAST_GRANULARITY_MINIMUM));\n#else\n        TM_CHECK(0);\n#endif\n    }\n    else {\n        CUDRVCHECK(cuMemGetAllocationGranularity(&granularity, &prop, CU_MEM_ALLOC_GRANULARITY_MINIMUM));\n    }\n\n    size = round_up(size, granularity);\n\n    CUmemGenericAllocationHandle handle{};\n    CUDRVCHECK(cuMemCreate(&handle, size, &prop, 0));\n\n    CUdeviceptr ptr{};\n    CUDRVCHECK(cuMemAddressReserve(&ptr, size, granularity, 0, 0));\n    CUDRVCHECK(cuMemMap(ptr, size, 0, handle, 0));\n    CUDRVCHECK(cuMemSetAccess(ptr, size, alloc_access_descs_.data(), alloc_access_descs_.size()));\n\n    Allocation a{};\n    a.handle    = handle;\n    a.size      = size;\n    a.uc_beg    = reinterpret_cast<void*>(ptr);\n    a.uc_end    = (char*)a.uc_beg + size;\n    a.alignment = granularity;\n\n    a.uc_ptrs = comm::AllGather(h_comm_, a.uc_beg);\n\n    allocation_.emplace(a);\n\n    return a.uc_beg;\n}\n\nvoid CudaIpcCommImpl::Free(void* ptr)\n{\n    if (auto it = allocation_.find(ptr); it != allocation_.end()) {\n        auto& a    = *it;\n        auto  dptr = reinterpret_cast<CUdeviceptr>(ptr);\n        CUDRVCHECK(cuMemUnmap(dptr, a.size));\n        CUDRVCHECK(cuMemRelease(a.handle));\n        CUDRVCHECK(cuMemAddressFree(dptr, a.size));\n        allocation_.erase(it);\n    }\n    else {\n        TM_LOG_WARNING(\"[TM][COMM][%d] Freeing %p which is not allocated by this module\", global_rank_, ptr);\n    }\n}\n\nvoid CudaIpcCommImpl::Register(void* ptr, size_t size)\n{\n    // register for all groups\n    auto& symm = groups_.at(0).symmetric;\n\n    if (symm.find(ptr) != symm.end()) {\n        TM_LOG_WARNING(\"[TM][COMM][%d] Duplicated registration on (%p, %lu)\", global_rank_, ptr, size);\n        return;\n    }\n\n    auto alloc = allocation_.find(ptr);\n    TM_CHECK(alloc != allocation_.end());\n\n    for (size_t i = 0; i < groups_.size(); ++i) {\n        Register(*alloc, i);\n    }\n}\n\nvoid CudaIpcCommImpl::Register(const Allocation& alloc, int group)\n{\n    auto size = alloc.size;\n\n    auto& g = groups_.at(group);\n\n    Symmetric s{};\n    s.size   = size;\n    s.uc_beg = alloc.uc_beg;\n    s.uc_end = alloc.uc_end;\n\n    for (auto r : g.l2g) {\n        s.uc_ptrs.push_back(alloc.uc_ptrs[r]);\n    }\n\n    const int ranks = n_ranks(group);\n    const int rank  = this->rank(group);\n\n    if (multicast_capability_ && ranks > 1) {  // ! `cuMulticastCreate` fails for `ranks == 1`\n#if __CUDACC_VER_MAJOR__ >= 12\n        CUmulticastObjectProp mc_prop{};\n        mc_prop.numDevices = ranks;\n        mc_prop.size       = size;\n        if (rank == 0) {\n            CUDRVCHECK(cuMulticastCreate(&s.mc_handle, &mc_prop));\n        }\n        auto handles = comm::AllGather(h_comm_, s.mc_handle);\n        s.mc_handle  = handles.at(g.l2g[0]);\n        CUDRVCHECK(cuMulticastAddDevice(s.mc_handle, ordinals_[global_rank_]));\n        CUDRVCHECK(cuMulticastBindMem(s.mc_handle, 0, alloc.handle, 0, size, 0));\n        CUdeviceptr mc_ptr{};\n        CUDRVCHECK(cuMemAddressReserve(&mc_ptr, size, alloc.alignment, 0, 0));\n        CUDRVCHECK(cuMemMap(mc_ptr, size, 0, s.mc_handle, 0));\n        CUDRVCHECK(cuMemSetAccess(mc_ptr, size, &alloc_access_descs_[global_rank_], 1));\n        s.mc_ptr = reinterpret_cast<void*>(mc_ptr);\n        if (rank != 0) {\n            // Increase reference count to the original handle so that all handles can be released\n            // without explicit synchronization\n            CUDRVCHECK(cuMemRetainAllocationHandle(&s.mc_handle, s.mc_ptr));\n        }\n#else\n        TM_CHECK(0);\n#endif\n    }\n\n    g.symmetric.insert(std::move(s));\n}\n\nvoid CudaIpcCommImpl::Deregister(Symmetric& s)\n{\n    if (s.mc_handle) {\n#if __CUDACC_VER_MAJOR__ >= 12\n        auto deviceptr = reinterpret_cast<CUdeviceptr>(s.mc_ptr);\n        CUDRVCHECK(cuMemUnmap(deviceptr, s.size));\n        CUDRVCHECK(cuMemAddressFree(deviceptr, s.size));\n        CUDRVCHECK(cuMulticastUnbind(s.mc_handle, ordinals_.at(global_rank_), 0, s.size));\n        CUDRVCHECK(cuMemRelease(s.mc_handle));\n        s.mc_handle = {};\n        s.mc_ptr    = {};\n#else\n        TM_CHECK(0);\n#endif\n    }\n}\n\nvoid CudaIpcCommImpl::Deregister(void* ptr)\n{\n    std::vector<CUmemGenericAllocationHandle> handles;\n\n    for (size_t i = 0; i < groups_.size(); ++i) {\n        auto& s = groups_[i].symmetric;\n        if (auto it = s.find(ptr); it != s.end()) {\n            Deregister(s.extract(it).value());\n        }\n        else {\n            TM_LOG_WARNING(\"[TM][COMM][%d] Deregistering non-registered address %p\", global_rank_, ptr);\n        }\n    }\n}\n\nint CudaIpcCommImpl::Query(QueryAttr attr) const noexcept\n{\n    if (attr == kHasAllGather2D) {\n        return 1;\n    }\n    return 0;\n}\n\nauto CudaIpcCommImpl::get_symmetric_v2_impl(void* ptr, int group) -> SymmetricPtr_V2<void>\n{\n    auto& g = groups_.at(group);\n\n    auto symm = g.symmetric.find(ptr);\n    TM_CHECK(symm != g.symmetric.end());\n\n    auto offset = (char*)ptr - (char*)symm->uc_beg;\n\n    SymmetricPtr_V2<void> p{};\n\n    TM_CHECK_LE((int)symm->uc_ptrs.size(), p.uc.size());\n\n    for (size_t i = 0; i < symm->uc_ptrs.size(); ++i) {\n        p.uc[i] = (char*)symm->uc_ptrs[i] + offset;\n    }\n\n    if (symm->mc_ptr) {\n        p.mc = (char*)symm->mc_ptr + offset;\n    }\n\n    return p;\n}\n\nDeviceComm CreateCudaIpcCommunicator(int n_ranks, int rank, HostComm h_comm)\n{\n    auto comm = std::make_unique<CudaIpcCommImpl>(h_comm);\n\n    return DeviceComm{std::move(comm)};\n}\n\n}  // namespace turbomind::comm\n"
  },
  {
    "path": "src/turbomind/comm/cuda_ipc/cuda_ipc_comm.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include <cuda.h>\n#include <set>\n\n#include \"src/turbomind/comm/cuda_ipc/common.h\"\n#include \"src/turbomind/comm/cuda_ipc/semaphore.h\"\n#include \"src/turbomind/comm/device_comm.h\"\n#include \"src/turbomind/comm/host_comm.h\"\n\n#include \"src/turbomind/kernels/core/array.h\"\n\n#include \"src/turbomind/utils/cuda_utils.h\"\n\nnamespace turbomind::comm {\n\nclass MaxCtas {\npublic:\n    MaxCtas(int value = 0): is_set_{}, value_{value} {}\n\n    void set_value(int value)\n    {\n        value_  = value;\n        is_set_ = true;\n    }\n\n    int value()\n    {\n        return value_;\n    }\n\n    int apply(int _default)\n    {\n        if (!is_set_) {  // `value_` is max possible value in this case\n            return std::min(_default, value_);\n        }\n        else {\n            return value_;\n        }\n    }\n\nprivate:\n    bool is_set_;\n    int  value_;\n};\n\nclass CudaIpcCommImpl: public DeviceCommImpl {\n    struct Allocation;\n    struct Symmetric;\n\npublic:\n    ~CudaIpcCommImpl() override;\n\n    explicit CudaIpcCommImpl(HostComm h_comm);\n\n    int n_ranks(int group) const override\n    {\n        return groups_.at(group).l2g.size();\n    }\n\n    int rank(int group) const override\n    {\n        return groups_.at(group).g2l.at(global_rank_);\n    }\n\n    void* Allocate(size_t size) override;\n\n    void Free(void* ptr) override;\n\n    void Register(void* ptr, size_t size) override;\n\n    void Deregister(void* ptr) override;\n\n    int Split(int color, int key, int group) override;\n\n    int Query(QueryAttr attr) const noexcept override;\n\n    void AllReduceSum(\n        const void* sendbuff, void* recvbuff, size_t count, DataType type, int group, cudaStream_t stream) override;\n\n    void AllGather(\n        const void* sendbuff, void* recvbuff, size_t sendcount, DataType type, int group, cudaStream_t stream) override;\n\n    void Broadcast(const void*  sendbuff,\n                   void*        recvbuff,\n                   size_t       count,\n                   DataType     type,\n                   int          root,\n                   int          group,\n                   cudaStream_t stream) override;\n\n    void\n    Gather(const void* sendbuff, void* recvbuff, size_t count, DataType type, int root, int group, cudaStream_t stream);\n\n    void Barrier(int group, cudaStream_t stream);\n\n    void AllreduceResidualBiasRMSnorm(void*        hidden,\n                                      void*        residual,\n                                      const void*  bias,\n                                      const void*  weights,\n                                      float        eps,\n                                      int          dim,\n                                      int          token_num,\n                                      DataType     dtype,\n                                      int          group,\n                                      cudaStream_t stream) override;\n\n    void AllreduceResidualBiasRMSnormEx(void*        hidden,\n                                        void*        residual,\n                                        const void*  bias,\n                                        const void*  weights,\n                                        float        eps,\n                                        int          dim,\n                                        DataType     type,\n                                        int          group0,\n                                        int          group1,\n                                        const int*   local_token_nums,\n                                        cudaStream_t stream) override;\n\n    void AllGather2D(const void*  sendbuff,\n                     void*        recvbuff,\n                     size_t       pitch,\n                     size_t       stride,\n                     int          width,\n                     int          height,\n                     DataType     type,\n                     int2         flags,\n                     int          group,\n                     cudaStream_t stream) override;\n\nprivate:\n    template<class T>\n    inline SymmetricPtr_V2<T> get_symmetric_v2(T* ptr, int group)\n    {\n        auto               tmp = get_symmetric_v2_impl(ptr, group);\n        SymmetricPtr_V2<T> ret{};\n        ret.mc = static_cast<T*>(tmp.mc);\n        for (int i = 0; i < ret.uc.size(); ++i) {\n            ret.uc[i] = static_cast<T*>(tmp.uc[i]);\n        }\n        return ret;\n    }\n\n    SymmetricPtr_V2<void> get_symmetric_v2_impl(void* ptr, int group);\n\n    void Register(const Allocation& alloc, int group);\n\n    void Deregister(Symmetric& s);\n\nprivate:\n    HostComm h_comm_;\n\n    int global_n_ranks_;\n    int global_rank_;\n\n    std::vector<int> ordinals_;\n\n    struct Symmetric {\n        void*              uc_beg;\n        void*              uc_end;\n        size_t             size;\n        std::vector<void*> uc_ptrs;  // peers\n        void*              mc_ptr;\n\n        CUmemGenericAllocationHandle mc_handle;\n\n        friend bool operator<(const Symmetric& a, const Symmetric& b)\n        {\n            return (char*)a.uc_beg < (char*)b.uc_beg;\n        }\n        friend bool operator<(const Symmetric& a, void* b)\n        {\n            return (char*)a.uc_end <= (char*)b;\n        }\n        friend bool operator<(void* a, const Symmetric& b)\n        {\n            return (char*)a < (char*)b.uc_beg;\n        }\n    };\n\n    void*    packet_buff_{};\n    void*    scratch_buff_{};\n    uint32_t flag_{1};\n\n    struct Allocation {\n        void*                        uc_beg;\n        void*                        uc_end;\n        size_t                       size;\n        size_t                       alignment;\n        std::vector<void*>           uc_ptrs;  // ranks\n        CUmemGenericAllocationHandle handle;\n\n        friend bool operator<(const Allocation& a, const Allocation& b)\n        {\n            return (char*)a.uc_beg < (char*)b.uc_beg;\n        }\n        friend bool operator<(const Allocation& a, void* b)\n        {\n            return (char*)a.uc_end <= (char*)b;\n        }\n        friend bool operator<(void* a, const Allocation& b)\n        {\n            return (char*)a < (char*)b.uc_beg;\n        }\n    };\n\n    std::vector<CUmemAccessDesc> alloc_access_descs_{};\n\n    int multicast_capability_{false};\n\n    std::set<Allocation, std::less<>> allocation_;\n\n    struct Group {\n        std::vector<int> l2g;  // local -> global\n        std::vector<int> g2l;  // global -> local\n\n        SystemSemaphoreStorage semaphore;\n\n        std::set<Symmetric, std::less<>> symmetric;\n    };\n\n    std::vector<Group> groups_;\n\n    MaxCtas max_ctas_;\n    size_t  copy_threshold_{INT64_MAX};\n};\n\nstruct Rank {\n    int                     rank;\n    int                     peers;\n    __host__ __device__ int get_next_peer(int i)\n    {\n        return i + rank < peers ? i + rank : i + rank - peers;\n    }\n    __host__ __device__ int get_prev_peer(int i)\n    {\n        return get_next_peer(peers - 1 - i);\n    }\n    __host__ __device__ int get_peer_rank(int p)  // rank of `p`\n    {\n        return p < rank ? p : p + 1;\n    }\n    __host__ __device__ int inverse_peer(int p)  // peer idx of `rank` on peer `p`\n    {\n        return p < rank ? rank - 1 : rank;\n    }\n};\n\n}  // namespace turbomind::comm\n"
  },
  {
    "path": "src/turbomind/comm/cuda_ipc/fused_allreduce.cu",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include <cuda_bf16.h>\n#include <type_traits>\n\n#include \"cub/block/block_reduce.cuh\"\n\n#include \"src/turbomind/comm/cuda_ipc/cuda_ipc_comm.h\"\n#include \"src/turbomind/comm/cuda_ipc/group_sum.h\"\n#include \"src/turbomind/comm/cuda_ipc/multimem.cuh\"\n#include \"src/turbomind/comm/cuda_ipc/semaphore.cuh\"\n\n#include \"src/turbomind/core/data_type.h\"\n#include \"src/turbomind/kernels/core/array_ops.h\"\n#include \"src/turbomind/kernels/core/common.h\"\n#include \"src/turbomind/kernels/core/meta.h\"\n\n#include \"src/turbomind/kernels/norm/rms_norm.h\"\n\n#include \"src/turbomind/utils/cuda_utils.h\"\n\nnamespace turbomind::comm {\n\ntemplate<class T, int vec_size, int block_dim, int groups, class Relaxed>\n__global__ void AllreduceResidualBiasRMSnorm_Simple_Pull(T*                   buf,\n                                                         T*                   res,\n                                                         const T*             bias,\n                                                         const T*             weights,\n                                                         Array<T*, kMaxRanks> symm,\n                                                         SystemSemaphoreInfo* semaphores,\n                                                         int                  rank,\n                                                         int                  ranks,\n                                                         int                  slice,\n                                                         int                  count,\n                                                         int                  vdim,\n                                                         float                inv_dim,\n                                                         float                eps,\n                                                         constant<vec_size>,\n                                                         constant<block_dim>,\n                                                         constant<groups>,\n                                                         Relaxed relaxed)\n{\n    SystemSemaphore sem(semaphores, ranks, blockIdx.x, threadIdx.x);\n\n    sem.Signal(relaxed);\n    sem.Wait(relaxed);\n\n    __syncthreads();\n\n    using Vec = Array<T, vec_size>;\n\n    using namespace ops;\n\n    static_assert(block_dim % groups == 0);\n    constexpr int threads = block_dim / groups;\n\n    static_assert(threads % WARP_SIZE == 0);\n    constexpr int warps = threads / WARP_SIZE;\n\n    const int xi = threadIdx.x / threads;\n    const int di = threadIdx.x % threads;\n    const int bi = blockIdx.x * groups + xi;\n    const int bn = gridDim.x * groups;\n\n    auto syncgroup = [&] {  //\n        asm volatile(\"bar.sync %0, %1;\" : : \"r\"(15 - xi), \"r\"(threads) : \"memory\");\n    };\n\n    const int first = rank * slice;\n    const int last  = min(count, first + slice);\n\n    for (int i = 1; i < ranks - 1; ++i) {\n        const int  p   = rank + i < ranks ? rank + i : rank + i - ranks;\n        const auto src = cvta_generic_to_global(symm[p]);\n        Vec        acc, tmp;\n        for (int ti = first + bi; ti < last; ti += bn) {\n            const int idx = (ti * vdim + di) * vec_size;\n            if (di < vdim) {\n                Load(tmp, src + idx);\n                Load(acc, buf + idx);\n                acc = acc + tmp;\n                Store(buf + idx, acc);\n            }\n        }\n    }\n\n    Vec b_vec{};\n    if (bias && di < vdim) {\n        Ldg(b_vec, bias + di * vec_size);\n    }\n\n    Vec w_vec;\n    if (di < vdim) {\n        Ldg(w_vec, weights + di * vec_size);\n    }\n\n    {\n        const int p   = rank > 0 ? rank - 1 : ranks - 1;  // last peer\n        auto      chn = cvta_generic_to_global(symm[p]);\n        for (int ti = first + bi; ti < last; ti += bn) {\n            const int idx = (ti * vdim + di) * vec_size;\n            Vec       acc, tmp;\n            Vec       r_vec{};\n            float     sum{};\n            if (di < vdim) {\n                Load(tmp, chn + idx);\n                Load(acc, buf + idx);\n                acc = acc + tmp;\n                Load(r_vec, res + idx);\n                r_vec = r_vec + acc;\n                if (bias) {\n                    r_vec = r_vec + b_vec;\n                }\n                Store(res + idx, r_vec);\n                PRAGMA_UNROLL\n                for (int i = 0; i < vec_size; ++i) {\n                    sum += (float)r_vec[i] * (float)r_vec[i];\n                }\n            }\n            sum = detail::GroupSum(sum, warps, syncgroup);\n            __shared__ float shared_sum[groups];\n            if (di == 0) {\n                shared_sum[xi] = rsqrtf(sum * inv_dim + eps);\n            }\n            syncgroup();\n            sum = shared_sum[xi];\n            if (di < vdim) {\n                PRAGMA_UNROLL\n                for (int i = 0; i < vec_size; ++i) {\n                    r_vec[i] = static_cast<T>(((float)r_vec[i] * sum)) * w_vec[i];\n                }\n                Store(buf + idx, r_vec);\n            }\n        }\n    }\n\n    __syncthreads();\n\n    sem.Signal(relaxed);\n    sem.Wait(relaxed);\n\n    __syncthreads();\n\n    for (int i = 1; i < ranks; ++i) {\n        const int p     = rank + i < ranks ? rank + i : rank + i - ranks;\n        const int first = slice * p;\n        const int last  = min(count, first + slice);\n        auto      src   = cvta_generic_to_global(symm[p]);\n        for (int ti = first + bi; ti < last; ti += bn) {\n            const int idx = (ti * vdim + di) * vec_size;\n            if (di < vdim) {\n                Vec vec;\n                Load(vec, src + idx);\n                Store(buf + idx, vec);\n            }\n        }\n    }\n\n    __syncthreads();\n\n    sem.Signal(true);\n    sem.Wait(true);\n\n    sem.Update(semaphores, ranks, blockIdx.x, threadIdx.x);\n}\n\ntemplate<class T, int vec_size, int block_dim, int groups, class Relaxed>\n__global__ void AllreduceResidualBiasRMSnorm_NVLS(T*                   mc_buf,\n                                                  T*                   uc_buf,\n                                                  T*                   res,\n                                                  const T*             bias,\n                                                  const T*             weights,\n                                                  SystemSemaphoreInfo* semaphores,\n                                                  int                  rank,\n                                                  int                  ranks,\n                                                  int                  slice,\n                                                  int                  count,\n                                                  int                  vdim,\n                                                  float                inv_dim,\n                                                  float                eps,\n                                                  constant<vec_size>,\n                                                  constant<block_dim>,\n                                                  constant<groups>,\n                                                  Relaxed relaxed)\n{\n\n#if TURBOMIND_ARCH_SM90\n\n    SystemSemaphore sem(semaphores, ranks, blockIdx.x, threadIdx.x);\n\n    sem.Signal(relaxed);\n\n    static_assert(block_dim % groups == 0);\n    constexpr int threads = block_dim / groups;\n\n    static_assert(threads % WARP_SIZE == 0);\n    constexpr int warps = threads / WARP_SIZE;\n\n    const int xi = threadIdx.x / threads;\n    const int di = threadIdx.x % threads;\n\n    using Vec = Array<T, vec_size>;\n\n    Vec b_vec{};\n    if (bias && di < vdim) {\n        Ldg(b_vec, bias + di * vec_size);\n    }\n\n    Vec w_vec;\n    if (di < vdim) {\n        Ldg(w_vec, weights + di * vec_size);\n    }\n\n    sem.Wait(relaxed);\n\n    __syncthreads();\n\n    using namespace ops;\n\n    const int bi = blockIdx.x * groups + xi;\n    const int bn = gridDim.x * groups;\n\n    auto syncgroup = [&] {  //\n        asm volatile(\"bar.sync %0, %1;\" : : \"r\"(15 - xi), \"r\"(threads) : \"memory\");\n    };\n\n    const int first = rank * slice;\n    const int last  = min(count, first + slice);\n\n    for (int ti = first + bi; ti < last; ti += bn) {\n        const int idx = (ti * vdim + di) * vec_size;\n        float     sum{};\n        Vec       vec;\n        if (di < vdim) {\n            Vec acc = multimem_ld_reduce_sum((const Vec*)(mc_buf + idx));\n            Load(vec, res + idx);\n            vec = vec + acc;\n            if (bias) {\n                vec = vec + b_vec;\n            }\n            Store(res + idx, vec);\n            PRAGMA_UNROLL\n            for (int i = 0; i < vec_size; ++i) {\n                sum += (float)vec[i] * (float)vec[i];\n            }\n        }\n        sum = detail::GroupSum(sum, warps, syncgroup);\n        __shared__ float shared_sum[groups];\n        if (di == 0) {\n            shared_sum[xi] = rsqrtf(sum * inv_dim + eps);\n        }\n        syncgroup();\n        sum = shared_sum[xi];\n        if (di < vdim) {\n            PRAGMA_UNROLL\n            for (int i = 0; i < vec_size; ++i) {\n                vec[i] = static_cast<T>(((float)vec[i] * sum)) * w_vec[i];\n            }\n            multimem_st(mc_buf + idx, vec);\n        }\n    }\n\n    __syncthreads();\n\n    sem.Signal(true);\n    sem.Wait(true);\n\n    sem.Update(semaphores, ranks, blockIdx.x, threadIdx.x);\n\n#endif\n}\n\ntemplate<class T, int vec_size, int block_dim, int groups, class Relaxed>\n__global__ void AllreduceResidualBiasRMSnorm_Simple_Push(T*                   buf,\n                                                         T*                   res,\n                                                         const T*             bias,\n                                                         const T*             weights,\n                                                         T*                   scratch,\n                                                         Array<T*, kMaxRanks> symm_buf,\n                                                         Array<T*, kMaxRanks> symm_scratch,\n                                                         SystemSemaphoreInfo* semaphores,\n                                                         int                  rank,\n                                                         int                  ranks,\n                                                         int                  slice,\n                                                         int                  count,\n                                                         int                  vdim,\n                                                         float                inv_dim,\n                                                         float                eps,\n                                                         constant<vec_size>,\n                                                         constant<block_dim>,\n                                                         constant<groups>,\n                                                         Relaxed relaxed)\n{\n    using Vec = Array<T, vec_size>;\n\n    using namespace ops;\n\n    static_assert(block_dim % groups == 0);\n    constexpr int threads = block_dim / groups;\n\n    static_assert(threads % WARP_SIZE == 0);\n    constexpr int warps = threads / WARP_SIZE;\n\n    const int xi = threadIdx.x / threads;\n    const int di = threadIdx.x % threads;\n    const int bi = blockIdx.x * groups + xi;\n    const int bn = gridDim.x * groups;\n\n    auto syncgroup = [&] {  //\n        asm volatile(\"bar.sync %0, %1;\" : : \"r\"(15 - xi), \"r\"(threads) : \"memory\");\n    };\n\n    for (int i = 1; i < ranks; ++i) {\n        const int  p   = rank + i < ranks ? rank + i : rank + i - ranks;\n        const int  n   = min(count, p * slice + slice) - p * slice;\n        const auto src = buf + p * slice * vdim * vec_size;\n        const auto dst = symm_scratch[p] + rank * slice * vdim * vec_size;\n        for (int ti = bi; ti < n; ti += bn) {\n            if (di < vdim) {\n                Vec vec;\n                Load(vec, src + (ti * vdim + di) * vec_size);\n                Store(dst + (ti * vdim + di) * vec_size, vec);\n            }\n        }\n    }\n\n    __syncthreads();\n\n    SystemSemaphore sem(semaphores, ranks, blockIdx.x, threadIdx.x);\n\n    sem.Signal(relaxed);\n    sem.Wait(relaxed);\n\n    __syncthreads();\n\n    Vec b_vec{};\n    if (bias && di < vdim) {\n        Ldg(b_vec, bias + di * vec_size);\n    }\n\n    Vec w_vec;\n    if (di < vdim) {\n        Ldg(w_vec, weights + di * vec_size);\n    }\n\n    const int n = min(count, rank * slice + slice) - rank * slice;\n\n    for (int ti = bi; ti < n; ti += bn) {\n        const int idx = ((rank * slice + ti) * vdim + di) * vec_size;  // idx into local buffers\n        Vec       r_vec{};\n        float     sum{};\n        if (di < vdim) {\n            Vec acc;\n            Load(acc, buf + idx);\n            for (int i = 1; i < ranks; ++i) {\n                const int p = rank + i < ranks ? rank + i : rank + i - ranks;\n                Vec       tmp;\n                Load(tmp, scratch + ((p * slice + ti) * vdim + di) * vec_size);\n                acc = acc + tmp;\n            }\n            Load(r_vec, res + idx);\n            r_vec = r_vec + acc;\n            if (bias) {\n                r_vec = r_vec + b_vec;\n            }\n            Store(res + idx, r_vec);\n            PRAGMA_UNROLL\n            for (int i = 0; i < vec_size; ++i) {\n                sum += (float)r_vec[i] * (float)r_vec[i];\n            }\n        }\n\n        sum = detail::GroupSum(sum, warps, syncgroup);\n        __shared__ float shared_sum[groups];\n        if (di == 0) {\n            shared_sum[xi] = rsqrtf(sum * inv_dim + eps);\n        }\n        syncgroup();\n        sum = shared_sum[xi];\n\n        if (di < vdim) {\n            PRAGMA_UNROLL\n            for (int i = 0; i < vec_size; ++i) {\n                r_vec[i] = static_cast<T>(((float)r_vec[i] * sum)) * w_vec[i];\n            }\n            Store(buf + idx, r_vec);\n            for (int i = 1; i < ranks; ++i) {\n                const int p = rank + i < ranks ? rank + i : rank + i - ranks;\n                Store(symm_buf[p] + ((rank * slice + ti) * vdim + di) * vec_size, r_vec);\n            }\n        }\n    }\n\n    __syncthreads();\n\n    sem.Signal(true);\n    sem.Wait(true);\n\n    sem.Update(semaphores, ranks, blockIdx.x, threadIdx.x);\n}\n\nvoid CudaIpcCommImpl::AllreduceResidualBiasRMSnorm(void*        hidden,\n                                                   void*        residual,\n                                                   const void*  bias,\n                                                   const void*  weights,\n                                                   float        eps,\n                                                   int          dim,\n                                                   int          token_num,\n                                                   DataType     dtype,\n                                                   int          group,\n                                                   cudaStream_t stream)\n{\n\n    const size_t elemsize = byte_size(dtype);\n    const size_t bytesize = elemsize * token_num * dim;\n\n    const int n_ranks = this->n_ranks(group);\n    const int rank    = this->rank(group);\n\n    auto semaphore = groups_.at(group).semaphore.handle();\n\n    auto invoke = [&](auto t, auto groups) {\n        using T                = decltype(t);\n        auto          symm_ptr = get_symmetric_v2((T*)hidden, group);\n        constexpr int vec_size = sizeof(uint4) / sizeof(T);\n        const int     slice    = (token_num + n_ranks - 1) / n_ranks;\n        const int     count    = token_num;\n\n        if (symm_ptr.mc) {\n            constexpr int block_dim = 1024;\n            const int     max_ctas  = max_ctas_.apply(8);\n            const int     blocks    = std::min((slice + groups - 1) / groups, max_ctas);\n            AllreduceResidualBiasRMSnorm_NVLS<<<blocks, block_dim, 0, stream>>>(symm_ptr.mc,\n                                                                                (T*)hidden,\n                                                                                (T*)residual,\n                                                                                (const T*)bias,\n                                                                                (const T*)weights,\n                                                                                semaphore,\n                                                                                rank,\n                                                                                n_ranks,\n                                                                                slice,\n                                                                                count,\n                                                                                dim / vec_size,\n                                                                                1.f / dim,\n                                                                                eps,\n                                                                                constant<vec_size>{},\n                                                                                constant<block_dim>{},\n                                                                                groups,\n                                                                                std::false_type{});\n        }\n#if 1\n        else if (bytesize <= 1 << 19) {\n            return false;\n        }\n#endif\n        else if (bytesize <= kScratchBuffSize && bytesize <= 6 << 20) {\n            constexpr int block_dim    = 1024;\n            const int     max_ctas     = max_ctas_.apply(48);\n            const int     blocks       = std::min((slice + groups - 1) / groups, max_ctas);\n            auto          symm_scratch = get_symmetric_v2((T*)scratch_buff_, group).uc;\n            AllreduceResidualBiasRMSnorm_Simple_Push<<<blocks, block_dim, 0, stream>>>((T*)hidden,\n                                                                                       (T*)residual,\n                                                                                       (const T*)bias,\n                                                                                       (const T*)weights,\n                                                                                       (T*)scratch_buff_,\n                                                                                       symm_ptr.uc,\n                                                                                       symm_scratch,\n                                                                                       semaphore,\n                                                                                       rank,\n                                                                                       n_ranks,\n                                                                                       slice,\n                                                                                       count,\n                                                                                       dim / vec_size,\n                                                                                       1.f / dim,\n                                                                                       eps,\n                                                                                       constant<vec_size>{},\n                                                                                       constant<block_dim>{},\n                                                                                       groups,\n                                                                                       std::false_type{});\n        }\n        else {\n            constexpr int block_dim = 1024;\n            const int     max_ctas  = max_ctas_.apply(48);\n            const int     blocks    = std::min((slice + groups - 1) / groups, max_ctas);\n            AllreduceResidualBiasRMSnorm_Simple_Pull<<<blocks, block_dim, 0, stream>>>((T*)hidden,\n                                                                                       (T*)residual,\n                                                                                       (const T*)bias,\n                                                                                       (const T*)weights,\n                                                                                       symm_ptr.uc,\n                                                                                       semaphore,\n                                                                                       rank,\n                                                                                       n_ranks,\n                                                                                       slice,\n                                                                                       count,\n                                                                                       dim / vec_size,\n                                                                                       1.f / dim,\n                                                                                       eps,\n                                                                                       constant<vec_size>{},\n                                                                                       constant<block_dim>{},\n                                                                                       groups,\n                                                                                       std::false_type{});\n        }\n\n        return true;\n    };\n\n    auto dispatch_D = [&](auto t) {\n        using T                = decltype(t);\n        constexpr int vec_size = sizeof(uint4) / sizeof(T);\n        if (dim % vec_size) {\n            return false;  // non-aligned\n        }\n        const int vdim = dim / vec_size;\n        if (0) {}\n        else if (vdim <= 256) {\n            return invoke(t, constant<4>{});\n        }\n        else if (vdim <= 512) {\n            return invoke(t, constant<2>{});\n        }\n        else if (vdim <= 1024) {\n            return invoke(t, constant<1>{});\n        }\n        return false;  // > 1024 vdim\n    };\n\n    auto dispatch = [&]() -> bool { TM_DISPATCH_PRIMARY_DTYPES_RET(dtype, dispatch_D); };\n\n    if (dispatch()) {\n        return;\n    }\n\n    // fallback\n    AllReduceSum(hidden, hidden, token_num * dim, dtype, group, stream);\n    invokeResidualBiasRMSNorm(hidden, residual, weights, bias, dtype, dim, token_num, eps, stream);\n}\n\n}  // namespace turbomind::comm\n"
  },
  {
    "path": "src/turbomind/comm/cuda_ipc/fused_allreduce_ex.cu",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include \"src/turbomind/comm/cuda_ipc/common.h\"\n\n#include \"src/turbomind/comm/cuda_ipc/cuda_ipc_comm.h\"\n#include \"src/turbomind/comm/cuda_ipc/group_sum.h\"\n#include \"src/turbomind/comm/cuda_ipc/semaphore.cuh\"\n\n#include \"src/turbomind/comm/cuda_ipc/multimem.cuh\"\n\n#include \"src/turbomind/core/data_type.h\"\n#include \"src/turbomind/kernels/core/array_ops.h\"\n#include \"src/turbomind/kernels/core/common.h\"\n#include \"src/turbomind/kernels/core/meta.h\"\n#include \"src/turbomind/utils/cuda_utils.h\"\n\nnamespace turbomind::comm {\n\ntemplate<class T, int vec_size, int block_dim, int groups, class Relaxed>\n__global__ void AllreduceResidualBiasRMSnormV_Simple_Pull(T*                     buf,\n                                                          T*                     res,\n                                                          const T*               bias,\n                                                          const T*               weights,\n                                                          Array<T*, kMaxRanks>   rs_buf,\n                                                          Array<T*, kMaxRanks>   ag_buf,\n                                                          SystemSemaphoreInfo*   g_semaphores,\n                                                          int                    rs_rank,\n                                                          int                    ag_rank,\n                                                          int                    rs_ranks,\n                                                          int                    ag_ranks,\n                                                          int                    g_rank,\n                                                          int                    g_ranks,\n                                                          int                    offset,\n                                                          int                    first,\n                                                          int                    last,\n                                                          Array<int2, kMaxRanks> ag_ranges,\n                                                          int                    vdim,\n                                                          float                  inv_dim,\n                                                          float                  eps,\n                                                          constant<vec_size>,\n                                                          constant<block_dim>,\n                                                          constant<groups>,\n                                                          Relaxed relaxed)\n{\n    SystemSemaphore sem(g_semaphores, g_ranks, blockIdx.x, threadIdx.x);\n\n    sem.Signal(relaxed);\n\n    using Vec = Array<T, vec_size>;\n\n    using namespace ops;\n\n    static_assert(block_dim % groups == 0);\n    constexpr int threads = block_dim / groups;\n\n    static_assert(threads % WARP_SIZE == 0);\n    constexpr int warps = threads / WARP_SIZE;\n\n    const int xi = threadIdx.x / threads;\n    const int di = threadIdx.x % threads;\n\n    Vec b_vec{};\n    if (bias && di < vdim) {\n        Ldg(b_vec, bias + di * vec_size);\n    }\n\n    Vec w_vec;\n    if (di < vdim) {\n        Ldg(w_vec, weights + di * vec_size);\n    }\n\n    sem.Wait(relaxed);\n\n    __syncthreads();\n\n    const int bi = blockIdx.x * groups + xi;\n    const int bn = gridDim.x * groups;\n\n    for (int i = 1; i < rs_ranks - 1; ++i) {\n        const int  p   = rs_rank + i < rs_ranks ? rs_rank + i : rs_rank + i - rs_ranks;\n        const auto src = cvta_generic_to_global(rs_buf[p]);\n        Vec        acc, tmp;\n        for (int ti = offset + first + bi; ti < offset + last; ti += bn) {\n            const int idx = (ti * vdim + di) * vec_size;\n            if (di < vdim) {\n                Load(tmp, src + idx);\n                Load(acc, buf + idx);\n                acc = acc + tmp;\n                Store(buf + idx, acc);\n            }\n        }\n    }\n\n    auto syncgroup = [&] {  //\n        asm volatile(\"bar.sync %0, %1;\" : : \"r\"(15 - xi), \"r\"(threads) : \"memory\");\n    };\n\n    {\n        const T* chn{};\n        if (rs_ranks > 1) {\n            const int p = rs_rank > 0 ? rs_rank - 1 : rs_ranks - 1;  // last peer\n            chn         = cvta_generic_to_global(rs_buf[p]);\n        }\n        for (int ti = first + bi; ti < last; ti += bn) {\n            const int idx = ((offset + ti) * vdim + di) * vec_size;\n            Vec       acc, tmp;\n            Vec       r_vec{};\n            float     sum{};\n            if (di < vdim) {\n                if (chn) {\n                    Load(tmp, chn + idx);\n                }\n                Load(acc, buf + idx);\n                if (chn) {\n                    acc = acc + tmp;\n                }\n                Load(r_vec, res + (ti * vdim + di) * vec_size);\n                r_vec = r_vec + acc;\n                if (bias) {\n                    r_vec = r_vec + b_vec;\n                }\n                Store(res + (ti * vdim + di) * vec_size, r_vec);\n                PRAGMA_UNROLL\n                for (int i = 0; i < vec_size; ++i) {\n                    sum += (float)r_vec[i] * (float)r_vec[i];\n                }\n            }\n            sum = detail::GroupSum(sum, warps, syncgroup);\n            __shared__ float shared_sum[groups];\n            if (di == 0) {\n                shared_sum[xi] = rsqrtf(sum * inv_dim + eps);\n            }\n            syncgroup();\n            sum = shared_sum[xi];\n            if (di < vdim) {\n                PRAGMA_UNROLL\n                for (int i = 0; i < vec_size; ++i) {\n                    r_vec[i] = static_cast<T>(((float)r_vec[i] * sum)) * w_vec[i];\n                }\n                Store(buf + idx, r_vec);\n            }\n        }\n    }\n\n    __syncthreads();\n\n    sem.Signal(relaxed);\n    sem.Wait(relaxed);\n\n    __syncthreads();\n\n#if 1\n    for (int i = 1; i < ag_ranks; ++i) {\n        const int p   = ag_rank + i < ag_ranks ? ag_rank + i : ag_rank + i - ag_ranks;\n        auto      dst = cvta_generic_to_global(ag_buf[p]);\n        for (int ti = offset + first + bi; ti < offset + last; ti += bn) {\n            const int idx = (ti * vdim + di) * vec_size;\n            if (di < vdim) {\n                Vec vec;\n                Load(vec, buf + idx);\n                Store(dst + idx, vec);\n            }\n        }\n    }\n#else\n    for (int i = 1; i < ag_ranks; ++i) {\n        const int p              = ag_rank + i < ag_ranks ? ag_rank + i : ag_rank + i - ag_ranks;\n        const auto [first, last] = ag_ranges[p];\n        auto src                 = cvta_generic_to_global(ag_buf[p]);\n        for (int ti = first + bi; ti < last; ti += bn) {\n            const int idx = (ti * vdim + di) * vec_size;\n            if (di < vdim) {\n                Vec vec;\n                Load(vec, src + idx);\n                Store(buf + idx, vec);\n            }\n        }\n    }\n#endif\n\n    __syncthreads();\n\n    sem.Signal(true);\n    sem.Wait(true);\n\n    sem.Update(g_semaphores, g_ranks, blockIdx.x, threadIdx.x);\n}\n\ntemplate<class T, int vec_size, int block_dim, int groups, class Relaxed>\n__global__ void AllreduceResidualBiasRMSnormV_NVLS(T*                   rs_mc_buf,\n                                                   T*                   ag_mc_buf,\n                                                   T*                   res,\n                                                   const T*             bias,\n                                                   const T*             weights,\n                                                   SystemSemaphoreInfo* semaphores,\n                                                   int                  g_rank,\n                                                   int                  g_ranks,\n                                                   int                  first,\n                                                   int                  last,\n                                                   int                  offset,\n                                                   int                  vdim,\n                                                   float                inv_dim,\n                                                   float                eps,\n                                                   constant<vec_size>,\n                                                   constant<block_dim>,\n                                                   constant<groups>,\n                                                   Relaxed relaxed)\n{\n\n#if TURBOMIND_ARCH_SM90\n\n    SystemSemaphore sem(semaphores, g_ranks, blockIdx.x, threadIdx.x);\n\n    sem.Signal(relaxed);\n\n    using Vec = Array<T, vec_size>;\n\n    using namespace ops;\n\n    static_assert(block_dim % groups == 0);\n    constexpr int threads = block_dim / groups;\n\n    static_assert(threads % WARP_SIZE == 0);\n    constexpr int warps = threads / WARP_SIZE;\n\n    const int xi = threadIdx.x / threads;\n    const int di = threadIdx.x % threads;\n\n    using Vec = Array<T, vec_size>;\n\n    Vec b_vec{};\n    if (bias && di < vdim) {\n        Ldg(b_vec, bias + di * vec_size);\n    }\n\n    Vec w_vec;\n    if (di < vdim) {\n        Ldg(w_vec, weights + di * vec_size);\n    }\n\n    sem.Wait(relaxed);\n\n    __syncthreads();\n\n    const int bi = blockIdx.x * groups + xi;\n    const int bn = gridDim.x * groups;\n\n    auto syncgroup = [&] {  //\n        asm volatile(\"bar.sync %0, %1;\" : : \"r\"(15 - xi), \"r\"(threads) : \"memory\");\n    };\n\n    for (int ti = first + bi; ti < last; ti += bn) {\n        const int idx = ((offset + ti) * vdim + di) * vec_size;\n        float     sum{};\n        Vec       vec;\n        if (di < vdim) {\n            Vec acc = multimem_ld_reduce_sum((const Vec*)(rs_mc_buf + idx));\n            Load(vec, res + (ti * vdim + di) * vec_size);\n            vec = vec + acc;\n            if (bias) {\n                vec = vec + b_vec;\n            }\n            Store(res + (ti * vdim + di) * vec_size, vec);\n            PRAGMA_UNROLL\n            for (int i = 0; i < vec_size; ++i) {\n                sum += (float)vec[i] * (float)vec[i];\n            }\n        }\n        sum = detail::GroupSum(sum, warps, syncgroup);\n        __shared__ float shared_sum[groups];\n        if (di == 0) {\n            shared_sum[xi] = rsqrtf(sum * inv_dim + eps);\n        }\n        syncgroup();\n        sum = shared_sum[xi];\n        if (di < vdim) {\n            PRAGMA_UNROLL\n            for (int i = 0; i < vec_size; ++i) {\n                vec[i] = static_cast<T>(((float)vec[i] * sum)) * w_vec[i];\n            }\n            multimem_st(ag_mc_buf + idx, vec);\n        }\n    }\n\n    __syncthreads();\n\n    sem.Signal(true);\n    sem.Wait(true);\n\n    sem.Update(semaphores, g_ranks, blockIdx.x, threadIdx.x);\n#endif\n}\n\nvoid CudaIpcCommImpl::AllreduceResidualBiasRMSnormEx(void*        hidden,\n                                                     void*        residual,\n                                                     const void*  bias,\n                                                     const void*  weights,\n                                                     float        eps,\n                                                     int          dim,\n                                                     DataType     dtype,\n                                                     int          group0,\n                                                     int          group1,\n                                                     const int*   local_token_nums,\n                                                     cudaStream_t stream)\n{\n    FT_CHECK(group0 * group1 == 0);\n\n    const auto& g0 = groups_.at(group0);\n    const auto& g1 = groups_.at(group1);\n\n    const int tp0 = n_ranks(group0);\n    const int tp1 = n_ranks(group1);\n\n    const int inner_tp = std::min(tp0, tp1);\n\n    FT_CHECK(tp0 % inner_tp == 0 && tp1 % inner_tp == 0);\n\n    Array<int, kMaxRanks> offsets{};\n    Array<int, kMaxRanks> firsts{};\n    Array<int, kMaxRanks> lasts{};\n\n    for (int i = 0, offset = 0; i < global_n_ranks_; ++i) {\n        const int num   = local_token_nums[i / inner_tp];\n        const int slice = (num + inner_tp - 1) / inner_tp;\n        const int first = std::min(num, i % inner_tp * slice);\n        const int last  = std::min(num, first + slice);\n\n        std::tie(offsets[i], firsts[i], lasts[i]) = std::tie(offset, first, last);\n\n        if ((i + 1) % inner_tp == 0) {\n            offset += num;\n        }\n    }\n    const int g_rank = rank(0);\n\n    const int first  = firsts[g_rank];\n    const int last   = lasts[g_rank];\n    const int offset = offsets[g_rank];\n\n    auto semaphore = groups_.at(0).semaphore.handle();\n\n    auto invoke = [&](auto t, auto groups) {\n        using T                = decltype(t);\n        constexpr int vec_size = sizeof(uint4) / sizeof(T);\n\n        auto rs_symm_ptr = get_symmetric_v2((T*)hidden, group0);\n        auto ag_symm_ptr = get_symmetric_v2((T*)hidden, group1);\n\n        if (rs_symm_ptr.mc && ag_symm_ptr.mc) {\n            const int max_ctas = max_ctas_.apply(40);\n            AllreduceResidualBiasRMSnormV_NVLS<<<max_ctas, 1024, 0, stream>>>(rs_symm_ptr.mc,\n                                                                              ag_symm_ptr.mc,\n                                                                              (T*)residual,\n                                                                              (const T*)bias,\n                                                                              (const T*)weights,\n                                                                              semaphore,\n                                                                              g_rank,\n                                                                              n_ranks(0),\n                                                                              first,\n                                                                              last,\n                                                                              offset,\n                                                                              dim / vec_size,\n                                                                              1.f / dim,\n                                                                              eps,\n                                                                              constant<vec_size>{},\n                                                                              constant<1024>{},\n                                                                              constant<1>{},\n                                                                              std::true_type{});\n        }\n        else {\n            Array<int2, kMaxRanks> ag_ranges{};\n            for (int i = 0; i < tp1; ++i) {\n                const auto r = g1.l2g[i];\n                ag_ranges[i] = {offsets[r] + firsts[r], offsets[r] + lasts[r]};\n            }\n            const int max_ctas = max_ctas_.apply(48);\n            AllreduceResidualBiasRMSnormV_Simple_Pull<<<max_ctas, 1024, 0, stream>>>((T*)hidden,\n                                                                                     (T*)residual,\n                                                                                     (const T*)bias,\n                                                                                     (const T*)weights,\n                                                                                     rs_symm_ptr.uc,\n                                                                                     ag_symm_ptr.uc,\n                                                                                     semaphore,\n                                                                                     rank(group0),\n                                                                                     rank(group1),\n                                                                                     tp0,\n                                                                                     tp1,\n                                                                                     rank(0),\n                                                                                     n_ranks(0),\n                                                                                     offset,\n                                                                                     first,\n                                                                                     last,\n                                                                                     ag_ranges,\n                                                                                     dim / vec_size,\n                                                                                     1.f / dim,\n                                                                                     eps,\n                                                                                     constant<vec_size>{},\n                                                                                     constant<1024>{},\n                                                                                     constant<1>{},\n                                                                                     std::true_type{});\n        }\n        return true;\n    };\n\n    sync_check_cuda_error();\n\n    auto dispatch_D = [&](auto t) {\n        using T                = decltype(t);\n        constexpr int vec_size = sizeof(uint4) / sizeof(T);\n        if (dim % vec_size) {\n            return false;  // non-aligned\n        }\n        const int vdim = dim / vec_size;\n        if (0) {}\n        else if (vdim <= 256) {\n            return invoke(t, constant<4>{});\n        }\n        else if (vdim <= 512) {\n            return invoke(t, constant<2>{});\n        }\n        else if (vdim <= 1024) {\n            return invoke(t, constant<1>{});\n        }\n        return false;  // > 1024 vdim\n    };\n\n    auto dispatch = [&]() -> bool {  //\n        TM_DISPATCH_PRIMARY_DTYPES_RET(dtype, dispatch_D);\n    };\n\n    TM_CHECK(dispatch());\n}\n\n}  // namespace turbomind::comm\n"
  },
  {
    "path": "src/turbomind/comm/cuda_ipc/group_sum.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include \"src/turbomind/kernels/core/common.h\"\n\nnamespace turbomind::comm {\n\nnamespace detail {\n\ntemplate<class Syncgroup>\n__device__ float GroupSum(const float val, int warps, Syncgroup syncgroup)\n{\n    const int warp_id = threadIdx.x / WARP_SIZE;\n    const int lane_id = threadIdx.x % WARP_SIZE;\n    float     sum     = val;\n    PRAGMA_UNROLL\n    for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {\n        sum += __shfl_xor_sync((uint32_t)-1, sum, mask);\n    }\n    __shared__ float smem[32];\n    // syncgroup();\n    if (lane_id == 0) {\n        smem[warp_id] = sum;\n    }\n    syncgroup();\n    for (int i = 1; i < warps; ++i) {\n        sum += smem[warp_id / warps * warps + i];\n    }\n    // sum = {};\n    // for (int i = 0; i < warps; ++i) {\n    //     sum += smem[warp_id / warps * warps + i];\n    // }\n    return sum;\n}\n\n}  // namespace detail\n\n}  // namespace turbomind::comm\n"
  },
  {
    "path": "src/turbomind/comm/cuda_ipc/mscclpp.h",
    "content": "// Copyright (c) Microsoft Corporation.\n// Licensed under the MIT license.\n\n#pragma once\n\n#include <cstdint>\n#include <cuda_runtime.h>\n\nnamespace mscclpp {\n\n// Copied from\n// https://github.com/microsoft/mscclpp/blob/591276f9d07d2df8e2a45a16738e27867e468ca3/include/mscclpp/packet_device.hpp#L19\nunion alignas(16) LL16Packet {\n    // Assume data is written with an atomicity of 8 bytes (IB/RDMA).\n    struct {\n        uint32_t data1;\n        uint32_t flag1;\n        uint32_t data2;\n        uint32_t flag2;\n    };\n    using Payload = uint2;\n\n    ulonglong2 raw_;\n\n    __device__ LL16Packet() {}\n\n    __device__ LL16Packet(uint2 val, uint32_t flag)\n    {\n        data1 = val.x;\n        flag1 = flag;\n        data2 = val.y;\n        flag2 = flag;\n    }\n\n    /// Write 8 bytes of data to the packet.\n    /// @param val1 The first 4-byte data to write.\n    /// @param val2 The second 4-byte data to write.\n    /// @param flag The flag to write.\n    __device__ void write(uint32_t val1, uint32_t val2, uint32_t flag)\n    {\n        asm volatile(\n            \"st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};\" ::\"l\"(&raw_), \"r\"(val1), \"r\"(flag), \"r\"(val2), \"r\"(flag));\n    }\n\n    /// Write 8 bytes of data to the packet.\n    /// @param val The 8-byte data to write.\n    /// @param flag The flag to write.\n    __device__ void write(uint64_t val, uint32_t flag)\n    {\n        write((uint32_t)val, (uint32_t)(val >> 32), flag);\n    }\n\n    /// Write 8 bytes of data to the packet.\n    /// @param val The 8-byte data to write.\n    /// @param flag The flag to write.\n    __device__ void write(uint2 val, uint32_t flag)\n    {\n        write(val.x, val.y, flag);\n    }\n\n    /// Helper of @ref read().\n    /// @param flag The flag to read.\n    /// @param data The 8-byte data read.\n    /// @return True if the flag is not equal to the given flag.\n    __device__ bool readOnce(uint32_t flag, uint2& data) const\n    {\n        uint32_t flag1, flag2;\n        asm volatile(\"ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];\"\n                     : \"=r\"(data.x), \"=r\"(flag1), \"=r\"(data.y), \"=r\"(flag2)\n                     : \"l\"(&raw_));\n        return (flag1 != flag) || (flag2 != flag);\n    }\n\n    /// Read 8 bytes of data from the packet.\n    /// @param flag The flag to read.\n    /// @return The 8-byte data read.\n    __device__ uint2 read(uint32_t flag) const\n    {\n        uint2 data;\n        while (readOnce(flag, data)) {}\n        return data;\n    }\n\n    /// Clear the packet.\n    __device__ void clear()\n    {\n        raw_ = make_ulonglong2(0, 0);\n    }\n};\n\nusing LLPacket = LL16Packet;\n\n}  // namespace mscclpp\n"
  },
  {
    "path": "src/turbomind/comm/cuda_ipc/multimem.cuh",
    "content": "#pragma once\n\n#include \"src/turbomind/kernels/core/array.h\"\n#include <cuda_bf16.h>\n\nnamespace turbomind {\n\ntemplate<class T, int N>\ninline __device__ Array<T, N> multimem_ld_reduce_sum(const Array<T, N>* mc_ptr)\n{\n    return {};\n}\n\ninline __device__ Array<half, 8> multimem_ld_reduce_sum(const Array<half, 8>* mc_ptr)\n{\n    union {\n        Array<half, 8>     x;\n        Array<uint32_t, 4> u;\n    };\n    // LDGMC.E.ADD.F16x8.RN.STRONG.SYS\n    asm volatile(\"multimem.ld_reduce.weak.global.add.v4.f16x2 {%0,%1,%2,%3}, [%4];\"\n                 : \"=r\"(u[0]), \"=r\"(u[1]), \"=r\"(u[2]), \"=r\"(u[3])\n                 : \"l\"(mc_ptr)\n                 : \"memory\");\n    return x;\n}\n\ninline __device__ Array<nv_bfloat16, 8> multimem_ld_reduce_sum(const Array<nv_bfloat16, 8>* mc_ptr)\n{\n    union {\n        Array<nv_bfloat16, 8> x;\n        Array<uint32_t, 4>    u;\n    };\n    asm volatile(\"multimem.ld_reduce.weak.global.add.v4.bf16x2 {%0,%1,%2,%3}, [%4];\"\n                 : \"=r\"(u[0]), \"=r\"(u[1]), \"=r\"(u[2]), \"=r\"(u[3])\n                 : \"l\"(mc_ptr)\n                 : \"memory\");\n    return x;\n}\n\ntemplate<class T, int N>\ninline __device__ void multimem_st(T* mc_ptr, const Array<T, N>& vec)\n{\n}\n\ninline __device__ void multimem_st(half* mc_ptr, const Array<half, 8>& vec)\n{\n    union {\n        Array<half, 8>     x;\n        Array<uint32_t, 4> u;\n    };\n    x = vec;\n    // STG.E.128\n    asm volatile(\"multimem.st.weak.global.v4.f16x2 [%0], {%1,%2,%3,%4};\" ::\"l\"(mc_ptr),\n                 \"r\"(u[0]),\n                 \"r\"(u[1]),\n                 \"r\"(u[2]),\n                 \"r\"(u[3]));\n}\n\ninline __device__ void multimem_st(nv_bfloat16* mc_ptr, const Array<nv_bfloat16, 8>& vec)\n{\n    union {\n        Array<nv_bfloat16, 8> x;\n        Array<uint32_t, 4>    u;\n    };\n    x = vec;\n    asm volatile(\"multimem.st.weak.global.v4.bf16x2 [%0], {%1,%2,%3,%4};\" ::\"l\"(mc_ptr),\n                 \"r\"(u[0]),\n                 \"r\"(u[1]),\n                 \"r\"(u[2]),\n                 \"r\"(u[3]));\n}\n\ninline __device__ void multimem_st(uint4* mc_ptr, const uint4& u)\n{\n    asm volatile(\n        \"multimem.st.weak.global.v4.f16x2 [%0], {%1,%2,%3,%4};\" ::\"l\"(mc_ptr), \"r\"(u.x), \"r\"(u.y), \"r\"(u.z), \"r\"(u.w));\n}\n\ninline __device__ void multimem_st(uint2* mc_ptr, const uint2& u) {}\n\ninline __device__ void multimem_st(uint* mc_ptr, const uint& u) {}\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/comm/cuda_ipc/semaphore.cuh",
    "content": "#pragma once\n\n#include <vector>\n\n#include \"src/turbomind/kernels/core/array.h\"\n\n#include \"src/turbomind/comm/cuda_ipc/common.h\"\n#include \"src/turbomind/comm/cuda_ipc/semaphore.h\"\n\nnamespace turbomind::comm {\n\ntemplate<class T>\n__device__ T* cvta_generic_to_global(T* p)\n{\n    uintptr_t ret;\n    asm(\"cvta.to.global.u64 %0, %1;\" : \"=l\"(ret) : \"l\"(p));\n    return reinterpret_cast<T*>(ret);\n}\n\nstruct SystemSemaphore {\n\n    using T = uint64_t;\n\n    T* outbound_;\n    T* inbound_;\n    T  expected_;\n    // T* mc_ptr_;\n\n    bool uc_predicate_;\n    // bool mc_predicate_;\n\n    __device__ SystemSemaphore(const SystemSemaphoreInfo* info, int ranks, int channel, int thread_idx)\n    {\n        uc_predicate_ = thread_idx < ranks;\n        // mc_predicate_ = thread_idx == 0;\n\n        if (uc_predicate_) {\n            int index = channel * kMaxRanks + thread_idx;\n            inbound_  = info->inbound[index];\n            outbound_ = info->outbound[index];\n            expected_ = info->expected[index];\n            // mc_ptr_   = info->mc_ptr[channel];\n        }\n    }\n\n    __device__ void Update(SystemSemaphoreInfo* info, int ranks, int channel, int thread_idx)\n    {\n        if (uc_predicate_) {\n            info->expected[channel * kMaxRanks + thread_idx] = expected_;\n        }\n    }\n\n    __device__ void Signal(bool relaxed)\n    {\n        if (uc_predicate_) {\n            if (relaxed) {\n                asm volatile(\"atom.relaxed.sys.global.add.u64 _, [%0], %1;\" ::\"l\"(outbound_), \"n\"(1) : \"memory\");\n            }\n            else {\n                asm volatile(\"atom.release.sys.global.add.u64 _, [%0], %1;\" ::\"l\"(outbound_), \"n\"(1) : \"memory\");\n            }\n        }\n    }\n\n    __device__ void Wait(bool relaxed)\n    {\n        if (uc_predicate_) {\n            ++expected_;\n            T x{};\n            do {\n                if (relaxed) {\n                    asm volatile(\"ld.relaxed.sys.global.u64 %0,[%1];\" : \"=l\"(x) : \"l\"(inbound_) : \"memory\");\n                }\n                else {\n                    asm volatile(\"ld.acquire.sys.global.u64 %0,[%1];\" : \"=l\"(x) : \"l\"(inbound_) : \"memory\");\n                }\n            } while (x < expected_);\n        }\n    }\n\n    //     __device__ void SignalMulticast(bool relaxed)\n    //     {\n    // #if TURBOMIND_ARCH_SM90\n    //         if (mc_predicate_) {\n    //             if (relaxed) {\n    //                 asm volatile(\"multimem.red.relaxed.sys.global.add.u64 [%0], %1;\" ::\"l\"(mc_ptr_), \"n\"(1) :\n    //                 \"memory\");\n    //             }\n    //             else {\n    //                 asm volatile(\"multimem.red.release.sys.global.add.u64 [%0], %1;\" ::\"l\"(mc_ptr_), \"n\"(1) :\n    //                 \"memory\");\n    //             }\n    //             asm volatile(\"fence.proxy.alias;\" ::: \"memory\");\n    //         }\n    // #endif\n    //     }\n};\n\n}  // namespace turbomind::comm\n"
  },
  {
    "path": "src/turbomind/comm/cuda_ipc/semaphore.h",
    "content": "#pragma once\n\n#include <cuda_runtime.h>\n\n#include \"src/turbomind/comm/cuda_ipc/common.h\"\n#include \"src/turbomind/utils/cuda_utils.h\"\n\nnamespace turbomind::comm {\n\nstruct SystemSemaphoreInfo {\n    uint64_t* outbound[kMaxChannels * kMaxRanks];\n    uint64_t* inbound[kMaxChannels * kMaxRanks];\n    uint64_t  expected[kMaxChannels * kMaxRanks];\n    // uint64_t* mc_ptr[kMaxChannels];\n};\n\nstruct SystemSemaphoreStorage {\n\n    uint64_t*            data_{};  // uint32[kMaxChannels][kMaxRanks], symmetric\n    SystemSemaphoreInfo* info_{};\n\n    template<class AllocReg>\n    void Allocate(int ranks, int rank, AllocReg alloc_reg)\n    {\n        const size_t byte_size = sizeof(uint64_t) * kMaxChannels * kMaxRanks;\n\n        SymmetricPtr_V2<uint64_t> v = alloc_reg(byte_size);\n\n        data_ = v.uc[rank];\n\n        SystemSemaphoreInfo info{};\n\n        for (int c = 0; c < kMaxChannels; ++c) {  // block idx\n            for (int r = 0; r < ranks; ++r) {     // thread idx\n                info.inbound[c * kMaxRanks + r]  = v.uc[rank] + c * kMaxRanks + r;\n                info.outbound[c * kMaxRanks + r] = v.uc[r] + c * kMaxRanks + rank;\n                // info.mc_ptr[c]                   = v.mc + c * kMaxRanks + rank;\n            }\n        }\n\n        check_cuda_error(cudaMallocAsync(&info_, sizeof(SystemSemaphoreInfo), 0));\n        check_cuda_error(cudaMemcpyAsync(info_, &info, sizeof(SystemSemaphoreInfo), cudaMemcpyDefault, 0));\n\n        check_cuda_error(cudaStreamSynchronize(0));\n    }\n\n    template<class DeregFree>\n    void Free(DeregFree dereg_free)\n    {\n        check_cuda_error(cudaFreeAsync(info_, 0));\n        info_ = {};\n\n        dereg_free(data_);\n        data_ = {};\n    }\n\n    SystemSemaphoreInfo* handle()\n    {\n        return info_;\n    }\n};\n\n}  // namespace turbomind::comm\n"
  },
  {
    "path": "src/turbomind/comm/device_comm.cc",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include \"src/turbomind/comm/device_comm.h\"\n#include \"src/turbomind/utils/cuda_utils.h\"\n\nnamespace turbomind::comm {\n\nDeviceCommImpl::~DeviceCommImpl() = default;\n\nDeviceComm CreateNcclCommunicator(int n_ranks, int rank, HostComm h_comm);\n\nDeviceComm CreateCudaIpcCommunicator(int n_ranks, int rank, HostComm h_comm);\n\nDeviceComm CreateDeviceCommunicator(const std::string& backend, int n_ranks, int rank, HostComm h_comm)\n{\n#if BUILD_MULTI_GPU && USE_NCCL\n    if (backend == \"nccl\") {\n        return CreateNcclCommunicator(n_ranks, rank, h_comm);\n    }\n#endif\n\n#if BUILD_MULTI_GPU\n    if (backend == \"native\" || backend == \"cuda-ipc\") {\n        return CreateCudaIpcCommunicator(n_ranks, rank, h_comm);\n    }\n#endif\n\n    TM_CHECK(0) << \"Unknown communication backend: \" << backend;\n    return {};\n}\n\n}  // namespace turbomind::comm\n"
  },
  {
    "path": "src/turbomind/comm/device_comm.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include <memory>\n\n#include <stdexcept>\n\n#include <cuda_runtime.h>\n\n#include \"src/turbomind/comm/host_comm.h\"\n\nnamespace turbomind::comm {\n\nenum QueryAttr\n{\n    kHasAllGather2D\n};\n\nclass DeviceCommImpl {\npublic:\n    virtual ~DeviceCommImpl();\n\n    virtual int n_ranks(int group) const = 0;\n\n    virtual int rank(int group) const = 0;\n\n    virtual void* Allocate(size_t size) = 0;\n\n    virtual void Free(void* ptr) = 0;\n\n    virtual void Register(void* ptr, size_t size) = 0;\n\n    virtual void Deregister(void* ptr) = 0;\n\n    virtual int Split(int color, int key, int group)\n    {\n        throw std::runtime_error(\"not implemented\");\n    }\n\n    virtual int Query(QueryAttr attr) const noexcept = 0;\n\n    virtual void AllReduceSum(const void*  sendbuff,  //\n                              void*        recvbuff,\n                              size_t       count,\n                              DataType     type,\n                              int          group,\n                              cudaStream_t stream) = 0;\n\n    virtual void AllGather(const void*  sendbuff,  //\n                           void*        recvbuff,\n                           size_t       sendcount,\n                           DataType     type,\n                           int          group,\n                           cudaStream_t stream) = 0;\n\n    virtual void ReduceScatter(const void*  sendbuff,  //\n                               void*        recvbuff,\n                               size_t       recvcount,\n                               DataType     type,\n                               int          group,\n                               cudaStream_t stream)\n    {\n        throw std::runtime_error(\"not implemented\");\n    }\n\n    virtual void AllreduceResidualBiasRMSnorm(void*        hidden,\n                                              void*        residual,\n                                              const void*  bias,\n                                              const void*  weights,\n                                              float        eps,\n                                              int          dim,\n                                              int          token_num,\n                                              DataType     dtype,\n                                              int          group,\n                                              cudaStream_t stream)\n    {\n        throw std::runtime_error(\"not implemented\");\n    }\n\n    virtual void AllreduceResidualBiasRMSnormEx(void*        hidden,\n                                                void*        residual,\n                                                const void*  bias,\n                                                const void*  weights,\n                                                float        eps,\n                                                int          dim,\n                                                DataType     type,\n                                                int          group0,\n                                                int          group1,\n                                                const int*   local_token_nums,\n                                                cudaStream_t stream)\n    {\n        throw std::runtime_error(\"not implemented\");\n    }\n\n    virtual void AllGather2D(const void*  sendbuff,\n                             void*        recvbuff,\n                             size_t       pitch,\n                             size_t       stride,\n                             int          width,\n                             int          height,\n                             DataType     type,\n                             int2         flags,  // (is_first, is_last)\n                             int          group,\n                             cudaStream_t stream)\n    {\n        throw std::runtime_error(\"not implemented\");\n    }\n\n    virtual void Broadcast(const void*  sendbuff,  //\n                           void*        recvbuff,\n                           size_t       count,\n                           DataType     type,\n                           int          root,\n                           int          group,\n                           cudaStream_t stream)\n    {\n        throw std::runtime_error(\"not implemented\");\n    }\n};\n\nclass DeviceComm {\npublic:\n    DeviceComm() = default;\n\n    /* implicit */ DeviceComm(std::unique_ptr<DeviceCommImpl> impl): impl_{std::move(impl)} {}\n\n    DeviceCommImpl* operator->() const noexcept\n    {\n        return impl_.get();\n    }\n\n    operator DeviceCommImpl*() const noexcept\n    {\n        return impl_.get();\n    }\n\nprivate:\n    std::unique_ptr<DeviceCommImpl> impl_;\n};\n\nDeviceComm CreateDeviceCommunicator(const std::string& backend,  //\n                                    int                n_ranks,\n                                    int                rank,\n                                    HostComm           h_comm);\n\n}  // namespace turbomind::comm\n"
  },
  {
    "path": "src/turbomind/comm/env.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include <cstdlib>\n#include <sstream>\n#include <string>\n#include <type_traits>\n\n#include \"src/turbomind/utils/logger.h\"\n\nnamespace turbomind {\n\ntemplate<class E>\nauto GetEnv()\n{\n    static auto value = [] {\n        bool is_set{};\n        auto x  = E::init();\n        using T = decltype(x);\n        try {\n            if (auto p = std::getenv(E::full_name)) {\n                is_set = true;\n                if constexpr (std::is_integral_v<T>) {\n                    x = std::stoll(p);\n                }\n                else if constexpr (std::is_floating_point_v<T>) {\n                    x = std::stod(p);\n                }\n                else if constexpr (std::is_same_v<T, std::string>) {\n                    x = std::string{p};\n                }\n                else {\n                    static_assert(!std::is_same_v<T, T>, \"not implemented\");\n                }\n            }\n        }\n        catch (...) {\n        }\n        if (is_set) {\n            std::stringstream ss;\n            ss << x;\n            TM_LOG_INFO(\"[%s] %s=%s\", E::prefix, E::name, ss.str().c_str());\n        }\n        return x;\n    }();\n    return value;\n}\n\n#define TM_ENV_VAR(prefix_, name_, init_)                                                                              \\\n    struct prefix_##_##name_ {                                                                                         \\\n        static auto init()                                                                                             \\\n        {                                                                                                              \\\n            return init_;                                                                                              \\\n        }                                                                                                              \\\n        static constexpr auto prefix    = #prefix_;                                                                    \\\n        static constexpr auto name      = #name_;                                                                      \\\n        static constexpr auto full_name = \"TM_\" #prefix_ \"_\" #name_;                                                   \\\n    }\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/comm/gloo/CMakeLists.txt",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\ncmake_minimum_required(VERSION 3.8)\n\ninclude(FetchContent)\nFetchContent_Declare(\n  gloo\n  GIT_REPOSITORY https://github.com/pytorch/gloo.git\n  GIT_TAG c7b7b022c124d9643957d9bd55f57ac59fce8fa2 # pytorch-v2.8.0-rc4\n)\n\n# some settings of gloo,\nset(GLOO_INSTALL OFF CACHE BOOL \"\" FORCE)\nset(GLOO_STATIC_OR_SHARED STATIC CACHE STRING \"\" FORCE)\nset(USE_NCCL OFF)\nset(BUILD_TEST OFF)\nset(USE_IBVERBS OFF)\nFetchContent_MakeAvailable(gloo)\n\n# gloo build doesn't add include directories as a target property...\ntarget_include_directories(gloo PUBLIC\n    $<BUILD_INTERFACE:${gloo_SOURCE_DIR}>\n    $<BUILD_INTERFACE:${gloo_BINARY_DIR}> # config.h generated at cmake config time\n)\n\ntarget_compile_options(gloo PRIVATE\n    $<$<CXX_COMPILER_ID:MSVC>:/W0>\n    $<$<OR:$<CXX_COMPILER_ID:GNU>,$<CXX_COMPILER_ID:Clang>>:-w>\n)\n\nadd_library(gloo_comm STATIC\n    gloo_comm.cc\n    hybrid_comm.cc\n    tcp_store.cc\n)\nset_property(TARGET gloo_comm PROPERTY POSITION_INDEPENDENT_CODE ON)\ntarget_link_libraries(gloo_comm PUBLIC gloo host_comm logger xgrammar)\n\nadd_executable(test_ipc_comm test_ipc_comm.cc)\ntarget_link_libraries(test_ipc_comm PRIVATE gloo_comm Threads::Threads)\n"
  },
  {
    "path": "src/turbomind/comm/gloo/gloo_comm.cc",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include <algorithm>\n#include <mutex>\n\n#include <gloo/allgather.h>\n#include <gloo/allreduce.h>\n#include <gloo/barrier.h>\n#include <gloo/broadcast.h>\n#include <gloo/common/utils.h>\n#include <gloo/config.h>\n#include <gloo/context.h>\n#include <gloo/math.h>\n#include <gloo/rendezvous/context.h>\n#include <gloo/rendezvous/prefix_store.h>\n#include <gloo/rendezvous/store.h>\n#include <gloo/transport/tcp/attr.h>\n#include <gloo/transport/tcp/device.h>\n\n#if GLOO_HAVE_TRANSPORT_IBVERBS\n#include \"gloo/transport/ibverbs/device.h\"\n#endif\n\n#include \"src/turbomind/comm/gloo/tcp_store.h\"\n#include \"src/turbomind/comm/host_comm.h\"\n#include \"src/turbomind/utils/logger.h\"\n\nnamespace turbomind::comm {\n\nconst char* GLOO_SOCKET_IFNAME_ENV = \"GLOO_SOCKET_IFNAME\";\nconst char  STORE_INFO_DELIM       = ',';\n\nstd::shared_ptr<::gloo::transport::Device> createGlooDevice()\n{\n#if GLOO_HAVE_TRANSPORT_IBVERBS\n    if (auto transport = std::getenv(\"GLOO_DEVICE_TRANSPORT\");\n        transport != nullptr && strcmp(transport, \"ibverbs\") == 0) {\n        ::gloo::transport::ibverbs::attr ib_attr{};\n        ib_attr.name  = \"\";\n        ib_attr.port  = 1;\n        ib_attr.index = 3;  // use IBV_GID_TYPE_ROCE_V2 and ipv4\n        return ::gloo::transport::ibverbs::CreateDevice(ib_attr);\n    }\n#endif\n    ::gloo::transport::tcp::attr attr;\n    if (auto ifname = std::getenv(GLOO_SOCKET_IFNAME_ENV); ifname) {\n        attr.iface = ifname;\n    }\n    else {\n        attr.hostname = ::gloo::getHostname();\n    }\n    return ::gloo::transport::tcp::CreateDevice(attr);\n}\n\nclass Store: public ::gloo::rendezvous::PrefixStore {\npublic:\n    explicit Store(const std::string& host, int port, const std::string& prefix):\n        host_(host), port_(port), ::gloo::rendezvous::PrefixStore(prefix, nullptr)\n    {\n        store_ = std::make_shared<TCPStore>(host_, port_);\n    };\n\n    ~Store() = default;\n\n    std::shared_ptr<Store> New(const std::string& prefix)\n    {\n        std::string new_prefix = prefix + \"/\" + prefix_;\n        return std::make_shared<Store>(host_, port_, new_prefix);\n    }\n\npublic:\n    std::string host_;\n    int         port_;\n\n    using ::gloo::rendezvous::PrefixStore::store_;\n    using ::gloo::rendezvous::PrefixStore::prefix_;\n};\n\nclass GlobalStoreFactory {\npublic:\n    static GlobalStoreFactory& Instance()\n    {\n        static GlobalStoreFactory instance;\n        return instance;\n    }\n\n    std::string New()\n    {\n        std::lock_guard<std::mutex> lock(mutex_);\n        TM_CHECK(std::getenv(\"LMDEPLOY_DIST_INIT_ADDR\") != nullptr) << \"LMDEPLOY_DIST_INIT_ADDR not set\";\n        TM_CHECK(std::getenv(\"LMDEPLOY_DIST_INIT_PORT\") != nullptr) << \"LMDEPLOY_DIST_INIT_PORT not set\";\n\n        std::string host = std::getenv(\"LMDEPLOY_DIST_INIT_ADDR\");\n        int         port = std::stoi(std::getenv(\"LMDEPLOY_DIST_INIT_PORT\"));\n\n        std::stringstream ss;\n        ss << host << STORE_INFO_DELIM << port << STORE_INFO_DELIM << prefix_++;\n        return ss.str();\n    }\n\n    std::shared_ptr<Store> Load(const std::string& info)\n    {\n        std::stringstream        ss(info);\n        std::vector<std::string> keys;\n        std::string              local;\n        while (getline(ss, local, STORE_INFO_DELIM)) {\n            keys.push_back(std::move(local));\n        }\n        TM_CHECK(keys.size() == 3);\n\n        std::string host   = keys[0];\n        int         port   = stoi(keys[1]);\n        std::string prefix = keys[2];\n\n        return std::make_shared<Store>(host, port, prefix);\n    }\n\nprivate:\n    GlobalStoreFactory() {}\n\n    std::mutex mutex_;\n    int        prefix_{0};\n};\n\ntypedef void (*ReduceFunc)(void*, const void*, const void*, size_t);\n\nstruct GlooCommImpl: public HostCommImpl {\n\n    struct SplitInfo {\n        int color;\n        int rank;\n\n        bool operator<(const SplitInfo& other) const\n        {\n            return (color < other.color) || (color == other.color && rank < other.rank);\n        }\n\n        bool operator==(const SplitInfo& other) const\n        {\n            return (color == other.color) && (rank == other.rank);\n        }\n    };\n\n    GlooCommImpl(std::shared_ptr<Store> store, int n_ranks, int rank):\n        store_{std::move(store)}, rank_{rank}, n_ranks_{n_ranks}\n    {\n        device_  = createGlooDevice();\n        context_ = std::make_shared<::gloo::rendezvous::Context>(rank_, n_ranks_);\n        context_->setTimeout(kTimeOut);\n        context_->connectFullMesh(store_, device_);\n    }\n\n    ~GlooCommImpl() {}\n\n    int rank() const override\n    {\n        return rank_;\n    }\n\n    int n_ranks() const override\n    {\n        return n_ranks_;\n    }\n\n    bool is_same_process() const override\n    {\n        return false;\n    }\n\n    std::shared_ptr<HostCommImpl> Split(int color, int key) override\n    {\n        auto vec  = comm::AllGather(this, SplitInfo{color, rank_});\n        auto last = std::stable_partition(vec.begin(), vec.end(), [&](auto x) {  //\n            return x.color == color;\n        });\n        vec.erase(last, vec.end());\n        std::stable_sort(vec.begin(), vec.end(), [](auto& a, auto& b) {  //\n            return a < b;\n        });\n\n        auto new_prefix  = std::to_string(color) + \":\" + std::to_string(n_split_++);\n        auto new_store   = store_->New(new_prefix);\n        int  new_n_ranks = vec.size();\n        int  new_rank    = std::find(vec.begin(), vec.end(), SplitInfo{color, rank_}) - vec.begin();\n        return std::make_shared<GlooCommImpl>(new_store, new_n_ranks, new_rank);\n    }\n\n    void Sync(bool blocking) override\n    {\n        ::gloo::BarrierOptions opts(context_);\n        ::gloo::barrier(opts);\n    }\n\n    void Broadcast(void* data, int count, DataType dtype, int root, copy_fn copy, ser_fn ser, des_fn des) override\n    {\n        // trivially copyable if no ser/des function\n        if (!ser || !des) {\n            return Broadcast(data, count, dtype, root);\n        }\n\n        // broadcast buffer size\n        size_t size;\n        if (root == rank()) {\n            ser(data, 0, count, size, nullptr);\n        }\n        Broadcast(&size, 1, data_type_v<size_t>, root);\n\n        // serialize data on root rank\n        std::vector<std::byte> bytes;\n        bytes.reserve(size);\n        if (root == rank()) {\n            ser(data, 0, count, size, bytes.data());\n        }\n\n        // broadcast serialized data\n        Broadcast(bytes.data(), size, data_type_v<uint8_t>, root);\n\n        // deserialize data on all ranks\n        if (root != rank()) {\n            des(data, 0, count, bytes.data(), size);\n        }\n    }\n\n    void AllGather(void* data, int count, DataType dtype, copy_fn copy, ser_fn ser, des_fn des) override\n    {\n        // trivially copyable if no ser/des function\n        if (!ser || !des) {\n            return AllGather(data, count, dtype);\n        }\n\n        // get buffer size on each rank and find max size\n        size_t size;\n        ser(data, count * rank(), count, size, nullptr);\n        std::vector<size_t> sizes(n_ranks());\n        sizes[rank()] = size;\n        AllGather(sizes.data(), 1, data_type_v<size_t>);\n        auto max_size = *std::max_element(sizes.begin(), sizes.end());\n\n        // serialize data on each rank\n        std::vector<std::byte> bytes(max_size * n_ranks());\n        ser(data, count * rank(), count, size, bytes.data() + rank() * max_size);\n\n        // gather serialized data\n        AllGather(bytes.data(), max_size, data_type_v<uint8_t>);\n\n        // deserialize data on each rank\n        for (int i = 0; i < n_ranks(); ++i) {\n            if (i != rank()) {\n                des(data, i * count, count, bytes.data() + i * max_size, sizes[i]);\n            }\n        }\n    }\n\n    void Broadcast(void* data, int count, DataType dtype, int root)\n    {\n        ::gloo::BroadcastOptions opts(context_);\n        opts.setRoot(root);\n        opts.setOutput((char*)data, count * byte_size(dtype));\n        ::gloo::broadcast(opts);\n    }\n\n    void AllGather(void* data, int count, DataType dtype)\n    {\n        ::gloo::AllgatherOptions opts(context_);\n        opts.setOutput((char*)data, count * byte_size(dtype) * n_ranks_);\n        ::gloo::allgather(opts);\n    }\n\n    static ReduceFunc getReduceFunc(DataType dtype, RedOp red_op)\n    {\n\n        auto dispatch_op = [&](auto t) -> ReduceFunc {\n            using T = decltype(t);\n            switch (red_op) {\n                case RedOp::kSum:\n                    return ::gloo::sum<T>;\n                case RedOp::kMax:\n                    return ::gloo::max<T>;\n                case RedOp::kMin:\n                    return ::gloo::min<T>;\n                default:\n                    return {};\n            }\n        };\n\n        auto dispatch = [&]() -> ReduceFunc {\n            switch (dtype) {\n                case kInt32:\n                    return dispatch_op(int32_t{});\n                case kInt64:\n                    return dispatch_op(int64_t{});\n                case kUint32:\n                    return dispatch_op(uint32_t{});\n                case kUint64:\n                    return dispatch_op(uint64_t{});\n                default:\n                    return {};\n            }\n        };\n\n        if (auto fn = dispatch()) {\n            return fn;\n        }\n        else {\n            throw std::runtime_error(\"not implemented\");\n            return {};\n        }\n    }\n\n    void AllReduce(void* data, int count, DataType dtype, RedOp red_op) override\n    {\n        ::gloo::AllreduceOptions opts(context_);\n        opts.setReduceFunction(getReduceFunc(dtype, red_op));\n        switch (dtype) {\n            case kInt32:\n                opts.setOutput((int32_t*)data, count);\n                break;\n            case kInt64:\n                opts.setOutput((int64_t*)data, count);\n                break;\n            case kUint32:\n                opts.setOutput((uint32_t*)data, count);\n                break;\n            case kUint64:\n                opts.setOutput((uint64_t*)data, count);\n                break;\n            default:\n                throw std::runtime_error(\"not implemented\");\n        }\n        ::gloo::allreduce(opts);\n    }\n\n    // there might be very long intervals between receiving requests.\n    static constexpr std::chrono::milliseconds kTimeOut = std::chrono::milliseconds(1000LL * 3600 * 24 * 365);\n\n    int                                          n_split_{};\n    std::shared_ptr<::gloo::transport::Device>   device_;\n    std::shared_ptr<::gloo::rendezvous::Context> context_;\n    std::shared_ptr<Store>                       store_;\n    int                                          rank_;\n    int                                          n_ranks_;\n};\n\nclass GlooGroupId: public HostGroupId {\n\n    void Initialize() override\n    {\n        info_ = GlobalStoreFactory::Instance().New();\n        TM_LOG_INFO(\"[TM][COMM] GlooGroupId=%s\", info_.c_str());\n    }\n\n    void Export(std::ostream& os) override\n    {\n        os << info_;\n    }\n\n    void Import(std::istream& is) override\n    {\n        std::stringstream ss;\n        ss << is.rdbuf();\n        info_ = ss.str();\n    }\n\n    HostComm CreateCommunicator(int n_ranks, int rank, int node_rank = 0) override\n    {\n        TM_CHECK(info_ != \"\");\n        auto impl = std::make_shared<GlooCommImpl>(GlobalStoreFactory::Instance().Load(info_), n_ranks, rank);\n        return std::static_pointer_cast<HostCommImpl>(impl);\n    }\n\nprivate:\n    std::string                                info_;  // ip,port,prefix\n    std::shared_ptr<::gloo::rendezvous::Store> store_;\n};\n\nstd::unique_ptr<HostGroupId> CreateGlooGroupId()\n{\n    return std::make_unique<GlooGroupId>();\n}\n\n}  // namespace turbomind::comm\n"
  },
  {
    "path": "src/turbomind/comm/gloo/hybrid_comm.cc",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include \"src/turbomind/comm/host_comm.h\"\n#include \"src/turbomind/core/check.h\"\n\nnamespace turbomind::comm {\n\nextern std::unique_ptr<HostGroupId> CreateThreadGroupId();\nextern std::unique_ptr<HostGroupId> CreateGlooGroupId();\n\nstruct HybridCommImpl: public HostCommImpl {\n\n    HybridCommImpl(int n_ranks, int rank, int node_rank, HostGroupId* gloo_group_id, HostGroupId* thread_group_id):\n        n_ranks_{n_ranks},  //\n        rank_{rank},\n        node_rank_(node_rank)\n    {\n        gloo_comm_     = gloo_group_id->CreateCommunicator(n_ranks, rank);\n        rank_to_nodes_ = ::turbomind::comm::AllGather(gloo_comm_, node_rank);\n        same_process_  = rank_to_nodes_.front() == rank_to_nodes_.back();\n        if (same_process_) {\n            intra_comm_ = thread_group_id->CreateCommunicator(n_ranks, rank);\n        }\n        else {\n            init_inter_comm();\n            intra_comm_ = thread_group_id->CreateCommunicator(intra_n_ranks_, rank_to_intra_[rank_]);\n        }\n    }\n\n    HybridCommImpl(std::shared_ptr<HostCommImpl> gloo_comm, std::shared_ptr<HostCommImpl> intra_comm, int node_rank):\n        gloo_comm_{std::move(gloo_comm)},\n        intra_comm_{std::move(intra_comm)},\n        rank_{gloo_comm_->rank()},\n        n_ranks_{gloo_comm_->n_ranks()},\n        node_rank_(node_rank)\n    {\n        rank_to_nodes_ = ::turbomind::comm::AllGather(gloo_comm_, node_rank);\n        same_process_  = rank_to_nodes_.front() == rank_to_nodes_.back();\n        if (same_process_) {}\n        else {\n            init_inter_comm();\n        }\n    }\n\n    void init_inter_comm()\n    {\n        int intra_n_ranks = 0;\n        int intra_rank    = -1;\n        for (int r = 0; r < n_ranks_; ++r) {\n            if (rank_to_nodes_[r] == node_rank_) {\n                if (r == rank_) {\n                    intra_rank = intra_n_ranks;\n                }\n                intra_n_ranks++;\n            }\n        }\n\n        intra_n_ranks_ = intra_n_ranks;\n        gloo_comm_->AllReduce(&intra_n_ranks_, 1, DataType::kInt, RedOp::kMin);\n        TM_CHECK_EQ(intra_n_ranks_, intra_n_ranks) << \"The number of ranks in each node should be same.\";\n        TM_CHECK_GT(intra_rank, -1) << \"Invalid intra_rank.\";\n        rank_to_intra_ = ::turbomind::comm::AllGather(gloo_comm_, intra_rank);\n\n        inter_comm_    = gloo_comm_->Split(rank_to_intra_[rank_], 0);\n        rank_to_inter_ = ::turbomind::comm::AllGather(gloo_comm_, inter_comm_->rank());\n    }\n\n    std::shared_ptr<HostCommImpl> Split(int color, int key) override\n    {\n        if (!is_same_process()) {\n            auto new_gloo_comm  = gloo_comm_->Split(color, key);\n            auto new_intra_comm = intra_comm_->Split(color, key);\n            return std::make_shared<HybridCommImpl>(new_gloo_comm, new_intra_comm, node_rank_);\n        }\n        else {\n            return intra_comm_->Split(color, key);\n        }\n    }\n\n    int rank() const override\n    {\n        return rank_;\n    }\n\n    int n_ranks() const override\n    {\n        return n_ranks_;\n    }\n\n    bool is_same_process() const override\n    {\n        return same_process_;\n    }\n\n    void Sync(bool blocking) override\n    {\n        if (!is_same_process() && rank_to_intra_[rank_] == 0) {\n            inter_comm_->Sync(blocking);\n        }\n        intra_comm_->Sync(blocking);\n    }\n\n    void Broadcast(void* data, int count, DataType dtype, int root, copy_fn copy, ser_fn ser, des_fn des) override\n    {\n        if (!ser || !des) {\n            return Broadcast(data, count, dtype, root, copy);\n        }\n\n        if (rank_to_intra_[root] == rank_to_intra_[rank_]) {  // same ith rank in node\n            inter_comm_->Broadcast(data, count, dtype, rank_to_inter_[root], copy, ser, des);\n        }\n        intra_comm_->Broadcast(data, count, dtype, rank_to_intra_[root], copy);\n    }\n\n    void Broadcast(void* data, int count, DataType dtype, int root, copy_fn copy)\n    {\n        if (is_same_process()) {\n            return intra_comm_->Broadcast(data, count, dtype, root, copy);\n        }\n\n        if (rank_to_intra_[root] == rank_to_intra_[rank_]) {  // same ith rank in node\n            inter_comm_->Broadcast(data, count, dtype, rank_to_inter_[root], copy);\n        }\n        intra_comm_->Broadcast(data, count, dtype, rank_to_intra_[root], copy);\n    }\n\n    void AllGather(void* data, int count, DataType dtype, copy_fn copy, ser_fn ser, des_fn des) override\n    {\n        if (!ser || !des) {\n            return AllGather(data, count, dtype, copy);\n        }\n\n        return gloo_comm_->AllGather(data, count, dtype, copy, ser, des);\n    }\n\n    void AllGather(void* data, int count, DataType dtype, copy_fn copy)\n    {\n        if (is_same_process()) {\n            return intra_comm_->AllGather(data, count, dtype, copy);\n        }\n\n        // TODO: support allgatherv in gloo comm (each node may has different rank size)\n        return gloo_comm_->AllGather(data, count, dtype, copy);\n    }\n\n    void AllReduce(void* data, int count, DataType dtype, RedOp red_op) override\n    {\n        if (is_same_process()) {\n            return intra_comm_->AllReduce(data, count, dtype, red_op);\n        }\n\n        intra_comm_->AllReduce(data, count, dtype, red_op);\n        if (rank_to_intra_[rank_] == 0) {\n            inter_comm_->AllReduce(data, count, dtype, red_op);\n        }\n        intra_comm_->Broadcast(data, byte_size(dtype) * count, data_type_v<uint8_t>, 0, detail::copy_fn<uint8_t>);\n    }\n\n    HostComm gloo_comm_{};   // primitive comm, used for initializing inter_comm and intra_comm\n    HostComm inter_comm_{};  // inter-node comm\n    HostComm intra_comm_{};  // intra-node comm\n\n    int rank_;       // group rank\n    int n_ranks_;    // group size\n    int node_rank_;  // node rank\n    int intra_n_ranks_;\n\n    std::vector<int> rank_to_nodes_{};  // map group rank to node rank (not global)\n    std::vector<int> rank_to_intra_{};  // map group rank to intra-node rank\n    std::vector<int> rank_to_inter_{};  // map group rank to inter-node rank\n\n    bool same_process_;\n};\n\nclass HybridGroupId: public HostGroupId {\npublic:\n    HybridGroupId()\n    {\n        thread_group_id_ = CreateThreadGroupId();\n        gloo_group_id_   = CreateGlooGroupId();\n    }\n\n    void Initialize() override\n    {\n        thread_group_id_->Initialize();\n        gloo_group_id_->Initialize();\n    }\n\n    void Export(std::ostream& os) override\n    {\n        thread_group_id_->Export(os);\n        gloo_group_id_->Export(os);\n    }\n\n    void Import(std::istream& is) override\n    {\n        thread_group_id_->Import(is);\n        gloo_group_id_->Import(is);\n    }\n\n    HostComm CreateCommunicator(int n_ranks, int rank, int node_rank)\n    {\n        auto impl = std::make_shared<HybridCommImpl>(n_ranks,  //\n                                                     rank,\n                                                     node_rank,\n                                                     gloo_group_id_.get(),\n                                                     thread_group_id_.get());\n        return std::static_pointer_cast<HostCommImpl>(impl);\n    }\n\n    std::unique_ptr<HostGroupId> thread_group_id_;\n    std::unique_ptr<HostGroupId> gloo_group_id_;\n};\n\nstd::unique_ptr<HostGroupId> CreateHybridGroupId()\n{\n    return std::make_unique<HybridGroupId>();\n}\n\n}  // namespace turbomind::comm\n"
  },
  {
    "path": "src/turbomind/comm/gloo/tcp_store.cc",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include <chrono>\n#include <netdb.h>\n#include <thread>\n#include <unistd.h>\n\n#include <gloo/transport/tcp/device.h>\n#include <gloo/transport/tcp/socket.h>\n\n#include \"src/turbomind/comm/gloo/tcp_store.h\"\n#include \"src/turbomind/utils/logger.h\"\n\nnamespace turbomind::comm {\n\nnamespace {\n\n// copy from pytorch https://github.com/pytorch/pytorch/blob/v2.8.0-rc4/torch/csrc/distributed/c10d/TCPStoreBackend.hpp\n\nstatic const uint32_t validationMagicNumber = 0x3C85F7CE;\n\nenum class CheckResponseType : uint8_t\n{\n    READY,\n    NOT_READY\n};\n\nenum class QueryType : uint8_t\n{\n    VALIDATE,\n    SET,\n    COMPARE_SET,\n    GET,\n    ADD,\n    CHECK,\n    WAIT,\n    GETNUMKEYS,\n    DELETE_KEY,\n    APPEND,\n    MULTI_GET,\n    MULTI_SET,\n    CANCEL_WAIT,\n    PING,\n    QUEUE_PUSH,\n    QUEUE_POP,\n    QUEUE_LEN,\n};\n\n}  // namespace\n\nstruct Buffer {\n    std::vector<char> buffer;\n\n    template<typename T, typename = std::enable_if_t<std::is_trivially_copyable_v<T>>>\n    void append(T val)\n    {\n        char* ptr = (char*)&val;\n        buffer.insert(buffer.end(), ptr, ptr + sizeof(T));\n    }\n\n    void append(const std::vector<char>& vec)\n    {\n        append((uint64_t)vec.size());\n        buffer.insert(buffer.end(), vec.begin(), vec.end());\n    }\n\n    void append(const std::string& str)\n    {\n        append((uint64_t)str.size());\n        buffer.insert(buffer.end(), str.begin(), str.end());\n    }\n\n    const char* data() const\n    {\n        return buffer.data();\n    }\n\n    size_t count() const\n    {\n        return buffer.size();\n    }\n};\n\nvoid validate(std::shared_ptr<::gloo::transport::tcp::Socket>& socket)\n{\n    Buffer buffer;\n    buffer.append(QueryType::VALIDATE);\n    buffer.append(validationMagicNumber);\n    socket->write(buffer.data(), buffer.count());\n}\n\nvoid ping(std::shared_ptr<::gloo::transport::tcp::Socket>& socket)\n{\n    Buffer buffer;\n    buffer.append(QueryType::PING);\n    uint32_t nonce         = getpid();\n    uint32_t returnedNonce = -1;\n    buffer.append(nonce);\n    socket->write(buffer.data(), buffer.count());\n    int r = socket->read(&returnedNonce, sizeof(returnedNonce));\n    if (nonce != returnedNonce) {\n        std::stringstream ss;\n        ss << \"Ping failed, nonce=\" << nonce << \", returnedNonce=\" << returnedNonce << \", socket read=\" << r;\n        throw std::runtime_error(ss.str());\n    }\n}\n\nTCPStore::TCPStore(const std::string& host, int port)\n{\n    auto retry = 0;\n    do {\n        try {\n            ::addrinfo hints{}, *res{};\n            hints.ai_flags    = AI_V4MAPPED | AI_ALL | AI_NUMERICSERV;\n            hints.ai_family   = AF_UNSPEC;\n            hints.ai_socktype = SOCK_STREAM;\n\n            int status = getaddrinfo(host.c_str(), std::to_string(port).c_str(), &hints, &res);\n\n            std::shared_ptr<addrinfo> holder(res, [](addrinfo* p) {\n                if (p != nullptr) {\n                    freeaddrinfo(p);\n                }\n            });\n\n            if (status != 0) {\n                throw std::runtime_error(\"getaddrinfo failed: \" + std::string(gai_strerror(status)));\n            }\n\n            for (::addrinfo* addr = res; addr != nullptr; addr = addr->ai_next) {\n                int fd = ::socket(addr->ai_family, addr->ai_socktype, addr->ai_protocol);\n                if (fd == -1) {\n                    continue;\n                }\n                auto socket = std::make_shared<::gloo::transport::tcp::Socket>(fd);\n                socket->connect(addr->ai_addr, addr->ai_addrlen);\n                socket->noDelay(true);\n                socket->recvTimeout(std::chrono::milliseconds(5000));\n                socket->sendTimeout(std::chrono::milliseconds(5000));\n                validate(socket);  // validate the connection\n                ping(socket);      // check send/recv\n                socket_ = std::move(socket);\n                break;\n            }\n\n            if (socket_ == nullptr) {\n                throw std::runtime_error(\"unable to connect to \" + host + \":\" + std::to_string(port));\n            }\n        }\n        catch (const std::exception& e) {\n            TM_LOG_WARNING(\"[TM][COMM] Failed to connect to store after %d retries: %s\", retry, e.what());\n            std::this_thread::sleep_for(std::chrono::seconds(1));\n            retry += 1;\n        }\n    } while (socket_ == nullptr);\n}\n\nvoid TCPStore::set(const std::string& key, const std::vector<char>& data)\n{\n    std::lock_guard<std::mutex> lock(mutex_);\n    Buffer                      buffer;\n    buffer.append(QueryType::SET);\n    buffer.append(key);\n    buffer.append(data);\n    socket_->write(buffer.data(), buffer.count());\n}\n\nstd::vector<char> TCPStore::get(const std::string& key)\n{\n    wait({key});\n    std::lock_guard<std::mutex> lock(mutex_);\n    Buffer                      buffer;\n    buffer.append(QueryType::GET);\n    buffer.append(key);\n    socket_->write(buffer.data(), buffer.count());\n\n    uint64_t vec_size;\n    socket_->read(&vec_size, sizeof(vec_size));\n    std::vector<char> value(vec_size);\n    socket_->read(value.data(), value.size());\n    return value;\n}\n\nbool TCPStore::check(const std::vector<std::string>& keys)\n{\n    std::lock_guard<std::mutex> lock(mutex_);\n    Buffer                      buffer;\n    buffer.append(QueryType::CHECK);\n    buffer.append((uint64_t)keys.size());\n    for (const auto& key : keys) {\n        buffer.append(key);\n    }\n    socket_->write(buffer.data(), buffer.count());\n\n    CheckResponseType response;\n    socket_->read(&response, sizeof(response));\n    return response == CheckResponseType::READY;\n}\n\nvoid TCPStore::wait(const std::vector<std::string>& keys, const std::chrono::milliseconds& timeout)\n{\n    const auto start = std::chrono::steady_clock::now();\n    while (!check(keys)) {\n        const auto elapsed = std::chrono::duration_cast<std::chrono::seconds>(std::chrono::steady_clock::now() - start);\n        if (elapsed > timeout) {\n            std::stringstream ss;\n            ss << \"Wait timeout for key(s): [\";\n            for (const auto& key : keys) {\n                ss << key << \" \";\n            }\n            ss << \"]\";\n            TM_LOG_ERROR(\"[TM][COMM] %s, elapsed %lld s\", ss.str().c_str(), elapsed.count());\n            throw std::runtime_error(\"Wait timeout for key(s): \" + ss.str());\n        }\n        std::this_thread::sleep_for(std::chrono::milliseconds(1000));\n    }\n}\n\nTCPStore::~TCPStore() = default;\n\n}  // namespace turbomind::comm\n"
  },
  {
    "path": "src/turbomind/comm/gloo/tcp_store.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include <memory>\n#include <mutex>\n\n#include <gloo/rendezvous/store.h>\n#include <gloo/transport/tcp/socket.h>\n\nnamespace turbomind::comm {\n\nclass TCPStore: public gloo::rendezvous::Store {\npublic:\n    explicit TCPStore(const std::string& host, int port);\n\n    ~TCPStore();\n\n    void set(const std::string& key, const std::vector<char>& data) override;\n\n    std::vector<char> get(const std::string& key) override;\n\n    bool check(const std::vector<std::string>& keys);\n\n    void wait(const std::vector<std::string>& keys) override\n    {\n        wait(keys, std::chrono::seconds(30));\n    }\n\n    void wait(const std::vector<std::string>& keys, const std::chrono::milliseconds& timeout) override;\n\nprivate:\n    std::shared_ptr<::gloo::transport::tcp::Socket> socket_;\n    std::mutex                                      mutex_;\n};\n\n}  // namespace turbomind::comm\n"
  },
  {
    "path": "src/turbomind/comm/gloo/test_ipc_comm.cc",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include <chrono>\n#include <cstdlib>\n#include <fstream>\n#include <iostream>\n#include <thread>\n#include <unistd.h>\n\n#include \"src/turbomind/comm/host_comm.h\"\n\nusing namespace turbomind::comm;\n\n#define TEST_TRIVIALLY_COPYABLE 1\n\n// #define SKIP_SERIALIZE 0 // useless now\n\n// const std::string backend = \"\";\nconst std::string backend = \"hybrid\";\n// const std::string backend = \"gloo\";\n\nstruct Store {\n    std::string hostname_;\n    std::string port_;\n    int         nnodes_;\n    int         node_rank_;\n    std::string py_script_;\n    std::string py_file_path_ = \"/tmp/start_tcp_store.py\";\n\n    std::thread thread_;\n\n    Store(const std::string& hostname, const std::string& port, int nnodes, int node_rank):\n        hostname_(hostname), port_(port), nnodes_(nnodes), node_rank_(node_rank)\n    {\n\n        int pid = getpid();\n\n        // clang-format off\n    py_script_ =\n\"import psutil\\n\"\n\"import os\\n\"\n\"import time\\n\"\n\"from torch.distributed import TCPStore\\n\"\n\"store = TCPStore(host_name='\" + hostname_ + \"',\\n\"\n\"                 port=\" + port_ + \",\\n\"\n\"                 world_size=\" + std::to_string(nnodes_) + \",\\n\"\n\"                 is_master=\" + (node_rank_ == 0 ? \"True\" : \"False\") + \")\\n\"\n\"while True:\\n\"\n\"    time.sleep(1)\\n\"\n\"    if not psutil.pid_exists(\" + std::to_string(pid) + \"):\\n\"\n\"        break\\n\"\n\"    if not os.path.exists('/tmp/start_tcp_store.py'):\\n\"\n\"        break\\n\";\n\n        // clang-format on\n        std::ofstream py_file(py_file_path_);\n        py_file << py_script_;\n        py_file.close();\n\n        std::string env_addr = \"LMDEPLOY_DIST_INIT_ADDR=\" + hostname_;\n        std::string env_port = \"LMDEPLOY_DIST_INIT_PORT=\" + port_;\n        setenv(\"LMDEPLOY_DIST_INIT_ADDR\", hostname_.c_str(), 1);\n        setenv(\"LMDEPLOY_DIST_INIT_PORT\", port_.c_str(), 1);\n\n        start();\n        // wait a moment for the store to start.\n        std::this_thread::sleep_for(std::chrono::seconds(3));\n    }\n\n    ~Store()\n    {\n        stop();\n    }\n\n    void start()\n    {\n        const std::string cmd = (\"python \" + py_file_path_);\n        thread_               = std::thread([](const std::string& cmd) { int result = system(cmd.c_str()); }, cmd);\n    }\n\n    void stop()\n    {\n        int r = system(\"rm /tmp/start_tcp_store.py\");\n        thread_.join();\n    }\n};\n\nstruct TestGlooComm {\n    std::string hostname_;\n    std::string port_;\n    int         nnodes_;\n    int         node_rank_;\n    int         n_ranks_per_node_;\n\n    std::vector<HostComm> h_comm_;\n\n    TestGlooComm(const std::string& host, const std::string& port, int nnodes, int node_rank, int n_ranks_per_node):\n        hostname_(host), port_(port), nnodes_(nnodes), node_rank_(node_rank), n_ranks_per_node_(n_ranks_per_node)\n    {\n        h_comm_.resize(n_ranks_per_node_);\n    }\n\n    void init()\n    {\n        std::unique_ptr<HostGroupId> group_id = CreateHostGroupId(backend);\n        std::string                  group_id_data;\n        if (1) {  // master\n            group_id->Initialize();\n            std::stringstream ss;\n            group_id->Export(ss);\n            group_id_data = ss.str();\n        }\n\n        auto init = [&](int rank) {\n            // initialize host communicators\n            std::stringstream            ss(group_id_data);\n            std::unique_ptr<HostGroupId> host_id = CreateHostGroupId(backend);\n            host_id->Import(ss);\n            h_comm_[rank % n_ranks_per_node_] =\n                host_id->CreateCommunicator(n_ranks_per_node_ * nnodes_, rank, node_rank_);\n        };\n\n        std::vector<std::thread> threads;\n        for (int i = 0; i < n_ranks_per_node_; ++i) {\n            threads.emplace_back(init, n_ranks_per_node_ * node_rank_ + i);\n        }\n        for (auto& t : threads) {\n            t.join();\n        }\n    }\n\n    void test_broadcast()\n    {\n        const int count = 10;\n\n        auto fun = [&](HostComm& comm, int rank) {\n            for (int r = 0; r < comm->n_ranks(); ++r) {\n\n#if TEST_TRIVIALLY_COPYABLE\n                std::vector<int> data(count);\n#else\n                std::shared_ptr<std::vector<int>> data_ptr = std::make_shared<std::vector<int>>(count);\n                int*                              data     = data_ptr->data();\n#endif\n\n                for (int i = 0; i < count; ++i) {\n                    data[i] = i + rank * count;  // i + rank * count\n                }\n\n#if TEST_TRIVIALLY_COPYABLE\n                Broadcast(comm, data.data(), count, r);\n#else\n                Broadcast(comm, data_ptr, r);\n                data = data_ptr->data();\n#endif\n                // check result\n                for (int i = 0; i < count; ++i) {\n                    int expected = i + r * count;\n                    if (data[i] != expected) {\n                        printf(\"Rank %d: Broadcast failed at root %d, index %d, got %d, expected %d\\n\",\n                               rank,\n                               r,\n                               i,\n                               data[i],\n                               expected);\n                    }\n                }\n            }\n        };\n\n        std::vector<std::thread> threads;\n        for (size_t i = 0; i < n_ranks_per_node_; ++i) {\n            threads.emplace_back(fun, std::ref(h_comm_[i]), n_ranks_per_node_ * node_rank_ + i);\n        }\n        for (auto& t : threads) {\n            t.join();\n        }\n    }\n\n    void test_allgather()\n    {\n        const int count = 40;\n\n        auto fun = [&](HostComm& comm, int rank) {\n\n#if TEST_TRIVIALLY_COPYABLE\n            std::vector<int> data(count * comm->n_ranks());\n            for (int i = 0; i < count; ++i) {\n                data[i + count * comm->rank()] = i + rank * count;  // i + rank * count\n            }\n#else\n            std::vector<std::shared_ptr<std::vector<int>>> data_ptrs(comm->n_ranks());\n            data_ptrs[comm->rank()] = std::make_shared<std::vector<int>>(count);\n            int* data = data_ptrs[comm->rank()]->data();\n            for (int i = 0; i < count; ++i) {\n                data[i] = i + rank * count;  // i + rank * count\n            }\n#endif\n\n#if TEST_TRIVIALLY_COPYABLE\n            AllGather(comm, data.data(), count);\n            for (int r = 0; r < comm->n_ranks(); ++r) {\n                for (int j = 0; j < count; ++j) {\n                    int expected = j + r * count;\n                    if (data[j + r * count] != expected) {\n                        printf(\"Rank %d: AllGather failed, index %d, got %d, expected %d\\n\",\n                               rank,\n                               j + r * count,\n                               data[j + r * count],\n                               expected);\n                    }\n                }\n            }\n#else\n            AllGather(comm, data_ptrs.data(), 1);\n            for (int r = 0; r < comm->n_ranks(); ++r) {\n                data = data_ptrs[r]->data();\n                for (int j = 0; j < count; ++j) {\n                    int expected = j + r * count;\n                    if (data[j] != expected) {\n                        printf(\"Rank %d: AllGather failed, index %d, got %d, expected %d\\n\",\n                               rank,\n                               j + r * count,\n                               data[j],\n                               expected);\n                    }\n                }\n            }\n#endif\n        };\n\n        std::vector<std::thread> threads;\n        for (size_t i = 0; i < n_ranks_per_node_; ++i) {\n            threads.emplace_back(fun, std::ref(h_comm_[i]), n_ranks_per_node_ * node_rank_ + i);\n        }\n        for (auto& t : threads) {\n            t.join();\n        }\n    }\n\n    void test_allreduce()\n    {\n        const int count = 10;\n\n        auto fun = [&](HostComm& comm, int rank) {\n            std::vector<int> data(count);\n            for (int i = 0; i < count; ++i) {\n                data[i] = i + rank * count;  // i + rank * count\n            }\n\n            AllReduce(comm, data.data(), count, RedOp::kSum);\n            for (int j = 0; j < count; ++j) {\n                int expected{};\n                for (int r = 0; r < comm->n_ranks(); ++r) {\n                    expected += j + r * count;\n                }\n                if (data[j] != expected) {\n                    printf(\"Rank %d: AllReduce failed, index %d, got %d, expected %d\\n\", rank, j, data[j], expected);\n                }\n            }\n        };\n\n        std::vector<std::thread> threads;\n        for (size_t i = 0; i < n_ranks_per_node_; ++i) {\n            threads.emplace_back(fun, std::ref(h_comm_[i]), n_ranks_per_node_ * node_rank_ + i);\n        }\n        for (auto& t : threads) {\n            t.join();\n        }\n    }\n\n    void test_perf()\n    {\n        const long  kMinDurationNs   = 2e9;  // 2 second\n        const long  kWarmupIter      = 5;    // warmup iter\n        const float kItersMultiplier = 1.2;\n\n        std::vector<int> count = {1024, 262144, 524288, 1048576, 2097152, 4194304, 67108864};\n        //                              1M,     2M,     4M,      8M,      16M,     256M\n\n        if (node_rank_ == 0) {\n            printf(\"%10s %10s %10s %10s %11s %18s %10s\\n\",\n                   \"size(MB)\",\n                   \"elements\",\n                   \"avg(us)\",\n                   \"p50(us)\",\n                   \"p99(us)\",\n                   \"bandwidth(GB/s)\",\n                   \"iterations\");\n        }\n\n        auto fun = [&](HostComm& comm, int rank, int n) {\n\n#if TEST_TRIVIALLY_COPYABLE\n            std::vector<int> data(n);\n#else\n            std::shared_ptr<std::vector<int>> sptr;\n            if (rank == 0) {\n                sptr = std::make_shared<std::vector<int>>(n);\n            }\n#endif\n\n            std::vector<int64_t> times;\n\n            auto job = [&](int n_iters) {\n                times.clear();\n                int64_t total = 0;\n                int64_t ns    = 0;\n                comm->Sync();\n                for (int i = 0; i < n_iters; ++i) {\n                    auto start = std::chrono::high_resolution_clock::now();\n#if TEST_TRIVIALLY_COPYABLE\n                    Broadcast(comm, data.data(), n, 0);\n#else\n                    Broadcast(comm, sptr, 0);\n#endif\n                    auto    now = std::chrono::high_resolution_clock::now();\n                    int64_t ns  = std::chrono::duration_cast<std::chrono::nanoseconds>(now - start).count();\n                    total += ns;\n                    times.push_back(ns);\n                }\n                Broadcast(comm, total, 0);\n                return total;\n            };\n\n            auto warmup_dur = job(kWarmupIter) / kWarmupIter;\n            auto iter       = (int)std::max(kMinDurationNs / warmup_dur * 0.5f, 100.f);\n\n            while (1) {\n                auto dur = job(iter);\n                std::sort(times.begin(), times.end());\n\n                if (rank == 0) {\n                    size_t bytes = n * sizeof(int);\n                    int    p50   = std::min(times.size() / 2, times.size() - 1);\n                    int    p99   = std::min((int)(times.size() * 0.99), (int)times.size() - 1);\n                    printf(\"%10.5f %10d %10lld %10lld %10lld %18.3f %10lld\\n\",\n                           bytes / 1024.f / 1024.f,\n                           n,\n                           static_cast<long long>(dur / 1e3f / iter),\n                           static_cast<long long>(times[p50] / 1e3f),\n                           static_cast<long long>(times[p99] / 1e3f),\n                           (bytes * iter) / (dur / 1e9f) / (1024 * 1024 * 1024),\n                           static_cast<long long>(iter));\n                }\n\n                if (dur >= kMinDurationNs) {\n                    break;\n                }\n                iter = std::max(iter * kItersMultiplier, iter + 1.f);\n            }\n        };\n\n        for (auto n : count) {\n            std::vector<std::thread> threads;\n            for (size_t i = 0; i < n_ranks_per_node_; ++i) {\n                threads.emplace_back(fun, std::ref(h_comm_[i]), n_ranks_per_node_ * node_rank_ + i, n);\n            }\n            for (auto& t : threads) {\n                t.join();\n            }\n        }\n    }\n};\n\n// ./test_gloo_comm <nnodes> <node_rank> <n_ranks_per_node> <init_addr>\nint main(int argc, char* argv[])\n{\n    if (argc != 5) {\n        std::cerr << \"Usage: \" << argv[0] << \" <nnodes> <node_rank> <n_ranks_per_node> <init_addr>\" << std::endl;\n        return -1;\n    }\n\n    int nnodes           = std::atoi(argv[1]);\n    int node_rank        = std::atoi(argv[2]);\n    int n_ranks_per_node = std::atoi(argv[3]);\n\n    const std::string init_addr = argv[4];\n    auto              pos       = init_addr.find(\":\");\n    const std::string host      = init_addr.substr(0, pos);\n    const std::string port      = init_addr.substr(pos + 1);\n\n    Store store(host, port, nnodes, node_rank);\n\n    {\n        TestGlooComm test(host, port, nnodes, node_rank, n_ranks_per_node);\n        test.init();\n\n        test.test_broadcast();\n        test.test_allgather();\n        test.test_allreduce();\n\n        // test.test_perf();\n    }\n\n    return 0;\n}\n"
  },
  {
    "path": "src/turbomind/comm/host_comm.cc",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include \"src/turbomind/comm/host_comm.h\"\n\nnamespace turbomind::comm {\n\nHostCommImpl::~HostCommImpl() = default;\n\nstd::unique_ptr<HostGroupId> CreateThreadGroupId();\n\nstd::unique_ptr<HostGroupId> CreateGlooGroupId();\n\nstd::unique_ptr<HostGroupId> CreateHybridGroupId();\n\nstd::unique_ptr<HostGroupId> CreateHostGroupId(const std::string& backend)\n{\n#ifdef BUILD_MULTI_GPU\n    if (backend == \"hybrid\") {\n        return CreateHybridGroupId();\n    }\n    if (backend == \"gloo\") {\n        return CreateGlooGroupId();\n    }\n#endif\n\n    return CreateThreadGroupId();\n}\n\n}  // namespace turbomind::comm\n"
  },
  {
    "path": "src/turbomind/comm/host_comm.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include <algorithm>\n#include <cstring>\n#include <memory>\n#include <stdexcept>\n#include <tuple>\n#include <type_traits>\n#include <vector>\n\n#include \"src/turbomind/core/data_type.h\"\n#include \"src/turbomind/core/serdes.h\"\n#include \"src/turbomind/utils/logger.h\"\n\nnamespace turbomind::comm {\n\nenum class RedOp\n{\n    kSum,\n    kMin,\n    kMax,\n};\n\ntypedef void (*copy_fn)(void* src, int n, void* dst, int offset);\n\ntypedef void (*reduce_fn)(void* src, int n, void* dst, int offset);\n\ntypedef void (*ser_fn)(void* data, int offset, int n, size_t& size, void* out);\n\ntypedef void (*des_fn)(void* data, int offset, int n, void* in, size_t size);\n\nclass HostCommImpl {\npublic:\n    virtual ~HostCommImpl();\n\n    virtual int rank() const = 0;\n\n    virtual int n_ranks() const = 0;\n\n    virtual bool is_same_process() const = 0;\n\n    virtual std::shared_ptr<HostCommImpl> Split(int color, int key) = 0;\n\n    virtual void Sync(bool blocking = false) = 0;\n\n    virtual void Broadcast(void*    data,  //\n                           int      count,\n                           DataType dtype,\n                           int      root,\n                           copy_fn  copy,\n                           ser_fn   ser = nullptr,\n                           des_fn   des = nullptr) = 0;\n\n    virtual void AllGather(void*    data,  //\n                           int      count,\n                           DataType dtype,\n                           copy_fn  copy,\n                           ser_fn   ser = nullptr,\n                           des_fn   des = nullptr) = 0;\n\n    virtual void AllReduce(void* data, int count, DataType dtype, RedOp red_op) = 0;\n};\n\nclass HostComm {\npublic:\n    HostComm() = default;\n\n    /* implicit */ HostComm(std::shared_ptr<HostCommImpl> impl): impl_{std::move(impl)} {}\n\n    HostCommImpl* operator->() const noexcept\n    {\n        return impl_.get();\n    }\n\n    operator HostCommImpl*() const noexcept\n    {\n        return impl_.get();\n    }\n\nprivate:\n    std::shared_ptr<HostCommImpl> impl_;\n};\n\nnamespace detail {\ntemplate<class T>\nvoid copy_fn(void* src, int n, void* dst, int offset)\n{\n    std::copy_n((T*)src + offset, n, (T*)dst + offset);\n}\n\ntemplate<class T>\nvoid ser_fn(void* data, int offset, int n, size_t& size, void* out)\n{\n    if (out == nullptr) {\n        size = 0;\n        core::BinarySizeArchive sa;\n        for (int i = 0; i < n; ++i) {\n            sa&((T*)data)[offset + i];\n        }\n        size = sa.size();\n    }\n    else {\n        core::BinaryOutputArchive oa(core::ArrayWrapper((std::byte*)out, size));\n        for (int i = 0; i < n; ++i) {\n            oa&((T*)data)[offset + i];\n        }\n    }\n}\n\ntemplate<class T>\nvoid des_fn(void* data, int offset, int n, void* in, size_t size)\n{\n    core::BinaryInputArchive ia(core::ArrayWrapper((std::byte*)in, size));\n    for (int i = 0; i < n; ++i) {\n        ia&((T*)data)[offset + i];\n    }\n}\n\n}  // namespace detail\n\n//////////////////////////////////////////////////////////////////////////////////\n// Typed array interface\ntemplate<class T>\nvoid Broadcast(HostCommImpl* comm, T* data, int n, int root)\n{\n    if constexpr (std::is_trivially_copyable_v<T>) {\n        comm->Broadcast(data, sizeof(T) * n, data_type_v<uint8_t>, root, detail::copy_fn<uint8_t>);\n    }\n    else {\n        if (comm->is_same_process()) {\n            /// TODO: Constness should be considered\n            comm->Broadcast(data, n, kNull, root, detail::copy_fn<T>);\n        }\n        else {\n            comm->Broadcast(data, n, kNull, root, detail::copy_fn<T>, detail::ser_fn<T>, detail::des_fn<T>);\n        }\n    }\n}\n\ntemplate<class T>\nvoid AllGather(HostCommImpl* comm, T* data, int n)\n{\n    if constexpr (std::is_trivially_copyable_v<T>) {\n        comm->AllGather(data, sizeof(T) * n, data_type_v<uint8_t>, detail::copy_fn<uint8_t>);\n    }\n    else {\n        if (comm->is_same_process()) {\n            /// TODO: Constness should be considered\n            comm->AllGather(data, n, kNull, detail::copy_fn<T>);\n        }\n        else {\n            comm->AllGather(data, n, kNull, detail::copy_fn<T>, detail::ser_fn<T>, detail::des_fn<T>);\n        }\n    }\n}\n\ntemplate<class T>\nvoid AllReduce(HostCommImpl* comm, T* data, int n, RedOp red_op)\n{\n    comm->AllReduce(data, n, data_type_v<T>, red_op);\n}\n\n//////////////////////////////////////////////////////////////////////////////////\n// Typed value interface\ntemplate<class T>\nvoid Broadcast(HostCommImpl* comm, T& value, int root)\n{\n    Broadcast(comm, &value, 1, root);\n}\n\ntemplate<class T>\nstd::vector<T> AllGather(HostCommImpl* comm, const T& value)\n{\n    std::vector<T> ret(comm->n_ranks());\n    ret.at(comm->rank()) = value;\n    AllGather(comm, ret.data(), 1);\n    return ret;\n}\n\ntemplate<class T>\nT AllReduce(HostCommImpl* comm, const T& value, RedOp red_op)\n{\n    T tmp = value;\n    AllReduce(comm, &tmp, 1, red_op);\n    return tmp;\n}\n\nclass HostGroupId {\npublic:\n    virtual ~HostGroupId() = default;\n\n    virtual void Initialize()             = 0;\n    virtual void Export(std::ostream& os) = 0;\n    virtual void Import(std::istream& is) = 0;\n\n    virtual HostComm CreateCommunicator(int n_ranks, int rank, int node_rank = 0) = 0;\n};\n\nstd::unique_ptr<HostGroupId> CreateHostGroupId(const std::string& backend);\n\n}  // namespace turbomind::comm\n"
  },
  {
    "path": "src/turbomind/comm/nccl/CMakeLists.txt",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\ncmake_minimum_required(VERSION 3.11)\n\nadd_library(nccl_comm STATIC nccl.cu)\ntarget_link_libraries(nccl_comm PRIVATE rms_norm core ${NCCL_LIBRARIES} logger)\ntarget_include_directories(nccl_comm PRIVATE ${NCCL_INCLUDE_DIRS})\n\nset_property(TARGET nccl_comm PROPERTY POSITION_INDEPENDENT_CODE  ON)\nset_property(TARGET nccl_comm PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS  ON)\n"
  },
  {
    "path": "src/turbomind/comm/nccl/nccl.cu",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include <cstdint>\n#include <memory>\n#include <numeric>\n#include <type_traits>\n#include <unordered_map>\n\n#include <dlfcn.h>\n\n#include <nccl.h>\n\n#include \"src/turbomind/comm/device_comm.h\"\n#include \"src/turbomind/comm/host_comm.h\"\n#include \"src/turbomind/core/check.h\"\n#include \"src/turbomind/utils/cuda_utils.h\"\n#include \"src/turbomind/utils/logger.h\"\n#include \"src/turbomind/utils/string_utils.h\"\n\n#include \"src/turbomind/kernels/norm/rms_norm.h\"\n\n#define NCCLCHECK(e)                                                                                                   \\\n    if (auto ec = e; ec != ncclSuccess) {                                                                              \\\n        auto msg = fmtstr(\"NCCL error %s:%d '%s'\", __FILE__, __LINE__, ncclGetErrorString(ec));                        \\\n        throw std::runtime_error(msg.c_str());                                                                         \\\n    }\n\n#if NCCL_VERSION_CODE < NCCL_VERSION(2, 27, 0)\n/* Window Registration flags */\n#define NCCL_WIN_DEFAULT 0x00\n#define NCCL_WIN_COLL_SYMMETRIC 0x01\n#endif\n\nnamespace turbomind::comm {\n\nstatic inline ncclDataType_t to_nccl_dtype(DataType type)\n{\n    switch (type) {\n        case kFloat32:\n            return ncclFloat;\n        case kFloat16:\n            return ncclHalf;\n        case kBfloat16:\n            return ncclBfloat16;\n        case kUint8:\n            return ncclUint8;\n        default:\n            throw std::runtime_error(\"not supported\");\n    }\n}\n\nstruct NcclApis {\n    ncclResult_t (*ncclMemAlloc)(void** ptr, size_t size);\n    ncclResult_t (*ncclMemFree)(void* ptr);\n    ncclResult_t (*ncclCommRegister)(const ncclComm_t comm, void* buff, size_t size, void** handle);\n    ncclResult_t (*ncclCommDeregister)(const ncclComm_t comm, void* handle);\n    ncclResult_t (*ncclCommWindowRegister)(ncclComm_t comm, void* buff, size_t size, void** win, int winFlags);\n    ncclResult_t (*ncclCommWindowDeregister)(ncclComm_t comm, void* win);\n    // `ncclConfig_t` varies between versions, should be fine as long as we are passing nullptr to it\n    ncclResult_t (*ncclCommSplit)(ncclComm_t comm, int color, int key, ncclComm_t* newcomm, void* config);\n};\n\nstatic NcclApis& nccl_apis()\n{\n    static auto value = [] {\n        int version{};\n        ncclGetVersion(&version);\n        auto     handle = dlopen(\"libnccl.so.2\", RTLD_LAZY);\n        NcclApis apis{};\n        if (!handle) {\n            return apis;\n        }\n        auto load_symbol = [&](auto& dst, auto name) {\n            using T = std::remove_reference_t<decltype(dst)>;\n            dst     = reinterpret_cast<T>(dlsym(handle, name));\n        };\n        if (version >= NCCL_VERSION(2, 27, 0)) {\n            if (version < NCCL_VERSION(2, 28, 0)) {\n                TM_LOG_WARNING(\n                    \"[NCCL] Window registration may cause memory leaks in NCCL 2.27, use NCCL 2.28+ or disable the feature by setting NCCL_WIN_ENABLE=0.\");\n            }\n            load_symbol(apis.ncclCommWindowRegister, \"ncclCommWindowRegister\");\n            load_symbol(apis.ncclCommWindowDeregister, \"ncclCommWindowDeregister\");\n        }\n        else {\n            TM_LOG_WARNING(\n                \"[NCCL] Window registration is not supported by NCCL %d, use NCCL 2.28+ for better performance.\",\n                version);\n        }\n        if (version >= NCCL_VERSION(2, 19, 0)) {\n            load_symbol(apis.ncclMemAlloc, \"ncclMemAlloc\");\n            load_symbol(apis.ncclMemFree, \"ncclMemFree\");\n            load_symbol(apis.ncclCommRegister, \"ncclCommRegister\");\n            load_symbol(apis.ncclCommDeregister, \"ncclCommDeregister\");\n        }\n        if (version >= NCCL_VERSION(2, 18, 0)) {\n            load_symbol(apis.ncclCommSplit, \"ncclCommSplit\");\n        }\n        else {\n            TM_LOG_WARNING(\"[NCCL] Splitting communicators is not supported by NCCL %d, use NCCL 2.18+ if needed.\",\n                           version);\n        }\n        return apis;\n    }();\n    return value;\n}\n\nclass NcclCommImpl: public DeviceCommImpl {\npublic:\n    NcclCommImpl(ncclComm_t comm, int n_ranks, int rank, HostComm h_comm):\n        h_comm_{h_comm}, global_n_ranks_{n_ranks}, global_rank_{rank}, groups_{comm}\n    {\n        handles_.emplace_back();\n    }\n\n    ~NcclCommImpl()\n    {\n        for (const auto& [ptr, _] : handles_.at(0)) {\n            TM_LOG_WARNING(\"[NCCL][%d] Buffer %p is not deregistered\", global_rank_, ptr);\n        }\n\n        for (const auto& [ptr, size] : buffers_) {\n            TM_LOG_WARNING(\"[NCCL][%d] Allocation (%p, %lu) is not freed\", global_rank_, ptr, size);\n        }\n\n        for (auto& c : groups_) {\n            if (auto ec = ncclCommDestroy(c); ec != ncclSuccess) {\n                TM_LOG_ERROR(\"[NCCL][%d] Failed to destroy communicator: %s\", global_rank_, ncclGetErrorString(ec));\n            }\n        }\n    }\n\n    int rank(int group) const override\n    {\n        int rank{};\n        NCCLCHECK(ncclCommUserRank(groups_.at(group), &rank));\n        return rank;\n    }\n\n    int n_ranks(int group) const override\n    {\n        int n_ranks{};\n        NCCLCHECK(ncclCommCount(groups_.at(group), &n_ranks));\n        return n_ranks;\n    }\n\n    void* Allocate(size_t size) override\n    {\n        void* ptr{};\n        if (auto alloc_fn = nccl_apis().ncclMemAlloc) {\n            NCCLCHECK(alloc_fn(&ptr, size));\n        }\n        else {\n            check_cuda_error(cudaMalloc(&ptr, size));\n        }\n        buffers_.emplace(ptr, size);\n        return ptr;\n    }\n\n    void Free(void* ptr) override\n    {\n        if (auto it = buffers_.find(ptr); it != buffers_.end()) {\n            if (auto free_fn = nccl_apis().ncclMemFree) {\n                NCCLCHECK(free_fn(ptr));\n            }\n            else {\n                check_cuda_error(cudaFree(ptr));\n            }\n            buffers_.erase(ptr);\n        }\n        else {\n            TM_LOG_WARNING(\"[NCCL][%d] Freeing %p which is not allocated by NcclComm\", global_rank_, ptr);\n        }\n    }\n\n    void Register(void* ptr, size_t size) override\n    {\n        if (!handles_.at(0).count(ptr)) {\n            for (size_t i = 0; i < handles_.size(); ++i) {\n                Register(i, ptr, size);\n            }\n        }\n        else {\n            TM_LOG_WARNING(\"[NCCL][%d] Duplicated registration on (%p, %lu)\", global_rank_, ptr, size);\n        }\n    }\n\n    void Deregister(void* ptr) override\n    {\n        if (handles_.at(0).count(ptr)) {\n            for (size_t i = 0; i < handles_.size(); ++i) {\n                Deregister(i, ptr);\n            }\n        }\n        else {\n            TM_LOG_WARNING(\"[NCCL][%d] Deregistering non-registered address %p\", global_rank_, ptr);\n        }\n    }\n\n    void Register(int group, void* buff, size_t size)\n    {\n        void* handle{};\n        auto  comm = groups_.at(group);\n        if (auto func = nccl_apis().ncclCommWindowRegister) {\n            NCCLCHECK(func(comm, buff, size, &handle, NCCL_WIN_COLL_SYMMETRIC));\n        }\n        else if (auto func = nccl_apis().ncclCommRegister) {\n            NCCLCHECK(func(comm, buff, size, &handle));\n        }\n        handles_.at(group).emplace(buff, std::make_pair(handle, size));\n    }\n\n    void Deregister(int group, void* buff)\n    {\n        auto& handles = handles_.at(group);\n        if (auto it = handles.find(buff); it != handles.end()) {\n            if (auto func = nccl_apis().ncclCommWindowDeregister) {\n                NCCLCHECK(func(groups_.at(group), it->second.first));\n            }\n            else if (auto func = nccl_apis().ncclCommDeregister) {\n                NCCLCHECK(func(groups_.at(group), it->second.first));\n            }\n            handles.erase(it);\n        }\n    }\n\n    int Split(int color, int key, int group) override\n    {\n        auto split_fn = TM_CHECK_NOTNULL(nccl_apis().ncclCommSplit);\n\n        ncclComm_t comm{};\n        NCCLCHECK(split_fn(groups_.at(group), color, key, &comm, nullptr));\n\n        int index = groups_.size();\n        groups_.push_back(comm);\n        handles_.emplace_back();\n\n        // register all existing buffers on the group\n        for (const auto& [k, v] : handles_.at(0)) {\n            Register(index, k, v.second);\n        }\n\n        return index;\n    }\n\n    int Query(QueryAttr attr) const noexcept override\n    {\n        return 0;\n    }\n\n    void AllReduceSum(\n        const void* sendbuff, void* recvbuff, size_t count, DataType type, int group, cudaStream_t stream) override\n    {\n        NCCLCHECK(ncclGroupStart());\n        NCCLCHECK(ncclAllReduce(sendbuff, recvbuff, count, to_nccl_dtype(type), ncclSum, groups_.at(group), stream));\n        NCCLCHECK(ncclGroupEnd());\n    }\n\n    void AllGather(\n        const void* sendbuff, void* recvbuff, size_t sendcount, DataType type, int group, cudaStream_t stream) override\n    {\n        NCCLCHECK(ncclGroupStart());\n        NCCLCHECK(ncclAllGather(sendbuff, recvbuff, sendcount, to_nccl_dtype(type), groups_.at(group), stream));\n        NCCLCHECK(ncclGroupEnd());\n    }\n\n    void ReduceScatter(\n        const void* sendbuff, void* recvbuff, size_t recvcount, DataType type, int group, cudaStream_t stream) override\n    {\n        NCCLCHECK(ncclGroupStart());\n        NCCLCHECK(\n            ncclReduceScatter(sendbuff, recvbuff, recvcount, to_nccl_dtype(type), ncclSum, groups_.at(group), stream));\n        NCCLCHECK(ncclGroupEnd());\n    }\n\n    void AllreduceResidualBiasRMSnorm(void*        hidden,\n                                      void*        residual,\n                                      const void*  bias,\n                                      const void*  weights,\n                                      float        eps,\n                                      int          dim,\n                                      int          token_num,\n                                      DataType     dtype,\n                                      int          group,\n                                      cudaStream_t stream) override\n    {\n        const auto elem_size = byte_size(dtype);\n\n        auto rms_norm = [&](int64_t first, int64_t count) {\n            invokeResidualBiasRMSNorm((char*)hidden + elem_size * first * dim,\n                                      (char*)residual + elem_size * first * dim,\n                                      weights,\n                                      bias,\n                                      dtype,\n                                      dim,\n                                      count,\n                                      eps,\n                                      stream);\n        };\n\n        if (1) {\n            AllReduceSum(hidden, hidden, token_num * dim, dtype, group, stream);\n            rms_norm(0, token_num);\n        }\n        else {  // Only useful for large input size\n            const int    n_ranks   = this->n_ranks(group);\n            const int    rank      = this->rank(group);\n            const int    slice     = (token_num + n_ranks - 1) / n_ranks;\n            const size_t recvcount = slice * dim;\n            auto         sendbuff  = hidden;\n            auto         recvbuff  = (char*)hidden + elem_size * rank * recvcount;\n            ReduceScatter(sendbuff, recvbuff, recvcount, dtype, group, stream);\n            rms_norm(rank * slice, slice);\n            AllGather(recvbuff, sendbuff, recvcount, dtype, group, stream);\n        }\n    }\n\n    void AllreduceResidualBiasRMSnormEx(void*        hidden,\n                                        void*        residual,\n                                        const void*  bias,\n                                        const void*  weights,\n                                        float        eps,\n                                        int          dim,\n                                        DataType     type,\n                                        int          group0,\n                                        int          group1,\n                                        const int*   local_token_nums,\n                                        cudaStream_t stream) override\n    {\n        const size_t         elem_size = byte_size(type);\n        const ncclDataType_t nccl_type = to_nccl_dtype(type);\n\n        FT_CHECK(group0 == 0 || group1 == 0);\n\n        ncclComm_t comm0 = groups_.at(group0);\n        ncclComm_t comm1 = groups_.at(group1);\n\n        int tp0{}, tp1{};\n        NCCLCHECK(ncclCommCount(comm0, &tp0));\n        NCCLCHECK(ncclCommCount(comm1, &tp1));\n\n        const int inner_tp = std::min(tp0, tp1);\n\n        FT_CHECK(tp0 % inner_tp == 0 && tp1 % inner_tp == 0);\n\n        std::vector<std::tuple<int, int, int>> tasks;\n        tasks.reserve(global_n_ranks_);\n\n        for (int i = 0, offset = 0; i < global_n_ranks_; ++i) {\n            const int num   = local_token_nums[i / inner_tp];\n            const int slice = (num + inner_tp - 1) / inner_tp;\n            const int first = std::min(num, i % inner_tp * slice);\n            const int last  = std::min(num, first + slice);\n            tasks.emplace_back(offset, first, last - first);\n            if ((i + 1) % inner_tp == 0) {\n                offset += num;\n            }\n        }\n\n        if (tp0 > 1) {\n            NCCLCHECK(ncclGroupStart());\n            for (int i = 0; i < global_n_ranks_; ++i) {\n                if (auto& [offset, first, num] = tasks[i]; num > 0) {\n                    char* buff = (char*)hidden + elem_size * (offset + first) * dim;\n                    NCCLCHECK(ncclReduce(buff, buff, (size_t)num * dim, nccl_type, ncclSum, i % tp0, comm0, stream));\n                }\n            }\n            NCCLCHECK(ncclGroupEnd());\n            sync_check_cuda_error();\n        }\n\n        if (auto& [offset, first, num] = tasks[global_rank_]; num > 0) {\n            char* buff = (char*)hidden + elem_size * (offset + first) * dim;\n            invokeResidualBiasRMSNorm(\n                buff, (char*)residual + elem_size * first * dim, weights, bias, type, dim, num, eps, stream);\n            sync_check_cuda_error();\n        }\n\n        if (tp1 > 1) {\n            NCCLCHECK(ncclGroupStart());\n            for (int i = 0; i < global_n_ranks_; ++i) {\n                if (auto& [offset, first, num] = tasks[i]; num > 0) {\n                    char* buff = (char*)hidden + elem_size * (offset + first) * dim;\n                    NCCLCHECK(ncclBroadcast(buff, buff, (size_t)num * dim, nccl_type, i % tp1, comm1, stream));\n                }\n            }\n            NCCLCHECK(ncclGroupEnd());\n            sync_check_cuda_error();\n        }\n    }\n\n    void Broadcast(const void*  sendbuff,  //\n                   void*        recvbuff,\n                   size_t       count,\n                   DataType     type,\n                   int          root,\n                   int          group,\n                   cudaStream_t stream) override\n    {\n        NCCLCHECK(ncclBroadcast(recvbuff, recvbuff, count, to_nccl_dtype(type), root, groups_.at(group), stream));\n    }\n\nprivate:\n    HostComm h_comm_;\n\n    int global_n_ranks_;\n    int global_rank_;\n\n    std::vector<ncclComm_t> groups_;\n\n    std::vector<std::unordered_map<void*, std::pair<void*, size_t>>> handles_;\n\n    std::unordered_map<void*, size_t> buffers_;\n};\n\nDeviceComm CreateNcclCommunicator(int n_ranks, int rank, HostComm h_comm)\n{\n    ncclUniqueId uid{};\n    if (rank == 0) {\n        NCCLCHECK(ncclGetUniqueId(&uid));\n    }\n\n    static_assert(std::is_trivially_copyable_v<ncclUniqueId>);\n    Broadcast(h_comm, uid, 0);\n\n    ncclComm_t comm{};\n    NCCLCHECK(ncclCommInitRank(&comm, n_ranks, uid, rank));\n\n    return DeviceComm{std::make_unique<NcclCommImpl>(comm, n_ranks, rank, h_comm)};\n}\n\n}  // namespace turbomind::comm\n"
  },
  {
    "path": "src/turbomind/comm/test_comm.cu",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include <algorithm>\n#include <chrono>\n#include <cmath>\n#include <cstdio>\n#include <memory>\n#include <numeric>\n#include <optional>\n#include <ostream>\n#include <random>\n#include <sstream>\n#include <thread>\n\n// #include <cuda_profiler_api.h>\n#include <cuda_runtime.h>\n\n#include \"src/turbomind/comm/device_comm.h\"\n#include \"src/turbomind/comm/host_comm.h\"\n#include \"src/turbomind/utils/cuda_utils.h\"\n\nusing namespace turbomind::comm;\nusing turbomind::data_type_v;\nusing turbomind::check;\nusing turbomind::myAssert;\nusing std::vector;\n\n[[maybe_unused]] static constexpr bool is_ncu = 0;\n\nstruct Context {\n\n    cudaStream_t stream;\n\n    cudaEvent_t ev_start;\n    cudaEvent_t ev_end;\n\n    std::vector<void*> buffers;\n\n    template<class F>\n    float exec(F func)\n    {\n        check_cuda_error(cudaStreamSynchronize(stream));\n        check_cuda_error(cudaEventRecord(ev_start, stream));\n\n        func(stream);\n\n        check_cuda_error(cudaEventRecord(ev_end, stream));\n        check_cuda_error(cudaEventSynchronize(ev_end));\n        float ms{};\n        check_cuda_error(cudaEventElapsedTime(&ms, ev_start, ev_end));\n        return ms;\n    }\n\n    template<class T>\n    T* malloc(size_t count)\n    {\n        T* data;\n        check_cuda_error(cudaMallocAsync(&data, sizeof(T) * count, stream));\n        buffers.push_back(data);\n        return data;\n    }\n\n    template<class T>\n    void copy_n(const T* src, size_t count, T* dst)\n    {\n        check_cuda_error(cudaMemcpyAsync(dst, src, sizeof(T) * count, cudaMemcpyDefault, stream));\n    }\n\n    void sync()\n    {\n        check_cuda_error(cudaStreamSynchronize(stream));\n    }\n\n    Context(int device_id)\n    {\n        check_cuda_error(cudaSetDevice(device_id));\n        check_cuda_error(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));\n        check_cuda_error(cudaEventCreate(&ev_start));\n        check_cuda_error(cudaEventCreate(&ev_end));\n    }\n    ~Context()\n    {\n        for (auto& p : buffers) {\n            cudaFreeAsync(p, stream);\n            p = {};\n        }\n        cudaStreamSynchronize(stream);\n        cudaEventDestroy(ev_end);\n        cudaEventDestroy(ev_start);\n        cudaStreamDestroy(stream);\n    }\n};\n\nstruct TestComm {\n    std::vector<HostComm>   h_comm_;\n    std::vector<DeviceComm> d_comm_;\n    std::vector<HostComm>   h_split_;\n    std::vector<int>        d_split_;\n\n    int              warmup_;\n    int              iters_;\n    std::vector<int> tokens_;\n    size_t           max_tokens_;\n\n    static auto Init(int n_ranks, int split, const std::string& backend)\n    {\n\n        std::unique_ptr<HostGroupId> group_id = CreateHostGroupId({});\n        std::string                  group_id_data;\n        if (1) {  // master\n            group_id->Initialize();\n            std::stringstream ss;\n            group_id->Export(ss);\n            group_id_data = ss.str();\n        }\n\n        std::vector<DeviceComm> d_comm(n_ranks);\n        std::vector<HostComm>   h_comm(n_ranks);\n        std::vector<int>        d_split(n_ranks);\n        std::vector<HostComm>   h_split(n_ranks);\n\n        auto init = [&](int rank) {\n            // initialize host communicators\n            std::stringstream            ss(group_id_data);\n            std::unique_ptr<HostGroupId> host_id = CreateHostGroupId({});\n            host_id->Import(ss);\n            h_comm[rank] = host_id->CreateCommunicator(n_ranks, rank);\n\n            // initialize device communicators\n            cudaSetDevice(rank);\n            d_comm[rank] = CreateDeviceCommunicator(backend, n_ranks, rank, h_comm[rank]);\n\n            // split communicators\n            if (split) {\n                h_split[rank] = h_comm[rank]->Split(rank / split, 0);\n                d_split[rank] = d_comm[rank]->Split(rank / split, 0, 0);\n            }\n            else {\n                h_split[rank] = h_comm[rank];\n                d_split[rank] = 0;\n            }\n        };\n\n        std::vector<std::thread> threads;\n        for (int i = 0; i < n_ranks; ++i) {\n            threads.emplace_back(init, i);\n        }\n        for (auto& t : threads) {\n            t.join();\n        }\n\n        return std::make_tuple(h_comm, std::move(d_comm), h_split, d_split);\n    }\n\n    void Run(int hidden_dim, int vocab_size, int tp, int warmup, int iters, std::vector<int> tokens)\n    {\n        int device_num{};\n        cudaGetDeviceCount(&device_num);\n\n        std::cout << \"Device count: \" << device_num << \"\\n\";\n\n        if (tp < 0) {\n            tp = device_num;\n        }\n\n        std::tie(h_comm_, d_comm_, h_split_, d_split_) = Init(device_num, 4, \"cuda-ipc\");\n\n        TM_CHECK_GT(h_comm_.size(), 0);\n        TM_CHECK_GT(d_comm_.size(), 0);\n\n        warmup_ = warmup;\n        iters_  = iters;\n        tokens_ = tokens;\n\n        max_tokens_ = *std::max_element(tokens_.begin(), tokens_.end());\n\n        const int g = 0;\n\n        TestAllReduce<half>(hidden_dim, 0);\n        // TestAllreduceResidualBiasRMSnorm<half>(hidden_dim, g);\n        // TestAllreduceResidualBiasRMSnormEx<half>(hidden_dim, 0, 0);\n        // TestAllreduceResidualBiasRMSnormEx<half>(hidden_dim, 1, 0);\n        // TestAllreduceResidualBiasRMSnormEx<half>(hidden_dim, 0, 1);\n        // TestAllGather<half>(hidden_dim / tp, g);  // tp embedding\n        // TestAllGather<half>(vocab_size / tp, g);\n        // TestBroadcast<half>(32768, g);\n    }\n\n    template<class T>\n    void TestAllReduce(size_t dim, int group = 0)\n    {\n        const auto dtype = data_type_v<T>;\n\n        const int tp_size = d_comm_[0]->n_ranks(group);\n        const int dp_size = d_comm_.size() / tp_size;\n\n        //    dp         tp           dim\n        std::vector<std::vector<std::vector<T>>> data(dp_size);\n        //    dp         dim\n        std::vector<std::vector<T>> ref_data(dp_size);\n\n        for (int i = 0; i < dp_size; ++i) {\n            data[i].resize(tp_size);\n            ref_data[i].resize(max_tokens_ * dim);\n        }\n\n        auto func = [&](int index, DeviceComm& d_comm, HostComm& h_comm) {\n            const int rank    = d_comm->rank(group);\n            const int n_ranks = d_comm->n_ranks(group);\n            const int g_rank  = d_comm->rank(0);\n            const int d       = g_rank / n_ranks;\n\n            const size_t max_count = max_tokens_ * dim;\n\n            std::mt19937                  gen{(unsigned)index};\n            std::uniform_int_distribution dist{0, 31};  // 5 mantissa bits\n            if (g_rank == 0) {\n                std::cout << \"preparing data ... \" << std::flush;\n            }\n            data[d][rank].resize(max_count);\n            for (size_t i = 0; i < max_count; ++i) {\n                data[d][rank][i] = T(dist(gen));\n            }\n            h_comm->Sync();\n            const size_t slice = (max_count + n_ranks - 1) / n_ranks;\n            for (int r = 0; r < n_ranks; ++r) {\n                for (size_t i = rank * slice; i < (rank + 1) * slice && i < max_count; ++i) {\n                    ref_data[d][i] += data[d][r][i];\n                }\n            }\n            h_comm->Sync();\n            if (g_rank == 0) {\n                std::cout << \"done.\\n\";\n            }\n\n            Context ctx{g_rank};\n\n            T* d_data = ctx.malloc<T>(max_count);\n\n            T* d_tmp = (T*)d_comm->Allocate(sizeof(T) * max_count);\n            d_comm->Register(d_tmp, sizeof(T) * max_count);\n\n            ctx.copy_n(data[d][rank].data(), max_count, d_data);\n\n            [[maybe_unused]] auto verify = [&](auto count) {\n                std::vector<T> res(count);\n                ctx.copy_n(d_tmp, count, res.data());\n                ctx.sync();\n                size_t diff = 0;\n                for (size_t i = 0; i < count; ++i) {\n                    auto& x = res[i];\n                    auto& y = ref_data[d][i];\n                    diff += x != y;\n                    if (diff == 1) {\n                        printf(\"%d: %f vs %f\\n\", (int)i, (float)x, (float)y);\n                    }\n                }\n                if (diff) {\n                    printf(\"[rank %d] count = %d, diff = %lu\\n\", g_rank, (int)count, diff);\n                    std::this_thread::sleep_for(std::chrono::seconds(1));\n                    std::abort();\n                }\n            };\n\n            std::vector<float> deltas;\n            for (const auto& n : tokens_) {\n                const size_t count = (size_t)n * dim;\n                auto&        delta = deltas.emplace_back();\n                h_comm->Sync();\n                for (int i = 0; i < warmup_ + iters_; ++i) {\n                    ctx.copy_n(d_data, count, d_tmp);\n                    auto ms = ctx.exec([&](auto stream) {  //\n                        d_comm->AllReduceSum(d_tmp, d_tmp, count, dtype, group, stream);\n                    });\n                    if (i >= warmup_) {\n                        delta += ms;\n                    }\n                    // verify(count);\n                }\n                verify(count);\n            }\n\n            if (g_rank == 0) {\n                SummaryHeader(\"allreduce\", dim, n_ranks);\n                for (size_t i = 0; i < tokens_.size(); ++i) {\n                    const float  avg   = deltas[i] / iters_;\n                    const size_t count = tokens_[i] * dim;\n                    const float  algbw = sizeof(T) * count / 1e9f / avg * 1000.f;\n                    const float  busbw = algbw * (2 * (n_ranks - 1)) / n_ranks;\n                    SummaryEntry(tokens_[i], count, sizeof(T), avg, algbw, busbw);\n                }\n            }\n\n            d_comm->Deregister(d_tmp);\n            d_comm->Free(d_tmp);\n        };\n\n        std::vector<std::thread> threads;\n        for (size_t i = 0; i < d_comm_.size(); ++i) {\n            threads.emplace_back(func, i, std::ref(d_comm_[i]), std::ref(h_comm_[i]));\n        }\n        for (auto& t : threads) {\n            t.join();\n        }\n    }\n\n    template<class T>\n    void TestAllreduceResidualBiasRMSnorm(size_t dim, int group)\n    {\n        vector<T> weight(dim);\n        vector<T> bias(dim);\n\n        constexpr float eps      = 1e-5;\n        constexpr bool  has_bias = true;\n\n        std::cout << \"preparing data ... \" << std::flush;\n\n        {\n            std::mt19937                  gen{};\n            std::uniform_int_distribution dist{0, 31};  // 5 mantissa bits\n            for (size_t i = 0; i < dim; ++i) {\n                weight[i] = T(dist(gen));\n            }\n            if (has_bias) {\n                for (size_t i = 0; i < dim; ++i) {\n                    bias[i] = T(dist(gen));\n                }\n            }\n        }\n\n        const auto dtype = data_type_v<T>;\n\n        const int tp_size = d_comm_[0]->n_ranks(group);\n        const int dp_size = d_comm_.size() / tp_size;\n        // dp    tp     dim\n        vector<vector<vector<T>>> src_data(dp_size);\n        // dp    dim\n        vector<vector<T>> ref_data(dp_size);\n        vector<vector<T>> src_res(dp_size);\n        vector<vector<T>> ref_res(dp_size);\n\n        for (int i = 0; i < dp_size; ++i) {\n            src_data[i].resize(tp_size);\n            ref_data[i].resize(max_tokens_ * dim);\n            src_res[i].resize(max_tokens_ * dim);\n            ref_res[i].resize(max_tokens_ * dim);\n        }\n\n        auto func = [&](int index, DeviceComm& d_comm, HostComm& h_comm) {\n            const int rank    = d_comm->rank(group);\n            const int n_ranks = d_comm->n_ranks(group);\n            const int g_rank  = d_comm->rank(0);\n            const int d       = g_rank / n_ranks;\n\n            const size_t max_count = max_tokens_ * dim;\n\n            std::mt19937                  gen{(unsigned)index};\n            std::uniform_int_distribution dist{0, 31};  // 5 mantissa bits\n\n            src_data[d][rank].resize(max_count);\n            for (size_t i = 0; i < max_count; ++i) {\n                src_data[d][rank][i] = T(dist(gen));\n            }\n            h_comm->Sync();\n            const size_t slice = (max_tokens_ + n_ranks - 1) / n_ranks;\n            for (size_t t = rank * slice; t < (rank + 1) * slice && t < max_tokens_; ++t) {\n                for (int r = 0; r < n_ranks; ++r) {\n                    for (size_t i = 0; i < dim; ++i) {\n                        ref_data[d][t * dim + i] += src_data[d][r][t * dim + i];\n                    }\n                }\n                float sum = 0.f;\n                for (size_t i = 0; i < dim; ++i) {\n                    const size_t idx = t * dim + i;\n                    src_res[d][idx]  = T(dist(gen));\n                    ref_res[d][idx]  = src_res[d][idx] + ref_data[d][idx] + bias[i];  // r' <- r + (h + b)\n                    sum += (float)ref_res[d][idx] * (float)ref_res[d][idx];\n                }\n                sum = 1 / (sqrtf(sum / dim) + eps);\n                for (size_t i = 0; i < dim; ++i) {\n                    const size_t idx = t * dim + i;\n                    float        tmp = (float)ref_res[d][idx];\n                    ref_data[d][idx] = tmp * sum * (float)weight[i];  // h' <- norm(r) * w\n                }\n            }\n            h_comm->Sync();\n            if (g_rank == 0) {\n                std::cout << \"done.\\n\";\n            }\n\n            Context ctx{g_rank};\n\n            T* d_bias   = ctx.malloc<T>(dim);\n            T* d_weight = ctx.malloc<T>(dim);\n\n            T* d_data    = ctx.malloc<T>(max_count);\n            T* d_res     = ctx.malloc<T>(max_count);\n            T* d_tmp_res = ctx.malloc<T>(max_count);\n\n            T* d_tmp_data = (T*)d_comm->Allocate(sizeof(T) * max_count);\n            d_comm->Register(d_tmp_data, sizeof(T) * max_count);\n\n            ctx.copy_n(src_data[d][rank].data(), max_count, d_data);\n            ctx.copy_n(src_res[d].data(), max_count, d_res);\n            ctx.copy_n(bias.data(), dim, d_bias);\n            ctx.copy_n(weight.data(), dim, d_weight);\n\n            [[maybe_unused]] auto verify = [&](auto token_num) {\n                const size_t count = (size_t)token_num * dim;\n                vector<T>    h_data(count);\n                vector<T>    h_res(count);\n                ctx.copy_n(d_tmp_data, count, h_data.data());\n                ctx.copy_n(d_tmp_res, count, h_res.data());\n                ctx.sync();\n                const size_t slice    = (token_num + n_ranks - 1) / n_ranks * dim;\n                const size_t first    = rank * slice;\n                const size_t last     = std::min(first + slice, count);\n                size_t       res_diff = 0;\n                for (size_t i = first; i < last; ++i) {\n                    auto& x       = h_res[i];\n                    auto& y       = ref_res[d][i];\n                    int   is_diff = !(x == y);\n                    if (!res_diff && is_diff) {\n                        printf(\"[rank %d], %ld: %f vs %f\\n\", g_rank, i - first, (float)x, (float)y);\n                    }\n                    res_diff += is_diff;\n                }\n                float data_diff = 0;\n                for (size_t i = 0; i < count; ++i) {\n                    float diff = (float)h_data[i] - (float)ref_data[d][i];\n                    data_diff += std::abs(diff);\n                }\n                data_diff /= count;\n                if (res_diff || data_diff > 0.1f || std::isnan(data_diff)) {\n                    printf(\"[rank %d] count = %d, res_diff = %lu, data_diff = %f\\n\",\n                           g_rank,\n                           (int)token_num,\n                           res_diff,\n                           data_diff);\n                    std::this_thread::sleep_for(std::chrono::seconds(5));\n                    std::abort();\n                }\n                else if (g_rank == 0) {\n                    printf(\"[rank %d] count = %d, data_diff = %f\\n\", g_rank, (int)token_num, data_diff);\n                }\n            };\n\n            vector<float> deltas;\n            for (const auto& n : tokens_) {\n                const size_t count = (size_t)n * dim;\n                auto&        delta = deltas.emplace_back();\n                h_comm->Sync();\n                for (int i = 0; i < warmup_ + iters_; ++i) {\n                    ctx.copy_n(d_data, count, d_tmp_data);\n                    ctx.copy_n(d_res, count, d_tmp_res);\n                    auto ms = ctx.exec([&](auto stream) {  //\n                        d_comm->AllreduceResidualBiasRMSnorm(d_tmp_data,\n                                                             d_tmp_res,\n                                                             has_bias ? d_bias : nullptr,\n                                                             d_weight,\n                                                             eps,\n                                                             dim,\n                                                             n,\n                                                             dtype,\n                                                             group,\n                                                             stream);\n                    });\n                    if (i >= warmup_) {\n                        delta += ms;\n                    }\n                    // verify(n);\n                }\n                verify(n);\n            }\n\n            d_comm->Deregister(d_tmp_data);\n            d_comm->Free(d_tmp_data);\n\n            if (g_rank == 0) {\n                SummaryHeader(\"allreduce | rmsnorm\", dim, n_ranks);\n                for (size_t i = 0; i < tokens_.size(); ++i) {\n                    const float  avg   = deltas[i] / iters_;\n                    const size_t count = tokens_[i] * dim;\n                    const float  algbw = sizeof(T) * count / 1e9f / avg * 1000.f;\n                    const float  busbw = algbw * (2 * (n_ranks - 1)) / n_ranks;\n                    SummaryEntry(tokens_[i], count, sizeof(T), avg, algbw, busbw);\n                }\n            }\n        };\n\n        std::vector<std::thread> threads;\n        for (size_t i = 0; i < d_comm_.size(); ++i) {\n            threads.emplace_back(func, i, std::ref(d_comm_[i]), std::ref(h_comm_[i]));\n        }\n        for (auto& t : threads) {\n            t.join();\n        }\n    }\n\n    template<class T>\n    void TestAllGather(size_t dim, int group)\n    {\n        const auto dtype = data_type_v<T>;\n\n        const int tp_size = d_comm_[0]->n_ranks(group);\n        const int dp_size = d_comm_.size() / tp_size;\n\n        vector<vector<vector<T>>> data(dp_size);\n\n        for (int i = 0; i < dp_size; ++i) {\n            data[i].resize(tp_size);\n        }\n\n        auto func = [&](int index, DeviceComm& d_comm, HostComm& h_comm) {\n            const int rank    = d_comm->rank(group);\n            const int n_ranks = d_comm->n_ranks(group);\n            const int g_rank  = d_comm->rank(0);\n            const int d       = g_rank / n_ranks;\n\n            const size_t max_count = max_tokens_ * dim;\n\n            if (h_comm->rank() == 0) {\n                std::cout << \"preparing data ... \" << std::flush;\n            }\n            std::mt19937                  gen{(unsigned)index};\n            std::uniform_int_distribution dist{0, 100};\n            data[d][rank].resize(max_count);\n            for (size_t i = 0; i < max_count; ++i) {\n                data[d][rank][i] = T(dist(gen));\n            }\n            h_comm->Sync();\n            if (h_comm->rank() == 0) {\n                std::cout << \"done.\\n\";\n            }\n\n            Context ctx{g_rank};\n\n            T* d_data = ctx.malloc<T>(max_count);\n\n            T* d_tmp = (T*)d_comm->Allocate(sizeof(T) * max_count * n_ranks);\n            d_comm->Register(d_tmp, sizeof(T) * max_count * n_ranks);\n\n            ctx.copy_n(data[d][rank].data(), max_count, d_data);\n\n            [[maybe_unused]] auto verify = [&](int64_t count) {\n                auto           total_count = count * n_ranks;\n                std::vector<T> res(total_count);\n                ctx.copy_n(d_tmp, total_count, res.data());\n                ctx.sync();\n                size_t diff = 0;\n                for (int r = 0; r < n_ranks; ++r) {\n                    for (auto i = 0; i < count; ++i) {\n                        auto& x = res[r * count + i];\n                        auto& y = data[d][r][i];\n                        diff += (x != y);\n                        if (diff == 1) {\n                            printf(\"%d: %f vs %f\\n\", (int)i, (float)x, (float)y);\n                        }\n                    }\n                }\n                if (diff) {\n                    printf(\"[rank %d] count = %d, diff = %lu\\n\", g_rank, (int)count, diff);\n                    std::this_thread::sleep_for(std::chrono::seconds(1));\n                    std::abort();\n                }\n            };\n\n            std::vector<float> deltas;\n            for (const auto& n : tokens_) {\n                const size_t count = (size_t)n * dim;  // dim = hidden_dim / tp\n                auto&        delta = deltas.emplace_back();\n                h_comm->Sync();\n                for (int i = 0; i < warmup_ + iters_; ++i) {\n                    check_cuda_error(cudaMemsetAsync(d_tmp, 0, sizeof(T) * count * n_ranks, ctx.stream));\n                    ctx.copy_n(d_data, count, d_tmp + rank * count);\n                    auto ms = ctx.exec([&](auto stream) {  //\n                        if (d_comm->Query(kHasAllGather2D) && 0) {\n                            d_comm->AllGather2D(\n                                d_tmp + rank * count, d_tmp, dim, count, dim, n, dtype, {1, 1}, group, stream);\n                        }\n                        else {\n                            d_comm->AllGather(d_tmp + rank * count, d_tmp, count, dtype, group, stream);\n                        }\n                    });\n                    if (i >= warmup_) {\n                        delta += ms;\n                    }\n                    // verify(count);\n                }\n                verify(count);\n            }\n\n            if (g_rank == 0) {\n                SummaryHeader(\"allgather\", dim, n_ranks);\n                for (size_t i = 0; i < tokens_.size(); ++i) {\n                    const float  avg   = deltas[i] / iters_;\n                    const size_t count = n_ranks * tokens_[i] * dim;\n                    const float  algbw = sizeof(T) * count / 1e9f / avg * 1000.f;\n                    const float  busbw = algbw * (n_ranks - 1) / n_ranks;\n\n                    SummaryEntry(tokens_[i], count, sizeof(T), avg, algbw, busbw);\n                }\n            }\n\n            d_comm->Deregister(d_tmp);\n            d_comm->Free(d_tmp);\n        };\n\n        std::vector<std::thread> threads;\n        for (size_t i = 0; i < d_comm_.size(); ++i) {\n            threads.emplace_back(func, i, std::ref(d_comm_[i]), std::ref(h_comm_[i]));\n        }\n        for (auto& t : threads) {\n            t.join();\n        }\n    }\n\n    template<class T>\n    void TestBroadcast(size_t dim, int group)\n    {\n        const auto dtype = data_type_v<T>;\n\n        const int tp_size = d_comm_[0]->n_ranks(group);\n        const int dp_size = d_comm_.size() / tp_size;\n\n        constexpr int root = 0;\n\n        vector<vector<T>> data(dp_size);\n\n        auto func = [&](int index, DeviceComm& d_comm, HostComm& h_comm) {\n            const int rank    = d_comm->rank(group);\n            const int n_ranks = d_comm->n_ranks(group);\n            const int g_rank  = d_comm->rank(0);\n            const int d       = g_rank / n_ranks;\n\n            const size_t max_count = max_tokens_ * dim;\n\n            if (h_comm->rank() == root) {\n                std::cout << \"preparing data ... \" << std::flush;\n                std::mt19937                  gen{(unsigned)index};\n                std::uniform_int_distribution dist{0, 100};\n                data[d].resize(max_count);\n                for (size_t i = 0; i < max_count; ++i) {\n                    data[d][i] = T(dist(gen));\n                }\n                std::cout << \"done.\\n\";\n            }\n\n            h_comm->Sync();\n\n            Context ctx{g_rank};\n\n            T* d_data = ctx.malloc<T>(max_count);\n\n            T* d_tmp = (T*)d_comm->Allocate(sizeof(T) * max_count);\n            d_comm->Register(d_tmp, sizeof(T) * max_count);\n\n            if (rank == root) {\n                ctx.copy_n(data[d].data(), max_count, d_data);\n            }\n\n            [[maybe_unused]] auto verify = [&](int64_t count) {\n                auto           total_count = count;\n                std::vector<T> res(total_count);\n                ctx.copy_n(d_tmp, total_count, res.data());\n                ctx.sync();\n                size_t diff = 0;\n                for (auto i = 0; i < count; ++i) {\n                    auto& x = res[i];\n                    auto& y = data[d][i];\n                    diff += (x != y);\n                    if (diff == 1) {\n                        printf(\"%d: %f vs %f\\n\", (int)i, (float)x, (float)y);\n                    }\n                }\n                if (diff) {\n                    printf(\"[rank %d] count = %d, diff = %lu\\n\", g_rank, (int)count, diff);\n                    std::this_thread::sleep_for(std::chrono::seconds(1));\n                    std::abort();\n                }\n            };\n\n            std::vector<float> deltas;\n            for (const auto& n : tokens_) {\n                const size_t count = (size_t)n * dim;  // dim = hidden_dim / tp\n                auto&        delta = deltas.emplace_back();\n                h_comm->Sync();\n                for (int i = 0; i < warmup_ + iters_; ++i) {\n                    check_cuda_error(cudaMemsetAsync(d_tmp, 0, sizeof(T) * count, ctx.stream));\n                    if (rank == root) {\n                        ctx.copy_n(d_data, count, d_tmp);\n                    }\n                    auto ms = ctx.exec([&](auto stream) {  //\n                        d_comm->Broadcast(d_tmp, d_tmp, count, dtype, 0, group, stream);\n                    });\n                    if (i >= warmup_) {\n                        delta += ms;\n                    }\n                    // verify(count);\n                }\n                verify(count);\n            }\n\n            if (g_rank == 0) {\n                SummaryHeader(\"broadcast\", dim, n_ranks);\n                for (size_t i = 0; i < tokens_.size(); ++i) {\n                    const float  avg   = deltas[i] / iters_;\n                    const size_t count = tokens_[i] * dim;\n                    const float  algbw = sizeof(T) * count / 1e9f / avg * 1000.f;\n                    const float  busbw = algbw;\n                    SummaryEntry(tokens_[i], count, sizeof(T), avg, algbw, busbw);\n                }\n            }\n\n            d_comm->Deregister(d_tmp);\n            d_comm->Free(d_tmp);\n        };\n\n        std::vector<std::thread> threads;\n        for (size_t i = 0; i < d_comm_.size(); ++i) {\n            threads.emplace_back(func, i, std::ref(d_comm_[i]), std::ref(h_comm_[i]));\n        }\n        for (auto& t : threads) {\n            t.join();\n        }\n    }\n\n    template<class T>\n    void TestAllreduceResidualBiasRMSnormEx(size_t dim, int group0, int group1)\n    {\n        const int tp_size_0 = d_comm_.at(0)->n_ranks(group0);\n        const int tp_size_1 = d_comm_.at(0)->n_ranks(group1);\n        const int dp_size_0 = d_comm_.size() / tp_size_0;\n        const int dp_size_1 = d_comm_.size() / tp_size_1;\n\n        const int inner_tp = std::gcd(tp_size_0, tp_size_1);\n\n        const auto dtype = data_type_v<T>;\n\n        std::mt19937                  gen{};\n        std::uniform_int_distribution dist{0, 31};  // 5 mantissa bits\n\n        TM_LOG_INFO(\"dp_size_0 %d, tp_size_0 %d\", dp_size_0, tp_size_0);\n        TM_LOG_INFO(\"dp_size_1 %d, tp_size_1 %d\", dp_size_1, tp_size_1);\n        TM_LOG_INFO(\"inner_tp %d\", inner_tp);\n\n        vector tokens = tokens_;\n        for (auto& x : tokens) {\n            x = (x + dp_size_0 - 1) / dp_size_0;\n        }\n        std::sort(tokens.begin(), tokens.end());\n        tokens.erase(std::unique(tokens.begin(), tokens.end()), tokens.end());\n        const size_t max_tokens = tokens.back();\n\n        vector<T> ref_data(dp_size_0 * max_tokens * dim);\n        vector<T> src_res(ref_data.size());\n        vector<T> ref_res(ref_data.size());\n\n        vector<T> weight(dim);\n        vector<T> bias(dim);\n\n        constexpr float eps      = 1e-5;\n        constexpr bool  has_bias = true;\n\n        std::cout << \"preparing data ... \" << std::flush;\n\n        for (size_t i = 0; i < dim; ++i) {\n            weight[i] = T(dist(gen));\n        }\n\n        if (has_bias) {\n            for (size_t i = 0; i < dim; ++i) {\n                bias[i] = T(dist(gen));\n            }\n        }\n\n        std::vector<std::vector<T>> src_data(tp_size_0);\n        for (int r = 0; r < tp_size_0; ++r) {\n            src_data[r].resize(ref_data.size());\n            for (size_t i = 0; i < ref_data.size(); ++i) {\n                src_data[r][i] = T(dist(gen));\n            }\n        }\n\n        for (size_t i = 0; i < src_res.size(); ++i) {\n            src_res[i] = T(dist(gen));\n        }\n\n        for (int r = 0; r < tp_size_0; ++r) {\n            for (size_t i = 0; i < ref_data.size(); ++i) {\n                ref_data[i] += src_data[r][i];\n            }\n        }\n\n        for (size_t i = 0; i < dp_size_0 * max_tokens; ++i) {\n            float sum = 0.f;\n            for (size_t d = 0; d < dim; ++d) {\n                size_t idx   = i * dim + d;\n                ref_res[idx] = src_res[idx] + ref_data[idx] + bias[d];  // r' <- r + (h + b)\n                sum += (float)ref_res[idx] * (float)ref_res[idx];\n            }\n            sum = 1 / (sqrtf(sum / dim) + eps);\n            for (size_t d = 0; d < dim; ++d) {\n                size_t idx    = i * dim + d;\n                ref_data[idx] = (float)ref_res[idx] * sum * (float)weight[d];  // h' <- norm(r) * w\n            }\n        }\n\n        std::cout << \"done\" << std::endl;\n\n        auto func = [&](int index, DeviceComm& d_comm, HostComm& h_comm) {\n            const int g_rank    = d_comm->rank(0);\n            const int g_n_ranks = d_comm->n_ranks(0);\n            const int dp_rank_0 = g_rank / tp_size_0;\n            const int dp_rank_1 = g_rank / tp_size_1;\n            const int tp_rank_0 = d_comm->rank(group0);\n            const int tp_rank_1 = d_comm->rank(group1);\n            const int local_id  = g_rank / inner_tp;  // which local partition this rank belongs to\n\n            // TM_LOG_INFO(\"g_rank %d, dp_rank_0 %d, tp_rank_0 %d, dp_rank_1 %d, tp_rank_1 %d, local_id %d\",\n            //             g_rank,\n            //             dp_rank_0,\n            //             tp_rank_0,\n            //             dp_rank_1,\n            //             tp_rank_1,\n            //             local_id);\n\n            const size_t max_count = max_tokens * dim;\n\n            Context ctx{g_rank};\n\n            T* d_bias    = ctx.malloc<T>(dim);\n            T* d_weight  = ctx.malloc<T>(dim);\n            T* d_data    = ctx.malloc<T>(max_count);\n            T* d_res     = ctx.malloc<T>(max_count);\n            T* d_tmp_res = ctx.malloc<T>(max_count);\n\n            T* d_tmp_data = (T*)d_comm->Allocate(sizeof(T) * dp_size_0 * max_count);\n            d_comm->Register(d_tmp_data, sizeof(T) * dp_size_0 * max_count);\n\n            ctx.copy_n(bias.data(), dim, d_bias);\n            ctx.copy_n(weight.data(), dim, d_weight);\n\n            [[maybe_unused]] auto verify = [&](auto n) {\n                const size_t dst_tokens = n / dp_size_1 * dp_size_0;\n                const size_t dst_count  = dst_tokens * dim;\n                vector<T>    h_data(dst_count);\n                ctx.copy_n(d_tmp_data + dp_rank_1 * dst_count, dst_count, h_data.data());\n                const size_t local_tokens = (size_t)n / dp_size_1;\n                const size_t local_count  = local_tokens * dim;\n                const size_t slice        = (local_tokens + inner_tp - 1) / inner_tp * dim;\n                const size_t first        = std::min(local_count, g_rank % inner_tp * slice);\n                const size_t last         = std::min(local_count, first + slice);\n                vector<T>    h_res(last - first);\n                ctx.copy_n(d_tmp_res + first, h_res.size(), h_res.data());\n                ctx.sync();\n                size_t res_diff = 0;\n                for (size_t i = first; i < last; ++i) {\n                    auto& val  = h_res[i - first];\n                    auto& ref  = ref_res[local_id * local_count + i];\n                    int   diff = !(val == ref);\n                    if (res_diff < 5 && diff) {\n                        printf(\"[rank %d], %ld: %f vs %f\\n\", g_rank, i - first, (float)val, (float)ref);\n                    }\n                    res_diff += diff;\n                }\n                float data_diff = 0;\n                for (size_t i = 0; i < dst_count; ++i) {\n                    float diff = (float)h_data[i] - (float)ref_data[dp_rank_1 * dst_count + i];\n                    data_diff += std::abs(diff);\n                }\n                data_diff /= dst_count;\n                if (res_diff || data_diff > 0.1f || std::isnan(data_diff)) {\n                    printf(\n                        \"[rank %d] count = %d, res_diff = %lu, data_diff = %f\\n\", g_rank, (int)n, res_diff, data_diff);\n                    std::this_thread::sleep_for(std::chrono::seconds(5));\n                    std::abort();\n                }\n                else if (tp_rank_1 == 0) {\n                    printf(\"[rank %d] count = %d, data_diff = %f\\n\", g_rank, (int)n, data_diff);\n                }\n            };\n\n            std::vector<std::pair<int, float>> stats;\n            for (const auto& n : tokens) {\n                if (n % dp_size_1) {\n                    if (g_rank == 0) {\n                        TM_LOG_INFO(\"Skipped %d\", n);\n                    }\n                    continue;\n                }\n                // const int src_token_num = n;\n                // const int dst_token_num = n / dp_size_1 * dp_size_0;\n                const size_t count       = (size_t)n * dim;\n                const size_t local_count = count / dp_size_1;\n                std::vector  local_token_nums(dp_size_0 * dp_size_1, n / dp_size_1);\n                ctx.copy_n(src_data[tp_rank_0].data() + dp_rank_0 * count, count, d_data);\n                ctx.copy_n(src_res.data() + local_id * local_count, local_count, d_res);\n                auto& [_, delta] = stats.emplace_back(n * dp_size_0, 0.f);\n                h_comm->Sync();\n                for (int i = 0; i < warmup_ + iters_; ++i) {\n                    ctx.copy_n(d_data, count, d_tmp_data + dp_rank_0 * count);\n                    ctx.copy_n(d_res, local_count, d_tmp_res);\n                    auto ms = ctx.exec([&](auto stream) {  //\n                        d_comm->AllreduceResidualBiasRMSnormEx(d_tmp_data,\n                                                               d_tmp_res,\n                                                               has_bias ? d_bias : nullptr,\n                                                               d_weight,\n                                                               eps,\n                                                               dim,\n                                                               dtype,\n                                                               group0,\n                                                               group1,\n                                                               local_token_nums.data(),\n                                                               stream);\n                    });\n                    if (i >= warmup_) {\n                        delta += ms;\n                    }\n                    // verify(n);\n                }\n                verify(n);\n            }\n\n            d_comm->Deregister(d_tmp_data);\n            d_comm->Free(d_tmp_data);\n\n            if (g_rank == 0) {\n                SummaryHeader(\"rs | rmsnorm | ag\", dim, g_n_ranks);\n                for (const auto& [num, ms] : stats) {\n                    const float  avg    = ms / iters_;\n                    const size_t count  = num * dim;\n                    const float  algbw  = sizeof(T) * count / 1e9f / avg * 1000.f;\n                    const float  factor = (tp_size_0 + tp_size_1 - 2) / (float)g_n_ranks;\n                    const float  busbw  = algbw * factor;\n                    // g_n_ranks;\n                    SummaryEntry(num, count, sizeof(T), avg, algbw, busbw);\n                }\n            }\n        };\n\n        std::vector<std::thread> threads;\n        for (size_t i = 0; i < d_comm_.size(); ++i) {\n            threads.emplace_back(func, i, std::ref(d_comm_[i]), std::ref(h_comm_[i]));\n        }\n        for (auto& t : threads) {\n            t.join();\n        }\n    }\n\n    void SummaryHeader(const char* name, int dim, int world_size)\n    {\n        printf(\"[%s] dim %d tp %d warmup %d iters %d\\n\", name, dim, world_size, warmup_, iters_);\n        printf(\"%15s%15s%15s%15s%15s%15s\\n\", \"num\", \"count\", \"size\", \"time\", \"algbw\", \"busbw\");\n        printf(\"%15s%15s%15s%15s%15s%15s\\n\", \"(tokens)\", \"(elements)\", \"(MB)\", \"(us)\", \"(GB/s)\", \"(GB/s)\");\n    }\n\n    void SummaryEntry(int num, size_t count, size_t elem_size, float time, float algbw, float busbw)\n    {\n        float mb_size = count * elem_size / (1024.f * 1024);\n        printf(\"%15d%15ld%15.2f%15.3f%15.3f%15.3f\\n\", num, count, mb_size, time * 1e3f, algbw, busbw);\n    }\n};\n\nint main(int argc, char* argv[])\n{\n\n    TestComm test;\n\n    test.Run(2048,  //\n             128000,\n             -1,\n             10,\n             10000,\n             //   {1024});\n             //   {1024, 2048, 4096, 8192});\n             // {512});\n             //    {1, 2, 3, 4, 5, 6, 7, 8, 12, 16, 24, 32, 48, 64, 96, 128});\n             //  {2, 4, 6, 8, 12, 16, 24, 32, 48, 64, 96, 128});\n             //  {128, 256, 512, 1024, 2048, 4096, 8192});\n             //  {8, 16, 24, 32, 48, 64, 96, 128, 192, 256, 384, 512, 768, 1024, 1536, 2048, 4096, 6144, 8192});\n             //   {8192, 16384, 32768});\n             //  {1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 192, 256, 384, 512, 768, 1024, 8192});\n             {1,   2,   4,   6,   8,   12,   16,   24,   32,   48,   64,   96,   128,\n              192, 256, 384, 512, 768, 1024, 1536, 2048, 3072, 4096, 6144, 8192, 16384});\n\n    return 0;\n}\n"
  },
  {
    "path": "src/turbomind/comm/test_host_comm.cc",
    "content": "\n#include <iostream>\n#include <numeric>\n#include <thread>\n\n#include \"src/turbomind/comm/host_comm.h\"\n\nusing namespace turbomind;\nusing namespace turbomind::comm;\n\nint main(int argc, char* argv[])\n{\n    const int                    N        = 32;\n    std::unique_ptr<HostGroupId> group_id = CreateHostGroupId({});\n    group_id->Initialize();\n    std::vector<std::thread> threads;\n    for (int r = 0; r < N; ++r) {\n        threads.emplace_back([&, r] {\n            HostComm world = group_id->CreateCommunicator(N, r);\n\n            HostComm group = world;\n            group          = world->Split(r / (N / 4), 0);\n\n            auto tick = std::chrono::steady_clock::now();\n\n            // int data = 100;\n            // for (int i = 0; i < 10000; ++i, ++data) {\n            //     group->Sync(true);\n            // }\n\n            volatile int a;\n            volatile int b;\n            for (int i = 0; i < 1; ++i) {\n                a      = AllReduce(group, r, RedOp::kSum);\n                auto v = AllGather(group, r);\n                b      = std::accumulate(v.begin(), v.end(), 0);\n                for (int j = 0; j < N; ++j) {\n                    world->Sync();\n                    if (j == r) {\n                        std::cout << a << \" \" << b << std::endl;\n                    }\n                }\n            }\n\n            auto tock = std::chrono::steady_clock::now();\n\n            for (int i = 0; i < N; ++i) {\n                world->Sync();\n                if (i == r) {\n                    std::cout << std::chrono::duration<float, std::milli>(tock - tick).count() << std::endl;\n                }\n            }\n        });\n    }\n\n    std::cout << \"main thread waiting.\\n\";\n\n    for (auto& t : threads) {\n        t.join();\n    }\n\n    return 0;\n}\n"
  },
  {
    "path": "src/turbomind/comm/thread_comm.cc",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include <algorithm>\n#include <atomic>\n#include <deque>\n#include <memory>\n#include <mutex>\n#include <new>\n\n#include \"src/turbomind/comm/barrier.h\"\n#include \"src/turbomind/comm/host_comm.h\"\n#include \"src/turbomind/core/check.h\"\n#include \"src/turbomind/core/data_type.h\"\n#include \"src/turbomind/core/serdes.h\"\nnamespace turbomind::comm {\n\nstruct ThreadCommImpl: public HostCommImpl {\n\n    class State {\n    public:\n        explicit State(int n): n_{n}, channels_(n * n), barrier_{n} {}\n\n        std::atomic<void*>& channel(int from, int to)\n        {\n            return channels_[from * n_ + to];\n        }\n\n        void sync()\n        {\n            barrier_.arrive_and_wait();\n        }\n\n    private:\n        int                            n_;\n        std::deque<std::atomic<void*>> channels_;\n        Barrier                        barrier_;\n    };\n\n    std::shared_ptr<State> state_;\n\n    int n_ranks_;\n    int rank_;\n\n    ThreadCommImpl(int n_ranks, std::shared_ptr<State> state, int rank):\n        state_{std::move(state)}, n_ranks_{n_ranks}, rank_{rank}\n    {\n    }\n\n    int rank() const override\n    {\n        return rank_;\n    }\n\n    int n_ranks() const override\n    {\n        return n_ranks_;\n    }\n\n    bool is_same_process() const override\n    {\n        return true;\n    }\n\n    std::atomic<void*>& channel(int from, int to)\n    {\n        return state_->channel(from, to);\n    }\n\n    std::shared_ptr<HostCommImpl> Split(int color, int key) override\n    {\n        TM_CHECK(color >= 0);\n\n        auto ranks = comm::AllGather(this, std::make_tuple(color, key, rank_));\n\n        auto same_color = [&](auto x) { return std::get<0>(x) == color; };\n        ranks.erase(std::stable_partition(ranks.begin(), ranks.end(), same_color), ranks.end());\n\n        std::stable_sort(ranks.begin(), ranks.end(), [](auto& a, auto& b) { return a < b; });\n\n        std::shared_ptr<State> state;\n\n        int rank = -1;\n        for (int i = 0; i < ranks.size(); ++i) {\n            if (std::get<2>(ranks[i]) == rank_) {\n                rank = i;\n            }\n        }\n\n        TM_CHECK_GE(rank, 0);\n\n        if (rank == 0) {\n            state = std::make_shared<State>(ranks.size());\n        }\n\n        auto states = comm::AllGather(this, state);\n        if (rank != 0) {\n            const int root = std::get<2>(ranks[0]);\n            state          = states[root];\n        }\n\n        return std::make_shared<ThreadCommImpl>(ranks.size(), state, rank);\n    }\n\n    void Sync(bool blocking) override\n    {\n        if (n_ranks_ == 1) {\n            return;\n        }\n\n        if (blocking) {\n            state_->sync();\n            return;\n        }\n\n        for (int r = 0; r < n_ranks_; ++r) {\n            if (r != rank_) {\n                auto& c = channel(rank_, r);\n                void* expected{};\n                while (!c.compare_exchange_weak(expected, (void*)1, std::memory_order_release)) {\n                    expected = {};\n                }\n            }\n        }\n        for (int r = 0; r < n_ranks_; ++r) {\n            if (r != rank_) {\n                auto& c        = channel(r, rank_);\n                void* expected = (void*)1;\n                while (!c.compare_exchange_weak(expected, nullptr, std::memory_order_acquire)) {\n                    expected = (void*)1;\n                }\n            }\n        }\n    }\n\n    void Broadcast(void* data, int count, DataType dtype, int root, copy_fn copy, ser_fn ser, des_fn des) override\n    {\n        TM_CHECK(copy);\n        if (n_ranks_ == 1) {\n            return;\n        }\n        // transform root to global rank\n        if (rank_ == root) {\n            for (int r = 0; r < n_ranks_; ++r) {\n                if (r != rank_) {\n                    auto& c = channel(rank_, r);\n                    void* expected{};\n                    while (!c.compare_exchange_weak(expected, data, std::memory_order_release)) {\n                        expected = {};\n                    }\n                }\n            }\n            for (int r = 0; r < n_ranks_; ++r) {\n                if (r != rank_) {\n                    auto& c = channel(rank_, r);\n                    while (c.load(std::memory_order_relaxed)) {}\n                }\n            }\n        }\n        else {\n            auto& c = channel(root, rank_);\n            void* incoming{};\n            while (!(incoming = c.load(std::memory_order_acquire))) {}\n            copy(incoming, count, data, 0);\n            c.store(nullptr, std::memory_order_relaxed);\n        }\n    }\n\n    void AllGather(void* data, int count, DataType dtype, copy_fn copy, ser_fn ser, des_fn des) override\n    {\n        TM_CHECK(copy);\n        if (n_ranks_ == 1) {\n            return;\n        }\n        for (int r = 0; r < n_ranks_; ++r) {\n            if (r != rank_) {\n                auto& c = channel(rank_, r);\n                void* expected{};\n                while (!c.compare_exchange_weak(expected, data, std::memory_order_release)) {\n                    expected = {};\n                }\n            }\n        }\n        for (int r = 0; r < n_ranks_; ++r) {\n            if (r != rank_) {\n                auto& c = channel(r, rank_);\n                void* incoming{};\n                while (!(incoming = c.load(std::memory_order_acquire))) {}\n                copy(incoming, count, data, r * count);\n                c.store(nullptr, std::memory_order_relaxed);\n            }\n        }\n        for (int r = 0; r < n_ranks_; ++r) {\n            if (r != rank_) {\n                auto& c = channel(rank_, r);\n                while (c.load(std::memory_order_relaxed)) {}\n            }\n        }\n    }\n\n    template<class T, RedOp op>\n    static void reduce(void* src, int n, void* dst, int offset)\n    {\n        for (int i = 0; i < n; ++i) {\n            auto& s = *((T*)src + offset + i);\n            auto& a = *((T*)dst + offset + i);\n            if constexpr (op == RedOp::kSum) {\n                a += s;\n            }\n            else if constexpr (op == RedOp::kMin) {\n                a = std::min(a, s);\n            }\n            else if constexpr (op == RedOp::kMax) {\n                a = std::max(a, s);\n            }\n            else {\n                static_assert(sizeof(T) != sizeof(T), \"not implemented\");\n            }\n        }\n    }\n\n    static reduce_fn get_reduce(DataType dtype, RedOp red_op)\n    {\n        auto dispatch_op = [&](auto t) -> reduce_fn {\n            using T = decltype(t);\n            switch (red_op) {\n                case RedOp::kSum:\n                    return reduce<T, RedOp::kSum>;\n                case RedOp::kMax:\n                    return reduce<T, RedOp::kMax>;\n                case RedOp::kMin:\n                    return reduce<T, RedOp::kMin>;\n                default:\n                    return {};\n            }\n        };\n        auto dispatch = [&]() -> reduce_fn {\n            switch (dtype) {\n                case kInt32:\n                    return dispatch_op(int32_t{});\n                case kInt64:\n                    return dispatch_op(int64_t{});\n                case kUint32:\n                    return dispatch_op(uint32_t{});\n                case kUint64:\n                    return dispatch_op(uint64_t{});\n                default:\n                    return {};\n            }\n        };\n        if (auto fn = dispatch()) {\n            return fn;\n        }\n        else {\n            throw std::runtime_error(\"not implemented\");\n            return {};\n        }\n    }\n\n    void AllReduce(void* data, int count, DataType dtype, RedOp red_op) override\n    {\n        const auto reduce    = get_reduce(dtype, red_op);\n        const auto elem_size = byte_size(dtype);\n        if (n_ranks_ == 1) {\n            return;\n        }\n        std::unique_ptr<char[]> tmp((char*)::operator new[](elem_size* count));\n        std::copy_n((char*)data, elem_size * count, tmp.get());\n        for (int r = 0; r < n_ranks_; ++r) {\n            if (r != rank_) {\n                auto& c = channel(rank_, r);\n                void* expected{};\n                while (!c.compare_exchange_weak(expected, (void*)tmp.get(), std::memory_order_release)) {\n                    expected = {};\n                }\n            }\n        }\n        for (int r = 0; r < n_ranks_; ++r) {\n            if (r != rank_) {\n                auto& c = channel(r, rank_);\n                void* incoming{};\n                while (!(incoming = c.load(std::memory_order_acquire))) {}\n                reduce(incoming, count, data, 0);\n                c.store(nullptr, std::memory_order_relaxed);\n            }\n        }\n        for (int r = 0; r < n_ranks_; ++r) {\n            if (r != rank_) {\n                auto& c = channel(rank_, r);\n                while (c.load(std::memory_order_relaxed)) {}\n            }\n        }\n    }\n};\n\nclass ThreadGroupId: public HostGroupId {\npublic:\n    void Initialize() override\n    {\n        internal_ = std::make_shared<Internal>();\n    }\n\n    void Export(std::ostream& os) override\n    {\n        TM_CHECK((bool)internal_);  // `Initialize` must come befor `Export`\n\n        const void* ptr = this;\n        os.write((const char*)&ptr, sizeof(ptr));\n    }\n\n    void Import(std::istream& is) override\n    {\n        void* ptr{};\n        is.read((char*)&ptr, sizeof(ptr));\n        internal_ = reinterpret_cast<ThreadGroupId*>(ptr)->internal_;\n\n        TM_CHECK((bool)internal_);\n    }\n\n    HostComm CreateCommunicator(int n_ranks, int rank, int node_rank = 0) override\n    {\n        auto init_shared_state = [&] {  //\n            internal_->state = std::make_shared<ThreadCommImpl::State>(n_ranks);\n        };\n\n        TM_CHECK((bool)internal_);\n\n        // One of the rank initialize the shared state\n        std::call_once(internal_->flag, init_shared_state);\n\n        TM_CHECK((bool)internal_->state);\n\n        auto impl = std::make_shared<ThreadCommImpl>(n_ranks, internal_->state, rank);\n\n        return std::static_pointer_cast<HostCommImpl>(impl);\n    }\n\nprivate:\n    struct Internal {\n        std::once_flag                         flag;\n        std::shared_ptr<ThreadCommImpl::State> state;\n    };\n\nprivate:\n    std::shared_ptr<Internal> internal_;\n};\n\nstd::unique_ptr<HostGroupId> CreateThreadGroupId()\n{\n    return std::make_unique<ThreadGroupId>();\n}\n\ntemplate<class Archive>\nvoid save(Archive& ar, const std::shared_ptr<ThreadCommImpl::State>& p)\n{\n    TM_CHECK(false) << \"should never be called\";\n}\n\ntemplate<class Archive>\nvoid load(Archive& ar, std::shared_ptr<ThreadCommImpl::State>& p)\n{\n    TM_CHECK(false) << \"should never be called\";\n}\n\n}  // namespace turbomind::comm\n"
  },
  {
    "path": "src/turbomind/core/CMakeLists.txt",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\ncmake_minimum_required(VERSION 3.11)\n\nadd_library(core STATIC\n        check.cc\n        allocator.cc\n        stream.cc\n        context.cc\n        buffer.cc\n        layout.cc\n        tensor.cc\n        tensor.cu\n        module.cc\n        copy.cc)\n\ntarget_link_libraries(core PUBLIC cuda_utils logger CUDA::cudart CUDA::cuda_driver)\n\nset_property(TARGET core PROPERTY POSITION_INDEPENDENT_CODE ON)\nset_property(TARGET core PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)\n\ntarget_compile_options(core PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xptxas=-v>)\n\nif (BUILD_TEST)\n    add_executable(test_core test_core.cc)\n    target_link_libraries(test_core PRIVATE core logger Catch2::Catch2WithMain)\nendif ()\n"
  },
  {
    "path": "src/turbomind/core/allocator.cc",
    "content": "\n#include <cuda_runtime.h>\n#include <cuda_runtime_api.h>\n\n#include \"src/turbomind/core/allocator.h\"\n#include \"src/turbomind/core/check.h\"\n\nnamespace turbomind::core {\n\nAllocatorImpl::~AllocatorImpl() = default;\n\nStream AllocatorImpl::stream() const noexcept\n{\n    return Stream{};\n}\n\nclass CudaMemPoolAllocator: public AllocatorImpl {\npublic:\n    CudaMemPoolAllocator(Stream stream, bool use_default_pool):\n        pool_{}, stream_{stream}, device_{kDEVICE}, use_default_pool_{use_default_pool}\n    {\n        check_cuda_error(cudaGetDevice(&device_.id));\n        if (use_default_pool_) {\n            check_cuda_error(cudaDeviceGetDefaultMemPool(&pool_, device_.id));\n        }\n        else {\n            cudaMemPoolProps props{};\n            props.allocType     = cudaMemAllocationTypePinned;\n            props.handleTypes   = cudaMemHandleTypeNone;\n            props.location.type = cudaMemLocationTypeDevice;\n            props.location.id   = device_.id;\n            check_cuda_error(cudaMemPoolCreate(&pool_, &props));\n            cuuint64_t thres = (cuuint64_t)-1;\n            check_cuda_error(cudaMemPoolSetAttribute(pool_, cudaMemPoolAttrReleaseThreshold, &thres));\n        }\n    }\n\n    ~CudaMemPoolAllocator() override\n    {\n        if (!use_default_pool_) {\n            check_cuda_error(cudaMemPoolDestroy(pool_));\n        }\n        pool_ = {};\n    }\n\n    void* allocate(ssize_t size) override\n    {\n        void* ptr{};\n        check_cuda_error(cudaMallocFromPoolAsync(&ptr, size, pool_, stream_.handle()));\n        return ptr;\n    }\n\n    void deallocate(void* p, ssize_t) override\n    {\n        check_cuda_error(cudaFreeAsync(p, stream_.handle()));\n    }\n\n    Device device() const noexcept override\n    {\n        return device_;\n    }\n\n    Stream stream() const noexcept override\n    {\n        return stream_;\n    }\n\n    void trim(size_t bytes_to_keep)\n    {\n        check_cuda_error(cudaMemPoolTrimTo(pool_, bytes_to_keep));\n    }\n\nprivate:\n    cudaMemPool_t pool_;\n    Stream        stream_;\n    Device        device_;\n    bool          use_default_pool_;\n};\n\nclass CudaAllocator: public AllocatorImpl {\npublic:\n    void* allocate(ssize_t size) override\n    {\n        void* ptr{};\n        check_cuda_error(cudaMalloc(&ptr, size));\n        return ptr;\n    }\n\n    void deallocate(void* p, ssize_t) override\n    {\n        check_cuda_error(cudaFree(p));\n    }\n\n    Device device() const noexcept override\n    {\n        return kDEVICE;\n    }\n};\n\nclass CudaHostAllocator: public AllocatorImpl {\npublic:\n    void* allocate(ssize_t size) override\n    {\n        void* ptr{};\n        check_cuda_error(cudaHostAlloc(&ptr, size, cudaHostAllocDefault));\n        return ptr;\n    }\n\n    void deallocate(void* p, ssize_t) override\n    {\n        check_cuda_error(cudaFreeHost(p));\n    }\n\n    Device device() const noexcept override\n    {\n        return kCPUpinned;\n    }\n};\n\nclass HostAllocator: public AllocatorImpl {\npublic:\n    void* allocate(ssize_t size) override\n    {\n        return ::operator new(size);\n    }\n\n    void deallocate(void* p, ssize_t) override\n    {\n        ::operator delete(p);\n    }\n\n    Device device() const noexcept override\n    {\n        return kCPU;\n    }\n};\n\nAllocator::Allocator(DeviceType type)\n{\n    impl_ = [&]() -> shared_ptr<AllocatorImpl> {\n        switch (type) {\n            case kCPU:\n                return std::make_shared<HostAllocator>();\n            case kDEVICE:\n                return std::make_shared<CudaAllocator>();\n            case kCPUpinned:\n                return std::make_shared<CudaHostAllocator>();\n        }\n        return {};\n    }();\n    TM_CHECK_NOTNULL(impl_);\n}\n\nAllocator::Allocator(Stream stream, bool use_default_pool)\n{\n    impl_ = std::make_shared<CudaMemPoolAllocator>(std::move(stream), use_default_pool);\n}\n\n}  // namespace turbomind::core\n"
  },
  {
    "path": "src/turbomind/core/allocator.h",
    "content": "#pragma once\n\n#include <algorithm>\n#include <functional>\n\n#include \"src/turbomind/core/check.h\"\n#include \"src/turbomind/core/common.h\"\n#include \"src/turbomind/core/stream.h\"\n\n#include \"src/turbomind/kernels/core/math.h\"\n\nnamespace turbomind {\n\nenum class DeviceType : int\n{\n    kCPU,\n    kCPUpinned,\n    kDEVICE\n};\n\ninline constexpr DeviceType kCPU       = DeviceType::kCPU;\ninline constexpr DeviceType kCPUpinned = DeviceType::kCPUpinned;\ninline constexpr DeviceType kDEVICE    = DeviceType::kDEVICE;\n\nconstexpr const char* to_string(DeviceType device)\n{\n    switch (device) {\n        case kCPU:\n            return \"cpu\";\n        case kCPUpinned:\n            return \"cpu_pinned\";\n        case kDEVICE:\n            return \"device\";\n    }\n    return \"\";\n}\n\ninline std::ostream& operator<<(std::ostream& os, DeviceType device)\n{\n    return os << to_string(device);\n}\n\n}  // namespace turbomind\n\nnamespace turbomind::core {\n\nstruct Device {\n    DeviceType type;\n    int        id;\n    Device(): Device{kCPU} {}\n    Device(DeviceType type_): type{type_}, id{-1} {}\n    Device(DeviceType type_, int device_): type{type_}, id{device_} {}\n    friend bool operator==(const Device& a, const Device& b)\n    {\n        return a.type == b.type && a.id == b.id;\n    }\n    friend bool operator!=(const Device& a, const Device& b)\n    {\n        return !(a == b);\n    }\n};\n\nclass AllocatorImpl {\npublic:\n    virtual ~AllocatorImpl();\n\n    virtual void* allocate(ssize_t size) = 0;\n\n    virtual void deallocate(void* p, ssize_t size) = 0;\n\n    // Returns invalid stream by default\n    virtual Stream stream() const noexcept;\n\n    virtual Device device() const noexcept = 0;\n\n    virtual void trim(size_t bytes_to_keep){};\n};\n\nclass Allocator {\npublic:\n    Allocator() = default;\n\n    explicit Allocator(DeviceType type);\n\n    Allocator(Stream stream, bool use_default_pool);\n\n    Allocator(shared_ptr<AllocatorImpl> impl): impl_{std::move(impl)} {};\n\n    AllocatorImpl* operator->() const\n    {\n        TM_CHECK_NOTNULL(impl_);\n        return impl_.get();\n    }\n\n    explicit operator bool() const noexcept\n    {\n        return static_cast<bool>(impl_);\n    }\n\n    friend bool operator==(const Allocator& a, const Allocator& b)\n    {\n        return a.impl_ == b.impl_;\n    }\n\n    friend bool operator!=(const Allocator& a, const Allocator& b)\n    {\n        return !(a == b);\n    }\n\n    template<class T, class... Args>\n    shared_ptr<T> adapt(Args&&... args) const\n    {\n        return {std::make_shared<T>(impl_, ((Args &&) args)...)};\n    }\n\nprivate:\n    shared_ptr<AllocatorImpl> impl_;\n};\n\nclass StackAllocatorImpl: public AllocatorImpl {\npublic:\n    static constexpr ssize_t kAlignment = 256;\n\n    explicit StackAllocatorImpl(shared_ptr<AllocatorImpl> underlying_impl): underlying_impl_{std::move(underlying_impl)}\n    {\n    }\n\n    ~StackAllocatorImpl() override\n    {\n        if (cached_beg_) {\n            underlying_impl_->deallocate(cached_beg_, cached_end_ - cached_beg_);\n        }\n    }\n\n    void* allocate(ssize_t size) override\n    {\n        size = round_up(size, kAlignment);\n\n        void* p{};\n        if (cached_ptr_ + size <= cached_end_) {\n            p = cached_ptr_;\n            cached_ptr_ += size;\n        }\n        else {\n            TM_CHECK(!cached_beg_);\n            p = underlying_impl_->allocate(size);\n        }\n\n        // TM_LOG_ERROR(\"allocate %p, %ld\", p, size);\n\n        size_ += size;\n        ++num_;\n        max_size_ = std::max(size_, max_size_);\n        num_      = std::max(num_, max_num_);\n        return p;\n    }\n\n    void deallocate(void* p, ssize_t size) override\n    {\n        size = round_up(size, kAlignment);\n\n        // TM_LOG_ERROR(\"deallocate %p, %p, %ld\", p, cached_ptr_, size);\n\n        if ((char*)p + size == cached_ptr_) {\n            cached_ptr_ -= size;\n        }\n        else {\n            TM_CHECK(!cached_beg_);\n            underlying_impl_->deallocate(p, size);\n        }\n        size_ -= size;\n        --num_;\n    }\n\n    Stream stream() const noexcept override\n    {\n        return underlying_impl_->stream();\n    }\n\n    Device device() const noexcept override\n    {\n        return underlying_impl_->device();\n    }\n\n    void iter()\n    {\n        TM_CHECK_EQ((void*)cached_beg_, (void*)cached_ptr_);\n        auto excpected = max_size_ + kAlignment * max_num_;\n        if (cached_end_ - cached_beg_ < excpected) {\n            if (cached_beg_) {\n                underlying_impl_->deallocate(cached_beg_, cached_end_ - cached_beg_);\n            }\n            cached_ptr_ = cached_beg_ = (char*)underlying_impl_->allocate(excpected);\n            cached_end_               = cached_beg_ + excpected;\n        }\n        size_ = num_ = max_size_ = max_num_ = 0;\n    }\n\nprivate:\n    ssize_t size_{};\n    ssize_t num_{};\n    ssize_t max_size_{};\n    ssize_t max_num_{};\n\n    char* cached_beg_{};\n    char* cached_end_{};\n    char* cached_ptr_{};\n\n    std::shared_ptr<AllocatorImpl> underlying_impl_;\n};\n\nclass SimpleAllocator: public AllocatorImpl {\npublic:\n    template<class Alloc, class Dealloc>\n    static Allocator Create(Alloc&& alloc, Dealloc&& dealloc, Device device)\n    {\n        return Allocator{std::make_shared<SimpleAllocator>((Alloc &&) alloc, (Dealloc &&) dealloc, device)};\n    }\n\n    template<class Alloc, class Dealloc>\n    SimpleAllocator(Alloc&& alloc, Dealloc&& dealloc, Device device):\n        alloc_{std::move(alloc)}, dealloc_{std ::move(dealloc)}, device_{device}\n    {\n    }\n\n    void* allocate(ssize_t size) override\n    {\n        return alloc_(size);\n    };\n\n    void deallocate(void* p, ssize_t size) override\n    {\n        return dealloc_(p, size);\n    }\n\n    Device device() const noexcept override\n    {\n        return device_;\n    }\n\nprivate:\n    std::function<void*(ssize_t)>       alloc_;\n    std::function<void(void*, ssize_t)> dealloc_;\n    Device                              device_;\n};\n\n}  // namespace turbomind::core\n"
  },
  {
    "path": "src/turbomind/core/buffer.cc",
    "content": "\n#include \"src/turbomind/core/buffer.h\"\n#include \"src/turbomind/core/check.h\"\n#include \"src/turbomind/core/context.h\"\n#include \"src/turbomind/core/data_type.h\"\n#include \"src/turbomind/core/stream.h\"\nnamespace turbomind::core {\n\nBuffer Buffer::view(DataType dtype) const\n{\n    auto b = *this;\n    if (dtype == dtype_) {\n        return b;\n    }\n    b.dtype_ = dtype;\n    b.size_  = numel(dtype, byte_size());\n    if (base_) {\n        b.base_ = numel(dtype, turbomind::byte_size(dtype_, base_));\n    }\n    return b;\n}\n\nBuffer Buffer::slice(ssize_t base, ssize_t size) const\n{\n    TM_CHECK_LE(base + size, size_);\n    auto b = *this;\n    b.base_ += base;\n    if (size == -1) {\n        b.size_ -= base;\n    }\n    else {\n        b.size_ = size;\n    }\n    return b;\n}\n\nstd::ostream& operator<<(std::ostream& os, const Buffer& b)\n{\n    os << b.dtype() << \"[\" << b.size() << \"]@\" << b.data_;\n    if (b.base_) {\n        os << \"+\" << b.base_;\n    }\n    return os;\n}\n\nvoid Copy(const Buffer& a, ssize_t n, Ref<Buffer> b_, const Stream& stream)\n{\n    auto& b = b_.get();\n    TM_CHECK_EQ(a.dtype(), b.dtype());\n    TM_CHECK_LE(n, a.size());\n    TM_CHECK_LE(n, b.size());\n    if (auto size = byte_size(a.dtype(), n)) {\n        check_cuda_error(cudaMemcpyAsync(b.raw_data(), a.raw_data(), size, cudaMemcpyDefault, stream.handle()));\n    }\n}\n\nvoid Copy(const Buffer& a, ssize_t n, Ref<Buffer> b_)\n{\n    Copy(a, n, b_, Context::stream());\n}\n\nvoid Copy(const Buffer& a, Ref<Buffer> b_, const Stream& stream)\n{\n    TM_CHECK_EQ(a.size(), b_.get().size());\n    Copy(a, a.size(), b_, stream);\n}\n\nvoid Copy(const Buffer& a, Ref<Buffer> b_)\n{\n    Copy(a, b_, Context::stream());\n}\n\nnamespace detail {\n\nvoid* Copy(const void* a, ssize_t n, void* b, const Stream& stream)\n{\n    if (n) {\n        check_cuda_error(cudaMemcpyAsync(b, a, n, cudaMemcpyDefault, stream.handle()));\n    }\n    return (uint8_t*)b + n;\n}\n\n}  // namespace detail\n\nvoid Clear(Ref<Buffer> b_, const Stream& stream)\n{\n    auto& b = b_.get();\n    if (auto size = b.byte_size()) {\n        check_cuda_error(cudaMemsetAsync(b.raw_data(), 0, b.byte_size(), stream.handle()));\n    }\n}\n\nvoid Clear(Ref<Buffer> b_)\n{\n    Clear(b_, Context::stream());\n}\n\n}  // namespace turbomind::core\n"
  },
  {
    "path": "src/turbomind/core/buffer.h",
    "content": "#pragma once\n\n#include <memory>\n\n#include <cuda_runtime.h>\n#include <type_traits>\n\n#include \"src/turbomind/core/allocator.h\"\n#include \"src/turbomind/core/check.h\"\n#include \"src/turbomind/core/common.h\"\n#include \"src/turbomind/core/context.h\"\n#include \"src/turbomind/core/data_type.h\"\n#include \"src/turbomind/core/serdes.h\"\n\nnamespace turbomind::core {\n\nclass Buffer {\npublic:\n    Buffer(): data_{}, base_{}, size_{}, device_{}, dtype_{} {}\n\n    // Typed empty buffer\n    explicit Buffer(DataType dtype): Buffer()\n    {\n        dtype_ = dtype;\n    }\n\n    // Reference into `data` buffer\n    template<class T>\n    Buffer(T* data, ssize_t size, Device device):\n        data_{data, [](auto) {}}, base_{}, size_{size}, device_{device}, dtype_{data_type_v<T>}\n    {\n    }\n\n    Buffer(void* data, ssize_t size, DataType dtype, Device device):\n        data_{data, [](auto) {}}, base_{}, size_{size}, device_{device}, dtype_{dtype}\n    {\n    }\n\n    // Share ownership of `data`\n    Buffer(shared_ptr<void> data, ssize_t size, DataType dtype, Device device):\n        data_{std::move(data)}, base_{}, size_{size}, device_{device}, dtype_{dtype}\n    {\n    }\n\n    // Create from the allocator\n    Buffer(ssize_t size, DataType dtype, Allocator& alloc):\n        base_{}, size_{size}, device_{alloc->device()}, dtype_{dtype}\n    {\n        auto bytes = turbomind::byte_size(dtype, size);\n        data_      = {alloc->allocate(bytes), [=](auto p) { alloc->deallocate(p, bytes); }};\n    }\n\n    Buffer(ssize_t size, DataType dtype, Device device): Buffer{size, dtype, Context::alloc(device)} {}\n\n    template<class T>\n    T* data()\n    {\n        TM_CHECK_EQ(data_type_v<T>, dtype_);\n        return (T*)((char*)TM_CHECK_NOTNULL(data_).get() + turbomind::byte_size<T>(base_));\n    }\n\n    template<class T>\n    const T* data() const\n    {\n        return const_cast<Buffer*>(this)->data<T>();\n    }\n\n    void* raw_data(ssize_t offset = 0)\n    {\n        return (char*)TM_CHECK_NOTNULL(data_).get() + turbomind::byte_size(dtype_, base_ + offset);\n    }\n\n    const void* raw_data(ssize_t offset = 0) const\n    {\n        return const_cast<Buffer*>(this)->raw_data(offset);\n    }\n\n    template<class T>\n    T* data_or(T* other) noexcept\n    {\n        if constexpr (std::is_void_v<T>) {\n            return data_ ? (T*)raw_data() : other;\n        }\n        else {\n            return data_ ? data<T>() : other;\n        }\n    }\n\n    template<class T>\n    const T* data_or(const T* other) const noexcept\n    {\n        return const_cast<Buffer*>(this)->data_or(other);\n    }\n\n    DataType dtype() const\n    {\n        return dtype_;\n    }\n\n    Device device() const\n    {\n        return device_;\n    }\n\n    ssize_t size() const\n    {\n        return size_;\n    }\n\n    ssize_t byte_size() const\n    {\n        return turbomind::byte_size(dtype_, size_);\n    }\n\n    explicit operator bool() const noexcept\n    {\n        return static_cast<bool>(data_);\n    }\n\n    Buffer view(DataType dtype) const;\n\n    template<class T>\n    Buffer view() const\n    {\n        return view(data_type_v<T>);\n    }\n\n    Buffer slice(ssize_t base, ssize_t size) const;\n\n    Buffer borrow() const\n    {\n        return Buffer{const_cast<void*>(raw_data()), size_, dtype_, device_};\n    }\n\n    friend bool operator==(const Buffer& a, const Buffer& b);\n\n    friend bool operator!=(const Buffer& a, const Buffer& b);\n\n    friend std::ostream& operator<<(std::ostream& os, const Buffer& b);\n\nprotected:\n    auto as_tuple() const\n    {\n        return std::tie(data_, base_, size_, dtype_, device_);\n    }\n\n    shared_ptr<void> data_;\n    ssize_t          base_;\n    ssize_t          size_;\n    Device           device_;\n    DataType         dtype_;\n};\n\ninline bool operator==(const Buffer& a, const Buffer& b)\n{\n    return a.as_tuple() == b.as_tuple();\n}\n\ninline bool operator!=(const Buffer& a, const Buffer& b)\n{\n    return !(a == b);\n}\n\ninline Buffer empty_like(const Buffer& buffer)\n{\n    return Buffer{buffer.size(), buffer.dtype(), buffer.device()};\n}\n\ninline Buffer empty_like(const Buffer& buffer, Device device)\n{\n    return Buffer{buffer.size(), buffer.dtype(), device};\n}\n\ninline Buffer empty_like(const Buffer& buffer, DataType dtype)\n{\n    return Buffer{buffer.size(), dtype, buffer.device()};\n}\n\ntemplate<class T>\nstruct Buffer_: public Buffer {\n\n    Buffer_(): Buffer{data_type_v<T>} {}\n\n    Buffer_(T* data, ssize_t size, Device device): Buffer{data, size, device} {}\n\n    Buffer_(shared_ptr<void> data, ssize_t size, Device device): Buffer{std::move(data), size, data_type_v<T>, device}\n    {\n    }\n\n    Buffer_(ssize_t size, Allocator& alloc): Buffer{size, data_type_v<T>, alloc} {}\n\n    Buffer_(ssize_t size, Device device): Buffer{size, data_type_v<T>, device} {}\n\n    Buffer_(const Buffer_&) = default;\n    Buffer_& operator=(const Buffer_&) = default;\n\n    Buffer_(Buffer_&&) noexcept = default;\n    Buffer_& operator=(Buffer_&&) noexcept = default;\n\n    Buffer_(const Buffer& b)\n    {\n        *static_cast<Buffer*>(this) = ensure_dtype(b);\n    }\n    Buffer_(Buffer&& b) noexcept\n    {\n        *static_cast<Buffer*>(this) = ensure_dtype(std::move(b));\n    }\n\n    T* data_or(T* other)\n    {\n        return data_ ? data() : other;\n    }\n\n    const T* data_or(const T* other) const\n    {\n        return data_ ? data() : other;\n    }\n\n    void* raw_data(ssize_t offset = 0)\n    {\n        return (char*)TM_CHECK_NOTNULL(data_).get() + turbomind::byte_size<T>(base_ + offset);\n    }\n\n    const void* raw_data(ssize_t offset = 0) const\n    {\n        return const_cast<Buffer_*>(this)->raw_data(offset);\n    }\n\n    T* data()\n    {\n        return static_cast<T*>(raw_data());\n    }\n\n    const T* data() const\n    {\n        return static_cast<const T*>(raw_data());\n    }\n\n    T* begin()\n    {\n        return data();\n    }\n\n    const T* begin() const\n    {\n        return data();\n    }\n\n    T* end()\n    {\n        return begin() + size();\n    }\n\n    const T* end() const\n    {\n        return begin() + size();\n    }\n\n    T& operator[](ssize_t i)\n    {\n        return data()[i];\n    }\n\n    const T& operator[](ssize_t i) const\n    {\n        return data()[i];\n    }\n\n    T& at(ssize_t i)\n    {\n        TM_CHECK_LT(i, size());\n        return data()[i];\n    }\n\n    T& at(ssize_t i) const\n    {\n        TM_CHECK_LT(i, size());\n        return data()[i];\n    }\n\n    constexpr DataType dtype() const noexcept\n    {\n        return data_type_v<T>;\n    }\n\nprivate:\n    template<class U>\n    static decltype(auto) ensure_dtype(U&& u) noexcept\n    {\n        TM_CHECK_EQ(u.dtype(), data_type_v<T>);\n        return (U &&) u;\n    }\n};\n\ntemplate<class T>\nclass Ref {\npublic:\n    Ref(T& x): ref_{x} {}\n    Ref(T&& x): ref_{x} {}\n\n    operator T&()\n    {\n        return ref_;\n    }\n\n    T& get()\n    {\n        return ref_;\n    }\n\nprivate:\n    T& ref_;\n};\n\nvoid Copy(const Buffer& a, ssize_t n, Ref<Buffer> b_, const Stream& stream);\n\nvoid Copy(const Buffer& a, ssize_t n, Ref<Buffer> b_);\n\nvoid Copy(const Buffer& a, Ref<Buffer> b_, const Stream& stream);\n\nvoid Copy(const Buffer& a, Ref<Buffer> b_);\n\n// Static type checking\ntemplate<class T>\ninline void Copy_(const Buffer_<T>& a, ssize_t n, Buffer_<T>& b_)\n{\n    Copy((const Buffer&)a, n, (Buffer&)b_);\n}\n\nnamespace detail {\n\nvoid* Copy(const void* a, ssize_t n, void* b, const Stream& stream);\n\n}  // namespace detail\n\ntemplate<class T>\ninline T* Copy(const T* a, ssize_t n, T* b, const Stream& stream)\n{\n    return (T*)detail::Copy((const void*)a, sizeof(T) * n, (void*)b, stream);\n}\n\ntemplate<class T>\ninline T* Copy(const T* a, ssize_t n, T* b)\n{\n    return (T*)detail::Copy((const void*)a, sizeof(T) * n, (void*)b, Context::stream());\n}\n\nstruct CopyT {\n    template<class... Args>\n    auto operator()(Args&&... args) const\n    {\n        return Copy(((Args &&) args)...);\n    }\n};\n\nvoid Clear(Ref<Buffer> b_, const Stream& stream);\n\nvoid Clear(Ref<Buffer> b_);\n\ntemplate<class T>\nstd::vector<T> to_vector(const Buffer_<T>& b)\n{\n    TM_CHECK(b.device().type == kCPU || b.device().type == kCPUpinned);\n    return std::vector<T>(b.begin(), b.end());\n}\n\n// clang-format off\ntemplate<class Archive>\nvoid save(Archive& ar, const Buffer& buffer)\n{\n    TM_CHECK(buffer.device().type == kCPU);\n    ar & buffer.size();\n    ar & buffer.dtype();\n    ar & ArrayWrapper((char*)buffer.raw_data(), buffer.byte_size());\n}\n\ntemplate<class Archive>\nvoid load(Archive& ar, Buffer& buffer)\n{\n    decltype(buffer.size())  size;\n    decltype(buffer.dtype()) dtype;\n\n    ar & size;\n    ar & dtype;\n    buffer = Buffer(size, dtype, kCPU);\n    ar & ArrayWrapper((char*)buffer.raw_data(), buffer.byte_size());\n}\n// clang-format on\n\n}  // namespace turbomind::core\n"
  },
  {
    "path": "src/turbomind/core/check.cc",
    "content": "\n#include <cstdlib>\n#include <filesystem>\n#include <iostream>\n#include <sstream>\n\n#include \"src/turbomind/core/check.h\"\n#include \"src/turbomind/utils/logger.h\"\n\nnamespace turbomind::core {\n\nnamespace {\n\nstd::string StripSrcPrefix(const char* file)\n{\n    static const char* flag = std::getenv(\"TM_SRC_FULL_PATH\");\n    if (flag) {\n        return file;\n    }\n\n    std::filesystem::path path{file};\n    std::filesystem::path ret{path};  // return the original path if anchor is not found\n\n    constexpr auto anchor = \"turbomind\";\n\n    bool found = false;\n\n    for (const auto& x : path) {\n        if (x == anchor) {\n            found = true;\n            ret.clear();\n        }\n        else if (found) {\n            ret /= x;\n        }\n    }\n\n    return ret.string();\n}\n\n}  // namespace\n\nCheckOpStringBuilder::CheckOpStringBuilder()\n{\n    oss_ = new std::ostringstream;\n}\n\nstd::ostream* CheckOpStringBuilder::ForVal1()\n{\n    (*oss_) << \"(\";\n    return oss_;\n}\nstd::ostream* CheckOpStringBuilder::ForVal2()\n{\n    (*oss_) << \" vs. \";\n    return oss_;\n}\nstd::string* CheckOpStringBuilder::NewString()\n{\n    (*oss_) << \")\";\n    return new std::string{oss_->str()};\n}\n\nCheckErrorStream::CheckErrorStream(const char* file, int line, const char* expr)\n{\n    oss_ = new std::ostringstream{};\n    *oss_ << StripSrcPrefix(file) << \"(\" << line << \"): Check failed: \" << expr << \" \";\n}\n\nCheckErrorStream::CheckErrorStream(const char* file, int line, const char* expr, std::string* str):\n    CheckErrorStream{file, line, expr}\n{\n    *oss_ << *str << \" \";\n}\n\nvoid CheckErrorStream::Report()\n{\n    // ! Be aware of `%` in expr\n    std::cerr << \"[TM][FATAL] \" << oss_->str() << \"\\n\";\n    std::abort();\n}\n\nvoid ReportNullError(const char* file, int line, const char* expr)\n{\n    // ! Be aware of `%` in expr\n    std::cerr << \"[TM][FATAL] \" << StripSrcPrefix(file) << \"(\" << line << \"): '\" << expr << \"' Must be non NULL\\n\";\n    std::abort();\n}\n\n}  // namespace turbomind::core\n"
  },
  {
    "path": "src/turbomind/core/check.h",
    "content": "\n// Inspired by <glog/logging.h>\n\n#pragma once\n\n#include <sstream>\n\nnamespace turbomind::core {\n\n#if defined(_MSC_VER) && !defined(__clang__)\n#define TM_LIKELY(expr) (expr)\n#define TM_UNLIKELY(expr) (expr)\n#define TM_NOINLINE\n#define TM_UNREACHABLE __assume(0)\n#else\n#define TM_LIKELY(expr) (__builtin_expect(bool(expr), 1))\n#define TM_UNLIKELY(expr) (__builtin_expect(bool(expr), 0))\n#define TM_NOINLINE __attribute__((noinline))\n#define TM_UNREACHABLE __builtin_unreachable()\n#endif\n\n#define TM_DISABLE_CHECK_STREAM 0\n#define TM_DISABLE_CHECK_OP 0\n\nclass CheckErrorStream {\npublic:\n    CheckErrorStream(const char* file, int line, const char* expr);\n\n    CheckErrorStream(const char* file, int line, const char* expr, std::string* str);\n\n#if defined(_MSC_VER) && !defined(__clang__)\n#pragma warning(push)\n#pragma warning(disable : 4722)  // MSVC warns dtor never return\n#endif\n    ~CheckErrorStream()\n    {\n        Report();\n    }\n#if defined(_MSC_VER) && !defined(__clang__)\n#pragma warning(pop)\n#endif\n\n    template<class T>\n    CheckErrorStream& operator<<(const T& msg)\n    {\n#if TM_DISABLE_CHECK_STREAM\n#else\n        *oss_ << msg;\n#endif\n        return *this;\n    }\n\nprivate:\n    [[noreturn]] void Report();\n\n    std::ostringstream* oss_;\n};\n\nclass CheckOpStringBuilder {\npublic:\n    CheckOpStringBuilder();\n    std::ostream* ForVal1();\n    std::ostream* ForVal2();\n    std::string*  NewString();\n\nprivate:\n    std::ostringstream* oss_;\n};\n\ntemplate<class T1, class T2>\nstd::string* MakeCheckOpString(const T1& v1, const T2& v2) TM_NOINLINE;\n\ntemplate<class T1, class T2>\nstd::string* MakeCheckOpString(const T1& v1, const T2& v2)\n{\n    CheckOpStringBuilder builder;\n    *builder.ForVal1() << v1;\n    *builder.ForVal2() << v2;\n    return builder.NewString();\n}\n\n#define DEFINE_CHECK_OP_IMPL(name, op)                                                                                 \\\n    template<class T1, class T2>                                                                                       \\\n    inline std::pair<bool, std::string*> name##Impl(const T1& v1, const T2& v2)                                        \\\n    {                                                                                                                  \\\n        if (TM_LIKELY(v1 op v2))                                                                                       \\\n            return {false, nullptr};                                                                                   \\\n        else                                                                                                           \\\n            return {true, MakeCheckOpString(v1, v2)};                                                                  \\\n    }\n\nDEFINE_CHECK_OP_IMPL(Check_EQ, ==);\nDEFINE_CHECK_OP_IMPL(Check_NE, !=);\nDEFINE_CHECK_OP_IMPL(Check_LE, <=);\nDEFINE_CHECK_OP_IMPL(Check_LT, <);\nDEFINE_CHECK_OP_IMPL(Check_GE, >=);\nDEFINE_CHECK_OP_IMPL(Check_GT, >);\n\n#undef DEFINE_CHECK_OP_IMPL\n\n// clang-format off\n#define TM_CHECK(e)                                                                  \\\n    if (TM_UNLIKELY(!(e))) turbomind::core::CheckErrorStream(__FILE__, __LINE__, #e)\n\n#define TM_CHECK_OP(name, op, a, b)                                                  \\\n    if (auto&& [__p, __s] = turbomind::core::Check##name##Impl(a, b); __p) \\\n        turbomind::core::CheckErrorStream(__FILE__, __LINE__, #a \" \" #op \" \" #b, __s)\n// clang-format on\n\n#if TM_DISABLE_CHECK_OP\n\n#define TM_CHECK_EQ(a, b) TM_CHECK(a == b)\n#define TM_CHECK_NE(a, b) TM_CHECK(a != b)\n#define TM_CHECK_LE(a, b) TM_CHECK(a <= b)\n#define TM_CHECK_LT(a, b) TM_CHECK(a < b)\n#define TM_CHECK_GE(a, b) TM_CHECK(a >= b)\n#define TM_CHECK_GT(a, b) TM_CHECK(a > b)\n\n#else\n\n#define TM_CHECK_EQ(a, b) TM_CHECK_OP(_EQ, ==, a, b)\n#define TM_CHECK_NE(a, b) TM_CHECK_OP(_NE, !=, a, b)\n#define TM_CHECK_LE(a, b) TM_CHECK_OP(_LE, <=, a, b)\n#define TM_CHECK_LT(a, b) TM_CHECK_OP(_LT, <, a, b)\n#define TM_CHECK_GE(a, b) TM_CHECK_OP(_GE, >=, a, b)\n#define TM_CHECK_GT(a, b) TM_CHECK_OP(_GT, >, a, b)\n\n#endif\n\n[[noreturn]] void ReportNullError(const char* file, int line, const char* expr);\n\ntemplate<class T>\ndecltype(auto) EnsureNotNull(const char* file, int line, const char* expr, T&& p)\n{\n    if (TM_UNLIKELY(p == nullptr)) {\n        ReportNullError(file, line, expr);\n    }\n    return (T &&) p;\n}\n\n#define TM_CHECK_NOTNULL(p) ::turbomind::core::EnsureNotNull(__FILE__, __LINE__, #p, (p))\n\n}  // namespace turbomind::core\n"
  },
  {
    "path": "src/turbomind/core/common.h",
    "content": "\n#pragma once\n\n#include <cstddef>\n#include <memory>\n#include <vector>\n\n/// TODO: remove this dependency\n#include \"src/turbomind/utils/cuda_utils.h\"\n\nnamespace turbomind::core {\n\nclass Allocator;\nclass Buffer;\nclass Stream;\nclass Event;\nclass Context;\n\nusing std::shared_ptr;\nusing std::vector;\n\nusing ssize_t = std::ptrdiff_t;\n\n}  // namespace turbomind::core\n"
  },
  {
    "path": "src/turbomind/core/context.cc",
    "content": "\n#include <stack>\n\n#include \"src/turbomind/core/allocator.h\"\n#include \"src/turbomind/core/context.h\"\n\nnamespace turbomind::core {\n\nnamespace {\n\nstruct ContextStorage {\n    enum\n    {\n        stream_bit       = 1,\n        host_alloc_bit   = 2,\n        device_alloc_bit = 4,\n        pinned_alloc_bit = 8,\n    };\n\n    std::stack<Stream>    stream_;\n    std::stack<Allocator> host_alloc_;\n    std::stack<Allocator> device_alloc_;\n    std::stack<Allocator> pinned_alloc_;\n    std::stack<int>       mask_;\n\n    ContextStorage()\n    {\n        push(Allocator{kCPU});\n    }\n\n    void push(const Stream& stream)\n    {\n        int mask{};\n        if (stream) {\n            stream_.push(stream);\n            mask = stream_bit;\n        }\n        mask_.push(mask);\n    }\n\n    void push(const Allocator& alloc)\n    {\n        int mask{};\n        if (alloc) {\n            const auto type = alloc->device().type;\n            if (type == kCPU) {\n                mask = host_alloc_bit;\n                host_alloc_.push(alloc);\n            }\n            else if (type == kDEVICE) {\n                mask = device_alloc_bit;\n                device_alloc_.push(alloc);\n            }\n            else if (type == kCPUpinned) {\n                mask = pinned_alloc_bit;\n                pinned_alloc_.push(alloc);\n            }\n        }\n        mask_.push(mask);\n    }\n\n    void pop()\n    {\n        if (mask_.top() & stream_bit) {\n            stream_.pop();\n        }\n        if (mask_.top() & host_alloc_bit) {\n            host_alloc_.pop();\n        }\n        if (mask_.top() & device_alloc_bit) {\n            device_alloc_.pop();\n        }\n        if (mask_.top() & pinned_alloc_bit) {\n            pinned_alloc_.pop();\n        }\n        mask_.pop();\n    }\n\n    static ContextStorage& instance()\n    {\n        thread_local ContextStorage inst{};\n        return inst;\n    }\n};\n\n}  // namespace\n\nvoid Context::push(const Stream& stream)\n{\n    ContextStorage::instance().push(stream);\n}\n\nvoid Context::push(const Allocator& alloc)\n{\n    ContextStorage::instance().push(alloc);\n}\n\nvoid Context::pop()\n{\n    ContextStorage::instance().pop();\n}\n\nStream& Context::stream()\n{\n    auto& stream_ = ContextStorage::instance().stream_;\n    TM_CHECK(!stream_.empty()) << \"No STREAM available in current context\";\n    return stream_.top();\n}\n\nAllocator& Context::host_alloc()\n{\n    auto& host_alloc_ = ContextStorage::instance().host_alloc_;\n    TM_CHECK(!host_alloc_.empty()) << \"No HOST memory allocator available in current context\";\n    return host_alloc_.top();\n}\n\nAllocator& Context::device_alloc()\n{\n    auto& device_alloc_ = ContextStorage::instance().device_alloc_;\n    TM_CHECK(!device_alloc_.empty()) << \"No DEVICE memory allocator available in current context\";\n    return device_alloc_.top();\n}\n\nAllocator& Context::pinned_alloc()\n{\n    auto& pinned_alloc_ = ContextStorage::instance().pinned_alloc_;\n    TM_CHECK(!pinned_alloc_.empty()) << \"No PINNED memory allocator available in current context\";\n    return pinned_alloc_.top();\n}\n\nAllocator& Context::alloc(Device device)\n{\n    switch (device.type) {\n        case kDEVICE:\n            return device_alloc();\n        case kCPU:\n            return host_alloc();\n        case kCPUpinned:\n            return pinned_alloc();\n    }\n    TM_UNREACHABLE;\n}\n\n}  // namespace turbomind::core\n"
  },
  {
    "path": "src/turbomind/core/context.h",
    "content": "#pragma once\n\n#include \"src/turbomind/core/allocator.h\"\n#include \"src/turbomind/core/common.h\"\n#include \"src/turbomind/core/stream.h\"\n\nnamespace turbomind::core {\n\nclass Context {\npublic:\n    static Stream&    stream();\n    static Allocator& host_alloc();\n    static Allocator& device_alloc();\n    static Allocator& pinned_alloc();\n    static Allocator& alloc(Device device);\n\nprivate:\n    friend class ContextGuard;\n    static void push(const Stream& stream);\n    static void push(const Allocator& alloc);\n    static void pop();\n};\n\nclass ContextGuard {\npublic:\n    template<class... Args>\n    explicit ContextGuard(Args&&... args): n_{}\n    {\n        (Context::push((Args &&) args), ...);\n        n_ = sizeof...(Args);\n    }\n    ~ContextGuard()\n    {\n        for (int i = 0; i < n_; ++i) {\n            Context::pop();\n        }\n    }\n\nprivate:\n    int n_;\n};\n\n}  // namespace turbomind::core\n"
  },
  {
    "path": "src/turbomind/core/copy.cc",
    "content": "\n#include \"src/turbomind/core/copy.h\"\n\n#include <cstdint>\n#include <type_traits>\n#include <variant>\n\n#include <cuda_runtime.h>\n\n#include \"src/turbomind/core/check.h\"\n\nnamespace turbomind::core {\n\n// picked from \"cudaTypedefs.h\" / \"cuda.h\"\n\ntypedef enum CUmemcpyFlags_enum\n{\n    CU_MEMCPY_FLAG_DEFAULT                     = 0x0,\n    CU_MEMCPY_FLAG_PREFER_OVERLAP_WITH_COMPUTE = 0x1\n} CUmemcpyFlags;\n\ntypedef enum CUmemcpySrcAccessOrder_enum\n{\n    CU_MEMCPY_SRC_ACCESS_ORDER_INVALID         = 0x0,\n    CU_MEMCPY_SRC_ACCESS_ORDER_STREAM          = 0x1,\n    CU_MEMCPY_SRC_ACCESS_ORDER_DURING_API_CALL = 0x2,\n    CU_MEMCPY_SRC_ACCESS_ORDER_ANY             = 0x3,\n    CU_MEMCPY_SRC_ACCESS_ORDER_MAX             = 0x7FFFFFFF\n} CUmemcpySrcAccessOrder;\n\ntypedef struct CUmemcpyAttributes_st {\n    CUmemcpySrcAccessOrder srcAccessOrder;\n    CUmemLocation          srcLocHint;\n    CUmemLocation          dstLocHint;\n    unsigned int           flags;\n} CUmemcpyAttributes_v1;\n\ntypedef CUresult(CUDAAPI* PFN_cuMemcpyBatchAsync_v12080)(CUdeviceptr_v2*        dsts,\n                                                         CUdeviceptr_v2*        srcs,\n                                                         size_t*                sizes,\n                                                         size_t                 count,\n                                                         CUmemcpyAttributes_v1* attrs,\n                                                         size_t*                attrIdxs,\n                                                         size_t                 numAttrs,\n                                                         size_t*                failIdx,\n                                                         CUstream               hStream);\n\n/// TODO: add `PFN_cuMemcpyBatchAsync_v13000`\n\nnamespace {\n\nconst auto& GetCopyAPI()\n{\n    static auto inst = []() -> std::variant<std::monostate, PFN_cuMemcpyBatchAsync_v12080> {\n        const auto                      symbol = \"cuMemcpyBatchAsync\";\n        cudaDriverEntryPointQueryResult status{};\n        void*                           fpn{};\n        TM_CHECK_EQ(cudaGetDriverEntryPoint(symbol, &fpn, cudaEnableDefault, &status), 0);\n        if (fpn && status == cudaDriverEntryPointSuccess) {\n            return (PFN_cuMemcpyBatchAsync_v12080)fpn;\n        }\n        else {\n            return {};\n        }\n    }();\n    return inst;\n}\n\n}  // namespace\n\nBatchCopy::~BatchCopy() = default;\n\nBatchCopy::BatchCopy(): self_{this}\n{\n    Reset();\n}\n\nvoid BatchCopy::Run()\n{\n    if (src_.empty()) {\n        return;\n    }\n\n    std::visit(\n        [&](auto&& copy) {\n            using T = std::decay_t<decltype(copy)>;\n            if constexpr (std::is_same_v<T, PFN_cuMemcpyBatchAsync_v12080>) {\n                CUmemcpyAttributes_v1 attr{};\n                attr.srcAccessOrder = CU_MEMCPY_SRC_ACCESS_ORDER_STREAM;\n                attr.flags          = CU_MEMCPY_FLAG_PREFER_OVERLAP_WITH_COMPUTE;\n                std::vector<size_t> ais(src_.size(), 0);\n                size_t              fail_idx{SIZE_MAX};\n\n                auto status = copy((CUdeviceptr_v2*)dst_.data(),\n                                   (CUdeviceptr_v2*)src_.data(),\n                                   size_.data(),\n                                   src_.size(),\n                                   &attr,\n                                   ais.data(),\n                                   1,\n                                   &fail_idx,\n                                   core::Context::stream().handle());\n\n                if (auto i = fail_idx; i != SIZE_MAX) {\n                    TM_CHECK(0) << (void*)src_[i] << \" \" << size_[i] << \" \" << (void*)dst_[i] << \" code \" << status;\n                }\n            }\n            else {\n                for (unsigned i = 0; i < src_.size(); ++i) {\n                    core::Copy(src_[i], size_[i], dst_[i]);\n                }\n            }\n        },\n        GetCopyAPI());\n\n    Reset();\n}\n\n}  // namespace turbomind::core\n"
  },
  {
    "path": "src/turbomind/core/copy.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include \"src/turbomind/core/buffer.h\"\n#include \"src/turbomind/core/check.h\"\n\nnamespace turbomind::core {\n\nclass BatchCopy {\npublic:\n    ~BatchCopy();\n\n    BatchCopy();\n\n    BatchCopy(const BatchCopy&) = delete;\n    BatchCopy& operator=(const BatchCopy&) = delete;\n    BatchCopy(BatchCopy&&) noexcept        = delete;\n    BatchCopy& operator=(BatchCopy&&) noexcept = delete;\n\n    // clang-format off\n    class Group {\n    public:\n        ~Group() { parent_.group_end(); }\n        Group(BatchCopy& parent): parent_{parent} { parent_.group_begin(); }\n        explicit constexpr operator bool() const noexcept { return true; }\n    private:\n        BatchCopy& parent_;\n    };\n    // clang-format on\n\n    friend Group;\n\n    Group group()\n    {\n        return {*this};\n    }\n\n    template<class T>\n    T* operator()(const T* src, ssize_t size, T* dst)\n    {\n        // return core::Copy(src, size, dst);\n\n        /// TODO: verify this is actually a fast path in a loop (without extra jump)\n        if (TM_LIKELY(group_ && src == (const T*)src_ptr_ && dst == (T*)dst_ptr_)) {\n            src_ptr_ += sizeof(T) * size;\n            dst_ptr_ += sizeof(T) * size;\n            gsize_ += sizeof(T) * size;\n            count_ += 1;\n            return dst + size;\n        }\n        else if (group_) {\n            group_commit();\n            gsize_   = sizeof(T) * size;\n            src_ptr_ = reinterpret_cast<const char*>(src + size);\n            dst_ptr_ = reinterpret_cast<char*>(dst + size);\n            count_ += 1;\n            return dst + size;\n        }\n        else {\n            gsize_   = sizeof(T) * size;\n            src_ptr_ = reinterpret_cast<const char*>(src + size);\n            dst_ptr_ = reinterpret_cast<char*>(dst + size);\n            count_   = 1;\n            group_commit();\n            return dst + size;\n        }\n    }\n\n    void operator()(const Buffer& src, ssize_t size, Ref<Buffer> dst_)\n    {\n        auto& dst = dst_.get();\n        TM_CHECK_EQ(src.dtype(), dst.dtype());\n        TM_CHECK_LE(size, src.size());\n        TM_CHECK_LE(size, dst.size());\n        (*this)((const char*)src.raw_data(), byte_size(src.dtype(), size), (char*)dst.raw_data());\n    }\n\n    void Run();\n\n    Buffer_<BatchCopy*> buf()\n    {\n        return {&self_, 1, kCPU};\n    }\n\n    friend std::ostream& operator<<(std::ostream& os, const BatchCopy& a)\n    {\n        os << \"(\" << a.count_ << \", \" << a.src_.size() << \")\";\n        return os;\n    }\n\nprivate:\n    void Reset()\n    {\n        src_.clear();\n        dst_.clear();\n        size_.clear();\n        count_ = 0;\n    }\n\n    void group_begin()\n    {\n        TM_CHECK(!group_) << \"Nested group is not supported\";\n        group_ = true;\n    }\n\n    void group_end()\n    {\n        TM_CHECK(group_) << \"Mismatched group end\";\n        group_commit();\n        group_ = false;\n    }\n\n    void group_commit()\n    {\n        if (gsize_) {\n            src_.push_back(src_ptr_ - gsize_);\n            dst_.push_back(dst_ptr_ - gsize_);\n            size_.push_back(gsize_);\n        }\n        src_ptr_ = dst_ptr_ = {};\n        gsize_              = {};\n    }\n\nprivate:\n    std::vector<const char*> src_;\n    std::vector<char*>       dst_;\n    std::vector<size_t>      size_;\n\n    int         group_   = 0;\n    size_t      gsize_   = 0;\n    const char* src_ptr_ = {};\n    char*       dst_ptr_ = {};\n\n    size_t count_;\n\n    BatchCopy* self_;\n};\n\n}  // namespace turbomind::core\n"
  },
  {
    "path": "src/turbomind/core/core.h",
    "content": "#pragma once\n\n#include \"src/turbomind/core/allocator.h\"\n#include \"src/turbomind/core/buffer.h\"\n#include \"src/turbomind/core/check.h\"\n#include \"src/turbomind/core/context.h\"\n#include \"src/turbomind/core/copy.h\"\n#include \"src/turbomind/core/data_type.h\"\n#include \"src/turbomind/core/layout.h\"\n#include \"src/turbomind/core/ranges.h\"\n#include \"src/turbomind/core/stream.h\"\n#include \"src/turbomind/core/tensor.h\"\n\nnamespace turbomind {\n\nusing core::ssize_t;\nusing core::Buffer;\nusing core::Buffer_;\nusing core::Tensor;\nusing core::Tensor_;\nusing core::TensorMap;\nusing core::Ref;\nusing core::Layout;\nusing core::Allocator;\nusing core::Stream;\nusing core::Event;\nusing core::BatchCopy;\n\nusing core::subrange;\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/core/cuda_data_type.h",
    "content": "\n\n#include <cuda.h>\n#include <cuda_runtime.h>\n\n#include <cublas_v2.h>\n\n#include <cuda_bf16.h>\n#include <cuda_fp16.h>\n#include <cuda_fp8.h>\n\n#include \"src/turbomind/core/data_type.h\"\n\nnamespace turbomind {\n\n// clang-format off\n\nconstexpr cudaDataType to_cuda_dtype(DataType type)\n{\n    switch (type) {\n        case kUint8:  return CUDA_R_8U;\n        case kUint16: return CUDA_R_16U;\n        case kUint32: return CUDA_R_32U;\n        case kUint64: return CUDA_R_64U;\n        case kInt8:  return CUDA_R_8I;\n        case kInt16: return CUDA_R_16I;\n        case kInt32: return CUDA_R_32I;\n        case kInt64: return CUDA_R_64I;\n        case kFloat16: return CUDA_R_16F;\n        case kFloat32: return CUDA_R_32F;\n        case kFloat64: return CUDA_R_64F;\n        case kBfloat16: return CUDA_R_16BF;\n        case kFloat8_e4m3: return CUDA_R_8F_E4M3;\n        case kFloat8_e5m2: return CUDA_R_8F_E5M2;\n        default:\n            throw std::runtime_error(\"Not supported \" + std::string{to_string(type)});\n    }\n}\n\nconstexpr DataType from_cuda_dtype(cudaDataType type) {\n    switch (type) {\n        case CUDA_R_8U:  return kUint8;\n        case CUDA_R_16U: return kUint16;\n        case CUDA_R_32U: return kUint32;\n        case CUDA_R_64U: return kUint64;\n        case CUDA_R_8I:  return kInt8;\n        case CUDA_R_16I: return kInt16;\n        case CUDA_R_32I: return kInt32;\n        case CUDA_R_64I: return kInt64;\n        case CUDA_R_16F: return kFloat16;\n        case CUDA_R_32F: return kFloat32;\n        case CUDA_R_64F: return kFloat64;\n        case CUDA_R_16BF: return kBfloat16;\n        case CUDA_R_8F_E4M3: return kFloat8_e4m3;\n        case CUDA_R_8F_E5M2: return kFloat8_e5m2;\n        default:\n            throw std::runtime_error(\"Not supported \" + std::string{std::to_string(type)});\n    }\n}\n\n#if __CUDACC_VER_MAJOR__ >= 12\n\nconstexpr CUtensorMapDataType to_CUtensorMap_dtype(DataType type) {\n    switch (type) {\n        case kFloat32:\n            return CU_TENSOR_MAP_DATA_TYPE_FLOAT32;\n        case kFloat16:\n            return CU_TENSOR_MAP_DATA_TYPE_FLOAT16;\n        case kBfloat16:\n            return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16;\n        case kFloat8_e4m3:\n        case kFloat8_e5m2:\n            return CU_TENSOR_MAP_DATA_TYPE_UINT8;\n        default:\n            throw std::runtime_error(\"Not supported \" + std::string{to_string(type)});\n    }\n}\n\n#endif\n\n// clang-format on\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/core/data_type.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include \"src/turbomind/core/check.h\"\n\n#include <cstddef>\n#include <cstdint>\n#include <type_traits>\n\n// forward declarations for CUDA floating point types\nstruct __half;\nstruct __nv_bfloat16;\nstruct __nv_fp8_e4m3;\nstruct __nv_fp8_e5m2;\n\nnamespace turbomind {\n\n// clang-format off\n\nstruct uint2_t {};\nstruct uint4_t {};\nstruct uint6_t {};\n\ntemplate <int I>\nstruct int_constant: std::integral_constant<int, I> {};\n\ntemplate <class T>\nstruct bitsof_t: int_constant<sizeof(T) * 8> {};\n\ntemplate <> struct bitsof_t<uint2_t>: int_constant<2> {};\ntemplate <> struct bitsof_t<uint4_t>: int_constant<4> {};\ntemplate <> struct bitsof_t<uint6_t>: int_constant<6> {};\n\ntemplate <class T>\ninline constexpr bitsof_t<T> bitsof{};\n\nusing half_t = __half;\nusing bfloat16_t = __nv_bfloat16;\nusing fp8_e4m3_t = __nv_fp8_e4m3;\nusing fp8_e5m2_t = __nv_fp8_e5m2;\n\nstruct fp4_e2m1_t {};\n\ntemplate <> struct bitsof_t<fp4_e2m1_t>: int_constant<4> {};\n\n\nconstexpr int encode_data_type(bool sign, int exponent, int mantissa) {\n    return ((sign << 16) | (exponent << 8) | mantissa);\n}\n\nenum class DataType: int {\n    kNull        = 0,\n    kBool        = 1,\n    kUint8       = encode_data_type(0,  0,  8),\n    kUint16      = encode_data_type(0,  0, 16),\n    kUint32      = encode_data_type(0,  0, 32),\n    kUint64      = encode_data_type(0,  0, 64),\n    kInt8        = encode_data_type(1,  0,  8),\n    kInt16       = encode_data_type(1,  0, 16),\n    kInt32       = encode_data_type(1,  0, 32),\n    kInt64       = encode_data_type(1,  0, 64),\n    kFloat16     = encode_data_type(1,  5, 10),\n    kFloat32     = encode_data_type(1,  8, 23),\n    kFloat64     = encode_data_type(1, 11, 52),\n    kBfloat16    = encode_data_type(1,  8,  7),\n    kFloat4_e2m1 = encode_data_type(1,  2,  1),\n    kFloat6_e2m3 = encode_data_type(1,  2,  3),\n    kFloat6_e3m2 = encode_data_type(1,  3,  2),\n    kFloat8_e4m3 = encode_data_type(1,  4,  3),\n    kFloat8_e5m2 = encode_data_type(1,  5,  2),\n    kUint2       = encode_data_type(0,  0,  2),\n    kUint4       = encode_data_type(0,  0,  4),\n    kUint6       = encode_data_type(0,  0,  6),\n    kPointer,\n    kUint        = kUint32,\n    kInt         = kInt32,\n    kFloat       = kFloat32,\n    kHalf        = kFloat16,\n    kDouble      = kFloat64,\n    kE2m1        = kFloat4_e2m1,\n    kE2m3        = kFloat6_e2m3,\n    kE3m2        = kFloat6_e3m2,\n    kE4m3        = kFloat8_e4m3,\n    kE5m2        = kFloat8_e5m2,\n};\n\ninline constexpr DataType kNull = DataType::kNull;\ninline constexpr DataType kBool = DataType::kBool;\ninline constexpr DataType kPointer = DataType::kPointer;\ninline constexpr DataType kUint8  = DataType::kUint8;\ninline constexpr DataType kUint16 = DataType::kUint16;\ninline constexpr DataType kUint32 = DataType::kUint32;\ninline constexpr DataType kUint64 = DataType::kUint64;\ninline constexpr DataType kInt8  = DataType::kInt8;\ninline constexpr DataType kInt16 = DataType::kInt16;\ninline constexpr DataType kInt32 = DataType::kInt32;\ninline constexpr DataType kInt64 = DataType::kInt64;\ninline constexpr DataType kFloat16 = DataType::kFloat16;\ninline constexpr DataType kFloat32 = DataType::kFloat32;\ninline constexpr DataType kFloat64 = DataType::kFloat64;\ninline constexpr DataType kBfloat16 = DataType::kBfloat16;\ninline constexpr DataType kFloat8_e4m3 = DataType::kFloat8_e4m3;\ninline constexpr DataType kFloat8_e5m2 = DataType::kFloat8_e5m2;\ninline constexpr DataType kFloat4_e2m1 = DataType::kFloat4_e2m1;\ninline constexpr DataType kUint2  = DataType::kUint2;\ninline constexpr DataType kUint4  = DataType::kUint4;\ninline constexpr DataType kUint6  = DataType::kUint6;\ninline constexpr DataType kUint = DataType::kUint;\ninline constexpr DataType kInt = DataType::kInt;\ninline constexpr DataType kHalf = DataType::kHalf;\ninline constexpr DataType kFloat = DataType::kFloat;\ninline constexpr DataType kDouble = DataType::kDouble;\n\ntemplate <class T>\nstruct to_data_type;\n\ntemplate <DataType D>\nstruct from_data_type;\n\n#define CVT_DATA_TYPE(D, T) \\\n    template <> struct to_data_type<T> { static constexpr auto value = DataType::D; }; \\\n    template <> struct from_data_type<DataType::D> { using type = T; }\n\nCVT_DATA_TYPE(kNull, void);\n\nCVT_DATA_TYPE(kBool, bool);\nCVT_DATA_TYPE( kUint8, uint8_t);\nCVT_DATA_TYPE(kUint16, uint16_t);\nCVT_DATA_TYPE(kUint32, uint32_t);\nCVT_DATA_TYPE(kUint64, uint64_t);\n\nCVT_DATA_TYPE( kInt8, int8_t);  // NOTE: `int8_t` is `signed char` and is different from `char`\nCVT_DATA_TYPE(kInt16, int16_t);\nCVT_DATA_TYPE(kInt32, int32_t);\nCVT_DATA_TYPE(kInt64, int64_t);\n\nCVT_DATA_TYPE(kFloat16, half_t);\nCVT_DATA_TYPE(kFloat32, float);\nCVT_DATA_TYPE(kFloat64, double);\nCVT_DATA_TYPE(kBfloat16, bfloat16_t);\nCVT_DATA_TYPE(kFloat4_e2m1, fp4_e2m1_t);\nCVT_DATA_TYPE(kFloat8_e4m3, fp8_e4m3_t);\nCVT_DATA_TYPE(kFloat8_e5m2, fp8_e5m2_t);\n\nCVT_DATA_TYPE(kUint2, uint2_t);\nCVT_DATA_TYPE(kUint4, uint4_t);\nCVT_DATA_TYPE(kUint6, uint6_t);\n\n#undef CVT_DATA_TYPE\n\ntemplate <class T> struct to_data_type<T*> { static constexpr auto value = DataType::kPointer; };\ntemplate <>  struct from_data_type<DataType::kPointer> { using type = void*; };\n\ntemplate <class T>\ninline constexpr auto data_type_v = to_data_type<std::remove_cv_t<T>>::value;\n\ntemplate <DataType D>\nusing data_type_t = typename from_data_type<D>::type;\n\nconstexpr std::ptrdiff_t byte_size(DataType type, std::ptrdiff_t size = 1) {\n    switch (type) {\n        case kNull: return 0;\n        case kBool:\n        case kUint8:\n        case kInt8:\n        case kFloat8_e4m3:\n        case kFloat8_e5m2:\n            return size;\n        case kUint16:\n        case kInt16:\n        case kFloat16:\n        case kBfloat16:\n            return size * 2;\n        case kUint32:\n        case kInt32:\n        case kFloat32:\n            return size * 4;\n        case kUint64:\n        case kInt64:\n        case kFloat64:\n            return size * 8;\n        case kUint2: return size * 2 / 8;\n        case kUint4:\n        case kFloat4_e2m1:\n            return size * 4 / 8;\n        case kUint6: return size * 6 / 8;\n        case kPointer: return size * sizeof(void*);\n        default:\n            return 0;\n    }\n    return 0;\n}\n\ntemplate <class T>\nconstexpr std::ptrdiff_t byte_size(std::ptrdiff_t size = 1) { return byte_size(data_type_v<T>, size); }\n\nconstexpr std::ptrdiff_t numel(DataType type, std::ptrdiff_t size = 1) {\n    switch (type) {\n        case kNull: return 0;\n        case kBool:\n        case kUint8:\n        case kInt8:\n        case kFloat8_e4m3:\n        case kFloat8_e5m2:\n            return size;\n        case kUint16:\n        case kInt16:\n        case kFloat16:\n        case kBfloat16:\n            return size / 2;\n        case kUint32:\n        case kInt32:\n        case kFloat32:\n            return size / 4;\n        case kUint64:\n        case kInt64:\n        case kFloat64:\n            return size / 8;\n        case kUint2: return size * 8 / 2;\n        case kUint4:\n        case kFloat4_e2m1:\n            return size * 8 / 4;\n        case kUint6: return size * 8 / 6;\n        case kPointer: return size / sizeof(void*);\n        default:\n            return 0;\n    }\n    return 0;\n}\n\ntemplate <class T>\nconstexpr std::ptrdiff_t numel(std::ptrdiff_t size) { return numel(data_type_v<T>, size); }\n\nconstexpr const char* to_string(DataType type) {\n    switch (type) {\n        case kNull: return \"nil\";\n        case kBool: return \"bool\";\n        case kUint8: return \"u8\";\n        case kUint16: return \"u16\";\n        case kUint32: return \"u32\";\n        case kUint64: return \"u64\";\n        case kInt8: return \"i8\";\n        case kInt16: return \"i16\";\n        case kInt32: return \"i32\";\n        case kInt64: return \"i64\";\n        case kFloat16: return \"f16\";\n        case kFloat32: return \"f32\";\n        case kFloat64: return \"f64\";\n        case kBfloat16: return \"bf16\";\n        case kFloat8_e4m3: return \"e4m3\";\n        case kFloat8_e5m2: return \"e5m2\";\n        case kFloat4_e2m1: return \"e2m1\";\n        case kUint2: return \"u2\";\n        case kUint4: return \"u4\";\n        case kUint6: return \"u8\";\n        case kPointer: return \"pointer\";\n        default:\n            return \"unknown\";\n    }\n    return \"\";\n}\n\ninline std::ostream& operator<<(std::ostream& os, DataType type) {\n    os << to_string(type);\n    return os;\n}\n\n/// TODO: mapping with DLPack\n\n// clang-format on\n\n#define TM_PP_NARGS(...) TM_PP_NARGS_IMPL(__VA_ARGS__, 8, 7, 6, 5, 4, 3, 2, 1, 0)\n#define TM_PP_NARGS_IMPL(_0, _1, _2, _3, _4, _5, _6, _7, N, ...) N\n\n#define TM_PP_CAT(a, b) a##b\n#define TM_PP_STR(x) #x\n\n#define TM_PP_DISPATCH_N(macro, ...) TM_PP_DISPATCH_N_IMPL(macro, TM_PP_NARGS(__VA_ARGS__))\n#define TM_PP_DISPATCH_N_IMPL(macro, x) TM_PP_CAT(macro, x)\n\n#define TM_PP_INVOKE_1(macro, f, _0) macro(f, _0)\n\n#define TM_PP_INVOKE_2(macro, f, _0, _1)                                                                               \\\n    macro(f, _0);                                                                                                      \\\n    macro(f, _1)\n\n#define TM_PP_INVOKE_3(macro, f, _0, _1, _2)                                                                           \\\n    macro(f, _0);                                                                                                      \\\n    macro(f, _1);                                                                                                      \\\n    macro(f, _2)\n\n#define TM_PP_INVOKE_4(macro, f, _0, _1, _2, _3)                                                                       \\\n    macro(f, _0);                                                                                                      \\\n    macro(f, _1);                                                                                                      \\\n    macro(f, _2);                                                                                                      \\\n    macro(f, _3)\n\n#define TM_PP_INVOKE_5(macro, f, _0, _1, _2, _3, _4)                                                                   \\\n    macro(f, _0);                                                                                                      \\\n    macro(f, _1);                                                                                                      \\\n    macro(f, _2);                                                                                                      \\\n    macro(f, _3);                                                                                                      \\\n    macro(f, _4)\n\n#define TM_DISPATCH_DTYPE_RET_CASE(f, t)                                                                               \\\n    case ::turbomind::data_type_v<t>:                                                                                  \\\n        return f(t{});\n\n#define TM_DISPATCH_DTYPE_CASE(f, t)                                                                                   \\\n    case ::turbomind::data_type_v<t>:                                                                                  \\\n        f(t{});                                                                                                        \\\n        break\n\n// clang-format off\n#define TM_DISPATCH_DTYPES_RET(var, f, ...)                                                                            \\\n    switch (var) {                                                                                                     \\\n        TM_PP_DISPATCH_N(TM_PP_INVOKE_, __VA_ARGS__)(TM_DISPATCH_DTYPE_RET_CASE, f, __VA_ARGS__);                      \\\n        default:                                                                                                       \\\n            TM_CHECK(0) << \"unsupported type: \"  << to_string(var);                                                    \\\n            return {};                                                                                                 \\\n    }\n\n#define TM_DISPATCH_DTYPES(var, f, ...)                                                                                \\\n    switch (var) {                                                                                                     \\\n        TM_PP_DISPATCH_N(TM_PP_INVOKE_, __VA_ARGS__)(TM_DISPATCH_DTYPE_CASE, f, __VA_ARGS__);                          \\\n        default:                                                                                                       \\\n            TM_CHECK(0) << \"unsupported type: \"  << to_string(var);                                                    \\\n    }\n// clang-format on\n\n#define TM_PRIMARY_DTYPES_0 ::turbomind::half_t\n\n#if ENABLE_BF16\n#define TM_PRIMARY_DTYPES_1 TM_PRIMARY_DTYPES_0, ::turbomind::bfloat16_t\n#else\n#define TM_PRIMARY_DTYPES_1 TM_PRIMARY_DTYPES_0\n#endif\n\n#if ENABLE_FP32\n#define TM_PRIMARY_DTYPES TM_PRIMARY_DTYPES_1, float\n#else\n#define TM_PRIMARY_DTYPES TM_PRIMARY_DTYPES_1\n#endif\n\n#define TM_DISPATCH_PRIMARY_DTYPES(var, func) TM_DISPATCH_DTYPES(var, func, TM_PRIMARY_DTYPES)\n\n#define TM_DISPATCH_PRIMARY_DTYPES_RET(var, func) TM_DISPATCH_DTYPES_RET(var, func, TM_PRIMARY_DTYPES)\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/core/interval.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include <algorithm>\n#include <climits>\n#include <ostream>\n\nnamespace turbomind {\n\nclass Interval {\npublic:\n    struct Size {\n        int      x;\n        explicit operator int() const noexcept\n        {\n            return x;\n        }\n        friend bool operator<(const Size& a, const Size& b)\n        {\n            return a.x < b.x;\n        }\n    };\n\n    Interval(): first_{0}, last_{0} {}\n\n    explicit Interval(int first): first_{first}, last_{INT_MAX} {};\n\n    Interval(int first, int last): first_{first}, last_{last} {}\n\n    Interval(int first, Size size): first_{first}, last_{first + (int)size} {}\n\n    bool empty() const noexcept\n    {\n        return first_ >= last_;\n    }\n\n    explicit operator bool() const noexcept\n    {\n        return !empty();\n    }\n\n    Size size() const noexcept\n    {\n        return Size{std::max(0, last_ - first_)};\n    }\n\n    int begin() const noexcept\n    {\n        return first_;\n    }\n\n    int end() const noexcept\n    {\n        return last_;\n    }\n\n    friend Interval operator&(const Interval& a, const Interval& b)\n    {\n        return {std::max(a.first_, b.first_), std::min(a.last_, b.last_)};\n    }\n\n    friend Interval operator|(const Interval& a, const Interval& b)\n    {\n        return {std::min(a.first_, b.first_), std::max(a.last_, b.last_)};\n    }\n\n    // dilate / erode left\n    friend Interval operator|(int x, const Interval& a)\n    {\n        return {a.begin() - x, a.end()};\n    }\n\n    // dilate / erode right\n    friend Interval operator|(const Interval& a, int x)\n    {\n        return {a.begin(), a.end() + x};\n    }\n\n    friend std::ostream& operator<<(std::ostream& os, const Interval& a)\n    {\n        return os << \"[\" << a.first_ << \", \" << a.last_ << \")\";\n    }\n\n    friend std::ostream& operator<<(std::ostream& os, const Interval* a)\n    {\n        return os << *a;\n    }\n\nprivate:\n    int first_;\n    int last_;\n};\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/core/layout.cc",
    "content": "\n#include <numeric>\n\n#include \"src/turbomind/core/check.h\"\n#include \"src/turbomind/core/layout.h\"\n\nnamespace turbomind::core {\n\nLayout::Layout(std::vector<ssize_t> shape): shape_{std::move(shape)}\n{\n    TM_CHECK(shape_.size());\n    stride_.resize(shape_.size());\n    size_ = 1;\n    for (int i = shape_.size() - 1; i >= 0; --i) {\n        stride_[i] = size_;\n        size_ *= shape_[i];\n    }\n}\n\nLayout::Layout(vector<ssize_t> shape, vector<ssize_t> stride): shape_{std::move(shape)}, stride_{std::move(stride)}\n{\n    TM_CHECK(shape_.size());\n    TM_CHECK_EQ(shape_.size(), stride_.size());\n\n    size_ = std::accumulate(shape_.begin(), shape_.end(), ssize_t{1}, std::multiplies<>{});\n\n    TM_CHECK_GE(size_, 0);\n}\n\nssize_t Layout::cosize() const noexcept\n{\n    if (rank() == 0) {\n        return 0;\n    }\n    ssize_t value{1};\n    for (size_t i = 0; i < shape_.size(); ++i) {\n        value += (shape_[i] - 1) * stride_[i];\n    }\n    return value;\n}\n\nLayout Layout::coalesce() const noexcept\n{\n    vector<ssize_t> shape{shape_.front()};\n    vector<ssize_t> stride{stride_.front()};\n\n    for (size_t i = 1; i < shape_.size(); ++i) {\n        if (shape_[i] == 1) {\n            continue;\n        }\n        else if (shape.back() == 1) {\n            shape.back()  = shape_[i];\n            stride.back() = stride_[i];\n        }\n        else if (stride.back() == shape_[i] * stride_[i]) {\n            stride.back() = stride_[i];\n            shape.back() *= shape_[i];\n        }\n        else {\n            shape.push_back(shape_[i]);\n            stride.push_back(stride_[i]);\n        }\n    }\n\n    return Layout{shape, stride};\n}\n\nLayout Layout::view(vector<ssize_t> shape) const\n{\n    if (shape == shape_) {\n        return *this;\n    }\n\n    TM_CHECK(!shape.empty());\n\n    // size check & wildcard resolution\n    auto wildcard = std::find(shape.begin(), shape.end(), -1);\n    if (wildcard != shape.end()) {\n        TM_CHECK(std::find(wildcard + 1, shape.end(), -1) == shape.end());\n        *wildcard = 1;\n    }\n    auto new_size = std::accumulate(shape.begin(), shape.end(), ssize_t{1}, std::multiplies<>{});\n    if (wildcard != shape.end()) {\n        TM_CHECK(size_ % new_size == 0) << size_ << \" % \" << new_size;\n        *wildcard = size_ / new_size;\n    }\n    else {\n        TM_CHECK_EQ(size_, new_size);\n    }\n\n    if (is_contiguous()) {\n        return Layout{shape};\n    }\n\n    const Layout c = coalesce();  // merge contiguous dimensions\n\n    ssize_t p = c.rank();\n    ssize_t s = 1;\n    ssize_t d = 0;\n\n    vector<ssize_t> stride(shape.size());\n\n    for (int i = shape.size() - 1; i >= 0; --i) {\n        if (shape[i] == 1) {\n            stride[i] = 0;\n        }\n        else {\n            if (s == 1) {\n                --p;\n                s = c.shape().at(p);\n                d = c.stride().at(p);\n            }\n            TM_CHECK_EQ(s % shape[i], 0);  // crossing non-contiguous dimensions\n            stride[i] = d;\n            d *= shape[i];\n            s /= shape[i];\n        }\n    }\n    return Layout{std::move(shape), std::move(stride)};\n}\n\nstd::pair<Layout, ssize_t> Layout::slice(const vector<ssize_t>& base, vector<ssize_t> shape) const\n{\n    TM_CHECK_EQ(base.size(), shape.size());\n    TM_CHECK_EQ(shape_.size(), shape.size());\n    ssize_t offset = 0;\n    for (size_t i = 0; i < shape.size(); ++i) {\n        const auto space = shape_[i] - base[i];\n        TM_CHECK_GE(space, 0);\n        if (shape[i] == -1) {\n            shape[i] = space;\n        }\n        TM_CHECK_LE(shape[i], space);\n        offset += base[i] * stride_[i];\n    }\n    return {Layout{std::move(shape), stride_}, offset};\n}\n\nstd::ostream& operator<<(std::ostream& os, const Layout& x)\n{\n    os << \"(\";\n    for (int i = 0; i < x.rank(); ++i) {\n        os << (i ? \",\" : \"\") << x.shape_[i];\n    }\n    os << \"):(\";\n    for (int i = 0; i < x.rank(); ++i) {\n        os << (i ? \",\" : \"\") << x.stride_[i];\n    }\n    os << \")\";\n    return os;\n}\n\n}  // namespace turbomind::core\n"
  },
  {
    "path": "src/turbomind/core/layout.h",
    "content": "\n#pragma once\n\n#include <initializer_list>\n\n#include \"src/turbomind/core/check.h\"\n#include \"src/turbomind/core/common.h\"\n\nnamespace turbomind::core {\n\nclass Layout {\npublic:\n    Layout(): size_{0} {}\n\n    /* implicit */ Layout(vector<ssize_t> shape);\n\n    /* implicit */ Layout(std::initializer_list<ssize_t> shape): Layout(vector(shape)) {}\n\n    Layout(vector<ssize_t> shape, vector<ssize_t> stride);\n\n    ssize_t size() const noexcept\n    {\n        return size_;\n    }\n\n    ssize_t cosize() const noexcept;\n\n    ssize_t rank() const noexcept\n    {\n        return shape_.size();\n    }\n\n    auto& shape() const noexcept\n    {\n        return shape_;\n    }\n\n    auto shape(int i) const\n    {\n        return shape_.at(wrap(i));\n    }\n\n    template<class... Is>\n    auto shapes(Is... is) const\n    {\n        return std::make_tuple(shape(is)...);\n    }\n\n    auto& stride() const noexcept\n    {\n        return stride_;\n    }\n\n    auto stride(int i) const\n    {\n        return stride_.at(wrap(i));\n    }\n\n    template<class... Is>\n    auto strides(Is... is) const\n    {\n        return std::make_tuple(stride(is)...);\n    }\n\n    bool is_contiguous() const noexcept\n    {\n        if (stride_.back() != 1) {\n            return false;\n        }\n        if (size() != cosize()) {\n            return false;\n        }\n        for (int i = 0; i < rank() - 1; ++i) {\n            // TODO: skip when shape == 1\n            if (stride_[i] < stride_[i + 1]) {\n                return false;\n            }\n        }\n        return true;\n    }\n\n    Layout permute(const vector<int>& dims) const\n    {\n        TM_CHECK((int)dims.size() == rank());\n        auto a = *this;\n        for (int i = 0; i < rank(); ++i) {\n            a.shape_[i]  = shape_[dims[i]];\n            a.stride_[i] = stride_[dims[i]];\n        }\n        return a;\n    }\n\n    Layout transpose(int a, int b) const\n    {\n        TM_CHECK_LT(a, rank());\n        TM_CHECK_LT(b, rank());\n        auto x = *this;\n        std::swap(x.shape_[a], x.shape_[b]);\n        std::swap(x.stride_[a], x.stride_[b]);\n        return x;\n    }\n\n    ssize_t offset(const vector<ssize_t>& idxs) const\n    {\n        TM_CHECK((int)idxs.size() < rank());\n        ssize_t val = 0;\n        for (size_t i = 0; i < idxs.size(); ++i) {\n            TM_CHECK_LT(idxs[i], shape_[i]);\n            val += idxs[i] * stride_[i];\n        }\n        return val;\n    }\n\n    ssize_t offset(ssize_t idx0) const\n    {\n        TM_CHECK(rank());\n        TM_CHECK_LT(idx0, shape_[0]);\n        return stride_[0] * idx0;\n    }\n\n    Layout coalesce() const noexcept;\n\n    Layout view(vector<ssize_t> shape) const;\n\n    std::pair<Layout, ssize_t> slice(const vector<ssize_t>& base, vector<ssize_t> shape) const;\n\n    Layout squeeze(int dim) const\n    {\n        if (rank() == 1 || shape(dim) != 1) {\n            return *this;\n        }\n        Layout a;\n        a.shape_.reserve(rank() - 1);\n        a.stride_.reserve(rank() - 1);\n        for (int i = 0; i < rank(); ++i) {\n            if (i != dim) {\n                a.shape_.push_back(shape_[i]);\n                a.stride_.push_back(stride_[i]);\n            }\n        }\n        a.size_ = size_;\n        return a;\n    }\n\n    friend std::ostream& operator<<(std::ostream& os, const Layout& x);\n\n    friend bool operator==(const Layout& a, const Layout& b)\n    {\n        return a.shape_ == b.shape_ && a.stride_ == b.stride_;\n    }\n\n    friend bool operator!=(const Layout& a, const Layout& b)\n    {\n        return !(a == b);\n    }\n\nprivate:\n    int wrap(int dim) const noexcept\n    {\n        return dim < 0 ? dim + shape_.size() : dim;\n    }\n\nprivate:\n    vector<ssize_t> shape_;\n    vector<ssize_t> stride_;\n    ssize_t         size_;\n};\n\ninline std::string to_string(const Layout& x)\n{\n    std::stringstream ss;\n    ss << x;\n    return ss.str();\n}\n\n// clang-format off\ntemplate<class Archive>\nvoid save(Archive& ar, const Layout& layout)\n{\n    ar & layout.shape();\n    ar & layout.stride();\n}\n\ntemplate<class Archive>\nvoid load(Archive& ar, Layout& layout)\n{\n    vector<ssize_t> shape;\n    vector<ssize_t> stride;\n    ar & shape;\n    ar & stride;\n    layout = Layout(std::move(shape), std::move(stride));\n}\n// clang-format on\n\n}  // namespace turbomind::core\n"
  },
  {
    "path": "src/turbomind/core/module.cc",
    "content": "\n#include \"src/turbomind/core/module.h\"\n#include \"src/turbomind/core/check.h\"\n#include <optional>\n\nnamespace turbomind::core {\n\nModule::Module(): parent_{} {}\n\nModule::~Module()\n{\n    if (parent_) {\n        parent_->remove_module(*this);\n        parent_ = {};\n    }\n}\n\nvoid Module::register_module(std::string name, Module& module, std::optional<int> index)\n{\n    module.parent_ = this;\n    if (index) {\n        name += \".\";\n        name += std::to_string(*index);\n    }\n    // std::cout << \"register Module \" << name << \" \" << &module << \", parent \" << this << \"\\n\";\n    modules_.emplace_back(std::move(name), &module);\n}\n\nvoid Module::register_parameter(std::string name, Tensor& param)\n{\n    // std::cout << \"register Parameter \" << name << \" \" << &param << \" \" << param.layout() << \"\\n\";\n    params_.emplace_back(std::move(name), &param);\n}\n\nvoid Module::remove_module(Module& module)\n{\n    for (auto it = modules_.begin(); it != modules_.end(); ++it) {\n        if (it->second == &module) {\n            // std::cout << \"erase \" << it->first << \" \" << &module << \" from \" << this << \"\\n\";\n            modules_.erase(it);\n            return;\n        }\n    }\n    TM_CHECK(0) << \"module \" << &module << \" not found\";\n}\n\nvoid Module::remove_parameter(Tensor& param)\n{\n    for (auto it = params_.begin(); it != params_.end(); ++it) {\n        if (it->second == &param) {\n            params_.erase(it);\n            return;\n        }\n    }\n    TM_CHECK(0) << \"param \" << &param << \" not found\";\n}\n\nstd::unordered_map<std::string, Tensor*> Module::get_parameters() const\n{\n    std::unordered_map<std::string, Tensor*> m;\n    get_parameters_impl({}, m);\n    return m;\n}\n\nvoid Module::get_parameters_impl(std::string prefix, std::unordered_map<std::string, Tensor*>& m) const\n{\n    if (!prefix.empty()) {\n        prefix += \".\";\n    }\n    for (const auto& [k, v] : params_) {\n        m.emplace(prefix + k, v);\n    }\n    for (const auto& [k, v] : modules_) {\n        v->get_parameters_impl(prefix + k, m);\n    }\n}\n\n}  // namespace turbomind::core\n"
  },
  {
    "path": "src/turbomind/core/module.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n#ifndef TURBOMIND_CORE_MODULE_H\n#define TURBOMIND_CORE_MODULE_H\n\n#include \"src/turbomind/core/tensor.h\"\n\nnamespace turbomind::core {\n\nclass Module {\npublic:\n    virtual ~Module();\n\n    Module();\n\n    Module(const Module&) = delete;\n    Module& operator=(const Module&) = delete;\n\n    Module(Module&&) noexcept = delete;\n    Module& operator=(Module&&) noexcept = delete;\n\n    void register_module(std::string name, Module& module, std::optional<int> index = {});\n    void register_parameter(std::string name, Tensor& param);\n\n    void remove_module(Module& module);\n    void remove_parameter(Tensor& param);\n\n    std::unordered_map<std::string, Tensor*> get_parameters() const;\n\nprivate:\n    void get_parameters_impl(std::string prefix, std::unordered_map<std::string, Tensor*>& m) const;\n\nprotected:\n    Module* parent_;\n\n    std::vector<std::pair<std::string, Module*>> modules_;\n    std::vector<std::pair<std::string, Tensor*>> params_;\n};\n\n}  // namespace turbomind::core\n\n#endif  // TURBOMIND_CORE_MODULE_H\n"
  },
  {
    "path": "src/turbomind/core/ranges.h",
    "content": "#pragma once\n\nnamespace turbomind::core {\n\ntemplate<class Iterator>\nclass subrange {\npublic:\n    subrange(Iterator first, Iterator last): first_{first}, last_{last} {}\n\n    Iterator begin()\n    {\n        return first_;\n    }\n\n    Iterator end()\n    {\n        return last_;\n    }\n\n    auto empty() const\n    {\n        return first_ == last_;\n    }\n\n    auto size() const\n    {\n        return last_ - first_;\n    }\n\nprivate:\n    Iterator first_;\n    Iterator last_;\n};\n\n}  // namespace turbomind::core\n"
  },
  {
    "path": "src/turbomind/core/serdes.h",
    "content": "#pragma once\n\n#include <algorithm>\n#include <array>\n#include <memory>\n#include <type_traits>\n#include <vector>\n\nnamespace turbomind::core {\n\ntemplate<template<class...> typename F, class SFINAE, class... Args>\nstruct is_detected: std::false_type {\n};\n\ntemplate<template<class... Args> typename F, class... Args>\nstruct is_detected<F, std::void_t<F<Args...>>, Args...>: std::true_type {\n};\n\ntemplate<class Archive, class T>\nusing save_t = decltype(save(std::declval<Archive&>(), std::declval<T>()));\n\ntemplate<class Archive, class T>\ninline constexpr bool has_save_v = is_detected<save_t, void, Archive, T>::value;\n\ntemplate<class Archive, class T>\nusing load_t = decltype(load(std::declval<Archive&>(), std::declval<T>()));\n\ntemplate<class Archive, class T>\ninline constexpr bool has_load_v = is_detected<load_t, void, Archive, T>::value;\n\ntemplate<class Archive, class T>\nusing serdes_t = decltype(serdes(std::declval<Archive&>(), std::declval<T>()));\n\ntemplate<class Archive, class T>\ninline constexpr bool has_serdes_v = is_detected<serdes_t, void, Archive, T>::value;\n\ntemplate<typename T>\nclass ArrayWrapper {\npublic:\n    ArrayWrapper(T* t, std::size_t size): t_{t}, size_{size}\n    {\n        static_assert(std::is_trivially_copyable_v<T>, \"ArrayWrapper requires trivially copyable type\");\n    }\n\n    T* data() const\n    {\n        return t_;\n    }\n\n    std::size_t count() const\n    {\n        return size_;\n    }\n\n    T* const          t_;\n    const std::size_t size_;\n};\n\ntemplate<typename T>\ninline constexpr bool is_array_wrapper_v = std::false_type{};\n\ntemplate<typename T>\ninline constexpr bool is_array_wrapper_v<ArrayWrapper<T>> = std::true_type{};\n\ntemplate<class Derived>\nstruct OutputArchive {\n    static constexpr bool is_loading = false;\n\n    template<class T>\n    void operator&(T&& x)\n    {\n        if constexpr (has_save_v<Derived, T>) {\n            save(*this, (T &&) x);\n        }\n        else if constexpr (has_serdes_v<Derived, T>) {\n            serdes(*this, (T &&) x);\n        }\n        else {\n            reinterpret_cast<Derived*>(this)->write((T &&) x);\n        }\n    }\n};\n\ntemplate<class Derived>\nstruct InputArchive {\n    static constexpr bool is_loading = true;\n\n    template<class T>\n    void operator&(T&& x)\n    {\n        if constexpr (has_load_v<Derived, T>) {\n            load(*this, (T &&) x);\n        }\n        else if constexpr (has_serdes_v<Derived, T>) {\n            serdes(*this, (T &&) x);\n        }\n        else {\n            reinterpret_cast<Derived*>(this)->read((T &&) x);\n        }\n    }\n};\n\nstruct BinarySizeArchive: OutputArchive<BinarySizeArchive> {\n    size_t size_{};\n\n    size_t size()\n    {\n        return size_;\n    }\n\n    template<class T>\n    void write(const T& x)\n    {\n        static_assert(std::is_trivially_copyable_v<T>);\n        size_ += sizeof(x);\n    }\n\n    template<class T>\n    void write(const ArrayWrapper<T>& arr)\n    {\n        static_assert(std::is_trivially_copyable_v<T>);\n        size_ += sizeof(T) * arr.count();\n    }\n};\n\nstruct BinaryOutputArchive: OutputArchive<BinaryOutputArchive> {\n\n    ArrayWrapper<std::byte> external_;\n    size_t                  ptr_;\n\n    BinaryOutputArchive(ArrayWrapper<std::byte> external): external_{external}, ptr_{} {}\n\n    template<class T>\n    void write(const T& x)\n    {\n        static_assert(std::is_trivially_copyable_v<T>);\n        auto data = (const std::byte*)&x;\n        TM_CHECK_LE(ptr_ + sizeof(T), external_.count());\n        std::copy_n(data, sizeof(T), external_.data() + ptr_);\n        ptr_ += sizeof(T);\n    }\n\n    template<class T>\n    void write(const ArrayWrapper<T>& arr)\n    {\n        static_assert(std::is_trivially_copyable_v<T>);\n        auto data = (const std::byte*)arr.data();\n        TM_CHECK_LE(ptr_ + sizeof(T) * arr.count(), external_.count());\n        std::copy_n(data, sizeof(T) * arr.count(), external_.data() + ptr_);\n        ptr_ += sizeof(T) * arr.count();\n    }\n};\n\nstruct BinaryInputArchive: InputArchive<BinaryInputArchive> {\n\n    ArrayWrapper<std::byte> external_;\n    size_t                  ptr_;\n\n    BinaryInputArchive(ArrayWrapper<std::byte> external): external_{external}, ptr_{} {}\n\n    template<class T>\n    void read(T& x)\n    {\n        static_assert(std::is_trivially_copyable_v<T>);\n        TM_CHECK_LE(ptr_ + sizeof(T), external_.count());\n        std::copy_n(external_.data() + ptr_, sizeof(T), (std::byte*)&x);\n        ptr_ += sizeof(T);\n    }\n\n    template<class T>\n    void read(ArrayWrapper<T>&& arr)\n    {\n        static_assert(std::is_trivially_copyable_v<T>);\n        TM_CHECK_LE(ptr_ + sizeof(T) * arr.count(), external_.count());\n        std::copy_n(external_.data() + ptr_, sizeof(T) * arr.count(), (std::byte*)arr.data());\n        ptr_ += sizeof(T) * arr.count();\n    }\n};\n\ntemplate<class Archive, class T>\nvoid save(Archive& ar, const std::vector<T>& xs)\n{\n    // clang-format off\n    ar & xs.size();\n    if constexpr (std::is_trivially_copyable_v<T>) {\n        ar & ArrayWrapper(xs.data(), xs.size());\n    }\n    else {\n        for (const auto& x : xs) {\n            ar & x;\n        }\n    }\n    // clang-format on\n}\n\ntemplate<class Archive, class T>\nvoid load(Archive& ar, std::vector<T>& xs)\n{\n    // clang-format off\n    decltype(xs.size()) size;\n    ar & size;\n    xs.resize(size);\n\n    if constexpr (std::is_trivially_copyable_v<T>) {\n        ar & ArrayWrapper(xs.data(), size);\n    } else {\n        for (size_t i = 0; i < size; ++i) {\n            ar & xs[i];\n        }\n    }\n    // clang-format on\n}\n\ntemplate<class Archive>\nvoid save(Archive& ar, const std::string& s)\n{\n    // clang-format off\n    ar & s.size();\n    ar & ArrayWrapper(s.data(), s.size());\n    // clang-format on\n}\n\ntemplate<class Archive>\nvoid load(Archive& ar, std::string& s)\n{\n    // clang-format off\n    decltype(s.size()) size;\n    ar & size;\n    s.resize(size);\n    ar & ArrayWrapper(s.data(), size);\n    // clang-format on\n}\n\ntemplate<class Archive, class T>\nvoid save(Archive& ar, const std::shared_ptr<T>& p)\n{\n    // clang-format off\n    ar & (bool)p;\n    if (p) {\n        ar & (*p);\n    }\n    // clang-format on\n}\n\ntemplate<class Archive, class T>\nvoid load(Archive& ar, std::shared_ptr<T>& p)\n{\n    // clang-format off\n    bool pred;\n    ar & pred;\n    if (pred) {\n        p = std::make_shared<T>();\n        ar & (*p);\n    }\n}\n\ntemplate<class Archive, class T, size_t N>\nvoid serdes(Archive& ar, std::array<T, N>& xs)\n{\n    // clang-format off\n    if constexpr (std::is_trivially_copyable_v<T>) {\n        ar & ArrayWrapper(xs.data(), N);\n    }\n    else {\n        for (size_t i = 0; i < N; ++i) {\n            ar & xs[i];\n        }\n    }\n    // clang-format on\n}\n\ntemplate<class Archive, class... Ts>\nvoid serdes(Archive& ar, std::tuple<Ts...>& tpl)\n{\n    std::apply([&](auto&... elems) { ((ar & elems), ...); }, tpl);\n}\n\n}  // namespace turbomind::core\n"
  },
  {
    "path": "src/turbomind/core/state.h",
    "content": "\n#include \"src/turbomind/core/allocator.h\"\n#include \"src/turbomind/core/check.h\"\n#include \"src/turbomind/core/core.h\"\n#include \"src/turbomind/core/data_type.h\"\n#include \"src/turbomind/core/layout.h\"\n#include \"src/turbomind/core/tensor.h\"\n#include <algorithm>\n\nnamespace turbomind {\n\n// Goals:\n// 1. constant number of cudaMemcpy / kernel launches\n// 2. single stream synchronization / iteration\n\nstruct State {\n\n    Tensor data_[2];\n\n    State() = default;\n\n    State(const Layout& layout, DataType dtype, const core::Device& device)\n    {\n        data_[0] = {layout, dtype, device};\n        data_[1] = {layout, dtype, device};\n    }\n\n    Tensor& front()\n    {\n        return data_[0];\n    }\n\n    Tensor& back()\n    {\n        return data_[1];\n    }\n\n    void Swap()\n    {\n        std::swap(data_[0], data_[1]);\n    }\n};\n\ntemplate<class Copy>\nvoid Warp(const Tensor& a0, int size0, const Buffer_<int>& perm, Tensor b1, Copy& copy)\n{\n    auto a0_ptr = (const uint8_t*)a0.raw_data();\n    auto b1_ptr = (uint8_t*)b1.raw_data();\n\n    const auto vec_size = byte_size(a0.dtype(), a0.stride(0));\n\n    for (int i = 0; i < perm.size(); ++i) {\n        if (const int j = perm[i]; TM_LIKELY(j < size0)) {\n            copy(a0_ptr + j * vec_size, vec_size, b1_ptr + i * vec_size);\n        }\n    }\n}\n\ntemplate<class Copy>\nvoid Warp(const Tensor& a0, const Tensor& b1, int size0, const Buffer_<int>& perm, Tensor c1, Copy& copy)\n{\n    auto a0_ptr = (const uint8_t*)a0.raw_data();\n    auto b1_ptr = (const uint8_t*)b1.raw_data();\n    auto c1_ptr = (uint8_t*)c1.raw_data();\n\n    const auto vec_size = byte_size(a0.dtype(), a0.stride(0));\n\n    for (int i = 0; i < perm.size(); ++i) {\n        const uint8_t* src_ptr = TM_LIKELY(perm[i] < size0) ? a0_ptr + perm[i] * vec_size : b1_ptr + i * vec_size;\n        copy(src_ptr, vec_size, c1_ptr + i * vec_size);\n    }\n}\n\ntemplate<class Copy>\nvoid Warp(const Tensor&       src0,\n          const Buffer_<int>& offset0,\n          int                 size0,\n          const Tensor&       src1,\n          const Buffer_<int>& offset1,\n          const Buffer_<int>& perm0,\n          Tensor              dst,\n          Buffer_<int>        offsetd,\n          Copy&               copy)\n{\n    auto p_src0 = (const uint8_t*)src0.raw_data();\n    auto p_src1 = (const uint8_t*)src1.raw_data();\n\n    const ssize_t vec_size = byte_size(src0.dtype(), src0.stride(0));\n\n    auto p_dst = (uint8_t*)dst.raw_data();\n\n    offsetd[0] = 0;\n\n    for (int i = 0; i < perm0.size(); ++i) {\n        const uint8_t* p_src;\n        ssize_t        n;\n        if (const int j = perm0[i]; TM_LIKELY(j < size0)) {\n            p_src = p_src0 + offset0[j] * vec_size;\n            n     = offset0[j + 1] - offset0[j];\n        }\n        else {\n            p_src = p_src1 + offset1[i] * vec_size;\n            n     = offset1[i + 1] - offset1[i];\n        }\n        offsetd[i + 1] = offsetd[i] + n;\n        copy(p_src, n * vec_size, p_dst + offsetd[i] * vec_size);\n    }\n}\n\n// d1[i] = a0[perm[i]]:b0[perm[i]] if perm[i] < size0 else c1[i]\n// where `a0` has variable size with fixed stride\n//       `b0` has fixed size (1)\n//       `a1` has variable size\n//       `c1` has variable size with fixed stride\ntemplate<class Copy>\nvoid Append(const Tensor&       a0,\n            const Buffer_<int>& a0_size,\n            const Tensor&       b0,\n            const Tensor&       c1,\n            const Buffer_<int>& c1_offset,\n            const Buffer_<int>& perm,\n            int                 size0,\n            Tensor              d1,\n            Buffer_<int>        d1_size,\n            Copy&               copy)\n{\n    auto a0_ptr = (const uint8_t*)a0.raw_data();\n    auto b0_ptr = (const uint8_t*)b0.raw_data();\n    auto c1_ptr = (const uint8_t*)c1.raw_data();\n\n    auto d1_ptr = (uint8_t*)d1.raw_data();\n\n    TM_CHECK_EQ(a0.stride(0), d1.stride(0));\n\n    const auto stride   = byte_size(a0.dtype(), a0.stride(0));\n    const auto vec_size = byte_size(a0.dtype(), a0.stride(1));\n\n    for (int i = 0; i < perm.size(); ++i) {\n        if (const int j = perm[i]; TM_LIKELY(j < size0)) {\n            uint8_t* out = copy(a0_ptr + j * stride, vec_size * a0_size[j], d1_ptr + i * stride);\n            copy(b0_ptr + j * vec_size, vec_size, out);\n            d1_size[i] = a0_size[j] + 1;\n        }\n        else {\n            const auto n = c1_offset[i + 1] - c1_offset[i];\n            copy(c1_ptr + c1_offset[i] * vec_size, n * vec_size, d1_ptr + i * stride);\n            d1_size[i] = n;\n        }\n    }\n}\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/core/stream.cc",
    "content": "\n#include \"src/turbomind/core/stream.h\"\n#include <memory>\n\nnamespace turbomind::core {\n\nStream Stream::create(int priority)\n{\n    Stream stream;\n    stream.impl_ = std::make_shared<StreamImpl>(priority);\n    return stream;\n}\n\nvoid StreamImpl::Wait(const Event& event)\n{\n    check_cuda_error(cudaStreamWaitEvent(stream_, event));\n}\n\n}  // namespace turbomind::core\n"
  },
  {
    "path": "src/turbomind/core/stream.h",
    "content": "#pragma once\n\n#include <cuda_runtime.h>\n\n#include \"src/turbomind/core/check.h\"\n#include \"src/turbomind/core/common.h\"\n\nnamespace turbomind::core {\n\nclass StreamImpl {\npublic:\n    StreamImpl(int priority): stream_{}\n    {\n        check_cuda_error(cudaStreamCreateWithPriority(&stream_, cudaStreamNonBlocking, priority));\n    }\n\n    ~StreamImpl()\n    {\n        if (auto ec = cudaStreamDestroy(stream_); ec != cudaSuccess) {\n            TM_LOG_ERROR(cudaGetErrorString(ec));\n        }\n        stream_ = {};\n    }\n\n    void Sync()\n    {\n        check_cuda_error(cudaStreamSynchronize(stream_));\n    }\n\n    void Wait(const Event& event);\n\n    cudaStream_t handle() const\n    {\n        return stream_;\n    }\n\npublic:\n    cudaStream_t stream_;\n};\n\nclass Stream {\npublic:\n    Stream() = default;\n\n    static Stream create(int priority = 0);\n\n    void Sync()\n    {\n        impl_->Sync();\n    }\n\n    void Wait(const Event& event)\n    {\n        impl_->Wait(event);\n    }\n\n    cudaStream_t handle() const\n    {\n        return TM_CHECK_NOTNULL(impl_)->handle();\n    }\n\n    explicit operator cudaStream_t() const\n    {\n        return handle();\n    }\n\n    explicit operator bool() const noexcept\n    {\n        return static_cast<bool>(impl_);\n    }\n\n    friend bool operator==(const Stream& a, const Stream& b)\n    {\n        return a.impl_ == b.impl_;\n    }\n\n    friend bool operator!=(const Stream& a, const Stream& b)\n    {\n        return !(a == b);\n    }\n\n    friend std::ostream& operator<<(std::ostream& os, const Stream& s)\n    {\n        os << s.impl_;\n        return os;\n    }\n\nprivate:\n    shared_ptr<StreamImpl> impl_;\n};\n\nclass EventImpl {\npublic:\n    explicit EventImpl(unsigned flags)\n    {\n        check_cuda_error(cudaEventCreateWithFlags(&event_, flags));\n    }\n\n    ~EventImpl()\n    {\n        if (auto ec = cudaEventDestroy(event_); ec != cudaSuccess) {\n            TM_LOG_ERROR(cudaGetErrorString(ec));\n        }\n    }\n\n    void Record(const Stream& stream)\n    {\n        check_cuda_error(cudaEventRecord(event_, stream.handle()));\n    }\n\n    void Sync() const\n    {\n        check_cuda_error(cudaEventSynchronize(event_));\n    }\n\n    cudaEvent_t handle() const\n    {\n        return event_;\n    }\n\nprivate:\n    cudaEvent_t event_;\n};\n\nclass Event {\npublic:\n    Event() = default;\n\n    static Event create(bool timing = false)\n    {\n        Event e{};\n        e.impl_ = std::make_shared<EventImpl>(timing ? 0 : cudaEventDisableTiming);\n        return e;\n    }\n\n    void Record(const Stream& stream)\n    {\n        TM_CHECK_NOTNULL(impl_)->Record(stream);\n    }\n\n    void Sync() const\n    {\n        TM_CHECK_NOTNULL(impl_)->Sync();\n    }\n\n    operator cudaEvent_t() const\n    {\n        return TM_CHECK_NOTNULL(impl_)->handle();\n    }\n\n    explicit operator bool() const noexcept\n    {\n        return static_cast<bool>(impl_);\n    }\n\nprivate:\n    shared_ptr<EventImpl> impl_;\n};\n\n}  // namespace turbomind::core\n"
  },
  {
    "path": "src/turbomind/core/tensor.cc",
    "content": "\n#include \"src/turbomind/core/tensor.h\"\n#include \"src/turbomind/core/buffer.h\"\n#include \"src/turbomind/core/check.h\"\n#include \"src/turbomind/core/stream.h\"\n\nnamespace turbomind::core {\n\nstd::ostream& operator<<(std::ostream& os, const Tensor& t)\n{\n    os << t.dtype() << \"[\" << t.layout() << \"]@\" << t.buffer_.data_or((void*)nullptr);\n    return os;\n}\n\nTensor& TensorMap::at(const std::string& key)\n{\n    auto it = find(key);\n    TM_CHECK(it != end()) << get_out_of_range_msg(key);\n    return it->second;\n}\n\nstd::string TensorMap::get_out_of_range_msg(const std::string& key) const\n{\n    std::ostringstream oss;\n    oss << \"Cannot find a tensor of name '\" << key << \"' in the tensor map (keys: \";\n    auto sep = \"\";\n    for (const auto& [k, _] : *this) {\n        oss << std::exchange(sep, \", \") << k;\n    }\n    oss << \")\";\n    return oss.str();\n}\n\nTensor* TensorMap::try_(const std::string& key)\n{\n    auto it = find(key);\n    if (it != end()) {\n        return &it->second;\n    }\n    return nullptr;\n}\n\nvoid Copy(const Tensor& src, Ref<Tensor> dst_, const Stream& stream)\n{\n    auto& dst = dst_.get();\n    TM_CHECK(src.dtype() == dst.dtype());\n    TM_CHECK(src.shape() == dst.shape());\n    TM_CHECK(src.is_contiguous());\n    TM_CHECK(dst.is_contiguous());\n    if (auto size = src.byte_size()) {\n        check_cuda_error(cudaMemcpyAsync(dst.raw_data(), src.raw_data(), size, cudaMemcpyDefault, stream.handle()));\n    }\n}\n\nvoid Copy(const Tensor& src, Ref<Tensor> dst_)\n{\n    Copy(src, dst_, Context::stream());\n}\n\nvoid Clear(Ref<Tensor> a_, const Stream& stream)\n{\n    auto& a = a_.get();\n    TM_CHECK(a.is_contiguous());\n    if (auto size = a.byte_size()) {\n        check_cuda_error(cudaMemsetAsync(a.raw_data(), 0, size, stream.handle()));\n    }\n}\n\nvoid Clear(Ref<Tensor> a_)\n{\n    Clear(a_, Context::stream());\n}\n\n#if 0\n\nvoid Copy(const Tensor& src, Tensor& dst, Stream& stream)\n{\n    TM_CHECK(src.dtype() == dst.dtype());\n    TM_CHECK(src.shape() == dst.shape());\n\n    const DataType dtype = src.dtype();\n\n    auto trivial = [&] {\n        const ssize_t bytesize = get_byte_size(dtype, src.size());\n        check_cuda_error(cudaMemcpyAsync(dst.raw_data(), src.raw_data(), bytesize, cudaMemcpyDefault, stream.handle()));\n    };\n\n    if (src.layout().is_contiguous() && dst.layout().is_contiguous()) {\n        return trivial();\n    }\n\n    auto a = src.layout();\n    auto b = dst.layout();\n\n    vector<int> idxs(a.rank());\n    std::iota(idxs.begin(), idxs.end(), 0);\n    std::sort(idxs.begin(), idxs.end(), [&](int i, int j) {  //\n        return a.stride()[j] < a.stride()[i];\n    });\n\n    // innermost dim is not contiguous\n    if (a.stride(idxs.back()) > 1 || b.stride(idxs.back()) > 1) {\n        return GenericCopy(src, dst, stream);\n    }\n\n    a = a.reorder(idxs);\n    b = b.reorder(idxs);\n\n    // trivial after reorder (e.g. transposed matrices)\n    if (a.is_contiguous() && b.is_contiguous()) {\n        return trivial();\n    }\n\n    a = a.coalesce();\n    b = b.coalesce();\n\n    int rank = std::max(a.rank(), b.rank());\n\n    if (rank > 3) {\n        return GenericCopy(src, dst, stream);\n    }\n\n    if (a.rank() < rank) {\n        a = a.view(b.shape());\n    }\n    else if (b.rank() < rank) {\n        b = b.view(b.shape());\n    }\n\n    if (rank == 2) {\n        check_cuda_error(cudaMemcpy2DAsync(dst.raw_data(),\n                                           get_byte_size(dtype, b.stride(0)),\n                                           src.raw_data(),\n                                           get_byte_size(dtype, a.stride(0)),\n                                           get_byte_size(dtype, a.shape(1)),\n                                           a.shape(0),\n                                           cudaMemcpyDefault,\n                                           stream.handle()));\n        return;\n    }\n\n    auto [a0, a1] = a.strides(0, 1);\n    auto [b0, b1] = b.strides(0, 1);\n\n    // make sure the underlying space is actually a cube [x % (y * z) == 0]\n    if (rank == 3 && a0 % a1 == 0 && b0 % b1 == 0) {\n        const auto xsz_a = get_byte_size(dtype, a.stride(1));\n        const auto xsz_b = get_byte_size(dtype, b.stride(1));\n        const auto ysz_a = a0 / a1;\n        const auto ysz_b = b0 / b1;\n\n        cudaMemcpy3DParms param{};\n        param.srcPtr = make_cudaPitchedPtr((void*)src.raw_data(), xsz_a, xsz_a, ysz_a);\n        param.dstPtr = make_cudaPitchedPtr((void*)dst.raw_data(), xsz_b, xsz_b, ysz_b);\n        param.extent = make_cudaExtent(get_byte_size(dtype, a.shape(2)), a.shape(1), a.shape(0));\n        param.kind   = cudaMemcpyDefault;\n\n        if (auto ec = cudaMemcpy3DAsync(&param, stream.handle()); ec == cudaSuccess) {\n            TM_LOG_WARNING(cudaGetErrorString(ec));\n            return;\n        }\n    }\n\n    return GenericCopy(src, dst, stream);\n}\n\nvoid Copy(const Tensor& src, Tensor&& dst, Stream& stream)\n{\n    return Copy(src, dst, stream);\n}\n\n#endif\n\n}  // namespace turbomind::core\n"
  },
  {
    "path": "src/turbomind/core/tensor.cu",
    "content": "\n\n#include \"src/turbomind/core/buffer.h\"\n#include \"src/turbomind/core/tensor.h\"\n#include \"src/turbomind/kernels/core/array.h\"\n#include \"src/turbomind/kernels/core/math.h\"\n#include \"src/turbomind/kernels/core/meta.h\"\n\nnamespace turbomind::core {\n\n#if 0\n\nnamespace kernel {\n\n// This is going to be slow for transposing the innermost dim\ntemplate<class T, class Index, int D>\n__global__ void GenericCopy(const T*          a,\n                            T*                b,\n                            Array<int64_t, D> stride_a,\n                            Array<int64_t, D> stride_b,\n                            Array<Index, D>   shape,\n                            int               ndim,\n                            int64_t           size)\n{\n    Index idx = threadIdx.x + (Index)blockIdx.x * blockDim.x;\n\n    if (idx >= size) {\n        return;\n    }\n\n    Array<int64_t, D> coord;\n    PRAGMA_UNROLL\n    for (int i = 0; i < D; ++i) {\n        if (i < ndim) {\n            auto div = idx / shape[i];\n            auto mod = idx % shape[i];\n            coord[i] = mod;\n            idx      = div;\n        }\n    }\n\n    int64_t idx_a = 0;\n    int64_t idx_b = 0;\n\n    PRAGMA_UNROLL\n    for (int i = 0; i < D; ++i) {\n        if (i < ndim) {\n            idx_a += coord[i] * stride_a[i];\n            idx_b += coord[i] * stride_b[i];\n        }\n    }\n\n    b[idx_b] = a[idx_a];\n}\n\n}  // namespace kernel\n\nvoid GenericCopy(const Tensor& src, Tensor& dst, Stream& stream)\n{\n    auto a = src.layout();\n    auto b = dst.layout();\n\n    // Sort strides ascending\n    vector<int> idxs(a.rank());\n    std::iota(idxs.begin(), idxs.end(), 0);\n    std::sort(idxs.begin(), idxs.end(), [&](int i, int j) {  //\n        return a.stride()[i] < a.stride()[j];\n    });\n\n    a = a.permute(idxs);\n    b = b.permute(idxs);\n\n    a = a.coalesce();\n    b = b.coalesce();\n\n    int rank = std::max(a.rank(), b.rank());\n\n    if (a.rank() < rank) {\n        a = a.view(b.shape());\n    }\n    else if (b.rank() < rank) {\n        b = b.view(b.shape());\n    }\n\n    const DataType dtype = src.dtype();\n\n    int64_t alignment = 16;\n\n    auto align = [&](auto v) { alignment = std::gcd(alignment, v); };\n\n    if (a.stride(0) > 1 || b.stride(0) > 1) {\n        alignment = get_byte_size(dtype);\n    }\n\n    align(get_byte_size(dtype, a.shape(0)));\n\n    auto data_a = src.raw_data();\n    auto data_b = dst.raw_data();\n\n    align(reinterpret_cast<uintptr_t>(data_a));\n    align(reinterpret_cast<uintptr_t>(data_b));\n\n    for (int i = 1; i < rank; ++i) {\n        align(get_byte_size(dtype, a.stride(i)));\n        align(get_byte_size(dtype, b.stride(i)));\n    }\n\n    const auto vec_size = get_elem_num(alignment, dtype);\n\n    const auto size = a.size() / vec_size;\n\n    int device{};\n    check_cuda_error(cudaGetDevice(&device));\n    int sm_num{};\n    check_cuda_error(cudaDeviceGetAttribute(&sm_num, cudaDevAttrMultiProcessorCount, device));\n\n    auto invoke = [&](auto vec_t, auto index_t, auto d) {\n        using T         = decltype(vec_t);\n        using Index     = decltype(index_t);\n        constexpr int D = d.value;\n\n        Array<Index, D> shape;\n        std::fill(shape.begin() + rank, shape.end(), 1);\n        std::copy_n(a.shape().data(), rank, shape.data());\n\n        Array<int64_t, D> stride_a{};\n        Array<int64_t, D> stride_b{};\n        std::copy_n(a.stride().data(), rank, stride_a.data());\n        std::copy_n(b.stride().data(), rank, stride_b.data());\n\n        if (vec_size > 1) {\n            shape[0] /= vec_size;\n            for (int i = 0; i < rank; ++i) {\n                stride_a[i] /= vec_size;\n                stride_b[i] /= vec_size;\n            }\n        }\n\n        auto func = kernel::GenericCopy<T, Index, D>;\n\n        int min_waves  = INT_MAX;\n        int block_size = 0;\n        int grid_size  = 0;\n\n        for (int threads = 256; threads <= 1024; threads *= 2) {\n            int blocks = cdiv<ssize_t>(size, block_size);\n            int n_active{};\n            check_cuda_error(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&n_active, func, block_size, 0));\n            int waves = cdiv(blocks, n_active * sm_num);\n            if (waves < min_waves) {\n                min_waves  = waves;\n                block_size = threads;\n                grid_size  = blocks;\n            }\n        }\n\n        func<<<grid_size, block_size, 0, stream.handle()>>>(\n            (const T*)data_a, (T*)data_b, stride_a, stride_b, shape, rank, a.size());\n    };\n\n    auto invoke_d = [&](auto vec_t, auto idx_t) {\n        if (rank <= 2) {\n            invoke(vec_t, idx_t, constant<2>{});\n        }\n        else if (rank <= 4) {\n            invoke(vec_t, idx_t, constant<4>{});\n        }\n        else if (rank <= 8) {\n            invoke(vec_t, idx_t, constant<8>{});\n        }\n        else {\n            throw std::runtime_error(\"not implemented\");\n        }\n    };\n\n    auto invoke_i = [&](auto vec_t) {\n        if (size < INT_MAX) {\n            invoke_d(vec_t, int{});\n        }\n        else {\n            invoke_d(vec_t, int64_t{});\n        }\n    };\n\n    switch (alignment) {\n        case 16:\n            return invoke_i(uint4{});\n        case 8:\n            return invoke_i(uint2{});\n        case 4:\n            return invoke_i(uint{});\n        case 2:\n            return invoke_i(ushort{});\n        default:\n            return invoke_i(char{});\n    }\n}\n\n#endif\n\n}  // namespace turbomind::core\n"
  },
  {
    "path": "src/turbomind/core/tensor.h",
    "content": "#pragma once\n\n#include <optional>\n#include <string>\n#include <unordered_map>\n\n#include \"src/turbomind/core/allocator.h\"\n#include \"src/turbomind/core/buffer.h\"\n#include \"src/turbomind/core/context.h\"\n#include \"src/turbomind/core/layout.h\"\n\nnamespace turbomind::core {\n\nclass Tensor {\npublic:\n    Tensor() = default;\n\n    Tensor(Layout layout, DataType dtype, Device device): Tensor{layout, dtype, Context::alloc(device)} {}\n\n    Tensor(Layout layout, DataType dtype, Allocator& alloc): layout_{std::move(layout)}\n    {\n        buffer_ = Buffer(layout_.cosize(), dtype, alloc);\n    }\n\n    Tensor(Buffer buffer, Layout layout): layout_{std::move(layout)}, buffer_{buffer.slice(0, layout_.cosize())} {}\n\n    Tensor(Buffer buffer): layout_{buffer.size()}, buffer_{buffer} {}\n\n    Tensor(void* data, Layout layout, DataType dtype, Device device):\n        Tensor{Buffer{data, layout.cosize(), dtype, device}, layout}\n    {\n    }\n\n    Tensor(std::shared_ptr<void> data, Layout layout, DataType dtype, Device device):\n        Tensor{Buffer{data, layout.cosize(), dtype, device}, layout}\n    {\n    }\n\n    template<class T>\n    Tensor(T* data, Layout layout, Device device): Tensor{Buffer{data, layout.cosize(), device}, layout}\n    {\n    }\n\n    Buffer& buffer() noexcept\n    {\n        return buffer_;\n    }\n\n    const Buffer& buffer() const noexcept\n    {\n        return buffer_;\n    }\n\n    DataType dtype() const\n    {\n        return buffer_.dtype();\n    }\n\n    Device device() const\n    {\n        return buffer_.device();\n    }\n\n    ssize_t size() const noexcept\n    {\n        return layout_.size();\n    }\n\n    ssize_t byte_size() const noexcept\n    {\n        return turbomind::byte_size(dtype(), size());\n    }\n\n    explicit operator bool() const noexcept\n    {\n        return static_cast<bool>(buffer_);\n    }\n\n    template<class T>\n    T* data()\n    {\n        return buffer_.data<T>();\n    }\n\n    template<class T>\n    const T* data() const\n    {\n        return const_cast<Tensor*>(this)->data<T>();\n    }\n\n    void* raw_data()\n    {\n        return buffer_.raw_data();\n    }\n\n    const void* raw_data() const\n    {\n        return const_cast<Tensor*>(this)->raw_data();\n    }\n\n    template<class T>\n    T* data_or(T* other)\n    {\n        return buffer_.data_or(other);\n    }\n\n    template<class T>\n    const T* data_or(T* other) const\n    {\n        return buffer_.data_or(other);\n    }\n\n    Tensor view(std::vector<ssize_t> shape) const\n    {\n        return Tensor{buffer_, layout_.view(std::move(shape))};\n    }\n\n    auto& layout() const noexcept\n    {\n        return layout_;\n    }\n\n    auto& shape() const noexcept\n    {\n        return layout_.shape();\n    }\n\n    auto shape(int i) const\n    {\n        return layout_.shape(i);\n    }\n\n    template<class... Is>\n    auto shapes(Is&&... is) const\n    {\n        return layout_.shapes(((Is &&) is)...);\n    }\n\n    auto& stride() const noexcept\n    {\n        return layout_.stride();\n    }\n\n    auto stride(int i) const\n    {\n        return layout_.stride(i);\n    }\n\n    template<class... Is>\n    auto strides(Is&&... is) const\n    {\n        return layout_.strides(((Is &&) is)...);\n    }\n\n    bool is_contiguous() const noexcept\n    {\n        return layout().is_contiguous();\n    }\n\n    Tensor slice(std::vector<ssize_t> base, std::vector<ssize_t> shape) const\n    {\n        auto&& [layout, offset] = layout_.slice(base, std::move(shape));\n        const auto cosize       = layout.cosize();\n        return Tensor{buffer_.slice(offset, cosize), std::move(layout)};\n    }\n\n    // The outermost dimension\n    Tensor slice(ssize_t base, ssize_t size = 1) const\n    {\n        vector<ssize_t> bases(shape().size());\n        bases.front() = base;\n        vector<ssize_t> sizes{this->shape()};\n        sizes.front() = size;\n        return slice(bases, sizes);\n    }\n\n    Tensor borrow() const\n    {\n        return Tensor{buffer_.borrow(), layout_};\n    }\n\n    Tensor squeeze(int dim) const\n    {\n        return Tensor{buffer_, layout_.squeeze(dim)};\n    }\n\n    Tensor transpose(int a, int b) const\n    {\n        return Tensor{buffer_, layout_.transpose(a, b)};\n    }\n\n    Tensor t() const\n    {\n        TM_CHECK_EQ(ndim(), 2);\n        return transpose(0, 1);\n    }\n\n    int ndim() const noexcept\n    {\n        return layout_.rank();\n    }\n\n    friend std::ostream& operator<<(std::ostream& os, const Tensor& t);\n\nprivate:\n    Layout layout_;\n    Buffer buffer_;\n};\n\ninline Tensor empty_like(const Tensor& tensor)\n{\n    return Tensor{tensor.layout(), tensor.dtype(), tensor.device()};\n}\n\ninline Tensor empty_like(const Tensor& tensor, Device device)\n{\n    return Tensor{tensor.layout(), tensor.dtype(), device};\n}\n\ninline Tensor empty_like(const Tensor& tensor, DataType dtype)\n{\n    return Tensor{tensor.layout(), dtype, tensor.device()};\n}\n\nvoid Copy(const Tensor& src, Ref<Tensor> dst_, const Stream& stream);\n\nvoid Copy(const Tensor& src, Ref<Tensor> dst_);\n\nvoid Clear(Ref<Tensor> a_, const Stream& stream);\n\nvoid Clear(Ref<Tensor> a_);\n\n#if 0\n\nvoid Copy(const Tensor& src, Tensor&& dst, Stream& stream);\n\n// Launch a kernel to perform the complicated copying\nvoid GenericCopy(const Tensor& src, Tensor& dst, Stream& stream);\n\nTensor Reshape(const Tensor& t, vector<ssize_t> shape);\n\nTensor Transpoe(const Tensor& t, int dim0, int dim1);\n\nTensor Permute(const Tensor& t, vector<int> dims);\n\nTensor Contiguous(const Tensor& t);\n#endif\n\ntemplate<class T>\nstruct Tensor_: public Tensor {\n    Tensor_() = default;\n\n    Tensor_(Layout layout, Device device): Tensor{std::move(layout), data_type_v<T>, device} {}\n\n    Tensor_(Layout layout, Allocator& alloc): Tensor{std::move(layout), data_type_v<T>, alloc} {}\n\n    Tensor_(Buffer buffer, Layout layout): Tensor{ensure_dtype(std::move(buffer)), std::move(layout)} {}\n\n    Tensor_(T* data, Layout layout, Device device): Tensor{data, std::move(layout), device} {}\n\n    Tensor_(shared_ptr<void> data, Layout layout, Device device):\n        Tensor{Buffer{std::move(data), layout.cosize(), data_type_v<T>, device}, layout}\n    {\n    }\n\n    Tensor_(const Tensor_&) = default;\n    Tensor_& operator=(const Tensor_&) = default;\n\n    Tensor_(Tensor_&&) noexcept = default;\n    Tensor_& operator=(Tensor_&&) noexcept = default;\n\n    Tensor_(const Tensor& other)\n    {\n        *static_cast<Tensor*>(this) = ensure_dtype(other);\n    }\n    Tensor_(Tensor&& other) noexcept\n    {\n        *static_cast<Tensor*>(this) = ensure_dtype(std::move(other));\n    }\n\n    ssize_t offset(const vector<ssize_t>& idxs)\n    {\n        return layout().offset(idxs);\n    }\n\n    T* data() noexcept\n    {\n        return Tensor::data<T>();\n    }\n\n    const T* data() const noexcept\n    {\n        return Tensor::data<T>();\n    }\n\n    T* data_or(T* other)\n    {\n        return Tensor::data_or<T>(other);\n    }\n\n    const T* data_or(T* other) const\n    {\n        return Tensor::data_or<T>(other);\n    }\n\n    constexpr DataType dtype() const noexcept\n    {\n        return data_type_v<T>;\n    }\n\nprivate:\n    template<class U>\n    static decltype(auto) ensure_dtype(U&& u)\n    {\n        TM_CHECK_EQ(u.dtype(), data_type_v<T>);\n        return (U &&) u;\n    }\n};\n\nclass TensorMap: public std::unordered_map<std::string, Tensor> {\npublic:\n    using std::unordered_map<std::string, Tensor>::unordered_map;\n\n    Tensor& at(const std::string& key);\n\n    const Tensor& at(const std::string& key) const\n    {\n        return const_cast<TensorMap*>(this)->at(key);\n    }\n\n    Tensor* try_(const std::string& key);\n\n    const Tensor* try_(const std::string& key) const\n    {\n        return const_cast<TensorMap*>(this)->try_(key);\n    }\n\n    bool contains(const std::string& key) const\n    {\n        return find(key) != end();\n    }\n\n    void produce(const std::string& key, Tensor value)\n    {\n        TM_CHECK(emplace(key, std::move(value)).second);\n    }\n\n    Tensor try_consume(const std::string& key)\n    {\n        if (auto it = find(key); it != end()) {\n            auto value = std::move(it->second);\n            erase(it);\n            return value;\n        }\n        return Tensor{};\n    }\n\n    Tensor consume(const std::string& key)\n    {\n        auto value = try_consume(key);\n        TM_CHECK(value) << get_out_of_range_msg(key);\n        return value;\n    }\n\nprivate:\n    std::string get_out_of_range_msg(const std::string& key) const;\n};\n\n// clang-format off\ntemplate<class Archive, class T, std::enable_if_t<std::is_same_v<Tensor, T>, int> = 0>\nvoid save(Archive& ar, const T& tensor)\n{\n    TM_CHECK(tensor.size() == 0 || tensor.is_contiguous());\n    ar & tensor.buffer(); // implicit convert to tensor\n    ar & tensor.layout();\n}\n\ntemplate<class Archive>\nvoid load(Archive& ar, Tensor& tensor)\n{\n    Buffer buffer;\n    Layout layout;\n    ar & buffer;\n    ar & layout;\n    tensor = Tensor{std::move(buffer), std::move(layout)};\n}\n\n\ntemplate<class Archive>\nvoid save(Archive& ar, const TensorMap& map)\n{\n    ar & map.size();\n    for (const auto& [k, t]: map) {\n        ar & k;\n        ar & t;\n    }\n}\n\ntemplate<class Archive>\nvoid load(Archive& ar, TensorMap& map)\n{\n    map.clear();\n    decltype(map.size()) size;\n    ar & size;\n    for (int i = 0; i < size; ++i) {\n        std::string k;\n        Tensor   t;\n        ar & k;\n        ar & t;\n        map.emplace(std::move(k), std::move(t));\n    }\n}\n// clang-format on\n\n}  // namespace turbomind::core\n"
  },
  {
    "path": "src/turbomind/core/test_core.cc",
    "content": "\n#include <numeric>\n\n#include \"src/turbomind/core/core.h\"\n\n#include \"catch2/catch_test_macros.hpp\"\n\nusing namespace turbomind;\n\nTEST_CASE(\"test check\", \"[check]\")\n{\n    int zero = 0;\n\n    TM_CHECK(!zero);\n\n    TM_CHECK_EQ(42, 42) << \"Ok\";\n    TM_CHECK_NE(42, 24) << \"Ok\";\n    TM_CHECK_GE(50, 42) << \"Ok\";\n    TM_CHECK_GT(50, 42) << \"Ok\";\n    TM_CHECK_LE(42, 50) << \"Ok\";\n    TM_CHECK_LT(42, 50) << \"Ok\";\n\n    if (0) {\n        TM_CHECK(zero);\n        TM_CHECK_EQ(42, 43) << \"Not \"\n                            << \"Ok\";\n    }\n\n    int  x = 42;\n    auto p = TM_CHECK_NOTNULL(&x);\n    REQUIRE(p == &x);\n\n    if (0) {\n        int* y{};\n        TM_CHECK_NOTNULL(y);\n        TM_CHECK_NOTNULL(std::shared_ptr<void>{});\n    }\n\n    auto y = TM_CHECK_NOTNULL(std::make_shared<int>(42));\n    REQUIRE(*y == 42);\n\n    TM_CHECK(y);\n}\n\nTEST_CASE(\"test allocator\", \"[allocator]\")\n{\n\n    using core::Allocator;\n    using core::Stream;\n\n    Allocator a;\n    REQUIRE(!a);\n\n    Allocator b{kCPU};\n    REQUIRE(b);\n    REQUIRE(a != b);\n    REQUIRE(b->device() == kCPU);\n    Stream s{};\n    REQUIRE(!b->stream());\n\n    // std::vector<int> v(1 << 20);\n    // std::iota(v.begin(), v.end(), 0);\n\n    // auto p = (int*)b->allocate(sizeof(int) * v.size());\n    // std::iota(p, p + v.size(), 0);\n\n    // REQUIRE(v == std::vector(p, p + v.size()));\n}\n\nTEST_CASE(\"test context\", \"[context]\")\n{\n    using core::Context;\n    using core::ContextGuard;\n    using core::Stream;\n    using core::Allocator;\n\n    Stream s0 = Stream::create();\n\n    ContextGuard g0{s0, Allocator{kCPU}};\n\n    REQUIRE(Context::stream());\n    REQUIRE(Context::stream() == s0);\n\n    auto a0 = Context::host_alloc();\n\n    {\n        Allocator a1(Context::stream(), false);  // device allocator\n        REQUIRE(a1->device().type == kDEVICE);\n\n        ContextGuard g1{a1};\n\n        REQUIRE(Context::stream() == s0);\n        REQUIRE(Context::device_alloc() == a1);\n        REQUIRE(Context::host_alloc() == a0);\n\n        {\n            ContextGuard g2{Stream::create(), Allocator(kDEVICE)};\n            REQUIRE(Context::device_alloc() != a1);\n            REQUIRE(Context::stream() != s0);\n        }\n\n        REQUIRE(Context::stream() == s0);\n        REQUIRE(Context::device_alloc() == a1);\n    }\n\n    REQUIRE(Context::stream() == s0);\n}\n\nTEST_CASE(\"test basic buffer\", \"[buffer]\")\n{\n    using core::Buffer;\n    using core::Buffer_;\n    using core::Allocator;\n\n    Buffer a;\n    REQUIRE(!a);\n\n    Buffer b;\n    REQUIRE(!b);\n    REQUIRE(a == b);\n\n    std::vector v{0, 1, 2, 3, 4, 5, 6, 7};\n\n    SECTION(\"reference into v\")\n    {\n        b = Buffer(v.data(), v.size(), kCPU);\n        REQUIRE(b.data<int>() == v.data());\n        REQUIRE(b.raw_data() == v.data());\n    }\n    SECTION(\"shared ownership\")\n    {\n        auto x = std::shared_ptr<int[]>(new int[v.size()]);\n        std::copy(v.begin(), v.end(), x.get());\n        b = Buffer(x, v.size(), data_type_v<int>, kCPU);\n        REQUIRE(b.data<int>() == x.get());\n        REQUIRE(b.raw_data() == x.get());\n    }\n    SECTION(\"allocation\")\n    {\n        Allocator alloc{kCPU};\n        b = Buffer(v.size(), data_type_v<int>, alloc);\n        std::copy(v.begin(), v.end(), b.data<int>());\n    }\n\n    REQUIRE(b);\n    REQUIRE(b.size() == v.size());\n    REQUIRE(b.dtype() == data_type_v<int>);\n    REQUIRE(b.byte_size() == sizeof(int) * v.size());\n    auto c = b;\n    REQUIRE(c == b);\n    REQUIRE(b == c);\n    REQUIRE(a != b);\n    REQUIRE(b != a);\n    REQUIRE(std::vector(b.data<int>(), b.data<int>() + b.size()) == v);\n\n    auto s = b.slice(3, 2);\n    REQUIRE(s.size() == 2);\n    REQUIRE(s.raw_data() == b.data<int>() + 3);\n\n    Buffer_<int> x;\n    Buffer_<int> y = Buffer{data_type_v<int>};\n\n    Buffer z = Buffer_<int>(1024, kCPU);\n\n    x = z;\n\n    for (int i = 0; i < z.size(); ++i) {\n        x[i] = i;\n    }\n\n    std::vector<int> ref(1024);\n    std::iota(ref.begin(), ref.end(), 0);\n    REQUIRE(std::vector(x.begin(), x.end()) == ref);\n\n    Buffer e;\n    REQUIRE(!e.data_or((void*)0));\n    REQUIRE(!e.data_or<int>(nullptr));\n\n    Buffer_<int> w;\n    REQUIRE(!w.data_or(nullptr));\n    REQUIRE(!std::as_const(w).data_or(nullptr));\n\n    w = {1024, kCPU};\n    REQUIRE(w.raw_data());\n    REQUIRE(std::as_const(w).raw_data());\n}\n\nTEST_CASE(\"test buffer view\", \"[buffer]\")\n{\n    using core::Buffer;\n\n    std::vector<int64_t> v{0, 1, 2, 3, 4, 5, 6, 7};\n\n    Buffer b(v.data(), v.size(), kCPU);\n\n    auto c = b.slice(2, 4);\n    REQUIRE(c.size() == 4);\n    REQUIRE(c.raw_data() == b.data<int64_t>() + 2);\n\n    std::cout << c << std::endl;\n\n    auto d = c.view<int>();\n\n    REQUIRE(d.size() == c.size() * 2);\n    REQUIRE(d.raw_data() == c.raw_data());\n}\n\nTEST_CASE(\"test layout\", \"[layout]\")\n{\n    using core::Layout;\n\n    Layout a;  // default ctor\n    REQUIRE(a.size() == 0);\n    REQUIRE(a.cosize() == 0);\n\n    Layout b({20, 50});\n    REQUIRE(b.size() == 1000);\n    REQUIRE(b.cosize() == b.size());\n    REQUIRE(to_string(b) == \"(20,50):(50,1)\");\n\n    Layout c = b.coalesce();\n    REQUIRE(c.size() == b.size());\n    REQUIRE(c.cosize() == b.cosize());\n    REQUIRE(to_string(c) == \"(1000):(1)\");\n\n    Layout v = b.view({50, 20});\n    REQUIRE(v.size() == b.size());\n    REQUIRE(v.cosize() == b.cosize());\n    REQUIRE(to_string(v) == \"(50,20):(20,1)\");\n\n    v = b.view({25, -1});\n    REQUIRE(to_string(v) == \"(25,40):(40,1)\");\n\n    v = b.view({5, -1, 5});\n    REQUIRE(to_string(v) == \"(5,40,5):(200,5,1)\");\n\n    v = b.view({-1, 20, 10, 1});\n    REQUIRE(to_string(v) == \"(5,20,10,1):(200,10,1,1)\");\n\n    REQUIRE(to_string(v.coalesce()) == \"(1000):(1)\");\n\n    auto [s, offset] = b.slice({10, 20}, {-1, -1});\n    REQUIRE(to_string(s) == \"(10,30):(50,1)\");\n    REQUIRE(offset == 520);\n\n    v = s.view({2, -1, 3, 10});\n    std::cout << v << std::endl;\n\n    std::cout << v.coalesce() << std::endl;\n\n    // v = s.view({30, 10});\n    // std::cout << v << std::endl;\n}\n\nTEST_CASE(\"test tensor\", \"[tensor]\")\n{\n    using core::Tensor;\n    using core::Tensor_;\n    using core::Allocator;\n\n    Tensor a;\n    REQUIRE(!a);\n\n    Tensor_<float> b{{10, 20}, kCPU};\n    Tensor_<float> c = b.slice(0, 5);\n\n    std::cout << b << std::endl;\n\n    REQUIRE(c.shape() == std::vector<ssize_t>{5, 20});\n    REQUIRE(c.data() == b.data());\n\n    auto d = b.view({2, -1, 10});\n    REQUIRE(d.shape() == std::vector<ssize_t>{2, 10, 10});\n\n    // this is typed\n    Tensor_<float> x = Tensor_<float>{};\n    // while being empty\n    REQUIRE(!x);\n\n    if (0) {\n        // empty Tensor has invalid type\n        Tensor_<float> x = Tensor{};\n    }\n    a = {};\n    x = {};\n\n    Tensor y = core::Buffer{100, kInt32, kCPU};\n    REQUIRE(y.ndim() == 1);\n    REQUIRE(y.shape(0) == 100);\n}\n"
  },
  {
    "path": "src/turbomind/engine/CMakeLists.txt",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\ncmake_minimum_required(VERSION 3.11)\n\nadd_library(engine STATIC\n    gateway.cc\n    request.cc\n    request_queue.cc\n    model_request.cc\n    model_executor.cc\n    engine.cc\n    )\ntarget_link_libraries(engine PRIVATE xgrammar core)\nset_property(TARGET engine PROPERTY POSITION_INDEPENDENT_CODE  ON)\nset_property(TARGET engine PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS  ON)\n"
  },
  {
    "path": "src/turbomind/engine/batch.h",
    "content": "\n#pragma once\n\n#include <future>\n\n#include \"src/turbomind/core/core.h\"\n#include \"src/turbomind/engine/request.h\"\n\nnamespace turbomind {\n\nenum class BatchOp\n{\n    kAdd,      //  Se ->  Rc         H\n    kSetup,    //  Rc -> (B  -> D)   H2D\n    kPrepare,  // (D  ->  St)        D\n    kForward,  //  St ->  St         D\n    kUnprep,   // (St ->  D)         D\n    kFetch,    // (D  ->  B)         D2H\n    kUpdate,   //  B  ->  Rc         H\n    kDel,      //  Rc ->  Se         H\n};\n\n// Se -> Rc -> (B -> D) -> St -> (D -> B) -> Rc -> Se\n\n/*\nSe -> Rc                   (add: rc)\n    Rc -> B\n        (B -> D)           (setup: rc, d, copy)\n            (D -> St)\n                St -> St   (forward)\n            (St -> D)\n        (D -> B)\n    B -> Rc                (sync)\nRc -> Se                   (del: rc)\n*/\n\nstruct BatchData {\n\n    explicit BatchData(int phase): self{this}, phase{phase}\n    {\n        ready = Event::create();\n        done  = Event::create();\n        next  = Event::create();\n    }\n\n    BatchData(const BatchData&)     = delete;\n    BatchData(BatchData&&) noexcept = delete;\n    BatchData& operator=(const BatchData&) = delete;\n    BatchData& operator=(BatchData&&) noexcept = delete;\n\n    BatchData* self;\n\n    const int phase;\n\n    int bs0 = 0;\n    int bsz = 0;\n\n    Buffer_<int> perm;\n\n    std::vector<std::shared_ptr<RequestCache>> rc;\n\n    std::vector<int> local_token_num;\n    int              global_token_num = 0;\n\n    Event ready;\n    Event done;\n    Event next;\n\n    std::promise<Event> promise;\n\n    Buffer buf()\n    {\n        return Buffer{&self, 1, kCPU};\n    }\n\n    void Notify()\n    {\n        next.Record(core::Context::stream());\n        promise.set_value(next);\n    }\n};\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/engine/engine.cc",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include <algorithm>\n#include <atomic>\n#include <chrono>\n#include <memory>\n#include <thread>\n\n#include \"nvtx3/nvToolsExt.h\"\n\n#include \"src/turbomind/comm/host_comm.h\"\n#include \"src/turbomind/core/allocator.h\"\n#include \"src/turbomind/core/check.h\"\n#include \"src/turbomind/core/context.h\"\n#include \"src/turbomind/engine/engine.h\"\n#include \"src/turbomind/engine/model_executor.h\"\n#include \"src/turbomind/engine/request.h\"\n\n#include \"src/turbomind/core/copy.h\"\n#include \"src/turbomind/models/language_model.h\"\n#include \"src/turbomind/models/llama/SequenceManager.h\"\n#include \"src/turbomind/models/llama/llama_params.h\"\n#include \"src/turbomind/utils/logger.h\"\n#include \"src/turbomind/utils/metrics.h\"\n\n// #include \"dbg.h\"\n\nnamespace turbomind {\n\nusing std::shared_ptr;\nusing std::unique_ptr;\nusing std::vector;\n\nstruct RequestData {\n    vector<shared_ptr<Request>> infer;  // incoming inference request\n    vector<shared_ptr<Request>> kill;   // incoming kill request\n\n    vector<int> cancel;  // canceled indices in current batch\n    bool        abort;\n};\n\ntemplate<class Archive>\nvoid serdes(Archive& ar, RequestData& r)\n{\n    ar& r.infer;\n    ar& r.kill;\n    ar& r.cancel;\n    ar& r.abort;\n}\n\nstruct Engine::Impl {\n\n    using Requests = vector<shared_ptr<Request>>;\n    using Signal   = std::function<void()>;\n\n    Impl(DataType      dtype,\n         EngineParam   param,\n         LanguageModel model,\n         Context&      ctx,\n         Gateway&      gateway,\n         int           device_id,\n         int           queue_id,\n         int           phases);\n\n    void CreateSequenceManager();\n\n    void InternalThreadEntry();\n\n    void Validate(Requests& infer_rs, Requests& kill_rs);\n\n    void Kill(const Requests& rs, vector<Signal>& signals);\n\n    vector<int> GetCanceled();\n\n    void Cancel(vector<int>& indices, vector<Signal>& signals);\n\n    void Accept(const Requests& rs, vector<Signal>& signals);\n\n    void Interrupt(RequestCache& c);\n\n    // Allocation of memory / compute resources\n    void Schedule();\n\n    // intiailize RC from `Sequence`\n    void Setup(BatchData& d);\n\n    // Sync vars from batch output to RC\n    void Update(BatchData& d, std::vector<Signal>& signals);\n\n    void Run(BatchOp op, int phase, Ref<TensorMap> env)\n    {\n        model_.Run(op, phase, env);\n    }\n\n    void Start()\n    {\n        internal_thread_ = std::thread(&Impl::InternalThreadEntry, this);\n        executor_.Start();\n    }\n\n    void UpdateScheduleMetrics();\n\n    ~Impl();\n\n    const DataType    dtype_;\n    const EngineParam param_;\n\n    Gateway& gateway_;\n\n    comm::HostComm& tp_group_;\n    comm::HostComm& dp_group_;\n\n    const int tp_rank_;\n    const int dp_rank_;\n    const int dp_size_;\n\n    const int device_id_;\n    const int queue_id_;\n\n    const int async_;\n\n    int& is_warm_up_;\n\n    unique_ptr<SequenceManager> seq_mgr_;\n\n    Queue<unique_ptr<BatchData>> inbound_;\n    Queue<unique_ptr<BatchData>> outbound_;\n\n    LanguageModel model_;\n    ModelExecutor executor_;\n\n    std::thread internal_thread_;\n\n    int session_len_trunc_;\n\n    shared_ptr<ScheduleMetrics> metrics_;\n\n    struct State {\n        vector<shared_ptr<RequestCache>> rc;\n        vector<int>                      perm;\n\n        int bs0     = 0;\n        int active  = 0;\n        int finish  = 0;\n        int swapout = 0;\n\n        int size() const noexcept\n        {\n            return rc.size();\n        }\n    };\n\n    vector<State> states_;\n\n    struct Data {\n    };\n    vector<Data> data_;\n\n    // staging buffers\n    Buffer_<void*> block_ptrs_buf_;\n    Buffer_<int>   block_ptrs_offsets_buf_;\n};\n\nEngine::Impl::~Impl()\n{\n    TM_LOG_INFO(__PRETTY_FUNCTION__);\n    inbound_.close();\n    outbound_.close();\n    if (internal_thread_.joinable()) {\n        internal_thread_.join();\n    }\n    executor_ = {};\n}\n\nEngine::Impl::Impl(DataType      dtype,\n                   EngineParam   param,\n                   LanguageModel model,\n                   Context&      ctx,\n                   Gateway&      gateway,\n                   int           device_id,\n                   int           queue_id,\n                   int           phases):\n    dtype_{dtype},\n    param_{param},\n    gateway_{gateway},\n    tp_group_{ctx.comm.h_tp_group},\n    dp_group_{ctx.comm.h_dp_group},\n    tp_rank_{tp_group_->rank()},\n    dp_rank_{dp_group_->rank()},\n    dp_size_{dp_group_->n_ranks()},\n    device_id_{device_id},\n    queue_id_{queue_id},\n    async_{phases > 1},\n    is_warm_up_{*ctx.is_warm_up},\n    model_{std::move(model)}\n{\n    states_.emplace_back();\n\n    for (int i = 0; i < phases; ++i) {\n        data_.emplace_back();\n    }\n\n    executor_ = ModelExecutor{model_, ctx, device_id_, outbound_, inbound_};\n\n    CreateSequenceManager();  // initializes `session_len_trunc_`\n\n    const ssize_t max_batch_block_num =\n        param.max_batch_size * cdiv(session_len_trunc_, model_.attn_param().cache_block_seq_len);\n    block_ptrs_buf_         = {max_batch_block_num, kCPUpinned};\n    block_ptrs_offsets_buf_ = {param.max_batch_size + 1, kCPUpinned};\n}\n\nvoid Engine::Impl::CreateSequenceManager()\n{\n    const auto cache_block_seq_len = model_.attn_param().cache_block_seq_len;\n\n    const auto& model_param = model_.model_param();\n\n    const auto get_free_size = [&] {  //\n        size_t free{}, total{};\n        check_cuda_error(cudaMemGetInfo(&free, &total));\n        return AllReduce(tp_group_, free, comm::RedOp::kMin);\n    };\n\n    seq_mgr_ = std::make_unique<SequenceManager>(model_param,\n                                                 dtype_,\n                                                 cache_block_seq_len,\n                                                 param_.attn_tp_size,\n                                                 param_.max_batch_size,\n                                                 param_.cache_max_block_count,\n                                                 param_.cache_chunk_size,\n                                                 param_.enable_prefix_caching,\n                                                 tp_rank_,\n                                                 param_.attn_cp_size,\n                                                 core::Context::alloc(kDEVICE),\n                                                 get_free_size);\n\n    const auto max_cached_tokens = seq_mgr_->max_block_count() * (size_t)cache_block_seq_len * param_.attn_cp_size;\n    session_len_trunc_           = std::min(max_cached_tokens, (size_t)param_.session_len);\n    TM_LOG_INFO(\"max cached tokens: %lld\", max_cached_tokens);\n    if (session_len_trunc_ != param_.session_len) {\n        TM_LOG_WARNING(\"`session_len` truncated to %d due to limited KV cache memory\", session_len_trunc_);\n    }\n}\n\nvoid Engine::Impl::Validate(Requests& infer_reqs, Requests& kill_reqs)\n{\n    std::pmr::monotonic_buffer_resource    mbr;\n    std::pmr::unordered_map<uint64_t, int> occur(&mbr);\n\n    const bool has_linear_attention = HasLinearAttention(model_.model_param());\n\n    auto count = [&occur](const auto& reqs) {\n        for (const auto& r : reqs) {\n            ++occur[r->id];\n        }\n    };\n\n    auto validate = [&](auto& reqs, const char* type, bool is_infer) {\n        for (const auto& r : reqs) {\n            if (occur[r->id] > 1) {\n                TM_LOG_ERROR(\"Skip conflicting %s request for ID %lu\", type, r->id);\n                r->ec = Request::kConflict;\n            }\n            if (!r->ec && is_infer && has_linear_attention && !r->session.end_flag) {\n                TM_LOG_ERROR(\"Skip inconsistent %s request for ID %lu. Linear attention only supports stateless \"\n                             \"requests\",\n                             type,\n                             r->id);\n                r->ec = Request::kInconsistency;\n            }\n            if (param_.enable_prefix_caching) {\n                if (r->session.step != 0) {\n                    // Prefix caching is incompatible with interactive mode\n                    TM_LOG_ERROR(\"Skip inconsistent %s request for ID %lu step %d\", type, r->id, r->session.step);\n                    r->ec = Request::kInconsistency;\n                }\n                else if (r->gen_cfg.output_logits == GenerationConfig::kAll\n                         || r->gen_cfg.output_last_hidden_state == GenerationConfig::kAll) {\n                    // Prefix caching is incompatible with outputting all tokens' logits or last_hidden_state\n                    TM_LOG_ERROR(\"Skip inconsistent %s request for ID %lu. It cannot output logits or \"\n                                 \"last_hidden_states for all tokens\",\n                                 type,\n                                 r->id);\n                    r->ec = Request::kInconsistency;\n                }\n            }\n        }\n    };\n\n    for (const auto& s : states_) {\n        for (int i = 0; i < s.size(); ++i) {\n            if (s.rc[i]) {\n                ++occur[s.rc[i]->req->id];\n            }\n        }\n    }\n\n    count(kill_reqs);\n    count(infer_reqs);\n\n    validate(kill_reqs, \"kill\", false);\n    validate(infer_reqs, \"infer\", true);\n\n    // New requests that never get a chance to start\n    for (auto& r : infer_reqs) {\n        if (r && r->cancel_flag.load(std::memory_order_acquire) == -1) {\n            r->ec = Request::kCancel;\n        }\n    }\n}\n\nvector<int> Engine::Impl::GetCanceled()\n{\n    auto& s = states_.at(0);\n\n    vector<int> idxs;\n    for (int i = 0; i < s.size(); ++i) {  // current batch\n        const auto& r = s.rc[i];\n        if (r && r->req->cancel_flag.load(std::memory_order_acquire) == -1) {\n            idxs.push_back(i);\n        }\n    }\n    return idxs;\n}\n\nvoid Engine::Impl::Kill(const Requests& kills, vector<Signal>& signals)\n{\n    for (auto& r : kills) {\n        if (r) {\n            int ec = r->ec;\n            if (!ec) {\n                if (!seq_mgr_->Erase(r->id)) {\n                    ec = Request::kInvalid;\n                }\n            }\n            signals.push_back([=] { r->end_cb ? r->end_cb(ec) : void(); });\n        }\n    }\n}\n\nvoid Engine::Impl::Interrupt(RequestCache& c)\n{\n    auto& s = *TM_CHECK_NOTNULL(c.seq);\n    if (c.req->session.end_flag) {\n        if (!is_warm_up_ && s.status != Sequence::kCached) {  // At least `Locked` status is required for caching\n            seq_mgr_->CacheGeneration(s);\n        }\n        TM_CHECK(seq_mgr_->Erase(c.req->id));\n    }\n    else {\n        if (s.recurrent_states && c.seq_len != s.cache_len) {\n            TM_LOG_WARNING(\n                \"[Engine][Interrupt] Invalidating cache for ID %llu due to linear-state/cache mismatch (%d vs %d)\",\n                s.id,\n                c.seq_len,\n                s.cache_len);\n            seq_mgr_->InvalidateStatesAndCache(s);\n        }\n        else {\n            seq_mgr_->UpdateAndSetUnlock(s);\n        }\n    }\n    c.seq = nullptr;\n}\n\nvoid Engine::Impl::Cancel(vector<int>& indices, vector<Signal>& signals)\n{\n    auto& s = states_.at(0);\n    for (const auto& i : indices) {\n        auto& c = TM_CHECK_NOTNULL(s.rc[i]);\n        c->done = true;\n        Interrupt(*c);\n        signals.push_back([r = std::move(c->req), l = c->seq_len] {  //\n            UpdateState(*r, Request::kCancel, l);\n        });\n        c.reset();\n        s.finish += 1;\n    }\n}\n\nvoid Engine::Impl::Accept(const Requests& rs, vector<Signal>& signals)\n{\n    auto& s = states_.at(0);\n\n    vector<unique_ptr<RequestCache>> incoming;\n    incoming.reserve(rs.size());\n\n    for (const auto& r : rs) {\n\n        if (r->ec) {\n            signals.push_back([r] { UpdateState(*r, r->ec, 0); });\n            continue;\n        }\n\n        const int input_len = r->inputs.at(\"input_ids\").shape(0);\n\n        if (input_len > session_len_trunc_) {\n            signals.push_back([r] { UpdateState(*r, Request::kTooLong, 0); });\n            continue;\n        }\n\n        auto ptr = r->session.start_flag ? seq_mgr_->Create(r->id) : seq_mgr_->Get(r->id);\n        if (!ptr) {\n            signals.push_back([r] { UpdateState(*r, Request::kInvalid, 0); });\n            continue;\n        }\n\n        const int step = [&] {\n            int s = r->session.step;\n            if (s < 0) {\n                s = ptr->tokens.size();\n            }\n            else if (s > ptr->tokens.size()) {\n                if (tp_rank_ == 0) {\n                    TM_LOG_WARNING(\"[ProcessInferRequests] Skipping invalid step (%d) setting for ID %lu\", s, ptr->id);\n                }\n                s = ptr->tokens.size();\n            }\n            return s;\n        }();\n\n        if (step + input_len > session_len_trunc_) {\n            signals.push_back([r] { UpdateState(*r, Request::kTooLong, 0); });\n            continue;\n        }\n\n        if (step && param_.enable_prefix_caching) {\n            // step not supported in prefix-caching mode\n            signals.push_back([r] { UpdateState(*r, Request::kInconsistency, 0); });\n            continue;\n        }\n\n        auto& seq = *ptr;\n        seq_mgr_->AcquireLinearStateSlot(seq);\n\n        if (seq.recurrent_states) {\n            if (step != seq.cache_len) {\n                signals.push_back([r] { UpdateState(*r, Request::kInvalid, 0); });\n                continue;\n            }\n        }\n\n        auto c = std::make_unique<RequestCache>(r, seq);\n\n        if (step < seq.tokens.size()) {\n            seq.tokens.resize(step);\n            seq.cache_len = std::min(seq.cache_len, step);\n        }\n\n        c->step0 = step;\n\n        // const int* input_ids = r->inputs.at(\"input_ids\").data<int>();\n        auto& input_ids = r->inputs.at(\"input_ids\");\n\n        int* token_ids = c->token_ids = r->output_ids.data();\n\n        /// TODO: move this somewhere else\n        token_ids = std::copy_n(seq.tokens.data(), seq.tokens.size(), token_ids);\n        token_ids = std::copy_n(input_ids.data<int>(), input_len, token_ids);\n\n        c->prompt_len = c->seq_len = token_ids - c->token_ids;  // all known tokens\n\n        // Only prefix cache needs prompt data\n        if (param_.enable_prefix_caching && input_len && r->session.start_flag) {\n            seq.prompt.insert(seq.prompt.end(), input_ids.data<int>(), input_ids.data<int>() + input_len);\n        }\n\n        // dbg(seq.cache_len, seq.tokens.size(), input_len, c->seq_len);\n\n        int max_seq_len = c->prompt_len + c->gen_cfg.max_new_tokens;\n        if (max_seq_len > session_len_trunc_) {\n            max_seq_len = session_len_trunc_;\n            if (tp_rank_ == 0) {\n                const int trunc_output_len = max_seq_len - c->prompt_len;\n                // clang-format off\n                TM_LOG_WARNING(\"[ProcessInferRequests] [%ld] total sequence length (%d + %d) exceeds `session_len` (%d), `max_new_tokens` is truncated to %d\",\n                    (long)seq.id, c->prompt_len, c->gen_cfg.max_new_tokens, session_len_trunc_, trunc_output_len);\n                // clang-format on\n            }\n        }\n        c->max_seq_len = max_seq_len;\n\n        incoming.push_back(std::move(c));\n    }\n\n    Buffer_<RequestCache*> buf(incoming.size(), kCPU);\n    for (int i = 0; i < incoming.size(); ++i) {\n        buf[i] = incoming[i].get();\n    }\n\n    // This includes checks from all modules handling `Add` operation\n    Run(BatchOp::kAdd, -1, TensorMap{{\"requests\", buf}});\n\n    for (auto& x : incoming) {\n        if (x->status == 0) {\n            s.rc.push_back(std::move(x));\n        }\n        else {\n            Interrupt(*x);\n            signals.push_back([r = x->req, ec = x->status] {  //\n                UpdateState(*r, ec, 0);\n            });\n        }\n    }\n}\n\nvoid Engine::Impl::Schedule()\n{\n    auto& s = states_.at(0);\n\n    vector<const Sequence*>  sequences;\n    vector<Sequence::Status> status;\n    vector<int>              context_length;\n    vector<int>              alpha;\n    vector<uint64_t>         priorities;\n    vector<RequestCache*>    cache;\n    vector<int>              inv;\n\n    for (int i = 0; i < s.size(); ++i) {\n        // skip invalid positions\n        if (const auto& c = s.rc[i]) {\n            cache.push_back(c.get());\n            sequences.push_back(c->seq);\n            status.push_back(c->seq->status);\n            priorities.push_back(c->req->unique_id);\n            context_length.push_back(c->seq_len + c->beta /* plus draft tokens */);\n            alpha.push_back(c->alpha);\n            TM_CHECK(c->seq->status == Sequence::kActive || c->alpha == 0) << c->seq->status << \" \" << c->alpha;\n            inv.push_back(i);\n            c->input_len = c->history_len = 0;\n            // dbg(c->request->id, c->seq_len, c->sequence.cache_len, c->alpha, c->beta, c->is_decoding,\n            // c->is_generate);\n        }\n    }\n\n    // dbg(\"Schedule\");\n\n    seq_mgr_->Materialize(\n        sequences, context_length, alpha, priorities, param_.max_forward_token_num, param_.max_context_token_num);\n\n    vector<int> idxs(sequences.size());\n    std::iota(idxs.begin(), idxs.end(), 0);\n\n    subrange active{idxs.begin(), std::stable_partition(idxs.begin(), idxs.end(), [&](int i) {\n                        return sequences[i]->status == Sequence::kActive;  // IS active\n                    })};\n\n    TM_CHECK(sequences.empty() || !active.empty()) << \"No enough blocks\";\n\n    if (is_warm_up_) {\n        // Avoid extra iteration for warm up request in async mode (force inactivate)\n        active = {active.begin(), std::stable_partition(active.begin(), active.end(), [&](int i) {  //\n                      return alpha[i] == 0;\n                  })};\n    }\n\n    subrange inactive{active.end(), idxs.end()};\n\n    subrange existing{active.begin(), std::stable_partition(active.begin(), active.end(), [&](int i) {\n                          return status[i] == Sequence::kActive;  // WAS active in active\n                      })};\n\n    subrange swap_in{existing.end(), active.end()};\n\n    subrange swap_out{inactive.begin(), std::stable_partition(inactive.begin(), inactive.end(), [&](int i) {\n                          return status[i] == Sequence::kActive;  // WAS active in inactive\n                      })};\n\n    // |<-- existing -->|<-- swap-in -->|<- swap-out ->|\n    // |<----------- active ----------->|<------- inactive ----->|\n\n    for (auto i : swap_in) {\n        cache[i]->autoregres = {};\n        cache[i]->generating = {};\n    }\n\n    if (param_.enable_metrics) {\n        for (auto i : swap_in) {\n            if (auto& m = cache[i]->req->metrics; TM_LIKELY(m)) {\n                int64_t expected = 0;\n                m->scheduled_time.compare_exchange_strong(\n                    expected, RequestMetrics::timestamp(), std::memory_order_relaxed);\n            }\n        }\n    }\n\n    for (auto i : existing) {\n        if (cache[i]->generating) {\n            cache[i]->autoregres = true;\n        }\n    }\n\n    for (auto i : active) {\n        auto& s = *sequences[i];\n        auto& c = *cache[i];\n        if (s.cache_len + c.alpha + s.input_length == c.seq_len + c.beta) {\n            c.generating = true;\n        }\n    }\n\n    // move partially prefilled sequences to the back\n    subrange partial{std::stable_partition(active.begin(), active.end(), [&](int i) { return cache[i]->generating; }),\n                     active.end()};\n    TM_CHECK_LE(partial.size(), 1);\n\n    // dbg(inv);\n\n    vector<shared_ptr<RequestCache>> rc(idxs.size());\n    vector<int>                      perm(idxs.size());\n    for (int i = 0; i < idxs.size(); ++i) {\n        perm[i] = inv[idxs[i]];              // inverse map to original indices\n        rc[i]   = std::move(s.rc[perm[i]]);  // warp the request cache\n    }\n    s.rc.swap(rc);\n    s.perm.swap(perm);\n\n    for (auto& c : s.rc) {\n        /// ! input_length not updated for inactive seqs\n        c->input_len   = c->seq->input_length;\n        c->history_len = c->seq->cache_len;\n        // dbg(c->request->id,\n        //     c->seq_len,\n        //     c->history_len,\n        //     c->input_len,\n        //     c->alpha,\n        //     c->beta,\n        //     c->is_decoding,\n        //     c->is_generate);\n    }\n\n    s.bs0     = std::exchange(s.active, active.size());\n    s.swapout = swap_out.size();\n    s.finish  = 0;\n}\n\nvoid Engine::Impl::Setup(BatchData& d)\n{\n    auto& st = states_.at(0);\n\n    d.rc.resize(st.active);\n    std::copy_n(st.rc.begin(), st.active, d.rc.begin());\n\n    block_ptrs_offsets_buf_[0] = 0;\n    auto block_ptrs            = block_ptrs_buf_.data();\n    for (int i = 0; i < st.active; ++i) {\n        const auto& s                  = *st.rc[i]->seq;\n        block_ptrs_offsets_buf_[i + 1] = block_ptrs_offsets_buf_[i] + s.blocks.size();\n        block_ptrs = std::transform(s.blocks.cbegin(), s.blocks.cend(), block_ptrs, [&](int block_id) {\n            return seq_mgr_->GetBlockPtr(block_id);\n        });\n    }\n\n    d.bs0 = st.bs0;\n    d.bsz = st.active;\n\n    d.perm = {d.bsz, kCPU};\n    std::copy_n(st.perm.data(), d.bsz, d.perm.data());\n\n    // dbg(d.bs0, d.bsz, d.perm);\n\n    BatchCopy copy{};\n\n    TensorMap env{{\"batch\", d.buf()},\n                  {\"copy\", copy.buf()},\n                  {\"block_ptrs_offsets\", block_ptrs_offsets_buf_},\n                  {\"block_ptrs\", block_ptrs_buf_}};\n\n    Run(BatchOp::kSetup, d.phase, env);\n\n    // dbg(copy);\n    copy.Run();\n\n    d.local_token_num.resize(dp_size_);\n    d.local_token_num[dp_rank_] = *env.at(\"token_num\").data<int>();\n    if (dp_size_ > 1) {\n        AllGather(dp_group_, d.local_token_num.data(), 1);\n    }\n    d.global_token_num = std::accumulate(d.local_token_num.begin(), d.local_token_num.end(), 0);\n    // dbg(dp_group_->rank(), d.local_token_num, d.global_token_num);\n}\n\nvoid Engine::Impl::Update(BatchData& b, std::vector<Signal>& signals)\n{\n    auto& s = states_.at(0);\n\n    BatchCopy copy;\n\n    TensorMap env{{\"batch\", b.buf()}, {\"copy\", copy.buf()}};\n\n    // Copy outputs to host buffers\n    Run(BatchOp::kFetch, b.phase, env);\n\n    copy.Run();\n\n    core::Context::stream().Sync();\n\n    //\n    Run(BatchOp::kUpdate, b.phase, env);\n\n    Buffer_<bool> finished        = env.at(\"finished\").buffer();\n    Buffer_<bool> generating      = env.at(\"generating\").buffer();\n    Buffer_<int>  output_ids      = env.at(\"output_ids\").buffer();\n    Buffer_<int>  sequence_length = env.at(\"sequence_length\").buffer();\n\n    env = {};\n\n    vector<const Sequence*> sequences_to_cache;\n\n    for (int i = 0; i < b.rc.size(); ++i) {\n        // In async mode, `seq` may be nullptr when the request is done\n        if (auto& c = *b.rc[i]; c.seq) {\n            if (auto& s = *c.seq; generating[i]) {\n                c.token_ids[c.seq_len] = output_ids[i];\n                c.seq_len              = sequence_length[i];\n                s.cache_len            = sequence_length[i] - 1;\n                if (const int new_tokens = c.seq_len - s.tokens.size()) {\n                    s.tokens.insert(s.tokens.end(), c.token_ids + c.seq_len - new_tokens, c.token_ids + c.seq_len);\n                }\n                if (TM_UNLIKELY(finished[i])) {\n                    signals.push_back([r = c.req, l = c.seq_len] {  //\n                        UpdateState(*r, Request::kFinish, l);\n                    });\n                }\n                else if (c.req->stream_output) {\n                    signals.push_back([r = c.req, l = c.seq_len] {  //\n                        UpdateState(*r, Request::kOk, l);\n                    });\n                }\n            }\n            else {\n                s.cache_len = sequence_length[i];\n            }\n            c.done |= finished[i];\n            if (c.seq->status != Sequence::kCached) {  // At least `Locked` status is required for caching\n                sequences_to_cache.push_back(c.seq);\n            }\n            // dbg(c.seq_len, c.sequence.cache_len, c.alpha, c.beta, c.is_decoding, c.is_generate);\n        }\n    }\n\n    if (!is_warm_up_) {\n        seq_mgr_->CachePrompt(sequences_to_cache, sequences_to_cache.size());\n    }\n\n    b.rc.clear();\n\n    if (async_) {\n        const int size = s.active + s.swapout;\n        for (int i = 0; i < size; ++i) {\n            auto& c = *s.rc[i];\n            if (i < s.active) {\n                c.alpha = c.input_len;\n                c.beta  = c.generating;\n            }\n            else {\n                // Just got swaped-out\n                c.alpha = c.beta = 0;\n            }\n        }\n    }\n\n    for (auto& x : s.rc) {\n        if (TM_UNLIKELY(x->done)) {\n            Interrupt(*x);\n            x.reset();\n            s.finish += 1;\n        }\n    }\n}\n\nvoid Engine::Impl::InternalThreadEntry()\n{\n    check_cuda_error(cudaSetDevice(device_id_));\n\n    auto stream = Stream::create();\n\n    core::ContextGuard ctx{stream, Allocator(kCPU), Allocator(stream, false)};\n\n    unique_ptr<BatchData> d = std::make_unique<BatchData>(0);\n\n    for (unsigned i = 1; i < data_.size(); ++i) {\n        inbound_.push(std::make_unique<BatchData>(i));\n    }\n\n    while (true) {\n\n        shared_ptr<RequestData> rs;\n\n        auto& st = states_.at(0);\n\n        if (tp_rank_ == 0) {\n            rs = std::make_shared<RequestData>();\n\n            const int  n_free   = param_.max_batch_size - st.size() + st.finish;\n            const bool blocking = n_free == param_.max_batch_size;\n\n            gateway_.pop(rs->infer, rs->kill, n_free, blocking, rs->abort, dp_group_, queue_id_);\n\n            Validate(rs->infer, rs->kill);\n\n            rs->cancel = GetCanceled();\n        }\n\n        if (st.size() - st.finish == 0 && tp_group_->is_same_process()) {\n            // Only thread comm has blocking sync\n            tp_group_->Sync(true);\n        }\n\n        if (tp_group_->n_ranks() > 1) {\n            Broadcast(tp_group_, rs, 0);\n        }\n\n        if (rs->abort) {\n            TM_LOG_INFO(\"[Engine] stop requested.\");\n            break;\n        }\n\n        vector<Signal> signals;\n\n        Kill(rs->kill, signals);\n\n        Accept(rs->infer, signals);\n\n        Cancel(rs->cancel, signals);\n\n        gateway_.notify(std::move(signals), tp_rank_ == 0);\n\n        int n_active = st.size() - st.finish;\n\n        TM_CHECK_GE(n_active, 0);\n\n        n_active = AllReduce(dp_group_, n_active, comm::RedOp::kSum);\n\n        if (n_active) {\n\n            Schedule();\n\n            UpdateScheduleMetrics();\n\n            Setup(*d);\n\n            d->ready.Record(core::Context::stream());\n\n            // auto future = (d->promise = {}).get_future();\n\n            outbound_.push(std::move(d));\n\n            if (!inbound_.pop(d)) {\n                break;\n            }\n\n            // Must assume `d` is not the same one as above\n            TM_CHECK_NOTNULL(d);\n\n            core::Context::stream().Wait(d->done);\n\n            Update(*d, signals);\n\n            gateway_.notify(std::move(signals), tp_rank_ == 0);\n\n            // if (future.valid()) {\n            //     future.get().Sync();\n            // }\n        }\n\n        // dbg(\"=========================================================================\");\n    }\n}\n\nEngine::~Engine() = default;\n\nEngine::Engine()                  = default;\nEngine::Engine(Engine&&) noexcept = default;\nEngine& Engine::operator=(Engine&&) noexcept = default;\n\nEngine::Engine(DataType      dtype,\n               EngineParam   param,\n               LanguageModel model,\n               Context&      ctx,\n               Gateway&      gateway,\n               int           device_id,\n               int           dp_rank,\n               int           phases):\n    impl_{std::make_unique<Impl>(dtype, param, std::move(model), ctx, gateway, device_id, dp_rank, phases)}\n{\n}\n\nvoid Engine::Start()\n{\n    return impl_->Start();\n}\n\nvoid Engine::Impl::UpdateScheduleMetrics()\n{\n    if (param_.enable_metrics) {\n        const auto& [total, active, cached] = seq_mgr_->seq_stats();\n\n        auto m = std::make_shared<ScheduleMetrics>();\n\n        m->total_seqs   = total;\n        m->active_seqs  = active;\n        m->waiting_seqs = total - active;\n\n        m->total_blocks  = seq_mgr_->total_count();\n        m->active_blocks = seq_mgr_->active_count();\n        m->cached_blocks = seq_mgr_->cached_count();\n        m->free_blocks   = seq_mgr_->free_count();\n\n        std::atomic_store_explicit(&metrics_, std::move(m), std::memory_order_release);\n    }\n}\n\nshared_ptr<ScheduleMetrics> Engine::GetScheduleMetrics()\n{\n    if (impl_->param_.enable_metrics) {\n        return std::atomic_load_explicit(&impl_->metrics_, std::memory_order_acquire);\n    }\n    return {};\n}\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/engine/engine.h",
    "content": "\n#pragma once\n\n#include <memory>\n\n#include \"src/turbomind/engine/gateway.h\"\n\n#include \"src/turbomind/models/language_model.h\"\n#include \"src/turbomind/models/llama/context.h\"\n#include \"src/turbomind/models/llama/llama_params.h\"\n\nnamespace turbomind {\n\nstruct ScheduleMetrics;\n\nclass Engine {\npublic:\n    ~Engine();\n\n    Engine();\n    Engine(Engine&&) noexcept;\n    Engine& operator=(Engine&&) noexcept;\n\n    explicit operator bool() const noexcept\n    {\n        return static_cast<bool>(impl_);\n    }\n\n    Engine(DataType      dtype,\n           EngineParam   param,\n           LanguageModel model,\n           Context&      ctx,\n           Gateway&      gateway,\n           int           device_id,\n           int           queue_id,\n           int           phases);\n\n    void Start();\n\n    std::shared_ptr<ScheduleMetrics> GetScheduleMetrics();\n\nprivate:\n    struct Impl;\n    std::unique_ptr<Impl> impl_;\n};\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/engine/gateway.cc",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include <memory>\n\n#include \"src/turbomind/core/check.h\"\n#include \"src/turbomind/engine/gateway.h\"\n#include \"src/turbomind/engine/request_queue.h\"\n\nnamespace turbomind {\n\nGateway::Gateway(int size, std::function<std::shared_ptr<void>()> ctx_factory):\n    size_{size}, queues_(size_), dp_thr_{1}, ctx_factory_{ctx_factory}, next_{0}\n{\n    for (int i = 0; i < size_; ++i) {\n        queues_[i] = std::make_unique<RequestQueue>();\n    }\n\n    signal_thread_ = std::thread(&Gateway::signal_thread_entry, this);\n}\n\nvoid Gateway::shutdown()\n{\n    for (auto& q : queues_) {\n        q->close();\n    }\n\n    signal_buffer_.close();\n    signal_thread_.join();\n}\n\nvoid Gateway::push(std::shared_ptr<Request> r)\n{\n    int rank = -1;\n\n    if (TM_UNLIKELY(!r->session.start_flag)) {\n        // route to corresponding rank\n        rank = binding_.find(r->session.id);\n    }\n    else if (TM_LIKELY(size_)) {\n        rank = next_.fetch_add(1, std::memory_order_relaxed) % size_;\n    }\n    else {\n        TM_LOG_ERROR(\"[Gateway] No queues available for submitting the request\");\n        notify({[r = std::move(r)] { UpdateState(*r, Request::kNoQueue, 0); }});\n        return;\n    }\n\n    if (TM_LIKELY(rank >= 0)) {\n        queues_[rank]->push({std::move(r)});\n    }\n    else {\n        TM_LOG_ERROR(\"[Gateway] Failed to find a binded queue for %lu\", r->session.id);\n        notify({[r = std::move(r)] { UpdateState(*r, Request::kInvalid, 0); }});\n    }\n}\n\nvoid Gateway::pop(std::vector<std::shared_ptr<Request>>& infer_reqs,\n                  std::vector<std::shared_ptr<Request>>& kill_reqs,\n                  unsigned                               max_infer,\n                  bool                                   blocking,\n                  bool&                                  abort,\n                  comm::HostComm&                        dp_group,\n                  int                                    qid)\n{\n    TM_CHECK_GE(qid, 0);\n\n    auto& q = *queues_.at(qid);\n\n    infer_reqs.clear();\n    kill_reqs.clear();\n\n    if (dp_group->n_ranks() == 1) {\n        q.pop(infer_reqs, kill_reqs, max_infer, blocking, abort);\n    }\n    else {\n        union {\n            uint16_t data[2];\n            uint32_t value;\n        };\n        while (true) {\n            q.pop(infer_reqs, kill_reqs, max_infer, false, abort);\n            data[0] = !(blocking && infer_reqs.empty() && kill_reqs.empty());  // ready?\n            data[1] = abort;\n            value   = comm::AllReduce(dp_group, value, comm::RedOp::kSum);\n            if (data[0] >= dp_thr_ || data[1]) {\n                break;\n            }\n        }\n        abort = data[1];\n    }\n\n    // Assign a monotonic increasing id for each infer request\n    q.assign_unique_ids(infer_reqs);\n\n    // Bind for stateful inference\n    std::vector<uint64_t> bind_ids;\n    for (const auto& r : infer_reqs) {\n        if (r->session.start_flag && !r->session.end_flag) {  // started but not ended\n            bind_ids.push_back(r->session.id);\n        }\n    }\n\n    /// TODO: fix qid <-> rank mapping\n    if (!bind_ids.empty()) {\n        binding_.bind(bind_ids, qid);\n    }\n\n    // Unbind for stateful kill\n    std::vector<uint64_t> unbind_ids;\n    for (const auto& r : kill_reqs) {\n        unbind_ids.push_back(r->session.id);\n    }\n    if (!unbind_ids.empty()) {\n        binding_.unbind(unbind_ids, qid);\n    }\n}\n\nvoid Gateway::cancel(std::shared_ptr<Request> r)\n{\n    // {-1: canceled, 0: queued, 1: active}\n    if (r->cancel_flag.exchange(-1, std::memory_order_acq_rel) == 0) {\n        notify({[r = std::move(r)] {  //\n            UpdateState(*r, Request::kCancel, 0);\n        }});\n    }\n    else {\n        // request is picked up by engine\n    }\n}\n\nvoid Gateway::kill(std::shared_ptr<Request> r)\n{\n    if (auto rank = binding_.find(r->session.id); rank >= 0) {\n        queues_[rank]->kill(std::move(r));\n    }\n    else {\n        TM_LOG_ERROR(\"[Gateway] Failed to find a binded queue for %lu\", r->session.id);\n        notify({[r = std::move(r)] {  //\n            UpdateState(*r, Request::kInvalid, 0);\n        }});\n    }\n}\n\nvoid Gateway::notify(std::vector<Signal> signals, bool pred)\n{\n    if (pred) {\n        signal_buffer_.push(std::move(signals));\n    }\n}\n\nvoid Gateway::signal_thread_entry() noexcept\n{\n    while (true) {\n        bool                abort{};\n        std::vector<Signal> signals = signal_buffer_.take_all(abort);\n        if (abort) {\n            break;\n        }\n        else {\n            auto ctx = ctx_factory_();\n            for (const auto& s : signals) {\n                s();\n            }\n        }\n    }\n}\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/engine/gateway.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include <atomic>\n#include <memory>\n#include <mutex>\n#include <thread>\n#include <vector>\n\n#include \"src/turbomind/comm/host_comm.h\"\n#include \"src/turbomind/core/check.h\"\n#include \"src/turbomind/engine/request.h\"\n#include \"src/turbomind/engine/request_queue.h\"\n#include \"src/turbomind/engine/signal_buffer.h\"\n#include \"src/turbomind/utils/logger.h\"\n\nnamespace turbomind {\n\nclass SequenceBinding {\npublic:\n    int find(uint64_t seq_id)\n    {\n        std::lock_guard lock{mutex_};\n        if (auto it = map_.find(seq_id); it != map_.end()) {\n            return it->second;\n        }\n        return -1;\n    }\n\n    void bind(const std::vector<uint64_t>& seq_ids, int rank)\n    {\n        std::lock_guard lock{mutex_};\n        for (const auto& x : seq_ids) {\n            if (auto [it, success] = map_.emplace(x, rank); !success) {\n                TM_LOG_WARNING(\"[TM][Gateway] Duplicated binding for %lu, %d vs %d\", x, rank, it->second);\n            }\n        }\n    }\n\n    void unbind(const std::vector<uint64_t>& seq_ids, int rank)\n    {\n        std::lock_guard lock{mutex_};\n        for (const auto& x : seq_ids) {\n            auto it = map_.find(x);\n            if (it == map_.end()) {\n                TM_LOG_WARNING(\"[TM][Gateway] No entry found for unbinding %lu, %d\", x, rank);\n            }\n            else if (it->second != rank) {\n                TM_LOG_WARNING(\"[TM][Gateway] Mismatched entry for unbinding %lu, %d vs %d\", x, rank, it->second);\n            }\n            else {\n                map_.erase(it);\n            }\n        }\n    }\n\nprivate:\n    std::mutex                        mutex_;\n    std::unordered_map<uint64_t, int> map_;\n};\n\nclass Gateway {\npublic:\n    Gateway(int size, std::function<std::shared_ptr<void>()> ctx_factory);\n\n    void shutdown();\n\n    void push(std::shared_ptr<Request> r);\n\n    void pop(std::vector<std::shared_ptr<Request>>& infer_reqs,\n             std::vector<std::shared_ptr<Request>>& kill_reqs,\n             unsigned                               max_infer,\n             bool                                   blocking,\n             bool&                                  abort,\n             comm::HostComm&                        dp_group,\n             int                                    qid);\n\n    void cancel(std::shared_ptr<Request> r);\n\n    void kill(std::shared_ptr<Request> r);\n\n    void notify(std::vector<Signal> signals, bool pred = true);\n\n    void set_threshold(int value)\n    {\n        TM_LOG_INFO(\"set threshold %d -> %d\", dp_thr_, value);\n        dp_thr_ = value;\n    }\n\nprivate:\n    void signal_thread_entry() noexcept;\n\nprivate:\n    const int size_;\n\n    int dp_thr_;\n\n    std::vector<std::unique_ptr<RequestQueue>> queues_;\n\n    std::function<std::shared_ptr<void>()> ctx_factory_;\n\n    SignalBuffer signal_buffer_;\n    std::thread  signal_thread_;\n\n    SequenceBinding binding_;\n\n    std::atomic<uint32_t> next_;\n};\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/engine/model_executor.cc",
    "content": "\n#include \"src/turbomind/engine/model_executor.h\"\n\n#include <memory>\n\n#include \"src/turbomind/core/allocator.h\"\n#include \"src/turbomind/core/check.h\"\n#include \"src/turbomind/core/copy.h\"\n#include \"src/turbomind/engine/batch.h\"\n#include \"src/turbomind/models/language_model.h\"\n#include \"src/turbomind/models/llama/llama_utils.h\"\n#include \"src/turbomind/utils/anomaly_handler.h\"\n\n// #include \"dbg.h\"\n\nnamespace turbomind {\n\nusing std::shared_ptr;\nusing std::unique_ptr;\n\nstruct ModelExecutor::Impl {\n\n    LanguageModel& model_;\n    LlamaLinear&   linear_;\n\n    const int device_id_;\n\n    Queue<unique_ptr<BatchData>>& inbound_;\n    Queue<unique_ptr<BatchData>>& outbound_;\n\n    std::thread internal_thread_;\n\n    void InternalThreadEntry()\n    {\n        check_cuda_error(cudaSetDevice(device_id_));\n\n        Stream    stream  = Stream::create();\n        Allocator h_alloc = Allocator(kCPU);\n        Allocator d_alloc = Allocator(stream, false);\n\n        AnomalyHandler::instance().Init(0, 1000, 0, 1000, stream.handle());\n\n        core::ContextGuard ctx{stream, h_alloc, d_alloc};\n\n        unique_ptr<BatchData> d;\n\n        while (inbound_.pop(d)) {\n            TM_CHECK_NOTNULL(d);\n            core::Context::stream().Wait(d->ready);\n            Run(*d);\n            d->done.Record(core::Context::stream());\n            outbound_.push(std::move(d));\n        }\n    }\n\n    void Run(BatchData& d)\n    {\n        auto batch = &d;\n\n        BatchCopy copy;\n        TensorMap env{{\"batch\", d.buf()}, {\"copy\", copy.buf()}};\n\n        model_.Run(BatchOp::kPrepare, d.phase, env);\n        // dbg(copy);\n        copy.Run();\n\n        model_.Run(BatchOp::kForward, d.phase, env);\n\n        model_.Run(BatchOp::kUnprep, d.phase, env);\n        // dbg(copy);\n        copy.Run();\n\n        // TM_CHECK(0);\n        AnomalyHandler::instance().Summarize([](...) {});\n        AnomalyHandler::instance().Reset();\n    }\n\n    Impl(LanguageModel&                model,\n         Context&                      context,\n         int                           device_id,\n         Queue<unique_ptr<BatchData>>& inbound,\n         Queue<unique_ptr<BatchData>>& outbound):\n        model_{model}, linear_{*context.linear}, device_id_{device_id}, inbound_{inbound}, outbound_{outbound}\n    {\n    }\n\n    ~Impl()\n    {\n        if (internal_thread_.joinable()) {\n            internal_thread_.join();\n        }\n    }\n\n    void Start()\n    {\n        internal_thread_ = std::thread(&Impl::InternalThreadEntry, this);\n    }\n};\n\nModelExecutor::~ModelExecutor() = default;\n\nModelExecutor::ModelExecutor()                         = default;\nModelExecutor::ModelExecutor(ModelExecutor&&) noexcept = default;\nModelExecutor& ModelExecutor::operator=(ModelExecutor&&) noexcept = default;\n\nModelExecutor::ModelExecutor(LanguageModel&                model,\n                             Context&                      context,\n                             int                           device_id,\n                             Queue<unique_ptr<BatchData>>& inbound,\n                             Queue<unique_ptr<BatchData>>& outbound):\n    impl_{std::make_unique<Impl>(model, context, device_id, inbound, outbound)}\n{\n}\n\nvoid ModelExecutor::Start()\n{\n    return impl_->Start();\n}\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/engine/model_executor.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include <memory>\n\n#include \"src/turbomind/core/core.h\"\n\n#include \"src/turbomind/engine/batch.h\"\n#include \"src/turbomind/engine/queue.h\"\n#include \"src/turbomind/models/language_model.h\"\n\n#include \"src/turbomind/models/llama/context.h\"\n\nnamespace turbomind {\n\n// Model executor for auto-regressive language models\nclass ModelExecutor {\npublic:\n    ~ModelExecutor();\n\n    ModelExecutor();\n    ModelExecutor(ModelExecutor&&) noexcept;\n    ModelExecutor& operator=(ModelExecutor&&) noexcept;\n\n    explicit operator bool() const noexcept\n    {\n        return static_cast<bool>(impl_);\n    }\n\n    ModelExecutor(LanguageModel&                     model,\n                  Context&                           context,\n                  int                                device_id,\n                  Queue<std::unique_ptr<BatchData>>& inbound,\n                  Queue<std::unique_ptr<BatchData>>& outbound);\n\n    void Start();\n\nprivate:\n    struct Impl;\n    std::unique_ptr<Impl> impl_;\n};\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/engine/model_request.cc",
    "content": "\n\n#include <algorithm>\n#include <functional>\n#include <memory>\n#include <type_traits>\n#include <utility>\n\n#include \"xgrammar/compiler.h\"\n#include \"xgrammar/matcher.h\"\n\n#include \"src/turbomind/engine/model_request.h\"\n#include \"src/turbomind/engine/request.h\"\n#include \"src/turbomind/utils/constant.h\"\n#include \"src/turbomind/utils/metrics.h\"\n\nnamespace turbomind {\n\nModelRequest::ModelRequest(Gateway* gateway, DataType data_type, int session_len, int vocab_size, int hidden_dim):\n    gateway_{gateway},\n    data_type_{data_type},\n    session_len_{session_len},\n    vocab_size_{vocab_size},\n    hidden_dim_{hidden_dim}\n{\n}\n\nvoid ModelRequest::Cancel()\n{\n    // request is finished if lock failed\n    if (auto r = request_.lock()) {\n        gateway_->cancel(std::move(r));\n    }\n}\n\nvoid ModelRequest::End(std::function<void(int)> cb, uint64_t session_id)\n{\n    auto r = std::make_shared<Request>();\n\n    r->id = r->session.id = session_id;\n    r->session.kill_flag  = true;\n\n    r->end_cb = std::move(cb);\n\n    gateway_->kill(std::move(r));\n}\n\nauto ModelRequest::Forward(InputParam param, std::function<void()> cb) -> OutputParam\n{\n    inputs_  = std::make_shared<TensorMap>();\n    outputs_ = std::make_shared<TensorMap>();\n\n    auto add = [](auto& dest, auto key, auto dtype, auto where, auto shape, auto&&... dims) {\n        Layout shape_;\n        if constexpr (std::is_integral_v<decltype(shape)>) {\n            shape_ = {shape, dims...};\n        }\n        else {\n            shape_ = {shape.cbegin(), shape.cend()};\n        }\n        dest->emplace(key, Tensor{shape_, dtype, where});\n    };\n\n    auto& inputs = *param.tensors;\n\n    TM_CHECK_EQ(inputs.at(\"input_ids\").ndim(), 1);\n\n    const int input_len  = inputs.at(\"input_ids\").shape(0);\n    const int output_len = param.gen_cfg.max_new_tokens;\n\n    // Max possible length of a sequence, this depends on `history_len` which isn't available here, so `session_len`\n    // is used instead\n    const int max_seq_len = session_len_ + 1;\n    const int max_out_len = std::min(output_len, session_len_) + 1;\n    // This does not include histroy length in interactive mode\n    const int max_in_out_len = std::min(input_len + output_len, session_len_) + 1;\n\n    for (auto& [k, v] : *param.tensors) {\n        inputs_->emplace(k, v);\n    }\n\n    add(outputs_, \"output_ids\", data_type_v<int>, kCPU, max_seq_len);\n    add(outputs_, \"sequence_length\", data_type_v<int>, kCPU, 1);\n\n    if (param.gen_cfg.output_logits) {\n        const int len = param.gen_cfg.output_logits == GenerationConfig::kAll ? max_in_out_len : max_out_len;\n        add(outputs_, \"logits\", data_type_, kCPU, len, vocab_size_);\n    }\n\n    if (param.gen_cfg.output_last_hidden_state) {\n        const int len = param.gen_cfg.output_last_hidden_state == GenerationConfig::kAll ? max_in_out_len : max_out_len;\n        add(outputs_, \"last_hidden_state\", data_type_, kCPU, len, hidden_dim_);\n    }\n\n    if (param.gen_cfg.output_logprobs) {\n        add(outputs_, \"logprob_vals\", data_type_v<float>, kCPU, max_out_len, kMaxLogProb);\n        add(outputs_, \"logprob_indexes\", data_type_v<int>, kCPU, max_out_len, kMaxLogProb);\n        add(outputs_, \"logprob_nums\", data_type_v<int>, kCPU, max_out_len);\n    }\n\n    auto r = std::make_shared<Request>();\n\n    for (const auto& [k, v] : *inputs_) {\n        r->inputs.emplace(k, v);\n    }\n    for (const auto& [k, v] : *outputs_) {\n        r->outputs.emplace(k, v);\n    }\n\n    auto state = std::make_shared<AtomicRequestState>();\n\n    auto metrics = param.enable_metrics ? std::make_shared<RequestMetrics>() : nullptr;\n    if (metrics) {\n        metrics->enqueue_time.store(RequestMetrics::timestamp(), std::memory_order_relaxed);\n        metrics->scheduled_time.store(0, std::memory_order_relaxed);\n    }\n\n    if (param.session.start_flag) {\n        session_id_ = param.session.id;\n    }\n\n    r->id            = param.session.id;\n    r->session       = param.session;\n    r->gen_cfg       = param.gen_cfg;\n    r->stream_output = param.stream_output;\n    r->forward_cb    = std::move(cb);\n    r->state         = state;\n    r->metrics       = metrics;\n\n    r->output_ids      = outputs_->at(\"output_ids\");\n    r->sequence_length = outputs_->at(\"sequence_length\");\n\n    if (grammar_) {\n        r->grammar = std::move(grammar_);\n        r->matcher = std::make_shared<xgrammar::GrammarMatcher>(*r->grammar);\n    }\n\n    // Keep a WEAK reference for canceling the request\n    request_ = r;\n\n    gateway_->push({std::move(r)});\n\n    return OutputParam{outputs_, state, metrics};\n}\n\nvoid ModelRequest::setGrammar(const xgrammar::CompiledGrammar& grammar)\n{\n    grammar_ = std::make_shared<xgrammar::CompiledGrammar>(grammar);\n}\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/engine/model_request.h",
    "content": "\n\n#pragma once\n\n#include <memory>\n\n#include \"src/turbomind/core/core.h\"\n#include \"src/turbomind/engine/gateway.h\"\n\nnamespace xgrammar {\nclass CompiledGrammar;\n}\n\nnamespace turbomind {\n\nclass ModelRequest {\npublic:\n    virtual ~ModelRequest() = default;\n\n    ModelRequest(Gateway* gateway, DataType data_type, int session_len, int vocab_size, int hidden_dim);\n\n    // Cancel running request\n    void Cancel();\n\n    // Reset the channel to uninitailized state, calls `notify` when done\n    void End(std::function<void(int)> cb, uint64_t session_id);\n\n    struct InputParam {\n        std::shared_ptr<TensorMap> tensors;\n\n        SessionParam     session;\n        GenerationConfig gen_cfg;\n\n        bool stream_output;\n        bool enable_metrics;\n    };\n\n    struct OutputParam {\n        std::shared_ptr<TensorMap>          tensors;\n        std::shared_ptr<AtomicRequestState> state;\n        std::shared_ptr<RequestMetrics>     metrics;\n    };\n\n    OutputParam Forward(InputParam param, std::function<void()> cb);\n\n    void setGrammar(const xgrammar::CompiledGrammar& grammar);\n\nprotected:\n    Gateway* const gateway_;\n\n    const DataType data_type_;\n\n    const int session_len_;\n    const int hidden_dim_;\n    const int vocab_size_;\n\n    uint64_t session_id_;\n\n    std::weak_ptr<Request> request_;\n\n    std::shared_ptr<TensorMap> inputs_;\n    std::shared_ptr<TensorMap> outputs_;\n\n    std::shared_ptr<xgrammar::CompiledGrammar> grammar_;\n};\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/engine/queue.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include <condition_variable>\n#include <mutex>\n#include <queue>\n\nnamespace turbomind {\n\ntemplate<class T>\nclass Queue {\npublic:\n    template<class X>\n    void push(X&& x)\n    {\n        {\n            std::lock_guard lock{mutex_};\n            queue_.push(std::forward<X>(x));\n        }\n        cv_.notify_one();\n    }\n\n    bool pop(T& x)\n    {\n        std::unique_lock lock{mutex_};\n        cv_.wait(lock, [&] { return !queue_.empty() || is_closed_; });\n        if (is_closed_) {\n            return false;\n        }\n        x = std::move(queue_.front());\n        queue_.pop();\n        return true;\n    }\n\n    void close()\n    {\n        {\n            std::lock_guard lock{mutex_};\n            is_closed_ = true;\n        }\n        cv_.notify_all();\n    }\n\nprivate:\n    std::queue<T>           queue_;\n    std::mutex              mutex_;\n    std::condition_variable cv_;\n    bool                    is_closed_{false};\n};\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/engine/request.cc",
    "content": "\n\n#include \"src/turbomind/engine/request.h\"\n\n#include <iterator>\n\nnamespace turbomind {\n\nnamespace {\n\ntemplate<typename T>\ninline std::ostream& operator<<(std::ostream& os, const std::vector<T>& vec)\n{\n    os << \"[\";\n    std::copy(vec.begin(), vec.end(), std::ostream_iterator<T>(os, \", \"));\n    if (!vec.empty()) {\n        os.seekp(-2, std::ios_base::end);\n    }\n    os << \"]\";\n    return os;\n}\n\n}  // namespace\n\nstd::ostream& operator<<(std::ostream& os, const GenerationConfig& c)\n{\n    os << \"GenerationConfig { \";\n    os << \"max_new_tokens=\" << c.max_new_tokens;\n    os << \", min_new_tokens=\" << c.min_new_tokens;\n    os << \", eos_ids=\" << c.eos_ids;\n    os << \", stop_ids=[\" << c.stop_ids[0] << \", \" << c.stop_ids[1] << \"]\";\n    os << \", bad_ids=[\" << c.bad_ids[0] << \", \" << c.bad_ids[1] << \"]\";\n    os << \", top_p=\" << c.top_p;\n    os << \", top_k=\" << c.top_k;\n    os << \", min_p=\" << c.min_p;\n    os << \", temperature=\" << c.temperature;\n    os << \", repetition_penalty=\" << c.repetition_penalty;\n    os << \", random_seed=\" << c.random_seed;\n    os << \", output_logprobs=\" << c.output_logprobs;\n    os << \", output_hidden_states=\" << c.output_last_hidden_state;\n    os << \", output_logits=\" << c.output_logits;\n    os << \" }\";\n    return os;\n}\n\nvoid UpdateState(Request& r, int status, int seq_len)\n{\n    try {\n        auto new_state = new RequestState{status, seq_len};\n        auto old_state = r.state->exchange(new_state);\n        if (!old_state && r.forward_cb) {\n            r.forward_cb();\n        }\n    }\n    catch (const std::exception& e) {\n        TM_LOG_ERROR(\"Error invoking callback for (%lu): %s\", r.id, e.what());\n    }\n    catch (...) {\n        TM_LOG_ERROR(\"Unknown error invoking callback for (%lu)\", r.id);\n    }\n}\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/engine/request.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include <array>\n#include <atomic>\n#include <cstdint>\n#include <functional>\n#include <memory>\n#include <ostream>\n\n#include \"src/turbomind/core/core.h\"\n#include \"src/turbomind/core/interval.h\"\n#include \"src/turbomind/utils/metrics.h\"\n\nnamespace xgrammar {\nclass GrammarMatcher;  // forward declaration\nclass CompiledGrammar;\n}  // namespace xgrammar\n\nnamespace turbomind {\n\nstruct GenerationConfig {\n    int max_new_tokens = 0;\n    int min_new_tokens = 0;\n\n    std::vector<int> eos_ids;  // only support single token id\n\n    std::array<std::vector<int>, 2> stop_ids;  // (token_id, offset)\n    std::array<std::vector<int>, 2> bad_ids;\n\n    int   top_k       = 1;\n    float top_p       = 0.f;\n    float min_p       = 0.f;\n    float temperature = 1.f;\n\n    float repetition_penalty = 1.f;\n\n    uint64_t random_seed = 0;\n\n    int output_logprobs = 0;\n\n    enum OutType\n    {\n        kNone       = 0,\n        kAll        = 1,\n        kGeneration = 2\n    };\n    int output_last_hidden_state = 0;\n    int output_logits            = 0;\n};\n\nstd::ostream& operator<<(std::ostream& os, const GenerationConfig& c);\n\nstruct SessionParam {\n    uint64_t id;\n\n    int step;\n\n    bool start_flag;\n    bool end_flag;\n    bool kill_flag;\n};\n\nstruct RequestState {\n    int status;\n    int seq_len;\n};\n\nstruct AtomicRequestState {\n\n    std::atomic<RequestState*> data_;\n\n    static_assert(std::atomic<RequestState*>::is_always_lock_free);\n\n    ~AtomicRequestState()\n    {\n        auto data = exchange(nullptr);\n    }\n\n    std::unique_ptr<RequestState> exchange(RequestState* data)\n    {\n        return std::unique_ptr<RequestState>{data_.exchange(data, std::memory_order_acq_rel)};\n    }\n};\n\nstruct Request {\n    uint64_t id;         // sequence id\n    uint64_t unique_id;  // monotonic increasing\n\n    SessionParam     session;\n    GenerationConfig gen_cfg;\n\n    bool stream_output;\n\n    // reference to IO tensors\n    TensorMap inputs;\n    TensorMap outputs;\n    // fast path for accessing common output buffers\n    Tensor_<int> output_ids;\n    Tensor_<int> sequence_length;\n\n    std::function<void(int)> end_cb;\n\n    std::atomic<int> cancel_flag;\n\n    std::function<void()> forward_cb;\n\n    std::shared_ptr<AtomicRequestState> state;\n\n    std::shared_ptr<RequestMetrics> metrics;\n\n    int ec = 0;  // set when disabling conflicting requests\n\n    enum\n    {\n        kOk            = 0,\n        kInvalid       = 1,  // Sequence not exist or both `start` & `stop` (instead of `end`) is set\n        kConflict      = 2,  // Concurrent requests to the same sequence\n        kBusy          = 3,  // Sequence is already running\n        kInactive      = 4,  // Sequence to `stop` is not active\n        kFail          = 5,  // Can't find sequence for `stop` request or internal error during inference\n        kTooLong       = 6,  // history + prompt > session_len,\n        kFinish        = 7,\n        kCancel        = 8,\n        kInconsistency = 9,   // Inconsistent request parameters, e.g. prefix caching is not allowed in interactive mode\n        kNoQueue       = 10,  // No queue available for submitting the request (in current process)\n    };\n\n    std::shared_ptr<xgrammar::CompiledGrammar> grammar;\n    std::shared_ptr<xgrammar::GrammarMatcher>  matcher;\n};\n\nvoid UpdateState(Request& r, int status, int seq_len);\n\nclass Sequence;\n\n// Unlike `Request` which is shared by all local TP ranks, each rank has its own `RequestCache`.\nstruct RequestCache {\n    std::shared_ptr<Request> req;\n    const Sequence*          seq;  // May be NULL in `Update` (seq get erased when req is done)\n    const GenerationConfig&  gen_cfg;\n\n    RequestCache(std::shared_ptr<Request> r, const Sequence& s): req{std::move(r)}, seq{&s}, gen_cfg{req->gen_cfg} {}\n\n    int status = Request::kOk;\n\n    // These members may be opaque handles from individual modules (pointers to forward declared types), but we tend to\n    // keep it simple as long as the complexity is manageable\n\n    int*     token_ids    = nullptr;  // currently the `output_ids` buf of request\n    uint8_t* random_state = nullptr;\n\n    int step0       = 0;  // set at request init, constant, first prefill step\n    int prompt_len  = 0;  // set at request init, constant, first decode step\n    int max_seq_len = 0;  // set at request init, constant\n\n    int hidden_states_offset = 0;  // set at request init, constant\n    int logits_offset        = 0;  // set at request init, constant\n\n    int seq_len = 0;  // set at request init, updated per step\n\n    int input_len   = 0;  // set at schedule (set to `seq.input_len`)\n    int history_len = 0;  // set at schedule (set to `seq.cache_len`)\n\n    bool autoregres = false;  // set at schedule, `seq_len` and `input_ids` taken from the engine\n    bool generating = false;  // set at schedule\n\n    bool done = false;  // set at cancel / update, is the request finished / canceled\n\n    int alpha = 0;  // pending growth of cache_len (draft_len + input_len)\n    int beta  = 0;  // pending growth of seq_len (draft_len + {0,1})\n\n    float rope_base = 0.f;\n\n    Interval output_hidden_states;\n    Interval output_logits;\n};\n\ntemplate<class Archive>\nvoid serdes(Archive& ar, GenerationConfig& g)\n{\n    // clang-format off\n    ar & g.max_new_tokens;\n    ar & g.min_new_tokens;\n    ar & g.eos_ids;\n    ar & g.stop_ids[0];\n    ar & g.stop_ids[1];\n    ar & g.bad_ids[0];\n    ar & g.bad_ids[1];\n    ar & g.top_k;\n    ar & g.top_p;\n    ar & g.min_p;\n    ar & g.temperature;\n    ar & g.repetition_penalty;\n    ar & g.random_seed;\n    ar & g.output_logprobs;\n    ar & g.output_last_hidden_state;\n    ar & g.output_logits;\n    // clang-format on\n}\n\ntemplate<class Archive>\nvoid save_req_output(Archive& ar, const TensorMap& map)\n{\n    // clang-format off\n    ar & map.size();\n    for (const auto& [k, t] : map) {\n        TM_CHECK(t.device().type == kCPU);\n        ar & k;\n        ar & t.layout();\n        ar & t.dtype();\n    }\n    // clang-format on\n}\n\ntemplate<class Archive>\nvoid load_req_output(Archive& ar, TensorMap& map)\n{\n    // clang-format off\n    decltype(map.size()) size;\n    ar & size;\n    for (int i = 0; i < size; ++i) {\n        std::string k;\n        Layout      layout;\n        DataType    dtype;\n        ar & k;\n        ar & layout;\n        ar & dtype;\n        map.emplace(std::move(k), Tensor{layout, dtype, kCPU});\n    }\n    // clang-format on\n}\n\ntemplate<class Archive>\nvoid serdes(Archive& ar, Request& r)\n{\n    // clang-format off\n    ar & r.id;\n    ar & r.unique_id;\n    ar & r.session;\n    ar & r.gen_cfg;\n    ar & r.stream_output;\n    ar & r.inputs;\n    if constexpr(Archive::is_loading) {\n        load_req_output(ar, r.outputs);\n        r.output_ids      = r.outputs.at(\"output_ids\");\n        r.sequence_length = r.outputs.at(\"sequence_length\");\n    } else {\n        save_req_output(ar, r.outputs);\n    }\n    ar & r.ec;\n    // clang-format on\n}\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/engine/request_queue.cc",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include \"src/turbomind/engine/request_queue.h\"\n#include \"src/turbomind/engine/gateway.h\"\n\n#include \"src/turbomind/engine/request.h\"\n\nnamespace turbomind {\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/engine/request_queue.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include <condition_variable>\n#include <list>\n#include <memory_resource>\n#include <mutex>\n\n#include \"src/turbomind/engine/request.h\"\n\nnamespace turbomind {\n\nclass RequestQueue {\npublic:\n    explicit RequestQueue(): queue_{&pool_} {}\n\n    void push(std::shared_ptr<Request> r)\n    {\n        {\n            std::lock_guard lock{mutex_};\n            if (closed_) {\n                throw std::runtime_error(\"Queue is closed\");\n            }\n            queue_.push_back(std::move(r));\n        }\n        cv_.notify_one();\n    }\n\n    void kill(std::shared_ptr<Request> r)\n    {\n        {\n            std::lock_guard lock{mutex_};\n            if (closed_) {\n                throw std::runtime_error(\"Queue is closed\");\n            }\n            kill_.push_back(std::move(r));\n        }\n        cv_.notify_one();\n    }\n\n    void pop(std::vector<std::shared_ptr<Request>>& infer_reqs,\n             std::vector<std::shared_ptr<Request>>& kill_reqs,\n             unsigned                               max_infer,\n             bool                                   blocking,\n             bool&                                  abort)\n    {\n        std::unique_lock lock{mutex_};\n\n        if (blocking) {\n            cv_.wait(lock, [this] { return !(queue_.empty() && kill_.empty()) || closed_; });\n        }\n\n        if (closed_) {\n            abort = true;\n        }\n\n        while (!queue_.empty() && infer_reqs.size() < max_infer) {\n            auto& r = queue_.front();\n            if (r->cancel_flag.exchange(1, std::memory_order_acq_rel) == 0) {\n                infer_reqs.push_back(std::move(r));\n            }\n            queue_.pop_front();\n        }\n\n        kill_reqs.insert(kill_reqs.end(), kill_.begin(), kill_.end());\n        kill_.clear();\n    }\n\n    void close()\n    {\n        {\n            std::lock_guard<std::mutex> lock(mutex_);\n            closed_ = true;\n        }\n        cv_.notify_all();\n    }\n\n    void notify()\n    {\n        cv_.notify_all();\n    }\n\n    void assign_unique_ids(std::vector<std::shared_ptr<Request>>& rs)\n    {\n        for (auto& r : rs) {\n            r->unique_id = unique_id_.fetch_add(1, std::memory_order_relaxed);\n        }\n    }\n\nprivate:\n    std::atomic<uint64_t> unique_id_{};\n\n    std::pmr::unsynchronized_pool_resource   pool_;\n    std::pmr::list<std::shared_ptr<Request>> queue_;\n\n    std::vector<std::shared_ptr<Request>> kill_;\n\n    std::mutex              mutex_;\n    std::condition_variable cv_;\n\n    bool closed_{};\n};\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/engine/signal_buffer.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include <condition_variable>\n#include <functional>\n#include <mutex>\n\nnamespace turbomind {\n\nusing Signal = std::function<void()>;\n\nclass SignalBuffer {\npublic:\n    void push(std::vector<Signal> signals)\n    {\n        if (signals.empty()) {\n            return;\n        }\n        {\n            std::lock_guard lock{mutex_};\n            signals_.insert(signals_.end(), std::move_iterator{signals.begin()}, std::move_iterator{signals.end()});\n        }\n        cv_.notify_one();\n    }\n\n    void close()\n    {\n        {\n            std::lock_guard lock{mutex_};\n            aborted_ = true;\n        }\n        cv_.notify_all();\n    }\n\n    std::vector<Signal> take_all(bool& abort)\n    {\n        std::vector<Signal> signals;\n        {\n            std::unique_lock lock{mutex_};\n            cv_.wait(lock, [&] { return !signals_.empty() || aborted_; });\n            if (aborted_) {\n                abort = true;\n            }\n            else {\n                signals.swap(signals_);\n            }\n        }\n        return signals;\n    }\n\nprivate:\n    std::vector<Signal> signals_;\n\n    std::mutex              mutex_;\n    std::condition_variable cv_;\n\n    bool aborted_{false};\n};\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/generation/CMakeLists.txt",
    "content": "# Copyright (c) 2019-2023, NVIDIA CORPORATION.  All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\ncmake_minimum_required(VERSION 3.11)\n\nadd_library(guided_decoding STATIC guided_decoding.cc)\ntarget_link_libraries(guided_decoding PRIVATE\n    apply_token_bitmask_inplace_cuda\n    xgrammar\n    core)\nset_property(TARGET guided_decoding PROPERTY POSITION_INDEPENDENT_CODE ON)\n\nadd_library(generation STATIC\n    generation.cc\n    logits_processor.cc\n    sampling.cc\n    stop_criteria.cc)\nset_property(TARGET generation PROPERTY POSITION_INDEPENDENT_CODE ON)\nset_property(TARGET generation PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)\ntarget_link_libraries(generation PUBLIC\n    ban_bad_words\n    sampling_penalty_kernels\n    sampling_topk_kernels\n    sampling_topp_kernels\n    sampling_kernels\n    stop_criteria\n    guided_decoding\n    memory_utils\n    CUDA::cudart)\n"
  },
  {
    "path": "src/turbomind/generation/base_param.h",
    "content": "\n\n#pragma once\n\nnamespace turbomind {\n\nclass BaseGenerationParam {\npublic:\n    explicit BaseGenerationParam(int max_batch_size, int vocab_size, int vocab_size_padded):\n        max_batch_size_{max_batch_size}, vocab_size_{vocab_size}, vocab_size_padded_{vocab_size_padded}\n    {\n    }\n\nprotected:\n    int max_batch_size_;\n    int vocab_size_;\n    int vocab_size_padded_;\n};\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/generation/generation.cc",
    "content": "\n#include <memory>\n\n#include \"src/turbomind/generation/generation.h\"\n\n#include \"src/turbomind/core/allocator.h\"\n#include \"src/turbomind/core/check.h\"\n#include \"src/turbomind/core/copy.h\"\n#include \"src/turbomind/core/data_type.h\"\n#include \"src/turbomind/core/state.h\"\n#include \"src/turbomind/engine/batch.h\"\n#include \"src/turbomind/engine/request.h\"\n\n#include \"src/turbomind/generation/guided_decoding.h\"\n#include \"src/turbomind/generation/logits_processor.h\"\n#include \"src/turbomind/generation/sampling.h\"\n#include \"src/turbomind/generation/stop_criteria.h\"\n\n#include \"src/turbomind/kernels/sampling_topk_kernels.h\"  // InitializeRandomStates\n\n#include \"src/turbomind/models/llama/llama_kernels.h\"  // invokePadLastTokenIds\n\n// #include \"dbg.h\"\n\nnamespace turbomind {\n\nusing std::unique_ptr;\nusing std::shared_ptr;\nusing std::vector;\n\nstruct GenerationData {\n    Buffer_<uint8_t>  random_state;\n    Buffer_<uint64_t> random_seed;\n    Buffer_<bool>     random_init;\n    Buffer_<int>      max_seq_len;\n    Buffer_<int*>     token_ids_ptrs;\n    Buffer_<int>      output_ids;\n\n    bool random_init_needed;\n    int  generation_size;\n};\n\nstruct Generation::Impl {\n\n    // child modules\n    unique_ptr<LogitsProcessor> logits_processor_;\n    unique_ptr<Sampling>        sampling_;\n    shared_ptr<StopCriteria>    stop_criteria_;\n    unique_ptr<GuidedDecoding>  guided_decoding_;\n\n    // persistent\n    Tensor_<int> token_ids_;\n\n    // scheduling states\n    vector<int*> h_token_ids_ptrs_;\n    vector<int*> h_token_ids_free_;\n\n    // execution states\n    State random_state_;\n\n    // immutable states\n    Buffer_<int> output_ids_;\n\n    std::vector<std::unique_ptr<GenerationData>> data_;\n\n    // staging buffers\n    Buffer_<uint8_t>  random_state_buf_;\n    Buffer_<uint64_t> random_seed_buf_;\n    Buffer_<bool>     random_init_buf_;\n    Buffer_<int*>     token_ids_ptrs_buf_;\n    Buffer_<int>      token_ids_buf_;\n    Buffer_<int>      output_ids_buf_;\n\n    const int max_batch_size_;\n    const int session_len_;\n\n    Impl(DataType              dtype,\n         int                   max_batch_size,\n         int                   session_len,\n         int                   vocab_size,\n         int                   vocab_size_padded,\n         const comm::HostComm& tp_group,\n         int                   phases):\n        max_batch_size_{max_batch_size}, session_len_{session_len}\n    {\n        TM_CHECK_EQ(dtype, kFloat32);\n        BaseGenerationParam base{max_batch_size, vocab_size, vocab_size_padded};\n        logits_processor_ = std::make_unique<LogitsProcessor>(base, phases);\n        sampling_         = std::make_unique<Sampling>(base, phases);\n        stop_criteria_    = std::make_unique<StopCriteria>(base, phases);\n        guided_decoding_  = std::make_unique<GuidedDecoding>(base, tp_group, phases);\n\n        static_assert(sizeof(curandState_t) % alignof(curandState_t) == 0);\n        random_state_ = {{max_batch_size_, (int)sizeof(curandState_t)}, kUint8, kDEVICE};\n        token_ids_    = {{max_batch_size_, session_len_}, kDEVICE};\n        output_ids_   = {max_batch_size_, kDEVICE};\n        for (int i = 0; i < max_batch_size_; ++i) {\n            h_token_ids_free_.push_back(token_ids_.data() + i * token_ids_.stride(0));\n        }\n        h_token_ids_ptrs_.resize(max_batch_size_);\n\n        random_state_buf_ = {max_batch_size_ * (int)sizeof(curandState_t), kCPUpinned};\n        random_seed_buf_  = {max_batch_size_, kCPUpinned};\n        random_init_buf_  = {max_batch_size_, kCPUpinned};\n\n        token_ids_ptrs_buf_ = {max_batch_size_, kCPUpinned};\n        token_ids_buf_      = {max_batch_size_ * (ssize_t)session_len_, kCPUpinned};\n\n        output_ids_buf_ = {max_batch_size_, kCPUpinned};\n\n        for (int i = 0; i < phases; ++i) {\n            auto d = std::make_unique<GenerationData>();\n\n            d->random_state   = empty_like(random_state_buf_, kDEVICE);\n            d->random_seed    = empty_like(random_seed_buf_, kDEVICE);\n            d->random_init    = empty_like(random_init_buf_, kDEVICE);\n            d->token_ids_ptrs = empty_like(token_ids_ptrs_buf_, kDEVICE);\n            d->output_ids     = empty_like(output_ids_, kDEVICE);\n\n            data_.push_back(std::move(d));\n        }\n    }\n\n    void Setup(int phase, TensorMap& env)\n    {\n        auto& d = *data_.at(phase);\n\n        auto& b    = *env.at(\"batch\").data<BatchData*>()[0];\n        auto& copy = *env.at(\"copy\").data<BatchCopy*>()[0];\n\n        const auto& rc = b.rc;\n\n        // random states\n        d.random_init_needed = false;\n        for (int i = 0; i < b.perm.size(); ++i) {\n            const auto& c = *rc[i];\n            if (TM_LIKELY(b.perm[i] < b.bs0)) {  // existing\n                random_init_buf_[i] = false;\n            }\n            else if (c.random_state) {  // already initialized\n                std::copy_n(\n                    c.random_state, sizeof(curandState_t), random_state_buf_.data() + i * sizeof(curandState_t));\n            }\n            else {  // uninitialized\n                d.random_init_needed = true;\n                random_init_buf_[i]  = true;\n                random_seed_buf_[i]  = rc[i]->gen_cfg.random_seed;\n            }\n        }\n        copy(random_state_buf_, b.bsz, d.random_state);\n        if (d.random_init_needed) {\n            copy(random_init_buf_, b.bsz, d.random_init);\n            copy(random_seed_buf_, b.bsz, d.random_seed);\n        }\n\n        vector<int> used(b.bs0);\n        for (int i = 0; i < b.bsz; ++i) {\n            if (b.perm[i] < b.bs0) {\n                used[b.perm[i]] = 1;\n            }\n        }\n        for (int i = 0; i < b.bs0; ++i) {\n            if (!used[i]) {  // free unused chunks\n                h_token_ids_free_.push_back(h_token_ids_ptrs_[i]);\n            }\n        }\n        // swap-in token_ids\n        int* token_ids_buf = token_ids_buf_.data();\n        for (int i = 0; i < rc.size(); ++i) {\n            if (const auto& c = *rc[i]; TM_UNLIKELY(b.perm[i] >= b.bs0)) {\n                // allocation\n                TM_CHECK(!h_token_ids_free_.empty());\n                token_ids_ptrs_buf_[i] = h_token_ids_free_.back();\n                h_token_ids_free_.pop_back();\n                // copy to staging buffer\n                std::copy_n(c.token_ids, c.seq_len, token_ids_buf);\n                copy(token_ids_buf, c.seq_len, token_ids_ptrs_buf_[i]);\n                token_ids_buf += c.seq_len;\n            }\n            else {\n                token_ids_ptrs_buf_[i] = h_token_ids_ptrs_[b.perm[i]];\n            }\n        }\n\n        copy(token_ids_ptrs_buf_, b.bsz, d.token_ids_ptrs);\n\n        // update `h_token_ids_ptrs_`\n        std::copy_n(token_ids_ptrs_buf_.data(), b.bsz, h_token_ids_ptrs_.data());\n\n        d.generation_size = 0;\n        for (int i = 0; i < rc.size(); ++i) {\n            const auto& c = *rc[i];\n            d.generation_size += c.generating;\n        }\n        // dbg(d.generation_size);\n\n        logits_processor_->Setup(phase, env);\n        sampling_->Setup(phase, env);\n        stop_criteria_->Setup(phase, env);\n        guided_decoding_->Setup(phase, env);\n    }\n\n    void Prepare(int phase, TensorMap& env)\n    {\n        auto& d = *data_.at(phase);\n\n        auto& b    = *env.at(\"batch\").data<BatchData*>()[0];\n        auto& copy = *env.at(\"copy\").data<BatchCopy*>()[0];\n\n        if (auto g = copy.group()) {\n            Warp(random_state_.front(), d.random_state, b.bs0, b.perm, random_state_.back(), copy);\n            random_state_.Swap();\n        }\n    }\n\n    void Unprep(int phase, TensorMap& env)\n    {\n        auto& d    = *data_.at(phase);\n        auto& b    = *env.at(\"batch\").data<BatchData*>()[0];\n        auto& copy = *env.at(\"copy\").data<BatchCopy*>()[0];\n\n        // state -> data\n        copy(random_state_.front().buffer(), b.bsz * sizeof(curandState_t), d.random_state);\n        copy(output_ids_, b.bsz, d.output_ids);\n    }\n\n    void Fetch(int phase, TensorMap& env)\n    {\n        auto& d    = *data_.at(phase);\n        auto& copy = *env.at(\"copy\").data<BatchCopy*>()[0];\n\n        copy(d.random_state, d.random_state.size(), random_state_buf_);\n        env.produce(\"random_state\", random_state_buf_);\n\n        copy(d.output_ids, d.output_ids.size(), output_ids_buf_);\n        env.produce(\"output_ids\", output_ids_buf_);\n\n        sampling_->Fetch(phase, env);\n    }\n\n    void Update(int phase, TensorMap& env)\n    {\n        sampling_->Update(phase, env);\n    }\n\n    void Forward(int phase, TensorMap& env)\n    {\n        auto& d = *data_.at(phase);\n        auto& b = *env.at(\"batch\").data<BatchData*>()[0];\n\n        const auto stream = core::Context::stream().handle();\n\n        if (d.random_init_needed) {\n            InitializeRandomStates((curandState_t*)random_state_.front().raw_data(),\n                                   d.random_seed.data(),\n                                   d.random_init.data(),\n                                   b.bsz,\n                                   stream);\n            sync_check_cuda_error();\n        }\n\n        env.emplace(\"output_ids\", output_ids_);              // out\n        env.emplace(\"curand_state\", random_state_.front());  // inout\n\n        if (const int gs = d.generation_size) {\n\n            env.emplace(\"token_ids_ptrs\", d.token_ids_ptrs.slice(0, gs));\n\n            auto logits = env.consume(\"logits\");\n\n            if (logits.dtype() != kFloat32) {\n                auto tmp = empty_like(logits, kFloat32);\n                invokeCastFloat2D(logits, tmp, stream);\n                logits = std::move(tmp);\n            }\n\n            env.produce(\"logits\", logits.slice(0, gs));\n\n            Buffer_<int> output_pos{max_batch_size_, kDEVICE};\n            Copy(env.at(\"sequence_length\").buffer(), gs, output_pos);\n\n            logits_processor_->Forward(phase, env);\n\n            guided_decoding_->FillMask(phase, env);\n            guided_decoding_->ApplyMask(phase, env);\n\n            sampling_->Forward(phase, env);\n\n            guided_decoding_->Update(phase, env);\n\n            AppendTokenIds(d.token_ids_ptrs.data(), output_ids_.data(), output_pos.data(), gs, stream);\n\n            stop_criteria_->Forward(phase, env);\n        }\n    }\n};\n\nGeneration::~Generation() = default;\n\nGeneration::Generation(DataType              dtype,\n                       int                   max_batch_size,\n                       int                   session_len,\n                       int                   vocab_size,\n                       int                   vocab_size_padded,\n                       const comm::HostComm& tp_group,\n                       int                   phases):\n    impl_{std::make_unique<Impl>(dtype, max_batch_size, session_len, vocab_size, vocab_size_padded, tp_group, phases)}\n{\n}\n\nvoid Generation::Run(BatchOp op, int phase, TensorMap& env)\n{\n    if (op == BatchOp::kSetup) {\n        return impl_->Setup(phase, env);\n    }\n    else if (op == BatchOp::kPrepare) {\n        return impl_->Prepare(phase, env);\n    }\n    else if (op == BatchOp::kForward) {\n        return impl_->Forward(phase, env);\n    }\n    else if (op == BatchOp::kUnprep) {\n        return impl_->Unprep(phase, env);\n    }\n    else if (op == BatchOp::kFetch) {\n        return impl_->Fetch(phase, env);\n    }\n    else if (op == BatchOp::kUpdate) {\n        return impl_->Update(phase, env);\n    }\n}\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/generation/generation.h",
    "content": "\n\n#pragma once\n\n#include <memory>\n\n#include \"src/turbomind/core/core.h\"\n#include \"src/turbomind/engine/batch.h\"\n\nnamespace turbomind {\n\nnamespace comm {\nclass HostComm;\n}\n\nstruct GenerationData;\n\nclass Generation {\npublic:\n    ~Generation();\n\n    Generation(DataType              data_type,  //\n               int                   max_batch_size,\n               int                   session_len,\n               int                   vocab_size,\n               int                   vocab_size_padded,\n               const comm::HostComm& tp_group,\n               int                   phases);\n\n    void Run(BatchOp op, int phase, TensorMap& env);\n\nprivate:\n    struct Impl;\n\n    std::unique_ptr<Impl> impl_;\n};\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/generation/guided_decoding.cc",
    "content": "#include \"src/turbomind/generation/guided_decoding.h\"\n\n#include \"src/turbomind/comm/host_comm.h\"\n#include \"src/turbomind/core/allocator.h\"\n#include \"src/turbomind/engine/batch.h\"\n#include \"src/turbomind/kernels/apply_token_bitmask_inplace_cuda.h\"\n#include \"xgrammar/matcher.h\"\n#include <dlpack/dlpack.h>\n\nnamespace turbomind {\n\nstruct GuidedDecoding::Data {\n    Tensor_<int32_t> bitmask;\n    bool             active{};\n\n    std::vector<std::shared_ptr<xgrammar::GrammarMatcher>> matchers;\n};\n\nGuidedDecoding::GuidedDecoding(const BaseGenerationParam& base, const comm::HostComm& tp_group, int phases):\n    BaseGenerationParam{base},        //\n    tp_group_{tp_group->Split(0, 0)}  // duplicate to avoid data race\n{\n    const auto bitmask_size = xgrammar::GetBitmaskSize(vocab_size_padded_);\n\n    bitmask_buf_    = {{max_batch_size_, bitmask_size}, kCPUpinned};\n    output_ids_buf_ = {max_batch_size_, kCPUpinned};\n\n    for (int i = 0; i < phases; ++i) {\n        auto& d    = data_.emplace_back(std::make_shared<Data>());\n        d->bitmask = empty_like(bitmask_buf_);\n    }\n}\n\nvoid GuidedDecoding::Setup(int phase, TensorMap& env)\n{\n    auto& d = *data_.at(phase);\n    auto& b = *env.at(\"batch\").data<BatchData*>()[0];\n\n    d.matchers.clear();\n    d.active = false;\n    for (const auto& r : b.rc) {\n        if (d.matchers.emplace_back(r->req->matcher)) {\n            d.active = true;\n        }\n    }\n}\n\nvoid GuidedDecoding::FillMask(int phase, TensorMap& env)\n{\n    if (auto& d = *data_.at(phase); d.active) {\n        static_assert(sizeof(ssize_t) == sizeof(int64_t));\n        DLTensor dlbitmask{bitmask_buf_.data(),\n                           DLDevice{kDLCPU, 0},\n                           bitmask_buf_.ndim(),\n                           xgrammar::GetBitmaskDLType(),\n                           (int64_t*)bitmask_buf_.shape().data(),\n                           nullptr,\n                           0};\n        if (tp_group_->rank() == 0) {\n            for (size_t i = 0; i < d.matchers.size(); ++i) {\n                if (const auto& matcher = d.matchers[i]; matcher && !matcher->IsTerminated()) {\n                    matcher->FillNextTokenBitmask(&dlbitmask, i);\n                }\n                else {\n                    std::fill_n(bitmask_buf_.data() + i * bitmask_buf_.stride(0),\n                                bitmask_buf_.stride(0),\n                                static_cast<int32_t>(-1));\n                }\n            }\n        }\n    }\n}\n\nvoid GuidedDecoding::ApplyMask(int phase, TensorMap& env)\n{\n    if (auto& d = *data_.at(phase); d.active) {\n        const ssize_t numel = d.matchers.size() * bitmask_buf_.stride(0);\n        if (tp_group_->n_ranks() > 1) {\n            // bcast the data instead of `bitmask_buf` instance (which may avoid copying the data)\n            comm::Broadcast(tp_group_, bitmask_buf_.data(), numel, 0);\n        }\n        Copy(bitmask_buf_.buffer(), numel, d.bitmask.buffer());\n        ApplyTokenBitmaskInplace(env.at(\"logits\"), d.bitmask.slice(0, d.matchers.size()));\n    }\n}\n\nvoid GuidedDecoding::Update(int phase, TensorMap& env)\n{\n    if (auto& d = *data_.at(phase); d.active) {\n        Copy(env.at(\"output_ids\").buffer(), d.matchers.size(), output_ids_buf_);\n        core::Context::stream().Sync();\n        if (tp_group_->rank() == 0) {\n            for (size_t i = 0; i < d.matchers.size(); ++i) {\n                if (const auto& matcher = d.matchers[i]; matcher && !matcher->IsTerminated()) {\n                    matcher->AcceptToken(output_ids_buf_[i]);\n                }\n            }\n        }\n    }\n}\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/generation/guided_decoding.h",
    "content": "#pragma once\n\n#include <memory>\n\n#include \"src/turbomind/generation/base_param.h\"\n\n#include \"src/turbomind/comm/host_comm.h\"\n#include \"src/turbomind/core/core.h\"\n\nnamespace turbomind {\n\nclass GuidedDecoding: public BaseGenerationParam {\npublic:\n    explicit GuidedDecoding(const BaseGenerationParam& base, const comm::HostComm& tp_group, int phases);\n\n    void Setup(int phase, TensorMap& env);\n\n    void FillMask(int phase, TensorMap& env);\n\n    void ApplyMask(int phase, TensorMap& env);\n\n    void Update(int phase, TensorMap& env);\n\nprivate:\n    comm::HostComm tp_group_;\n\n    struct Data;\n    std::vector<std::shared_ptr<Data>> data_;\n\n    Tensor_<int32_t> bitmask_buf_;\n    Buffer_<int>     output_ids_buf_;\n};\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/generation/logits_processor.cc",
    "content": "/*\n * Copyright (c) 2019-2023, NVIDIA CORPORATION.  All rights reserved.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include \"src/turbomind/core/allocator.h\"\n#include \"src/turbomind/core/check.h\"\n\n#include \"src/turbomind/engine/batch.h\"\n#include \"src/turbomind/engine/request.h\"\n\n#include \"src/turbomind/kernels/ban_bad_words.h\"\n#include \"src/turbomind/kernels/sampling_penalty_kernels.h\"\n\n#include \"src/turbomind/generation/logits_processor.h\"\n#include \"src/turbomind/generation/utils.h\"\n\nnamespace turbomind {\n\nstruct LogitsProcessor::Data {\n\n    Data(int max_batch_size, DeviceType device)\n    {\n        repetition_penalty_buf = {max_batch_size, device};\n        min_lengths_buf        = {max_batch_size, device};\n        temperature_buf        = {max_batch_size, device};\n        bad_words_buf          = {max_batch_size * 2 * kMaxStopBadWordsLen, device};\n        end_ids_buf            = {max_batch_size * kMaxEndIdsSize, device};\n    }\n\n    Buffer_<float> repetition_penalty_buf;\n    Buffer_<int>   min_lengths_buf;\n    Buffer_<float> temperature_buf;\n    Buffer_<int>   bad_words_buf;\n    Buffer_<int>   end_ids_buf;\n\n    Tensor_<int> bad_words_ten;\n    Tensor_<int> end_ids_ten;\n\n    bool has_repetition_penalty{};\n    bool has_bad_words_penalty{};\n    bool has_min_length_penalty{};\n    bool has_temperature_penalty{};\n};\n\nLogitsProcessor::LogitsProcessor(const BaseGenerationParam& base, int phases): BaseGenerationParam{base}\n{\n    buf_ = std::make_shared<Data>(max_batch_size_, kCPUpinned);\n    for (int i = 0; i < phases; ++i) {\n        data_.push_back(std::make_shared<Data>(max_batch_size_, kDEVICE));\n    }\n}\n\nvoid LogitsProcessor::Forward(int phase, TensorMap& env)\n{\n    // apply repetition penalty -> ban bad words -> min length penalty -> temperature penalty\n    // the order is same with transformerss\n    TM_LOG_DEBUG(\"%s start\", __PRETTY_FUNCTION__);\n\n    Tensor_<float>      logits          = env.at(\"logits\");\n    const Buffer_<int*> token_ids_ptrs  = env.at(\"token_ids_ptrs\").buffer();\n    const Buffer_<int>  sequence_length = env.at(\"sequence_length\").buffer();\n\n    const auto bsz = logits.shape(0);\n\n    auto& d = *data_.at(phase);\n\n    auto stream = core::Context::stream().handle();\n\n    // repetition penalty\n    if (d.has_repetition_penalty) {\n        ApplyRepetitionPenalty(logits, d.repetition_penalty_buf, token_ids_ptrs, sequence_length, stream);\n        sync_check_cuda_error();\n    }\n\n    // ban bad words\n    if (auto& bad_words = d.bad_words_ten) {\n        BanBadWords(logits, token_ids_ptrs, sequence_length, bad_words, stream);\n        sync_check_cuda_error();\n    }\n\n    // min length\n    if (d.has_min_length_penalty) {\n        invokeMinLengthPenalty(logits.data(),\n                               d.min_lengths_buf.data(),\n                               sequence_length.data(),\n                               vocab_size_padded_,\n                               bsz,\n                               d.end_ids_ten.data(),\n                               d.end_ids_ten.shape(1),\n                               stream);\n        sync_check_cuda_error();\n    }\n\n    // temperature\n    if (d.has_temperature_penalty) {\n        invokeBatchApplyTemperaturePenalty_v2(logits.data(),  //\n                                              (float*)nullptr,\n                                              d.temperature_buf.data(),\n                                              bsz,\n                                              vocab_size_,\n                                              vocab_size_padded_,\n                                              stream);\n        sync_check_cuda_error();\n    }\n\n    TM_LOG_DEBUG(\"%s stop\", __PRETTY_FUNCTION__);\n}\n\nvoid LogitsProcessor::Setup(int phase, TensorMap& env)\n{\n    TM_LOG_DEBUG(\"%s start\", __PRETTY_FUNCTION__);\n\n    auto& d = *data_.at(phase);\n\n    const auto& rs   = env.at(\"batch\").data<BatchData*>()[0]->rc;\n    auto&       copy = *env.at(\"copy\").data<BatchCopy*>()[0];\n\n    const int bsz = rs.size();\n\n    auto& repetition_penalty = buf_->repetition_penalty_buf;\n    auto& temperature        = buf_->temperature_buf;\n    auto& min_lengths        = buf_->min_lengths_buf;\n\n    d.has_temperature_penalty = {};\n    d.has_min_length_penalty  = {};\n    d.has_repetition_penalty  = {};\n    d.has_bad_words_penalty   = {};\n\n    for (int i = 0; i < bsz; ++i) {\n        auto& g = rs[i]->gen_cfg;\n\n        // repetition_penalty\n        repetition_penalty[i] = g.repetition_penalty;\n        if (repetition_penalty[i] != 1.f) {\n            d.has_repetition_penalty = true;\n        }\n\n        // temperature\n        temperature[i] = g.temperature;\n        if (g.temperature != 1.f) {\n            d.has_temperature_penalty = true;\n        }\n\n        // min_length\n        min_lengths[i] = rs[i]->prompt_len + g.min_new_tokens;\n        if (rs[i]->seq_len + rs[i]->beta < min_lengths[i]) {\n            d.has_min_length_penalty = true;\n        }\n    }\n\n    if (d.has_temperature_penalty) {\n        copy(temperature, bsz, d.temperature_buf);\n    }\n\n    if (d.has_repetition_penalty) {\n        copy(repetition_penalty, bsz, d.repetition_penalty_buf);\n    }\n\n    if (d.has_min_length_penalty) {\n        copy(min_lengths, bsz, d.min_lengths_buf);\n    }\n\n    sync_check_cuda_error();\n\n    d.bad_words_ten = {};\n    init_stop_bad_words(&GenerationConfig::bad_ids,  //\n                        \"bad_words\",\n                        rs,\n                        buf_->bad_words_buf.data(),\n                        d.bad_words_buf.data(),\n                        d.bad_words_ten,\n                        copy);\n\n    if (d.has_min_length_penalty) {  // end ids for min length\n        d.end_ids_ten  = {};\n        int max_length = 0;\n        for (int i = 0; i < bsz; ++i) {\n            max_length = std::max(max_length, (int)rs[i]->gen_cfg.eos_ids.size());\n        }\n        if (max_length) {\n            max_length     = std::min(max_length, kMaxEndIdsSize);\n            int* h_end_ids = buf_->end_ids_buf.data();\n            std::fill(h_end_ids, h_end_ids + std::min(kMaxEndIdsSize, max_length) * bsz, -1);\n            for (int i = 0; i < bsz; ++i) {\n                const auto& eos_ids = rs[i]->gen_cfg.eos_ids;\n                if (eos_ids.size() == 0) {\n                    continue;\n                }\n                if (TM_UNLIKELY(eos_ids.size() > kMaxEndIdsSize)) {\n                    TM_LOG_WARNING(\"[InitializeSampling] [%ld] eos length (%d) exceeds %d, truncated to %d\",\n                                   (long)rs[i]->req->id,\n                                   (int)eos_ids.size(),\n                                   kMaxEndIdsSize,\n                                   kMaxEndIdsSize);\n                }\n                std::copy_n(eos_ids.begin(), std::min((int)eos_ids.size(), kMaxEndIdsSize), h_end_ids);\n                h_end_ids += max_length;\n            }\n            copy(buf_->end_ids_buf, bsz * max_length, d.end_ids_buf);\n            d.end_ids_ten = {d.end_ids_buf.data(), {bsz, max_length}, kDEVICE};\n        }\n    }\n\n    TM_LOG_DEBUG(\"%s stop\", __PRETTY_FUNCTION__);\n}\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/generation/logits_processor.h",
    "content": "/*\n * Copyright (c) 2019-2023, NVIDIA CORPORATION.  All rights reserved.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#pragma once\n\n#include <memory>\n\n#include \"src/turbomind/core/core.h\"\n\n#include \"src/turbomind/generation/base_param.h\"\n\nnamespace turbomind {\n\nclass LogitsProcessor: public BaseGenerationParam {\npublic:\n    explicit LogitsProcessor(const BaseGenerationParam& base, int phases);\n\n    void Setup(int phase, TensorMap& env);\n\n    void Forward(int phase, TensorMap& env);\n\nprivate:\n    struct Data;\n\n    std::vector<std::shared_ptr<Data>> data_;\n\n    std::shared_ptr<Data> buf_;  // temp host buffer\n};\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/generation/sampling.cc",
    "content": "/*\n * Copyright (c) 2019-2023, NVIDIA CORPORATION.  All rights reserved.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include \"src/turbomind/generation/sampling.h\"\n\n#include \"src/turbomind/kernels/sampling_kernels.h\"\n#include \"src/turbomind/kernels/sampling_topk_kernels.h\"\n#include \"src/turbomind/kernels/sampling_topp_kernels.h\"\n\n#include \"src/turbomind/engine/batch.h\"\n#include \"src/turbomind/engine/request.h\"\n\n#include \"src/turbomind/utils/constant.h\"\n#include \"src/turbomind/utils/logger.h\"\n\nnamespace turbomind {\n\nstruct SamplingData {\n\n    explicit SamplingData(int max_batch_size, DeviceType device)\n    {\n        top_k_buf = {max_batch_size, device};\n        top_p_buf = {max_batch_size, device};\n        min_p_buf = {max_batch_size, device};\n        kept_buf  = {max_batch_size, device};\n\n        sampled_logprobs = {max_batch_size * (ssize_t)kMaxLogProb, device};\n        sampled_indices  = {max_batch_size * (ssize_t)kMaxLogProb, device};\n        sampled_nums     = {max_batch_size, device};\n    }\n\n    int   max_topk = 0;\n    int   min_topk = 0;\n    float min_topp = 0;\n    float max_minp = 0;\n\n    Buffer_<int>   top_k_buf;\n    Buffer_<float> top_p_buf;\n    Buffer_<float> min_p_buf;\n\n    Buffer_<int> kept_buf;  // kept sample\n\n    bool output_logprobs = 0;\n\n    Buffer_<float> sampled_logprobs;\n    Buffer_<int>   sampled_indices;\n    Buffer_<int>   sampled_nums;\n};\n\nSampling::Sampling(const BaseGenerationParam& base, int phases): BaseGenerationParam{base}\n{\n    top_k_ = {max_batch_size_, kCPUpinned};\n    top_p_ = {max_batch_size_, kCPUpinned};\n    min_p_ = {max_batch_size_, kCPUpinned};\n    kept_  = {max_batch_size_, kCPUpinned};\n\n    sampled_logprobs_buf_ = {max_batch_size_ * (ssize_t)kMaxLogProb, kCPUpinned};\n    sampled_indices_buf_  = {max_batch_size_ * (ssize_t)kMaxLogProb, kCPUpinned};\n    sampled_nums_buf_     = {max_batch_size_, kCPUpinned};\n\n    // constant array\n    std::fill_n(kept_.data(), max_batch_size_, vocab_size_);\n\n    for (int i = 0; i < phases; ++i) {\n        data_.push_back(std::make_shared<SamplingData>(max_batch_size_, kDEVICE));\n    }\n}\n\nvoid Sampling::Forward(int phase, TensorMap& args)\n{\n    // step1:\n    //  - use topk / topp_minp kernel to sort and filter the scores\n    //  - softmax the left score\n    // step2:\n    //  - sampling from left and sorted scores\n\n    TM_LOG_DEBUG(\"%s start\", __PRETTY_FUNCTION__);\n\n    auto& d = *data_.at(phase);\n\n    Tensor_<float> logits = args.at(\"logits\");\n\n    const auto bsz = logits.shape(0);\n\n    Buffer_<int> indices(bsz * vocab_size_padded_, kDEVICE);\n\n    auto stream = core::Context::stream().handle();\n\n    // use topk sort if some request use topk filter\n    if (d.max_topk > 0) {\n        // TODO: top_k >= 64 is much slower than torch.topk()\n        TopKSortFilterParams params{};\n        params.logits            = logits.data();\n        params.sorted_logits     = logits.data();\n        params.sorted_indices    = indices.data();\n        params.kept              = d.kept_buf.data();\n        params.top_ks            = d.top_k_buf.data();\n        params.max_top_k         = d.max_topk;\n        params.batch_size        = bsz;\n        params.vocab_size        = vocab_size_;\n        params.vocab_size_padded = vocab_size_padded_;\n        invokeTopKSortFilter<float>(params, stream);\n    }\n\n    // use topp sort if some request skip topk filter\n    if (d.min_topk == 0) {\n        invokeSoftmax<float>(logits.data(), vocab_size_padded_, vocab_size_, bsz, d.kept_buf.data(), stream);\n\n        TopPSortParams params{};\n        params.logits            = logits.data();\n        params.sorted_logits     = logits.data();\n        params.sorted_indices    = indices.data();\n        params.kept              = d.kept_buf.data();\n        params.top_ks            = d.top_k_buf.data();\n        params.top_ps            = d.top_p_buf.data();\n        params.batch_size        = bsz;\n        params.vocab_size        = vocab_size_;\n        params.vocab_size_padded = vocab_size_padded_;\n        invokeTopPSort<float>(params, stream);\n    }\n\n    // apply topp minp filter\n    if (d.max_minp != 0.f || d.min_topp != 1.f) {\n        TopPMinPFilterParams params{};\n        params.sorted_logits     = logits.data();\n        params.sorted_indices    = indices.data();\n        params.kept              = d.kept_buf.data();\n        params.top_ps            = d.top_p_buf.data();\n        params.min_ps            = d.min_p_buf.data();\n        params.batch_size        = bsz;\n        params.vocab_size        = vocab_size_;\n        params.vocab_size_padded = vocab_size_padded_;\n        invokeTopPMinPFilter<float>(params, stream);\n    }\n\n    // sample\n    {\n        SamplingParams params{};\n        params.logits          = logits.data();\n        params.stride          = vocab_size_padded_;\n        params.indices         = indices.data();\n        params.kept            = d.kept_buf.data();\n        params.curandstate     = (curandState_t*)args.at(\"curand_state\").raw_data();\n        params.batch_size      = bsz;\n        params.output_ids      = args.at(\"output_ids\").data<int>();  // (B, 1)\n        params.sequence_length = args.at(\"sequence_length\").data<int>();\n\n        if (d.output_logprobs) {\n            params.sampled_logprobs = d.sampled_logprobs.data();\n            params.sampled_indexes  = d.sampled_indices.data();\n            params.sampled_nums     = d.sampled_nums.data();\n        }\n\n        invokeSampling<float>(params, stream);\n        sync_check_cuda_error();\n    }\n\n    TM_LOG_DEBUG(\"%s stop\", __PRETTY_FUNCTION__);\n}\n\nvoid Sampling::Setup(int phase, TensorMap& env)\n{\n\n    const auto& rc   = env.at(\"batch\").data<BatchData*>()[0]->rc;\n    auto&       copy = *env.at(\"copy\").data<BatchCopy*>()[0];\n\n    const auto bsz = rc.size();\n\n    for (int i = 0; i < bsz; ++i) {\n        top_k_[i] = rc[i]->gen_cfg.top_k;\n        top_p_[i] = rc[i]->gen_cfg.top_p;\n        min_p_[i] = rc[i]->gen_cfg.min_p;\n    }\n\n    auto& d = *data_.at(phase);\n\n    d.max_topk = *std::max_element(top_k_.begin(), top_k_.begin() + bsz);\n    d.min_topk = *std::min_element(top_k_.begin(), top_k_.begin() + bsz);\n    d.min_topp = *std::min_element(top_p_.begin(), top_p_.begin() + bsz);\n    d.max_minp = *std::max_element(min_p_.begin(), min_p_.begin() + bsz);\n\n    copy(top_k_.data(), bsz, d.top_k_buf.data());\n    copy(top_p_.data(), bsz, d.top_p_buf.data());\n\n    copy(min_p_.data(), bsz, d.min_p_buf.data());\n    copy(kept_.data(), bsz, d.kept_buf.data());\n\n    d.output_logprobs = std::any_of(rc.begin(), rc.end(), [](auto& x) { return x->gen_cfg.output_logprobs; });\n}\n\nvoid Sampling::Fetch(int phase, TensorMap& env)\n{\n    auto& d    = *data_.at(phase);\n    auto& b    = *env.at(\"batch\").data<BatchData*>()[0];\n    auto& copy = *env.at(\"copy\").data<BatchCopy*>()[0];\n\n    if (d.output_logprobs) {\n        copy(d.sampled_logprobs, b.bsz * kMaxLogProb, sampled_logprobs_buf_);\n        copy(d.sampled_indices, b.bsz * kMaxLogProb, sampled_indices_buf_);\n        copy(d.sampled_nums, b.bsz, sampled_nums_buf_);\n    }\n}\n\nvoid Sampling::Update(int phase, TensorMap& env)\n{\n    auto& d = *data_.at(phase);\n    auto& b = *env.at(\"batch\").data<BatchData*>()[0];\n\n    if (d.output_logprobs) {\n        float* logprob_buf = sampled_logprobs_buf_.data();\n        int*   indices_buf = sampled_indices_buf_.data();\n        int*   n_buf       = sampled_nums_buf_.data();\n        for (int i = 0; i < b.rc.size(); ++i) {\n            if (auto& x = *b.rc[i]; x.gen_cfg.output_logprobs) {\n                // output buffers\n                auto logprob_out = x.req->outputs.at(\"logprob_vals\").data<float>();\n                auto indices_out = x.req->outputs.at(\"logprob_indexes\").data<int>();\n                auto n_out       = x.req->outputs.at(\"logprob_nums\").data<int>();\n                // offset into output buffers\n                const int offset = x.seq_len - x.prompt_len;\n                std::copy_n(logprob_buf + i * kMaxLogProb, n_buf[i], logprob_out + offset * kMaxLogProb);\n                std::copy_n(indices_buf + i * kMaxLogProb, n_buf[i], indices_out + offset * kMaxLogProb);\n                n_out[offset] = n_buf[i];\n            }\n        }\n    }\n}\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/generation/sampling.h",
    "content": "\n#pragma once\n\n#include \"src/turbomind/core/core.h\"\n#include \"src/turbomind/generation/base_param.h\"\n\nnamespace turbomind {\n\nstruct SamplingData;\n\nclass Sampling: public BaseGenerationParam {\npublic:\n    explicit Sampling(const BaseGenerationParam& base, int phases);\n\n    void Setup(int phase, TensorMap& env);\n\n    void Forward(int phase, TensorMap& env);\n\n    void Fetch(int phase, TensorMap& env);\n\n    void Update(int phase, TensorMap& env);\n\nprivate:\n    std::vector<std::shared_ptr<SamplingData>> data_;\n\n    // host buffer\n    Buffer_<int>   kept_;\n    Buffer_<int>   top_k_;\n    Buffer_<float> top_p_;\n    Buffer_<float> min_p_;\n\n    Buffer_<float> sampled_logprobs_buf_;\n    Buffer_<int>   sampled_indices_buf_;\n    Buffer_<int>   sampled_nums_buf_;\n};\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/generation/stop_criteria.cc",
    "content": "\n\n#include \"src/turbomind/generation/stop_criteria.h\"\n#include \"src/turbomind/generation/utils.h\"\n\n#include \"src/turbomind/kernels/stop_criteria_kernels.h\"\n\n#include \"src/turbomind/engine/batch.h\"\n#include \"src/turbomind/engine/request.h\"\n\nnamespace turbomind {\n\nstruct StopCriteriaData {\n    explicit StopCriteriaData(int batch_size)\n    {\n        stop_words  = {batch_size * 2 * kMaxStopBadWordsLen, kDEVICE};\n        max_seq_len = {batch_size, kDEVICE};\n    }\n    Buffer_<int> stop_words;\n    Buffer_<int> max_seq_len;\n    Tensor_<int> stop_words_ten;  // reference int `stop_words`\n};\n\nStopCriteria::StopCriteria(const BaseGenerationParam& base, int phases): BaseGenerationParam{base}\n{\n    stop_words_buf_  = {max_batch_size_ * 2 * kMaxStopBadWordsLen, kCPUpinned};\n    max_seq_len_buf_ = {max_batch_size_, kCPUpinned};\n    for (int i = 0; i < phases; ++i) {\n        data_.push_back(std::make_shared<StopCriteriaData>(max_batch_size_));\n    }\n}\n\nvoid StopCriteria::Setup(int phase, TensorMap& env)\n{\n    auto& d = *data_.at(phase);\n\n    const auto& rs   = env.at(\"batch\").data<BatchData*>()[0]->rc;\n    auto&       copy = *env.at(\"copy\").data<BatchCopy*>()[0];\n\n    for (int i = 0; i < rs.size(); ++i) {\n        max_seq_len_buf_[i] = rs[i]->max_seq_len;\n    }\n    copy(max_seq_len_buf_, rs.size(), d.max_seq_len);\n\n    d.stop_words_ten = {};\n    init_stop_bad_words(&GenerationConfig::stop_ids,  //\n                        \"stop_words\",\n                        rs,\n                        stop_words_buf_.data(),\n                        d.stop_words.data(),\n                        d.stop_words_ten,\n                        copy);\n}\n\nvoid StopCriteria::Forward(int phase, TensorMap& env)\n{\n    auto& d = *data_.at(phase);\n\n    const Buffer_<int*> token_ids_ptrs  = env.at(\"token_ids_ptrs\").buffer();\n    const Buffer_<int>  sequence_length = env.at(\"sequence_length\").buffer();\n\n    Buffer_<bool> finished = env.at(\"finished\").buffer();\n\n    const int batch_size = token_ids_ptrs.size();\n\n    auto stream = core::Context::stream().handle();\n\n    if (auto& stop_words = d.stop_words_ten) {\n        TM_CHECK_EQ(stop_words.ndim(), 3);  // [batch, 2, len]\n        size_t stop_words_len = stop_words.shape(2);\n        invokeStopWordsCriterion_v2((const int**)token_ids_ptrs.data(),\n                                    sequence_length.data(),\n                                    stop_words.data(),\n                                    finished.data(),\n                                    stop_words_len,\n                                    batch_size,\n                                    stream);\n        sync_check_cuda_error();\n    }\n\n    invokeLengthCriterion_v2(finished.data(),  //\n                             sequence_length.data(),\n                             d.max_seq_len.data(),\n                             batch_size,\n                             stream);\n    sync_check_cuda_error();\n}\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/generation/stop_criteria.h",
    "content": "/*\n * Copyright (c) 2019-2023, NVIDIA CORPORATION.  All rights reserved.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#pragma once\n\n#include \"src/turbomind/core/core.h\"\n\n#include \"src/turbomind/generation/base_param.h\"\n\nnamespace turbomind {\n\nstruct StopCriteriaData;\n\nclass StopCriteria: public BaseGenerationParam {\npublic:\n    explicit StopCriteria(const BaseGenerationParam& base, int phases);\n\n    void Setup(int phase, TensorMap& env);\n\n    void Forward(int phase, TensorMap& env);\n\nprivate:\n    std::vector<std::shared_ptr<StopCriteriaData>> data_;\n\n    Buffer_<int> stop_words_buf_;\n    Buffer_<int> max_seq_len_buf_;\n};\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/generation/utils.h",
    "content": "\n#include <functional>\n#include <vector>\n\n#include \"src/turbomind/core/core.h\"\n\nnamespace turbomind {\n\nconstexpr int kMaxStopBadWordsLen = 32;\nconstexpr int kMaxEndIdsSize      = 32;\n\nnamespace {\n\ntemplate<class G, class Rs, class T, class Copy>\nvoid init_stop_bad_words(G getter, const char* key, const Rs& rs, T* h_buf, T* d_buf, Tensor_<T>& out, Copy& copy)\n{\n    const int bsz        = rs.size();\n    int       max_length = 0;\n\n    std::vector<std::pair<const int*, int>> copy_tokens(bsz);\n    std::vector<std::pair<const int*, int>> copy_offsets(bsz);\n    for (int i = 0; i < bsz; ++i) {\n        const auto& [token_ids, offsets] = std::invoke(getter, rs[i]->gen_cfg);\n        if (offsets.size() == 0 || token_ids.size() == 0) {\n            continue;\n        }\n        FT_CHECK(offsets.back() == token_ids.size());\n        if (offsets.back() <= kMaxStopBadWordsLen) {\n            copy_tokens[i]  = std::make_pair(token_ids.data(), (int)token_ids.size());\n            copy_offsets[i] = std::make_pair(offsets.data(), (int)offsets.size());\n            max_length      = std::max(max_length, (int)token_ids.size());\n        }\n        else {\n            auto trunc_offset_size =\n                std::upper_bound(offsets.begin(),\n                                 offsets.begin() + std::min(kMaxStopBadWordsLen, (int)offsets.size()),\n                                 kMaxStopBadWordsLen)\n                - offsets.begin();\n            TM_LOG_WARNING(\"[InitializeSampling] [%ld] %s length (%d) exceeds %d, truncated to %d\",\n                           rs[i]->req->id,\n                           key,\n                           offsets.back(),\n                           kMaxStopBadWordsLen,\n                           trunc_offset_size);\n            if (trunc_offset_size > 0) {\n                int trunc_token_size = offsets[trunc_offset_size - 1];\n                copy_tokens[i]       = std::make_pair(token_ids.data(), trunc_token_size);\n                copy_offsets[i]      = std::make_pair(offsets.data(), trunc_offset_size);\n                max_length           = std::max(max_length, trunc_token_size);\n            }\n        }\n    }\n    if (!max_length) {\n        return;\n    }\n    std::fill_n(h_buf, bsz * 2 * max_length, -1);\n    for (int i = 0; i < bsz; ++i) {\n        if (copy_tokens[i].first != nullptr) {\n            std::copy_n(copy_tokens[i].first, copy_tokens[i].second, h_buf + i * 2 * max_length);\n        }\n        if (copy_offsets[i].first != nullptr) {\n            std::copy_n(copy_offsets[i].first, copy_offsets[i].second, h_buf + i * 2 * max_length + max_length);\n        }\n    }\n    copy(h_buf, bsz * 2 * max_length, d_buf);\n    // Construct a tensor from the device buffer\n    out = {d_buf, {bsz, 2, max_length}, kDEVICE};\n};\n\n}  // namespace\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/kernels/CMakeLists.txt",
    "content": "# Copyright (c) 2019-2023, NVIDIA CORPORATION.  All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\ncmake_minimum_required(VERSION 3.11)\n\nadd_library(ban_bad_words STATIC ban_bad_words.cu)\nset_property(TARGET ban_bad_words PROPERTY POSITION_INDEPENDENT_CODE  ON)\nset_property(TARGET ban_bad_words PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS  ON)\n\nadd_library(stop_criteria STATIC stop_criteria_kernels.cu)\nset_property(TARGET stop_criteria PROPERTY POSITION_INDEPENDENT_CODE  ON)\nset_property(TARGET stop_criteria PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS  ON)\n\nadd_library(activation_kernels STATIC activation_kernels.cu)\nset_property(TARGET activation_kernels PROPERTY POSITION_INDEPENDENT_CODE  ON)\nset_property(TARGET activation_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS  ON)\n\nadd_library(activation STATIC activation.cu)\nset_property(TARGET activation PROPERTY POSITION_INDEPENDENT_CODE ON)\nset_property(TARGET activation PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)\n\nadd_library(quantization_kernels STATIC quantization.cu)\nset_property(TARGET quantization_kernels PROPERTY POSITION_INDEPENDENT_CODE  ON)\nset_property(TARGET quantization_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS  ON)\n\nif (BUILD_TEST)\nadd_executable(test_quantization test_quantization.cc gemm/test/test_utils.cu)\ntarget_link_libraries(test_quantization PRIVATE quantization_kernels core)\nendif ()\n\nadd_library(logprob_kernels STATIC logprob_kernels.cu)\nset_property(TARGET logprob_kernels PROPERTY POSITION_INDEPENDENT_CODE  ON)\nset_property(TARGET logprob_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS  ON)\n\nadd_library(unfused_attention_kernels STATIC unfused_attention_kernels.cu)\nset_property(TARGET unfused_attention_kernels PROPERTY POSITION_INDEPENDENT_CODE  ON)\nset_property(TARGET unfused_attention_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS  ON)\n\nadd_library(decoding_kernels STATIC decoding_kernels.cu)\nset_property(TARGET decoding_kernels PROPERTY POSITION_INDEPENDENT_CODE  ON)\nset_property(TARGET decoding_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS  ON)\n\nadd_library(gpt_kernels STATIC gpt_kernels.cu)\nset_property(TARGET gpt_kernels PROPERTY POSITION_INDEPENDENT_CODE  ON)\nset_property(TARGET gpt_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS  ON)\n\nadd_library(sampling_topk_kernels STATIC sampling_topk_kernels.cu)\nset_property(TARGET sampling_topk_kernels PROPERTY POSITION_INDEPENDENT_CODE  ON)\nset_property(TARGET sampling_topk_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS  ON)\n\nadd_library(sampling_topp_kernels STATIC sampling_topp_kernels.cu)\nset_property(TARGET sampling_topp_kernels PROPERTY POSITION_INDEPENDENT_CODE  ON)\nset_property(TARGET sampling_topp_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS  ON)\n\nadd_library(sampling_penalty_kernels STATIC sampling_penalty_kernels.cu)\nset_property(TARGET sampling_penalty_kernels PROPERTY POSITION_INDEPENDENT_CODE  ON)\nset_property(TARGET sampling_penalty_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS  ON)\n\nadd_library(sampling_kernels STATIC sampling_kernels.cu)\nset_property(TARGET sampling_kernels PROPERTY POSITION_INDEPENDENT_CODE  ON)\nset_property(TARGET sampling_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS  ON)\n\nadd_library(apply_token_bitmask_inplace_cuda STATIC apply_token_bitmask_inplace_cuda.cu)\nset_property(TARGET apply_token_bitmask_inplace_cuda PROPERTY POSITION_INDEPENDENT_CODE  ON)\nset_property(TARGET apply_token_bitmask_inplace_cuda PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS  ON)\n\nadd_subdirectory(attention)\nadd_subdirectory(gemm)\nadd_subdirectory(norm)\n"
  },
  {
    "path": "src/turbomind/kernels/activation.cu",
    "content": "\n#include \"src/turbomind/core/data_type.h\"\n\n#include \"src/turbomind/kernels/activation.h\"\n#include \"src/turbomind/kernels/core/array_ops.h\"\n#include \"src/turbomind/kernels/core/common.h\"\n\nnamespace turbomind {\n\ntemplate<class T>\nstruct SiluGptOss {\n    __device__ T operator()(T gate, T up) const noexcept\n    {\n        gate = __hmin((T)7.f, gate);\n        up   = __hmax((T)-7.f, __hmin((T)7.f, up));\n        return static_cast<T>(fdividef((float)gate, 1.f + expf((float)-gate * 1.702f)) * (1.f + (float)up));\n    }\n};\n\ntemplate<class T>\nstruct Silu {\n    __device__ T operator()(T gate, T up) const noexcept\n    {\n        return static_cast<T>(fdividef((float)gate, 1.f + expf(-(float)gate)) * (float)up);\n    }\n};\n\ntemplate<int vec_size, class Activation, class T>\n__global__ void ActivationKernel(\n    T* gate_buf, const T* __restrict__ up_buf, Activation activation, int64_t stride, int token_num, int dim)\n{\n    if constexpr (TURBOMIND_ARCH_DTYPE_GUARD(data_type_v<T>)) {\n        const int di = threadIdx.x + blockIdx.y * blockDim.x;\n        const int ti = blockIdx.x;\n\n        dim /= vec_size;\n\n        if (di >= dim) {\n            return;\n        }\n\n        using Vec = Array<T, vec_size>;\n\n        auto p_gate = reinterpret_cast<Vec*>(gate_buf + ti * stride);\n        auto p_up   = reinterpret_cast<const Vec*>(up_buf + ti * stride);\n\n        Vec gate;\n        Load(gate, (const T*)&p_gate[di]);\n\n        Vec up;\n        Ldg(up, (T*)&p_up[di]);\n\n        PRAGMA_UNROLL\n        for (int i = 0; i < vec_size; ++i) {\n            gate[i] = activation(gate[i], up[i]);\n        }\n\n        Store((T*)&p_gate[di], gate);\n    }\n}\n\nvoid Activation(Ref<Tensor> gate_, const Tensor& up, ActivationType type, cudaStream_t stream)\n{\n    auto& gate = gate_.get();\n\n    TM_CHECK(gate.shape() == up.shape());\n\n    int num, dim;\n    std::tie(num, dim) = gate.shapes(0, 1);\n\n    auto invoke = [&](auto t, auto act) {\n        using T = decltype(t);\n\n        constexpr int vec_size = 4;\n        constexpr int threads  = 512;\n\n        const dim3 blocks(num, cdiv(dim, threads * vec_size));\n\n        ActivationKernel<vec_size><<<blocks, threads, 0, stream>>>(gate.data<T>(),  //\n                                                                   up.data<T>(),\n                                                                   act,\n                                                                   gate.stride(0),\n                                                                   num,\n                                                                   dim);\n    };\n\n    auto dispatch = [&](auto t) {\n        using T = decltype(t);\n        if (type == ActivationType::kSilu) {\n            return invoke(t, Silu<T>{});\n        }\n        else if (type == ActivationType::kSiluGptOss) {\n            return invoke(t, SiluGptOss<T>{});\n        }\n        else {\n            TM_CHECK(0) << \"unknown activation type: \" << (int)type;\n        }\n    };\n\n    TM_DISPATCH_PRIMARY_DTYPES(gate.dtype(), dispatch);\n}\n\ntemplate<int vec_size, class Activation, class T>\n__global__ void ActivationKernel(\n    T* gate_up, const T* bias, const int* group_ids, int64_t stride, Activation activation, int token_num, int dim)\n{\n    if constexpr (TURBOMIND_ARCH_DTYPE_GUARD(data_type_v<T>)) {\n        const int di = (threadIdx.x + blockIdx.y * blockDim.x) * vec_size;\n        const int ti = blockIdx.x;\n        const int gi = group_ids ? group_ids[ti] : 0;\n\n        if (di >= dim) {\n            return;\n        }\n\n        using Vec = Array<T, vec_size>;\n\n        Vec gate_bias{}, up_bias{};\n        Ldg(gate_bias, &bias[gi * stride + di]);\n        Ldg(up_bias, &bias[gi * stride + dim + di]);\n\n        Vec gate, up;\n        Load(gate, &gate_up[ti * stride + di]);\n        Load(up, &gate_up[ti * stride + dim + di]);\n\n        {\n            using namespace ops;\n            gate = gate + gate_bias;\n            up   = up + up_bias;\n        }\n\n        PRAGMA_UNROLL\n        for (int i = 0; i < vec_size; ++i) {\n            gate[i] = activation(gate[i], up[i]);\n        }\n\n        Store(&gate_up[ti * stride + di], gate);\n    }\n}\n\nvoid Activation(Tensor&             gate_up,  //\n                const Tensor&       bias,\n                const Buffer_<int>& group_ids,\n                ActivationType      type,\n                cudaStream_t        stream)\n{\n    const int num = gate_up.shape(0);\n    const int dim = gate_up.shape(1) / 2;\n\n    if (!bias) {\n        Activation(gate_up.slice({0, 0}, {-1, dim}),  //\n                   gate_up.slice({0, dim}, {-1, -1}),\n                   type,\n                   stream);\n        return;\n    }\n\n    TM_CHECK_EQ(gate_up.shape(-1), bias.shape(-1));\n\n    auto invoke = [&](auto t, auto act) {\n        using T = decltype(t);\n\n        constexpr int vec_size = 4;\n        constexpr int threads  = 512;\n\n        const dim3 blocks(num, cdiv(dim, threads * vec_size));\n\n        ActivationKernel<vec_size><<<blocks, threads, 0, stream>>>(gate_up.data<T>(),  //\n                                                                   bias.data_or((T*)nullptr),\n                                                                   group_ids.data_or(nullptr),\n                                                                   gate_up.stride(0),\n                                                                   act,\n                                                                   num,\n                                                                   dim);\n    };\n\n    auto dispatch = [&](auto t) {\n        using T = decltype(t);\n        if (type == ActivationType::kSilu) {\n            return invoke(t, Silu<T>{});\n        }\n        else if (type == ActivationType::kSiluGptOss) {\n            return invoke(t, SiluGptOss<T>{});\n        }\n        else {\n            TM_CHECK(0) << \"unknown activation type: \" << (int)type;\n        }\n    };\n\n    TM_DISPATCH_PRIMARY_DTYPES(gate_up.dtype(), dispatch);\n}\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/kernels/activation.h",
    "content": "#pragma once\n\n#include \"src/turbomind/core/core.h\"\n\nnamespace turbomind {\n\nenum class ActivationType\n{\n    kSilu,\n    kSiluGptOss\n};\n\nvoid Activation(Ref<Tensor> gate, const Tensor& up, ActivationType type, cudaStream_t stream);\n\nvoid Activation(Tensor&             gate_up,  //\n                const Tensor&       bias,\n                const Buffer_<int>& group_ids,\n                ActivationType      type,\n                cudaStream_t        stream);\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/kernels/activation_kernels.cu",
    "content": "/*\n * Copyright (c) 2019-2023, NVIDIA CORPORATION.  All rights reserved.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include \"src/turbomind/core/core.h\"\n#include \"src/turbomind/core/data_type.h\"\n#include \"src/turbomind/kernels/activation_kernels.h\"\n#include \"src/turbomind/kernels/core/array.h\"\n#include \"src/turbomind/kernels/core/array_ops.h\"\n#include \"src/turbomind/kernels/core/math.h\"\n#include \"src/turbomind/macro.h\"\n#include \"src/turbomind/utils/cuda_type_utils.cuh\"\n#include \"src/turbomind/utils/cuda_utils.h\"\n#include \"src/turbomind/utils/memory_utils.h\"\n\n#ifndef CUDART_VERSION\n#error CUDART_VERSION Undefined!\n#endif\n\nnamespace turbomind {\n\n/* Gelu Activation */\n\n__forceinline__ __device__ float copysignf_pos(float a, float b)\n{\n    float r;\n    r = __int_as_float(__float_as_int(a) | (__float_as_int(b) & 0x80000000));\n    return r;\n}\n\n__inline__ __device__ float tanh_opt(float x)\n{\n#if (__CUDA_ARCH__ >= 750 && CUDART_VERSION >= 11000)\n    float r;\n    asm(\"tanh.approx.f32 %0,%1; \\n\\t\" : \"=f\"(r) : \"f\"(x));\n    return r;\n#else\n    const float exp_val = -1.f * fabs(2 * x);\n    return copysignf_pos((1.0f - __expf(exp_val)) / (__expf(exp_val) + 1.0f), x);\n#endif\n}\n\ntemplate<typename T>\nstruct GeluActivation {\n    using return_type = T;\n    static __device__ __forceinline__ T apply(const T& val)\n    {\n        const float cdf = 0.5f * (1.0f + tanh_opt((0.7978845608028654f * (val + 0.044715f * val * val * val))));\n        return val * cdf;\n    }\n};\n\ntemplate<>\nstruct GeluActivation<half2> {\n    using return_type = half2;\n    static __device__ __forceinline__ half2 apply(const half2& val)\n    {\n        half2  val_pow3 = __hmul2(val, __hmul2(val, val));\n        float2 tmp_pow  = __half22float2(val_pow3);\n        float2 tmp      = __half22float2(val);\n\n        tmp.x = 0.5f * (1.0f + tanh_opt((0.7978845608028654f * (tmp.x + 0.044715f * tmp_pow.x))));\n        tmp.y = 0.5f * (1.0f + tanh_opt((0.7978845608028654f * (tmp.y + 0.044715f * tmp_pow.y))));\n        return __hmul2(val, __float22half2_rn(tmp));\n    }\n};\n\n#ifdef ENABLE_BF16\ntemplate<>\nstruct GeluActivation<__nv_bfloat162> {\n    using return_type = __nv_bfloat162;\n    static __device__ __forceinline__ __nv_bfloat162 apply(const __nv_bfloat162& val)\n    {\n        __nv_bfloat162 val_pow3 = bf16hmul2(val, bf16hmul2(val, val));\n        float2         tmp_pow  = bf1622float2(val_pow3);\n        float2         tmp      = bf1622float2(val);\n\n        tmp.x = 0.5f * (1.0f + tanh_opt((0.7978845608028654f * (tmp.x + 0.044715f * tmp_pow.x))));\n        tmp.y = 0.5f * (1.0f + tanh_opt((0.7978845608028654f * (tmp.y + 0.044715f * tmp_pow.y))));\n        return bf16hmul2(val, __floats2bfloat162_rn(tmp.x, tmp.y));\n    }\n};\n#endif\n\n/* Relu Activation */\n\ntemplate<typename T>\nstruct ReluActivation {\n    using return_type = T;\n    static __device__ __forceinline__ T apply(const T& val)\n    {\n        return val > static_cast<T>(0.0f) ? val : static_cast<T>(0.0f);\n    }\n};\n\ntemplate<>\nstruct ReluActivation<half2> {\n    using return_type = half2;\n    static __device__ __forceinline__ half2 apply(const half2& val)\n    {\n        const half zero_half = static_cast<half>(0.0f);\n        return make_half2(val.x > zero_half ? val.x : zero_half, val.y > zero_half ? val.y : zero_half);\n    }\n};\n\n#ifdef ENABLE_BF16\ntemplate<>\nstruct ReluActivation<__nv_bfloat162> {\n    using return_type = __nv_bfloat162;\n    static __device__ __forceinline__ __nv_bfloat162 apply(const __nv_bfloat162& val)\n    {\n        const __nv_bfloat16 zero_bf16 = static_cast<__nv_bfloat16>(0.0f);\n#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)\n        return turbomind::make_bfloat162(val.x > zero_bf16 ? val.x : zero_bf16, val.y > zero_bf16 ? val.y : zero_bf16);\n#else\n        return make_bfloat162(val.x > zero_bf16 ? val.x : zero_bf16, val.y > zero_bf16 ? val.y : zero_bf16);\n#endif\n    }\n};\n#endif\n\n/* Silu Activation */\n\ntemplate<typename T>\nstruct SiluActivation {\n    using return_type = T;\n    static __device__ __forceinline__ T apply(const T& val)\n    {\n        return (T)((float)val / (1.0f + __expf((float)-val)));\n    }\n};\n\ntemplate<>\nstruct SiluActivation<half2> {\n    using return_type = float2;\n    static __device__ __forceinline__ float2 apply(const half2& val)\n    {\n        return make_float2(SiluActivation<float>::apply(val.x), SiluActivation<float>::apply(val.y));\n    }\n};\n\n#ifdef ENABLE_BF16\ntemplate<>\nstruct SiluActivation<__nv_bfloat162> {\n    using return_type = float2;\n    static __device__ __forceinline__ float2 apply(const __nv_bfloat162& val)\n    {\n        return make_float2(SiluActivation<float>::apply(val.x), SiluActivation<float>::apply(val.y));\n    }\n};\n#endif  // ENABLE_BF16\n\n/* Identity Activation (= no activation) */\n\ntemplate<typename T>\nstruct IdentityActivation {\n    using return_type = T;\n    static __device__ __forceinline__ T apply(const T& val)\n    {\n        return val;\n    }\n};\n\n// `output` may be an alias of `inter_buf`\ntemplate<int VecSize, template<typename T> class Activation, typename T>\n__global__ void activation_kernel(T* inter_buf, const T* __restrict__ gate_buf, int64_t stride, int token_num, int dims)\n{\n    const int di = threadIdx.x + blockIdx.y * blockDim.x;\n    const int ti = blockIdx.x;\n\n    dims /= VecSize;\n\n    if (di >= dims) {\n        return;\n    }\n\n    using Vec = Array<T, VecSize>;\n\n    auto p_inter = reinterpret_cast<Vec*>(inter_buf + ti * stride);\n    auto p_gate  = reinterpret_cast<const Vec*>(gate_buf + ti * stride);\n\n    Vec inter;\n    Load(inter, (T*)&p_inter[di]);\n\n    Vec gate;\n    Ldg(gate, (const T*)&p_gate[di]);\n\n    PRAGMA_UNROLL\n    for (int i = 0; i < VecSize; ++i) {\n        inter[i] = Activation<T>::apply(inter[i]) * gate[i];\n    }\n\n    Store((T*)&p_inter[di], inter);\n}\n\ntemplate<template<typename T> class Activation, typename T>\nvoid invokeGenericActivation_v2(\n    T* inter_buf, const T* __restrict__ gate_buf, int64_t stride, int token_num, int dims, cudaStream_t stream)\n{\n    constexpr int kVecSize = 4;\n\n    constexpr int block = 512;\n    const dim3    grid(token_num, ceil_div(dims, block * kVecSize));\n\n    activation_kernel<kVecSize, Activation, T>\n        <<<grid, block, 0, stream>>>(inter_buf, gate_buf, stride, token_num, dims);\n}\n\ntemplate<template<typename T> class Activation>\nvoid invokeGenericActivation_v3(Ref<Tensor> inter_, const Tensor& gate, cudaStream_t stream)\n{\n    auto& inter = inter_.get();\n    TM_CHECK_EQ(inter.ndim(), 2);\n    TM_CHECK_EQ(gate.ndim(), 2);\n    TM_CHECK_EQ(inter.stride(0), gate.stride(0));\n\n    TM_CHECK(inter.shape() == gate.shape());\n\n    auto invoke = [&](auto t) {\n        using T = decltype(t);\n\n        const auto [num, dim] = inter.shapes(0, 1);\n\n        constexpr int kVecSize = 4;\n        constexpr int block    = 512;\n\n        const dim3 grid(num, cdiv((int)dim, block * kVecSize));\n\n        activation_kernel<kVecSize, Activation, T>\n            <<<grid, block, 0, stream>>>(inter.data<T>(), gate.data<T>(), inter.stride(0), num, dim);\n    };\n\n    TM_DISPATCH_PRIMARY_DTYPES(inter.dtype(), invoke);\n}\n\ntemplate void invokeGenericActivation_v3<SiluActivation>(Ref<Tensor> inter_, const Tensor& gate, cudaStream_t stream);\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/kernels/activation_kernels.h",
    "content": "/*\n * Copyright (c) 2019-2023, NVIDIA CORPORATION.  All rights reserved.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#pragma once\n\n#include <cuda_runtime.h>\n\n#include \"src/turbomind/core/core.h\"\n\nnamespace turbomind {\n\n// clang-format off\ntemplate<typename T> struct GeluActivation;\ntemplate<typename T> struct ReluActivation;\ntemplate<typename T> struct SiluActivation;\ntemplate<typename T> struct IdentityActivation;\n// clang-format on\n\ntemplate<template<typename T> class Activation>\nvoid invokeGenericActivation_v3(Ref<Tensor> inter_, const Tensor& gate, cudaStream_t stream);\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/kernels/apply_token_bitmask_inplace_cuda.cu",
    "content": "// Modified from xgrammar python/xgrammar/kernels/apply_token_bitmask_inplace_cuda.cu\n\n/*\n * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: Apache-2.0\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n// clang-format off\n#include <cuda_bf16.h>\n#include <cuda_fp16.h>\n#include <cuda_runtime.h>\n\n#include \"src/turbomind/core/context.h\"\n#include \"src/turbomind/kernels/apply_token_bitmask_inplace_cuda.h\"\n// clang-format on\n\nusing namespace std;\n\n#ifndef CUDART_INF_FP16\n#define CUDART_INF_FP16 __ushort_as_half((unsigned short)0x7C00U)\n#endif\n\n#if __CUDA_ARCH__ >= 800\n#ifndef CUDART_INF_BF16\n#define CUDART_INF_BF16 __ushort_as_bfloat16((unsigned short)0x7F80U)\n#endif\n#endif\n\nconstexpr int32_t BITS_PER_BLOCK           = 32;\nconstexpr int32_t THREADS_PER_THREAD_BLOCK = 256;\n\ntemplate<typename T>\n__device__ T NegativeInfinity()\n{\n    return -INFINITY;\n}\n\ntemplate<>\n__device__ __half NegativeInfinity<__half>()\n{\n    return -CUDART_INF_FP16;\n}\n\n#if __CUDA_ARCH__ >= 800\ntemplate<>\n__device__ __nv_bfloat16 NegativeInfinity<__nv_bfloat16>()\n{\n    return -CUDART_INF_BF16;\n}\n#endif\n\ntemplate<typename T, typename PackedT>\n__device__ PackedT PackedNegativeInfinity()\n{\n    constexpr int kAlignment = sizeof(PackedT) / sizeof(T);\n    T             packed[kAlignment];\n#pragma unroll\n    for (int i = 0; i < kAlignment; i++) {\n        packed[i] = NegativeInfinity<T>();\n    }\n    return *reinterpret_cast<PackedT*>(packed);\n}\n\ntemplate<typename T, typename PackedT, int32_t kBitsPerThread>\n__global__ void __launch_bounds__(THREADS_PER_THREAD_BLOCK) LogitsBitmaskKernel(T* __restrict__ logits,\n                                                                                const int32_t* __restrict__ bitmask,\n                                                                                const int32_t* __restrict__ indices,\n                                                                                int32_t vocab_size,\n                                                                                int32_t logits_stride,\n                                                                                int32_t bitmask_stride)\n{\n    constexpr int      kAlignment  = sizeof(PackedT) / sizeof(T);\n    constexpr uint32_t kPackedMask = (1 << kAlignment) - 1;\n\n    const int batch_idx = (indices == nullptr) ? blockIdx.y : indices[blockIdx.y];\n\n    const int      block_offset      = blockIdx.x * THREADS_PER_THREAD_BLOCK * kBitsPerThread;\n    T*             logits_gmem_ptr   = logits + batch_idx * logits_stride + block_offset;\n    const int32_t* bitmask_gmem_ptr  = bitmask + batch_idx * bitmask_stride + block_offset / BITS_PER_BLOCK;\n    const int      bitmask_inner_idx = threadIdx.x % (BITS_PER_BLOCK / kAlignment);\n    T              logits_reg[kAlignment];\n\n#pragma unroll\n    for (int offset = threadIdx.x * kAlignment; offset < THREADS_PER_THREAD_BLOCK * kBitsPerThread;\n         offset += THREADS_PER_THREAD_BLOCK * kAlignment) {\n        if (block_offset + offset >= vocab_size) {\n            break;\n        }\n\n        const uint32_t bitmask_val =\n            (~bitmask_gmem_ptr[offset / BITS_PER_BLOCK] >> (bitmask_inner_idx * kAlignment)) & kPackedMask;\n\n        if (bitmask_val == 0) {\n            continue;\n        }\n\n        if (bitmask_val == kPackedMask) {\n            *reinterpret_cast<PackedT*>(logits_gmem_ptr + offset) = PackedNegativeInfinity<T, PackedT>();\n            continue;\n        }\n\n        *reinterpret_cast<PackedT*>(logits_reg) = *reinterpret_cast<PackedT*>(logits_gmem_ptr + offset);\n#pragma unroll\n        for (int i = 0; i < kAlignment; i++) {\n            if (((bitmask_val >> i) & 1)) {\n                logits_reg[i] = NegativeInfinity<T>();\n            }\n        }\n        *reinterpret_cast<PackedT*>(logits_gmem_ptr + offset) = *reinterpret_cast<PackedT*>(logits_reg);\n    }\n}\n\ntemplate<typename T, typename = std::enable_if_t<std::is_integral<T>::value>>\nconstexpr auto CeilDiv(T numerator, T denominator)\n{\n    return (numerator + denominator - 1) / denominator;\n}\n\ntemplate<typename T, typename PackedT>\nvoid ApplyTokenBitmaskInplaceDispatchToBitsPerThread(T* __restrict__ logits,\n                                                     const int32_t* __restrict__ bitmask,\n                                                     const int32_t* __restrict__ indices,\n                                                     int32_t vocab_size,\n                                                     int32_t logits_stride,\n                                                     int32_t bitmask_stride,\n                                                     int32_t num_rows)\n{\n    constexpr int kAlignment          = sizeof(PackedT) / sizeof(T);\n    const int32_t num_blocks_per_row  = CeilDiv(2048 / THREADS_PER_THREAD_BLOCK * 128, num_rows);\n    const int32_t num_bits_per_thread = CeilDiv(vocab_size, THREADS_PER_THREAD_BLOCK * num_blocks_per_row);\n\n    const dim3  block(THREADS_PER_THREAD_BLOCK);\n    const auto& stream = turbomind::core::Context::stream();\n\n    if (num_bits_per_thread <= 4 && kAlignment <= 4) {\n        const dim3 grid(CeilDiv(vocab_size, THREADS_PER_THREAD_BLOCK * 4), num_rows);\n        LogitsBitmaskKernel<T, PackedT, 4>\n            <<<grid, block, 0, stream.handle()>>>(logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride);\n    }\n    else if (num_bits_per_thread <= 8 && kAlignment <= 8) {\n        const dim3 grid(CeilDiv(vocab_size, THREADS_PER_THREAD_BLOCK * 8), num_rows);\n        LogitsBitmaskKernel<T, PackedT, 8>\n            <<<grid, block, 0, stream.handle()>>>(logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride);\n    }\n    else if (num_bits_per_thread <= 16 && kAlignment <= 16) {\n        const dim3 grid(CeilDiv(vocab_size, THREADS_PER_THREAD_BLOCK * 16), num_rows);\n        LogitsBitmaskKernel<T, PackedT, 16>\n            <<<grid, block, 0, stream.handle()>>>(logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride);\n    }\n    else {\n        const dim3 grid(CeilDiv(vocab_size, THREADS_PER_THREAD_BLOCK * 32), num_rows);\n        LogitsBitmaskKernel<T, PackedT, 32>\n            <<<grid, block, 0, stream.handle()>>>(logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride);\n    }\n}\n\ntemplate<typename T>\nvoid ApplyTokenBitmaskInplaceDispatchToPackedT(T* __restrict__ logits,\n                                               const int32_t* __restrict__ bitmask,\n                                               const int32_t* __restrict__ indices,\n                                               int32_t vocab_size,\n                                               int32_t logits_stride,\n                                               int32_t bitmask_stride,\n                                               int32_t num_rows)\n{\n    if (logits_stride % (sizeof(float4) / sizeof(T)) == 0) {\n        ApplyTokenBitmaskInplaceDispatchToBitsPerThread<T, float4>(\n            logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride, num_rows);\n    }\n    else {\n        ApplyTokenBitmaskInplaceDispatchToBitsPerThread<T, T>(\n            logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride, num_rows);\n    }\n}\n\nnamespace turbomind {\nusing namespace turbomind::core;\n\nvoid ApplyTokenBitmaskInplace(Tensor logits, Tensor bitmask, std::optional<Tensor> indices)\n{\n    std::pair<int32_t, int32_t> logits_shape =\n        logits.ndim() == 2 ?\n            std::make_pair(static_cast<int32_t>(logits.shape(0)), static_cast<int32_t>(logits.shape(1))) :\n            std::make_pair(1, static_cast<int32_t>(logits.shape(0)));\n\n    std::pair<int32_t, int32_t> bitmask_shape =\n        bitmask.ndim() == 2 ?\n            std::make_pair(static_cast<int32_t>(bitmask.shape(0)), static_cast<int32_t>(bitmask.shape(1))) :\n            std::make_pair(1, static_cast<int32_t>(bitmask.shape(0)));\n\n    int vocab_size = std::min(logits_shape.second, bitmask_shape.second * BITS_PER_BLOCK);\n\n    int32_t  num_rows    = logits_shape.first;\n    int32_t* indices_ptr = nullptr;\n    if (indices) {\n        num_rows    = indices->shape(0);\n        indices_ptr = indices->data<int32_t>();\n    }\n    else {\n        TM_CHECK(logits_shape.first == bitmask_shape.first) << \"logits and bitmask must have the same batch size.\";\n    }\n\n    // Currently we use only float logits.\n    TM_CHECK(logits.dtype() == kFloat32);\n    ApplyTokenBitmaskInplaceDispatchToPackedT(logits.data<float>(),\n                                              bitmask.data<int32_t>(),\n                                              indices_ptr,\n                                              vocab_size,\n                                              logits.stride(0),\n                                              bitmask.stride(0),\n                                              num_rows);\n}\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/kernels/apply_token_bitmask_inplace_cuda.h",
    "content": "#include \"src/turbomind/core/tensor.h\"\n\nnamespace turbomind {\nvoid ApplyTokenBitmaskInplace(core::Tensor                logits,\n                              core::Tensor                bitmask,\n                              std::optional<core::Tensor> indices = std::nullopt);\n}\n"
  },
  {
    "path": "src/turbomind/kernels/attention/CMakeLists.txt",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nadd_subdirectory(kernel)\n\nadd_library(attention STATIC\n            attention.cu\n            decoding.cu\n            kv_cache_utils_v2.cu\n            cp_utils.cu\n            registry.cu\n            )\nset_property(TARGET attention PROPERTY POSITION_INDEPENDENT_CODE ON)\nset_property(TARGET attention PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)\ntarget_compile_options(attention PRIVATE -O3\n    $<$<COMPILE_LANGUAGE:CUDA>:-use_fast_math --expt-relaxed-constexpr -Xptxas=-v --threads 16>)\ntarget_link_libraries(attention PUBLIC $<LINK_LIBRARY:WHOLE_ARCHIVE,attention_kernels>)\ntarget_link_libraries(attention PRIVATE nvidia::cutlass::cutlass)\n\nif (BUILD_TEST)\n    target_compile_options(attention PRIVATE\n        $<$<COMPILE_LANGUAGE:CUDA>:-Xptxas=-v --generate-line-info>)\n\n    add_executable(test_attention\n        test_utils.cu\n        test_attention.cu\n        reference.cu)\n    target_compile_options(test_attention PRIVATE\n        --generate-line-info -O3 -use_fast_math --expt-relaxed-constexpr)\n    target_link_libraries(test_attention PRIVATE\n        attention\n        # flash_attention\n        nvidia::cutlass::cutlass\n        models\n        unfused_attention_kernels\n        logger\n        cublas)\n\n    add_executable(test_quant test_quant.cu test_utils.cu)\n    target_compile_options(test_quant PRIVATE\n        --generate-line-info -O3 -use_fast_math --expt-relaxed-constexpr)\n    target_link_libraries(test_quant PRIVATE\n        nvidia::cutlass::cutlass\n    )\nendif ()\n"
  },
  {
    "path": "src/turbomind/kernels/attention/arch.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\nnamespace turbomind::arch {\n\n// tags for dispatching & conditional codegen\n\ntemplate<int Begin, int End = -1>\nstruct Arch {\n    static constexpr bool is_compatible(int arch)\n    {\n        return Begin <= arch && (End == -1 || arch < End);\n    }\n};\n\nstruct Sm70: Arch<700, 750> {\n    static constexpr int value = 700;\n};\n\nstruct Sm75: Arch<750, 800> {\n    static constexpr int value = 750;\n};\n\nstruct Sm80: Arch<800> {\n    static constexpr int value = 800;\n};\n\ninline bool is_arch_compatible(int karch, int darch)\n{\n    switch (karch) {\n        case 0:\n            return true;\n        case 700:\n            return Sm70::is_compatible(darch);\n        case 750:\n            return Sm75::is_compatible(darch);\n        case 800:\n            return Sm80::is_compatible(darch);\n        default:\n            return false;\n    }\n}\n\n}  // namespace turbomind::arch\n"
  },
  {
    "path": "src/turbomind/kernels/attention/attention.cu",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include \"attention.h\"\n#include \"src/turbomind/core/data_type.h\"\n#include \"src/turbomind/kernels/attention/registry.h\"\n#include \"src/turbomind/utils/cuda_utils.h\"\n\nnamespace turbomind {\n\ntemplate<class T>\nvoid dispatchAttention(const AttentionParams<T>& params)\n{\n    using namespace attention;\n\n    auto&    reg = Registry::instance();\n    AttnDesc desc{};\n    desc.mode      = AttnDesc::kPrefill;\n    desc.head_dim  = params.size_per_head;\n    desc.data_type = data_type_v<T>;\n\n    auto* kernel = reg.Find(desc);\n\n    TM_CHECK(kernel) << \"No attention kernel found: \" + to_string(desc);\n\n    kernel->Launch(&params, reg.sm_count());\n}\n\ntemplate void dispatchAttention(const AttentionParams<half>& params);\n#if ENABLE_BF16\ntemplate void dispatchAttention(const AttentionParams<nv_bfloat16>& params);\n#endif\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/kernels/attention/attention.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include \"attention_params.h\"\n\nnamespace turbomind {\n\nconstexpr int MAX_CTA_S = 64;\n\ntemplate<typename T>\nvoid dispatchAttention(const AttentionParams<T>& params);\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/kernels/attention/attention_params.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include \"cutlass/fast_math.h\"\n#include <cstdint>\n#include <cuda_runtime.h>\n\n#include \"src/turbomind/models/llama/llama_rope.h\"\n\nnamespace turbomind {\n\n// 64-bit offsets may be needed\nstruct LinearIteratorParams {\n    const void* kv_cache;\n    int         stride_h;\n    int         key_to_val;\n};\n\nstruct BlockIteratorParams {\n    char**     block_ptrs;\n    const int* cu_block_nums;\n    int        layer_id;\n    int        block_len;\n};\n\ntypedef void (*cp_post_fn)(void* context);\n\n/// TODO: Rename to attention::Param\ntemplate<typename T>\nstruct AttentionParams {\n    // token-level buffers, [B, qH + 2kvH, D] or [B, kvH, D]\n    T*      out;\n    T*      q;\n    T*      k;\n    T*      v;\n    int64_t stride;\n\n    // bias, [qH, D] or [kvH, D]\n    T* q_bias;\n    T* k_bias;\n    T* v_bias;\n\n    // sequence-level buffers\n    const int*   cu_q_len;\n    const int*   cu_k_len;\n    const bool*  finished;\n    const float* rope_theta;\n\n    const T* sinks;\n    float    scale_sinks;\n\n    LinearIteratorParams linear_iter_params;\n    BlockIteratorParams  block_iter_params;\n\n    // batch-level params\n    int token_num;\n    int batch_size;\n    int max_q_len;\n    int max_k_len;\n\n    // instance-level params\n    int   num_heads;\n    int   num_kv_heads;\n    int   size_per_head;\n    float inv_sqrt_dh;\n    int   window_size;\n    int   layer_id;  // for debugging\n\n    // rotary embedding\n    RopeKernelParam rope_param;\n\n    // log(n) attention\n    bool use_logn_attn;\n    int  max_position_embeddings;\n\n    int quant_policy;\n\n    int    max_split_k;\n    int*   split_cnt;\n    float* partial_O;\n    float* partial_ML;\n\n    // context parallel\n    int                 cp_rank{0};\n    cutlass::FastDivmod cp_size{1};\n    int                 offset_q{0};  // decode offset\n    cp_post_fn          cp_fn{nullptr};\n    void*               cp_fn_ctx{nullptr};\n\n    int          arch;\n    cudaStream_t stream;\n\n    // debug\n    float* qk;\n    T*     pr;\n};\n\ntemplate<class CacheIterFactory, class SFINAE = void>\nstruct CreateCacheIterFactory {\n    template<class Param>\n    static CacheIterFactory apply(const Param& param)\n    {\n        using Tkv = typename CacheIterFactory::Tkv;\n        return {(const Tkv*)param.linear_iter_params.kv_cache,\n                param.cu_k_len,\n                param.linear_iter_params.stride_h,\n                param.linear_iter_params.key_to_val};\n    }\n};\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/kernels/attention/attention_template.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include \"attention_params.h\"\n#include \"attention_universal.h\"\n#include \"reduce.h\"\n#include \"utils.h\"\n\nnamespace turbomind {\n\ntemplate<class Kernel>\nvoid invokeAttention(const typename Kernel::ParamType& params, int sm_count, int max_active_ctas)\n{\n    static const size_t kSmemSize = sizeof(typename Kernel::SharedStorage);\n\n    if constexpr (1) {\n\n        [[maybe_unused]] static const int _ = [&] {\n            // std::cout << __PRETTY_FUNCTION__ << std::endl;\n            // std::cout << \"GmemMap:\\n\";\n            // Print(typename Kernel::Impl::ThreadMapKV{});\n            // std::cout << \"\\nDynamic smem size: \" << kSmemSize << \"\\n\";\n            return 0;\n        }();\n    }\n\n    dim3 block(Kernel::kWarpCount * WARP_SIZE);\n\n    static const auto kernel_func = &attention_kernel<Kernel>;\n\n    const int max_cp_k_len    = cdiv(params.max_k_len, (int)params.cp_size);\n    const int tile_count      = cdiv(std::min(max_cp_k_len, params.window_size), Kernel::CTA_S);\n    const int max_split_count = std::min(params.max_split_k, tile_count);\n\n    typename Kernel::CtaMap cta_map{\n        params.max_q_len, params.batch_size, params.num_heads, Kernel::CTA_Q, Kernel::CTA_H, 1};\n\n    // grid shape when split cnt = 1\n    dim3 grid = cta_map.get_grid_shape();\n\n    const int grid_size = grid.x * grid.y * grid.z;\n    const int split_cnt = GetSplitCount(max_split_count, grid_size, max_active_ctas, sm_count, 8);\n\n    // printf(\"max split cnt: %d, split cnt: %d\\n\", max_split_count, split_cnt);\n\n    // adjust split cnt and update grid shape\n    cta_map.set_split_cnt(split_cnt);\n    grid = cta_map.get_grid_shape();\n\n    auto cache_iter_factory = CreateCacheIterFactory<typename Kernel::CacheIteratorFactory>::apply(params);\n\n    const int q_group_size = params.num_heads / params.num_kv_heads;\n\n    kernel_func<<<grid, block, kSmemSize, params.stream>>>(params,\n                                                           cache_iter_factory,\n                                                           cta_map,\n                                                           q_group_size,\n                                                           1,            // q_head_per_cta\n                                                           q_group_size  // cta_per_q_group\n    );\n\n    if (auto err = cudaGetLastError(); err != cudaSuccess) {\n        std::cout << cudaGetErrorString(err) << \"\\n\";\n        std::abort();\n    }\n\n    if (params.cp_fn) {\n        params.cp_fn(params.cp_fn_ctx);\n    }\n\n    if (split_cnt > 1 || params.cp_size > 1) {\n        attention::invokeReduceV3<Kernel::kHeadDim>(params.out + params.offset_q * params.num_heads * Kernel::kHeadDim,\n                                                    params.partial_ML,\n                                                    params.partial_O,\n                                                    split_cnt > 1 ? params.split_cnt : nullptr,\n                                                    params.max_split_k,\n                                                    split_cnt,\n                                                    params.cp_size,\n                                                    params.cp_rank,\n                                                    params.token_num,\n                                                    params.num_heads,\n                                                    params.inv_sqrt_dh,\n                                                    params.stream);\n    }\n}\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/kernels/attention/attention_universal.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include <limits>\n#include <type_traits>\n\n#include \"quantization.h\"\n\n#include \"src/turbomind/kernels/attention/rotary_embedding.h\"\n#include \"src/turbomind/kernels/core/array_ops.h\"\n#include \"src/turbomind/kernels/core/layout.h\"\n#include \"src/turbomind/kernels/core/math.h\"\n\n#include \"attention_params.h\"\n\nnamespace turbomind {\n\nnamespace attention {\nstruct DecodingCtaMap;\n}  // namespace attention\n\ntemplate<class Arch_, class Mainloop, class CacheIteratorFactory_, class CtaMap_>\nstruct AttentionUniversal {\n\n    using T   = typename Mainloop::T;\n    using Tkv = typename Mainloop::Tkv;\n\n    using Impl = typename Mainloop::Impl;\n\n    using CacheIteratorFactory = CacheIteratorFactory_;\n    using CtaMap               = CtaMap_;\n\n    using Arch = Arch_;\n\n    static constexpr int kWarpCount = Impl::kWarpCount;\n\n    using ParamType = AttentionParams<T>;\n\n    static constexpr int kHeadDim = Impl::kHeadDim;\n\n    using FragQ = typename Impl::FragQ;\n    using FragO = typename Impl::FragO;\n    using FragM = typename Impl::FragM;\n    using FragL = typename Impl::FragL;\n\n    using GmemIterK = typename Mainloop::GmemIterK;\n    using GmemIterV = typename Mainloop::GmemIterV;\n\n    static constexpr int CTA_H = Impl::CTA_H;\n    static constexpr int CTA_Q = Impl::CTA_Q;\n    static constexpr int CTA_S = Impl::CTA_S;\n\n    using SharedStorage = typename Mainloop::SharedStorage;\n\n    // Only process KV inline during decoding (DecodingCtaMap), not during context attention\n    // (AttentionCtaMap), even when CTA_Q == 1 (e.g. SIMT kernels).\n    static constexpr bool kProcessKV = std::is_same_v<CtaMap, attention::DecodingCtaMap>;\n\n    const int q_group_size_;\n    const int q_head_per_cta_;\n    const int cta_per_q_group_;\n\n    // past-the-end hi of the CTA\n    int hi_end_{1};\n\n    __device__ bool check_h(int hi)\n    {\n        if constexpr (CTA_Q > 1) {\n            // bypass the check for prefill kernels since `hi == 0` constantly\n            return true;\n        }\n        else {\n            return hi < hi_end_;\n        }\n    }\n\n    template<class VecQ, class VecKV>\n    __device__ void ApplyBias(\n        VecQ& vec_Q, VecKV& vec_K, VecKV& vec_V, const ParamType& params, int head_idx, int kv_head_idx, int2 offset)\n    {\n        using Map = typename Impl::ThreadMapQ;\n\n        constexpr int kVecSize = Map::kAccessC;\n        constexpr int ITER_C   = Map::kIterC;\n        constexpr int ITER_S   = Map::kIterS;\n\n        constexpr bool HAS_V = kHeadDim != 576;\n\n        if constexpr (kProcessKV) {\n            Array<T, kVecSize> bias_K[ITER_C];\n            Array<T, kVecSize> bias_V[ITER_C];\n            PRAGMA_UNROLL\n            for (int c = 0; c < ITER_C; ++c) {\n                const int di    = offset.x + c * Map::kDeltaC;\n                const int k_idx = kv_head_idx * kHeadDim + di;\n                if (params.k_bias) {\n                    Ldg(bias_K[c], &params.k_bias[k_idx]);\n                }\n                if (params.v_bias && HAS_V) {\n                    Ldg(bias_V[c], &params.v_bias[k_idx]);\n                }\n            }\n            PRAGMA_UNROLL\n            for (int c = 0; c < ITER_C; ++c) {\n                using namespace ops;\n                if (params.k_bias) {\n                    vec_K[0][c] = vec_K[0][c] + bias_K[c];\n                }\n                if (params.v_bias && HAS_V) {\n                    vec_V[0][c] = vec_V[0][c] + bias_V[c];\n                }\n            }\n        }\n\n        if constexpr (CTA_H == 1) {\n            Array<T, kVecSize> bias_Q[ITER_C];\n            PRAGMA_UNROLL\n            for (int c = 0; c < ITER_C; ++c) {\n                const int di    = offset.x + c * Map::kDeltaC;\n                const int q_idx = head_idx * kHeadDim + di;\n                if (params.q_bias) {\n                    Ldg(bias_Q[c], &params.q_bias[q_idx]);\n                }\n            }\n            PRAGMA_UNROLL\n            for (int s = 0; s < ITER_S; ++s) {\n                PRAGMA_UNROLL\n                for (int c = 0; c < ITER_C; ++c) {\n                    using namespace ops;\n                    if (params.q_bias) {\n                        vec_Q[s][c] = vec_Q[s][c] + bias_Q[c];\n                    }\n                }\n            }\n        }\n        else if constexpr (CTA_Q == 1) {\n            Array<T, kVecSize> bias_Q[ITER_S][ITER_C];\n            PRAGMA_UNROLL\n            for (int s = 0; s < ITER_S; ++s) {\n                const int hi = offset.y + s * Map::kDeltaS;\n                PRAGMA_UNROLL\n                for (int c = 0; c < ITER_C; ++c) {\n                    const int di    = offset.x + c * Map::kDeltaC;\n                    const int q_idx = (head_idx + hi) * kHeadDim + di;\n                    if (params.q_bias && check_h(hi)) {\n                        Ldg(bias_Q[s][c], &params.q_bias[q_idx]);\n                    }\n                }\n            }\n            PRAGMA_UNROLL\n            for (int s = 0; s < ITER_S; ++s) {\n                PRAGMA_UNROLL\n                for (int c = 0; c < ITER_C; ++c) {\n                    using namespace ops;\n                    if (params.q_bias) {\n                        vec_Q[s][c] = vec_Q[s][c] + bias_Q[s][c];\n                    }\n                }\n            }\n        }\n        else {\n            static_assert(CTA_Q == 1 || CTA_H == 1);\n        }\n    }\n\n    template<class Iterator>\n    __device__ void Prologue(const ParamType& params,\n                             T*               smem_Q,\n                             FragQ&           frag_Q,\n                             int              qi_begin,\n                             int              qi_end,\n                             int              query_idx,\n                             int              head_idx,\n                             int              kv_head_idx,\n                             int              batch_idx,\n                             int              history_len,\n                             Iterator&        iterator,\n                             int              warp_id,\n                             int              lane_id)\n    {\n\n        using Map = typename Impl::ThreadMapQ;\n\n        constexpr int kVecSize = Map::kAccessC;\n\n        using Vec = Array<T, kVecSize>;\n\n        constexpr int ITER_C = Map::kIterC;\n        constexpr int ITER_S = Map::kIterS;\n\n        constexpr bool HAS_V = kHeadDim != 576;\n\n        Vec vec_Q[ITER_S][ITER_C]{};  // [QxH, D]\n        Vec vec_K[1][ITER_C];\n        Vec vec_V[1][ITER_C];\n\n        const int2 offset = Map::get_offset(warp_id, lane_id);\n\n        // Load Q\n        PRAGMA_UNROLL\n        for (int s = 0; s < ITER_S; ++s) {\n            const int si = offset.y + s * Map::kDeltaS;\n            const int hi = si % CTA_H + head_idx;\n            const int qi = si / CTA_H + qi_begin;\n            PRAGMA_UNROLL\n            for (int c = 0; c < ITER_C; ++c) {\n                const int     di    = offset.x + c * Map::kDeltaC;\n                const int64_t q_idx = qi * params.stride + hi * kHeadDim + di;\n                const int64_t k_idx = qi * params.stride + kv_head_idx * kHeadDim + di;\n                if (qi < qi_end) {\n                    if (check_h(si % CTA_H)) {\n                        Ldg(vec_Q[s][c], &params.q[q_idx]);\n                    }\n                    if constexpr (kProcessKV) {  // duplicate loads in s\n                        if (s == 0) {\n                            Ldg(vec_K[0][c], &params.k[k_idx]);\n                            if constexpr (HAS_V) {\n                                Ldg(vec_V[0][c], &params.v[k_idx]);\n                            }\n                        }\n                    }\n                }\n            }\n        }\n\n        ApplyBias(vec_Q, vec_K, vec_V, params, head_idx, kv_head_idx, offset);\n\n        FastRoPE rope(params.rope_param, batch_idx, std::integral_constant<int, kVecSize>{});\n        PRAGMA_UNROLL\n        for (int c = 0; c < ITER_C; ++c) {\n            const int di = offset.x + c * Map::kDeltaC;\n            rope.init(di);\n            PRAGMA_UNROLL\n            for (int s = 0; s < ITER_S; ++s) {\n                const int ti = (offset.y + s * Map::kDeltaS) / CTA_H + query_idx + history_len;\n                rope.apply(vec_Q[s][c], ti);\n                if constexpr (kProcessKV) {\n                    if (s == 0) {\n                        rope.apply(vec_K[0][c], ti);\n                    }\n                }\n            }\n        }\n\n        if (params.use_logn_attn) {\n            PRAGMA_UNROLL\n            for (int s = 0; s < ITER_S; ++s) {\n                const int   ti = (offset.y + s * Map::kDeltaS) / CTA_H + query_idx + history_len;\n                LogNScaling logn_scaling(ti, params.max_position_embeddings);\n                PRAGMA_UNROLL\n                for (int c = 0; c < ITER_C; ++c) {\n                    logn_scaling.apply(vec_Q[s][c]);\n                }\n            }\n        }\n\n        if constexpr (kProcessKV) {\n            const int qi = offset.y / CTA_H;\n            const int ti = history_len;\n\n            int local_ti, local_ti_rank;\n            local_ti = params.cp_size.divmod(local_ti_rank, ti);\n\n            Array<T, 2> param_K[1];\n            Array<T, 2> param_V[1];\n\n            if constexpr (!std::is_same_v<T, Tkv>) {\n                warp_stats<Map::kWarpThreadC>(param_K, vec_K, bitsof<Tkv>);\n                if constexpr (HAS_V) {\n                    warp_stats<Map::kWarpThreadC>(param_V, vec_V, bitsof<Tkv>);\n                }\n            }\n\n            Array<Tkv, kVecSize> out_K[1][ITER_C];\n            Array<Tkv, kVecSize> out_V[1][ITER_C];\n\n            ConvertKvCache<T, Tkv> conv_K{param_K[0][0], param_K[0][1]};\n            ConvertKvCache<T, Tkv> conv_V{param_V[0][0], param_V[0][1]};\n            PRAGMA_UNROLL\n            for (int c = 0; c < ITER_C; ++c) {\n                out_K[0][c] = conv_K(vec_K[0][c]);\n                if constexpr (HAS_V) {\n                    out_V[0][c] = conv_V(vec_V[0][c]);\n                }\n            }\n\n            iterator.block_head_.with(\n                iterator.block_ptrs_, local_ti, [&](auto k_cache, auto v_cache, T* k_param, T* v_param) {\n                    if (local_ti_rank != params.cp_rank) {\n                        return;\n                    }\n                    PRAGMA_UNROLL\n                    for (int c = 0; c < ITER_C; ++c) {\n                        const int di = offset.x + c * Map::kDeltaC;\n                        if (qi < CTA_Q) {\n                            Store(&k_cache[di], out_K[0][c]);\n                            if constexpr (HAS_V) {\n                                Store(&v_cache[di], out_V[0][c]);\n                            }\n                        }\n                    }\n                    if constexpr (!std::is_same_v<T, Tkv>) {\n                        if (qi < CTA_Q && offset.x == 0) {\n                            StoreQuantParam<Tkv>(k_param, param_K[0]);\n                            if constexpr (HAS_V) {\n                                StoreQuantParam<Tkv>(v_param, param_V[0]);\n                            }\n                        }\n                    }\n                });\n\n            __syncthreads();\n        }\n\n        using SmemLayoutQ = typename Impl::SmemLayoutQ;\n\n        SmemAccessor<T, SmemLayoutQ> sQ{smem_Q};\n\n        // Store to shared memory\n        PRAGMA_UNROLL\n        for (int s = 0; s < ITER_S; ++s) {\n            const int si = offset.y + s * Map::kDeltaS;\n            const int hi = si % CTA_H;\n            const int qi = si / CTA_H;\n            PRAGMA_UNROLL\n            for (int c = 0; c < ITER_C; ++c) {\n                const int di = offset.x + c * Map::kDeltaC;\n                if (qi < CTA_Q && hi < CTA_H) {\n                    Store(&sQ(si, di), vec_Q[s][c]);\n                }\n            }\n        }\n\n        __syncthreads();\n\n        Impl::TransformQ(smem_Q, frag_Q);\n    }\n\n    __device__ AttentionUniversal(int q_group_size, int q_head_per_cta, int cta_per_q_group):\n        q_group_size_{q_group_size}, q_head_per_cta_{q_head_per_cta}, cta_per_q_group_{cta_per_q_group}\n    {\n    }\n\n    __device__ void\n    operator()(const ParamType& params, CacheIteratorFactory& cache_iter_factory, const CtaMap& cta_map, char* smem_buf)\n    {\n        // [q, h, b]\n        const int query_idx = cta_map.query_idx() * CTA_Q;  // Q offset of this sequence\n        const int batch_idx = cta_map.batch_idx();\n        const int split_idx = cta_map.split_idx();\n        const int split_cnt = cta_map.split_count();\n\n        int head_idx;\n        int kv_head_idx;\n\n        if constexpr (CTA_H == 1) {\n            head_idx    = cta_map.head_idx();\n            kv_head_idx = head_idx / q_group_size_;\n        }\n        else {\n            int cta_h_idx = cta_map.head_idx();\n            int local_idx = cta_h_idx % cta_per_q_group_ * q_head_per_cta_;\n            kv_head_idx   = cta_h_idx / cta_per_q_group_;\n            head_idx      = kv_head_idx * q_group_size_ + local_idx;\n            hi_end_       = q_group_size_ - local_idx;\n        }\n\n        // early exit if finished flag is set\n        if (params.finished[batch_idx]) {\n            return;\n        }\n\n        const int qi_begin = params.cu_q_len[batch_idx] + query_idx;  // global offset into `cu_seqlens`\n        const int qi_end   = params.cu_q_len[batch_idx + 1];\n\n        if (qi_begin >= qi_end) {\n            return;\n        }\n\n        const int input_len = qi_end - (qi_begin - query_idx);\n\n        SharedStorage& storage = *(SharedStorage*)smem_buf;\n\n        const int warp_id = threadIdx.x / WARP_SIZE;\n        const int lane_id = threadIdx.x % WARP_SIZE;\n\n        const int context_len = params.cu_k_len[batch_idx + 1] - params.cu_k_len[batch_idx];\n        const int history_len = context_len - input_len;\n\n        auto get_cp_len = [&](int length, int rank) -> int {\n            int local_ti, local_ti_rank;\n            local_ti = params.cp_size.divmod(local_ti_rank, length);\n            return (local_ti + (local_ti_rank > rank ? 1 : 0));\n        };\n\n        const int last_K = history_len + min(query_idx + CTA_Q, input_len);\n        const int last_K_tile =\n            (get_cp_len(last_K, 0) - 1) / CTA_S + 1;  // past-the-end index to past-the-end tile index conversion\n\n        const int first_K      = max(history_len + query_idx - (params.window_size - 1), 0);\n        const int first_K_tile = get_cp_len(first_K, 0) / CTA_S;\n\n        const int tile_count = last_K_tile - first_K_tile;\n\n        /// FIXME: This scheme produce splits less than expected\n        const int tile_per_split = cdiv(tile_count, split_cnt);\n        const int iter_begin     = tile_per_split * split_idx;\n        const int iter_end       = min(iter_begin + tile_per_split, tile_count);\n\n        if (iter_begin >= iter_end) {\n            return;\n        }\n\n        auto cache_iter = cache_iter_factory.Create(batch_idx, kv_head_idx);\n\n        FragQ frag_Q;\n        Prologue(params,\n                 storage.Q,\n                 frag_Q,\n                 qi_begin,\n                 qi_end,\n                 query_idx,\n                 head_idx,\n                 kv_head_idx,\n                 batch_idx,\n                 history_len,\n                 cache_iter,\n                 warp_id,\n                 lane_id);\n\n        __align__(16) FragO frag_O{};\n\n        FragL frag_L{};\n        FragM frag_M;\n        fill(frag_M, -std::numeric_limits<float>::infinity());\n\n        __syncthreads();\n\n        const int offset_Q = history_len + query_idx;\n        const int offset_K = (first_K_tile + iter_end - 1) * CTA_S;\n\n        // This is for avoiding OOB access only\n        const int max_K = min(get_cp_len(context_len, params.cp_rank), (first_K_tile + iter_end) * CTA_S);\n\n        int tile_iter = iter_end - iter_begin;\n\n        //    min(Q) >= max(K)\n        // -> offset_Q >= offset_K + CTA_S - x * CTA_S\n        // -> x * CTA_S >= offset_K - offset_Q + CTA_S\n        int mask_iter_back = cdiv(max(0, offset_K - offset_Q + CTA_S), CTA_S);\n        //    max(Q) < min(K) + w\n        // -> offset_Q + CTA_Q - 1 < offset_K - tile_iter * CTA_S + x * CTA_S + w\n        // -> x * CTA_S >= offset_Q + CTA_Q - offset_K + tile_iter * CTA_S - w\n        int mask_iter_front = cdiv(max(0, offset_Q + CTA_Q - offset_K + tile_iter * CTA_S - params.window_size), CTA_S);\n\n        if (params.cp_size > 1) {\n            mask_iter_back =\n                cdiv(max(0, params.cp_size * (offset_K + CTA_S) - offset_Q + params.cp_rank), params.cp_size * CTA_S);\n            mask_iter_front = cdiv(max(0,\n                                       offset_Q + CTA_Q - params.window_size - params.cp_rank\n                                           - params.cp_size * (offset_K - tile_iter * CTA_S)),\n                                   params.cp_size * CTA_S);\n        }\n\n#if 0\n        if (threadIdx.x == 0) {\n            printf(\n                \"tile count: %d, tile per iter: %d, range_Q: [%d, %d), offset_K: %d, max_K: %d, tile_iter: %d, range_K: [%d, %d), range_K_tiles: [%d, %d), mask_iter: %d, mask_iter_front: %d\\n\",\n                tile_count,\n                tile_per_split,\n                offset_Q,\n                offset_Q + min(query_idx + CTA_Q, input_len),\n                offset_K,\n                max_K,\n                tile_iter,\n                first_K,\n                last_K,\n                first_K_tile * CTA_S,\n                last_K_tile * CTA_S,\n                mask_iter_back,\n                mask_iter_front);\n        }\n#endif\n\n        cache_iter.SetTile(first_K_tile + iter_end - 1);\n\n        Mainloop mainloop;\n        mainloop.SetCpInfo(params.cp_size, params.cp_rank);\n        mainloop(frag_Q,\n                 cache_iter,\n                 frag_O,\n                 frag_M,\n                 frag_L,\n                 offset_Q,\n                 offset_K,\n                 max_K,\n                 tile_iter,\n                 mask_iter_back,\n                 mask_iter_front,\n                 params.window_size,\n                 params.inv_sqrt_dh,\n                 storage,\n                 StoreS(params, query_idx, head_idx, batch_idx, context_len));\n\n        Impl::Merge(frag_O, frag_M, frag_L, params.inv_sqrt_dh, storage);\n\n        if (params.sinks && iter_end == tile_count && params.cp_rank == 0) {\n            Impl::ForeachML(frag_M, frag_L, [&](int hi, int qi, int ri, float& M, float& L) {\n                if (check_h(hi) && M != -std::numeric_limits<float>::infinity()) {\n                    auto sink = (float)params.sinks[head_idx + hi];\n                    L += expf(sink - M * params.scale_sinks);\n                }\n            });\n        }\n\n        if (split_cnt > 1 && iter_end == tile_count && head_idx == 0) {\n            // Store actual split count, only used by separate reduction kernel\n            for (int ti = threadIdx.x; ti < CTA_Q; ti += kWarpCount * WARP_SIZE) {\n                if (qi_begin + ti < qi_end) {\n                    params.split_cnt[qi_begin + ti] = split_idx ? split_idx + 1 : (params.cp_size > 1 ? 1 : 0);\n                }\n            }\n        }\n\n        if (iter_begin == 0 && iter_end == tile_count && params.cp_size == 1) {\n            StoreO(frag_O, frag_L, qi_begin, qi_end, head_idx, params, storage);\n        }\n        else {\n            StorePartial(frag_O, frag_M, frag_L, split_cnt, qi_begin, qi_end, head_idx, split_idx, params, storage);\n        }\n    }\n\n    __device__ void StoreO(FragO&           frag_O,\n                           FragL&           frag_L,\n                           int              qi_begin,\n                           int              qi_end,\n                           int              head_idx,\n                           const ParamType& params,\n                           SharedStorage&   storage)\n    {\n        Impl::StoreO<true>(frag_O, frag_L, storage, [&](int hi, int qi, int di, const auto& vec) {\n            if (qi_begin + qi < qi_end && check_h(hi)) {\n                const int offset = (qi_begin + qi) * params.num_heads * kHeadDim + (head_idx + hi) * kHeadDim + di;\n                Store(&params.out[offset], cast<T>(vec));\n            }\n        });\n    }\n\n    __device__ auto StoreS(const ParamType& params,\n                           const int&       query_idx,\n                           const int&       head_idx,\n                           const int&       batch_idx,\n                           const int&       max_context_len)\n    {\n        return [&](auto& frag_S, int offset_K) {\n            Impl::ForeachS(frag_S, [&](int hi, int qi, int si, int ri, float score) {\n                qi += query_idx;\n                si += offset_K;\n                if (qi < params.max_q_len && si < max_context_len && check_h(hi)) {\n                    params.qk[batch_idx * params.num_heads * params.max_q_len * max_context_len\n                              + (head_idx + hi) * params.max_q_len * max_context_len + qi * max_context_len + si] =\n                        score;\n                }\n            });\n        };\n    }\n\n    __device__ void StorePartial(FragO&           frag_O,\n                                 FragM&           frag_M,\n                                 FragL&           frag_L,\n                                 int              split_cnt,\n                                 int              qi_begin,\n                                 int              qi_end,\n                                 int              head_idx,\n                                 int              split_idx,\n                                 const ParamType& params,\n                                 SharedStorage&   storage)\n    {\n        auto get_index = [&](int hi, int qi) {\n            // [B, H, k, D]\n            return (qi_begin + qi - params.offset_q) * params.num_heads * params.max_split_k\n                   + (head_idx + hi) * params.max_split_k + split_idx;\n        };\n\n        Impl::StoreO<false>(frag_O, frag_L, storage, [&](int hi, int qi, int di, const auto& vec) {\n            if (qi_begin + qi < qi_end && check_h(hi)) {\n                Store(&params.partial_O[get_index(hi, qi) * kHeadDim + di], vec);\n            }\n        });\n\n        Impl::ForeachML(frag_M, frag_L, [&](int hi, int qi, int ri, float M, float L) {\n            const int index = get_index(hi, qi);\n            if (qi_begin + qi < qi_end && ri == 0 && check_h(hi)) {\n                Store(&params.partial_ML[index * 2], Array<float, 2>{M, L});\n            }\n        });\n    }\n};\n\n/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////\n\nextern __shared__ char smem_buf[];\n\ntemplate<class Kernel>\n__global__ void attention_kernel(typename Kernel::ParamType            params,\n                                 typename Kernel::CacheIteratorFactory cache_iter_factory,\n                                 typename Kernel::CtaMap               cta_map,\n                                 int                                   q_group_size,\n                                 int                                   q_head_per_cta,\n                                 int                                   cta_per_q_group)\n{\n#if __CUDA_ARCH__\n    if constexpr (Kernel::Arch::is_compatible(__CUDA_ARCH__)) {\n        Kernel{q_group_size, q_head_per_cta, cta_per_q_group}(params, cache_iter_factory, cta_map, smem_buf);\n    }\n#endif\n}\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/kernels/attention/block.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include \"src/turbomind/kernels/core/common.h\"\n#include \"src/turbomind/kernels/core/data_type.h\"\n#include \"src/turbomind/kernels/core/sub_byte_ptr.h\"\n#include <iostream>\n#include <type_traits>\n\nnamespace turbomind {\n\nnamespace block {\n\ntemplate<class T, class Tkv, int HeadDim, bool ShareKV = false>\nstruct Config {\n    int head_num_;\n    int block_len_;\n\n    TM_HOST_DEVICE constexpr int t_bits() const\n    {\n        if constexpr (std::is_same_v<T, Tkv>) {\n            return 0;\n        }\n        else {\n            return bitsof<T>;\n        }\n    }\n\n    TM_HOST_DEVICE constexpr int q_bits() const\n    {\n        return bitsof<Tkv>;\n    }\n\n    TM_HOST_DEVICE constexpr int head_dim() const\n    {\n        return HeadDim;\n    }\n\n    TM_HOST_DEVICE int head_num() const\n    {\n        return head_num_;\n    }\n\n    TM_HOST_DEVICE constexpr int block_len() const\n    {\n        return block_len_;\n    }\n\n    TM_HOST_DEVICE constexpr bool is_share_kv() const\n    {\n        return ShareKV;\n    }\n};\n\n// Layout -> LayerId -> HeadId -> Timestep -> [Block] -> (k_data, v_data, k_param, v_param)\n\ntemplate<class T, class Tkv, class Layout>\nclass Head {\npublic:\n    TM_HOST_DEVICE Head(Layout layout, int layer_id, int head_id):\n        layout_{layout}, layer_id_{layer_id}, head_id_{head_id}\n    {\n    }\n\n    TM_HOST_DEVICE auto k_data(char* block, int ti) const\n    {\n        if constexpr (std::is_same_v<Tkv, uint4_t>) {\n            return SubBytePtr<Tkv>{block + layout_.k_data(layer_id_, head_id_, ti)};\n        }\n        else {\n            return reinterpret_cast<Tkv*>(block + layout_.k_data(layer_id_, head_id_, ti));\n        }\n    }\n\n    TM_HOST_DEVICE auto v_data(char* block, int ti) const\n    {\n        if constexpr (std::is_same_v<Tkv, uint4_t>) {\n            return SubBytePtr<Tkv>{block + layout_.v_data(layer_id_, head_id_, ti)};\n        }\n        else {\n            return reinterpret_cast<Tkv*>(block + layout_.v_data(layer_id_, head_id_, ti));\n        }\n    }\n\n    TM_HOST_DEVICE T* k_param(char* block, int ti) const\n    {\n        return reinterpret_cast<T*>(block + layout_.k_param(layer_id_, head_id_, ti));\n    }\n\n    TM_HOST_DEVICE T* v_param(char* block, int ti) const\n    {\n        return reinterpret_cast<T*>(block + layout_.v_param(layer_id_, head_id_, ti));\n    }\n\n    TM_HOST_DEVICE void get_block_coord(int seq_ti, int& block_idx, int& block_ti) const\n    {\n        block_idx = seq_ti / block_len();\n        block_ti  = seq_ti % block_len();\n    }\n\n    TM_HOST_DEVICE auto block_len() const\n    {\n        return layout_.config().block_len();\n    }\n\n    template<class Func>\n    TM_HOST_DEVICE auto with(char** block_ptrs, int ti, Func&& func) const\n    {\n        int block_id;\n        int block_ti;\n        get_block_coord(ti, block_id, block_ti);\n\n        char* block = block_ptrs[block_id];\n\n        return ((Func &&) func)(\n            k_data(block, block_ti), v_data(block, block_ti), k_param(block, block_ti), v_param(block, block_ti));\n    }\n\nprivate:\n    Layout layout_;\n\n    int layer_id_;\n    int head_id_;\n};\n\n// L(H2SDQ+H2S2T)\ntemplate<class Config_>\nstruct Layout {\n\n    using Config = Config_;\n\n    Config config_;\n\n    // This trivial ctor is defined for CTAD\n    TM_HOST_DEVICE Layout(Config config): config_{config} {}\n\n    TM_HOST_DEVICE const Config& config() const\n    {\n        return config_;\n    }\n\n    TM_HOST_DEVICE constexpr bool is_share_kv() const\n    {\n        // return 0;\n        return config().is_share_kv();\n    }\n\n    TM_HOST_DEVICE constexpr int kv_num() const\n    {\n        // return 2;\n        return is_share_kv() ? 1 : 2;\n    }\n\n    TM_HOST_DEVICE int token_data_size() const\n    {\n        return config().q_bits() * config().head_dim() / 8;\n    }\n\n    TM_HOST_DEVICE int token_param_size() const\n    {\n        return config().t_bits() * 2 / 8;  // 2 for scales/zeros\n    }\n\n    TM_HOST_DEVICE int head_data_size() const\n    {\n        return config().block_len() * token_data_size();\n    }\n\n    TM_HOST_DEVICE int head_param_size() const\n    {\n        return config().block_len() * token_param_size();\n    }\n\n    TM_HOST_DEVICE int layer_size() const\n    {\n        // TODO: enforce alignment\n        return config().head_num() * kv_num() * head_data_size() + config().head_num() * kv_num() * head_param_size();\n    }\n\n    TM_HOST_DEVICE int block_size(int layer_num) const\n    {\n        return layer_size() * layer_num;\n    }\n\n    TM_HOST_DEVICE int k_data(int layer, int head, int token) const\n    {\n        return layer_data(layer) + head_data(head) + token_data(token);\n    }\n\n    TM_HOST_DEVICE int v_data(int layer, int head, int token) const\n    {\n        return k_data(layer, head, token) + (is_share_kv() ? 0 : head_data_size());\n    }\n\n    TM_HOST_DEVICE int k_param(int layer, int head, int token) const\n    {\n        return layer_param(layer) + head_param(head) + token_param(token);\n    }\n\n    TM_HOST_DEVICE int v_param(int layer, int head, int token) const\n    {\n        return k_param(layer, head, token) + (is_share_kv() ? 0 : head_param_size());\n    }\n\n    TM_HOST_DEVICE int layer_data(int layer) const\n    {\n        return layer * layer_size();\n    }\n\n    TM_HOST_DEVICE int layer_param(int layer) const\n    {\n        return layer_data(layer) + head_data(config_.head_num());\n    }\n\n    TM_HOST_DEVICE int head_data(int head) const\n    {\n        return head * kv_num() * head_data_size();\n    }\n\n    TM_HOST_DEVICE int head_param(int head) const\n    {\n        return head * kv_num() * head_param_size();\n    }\n\n    TM_HOST_DEVICE int token_data(int ti) const\n    {\n        return ti * token_data_size();\n    }\n\n    TM_HOST_DEVICE int token_param(int ti) const\n    {\n        return ti * token_param_size();\n    }\n};\n\ntemplate<class Config>\nvoid dump(const Layout<Config>& layout)\n{\n    std::cout << \"head_dim: \" << layout.config().head_dim() << \"\\n\";\n    std::cout << \"head_num: \" << layout.config().head_num() << \"\\n\";\n    std::cout << \"block_len: \" << layout.config().block_len() << \"\\n\";\n    std::cout << \"q_bits: \" << layout.config().q_bits() << \"\\n\";\n    std::cout << \"t_bits: \" << layout.config().t_bits() << \"\\n\";\n    std::cout << \"token_data_size: \" << layout.token_data_size() << \"\\n\";\n    std::cout << \"token_param_size: \" << layout.token_param_size() << \"\\n\";\n    std::cout << \"head_data_size: \" << layout.head_data_size() << \"\\n\";\n    std::cout << \"head_param_size: \" << layout.head_param_size() << \"\\n\";\n    std::cout << \"layer_size: \" << layout.layer_size() << \"\\n\";\n}\n\n}  // namespace block\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/kernels/attention/block_iterator.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include \"attention_params.h\"\n#include \"block.h\"\n\nnamespace turbomind {\n\ntemplate<class BlockHead, int CTA_S>\nstruct BlockIterator {\n\n    BlockHead block_head_;\n    char**    block_ptrs_;\n\n    char* block_{};\n    int   block_id_{};\n    int   block_ti_{};\n\n    __device__ BlockIterator(BlockHead block_head, char** block_ptrs): block_head_{block_head}, block_ptrs_{block_ptrs}\n    {\n    }\n\n    __device__ void SetTile(int iter)\n    {\n        block_head_.get_block_coord(iter * CTA_S, block_id_, block_ti_);\n        block_ = block_ptrs_[block_id_];\n    }\n\n    __device__ void Advance()\n    {\n        block_ti_ -= CTA_S;\n        if (block_ti_ < 0) {\n            block_ti_ += block_head_.block_len();\n            block_id_ -= 1;\n        }\n        if (block_id_ >= 0) {\n            block_ = block_ptrs_[block_id_];\n        }\n    }\n\n    template<int Index>\n    __device__ auto OffsetPtr(int offset) const\n    {\n        if constexpr (Index == 0) {\n            return block_head_.k_data(block_, block_ti_) + offset;\n        }\n        else if constexpr (Index == 1) {\n            return block_head_.v_data(block_, block_ti_) + offset;\n        }\n        else if constexpr (Index == 2) {\n            return block_head_.k_param(block_, block_ti_) + offset;\n        }\n        else if constexpr (Index == 3) {\n            return block_head_.v_param(block_, block_ti_) + offset;\n        }\n        else {\n            static_assert(Index != Index, \"invalid index\");\n        }\n    }\n};\n\ntemplate<class T, class Tkv, class BlockLayout_, int CTA_S>\nstruct BlockIteratorFactory {\n    using BlockLayout = BlockLayout_;\n\n    BlockLayout_ block_layout_;\n    char**       block_ptrs_;\n    const int*   cu_block_nums_;\n    int          layer_idx_;\n\n    __device__ auto Create(int batch_idx, int head_idx)\n    {\n        block::Head<T, Tkv, BlockLayout> head{block_layout_, layer_idx_, head_idx};\n\n        char** block_ptrs = block_ptrs_ + cu_block_nums_[batch_idx];\n\n        return BlockIterator<block::Head<T, Tkv, BlockLayout>, CTA_S>{head, block_ptrs};\n    }\n};\n\ntemplate<class CacheIterFactory>\nstruct CreateCacheIterFactory<CacheIterFactory, std::void_t<typename CacheIterFactory::BlockLayout>> {\n    template<class Param>\n    static CacheIterFactory apply(const Param& param)\n    {\n        using BlockLayout = typename CacheIterFactory::BlockLayout;\n        using BlockConfig = typename BlockLayout::Config;\n\n        return {\n            BlockLayout{BlockConfig{param.num_kv_heads, param.block_iter_params.block_len}},\n            param.block_iter_params.block_ptrs,\n            param.block_iter_params.cu_block_nums,\n            param.block_iter_params.layer_id,\n        };\n    }\n};\n\ntemplate<class T, class Tkv, int CTA_S, int HeadDim>\nusing GetBlockIterFactory =\n    BlockIteratorFactory<T, Tkv, block::Layout<block::Config<T, Tkv, HeadDim, HeadDim == 576>>, CTA_S>;\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/kernels/attention/cp_utils.cu",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include \"src/turbomind/kernels/attention/cp_utils.h\"\n\nnamespace turbomind {\n\nvoid CpPost(void* context)\n{\n    auto ctx = reinterpret_cast<CpPostContext*>(context);\n\n    ctx->d_comm->AllGather(ctx->partial_ML + ctx->cp_rank * ctx->count,  //\n                           ctx->partial_ML,\n                           ctx->count,\n                           DataType::kFloat,\n                           ctx->attn_cp_group,\n                           ctx->stream);\n    sync_check_cuda_error();\n}\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/kernels/attention/cp_utils.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include \"src/turbomind/comm/device_comm.h\"\n#include \"src/turbomind/utils/cuda_utils.h\"\n\nnamespace turbomind {\n\nstruct CpPostContext {\n\n    CpPostContext(comm::DeviceCommImpl* d_comm, int attn_cp_group): d_comm(d_comm), attn_cp_group(attn_cp_group) {}\n\n    comm::DeviceCommImpl* d_comm;\n    int                   attn_cp_group;\n\n    int          cp_rank;\n    int          count;\n    float*       partial_ML;\n    cudaStream_t stream;\n};\n\nvoid CpPost(void* context);\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/kernels/attention/cta_map.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\nnamespace turbomind::attention {\n\n#if 1\nstruct AttentionCtaMap {\n\n    int q_cta_cnt_;\n    int h_cta_cnt_;\n    int batch_size_;\n    int split_cnt_;\n\n    __host__ __device__\n    AttentionCtaMap(int max_q_len, int batch_size, int head_num, int cta_q, int cta_h, int split_cnt):\n        q_cta_cnt_((max_q_len + cta_q - 1) / cta_q),\n        h_cta_cnt_(head_num / cta_h),\n        batch_size_(batch_size),\n        split_cnt_(split_cnt)\n    {\n    }\n\n    __host__ __device__ void set_split_cnt(int value)\n    {\n        split_cnt_ = value;\n    }\n\n    __host__ dim3 get_grid_shape() const\n    {\n        return dim3(q_cta_cnt_, batch_size_, split_cnt_ * h_cta_cnt_);\n    }\n    __device__ int query_idx() const\n    {\n        return blockIdx.x;\n    }\n    __device__ int head_idx() const\n    {\n        return blockIdx.z % h_cta_cnt_;\n    }\n    __device__ int batch_idx() const\n    {\n        return blockIdx.y;\n    }\n    __device__ int split_idx() const\n    {\n        return blockIdx.z / h_cta_cnt_;\n    }\n    __device__ int split_count() const\n    {\n        return split_cnt_;\n    }\n};\n#else\nstruct AttentionCtaMap {\n\n    int q_cta_cnt_;\n    int h_cta_cnt_;\n    int batch_size_;\n    int split_cnt_;\n\n    __host__ __device__\n    AttentionCtaMap(int max_q_len, int batch_size, int head_num, int cta_q, int cta_h, int split_cnt):\n        q_cta_cnt_((max_q_len + cta_q - 1) / cta_q),\n        h_cta_cnt_(head_num / cta_h),\n        batch_size_(batch_size),\n        split_cnt_(split_cnt)\n    {\n    }\n\n    __host__ __device__ void set_split_cnt(int value)\n    {\n        split_cnt_ = value;\n    }\n\n    __host__ dim3 get_grid_shape() const\n    {\n        return dim3(q_cta_cnt_, h_cta_cnt_, split_cnt_ * batch_size_);\n    }\n    __device__ int query_idx() const\n    {\n        return blockIdx.x;\n    }\n    __device__ int head_idx() const\n    {\n        return blockIdx.y;\n    }\n    __device__ int batch_idx() const\n    {\n        return blockIdx.z % batch_size_;\n    }\n    __device__ int split_idx() const\n    {\n        return blockIdx.z / batch_size_;\n    }\n    __device__ int split_count() const\n    {\n        return split_cnt_;\n    }\n};\n#endif\n\nstruct DecodingCtaMap {\n    static __host__ dim3 get_grid_shape(int kv_head_num, int batch_size, int split_count, int cta_per_q_group)\n    {\n        return dim3(cta_per_q_group * kv_head_num, batch_size, split_count);\n    }\n    __device__ int query_idx() const\n    {\n        return 0;\n    }\n    __device__ int head_idx() const\n    {\n        return blockIdx.x;\n    }\n    __device__ int batch_idx() const\n    {\n        return blockIdx.y;\n    }\n    __device__ int split_idx() const\n    {\n        return blockIdx.z;\n    }\n    __device__ int split_count() const\n    {\n        return gridDim.z;\n    }\n};\n\nstruct ReduceCtaMap {\n    static __host__ dim3 get_grid_shape(int query_num, int head_num, int max_split_cnt, int cta_k)\n    {\n        return dim3(head_num, query_num, (max_split_cnt + cta_k - 1) / cta_k);\n    }\n    static __device__ int query_idx()\n    {\n        return blockIdx.y;\n    }\n    static __device__ int head_idx()\n    {\n        return blockIdx.x;\n    }\n    static __device__ int split_idx()\n    {\n        return blockIdx.z;\n    }\n};\n\n}  // namespace turbomind::attention\n"
  },
  {
    "path": "src/turbomind/kernels/attention/decoding.cu",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include \"decoding.h\"\n#include \"src/turbomind/core/data_type.h\"\n#include \"src/turbomind/kernels/attention/registry.h\"\n#include \"src/turbomind/models/llama/llama_utils.h\"\n#include \"src/turbomind/utils/cuda_utils.h\"\n\nnamespace turbomind {\n\ntemplate<class T>\nvoid dispatchDecoding(const AttentionParams<T>& params)\n{\n    using namespace attention;\n\n    const bool is_kv_int8     = params.quant_policy & QuantPolicy::kCacheKVInt8;\n    const bool is_kv_int4     = params.quant_policy & QuantPolicy::kCacheKVInt4;\n    const int  query_group_sz = params.num_heads / params.num_kv_heads;\n\n    FT_CHECK(!(is_kv_int4 && is_kv_int8));\n\n    int kv_quant = is_kv_int4 ? 4 : (is_kv_int8 ? 8 : 0);\n\n    AttnDesc desc{};\n    desc.mode           = AttnDesc::kDecoding;\n    desc.head_dim       = params.size_per_head;\n    desc.data_type      = data_type_v<T>;\n    desc.kv_quant       = kv_quant;\n    desc.query_group_sz = query_group_sz;\n\n    auto& reg    = Registry::instance();\n    auto* kernel = reg.Find(desc);\n\n    TM_CHECK(kernel) << \"No decoding kernel found: \" + to_string(desc);\n\n    kernel->Launch(&params, reg.sm_count());\n}\n\ntemplate void dispatchDecoding(const AttentionParams<half>& params);\n#if ENABLE_BF16\ntemplate void dispatchDecoding(const AttentionParams<nv_bfloat16>& params);\n#endif\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/kernels/attention/decoding.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include \"attention_params.h\"\n\nnamespace turbomind {\n\ntemplate<class T>\nvoid dispatchDecoding(const AttentionParams<T>& params);\n\n}\n"
  },
  {
    "path": "src/turbomind/kernels/attention/decoding_template.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include \"attention_params.h\"\n#include \"attention_universal.h\"\n#include \"reduce.h\"\n#include \"src/turbomind/kernels/core/thread_map.h\"\n#include \"utils.h\"\nnamespace turbomind {\n\ntemplate<class Kernel>\nbool invokeDecoding(const typename Kernel::ParamType& params, int sm_count, int max_active_ctas)\n{\n    static const size_t kSmemSize = sizeof(typename Kernel::SharedStorage);\n\n    if constexpr (1) {\n        [[maybe_unused]] static const int _ = [&] {\n            // std::cout << __PRETTY_FUNCTION__ << std::endl;\n            // std::cout << \"GmemMap:\\n\";\n            // Print(typename Kernel::Impl::ThreadMapKV{});\n            // std::cout << \"\\nDynamic smem size: \" << kSmemSize << \"\\n\";\n            return 0;\n        }();\n    }\n\n    const int max_cp_k_len    = cdiv(params.max_k_len, (int)params.cp_size);\n    const int tile_count      = cdiv(std::min(max_cp_k_len, params.window_size), Kernel::CTA_S);\n    const int max_split_count = std::min(params.max_split_k, tile_count);\n\n    using CtaMap = typename Kernel::CtaMap;\n\n    dim3 block(Kernel::kWarpCount * WARP_SIZE);\n\n    auto kernel_func = &attention_kernel<Kernel>;\n\n    const int q_group_size   = params.num_heads / params.num_kv_heads;\n    const int q_head_per_cta = std::min(q_group_size, Kernel::CTA_H);\n\n    // cta needed to process one query group\n    const int cta_per_q_group = (q_group_size + q_head_per_cta - 1) / q_head_per_cta;\n\n    // std::cout << \"CTA_H: \" << Kernel::CTA_H << \", head_per_cta: \" << q_head_per_cta\n    //           << \", cta_per_q_group: \" << cta_per_q_group << \"\\n\";\n\n    dim3 grid = CtaMap::get_grid_shape(params.num_kv_heads, params.batch_size, 1, cta_per_q_group);\n\n    const int grid_size = grid.x * grid.y * grid.z;\n    const int split_cnt = GetSplitCount(max_split_count, grid_size, max_active_ctas, sm_count, 4);\n\n    grid = CtaMap::get_grid_shape(params.num_kv_heads, params.batch_size, split_cnt, cta_per_q_group);\n\n    // Print(typename Kernel::Impl::ThreadMapKVp{});\n\n    // std::cout << \"split count: \" << split_cnt << \"\\n\";\n\n    auto cache_iter_factory = CreateCacheIterFactory<typename Kernel::CacheIteratorFactory>::apply(params);\n\n    kernel_func<<<grid, block, kSmemSize, params.stream>>>(\n        params, cache_iter_factory, CtaMap{}, q_group_size, q_head_per_cta, cta_per_q_group);\n\n    if (auto err = cudaGetLastError(); err != cudaSuccess) {\n        std::cout << cudaGetErrorString(err) << \"\\n\";\n        std::abort();\n    }\n\n    if (params.cp_fn) {\n        params.cp_fn(params.cp_fn_ctx);\n    }\n\n    if (split_cnt > 1 || params.cp_size > 1) {\n        attention::invokeReduceV3<Kernel::kHeadDim>(params.out,\n                                                    params.partial_ML,\n                                                    params.partial_O,\n                                                    split_cnt > 1 ? params.split_cnt : nullptr,\n                                                    params.max_split_k,\n                                                    split_cnt,\n                                                    params.cp_size,\n                                                    params.cp_rank,\n                                                    params.token_num,\n                                                    params.num_heads,\n                                                    params.inv_sqrt_dh,\n                                                    params.stream);\n    }\n\n    return true;\n}\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/kernels/attention/desc.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include \"src/turbomind/core/data_type.h\"\n#include <cuda_runtime.h>\n#include <sstream>\n#include <string>\n\nnamespace turbomind::attention {\n\nstruct AttnDesc {\n    enum Mode\n    {\n        kPrefill,\n        kDecoding\n    };\n    Mode     mode;\n    int      head_dim;\n    DataType data_type;\n    int      kv_quant;        // 0=none, 8=int8, 4=int4\n    int      query_group_sz;  // num_heads/num_kv_heads for decoding; 0 for prefill\n};\n\ninline std::string to_string(const AttnDesc& d)\n{\n    std::ostringstream ss;\n    ss << (d.mode == AttnDesc::kPrefill ? \"prefill\" : \"decode\");\n    ss << \"_d\" << d.head_dim;\n    ss << \"_\" << to_string(d.data_type);\n    if (d.mode == AttnDesc::kDecoding) {\n        if (d.kv_quant == 8)\n            ss << \"_kvint8\";\n        else if (d.kv_quant == 4)\n            ss << \"_kvint4\";\n        ss << \"_gs\" << d.query_group_sz;\n    }\n    return ss.str();\n}\n\nstruct KernelDesc {\n    AttnDesc::Mode mode;\n    int            arch;  // 700, 750, 800\n    int            head_dim;\n    DataType       data_type;\n    int            kv_quant;  // 0=none, 8=int8, 4=int4\n    int            qh;        // query heads per CTA (1 for prefill)\n};\n\nstruct KernelInfo {\n    int                dynamic_smem_size;\n    int                max_active_ctas;\n    int                num_warps;\n    std::string        name;\n    cudaFuncAttributes attr;\n};\n\ninline std::string to_string(const KernelDesc& d)\n{\n    std::ostringstream ss;\n    ss << (d.mode == AttnDesc::kPrefill ? \"prefill\" : \"decode\");\n    ss << \"_sm\" << d.arch / 10;\n    ss << \"_d\" << d.head_dim;\n    ss << \"_\" << to_string(d.data_type);\n    if (d.mode == AttnDesc::kDecoding) {\n        if (d.kv_quant == 8)\n            ss << \"_kvint8\";\n        else if (d.kv_quant == 4)\n            ss << \"_kvint4\";\n        ss << \"_qh\" << d.qh;\n    }\n    return ss.str();\n}\n\n}  // namespace turbomind::attention\n"
  },
  {
    "path": "src/turbomind/kernels/attention/impl.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\nnamespace turbomind {\n\nnamespace attention {\n\nstruct MMA_16816 {\n};\n\nstruct MMA_81616 {\n};  // MMA_16816 transposed\n\nstruct MMA_1688 {\n};\n\nstruct MMA_884 {\n};\n\nstruct MMA_SIMT {\n};\n\ntemplate<class Tag,\n         class T,\n         class Tkv,\n         int CTA_H,\n         int CTA_Q,\n         int CTA_S,\n         int WARP_H,\n         int WARP_Q,\n         int WARP_S,\n         int HeadDim,\n         int Stages = 2>\nstruct Impl {\n};\n\n}  // namespace attention\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/kernels/attention/impl_16816.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include \"src/turbomind/kernels/attention/impl.h\"\n#include \"src/turbomind/kernels/attention/impl_m16n8.h\"\n#include \"src/turbomind/kernels/core/array_ops.h\"\n#include \"src/turbomind/kernels/core/layout.h\"\n#include \"src/turbomind/kernels/core/mma.h\"\n#include \"src/turbomind/kernels/core/smem.h\"\n#include \"src/turbomind/kernels/core/thread_map.h\"\n\nnamespace turbomind::attention {\n\ntemplate<class T_, int CTA_H_, int CTA_Q_, int CTA_S_, int WARP_H, int WARP_Q, int WARP_S, int HeadDim, int Stages>\nstruct Impl<MMA_16816, T_, T_, CTA_H_, CTA_Q_, CTA_S_, WARP_H, WARP_Q, WARP_S, HeadDim, Stages>:\n    Impl_m16k8<T_, WARP_H, WARP_Q, WARP_S, HeadDim> {\n\n    using Base = Impl_m16k8<T_, WARP_H, WARP_Q, WARP_S, HeadDim>;\n\n    static constexpr bool MLA = HeadDim == 576;\n\n    using Base::OP_M;\n    using Base::OP_N;\n    using Base::K_M;\n    using Base::K_N;\n    using Base::V_M;\n    using Base::V_N;\n\n    using typename Base::FragS;\n    using typename Base::FragO;\n    using typename Base::FragM;\n    using typename Base::FragL;\n\n    using Base::ForeachS;\n    using Base::Softmax;\n    using Base::ConvertStoP;\n    using Base::StoreO;\n\n    using T   = T_;\n    using Tkv = T_;\n\n    static constexpr int kHeadDim = HeadDim;\n\n    static constexpr int CTA_H = CTA_H_;\n    static constexpr int CTA_Q = CTA_Q_;\n    static constexpr int CTA_S = CTA_S_;\n\n    static constexpr int kWarpCntQ  = CTA_Q * CTA_H / WARP_Q;\n    static constexpr int kWarpCntS  = CTA_S / WARP_S;\n    static constexpr int kWarpCount = kWarpCntQ * kWarpCntS;\n\n    static constexpr int OP_K = 16;\n\n    static constexpr int K_K = HeadDim / OP_K;  // 128 / 16 = 8\n    static constexpr int V_K = WARP_S / OP_K;   //  64 / 16 = 4  -> S4\n\n    using FragQ = Array<T, 8>[K_K][K_M];  // ((q8, d4), (Dk, Qm), (d2, q2, d2))\n                                          //    1   2    16  16     8   8   1\n    using FragK = Array<T, 4>[K_K][K_N];  // ((s8, d4), (Dk, Sn), (d2, d2))\n                                          //    1   2    16   8     8   1\n    using FragP = Array<T, 8>[V_M][V_K];  // ((q8, s4), (Qm, Sk), (s2, q2, s2))\n                                          //    1   2    16  16     8   8   1\n    using FragV = Array<T, 4>[V_K][V_N];  // ((d8, s4), (Sk, Dn), (s2, s2))\n                                          //    1   2    16   8     8   1\n\n    static_assert(sizeof(FragS) / 2 == sizeof(FragP));\n\n    using SmemLayoutQ = std::conditional_t<HeadDim % 128 == 0,\n                                           SmemLayoutV2<CTA_Q * CTA_H, HeadDim, 64, 128, Swizzle<3, 3, 4>>,\n                                           SmemLayoutV2<CTA_Q * CTA_H, HeadDim, 64, 64, Swizzle<3, 3, 3>>>;\n    using SmemLayoutK = std::conditional_t<HeadDim % 128 == 0,\n                                           SmemLayoutV2<CTA_S, HeadDim, 16, 128, Swizzle<3, 3, 4>>,\n                                           SmemLayoutV2<CTA_S, HeadDim, 16, 64, Swizzle<3, 3, 3>>>;\n    using SmemLayoutV = std::conditional_t<HeadDim % 128 == 0,\n                                           SmemLayoutV2<CTA_S, HeadDim, 16, 128, Swizzle<3, 3, 4>>,\n                                           SmemLayoutV2<CTA_S, HeadDim, 16, 64, Swizzle<3, 3, 3>>>;\n\n    using SmemLayoutKVp = void;\n\n    static constexpr bool kUseSmemQ = false;\n    static constexpr bool kUseSmemP = false;\n\n    static_assert(!kUseSmemQ, \"current smemQ impl yields inconsistent outputs\");\n\n    union SharedStorage {\n        __align__(16) T KV[Stages * (SmemLayoutK::kSize + SmemLayoutV::kSize) / 2];\n        __align__(16) T Q[SmemLayoutQ::kSize];\n    };\n\n    using ThreadMapQ  = RakedThreadMap<HeadDim, CTA_Q * CTA_H, 8, kWarpCount>;\n    using ThreadMapKV = RakedThreadMap<HeadDim, CTA_S, 8, kWarpCount>;\n\n    using ThreadMapKVp = void;\n\n    static constexpr int kBatchK = std::min(4, ThreadMapKV::kIterS);\n    static constexpr int kBatchV = kBatchK;\n\n    __device__ static void Sync()\n    {\n        __syncthreads();\n    }\n\n    template<class GmemIterK, class GmemIterV>\n    __device__ static void SetSmemKV(GmemIterK& gmem_K, GmemIterV& gmem_V, SharedStorage& storage, bool offset_kv)\n    {\n        int pred = offset_kv;\n        gmem_K.SetSmem(storage.KV);\n        gmem_V.SetSmem(storage.KV + pred * SmemLayoutK::kSize);\n    }\n\n    __device__ static void TransformQ(T* smem_Q, FragQ& frag_Q)\n    {\n        const int warp_id = threadIdx.x / WARP_SIZE;\n        const int lane_id = threadIdx.x % WARP_SIZE;\n\n        if constexpr (!kUseSmemQ) {\n            __syncwarp();\n\n            SmemAccessor<T, SmemLayoutQ> sQ{smem_Q};\n\n            // Load from shared memory using LDSM, rearrange to m16n8k16 atom layout\n            PRAGMA_UNROLL\n            for (int m = 0; m < K_M; ++m) {\n                PRAGMA_UNROLL\n                for (int k = 0; k < K_K; ++k) {\n                    const int qi = lane_id % 16 * 1 + m * 16 + warp_id * WARP_Q;\n                    const int di = lane_id / 16 * 8 + k * 16;\n                    ldsm_x4((Array<uint32_t, 4>&)frag_Q[k][m], cast_smem_ptr_to_uint(&sQ(qi, di)));\n                }\n            }\n        }\n\n        if constexpr (0) {\n            __syncthreads();\n\n            // Rearrange Q in smem so that swizzling is not needed for later LDSMs\n            constexpr int THREADS = kWarpCount * WARP_SIZE;\n            PRAGMA_UNROLL\n            for (int k = 0; k < K_K; ++k) {\n                PRAGMA_UNROLL\n                for (int m = 0; m < K_M; ++m) {\n                    constexpr int kVecSize = 8;\n                    Store(&smem_Q[(k * K_M * THREADS + m * THREADS + threadIdx.x) * kVecSize], frag_Q[k][m]);\n                }\n            }\n        }\n    }\n\n    struct StateQK {\n        SmemAccessor<T, SmemLayoutK> smem_K;\n        T*                           smem_Q;\n\n        FragQ frag_Q;\n        FragK frag_K;\n\n        __device__ StateQK(SharedStorage& storage, FragQ frag_Q_): smem_K{storage.KV}\n        {\n            if constexpr (!kUseSmemQ) {\n                PRAGMA_UNROLL\n                for (int k = 0; k < K_K; ++k) {\n                    PRAGMA_UNROLL\n                    for (int m = 0; m < K_M; ++m) {\n                        frag_Q[k][m] = frag_Q_[k][m];\n                    }\n                }\n            }\n            else {\n                smem_Q = storage.Q;\n            }\n        }\n\n        __device__ void Load(int k, int pipe_iter)\n        {\n            const int lane_id       = threadIdx.x % WARP_SIZE;\n            const int group_id      = lane_id / 16;\n            const int group_lane_id = lane_id % 16;\n            const int offset_s      = group_lane_id % 8 + group_id * 8;\n            const int offset_c      = group_lane_id / 8 * 8;\n            const int offset        = pipe_iter * SmemLayoutK::kSize;\n            if constexpr (kUseSmemQ) {\n                const int                    warp_id = threadIdx.x / WARP_SIZE;\n                SmemAccessor<T, SmemLayoutQ> sQ{smem_Q};\n                PRAGMA_UNROLL\n                for (int m = 0; m < K_M; ++m) {\n                    const int qi = lane_id % 16 * 1 + m * 16 + warp_id * WARP_Q;\n                    const int di = lane_id / 16 * 8 + k * 16;\n                    ldsm_x4((Array<uint32_t, 4>&)frag_Q[k][m], cast_smem_ptr_to_uint(&sQ(qi, di)));\n                }\n            }\n            PRAGMA_UNROLL\n            for (int n = 0; n < K_N; n += 2) {  // Load (s16,d16) tiles\n                const int s = n * 8 + offset_s;\n                const int c = k * 16 + offset_c;\n                ldsm_x4((Array<uint32_t, 4>&)frag_K[k][n], cast_smem_ptr_to_uint(&smem_K(s, c, offset)));\n            }\n        }\n\n        __device__ void Transform(int k) {}\n    };\n\n    template<class Prefetch, class Preload>\n    __device__ static void\n    ComputeQK(StateQK state_QK, FragS& frag_S, int offset, Prefetch&& prefetch, Preload&& preload)\n    {\n        PRAGMA_UNROLL\n        for (int k = 0; k < K_K; ++k) {\n            if (k < K_K - 1) {\n                state_QK.Load(k + 1, offset);\n            }\n            else {\n                ((Preload &&) preload)();\n            }\n            PRAGMA_UNROLL\n            for (int m = 0; m < K_M; ++m) {\n                PRAGMA_UNROLL\n                for (int n = 0; n < K_N; ++n) {\n                    const int nn = (Stages == 2) ? (n ^ 1) : (n ^ 2);\n                    mma_m16n8k16_row_col(frag_S[m][nn], state_QK.frag_Q[k][m], state_QK.frag_K[k][nn], frag_S[m][nn]);\n                }\n            }\n            if (k < K_K - 1) {\n                ((Prefetch &&) prefetch)(k);\n            }\n            if (k == K_K - 2) {\n                ((Prefetch &&) prefetch)(K_K - 1);\n            }\n        }\n    }\n\n    struct StatePV {\n        SmemAccessor<T, SmemLayoutV> smem_V;\n\n        FragP frag_P;\n        FragV frag_V;\n\n        __device__ StatePV(SharedStorage& storage, bool offset = false):\n            smem_V{storage.KV + (offset ? SmemLayoutK::kSize : 0)}\n        {\n        }\n\n        __device__ void Load(int k, int pipe_iter)\n        {\n            const int lane_id  = threadIdx.x % WARP_SIZE;\n            const int offset_s = lane_id % 16;\n            const int offset_c = lane_id / 16 * 8;\n            const int offset   = pipe_iter * SmemLayoutV::kSize;\n            PRAGMA_UNROLL\n            for (int n = 0; n < V_N; n += 2) {  // Load (d16,s16) tiles\n                const int s = k * 16 + offset_s;\n                const int c = n * 8 + offset_c;\n                ldsm_x4_trans((Array<uint32_t, 4>&)frag_V[k][n], cast_smem_ptr_to_uint(&smem_V(s, c, offset)));\n            }\n        }\n\n        __device__ void Transform(int k) {}\n    };\n\n    template<class Storage>\n    __device__ static void Merge(FragO& frag_O, FragM& frag_M, FragL& frag_L, float qk_scale, Storage& storage)\n    {\n        static_assert(kWarpCntS == 1);\n\n        PRAGMA_UNROLL\n        for (int m = 0; m < V_M; ++m) {\n            PRAGMA_UNROLL\n            for (int q = 0; q < 2; ++q) {\n                if constexpr (Base::kDeferReduceL) {\n                    frag_L[m][q] += __shfl_xor_sync(uint32_t(-1), frag_L[m][q], 1);\n                    frag_L[m][q] += __shfl_xor_sync(uint32_t(-1), frag_L[m][q], 2);\n                }\n            }\n        }\n    }\n\n    template<class Prefetch, class Preload>\n    __device__ static void\n    ComputePV(StatePV state_PV, FragO& frag_O, int offset, Prefetch&& prefetch, Preload&& preload)\n    {\n        PRAGMA_UNROLL\n        for (int k = 0; k < V_K; ++k) {\n            if (k < V_K - 1) {\n                state_PV.Load(k + 1, offset);\n            }\n            else {\n                ((Preload &&) preload)();\n            }\n            PRAGMA_UNROLL\n            for (int m = 0; m < V_M; ++m) {\n                PRAGMA_UNROLL\n                for (int n = 0; n < V_N; ++n) {\n                    const int nn = n ^ 0;\n                    mma_m16n8k16_row_col(frag_O[m][nn], state_PV.frag_P[m][k], state_PV.frag_V[k][nn], frag_O[m][nn]);\n                }\n            }\n            if (k < V_K - 1) {\n                ((Prefetch &&) prefetch)(k);\n            }\n            if (k == V_K - 2) {\n                ((Prefetch &&) prefetch)(V_K - 1);\n            }\n        }\n    }\n};\n\n}  // namespace turbomind::attention\n"
  },
  {
    "path": "src/turbomind/kernels/attention/impl_1688.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include \"src/turbomind/kernels/attention/impl.h\"\n#include \"src/turbomind/kernels/attention/impl_m16n8.h\"\n#include \"src/turbomind/kernels/core/layout.h\"\n#include \"src/turbomind/kernels/core/mma.h\"\n#include \"src/turbomind/kernels/core/smem.h\"\n#include \"src/turbomind/kernels/core/thread_map.h\"\n\nnamespace turbomind::attention {\n\ntemplate<class T_, int CTA_H_, int CTA_Q_, int CTA_S_, int WARP_H, int WARP_Q, int WARP_S, int HeadDim>\nstruct Impl<MMA_1688, T_, T_, CTA_H_, CTA_Q_, CTA_S_, WARP_H, WARP_Q, WARP_S, HeadDim, 2>:\n    Impl_m16k8<T_, WARP_H, WARP_Q, WARP_S, HeadDim> {\n\n    using Base = Impl_m16k8<T_, WARP_H, WARP_Q, WARP_S, HeadDim>;\n\n    static constexpr bool MLA = HeadDim == 576;\n\n    using Base::OP_M;\n    using Base::OP_N;\n    using Base::K_M;\n    using Base::K_N;\n    using Base::V_M;\n    using Base::V_N;\n\n    using typename Base::FragS;\n    using typename Base::FragO;\n    using typename Base::FragM;\n    using typename Base::FragL;\n\n    using Base::ForeachS;\n    using Base::Softmax;\n    using Base::ConvertStoP;\n    using Base::StoreO;\n\n    using T   = T_;\n    using Tkv = T_;\n\n    static constexpr int kHeadDim = HeadDim;\n\n    static constexpr int CTA_H = CTA_H_;\n    static constexpr int CTA_Q = CTA_Q_;\n    static constexpr int CTA_S = CTA_S_;\n\n    static constexpr int kWarpCntQ  = CTA_Q * CTA_H / WARP_Q;\n    static constexpr int kWarpCntS  = CTA_S / WARP_S;\n    static constexpr int kWarpCount = kWarpCntQ * kWarpCntS;\n\n    static constexpr int OP_K = 8;\n\n    static constexpr int K_K = HeadDim / OP_K;  // 128 / 16 = 8\n    static constexpr int V_K = WARP_S / OP_K;   //  64 / 16 = 4  -> S4\n\n    using FragQ = Array<T, 4>[K_K][K_M];  // ((q8, d4), (Dk, Qm), (q2, d2))\n                                          //    1   2     8  16     8   1\n    using FragK = Array<T, 2>[K_K][K_N];  // ((s8, d4), (Dk, Sn), (d2))\n                                          //    1   2     8   8     1\n    using FragP = Array<T, 4>[V_M][V_K];  // ((q8, s4), (Qm, Sk), (q2, s2))\n                                          //    1   2    16   8     8   1\n    using FragV = Array<T, 2>[V_K][V_N];  // ((d8, s4), (Sk, Dn), (s2))\n                                          //    1   2     8   8     1\n\n    using SmemLayoutQ = std::conditional_t<HeadDim == 128,\n                                           SmemLayoutV2<CTA_Q * CTA_H, HeadDim, 64, 128, Swizzle<3, 3, 4>>,\n                                           SmemLayoutV2<CTA_Q * CTA_H, HeadDim, 64, 64, Swizzle<3, 3, 3>>>;\n    using SmemLayoutK = std::conditional_t<HeadDim == 128,  // load by (s32,d8) tile\n                                           SmemLayoutV2<CTA_S, HeadDim, 32, 128, Swizzle<3, 3, 4>>,\n                                           SmemLayoutV2<CTA_S, HeadDim, 32, 64, Swizzle<3, 3, 3>>>;\n    using SmemLayoutV = std::conditional_t<HeadDim == 128,  // load by (s8,d32) tile\n                                           SmemLayoutV2<CTA_S, HeadDim, 16, 128, Swizzle<3, 3, 4>>,\n                                           SmemLayoutV2<CTA_S, HeadDim, 16, 64, Swizzle<3, 3, 3>>>;\n\n    using SmemLayoutKVp = void;\n\n    union SharedStorage {\n        __align__(16) T Q[SmemLayoutQ::kSize];\n        struct {\n            __align__(16) Tkv K[SmemLayoutK::kSize];\n            __align__(16) Tkv V[SmemLayoutV::kSize];\n        };\n    };\n\n    static constexpr bool kUseSmemQ = false;\n\n    using ThreadMapQ  = RakedThreadMap<HeadDim, CTA_Q * CTA_H, 8, kWarpCount>;\n    using ThreadMapKV = RakedThreadMap<HeadDim, CTA_S, 8, kWarpCount>;\n\n    using ThreadMapKVp = void;\n\n    __device__ static void Sync()\n    {\n        __syncthreads();\n    }\n\n    template<class GmemIterK, class GmemIterV>\n    __device__ static void SetSmemKV(GmemIterK& gmem_K, GmemIterV& gmem_V, SharedStorage& storage, bool offset_kv)\n    {\n        gmem_K.SetSmem(storage.K);\n        gmem_V.SetSmem(storage.V);\n    }\n\n    __device__ static void TransformQ(T* smem_Q, FragQ& frag_Q)\n    {\n        const int warp_id = threadIdx.x / WARP_SIZE;\n        const int lane_id = threadIdx.x % WARP_SIZE;\n\n        __syncwarp();\n\n        SmemAccessor<T, SmemLayoutQ> sQ{smem_Q};\n        if constexpr (!kUseSmemQ) {\n            // Load from shared memory using LDSM, rearrange to m16n8k16 atom layout\n            PRAGMA_UNROLL\n            for (int m = 0; m < K_M; ++m) {\n                PRAGMA_UNROLL\n                for (int k = 0; k < K_K; k += 2) {\n                    const int qi = lane_id % 16 * 1 + m * 16 + warp_id * WARP_Q;\n                    const int di = lane_id / 16 * 8 + k * 8;\n                    ldsm_x4((Array<uint32_t, 4>&)frag_Q[k][m], cast_smem_ptr_to_uint(&sQ(qi, di)));\n                }\n            }\n        }\n        else {\n            static_assert(!std::is_same_v<T, T>, \"not supported\");\n        }\n    }\n\n    struct StateQK {\n        SmemAccessor<T, SmemLayoutK> smem_K;\n\n        FragQ frag_Q;\n        FragK frag_K;\n\n        __device__ StateQK(SharedStorage& storage, FragQ frag_Q_): smem_K{storage.K}\n        {\n            static_assert(!kUseSmemQ, \"not implemented\");\n            PRAGMA_UNROLL\n            for (int k = 0; k < K_K; ++k) {\n                PRAGMA_UNROLL\n                for (int m = 0; m < K_M; ++m) {\n                    frag_Q[k][m] = frag_Q_[k][m];\n                }\n            }\n        }\n\n        __device__ void Load(int k, int pipe_iter)\n        {\n            const int lane_id = threadIdx.x % WARP_SIZE;\n            PRAGMA_UNROLL\n            for (int n = 0; n < K_N; n += 4) {  // Load (s32,d8) tiles\n                const int s = n * 8 + lane_id;\n                const int c = k * 8;\n                ldsm_x4((Array<uint32_t, 4>&)frag_K[k][n], cast_smem_ptr_to_uint(&smem_K(s, c)));\n            }\n        }\n\n        __device__ void Transform(int k) {}\n    };\n\n    template<class Prefetch, class Preload>\n    __device__ static void\n    ComputeQK(StateQK& state_QK, FragS& frag_S, int offset, Prefetch&& prefetch, Preload&& preload)\n    {\n        PRAGMA_UNROLL\n        for (int k = 0; k < K_K; ++k) {\n            if (k < K_K - 1) {\n                state_QK.Load(k + 1, offset);\n            }\n            else {\n                ((Preload &&) preload)();\n            }\n            PRAGMA_UNROLL\n            for (int m = 0; m < K_M; ++m) {\n                PRAGMA_UNROLL\n                for (int n = 0; n < K_N; ++n) {\n                    const int nn = n ^ 2;\n                    mma_m16n8k8_row_col(frag_S[m][nn], state_QK.frag_Q[k][m], state_QK.frag_K[k][nn], frag_S[m][nn]);\n                }\n            }\n        }\n    }\n\n    struct StatePV {\n        SmemAccessor<T, SmemLayoutV> smem_V;\n\n        FragP frag_P;\n        FragV frag_V;\n\n        __device__ StatePV(SharedStorage& storage, bool offset = true): smem_V{storage.V}\n        {\n            assert(offset);\n        }\n\n        __device__ void Load(int k, int pipe_iter)\n        {\n            const int lane_id = threadIdx.x % WARP_SIZE;\n            PRAGMA_UNROLL\n            for (int n = 0; n < V_N; n += 4) {  // Load (d32,s8) tiles\n                const int si = k * 8 + lane_id % 8;\n                const int di = n * 8 + lane_id / 8 * 8;\n                ldsm_x4_trans((Array<uint32_t, 4>&)frag_V[k][n], cast_smem_ptr_to_uint(&smem_V(si, di)));\n            }\n        }\n\n        __device__ void Transform(int k) {}\n    };\n\n    template<class Prefetch, class Preload>\n    __device__ static void\n    ComputePV(StatePV& state_PV, FragO& frag_O, int offset, Prefetch&& prefetch, Preload&& preload)\n    {\n        PRAGMA_UNROLL\n        for (int k = 0; k < V_K; ++k) {\n            if (k < V_K - 1) {\n                state_PV.Load(k + 1, offset);\n            }\n            else {\n                ((Preload &&) preload)();\n            }\n            PRAGMA_UNROLL\n            for (int m = 0; m < V_M; ++m) {\n                PRAGMA_UNROLL\n                for (int n = 0; n < V_N; ++n) {\n                    mma_m16n8k8_row_col(frag_O[m][n], state_PV.frag_P[m][k], state_PV.frag_V[k][n], frag_O[m][n]);\n                }\n            }\n        }\n    }\n};\n\n}  // namespace turbomind::attention\n"
  },
  {
    "path": "src/turbomind/kernels/attention/impl_81616.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include \"src/turbomind/kernels/attention/impl.h\"\n#include \"src/turbomind/kernels/attention/quantization.h\"\n#include \"src/turbomind/kernels/core/array_ops.h\"\n#include \"src/turbomind/kernels/core/layout.h\"\n#include \"src/turbomind/kernels/core/mma.h\"\n#include \"src/turbomind/kernels/core/smem.h\"\n#include \"src/turbomind/kernels/core/thread_map.h\"\n#include <type_traits>\n\nnamespace turbomind::attention {\n\ntemplate<class T_,\n         class Tkv_,\n         int CTA_H_,\n         int CTA_Q_,\n         int CTA_S_,\n         int WARP_H_,\n         int WARP_Q,\n         int WARP_S,\n         int HeadDim,\n         int Stages>\nstruct Impl<MMA_81616, T_, Tkv_, CTA_H_, CTA_Q_, CTA_S_, WARP_H_, WARP_Q, WARP_S, HeadDim, Stages> {\n    using T   = T_;\n    using Tkv = Tkv_;\n\n    static constexpr int kQuantKV = !std::is_same_v<T, Tkv>;\n\n    static constexpr bool MLA = HeadDim == 576;\n\n    static constexpr int CTA_H = CTA_H_;\n    static constexpr int CTA_Q = CTA_Q_;\n    static constexpr int CTA_S = CTA_S_;\n\n    static_assert(CTA_Q == 1);\n\n    static constexpr int WARP_H = WARP_H_;\n\n    static constexpr int kHeadDim = HeadDim;\n\n    static constexpr int kWarpCntH = CTA_H / WARP_H;\n    static constexpr int kWarpCntQ = CTA_Q / WARP_Q;\n    static constexpr int kWarpCntS = CTA_S / WARP_S;\n\n    static constexpr int kWarpCount = kWarpCntH * kWarpCntQ * kWarpCntS;\n\n    static constexpr int OP_M = 16;\n    static constexpr int OP_N = 8;\n    static constexpr int OP_K = 16;\n\n    static constexpr int K_M = WARP_S / OP_M;               // 1\n    static constexpr int K_N = (WARP_H + OP_N - 1) / OP_N;  // 1\n    static constexpr int K_K = HeadDim / OP_K;              // 8\n\n    static constexpr int V_M = HeadDim / OP_M;              // 8\n    static constexpr int V_N = (WARP_H + OP_N - 1) / OP_N;  // 1\n    static constexpr int V_K = WARP_S / OP_K;               // 1\n\n    using FragK = Array<T, 8>[K_K][K_M];      // (s8,d4) (Dk,Sm) (d2,s2,d2)\n                                              //   1  2   16 16    8  8  1\n    using FragQ = Array<T, 4>[K_N][K_K];      // (q8,d4) (Qn,Dk) (d2,d2)\n                                              //   1  2    8,16    8  1\n    using FragS = Array<float, 4>[K_M][K_N];  // (s8,q4) (Sm,Qn) (s2,q2)\n                                              //   1  2   16  8    8  1\n    using FragV = Array<T, 8>[V_M][V_K];      // (d8,s4) (Dm,Sk) (s2,d2,s2)\n                                              //   1  2   16 16    8  8  1\n    using FragP = Array<T, 4>[V_K][V_N];      // (q8,s4) (Sk,Qn) (s2,s2)\n                                              //   1  2   16  8    8  1\n    using FragO = Array<float, 4>[V_M][V_N];  // (d8,q4) (Dm,Qn) (d2,q2)\n                                              //   1  2   16  8    8  1\n    using FragM = Array<float, 2>[K_N];       // (_8,q4)    (Qn)    (q2)\n                                              //      2       8       1\n\n    static constexpr int X = 16 / bitsof<Tkv>;\n\n    using DataK = Array<Tkv, 8 * X>[K_K / X][K_M];  // {s8,d4} [Dk/x,Sm] (d2,s2,dx,d2)\n                                                    //   1 2x    16x 16   8x  8  2  1\n    using ParamK = Array<T, 2>[K_M][2];             // {s8,_4} [     Sm] (   s2      )\n                                                    //   1  0        16       8\n    using DataV = Array<Tkv, 8 * X>[V_M / X][V_K];  // {s8,d4} [Dm/x,Sk] (s2,d2,dx,d2)\n                                                    //   1 2x    16x 16    8 8x  2  1\n    using ParamV = Array<T, 2>[V_K][2];             // {s8,_4} [     Sk] (s2         )\n                                                    //   1  0        16    8\n\n    using FragL = FragM;\n\n    using SmemM = Array<float, 2>[K_N][kWarpCntH][kWarpCntS][4];\n\n    using SmemO = Array<float, 4>[V_M][V_N][kWarpCntH][kWarpCntS][WARP_SIZE];\n\n    static constexpr bool kUseSmemQ = false;\n    static constexpr bool kUseSmemP = false;\n\n    static constexpr int CTA_H1 = (CTA_H + OP_N - 1) / OP_N * OP_N;\n\n    static constexpr auto _SmemLayoutKV(std::integral_constant<int, 16>)\n    {\n        return SmemLayoutV2<CTA_S, HeadDim, 16, 64, Swizzle<3, 3, 3>>{};\n    }\n    static constexpr auto _SmemLayoutKV(std::integral_constant<int, 8>)\n    {\n        return SmemLayoutV2<CTA_S, HeadDim, 32, 64, Swizzle<3, 4, 3>>{};\n    }\n    static constexpr auto _SmemLayoutKV(std::integral_constant<int, 4>)\n    {\n        return std::conditional_t<HeadDim % 128 == 0,\n                                  SmemLayoutV2<CTA_S, HeadDim, 32, 128, Swizzle<2, 5, 3>>,\n                                  SmemLayoutV2<CTA_S, HeadDim, 32, 64, Swizzle<3, 4, 3>>>{};\n    }\n\n    using SmemLayoutQ = SmemLayoutV2<CTA_H1, HeadDim, CTA_H1, HeadDim, Swizzle<3, 3, 4>>;\n    using SmemLayoutK = decltype(_SmemLayoutKV(bitsof<Tkv>));\n    using SmemLayoutV = decltype(_SmemLayoutKV(bitsof<Tkv>));\n\n    using SmemLayoutKVp = SmemLayoutV2<CTA_S, 2, CTA_S, 2, Identity>;\n\n    using PointerKV = get_pointer_type<Tkv>;\n\n    union SharedStorage {\n        __align__(16) T Q[SmemLayoutQ::kSize];\n\n        struct {\n            __align__(16) Array<Tkv, Stages * SmemLayoutK::kSize> KV;\n            __align__(16) T KVp[Stages * SmemLayoutKVp::kSize];\n        };\n\n        struct {\n            __align__(16) SmemM M;\n            __align__(16) SmemM L;\n            __align__(16) SmemO O;\n        };\n\n        __align__(16) float O1[CTA_H1][kHeadDim];\n    };\n\n    using ThreadMapQ  = RakedThreadMap<HeadDim, CTA_H1, 8, kWarpCount>;\n    using ThreadMapKV = RakedThreadMap<HeadDim, CTA_S, 128 / bitsof<Tkv>, kWarpCount>;\n    // `WARP_SIZE / WARP_S` is chosen to achieve minimum kIterS w/o introducing partial S iter\n    using ThreadMapKVp = RakedThreadMap<2, CTA_S, 2, kWarpCount, WARP_SIZE / WARP_S>;\n\n    static constexpr int kBatchK = ThreadMapKV::kIterS;\n    static constexpr int kBatchV = ThreadMapKV::kIterS;\n\n    static constexpr bool kDeferReduceL = true;\n\n    __device__ static void Sync()\n    {\n        if constexpr (kWarpCntH > 1) {\n            __syncthreads();\n        }\n        else if constexpr (kQuantKV) {  // Thread layout of KV & KVp is different within warp boundary\n            __syncwarp();\n        }\n    }\n\n    template<class GmemIterK, class GmemIterV>\n    __device__ static void SetSmemKV(GmemIterK& gmem_K, GmemIterV& gmem_V, SharedStorage& storage, bool offset_kv)\n    {\n        int pred = offset_kv;\n        if constexpr (kQuantKV) {\n            gmem_K.SetSmem(storage.KV.data(), storage.KVp);\n            gmem_V.SetSmem(storage.KV.data() + pred * SmemLayoutK::kSize, storage.KVp + pred * SmemLayoutKVp::kSize);\n        }\n        else {\n            gmem_K.SetSmem(storage.KV.data());\n            gmem_V.SetSmem(storage.KV.data() + pred * SmemLayoutK::kSize);\n        }\n    }\n\n    static __device__ int2 get_warp_ids()\n    {\n        const int warp_id = threadIdx.x / WARP_SIZE;\n        if constexpr (kWarpCntH > 1) {\n            return {warp_id % kWarpCntS, warp_id / kWarpCntS};\n        }\n        else {\n            return {warp_id, 0};\n        }\n    }\n\n    template<class Fragment, class Func>\n    __device__ static void ForeachS(Fragment& S, Func&& func)\n    {\n        const auto warp_ids = get_warp_ids();\n        const int  lane_id  = threadIdx.x % WARP_SIZE;\n\n        PRAGMA_UNROLL\n        for (int m = 0; m < K_M; ++m) {\n            PRAGMA_UNROLL\n            for (int n = 0; n < K_N; ++n) {\n                PRAGMA_UNROLL\n                for (int s = 0; s < 2; ++s) {\n                    PRAGMA_UNROLL\n                    for (int q = 0; q < 2; ++q) {\n                        const int si = m * OP_M + lane_id / 4 * 1 + s * 8 + warp_ids.x * WARP_S;\n                        const int hi = n * OP_N + lane_id % 4 * 2 + q * 1 + warp_ids.y * WARP_H;\n                        ((Func &&) func)(hi, /*qi*/ 0, si, /*ri*/ 0, S[m][n][s * 2 + q]);\n                    }\n                }\n            }\n        }\n    }\n\n    template<class Func>\n    __device__ static void ForeachML(FragM& frag_M, FragL& frag_L, Func&& func)\n    {\n        const auto warp_ids = get_warp_ids();\n        const int  lane_id  = threadIdx.x % WARP_SIZE;\n\n        PRAGMA_UNROLL\n        for (int n = 0; n < K_N; ++n) {  // Q\n            PRAGMA_UNROLL\n            for (int q = 0; q < 2; ++q) {\n                const int hi = lane_id % 4 * 2 + n * OP_N + q * 1 + warp_ids.y * WARP_H;\n                const int ri = lane_id / 4 * 1;\n                ((Func &&) func)(hi, /*qi*/ 0, ri, frag_M[n][q], frag_L[n][q]);\n            }\n        }\n    }\n\n    __device__ static void TransformQ(T* smem_Q, FragQ& frag_Q)\n    {\n        static_assert(K_K % 2 == 0);\n        SmemAccessor<T, SmemLayoutQ> sQ{smem_Q};\n\n        const auto warp_ids = get_warp_ids();\n        const int  lane_id  = threadIdx.x % WARP_SIZE;\n\n        if constexpr (!kQuantKV) {\n            PRAGMA_UNROLL\n            for (int n = 0; n < K_N; ++n) {\n                PRAGMA_UNROLL\n                for (int k = 0; k < K_K; k += 2) {  // 16x16 tile\n                    const int hi = n * OP_N + lane_id % 8 + warp_ids.y * WARP_H;\n                    const int di = k * OP_K + lane_id / 8 * 8;\n                    ldsm_x4((Array<uint32_t, 4>&)frag_Q[n][k], cast_smem_ptr_to_uint(&sQ(hi, di)));\n                }\n            }\n        }\n        else {\n            PRAGMA_UNROLL\n            for (int n = 0; n < K_N; ++n) {\n                PRAGMA_UNROLL\n                for (int k = 0; k < K_K; k += X) {\n                    PRAGMA_UNROLL\n                    for (int x = 0; x < X; ++x) {\n                        PRAGMA_UNROLL\n                        for (int d = 0; d < 2; ++d) {  // (s8,d8)\n                            const int hi = n * OP_N + lane_id / 4 + warp_ids.y * WARP_H;\n                            const int di = k * OP_K + lane_id % 4 * 2 * X + x * 2 + d * 8 * X;\n                            Load((Array<T, 2>&)frag_Q[n][k + x][d * 2], &sQ(hi, di));\n                        }\n                    }\n                }\n            }\n        }\n    }\n\n    struct StateQK {\n        PointerKV smem_K;\n        T*        smem_K_param;\n        FragQ     frag_Q;\n        ParamK    param_K;\n        DataK     data_K;\n        FragK     frag_K;\n\n        __device__ StateQK(SharedStorage& storage, FragQ frag_Q_)\n        {\n            smem_K       = storage.KV.data();\n            smem_K_param = storage.KVp;\n            static_assert(!kUseSmemQ, \"not implemented\");\n            PRAGMA_UNROLL\n            for (int n = 0; n < K_N; ++n) {\n                PRAGMA_UNROLL\n                for (int k = 0; k < K_K; ++k) {\n                    frag_Q[n][k] = frag_Q_[n][k];\n                }\n            }\n        }\n\n        __device__ void Load(int k, int pipe_iter)\n        {\n            const auto warp_ids = get_warp_ids();\n            const int  lane_id  = threadIdx.x % WARP_SIZE;\n\n            if (kQuantKV && k == 0) {\n                static_assert(K_M == 1);\n                const int m = 0;\n                PRAGMA_UNROLL\n                for (int s = 0; s < 2; ++s) {\n                    const int si = m * 16 + lane_id / 4 * 1 + s * 8 + warp_ids.x * WARP_S;\n                    Lds(param_K[m][s], &smem_K_param[pipe_iter * SmemLayoutKVp::kSize + SmemLayoutKVp::apply(si, 0)]);\n                }\n            }\n\n            if (k % X == 0) {\n                const int offset_s = lane_id % 16 * 1 + warp_ids.x * WARP_S;\n                const int offset_c = lane_id / 16 * 8 * X;\n                PRAGMA_UNROLL\n                for (int m = 0; m < K_M; ++m) {\n                    const int s = m * 16 + offset_s;  // Q\n                    const int c = k * 16 + offset_c;  // D\n                    static_assert(sizeof(data_K[k / X][m]) == 16);\n                    ldsm_x4((Array<uint32_t, 4>&)data_K[k / X][m],\n                            cast_smem_ptr_to_uint(&smem_K[pipe_iter * SmemLayoutK::kSize + SmemLayoutK::apply(s, c)]));\n                }\n            }\n        }\n\n        __device__ void Transform(int k)\n        {\n            if constexpr (!kQuantKV) {\n                PRAGMA_UNROLL\n                for (int m = 0; m < K_M; ++m) {\n                    frag_K[k][m] = data_K[k][m];\n                }\n            }\n            else {  // this also covers non-quantized case, but it's too convolved to read\n                static_assert(K_M == 1);\n                if (k % X == 0) {\n                    using Converter = ConvertKvCache<Tkv, T>;\n                    PRAGMA_UNROLL\n                    for (int s = 0; s < 2; ++s) {\n                        PRAGMA_UNROLL\n                        for (int d = 0; d < 2; ++d) {\n                            auto dx_d2 =\n                                Converter::convert((Array<Tkv, X * 2>&)data_K[k / X][0][d * 4 * X + s * 2 * X]);\n                            PRAGMA_UNROLL\n                            for (int x = 0; x < X; ++x) {\n                                (Array<short, 2>&)frag_K[k + x][0][d * 4 + s * 2] = (Array<short, 2>&)dx_d2[x * 2];\n                            }\n                        }\n                    }\n                }\n                PRAGMA_UNROLL\n                for (int s = 0; s < 2; ++s) {\n                    PRAGMA_UNROLL\n                    for (int d = 0; d < 2; ++d) {\n                        auto& d2 = (Array<T, 2>&)frag_K[k][0][d * 4 + s * 2];\n                        PRAGMA_UNROLL\n                        for (int i = 0; i < 2; ++i) {\n                            d2[i] = __hfma(d2[i], param_K[0][s][0], param_K[0][s][1]);\n                        }\n                    }\n                }\n            }\n        }\n    };\n\n    template<class Prefetch, class Preload>\n    __device__ static void\n    ComputeQK(StateQK state_QK, FragS& frag_S, int offset, Prefetch&& prefetch, Preload&& preload)\n    {\n        if constexpr (K_K == 1) {\n            ((Prefetch &&) prefetch)(0);\n        }\n\n        PRAGMA_UNROLL\n        for (int k = 0; k < K_K; ++k) {\n            if (k < K_K - 1) {\n                state_QK.Load(k + 1, offset);\n            }\n            else {\n                ((Preload &&) preload)();\n            }\n\n            state_QK.Transform(k);\n\n            PRAGMA_UNROLL\n            for (int m = 0; m < K_M; ++m) {\n                PRAGMA_UNROLL\n                for (int n = 0; n < K_N; ++n) {\n                    mma_m16n8k16_row_col(frag_S[m][n], state_QK.frag_K[k][m], state_QK.frag_Q[n][k], frag_S[m][n]);\n                }\n            }\n            if (k < K_K - 1) {\n                ((Prefetch &&) prefetch)(k);\n            }\n            if (k == K_K - 2) {\n                ((Prefetch &&) prefetch)(K_K - 1);\n            }\n        }\n    }\n\n    struct StatePV {\n        PointerKV smem_V;\n        T*        smem_V_param;\n        ParamV    param_V;\n        DataV     data_V;\n        FragP     frag_P;\n        FragV     frag_V;\n\n        __device__ StatePV(SharedStorage& storage, bool offset = false)\n        {\n            smem_V       = storage.KV.data() + (offset ? SmemLayoutK::kSize : 0);\n            smem_V_param = storage.KVp + (offset ? SmemLayoutKVp::kSize : 0);\n        }\n\n        __device__ void Load(int m, int pipe_iter)\n        {\n            const auto warp_ids = get_warp_ids();\n            const int  lane_id  = threadIdx.x % WARP_SIZE;\n\n            if (kQuantKV && m == 0) {\n                static_assert(V_K == 1);\n                const int k = 0;\n                PRAGMA_UNROLL\n                for (int s = 0; s < 2; ++s) {\n                    const int si = k * 16 + lane_id / 4 * 1 + s * 8 + warp_ids.x * WARP_S;\n                    Lds(param_V[k][s], &smem_V_param[pipe_iter * SmemLayoutKVp::kSize + SmemLayoutKVp::apply(si, 0)]);\n                }\n            }\n\n            if (m % X == 0) {\n                const int offset_s = lane_id / 16 * 8 + lane_id % 8 + warp_ids.x * WARP_S;\n                const int offset_c = lane_id % 16 / 8 * 8 * X;\n                PRAGMA_UNROLL\n                for (int k = 0; k < V_K; ++k) {\n                    const int s = k * 16 + offset_s;\n                    const int c = m * 16 + offset_c;\n                    static_assert(sizeof(data_V[m / X][k]) == 16);\n                    if constexpr (!kQuantKV) {\n                        ldsm_x4_trans(\n                            (Array<uint32_t, 4>&)data_V[m / X][k],\n                            cast_smem_ptr_to_uint(&smem_V[pipe_iter * SmemLayoutV::kSize + SmemLayoutV::apply(s, c)]));\n                    }\n                    else {\n                        ldsm_x4(\n                            (Array<uint32_t, 4>&)data_V[m / X][k],\n                            cast_smem_ptr_to_uint(&smem_V[pipe_iter * SmemLayoutV::kSize + SmemLayoutV::apply(s, c)]));\n                    }\n                }\n            }\n        }\n\n        __device__ void Transform(int m)\n        {\n            if constexpr (!kQuantKV) {\n                PRAGMA_UNROLL\n                for (int k = 0; k < V_K; ++k) {\n                    frag_V[m][k] = data_V[m][k];\n                }\n            }\n            else {\n                static_assert(V_K == 1);\n                if (m % X == 0) {\n                    PRAGMA_UNROLL\n                    for (int s = 0; s < 2; ++s) {\n                        PRAGMA_UNROLL\n                        for (int d = 0; d < 2; ++d) {\n                            auto dx_d2 = ConvertKvCache<Tkv, T>::convert(\n                                (Array<Tkv, 2 * X>&)data_V[m / X][0][s * 4 * X + d * 2 * X]);\n                            PRAGMA_UNROLL\n                            for (int x = 0; x < X; ++x) {\n                                (Array<T, 2>&)frag_V[m + x][0][s * 4 + d * 2] = (Array<T, 2>&)dx_d2[x * 2];\n                            }\n                        }\n                    }\n                }\n                PRAGMA_UNROLL\n                for (int s = 0; s < 2; ++s) {\n                    PRAGMA_UNROLL\n                    for (int d = 0; d < 2; ++d) {\n                        auto& d2 = (Array<T, 2>&)frag_V[m][0][s * 4 + d * 2];\n                        PRAGMA_UNROLL\n                        for (int i = 0; i < 2; ++i) {\n                            d2[i] = __hfma(d2[i], param_V[0][s][0], param_V[0][s][1]);\n                        }\n                        (uint32_t&)d2 = transpose_m8n8_b16((uint32_t&)d2);\n                    }\n                }\n            }\n        }\n    };\n\n    template<class Prefetch, class Preload>\n    __device__ static void\n    ComputePV(StatePV state_PV, FragO& frag_O, int offset, Prefetch&& prefetch, Preload&& preload)\n    {\n        PRAGMA_UNROLL\n        for (int m = 0; m < V_M; ++m) {\n            if (m < V_M - 1) {\n                state_PV.Load(m + 1, offset);\n            }\n            else {\n                ((Preload &&) preload)();\n            }\n\n            state_PV.Transform(m);\n\n            PRAGMA_UNROLL\n            for (int k = 0; k < V_K; ++k) {\n                PRAGMA_UNROLL\n                for (int n = 0; n < V_N; ++n) {\n                    mma_m16n8k16_row_col(frag_O[m][n], state_PV.frag_V[m][k], state_PV.frag_P[k][n], frag_O[m][n]);\n                }\n            }\n            if (m < V_M - 1) {\n                ((Prefetch &&) prefetch)(m);\n            }\n            if (m == V_M - 2) {\n                ((Prefetch &&) prefetch)(V_M - 1);\n            }\n        }\n    }\n\n    template<bool is_residue>\n    __device__ static void Softmax(FragS& frag_S, FragM& frag_M, FragM& frag_L, FragO& frag_O, float qk_scale)\n    {\n        FragM prev_M;\n        copy(frag_M, prev_M);\n\n        PRAGMA_UNROLL\n        for (int n = 0; n < K_N; ++n) {  // h\n            PRAGMA_UNROLL\n            for (int m = 0; m < K_M; ++m) {  // s\n                PRAGMA_UNROLL\n                for (int s = 0; s < 2; ++s) {\n                    PRAGMA_UNROLL\n                    for (int q = 0; q < 2; ++q) {\n                        frag_M[n][q] = fmaxf(frag_M[n][q], frag_S[m][n][s * 2 + q]);\n                    }\n                }\n            }\n        }\n\n        PRAGMA_UNROLL\n        for (int n = 0; n < K_N; ++n) {\n            PRAGMA_UNROLL\n            for (int q = 0; q < 2; ++q) {\n                frag_M[n][q] = fmaxf(frag_M[n][q], __shfl_xor_sync(uint32_t(-1), frag_M[n][q], 4));\n                frag_M[n][q] = fmaxf(frag_M[n][q], __shfl_xor_sync(uint32_t(-1), frag_M[n][q], 8));\n                frag_M[n][q] = fmaxf(frag_M[n][q], __shfl_xor_sync(uint32_t(-1), frag_M[n][q], 16));\n            }\n        }\n\n        FragM expdiff_M;\n        PRAGMA_UNROLL\n        for (int n = 0; n < K_N; ++n) {\n            PRAGMA_UNROLL\n            for (int q = 0; q < 2; ++q) {\n                expdiff_M[n][q] = exp2f((prev_M[n][q] - frag_M[n][q]) * qk_scale);\n                if (is_residue && frag_M[n][q] == -std::numeric_limits<float>::infinity()) {\n                    expdiff_M[n][q] = 0.f;\n                }\n                frag_L[n][q] *= expdiff_M[n][q];\n            }\n        }\n\n        PRAGMA_UNROLL\n        for (int m = 0; m < V_M; ++m) {\n            PRAGMA_UNROLL\n            for (int n = 0; n < V_N; ++n) {\n                PRAGMA_UNROLL\n                for (int d = 0; d < 2; ++d) {\n                    PRAGMA_UNROLL\n                    for (int q = 0; q < 2; ++q) {\n                        frag_O[m][n][d * 2 + q] *= expdiff_M[n][q];  // Rescale previous output\n                    }\n                }\n            }\n        }\n\n        PRAGMA_UNROLL\n        for (int n = 0; n < K_N; ++n) {\n            PRAGMA_UNROLL\n            for (int q = 0; q < 2; ++q) {\n                float tmp_L{};\n                PRAGMA_UNROLL\n                for (int m = 0; m < K_M; ++m) {\n                    PRAGMA_UNROLL\n                    for (int s = 0; s < 2; ++s) {\n                        float p = exp2f(frag_S[m][n][s * 2 + q] * qk_scale - frag_M[n][q] * qk_scale);\n                        if (is_residue && frag_M[n][q] == -std::numeric_limits<float>::infinity()) {\n                            p = 0.f;\n                        }\n                        tmp_L += p;\n                        frag_S[m][n][s * 2 + q] = p;\n                    }\n                }\n                if constexpr (!kDeferReduceL) {\n                    tmp_L += __shfl_xor_sync(uint32_t(-1), tmp_L, 4);\n                    tmp_L += __shfl_xor_sync(uint32_t(-1), tmp_L, 8);\n                    tmp_L += __shfl_xor_sync(uint32_t(-1), tmp_L, 16);\n                }\n                frag_L[n][q] += tmp_L;  // update L\n            }\n        }\n    }\n\n    __device__ static void ConvertStoP(FragS& frag_S, FragP& frag_P, SharedStorage&)\n    {\n        static_assert(K_M == V_K);\n\n        PRAGMA_UNROLL\n        for (int m = 0; m < K_M; ++m) {\n            PRAGMA_UNROLL\n            for (int n = 0; n < K_N; ++n) {\n                PRAGMA_UNROLL\n                for (int s = 0; s < 2; ++s) {\n                    Array<T, 2> tmp_P;\n                    PRAGMA_UNROLL\n                    for (int q = 0; q < 2; ++q) {\n                        tmp_P[q] = static_cast<T>(frag_S[m][n][s * 2 + q]);\n                    }\n                    // (s8,q4),(s2,q2) -> (q8,s4),(s2,s2)\n                    //   1  2    8  1       1  2    8  1\n                    (uint32_t&)tmp_P = transpose_m8n8_b16((uint32_t&)tmp_P);\n\n                    (Array<T, 2>&)frag_P[m][n][s * 2] = tmp_P;\n                }\n            }\n        }\n    }\n\n    __device__ static void Merge(FragO& frag_O, FragM& frag_M, FragL& frag_L, float qk_scale, SharedStorage& storage)\n    {\n        if constexpr (kWarpCntS == 1 && !kDeferReduceL) {\n            __syncthreads();\n            return;\n        }\n\n        const auto warp_ids = get_warp_ids();\n        const int  lane_id  = threadIdx.x % WARP_SIZE;\n\n        FragM prev_M;\n        copy(frag_M, prev_M);\n\n        __syncthreads();\n\n        /////////////////////////////////////////////////////////////////////////\n        //  global max\n        PRAGMA_UNROLL\n        for (int n = 0; n < K_N; ++n) {\n            if (lane_id < 4) {\n                Store((float*)&storage.M[n][warp_ids.y][warp_ids.x][lane_id], frag_M[n]);\n            }\n        }\n\n        __syncthreads();\n\n        PRAGMA_UNROLL\n        for (int n = 0; n < K_N; ++n) {\n            // Compute global maximum\n            PRAGMA_UNROLL\n            for (int q = 0; q < 2; ++q) {\n                PRAGMA_UNROLL\n                for (int w = 0; w < kWarpCntS - 1; ++w) {\n                    const int src_warp = (warp_ids.x + w + 1) % kWarpCntS;\n                    frag_M[n][q]       = fmaxf(frag_M[n][q], storage.M[n][warp_ids.y][src_warp][lane_id % 4][q]);\n                }\n                // if (lane_id < 4) {\n                //     printf(\"M %d %d %f\\n\", lane_id % 4 * 2 + q, warp_id, frag_M[n][q]);\n                // }\n            }\n        }\n\n        // if (threadIdx.x == 0) {\n        //     printf(\"M %d %f\\n\", 0, frag_M[0][0]);\n        // }\n\n        ///////////////////////////////////////////////////////////////////////////\n        //  rescale & global sum\n\n        FragM expdiff_M;\n        PRAGMA_UNROLL\n        for (int n = 0; n < V_N; ++n) {\n            PRAGMA_UNROLL\n            for (int q = 0; q < 2; ++q) {\n                expdiff_M[n][q] = exp2f((prev_M[n][q] - frag_M[n][q]) * qk_scale);\n                if (frag_M[n][q] == -std::numeric_limits<float>::infinity()) {\n                    expdiff_M[n][q] = 0.f;\n                }\n            }\n            PRAGMA_UNROLL\n            for (int m = 0; m < V_M; ++m) {\n                PRAGMA_UNROLL\n                for (int d = 0; d < 2; ++d) {\n                    PRAGMA_UNROLL\n                    for (int q = 0; q < 2; ++q) {\n                        frag_O[m][n][d * 2 + q] *= expdiff_M[n][q];\n                    }\n                }\n                Store((float*)&storage.O[m][n][warp_ids.y][warp_ids.x][lane_id], frag_O[m][n]);\n            }\n            PRAGMA_UNROLL\n            for (int q = 0; q < 2; ++q) {\n                frag_L[n][q] *= expdiff_M[n][q];\n                if constexpr (kDeferReduceL) {\n                    frag_L[n][q] += __shfl_xor_sync(uint32_t(-1), frag_L[n][q], 4);\n                    frag_L[n][q] += __shfl_xor_sync(uint32_t(-1), frag_L[n][q], 8);\n                    frag_L[n][q] += __shfl_xor_sync(uint32_t(-1), frag_L[n][q], 16);\n                }\n            }\n            if (lane_id < 4) {\n                Store((float*)&storage.L[n][warp_ids.y][warp_ids.x][lane_id], frag_L[n]);\n            }\n        }\n\n        __syncthreads();\n\n        clear(frag_O);\n        clear(frag_L);\n\n        PRAGMA_UNROLL\n        for (int n = 0; n < V_N; ++n) {\n            PRAGMA_UNROLL\n            for (int w = 0; w < kWarpCntS; ++w) {\n                using namespace ops;\n                PRAGMA_UNROLL\n                for (int m = 0; m < V_M; ++m) {\n                    Array<float, 4> tmp_O;\n                    Load(tmp_O, storage.O[m][n][warp_ids.y][w][lane_id].data());\n                    frag_O[m][n] = frag_O[m][n] + tmp_O;\n                }\n                frag_L[n] = frag_L[n] + storage.L[n][warp_ids.y][w][lane_id % 4];\n            }\n            // PRAGMA_UNROLL\n            // for (int q = 0; q < 2; ++q) {\n            //     if (lane_id < 4) {\n            //         printf(\"L %d %d %f\\n\", lane_id % 4 * 2 + q, warp_id, frag_L[n][q]);\n            //     }\n            // }\n\n            // if (threadIdx.x == 0) {\n            //     printf(\"L %d %f\\n\", 0, frag_L[0][0]);\n            // }\n        }\n    }\n\n    template<bool is_norm, class Func>\n    __device__ static void StoreO(FragO& frag_O, const FragL& frag_L, SharedStorage& storage, Func&& func)\n    {\n        const auto warp_ids = get_warp_ids();\n        const int  lane_id  = threadIdx.x % WARP_SIZE;\n\n        FragL inv_L;\n        PRAGMA_UNROLL\n        for (int n = 0; n < V_N; ++n) {\n            PRAGMA_UNROLL\n            for (int q = 0; q < 2; ++q) {\n                inv_L[n][q] = fdividef(1.f, frag_L[n][q]);\n            }\n        }\n\n        __syncthreads();\n\n        PRAGMA_UNROLL\n        for (int m = 0; m < V_M; m += X) {\n            PRAGMA_UNROLL\n            for (int x = 0; x < X; ++x) {\n                PRAGMA_UNROLL\n                for (int n = 0; n < V_N; ++n) {\n                    PRAGMA_UNROLL\n                    for (int d = 0; d < 2; ++d) {\n                        if constexpr (is_norm) {\n                            using namespace ops;\n                            (Array<float, 2>&)frag_O[m + x][n][d * 2] =\n                                (Array<float, 2>&)frag_O[m + x][n][d * 2] * inv_L[n];\n                        }\n                        PRAGMA_UNROLL\n                        for (int q = 0; q < 2; ++q) {\n                            const int hi = n * OP_N + lane_id % 4 * 2 + q * 1 + warp_ids.y * WARP_H;\n                            // [43][2][10]\n                            //   2  1\n                            //   4  1\n                            //   8  1\n                            const int di = m * OP_M + lane_id / 4 % 2 + d * 8 * X + x * 2 + lane_id / 8 * X * 2;\n                            if (warp_ids.x == 0) {\n                                storage.O1[hi][di] = frag_O[m + x][n][d * 2 + q];\n                                // if (hi == 0) {\n                                //     printf(\"O %4d %4d %f\\n\", hi, di, frag_O[m][n][d * 2 + q]);\n                                // }\n                            }\n                        }\n                    }\n                }\n            }\n        }\n\n        __syncthreads();\n\n        // For HeadDim=256, WarpThreadC needs to be explicitly specified to avoid exceeding WARP_SIZE\n        using Map = std::conditional_t<kHeadDim == 256,\n                                       RakedThreadMap<kHeadDim, CTA_H1, 4, kWarpCount, 8>,\n                                       RakedThreadMap<kHeadDim, CTA_H1, 4, kWarpCount>>;\n        Array<float, 4> tmp_O[Map::kIterS][Map::kIterC];\n\n        const int  warp_id = threadIdx.x / WARP_SIZE;\n        const int2 offset  = Map::get_offset(warp_id, lane_id);\n\n        PRAGMA_UNROLL\n        for (int s = 0; s < Map::kIterS; ++s) {\n            PRAGMA_UNROLL\n            for (int c = 0; c < Map::kIterC; ++c) {\n                const int hi = offset.y + s * Map::kDeltaS;\n                const int di = offset.x + c * Map::kDeltaC;\n                Load(tmp_O[s][c], &storage.O1[hi][di]);\n                ((Func &&) func)(hi, 0, di, tmp_O[s][c]);\n            }\n        }\n    }\n};\n\n}  // namespace turbomind::attention\n"
  },
  {
    "path": "src/turbomind/kernels/attention/impl_884.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include \"src/turbomind/kernels/attention/impl.h\"\n#include \"src/turbomind/kernels/core/array_ops.h\"\n#include \"src/turbomind/kernels/core/layout.h\"\n#include \"src/turbomind/kernels/core/mma.h\"\n#include \"src/turbomind/kernels/core/thread_map.h\"\n\n#include <cmath>\n#include <type_traits>\n\nnamespace turbomind::attention {\n\ntemplate<class T_, int CTA_H_, int CTA_Q_, int CTA_S_, int WARP_H_, int WARP_Q, int WARP_S, int HeadDim>\nstruct Impl<MMA_884, T_, T_, CTA_H_, CTA_Q_, CTA_S_, WARP_H_, WARP_Q, WARP_S, HeadDim> {\n    using T   = T_;\n    using Tkv = T_;\n\n    static constexpr bool MLA = false;\n\n    static constexpr int CTA_H    = CTA_H_;\n    static constexpr int CTA_Q    = CTA_Q_;\n    static constexpr int CTA_S    = CTA_S_;\n    static constexpr int kHeadDim = HeadDim;\n\n    static constexpr int kWarpCntQ  = CTA_Q / WARP_Q;\n    static constexpr int kWarpCntS  = CTA_S / WARP_S;\n    static constexpr int kWarpCount = kWarpCntQ * kWarpCntS;\n\n    static constexpr int OP_M = 16;\n    static constexpr int OP_N = 16;\n    static constexpr int OP_K = 4;\n\n    static constexpr int K_M = WARP_Q / OP_M;   // 1\n    static constexpr int K_N = WARP_S / OP_N;   // 4\n    static constexpr int K_K = HeadDim / OP_K;  // 32\n\n    static constexpr int V_M = WARP_Q / OP_M;   // 1\n    static constexpr int V_N = HeadDim / OP_N;  // 8\n    static constexpr int V_K = WARP_S / OP_K;   // 16\n\n    //  +---+---+\n    //  | 0 | 1 |\n    //  +---+---+\n    //  | 2 | 3 |\n    //  +---+---+\n    using FragQ = Array<half, 4>[K_K][K_M];   //    (q2,q2,x2,q4) (Dk,Qm) (d4)\n                                              //      4  8  0  1    4 16    1\n    using FragK = Array<half, 4>[K_K][K_N];   //    (s2,x2,s2,s4) (Dk,Sn) (d4)\n                                              //      4  0  8  1    4 16    1\n    using FragS = Array<float, 8>[K_M][K_N];  // (q2,q2,s2,s2,q2) (Qm,Sn) (s2,q2,s2)\n                                              //   4  8  8  2  1   16 16    4  2  1\n    using FragP = Array<half, 4>[V_K][V_M];   //    (q2,q2,x2,q4) (Sk,Qm) (s4)\n                                              //      4  8  0  1    4 16    1\n    using FragV = Array<half, 4>[V_K][V_N];   //    (d2,x2,d2,s4) (Sk,Dn) (d4)       [row major]\n                                              //      4  0  8  1    4 16    1\n    using FragO = Array<float, 8>[V_M][V_N];  // (q2,q2,d2,d2,q2) (Qm,Dn) (d2,q2,d2)\n                                              //   4  8  8  2  1   16 16    4  2  1\n    using FragM = Array<float, 2>[V_M];       // (q2,q2,_2,_2,q2) (Qm)    (q2))\n    using FragL = FragM;\n\n    // using Swizzle = Identity;\n\n    struct SwizzleV {\n\n        __device__ static int apply(int offset)\n        {\n            // Rearrange for LDS.128 (also avoid bank-conflict along C)\n            // 6543210\n            // dDDDDdd\n            offset = ((offset & 8) << 2) ^ offset;                                     // x[5] ^= x[3]\n            offset = ((offset & ~20) | (((offset & 16) >> 2) | ((offset & 4) << 2)));  // swap(x[4], x[2])\n\n            // Shuffle C according S to avoid bank-conflict\n            // ssssSSdDDddd\n            offset = ((offset & (0x3 << 6)) >> 3) ^ offset;\n            return offset;\n        }\n\n        __device__ int operator()(int offset)\n        {\n            return apply(offset);\n        }\n    };\n\n    using SmemLayoutQ = SmemLayoutV2<CTA_Q, HeadDim + 4, 1, 1, Identity>;\n    using SmemLayoutP = SmemLayoutV2<CTA_Q, CTA_S + 4, 1, 1, Identity>;\n    using SmemLayoutK = SmemLayoutV2<CTA_S, HeadDim + 4, 1, 1, Identity>;\n    using SmemLayoutV = SmemLayoutV2<CTA_S, HeadDim, CTA_S, 64, SwizzleV>;\n\n    using SmemLayoutKVp = void;\n\n    struct SharedStorage {\n        union {\n            __align__(16) T Q[SmemLayoutQ::kSize];\n            struct {\n                __align__(16) T K[SmemLayoutK::kSize];\n                __align__(16) T V[SmemLayoutV::kSize];\n                __align__(16) T P[SmemLayoutP::kSize];\n            };\n        };\n    };\n\n    static constexpr bool kUseSmemQ = false;\n    static constexpr bool kUseSmemP = false;\n\n    // For HeadDim=256, WarpThreadC needs to be explicitly specified to avoid exceeding WARP_SIZE\n    using ThreadMapQ  = std::conditional_t<HeadDim == 256,\n                                          RakedThreadMap<HeadDim, CTA_Q, 4, kWarpCount, 8>,\n                                          RakedThreadMap<HeadDim, CTA_Q, 4, kWarpCount>>;\n    using ThreadMapKV = std::conditional_t<HeadDim == 256,\n                                           RakedThreadMap<HeadDim, CTA_S, 4, kWarpCount, 8>,\n                                           RakedThreadMap<HeadDim, CTA_S, 4, kWarpCount>>;\n\n    using ThreadMapKVp = void;\n\n    static constexpr bool kDeferReduceL = true;\n\n    __device__ static void Sync()\n    {\n        __syncthreads();\n    }\n\n    template<class Fragment, class Func>\n    __device__ static void ForeachS(Fragment& S, Func&& func)\n    {\n        const int warp_id = threadIdx.x / WARP_SIZE;\n        const int lane_id = threadIdx.x % WARP_SIZE;\n        PRAGMA_UNROLL\n        for (int m = 0; m < K_M; ++m) {\n            PRAGMA_UNROLL\n            for (int n = 0; n < K_N; ++n) {\n                PRAGMA_UNROLL\n                for (int s1 = 0; s1 < 2; ++s1) {\n                    PRAGMA_UNROLL\n                    for (int q = 0; q < 2; ++q) {\n                        PRAGMA_UNROLL\n                        for (int s0 = 0; s0 < 2; ++s0) {\n                            const int qi = m * OP_M + (lane_id & 8) + (lane_id & 1) + lane_id / 16 * 4 + q * 2;\n                            const int si = n * OP_N + (lane_id & 4) * 2 + (lane_id & 2) + s1 * 4 + s0;\n                            ((Func &&) func)(0, warp_id * WARP_Q + qi, si, /*ri*/ 0, S[m][n][s1 * 4 + q * 2 + s0]);\n                        }\n                    }\n                }\n            }\n        }\n    }\n\n    __device__ static void TransformQ(const T* smem_Q, FragQ& frag_Q)\n    {\n        if constexpr (!kUseSmemQ) {\n            const int warp_id = threadIdx.x / WARP_SIZE;\n            const int lane_id = threadIdx.x % WARP_SIZE;\n            PRAGMA_UNROLL\n            for (int k = 0; k < K_K; ++k) {\n                PRAGMA_UNROLL\n                for (int m = 0; m < K_M; ++m) {\n                    const int qi = m * OP_M + (lane_id & 8) + lane_id % 4 + lane_id / 16 * 4 + warp_id * WARP_Q;\n                    const int di = k * 4;\n                    Lds(frag_Q[k][m], &smem_Q[SmemLayoutQ::apply(qi, di)]);\n                }\n            }\n        }\n    }\n\n    template<class GmemIterK, class GmemIterV>\n    __device__ static void SetSmemKV(GmemIterK& gmem_K, GmemIterV& gmem_V, SharedStorage& storage, bool offset_kv)\n    {\n        gmem_K.SetSmem(storage.K);\n        gmem_V.SetSmem(storage.V);\n    }\n\n    struct StateQK {\n        SmemAccessor<T, SmemLayoutK> smem_K;\n\n        FragQ frag_Q;\n        FragK frag_K;\n\n        __device__ StateQK(SharedStorage& storage, FragQ frag_Q_): smem_K{storage.K}\n        {\n            static_assert(!kUseSmemQ, \"not implemented\");\n            PRAGMA_UNROLL\n            for (int k = 0; k < K_K; ++k) {\n                PRAGMA_UNROLL\n                for (int m = 0; m < K_M; ++m) {\n                    frag_Q[k][m] = frag_Q_[k][m];\n                }\n            }\n        }\n\n        __device__ void Load(int k, int pipe_iter)\n        {\n            const int lane_id = threadIdx.x % WARP_SIZE;\n            PRAGMA_UNROLL\n            for (int n = 0; n < K_N; ++n) {\n                const int s = n * 16 + lane_id / 16 * 4 + (lane_id & 4) * 2 + lane_id % 4;\n                const int c = k * 4;\n                Lds(frag_K[k][n], &smem_K(s, c));\n            }\n        }\n\n        __device__ void Transform(int k) {}\n    };\n\n    template<class Prefetch, class Preload>\n    __device__ static void\n    ComputeQK(StateQK& state_QK, FragS& frag_S, int offset, Prefetch&& prefetch, Preload&& preload)\n    {\n        PRAGMA_UNROLL\n        for (int k = 0; k < K_K; ++k) {\n            if (k < K_K - 1) {\n                state_QK.Load(k + 1, offset);\n            }\n            else {\n                ((Preload &&) preload)();\n            }\n            PRAGMA_UNROLL\n            for (int m = 0; m < K_M; ++m) {\n                PRAGMA_UNROLL\n                for (int n = 0; n < K_N; ++n) {\n                    const int nn = n ^ 1;\n                    mma_m8n8k4_row_col(frag_S[m][nn], state_QK.frag_Q[k][m], state_QK.frag_K[k][nn], frag_S[m][nn]);\n                }\n            }\n        }\n    }\n\n    struct StatePV {\n        T* smem_V;\n\n        static_assert(V_N % 2 == 0);\n        Array<int, V_N / 2> idxs_;\n\n        FragP frag_P;\n        FragV frag_V;\n\n        __device__ StatePV(SharedStorage& storage, bool offset): smem_V{storage.V}\n        {\n            assert(offset);\n            const int lane_id = threadIdx.x % WARP_SIZE;\n            PRAGMA_UNROLL\n            for (int n = 0; n < V_N; n += 2) {\n                const int s  = 0 * 4 + lane_id % 4;\n                const int c  = n * 16 + lane_id / 16 * 4 + (lane_id & 4) * 2;\n                idxs_[n / 2] = SmemLayoutV::apply(s, c);\n            }\n        }\n\n        __device__ void Load(int k, int pipe_iter)\n        {\n            PRAGMA_UNROLL\n            for (int n = 0; n < V_N; n += 2) {\n                const int idx = idxs_[n / 2] + k * 4 * SmemLayoutV::C0;\n                Lds((Array<half, 8>&)frag_V[k][n], &smem_V[idx]);\n            }\n        }\n\n        __device__ void Transform(int k) {}\n    };\n\n    template<class Prefetch, class Preload>\n    __device__ static void\n    ComputePV(StatePV& state_PV, FragO& frag_O, int offset, Prefetch&& prefetch, Preload&& preload)\n    {\n        PRAGMA_UNROLL\n        for (int k = 0; k < V_K; ++k) {\n            if (k < V_K - 1) {\n                state_PV.Load(k + 1, offset);\n            }\n            else {\n                ((Preload &&) preload)();\n            }\n            PRAGMA_UNROLL\n            for (int m = 0; m < V_M; ++m) {\n                PRAGMA_UNROLL\n                for (int n = 0; n < V_N; ++n) {\n                    mma_m8n8k4_row_row(frag_O[m][n], state_PV.frag_P[k][m], state_PV.frag_V[k][n], frag_O[m][n]);\n                }\n            }\n        }\n    }\n\n    template<bool is_residue>\n    __device__ static void Softmax(FragS& frag_S, FragM& frag_M, FragM& frag_L, FragO& frag_O, float qk_scale)\n    {\n        FragM prev_M;\n        PRAGMA_UNROLL\n        for (int m = 0; m < K_M; ++m) {\n            prev_M[m] = frag_M[m];\n        }\n\n        PRAGMA_UNROLL\n        for (int m = 0; m < K_M; ++m) {\n            PRAGMA_UNROLL\n            for (int n = 0; n < K_N; ++n) {\n                PRAGMA_UNROLL\n                for (int s1 = 0; s1 < 2; ++s1) {\n                    PRAGMA_UNROLL\n                    for (int q = 0; q < 2; ++q) {\n                        PRAGMA_UNROLL\n                        for (int s0 = 0; s0 < 2; ++s0) {\n                            frag_M[m][q] =\n                                fmaxf(frag_M[m][q], frag_S[m][n][s1 * 4 + q * 2 + s0]);  // reduce over local quad\n                        }\n                    }\n                }\n            }\n            PRAGMA_UNROLL\n            for (int q = 0; q < 2; ++q) {  // reduce over thread group within warp (within warp tiles)\n                frag_M[m][q] = fmaxf(frag_M[m][q], __shfl_xor_sync(uint32_t(-1), frag_M[m][q], 2));\n                frag_M[m][q] = fmaxf(frag_M[m][q], __shfl_xor_sync(uint32_t(-1), frag_M[m][q], 4));\n            }\n        }\n\n        PRAGMA_UNROLL\n        for (int m = 0; m < K_M; ++m) {\n            PRAGMA_UNROLL\n            for (int q = 0; q < 2; ++q) {\n                // exp(M - M'), isinf(frag_M) => isnan(expdiff_M)\n                float expdiff_M = exp2f((prev_M[m][q] - frag_M[m][q]) * qk_scale);\n                if (is_residue && frag_M[m][q] == -std::numeric_limits<float>::infinity()) {\n                    expdiff_M = 0.f;\n                }\n                PRAGMA_UNROLL\n                for (int n = 0; n < V_N; ++n) {\n                    PRAGMA_UNROLL\n                    for (int s1 = 0; s1 < 2; ++s1) {\n                        PRAGMA_UNROLL\n                        for (int s0 = 0; s0 < 2; ++s0) {\n                            frag_O[m][n][s1 * 4 + q * 2 + s0] *= expdiff_M;  // Rescale previous output\n                        }\n                    }\n                }\n                frag_L[m][q] *= expdiff_M;\n            }\n        }\n\n        PRAGMA_UNROLL\n        for (int m = 0; m < K_M; ++m) {\n            PRAGMA_UNROLL\n            for (int q = 0; q < 2; ++q) {\n                float tmp_L{};\n                PRAGMA_UNROLL\n                for (int n = 0; n < K_N; ++n) {\n                    PRAGMA_UNROLL\n                    for (int s1 = 0; s1 < 2; ++s1) {\n                        PRAGMA_UNROLL\n                        for (int s0 = 0; s0 < 2; ++s0) {\n                            // unnormalized prob, optimized to FFMA\n                            float p = exp2f(frag_S[m][n][s1 * 4 + q * 2 + s0] * qk_scale - frag_M[m][q] * qk_scale);\n                            if (is_residue && frag_M[m][q] == -std::numeric_limits<float>::infinity()) {\n                                p = 0.f;\n                            }\n                            tmp_L += p;\n                            frag_S[m][n][s1 * 4 + q * 2 + s0] = p;\n                        }\n                    }\n                }\n                if constexpr (!kDeferReduceL) {\n                    tmp_L += __shfl_xor_sync(uint32_t(-1), tmp_L, 2);\n                    tmp_L += __shfl_xor_sync(uint32_t(-1), tmp_L, 4);\n                }\n                frag_L[m][q] = frag_L[m][q] + tmp_L;  // update L\n            }\n        }\n    }\n\n    __device__ static void ConvertStoP(FragS& frag_S, FragP& frag_P, SharedStorage& storage)\n    {\n        ForeachS(frag_S,\n                 [&](int, int qi, int si, int ri, float p) { storage.P[SmemLayoutP::apply(qi, si)] = half(p); });\n\n        if constexpr (!kUseSmemP) {\n            const int warp_id = threadIdx.x / WARP_SIZE;\n            const int lane_id = threadIdx.x % WARP_SIZE;\n            PRAGMA_UNROLL\n            for (int k = 0; k < V_K; ++k) {\n                PRAGMA_UNROLL\n                for (int m = 0; m < V_M; ++m) {\n                    const int qi = m * OP_M + lane_id / 16 * 4 + (lane_id & 8) + lane_id % 4 + warp_id * WARP_Q;\n                    const int si = k * OP_K;\n                    Lds(frag_P[k][m], &storage.P[SmemLayoutP::apply(qi, si)]);\n                }\n            }\n        }\n    }\n\n    template<class Func>\n    __device__ static void ForeachML(FragM& frag_M, FragL& frag_L, Func&& func)\n    {\n        /// FIXME: implement this\n        const int warp_id = threadIdx.x / WARP_SIZE;\n        const int lane_id = threadIdx.x % WARP_SIZE;\n        PRAGMA_UNROLL\n        for (int m = 0; m < V_M; ++m) {  // Q,16\n            PRAGMA_UNROLL\n            for (int q = 0; q < 2; ++q) {  // Q,2\n                const int qi = (lane_id & 1) * 1 + (lane_id & 16) / 4 + (lane_id & 8) + m * OP_M + q * 2;\n                const int ri = (lane_id & 2) / 2 + (lane_id & 4) / 2;\n                ((Func &&) func)(0, warp_id * WARP_Q + qi, ri, frag_M[m][q], frag_L[m][q]);\n            }\n        }\n    };\n\n    template<class Storage>\n    __device__ static void Merge(FragO& frag_O, FragM& frag_M, FragL& frag_L, float qk_scale, Storage& storage)\n    {\n        static_assert(kWarpCntS == 1);\n\n        PRAGMA_UNROLL\n        for (int m = 0; m < V_M; ++m) {\n            PRAGMA_UNROLL\n            for (int q = 0; q < 2; ++q) {\n                if constexpr (kDeferReduceL) {\n                    frag_L[m][q] += __shfl_xor_sync(uint32_t(-1), frag_L[m][q], 2);\n                    frag_L[m][q] += __shfl_xor_sync(uint32_t(-1), frag_L[m][q], 4);\n                }\n            }\n        }\n    }\n\n    template<bool is_norm, class Func>\n    __device__ static void StoreO(FragO& frag_O, FragL& frag_L, SharedStorage& storage, Func&& func)\n    {\n        FragL inv_L;\n        PRAGMA_UNROLL\n        for (int m = 0; m < V_M; ++m) {\n            PRAGMA_UNROLL\n            for (int q = 0; q < 2; ++q) {\n                inv_L[m][q] = fdividef(1.f, frag_L[m][q] + 1e-8f);\n            }\n        }\n\n        const int warp_id = threadIdx.x / WARP_SIZE;\n        const int lane_id = threadIdx.x % WARP_SIZE;\n\n        const int mm = lane_id / 16 * 4 + (lane_id & 8) + (lane_id & 1);\n        const int nn = (lane_id & 4) * 2 + (lane_id & 2);\n\n        PRAGMA_UNROLL\n        for (int m = 0; m < V_M; ++m) {\n            PRAGMA_UNROLL\n            for (int n = 0; n < V_N; ++n) {\n                PRAGMA_UNROLL\n                for (int d1 = 0; d1 < 2; ++d1) {\n                    PRAGMA_UNROLL\n                    for (int q = 0; q < 2; ++q) {\n                        const int qi = m * OP_M + mm + q * 2 + warp_id * WARP_Q;\n                        const int di = n * OP_N + nn + d1 * 4;\n                        if constexpr (is_norm) {\n                            PRAGMA_UNROLL\n                            for (int d0 = 0; d0 < 2; ++d0) {\n                                frag_O[m][n][d1 * 4 + q * 2 + d0] *= inv_L[m][q];\n                            }\n                        }\n                        ((Func &&) func)(0, qi, di, (Array<float, 2>&)frag_O[m][n][d1 * 4 + q * 2]);\n                    }\n                }\n            }\n        }\n    }\n};\n\n}  // namespace turbomind::attention\n"
  },
  {
    "path": "src/turbomind/kernels/attention/impl_m16n8.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include \"src/turbomind/kernels/core/array.h\"\n\nnamespace turbomind::attention {\n\ntemplate<class T, int WARP_H, int WARP_Q, int WARP_S, int HeadDim>\nstruct Impl_m16k8 {\n\n    static constexpr int OP_M = 16;\n    static constexpr int OP_N = 8;\n\n    static constexpr int K_M = WARP_Q / OP_M;  //  16 / 16 = 1\n    static constexpr int K_N = WARP_S / OP_N;  //  64 /  8 = 8\n\n    static constexpr int V_M = WARP_Q / OP_M;   //  16 / 16 = 1\n    static constexpr int V_N = HeadDim / OP_N;  // 128 /  8 = 16 -> D16\n\n    template<class S>\n    using FragS_ = Array<S, 4>[K_M][K_N];     // ((q8, s4), (Qm, Sn), (q2, s2))\n                                              //    1   2    16   8     8   1\n    using FragO = Array<float, 4>[V_M][V_N];  // ((q8, d4), (Qm, Dn), (q2, d2))\n                                              //    1   2    16   8     8   1\n    using FragM = Array<float, 2>[V_M];       // ((q8, _4), Qm, q2) => FragS with all S dim reduced\n                                              //    1   0   16   8\n    using FragS = FragS_<float>;\n    using FragL = FragM;\n\n    static constexpr bool kDeferReduceL = false;\n\n    template<class Func>\n    __device__ static void ForeachML(FragM& frag_M, FragL& frag_L, Func&& func)\n    {\n        const int warp_id = threadIdx.x / WARP_SIZE;\n        const int lane_id = threadIdx.x % WARP_SIZE;\n        PRAGMA_UNROLL\n        for (int m = 0; m < K_M; ++m) {  // Q\n            PRAGMA_UNROLL\n            for (int q = 0; q < 2; ++q) {\n                const int qi = lane_id / 4 * 1 + m * OP_M + q * 8 + warp_id * WARP_Q;\n                const int ri = lane_id % 4 * 1;\n                ((Func &&) func)(qi % WARP_H, qi / WARP_H, ri, frag_M[m][q], frag_L[m][q]);\n            }\n        }\n    }\n\n    template<class Fragment, class Func>\n    __device__ static void ForeachS(Fragment& S, Func&& func)\n    {\n        const int warp_id = threadIdx.x / WARP_SIZE;\n        const int lane_id = threadIdx.x % WARP_SIZE;\n        PRAGMA_UNROLL\n        for (int m = 0; m < K_M; ++m) {  // Q\n            PRAGMA_UNROLL\n            for (int n = 0; n < K_N; ++n) {  // KV\n                PRAGMA_UNROLL\n                for (int q = 0; q < 2; ++q) {\n                    PRAGMA_UNROLL\n                    for (int s = 0; s < 2; ++s) {\n                        const int qi = lane_id / 4 * 1 + m * OP_M + q * 8 + warp_id * WARP_Q;\n                        const int ki = lane_id % 4 * 2 + n * OP_N + s * 1;\n                        ((Func &&) func)(qi % WARP_H, qi / WARP_H, ki, /*ri*/ 0, S[m][n][q * 2 + s]);\n                    }\n                }\n            }\n        }\n    }\n\n    template<bool is_residue>\n    __device__ static void Softmax(FragS& frag_S, FragM& frag_M, FragM& frag_L, FragO& frag_O, float qk_scale)\n    {\n        FragM prev_M;\n        PRAGMA_UNROLL\n        for (int m = 0; m < K_M; ++m) {\n            prev_M[m] = frag_M[m];\n        }\n\n        // maximum\n        PRAGMA_UNROLL\n        for (int m = 0; m < K_M; ++m) {  // Q\n            auto& row_M = frag_M[m];\n            PRAGMA_UNROLL\n            for (int n = 0; n < K_N; ++n) {  // KV\n                auto& C = frag_S[m][n];\n                PRAGMA_UNROLL\n                for (int q = 0; q < 2; ++q) {\n                    row_M[q] = fmaxf(row_M[q], fmaxf(C[q * 2 + 0], C[q * 2 + 1]));  // reduce over local pair\n                }\n            }\n            PRAGMA_UNROLL\n            for (int q = 0; q < 2; ++q) {  // reduce over thread group within warp (within warp tiles)\n                row_M[q] = fmaxf(row_M[q], __shfl_xor_sync(uint32_t(-1), row_M[q], 1));\n                row_M[q] = fmaxf(row_M[q], __shfl_xor_sync(uint32_t(-1), row_M[q], 2));\n            }\n        }\n\n        FragM expdiff_M;\n        PRAGMA_UNROLL\n        for (int m = 0; m < K_M; ++m) {\n            PRAGMA_UNROLL\n            for (int q = 0; q < 2; ++q) {\n                // exp(M - M'), isinf(frag_M) => isnan(expdiff_M)\n                expdiff_M[m][q] = exp2f((prev_M[m][q] - frag_M[m][q]) * qk_scale);\n                if (is_residue && frag_M[m][q] == -std::numeric_limits<float>::infinity()) {\n                    expdiff_M[m][q] = 0.f;\n                }\n            }\n        }\n\n        PRAGMA_UNROLL\n        for (int m = 0; m < K_M; ++m) {\n            PRAGMA_UNROLL\n            for (int q = 0; q < 2; ++q) {\n                frag_L[m][q] *= expdiff_M[m][q];\n            }\n        }\n\n        PRAGMA_UNROLL\n        for (int m = 0; m < K_M; ++m) {\n            PRAGMA_UNROLL\n            for (int q = 0; q < 2; ++q) {\n                float tmp_L{};\n                PRAGMA_UNROLL\n                for (int n = 0; n < K_N; ++n) {\n                    PRAGMA_UNROLL\n                    for (int s = 0; s < 2; ++s) {\n                        // unnormalized prob\n                        float p = exp2f(frag_S[m][n][q * 2 + s] * qk_scale - frag_M[m][q] * qk_scale);\n                        if (is_residue && frag_M[m][q] == -std::numeric_limits<float>::infinity()) {\n                            p = 0.f;\n                        }\n                        tmp_L += p;\n                        frag_S[m][n][q * 2 + s] = p;\n                    }\n                }\n                if constexpr (!kDeferReduceL) {\n                    tmp_L += __shfl_xor_sync(uint32_t(-1), tmp_L, 1);\n                    tmp_L += __shfl_xor_sync(uint32_t(-1), tmp_L, 2);\n                }\n                frag_L[m][q] += tmp_L;  // update L\n            }\n        }\n\n        PRAGMA_UNROLL\n        for (int m = 0; m < K_M; ++m) {\n            PRAGMA_UNROLL\n            for (int n = 0; n < V_N; ++n) {\n                PRAGMA_UNROLL\n                for (int q = 0; q < 2; ++q) {\n                    PRAGMA_UNROLL\n                    for (int d = 0; d < 2; ++d) {\n                        frag_O[m][n][q * 2 + d] *= expdiff_M[m][q];  // Rescale previous output\n                    }\n                }\n            }\n        }\n    }\n\n    template<class FragP, class Storage>\n    __device__ static void ConvertStoP(FragS& frag_S, FragP& frag_P, Storage&)\n    {\n        FragS_<T>& frag_Ps = (FragS_<T>&)frag_P;\n        PRAGMA_UNROLL\n        for (int m = 0; m < K_M; ++m) {\n            PRAGMA_UNROLL\n            for (int n = 0; n < K_N; ++n) {\n                PRAGMA_UNROLL\n                for (int q = 0; q < 2; ++q) {\n                    PRAGMA_UNROLL\n                    for (int s = 0; s < 2; ++s) {\n                        frag_Ps[m][n][q * 2 + s] = static_cast<T>(frag_S[m][n][q * 2 + s]);\n                    }\n                }\n            }\n        }\n    }\n\n    template<class Storage>\n    __device__ static void Merge(FragO& frag_O, FragM& frag_M, FragL& frag_L, float qk_scale, Storage& storage)\n    {\n    }\n\n    template<bool is_norm, class Func, class Storage>\n    __device__ static void StoreO(FragO& frag_O, FragL& frag_L, Storage& storage, Func&& func)\n    {\n        FragL inv_L;\n        PRAGMA_UNROLL\n        for (int m = 0; m < V_M; ++m) {\n            PRAGMA_UNROLL\n            for (int q = 0; q < 2; ++q) {\n                inv_L[m][q] = fdividef(1.f, frag_L[m][q]);\n            }\n        }\n\n        const int warp_id = threadIdx.x / WARP_SIZE;\n        const int lane_id = threadIdx.x % WARP_SIZE;\n\n        PRAGMA_UNROLL\n        for (int m = 0; m < V_M; ++m) {\n            PRAGMA_UNROLL\n            for (int q = 0; q < 2; ++q) {\n                const int qi = lane_id / 4 * 1 + m * OP_M + q * 8 + warp_id * WARP_Q;\n                PRAGMA_UNROLL\n                for (int n = 0; n < V_N; ++n) {\n                    if constexpr (is_norm) {\n                        PRAGMA_UNROLL\n                        for (int d = 0; d < 2; ++d) {\n                            frag_O[m][n][q * 2 + d] *= inv_L[m][q];\n                        }\n                    }\n                    const int di = n * 8 + lane_id % 4 * 2;\n                    ((Func &&) func)(qi % WARP_H, qi / WARP_H, di, (Array<float, 2>&)frag_O[m][n][q * 2]);\n                }\n            }\n        }\n    }\n};\n\n}  // namespace turbomind::attention\n"
  },
  {
    "path": "src/turbomind/kernels/attention/impl_simt.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include <limits>\n#include <numeric>\n#include <type_traits>\n\n#include \"src/turbomind/kernels/core/array_ops.h\"\n#include \"src/turbomind/kernels/core/common.h\"\n#include \"src/turbomind/kernels/core/layout.h\"\n#include \"src/turbomind/kernels/core/thread_map.h\"\n\n#include \"src/turbomind/kernels/attention/impl.h\"\n#include \"src/turbomind/kernels/attention/quantization.h\"\n\nnamespace turbomind::attention {\n\ntemplate<class T_,\n         class Tkv_,\n         int CTA_H_,\n         int CTA_Q_,\n         int CTA_S_,\n         int WARP_H_,\n         int WARP_Q,\n         int WARP_S,\n         int HeadDim,\n         int Stages>\nstruct Impl<MMA_SIMT, T_, Tkv_, CTA_H_, CTA_Q_, CTA_S_, WARP_H_, WARP_Q, WARP_S, HeadDim, Stages> {\n\n    using T   = T_;\n    using Tkv = Tkv_;\n\n    static constexpr int kQuantKV = !std::is_same_v<T, Tkv>;\n\n    static constexpr bool MLA = HeadDim == 576;\n\n    static constexpr int CTA_H = CTA_H_;\n    static constexpr int CTA_Q = CTA_Q_;\n    static constexpr int CTA_S = CTA_S_;\n\n    static constexpr int WARP_H = WARP_H_;\n\n    static constexpr int kHeadDim = HeadDim;\n\n    static constexpr int kWarpCntH = CTA_H / WARP_H;\n    static constexpr int kWarpCntQ = CTA_Q / WARP_Q;\n    static constexpr int kWarpCntS = CTA_S / WARP_S;\n\n    static constexpr int kWarpCount = kWarpCntH * kWarpCntQ * kWarpCntS;\n\n    static_assert(kWarpCntQ == 1);\n\n    static constexpr int VEC = 8;\n\n    static constexpr int T_D = 8;                // warp thread C\n    static constexpr int T_S = WARP_SIZE / T_D;  // warp thread S\n\n    // warp footprint (1x4x64)\n    static constexpr int OP_H = 1;\n    static constexpr int OP_S = T_S;\n    static constexpr int OP_D = VEC * T_D;\n\n    static constexpr int K_M = WARP_H / OP_H;   // 1\n    static constexpr int K_N = WARP_S / OP_S;   // 4\n    static constexpr int K_K = HeadDim / OP_D;  // 2\n\n    static constexpr int V_M = K_M;  // 1\n    static constexpr int V_N = K_K;  // 2\n    static constexpr int V_K = K_N;  // 4\n\n    static_assert(WARP_H % OP_H == 0);\n    static_assert(WARP_S % OP_S == 0);\n    static_assert(HeadDim % OP_D == 0);\n\n    using Tqk = std::conditional_t<sizeof(Tkv) == 2, float, T>;\n    using Tpv = Tqk;\n\n    struct RakedD {\n        static constexpr int S_D_thr = VEC * K_K;\n        static constexpr int S_S_thr = 1;\n        static constexpr int S_D     = VEC;\n        static constexpr int S_S     = T_S;\n        static constexpr int LDS     = std::gcd(16 / sizeof(Array<Tkv, VEC>), K_K);\n    };\n\n    struct LinearD {\n        static constexpr int S_D_thr = VEC;\n        static constexpr int S_S_thr = 1;\n        static constexpr int S_D     = VEC * T_D;\n        static constexpr int S_S     = T_S;\n        static constexpr int LDS     = 1;\n    };\n\n    using ThreadMap = std::conditional_t<sizeof(Tkv) == 2, LinearD, RakedD>;\n\n    // Strides of thread index\n    static constexpr int S_D_thr = ThreadMap::S_D_thr;\n    static constexpr int S_S_thr = ThreadMap::S_S_thr;\n    // Strides of array index\n    static constexpr int S_D = ThreadMap::S_D;\n    static constexpr int S_S = ThreadMap::S_S;\n    // LDS vec count\n    static constexpr int LDS_K = ThreadMap::LDS;\n    static constexpr int LDS_V = ThreadMap::LDS;\n\n    static_assert(LDS_K <= K_K);\n\n    using FragQ = Array<T, VEC>[K_M][K_K];      // (q4, d8), (Qm, Dk), (d8)\n    template<class Tk>                          //   0  16     1   8     1\n    using FragK_ = Array<Tk, VEC>[K_N][K_K];    // (s4, d8), (Sn, Dk), (d8)\n                                                //   4  16     1   8     1\n    using FragS = Array<float, 1>[K_M][K_N];    // (s4, d8), (Qm, Sn)\n                                                //   4  16     1   1\n                                                // (s4, _8), (Qm, Sn)       [after redsum]\n                                                //   4   0     1   1\n    using FragM = Array<float, 1>[K_M];         // (_4, _8), (Qm)\n                                                //   0   0     1\n    using FragP = Array<Tpv, 1>[V_M][V_K];      // (s4, _8), (Qm, Sk), (s1)\n    template<class Tv>                          //   4   0     1   1     1\n    using FragV_ = Array<Tv, VEC>[V_K][V_N];    // (s4, d8), (Sk, Dn), (d8)\n                                                //   4  16     1   8     1\n    using FragO = Array<float, VEC>[V_M][V_N];  // (s4, d8), (Qm, Dn), (d8)\n                                                //   1  16     1   8     1\n    using ParamK = Array<T, 2>[K_N];            // (s4, x8), (Sn)\n                                                //   4   0     1\n    using ParamV = Array<T, 2>[V_K];            // (s4, x8), (Sk)\n                                                //   4   0     1\n    using FragSp = Array<Tpv, 1>[K_M][K_N];\n\n    static_assert(sizeof(FragP) == sizeof(FragSp));\n\n    using DataK = FragK_<Tkv>;\n    using DataV = FragV_<Tkv>;\n\n    using FragK = FragK_<Tqk>;\n    using FragV = FragV_<Tpv>;\n\n    using FragL = FragM;\n\n    using SmemLayoutQ = SmemLayoutV2<CTA_S, HeadDim, 1, 1, Identity>;\n    using SmemLayoutP = SmemLayoutV2<CTA_H, CTA_S, 1, 1, Identity>;\n    using SmemLayoutK = SmemLayoutV2<CTA_S, HeadDim, CTA_S, HeadDim, Identity>;\n    using SmemLayoutV = SmemLayoutV2<CTA_S, HeadDim, CTA_S, HeadDim, Identity>;\n\n    using SmemLayoutKVp = SmemLayoutV2<CTA_S, 2, CTA_S, 2, Identity>;\n\n    using SmemM = float[K_M][kWarpCntH][kWarpCntS];\n    using SmemL = float[K_M][kWarpCntH][kWarpCntS];\n    using SmemO = Array<float, 4>[V_M][V_N][2][kWarpCntH][kWarpCntS][T_D];  // (Qm, Dn, d2, Hw, Sw, d8), (d4)\n                                                                            //   1  64   4  WH  WS   8     1\n\n    using PointerKV = get_pointer_type<Tkv>;\n\n    union SharedStorage {\n        __align__(16) T Q[SmemLayoutQ::kSize];\n\n        struct {\n            __align__(16) Array<Tkv, Stages * SmemLayoutK::kSize> KV;\n            __align__(16) T KVp[Stages * SmemLayoutKVp::kSize];\n        };\n\n        struct {\n            __align__(16) SmemM M;\n            __align__(16) SmemL L;\n            __align__(16) SmemO O;\n        };\n    };\n\n    static constexpr bool kUseSmemQ = false;\n    static constexpr bool kUseSmemP = false;\n\n    using ThreadMapQ  = RakedThreadMap<HeadDim, CTA_H, 8, kWarpCount>;\n    using ThreadMapKV = RakedThreadMap<HeadDim, CTA_S, 128 / bitsof<Tkv>, kWarpCount>;\n    // `WARP_SIZE / WARP_S` is chosen to achieve minimum kIterS w/o introducing partial S iter\n    using ThreadMapKVp = RakedThreadMap<2, CTA_S, 2, kWarpCount, WARP_SIZE / WARP_S>;\n\n    static constexpr int kBatchK = ThreadMapKV::kIterS;\n    static constexpr int kBatchV = ThreadMapKV::kIterS;\n\n    __device__ static void Sync()\n    {\n        if constexpr (kWarpCntH > 1) {\n            __syncthreads();\n        }\n        if constexpr (kQuantKV) {  // Thread layout of KV & KVp is different within warp boundary\n            __syncwarp();\n        }\n    }\n\n    template<class GmemIterK, class GmemIterV>\n    __device__ static void SetSmemKV(GmemIterK& gmem_K, GmemIterV& gmem_V, SharedStorage& storage, bool offset_kv)\n    {\n        int pred = offset_kv;\n        if constexpr (kQuantKV) {\n            gmem_K.SetSmem(storage.KV.data(), storage.KVp);\n            gmem_V.SetSmem(storage.KV.data() + pred * SmemLayoutK::kSize, storage.KVp + pred * SmemLayoutKVp::kSize);\n        }\n        else {\n            gmem_K.SetSmem(storage.KV.data());\n            gmem_V.SetSmem(storage.KV.data() + pred * SmemLayoutK::kSize);\n        }\n    }\n\n    static __device__ int2 get_warp_ids()\n    {\n        const int warp_id = threadIdx.x / WARP_SIZE;\n        if constexpr (kWarpCntH > 1) {\n            return {warp_id % kWarpCntS, warp_id / kWarpCntS};\n        }\n        else {\n            return {warp_id, 0};\n        }\n    }\n\n    template<class Func>\n    __device__ static void ForeachML(FragM& frag_M, FragL& frag_L, Func&& func)\n    {\n        const auto warp_ids = get_warp_ids();\n        PRAGMA_UNROLL\n        for (int m = 0; m < K_M; ++m) {  // Q\n            const int hi = m * OP_H + warp_ids.y * WARP_H;\n            const int ri = threadIdx.x % (WARP_SIZE * kWarpCntS);\n            ((Func &&) func)(hi, 0, ri, frag_M[m][0], frag_L[m][0]);\n        }\n    }\n\n    template<class Fragment, class Func>\n    __device__ static void ForeachS(Fragment& S, Func&& func)\n    {\n        const auto warp_ids = get_warp_ids();\n        const int  lane_id  = threadIdx.x % WARP_SIZE;\n\n        PRAGMA_UNROLL\n        for (int m = 0; m < K_M; ++m) {\n            PRAGMA_UNROLL\n            for (int n = 0; n < K_N; ++n) {\n                const int hi = m * OP_H + warp_ids.y * WARP_H;\n                const int si = lane_id / T_D * S_S_thr + n * S_S + warp_ids.x * WARP_S;\n                const int ri = lane_id % T_D;\n                ((Func &&) func)(hi, /*qi*/ 0, si, ri, S[m][n][0]);\n            }\n        }\n    }\n\n    __device__ static void TransformQ(T* smem_Q, FragQ& frag_Q)\n    {\n        const auto warp_ids = get_warp_ids();\n        const int  lane_id  = threadIdx.x % WARP_SIZE;\n\n        __syncthreads();\n\n        SmemAccessor<T, SmemLayoutQ> sQ{smem_Q};\n\n        PRAGMA_UNROLL\n        for (int m = 0; m < K_M; ++m) {\n            PRAGMA_UNROLL\n            for (int k = 0; k < K_K; ++k) {\n                const int hi = m + warp_ids.y * WARP_H;\n                const int di = k * S_D + lane_id % T_D * S_D_thr;\n                Lds(frag_Q[m][k], &sQ(hi, di));\n            }\n        }\n    }\n\n    struct StateQK {\n        PointerKV smem_K;\n        T*        smem_K_param;\n        FragQ     frag_Q;\n        FragK     frag_K;\n        DataK     data_K;\n        ParamK    param_K;\n\n        __device__ StateQK(SharedStorage& storage, FragQ frag_Q_)\n        {\n            smem_K       = storage.KV.data();\n            smem_K_param = storage.KVp;\n            if constexpr (!kUseSmemQ) {\n                PRAGMA_UNROLL\n                for (int m = 0; m < K_M; ++m) {\n                    PRAGMA_UNROLL\n                    for (int k = 0; k < K_K; ++k) {\n                        frag_Q[m][k] = frag_Q_[m][k];\n                    }\n                }\n            }\n        }\n\n        __device__ void Load(int n, int pipe_iter)\n        {\n            const auto warp_ids = get_warp_ids();\n            const int  lane_id  = threadIdx.x % WARP_SIZE;\n\n            const int offset_s = lane_id / T_D * S_S_thr + warp_ids.x * WARP_S;\n            const int offset_c = lane_id % T_D * S_D_thr;\n\n            if (kQuantKV && n == 0) {\n                PRAGMA_UNROLL\n                for (int n = 0; n < K_N; ++n) {\n                    const int si = n * S_S + offset_s;\n                    Lds(param_K[n], &smem_K_param[pipe_iter * SmemLayoutKVp::kSize + SmemLayoutKVp::apply(si, 0)]);\n                }\n            }\n\n            PRAGMA_UNROLL\n            for (int k = 0; k < K_K; k += LDS_K) {\n                const int si = n * S_S + offset_s;\n                const int di = k * S_D + offset_c;\n                Lds((Array<Tkv, VEC * LDS_K>&)data_K[n][k],\n                    &smem_K[pipe_iter * SmemLayoutK::kSize + SmemLayoutK::apply(si, di)]);\n            }\n        }\n\n        __device__ void Transform(int n)\n        {\n            PRAGMA_UNROLL\n            for (int k = 0; k < K_K; ++k) {\n                ConvertKvCache<Tkv, Tqk> convert(param_K[n][0], param_K[n][1]);\n                frag_K[n][k] = convert(data_K[n][k]);\n            }\n        }\n    };\n\n    template<class Prefetch, class Preload>\n    __device__ static void\n    ComputeQK(StateQK state_QK, FragS& frag_S, int offset, Prefetch&& prefetch, Preload&& preload)\n    {\n        if constexpr (K_N == 1) {\n            ((Prefetch &&) prefetch)(0);\n        }\n\n        PRAGMA_UNROLL\n        for (int n = 0; n < K_N; ++n) {\n            if (n < K_N - 1) {\n                state_QK.Load(n + 1, offset);\n            }\n            else {\n                ((Preload &&) preload)();\n            }\n\n            state_QK.Transform(n);\n\n            PRAGMA_UNROLL\n            for (int m = 0; m < K_M; ++m) {\n                PRAGMA_UNROLL\n                for (int k = 0; k < K_K; ++k) {\n                    PRAGMA_UNROLL\n                    for (int c = 0; c < 8; ++c) {\n                        frag_S[m][n][0] += static_cast<float>((Tqk)state_QK.frag_Q[m][k][c] * state_QK.frag_K[n][k][c]);\n                    }\n                }\n            }\n\n            if (n < K_N - 1) {\n                ((Prefetch &&) prefetch)(n);\n            }\n            if (n == K_N - 2) {\n                ((Prefetch &&) prefetch)(K_N - 1);\n            }\n        }\n\n        PRAGMA_UNROLL\n        for (int m = 0; m < K_M; ++m) {\n            PRAGMA_UNROLL\n            for (int n = 0; n < K_N; ++n) {\n                PRAGMA_UNROLL\n                for (int mask = 1; mask < T_D; mask *= 2) {\n                    frag_S[m][n][0] += __shfl_xor_sync(uint32_t(-1), frag_S[m][n][0], mask);\n                }\n            }\n        }\n    }\n\n    struct StatePV {\n        PointerKV smem_V;\n        T*        smem_V_param;\n        FragP     frag_P;\n        FragV     frag_V;\n        DataV     data_V;\n        ParamV    param_V;\n\n        __device__ StatePV(SharedStorage& storage, bool offset = false)\n        {\n            smem_V       = storage.KV.data() + (offset ? SmemLayoutK::kSize : 0);\n            smem_V_param = storage.KVp + (offset ? SmemLayoutKVp::kSize : 0);\n        }\n\n        __device__ void Load(int k, int pipe_iter)\n        {\n            const auto warp_ids = get_warp_ids();\n            const int  lane_id  = threadIdx.x % WARP_SIZE;\n\n            const int offset_s = lane_id / T_D * S_S_thr + warp_ids.x * WARP_S;\n            const int offset_c = lane_id % T_D * S_D_thr;\n\n            if (kQuantKV && k == 0) {\n                PRAGMA_UNROLL\n                for (int k = 0; k < V_K; ++k) {\n                    const int si = k * S_S + offset_s;\n                    Lds(param_V[k], &smem_V_param[pipe_iter * SmemLayoutKVp::kSize + SmemLayoutKVp::apply(si, 0)]);\n                }\n            }\n\n            PRAGMA_UNROLL\n            for (int n = 0; n < V_N; n += LDS_V) {\n                const int si = k * S_S + offset_s;\n                const int di = n * S_D + offset_c;\n                Lds((Array<Tkv, VEC * LDS_V>&)data_V[k][n],\n                    &smem_V[pipe_iter * SmemLayoutV::kSize + SmemLayoutV::apply(si, di)]);\n            }\n        }\n\n        __device__ void Transform(int k)\n        {\n            PRAGMA_UNROLL\n            for (int n = 0; n < V_N; ++n) {\n                ConvertKvCache<Tkv, Tpv> convert(param_V[k][0], param_V[k][1]);\n                frag_V[k][n] = convert(data_V[k][n]);\n            }\n        }\n    };\n\n    template<class Prefetch, class Preload>\n    __device__ static void\n    ComputePV(StatePV state_PV, FragO& frag_O, int offset, Prefetch&& prefetch, Preload&& preload)\n    {\n        if constexpr (V_K == 1) {\n            ((Prefetch &&) prefetch)(0);\n        }\n\n        PRAGMA_UNROLL\n        for (int k = 0; k < V_K; ++k) {\n            if (k < V_K - 1) {\n                state_PV.Load(k + 1, offset);\n            }\n            else {\n                ((Preload &&) preload)();\n            }\n\n            state_PV.Transform(k);\n\n            PRAGMA_UNROLL\n            for (int m = 0; m < V_M; ++m) {\n                PRAGMA_UNROLL\n                for (int n = 0; n < V_N; ++n) {\n                    PRAGMA_UNROLL\n                    for (int d = 0; d < 8; ++d) {\n                        frag_O[m][n][d] += static_cast<float>((Tpv)state_PV.frag_P[m][k][0] * state_PV.frag_V[k][n][d]);\n                    }\n                }\n            }\n\n            if (k < V_K - 1) {\n                ((Prefetch &&) prefetch)(k);\n            }\n            if (k == V_K - 2) {\n                ((Prefetch &&) prefetch)(V_K - 1);\n            }\n        }\n    }\n\n    template<bool is_residue>\n    __device__ static void Softmax(FragS& frag_S, FragM& frag_M, FragL& frag_L, FragO& frag_O, float qk_scale)\n    {\n        FragM prev_M;\n        copy(frag_M, prev_M);\n\n        PRAGMA_UNROLL\n        for (int m = 0; m < K_M; ++m) {\n            PRAGMA_UNROLL\n            for (int n = 0; n < K_N; ++n) {\n                frag_M[m][0] = fmaxf(frag_M[m][0], frag_S[m][n][0]);\n            }\n        }\n\n        PRAGMA_UNROLL\n        for (int m = 0; m < V_M; ++m) {\n            float expdiff_M = exp2f((prev_M[m][0] - frag_M[m][0]) * qk_scale);\n            if (is_residue && frag_M[m][0] == -std::numeric_limits<float>::infinity()) {\n                expdiff_M = 0.f;\n            }\n            PRAGMA_UNROLL\n            for (int n = 0; n < V_N; ++n) {\n                using namespace ops;\n                frag_O[m][n] = frag_O[m][n] * expdiff_M;\n            }\n            frag_L[m][0] *= expdiff_M;\n        }\n\n        PRAGMA_UNROLL\n        for (int m = 0; m < K_M; ++m) {\n            float tmp_L{};\n            PRAGMA_UNROLL\n            for (int n = 0; n < K_N; ++n) {\n                float p = exp2f(frag_S[m][n][0] * qk_scale - frag_M[m][0] * qk_scale);\n                if (is_residue && frag_M[m][0] == -std::numeric_limits<float>::infinity()) {\n                    p = 0.f;\n                }\n                tmp_L += p;\n                frag_S[m][n][0] = p;\n            }\n            frag_L[m][0] += tmp_L;\n        }\n    }\n\n    __device__ static void ConvertStoP(FragS& frag_S, FragP& frag_P, SharedStorage&)\n    {\n        FragSp& frag_Sp = (FragSp&)frag_P;\n        PRAGMA_UNROLL\n        for (int m = 0; m < K_M; ++m) {\n            PRAGMA_UNROLL\n            for (int n = 0; n < K_N; ++n) {\n                frag_Sp[m][n][0] = static_cast<T>(frag_S[m][n][0]);\n            }\n        }\n    }\n\n    __device__ static void Merge(FragO& frag_O, FragM& frag_M, FragL& frag_L, float qk_scale, SharedStorage& storage)\n    {\n        const auto warp_ids = get_warp_ids();\n        const int  lane_id  = threadIdx.x % WARP_SIZE;\n\n        FragM prev_M;\n        copy(frag_M, prev_M);\n\n        __syncthreads();\n\n        /////////////////////////////////////////////////////////////////////////\n        //  global max\n        PRAGMA_UNROLL\n        for (int m = 0; m < K_M; ++m) {\n            for (int mask = T_D; mask < WARP_SIZE; mask *= 2) {\n                frag_M[m][0] = fmaxf(frag_M[m][0], __shfl_xor_sync(uint32_t(-1), frag_M[m][0], mask));\n            }\n            if (lane_id == 0) {\n                // printf(\"warp M %d %f\\n\", warp_id, frag_M[m][0]);\n                storage.M[m][warp_ids.y][warp_ids.x] = frag_M[m][0];\n            }\n        }\n\n        __syncthreads();\n\n        PRAGMA_UNROLL\n        for (int m = 0; m < K_M; ++m) {\n            PRAGMA_UNROLL\n            for (int w = 0; w < kWarpCntS - 1; ++w) {\n                frag_M[m][0] = fmaxf(frag_M[m][0], storage.M[m][warp_ids.y][(warp_ids.x + w + 1) % kWarpCntS]);\n            }\n            // if (threadIdx.x == 0) {\n            //     printf(\"M %d %f\\n\", m * OP_H + blockIdx.x * CTA_H, frag_M[m][0]);\n            // }\n        }\n\n        ///////////////////////////////////////////////////////////////////////////\n        //  rescale & global sum\n        PRAGMA_UNROLL\n        for (int m = 0; m < V_M; ++m) {\n            float expdiff_M = exp2f((prev_M[m][0] - frag_M[m][0]) * qk_scale);\n            if (frag_M[m][0] == -std::numeric_limits<float>::infinity()) {\n                expdiff_M = 0.f;\n            }\n            PRAGMA_UNROLL\n            for (int n = 0; n < V_N; ++n) {\n                PRAGMA_UNROLL\n                for (int d = 0; d < 8; ++d) {\n                    frag_O[m][n][d] = frag_O[m][n][d] * expdiff_M;\n                    PRAGMA_UNROLL\n                    for (int mask = T_D; mask < WARP_SIZE; mask *= 2) {\n                        frag_O[m][n][d] += __shfl_xor_sync(uint32_t(-1), frag_O[m][n][d], mask);\n                    }\n                }\n                PRAGMA_UNROLL\n                for (int d = 0; d < 8; d += 4) {\n                    if (lane_id < T_D) {\n                        Store(storage.O[m][n][d / 4][warp_ids.y][warp_ids.x][lane_id].data(),\n                              (Array<float, 4>&)frag_O[m][n][d]);\n                    }\n                }\n            }\n            frag_L[m][0] *= expdiff_M;\n            PRAGMA_UNROLL\n            for (int mask = T_D; mask < WARP_SIZE; mask *= 2) {\n                frag_L[m][0] += __shfl_xor_sync(uint32_t(-1), frag_L[m][0], mask);\n            }\n            if (lane_id == 0) {\n                storage.L[m][warp_ids.y][warp_ids.x] = frag_L[m][0];\n            }\n        }\n\n        __syncthreads();\n\n        clear(frag_O);\n\n        PRAGMA_UNROLL\n        for (int m = 0; m < V_M; ++m) {\n            PRAGMA_UNROLL\n            for (int n = 0; n < V_N; ++n) {\n#if 0\n                static_assert(kWarpCntS % 4 == 0);\n                PRAGMA_UNROLL\n                for (int s = 0; s < kWarpCntS; s += 4) {\n                    PRAGMA_UNROLL\n                    for (int d = 0; d < 8; d += 4) {\n                        Array<float, 4> tmp_O;\n                        Lds(tmp_O, storage.O[m][n][d / 4][warp_ids.y][s + lane_id / 8][lane_id % T_D].data());\n                        using namespace ops;\n                        (Array<float, 4>&)frag_O[m][n][d] = (Array<float, 4>&)frag_O[m][n][d] + tmp_O;\n                    }\n                }\n                PRAGMA_UNROLL\n                for (int d = 0; d < 8; ++d) {\n                    PRAGMA_UNROLL\n                    for (int mask = T_D; mask < WARP_SIZE; mask *= 2) {\n                        frag_O[m][n][d] += __shfl_xor_sync(uint32_t(-1), frag_O[m][n][d], mask);\n                    }\n                }\n#else\n                PRAGMA_UNROLL\n                for (int s = 0; s < kWarpCntS; ++s) {\n                    PRAGMA_UNROLL\n                    for (int d = 0; d < 8; d += 4) {\n                        Array<float, 4> tmp_O;\n                        Lds(tmp_O, storage.O[m][n][d / 4][warp_ids.y][s][lane_id % T_D].data());\n                        using namespace ops;\n                        (Array<float, 4>&)frag_O[m][n][d] = (Array<float, 4>&)frag_O[m][n][d] + tmp_O;\n                    }\n                }\n#endif\n            }\n            PRAGMA_UNROLL\n            for (int w = 0; w < kWarpCntS - 1; ++w) {\n                frag_L[m][0] += storage.L[m][warp_ids.y][(warp_ids.x + w + 1) % kWarpCntS];\n            }\n            // if (threadIdx.x == 0) {\n            //     printf(\"L %d %f\\n\", m * OP_H + blockIdx.x * CTA_H, frag_L[m][0]);\n            // }\n        }\n    }\n\n    template<bool is_norm, class Func>\n    __device__ static void StoreO(FragO& frag_O, const FragL& frag_L, SharedStorage& storage, Func&& func)\n    {\n        FragL inv_L;\n\n        PRAGMA_UNROLL\n        for (int m = 0; m < K_M; ++m) {\n            inv_L[m][0] = fdividef(1.f, frag_L[m][0]);\n        }\n\n        const auto warp_ids = get_warp_ids();\n        const int  lane_id  = threadIdx.x % WARP_SIZE;\n\n        if (warp_ids.x != 0) {\n            return;\n        }\n\n        PRAGMA_UNROLL\n        for (int m = 0; m < V_M; ++m) {\n            PRAGMA_UNROLL\n            for (int n = 0; n < V_N; ++n) {\n                if constexpr (is_norm) {\n                    PRAGMA_UNROLL\n                    for (int d = 0; d < 8; ++d) {\n                        frag_O[m][n][d] *= inv_L[m][0];\n                    }\n                }\n\n                if (lane_id < T_D) {\n                    const int hi = m * OP_H + warp_ids.y * WARP_H;\n                    const int di = n * S_D + lane_id * S_D_thr;\n                    // for (int i = 0; i < 8; ++i) {\n                    //     printf(\"O %4d %4d %f\\n\", hi + blockIdx.x * CTA_H, di + i, frag_O[m][n][i]);\n                    // }\n                    ((Func &&) func)(hi, 0, di, frag_O[m][n]);\n                }\n            }\n        }\n    }\n};\n\n}  // namespace turbomind::attention\n"
  },
  {
    "path": "src/turbomind/kernels/attention/iterator.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include \"src/turbomind/kernels/core/array.h\"\n#include \"src/turbomind/kernels/core/data_type.h\"\n#include \"src/turbomind/kernels/core/layout.h\"\n#include \"src/turbomind/kernels/core/pipe_iter.h\"\n#include <type_traits>\n\nnamespace turbomind {\n\ntemplate<class T, class Map, class SmemLayout>\nstruct BaseGmemIterator {\n    using ElementType = T;\n    using AccessType  = Array<T, Map::kAccessC>;\n    using Pointer     = get_pointer_type<T>;\n\n    static constexpr int kElementSize = sizeof(ElementType);\n    static constexpr int kAccessSize  = sizeof(AccessType);\n    static constexpr int kIterCount   = Map::kIterS * Map::kIterC;\n\n    using Fragment = Array<T, Map::kAccessC>[Map::kIterS][Map::kIterC];\n\n    Pointer smem_;\n\n    int src_offset_;\n    int offset_c_;\n    int offset_s_;\n\n    static constexpr std::integral_constant<bool, Map::kPartialC> partial_c_{};\n\n    std::conditional_t<partial_c_, bool, std::true_type> pred_c_;\n\n    __device__ BaseGmemIterator()\n    {\n        int  warp_id = threadIdx.x / WARP_SIZE;\n        int  lane_id = threadIdx.x % WARP_SIZE;\n        int2 offsets = Map::get_offset(warp_id, lane_id);\n        src_offset_  = offsets.x + offsets.y * Map::kDimC;\n        offset_c_    = offsets.x;\n        offset_s_    = offsets.y;\n        if constexpr (partial_c_) {\n            pred_c_ = offset_c_ < Map::kDimC;\n        }\n    }\n\n    __device__ void SetSmem(Pointer smem)\n    {\n        smem_ = smem;\n    }\n\n    __device__ void ClearSmem(int pipe_iter = 0)\n    {\n        SmemAccessor<T, SmemLayout> data{smem_};\n        PRAGMA_UNROLL\n        for (int s = 0; s < Map::kIterS; ++s) {\n            PRAGMA_UNROLL\n            for (int c = 0; c < Map::kIterC; ++c) {\n                if (pred_c_) {\n                    Store(&data(offset_s_ + s * Map::kDeltaS,\n                                offset_c_ + c * Map::kDeltaC,\n                                pipe_iter * SmemLayout::kSize),\n                          Array<T, Map::kAccessC>{});\n                }\n            }\n        }\n    }\n};\n\ntemplate<class T, class Layout>\nstruct BaseSmemIterator {\n    static constexpr int kElemSize = sizeof(T);\n\n    using Accessor = SmemAccessor<T, Layout>;\n    T* smem_;\n\n    __device__ explicit BaseSmemIterator(T* smem): smem_{smem} {}\n};\n\ntemplate<class Iterator0, class Iterator1>\nstruct CombinedIterator {\n    Iterator0 iterator0_;\n    Iterator1 iterator1_;\n\n    struct Fragment {\n        typename Iterator0::Fragment frag0;\n        typename Iterator1::Fragment frag1;\n    };\n\n    // NOTE: can't use reference type here, nvcc does not support variadic templates well in device code\n    template<typename... Args>\n    __device__ void Prefetch(Args... args)\n    {\n        iterator0_.Prefetch(args...);\n        iterator1_.Prefetch(args...);\n    }\n\n    /// TODO: Load(bool_constant, CacheIter&) -> Fragment\n    template<bool is_residue, class CacheIter>\n    __device__ void Load(const CacheIter& cache_iter, Fragment& frag, int max_s)\n    {\n        iterator0_.Load<is_residue>(cache_iter, frag.frag0, max_s);\n        iterator1_.Load<is_residue>(cache_iter, frag.frag1, max_s);\n    }\n\n    __device__ void Save(const Fragment& frag)\n    {\n        iterator0_.Save(frag.frag0);\n        iterator1_.Save(frag.frag1);\n    }\n\n    __device__ void ClearSmem(int pipe_iter = 0)\n    {\n        iterator0_.ClearSmem(pipe_iter);\n        iterator1_.ClearSmem(pipe_iter);\n    }\n\n    template<class P0, class P1>\n    __device__ void SetSmem(P0 p0, P1 p1)\n    {\n        iterator0_.SetSmem(p0);\n        iterator1_.SetSmem(p1);\n    }\n};\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/kernels/attention/iterator_sm70.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include \"iterator.h\"\n#include \"src/turbomind/kernels/core/array_ops.h\"\n\nnamespace turbomind {\n\ntemplate<class T, class Map, class SmemLayout, int Idx>\nstruct Sm70GmemIterator: BaseGmemIterator<T, Map, SmemLayout> {\n    using Base = BaseGmemIterator<T, Map, SmemLayout>;\n\n    using typename Base::AccessType;\n    using typename Base::Fragment;\n\n    using Base::src_offset_;\n    using Base::offset_c_;\n    using Base::offset_s_;\n    using Base::smem_;\n\n    using Base::partial_c_;\n    using Base::pred_c_;\n\n    using Base::Base;\n\n    template<bool is_residue, class TileIter>\n    __device__ void Load(const TileIter& tile_iter, Fragment& rmem, int max_s)\n    {\n        auto src_data = tile_iter.OffsetPtr<Idx>(src_offset_);\n        int  offset_s = Map::get_offset(threadIdx.x / WARP_SIZE, threadIdx.x % WARP_SIZE).y;\n        PRAGMA_UNROLL\n        for (int s = 0; s < Map::kIterS; ++s) {\n            PRAGMA_UNROLL\n            for (int c = 0; c < Map::kIterC; ++c) {\n                copy(Array<T, Map::kAccessC>{}, rmem[s][c]);\n                auto src = &src_data[s * Map::kDeltaS * Map::kDimC + c * Map::kDeltaC];\n                if constexpr (partial_c_) {  // Only quant params is partial C\n                    if (pred_c_) {\n                        Ldg(rmem[s][c], src);\n                    }\n                }\n                else if (!is_residue || offset_s + s * Map::kDeltaS < max_s) {\n                    Ldg(rmem[s][c], src);\n                }\n            }\n        }\n    }\n\n    __device__ void Save(const Fragment& rmem)\n    {\n        typename SmemLayout::Swizzle swizzle{};\n\n        SmemAccessor<T, SmemLayout> data{smem_};\n        PRAGMA_UNROLL\n        for (int s = 0; s < Map::kIterS; ++s) {\n            PRAGMA_UNROLL\n            for (int c = 0; c < Map::kIterC; ++c) {\n                if (!partial_c_ || pred_c_) {\n                    Store(&data(offset_s_ + s * Map::kDeltaS, offset_c_ + c * Map::kDeltaC), rmem[s][c]);\n                }\n            }\n        }\n    }\n};\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/kernels/attention/iterator_sm80.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include \"iterator.h\"\n#include \"src/turbomind/kernels/core/smem.h\"\n#include <cassert>\n#include <type_traits>\n\nnamespace turbomind {\n\n#if (__CUDACC_VER_MAJOR__ >= 11) && (__CUDACC_VER_MINOR__ >= 4)\n#define L2_CACHEHINT(size) \".L2::\" #size \"B\"\n#else\n#define L2_CACHEHINT(size)\n#endif\n\ntemplate<class T, class Map, class SmemLayout, int Idx>\nstruct Sm80GmemIterator: BaseGmemIterator<T, Map, SmemLayout> {\n\n    using Base = BaseGmemIterator<T, Map, SmemLayout>;\n\n    using typename Base::AccessType;\n\n    using Base::Base;\n    using Base::kElementSize;\n    using Base::src_offset_;\n    using Base::offset_c_;\n    using Base::offset_s_;\n    using Base::smem_;\n\n    using Base::partial_c_;\n    using Base::pred_c_;\n\n    template<class PartialS, class TileIter>\n    __device__ void\n    Prefetch(PartialS partial_s, const TileIter& tile_iter, int s_begin, int s_count, int max_s, int pipe_iter)\n    {\n        // `src_data` may be `SubBytePtr`\n        auto src_data = tile_iter.OffsetPtr<Idx>(src_offset_);\n\n        SmemAccessor<T, SmemLayout> dst_data{smem_};\n\n        PRAGMA_UNROLL\n        for (int s = s_begin; s < s_begin + s_count && s < Map::kIterS; ++s) {\n            PRAGMA_UNROLL\n            for (int c = 0; c < Map::kIterC; ++c) {\n                auto dst = cast_smem_ptr_to_uint(&dst_data(offset_s_ + s * Map::kDeltaS,  //\n                                                           offset_c_ + c * Map::kDeltaC,\n                                                           pipe_iter * SmemLayout::kSize));\n                auto src = &src_data[s * Map::kDeltaS * Map::kDimC + c * Map::kDeltaC];\n\n                if constexpr (partial_c_) {\n                    CpAsync(std::true_type{}, dst, (const T*)src, pred_c_);\n                }\n                else {\n                    CpAsync(partial_s, dst, (const T*)src, offset_s_ + s * Map::kDeltaS < max_s);\n                }\n            }\n        }\n    }\n\n    template<class Partial, class TileIter>\n    __device__ void Prefetch(Partial partial, const TileIter& tile_iter, int max_s, int pipe_iter)\n    {\n        Prefetch(partial, tile_iter, 0, Map::kIterS, max_s, pipe_iter);\n    }\n\n    __device__ void CpAsync(std::true_type, int ptr, const T* __restrict__ src, bool mask)\n    {\n#if TURBOMIND_ARCH_SM80\n        constexpr int size = sizeof(AccessType);\n        // clang-format off\n        if constexpr (size == 16) {\n            asm volatile(\"{\\n\"\n                        \"  .reg .pred p;\\n\"\n                        \"  setp.ne.b32 p, %0, 0;\\n\"\n                        \"  @p cp.async.cg.shared.global\" L2_CACHEHINT(128) \" [%1], [%2], %3;\\n\"\n                        \"}\\n\" ::\"r\"((int)mask),\n                        \"r\"(ptr),\n                        \"l\"(src),\n                        \"n\"(size));\n        } else {\n            asm volatile(\"{\\n\"\n                        \"  .reg .pred p;\\n\"\n                        \"  setp.ne.b32 p, %0, 0;\\n\"\n                        \"  @p cp.async.ca.shared.global\" L2_CACHEHINT(128) \" [%1], [%2], %3;\\n\"\n                        \"}\\n\" ::\"r\"((int)mask),\n                        \"r\"(ptr),\n                        \"l\"(src),\n                        \"n\"(size));\n        }\n        // clang-format on\n#else\n        assert(TURBOMIND_ARCH_SM80);\n#endif\n    }\n\n    __device__ void CpAsync(std::false_type, int ptr, const T* __restrict__ src, bool)\n    {\n#if TURBOMIND_ARCH_SM80\n        constexpr int size = sizeof(AccessType);\n        if constexpr (size == 16) {\n            asm volatile(\n                \"cp.async.cg.shared.global\" L2_CACHEHINT(128) \" [%0], [%1], %2;\\n\" ::\"r\"(ptr), \"l\"(src), \"n\"(size));\n        }\n        else {\n            asm volatile(\n                \"cp.async.ca.shared.global\" L2_CACHEHINT(128) \" [%0], [%1], %2;\\n\" ::\"r\"(ptr), \"l\"(src), \"n\"(size));\n        }\n#else\n        assert(TURBOMIND_ARCH_SM80);\n#endif\n    }\n};\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/kernels/attention/kernel/CMakeLists.txt",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nadd_library(attention_kernels STATIC\n            ../utils.cc\n            ../reduce.cu\n            attention_sm70_64.cu\n            attention_sm70_128.cu\n            attention_sm70_256.cu\n            attention_sm70_576.cu\n            attention_sm75_64.cu\n            attention_sm75_128.cu\n            attention_sm75_256.cu\n            attention_sm75_576.cu\n            attention_sm80_64.cu\n            attention_sm80_128.cu\n            attention_sm80_192.cu\n            attention_sm80_256.cu\n            attention_sm80_576.cu\n            decoding_sm70_64.cu\n            decoding_sm70_128.cu\n            decoding_sm70_256.cu\n            decoding_sm70_576.cu\n            decoding_sm75_64.cu\n            decoding_sm75_128.cu\n            decoding_sm75_256.cu\n            decoding_sm75_576.cu\n            decoding_sm80_64.cu\n            decoding_sm80_128.cu\n            decoding_sm80_192.cu\n            decoding_sm80_256.cu\n            decoding_sm80_576.cu\n            )\nset_property(TARGET attention_kernels PROPERTY POSITION_INDEPENDENT_CODE ON)\nset_property(TARGET attention_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)\ntarget_compile_options(attention_kernels PRIVATE -O3\n    $<$<COMPILE_LANGUAGE:CUDA>:-use_fast_math --expt-relaxed-constexpr  -Xptxas=-v --threads 8>)\ntarget_link_libraries(attention_kernels PRIVATE nvidia::cutlass::cutlass)\n"
  },
  {
    "path": "src/turbomind/kernels/attention/kernel/attention_sm70_128.cu",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include \"src/turbomind/kernels/attention/attention_universal.h\"\n#include \"src/turbomind/kernels/attention/cta_map.h\"\n#include \"src/turbomind/kernels/attention/impl.h\"\n#include \"src/turbomind/kernels/attention/impl_884.h\"\n#include \"src/turbomind/kernels/attention/linear_iterator.h\"\n#include \"src/turbomind/kernels/attention/mainloop_sm70.h\"\n#include \"src/turbomind/kernels/attention/registrar.h\"\n\nnamespace turbomind::attention {\n\nconstexpr int kHeadDim = 128;\nconstexpr int kCTA_Q   = 64;\nconstexpr int kCTA_S   = 64;\nconstexpr int kWARP_Q  = 16;\nconstexpr int kStages  = 2;\n\ntemplate<class T>\nusing KT = AttentionUniversal<\n    arch::Sm70,\n    Mainloop<arch::Sm70, Impl<MMA_884, T, T, 1, kCTA_Q, kCTA_S, 1, kWARP_Q, kCTA_S, kHeadDim, kStages>>,\n    LinearIteratorFactory<T, kCTA_S, kHeadDim>,\n    AttentionCtaMap>;\n\nnamespace {\nRegistrar reg([](Collector& c) { c.add<KT<half>>(); });\n}\n\n}  // namespace turbomind::attention\n"
  },
  {
    "path": "src/turbomind/kernels/attention/kernel/attention_sm70_256.cu",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include \"src/turbomind/kernels/attention/attention_universal.h\"\n#include \"src/turbomind/kernels/attention/cta_map.h\"\n#include \"src/turbomind/kernels/attention/impl.h\"\n#include \"src/turbomind/kernels/attention/impl_884.h\"\n#include \"src/turbomind/kernels/attention/linear_iterator.h\"\n#include \"src/turbomind/kernels/attention/mainloop_sm70.h\"\n#include \"src/turbomind/kernels/attention/registrar.h\"\n\nnamespace turbomind::attention {\n\nconstexpr int kHeadDim = 256;\nconstexpr int kCTA_Q   = 64;\nconstexpr int kCTA_S   = 64;\nconstexpr int kWARP_Q  = 16;\nconstexpr int kStages  = 2;\n\ntemplate<class T>\nusing KT = AttentionUniversal<\n    arch::Sm70,\n    Mainloop<arch::Sm70, Impl<MMA_884, T, T, 1, kCTA_Q, kCTA_S, 1, kWARP_Q, kCTA_S, kHeadDim, kStages>>,\n    LinearIteratorFactory<T, kCTA_S, kHeadDim>,\n    AttentionCtaMap>;\n\nnamespace {\nRegistrar reg([](Collector& c) { c.add<KT<half>>(); });\n}\n\n}  // namespace turbomind::attention\n"
  },
  {
    "path": "src/turbomind/kernels/attention/kernel/attention_sm70_576.cu",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include \"src/turbomind/kernels/attention/attention_universal.h\"\n#include \"src/turbomind/kernels/attention/cta_map.h\"\n#include \"src/turbomind/kernels/attention/impl.h\"\n#include \"src/turbomind/kernels/attention/impl_884.h\"\n#include \"src/turbomind/kernels/attention/linear_iterator.h\"\n#include \"src/turbomind/kernels/attention/mainloop_sm70.h\"\n#include \"src/turbomind/kernels/attention/registrar.h\"\n\nnamespace turbomind::attention {\n\n// HeadDim=576 on Sm70: kCTA_S=32, WARP_S=kCTA_S to fit within V100's 96 KB shared memory limit\nconstexpr int kHeadDim = 576;\nconstexpr int kCTA_Q   = 64;\nconstexpr int kCTA_S   = 32;\nconstexpr int kWARP_Q  = 16;\nconstexpr int kStages  = 2;\n\ntemplate<class T>\nusing KT = AttentionUniversal<\n    arch::Sm70,\n    Mainloop<arch::Sm70, Impl<MMA_884, T, T, 1, kCTA_Q, kCTA_S, 1, kWARP_Q, kCTA_S, kHeadDim, kStages>>,\n    LinearIteratorFactory<T, kCTA_S, kHeadDim>,\n    AttentionCtaMap>;\n\nnamespace {\nRegistrar reg([](Collector& c) { c.add<KT<half>>(); });\n}\n\n}  // namespace turbomind::attention\n"
  },
  {
    "path": "src/turbomind/kernels/attention/kernel/attention_sm70_64.cu",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include \"src/turbomind/kernels/attention/attention_universal.h\"\n#include \"src/turbomind/kernels/attention/cta_map.h\"\n#include \"src/turbomind/kernels/attention/impl.h\"\n#include \"src/turbomind/kernels/attention/impl_884.h\"\n#include \"src/turbomind/kernels/attention/linear_iterator.h\"\n#include \"src/turbomind/kernels/attention/mainloop_sm70.h\"\n#include \"src/turbomind/kernels/attention/registrar.h\"\n\nnamespace turbomind::attention {\n\nconstexpr int kHeadDim = 64;\nconstexpr int kCTA_Q   = 64;\nconstexpr int kCTA_S   = 64;\nconstexpr int kWARP_Q  = 16;\nconstexpr int kStages  = 2;\n\ntemplate<class T>\nusing KT = AttentionUniversal<\n    arch::Sm70,\n    Mainloop<arch::Sm70, Impl<MMA_884, T, T, 1, kCTA_Q, kCTA_S, 1, kWARP_Q, kCTA_S, kHeadDim, kStages>>,\n    LinearIteratorFactory<T, kCTA_S, kHeadDim>,\n    AttentionCtaMap>;\n\nnamespace {\nRegistrar reg([](Collector& c) { c.add<KT<half>>(); });\n}\n\n}  // namespace turbomind::attention\n"
  },
  {
    "path": "src/turbomind/kernels/attention/kernel/attention_sm75_128.cu",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include \"src/turbomind/kernels/attention/attention_universal.h\"\n#include \"src/turbomind/kernels/attention/cta_map.h\"\n#include \"src/turbomind/kernels/attention/impl.h\"\n#include \"src/turbomind/kernels/attention/impl_1688.h\"\n#include \"src/turbomind/kernels/attention/linear_iterator.h\"\n#include \"src/turbomind/kernels/attention/mainloop_sm70.h\"\n#include \"src/turbomind/kernels/attention/registrar.h\"\n\nnamespace turbomind::attention {\n\nconstexpr int kHeadDim = 128;\nconstexpr int kCTA_Q   = 64;\nconstexpr int kCTA_S   = 64;\nconstexpr int kWARP_Q  = 16;\nconstexpr int kStages  = 2;\n\ntemplate<class T>\nusing KT = AttentionUniversal<\n    arch::Sm75,\n    Mainloop<arch::Sm70, Impl<MMA_1688, T, T, 1, kCTA_Q, kCTA_S, 1, kWARP_Q, kCTA_S, kHeadDim, kStages>>,\n    LinearIteratorFactory<T, kCTA_S, kHeadDim>,\n    AttentionCtaMap>;\n\nnamespace {\nRegistrar reg([](Collector& c) { c.add<KT<half>>(); });\n}\n\n}  // namespace turbomind::attention\n"
  },
  {
    "path": "src/turbomind/kernels/attention/kernel/attention_sm75_256.cu",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include \"src/turbomind/kernels/attention/attention_universal.h\"\n#include \"src/turbomind/kernels/attention/cta_map.h\"\n#include \"src/turbomind/kernels/attention/impl.h\"\n#include \"src/turbomind/kernels/attention/impl_1688.h\"\n#include \"src/turbomind/kernels/attention/linear_iterator.h\"\n#include \"src/turbomind/kernels/attention/mainloop_sm70.h\"\n#include \"src/turbomind/kernels/attention/registrar.h\"\n\nnamespace turbomind::attention {\n\nconstexpr int kHeadDim = 256;\nconstexpr int kCTA_Q   = 64;\nconstexpr int kCTA_S   = 64;\nconstexpr int kWARP_Q  = 16;\nconstexpr int kStages  = 2;\n\ntemplate<class T>\nusing KT = AttentionUniversal<\n    arch::Sm75,\n    Mainloop<arch::Sm70, Impl<MMA_1688, T, T, 1, kCTA_Q, kCTA_S, 1, kWARP_Q, kCTA_S, kHeadDim, kStages>>,\n    LinearIteratorFactory<T, kCTA_S, kHeadDim>,\n    AttentionCtaMap>;\n\nnamespace {\nRegistrar reg([](Collector& c) { c.add<KT<half>>(); });\n}\n\n}  // namespace turbomind::attention\n"
  },
  {
    "path": "src/turbomind/kernels/attention/kernel/attention_sm75_576.cu",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include \"src/turbomind/kernels/attention/attention_universal.h\"\n#include \"src/turbomind/kernels/attention/cta_map.h\"\n#include \"src/turbomind/kernels/attention/impl.h\"\n#include \"src/turbomind/kernels/attention/impl_1688.h\"\n#include \"src/turbomind/kernels/attention/linear_iterator.h\"\n#include \"src/turbomind/kernels/attention/mainloop_sm70.h\"\n#include \"src/turbomind/kernels/attention/registrar.h\"\n\nnamespace turbomind::attention {\n\nconstexpr int kHeadDim = 576;\nconstexpr int kCTA_Q   = 64;\nconstexpr int kCTA_S   = 32;\nconstexpr int kWARP_Q  = 16;\nconstexpr int kStages  = 2;\n\ntemplate<class T>\nusing KT = AttentionUniversal<\n    arch::Sm75,\n    Mainloop<arch::Sm70, Impl<MMA_1688, T, T, 1, kCTA_Q, kCTA_S, 1, kWARP_Q, kCTA_S, kHeadDim, kStages>>,\n    LinearIteratorFactory<T, kCTA_S, kHeadDim>,\n    AttentionCtaMap>;\n\nnamespace {\nRegistrar reg([](Collector& c) { c.add<KT<half>>(); });\n}\n\n}  // namespace turbomind::attention\n"
  },
  {
    "path": "src/turbomind/kernels/attention/kernel/attention_sm75_64.cu",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include \"src/turbomind/kernels/attention/attention_universal.h\"\n#include \"src/turbomind/kernels/attention/cta_map.h\"\n#include \"src/turbomind/kernels/attention/impl.h\"\n#include \"src/turbomind/kernels/attention/impl_1688.h\"\n#include \"src/turbomind/kernels/attention/linear_iterator.h\"\n#include \"src/turbomind/kernels/attention/mainloop_sm70.h\"\n#include \"src/turbomind/kernels/attention/registrar.h\"\n\nnamespace turbomind::attention {\n\nconstexpr int kHeadDim = 64;\nconstexpr int kCTA_Q   = 64;\nconstexpr int kCTA_S   = 64;\nconstexpr int kWARP_Q  = 16;\nconstexpr int kStages  = 2;\n\ntemplate<class T>\nusing KT = AttentionUniversal<\n    arch::Sm75,\n    Mainloop<arch::Sm70, Impl<MMA_1688, T, T, 1, kCTA_Q, kCTA_S, 1, kWARP_Q, kCTA_S, kHeadDim, kStages>>,\n    LinearIteratorFactory<T, kCTA_S, kHeadDim>,\n    AttentionCtaMap>;\n\nnamespace {\nRegistrar reg([](Collector& c) { c.add<KT<half>>(); });\n}\n\n}  // namespace turbomind::attention\n"
  },
  {
    "path": "src/turbomind/kernels/attention/kernel/attention_sm80_128.cu",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include \"src/turbomind/kernels/attention/arch.h\"\n#include \"src/turbomind/kernels/attention/attention_universal.h\"\n#include \"src/turbomind/kernels/attention/cta_map.h\"\n#include \"src/turbomind/kernels/attention/impl.h\"\n#include \"src/turbomind/kernels/attention/impl_16816.h\"\n#include \"src/turbomind/kernels/attention/linear_iterator.h\"\n#include \"src/turbomind/kernels/attention/mainloop_sm80.h\"\n#include \"src/turbomind/kernels/attention/registrar.h\"\n\nnamespace turbomind::attention {\n\nconstexpr int kHeadDim = 128;\nconstexpr int kCTA_Q   = 64;\nconstexpr int kCTA_S   = 64;\nconstexpr int kWARP_Q  = 16;\nconstexpr int kStages  = 2;\n\ntemplate<class T>\nusing KT = AttentionUniversal<\n    arch::Sm80,\n    Mainloop<Sm80_CpAsync<kStages>, Impl<MMA_16816, T, T, 1, kCTA_Q, kCTA_S, 1, kWARP_Q, kCTA_S, kHeadDim, kStages>>,\n    LinearIteratorFactory<T, kCTA_S, kHeadDim>,\n    AttentionCtaMap>;\n\nnamespace {\nRegistrar reg([](Collector& c) {\n    c.add<KT<half>>();\n#if ENABLE_BF16\n    c.add<KT<nv_bfloat16>>();\n#endif\n});\n}  // namespace\n\n}  // namespace turbomind::attention\n"
  },
  {
    "path": "src/turbomind/kernels/attention/kernel/attention_sm80_192.cu",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include \"src/turbomind/kernels/attention/arch.h\"\n#include \"src/turbomind/kernels/attention/attention_universal.h\"\n#include \"src/turbomind/kernels/attention/cta_map.h\"\n#include \"src/turbomind/kernels/attention/impl.h\"\n#include \"src/turbomind/kernels/attention/impl_16816.h\"\n#include \"src/turbomind/kernels/attention/linear_iterator.h\"\n#include \"src/turbomind/kernels/attention/mainloop_sm80.h\"\n#include \"src/turbomind/kernels/attention/registrar.h\"\n\nnamespace turbomind::attention {\n\nconstexpr int kHeadDim = 192;\nconstexpr int kCTA_Q   = 64;\nconstexpr int kCTA_S   = 64;\nconstexpr int kWARP_Q  = 16;\nconstexpr int kStages  = 2;\n\ntemplate<class T>\nusing KT = AttentionUniversal<\n    arch::Sm80,\n    Mainloop<Sm80_CpAsync<kStages>, Impl<MMA_16816, T, T, 1, kCTA_Q, kCTA_S, 1, kWARP_Q, kCTA_S, kHeadDim, kStages>>,\n    LinearIteratorFactory<T, kCTA_S, kHeadDim>,\n    AttentionCtaMap>;\n\nnamespace {\nRegistrar reg([](Collector& c) {\n    c.add<KT<half>>();\n#if ENABLE_BF16\n    c.add<KT<nv_bfloat16>>();\n#endif\n});\n}  // namespace\n\n}  // namespace turbomind::attention\n"
  },
  {
    "path": "src/turbomind/kernels/attention/kernel/attention_sm80_256.cu",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include \"src/turbomind/kernels/attention/arch.h\"\n#include \"src/turbomind/kernels/attention/attention_universal.h\"\n#include \"src/turbomind/kernels/attention/cta_map.h\"\n#include \"src/turbomind/kernels/attention/impl.h\"\n#include \"src/turbomind/kernels/attention/impl_16816.h\"\n#include \"src/turbomind/kernels/attention/linear_iterator.h\"\n#include \"src/turbomind/kernels/attention/mainloop_sm80.h\"\n#include \"src/turbomind/kernels/attention/registrar.h\"\n\nnamespace turbomind::attention {\n\nconstexpr int kHeadDim = 256;\nconstexpr int kCTA_Q   = 64;\nconstexpr int kCTA_S   = 64;\nconstexpr int kWARP_Q  = 16;\nconstexpr int kStages  = 2;\n\ntemplate<class T>\nusing KT = AttentionUniversal<\n    arch::Sm80,\n    Mainloop<Sm80_CpAsync<kStages>, Impl<MMA_16816, T, T, 1, kCTA_Q, kCTA_S, 1, kWARP_Q, kCTA_S, kHeadDim, kStages>>,\n    LinearIteratorFactory<T, kCTA_S, kHeadDim>,\n    AttentionCtaMap>;\n\nnamespace {\nRegistrar reg([](Collector& c) {\n    c.add<KT<half>>();\n#if ENABLE_BF16\n    c.add<KT<nv_bfloat16>>();\n#endif\n});\n}  // namespace\n\n}  // namespace turbomind::attention\n"
  },
  {
    "path": "src/turbomind/kernels/attention/kernel/attention_sm80_576.cu",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include \"src/turbomind/kernels/attention/arch.h\"\n#include \"src/turbomind/kernels/attention/attention_universal.h\"\n#include \"src/turbomind/kernels/attention/cta_map.h\"\n#include \"src/turbomind/kernels/attention/impl.h\"\n#include \"src/turbomind/kernels/attention/impl_16816.h\"\n#include \"src/turbomind/kernels/attention/linear_iterator.h\"\n#include \"src/turbomind/kernels/attention/mainloop_sm80.h\"\n#include \"src/turbomind/kernels/attention/registrar.h\"\n\nnamespace turbomind::attention {\n\nconstexpr int kHeadDim = 576;\nconstexpr int kCTA_Q   = 64;\nconstexpr int kCTA_S   = 32;\nconstexpr int kWARP_Q  = 16;\nconstexpr int kStages  = 2;\n\ntemplate<class T>\nusing KT = AttentionUniversal<\n    arch::Sm80,\n    Mainloop<Sm80_CpAsync<kStages>, Impl<MMA_16816, T, T, 1, kCTA_Q, kCTA_S, 1, kWARP_Q, kCTA_S, kHeadDim, kStages>>,\n    LinearIteratorFactory<T, kCTA_S, kHeadDim>,\n    AttentionCtaMap>;\n\nnamespace {\nRegistrar reg([](Collector& c) {\n    c.add<KT<half>>();\n#if ENABLE_BF16\n    c.add<KT<nv_bfloat16>>();\n#endif\n});\n}  // namespace\n\n}  // namespace turbomind::attention\n"
  },
  {
    "path": "src/turbomind/kernels/attention/kernel/attention_sm80_64.cu",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include \"src/turbomind/kernels/attention/arch.h\"\n#include \"src/turbomind/kernels/attention/attention_universal.h\"\n#include \"src/turbomind/kernels/attention/cta_map.h\"\n#include \"src/turbomind/kernels/attention/impl.h\"\n#include \"src/turbomind/kernels/attention/impl_16816.h\"\n#include \"src/turbomind/kernels/attention/linear_iterator.h\"\n#include \"src/turbomind/kernels/attention/mainloop_sm80.h\"\n#include \"src/turbomind/kernels/attention/registrar.h\"\n\nnamespace turbomind::attention {\n\n// HeadDim=64 special case: kCTA_S=128, WARP_S=kCTA_S\nconstexpr int kHeadDim = 64;\nconstexpr int kCTA_Q   = 64;\nconstexpr int kCTA_S   = 128;\nconstexpr int kWARP_Q  = 16;\nconstexpr int kStages  = 2;\n\ntemplate<class T>\nusing KT = AttentionUniversal<\n    arch::Sm80,\n    Mainloop<Sm80_CpAsync<kStages>, Impl<MMA_16816, T, T, 1, kCTA_Q, kCTA_S, 1, kWARP_Q, kCTA_S, kHeadDim, kStages>>,\n    LinearIteratorFactory<T, kCTA_S, kHeadDim>,\n    AttentionCtaMap>;\n\nnamespace {\nRegistrar reg([](Collector& c) {\n    c.add<KT<half>>();\n#if ENABLE_BF16\n    c.add<KT<nv_bfloat16>>();\n#endif\n});\n}  // namespace\n\n}  // namespace turbomind::attention\n"
  },
  {
    "path": "src/turbomind/kernels/attention/kernel/decoding_sm70_128.cu",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include \"src/turbomind/kernels/attention/attention_universal.h\"\n#include \"src/turbomind/kernels/attention/block_iterator.h\"\n#include \"src/turbomind/kernels/attention/cta_map.h\"\n#include \"src/turbomind/kernels/attention/impl.h\"\n#include \"src/turbomind/kernels/attention/impl_simt.h\"\n#include \"src/turbomind/kernels/attention/mainloop.h\"\n#include \"src/turbomind/kernels/attention/mainloop_sm70.h\"\n#include \"src/turbomind/kernels/attention/registrar.h\"\n\nnamespace turbomind::attention {\n\nconstexpr int kHeadDim = 128;\nconstexpr int kCTA_S   = 64;\nconstexpr int kWARP_S  = 16;\nconstexpr int kStages  = 2;\n\n// kH = Qh%3==0 ? 3 : (Qh%2==0 ? 2 : 1)\n// kH=1 covers Qh ∈ {1,5,7}, kH=2 covers {2,4,8}, kH=3 covers {3,6,9}\ntemplate<class T, class Tkv, int kH>\nusing KT =\n    AttentionUniversal<arch::Sm70,\n                       Mainloop<arch::Sm70, Impl<MMA_SIMT, T, Tkv, kH, 1, kCTA_S, kH, 1, kWARP_S, kHeadDim, kStages>>,\n                       GetBlockIterFactory<T, Tkv, kCTA_S, kHeadDim>,\n                       DecodingCtaMap>;\n\nnamespace {\nRegistrar reg([](Collector& c) {\n    c.add<KT<half, half, 1>>();\n    c.add<KT<half, half, 2>>();\n    c.add<KT<half, half, 3>>();\n\n    c.add<KT<half, uint8_t, 1>>();\n    c.add<KT<half, uint8_t, 2>>();\n    c.add<KT<half, uint8_t, 3>>();\n\n    c.add<KT<half, uint4_t, 1>>();\n    c.add<KT<half, uint4_t, 2>>();\n    c.add<KT<half, uint4_t, 3>>();\n});\n}\n\n}  // namespace turbomind::attention\n"
  },
  {
    "path": "src/turbomind/kernels/attention/kernel/decoding_sm70_256.cu",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include \"src/turbomind/kernels/attention/attention_universal.h\"\n#include \"src/turbomind/kernels/attention/block_iterator.h\"\n#include \"src/turbomind/kernels/attention/cta_map.h\"\n#include \"src/turbomind/kernels/attention/impl.h\"\n#include \"src/turbomind/kernels/attention/impl_simt.h\"\n#include \"src/turbomind/kernels/attention/mainloop.h\"\n#include \"src/turbomind/kernels/attention/mainloop_sm70.h\"\n#include \"src/turbomind/kernels/attention/registrar.h\"\n\nnamespace turbomind::attention {\n\nconstexpr int kHeadDim = 256;\nconstexpr int kCTA_S   = 32;\nconstexpr int kWARP_S  = 16;\nconstexpr int kStages  = 2;\n\n// kH = Qh%3==0 ? 3 : (Qh%2==0 ? 2 : 1)\n// kH=1 covers Qh ∈ {1,5,7}, kH=2 covers {2,4,8}, kH=3 covers {3,6,9}\ntemplate<class T, class Tkv, int kH>\nusing KT =\n    AttentionUniversal<arch::Sm70,\n                       Mainloop<arch::Sm70, Impl<MMA_SIMT, T, Tkv, kH, 1, kCTA_S, kH, 1, kWARP_S, kHeadDim, kStages>>,\n                       GetBlockIterFactory<T, Tkv, kCTA_S, kHeadDim>,\n                       DecodingCtaMap>;\n\nnamespace {\nRegistrar reg([](Collector& c) {\n    c.add<KT<half, half, 1>>();\n    c.add<KT<half, half, 2>>();\n    c.add<KT<half, half, 3>>();\n\n    c.add<KT<half, uint8_t, 1>>();\n    c.add<KT<half, uint8_t, 2>>();\n    c.add<KT<half, uint8_t, 3>>();\n\n    c.add<KT<half, uint4_t, 1>>();\n    c.add<KT<half, uint4_t, 2>>();\n    c.add<KT<half, uint4_t, 3>>();\n});\n}\n\n}  // namespace turbomind::attention\n"
  },
  {
    "path": "src/turbomind/kernels/attention/kernel/decoding_sm70_576.cu",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include \"src/turbomind/kernels/attention/attention_universal.h\"\n#include \"src/turbomind/kernels/attention/block_iterator.h\"\n#include \"src/turbomind/kernels/attention/cta_map.h\"\n#include \"src/turbomind/kernels/attention/impl.h\"\n#include \"src/turbomind/kernels/attention/impl_simt.h\"\n#include \"src/turbomind/kernels/attention/mainloop.h\"\n#include \"src/turbomind/kernels/attention/mainloop_sm70.h\"\n#include \"src/turbomind/kernels/attention/registrar.h\"\n\nnamespace turbomind::attention {\n\nconstexpr int kHeadDim = 576;\n\n// CTA_H=2, CTA_S=16, WARP_H=1, WARP_S=8, Stages=2\ntemplate<class T, class Tkv>\nusing KT = AttentionUniversal<arch::Sm70,\n                              Mainloop<arch::Sm70, Impl<MMA_SIMT, T, Tkv, 2, 1, 16, 1, 1, 8, kHeadDim, 2>>,\n                              GetBlockIterFactory<T, Tkv, 16, kHeadDim>,\n                              DecodingCtaMap>;\n\nnamespace {\nRegistrar reg([](Collector& c) {\n    c.add<KT<half, half>>();\n    c.add<KT<half, uint8_t>>();\n    c.add<KT<half, uint4_t>>();\n});\n}\n\n}  // namespace turbomind::attention\n"
  },
  {
    "path": "src/turbomind/kernels/attention/kernel/decoding_sm70_64.cu",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include \"src/turbomind/kernels/attention/attention_universal.h\"\n#include \"src/turbomind/kernels/attention/block_iterator.h\"\n#include \"src/turbomind/kernels/attention/cta_map.h\"\n#include \"src/turbomind/kernels/attention/impl.h\"\n#include \"src/turbomind/kernels/attention/impl_simt.h\"\n#include \"src/turbomind/kernels/attention/mainloop.h\"\n#include \"src/turbomind/kernels/attention/mainloop_sm70.h\"\n#include \"src/turbomind/kernels/attention/registrar.h\"\n\nnamespace turbomind::attention {\n\nconstexpr int kHeadDim = 64;\nconstexpr int kCTA_S   = 64;\nconstexpr int kWARP_S  = 16;\nconstexpr int kStages  = 2;\n\n// kH = Qh%3==0 ? 3 : (Qh%2==0 ? 2 : 1)\n// kH=1 covers Qh ∈ {1,5,7}, kH=2 covers {2,4,8}, kH=3 covers {3,6,9}\ntemplate<class T, class Tkv, int kH>\nusing KT =\n    AttentionUniversal<arch::Sm70,\n                       Mainloop<arch::Sm70, Impl<MMA_SIMT, T, Tkv, kH, 1, kCTA_S, kH, 1, kWARP_S, kHeadDim, kStages>>,\n                       GetBlockIterFactory<T, Tkv, kCTA_S, kHeadDim>,\n                       DecodingCtaMap>;\n\nnamespace {\nRegistrar reg([](Collector& c) {\n    c.add<KT<half, half, 1>>();\n    c.add<KT<half, half, 2>>();\n    c.add<KT<half, half, 3>>();\n\n    c.add<KT<half, uint8_t, 1>>();\n    c.add<KT<half, uint8_t, 2>>();\n    c.add<KT<half, uint8_t, 3>>();\n\n    c.add<KT<half, uint4_t, 1>>();\n    c.add<KT<half, uint4_t, 2>>();\n    c.add<KT<half, uint4_t, 3>>();\n});\n}\n\n}  // namespace turbomind::attention\n"
  },
  {
    "path": "src/turbomind/kernels/attention/kernel/decoding_sm75_128.cu",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include \"src/turbomind/kernels/attention/attention_universal.h\"\n#include \"src/turbomind/kernels/attention/block_iterator.h\"\n#include \"src/turbomind/kernels/attention/cta_map.h\"\n#include \"src/turbomind/kernels/attention/impl.h\"\n#include \"src/turbomind/kernels/attention/impl_81616.h\"\n#include \"src/turbomind/kernels/attention/mainloop.h\"\n#include \"src/turbomind/kernels/attention/mainloop_sm70.h\"\n#include \"src/turbomind/kernels/attention/registrar.h\"\n\nnamespace turbomind::attention {\n\nconstexpr int kHeadDim = 128;\nconstexpr int kCTA_S   = 64;\nconstexpr int kWARP_S  = 16;\nconstexpr int kStages  = 2;\n\n// Qh = (Qh_+7)/8*8: Qh_=1..8 → Qh=8, Qh_=9 → Qh=16\ntemplate<class T, class Tkv, int Qh>\nusing KT =\n    AttentionUniversal<arch::Sm75,\n                       Mainloop<arch::Sm70, Impl<MMA_81616, T, Tkv, Qh, 1, kCTA_S, Qh, 1, kWARP_S, kHeadDim, kStages>>,\n                       GetBlockIterFactory<T, Tkv, kCTA_S, kHeadDim>,\n                       DecodingCtaMap>;\n\nnamespace {\nRegistrar reg([](Collector& c) {\n    c.add<KT<half, half, 8>>();\n    c.add<KT<half, half, 16>>();\n\n    c.add<KT<half, uint8_t, 8>>();\n    c.add<KT<half, uint8_t, 16>>();\n\n    c.add<KT<half, uint4_t, 8>>();\n    c.add<KT<half, uint4_t, 16>>();\n});\n}\n\n}  // namespace turbomind::attention\n"
  },
  {
    "path": "src/turbomind/kernels/attention/kernel/decoding_sm75_256.cu",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include \"src/turbomind/kernels/attention/attention_universal.h\"\n#include \"src/turbomind/kernels/attention/block_iterator.h\"\n#include \"src/turbomind/kernels/attention/cta_map.h\"\n#include \"src/turbomind/kernels/attention/impl.h\"\n#include \"src/turbomind/kernels/attention/impl_81616.h\"\n#include \"src/turbomind/kernels/attention/mainloop.h\"\n#include \"src/turbomind/kernels/attention/mainloop_sm70.h\"\n#include \"src/turbomind/kernels/attention/registrar.h\"\n\nnamespace turbomind::attention {\n\nconstexpr int kHeadDim = 256;\nconstexpr int kCTA_S   = 64;\nconstexpr int kWARP_S  = 16;\nconstexpr int kStages  = 3;\n\n// Qh = (Qh_+7)/8*8: Qh_=1..8 → Qh=8, Qh_=9 → Qh=16\n// For 256 head dim, we use Qh=1 and Qh=9 (which maps to 16)\ntemplate<class T, class Tkv, int Qh>\nusing KT =\n    AttentionUniversal<arch::Sm75,\n                       Mainloop<arch::Sm70, Impl<MMA_81616, T, Tkv, Qh, 1, kCTA_S, Qh, 1, kWARP_S, kHeadDim, kStages>>,\n                       GetBlockIterFactory<T, Tkv, kCTA_S, kHeadDim>,\n                       DecodingCtaMap>;\n\nnamespace {\nRegistrar reg([](Collector& c) {\n    c.add<KT<half, half, 1>>();\n    c.add<KT<half, half, 16>>();  // Qh=9 maps to 16\n\n    c.add<KT<half, uint8_t, 1>>();\n    c.add<KT<half, uint8_t, 16>>();  // Qh=9 maps to 16\n\n    c.add<KT<half, uint4_t, 1>>();\n    c.add<KT<half, uint4_t, 16>>();  // Qh=9 maps to 16\n});\n}\n\n}  // namespace turbomind::attention\n"
  },
  {
    "path": "src/turbomind/kernels/attention/kernel/decoding_sm75_576.cu",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include \"src/turbomind/kernels/attention/attention_universal.h\"\n#include \"src/turbomind/kernels/attention/block_iterator.h\"\n#include \"src/turbomind/kernels/attention/cta_map.h\"\n#include \"src/turbomind/kernels/attention/impl.h\"\n#include \"src/turbomind/kernels/attention/impl_81616.h\"\n#include \"src/turbomind/kernels/attention/mainloop.h\"\n#include \"src/turbomind/kernels/attention/mainloop_sm70.h\"\n#include \"src/turbomind/kernels/attention/registrar.h\"\n\nnamespace turbomind::attention {\n\nconstexpr int kHeadDim = 576;\n\n// MLA config for all Tkv: CTA_H=16, CTA_S=16, WARP_H=8, WARP_S=16, Stages=2\ntemplate<class T, class Tkv>\nusing KT = AttentionUniversal<arch::Sm75,\n                              Mainloop<arch::Sm70, Impl<MMA_81616, T, Tkv, 16, 1, 32, 8, 1, 16, kHeadDim, 2>>,\n                              GetBlockIterFactory<T, Tkv, 32, kHeadDim>,\n                              DecodingCtaMap>;\n\nnamespace {\nRegistrar reg([](Collector& c) {\n    c.add<KT<half, half>>();\n    c.add<KT<half, uint8_t>>();\n    c.add<KT<half, uint4_t>>();\n});\n}\n\n}  // namespace turbomind::attention\n"
  },
  {
    "path": "src/turbomind/kernels/attention/kernel/decoding_sm75_64.cu",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include \"src/turbomind/kernels/attention/attention_universal.h\"\n#include \"src/turbomind/kernels/attention/block_iterator.h\"\n#include \"src/turbomind/kernels/attention/cta_map.h\"\n#include \"src/turbomind/kernels/attention/impl.h\"\n#include \"src/turbomind/kernels/attention/impl_81616.h\"\n#include \"src/turbomind/kernels/attention/mainloop.h\"\n#include \"src/turbomind/kernels/attention/mainloop_sm70.h\"\n#include \"src/turbomind/kernels/attention/registrar.h\"\n\nnamespace turbomind::attention {\n\nconstexpr int kHeadDim = 64;\nconstexpr int kCTA_S   = 64;\nconstexpr int kWARP_S  = 16;\nconstexpr int kStages  = 2;\n\n// Qh = (Qh_+7)/8*8: Qh_=1..8 → Qh=8, Qh_=9 → Qh=16\ntemplate<class T, class Tkv, int Qh>\nusing KT =\n    AttentionUniversal<arch::Sm75,\n                       Mainloop<arch::Sm70, Impl<MMA_81616, T, Tkv, Qh, 1, kCTA_S, Qh, 1, kWARP_S, kHeadDim, kStages>>,\n                       GetBlockIterFactory<T, Tkv, kCTA_S, kHeadDim>,\n                       DecodingCtaMap>;\n\nnamespace {\nRegistrar reg([](Collector& c) {\n    c.add<KT<half, half, 8>>();\n    c.add<KT<half, half, 16>>();\n\n    c.add<KT<half, uint8_t, 8>>();\n    c.add<KT<half, uint8_t, 16>>();\n\n    c.add<KT<half, uint4_t, 8>>();\n    c.add<KT<half, uint4_t, 16>>();\n});\n}\n\n}  // namespace turbomind::attention\n"
  },
  {
    "path": "src/turbomind/kernels/attention/kernel/decoding_sm80_128.cu",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include \"src/turbomind/kernels/attention/arch.h\"\n#include \"src/turbomind/kernels/attention/attention_universal.h\"\n#include \"src/turbomind/kernels/attention/block_iterator.h\"\n#include \"src/turbomind/kernels/attention/cta_map.h\"\n#include \"src/turbomind/kernels/attention/impl.h\"\n#include \"src/turbomind/kernels/attention/impl_81616.h\"\n#include \"src/turbomind/kernels/attention/impl_simt.h\"\n#include \"src/turbomind/kernels/attention/mainloop.h\"\n#include \"src/turbomind/kernels/attention/mainloop_sm80.h\"\n#include \"src/turbomind/kernels/attention/registrar.h\"\n\nnamespace turbomind::attention {\n\nconstexpr int kHeadDim = 128;\nconstexpr int kCTA_S   = 64;\nconstexpr int kWARP_S  = 16;\n\ntemplate<class Mainloop_, class CacheIter>\nusing KT = AttentionUniversal<arch::Sm80, Mainloop_, CacheIter, DecodingCtaMap>;\n\n// T==Tkv, Qh<=2: SIMT, stages=3\ntemplate<class T, int Qh>\nusing Decoding_SIMT = KT<Mainloop<Sm80_CpAsync<3>, Impl<MMA_SIMT, T, T, Qh, 1, kCTA_S, Qh, 1, kWARP_S, kHeadDim, 3>>,\n                         GetBlockIterFactory<T, T, kCTA_S, kHeadDim>>;\n\n// Qh>2: MMA_81616; Stages=3 for T==Tkv, Stages=5 for quant Tkv\n// Qh = (Qh_+7)/8*8: Qh_=3..8→Qh=8, Qh_=9→Qh=16\ntemplate<class T, class Tkv, int Qh, int Stages>\nusing Decoding_MMA =\n    KT<Mainloop<Sm80_CpAsync<Stages>, Impl<MMA_81616, T, Tkv, Qh, 1, kCTA_S, Qh, 1, kWARP_S, kHeadDim, Stages>>,\n       GetBlockIterFactory<T, Tkv, kCTA_S, kHeadDim>>;\n\nnamespace {\nRegistrar reg([](Collector& c) {\n    c.add<Decoding_SIMT<half, 1>>();\n    c.add<Decoding_SIMT<half, 2>>();\n    c.add<Decoding_MMA<half, half, 8, 3>>();\n    c.add<Decoding_MMA<half, half, 16, 3>>();\n    c.add<Decoding_MMA<half, uint8_t, 8, 5>>();\n    c.add<Decoding_MMA<half, uint8_t, 16, 5>>();\n    c.add<Decoding_MMA<half, uint4_t, 8, 5>>();\n    c.add<Decoding_MMA<half, uint4_t, 16, 5>>();\n\n#if ENABLE_BF16\n    c.add<Decoding_SIMT<nv_bfloat16, 1>>();\n    c.add<Decoding_SIMT<nv_bfloat16, 2>>();\n    c.add<Decoding_MMA<nv_bfloat16, nv_bfloat16, 8, 3>>();\n    c.add<Decoding_MMA<nv_bfloat16, nv_bfloat16, 16, 3>>();\n    c.add<Decoding_MMA<nv_bfloat16, uint8_t, 8, 5>>();\n    c.add<Decoding_MMA<nv_bfloat16, uint8_t, 16, 5>>();\n    c.add<Decoding_MMA<nv_bfloat16, uint4_t, 8, 5>>();\n    c.add<Decoding_MMA<nv_bfloat16, uint4_t, 16, 5>>();\n#endif\n});\n}  // namespace\n\n}  // namespace turbomind::attention\n"
  },
  {
    "path": "src/turbomind/kernels/attention/kernel/decoding_sm80_192.cu",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include \"src/turbomind/kernels/attention/arch.h\"\n#include \"src/turbomind/kernels/attention/attention_universal.h\"\n#include \"src/turbomind/kernels/attention/block_iterator.h\"\n#include \"src/turbomind/kernels/attention/cta_map.h\"\n#include \"src/turbomind/kernels/attention/impl.h\"\n#include \"src/turbomind/kernels/attention/impl_simt.h\"\n#include \"src/turbomind/kernels/attention/mainloop.h\"\n#include \"src/turbomind/kernels/attention/mainloop_sm80.h\"\n#include \"src/turbomind/kernels/attention/registrar.h\"\n\nnamespace turbomind::attention {\n\nconstexpr int kHeadDim = 192;\nconstexpr int kCTA_S   = 64;\nconstexpr int kWARP_S  = 16;\nconstexpr int kStages  = 3;\nconstexpr int kQh      = 1;\n\n// HeadDim=192 uses SIMT+kStages for all Tkv (incl. uint8_t), kQh=1 only\ntemplate<class T, class Tkv>\nusing KT = AttentionUniversal<\n    arch::Sm80,\n    Mainloop<Sm80_CpAsync<kStages>, Impl<MMA_SIMT, T, Tkv, kQh, 1, kCTA_S, kQh, 1, kWARP_S, kHeadDim, kStages>>,\n    GetBlockIterFactory<T, Tkv, kCTA_S, kHeadDim>,\n    DecodingCtaMap>;\n\nnamespace {\nRegistrar reg([](Collector& c) {\n    c.add<KT<half, half>>();\n    c.add<KT<half, uint8_t>>();\n\n#if ENABLE_BF16\n    c.add<KT<nv_bfloat16, nv_bfloat16>>();\n    c.add<KT<nv_bfloat16, uint8_t>>();\n#endif\n});\n}  // namespace\n\n}  // namespace turbomind::attention\n"
  },
  {
    "path": "src/turbomind/kernels/attention/kernel/decoding_sm80_256.cu",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include \"src/turbomind/kernels/attention/arch.h\"\n#include \"src/turbomind/kernels/attention/attention_universal.h\"\n#include \"src/turbomind/kernels/attention/block_iterator.h\"\n#include \"src/turbomind/kernels/attention/cta_map.h\"\n#include \"src/turbomind/kernels/attention/impl.h\"\n#include \"src/turbomind/kernels/attention/impl_81616.h\"\n#include \"src/turbomind/kernels/attention/impl_simt.h\"\n#include \"src/turbomind/kernels/attention/mainloop.h\"\n#include \"src/turbomind/kernels/attention/mainloop_sm80.h\"\n#include \"src/turbomind/kernels/attention/registrar.h\"\n\nnamespace turbomind::attention {\n\nconstexpr int kHeadDim = 256;\nconstexpr int kCTA_S   = 64;\nconstexpr int kWARP_S  = 16;\n\ntemplate<class Mainloop_, class CacheIter>\nusing KT = AttentionUniversal<arch::Sm80, Mainloop_, CacheIter, DecodingCtaMap>;\n\n// T==Tkv, Qh<=2: SIMT, stages=3\ntemplate<class T, int Qh>\nusing Decoding_SIMT = KT<Mainloop<Sm80_CpAsync<3>, Impl<MMA_SIMT, T, T, Qh, 1, kCTA_S, Qh, 1, kWARP_S, kHeadDim, 3>>,\n                         GetBlockIterFactory<T, T, kCTA_S, kHeadDim>>;\n\n// Qh>2: MMA_81616; Stages=3 for T==Tkv, Stages=5 for quant Tkv\n// Qh = (Qh_+7)/8*8: Qh_=3..8→Qh=8, Qh_=9→Qh=16\ntemplate<class T, class Tkv, int Qh, int Stages>\nusing Decoding_MMA =\n    KT<Mainloop<Sm80_CpAsync<Stages>, Impl<MMA_81616, T, Tkv, Qh, 1, kCTA_S, Qh, 1, kWARP_S, kHeadDim, Stages>>,\n       GetBlockIterFactory<T, Tkv, kCTA_S, kHeadDim>>;\n\nnamespace {\nRegistrar reg([](Collector& c) {\n    c.add<Decoding_SIMT<half, 1>>();\n    c.add<Decoding_SIMT<half, 2>>();\n    c.add<Decoding_MMA<half, half, 8, 3>>();\n    c.add<Decoding_MMA<half, half, 16, 3>>();\n    c.add<Decoding_MMA<half, uint8_t, 8, 5>>();\n    c.add<Decoding_MMA<half, uint8_t, 16, 5>>();\n    c.add<Decoding_MMA<half, uint4_t, 8, 5>>();\n    c.add<Decoding_MMA<half, uint4_t, 16, 5>>();\n\n#if ENABLE_BF16\n    c.add<Decoding_SIMT<nv_bfloat16, 1>>();\n    c.add<Decoding_SIMT<nv_bfloat16, 2>>();\n    c.add<Decoding_MMA<nv_bfloat16, nv_bfloat16, 8, 3>>();\n    c.add<Decoding_MMA<nv_bfloat16, nv_bfloat16, 16, 3>>();\n    c.add<Decoding_MMA<nv_bfloat16, uint8_t, 8, 5>>();\n    c.add<Decoding_MMA<nv_bfloat16, uint8_t, 16, 5>>();\n    c.add<Decoding_MMA<nv_bfloat16, uint4_t, 8, 5>>();\n    c.add<Decoding_MMA<nv_bfloat16, uint4_t, 16, 5>>();\n#endif\n});\n}  // namespace\n\n}  // namespace turbomind::attention\n"
  },
  {
    "path": "src/turbomind/kernels/attention/kernel/decoding_sm80_576.cu",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include \"src/turbomind/kernels/attention/arch.h\"\n#include \"src/turbomind/kernels/attention/attention_universal.h\"\n#include \"src/turbomind/kernels/attention/block_iterator.h\"\n#include \"src/turbomind/kernels/attention/cta_map.h\"\n#include \"src/turbomind/kernels/attention/impl.h\"\n#include \"src/turbomind/kernels/attention/impl_81616.h\"\n#include \"src/turbomind/kernels/attention/mainloop.h\"\n#include \"src/turbomind/kernels/attention/mainloop_sm80.h\"\n#include \"src/turbomind/kernels/attention/registrar.h\"\n\nnamespace turbomind::attention {\n\nconstexpr int kHeadDim = 576;\n\n// Non-quant MLA config: CTA_H=16, CTA_S=32, WARP_H=8, WARP_S=16, Stages=2\ntemplate<class T>\nusing Decoding_F =\n    AttentionUniversal<arch::Sm80,\n                       Mainloop<Sm80_CpAsync<2>, Impl<MMA_81616, T, T, 16, 1, 32, 8, 1, 16, kHeadDim, 2>>,\n                       GetBlockIterFactory<T, T, 32, kHeadDim>,\n                       DecodingCtaMap>;\n\n// Quant config: CTA_H=8, CTA_S=64, WARP_H=8, WARP_S=16, Stages=5\ntemplate<class T, class Tkv>\nusing Decoding_Q =\n    AttentionUniversal<arch::Sm80,\n                       Mainloop<Sm80_CpAsync<5>, Impl<MMA_81616, T, Tkv, 8, 1, 64, 8, 1, 16, kHeadDim, 5>>,\n                       GetBlockIterFactory<T, Tkv, 64, kHeadDim>,\n                       DecodingCtaMap>;\n\nnamespace {\nRegistrar reg([](Collector& c) {\n    c.add<Decoding_F<half>>();\n    c.add<Decoding_Q<half, uint8_t>>();\n    c.add<Decoding_Q<half, uint4_t>>();\n\n#if ENABLE_BF16\n    c.add<Decoding_F<nv_bfloat16>>();\n    c.add<Decoding_Q<nv_bfloat16, uint8_t>>();\n    c.add<Decoding_Q<nv_bfloat16, uint4_t>>();\n#endif\n});\n}  // namespace\n\n}  // namespace turbomind::attention\n"
  },
  {
    "path": "src/turbomind/kernels/attention/kernel/decoding_sm80_64.cu",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include \"src/turbomind/kernels/attention/arch.h\"\n#include \"src/turbomind/kernels/attention/attention_universal.h\"\n#include \"src/turbomind/kernels/attention/block_iterator.h\"\n#include \"src/turbomind/kernels/attention/cta_map.h\"\n#include \"src/turbomind/kernels/attention/impl.h\"\n#include \"src/turbomind/kernels/attention/impl_81616.h\"\n#include \"src/turbomind/kernels/attention/impl_simt.h\"\n#include \"src/turbomind/kernels/attention/mainloop.h\"\n#include \"src/turbomind/kernels/attention/mainloop_sm80.h\"\n#include \"src/turbomind/kernels/attention/registrar.h\"\n\nnamespace turbomind::attention {\n\nconstexpr int kHeadDim = 64;\nconstexpr int kCTA_S   = 64;\nconstexpr int kWARP_S  = 16;\n\ntemplate<class Mainloop_, class CacheIter>\nusing KT = AttentionUniversal<arch::Sm80, Mainloop_, CacheIter, DecodingCtaMap>;\n\n// T==Tkv, Qh<=2: SIMT, stages=3\ntemplate<class T, int Qh>\nusing Decoding_SIMT = KT<Mainloop<Sm80_CpAsync<3>, Impl<MMA_SIMT, T, T, Qh, 1, kCTA_S, Qh, 1, kWARP_S, kHeadDim, 3>>,\n                         GetBlockIterFactory<T, T, kCTA_S, kHeadDim>>;\n\n// Qh>2: MMA_81616; Stages=3 for T==Tkv, Stages=5 for quant Tkv\n// Qh = (Qh_+7)/8*8: Qh_=3..8→Qh=8, Qh_=9→Qh=16\ntemplate<class T, class Tkv, int Qh, int Stages>\nusing Decoding_MMA =\n    KT<Mainloop<Sm80_CpAsync<Stages>, Impl<MMA_81616, T, Tkv, Qh, 1, kCTA_S, Qh, 1, kWARP_S, kHeadDim, Stages>>,\n       GetBlockIterFactory<T, Tkv, kCTA_S, kHeadDim>>;\n\nnamespace {\nRegistrar reg([](Collector& c) {\n    c.add<Decoding_SIMT<half, 1>>();\n    c.add<Decoding_SIMT<half, 2>>();\n    c.add<Decoding_MMA<half, half, 8, 3>>();\n    c.add<Decoding_MMA<half, half, 16, 3>>();\n    c.add<Decoding_MMA<half, uint8_t, 8, 5>>();\n    c.add<Decoding_MMA<half, uint8_t, 16, 5>>();\n    c.add<Decoding_MMA<half, uint4_t, 8, 5>>();\n    c.add<Decoding_MMA<half, uint4_t, 16, 5>>();\n\n#if ENABLE_BF16\n    c.add<Decoding_SIMT<nv_bfloat16, 1>>();\n    c.add<Decoding_SIMT<nv_bfloat16, 2>>();\n    c.add<Decoding_MMA<nv_bfloat16, nv_bfloat16, 8, 3>>();\n    c.add<Decoding_MMA<nv_bfloat16, nv_bfloat16, 16, 3>>();\n    c.add<Decoding_MMA<nv_bfloat16, uint8_t, 8, 5>>();\n    c.add<Decoding_MMA<nv_bfloat16, uint8_t, 16, 5>>();\n    c.add<Decoding_MMA<nv_bfloat16, uint4_t, 8, 5>>();\n    c.add<Decoding_MMA<nv_bfloat16, uint4_t, 16, 5>>();\n#endif\n});\n}  // namespace\n\n}  // namespace turbomind::attention\n"
  },
  {
    "path": "src/turbomind/kernels/attention/kernel.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include \"src/turbomind/kernels/attention/desc.h\"\n\nnamespace turbomind::attention {\n\nclass Kernel {\npublic:\n    Kernel(): desc_{}, info_{} {}\n\n    virtual ~Kernel() = default;\n\n    virtual bool Launch(const void* params, int sm_count) const = 0;\n\n    const KernelDesc& desc() const noexcept\n    {\n        return desc_;\n    }\n\n    const KernelInfo& info() const noexcept\n    {\n        return info_;\n    }\n\n    int arch() const noexcept\n    {\n        return desc_.arch;\n    }\n\n    int smem_size() const noexcept\n    {\n        return info_.attr.sharedSizeBytes + info_.dynamic_smem_size;\n    }\n\n    const std::string& name() const\n    {\n        return info_.name;\n    }\n\nprotected:\n    KernelDesc desc_;\n    KernelInfo info_;\n};\n\n}  // namespace turbomind::attention\n"
  },
  {
    "path": "src/turbomind/kernels/attention/kernel_impl.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include <type_traits>\n\n#include \"src/turbomind/core/data_type.h\"\n#include \"src/turbomind/kernels/attention/attention_template.h\"\n#include \"src/turbomind/kernels/attention/cta_map.h\"\n#include \"src/turbomind/kernels/attention/decoding_template.h\"\n#include \"src/turbomind/kernels/attention/kernel.h\"\n#include \"src/turbomind/kernels/core/common.h\"\n\nnamespace turbomind::attention {\n\ntemplate<class Tkv>\nconstexpr int kv_quant_from_type()\n{\n    if constexpr (std::is_same_v<Tkv, uint8_t>) {\n        return 8;\n    }\n    else if constexpr (std::is_same_v<Tkv, uint4_t>) {\n        return 4;\n    }\n    else {\n        return 0;\n    }\n}\n\ntemplate<class K>\nclass KernelImpl: public Kernel {\n    static constexpr bool kIsDecoding = std::is_same_v<typename K::CtaMap, DecodingCtaMap>;\n\npublic:\n    KernelImpl()\n    {\n        desc_.mode      = kIsDecoding ? AttnDesc::kDecoding : AttnDesc::kPrefill;\n        desc_.arch      = K::Arch::value;\n        desc_.head_dim  = K::kHeadDim;\n        desc_.data_type = data_type_v<typename K::T>;\n\n        if constexpr (kIsDecoding) {\n            desc_.kv_quant = kv_quant_from_type<typename K::Tkv>();\n            desc_.qh       = K::CTA_H;\n        }\n        else {\n            desc_.kv_quant = 0;\n            desc_.qh       = 1;\n        }\n\n        auto func               = &attention_kernel<K>;\n        info_.dynamic_smem_size = sizeof(typename K::SharedStorage);\n\n        cudaFuncGetAttributes(&info_.attr, func);\n\n        if (info_.dynamic_smem_size > (48 << 10)) {\n            cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, info_.dynamic_smem_size);\n        }\n\n        info_.num_warps = K::kWarpCount;\n        cudaOccupancyMaxActiveBlocksPerMultiprocessor(\n            &info_.max_active_ctas, func, info_.num_warps * WARP_SIZE, info_.dynamic_smem_size);\n\n        info_.name = to_string(desc_);\n    }\n\n    bool Launch(const void* params, int sm_count) const override\n    {\n        const auto& p = *static_cast<const typename K::ParamType*>(params);\n        if constexpr (kIsDecoding) {\n            return invokeDecoding<K>(p, sm_count, info_.max_active_ctas);\n        }\n        else {\n            invokeAttention<K>(p, sm_count, info_.max_active_ctas);\n            return true;\n        }\n    }\n};\n\n}  // namespace turbomind::attention\n"
  },
  {
    "path": "src/turbomind/kernels/attention/kv_cache_utils_v2.cu",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include <type_traits>\n\n#include \"src/turbomind/kernels/attention/block.h\"\n#include \"src/turbomind/kernels/attention/kv_cache_utils_v2.h\"\n#include \"src/turbomind/kernels/attention/quantization.h\"\n#include \"src/turbomind/kernels/attention/rotary_embedding.h\"\n#include \"src/turbomind/kernels/core/array_ops.h\"\n#include \"src/turbomind/kernels/core/thread_map.h\"\n#include \"src/turbomind/models/llama/llama_utils.h\"\n#include \"src/turbomind/utils/cuda_utils.h\"\n\nnamespace turbomind {\n\nusing cutlass::FastDivmod;\n\ntemplate<class Tkv, int CTA_S, int HeadDim, int WarpCnt, class T, class BlockLayout>\n__global__ void __launch_bounds__(128) ProcessKV_v2(char**          blocks,\n                                                    const T*        k,\n                                                    const T*        v,\n                                                    const T*        k_bias,\n                                                    const T*        v_bias,\n                                                    const int*      cu_q_len,\n                                                    const int*      cu_k_len,\n                                                    const int*      cu_block_num,\n                                                    RopeKernelParam rope_param,\n                                                    int64_t         stride_b,\n                                                    int64_t         stride_c,\n                                                    int64_t         stride_h,\n                                                    int64_t         stride_s,\n                                                    int             layer_id,\n                                                    int             cp_rank,\n                                                    FastDivmod      cp_size,\n                                                    BlockLayout     block_layout)\n{\n\n    constexpr int kVecSize = sizeof(uint4) / sizeof(T);\n\n    using Vec = Array<T, kVecSize>;\n    using Map = RakedThreadMap<HeadDim, CTA_S, kVecSize, WarpCnt>;\n\n    constexpr int ITER_C = Map::kIterC;\n    constexpr int ITER_S = Map::kIterS;\n\n    constexpr bool HAS_V = !(typename BlockLayout::Config{}.is_share_kv());\n\n    const int token_idx = blockIdx.x * CTA_S;  // local offset into `input_length`\n    const int head_idx  = blockIdx.y;\n    const int batch_idx = blockIdx.z;\n\n    const int qi_beg = cu_q_len[batch_idx];\n    const int qi_end = cu_q_len[batch_idx + 1];\n    const int q_len  = qi_end - qi_beg;\n\n    const int k_len       = cu_k_len[batch_idx + 1] - cu_k_len[batch_idx];\n    const int history_len = k_len - q_len;\n\n    if (qi_beg + token_idx >= qi_end) {  // empty tile\n        return;\n    }\n\n    const int warp_id = threadIdx.x / WARP_SIZE;\n    const int lane_id = threadIdx.x % WARP_SIZE;\n\n    const int2 offset = Map::get_offset(warp_id, lane_id);\n\n    Vec __align__(16) vec_K[ITER_S][ITER_C];\n    Vec __align__(16) vec_V[ITER_S][ITER_C];\n\n    Vec bias_V[ITER_C];\n    Vec bias_K[ITER_C];\n\n    if (k_bias) {\n        PRAGMA_UNROLL\n        for (int c = 0; c < ITER_C; ++c) {\n            const int di = offset.x + c * Map::kDeltaC;\n            Ldg(bias_K[c], &k_bias[head_idx * HeadDim + di]);\n        }\n    }\n    if (v_bias && HAS_V) {\n        PRAGMA_UNROLL\n        for (int c = 0; c < ITER_C; ++c) {\n            const int di = offset.x + c * Map::kDeltaC;\n            Ldg(bias_V[c], &v_bias[head_idx * HeadDim + di]);\n        }\n    }\n\n    PRAGMA_UNROLL\n    for (int s = 0; s < ITER_S; ++s) {\n        PRAGMA_UNROLL\n        for (int c = 0; c < ITER_C; ++c) {\n            const int     qi = offset.y + s * Map::kDeltaS + token_idx;  // sequence local\n            const int     di = offset.x + c * Map::kDeltaC;\n            const int64_t index =\n                (batch_idx * stride_b + qi_beg * stride_c + qi * stride_s + head_idx * stride_h) * HeadDim + di;\n            if (qi < q_len) {\n                Ldg(vec_K[s][c], &k[index]);\n                if constexpr (HAS_V) {\n                    Ldg(vec_V[s][c], &v[index]);\n                }\n            }\n        }\n    }\n\n    if (k_bias) {\n        using namespace ops;\n        PRAGMA_UNROLL\n        for (int s = 0; s < ITER_S; ++s) {\n            PRAGMA_UNROLL\n            for (int c = 0; c < ITER_C; ++c) {\n                vec_K[s][c] = vec_K[s][c] + bias_K[c];\n            }\n        }\n    }\n    if (v_bias && HAS_V) {\n        using namespace ops;\n        PRAGMA_UNROLL\n        for (int s = 0; s < ITER_S; ++s) {\n            PRAGMA_UNROLL\n            for (int c = 0; c < ITER_C; ++c) {\n                vec_V[s][c] = vec_V[s][c] + bias_V[c];\n            }\n        }\n    }\n\n    if (rope_param.type != RopeType::kNull) {\n        FastRoPE rope(rope_param, batch_idx, std::integral_constant<int, kVecSize>{});\n        PRAGMA_UNROLL\n        for (int c = 0; c < ITER_C; ++c) {\n            const int di = offset.x + c * Map::kDeltaC;\n            rope.init(di);\n            PRAGMA_UNROLL\n            for (int s = 0; s < ITER_S; ++s) {\n                const int ti = history_len + offset.y + s * Map::kDeltaS + token_idx;  // sequence local\n                rope.apply(vec_K[s][c], ti);\n            }\n        }\n    }\n\n    Array<T, 2> param_K[ITER_S];\n    Array<T, 2> param_V[ITER_S];\n\n    if constexpr (!std::is_same_v<T, Tkv>) {\n        warp_stats<Map::kWarpThreadC>(param_K, vec_K, bitsof<Tkv>);\n        if constexpr (HAS_V) {\n            warp_stats<Map::kWarpThreadC>(param_V, vec_V, bitsof<Tkv>);\n        }\n    }\n\n    Array<Tkv, kVecSize> out_K[ITER_S][ITER_C];\n    Array<Tkv, kVecSize> out_V[ITER_S][ITER_C];\n\n    PRAGMA_UNROLL\n    for (int s = 0; s < ITER_S; ++s) {\n        ConvertKvCache<T, Tkv> conv_K{param_K[s][0], param_K[s][1]};\n        ConvertKvCache<T, Tkv> conv_V{param_V[s][0], param_V[s][1]};\n        PRAGMA_UNROLL\n        for (int c = 0; c < ITER_C; ++c) {\n            out_K[s][c] = conv_K(vec_K[s][c]);\n            if constexpr (HAS_V) {\n                out_V[s][c] = conv_V(vec_V[s][c]);\n            }\n        }\n    }\n\n    int local_ti, local_ti_rank;\n\n    blocks += cu_block_num[batch_idx];\n\n    block::Head<T, Tkv, BlockLayout> block_head{block_layout, layer_id, head_idx};\n\n    PRAGMA_UNROLL\n    for (int s = 0; s < ITER_S; ++s) {\n        const int qi = offset.y + s * Map::kDeltaS + token_idx;  // local offset into `input_length`\n        const int ti = history_len + qi;                         // timestep\n        local_ti     = cp_size.divmod(local_ti_rank, ti);\n        if (qi < q_len && local_ti_rank == cp_rank) {\n            block_head.with((char**)blocks, local_ti, [&](auto k_cache, auto v_cache, T* k_param, T* v_param) {\n                PRAGMA_UNROLL\n                for (int c = 0; c < ITER_C; ++c) {\n                    int di = offset.x + c * Map::kDeltaC;\n                    Store(&k_cache[di], out_K[s][c]);\n                    if constexpr (HAS_V) {\n                        Store(&v_cache[di], out_V[s][c]);\n                    }\n                }\n                if constexpr (!std::is_same_v<T, Tkv>) {\n                    if (offset.x == 0) {\n                        StoreQuantParam<Tkv>(k_param, param_K[s]);\n                        if constexpr (HAS_V) {\n                            StoreQuantParam<Tkv>(v_param, param_V[s]);\n                        }\n                        // if (ti == history_len) {\n                        // printf(\"src %d %f %f\\n\", ti, (float)param_K[s][0], (float)param_K[s][1]);\n                        // }\n                    }\n                }\n            });\n        }\n    }\n}\n\ntemplate<class T>\nvoid invokeProcessKV_v2(char**                 blocks,\n                        const T*               k,\n                        const T*               v,\n                        const T*               k_bias,\n                        const T*               v_bias,\n                        const int*             cu_q_len,\n                        const int*             cu_k_len,\n                        const int*             cu_block_num,\n                        const RopeKernelParam& rope_param,\n                        int64_t                stride_b,\n                        int64_t                stride_c,\n                        int64_t                stride_h,\n                        int64_t                stride_s,\n                        int                    block_seq_len,\n                        int                    layer_id,\n                        int                    cp_rank,\n                        FastDivmod             cp_size,\n                        int                    max_q_len,\n                        int                    head_num,\n                        int                    head_dim,\n                        int                    batch_size,\n                        int                    quant_policy,\n                        cudaStream_t           stream)\n{\n\n    auto invoke = [&](auto tkv, const auto dim) {\n        using Tkv = decltype(tkv);\n\n        constexpr int  kHeadDim = dim;\n        constexpr bool kShareKV = kHeadDim == 576;\n\n        constexpr int WARPS = 4;\n        constexpr int CTA_S = kShareKV ? 32 : 64;\n\n        int  block = WARPS * WARP_SIZE;\n        dim3 grid(cdiv(max_q_len, CTA_S), head_num, batch_size);\n\n        TM_CHECK_EQ(head_dim, kHeadDim);\n\n        block::Layout block_layout{block::Config<T, Tkv, kHeadDim, kShareKV>{head_num, block_seq_len}};\n\n        ProcessKV_v2<Tkv, CTA_S, kHeadDim, WARPS><<<grid, block, 0, stream>>>(blocks,\n                                                                              k,\n                                                                              v,\n                                                                              k_bias,\n                                                                              v_bias,\n                                                                              cu_q_len,\n                                                                              cu_k_len,\n                                                                              cu_block_num,\n                                                                              rope_param,\n                                                                              stride_b,\n                                                                              stride_c,\n                                                                              stride_h,\n                                                                              stride_s,\n                                                                              layer_id,\n                                                                              cp_rank,\n                                                                              cp_size,\n                                                                              block_layout);\n    };\n\n    auto dispatch = [&](auto tkv) {\n        if (0) {}\n        else if (head_dim == 64) {\n            return invoke(tkv, std::integral_constant<int, 64>{});\n        }\n        else if (head_dim == 128) {\n            return invoke(tkv, std::integral_constant<int, 128>{});\n        }\n        else if (head_dim == 192) {\n            return invoke(tkv, std::integral_constant<int, 192>{});\n        }\n        else if (head_dim == 256) {\n            return invoke(tkv, std::integral_constant<int, 256>{});\n        }\n        else if (head_dim == 576) {\n            return invoke(tkv, std::integral_constant<int, 576>{});\n        }\n        FT_CHECK(0);\n    };\n\n    if (quant_policy & QuantPolicy::kCacheKVInt8) {\n        dispatch(uint8_t{});\n    }\n    else if (quant_policy & QuantPolicy::kCacheKVInt4) {\n        dispatch(uint4_t{});\n    }\n    else {\n        dispatch(T{});\n    }\n}\n\n#define INSTANTIATE_invokeProcessKV_v2(type)                                                                           \\\n    template void invokeProcessKV_v2(char**                 blocks,                                                    \\\n                                     const type*            k,                                                         \\\n                                     const type*            v,                                                         \\\n                                     const type*            k_bias,                                                    \\\n                                     const type*            v_bias,                                                    \\\n                                     const int*             cu_q_len,                                                  \\\n                                     const int*             cu_k_len,                                                  \\\n                                     const int*             cu_block_num,                                              \\\n                                     const RopeKernelParam& rope_param,                                                \\\n                                     int64_t                stride_b,                                                  \\\n                                     int64_t                stride_c,                                                  \\\n                                     int64_t                stride_h,                                                  \\\n                                     int64_t                stride_s,                                                  \\\n                                     int                    block_seq_len,                                             \\\n                                     int                    layer_id,                                                  \\\n                                     int                    cp_rank,                                                   \\\n                                     FastDivmod             cp_size,                                                   \\\n                                     int                    max_q_len,                                                 \\\n                                     int                    head_num,                                                  \\\n                                     int                    head_dim,                                                  \\\n                                     int                    batch_size,                                                \\\n                                     int                    quant_policy,                                              \\\n                                     cudaStream_t           stream);\n\nINSTANTIATE_invokeProcessKV_v2(half);\n#if ENABLE_BF16\nINSTANTIATE_invokeProcessKV_v2(nv_bfloat16);\n#endif\n\ntemplate<int CTA_S, int HeadDim, int WarpCnt, class T, class Tkv, class BlockLayout>\n__global__ void __launch_bounds__(128) flattenKV_v2(T*              k,\n                                                    T*              v,\n                                                    const Tkv**     blocks,\n                                                    const int*      cu_k_len,\n                                                    const int*      cu_block_num,\n                                                    RopeKernelParam rope_param,\n                                                    int64_t         stride_b,\n                                                    int64_t         stride_c,\n                                                    int64_t         stride_h,\n                                                    int64_t         stride_s,\n                                                    int             layer_id,\n                                                    int             cp_rank,\n                                                    FastDivmod      cp_size,\n                                                    BlockLayout     block_layout)\n{\n    constexpr int kVecSize = sizeof(uint4) / sizeof(T);\n\n    using Map = RakedThreadMap<HeadDim, CTA_S, kVecSize, WarpCnt>;\n\n    constexpr int ITER_C = Map::kIterC;\n    constexpr int ITER_S = Map::kIterS;\n\n    constexpr bool HAS_V = !(typename BlockLayout::Config{}.is_share_kv());\n\n    const int token_idx = blockIdx.x * CTA_S;\n    const int head_idx  = blockIdx.y;\n    const int batch_idx = blockIdx.z;\n\n    const int ti_0   = cu_k_len[0];\n    const int ti_beg = cu_k_len[batch_idx] - ti_0;\n    const int ti_end = cu_k_len[batch_idx + 1] - ti_0;\n\n    const int seq_len = ti_end - ti_beg;\n\n    if (ti_beg + token_idx >= ti_end) {  // empty tile\n        return;\n    }\n\n    const int warp_id = threadIdx.x / WARP_SIZE;\n    const int lane_id = threadIdx.x % WARP_SIZE;\n\n    const int2 offset = Map::get_offset(warp_id, lane_id);\n\n    Array<Tkv, kVecSize> __align__(16) vec_K[ITER_S][ITER_C];\n    Array<Tkv, kVecSize> __align__(16) vec_V[ITER_S][ITER_C];\n\n    Array<T, kVecSize> __align__(16) out_K[ITER_S][ITER_C];\n    Array<T, kVecSize> __align__(16) out_V[ITER_S][ITER_C];\n\n    blocks += cu_block_num[batch_idx];\n\n    block::Head<T, Tkv, BlockLayout> block_head{block_layout, layer_id, head_idx};\n\n    Array<T, 2> param_K[ITER_S];\n    Array<T, 2> param_V[ITER_S];\n\n    int local_ti, local_ti_rank;\n\n    PRAGMA_UNROLL\n    for (int s = 0; s < ITER_S; ++s) {\n        const int si = offset.y + s * Map::kDeltaS + token_idx;\n        local_ti     = cp_size.divmod(local_ti_rank, si);\n        if (si < seq_len && local_ti_rank == cp_rank) {\n            block_head.with((char**)blocks, local_ti, [&](auto k_cache, auto v_cache, T* k_param, T* v_param) {\n                PRAGMA_UNROLL\n                for (int c = 0; c < ITER_C; ++c) {\n                    int di = offset.x + c * Map::kDeltaC;\n                    Ldg(vec_K[s][c], &k_cache[di]);\n                    if constexpr (HAS_V) {\n                        Ldg(vec_V[s][c], &v_cache[di]);\n                    }\n                }\n                if constexpr (!std::is_same_v<T, Tkv>) {\n                    Ldg(param_K[s], k_param);\n                    if constexpr (HAS_V) {\n                        Ldg(param_V[s], v_param);\n                    }\n                }\n            });\n        }\n    }\n\n    PRAGMA_UNROLL\n    for (int s = 0; s < ITER_S; ++s) {\n        ConvertKvCache<Tkv, T> conv_K{param_K[s][0], param_K[s][1]};\n        ConvertKvCache<Tkv, T> conv_V{param_V[s][0], param_V[s][1]};\n        PRAGMA_UNROLL\n        for (int c = 0; c < ITER_C; ++c) {\n            out_K[s][c] = conv_K(vec_K[s][c]);\n            if constexpr (HAS_V) {\n                out_V[s][c] = conv_V(vec_V[s][c]);\n            }\n        }\n    }\n\n    if (rope_param.type != RopeType::kNull) {\n        FastRoPE rope(rope_param, batch_idx, std::integral_constant<int, kVecSize>{});\n        PRAGMA_UNROLL\n        for (int c = 0; c < ITER_C; ++c) {\n            const int di = offset.x + c * Map::kDeltaC;\n            rope.init(di);\n            PRAGMA_UNROLL\n            for (int s = 0; s < ITER_S; ++s) {\n                const int ti = offset.y + s * Map::kDeltaS + token_idx;  // sequence local\n                rope.apply(out_K[s][c], ti);\n            }\n        }\n    }\n\n    PRAGMA_UNROLL\n    for (int s = 0; s < ITER_S; ++s) {\n        PRAGMA_UNROLL\n        for (int c = 0; c < ITER_C; ++c) {\n            const int si = offset.y + s * Map::kDeltaS + token_idx;\n            const int di = offset.x + c * Map::kDeltaC;\n            local_ti     = cp_size.divmod(local_ti_rank, si);\n            if (si < seq_len && local_ti_rank == cp_rank) {\n                const int64_t index =\n                    (batch_idx * stride_b + ti_beg * stride_c + local_ti * stride_s + head_idx * stride_h) * HeadDim\n                    + di;\n                Store(&k[index], out_K[s][c]);\n                if constexpr (HAS_V) {\n                    Store(&v[index], out_V[s][c]);\n                }\n            }\n        }\n    }\n}\n\ntemplate<class T>\nvoid invokeFlattenKV_v2(T*                     k,\n                        T*                     v,\n                        char**                 blocks,\n                        const int*             cu_k_len,\n                        const int*             cu_block_num,\n                        const RopeKernelParam& rope_param,\n                        int64_t                stride_b,\n                        int64_t                stride_c,\n                        int64_t                stride_h,\n                        int64_t                stride_s,\n                        int                    block_seq_len,\n                        int                    layer_id,\n                        int                    cp_rank,\n                        FastDivmod             cp_size,\n                        int                    max_seq_len,\n                        int                    head_num,\n                        int                    head_dim,\n                        int                    batch_size,\n                        int                    quant_policy,\n                        cudaStream_t           stream)\n{\n\n    auto invoke = [&](auto tkv, const auto dim) {\n        using Tkv = decltype(tkv);\n\n        constexpr int  kHeadDim = dim;\n        constexpr bool kShareKV = kHeadDim == 576;\n\n        constexpr int kWarpCnt = 4;\n        constexpr int CTA_S    = kShareKV ? 32 : 64;\n\n        constexpr int block = kWarpCnt * WARP_SIZE;\n        const dim3    grid((max_seq_len + CTA_S - 1) / CTA_S, head_num, batch_size);\n\n        TM_CHECK_EQ(head_dim, kHeadDim);\n\n        block::Layout block_layout{block::Config<T, Tkv, kHeadDim, kShareKV>{head_num, block_seq_len}};\n\n        flattenKV_v2<CTA_S, kHeadDim, kWarpCnt><<<grid, block, 0, stream>>>(k,\n                                                                            v,\n                                                                            (const Tkv**)blocks,\n                                                                            cu_k_len,\n                                                                            cu_block_num,\n                                                                            rope_param,\n                                                                            stride_b,\n                                                                            stride_c,\n                                                                            stride_h,\n                                                                            stride_s,\n                                                                            layer_id,\n                                                                            cp_rank,\n                                                                            cp_size,\n                                                                            block_layout);\n    };\n\n    auto dispatch = [&](auto tkv) {\n        if (0) {}\n        else if (head_dim == 64) {\n            return invoke(tkv, std::integral_constant<int, 64>{});\n        }\n        else if (head_dim == 128) {\n            return invoke(tkv, std::integral_constant<int, 128>{});\n        }\n        else if (head_dim == 192) {\n            return invoke(tkv, std::integral_constant<int, 192>{});\n        }\n        else if (head_dim == 256) {\n            return invoke(tkv, std::integral_constant<int, 256>{});\n        }\n        else if (head_dim == 576) {\n            return invoke(tkv, std::integral_constant<int, 576>{});\n        }\n        FT_CHECK(0);\n    };\n\n    if (quant_policy & QuantPolicy::kCacheKVInt8) {\n        dispatch(uint8_t{});\n    }\n    else if (quant_policy & QuantPolicy::kCacheKVInt4) {\n        dispatch(uint4_t{});\n    }\n    else {\n        dispatch(T{});\n    }\n}\n\n#define INSTANTIATE_invokeFlattenKV_v2(type)                                                                           \\\n    template void invokeFlattenKV_v2(type*                  k,                                                         \\\n                                     type*                  v,                                                         \\\n                                     char**                 blocks,                                                    \\\n                                     const int*             cu_k_len,                                                  \\\n                                     const int*             cu_block_num,                                              \\\n                                     const RopeKernelParam& rope_param,                                                \\\n                                     int64_t                stride_b,                                                  \\\n                                     int64_t                stride_c,                                                  \\\n                                     int64_t                stride_h,                                                  \\\n                                     int64_t                stride_s,                                                  \\\n                                     int                    block_seq_len,                                             \\\n                                     int                    layer_id,                                                  \\\n                                     int                    cp_rank,                                                   \\\n                                     FastDivmod             cp_size,                                                   \\\n                                     int                    max_seq_len,                                               \\\n                                     int                    head_num,                                                  \\\n                                     int                    head_dim,                                                  \\\n                                     int                    batch_size,                                                \\\n                                     int                    quant_policy,                                              \\\n                                     cudaStream_t           stream);\n\nINSTANTIATE_invokeFlattenKV_v2(half);\n#if ENABLE_BF16\nINSTANTIATE_invokeFlattenKV_v2(nv_bfloat16);\n#endif\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/kernels/attention/kv_cache_utils_v2.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include \"src/turbomind/core/data_type.h\"\n#include \"src/turbomind/kernels/attention/attention_params.h\"\n\nnamespace turbomind {\n\ntemplate<class T>\nvoid invokeProcessKV_v2(char**                 blocks,\n                        const T*               k,\n                        const T*               v,\n                        const T*               k_bias,\n                        const T*               v_bias,\n                        const int*             cu_q_len,\n                        const int*             cu_k_len,\n                        const int*             cu_block_num,\n                        const RopeKernelParam& rope_param,\n                        int64_t                stride_b,\n                        int64_t                stride_c,\n                        int64_t                stride_h,\n                        int64_t                stride_s,\n                        int                    block_seq_len,\n                        int                    layer_id,\n                        int                    cp_rank,\n                        cutlass::FastDivmod    cp_size,\n                        int                    max_q_len,\n                        int                    head_num,\n                        int                    head_dim,\n                        int                    batch_size,\n                        int                    quant_policy,\n                        cudaStream_t           stream = {});\n\ntemplate<class T>\nvoid invokeProcessKV_v2_(const AttentionParams<T>& params)\n{\n    invokeProcessKV_v2((char**)params.block_iter_params.block_ptrs,\n                       params.k,\n                       params.v,\n                       params.k_bias,\n                       params.v_bias,\n                       params.cu_q_len,\n                       params.cu_k_len,\n                       params.block_iter_params.cu_block_nums,\n                       params.rope_param,\n                       0,                                     // stride b\n                       params.stride / params.size_per_head,  // stride c\n                       1,                                     // stride h\n                       params.stride / params.size_per_head,  // stride s\n                       params.block_iter_params.block_len,\n                       params.block_iter_params.layer_id,\n                       params.cp_rank,\n                       params.cp_size,\n                       params.max_q_len,\n                       params.num_kv_heads,\n                       params.size_per_head,\n                       params.batch_size,\n                       params.quant_policy,\n                       params.stream);\n}\n\ntemplate<class T>\nvoid invokeFlattenKV_v2(T*                     k,\n                        T*                     v,\n                        char**                 blocks,\n                        const int*             cu_k_len,\n                        const int*             cu_block_num,\n                        const RopeKernelParam& rope_param,\n                        int64_t                stride_b,\n                        int64_t                stride_c,\n                        int64_t                stride_h,\n                        int64_t                stride_s,\n                        int                    block_seq_len,\n                        int                    layer_id,\n                        int                    cp_rank,\n                        cutlass::FastDivmod    cp_size,\n                        int                    max_seq_len,\n                        int                    head_num,\n                        int                    head_dim,\n                        int                    batch_size,\n                        int                    quant_policy,\n                        cudaStream_t           stream = {});\n\n/// TODO: remove `sum_k_len`\ntemplate<class T>\nvoid invokeFlattenKV_v2_(const AttentionParams<T>& params, int sum_k_len)\n{\n    // blocks -> [H, 2, sum_k_len, D]\n    invokeFlattenKV_v2((T*)params.linear_iter_params.kv_cache,\n                       (T*)params.linear_iter_params.kv_cache + params.linear_iter_params.key_to_val,\n                       (char**)params.block_iter_params.block_ptrs,\n                       params.cu_k_len,\n                       params.block_iter_params.cu_block_nums,\n                       RopeKernelParam{},\n                       0,\n                       1,\n                       params.linear_iter_params.stride_h / params.size_per_head,\n                       1,\n                       params.block_iter_params.block_len,\n                       params.block_iter_params.layer_id,\n                       params.cp_rank,\n                       params.cp_size,\n                       params.max_k_len,\n                       params.num_kv_heads,\n                       params.size_per_head,\n                       params.batch_size,\n                       params.quant_policy,\n                       params.stream);\n}\n\nsize_t\nget_cache_block_size(DataType dtype, DataType kvtype, int layer_num, int head_num, int head_dim, int block_seq_len);\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/kernels/attention/linear_iterator.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\nnamespace turbomind {\n\ntemplate<class T, int CTA_S, int HeadDim>\nstruct LinearIterator {\n\n    const T* kv_cache_;\n    int      key_to_val_;\n\n    const T* key_ptr_{};\n    int      tile_id_{};\n\n    __device__ LinearIterator(const T* kv_cache, int key_to_val): kv_cache_{kv_cache}, key_to_val_{key_to_val} {}\n\n    __device__ void SetTile(int tile_id)\n    {\n        key_ptr_ = kv_cache_ + tile_id * CTA_S * HeadDim;\n        tile_id_ = tile_id;\n    }\n\n    __device__ void Advance()\n    {\n        --tile_id_;\n        if (tile_id_ >= 0) {\n            key_ptr_ -= CTA_S * HeadDim;\n        }\n    }\n\n    template<int Index>\n    __device__ const T* OffsetPtr(int offset) const\n    {\n        if constexpr (Index == 0) {\n            return key_ptr_ + offset;\n        }\n        else if constexpr (Index == 1) {\n            return key_ptr_ + offset + key_to_val_;\n        }\n        else {\n            static_assert(Index != Index, \"invalid index\");\n        }\n    }\n};\n\ntemplate<class Tkv_, int CTA_S, int HeadDim>\nstruct LinearIteratorFactory {\n    using Tkv = Tkv_;\n\n    const Tkv* kv_cache_;\n    const int* cu_ctx_len_;\n    int        stride_h_;\n    int        key_to_val_;\n\n    __device__ auto Create(int batch_idx, int head_idx)\n    {\n        int seq_ti = cu_ctx_len_[batch_idx] - cu_ctx_len_[0];\n        // `head_idx * stride_h_` may be larger than `INT_MAX`\n        const Tkv* kv_cache = kv_cache_ + head_idx * (int64_t)stride_h_ + seq_ti * HeadDim;\n\n        return LinearIterator<Tkv, CTA_S, HeadDim>{kv_cache, key_to_val_};\n    }\n};\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/kernels/attention/mainloop.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\nnamespace turbomind::attention {\n\ntemplate<class Tag, class Attention>\nstruct Mainloop {\n};\n\n}  // namespace turbomind::attention\n"
  },
  {
    "path": "src/turbomind/kernels/attention/mainloop_sm70.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include \"arch.h\"\n#include \"iterator_sm70.h\"\n#include \"mainloop.h\"\n\nnamespace turbomind::attention {\n\ntemplate<class Impl_>\nstruct Mainloop<arch::Sm70, Impl_> {\n\n    using Impl = Impl_;\n\n    using T   = typename Impl::T;\n    using Tkv = typename Impl::Tkv;\n\n    using ThreadMapKV = typename Impl::ThreadMapKV;\n\n    using GmemIterK_ = Sm70GmemIterator<Tkv, ThreadMapKV, typename Impl::SmemLayoutK, 0>;\n    using GmemIterV_ = Sm70GmemIterator<Tkv, ThreadMapKV, typename Impl::SmemLayoutV, 1>;\n\n    /// TODO: hide this behind a SFINAE gate so that `*KVp` stuff won't be needed for non-quantized impls\n    using CombinedIterK =\n        CombinedIterator<GmemIterK_, Sm70GmemIterator<T, typename Impl::ThreadMapKVp, typename Impl::SmemLayoutKVp, 2>>;\n    using CombinedIterV =\n        CombinedIterator<GmemIterV_, Sm70GmemIterator<T, typename Impl::ThreadMapKVp, typename Impl::SmemLayoutKVp, 3>>;\n\n    using GmemIterK = std::conditional_t<std::is_same_v<T, Tkv>, GmemIterK_, CombinedIterK>;\n    using GmemIterV = std::conditional_t<std::is_same_v<T, Tkv>, GmemIterV_, CombinedIterV>;\n\n    using FragQ = typename Impl::FragQ;\n    using FragS = typename Impl::FragS;\n    using FragO = typename Impl::FragO;\n    using FragM = typename Impl::FragM;\n    using FragL = typename Impl::FragL;\n\n    using SharedStorage = typename Impl::SharedStorage;\n\n    static constexpr int CTA_S = Impl::CTA_S;\n\n    int cp_size_{1};\n    int cp_rank_{0};\n\n    __device__ void SetCpInfo(int cp_size, int cp_rank)\n    {\n        cp_size_ = cp_size;\n        cp_rank_ = cp_rank;\n    }\n\n    template<class CacheIter, class StoreS>\n    __device__ void operator()(FragQ&         frag_Q,\n                               CacheIter&     cache_iter,\n                               FragO&         frag_O,\n                               FragM&         frag_M,\n                               FragL&         frag_L,\n                               int            offset_Q,\n                               int            offset_K,\n                               int            max_step,\n                               int            tile_iter,\n                               int            mask_iter_back,\n                               int            mask_iter_front,\n                               int            window_size,\n                               float          qk_scale,\n                               SharedStorage& storage,\n                               const StoreS&  store_S)\n    {\n        GmemIterK gmem_K{};\n        GmemIterV gmem_V{};\n\n        Impl::SetSmemKV(gmem_K, gmem_V, storage, true);\n\n        typename GmemIterK::Fragment tmp_K;\n\n        typename Impl::StateQK state_QK{storage, frag_Q};\n        typename Impl::StatePV state_PV{storage, true};\n\n        Impl::Sync();\n\n        gmem_K.Load<true>(cache_iter, tmp_K, max_step - offset_K);\n        gmem_K.Save(tmp_K);\n\n        constexpr auto nop = [](int) {};\n\n        auto loop = [&](auto is_residue, auto is_mask) {\n            typename GmemIterV::Fragment tmp_V;\n\n            gmem_V.Load<is_residue>(cache_iter, tmp_V, is_residue ? max_step - offset_K : CTA_S);\n            cache_iter.Advance();\n\n            FragS frag_S{};\n\n            Impl::Sync();\n            state_QK.Load(0, 0);\n\n            Impl::ComputeQK(state_QK, frag_S, 0, nop, [&] {});\n\n            gmem_V.Save(tmp_V);\n\n            if (tile_iter > 0) {\n                gmem_K.Load<false>(cache_iter, tmp_K, CTA_S);\n            }\n\n            if constexpr (is_mask) {\n                ApplyCasualMask(frag_S, offset_Q, offset_K, window_size);\n            }\n\n            Impl::Softmax<is_mask>(frag_S, frag_M, frag_L, frag_O, qk_scale);\n\n            Impl::ConvertStoP(frag_S, state_PV.frag_P, storage);\n\n            Impl::Sync();\n            state_PV.Load(0, 0);\n\n            Impl::ComputePV(state_PV, frag_O, 0, nop, [&] {});\n\n            gmem_K.Save(tmp_K);\n\n            offset_K -= CTA_S;\n        };\n\n        for (int mask_iter = max(1, mask_iter_back); tile_iter > 0 && mask_iter > 0; --tile_iter, --mask_iter) {\n            loop(std::true_type{}, std::true_type{});\n        }\n\n        PRAGMA_NO_UNROLL\n        for (; tile_iter > mask_iter_front; --tile_iter) {\n            loop(std::false_type{}, std::false_type{});\n        }\n\n        for (; tile_iter > 0; --tile_iter) {\n            loop(std::false_type{}, std::true_type{});\n        }\n    }\n\n    __device__ void ApplyCasualMask(FragS& frag_S, int offset_Q, int offset_K, int window_size)\n    {\n        Impl::ForeachS(frag_S, [&](int hi, int qi, int si, int ri, float& score) {\n            int w = (offset_Q + qi) - ((offset_K + si) * cp_size_ + cp_rank_);\n            if (0 <= w && w < window_size) {}\n            else {\n                score -= std::numeric_limits<float>::infinity();\n            }\n        });\n    }\n};\n\n}  // namespace turbomind::attention\n"
  },
  {
    "path": "src/turbomind/kernels/attention/mainloop_sm80.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include \"iterator_sm80.h\"\n#include \"mainloop.h\"\n#include \"src/turbomind/kernels/core/pipe_iter.h\"\n#include <cuda_pipeline_primitives.h>\n#include <type_traits>\n\nnamespace turbomind::attention {\n\ntemplate<int Stages>\nstruct Sm80_CpAsync {\n};\n\ntemplate<int Stages, class Impl_>\nstruct Mainloop<Sm80_CpAsync<Stages>, Impl_> {\n\n    using Impl = Impl_;\n\n    using T   = typename Impl::T;\n    using Tkv = typename Impl::Tkv;\n\n    static constexpr std::false_type false_c{};\n    static constexpr std::true_type  true_c{};\n\n    static constexpr int CTA_S = Impl::CTA_S;\n\n    using ThreadMapKV = typename Impl::ThreadMapKV;\n\n    using GmemIterK_ = Sm80GmemIterator<Tkv, ThreadMapKV, typename Impl::SmemLayoutK, 0>;\n    using GmemIterV_ = Sm80GmemIterator<Tkv, ThreadMapKV, typename Impl::SmemLayoutV, 1>;\n\n    /// TODO: hide this behind a SFINAE gate so that `*KVp` stuff won't be needed for non-quantized impls\n    using CombinedIterK =\n        CombinedIterator<GmemIterK_, Sm80GmemIterator<T, typename Impl::ThreadMapKVp, typename Impl::SmemLayoutKVp, 2>>;\n    using CombinedIterV =\n        CombinedIterator<GmemIterV_, Sm80GmemIterator<T, typename Impl::ThreadMapKVp, typename Impl::SmemLayoutKVp, 3>>;\n\n    using GmemIterK = std::conditional_t<std::is_same_v<T, Tkv>, GmemIterK_, CombinedIterK>;\n    using GmemIterV = std::conditional_t<std::is_same_v<T, Tkv>, GmemIterV_, CombinedIterV>;\n\n    using FragQ = typename Impl::FragQ;\n    using FragS = typename Impl::FragS;\n    using FragO = typename Impl::FragO;\n    using FragM = typename Impl::FragM;\n    using FragL = typename Impl::FragL;\n\n    using SharedStorage = typename Impl::SharedStorage;\n\n    int cp_size_{1};\n    int cp_rank_{0};\n\n    __device__ void SetCpInfo(int cp_size, int cp_rank)\n    {\n        cp_size_ = cp_size;\n        cp_rank_ = cp_rank;\n    }\n\n    template<class... Args>\n    __device__ void operator()(Args&&... args)\n    {\n        Run(Sm80_CpAsync<Stages>{},\n            std::integral_constant<int, Impl::kHeadDim>{},\n            std::integral_constant<bool, Impl::MLA>{},\n            ((Args &&) args)...);\n    }\n\n    template<int Idx, class A, class B>\n    __device__ static decltype(auto) Select(A&& a, B&& b)\n    {\n        if constexpr (Idx) {\n            return (B &&) b;\n        }\n        else {\n            return (A &&) a;\n        }\n    }\n\n    template<int Batch, bool Advnace, class GmemIter, class BlockIter>\n    __device__ static void Prefetch(GmemIter gmem_iter, BlockIter& block_iter, int k, int pipe_iter)\n    {\n        const int begin = k * Batch;\n        if (begin < ThreadMapKV::kIterS) {\n            gmem_iter.Prefetch(false_c, block_iter, begin, Batch, CTA_S, pipe_iter);\n        }\n        if (begin + Batch == ThreadMapKV::kIterS) {\n            if constexpr (Advnace) {\n                block_iter.Advance();\n            }\n            __pipeline_commit();\n        }\n    }\n\n    template<int head_dim, class CacheIter, class StoreS, int Stages_>\n    __device__ void Run(Sm80_CpAsync<Stages_>,\n                        std::integral_constant<int, head_dim>,\n                        std::false_type,  // is MLA\n                        FragQ&         frag_Q,\n                        CacheIter&     cache_iter,\n                        FragO&         frag_O,\n                        FragM&         frag_M,\n                        FragL&         frag_L,\n                        int            offset_Q,\n                        int            offset_K,\n                        int            max_step,\n                        int            tile_iter,\n                        int            mask_iter_back,\n                        int            mask_iter_front,\n                        int            window_size,\n                        float          qk_scale,\n                        SharedStorage& storage,\n                        const StoreS&  store_S)\n    {\n        // multi-stage: pipe_iter * size\n        //   two-stage: constant offset\n\n        GmemIterK gmem_K{};\n        GmemIterV gmem_V{};\n\n        Impl::SetSmemKV(gmem_K, gmem_V, storage, false);\n\n        PipeIter<Stages> pipe_iter;\n\n        PRAGMA_UNROLL\n        for (int i = 0; i < Stages; ++i) {\n            gmem_K.ClearSmem((++pipe_iter).w);\n        }\n\n        Impl::Sync();\n\n        // 0\n        gmem_K.Prefetch(true_c, cache_iter, max_step - offset_K, (++pipe_iter).w);\n        __pipeline_commit();\n\n        // 1\n        gmem_V.Prefetch(true_c, cache_iter, max_step - offset_K, (++pipe_iter).w);\n        __pipeline_commit();\n\n        cache_iter.Advance();\n\n        PRAGMA_UNROLL\n        for (int stages = 2; stages < Stages - 2; stages += 2) {\n            // 2 + 2X\n            gmem_K.Prefetch(false_c, cache_iter, CTA_S, (++pipe_iter).w);\n            __pipeline_commit();\n            // 3 + 2X\n            gmem_V.Prefetch(false_c, cache_iter, CTA_S, (++pipe_iter).w);\n            __pipeline_commit();\n\n            cache_iter.Advance();\n        }\n\n        if constexpr (Stages % 2 == 0) {\n            // 2 + 2Y\n            gmem_K.Prefetch(false_c, cache_iter, CTA_S, (++pipe_iter).w);\n            __pipeline_commit();\n        }\n\n        auto& gmem_0 = Select<Stages % 2>(gmem_V, gmem_K);\n        auto& gmem_1 = Select<Stages % 2>(gmem_K, gmem_V);\n\n        constexpr auto kBatch0 = Stages % 2 ? Impl::kBatchV : Impl::kBatchK;\n        constexpr auto kBatch1 = Stages % 2 ? Impl::kBatchK : Impl::kBatchV;\n\n        typename Impl::StateQK state_QK{storage, frag_Q};\n        typename Impl::StatePV state_PV{storage};\n\n        Wait();\n        state_QK.Load(0, (++pipe_iter).r);\n\n        auto loop = [&](auto is_mask) {\n            __align__(16) FragS frag_S{};\n\n            auto prefetch_0 = [&, pipe_iter](int k) {\n                Prefetch<kBatch0, Stages % 2 == 0>(gmem_0, cache_iter, k, pipe_iter.w);\n            };\n\n            Impl::ComputeQK(state_QK, frag_S, pipe_iter.r, prefetch_0, [&] {\n                Wait();\n                state_PV.Load(0, (++pipe_iter).r);\n            });\n\n            if constexpr (is_mask) {\n                ApplyCasualMask(frag_S, offset_Q, offset_K, window_size);\n            }\n\n            Impl::Softmax<is_mask>(frag_S, frag_M, frag_L, frag_O, qk_scale);\n\n            Impl::ConvertStoP(frag_S, state_PV.frag_P, storage);\n\n            auto prefetch_1 = [&, pipe_iter](int k) {\n                Prefetch<kBatch1, Stages % 2 != 0>(gmem_1, cache_iter, k, pipe_iter.w);\n            };\n\n            Impl::ComputePV(state_PV, frag_O, pipe_iter.r, prefetch_1, [&] {\n                Wait();\n                state_QK.Load(0, (++pipe_iter).r);\n            });\n\n            offset_K -= CTA_S;\n        };\n\n        for (int mask_iter = mask_iter_back; tile_iter > 0 && mask_iter > 0; --tile_iter, --mask_iter) {\n            loop(true_c);\n        }\n\n        PRAGMA_NO_UNROLL\n        for (; tile_iter > mask_iter_front; --tile_iter) {\n            loop(false_c);\n        }\n\n        for (; tile_iter > 0; --tile_iter) {\n            loop(true_c);\n        }\n\n        __pipeline_commit();\n        __pipeline_wait_prior(0);\n    }\n\n    // #if 1\n    template<class CacheIter, class StoreS>\n    __device__ void Run(Sm80_CpAsync<2>,\n                        std::integral_constant<int, 192>,\n                        std::false_type,  // is MLA\n                        FragQ&         frag_Q,\n                        CacheIter&     cache_iter,\n                        FragO&         frag_O,\n                        FragM&         frag_M,\n                        FragL&         frag_L,\n                        int            offset_Q,\n                        int            offset_K,\n                        int            max_step,\n                        int            tile_iter,\n                        int            mask_iter_back,\n                        int            mask_iter_front,\n                        int            window_size,\n                        float          qk_scale,\n                        SharedStorage& storage,\n                        const StoreS&  store_S)\n    {\n        GmemIterK gmem_K{};\n        GmemIterV gmem_V{};\n\n        Impl::SetSmemKV(gmem_K, gmem_V, storage, false);\n\n        PRAGMA_UNROLL\n        for (int i = 0; i < Stages; ++i) {\n            gmem_K.ClearSmem(i);\n        }\n\n        gmem_K.Prefetch(true_c, cache_iter, max_step - offset_K, 0);\n        __pipeline_commit();\n\n        typename Impl::StateQK state_QK{storage, frag_Q};\n        typename Impl::StatePV state_PV{storage};\n\n        Wait();\n        state_QK.Load(0, 0);\n\n        constexpr auto _ = [](int) {};\n\n        auto loop = [&](auto is_residue, auto is_mask) {\n            __align__(16) FragS frag_S{};\n\n            auto prefetch_V = [&](int k) {\n                if (k == 0) {\n                    gmem_V.Prefetch(is_residue, cache_iter, max_step - offset_K, 1);\n                    __pipeline_commit();\n                }\n            };\n            prefetch_V(0);\n\n            Impl::ComputeQK(state_QK, frag_S, 0, _, [&] {\n                Wait();\n                state_PV.Load(0, 1);\n            });\n\n            cache_iter.Advance();\n\n            auto prefetch_K = [&](int k) {\n                if (k == 0) {\n                    gmem_K.Prefetch(false_c, cache_iter, CTA_S, 0);\n                    __pipeline_commit();\n                }\n            };\n            prefetch_K(0);\n\n            if constexpr (is_mask) {\n                ApplyCasualMask(frag_S, offset_Q, offset_K, window_size);\n            }\n\n            Impl::Softmax<is_mask>(frag_S, frag_M, frag_L, frag_O, qk_scale);\n\n            Impl::ConvertStoP(frag_S, state_PV.frag_P, storage);\n\n            Impl::ComputePV(state_PV, frag_O, 1, _, [&] {\n                Wait();\n                state_QK.Load(0, 0);\n            });\n\n            offset_K -= CTA_S;\n        };\n\n        for (int mask_iter = max(1, mask_iter_back); tile_iter > 0 && mask_iter > 0; --tile_iter, --mask_iter) {\n            loop(true_c, true_c);\n        }\n\n        PRAGMA_NO_UNROLL\n        for (; tile_iter > mask_iter_front; --tile_iter) {\n            loop(false_c, false_c);\n        }\n\n        for (; tile_iter > 0; --tile_iter) {\n            loop(false_c, true_c);\n        }\n\n        __pipeline_commit();\n        __pipeline_wait_prior(0);\n    }\n\n#if 1\n    // Load      : K0,K1 | V0,K2,V1,K3 ...\n    // Compute   :    K0 | K1,V0,K2,V1 ...\n    // - more register consumption\n    // - more interleaved HMMA and FMA\n    // - slight performance gain\n    template<int head_dim, class CacheIter, class StoreS>\n    __device__ void Run(Sm80_CpAsync<2>,\n                        std::integral_constant<int, head_dim>,\n                        std::false_type,  // is MLA\n                        FragQ&         frag_Q,\n                        CacheIter&     cache_iter_,\n                        FragO&         frag_O,\n                        FragM&         frag_M,\n                        FragL&         frag_L,\n                        int            offset_Q,\n                        int            offset_K,\n                        int            max_step,\n                        int            tile_iter,\n                        int            mask_iter_back,\n                        int            mask_iter_front,\n                        int            window_size,\n                        float          qk_scale,\n                        SharedStorage& storage,\n                        const StoreS&  store_S)\n    {\n        GmemIterK gmem_K{};\n        GmemIterV gmem_V{};\n\n        Impl::SetSmemKV(gmem_K, gmem_V, storage, false);\n\n        gmem_K.ClearSmem(0);\n        gmem_K.ClearSmem(1);\n\n        auto cache_iter_K = cache_iter_;\n        auto cache_iter_V = cache_iter_;\n\n        gmem_K.Prefetch(true_c, cache_iter_K, max_step - offset_K, 0);\n        __pipeline_commit();\n        cache_iter_K.Advance();\n\n        typename Impl::StateQK state_QK{storage, frag_Q};\n        typename Impl::StatePV state_PV{storage};\n\n        Wait();\n        state_QK.Load(0, 0);\n\n        FragS frag_S{};\n        auto  _ = [&](int k) {\n            if (k == 0) {\n                gmem_K.Prefetch(false_c, cache_iter_K, CTA_S, 1);\n                __pipeline_commit();\n            }\n        };\n        Impl::ComputeQK(state_QK, frag_S, 0, _, [&] {\n            Wait();\n            state_QK.Load(0, 1);\n        });\n        cache_iter_K.Advance();\n\n        auto loop = [&](auto is_residue, auto is_mask, auto is_last) {\n            if constexpr (is_mask) {\n                ApplyCasualMask(frag_S, offset_Q, offset_K, window_size);\n            }\n\n            Impl::Softmax<is_mask>(frag_S, frag_M, frag_L, frag_O, qk_scale);\n\n            Impl::ConvertStoP(frag_S, state_PV.frag_P, storage);\n\n            auto prefetch_V = [&](int k) {\n                if (k == 0) {\n                    gmem_V.Prefetch(is_residue, cache_iter_V, max_step - offset_K, 0);\n                    __pipeline_commit();\n                }\n            };\n            if constexpr (!is_last) {\n                clear(frag_S);\n                Impl::ComputeQK(state_QK, frag_S, 1, prefetch_V, [&] {\n                    Wait();\n                    state_PV.Load(0, 0);\n                });\n                cache_iter_V.Advance();\n            }\n            else {\n                prefetch_V(0);\n                Wait();\n                state_PV.Load(0, 0);\n            }\n\n            auto prefetch_K = [&](int k) {\n                if (k == 0) {\n                    gmem_K.Prefetch(false_c, cache_iter_K, CTA_S, 1);\n                    __pipeline_commit();\n                }\n            };\n            Impl::ComputePV(state_PV, frag_O, 0, prefetch_K, [&] {\n                Wait();\n                state_QK.Load(0, 1);\n            });\n            cache_iter_K.Advance();\n\n            offset_K -= CTA_S;\n        };\n\n        for (int mask_iter = max(1, mask_iter_back); tile_iter > 0 && mask_iter > 0; --tile_iter, --mask_iter) {\n            loop(true_c, true_c, false_c);\n        }\n\n        mask_iter_front = max(1, mask_iter_front);\n\n        PRAGMA_NO_UNROLL\n        for (; tile_iter > mask_iter_front; --tile_iter) {\n            loop(false_c, false_c, false_c);\n        }\n\n        for (; tile_iter > 1; --tile_iter) {\n            loop(false_c, true_c, false_c);\n        }\n\n        if (tile_iter > 0) {\n            loop(false_c, true_c, true_c);\n        }\n\n        __pipeline_commit();\n        __pipeline_wait_prior(0);\n    }\n#endif\n\n    // Simplified MLA implementation\n    template<int head_dim, class CacheIter, class StoreS, int Stages_>\n    __device__ void Run(Sm80_CpAsync<Stages_>,\n                        std::integral_constant<int, head_dim>,\n                        std::true_type,  // is MLA\n                        FragQ&         frag_Q,\n                        CacheIter&     cache_iter,\n                        FragO&         frag_O,\n                        FragM&         frag_M,\n                        FragL&         frag_L,\n                        int            offset_Q,\n                        int            offset_K,\n                        int            max_step,\n                        int            tile_iter,\n                        int            mask_iter_back,\n                        int            mask_iter_front,\n                        int            window_size,\n                        float          qk_scale,\n                        SharedStorage& storage,\n                        const StoreS&  store_S)\n    {\n        GmemIterK gmem_KV{};\n\n        Impl::SetSmemKV(gmem_KV, gmem_KV, storage, false);\n\n        PipeIter<Stages> pipe_iter;\n\n        PRAGMA_UNROLL\n        for (int i = 0; i < Stages; ++i) {\n            gmem_KV.ClearSmem((++pipe_iter).w);\n        }\n\n        Impl::Sync();\n\n        gmem_KV.Prefetch(true_c, cache_iter, max_step - offset_K, (++pipe_iter).w);\n        __pipeline_commit();\n        cache_iter.Advance();\n\n        PRAGMA_UNROLL\n        for (int stages = 1; stages < Stages - 1; ++stages) {\n            gmem_KV.Prefetch(false_c, cache_iter, CTA_S, (++pipe_iter).w);\n            __pipeline_commit();\n            cache_iter.Advance();\n        }\n\n        typename Impl::StateQK state_QK{storage, frag_Q};\n        typename Impl::StatePV state_PV{storage};\n\n        Wait();\n        state_QK.Load(0, (++pipe_iter).r);\n\n        auto loop = [&](auto is_mask) {\n            __align__(16) FragS frag_S{};\n\n            gmem_KV.Prefetch(false_c, cache_iter, CTA_S, pipe_iter.w);\n            __pipeline_commit();\n            cache_iter.Advance();\n\n            Impl::ComputeQK(\n                state_QK, frag_S, pipe_iter.r, [](int) {}, [] {});\n\n            if constexpr (is_mask) {\n                ApplyCasualMask(frag_S, offset_Q, offset_K, window_size);\n            }\n\n            Impl::Softmax<is_mask>(frag_S, frag_M, frag_L, frag_O, qk_scale);\n\n            Impl::ConvertStoP(frag_S, state_PV.frag_P, storage);\n\n            state_PV.Load(0, pipe_iter.r);\n            Impl::ComputePV(\n                state_PV, frag_O, pipe_iter.r, [](int) {}, [] {});\n\n            Wait();\n            state_QK.Load(0, (++pipe_iter).r);\n\n            offset_K -= CTA_S;\n        };\n\n        for (int mask_iter = mask_iter_back; tile_iter > 0 && mask_iter > 0; --tile_iter, --mask_iter) {\n            loop(true_c);\n        }\n\n        PRAGMA_NO_UNROLL\n        for (; tile_iter > mask_iter_front; --tile_iter) {\n            loop(false_c);\n        }\n\n        for (; tile_iter > 0; --tile_iter) {\n            loop(true_c);\n        }\n\n        __pipeline_commit();\n        __pipeline_wait_prior(0);\n    }\n\n    __device__ void Wait()\n    {\n        __pipeline_wait_prior(Stages - 2);\n        Impl::Sync();\n    }\n\n    __device__ void ApplyCasualMask(FragS& frag_S, int offset_Q, int offset_K, int window_size)\n    {\n        Impl::ForeachS(frag_S, [&](int hi, int qi, int si, int ri, float& score) {\n            int w = (offset_Q + qi) - ((offset_K + si) * cp_size_ + cp_rank_);\n            if (0 <= w && w < window_size) {}\n            else {\n                score -= std::numeric_limits<float>::infinity();\n            }\n        });\n    }\n};\n\n}  // namespace turbomind::attention\n"
  },
  {
    "path": "src/turbomind/kernels/attention/quantization.h",
    "content": "#pragma once\n\n#include \"src/turbomind/core/data_type.h\"\n\n#include \"src/turbomind/kernels/core/array_ops.h\"\n#include \"src/turbomind/kernels/core/data_type.h\"\n\n#include <cmath>\n#include <cuda_bf16.h>\n#include <cuda_fp16.h>\n#include <cuda_fp8.h>\n\nnamespace turbomind {\n\n#define TM_ROUND_USE_CVT_RNI 1\n\ninline constexpr bool kFuseU4F16Dequant  = false;\ninline constexpr bool kForceIntZeroPoint = false;\n\ntemplate<class T>\n__device__ T Infinity()\n{\n    if constexpr (std::is_same_v<T, half>) {\n        return __ushort_as_half((unsigned short)0x7C00U);\n    }\n\n#if __CUDA_ARCH__ >= 800\n    if constexpr (std::is_same_v<T, nv_bfloat16>) {\n        return __ushort_as_bfloat16((unsigned short)0x7F80U);\n    }\n#endif\n\n    if constexpr (std::is_same_v<T, float>) {\n        return __int_as_float(0x7f800000U);\n    }\n\n    return T{};\n}\n\ntemplate<class T>\n__device__ constexpr T Max(T a, T b)\n{\n    if constexpr (std::is_same_v<T, half>) {\n        return __hmax(a, b);\n    }\n\n#if __CUDA_ARCH__ >= 800\n    if constexpr (std::is_same_v<T, nv_bfloat16>) {\n        return __hmax(a, b);\n    }\n#endif\n\n    if constexpr (std::is_same_v<T, float>) {\n        return fmaxf(a, b);\n    }\n\n    if constexpr (std::is_same_v<T, int>) {\n        return max(a, b);\n    }\n\n    return T{};\n}\n\ntemplate<class T>\n__device__ constexpr T Min(T a, T b)\n{\n    if constexpr (std::is_same_v<T, half>) {\n        return __hmin(a, b);\n    }\n\n#if __CUDA_ARCH__ >= 800\n    if constexpr (std::is_same_v<T, nv_bfloat16>) {\n        return __hmin(a, b);\n    }\n#endif\n\n    if constexpr (std::is_same_v<T, float>) {\n        return fminf(a, b);\n    }\n\n    if constexpr (std::is_same_v<T, int>) {\n        return min(a, b);\n    }\n\n    return T{};\n}\n\ntemplate<bool norm = true>\ninline __device__ Array<half, 4> cvt_f16x4_u8(const Array<uint8_t, 4>& src)\n{\n    static constexpr uint32_t f16_magic = 0x64000000;\n    // 01234567 01234567\n    // SEEEEEMM MMMMMMMM\n    //      1MM XXXXXXXX\n    // (1 + x/2^10) * 2^(e-15) -> e-15=10 -> e=25=16+8+1 -> 01100100b -> 0x64\n    Array<uint32_t, 2> dst;\n    dst[0] = __byte_perm((uint32_t&)src, f16_magic, 0x7170);\n    dst[1] = __byte_perm((uint32_t&)src, f16_magic, 0x7372);\n    if constexpr (norm) {\n        for (int i = 0; i < 4; ++i) {\n            ((Array<half, 4>&)dst)[i] -= __ushort_as_half(0x6400U);\n        }\n    }\n    return (Array<half, 4>&)dst;\n}\n\ntemplate<bool norm = true>\ninline __device__ Array<half, 4> cvt_f16x2x2_u8_trans(const Array<uint8_t, 4>& src)\n{\n    static constexpr uint32_t f16_magic = 0x64000000;\n    // 01234567 01234567\n    // SEEEEEMM MMMMMMMM\n    //      1MM XXXXXXXX\n    // (1 + x/2^10) * 2^(e-15) -> e-15=10 -> e=25=16+8+1 -> 01100100b -> 0x64\n    Array<uint32_t, 2> dst;\n    dst[0] = __byte_perm((uint32_t&)src, f16_magic, 0x7270);\n    dst[1] = __byte_perm((uint32_t&)src, f16_magic, 0x7371);\n    if constexpr (norm) {\n        for (int i = 0; i < 4; ++i) {\n            ((Array<half, 4>&)dst)[i] -= __ushort_as_half(0x6400U);\n        }\n    }\n    return (Array<half, 4>&)dst;\n}\n\ninline __device__ Array<nv_bfloat16, 4> cvt_bf16x4_u8(const Array<uint8_t, 4>& src)\n{\n    // 01234567 01234567 01234567 01234567\n    // SEEEEEEE EMMMMMMM MMMMMMMM MMMMMMMM\n    //          1MM...   XXXXXXXX\n    // (1 + x/2^15) * 2^(e-127) -> e-127=15 -> e=142 -> 01000111 -> 0x47\n    static constexpr uint32_t f32_magic = 0x47000000;  // 32768\n\n    Array<uint32_t, 4> tmp;\n    tmp[0] = __byte_perm((uint32_t&)src, f32_magic, 0x7604);\n    tmp[1] = __byte_perm((uint32_t&)src, f32_magic, 0x7614);\n    tmp[2] = __byte_perm((uint32_t&)src, f32_magic, 0x7624);\n    tmp[3] = __byte_perm((uint32_t&)src, f32_magic, 0x7634);\n\n    auto& vec = (Array<float, 4>&)tmp;\n\n    Array<nv_bfloat16, 4> dst;\n    PRAGMA_UNROLL\n    for (int i = 0; i < 4; ++i) {\n        dst[i] = __float2bfloat16(vec[i] - 32768.f);\n    }\n    return dst;\n}\n\ninline __device__ Array<float, 4> cvt_f32x4_u8(const Array<uint8_t, 4>& src)\n{\n    // 01234567 01234567 01234567 01234567\n    // SEEEEEEE EMMMMMMM MMMMMMMM MMMMMMMM\n    //          1MM...   XXXXXXXX\n    // (1 + x/2^15) * 2^(e-127) -> e-127=15 -> e=142 -> 01000111 -> 0x47\n    static constexpr uint32_t f32_magic = 0x47000000;  // 32768\n\n    Array<uint32_t, 4> tmp;\n    tmp[0] = __byte_perm((uint32_t&)src, f32_magic, 0x7604);\n    tmp[1] = __byte_perm((uint32_t&)src, f32_magic, 0x7614);\n    tmp[2] = __byte_perm((uint32_t&)src, f32_magic, 0x7624);\n    tmp[3] = __byte_perm((uint32_t&)src, f32_magic, 0x7634);\n\n    auto& vec = (Array<float, 4>&)tmp;\n    PRAGMA_UNROLL\n    for (int i = 0; i < 4; ++i) {\n        vec[i] -= 32768.f;\n    }\n    return vec;\n}\n\ntemplate<bool norm = true>\ninline __device__ Array<nv_bfloat16, 8> cvt_bf16x8_u4(const Array<uint4_t, 8>& src)\n{\n#if __CUDA_ARCH__ >= 800\n    // 01234567 01234567\n    // SEEEEEEE EMMMMMMM\n    //          1...XXXX\n    // (1 + x/2^7) * 2^(e-127) -> e-127=7 -> e=134 -> 0100 0011 -> 0x43\n    static constexpr uint32_t TEMPLATE = 0x43004300;  // nv_bfloat162(128, 128)\n    static constexpr uint32_t MASK     = 0x000f000f;\n    static constexpr uint32_t immLut   = (0xf0 & 0xcc) | 0xaa;\n\n    Array<uint32_t, 4> h;\n\n    static_assert(sizeof(Array<nv_bfloat16, 8>) == sizeof(Array<uint32_t, 4>));\n\n    uint32_t const& i4s    = reinterpret_cast<uint32_t const&>(src);\n    const uint32_t  i4s_4  = i4s >> 4;\n    const uint32_t  i4s_8  = i4s >> 8;\n    const uint32_t  i4s_12 = i4s >> 12;\n\n    asm volatile(\"lop3.b32 %0, %1, %2, %3, %4;\\n\" : \"=r\"(h[0]) : \"r\"(i4s), \"n\"(MASK), \"n\"(TEMPLATE), \"n\"(immLut));\n    asm volatile(\"lop3.b32 %0, %1, %2, %3, %4;\\n\" : \"=r\"(h[1]) : \"r\"(i4s_4), \"n\"(MASK), \"n\"(TEMPLATE), \"n\"(immLut));\n    asm volatile(\"lop3.b32 %0, %1, %2, %3, %4;\\n\" : \"=r\"(h[2]) : \"r\"(i4s_8), \"n\"(MASK), \"n\"(TEMPLATE), \"n\"(immLut));\n    asm volatile(\"lop3.b32 %0, %1, %2, %3, %4;\\n\" : \"=r\"(h[3]) : \"r\"(i4s_12), \"n\"(MASK), \"n\"(TEMPLATE), \"n\"(immLut));\n\n    if constexpr (norm) {\n        auto result = reinterpret_cast<nv_bfloat16*>(h.data());\n        PRAGMA_UNROLL\n        for (int i = 0; i < 8; ++i) {\n            result[i] -= nv_bfloat16(128.f);\n        }\n    }\n    return (Array<nv_bfloat16, 8>&)h;\n#else\n    return {};\n#endif\n}\n\n#if TM_ROUND_USE_CVT_RNI\n\ntemplate<class T>\ninline __device__ T round(float x)\n{\n    uint32_t y{};\n    if constexpr (std::is_same_v<T, uint8_t>) {\n        asm(\"cvt.rni.sat.u8.f32 %0, %1;\\n\" : \"=r\"(y) : \"f\"(x));\n    }\n    else if constexpr (std::is_same_v<T, uint16_t>) {\n        asm(\"cvt.rni.sat.u16.f32 %0, %1;\\n\" : \"=r\"(y) : \"f\"(x));\n    }\n    else if constexpr (std::is_same_v<T, uint32_t>) {\n        asm(\"cvt.rni.sat.u32.f32 %0, %1;\\n\" : \"=r\"(y) : \"f\"(x));\n    }\n    else if constexpr (std::is_same_v<T, int32_t>) {\n        asm(\"cvt.rni.sat.s32.f32 %0, %1;\\n\" : \"=r\"(y) : \"f\"(x));\n    }\n    else {\n        static_assert(!std::is_same_v<T, T>, \"not implemented\");\n    }\n    return y;\n}\n\ntemplate<class T>\ninline __device__ T round(half x)\n{\n    uint32_t y{};\n    if constexpr (std::is_same_v<T, uint8_t>) {\n        asm(\"cvt.rni.sat.u8.f16 %0, %1;\\n\" : \"=r\"(y) : \"h\"((uint16_t&)x));\n    }\n    else if constexpr (std::is_same_v<T, uint16_t>) {\n        asm(\"cvt.rni.sat.u16.f16 %0, %1;\\n\" : \"=r\"(y) : \"h\"((uint16_t&)x));\n    }\n    else if constexpr (std::is_same_v<T, uint32_t>) {\n        asm(\"cvt.rni.sat.u32.f16 %0, %1;\\n\" : \"=r\"(y) : \"h\"((uint16_t&)x));\n    }\n    else if constexpr (std::is_same_v<T, int32_t>) {\n        asm(\"cvt.rni.sat.s32.f16 %0, %1;\\n\" : \"=r\"(y) : \"h\"((uint16_t&)x));\n    }\n    else {\n        static_assert(!std::is_same_v<T, T>, \"not implemented\");\n    }\n    return y;\n}\n\n#else\n\ntemplate<class T>\ninline __device__ T round(float x)\n{\n    x += .5f;\n\n    uint32_t y{};\n    if constexpr (std::is_same_v<T, uint8_t>) {\n        asm(\"cvt.rzi.sat.u8.f32 %0, %1;\\n\" : \"=r\"(y) : \"f\"(x));\n    }\n    else if constexpr (std::is_same_v<T, uint16_t>) {\n        asm(\"cvt.rzi.sat.u16.f32 %0, %1;\\n\" : \"=r\"(y) : \"f\"(x));\n    }\n    else if constexpr (std::is_same_v<T, uint32_t>) {\n        asm(\"cvt.rzi.sat.u32.f32 %0, %1;\\n\" : \"=r\"(y) : \"f\"(x));\n    }\n    else {\n        static_assert(!std::is_same_v<T, T>, \"not implemented\");\n    }\n    return y;\n}\n\ntemplate<class T>\ninline __device__ T round(half x)\n{\n    x += half(.5f);\n\n    uint32_t y{};\n    if constexpr (std::is_same_v<T, uint8_t>) {\n        asm(\"cvt.rzi.sat.u8.f16 %0, %1;\\n\" : \"=r\"(y) : \"h\"((uint16_t&)x));\n    }\n    else if constexpr (std::is_same_v<T, uint16_t>) {\n        asm(\"cvt.rzi.sat.u16.f16 %0, %1;\\n\" : \"=r\"(y) : \"h\"((uint16_t&)x));\n    }\n    else if constexpr (std::is_same_v<T, uint32_t>) {\n        asm(\"cvt.rzi.sat.u32.f16 %0, %1;\\n\" : \"=r\"(y) : \"h\"((uint16_t&)x));\n    }\n    else {\n        static_assert(!std::is_same_v<T, T>, \"not implemented\");\n    }\n    return y;\n}\n\n#endif\n\ntemplate<class To, class Ti, class B>\ninline __device__ To quant(Ti x, B n_bits)\n{\n    auto y = round<To>(x);\n    if constexpr (n_bits < sizeof(To) * 8) {  // saturate operation for sub-byte type\n        return min(y, To((1 << n_bits) - 1));\n    }\n    else {\n        return y;\n    }\n}\n\ntemplate<int WarpThreadC, class T, int C>\n__device__ inline void warp_minmax(Array<T, 2>& stats, const Array<T, C>& x)\n{\n    PRAGMA_UNROLL\n    for (int i = 0; i < C; ++i) {\n        stats[0] = Min(stats[0], x[i]);\n        stats[1] = Max(stats[1], x[i]);\n    }\n    if constexpr (sizeof(T) == 2) {\n        PRAGMA_UNROLL\n        for (int mask = WarpThreadC / 2; mask > 0; mask /= 2) {\n            Array<T, 2> tmp;\n            (uint32_t&)tmp = __shfl_xor_sync(uint32_t(-1), (uint32_t&)stats, mask);\n            stats[0]       = Min(stats[0], tmp[0]);\n            stats[1]       = Max(stats[1], tmp[1]);\n        }\n    }\n    else {\n        PRAGMA_UNROLL\n        for (int mask = WarpThreadC / 2; mask > 0; mask /= 2) {\n            stats[0] = Min(stats[0], __shfl_xor_sync(uint32_t(-1), stats[0], mask));\n            stats[1] = Max(stats[1], __shfl_xor_sync(uint32_t(-1), stats[1], mask));\n        }\n    }\n}\n\ntemplate<int WarpThreadC, class P, class T, class B, int N, int C, int S>\n__device__ void warp_stats(Array<P, 2> (&param)[S], const Array<T, N> (&x)[S][C], B n_bits)\n{\n    PRAGMA_UNROLL\n    for (int s = 0; s < S; ++s) {\n        Array<T, 2> stats{Infinity<T>(), -Infinity<T>()};\n        PRAGMA_UNROLL\n        for (int c = 0; c < C; ++c) {\n            warp_minmax<WarpThreadC>(stats, x[s][c]);\n        }\n        const float inv_q_max = fdividef(1.f, float((1 << n_bits) - 1));\n        const float scale     = ((float)stats[1] - (float)stats[0]) * inv_q_max;\n        param[s][0]           = (P)scale;\n        param[s][1]           = (P)stats[0];\n\n        if constexpr (kForceIntZeroPoint) {\n#if TM_ROUND_USE_CVT_RNI\n            // rintf -> cvt.rni.f32.f32\n            param[s][1] = (P)(rintf((float)stats[0] / scale) * scale);\n#else\n            // roundf -> cvt.rzi.f32.f32(x + 0.5)\n            param[s][1] = (P)(roundf((float)stats[0] / scale) * scale);\n#endif\n        }\n    }\n}\n\ntemplate<class Q, class T, class P, class B, int N, int C, int S>\n__device__ void\nquantize(Array<Q, N> (&dst)[S][C], const Array<T, N> (&src)[S][C], const Array<P, 2> (&params)[S], B n_bits)\n{\n    PRAGMA_UNROLL\n    for (int s = 0; s < S; ++s) {\n        P inv_scale = (P)fdividef(1.f, (float)params[s][0]);\n        P zero      = params[s][1];\n        PRAGMA_UNROLL\n        for (int c = 0; c < C; ++c) {\n            PRAGMA_UNROLL\n            for (int i = 0; i < N; ++i) {\n                const auto v = ((P)src[s][c][i] - zero) * inv_scale;\n                dst[s][c][i] = quant<Q>(v, n_bits);\n            }\n        }\n    }\n}\n\n//////////////////////////////////////////////////////////////////////////////////////////////////\n\n// generic case for floating point -> floating point / integer -> integer conversion\ntemplate<typename Ti, typename To, typename = void>\nstruct ConvertKvCache {\n    __device__ __host__ ConvertKvCache(float, float) {}\n    template<int N>\n    __device__ static auto convert(const Array<Ti, N>& vi)\n    {\n        Array<To, N> vo;\n        PRAGMA_UNROLL\n        for (int i = 0; i < N; ++i) {\n            vo[i] = (To)vi[i];\n        }\n        return vo;\n    }\n    template<int N>\n    inline __device__ auto operator()(const Array<Ti, N>& vi) const -> Array<To, N>\n    {\n        return convert(vi);\n    }\n};\n\n// generic case for converting to same type, bypass\ntemplate<typename T>\nstruct ConvertKvCache<T, T> {\n    __device__ __host__ ConvertKvCache(float, float) {}\n    template<int N>\n    __device__ static auto convert(const Array<T, N>& v)\n    {\n        return v;\n    }\n    template<int N>\n    inline __device__ auto operator()(const Array<T, N>& v) const -> Array<T, N>\n    {\n        return convert(v);\n    }\n};\n\n//  floating point -> u8\ntemplate<class T>\nstruct ConvertKvCache<T, uint8_t> {\n    T          inv_scale_;\n    T          zero_;\n    __device__ ConvertKvCache(T scale, T zero): zero_{zero}\n    {\n        // NVCC complains if we put this in the member init list\n        inv_scale_ = (T)fdividef(1.f, (float)scale);\n    }\n\n    template<int N>\n    __device__ auto operator()(const Array<T, N>& vi) const\n    {\n        Array<uint8_t, N> vo;\n        PRAGMA_UNROLL\n        for (int i = 0; i < N; ++i) {\n            vo[i] = quant<uint8_t>((vi[i] - zero_) * inv_scale_, std::integral_constant<int, 8>{});\n        }\n        return vo;\n    }\n};\n\ntemplate<class T>\nstruct ConvertKvCache<T, uint4_t> {\n    T          inv_scale_;\n    T          zero_;\n    __device__ ConvertKvCache(T scale, T zero): zero_{zero}\n    {\n        // NVCC complains if we put this in the member init list\n        inv_scale_ = (T)fdividef(1.f, (float)scale);\n    }\n\n    static __device__ Array<uint4_t, 8> pack(const Array<uint8_t, 8>& vi)\n    {\n        Array<uint32_t, 2> ui = (Array<uint32_t, 2>&)vi;\n\n        ui[0] |= (ui[0] >> 12);\n        ui[1] |= (ui[1] >> 12);\n\n        //  7 6 5 4 3 2 1 0\n        // _7_67564_3_23120\n        uint32_t uo = __byte_perm(ui[0], ui[1], 0x5140);\n\n        return (Array<uint4_t, 8>&)uo;\n    }\n\n    /// TODO: try cvt.pack.sat.u4\n    template<int N>\n    __device__ auto operator()(const Array<T, N>& vi) const\n    {\n        static_assert(N % 8 == 0);\n        Array<uint8_t, N> tmp;\n        PRAGMA_UNROLL\n        for (int i = 0; i < N; ++i) {\n            tmp[i] = quant<uint8_t>((vi[i] - zero_) * inv_scale_, std::integral_constant<int, 4>{});\n        }\n        Array<uint4_t, N> vo;\n        PRAGMA_UNROLL\n        for (int i = 0; i < N; i += 8) {\n            (Array<uint4_t, 8>&)vo[i] = pack((Array<uint8_t, 8>&)tmp[i]);\n        }\n        return vo;\n    }\n};\ntemplate<>\nstruct ConvertKvCache<uint4_t, half> {\n\n    half scale_;\n    half zero_;\n\n    __device__ ConvertKvCache(half scale, half zero)\n    {\n        scale_ = scale;\n        zero_  = zero;\n    }\n\n    static __device__ Array<half, 8> cvt_f16x8_u4(const Array<uint4_t, 8>& vi)\n    {\n        Array<half, 8>            result;\n        uint32_t*                 h           = reinterpret_cast<uint32_t*>(&result);\n        uint32_t const&           i4s         = reinterpret_cast<uint32_t const&>(vi);\n        static constexpr uint32_t immLut      = (0xf0 & 0xcc) | 0xaa;\n        static constexpr uint32_t BOT_MASK    = 0x000f000f;\n        static constexpr uint32_t TOP_MASK    = 0x00f000f0;\n        static constexpr uint32_t MAGIC_NUM_0 = 0x64006400;  // `1024`\n        static constexpr uint32_t MAGIC_NUM_1 = 0x54005400;  // `64`\n        // const uint32_t            top_i4s     = i4s >> 8;\n        uint32_t top_i4s = __byte_perm(i4s, 0, 0x4321);\n        asm(\"lop3.b32 %0, %1, %2, %3, %4;\\n\" : \"=r\"(h[0]) : \"r\"(i4s), \"n\"(BOT_MASK), \"n\"(MAGIC_NUM_0), \"n\"(immLut));\n        asm(\"lop3.b32 %0, %1, %2, %3, %4;\\n\" : \"=r\"(h[1]) : \"r\"(i4s), \"n\"(TOP_MASK), \"n\"(MAGIC_NUM_1), \"n\"(immLut));\n        asm(\"lop3.b32 %0, %1, %2, %3, %4;\\n\" : \"=r\"(h[2]) : \"r\"(top_i4s), \"n\"(BOT_MASK), \"n\"(MAGIC_NUM_0), \"n\"(immLut));\n        asm(\"lop3.b32 %0, %1, %2, %3, %4;\\n\" : \"=r\"(h[3]) : \"r\"(top_i4s), \"n\"(TOP_MASK), \"n\"(MAGIC_NUM_1), \"n\"(immLut));\n        asm(\"sub.f16x2 %0, %1, %2;\\n\" : \"=r\"(h[0]) : \"r\"(h[0]), \"r\"(MAGIC_NUM_0));\n        asm(\"sub.f16x2 %0, %1, %2;\\n\" : \"=r\"(h[1]) : \"r\"(h[1]), \"r\"(MAGIC_NUM_1));\n        asm(\"sub.f16x2 %0, %1, %2;\\n\" : \"=r\"(h[2]) : \"r\"(h[2]), \"r\"(MAGIC_NUM_0));\n        asm(\"sub.f16x2 %0, %1, %2;\\n\" : \"=r\"(h[3]) : \"r\"(h[3]), \"r\"(MAGIC_NUM_1));\n        return result;\n    }\n\n    static __device__ Array<half, 8> cvt_f16x8_u4_biased(const Array<uint4_t, 8>& vi)\n    {\n        Array<half, 8>            result;\n        uint32_t*                 h           = reinterpret_cast<uint32_t*>(&result);\n        uint32_t const&           i4s         = reinterpret_cast<uint32_t const&>(vi);\n        static constexpr uint32_t immLut      = (0xf0 & 0xcc) | 0xaa;\n        static constexpr uint32_t BOT_MASK    = 0x000f000f;\n        static constexpr uint32_t TOP_MASK    = 0x00f000f0;\n        static constexpr uint32_t MAGIC_NUM_1 = 0x54005400;        // `64`\n        static constexpr uint32_t MAGIC_NUM_2 = MAGIC_NUM_1 >> 4;  // `64` >> 4\n        const uint32_t            top_i4s     = i4s >> 8;\n        // uint32_t top_i4s = __byte_perm(i4s, 0, 0x4321);\n        asm(\"lop3.b32 %0, %1, %2, %3, %4;\\n\" : \"=r\"(h[0]) : \"r\"(i4s), \"n\"(BOT_MASK), \"n\"(MAGIC_NUM_2), \"n\"(immLut));\n        asm(\"lop3.b32 %0, %1, %2, %3, %4;\\n\" : \"=r\"(h[1]) : \"r\"(i4s), \"n\"(TOP_MASK), \"n\"(MAGIC_NUM_1), \"n\"(immLut));\n        asm(\"lop3.b32 %0, %1, %2, %3, %4;\\n\" : \"=r\"(h[2]) : \"r\"(top_i4s), \"n\"(BOT_MASK), \"n\"(MAGIC_NUM_2), \"n\"(immLut));\n        asm(\"lop3.b32 %0, %1, %2, %3, %4;\\n\" : \"=r\"(h[3]) : \"r\"(top_i4s), \"n\"(TOP_MASK), \"n\"(MAGIC_NUM_1), \"n\"(immLut));\n        h[0] <<= 4;\n        h[2] <<= 4;\n        return result;\n    }\n\n    template<int N>\n    __device__ static auto convert(const Array<uint4_t, N>& vi)\n    {\n        Array<half, N> vo;\n        PRAGMA_UNROLL\n        for (int i = 0; i < N; i += 8) {\n            auto& v = (Array<half, 8>&)vo[i];\n            if constexpr (kFuseU4F16Dequant) {\n                v = cvt_f16x8_u4_biased((Array<uint4_t, 8>&)vi[i]);\n            }\n            else {\n                v = cvt_f16x8_u4((Array<uint4_t, 8>&)vi[i]);\n            }\n        }\n        return vo;\n    }\n\n    template<int N>\n    __device__ auto operator()(const Array<uint4_t, N>& vi) const\n    {\n        auto vo = convert(vi);\n        PRAGMA_UNROLL\n        for (int i = 0; i < N; ++i) {\n            vo[i] = vo[i] * scale_ + zero_;\n        }\n        return vo;\n    }\n};\n\ntemplate<>\nstruct ConvertKvCache<uint4_t, nv_bfloat16> {\n\n    nv_bfloat16 scale_;\n    nv_bfloat16 zero_;\n\n    __device__ ConvertKvCache(nv_bfloat16 scale, nv_bfloat16 zero)\n    {\n        scale_ = scale;\n        zero_  = zero;\n    }\n\n    template<int N>\n    __device__ static Array<nv_bfloat16, N> convert(const Array<uint4_t, N>& vi)\n    {\n        Array<nv_bfloat16, N> vo{};\n        PRAGMA_UNROLL\n        for (int i = 0; i < N; i += 8) {\n            auto& v = (Array<short, 8>&)vo[i];\n            auto  u = cvt_bf16x8_u4((Array<uint4_t, 8>&)vi[i]);\n            v       = (Array<short, 8>&)u;\n        }\n        return vo;\n    }\n\n    template<int N>\n    __device__ Array<nv_bfloat16, N> operator()(const Array<uint4_t, N>& vi) const\n    {\n        auto vo = convert(vi);\n        PRAGMA_UNROLL\n        for (int i = 0; i < N; ++i) {\n            vo[i] = vo[i] * scale_ + zero_;\n        }\n        return (Array<nv_bfloat16, N>&)vo;\n    }\n};\n\ntemplate<>\nstruct ConvertKvCache<uint4_t, float> {\n\n#if 1\n    ConvertKvCache<uint4_t, half> impl_;\n\n    __device__ ConvertKvCache(float scale, float zero): impl_{scale, zero} {}\n\n    template<int N>\n    __device__ auto operator()(const Array<uint4_t, N>& vi) const\n    {\n        return cast<float>(impl_(vi));\n    }\n#else\n    static __device__ Array<half, 8> cvt_f16x8_u4_biased(const Array<uint4_t, 8>& vi)\n    {\n        Array<half, 8> result;\n        uint32_t* h = reinterpret_cast<uint32_t*>(&result);\n        uint32_t const& i4s = reinterpret_cast<uint32_t const&>(vi);\n        static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;\n        static constexpr uint32_t BOT_MASK = 0x000f000f;\n        static constexpr uint32_t TOP_MASK = 0x00f000f0;\n        static constexpr uint32_t MAGIC_NUM_1 = 0x54005400;        // `64`\n        static constexpr uint32_t MAGIC_NUM_2 = MAGIC_NUM_1 >> 4;  // `64` >> 4\n        const uint32_t top_i4s = i4s >> 8;\n        asm(\"lop3.b32 %0, %1, %2, %3, %4;\\n\" : \"=r\"(h[0]) : \"r\"(i4s), \"n\"(BOT_MASK), \"n\"(MAGIC_NUM_2), \"n\"(immLut));\n        asm(\"lop3.b32 %0, %1, %2, %3, %4;\\n\" : \"=r\"(h[1]) : \"r\"(i4s), \"n\"(TOP_MASK), \"n\"(MAGIC_NUM_1), \"n\"(immLut));\n        asm(\"lop3.b32 %0, %1, %2, %3, %4;\\n\" : \"=r\"(h[2]) : \"r\"(top_i4s), \"n\"(BOT_MASK), \"n\"(MAGIC_NUM_2), \"n\"(immLut));\n        asm(\"lop3.b32 %0, %1, %2, %3, %4;\\n\" : \"=r\"(h[3]) : \"r\"(top_i4s), \"n\"(TOP_MASK), \"n\"(MAGIC_NUM_1), \"n\"(immLut));\n        h[0] <<= 4;\n        h[2] <<= 4;\n        return result;\n    }\n    float scale_;\n    float zero_;\n    __device__ ConvertKvCache(float scale, float zero)\n    {\n        scale_ = scale;\n        zero_ = zero - scale * 64.f;\n    }\n    template<int N>\n    __device__ auto operator()(const Array<uint4_t, N>& vi) const\n    {\n        auto vo = cast<float>(cvt_f16x8_u4_biased(vi));\n        using namespace ops;\n        return vo * scale_ + zero_;\n    }\n#endif\n};\n\n// u8 -> f32/f16/bf16\ntemplate<class T>\nstruct ConvertKvCache<uint8_t, T> {\n    T          scale_;\n    T          zero_;\n    __device__ ConvertKvCache(T scale, T zero): scale_{scale}, zero_{zero} {}\n\n    template<int N>\n    __device__ static auto convert(const Array<uint8_t, N>& vi)\n    {\n        Array<T, N> vo;\n        PRAGMA_UNROLL\n        for (int n = 0; n < N; n += 4) {\n            auto& ui = (const Array<uint8_t, 4>&)vi[n];\n            auto& uo = (Array<T, 4>&)vo[n];\n\n            if constexpr (std::is_same_v<T, half>) {\n                uo = cvt_f16x4_u8<true>(ui);\n            }\n            else if constexpr (std::is_same_v<T, float>) {\n                uo = cvt_f32x4_u8(ui);\n            }\n#if __CUDA_ARCH__ >= 800\n            else if constexpr (std::is_same_v<T, nv_bfloat16>) {\n                uo = cvt_bf16x4_u8(ui);\n            }\n#endif\n        }\n        return vo;\n    }\n\n    template<int N>\n    __device__ auto operator()(const Array<uint8_t, N>& vi) const\n    {\n        auto vo = convert(vi);\n        PRAGMA_UNROLL\n        for (int i = 0; i < N; ++i) {\n            vo[i] = vo[i] * scale_ + zero_;\n        }\n        return vo;\n    }\n};\n\ntemplate<class T>\nstruct ConvertKvCache<fp4_e2m1_t, T> {\n\n    __device__ static Array<bfloat16_t, 8> cvt_bf16x8_e2m1(const Array<fp4_e2m1_t, 8>& vi)\n    {\n        const uint32_t& x = (const uint32_t&)vi;\n\n        constexpr uint32_t S  = 0x80008000U;\n        constexpr uint32_t EM = 0x01C001C0U;\n\n        Array<uint32_t, 4> vo;\n\n        // clang-format off\n        vo[0] = (x << 12 & S) | (x << 6 & EM);\n        vo[1] = (x <<  8 & S) | (x << 2 & EM);\n        vo[2] = (x <<  4 & S) | (x >> 2 & EM);\n        vo[3] = (x <<  0 & S) | (x >> 6 & EM);\n        // clang-format on\n\n        constexpr uint32_t e  = (127U - 1U + 127U) << 7U;\n        constexpr uint32_t ee = e << 16U | e;\n\n        PRAGMA_UNROLL\n        for (int i = 0; i < 4; ++i) {\n#if TURBOMIND_ARCH_SM90\n            asm(\"mul.rn.bf16x2 %0, %1, %2;\" : \"=r\"(vo[i]) : \"r\"(vo[i]), \"r\"(ee));\n#else\n            asm(\"fma.rn.bf16x2 %0, %1, %2, %3;\" : \"=r\"(vo[i]) : \"r\"(vo[i]), \"r\"(ee), \"r\"(0));\n#endif\n        }\n\n        return (Array<bfloat16_t, 8>&)vo;\n    }\n\n    __device__ static Array<half, 8> cvt_f16x8_e2m1(const Array<fp4_e2m1_t, 8>& vi)\n    {\n        const uint32_t& x = (const uint32_t&)vi;\n\n        constexpr uint32_t S  = 0x80008000U;\n        constexpr uint32_t EM = 0x0E000E00U;\n\n        Array<uint32_t, 4> vo;\n\n        // clang-format off\n        vo[0] = (x << 12 & S) | (x << 9 & EM);\n        vo[1] = (x <<  8 & S) | (x << 5 & EM);\n        vo[2] = (x <<  4 & S) | (x << 1 & EM);\n        vo[3] = (x <<  0 & S) | (x >> 3 & EM);\n        // clang-format on\n\n        constexpr uint32_t e  = (15U - 1U + 15U) << 10U;\n        constexpr uint32_t ee = e << 16U | e;\n\n        PRAGMA_UNROLL\n        for (int i = 0; i < 4; ++i) {\n            asm volatile(\"mul.f16x2 %0, %1, %2;\" : \"=r\"(vo[i]) : \"r\"(vo[i]), \"r\"(ee));\n        }\n\n        return (Array<half, 8>&)vo;\n    }\n\n    template<int N>\n    __device__ static auto convert(const Array<fp4_e2m1_t, N>& vi)\n    {\n        Array<T, N> vo;\n        PRAGMA_UNROLL\n        for (int i = 0; i < N; i += 8) {\n            auto& v = (Array<T, 8>&)vo[i];\n            if constexpr (std::is_same_v<T, bfloat16_t>) {\n                v = cvt_bf16x8_e2m1((Array<fp4_e2m1_t, 8>&)vi[i]);\n            }\n            else if constexpr (std::is_same_v<T, half_t>) {\n                v = cvt_f16x8_e2m1((Array<fp4_e2m1_t, 8>&)vi[i]);\n            }\n            else {\n                static_assert(N != N, \"not implemented\");\n            }\n        }\n        return vo;\n    }\n};\n\n__device__ inline Array<bfloat16_t, 4> cvt_bf16x4_e4m3(const Array<fp8_e4m3_t, 4>& vi)\n{\n    const uint32_t& x = (const uint32_t&)vi;\n\n    //    0   7   C   0\n    // SEEEEEEEEMMMMMMMSEEEEEEEEMMMMMMM\n    // SEEEEMMM        SEEEEMMM\n    //         SEEEEMMM        SEEEEMMM\n\n    constexpr uint32_t S  = 0x80008000U;\n    constexpr uint32_t EM = 0x07F007F0U;\n\n    Array<uint32_t, 2> vo;\n\n    vo[0] = (x << 8 & S) | (x << 4 & EM);\n    vo[1] = (x << 0 & S) | (x >> 4 & EM);\n\n    constexpr uint32_t e  = (127U - 7U + 127U) << 7U;\n    constexpr uint32_t ee = e << 16U | e;\n\n    PRAGMA_UNROLL\n    for (int i = 0; i < 2; ++i) {\n#if TURBOMIND_ARCH_SM90\n        asm(\"mul.rn.bf16x2 %0, %1, %2;\" : \"=r\"(vo[i]) : \"r\"(vo[i]), \"r\"(ee));\n#else\n        asm(\"fma.rn.bf16x2 %0, %1, %2, %3;\" : \"=r\"(vo[i]) : \"r\"(vo[i]), \"r\"(ee), \"r\"(0));\n#endif\n    }\n\n    return (Array<bfloat16_t, 4>&)vo;\n}\n\n__device__ inline Array<half, 4> cvt_f16x4_e4m3(const Array<fp8_e4m3_t, 4>& vi)\n{\n    const uint32_t& x = (const uint32_t&)vi;\n\n    //    3   F   8   0\n    // SEEEEEMMMMMMMMMMSEEEEEMMMMMMMMMM\n    // SEEEEMMM        SEEEEMMM\n    //         SEEEEMMM        SEEEEMMM\n\n    constexpr uint32_t S  = 0x80008000U;\n    constexpr uint32_t EM = 0x3F803F80U;\n\n    Array<uint32_t, 2> vo;\n\n    vo[0] = (x << 8 & S) | (x << 7 & EM);\n    vo[1] = (x << 0 & S) | (x >> 1 & EM);\n\n    constexpr uint32_t e  = (15U - 7U + 15U) << 10U;\n    constexpr uint32_t ee = e << 16U | e;\n\n    PRAGMA_UNROLL\n    for (int i = 0; i < 2; ++i) {\n        asm(\"mul.rn.f16x2 %0, %1, %2;\" : \"=r\"(vo[i]) : \"r\"(vo[i]), \"r\"(ee));\n    }\n\n    return (Array<half, 4>&)vo;\n}\n\ntemplate<class T>\nstruct ConvertKvCache<fp8_e4m3_t, T> {\n\n    template<int N>\n    __device__ static auto convert(const Array<fp8_e4m3_t, N>& vi)\n    {\n        static_assert(N % 4 == 0);\n        Array<T, N> vo;\n        PRAGMA_UNROLL\n        for (int i = 0; i < N; i += 4) {\n            auto& v = (Array<T, 4>&)vo[i];\n            if constexpr (std::is_same_v<T, bfloat16_t>) {\n                v = cvt_bf16x4_e4m3((Array<fp8_e4m3_t, 4>&)vi[i]);\n            }\n            else if constexpr (std::is_same_v<T, half_t>) {\n                v = cvt_f16x4_e4m3((Array<fp8_e4m3_t, 4>&)vi[i]);\n            }\n            else {\n                static_assert(N != N, \"not implemented\");\n            }\n        }\n        return vo;\n    }\n};\n\ntemplate<class Q, class T>\ninline __device__ void StoreQuantParam(T* dst, Array<T, 2> src)\n{\n    Store(dst, src);\n}\n\ntemplate<>\ninline __device__ void StoreQuantParam<uint4_t, half>(half* dst, Array<half, 2> src)\n{\n    if constexpr (kFuseU4F16Dequant) {\n        src[1] = src[1] - src[0] * __ushort_as_half(0x5400);\n    }\n    Store(dst, src);\n}\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/kernels/attention/reduce.cu",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include \"cutlass/fast_math.h\"\n#include \"src/turbomind/kernels/attention/cta_map.h\"\n#include \"src/turbomind/kernels/core/array_ops.h\"\n#include \"src/turbomind/kernels/core/thread_map.h\"\n#include \"src/turbomind/utils/cuda_utils.h\"\n\n#include <type_traits>\n\nnamespace turbomind::attention {\n\ntemplate<int CP, int CTA_K, int HeadDim, int WarpCnt, bool First, class T>\n__global__ void reduce(T*         out,\n                       float*     partial_ML,\n                       float*     partial_O,\n                       const int* split_cnt_,\n                       int        max_split_cnt,\n                       int        query_num,\n                       int        head_num,\n                       float      exp_scale,\n                       int        cp_rank,\n                       int        stride_k,\n                       int        offset_k)\n{\n    __shared__ float s_out[WarpCnt][HeadDim];\n    __shared__ float s_ML[WarpCnt][2];\n    __shared__ float s_scale[CTA_K];\n\n    const int warp_id = threadIdx.x / WARP_SIZE;\n    const int lane_id = threadIdx.x % WARP_SIZE;\n\n    const int head_idx  = ReduceCtaMap::head_idx();\n    const int query_idx = ReduceCtaMap::query_idx();\n    const int chunk_idx = ReduceCtaMap::split_idx();\n\n    offset_k *= chunk_idx;\n    const int split_cnt = (split_cnt_ != nullptr) ? split_cnt_[query_idx] : 1;\n    if (offset_k >= split_cnt) {  // out of bound\n        return;\n    }\n\n    // merge cp and k for the first time and merge k thereafter.\n    constexpr int kCpUb     = First ? CP : 1;\n    constexpr int kWarpIter = First ? (CP + WarpCnt - 1) / WarpCnt : 1;\n    float         ML[kWarpIter][2];\n\n    // frag_M of this cp_rank and lane\n    float frag_M = -std::numeric_limits<float>::infinity();\n\n    const int offset_r = cp_rank * query_num * head_num * max_split_cnt * 2;\n    const int offset_m = First ? 0 : offset_r;\n    const int warp_m   = First ? cp_rank % WarpCnt : 0;\n\n    PRAGMA_UNROLL\n    for (int i = 0; i < kWarpIter; ++i) {\n        int        cp_i = warp_id + i * WarpCnt;\n        int        ki   = lane_id * stride_k + offset_k;\n        const bool mask = cp_i < kCpUb && ki < split_cnt;  // cp, q, h, k, 2\n        const int  index =\n            offset_m + ((cp_i * query_num * head_num + (query_idx * head_num + head_idx)) * max_split_cnt + ki) * 2;\n\n        Array<float, 2> temp_ML = {-std::numeric_limits<float>::infinity(), 0.f};\n        if (mask) {\n            Load(temp_ML, &partial_ML[index]);\n        }\n        Store(&ML[i][0], temp_ML);\n\n        frag_M = (mask && warp_m == warp_id) ? ML[i][0] : frag_M;\n    }\n\n    float block_M = -std::numeric_limits<float>::infinity();\n    float block_L = 0.f;\n    PRAGMA_UNROLL\n    for (int i = 0; i < kWarpIter; ++i) {\n        block_M = fmaxf(block_M, ML[i][0]);\n    }\n\n    PRAGMA_UNROLL\n    for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {\n        block_M = fmaxf(block_M, __shfl_xor_sync(uint32_t(-1), block_M, mask));\n    }\n\n    PRAGMA_UNROLL\n    for (int i = 0; i < kWarpIter; ++i) {\n        block_L += (ML[i][0] == -std::numeric_limits<float>::infinity()) ?\n                       0.0f :\n                       exp2f((ML[i][0] - block_M) * exp_scale) * ML[i][1];\n    }\n\n    PRAGMA_UNROLL\n    for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {\n        block_L += __shfl_xor_sync(uint32_t(-1), block_L, mask);\n    }\n\n    if constexpr (First && CP > 1) {\n        if (lane_id == 0) {\n            Store(&s_ML[warp_id][0], Array<float, 2>{block_M, block_L});\n        }\n        __syncthreads();\n\n        if (warp_id == 0 && lane_id == 0) {\n            PRAGMA_UNROLL\n            for (int i = 0; i < WarpCnt; ++i) {\n                block_M = fmaxf(block_M, s_ML[i][0]);\n            }\n\n            block_L = 0.f;\n            PRAGMA_UNROLL\n            for (int i = 0; i < WarpCnt; ++i) {\n                block_L += exp2f((s_ML[i][0] - block_M) * exp_scale) * s_ML[i][1];\n            }\n\n            Store(&s_ML[0][0], Array<float, 2>{block_M, block_L});\n        }\n        __syncthreads();\n\n        block_M = s_ML[0][0];\n        block_L = s_ML[0][1];\n    }\n\n    if (gridDim.z > 1 && warp_id == 0) {\n        int        ki    = lane_id * stride_k + offset_k;\n        const bool mask  = ki < split_cnt;  // q, h, k, 2\n        const int  index = offset_r + ((query_idx * head_num + head_idx) * max_split_cnt + ki) * 2;\n        if (mask) {\n            Store(&partial_ML[index], Array<float, 2>{block_M, block_L});\n        }\n    }\n\n    if (warp_id == warp_m) {\n        const float divisor = gridDim.z == 1 ? block_L : 1.0f;\n        s_scale[lane_id] =\n            frag_M == -std::numeric_limits<float>::infinity() ? 0.0f : exp2f((frag_M - block_M) * exp_scale) / divisor;\n    }\n\n    __syncthreads();\n\n    // HeadDim / WARP_SIZE\n    // 128     -> 4\n    // 64, 192 -> 2\n    constexpr int kVecSize = HeadDim % 128 == 0 ? 4 : 2;\n\n    using Map = RakedThreadMap<HeadDim, WarpCnt, kVecSize, WarpCnt, WARP_SIZE>;\n    static_assert(Map::kIterS == 1);\n\n    constexpr int C = Map::kIterC;\n\n    using Vec = Array<float, kVecSize>;\n\n    Vec accu_O[C]{};\n    Vec frag_O[C];\n\n    const int2 d = Map::get_offset(warp_id, lane_id);\n\n    auto for_each = [&](auto fn) {\n        const int ki = d.y;\n        PRAGMA_UNROLL\n        for (int c = 0; c < C; ++c) {\n            const int di = d.x + c * Map::kDeltaC;\n            fn(c, ki, di);\n        }\n    };\n\n    PRAGMA_UNROLL\n    for (int k = 0; k < CTA_K; k += WarpCnt) {\n        for_each([&](int c, int ki, int di) {\n            using namespace ops;\n            ki += k;\n            const int  split_idx = offset_k + stride_k * ki;\n            const bool mask      = split_idx < split_cnt;\n            const int  index     = (query_idx * head_num + head_idx) * max_split_cnt + split_idx;\n            const int  offset    = index * HeadDim + di;\n            if (mask) {\n                Load(frag_O[c], &partial_O[offset]);\n                accu_O[c] = accu_O[c] + frag_O[c] * s_scale[ki];\n            }\n        });\n    }\n\n    for_each([&](int c, int ki, int di) {\n        Store(&s_out[ki][di], accu_O[c]);  //\n    });\n\n    PRAGMA_UNROLL\n    for (int w = WarpCnt / 2; w > 0; w /= 2) {\n        __syncthreads();\n        for_each([&](int c, int ki, int di) {\n            using namespace ops;\n            if (ki < w) {\n                (Vec&)s_out[ki][di] = (Vec&)s_out[ki][di] + (Vec&)s_out[w + ki][di];\n            }\n        });\n    }\n\n    for_each([&](int c, int ki, int di) {\n        if (ki == 0) {\n            if (gridDim.z == 1) {\n                const int offset = (query_idx * head_num + head_idx) * HeadDim + di;\n                Store(&out[offset], cast<T>((Vec&)s_out[ki][di]));\n            }\n            else {\n                const int offset = ((query_idx * head_num + head_idx) * max_split_cnt + offset_k) * HeadDim + di;\n                Store(&partial_O[offset], (Vec&)s_out[ki][di]);\n            }\n        }\n    });\n}\n\ntemplate<int HeadDim, class T>\nvoid invokeReduceV3(T*           out,\n                    float*       partial_ML,\n                    float*       partial_O,\n                    const int*   split_cnt,\n                    int          partial_len,\n                    int          max_split_cnt,\n                    int          cp_size,\n                    int          cp_rank,\n                    int          query_num,\n                    int          head_num,\n                    float        exp_scale,\n                    cudaStream_t stream)\n{\n    constexpr int CTA_K = 32;  // warp size\n\n    constexpr int    kWarpCnt  = 4;\n    constexpr size_t kSmemSize = sizeof(float) * (kWarpCnt * HeadDim + kWarpCnt * 2 + CTA_K);\n    static_assert(kSmemSize < (48 << 10), \"shared memory usage exceeds 48KB per block\");\n\n    partial_ML -= cp_rank * query_num * head_num * partial_len * 2;  // begin address of cp_rank0\n\n    auto invoke = [&](auto cp, auto is_first, int stride_k) {\n        const dim3 block = kWarpCnt * WARP_SIZE;\n        const dim3 grid  = ReduceCtaMap::get_grid_shape(query_num, head_num, max_split_cnt, CTA_K);\n\n        reduce<cp, CTA_K, HeadDim, kWarpCnt, is_first><<<grid, block, kSmemSize, stream>>>(  //\n            out,\n            partial_ML,\n            partial_O,\n            split_cnt,\n            partial_len,\n            query_num,\n            head_num,\n            exp_scale,\n            cp_rank,\n            stride_k,\n            stride_k * CTA_K);\n\n        sync_check_cuda_error();\n    };\n\n    auto dispatch_cp = [&](int stride_k, auto is_first) {\n        switch (cp_size) {\n#define LAUNCH_INVOKE(n)                                                                                               \\\n    case n:                                                                                                            \\\n        invoke(std::integral_constant<int, n>{}, is_first, stride_k);                                                  \\\n        break;\n            LAUNCH_INVOKE(1);\n            LAUNCH_INVOKE(2);\n            LAUNCH_INVOKE(4);\n            LAUNCH_INVOKE(8);\n            LAUNCH_INVOKE(16);\n            LAUNCH_INVOKE(32);\n            default:\n                TM_CHECK(false) << \"reduce does not support cp_size = \" << cp_size;\n#undef LAUNCH_INVOKE\n        }\n    };\n\n    int stride_k = 1;\n\n    dispatch_cp(stride_k, std::true_type{});\n    while (max_split_cnt > CTA_K) {\n        max_split_cnt = (max_split_cnt + CTA_K - 1) / CTA_K;\n        stride_k *= CTA_K;\n        dispatch_cp(stride_k, std::false_type{});\n    }\n}\n\n#define INSTANTIATE_invokeReduceV3(dim, type)                                                                          \\\n    template void invokeReduceV3<dim>(type * out,                                                                      \\\n                                      float*       partial_ML,                                                         \\\n                                      float*       partial_O,                                                          \\\n                                      const int*   split_cnt,                                                          \\\n                                      int          partial_len,                                                        \\\n                                      int          max_split_cnt,                                                      \\\n                                      int          cp_size,                                                            \\\n                                      int          cp_rank,                                                            \\\n                                      int          query_num,                                                          \\\n                                      int          head_num,                                                           \\\n                                      float        exp_scale,                                                          \\\n                                      cudaStream_t stream);\n\nINSTANTIATE_invokeReduceV3(64, half);\nINSTANTIATE_invokeReduceV3(128, half);\nINSTANTIATE_invokeReduceV3(192, half);\nINSTANTIATE_invokeReduceV3(256, half);\nINSTANTIATE_invokeReduceV3(576, half);\n\n#if ENABLE_BF16\nINSTANTIATE_invokeReduceV3(64, nv_bfloat16);\nINSTANTIATE_invokeReduceV3(128, nv_bfloat16);\nINSTANTIATE_invokeReduceV3(192, nv_bfloat16);\nINSTANTIATE_invokeReduceV3(256, nv_bfloat16);\nINSTANTIATE_invokeReduceV3(576, nv_bfloat16);\n#endif\n\n}  // namespace turbomind::attention\n"
  },
  {
    "path": "src/turbomind/kernels/attention/reduce.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include \"cta_map.h\"\n#include \"src/turbomind/kernels/core/array_ops.h\"\n#include \"src/turbomind/kernels/core/thread_map.h\"\n#include <cstddef>\n#include <cuda_runtime.h>\n#include <type_traits>\n\nnamespace turbomind::attention {\n\ntemplate<int HeadDim, class T>\nvoid invokeReduceV3(T*           out,\n                    float*       partial_ML,\n                    float*       partial_O,\n                    const int*   split_cnt,\n                    int          partial_len,\n                    int          max_split_cnt,\n                    int          cp_size,\n                    int          cp_rank,\n                    int          query_num,\n                    int          head_num,\n                    float        exp_scale,\n                    cudaStream_t stream);\n}  // namespace turbomind::attention\n"
  },
  {
    "path": "src/turbomind/kernels/attention/reference.cu",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include \"reference.h\"\n#include \"src/turbomind/kernels/attention/rotary_embedding.h\"\n#include \"src/turbomind/kernels/core/array_ops.h\"\n#include \"src/turbomind/kernels/unfused_attention_kernels.h\"\n\nnamespace turbomind {\n\ntemplate<class T>\n__global__ void\ncreateCausalMasks(T* mask, const int* q_lens, const int* k_lens, int64_t max_q_len, int64_t max_k_len, int window_size)\n{\n    const int     bi      = blockIdx.x;\n    const int64_t q_len   = q_lens ? q_lens[bi] : max_q_len;\n    const int64_t k_len   = k_lens ? k_lens[bi] : max_k_len;\n    const int     history = k_len - q_len;\n    mask += bi * max_q_len * max_k_len;\n    for (int64_t i = threadIdx.x; i < max_q_len * max_k_len; i += blockDim.x) {\n        const int q = i / max_k_len;\n        const int k = i % max_k_len;\n        const int w = q - (k - history);\n\n        const bool is_valid = q < q_len && k < k_len && 0 <= w && w < window_size;\n\n        mask[i] = is_valid ? T{1.} : T{0.};\n    }\n}\n\n// [B, H, S, D]\ntemplate<class T>\n__global__ void\napplyRotaryEmbedding(T* k_cache, int max_k_len, int head_num, int head_dim, float rope_base, int rope_dim)\n{\n    const int    ti = blockIdx.x;\n    const size_t hi = blockIdx.y;\n    const size_t bi = blockIdx.z;\n\n    constexpr int kVecSize = 2;\n    const int     history  = 0;\n\n    for (int d = threadIdx.x * kVecSize; d < head_dim; d += blockDim.x * kVecSize) {\n        const size_t idx =\n            bi * head_num * max_k_len * head_dim + hi * max_k_len * head_dim + (history + ti) * head_dim + d;\n\n        Array<T, kVecSize> vec_K;\n\n        Load(vec_K, &k_cache[idx]);\n\n        RotaryEmbedding<kVecSize> rope(rope_base, rope_dim, history + ti, {d, 0});\n\n        rope.apply(vec_K);\n\n        Store(&k_cache[idx], vec_K);\n    }\n}\n\ntemplate<class T>\nvoid invokeApplyRotaryEmbedding(T*           k_cache,\n                                int          max_k_len,\n                                int          head_num,\n                                int          head_dim,\n                                float        rope_base,\n                                int          rope_dim,\n                                int          batch_size,\n                                cudaStream_t stream)\n{\n    int  threads = 128;\n    dim3 blocks(max_k_len, head_num, batch_size);\n\n    applyRotaryEmbedding<<<blocks, threads, 0, stream>>>(k_cache, max_k_len, head_num, head_dim, rope_base, rope_dim);\n}\n\ntemplate void invokeApplyRotaryEmbedding(half*        k_cache,\n                                         int          max_k_len,\n                                         int          head_num,\n                                         int          head_dim,\n                                         float        rope_base,\n                                         int          rope_dim,\n                                         int          batch_size,\n                                         cudaStream_t stream);\n#if ENABLE_BF16\ntemplate void invokeApplyRotaryEmbedding(nv_bfloat16* k_cache,\n                                         int          max_k_len,\n                                         int          head_num,\n                                         int          head_dim,\n                                         float        rope_base,\n                                         int          rope_dim,\n                                         int          batch_size,\n                                         cudaStream_t stream);\n#endif\n\ntemplate<class T>\n__global__ void processQKV(T*       q_out,     // [B, H, s, D]\n                           T*       k_cache,   // [B, H, S, D]\n                           T*       v_cache,   // [B, H, S, D]\n                           const T* qkv,       // [B, s, H, D]\n                           const T* qkv_bias,  // [Q; K; V]\n                           int      max_q_len,\n                           int      max_k_len,\n                           int      head_num,\n                           int      head_dim,\n                           int      kv_head_num,\n                           float    rope_theta,\n                           int      rope_dim)\n{\n    const int    ti = blockIdx.x;\n    const size_t hi = blockIdx.y;\n    const size_t bi = blockIdx.z;\n\n    const int history = max_k_len - max_q_len;\n\n    size_t qkv_head_num = head_num + 2 * kv_head_num;\n\n    auto q = qkv + (bi * max_q_len + ti) * qkv_head_num * head_dim;\n    auto k = q + head_num * head_dim;\n    auto v = k + kv_head_num * head_dim;\n\n    auto q_bias = qkv_bias ? qkv_bias + hi * head_dim : nullptr;\n    auto k_bias = qkv_bias ? q_bias + head_num * head_dim : nullptr;\n    auto v_bias = qkv_bias ? k_bias + kv_head_num * head_dim : nullptr;\n\n    constexpr int kVecSize = 2;\n\n    using namespace ops;\n\n    for (int d = threadIdx.x * kVecSize; d < head_dim; d += blockDim.x * kVecSize) {\n        const auto         idx = bi * head_num * max_q_len * head_dim + hi * max_q_len * head_dim + ti * head_dim + d;\n        Array<T, kVecSize> vec;\n        Ldg(vec, &q[hi * head_dim + d]);\n        if (qkv_bias) {\n            Array<T, kVecSize> bias;\n            Load(bias, &q_bias[d]);\n            vec = vec + bias;\n        }\n        if (rope_theta) {\n            RotaryEmbedding<kVecSize> rope(rope_theta, rope_dim, history + ti, {d, 0});\n            rope.apply(vec);\n        }\n\n        Store(&q_out[idx], vec);\n    }\n\n    if (hi >= kv_head_num) {\n        return;\n    }\n\n    for (int d = threadIdx.x * kVecSize; d < head_dim; d += blockDim.x * kVecSize) {\n        const auto idx =\n            bi * kv_head_num * max_k_len * head_dim + hi * max_k_len * head_dim + (history + ti) * head_dim + d;\n        Array<T, kVecSize> vec_K;\n        Array<T, kVecSize> vec_V;\n        Ldg(vec_K, &k[hi * head_dim + d]);\n        Ldg(vec_V, &v[hi * head_dim + d]);\n        if (qkv_bias) {\n            Array<T, kVecSize> bias_K;\n            Array<T, kVecSize> bias_V;\n            Load(bias_K, &k_bias[d]);\n            Load(bias_V, &v_bias[d]);\n            vec_K = vec_K + bias_K;\n            vec_V = vec_V + bias_V;\n        }\n        if (rope_theta) {\n            RotaryEmbedding<kVecSize> rope(rope_theta, rope_dim, history + ti, {d, 0});\n            rope.apply(vec_K);\n        }\n        Store(&k_cache[idx], vec_K);\n        Store(&v_cache[idx], vec_V);\n    }\n}\n\ntemplate<class T>\n__global__ void RepeatKVKernel(T*       keys,\n                               T*       vals,\n                               const T* k_cache,\n                               const T* v_cache,\n                               int      head_num,\n                               int      max_k_len,\n                               int      head_dim,\n                               int      kv_head_num,\n                               int      n_reps)\n{\n    const int64_t ti = blockIdx.x;\n    const int64_t hi = blockIdx.y;\n    const int64_t bi = blockIdx.z;\n\n    const auto khi = hi / n_reps;\n\n    // clang-format off\n    for (int d = threadIdx.x; d < head_dim; d += blockDim.x) {\n        int64_t d_idx = bi *    head_num * max_k_len * head_dim +  hi * max_k_len * head_dim + ti * head_dim + d;\n        int64_t s_idx = bi * kv_head_num * max_k_len * head_dim + khi * max_k_len * head_dim + ti * head_dim + d;\n        keys[d_idx] = k_cache[s_idx];\n        vals[d_idx] = v_cache[s_idx];\n    }\n    // clang-format on\n}\n\ntemplate<class T>\nReference<T>::Reference(cudaStream_t stream): stream_(stream)\n{\n    cublasCreate(&cublas_);\n    cublasSetStream(cublas_, stream);\n}\n\ntemplate<class T>\nvoid Reference<T>::Reshape(size_t max_q_len,\n                           size_t max_k_len,\n                           size_t head_num,\n                           size_t head_dim,\n                           size_t kv_head_num,\n                           size_t batch_size,\n                           int    window_size)\n{\n    std::cout << max_q_len << \" \" << max_k_len << \" \" << head_num << \" \" << head_dim << \" \" << batch_size << \"\\n\";\n\n    q_.resize(batch_size * head_num * max_q_len * head_dim);\n    mask_.resize(batch_size * max_q_len * max_k_len);\n\n    std::cout << \"size of QK buf: \"\n              << ((batch_size * head_num * max_q_len * max_k_len * sizeof(float)) / float(1 << 30)) << \" GB\\n\";\n    qk_.resize(batch_size * head_num * max_q_len * max_k_len);\n    pr_.resize(batch_size * head_num * max_q_len * max_k_len);\n    out_.resize(batch_size * max_q_len * head_num * head_dim);\n\n    keys_.resize(batch_size * head_num * max_k_len * head_dim);\n    vals_.resize(batch_size * head_num * max_k_len * head_dim);\n\n    cudaStreamSynchronize(0);\n\n    createCausalMasks<<<batch_size, 512, 0, stream_>>>(\n        mask_.data().get(), nullptr, nullptr, max_q_len, max_k_len, window_size);\n\n    max_q_len_   = max_q_len;\n    max_k_len_   = max_k_len;\n    head_num_    = head_num;\n    head_dim_    = head_dim;\n    kv_head_num_ = kv_head_num;\n    batch_size_  = batch_size;\n    window_size_ = window_size;\n}\n\ntemplate<class T>\nvoid Reference<T>::Execute(\n    T* output, T* k_cache, T* v_cache, const T* qkv, const T* qkv_bias, const T* sinks, float rope_base, int rope_dim)\n{\n    {\n        int  threads = 128;\n        dim3 blocks(max_q_len_, head_num_, batch_size_);\n        cudaDeviceSynchronize();\n\n        processQKV<<<blocks, threads, 0, stream_>>>(q_.data().get(),  //\n                                                    k_cache,\n                                                    v_cache,\n                                                    qkv,\n                                                    qkv_bias,\n                                                    max_q_len_,\n                                                    max_k_len_,\n                                                    head_num_,\n                                                    head_dim_,\n                                                    kv_head_num_,\n                                                    rope_base,\n                                                    rope_dim);\n\n        // std::cout << head_num_ << \" \" << kv_head_num_ << \" \" << head_dim_ / kv_head_num_ << \"\\n\";\n\n        blocks.x = max_k_len_;\n        RepeatKVKernel<<<blocks, threads, 0, stream_>>>(keys_.data().get(),\n                                                        vals_.data().get(),\n                                                        k_cache,\n                                                        v_cache,\n                                                        head_num_,\n                                                        max_k_len_,\n                                                        head_dim_,\n                                                        kv_head_num_,\n                                                        head_num_ / kv_head_num_);\n\n        cudaDeviceSynchronize();\n    }\n\n    const cudaDataType data_type = std::is_same_v<T, half> ? CUDA_R_16F : CUDA_R_16BF;\n\n    float alpha = 1.f / sqrtf((float)head_dim_);\n    float beta  = 0.f;\n    cublasGemmStridedBatchedEx(cublas_,\n                               CUBLAS_OP_T,              // trans A\n                               CUBLAS_OP_N,              // trans B\n                               max_k_len_,               // m\n                               max_q_len_,               // n\n                               head_dim_,                // k\n                               &alpha,                   // alpha\n                               keys_.data().get(),       // A\n                               data_type,                // A type\n                               head_dim_,                // lda\n                               max_k_len_ * head_dim_,   // strideA\n                               q_.data().get(),          // B\n                               data_type,                // B type\n                               head_dim_,                // ldb\n                               max_q_len_ * head_dim_,   // stride B\n                               &beta,                    // beta\n                               qk_.data().get(),         // C\n                               CUDA_R_32F,               // C type\n                               max_k_len_,               // ldc\n                               max_q_len_ * max_k_len_,  // stride C\n                               batch_size_ * head_num_,  // batch count\n                               CUBLAS_COMPUTE_32F,       // compute type\n                               CUBLAS_GEMM_DEFAULT);\n\n    MaskedSoftmaxParam<T> params{};\n    params.attention_score = pr_.data().get();\n    params.qk              = qk_.data().get();\n    params.attention_mask  = mask_.data().get();\n    params.batch_size      = batch_size_;\n    params.q_length        = max_q_len_;\n    params.k_length        = max_k_len_;\n    params.num_heads       = head_num_;\n    params.sinks           = sinks;\n    invokeMaskedSoftmax(params, stream_);\n\n    alpha = 1.f;\n    cublasGemmStridedBatchedEx(cublas_,\n                               CUBLAS_OP_N,              // trans A\n                               CUBLAS_OP_N,              // trans B\n                               head_dim_,                // m\n                               max_q_len_,               // n\n                               max_k_len_,               // k\n                               &alpha,                   // alpha\n                               vals_.data().get(),       // A\n                               data_type,                // A type\n                               head_dim_,                // lda\n                               max_k_len_ * head_dim_,   // strideA\n                               pr_.data().get(),         // B\n                               data_type,                // B type\n                               max_k_len_,               // ldb\n                               max_q_len_ * max_k_len_,  // stride B\n                               &beta,                    // beta\n                               out_.data().get(),        // C [b, h, q, d]\n                               data_type,                // C type\n                               head_dim_,                // ldc\n                               max_q_len_ * head_dim_,   // stride C\n                               batch_size_ * head_num_,  // batch count\n                               CUBLAS_COMPUTE_32F,       // compute type\n                               CUBLAS_GEMM_DEFAULT);\n\n    // [B, H, Q, D] -> [B, Q, H, D]\n    invokeTransposeAttentionOutRemovePadding(out_.data().get(),\n                                             output,\n                                             batch_size_ * max_q_len_,\n                                             batch_size_,\n                                             max_q_len_,\n                                             head_num_,\n                                             head_dim_,\n                                             nullptr,\n                                             nullptr,\n                                             0,\n                                             stream_);\n}\n\ntemplate class Reference<half>;\n\n#if ENABLE_BF16\ntemplate class Reference<nv_bfloat16>;\n#endif\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/kernels/attention/reference.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include <cublas_v2.h>\n#include <cuda_runtime.h>\n\n#include <thrust/universal_vector.h>\n\n#include \"src/turbomind/kernels/unfused_attention_kernels.h\"\n\nnamespace turbomind {\n\ntemplate<class T>\nvoid invokeApplyRotaryEmbedding(T*           k_cache,\n                                int          max_k_len,\n                                int          head_num,\n                                int          head_dim,\n                                float        rope_base,\n                                int          rope_dim,\n                                int          batch_size,\n                                cudaStream_t stream = {});\n\ntemplate<class T>\nclass Reference {\npublic:\n    explicit Reference(cudaStream_t stream);\n\n    void Reshape(size_t max_q_len,\n                 size_t max_k_len,\n                 size_t head_num,\n                 size_t head_dim,\n                 size_t kv_head_num,\n                 size_t batch_size,\n                 int    window_size);\n\n    void Execute(T*       output,\n                 T*       k_cache,\n                 T*       v_cache,\n                 const T* qkv,\n                 const T* qkv_bias,\n                 const T* sinks,\n                 float    rope_base,\n                 int      rope_dim);\n\n    const float* qk() const\n    {\n        return qk_.data().get();\n    }\n\n    const T* pr() const\n    {\n        return pr_.data().get();\n    }\n\n    const T* mask() const\n    {\n        return mask_.data().get();\n    }\n\nprivate:\n    cudaStream_t                    stream_;\n    cublasHandle_t                  cublas_;\n    thrust::universal_vector<T>     mask_;\n    thrust::universal_vector<float> qk_;\n    thrust::universal_vector<T>     pr_;\n    thrust::universal_vector<T>     q_;\n    thrust::universal_vector<T>     out_;\n\n    thrust::universal_vector<T> keys_;\n    thrust::universal_vector<T> vals_;\n\n    int max_q_len_{};\n    int max_k_len_{};\n    int head_num_{};\n    int head_dim_{};\n    int kv_head_num_{};\n    int batch_size_{};\n    int window_size_{};\n};\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/kernels/attention/registrar.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include <functional>\n#include <memory>\n#include <vector>\n\n#include \"src/turbomind/kernels/attention/kernel_impl.h\"\n\nnamespace turbomind::attention {\n\nclass Collector {\npublic:\n    template<class T>\n    void add()\n    {\n        kernels_.emplace_back(std::make_unique<KernelImpl<T>>());\n        // std::cout << \"add kernel: \" << to_string(kernels_.back()->desc()) << std::endl;\n    }\n\n    std::vector<std::unique_ptr<Kernel>> release()\n    {\n        return std::move(kernels_);\n    }\n\nprivate:\n    std::vector<std::unique_ptr<Kernel>> kernels_;\n};\n\nusing RegisterFn = std::function<void(Collector&)>;\n\ninline std::vector<RegisterFn>& gKernelFactories()\n{\n    static std::vector<RegisterFn> v;\n    return v;\n}\n\nstruct Registrar {\n    explicit Registrar(RegisterFn fn)\n    {\n        gKernelFactories().push_back(std::move(fn));\n    }\n};\n\n}  // namespace turbomind::attention\n"
  },
  {
    "path": "src/turbomind/kernels/attention/registry.cu",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include \"src/turbomind/kernels/attention/registry.h\"\n\n#include <memory>\n#include <mutex>\n#include <tuple>\n#include <vector>\n\n#include \"src/turbomind/core/check.h\"\n#include \"src/turbomind/kernels/attention/arch.h\"\n#include \"src/turbomind/kernels/attention/registrar.h\"\n#include \"src/turbomind/kernels/core/math.h\"\n\nnamespace turbomind::attention {\n\nnamespace {\n\nconstexpr float kMaxWasteRatio = 1.f;\n\n}  // namespace\n\nRegistry::Registry(std::shared_ptr<cudaDeviceProp> device_prop):\n    device_prop_{std::move(device_prop)}, arch_{device_prop_->major * 100 + device_prop_->minor * 10}\n{\n    for (auto& register_fn : gKernelFactories()) {\n        Collector collector;\n        register_fn(collector);\n        for (auto& k : collector.release()) {\n            Add(std::move(k));\n        }\n    }\n}\n\nbool Registry::Add(std::unique_ptr<Kernel> kernel)\n{\n    bool is_valid = true;\n\n    if (!arch::is_arch_compatible(kernel->arch(), arch_)) {\n        is_valid = false;\n    }\n\n    if ((int)device_prop_->sharedMemPerBlockOptin < kernel->smem_size()) {\n        is_valid = false;\n    }\n\n    if (is_valid) {\n        ptrs_.push_back(kernels_.emplace_back(std::move(kernel)).get());\n    }\n\n    return is_valid;\n}\n\nconst Kernel* Registry::Find(const AttnDesc& desc) const\n{\n    const int threshold = static_cast<int>(kMaxWasteRatio * desc.query_group_sz);\n\n    const Kernel*             best = nullptr;\n    std::tuple<int, int, int> cost{};\n\n    for (const auto* k : ptrs_) {\n        const auto& d = k->desc();\n        if (d.mode != desc.mode || d.head_dim != desc.head_dim  //\n            || d.data_type != desc.data_type || d.kv_quant != desc.kv_quant) {\n            continue;\n        }\n        if (desc.mode == AttnDesc::kDecoding) {\n            const int ctas  = cdiv(desc.query_group_sz, d.qh);\n            const int waste = d.qh * ctas - desc.query_group_sz;\n\n            const auto v = std::make_tuple(waste > threshold, ctas, waste);\n            if (!best || v < cost) {\n                best = k;\n                cost = v;\n            }\n        }\n        else {  // attention, return on first match\n            return k;\n        }\n    }\n    return best;\n}\n\nRegistry& Registry::instance()\n{\n    struct DeviceState {\n        std::unique_ptr<Registry> registry;\n        std::once_flag            flag;\n    };\n\n    static std::vector<std::unique_ptr<DeviceState>> states = [] {\n        int count{};\n        TM_CHECK_EQ(cudaGetDeviceCount(&count), cudaSuccess);\n        std::vector<std::unique_ptr<DeviceState>> vec(count);\n        for (auto& s : vec) {\n            s = std::make_unique<DeviceState>();\n        }\n        return vec;\n    }();\n\n    int device_id{};\n    TM_CHECK_EQ(cudaGetDevice(&device_id), cudaSuccess);\n\n    auto& state = *states.at(device_id);\n\n    std::call_once(state.flag, [&]() {\n        auto prop = std::make_shared<cudaDeviceProp>();\n        TM_CHECK_EQ(cudaGetDeviceProperties(prop.get(), device_id), cudaSuccess);\n        state.registry = std::make_unique<Registry>(std::move(prop));\n    });\n\n    return *TM_CHECK_NOTNULL(state.registry);\n}\n\n}  // namespace turbomind::attention\n"
  },
  {
    "path": "src/turbomind/kernels/attention/registry.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include <memory>\n#include <vector>\n\n#include \"src/turbomind/kernels/attention/kernel_impl.h\"\n\nnamespace turbomind::attention {\n\nclass Registry {\npublic:\n    explicit Registry(std::shared_ptr<cudaDeviceProp> device_prop);\n\n    template<class KernelType>\n    [[maybe_unused]] bool Add()\n    {\n        return Add(std::make_unique<KernelImpl<KernelType>>());\n    }\n\n    const Kernel* Find(const AttnDesc& desc) const;\n\n    [[nodiscard]] const std::vector<Kernel*>& kernels() const\n    {\n        return ptrs_;\n    }\n\n    int sm_count() const noexcept\n    {\n        return device_prop_->multiProcessorCount;\n    }\n\n    static Registry& instance();\n\nprivate:\n    bool Add(std::unique_ptr<Kernel> kernel);\n\n    std::shared_ptr<cudaDeviceProp>      device_prop_;\n    int                                  arch_;\n    std::vector<std::unique_ptr<Kernel>> kernels_;\n    std::vector<Kernel*>                 ptrs_;\n};\n\n}  // namespace turbomind::attention\n"
  },
  {
    "path": "src/turbomind/kernels/attention/rotary_embedding.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include \"src/turbomind/kernels/core/array_ops.h\"\n#include \"src/turbomind/models/llama/llama_rope.h\"\n\nnamespace turbomind {\n\ntemplate<int N>\n__device__ void init_default(Array<float, N / 2>& inv_freq, int idx, RopeKernelParam& param)\n{\n    auto scale_factor = param.scale_factor;\n    auto inv_factor   = param.inv_factor;\n    PRAGMA_UNROLL\n    for (int i = 0; i < N; i += 2) {\n        inv_freq[i / 2] = inv_factor * exp2f((idx + i) * scale_factor);\n    }\n}\n\ntemplate<int N>\n__device__ void init_yarn(Array<float, N / 2>& inv_freq, int idx, RopeKernelParam& param)\n{\n    auto scale_factor            = param.scale_factor;\n    auto inv_factor              = param.inv_factor;\n    auto ramp_inv_factor_div_2   = param.yarn.ramp_inv_factor_div_2;\n    auto ramp_inv_factor_mul_min = param.yarn.ramp_inv_factor_mul_min;\n\n    PRAGMA_UNROLL\n    for (int i = 0; i < N; i += 2) {\n        auto freq       = exp2f((idx + i) * scale_factor);\n        auto alpha      = (idx + i) * ramp_inv_factor_div_2 - ramp_inv_factor_mul_min;\n        alpha           = fmaxf(0.f, fminf(1.f, alpha));\n        inv_freq[i / 2] = freq - freq * alpha * (1.f - inv_factor);\n    }\n}\n\ntemplate<int N>\n__device__ void init_llama3(Array<float, N / 2>& inv_freq, int idx, RopeKernelParam& param)\n{\n    auto scale_factor = param.scale_factor;\n    auto inv_factor   = param.inv_factor;\n    auto alpha        = param.llama3.alpha;\n    auto beta         = param.llama3.beta;\n\n    PRAGMA_UNROLL\n    for (int i = 0; i < N; i += 2) {\n        auto freq       = exp2f((idx + i) * scale_factor);\n        auto smooth     = fmaxf(0.f, fminf(1.f, alpha * freq - beta));\n        inv_freq[i / 2] = (1 - smooth) * freq * inv_factor + smooth * freq;\n    }\n}\n\ntemplate<int N>\nstruct FastRoPE {\n\n    static_assert(N % 2 == 0);\n\n    RopeKernelParam     param_;\n    Array<float, N / 2> inv_freq_;\n    bool                is_valid_;\n    float               attention_scaling_{1.f};\n    int                 idx_;\n\n    typedef void (*Func)(Array<float, N / 2>&, int, RopeKernelParam&);\n    Func fill_func_;\n\n    __device__ FastRoPE(const RopeKernelParam& param, int batch_idx, std::integral_constant<int, N>): param_(param)\n    {\n\n        if (param_.type == RopeType::kDynamic) {\n            float base          = param_.base[batch_idx];\n            param_.scale_factor = -log2f(base) / param_.dim;\n        }\n        else if (param_.type == RopeType::kYarn) {\n            attention_scaling_ = param_.yarn.attention_factor;\n        }\n        else if (param_.type == RopeType::kMrope) {\n            param_.mrope.position_ids += batch_idx * param_.mrope.stride;\n            param_.mrope.position_delta += batch_idx;\n            param_.mrope.length += batch_idx;\n        }\n    }\n\n    __device__ void init(int idx)\n    {\n        is_valid_ = idx < param_.dim;\n        idx_      = idx;\n        switch (param_.type) {\n            case RopeType::kDefault:\n            case RopeType::kLinear:\n            case RopeType::kDynamic:\n            case RopeType::kMrope:\n                init_default<N>(inv_freq_, idx, param_);\n                break;\n            case RopeType::kYarn:\n                init_yarn<N>(inv_freq_, idx, param_);\n                break;\n            case RopeType::kLlama3:\n                init_llama3<N>(inv_freq_, idx, param_);\n                break;\n        }\n    }\n\n    template<typename T>\n    __device__ void apply(Array<T, N>& x, float timestep)\n    {\n        if (param_.type == RopeType::kMrope) {\n            return apply_mrope(x, timestep);\n        }\n        // Most models apply rotary embedding in half precision\n        PRAGMA_UNROLL\n        for (int i = 0; i < N; i += 2) {\n            float c, s;\n            sincosf(timestep * inv_freq_[i / 2], &s, &c);\n            s *= attention_scaling_;\n            c *= attention_scaling_;\n            T tmp0 = (T)c * x[i] - (T)s * x[i + 1];\n            T tmp1 = (T)c * x[i + 1] + (T)s * x[i];\n            if (is_valid_) {\n                x[i]     = tmp0;\n                x[i + 1] = tmp1;\n            }\n        }\n    }\n\n    template<typename T>\n    __device__ void apply_mrope(Array<T, N>& x, float timestep)\n    {\n        int  tt, th, tw;\n        int3 section = param_.mrope.section;\n        if (timestep < *param_.mrope.length) {\n            const int* t = param_.mrope.position_ids + 3 * (int)timestep;\n            tt           = t[0];\n            th           = t[1];\n            tw           = t[2];\n        }\n        else {\n            tt = th = tw = (int)timestep + (*param_.mrope.position_delta);\n        }\n\n        PRAGMA_UNROLL\n        for (int i = 0; i < N; i += 2) {\n            if (i + idx_ < section.x) {\n                timestep = (float)tt;\n            }\n            else if (i + idx_ < section.y) {\n                timestep = (float)th;\n            }\n            else {\n                timestep = (float)tw;\n            }\n            float c, s;\n            sincosf(timestep * inv_freq_[i / 2], &s, &c);\n            T tmp0 = (T)c * x[i] - (T)s * x[i + 1];\n            T tmp1 = (T)c * x[i + 1] + (T)s * x[i];\n            if (is_valid_) {\n                x[i]     = tmp0;\n                x[i + 1] = tmp1;\n            }\n        }\n    }\n};\n\ntemplate<int N>\nstruct RotaryEmbedding {\n\n    static_assert(N % 2 == 0);\n\n    Array<float, N> cs_;\n\n    bool is_valid_;\n\n    __device__ RotaryEmbedding(float base, int dims, int timestep, int2 offset)\n    {\n        const int idx = offset.x;\n        is_valid_     = idx < dims;\n        PRAGMA_UNROLL\n        for (int i = 0; i < N; i += 2) {\n            const float2 tmp = get_coefficient(idx + i, dims, base, timestep);\n            cs_[i]           = tmp.x;\n            cs_[i + 1]       = tmp.y;\n        }\n    }\n\n    // ! depending on the context, this function may generate different result when inlined\n    static __device__ __noinline__ float2 get_coefficient(int idx, int dims, float base, int timestep)\n    {\n        const float inv_freq = timestep / powf(base, idx / (float)dims);\n        float2      cs;\n        sincosf(inv_freq, &cs.y, &cs.x);\n        return cs;\n    }\n\n    template<typename T>\n    __device__ void apply(Array<T, N>& x)\n    {\n\n        PRAGMA_UNROLL\n        for (int i = 0; i < N; i += 2) {\n            auto tmp0 = (T)cs_[i] * x[i] - (T)cs_[i + 1] * x[i + 1];\n            auto tmp1 = (T)cs_[i] * x[i + 1] + (T)cs_[i + 1] * x[i];\n            if (is_valid_) {\n                x[i]     = (T)tmp0;\n                x[i + 1] = (T)tmp1;\n            }\n        }\n    }\n};\ntemplate<class C, class T>\n__device__ void ApplyRotaryEmbedding(Array<T, 4>& x, float base, int dims, int ti, int di)\n{\n    PRAGMA_UNROLL\n    for (int d1 = 0; d1 < 2; ++d1) {\n        int    d        = d1 * 8 + di;\n        float  inv_freq = ti / powf(base, d / (float)dims);\n        float2 cs;\n        sincosf(inv_freq, &cs.y, &cs.x);\n        C x1          = (C)cs.x * (C)x[d1 * 2 + 0] - (C)cs.y * (C)x[d1 * 2 + 1];\n        C x2          = (C)cs.x * (C)x[d1 * 2 + 1] + (C)cs.y * (C)x[d1 * 2 + 0];\n        x[d1 * 2 + 0] = (T)x1;\n        x[d1 * 2 + 1] = (T)x2;\n    }\n}\n\ntemplate<int N, int C = 8>\nstruct RoPE {\n    Array<float, N> inv_freqs_;\n\n    RoPE() = default;\n    __device__ RoPE(float idx, float base, float dims)\n    {\n        for (int i = 0; i < N; ++i) {\n            inv_freqs_[i] = powf(base, idx / dims + (C / dims) * i);\n        }\n    }\n\n    template<class T>\n    __device__ void apply(Array<T, N * 2>& x, float timestep)\n    {\n        for (int i = 0; i < N; ++i) {\n            const float inv_freq = timestep * inv_freqs_[i];\n            float2      cs;\n            sincosf(inv_freq, &cs.y, &cs.x);\n            float tmp0   = cs.x * (float)x[i * 2] - cs.y * (float)x[i * 2 + 1];\n            float tmp1   = cs.x * (float)x[i * 2 + 1] + cs.y * (float)x[i * 2];\n            x[i * 2]     = (T)tmp0;\n            x[i * 2 + 1] = (T)tmp1;\n        }\n    }\n};\n\nstruct LogNScaling {\n\n    float scale_;\n\n    __device__ static float get_scale(int seq_len, int max_position_embeddings)\n    {\n        if (seq_len <= max_position_embeddings) {\n            return 1.f;\n        }\n        else {\n            return log2f(seq_len) / log2f(max_position_embeddings);\n        }\n    }\n\n    __device__ LogNScaling(int seq_len, int max_position_embeddings)\n    {\n        scale_ = get_scale(seq_len, max_position_embeddings);\n    }\n\n    template<typename T, int N>\n    __device__ void apply(Array<T, N>& x) const\n    {\n        PRAGMA_UNROLL\n        for (int i = 0; i < N; ++i) {\n            x[i] = (T)((float)x[i] * scale_);\n        }\n    }\n};\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/kernels/attention/test_attention.cu",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include \"attention.h\"\n#include \"block.h\"\n#include \"decoding.h\"\n#include \"kv_cache_utils_v2.h\"\n#include \"src/turbomind/kernels/attention/attention_params.h\"\n#include \"src/turbomind/kernels/attention/reference.h\"\n#include \"src/turbomind/models/llama/llama_utils.h\"\n#include \"src/turbomind/utils/cuda_utils.h\"\n#include \"test_utils.h\"\n#include <algorithm>\n#include <cmath>\n#include <iostream>\n#include <numeric>\n#include <random>\n#include <thrust/device_vector.h>\n#include <thrust/universal_vector.h>\n#include <utility>\n\nusing namespace turbomind;\n\n// [b, h, s, d] : current -> stride_h=s, stride_s=1, stride_b=hs\n// [cu_q, h, d] : qkvgemm -> stride_h=1, stride_s=h, stride_b=0\n// [h, cu_s, d] : prefill -> stride_h=s, stride_s=1, stride_b=0\n\ntemplate<class T, class Tkv>\nstruct Config {\n    int head_dim_;\n    int head_num_;\n    int block_len_;\n\n    TM_HOST_DEVICE constexpr int t_bits() const\n    {\n        if constexpr (std::is_same_v<T, Tkv>) {\n            return 0;\n        }\n        else {\n            return bitsof<T>;\n        }\n    }\n\n    TM_HOST_DEVICE constexpr int q_bits() const\n    {\n        return bitsof<Tkv>;\n    }\n\n    TM_HOST_DEVICE constexpr int head_dim() const\n    {\n        return head_dim_;\n    }\n\n    TM_HOST_DEVICE int head_num() const\n    {\n        return head_num_;\n    }\n\n    TM_HOST_DEVICE constexpr int block_len() const\n    {\n        return block_len_;\n    }\n\n    TM_HOST_DEVICE constexpr bool is_share_kv() const\n    {\n        return false;\n    }\n};\n\n// [S/S, H, S, D] <-> [S/b, H, b, D]\ntemplate<class Tkv, class T>\nvoid TestBlocks(const thrust::universal_vector<T>& k_cache,        // [B, H, S, D]\n                const thrust::universal_vector<T>& v_cache,        // [B, H, S, D]\n                thrust::universal_vector<char>&    blocks,         // block data\n                thrust::universal_vector<char*>&   k_ptrs,         // block ptrs\n                thrust::universal_vector<int>&     cu_block_cnts,  // cumulative block counts\n                const size_t                       head_num,\n                const size_t                       head_dim,\n                const size_t                       block_seq_len,\n                const size_t                       batch_size,\n                const int                          rope_dim,\n                int                                quant_policy)\n{\n    const size_t seq_len  = k_cache.size() / (head_dim * head_num * batch_size);\n    const size_t n_blocks = (seq_len + block_seq_len - 1) / block_seq_len;\n\n    Config<T, Tkv> config{(int)head_dim, (int)head_num, (int)block_seq_len};\n    block::Layout  layout{config};\n\n    dump(layout);\n\n    const size_t kHSD = head_num * seq_len * head_dim;\n\n    std::cout << \"batch_size = \" << batch_size << \", seq_len = \" << seq_len << \", block_size = \" << block_seq_len\n              << \", block_num = \" << n_blocks << \"\\n\";\n\n    thrust::universal_vector<T> kv_cache(k_cache.size() * 2);  // [B, 2, H, S, D]\n\n    {  // interleave K/V\n        auto k_src = k_cache.begin();\n        auto v_src = v_cache.begin();\n        auto dst   = kv_cache.begin();\n        for (size_t i = 0; i < batch_size; ++i) {\n            dst = thrust::copy_n(k_src, kHSD, dst);\n            dst = thrust::copy_n(v_src, kHSD, dst);\n            k_src += kHSD;\n            v_src += kHSD;\n        }\n    }\n\n    // const int kHsD = head_num * block_seq_len * head_dim;\n\n    // [B, S/s, 2, H, s, D]\n    // blocks.resize(batch_size * n_blocks * 2 * kHsD);\n    blocks.resize(batch_size * n_blocks * layout.block_size(1));\n    thrust::fill(blocks.begin(), blocks.end(), NAN);\n    k_ptrs.resize(batch_size * n_blocks + 1);  // +1 padding\n\n    std::vector<size_t> idxs(batch_size * n_blocks);\n    std::iota(idxs.begin(), idxs.end(), 0);\n\n    std::random_device rd;\n    std::mt19937       g(rd());\n    std::shuffle(idxs.begin(), idxs.end(), g);\n\n    for (size_t i = 0; i < idxs.size(); ++i) {\n        // k_ptrs[i] = blocks.data().get() + idxs[i] * 2 * kHsD;\n        k_ptrs[i] = blocks.data().get() + idxs[i] * layout.block_size(1);\n    }\n\n    thrust::universal_vector<int> seq_lens(batch_size);\n    thrust::universal_vector<int> cu_seq_lens(batch_size + 1);\n    thrust::fill(seq_lens.begin(), seq_lens.end(), seq_len);\n    for (size_t i = 0; i <= batch_size; ++i) {\n        cu_seq_lens[i] = i * seq_len;\n    }\n\n    std::vector<int> n_blocks_vec(batch_size + 1, n_blocks);\n    cu_block_cnts.resize(batch_size + 1);\n    std::exclusive_scan(n_blocks_vec.begin(), n_blocks_vec.end(), cu_block_cnts.begin(), 0);\n\n    cudaDeviceSynchronize();\n\n    // [B, 2H, S, D] -> [B, S/s] x [2H, s, D]\n    for (int i = 0; i < 1; ++i) {\n        // (B, 2, H, S, D) -> blocks\n        invokeProcessKV_v2(k_ptrs.data().get(),\n                           kv_cache.data().get(),\n                           kv_cache.data().get() + head_num * seq_len * head_dim,\n                           (T*)nullptr,\n                           (T*)nullptr,\n                           cu_seq_lens.data().get(),\n                           cu_seq_lens.data().get(),\n                           cu_block_cnts.data().get(),\n                           RopeKernelParam{},\n                           2 * head_num * seq_len,\n                           0,\n                           seq_len,\n                           1,\n                           block_seq_len,\n                           0,  // layer_id\n                           0,  // cp_rank\n                           1,  // cp_size\n                           seq_len,\n                           head_num,\n                           head_dim,\n                           batch_size,\n                           quant_policy);\n    }\n\n    thrust::universal_vector<T> kv_cache_2(kv_cache.size());\n\n    // round trip test\n    for (int i = 0; i < 1; ++i) {\n        // kv_cache_2 is [B, 2, H, S, D]\n        invokeFlattenKV_v2(kv_cache_2.data().get(),\n                           kv_cache_2.data().get() + head_num * seq_len * head_dim,\n                           k_ptrs.data().get(),\n                           cu_seq_lens.data().get(),\n                           cu_block_cnts.data().get(),\n                           RopeKernelParam{},\n                           2 * head_num * seq_len,\n                           0,\n                           seq_len,\n                           1,\n                           block_seq_len,\n                           0,  // layer_id\n                           0,  // cp_rank\n                           1,  // cp_size\n                           seq_len,\n                           head_num,\n                           head_dim,\n                           batch_size,\n                           quant_policy);\n    }\n\n    cudaDeviceSynchronize();\n\n    if (0) {\n        std::cout << \">>> Compare\\n\";\n        Compare(\n            kv_cache_2.data().get(), kv_cache.data().get(), head_dim, head_dim, batch_size * 2 * head_num * seq_len, 0);\n        std::cout << \"<<< Compare\\n\";\n    }\n}\n\ndouble get_memory_bandwidth()  // -> GB/s\n{\n    int clock_rate_khz{};\n    int bus_width_bits{};\n    cudaDeviceGetAttribute(&clock_rate_khz, cudaDevAttrMemoryClockRate, 0);\n    cudaDeviceGetAttribute(&bus_width_bits, cudaDevAttrGlobalMemoryBusWidth, 0);\n    return 2. * (double)clock_rate_khz / 1e6 * (double)bus_width_bits / 8.;\n}\n\n#define KV_INT8 0\n\n#define KV_INT4 0\n\n#define DECODING 0\n\n#define SINK 5\n\ntemplate<class T>\nint test_attention()\n{\n    AttentionParams<T> params{};\n\n    constexpr size_t kHeadDim    = 128;\n    constexpr int    kWindowSize = 128 << 20;\n\n#if DECODING\n    // constexpr size_t kHeadNum   = 32;\n    // constexpr size_t kBatchSize = 64;\n    constexpr size_t kHeadNum   = 64;\n    constexpr size_t KvHeadNum  = kHeadNum / 8;\n    constexpr size_t kBatchSize = 256;\n    constexpr size_t kInputLen  = 1;\n\n    constexpr size_t kSequenceLen = 1000;\n    // constexpr size_t kSequenceLen = 4095;\n    // constexpr size_t kSequenceLen = 511;\n    // constexpr size_t kSequenceLen = 2047;\n    // constexpr size_t kSequenceLen = 4095;\n    // constexpr size_t kSequenceLen = 8 * 1024 - 1;\n    // constexpr size_t kSequenceLen = 32767;\n    // constexpr size_t kSequenceLen = 65535;\n    // constexpr size_t kSequenceLen = 131071;\n    // constexpr size_t kSequenceLen = 200000;\n    // constexpr size_t kSequenceLen = 262143;\n    // constexpr size_t kSequenceLen = (1 << 20) - 1;  // 1M\n    // constexpr size_t kSequenceLen = (1 << 22) - 1;  // 4M\n    // constexpr size_t kSequenceLen = (1 << 24) - 1;  // 16M\n    // constexpr int kSequenceLen = 2047;\n    constexpr int kBlockSz   = 64;\n    constexpr int kMaxSplitK = 128;\n#else\n\n    // append\n    // constexpr size_t kHeadNum     = 32;\n    // constexpr size_t KvHeadNum    = kHeadNum;\n    // constexpr size_t kBatchSize   = 1;\n    // constexpr size_t kInputLen    = 128;\n    // constexpr size_t kSequenceLen = 65536;\n    // constexpr int    kMaxSplitK   = 128;\n\n    // constexpr size_t kHeadNum     = 1;\n    // constexpr size_t KvHeadNum    = kHeadNum;\n    // constexpr size_t kBatchSize   = 1;\n    // constexpr size_t kInputLen    = 64;\n    // constexpr size_t kSequenceLen = 65536;\n    // constexpr int    kMaxSplitK   = 1;\n\n    // prefill\n    constexpr size_t kHeadNum     = 16;\n    constexpr size_t KvHeadNum    = kHeadNum / 8;\n    constexpr size_t kBatchSize   = 2;\n    constexpr size_t kInputLen    = 8192;\n    constexpr size_t kSequenceLen = 0;\n    constexpr int    kMaxSplitK   = 1;\n\n    constexpr int kBlockSz     = 64;\n\n#endif\n\n#if KV_INT8\n    using Tkv                  = uint8_t;\n    constexpr int kQuantPolicy = QuantPolicy::kCacheKVInt8;\n#elif KV_INT4\n    using Tkv                  = uint4_t;\n    constexpr int kQuantPolicy = QuantPolicy::kCacheKVInt4;\n#else\n    using Tkv                  = T;\n    constexpr int kQuantPolicy = 0;\n#endif\n\n    static_assert(KvHeadNum > 0);\n\n    constexpr size_t kContextLen = kSequenceLen + kInputLen;\n    constexpr size_t kTokenNum   = kBatchSize * kInputLen;\n    constexpr int    kTestIter   = 10;\n\n    constexpr float kRoPEBase = 10000.f;\n    constexpr int   kRoPEDim  = kHeadDim / 2;\n    constexpr int   kDump     = 0;\n\n    RNG rng{};\n\n    thrust::universal_vector<T> k_cache(kBatchSize * KvHeadNum * kContextLen * kHeadDim);\n    thrust::universal_vector<T> v_cache(kBatchSize * KvHeadNum * kContextLen * kHeadDim);\n\n    // flattened float point KV cache\n    thrust::device_vector<T> kv_cache(KvHeadNum * 2 * (kBatchSize * kContextLen + MAX_CTA_S) * kHeadDim);\n\n    thrust::universal_vector<T> qkv(kBatchSize * kInputLen * (kHeadNum + KvHeadNum * 2) * kHeadDim);\n    thrust::universal_vector<T> output(kBatchSize * kInputLen * kHeadNum * kHeadDim);\n\n    thrust::universal_vector<bool>  finished(kBatchSize);\n    thrust::universal_vector<int>   sequence_length(kBatchSize);\n    thrust::universal_vector<int>   input_length(kBatchSize);\n    thrust::universal_vector<int>   context_length(kBatchSize);\n    thrust::universal_vector<float> rope_base(kBatchSize);\n    thrust::universal_vector<int>   cu_seqlens(kBatchSize + 1);\n    thrust::universal_vector<int>   cu_kv_lens(kBatchSize + 1);\n\n    thrust::device_vector<float> partial_ML(kTokenNum * kHeadNum * kMaxSplitK * 2);\n    thrust::device_vector<float> partial_O(kTokenNum * kHeadNum * kMaxSplitK * kHeadDim);\n    thrust::device_vector<int>   split_cnt(kTokenNum);\n\n    thrust::universal_vector<float> qk_buf((size_t)kDump * kBatchSize * kHeadNum * kInputLen * kContextLen);\n    thrust::universal_vector<T>     pr_buf((size_t)kDump * kBatchSize * kHeadNum * kInputLen * kContextLen);\n\n    thrust::universal_vector<T> sinks(kHeadNum);\n\n    rng.GenerateNormal(qkv.data().get(), qkv.size(), 1.f, 0.f);\n\n    rng.GenerateNormal(k_cache.data().get(), kBatchSize * KvHeadNum * kContextLen * kHeadDim);\n    rng.GenerateNormal(v_cache.data().get(), kBatchSize * KvHeadNum * kContextLen * kHeadDim);\n\n    if (SINK) {\n        rng.GenerateUniform(sinks.data().get(), sinks.size(), 2 * SINK, -SINK);\n    }\n\n    if (0) {\n        // Set input range to zero\n        // (BH, SD)\n        cudaMemset2DAsync(k_cache.data().get() + kSequenceLen * kHeadDim,\n                          sizeof(T) * kContextLen * kHeadDim,\n                          0,\n                          sizeof(T) * kInputLen * kHeadDim,\n                          kBatchSize * KvHeadNum);\n        cudaMemset2DAsync(v_cache.data().get() + kSequenceLen * kHeadDim,\n                          sizeof(T) * kContextLen * kHeadDim,\n                          0,\n                          sizeof(T) * kInputLen * kHeadDim,\n                          kBatchSize * KvHeadNum);\n    }\n\n    invokeApplyRotaryEmbedding(k_cache.data().get(), kContextLen, KvHeadNum, kHeadDim, kRoPEBase, kRoPEDim, kBatchSize);\n\n    thrust::universal_vector<T> k_cache_ref = k_cache;\n    thrust::universal_vector<T> v_cache_ref = v_cache;\n\n    thrust::universal_vector<char>  blocks;\n    thrust::universal_vector<char*> k_ptrs;\n    thrust::universal_vector<int>   cu_block_cnts;\n\n    TestBlocks<Tkv>(k_cache,\n                    v_cache,\n                    blocks,\n                    k_ptrs,\n                    cu_block_cnts,\n                    KvHeadNum,\n                    kHeadDim,\n                    kBlockSz,\n                    kBatchSize,\n                    kRoPEDim,\n                    kQuantPolicy);\n\n    thrust::universal_vector<T>     output_ref = output;\n    thrust::universal_vector<void*> k_cache_ref_ptrs(kBatchSize);\n    thrust::universal_vector<void*> v_cache_ref_ptrs(kBatchSize);\n\n    thrust::universal_vector<T> bias_QKV(kHeadNum * kHeadDim + 2 * KvHeadNum * kHeadDim);\n\n    rng.GenerateNormal(bias_QKV.data().get(), bias_QKV.size(), 0.1f, 0.f);\n\n    cudaDeviceSynchronize();\n\n    for (size_t i = 0; i <= kBatchSize; ++i) {\n        cu_seqlens[i] = i * kInputLen;\n        cu_kv_lens[i] = i * kContextLen;\n    }\n\n    for (size_t i = 0; i < kBatchSize; ++i) {\n        input_length[i]     = kInputLen;\n        sequence_length[i]  = kSequenceLen;\n        context_length[i]   = kContextLen;\n        k_cache_ref_ptrs[i] = k_cache_ref.data().get() + i * k_cache_ref.size() / kBatchSize;\n        v_cache_ref_ptrs[i] = v_cache_ref.data().get() + i * v_cache_ref.size() / kBatchSize;\n        rope_base[i]        = kRoPEBase;\n    }\n\n    // getchar();\n\n    params.out = output_ref.data().get();\n    params.q   = qkv.data().get();\n    params.k   = params.q + kHeadNum * kHeadDim;\n    params.v   = params.k + KvHeadNum * kHeadDim;\n\n    params.q_bias = bias_QKV.data().get();\n    params.k_bias = params.q_bias + kHeadNum * kHeadDim;\n    params.v_bias = params.k_bias + KvHeadNum * kHeadDim;\n\n    params.stride = (kHeadNum + 2 * KvHeadNum) * kHeadDim;\n\n    params.token_num  = kTokenNum;\n    params.batch_size = kBatchSize;\n    params.max_q_len  = kInputLen;\n    params.max_k_len  = kContextLen;\n\n    params.block_iter_params = BlockIteratorParams{k_ptrs.data().get(),  //\n                                                   cu_block_cnts.data().get(),\n                                                   0,\n                                                   kBlockSz};\n\n    params.linear_iter_params = LinearIteratorParams{kv_cache.data().get(),  //\n                                                     int(2 * kBatchSize * kContextLen * kHeadDim),\n                                                     int(kBatchSize * kContextLen * kHeadDim)};\n\n    params.quant_policy = kQuantPolicy;\n\n    params.finished   = finished.data().get();\n    params.rope_theta = rope_base.data().get();\n    params.cu_q_len   = cu_seqlens.data().get();\n    params.cu_k_len   = cu_kv_lens.data().get();\n\n    params.num_heads     = kHeadNum;\n    params.num_kv_heads  = KvHeadNum;\n    params.size_per_head = kHeadDim;\n    params.window_size   = kWindowSize;\n    params.inv_sqrt_dh   = (float)std::log2(expf(1.)) / std::sqrt((float)params.size_per_head);\n\n    if (SINK) {\n        params.sinks       = sinks.data().get();\n        params.scale_sinks = 1. / std::sqrt((float)params.size_per_head);\n    }\n\n    float scale_factor = -std::log2f(kRoPEBase) / kRoPEDim;\n    params.rope_param  = RopeKernelParam{RopeType::kDefault, nullptr, kRoPEDim, scale_factor, 1.f};\n\n    params.split_cnt  = split_cnt.data().get();\n    params.partial_ML = partial_ML.data().get();\n    params.partial_O  = partial_O.data().get();\n\n    params.max_split_k = kMaxSplitK;\n    params.arch        = getSMVersion();\n\n    params.qk = qk_buf.data().get();\n    params.pr = pr_buf.data().get();\n\n    Reference<T> reference({});\n    reference.Reshape(kInputLen, kContextLen, kHeadNum, kHeadDim, KvHeadNum, kBatchSize, kWindowSize);\n\n    for (int i = 0; i < 1; ++i) {\n        reference.Execute(params.out,  //\n                          k_cache_ref.data().get(),\n                          v_cache_ref.data().get(),\n                          qkv.data().get(),\n                          bias_QKV.data().get(),\n                          SINK ? sinks.data().get() : nullptr,\n                          kRoPEBase,\n                          kRoPEDim);\n    }\n\n    cudaDeviceSynchronize();\n\n    if constexpr (kDump) {\n        for (size_t b = 0; b < kBatchSize; ++b) {\n            for (size_t h = 0; h < kHeadNum; ++h) {\n                for (size_t q = 0; q < kInputLen; ++q) {\n                    auto qk = reference.qk() + b * kHeadNum * kInputLen * kContextLen + h * kInputLen * kContextLen\n                              + q * kContextLen;\n                    for (size_t k = 0; k < kContextLen; ++k) {\n                        std::cout << qk[k] * params.inv_sqrt_dh << \" \";\n                    }\n                    std::cout << \"\\n\";\n                }\n                std::cout << \"\\n\";\n            }\n            std::cout << \"\\n\";\n        }\n    }\n\n    if (auto err = cudaGetLastError(); err != cudaSuccess) {\n        std::cout << cudaGetErrorString(err) << \"\\n\";\n        return -1;\n    }\n    std::cout << \"---------------------------------------------------\\n\";\n\n    params.out = output.data().get();\n\n    std::vector<thrust::universal_vector<T>> outputs;\n\n    std::vector<cudaEvent_t> ev_start(kTestIter);\n    std::vector<cudaEvent_t> ev_end(kTestIter);\n\n    for (int i = 0; i < kTestIter; ++i) {\n        cudaEventCreate(&ev_start[i]);\n        cudaEventCreate(&ev_end[i]);\n    }\n\n    for (int i = 0; i < std::max(kTestIter, 1); ++i) {\n\n#if DECODING\n        cudaEventRecord(ev_start[i]);\n        dispatchDecoding<T>(params);\n        cudaEventRecord(ev_end[i]);\n#else\n        // input -> blocked\n        invokeProcessKV_v2_(params);\n        // blocked -> linear\n        invokeFlattenKV_v2_(params, cu_kv_lens[kBatchSize]);\n\n        cudaEventRecord(ev_start[i]);\n        dispatchAttention(params);\n        cudaEventRecord(ev_end[i]);\n#endif\n\n        if (auto err = cudaGetLastError(); err != cudaSuccess) {\n            std::cout << cudaGetErrorString(err) << \"\\n\";\n            return -1;\n        }\n        if (1) {\n            outputs.push_back(output);\n        }\n    }\n\n    if (kDump) {\n        cudaDeviceSynchronize();\n        for (size_t b = 0; b < kBatchSize; ++b) {\n            for (size_t h = 0; h < kHeadNum; ++h) {\n                for (size_t q = 0; q < kInputLen; ++q) {\n                    auto ref = reference.qk() + b * kHeadNum * kInputLen * kContextLen + h * kInputLen * kContextLen\n                               + q * kContextLen;\n                    auto data = qk_buf.data().get() + b * kHeadNum * kInputLen * kContextLen\n                                + h * kInputLen * kContextLen + q * kContextLen;\n                    for (size_t k = 0; k < kContextLen; ++k) {\n                        // std::cout << std::max(0.f, std::abs(data[k] - (float)ref[k]) - 1e-5f) << \" \";\n                        std::cout << data[k] * params.inv_sqrt_dh << \" \";\n                        // std::cout << (float)data[k] << \" \";\n                    }\n                    std::cout << \"\\n\";\n                }\n                std::cout << \"\\n\";\n            }\n            std::cout << \"\\n\";\n        }\n    }\n\n    invokeFlattenKV_v2(k_cache.data().get(),  // [B, H, S, D]\n                       v_cache.data().get(),\n                       k_ptrs.data().get(),\n                       cu_kv_lens.data().get(),\n                       cu_block_cnts.data().get(),\n                       RopeKernelParam{},  // DECODING ? nullptr : params.rope_theta,\n                       KvHeadNum * kContextLen,\n                       0,\n                       kContextLen,\n                       1,\n                       kBlockSz,\n                       0,  // layer_id\n                       0,  // cp_rank\n                       1,  // cp_size\n                       kContextLen,\n                       KvHeadNum,\n                       kHeadDim,\n                       kBatchSize,\n                       kQuantPolicy);\n    cudaDeviceSynchronize();\n\n    const size_t nbytes = blocks.size() / kContextLen * std::min(kContextLen, (size_t)params.window_size);\n    const size_t ops =\n        2 * kInputLen * std::min(kContextLen, (size_t)params.window_size) * kHeadDim * kHeadNum * kBatchSize;\n\n    const float peak_bw = get_memory_bandwidth();\n\n    std::cout << \"Device peak global memory bandwidth: \" << peak_bw << \" GB/s\\n\";\n\n    for (int i = 0; i < kTestIter; ++i) {\n        float ms{};\n        cudaEventElapsedTime(&ms, ev_start[i], ev_end[i]);\n        const float bw      = nbytes / 1e9f / ms * 1000.f;\n        const float flops   = ops / 1e12f / ms * 1000.f;\n        const float percent = bw / peak_bw * 100.f;\n        printf(\"time %.3f ms, bw %.3f GB/s, %.3f %%, tflops %.3f \\n\", ms, bw, percent, flops);\n    }\n\n    if (outputs.size() > 1) {\n        std::cout << \"Evaluating consistency...\" << std::endl;\n        for (size_t i = 1; i < outputs.size(); ++i) {\n            Compare(outputs[i].data().get(), outputs[i - 1].data().get(), kHeadDim, kHeadDim, kHeadNum, 0, 0, 0);\n        }\n    }\n\n    std::cout << \"---------------------------------------------------\\n\";\n\n    // [B, S, H, D]\n    Compare(output.data().get(),  //\n            output_ref.data().get(),\n            kHeadNum * kHeadDim,\n            kHeadNum * kHeadDim,\n            kBatchSize * kInputLen,\n            0);\n\n    // [BH, SD]\n    Compare(k_cache.data().get() + kSequenceLen * kHeadDim,\n            k_cache_ref.data().get() + kSequenceLen * kHeadDim,\n            kContextLen * kHeadDim,\n            kInputLen * kHeadDim,\n            kBatchSize * KvHeadNum,\n            0);\n    Compare(v_cache.data().get() + kSequenceLen * kHeadDim,\n            v_cache_ref.data().get() + kSequenceLen * kHeadDim,\n            kContextLen * kHeadDim,\n            kInputLen * kHeadDim,\n            kBatchSize * KvHeadNum);\n\n    return 0;\n}\n\nint main(int argc, char* argv[])\n{\n    test_attention<half>();\n\n    // test_attention<nv_bfloat16>();\n}\n"
  },
  {
    "path": "src/turbomind/kernels/attention/test_quant.cu",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include \"quantization.h\"\n#include \"src/turbomind/kernels/attention/test_utils.h\"\n#include \"src/turbomind/kernels/core/array_ops.h\"\n#include \"src/turbomind/macro.h\"\n#include <cstdint>\n#include <iostream>\n#include <thrust/universal_vector.h>\n\nusing namespace turbomind;\n\ntemplate<int kVecSize, class T0, class T1>\n__global__ void convert(T1* dst, const T0* src, size_t n, float scale, float zero)\n{\n    auto v_src = (Array<T0, kVecSize>*)src;\n    auto v_dst = (Array<T1, kVecSize>*)dst;\n\n    const int v_n = n / kVecSize;\n\n    ConvertKvCache<T0, T1> converter{scale, zero};\n\n    for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < v_n; i += blockDim.x * gridDim.x) {\n        Array<T0, kVecSize> vi;\n        Array<T1, kVecSize> vo;\n        Load(vi, (T0*)v_src[i].data());\n        vo = converter(vi);\n        Store((T1*)v_dst[i].data(), vo);\n    }\n}\n\ntemplate<class T0, class T1, int kVecSize>\nvoid round_trip_test(size_t n, float s1 = 1., float z1 = 0., float s2 = 1., float z2 = 0.)\n{\n    std::cout << __PRETTY_FUNCTION__ << std::endl;\n\n    using namespace thrust;\n\n    universal_vector<T0> src(n);\n    universal_vector<T0> dst(src.size());\n\n    universal_vector<Array<T1, kVecSize>> tmp(src.size() / kVecSize);\n\n    for (size_t i = 0; i < src.size(); ++i) {\n        src[i] = T0(float(rand() % (1 << bitsof<T1>)));\n    }\n\n    convert<kVecSize><<<256, 256>>>((T1*)tmp.data().get(), src.data().get(), n, s1, z1);\n    convert<kVecSize><<<256, 256>>>(dst.data().get(), (const T1*)tmp.data().get(), n, s2, z2);\n\n    cudaDeviceSynchronize();\n\n    Compare(dst.data().get(), src.data().get(), src.size(), src.size(), 1);\n}\n\nint main(int argc, char* argv[])\n{\n    round_trip_test<float, uint8_t, 4>(1 << 20);\n    round_trip_test<half, uint8_t, 4>(1 << 20);\n#if ENABLE_BF16\n    round_trip_test<nv_bfloat16, uint8_t, 4>(1 << 20);\n#endif\n\n    round_trip_test<float, uint4_t, 8>(1 << 20, 1, 0, 1, -64);\n    round_trip_test<half, uint4_t, 8>(1 << 20, 1, 0, 1, -64);\n#if ENABLE_BF16\n    round_trip_test<nv_bfloat16, uint4_t, 8>(1 << 20, 1, 0, 1, 0);\n#endif\n\n    return 0;\n}\n"
  },
  {
    "path": "src/turbomind/kernels/attention/test_utils.cu",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include \"test_utils.h\"\n#include <cublas_v2.h>\n#include <curand.h>\n#include <curand_kernel.h>\n#include <fstream>\n#include <iostream>\n\n#define _CG_ABI_EXPERIMENTAL\n#include <cooperative_groups.h>\n#include <cooperative_groups/memcpy_async.h>\n#include <cooperative_groups/reduce.h>\n\nnamespace turbomind {\n\ncublasHandle_t cublas_handle{};\ncudaStream_t   cublas_stream{};\n\ntemplate<typename T>\nvoid Compare(const T* src, const T* ref, size_t stride, int m, int n, bool show, float rtol, float atol)\n{\n    float asums{};\n    float rsums{};\n    int   outliers{};\n    for (int nn = 0; nn < n; ++nn) {\n        float abs_diff_sum{};\n        float rel_diff_sum{};\n        for (int mm = 0; mm < m; ++mm) {\n            auto x = float(src[nn * stride + mm]);\n            auto y = float(ref[nn * stride + mm]);\n            // if (show) {\n            //     std::cout << x << \"\\t\" << y << std::endl;\n            // }\n            auto abs_diff = std::abs(x - y);\n            auto rel_diff = abs_diff / std::abs(y + 1e-6f);\n            if (!(abs_diff <= atol + rtol * std::abs(y))) {\n                ++outliers;\n                if (show) {\n                    std::cout << nn << \",\" << mm << \"\\t\" << x << \"\\t\" << y << std::endl;\n                }\n            }\n            abs_diff_sum += abs_diff;\n            rel_diff_sum += rel_diff;\n        }\n        asums += abs_diff_sum / m;\n        rsums += rel_diff_sum / m;\n    }\n    std::cout << \"abs_diff = \" << asums / n << \" rel_diff = \" << rsums / n << \" outliers = \" << outliers / (float)n\n              << std::endl;\n}\n\ntemplate void Compare(const half* src, const half* ref, size_t stride, int m, int n, bool show, float rtol, float atol);\ntemplate void\nCompare(const float* src, const float* ref, size_t stride, int m, int n, bool show, float rtol, float atol);\n#if ENABLE_BF16\ntemplate void\nCompare(const nv_bfloat16* src, const nv_bfloat16* ref, size_t stride, int m, int n, bool show, float rtol, float atol);\n#endif\n\nvoid LoadBinary(const std::string& path, size_t size, void* dst)\n{\n    std::ifstream ifs(path, std::ios::binary | std::ios::in);\n    if (!ifs.is_open()) {\n        std::cerr << \"failed to open \" << path << \"\\n\";\n        std::abort();\n    }\n    ifs.seekg(0, ifs.end);\n    auto actual_size_in_bytes = ifs.tellg();\n    ifs.seekg(0, ifs.beg);\n    if (size != actual_size_in_bytes) {\n        std::cerr << \"[warning] file \" << path << \" has \" << actual_size_in_bytes << \" bytes, while \" << size\n                  << \" bytes is requested\\n\";\n    }\n    ifs.read((char*)dst, size);\n    std::cerr << \"[info] \" << path << \" \" << size << \"\\n\";\n}\n\nnamespace cg = cooperative_groups;\n\n__global__ void curand_init(curandState* state)\n{\n    auto tid = cg::this_grid().thread_rank();\n    curand_init(0xe4c45822e90461ddULL, tid, 0, state + tid);\n}\n\ntemplate<typename T>\n__global__ void curand_uniform(curandState* state, size_t count, T* result, float scale, float shift)\n{\n    auto grid = cg::this_grid();\n    for (auto i = grid.thread_rank(); i < count; i += grid.size()) {\n        float tmp = curand_uniform(state + grid.thread_rank());\n        result[i] = T(scale * tmp + shift);\n    }\n}\n\ntemplate<typename T>\n__global__ void curand_normal(curandState* state, size_t count, T* result, float scale, float shift)\n{\n    auto grid = cg::this_grid();\n    for (auto i = grid.thread_rank(); i < count; i += grid.size()) {\n        float tmp = curand_normal(state + grid.thread_rank());\n        result[i] = T(scale * tmp + shift);\n    }\n}\n\n__global__ void curand_bytes(curandState* state, size_t count, uint* result)\n{\n    auto grid = cg::this_grid();\n    for (auto i = grid.thread_rank(); i < count; i += grid.size()) {\n        result[i] = curand(state + grid.thread_rank());\n    }\n}\n\nstruct RNG::Impl {\n\n    curandState* states{};\n\n    Impl()\n    {\n        cudaMalloc(&states, sizeof(curandState) * 64 * 64);\n        curand_init<<<64, 64>>>(states);\n    }\n\n    ~Impl()\n    {\n        cudaFree(states);\n    }\n\n    void GenerateUInt(uint* out, size_t count)\n    {\n        curand_bytes<<<64, 64>>>(states, count, out);\n    }\n\n    template<typename T>\n    void GenerateUniform(T* out, size_t count, float scale, float shift)\n    {\n        curand_uniform<<<64, 64>>>(states, count, out, scale, shift);\n    }\n\n    template<typename T>\n    void GenerateNormal(T* out, size_t count, float scale, float shift)\n    {\n        curand_normal<<<64, 64>>>(states, count, out, scale, shift);\n    }\n};\n\nRNG::RNG(): impl_(std::make_unique<Impl>()) {}\n\nRNG::~RNG() = default;\n\nvoid RNG::GenerateUInt(uint* out, size_t count)\n{\n    impl_->GenerateUInt(out, count);\n}\n\ntemplate<typename T>\nvoid RNG::GenerateUniform(T* out, size_t count, float scale, float shift)\n{\n    std::cout << count << std::endl;\n    impl_->GenerateUniform(out, count, scale, shift);\n}\n\ntemplate<typename T>\nvoid RNG::GenerateNormal(T* out, size_t count, float scale, float shift)\n{\n    impl_->GenerateNormal(out, count, scale, shift);\n}\n\ntemplate void RNG::GenerateUniform(half* out, size_t count, float scale, float shift);\ntemplate void RNG::GenerateUniform(float* out, size_t count, float scale, float shift);\n#if ENABLE_BF16\ntemplate void RNG::GenerateUniform(nv_bfloat16* out, size_t count, float scale, float shift);\n#endif\n\ntemplate void RNG::GenerateNormal(half* out, size_t count, float scale, float shift);\ntemplate void RNG::GenerateNormal(float* out, size_t count, float scale, float shift);\n#if ENABLE_BF16\ntemplate void RNG::GenerateNormal(nv_bfloat16* out, size_t count, float scale, float shift);\n#endif\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/kernels/attention/test_utils.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include \"attention.h\"\n#include \"src/turbomind/macro.h\"\n#include <cuda_fp16.h>\n#include <memory>\n\nnamespace turbomind {\n\ntemplate<typename T>\nvoid Compare(\n    const T* src, const T* ref, size_t stride, int m, int n, bool show = false, float rtol = 1e-2, float atol = 1e-4);\n\nvoid LoadBinary(const std::string& path, size_t size, void* dst);\n\nclass RNG {\npublic:\n    RNG();\n    ~RNG();\n    void GenerateUInt(uint* out, size_t count);\n\n    template<typename T>\n    void GenerateUniform(T* out, size_t count, float scale = 1.f, float shift = 0.f);\n\n    template<typename T>\n    void GenerateNormal(T* out, size_t count, float scale = 1.f, float shift = 0.f);\n\nprivate:\n    struct Impl;\n    std::unique_ptr<Impl> impl_;\n};\n\ntemplate<typename T>\nvoid mmha_ft_reference(const AttentionParams<T>& params,\n                       T**                       per_sample_k_cache,\n                       T**                       per_sample_v_cache,\n                       const int*                sequence_length,\n                       int                       max_memory_len,\n                       cudaStream_t              st);\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/kernels/attention/utils.cc",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include \"utils.h\"\n#include <cmath>\n#include <cstdio>\n#include <limits>\n#include <tuple>\n\nnamespace turbomind {\n\nint GetSplitCount(\n    int max_split_cnt, int grid_size, int max_active_ctas, int sm_count, int max_wave_cnt, float alpha, float beta)\n{\n\n    const float scale = (float)grid_size / (sm_count * max_active_ctas);\n\n    auto eval = [&](int s) -> std::tuple<float, float, int> {\n        float waves = std::ceil(scale * s);\n        float cost  = std::numeric_limits<float>::infinity();\n        if (s == 1 || waves <= max_wave_cnt) {\n            cost = (alpha / s + beta) * waves;\n        }\n        return {cost, scale * s, s};\n    };\n\n    std::tuple<float, float, int> best{std::numeric_limits<float>::infinity(), 0.f, 0};\n\n    auto print = [](auto& x) {  //\n        // printf(\"%d %f %f\\n\", std::get<2>(x), std::get<1>(x), std::get<0>(x));\n    };\n\n    for (int i = 1; i <= max_split_cnt; ++i) {\n        auto res = eval(i);\n        if (std::isinf(std::get<0>(res))) {\n            break;\n        }\n        print(res);\n        if (res < best) {\n            best = res;\n        }\n    }\n\n    print(best);\n\n    return std::get<int>(best);\n}\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/kernels/attention/utils.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\nnamespace turbomind {\n\nint GetSplitCount(int   max_split_cnt,\n                  int   grid_size,\n                  int   max_active_ctas,\n                  int   sm_count,\n                  int   max_wave_cnt,\n                  float alpha = 1,\n                  float beta  = 1e-3);\n\n}\n"
  },
  {
    "path": "src/turbomind/kernels/ban_bad_words.cu",
    "content": "/*\n * Copyright (c) 2022-2023, NVIDIA CORPORATION.  All rights reserved.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include \"src/turbomind/kernels/ban_bad_words.h\"\n#include <cfloat>\n// #include \"src/turbomind/kernels/reduce_kernel_utils.cuh\"\n// #include \"src/turbomind/utils/cuda_utils.h\"\n#include <cuda_bf16.h>\n#include <cuda_fp16.h>\n\nnamespace turbomind {\n\ntemplate<typename T>\n__device__ inline T getMaxValue();\n\ntemplate<>\n__device__ inline float getMaxValue<float>()\n{\n    return FLT_MAX;\n}\n\ntemplate<>\n__device__ inline half getMaxValue<half>()\n{\n    return __ushort_as_half((unsigned short)0x7BFFU);\n}\n\n#ifdef ENABLE_BF16\ntemplate<>\n__device__ inline __nv_bfloat16 getMaxValue<__nv_bfloat16>()\n{\n#if __CUDA_ARCH__ >= 800\n    return __ushort_as_bfloat16((unsigned short)0x7F7FU);\n#endif\n    return {};\n}\n#endif\n\ntemplate<class T>\n__global__ void BanBadWordsKernel(T*                logits,\n                                  const int* const* token_ids_ptrs,\n                                  const int*        sequence_length,\n                                  const int*        bad_words,\n                                  size_t            bad_words_len,\n                                  int               vocab_size)\n{\n    const int id        = blockIdx.x * blockDim.x + threadIdx.x;\n    const int batch_idx = blockIdx.y;\n\n    const int* base_bad_words         = bad_words + batch_idx * 2 * bad_words_len;\n    const int* base_bad_words_offsets = base_bad_words + bad_words_len;\n\n    if (id >= bad_words_len || base_bad_words_offsets[id] < 0) {\n        return;\n    }\n\n    const int item_end   = base_bad_words_offsets[id];\n    const int item_start = (id > 0) ? base_bad_words_offsets[id - 1] : 0;\n    const int item_size  = item_end - item_start;\n\n    const int  seq_len   = sequence_length[batch_idx];\n    const int* token_ids = token_ids_ptrs[batch_idx];\n\n    /* The single-token case unconditionally bans the token */\n    bool should_ban = item_size == 1;\n\n    /* Multi-token case and enough previously generated tokens to look for a match */\n    if (item_size > 1 && seq_len >= item_size - 1) {\n        should_ban = true;\n        for (int token_idx = item_size - 2, offset = seq_len - 1; token_idx >= 0; token_idx--, offset--) {\n            if (token_ids[offset] != base_bad_words[item_start + token_idx]) {\n                should_ban = false;\n                break;\n            }\n        }\n    }\n\n    logits += batch_idx * (int64_t)vocab_size;\n    if (should_ban) {\n        int banned_token = base_bad_words[item_end - 1];\n        if (0 < banned_token && banned_token < vocab_size) {\n            logits[banned_token] = -getMaxValue<T>();\n        }\n    }\n}\n\nvoid BanBadWords(Tensor&             logits,\n                 const Buffer_<int*> token_ids_ptrs,\n                 const Buffer_<int>& sequence_length,\n                 const Tensor_<int>& bad_words,\n                 cudaStream_t        stream)\n{\n\n    auto invoke = [&](auto dtype) {\n        using T = decltype(dtype);\n\n        const auto [bsz, vocab_size] = logits.shapes(0, 1);\n        const int bad_words_len      = bad_words.shape(2);\n\n        const int  block = std::min(round_up(bad_words_len, WARP_SIZE), 256);\n        const dim3 grid(cdiv(bad_words_len, block), bsz);\n\n        BanBadWordsKernel<<<grid, block, 0, stream>>>(logits.data<T>(),\n                                                      token_ids_ptrs.data(),\n                                                      sequence_length.data(),\n                                                      bad_words.data(),\n                                                      bad_words_len,\n                                                      vocab_size);\n    };\n\n    invoke(float{});\n}\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/kernels/ban_bad_words.h",
    "content": "/*\n * Copyright (c) 2022-2023, NVIDIA CORPORATION.  All rights reserved.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#pragma once\n\n#include <cuda_runtime.h>\n\n#include \"src/turbomind/core/core.h\"\n\nnamespace turbomind {\n\nvoid BanBadWords(Tensor&             logits,\n                 const Buffer_<int*> token_ids_ptrs,\n                 const Buffer_<int>& sequence_length,\n                 const Tensor_<int>& bad_words,\n                 cudaStream_t        stream);\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/kernels/core/array.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include \"src/turbomind/core/data_type.h\"\n\n#include \"src/turbomind/kernels/core/common.h\"\n#include \"src/turbomind/kernels/core/data_type.h\"\n#include \"src/turbomind/kernels/core/sub_byte_ptr.h\"\n\nnamespace turbomind {\n\ntemplate<typename T, int N>\nstruct Array {\n    using value_type      = T;\n    using size_type       = int;\n    using difference_type = int;\n    using reference       = value_type&;\n    using const_reference = const value_type&;\n    using pointer         = value_type*;\n    using const_pointer   = const value_type*;\n    using iterator        = pointer;\n    using const_iterator  = const_pointer;\n\n    static_assert(N > 0);\n\n    T __a[N];\n\n    TM_HOST_DEVICE constexpr reference operator[](size_type i) noexcept\n    {\n        return __a[i];\n    }\n\n    TM_HOST_DEVICE constexpr const_reference operator[](size_type i) const noexcept\n    {\n        return __a[i];\n    }\n\n    TM_HOST_DEVICE constexpr reference front() noexcept\n    {\n        return *begin();\n    }\n\n    TM_HOST_DEVICE constexpr const_reference front() const noexcept\n    {\n        return *begin();\n    }\n\n    TM_HOST_DEVICE constexpr reference back() noexcept\n    {\n        return *(end() - 1);\n    }\n\n    TM_HOST_DEVICE constexpr const_reference back() const noexcept\n    {\n        return *(end() - 1);\n    }\n\n    TM_HOST_DEVICE constexpr pointer data() noexcept\n    {\n        return &__a[0];\n    }\n\n    TM_HOST_DEVICE constexpr const_pointer data() const noexcept\n    {\n        return &__a[0];\n    }\n\n    TM_HOST_DEVICE constexpr iterator begin() noexcept\n    {\n        return data();\n    }\n\n    TM_HOST_DEVICE constexpr const_iterator begin() const noexcept\n    {\n        return data();\n    }\n\n    TM_HOST_DEVICE constexpr iterator end() noexcept\n    {\n        return data() + N;\n    }\n\n    TM_HOST_DEVICE constexpr const_iterator end() const noexcept\n    {\n        return data() + N;\n    }\n\n    TM_HOST_DEVICE static constexpr std::integral_constant<int, N> size() noexcept\n    {\n        return {};\n    }\n\n    TM_HOST_DEVICE static constexpr std::false_type empty() noexcept\n    {\n        return {};\n    }\n};\n\ntemplate<int N>\nstruct Array<uint4_t, N> {\n    using value_type      = detail::__uint4_t;\n    using size_type       = int;\n    using difference_type = int;\n    using reference       = value_type&;\n    using const_reference = const value_type&;\n    using pointer         = SubBytePtr<uint4_t>;\n    using const_pointer   = SubBytePtr<const uint4_t>;\n\n    // static_assert(N % 8 == 0);\n\n    detail::__uint4_t __a[N / 8];\n\n    TM_HOST_DEVICE constexpr reference operator[](size_type i) noexcept\n    {\n        return __a[i / 8];\n    }\n\n    TM_HOST_DEVICE constexpr const_reference operator[](size_type i) const noexcept\n    {\n        return __a[i / 8];\n    }\n\n    TM_HOST_DEVICE static constexpr std::integral_constant<int, N> size() noexcept\n    {\n        return {};\n    }\n\n    TM_HOST_DEVICE static constexpr std::false_type empty() noexcept\n    {\n        return {};\n    }\n\n    TM_HOST_DEVICE constexpr pointer data() noexcept\n    {\n        return {(char*)&__a[0]};\n    }\n};\n\nstatic_assert(sizeof(Array<uint4_t, 8>) == 4);\nstatic_assert(sizeof(Array<uint4_t, 16>) == 8);\nstatic_assert(sizeof(Array<uint4_t, 24>) == 12);\nstatic_assert(sizeof(Array<uint4_t, 32>) == 16);\n\ntemplate<int N>\nstruct Array<fp4_e2m1_t, N> {\n    using value_type      = detail::__uint4_t;\n    using size_type       = int;\n    using difference_type = int;\n    using reference       = value_type&;\n    using const_reference = const value_type&;\n    using pointer         = SubBytePtr<fp4_e2m1_t>;\n    using const_pointer   = SubBytePtr<const fp4_e2m1_t>;\n\n    // static_assert(N % 8 == 0);\n\n    detail::__uint4_t __a[N / 8];\n\n    TM_HOST_DEVICE constexpr reference operator[](size_type i) noexcept\n    {\n        return __a[i / 8];\n    }\n\n    TM_HOST_DEVICE constexpr const_reference operator[](size_type i) const noexcept\n    {\n        return __a[i / 8];\n    }\n\n    TM_HOST_DEVICE static constexpr std::integral_constant<int, N> size() noexcept\n    {\n        return {};\n    }\n\n    TM_HOST_DEVICE static constexpr std::false_type empty() noexcept\n    {\n        return {};\n    }\n\n    TM_HOST_DEVICE constexpr pointer data() noexcept\n    {\n        return {(char*)&__a[0]};\n    }\n};\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/kernels/core/array_ops.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include \"src/turbomind/kernels/core/array.h\"\n#include \"src/turbomind/kernels/core/common.h\"\n#include <cassert>\n#include <type_traits>\n\nnamespace turbomind {\n\nnamespace ops {\n\ntemplate<typename T>\nstruct plus {\n    __device__ T operator()(T a, T b)\n    {\n        return a + b;\n    }\n};\n\ntemplate<typename T>\nstruct minus {\n    __device__ T operator()(T a, T b)\n    {\n        return a - b;\n    }\n};\n\ntemplate<typename T>\nstruct multiplies {\n    __device__ T operator()(T a, T b)\n    {\n        return a * b;\n    }\n};\n\ntemplate<typename T, int N, typename Op>\ninline __device__ Array<T, N> binary_op_vv(const Array<T, N>& a, const Array<T, N>& b, Op op)\n{\n    Array<T, N> c;\n    PRAGMA_UNROLL\n    for (int i = 0; i < N; ++i) {\n        c[i] = op(a[i], b[i]);\n    }\n    return c;\n}\n\ntemplate<typename T, int N, typename Op>\ninline __device__ Array<T, N> binary_op_sv(const T& a, const Array<T, N>& b, Op op)\n{\n    Array<T, N> c;\n    PRAGMA_UNROLL\n    for (int i = 0; i < N; ++i) {\n        c[i] = op(a, b[i]);\n    }\n    return c;\n}\n\ntemplate<typename T, int N, typename Op>\ninline __device__ Array<T, N> binary_op_vs(const Array<T, N>& a, const T& b, Op op)\n{\n    Array<T, N> c;\n    PRAGMA_UNROLL\n    for (int i = 0; i < N; ++i) {\n        c[i] = op(a[i], b);\n    }\n    return c;\n}\n\ntemplate<typename T, int N>\ninline __device__ Array<T, N> operator+(const Array<T, N>& a, const Array<T, N>& b)\n{\n    return binary_op_vv(a, b, plus<T>{});\n}\n\ntemplate<typename T, int N>\ninline __device__ Array<T, N> operator*(const Array<T, N>& a, const Array<T, N>& b)\n{\n    return binary_op_vv(a, b, multiplies<T>{});\n}\n\ntemplate<typename T, int N>\ninline __device__ Array<T, N> operator*(const Array<T, N>& a, const T& b)\n{\n    return binary_op_vs(a, b, multiplies<T>{});\n}\n\ntemplate<typename T, int N>\ninline __device__ Array<T, N> operator+(const Array<T, N>& a, const T& b)\n{\n    return binary_op_vs(a, b, plus<T>{});\n}\n\ntemplate<typename T, int N>\ninline __device__ Array<T, N> operator-(const Array<T, N>& a, const T& b)\n{\n    return binary_op_vs(a, b, minus<T>{});\n}\n\n}  // namespace ops\n\ntemplate<typename To, typename From, int N>\ninline __device__ Array<To, N> cast(const Array<From, N>& src)\n{\n    Array<To, N> dst;\n    PRAGMA_UNROLL\n    for (int i = 0; i < N; ++i) {\n        dst[i] = (To)src[i];\n    }\n    return dst;\n}\n\ntemplate<class T, int N>\ninline __device__ void fill(Array<T, N>& x, T val)\n{\n    PRAGMA_UNROLL\n    for (int i = 0; i < N; ++i) {\n        x[i] = val;\n    }\n}\n\ntemplate<class T, int M, int N>\ninline __device__ void fill(Array<T, N> (&x)[M], T val)\n{\n    PRAGMA_UNROLL\n    for (int i = 0; i < M; ++i) {\n        fill(x[i], val);\n    }\n}\n\ntemplate<class T, int N>\ninline __device__ void clear(Array<T, N>& x)\n{\n    fill(x, T(0));\n}\n\ntemplate<class T, int M, int N>\ninline __device__ void clear(Array<T, N> (&x)[M])\n{\n    PRAGMA_UNROLL\n    for (int i = 0; i < M; ++i) {\n        clear(x[i]);\n    }\n}\n\ntemplate<class T, int M1, int M0, int N>\ninline __device__ void clear(Array<T, N> (&x)[M1][M0])\n{\n    PRAGMA_UNROLL\n    for (int m1 = 0; m1 < M1; ++m1) {\n        PRAGMA_UNROLL\n        for (int m0 = 0; m0 < M0; ++m0) {\n            clear(x[m1][m0]);\n        }\n    }\n}\n\ntemplate<class T, int N>\ninline __device__ void copy(const Array<T, N>& src, Array<T, N>& dst)\n{\n    dst = src;\n}\n\ntemplate<class T, int M, int N>\ninline __device__ void copy(const Array<T, N> (&src)[M], Array<T, N> (&dst)[M])\n{\n    PRAGMA_UNROLL\n    for (int m = 0; m < M; ++m) {\n        dst[m] = src[m];\n    }\n}\n\ntemplate<typename T, int N>\ninline __device__ void Store(T* dst, const Array<T, N>& src)\n{\n    if constexpr (sizeof(Array<T, N>) == sizeof(uint4)) {\n        *(uint4*)dst = (const uint4&)src;\n    }\n    else if constexpr (sizeof(Array<T, N>) == sizeof(uint2)) {\n        *(uint2*)dst = (const uint2&)src;\n    }\n    else if constexpr (sizeof(Array<T, N>) == sizeof(uint1)) {\n        *(uint1*)dst = (const uint1&)src;\n    }\n    else if constexpr (sizeof(Array<T, N>) == sizeof(ushort)) {\n        *(ushort*)dst = (const ushort&)src;\n    }\n    else if constexpr (sizeof(Array<T, N>) == sizeof(char)) {\n        *(char*)dst = (const char&)src;\n    }\n    else if constexpr (sizeof(Array<T, N>) % sizeof(uint4) == 0) {  //  uncoalesced\n        static_assert(bitsof<T> % 8 == 0, \"raw pointer arithmetic of sub-byte types\");\n        constexpr int M = sizeof(Array<T, N>) / sizeof(uint4);\n        PRAGMA_UNROLL\n        for (int i = 0; i < M; ++i) {\n            *((uint4*)dst + i) = *((uint4*)&src + i);\n        }\n    }\n    else {\n        static_assert(!std::is_same_v<T, T>);\n    }\n}\n\ntemplate<typename T, int N>\ninline __device__ void Stcs(T* __restrict__ dst, const Array<T, N>& src)\n{\n    static_assert(sizeof(Array<T, N>) <= sizeof(uint4));\n\n    if constexpr (sizeof(Array<T, N>) == sizeof(uint4)) {\n        __stcs((uint4*)dst, (const uint4&)src);\n    }\n    else if constexpr (sizeof(Array<T, N>) == sizeof(uint2)) {\n        __stcs((uint2*)dst, (const uint2&)src);\n    }\n    else if constexpr (sizeof(Array<T, N>) == sizeof(uint1)) {\n        __stcs((uint*)dst, (const uint&)src);\n    }\n    else if constexpr (sizeof(Array<T, N>) == sizeof(uint16_t)) {\n        __stcs((uint16_t*)dst, (const uint16_t&)src);\n    }\n    else if constexpr (sizeof(Array<T, N>) == sizeof(uint8_t)) {\n        __stcs((uint8_t*)dst, (const uint8_t&)src);\n    }\n    else {\n        static_assert(!std::is_same_v<T, T>);\n    }\n}\n\ntemplate<typename T, int N>\ninline __device__ void Stcg(T* __restrict__ dst, const Array<T, N>& src)\n{\n    static_assert(sizeof(Array<T, N>) <= sizeof(uint4));\n\n    if constexpr (sizeof(Array<T, N>) == sizeof(uint4)) {\n        __stcg((uint4*)dst, (const uint4&)src);\n    }\n    else if constexpr (sizeof(Array<T, N>) == sizeof(uint2)) {\n        __stcg((uint2*)dst, (const uint2&)src);\n    }\n    else if constexpr (sizeof(Array<T, N>) == sizeof(uint1)) {\n        __stcg((uint*)dst, (const uint&)src);\n    }\n    else if constexpr (sizeof(Array<T, N>) == sizeof(uint16_t)) {\n        __stcg((uint16_t*)dst, (const uint16_t&)src);\n    }\n    else if constexpr (sizeof(Array<T, N>) == sizeof(uint8_t)) {\n        __stcg((uint8_t*)dst, (const uint8_t&)src);\n    }\n    else {\n        static_assert(!std::is_same_v<T, T>);\n    }\n}\n\ntemplate<typename T, int N>\ninline __device__ void Ldg(Array<T, N>& dst, const T* src)\n{\n    static_assert(sizeof(Array<T, N>) <= sizeof(uint4));\n\n    if constexpr (sizeof(Array<T, N>) == sizeof(uint4)) {\n        (uint4&)dst = __ldg((const uint4*)src);\n    }\n    else if constexpr (sizeof(Array<T, N>) == sizeof(uint2)) {\n        (uint2&)dst = __ldg((const uint2*)src);\n    }\n    else if constexpr (sizeof(Array<T, N>) == sizeof(uint)) {\n        (uint&)dst = __ldg((const uint*)src);\n    }\n    else if constexpr (sizeof(Array<T, N>) == sizeof(uint16_t)) {\n        (uint16_t&)dst = __ldg((const uint16_t*)src);\n    }\n    else if constexpr (sizeof(Array<T, N>) == sizeof(uint8_t)) {\n        (uint8_t&)dst = __ldg((const uint8_t*)src);\n    }\n    else {\n        static_assert(!std::is_same_v<T, T>);\n    }\n}\n\ntemplate<typename T, int N>\ninline __device__ void Ldcs(Array<T, N>& dst, const T* src)\n{\n    static_assert(sizeof(Array<T, N>) <= sizeof(uint4));\n\n    if constexpr (sizeof(Array<T, N>) == sizeof(uint4)) {\n        (uint4&)dst = __ldcs((const uint4*)src);\n    }\n    else if constexpr (sizeof(Array<T, N>) == sizeof(uint2)) {\n        (uint2&)dst = __ldcs((const uint2*)src);\n    }\n    else if constexpr (sizeof(Array<T, N>) == sizeof(uint)) {\n        (uint&)dst = __ldcs((const uint*)src);\n    }\n    else if constexpr (sizeof(Array<T, N>) == sizeof(uint16_t)) {\n        (uint16_t&)dst = __ldcs((const uint16_t*)src);\n    }\n    else if constexpr (sizeof(Array<T, N>) == sizeof(uint8_t)) {\n        (uint8_t&)dst = __ldcs((const uint8_t*)src);\n    }\n    else {\n        static_assert(!std::is_same_v<T, T>);\n    }\n}\n\ntemplate<typename T, int N>\ninline __device__ void Ldcg(Array<T, N>& dst, const T* src)\n{\n    static_assert(sizeof(Array<T, N>) <= sizeof(uint4));\n\n    if constexpr (sizeof(Array<T, N>) == sizeof(uint4)) {\n        (uint4&)dst = __ldcg((const uint4*)src);\n    }\n    else if constexpr (sizeof(Array<T, N>) == sizeof(uint2)) {\n        (uint2&)dst = __ldcg((const uint2*)src);\n    }\n    else if constexpr (sizeof(Array<T, N>) == sizeof(uint)) {\n        (uint&)dst = __ldcg((const uint*)src);\n    }\n    else if constexpr (sizeof(Array<T, N>) == sizeof(uint16_t)) {\n        (uint16_t&)dst = __ldcg((const uint16_t*)src);\n    }\n    else if constexpr (sizeof(Array<T, N>) == sizeof(uint8_t)) {\n        (uint8_t&)dst = __ldcg((const uint8_t*)src);\n    }\n    else {\n        static_assert(!std::is_same_v<T, T>);\n    }\n}\n\ntemplate<typename T, int N>\ninline __device__ void Load(Array<T, N>& dst, const T* src)\n{\n    if constexpr (sizeof(Array<T, N>) == sizeof(uint4)) {\n        (uint4&)dst = *(const uint4*)src;\n    }\n    else if constexpr (sizeof(Array<T, N>) == sizeof(uint2)) {\n        (uint2&)dst = *(const uint2*)src;\n    }\n    else if constexpr (sizeof(Array<T, N>) == sizeof(uint)) {\n        (uint1&)dst = *(const uint1*)src;\n    }\n    else if constexpr (sizeof(Array<T, N>) == sizeof(uint16_t)) {\n        (uint16_t&)dst = *(const uint16_t*)src;\n    }\n    else if constexpr (sizeof(Array<T, N>) == sizeof(uint8_t)) {\n        (uint8_t&)dst = *(const uint8_t*)src;\n    }\n    else if constexpr (sizeof(Array<T, N>) % sizeof(uint4) == 0) {  //  uncoalesced\n        static_assert(bitsof<T> % 8 == 0, \"raw pointer arithmetic of sub-byte types\");\n        constexpr int M = sizeof(Array<T, N>) / sizeof(uint4);\n        PRAGMA_UNROLL\n        for (int i = 0; i < M; ++i) {\n            *((uint4*)&dst + i) = *((uint4*)src + i);\n        }\n    }\n    else {\n        static_assert(!std::is_same_v<T, T>);\n    }\n}\n\ntemplate<typename T, int N>\ninline __device__ void Lds(Array<T, N>& dst, const T* src)\n{\n    Load(dst, src);\n}\n\ntemplate<typename T, int N>\ninline __device__ void LdShared(Array<T, N>& dst, uint32_t uintptr)\n{\n    static_assert(sizeof(Array<T, N>) <= sizeof(uint4));\n    if constexpr (sizeof(Array<T, N>) == sizeof(uint4)) {\n        uint4& p = (uint4&)dst;\n        // clang-format off\n        asm volatile(\"ld.shared.v4.b32 {%0,%1,%2,%3}, [%4];\\n\" : \"=r\"(p.x), \"=r\"(p.y), \"=r\"(p.z), \"=r\"(p.w) : \"r\"(uintptr));\n        // clang-format on\n    }\n    else if constexpr (sizeof(Array<T, N>) == sizeof(uint2)) {\n        uint2& p = (uint2&)dst;\n        asm volatile(\"ld.shared.v2.b32 {%0,%1}, [%2];\\n\" : \"=r\"(p.x), \"=r\"(p.y) : \"r\"(uintptr));\n    }\n    else if constexpr (sizeof(Array<T, N>) == sizeof(uint)) {\n        uint& p = (uint&)dst;\n        asm volatile(\"ld.shared.b32 %0, [%1];\\n\" : \"=r\"(p) : \"r\"(uintptr));\n    }\n    else {\n        static_assert(!std::is_same_v<T, T>);\n    }\n}\n\ntemplate<typename T, int N>\ninline __device__ void StShared(uint32_t uintptr, Array<T, N>& src)\n{\n    static_assert(sizeof(Array<T, N>) <= sizeof(uint4));\n    if constexpr (sizeof(Array<T, N>) == sizeof(uint4)) {\n        uint4& p = (uint4&)src;\n        // clang-format off\n        asm volatile(\"st.shared.v4.b32 [%0], {%1,%2,%3,%4};\\n\" :: \"r\"(uintptr), \"r\"(p.x), \"r\"(p.y), \"r\"(p.z), \"r\"(p.w) );\n        // clang-format on\n    }\n    else if constexpr (sizeof(Array<T, N>) == sizeof(uint2)) {\n        uint2& p = (uint2&)src;\n        asm volatile(\"st.shared.v2.b32 [%0], {%1,%2};\\n\" ::\"r\"(uintptr), \"r\"(p.x), \"r\"(p.y));\n    }\n    else if constexpr (sizeof(Array<T, N>) == sizeof(uint)) {\n        uint& p = (uint&)src;\n        asm volatile(\"st.shared.b32  [%0], %1;\\n\" ::\"r\"(uintptr), \"r\"(p));\n    }\n    else {\n        static_assert(!std::is_same_v<T, T>);\n    }\n}\n\ntemplate<int kWarpCount, typename T, int N>\ninline __device__ Array<T, N> blockSum(Array<T, N> val, T* smem_red, int warp_id, int lane_id)\n{\n    PRAGMA_UNROLL\n    for (int i = 0; i < N; ++i) {\n        PRAGMA_UNROLL\n        for (int mask = WARP_SIZE >> 1; mask >= 1; mask >>= 1) {\n            val[i] += __shfl_xor_sync((uint32_t)-1, val[i], mask);\n        }\n        if (lane_id == 0) {\n            smem_red[i * kWarpCount + warp_id] = val[i];\n        }\n    }\n\n    __syncthreads();\n\n    PRAGMA_UNROLL\n    for (int i = 0; i < N; ++i) {\n        val[i] = lane_id < kWarpCount ? smem_red[i * kWarpCount + lane_id] : T{};\n        PRAGMA_UNROLL\n        for (int mask = kWarpCount >> 1; mask >= 1; mask >>= 1) {\n            val[i] += __shfl_xor_sync((uint32_t)-1, val[i], mask);\n        }\n        val[i] = __shfl_sync((uint32_t)-1, val[i], 0);\n    }\n\n    return val;\n}\n\ntemplate<class T, int N>\n__device__ void CpAsync(T* dst, const Array<T, N>* __restrict__ src)\n{\n    const int     smem_int_ptr = cast_smem_ptr_to_uint(dst);\n    constexpr int cp_size      = sizeof(Array<T, N>);\n#if TURBOMIND_ARCH_SM80\n    asm volatile(\"cp.async.ca.shared.global [%0], [%1], %2;\\n\" ::\"r\"(smem_int_ptr), \"l\"(src), \"n\"(cp_size));\n#else\n    assert(TURBOMIND_ARCH_SM80);\n#endif\n}\n\n__inline__ __device__ uint transpose_m8n8_b16_warp_shuffle(uint value)\n{\n    const int lane_id  = threadIdx.x % WARP_SIZE;\n    int       src_lane = lane_id / 8 + lane_id % 4 * 8;\n    uint      u0       = __shfl_sync(0xffffffff, value, src_lane);\n    uint      u1       = __shfl_sync(0xffffffff, value, src_lane + 4);\n    short2    r;\n\n    if (lane_id % 8 < 4) {\n        r.x = ((short2&)u0).x;\n        r.y = ((short2&)u1).x;\n    }\n    else {\n        r.x = ((short2&)u0).y;\n        r.y = ((short2&)u1).y;\n    }\n    return (uint&)r;\n}\n\n#if (__CUDACC_VER_MAJOR__ >= 11) && (__CUDACC_VER_MINOR__ >= 8)\n__inline__ __device__ uint transpose_m8n8_b16_movmatrix(uint a)\n{\n#if TURBOMIND_ARCH_SM75\n    uint d;\n    asm volatile(\"movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;\\n\" : \"=r\"(d) : \"r\"(a));\n    return d;\n#else\n    assert(TURBOMIND_ARCH_SM75);\n    return 0;\n#endif\n}\n#endif\n\n__inline__ __device__ uint32_t transpose_m8n8_b16(uint32_t a)\n{\n#if (__CUDACC_VER_MAJOR__ >= 11) && (__CUDACC_VER_MINOR__ >= 8)\n    return transpose_m8n8_b16_movmatrix(a);\n#else\n    return transpose_m8n8_b16_warp_shuffle(a);\n#endif\n}\n\n__inline__ __device__ Array<uint32_t, 2> transpose_m8n8_b32(const Array<uint32_t, 2>& x)\n{\n    uint32_t lo = __byte_perm(x[0], x[1], 0x5410);\n    uint32_t hi = __byte_perm(x[0], x[1], 0x7632);\n\n    lo = transpose_m8n8_b16(lo);\n    hi = transpose_m8n8_b16(hi);\n\n    Array<uint32_t, 2> y;\n    y[0] = __byte_perm(lo, hi, 0x5410);\n    y[1] = __byte_perm(lo, hi, 0x7632);\n\n    return y;\n}\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/kernels/core/common.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700))\n#define TURBOMIND_ARCH_SM70 1\n#else\n#define TURBOMIND_ARCH_SM70 0\n#endif\n\n#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750))\n#define TURBOMIND_ARCH_SM75 1\n#else\n#define TURBOMIND_ARCH_SM75 0\n#endif\n\n#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))\n#define TURBOMIND_ARCH_SM80 1\n#else\n#define TURBOMIND_ARCH_SM80 0\n#endif\n\n#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))\n#define TURBOMIND_ARCH_SM90 1\n#else\n#define TURBOMIND_ARCH_SM90 0\n#endif\n\n#define TURBOMIND_ARCH_HAS_BF16 TURBOMIND_ARCH_SM80\n\n#define TURBOMIND_ARCH_HAS_FP8 TURBOMIND_ARCH_SM90\n\n#define TURBOMIND_ARCH_BF16_GUARD(type) (TURBOMIND_ARCH_HAS_BF16 || type != ::turbomind::kBfloat16)\n\n#define TURBOMIND_ARCH_FP8_GUARD(type)                                                                                 \\\n    (TURBOMIND_ARCH_HAS_FP8 || (type != ::turbomind::kFloat8_e4m3 && type != ::turbomind::kFloat8_e5m2))\n\n#define TURBOMIND_ARCH_DTYPE_GUARD(type) (TURBOMIND_ARCH_BF16_GUARD(type) && TURBOMIND_ARCH_FP8_GUARD(type))\n\n#if defined(__CUDA_ARCH__) && !defined(__INTELLISENSE__)\n#if defined(__CUDACC_RTC__) || (defined(__clang__) && defined(__CUDA__))\n#define PRAGMA_UNROLL _Pragma(\"unroll\")\n#define PRAGMA_UNROLL_4 _Pragma(\"unroll 4\")\n#define PRAGMA_NO_UNROLL _Pragma(\"unroll 1\")\n\n#else\n#define PRAGMA_UNROLL #pragma unroll\n#define PRAGMA_UNROLL_4 #pragma unroll 4\n#define PRAGMA_NO_UNROLL #pragma unroll 1\n\n#endif\n#else\n#define PRAGMA_UNROLL\n#define PRAGMA_UNROLL_4\n#define PRAGMA_NO_UNROLL\n#endif\n\n#if defined(__CUDACC__)\n#define TM_HOST_DEVICE __forceinline__ __host__ __device__\n#define TM_DEVICE __forceinline__ __device__\n#define TM_HOST __forceinline__ __host__\n#else\n#define TM_HOST_DEVICE inline\n#define TM_DEVICE inline\n#define TM_HOST inline\n#endif\n\nconstexpr int WARP_SIZE = 32;\n\n#ifndef uint\nusing uint = unsigned int;\n#endif\n\n#ifndef ushort\nusing ushort = unsigned short int;\n#endif\n"
  },
  {
    "path": "src/turbomind/kernels/core/data_type.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include <cuda_fp16.h>\n#if ENABLE_BF16\n#include <cuda_bf16.h>\n#endif\n\n#include <cstdint>\n\n#include \"src/turbomind/core/data_type.h\"\n\nnamespace turbomind {\n\nnamespace detail {\n\nstruct __uint4_t {\n    uint32_t x;\n};\n\n}  // namespace detail\n\ntemplate<class T, class SFINAE = void>\nstruct get_pointer_type_t {\n    using type = T*;\n};\n\ntemplate<class T>\nusing get_pointer_type = typename get_pointer_type_t<T>::type;\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/kernels/core/floating_point.h",
    "content": "#pragma once\n\n#include \"src/turbomind/core/data_type.h\"\n\n#include \"src/turbomind/kernels/core/common.h\"\n\nnamespace turbomind {\n\ntemplate<int E, int M>\nstruct FloatingPoint {\n    static constexpr unsigned exponent_bits = E;\n    static constexpr unsigned mantissa_bits = M;\n    static constexpr unsigned exponent_bias = ((1 << exponent_bits) - 1) / 2;\n\n    static constexpr unsigned bits = 1 + exponent_bits + mantissa_bits;\n\n    static constexpr unsigned exponent_mask = (1 << exponent_bits) - 1;\n    static constexpr unsigned mantissa_mask = (1 << mantissa_bits) - 1;\n\n    // clang-format off\n    // For `reinterpret_cast` is not constexpr yet\n    static constexpr float exp2(unsigned e) { float x = 1; for (; e > 0; --e) { x *= 2; } return x; }\n    // clang-format on\n\n    static constexpr float max_normal =\n        ((1U << (mantissa_bits + 1U)) - 1U) * exp2(exponent_bias + 1) / exp2(mantissa_bits);\n    static constexpr float min_normal   = 1 / exp2(exponent_bias - 1);\n    static constexpr float max_denormal = mantissa_mask / exp2(exponent_bias - 1 + mantissa_bits);\n    static constexpr float min_denormal = 1 / exp2(exponent_bias - 1 + mantissa_bits);\n\n    // Modified from `__nv_cvt_double_to_fp8` in <cuda_fp8.hpp>\n    template<class R>\n    __device__ static unsigned from_f32(float x, R rbits)\n    {\n        constexpr bool stochastic = std::is_same_v<R, unsigned>;\n\n        // 1/2 LSB of the target format, positioned in single precision mantissa\n        constexpr int half_ulp = 1U << (23U - mantissa_bits - 1U);\n\n        auto absx = fabsf(x);\n\n        unsigned xbits = __float_as_uint(x);\n\n        unsigned sign     = (xbits >> 31U) << (bits - 1);\n        unsigned exp      = ((xbits >> 23U) & 0xFFU) - 127U + exponent_bias;\n        unsigned mantissa = (xbits >> (23U - mantissa_bits)) & mantissa_mask;\n\n        unsigned res;\n\n        if (absx <= min_denormal / 2.) {  // underflow\n            res = 0;\n        }\n        else if (absx > max_normal) {  // overflow\n            res = (exponent_mask << mantissa_bits) | mantissa_mask;\n        }\n        else if (absx >= min_normal) {  // normal\n            res = (exp << mantissa_bits) | mantissa;\n\n            unsigned round_mask = (half_ulp << 1U) - 1U;\n            // rounded-off bits\n            unsigned round = xbits & round_mask;\n            if constexpr (stochastic) {\n                // stochastic rounding (.rs) adjustment\n                if (round + (rbits & round_mask) > round_mask) {\n                    res += 1U;\n                }\n            }\n            else {\n                // round-to-nearest-even (.rn) adjustment\n                if ((round > half_ulp) || ((round == half_ulp) && (mantissa & 1U))) {\n                    res += 1U;\n                }\n            }\n        }\n        else {  // denormal\n            unsigned shift = 1U - exp;\n            // add implicit leading bit\n            mantissa |= 1U << mantissa_bits;\n            // additional round-off due to denormalization\n            res = mantissa >> shift;\n\n            unsigned round_mask = (half_ulp << (shift + 1U)) - 1U;\n            // rounded-off bits, including implicit leading bit\n            unsigned round = (xbits | (1U << 23U)) & round_mask;\n            if constexpr (stochastic) {\n                // stochastic rounding (.rs) adjustment\n                if (round + (rbits & round_mask) > round_mask) {\n                    res += 1U;\n                }\n            }\n            else {\n                // round-to-nearest-even (.rn) adjustment\n                if ((round > (half_ulp << shift)) || ((round == (half_ulp << shift)) && (res & 1U))) {\n                    res += 1U;\n                }\n            }\n        }\n\n        res |= sign;  // preserve sign\n\n        return res;\n    }\n\n    __device__ static float to_f32(unsigned x)\n    {\n        unsigned u = (x >> (bits - 1U)) << 31U;\n        u |= (x & ((1U << (bits - 1U)) - 1U)) << (23U - mantissa_bits);\n\n        unsigned e = (127U - exponent_bias + 127U) << 23U;\n\n        float res;\n        /// ! force non-FTZ multiplication\n        asm(\"mul.f32 %0, %1, %2;\" : \"=f\"(res) : \"r\"(u), \"r\"(e));\n\n        return res;\n    }\n};\n\nstatic_assert(FloatingPoint<2, 1>::max_normal == 6);\nstatic_assert(FloatingPoint<2, 1>::min_normal == 1);\nstatic_assert(FloatingPoint<2, 1>::max_denormal == .5);\nstatic_assert(FloatingPoint<2, 1>::min_denormal == .5);\n\nstatic_assert(FloatingPoint<3, 2>::max_normal == 28.0);\nstatic_assert(FloatingPoint<3, 2>::min_normal == 0.25);\nstatic_assert(FloatingPoint<3, 2>::max_denormal == 0.1875);\nstatic_assert(FloatingPoint<3, 2>::min_denormal == 0.0625);\n\nstatic_assert(FloatingPoint<2, 3>::max_normal == 7.5);\nstatic_assert(FloatingPoint<2, 3>::min_normal == 1.0);\nstatic_assert(FloatingPoint<2, 3>::max_denormal == 0.875);\nstatic_assert(FloatingPoint<2, 3>::min_denormal == 0.125);\n\n// FloatingPoint<4, 3>::max_normal;\n// FloatingPoint<4, 3>::min_normal;\n// FloatingPoint<4, 3>::max_denormal;\n// FloatingPoint<4, 3>::min_denormal;\n\n// FloatingPoint<5, 2>::max_normal;\n// FloatingPoint<5, 2>::min_normal;\n// FloatingPoint<5, 2>::max_denormal;\n// FloatingPoint<5, 2>::min_denormal;\n\n#if 0\n__device__ int cvt_rn_sat_e2m1_f32(float x)\n{\n    // 0000  0.0\n    // 0001  0.5\n    // 0010  1.0\n    // 0011  1.5\n    // 0100  2.0\n    // 0101  3.0\n    // 0110  4.0\n    // 0111  6.0\n\n    float z = fabs(x);\n    //   0.25  0.75   1.25  1.75  2.5   3.5    5.0\n    // 0.0   0.5   1.0   1.5   2.0   3.0   4.0   6.0\n    // 0000  0001  0010  0011  0100  0101  0110  0111\n    //   *           *           *           *\n    auto f = [](float z) {\n        if (z <= .25f) {\n            return 0;\n        }\n        else if (z < .75f) {\n            return 1;  // 0.5\n        }\n        else if (z <= 1.25f) {\n            return 2;  // 1.0\n        }\n        else if (z < 1.75f) {\n            return 3;  // 1.5\n        }\n        else if (z <= 2.5) {\n            return 4;  // 2.0\n        }\n        else if (z < 3.5f) {\n            return 5;  // 3.0\n        }\n        else if (z <= 5.f) {\n            return 6;  // 4.0\n        }\n        else {\n            return 7;  // 6.0\n        }\n    };\n\n    return f(z) | ((__float_as_uint(x) >> 31) << 3);\n}\n#endif\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/kernels/core/layout.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include \"src/turbomind/kernels/core/data_type.h\"\nnamespace turbomind {\n\ntemplate<int Bits, int Base, int Shift>\nstruct Swizzle {\n\n    using bit_mask = std::integral_constant<int, (1 << Bits) - 1>;\n    using yyy_mask = std::integral_constant<int, bit_mask{} << (Base + Shift)>;\n    using shift    = std::integral_constant<int, Shift>;\n\n    template<class Offset>\n    __host__ __device__ constexpr static auto apply(Offset offset)\n    {\n        return offset ^ ((offset & yyy_mask{}) >> shift{});\n    }\n\n    template<class Offset>\n    __host__ __device__ constexpr auto operator()(Offset offset)\n    {\n        return apply(offset);\n    }\n};\n\nstruct Identity {\n\n    template<class Offset>\n    __device__ constexpr static auto apply(Offset offset)\n    {\n        return offset;\n    }\n\n    template<class Offset>\n    __device__ Offset operator()(Offset offset)\n    {\n        return apply(offset);\n    }\n\n    template<int D>\n    __device__ int AdvanceS(int offset, int s0, int s1)\n    {\n        return offset;\n    }\n};\n\ntemplate<int S_, int C_, int S0_ = -1, int C0_ = -1, class Swizzle_ = Identity>\nstruct SmemLayoutV2 {\n\n    // (C0,S0),(   C1,       S1)\n    // ( 1,C0),(C0*S0, C0*S0*C1)\n\n    static constexpr int S = S_;\n    static constexpr int C = C_;\n\n    static constexpr int S0 = S0_ < 0 ? S : S0_;\n    static constexpr int C0 = C0_ < 0 ? C : C0_;\n\n    static_assert(S % S0 == 0);\n    static_assert(C % C0 == 0);\n\n    static constexpr int S1 = S / S0;\n    static constexpr int C1 = C / C0;\n\n    static constexpr int kSize = S * C;\n\n    static constexpr int kSize0 = S0 * C0;\n    static constexpr int kSize1 = S1 * C1;\n\n    using Swizzle = Swizzle_;\n\n    static constexpr int kIsTrivial = S == S0 && C == C0 && std::is_same_v<Swizzle, Identity>;\n\n    __forceinline__ __device__ static int apply(int s, int c, int offset = 0)\n    {\n        int s1 = s / S0;\n        int s0 = s % S0;\n        int c1 = c / C0;\n        int c0 = c % C0;\n        //            variable             | uniform |         constant\n        // return Swizzle::apply(s0 * C0 + c0) + offset + (s1 * C1 + c1) * kSize0;\n\n        // return offset + Swizzle::apply(s0 * C0 + c0) + (s1 * C1 + c1) * kSize0;\n\n        return Swizzle::apply(s0 * C0 + c0) + (s1 * C1 + c1) * kSize0 + offset;\n    }\n\n    __forceinline__ __device__ int operator()(int s, int c, int offset = 0)\n    {\n        return apply(s, c, offset);\n    }\n};\n\nstruct Offset {\n    __device__ explicit Offset(int value): value_{value} {};\n    __device__ int& operator()()\n    {\n        return value_;\n    }\n    __device__ const int& operator()() const\n    {\n        return value_;\n    }\n    int value_;\n};\n\ntemplate<class T, class Layout>\nstruct SmemAccessor {\n    using Pointer = get_pointer_type<T>;\n    Pointer ptr_;\n    Layout  layout_;\n\n    __device__ SmemAccessor(Pointer ptr): ptr_{ptr} {}\n\n    __device__ T& operator()(int s, int c)\n    {\n        return ptr_[layout_(s, c)];\n    }\n\n    __device__ T& operator()(int s, int c, int offset)\n    {\n        return ptr_[layout_(s, c, offset)];\n    }\n\n    __device__ T& operator()(int idx)\n    {\n        return ptr_[idx];\n    }\n};\n\ntemplate<class T0, class T1>\nstruct Stride {\n    T0 v0;\n    T1 v1;\n\n    // CTAD\n    __host__ __device__ Stride(T0 v0, T1 v1): v0{v0}, v1{v1} {}\n\n    template<class I0, class I1>\n    __host__ __device__ constexpr auto operator()(I0 i0, I1 i1) const\n    {\n        return v0 * i0 + v1 * i1;\n    }\n};\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/kernels/core/math.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include \"src/turbomind/kernels/core/common.h\"\n#include <cassert>\n#include <cstdint>\n#include <type_traits>\n\nnamespace turbomind {\n\ntemplate<class T>\nTM_HOST_DEVICE constexpr T ceil_div(T a, T b)\n{\n    return (a + b - 1) / b;\n}\n\ntemplate<class T>\nTM_HOST_DEVICE constexpr T cdiv(T a, T b)\n{\n    return (a + b - 1) / b;\n}\n\ntemplate<class T>\nTM_HOST_DEVICE constexpr T round_up(T a, T b)\n{\n    return (a + b - 1) / b * b;\n}\n\ntemplate<class T>\nTM_HOST_DEVICE constexpr T log2(T x)\n{\n    T n = 0;\n    while (x != 1) {\n        x /= 2;\n        ++n;\n    }\n    return n;\n}\n\n// static_assert(log2(65536) == 16);\n// static_assert(log2(32) == 5);\n// static_assert(log2(1) == 0);\n\ntemplate<class T>\nTM_HOST_DEVICE constexpr T lowbit(T x)\n{\n    const std::make_signed_t<T> s = x;\n    return static_cast<T>(s & -s);\n}\n\n// https://arxiv.org/abs/1902.01961\ntemplate<class T>\nstruct FastDivMod {\n};\n\ntemplate<>\nstruct FastDivMod<uint16_t> {\n    uint32_t c_;  // cdiv(2^32,d) = (2^32+d-1)/d = (2^32-1)/d+1\n    uint32_t d_;\n\n    TM_HOST_DEVICE constexpr FastDivMod(uint16_t d): c_{0xFFFFFFFF / d + 1}, d_{d} {}\n\n    template<class T>\n    TM_HOST_DEVICE friend constexpr uint16_t operator/(T a, FastDivMod b)\n    {\n        return (a * (uint64_t)b.c_) >> 32;\n    }\n\n    template<class T>\n    TM_HOST_DEVICE friend constexpr uint16_t operator%(T a, FastDivMod b)\n    {\n        uint64_t lowbits = (a * (uint64_t)b.c_) & 0xFFFFFFFF;\n        return (lowbits * b.d_) >> 32;\n    }\n\n    TM_HOST_DEVICE constexpr operator uint16_t() const noexcept\n    {\n        return d_;\n    }\n};\n\nstatic_assert(32 / FastDivMod<uint16_t>{5} == 6);\nstatic_assert(32 % FastDivMod<uint16_t>{5} == 2);\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/kernels/core/meta.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\nnamespace turbomind {\n\ntemplate<class T>\nstruct basic_type {\n    using type = T;\n};\n\ntemplate<class T>\nconstexpr basic_type<T> type_c{};\n\ntemplate<auto v>\nstruct constant {\n    using type       = constant;\n    using value_type = decltype(v);\n\n    static constexpr value_type value = v;\n\n    constexpr value_type operator()() const noexcept\n    {\n        return v;\n    }\n    constexpr operator value_type() const noexcept\n    {\n        return v;\n    }\n};\n\ntemplate<auto u, auto v>\nstruct pair {\n};\n\ntemplate<auto u, auto v>\nconstexpr auto first(pair<u, v>)\n{\n    return u;\n}\n\ntemplate<auto u, auto v>\nconstexpr auto second(pair<u, v>)\n{\n    return v;\n}\n\ntemplate<auto u, auto v, auto w>\nstruct triplet {\n};\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/kernels/core/mma.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include \"src/turbomind/kernels/core/array.h\"\n#include \"src/turbomind/kernels/core/common.h\"\n#include <cassert>\n\nnamespace turbomind {\n\n__inline__ __device__ void\nmma_m8n8k4_row_col(Array<float, 8>& d, const Array<half, 4>& a, const Array<half, 4>& b, Array<float, 8>& c)\n{\n#if TURBOMIND_ARCH_SM70\n    uint32_t const* A = reinterpret_cast<uint32_t const*>(&a);\n    uint32_t const* B = reinterpret_cast<uint32_t const*>(&b);\n    // clang-format off\n    asm volatile(\n        \"mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32\"\n        \"{%0,  %1,  %2,  %3,  %4,  %5,  %6,  %7},\"\n        \"{%8,  %9},\"\n        \"{%10, %11},\"\n        \"{%12, %13, %14, %15, %16, %17, %18, %19};\"\n        : \"=f\"(d[0]), \"=f\"(d[1]), \"=f\"(d[2]), \"=f\"(d[3]), \"=f\"(d[4]), \"=f\"(d[5]), \"=f\"(d[6]), \"=f\"(d[7])\n        : \"r\"(A[0]), \"r\"(A[1]),\n          \"r\"(B[0]), \"r\"(B[1]),\n          \"f\"(c[0]), \"f\"(c[1]), \"f\"(c[2]), \"f\"(c[3]), \"f\"(c[4]), \"f\"(c[5]), \"f\"(c[6]), \"f\"(c[7]));\n// clang-format on\n#endif\n}\n\n__inline__ __device__ void\nmma_m8n8k4_row_row(Array<float, 8>& d, const Array<half, 4>& a, const Array<half, 4>& b, Array<float, 8>& c)\n{\n#if TURBOMIND_ARCH_SM70\n    uint32_t const* A = reinterpret_cast<uint32_t const*>(&a);\n    uint32_t const* B = reinterpret_cast<uint32_t const*>(&b);\n    // clang-format off\n    asm volatile(\n        \"mma.sync.aligned.m8n8k4.row.row.f32.f16.f16.f32\"\n        \"{%0,  %1,  %2,  %3,  %4,  %5,  %6,  %7},\"\n        \"{%8,  %9},\"\n        \"{%10, %11},\"\n        \"{%12, %13, %14, %15, %16, %17, %18, %19};\"\n        : \"=f\"(d[0]), \"=f\"(d[1]), \"=f\"(d[2]), \"=f\"(d[3]), \"=f\"(d[4]), \"=f\"(d[5]), \"=f\"(d[6]), \"=f\"(d[7])\n        : \"r\"(A[0]), \"r\"(A[1]),\n          \"r\"(B[0]), \"r\"(B[1]),\n          \"f\"(c[0]), \"f\"(c[1]), \"f\"(c[2]), \"f\"(c[3]), \"f\"(c[4]), \"f\"(c[5]), \"f\"(c[6]), \"f\"(c[7]));\n// clang-format on\n#endif\n}\n\n__inline__ __device__ void\nmma_m16n8k8_row_col(Array<float, 4>& d, const Array<half, 4>& a, const Array<half, 2>& b, Array<float, 4>& c)\n{\n#if TURBOMIND_ARCH_SM75\n    uint32_t const* A = reinterpret_cast<uint32_t const*>(&a);\n    uint32_t const* B = reinterpret_cast<uint32_t const*>(&b);\n    float const*    C = reinterpret_cast<float const*>(&c);\n    float*          D = reinterpret_cast<float*>(&d);\n    asm volatile(\"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32  {%0,%1,%2,%3}, \"\n                 \"{%4,%5}, {%6}, {%7,%8,%9,%10};\\n\"\n                 : \"=f\"(D[0]), \"=f\"(D[1]), \"=f\"(D[2]), \"=f\"(D[3])\n                 : \"r\"(A[0]), \"r\"(A[1]), \"r\"(B[0]), \"f\"(C[0]), \"f\"(C[1]), \"f\"(C[2]), \"f\"(C[3]));\n#else\n    assert(TURBOMIND_ARCH_SM75);\n#endif\n}\n\n__inline__ __device__ void\nmma_m16n8k8_row_col(Array<half, 4>& d, const Array<half, 4>& a, const Array<half, 2>& b, Array<half, 4>& c)\n{\n#if TURBOMIND_ARCH_SM75\n    uint32_t const* A = reinterpret_cast<uint32_t const*>(&a);\n    uint32_t const* B = reinterpret_cast<uint32_t const*>(&b);\n    uint32_t const* C = reinterpret_cast<uint32_t const*>(&c);\n    uint32_t*       D = reinterpret_cast<uint32_t*>(&d);\n    asm volatile(\"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16  {%0,%1}, \"\n                 \"{%2,%3}, {%4}, {%5,%6};\\n\"\n                 : \"=r\"(D[0]), \"=r\"(D[1])\n                 : \"r\"(A[0]), \"r\"(A[1]), \"r\"(B[0]), \"r\"(C[0]), \"r\"(C[1]));\n#else\n    assert(TURBOMIND_ARCH_SM75);\n#endif\n}\n\n__inline__ __device__ void mma_m16n8k8_row_col(Array<float, 4>&             d,\n                                               const Array<nv_bfloat16, 4>& a,\n                                               const Array<nv_bfloat16, 2>& b,\n                                               Array<float, 4>&             c)\n{\n#if TURBOMIND_ARCH_SM80\n    uint32_t const* A = reinterpret_cast<uint32_t const*>(&a);\n    uint32_t const* B = reinterpret_cast<uint32_t const*>(&b);\n    float const*    C = reinterpret_cast<float const*>(&c);\n    float*          D = reinterpret_cast<float*>(&d);\n    asm volatile(\"mma.sync.aligned.m16n8k8.row.col.f32.bf16.bf16.f32  {%0,%1,%2,%3}, \"\n                 \"{%4,%5}, {%6}, {%7,%8,%9,%10};\\n\"\n                 : \"=f\"(D[0]), \"=f\"(D[1]), \"=f\"(D[2]), \"=f\"(D[3])\n                 : \"r\"(A[0]), \"r\"(A[1]), \"r\"(B[0]), \"f\"(C[0]), \"f\"(C[1]), \"f\"(C[2]), \"f\"(C[3]));\n#else\n    assert(TURBOMIND_ARCH_SM80);\n#endif\n}\n\n__inline__ __device__ void mma_m16n8k8_row_col(Array<nv_bfloat16, 4>&       d,\n                                               const Array<nv_bfloat16, 4>& a,\n                                               const Array<nv_bfloat16, 2>& b,\n                                               Array<nv_bfloat16, 4>&       c)\n{\n#if TURBOMIND_ARCH_SM80\n    uint32_t const* A = reinterpret_cast<uint32_t const*>(&a);\n    uint32_t const* B = reinterpret_cast<uint32_t const*>(&b);\n    uint32_t const* C = reinterpret_cast<uint32_t const*>(&c);\n    uint32_t*       D = reinterpret_cast<uint32_t*>(&d);\n    asm volatile(\"mma.sync.aligned.m16n8k8.row.col.bf16.bf16.bf16.bf16  {%0,%1}, \"\n                 \"{%2,%3}, {%4}, {%5,%6};\\n\"\n                 : \"=r\"(D[0]), \"=r\"(D[1])\n                 : \"r\"(A[0]), \"r\"(A[1]), \"r\"(B[0]), \"r\"(C[0]), \"r\"(C[1]));\n#else\n    assert(TURBOMIND_ARCH_SM80);\n#endif\n}\n\n__inline__ __device__ void\nmma_m16n8k16_row_col(Array<float, 4>& d, const Array<half, 8>& a, const Array<half, 4>& b, Array<float, 4>& c)\n{\n#if TURBOMIND_ARCH_SM80\n    uint32_t const* A = reinterpret_cast<uint32_t const*>(&a);\n    uint32_t const* B = reinterpret_cast<uint32_t const*>(&b);\n    float const*    C = reinterpret_cast<float const*>(&c);\n    float*          D = reinterpret_cast<float*>(&d);\n    asm volatile(\n        \"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32  {%0,%1,%2,%3}, \"\n        \"{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\\n\"\n        : \"=f\"(D[0]), \"=f\"(D[1]), \"=f\"(D[2]), \"=f\"(D[3])\n        : \"r\"(A[0]), \"r\"(A[1]), \"r\"(A[2]), \"r\"(A[3]), \"r\"(B[0]), \"r\"(B[1]), \"f\"(C[0]), \"f\"(C[1]), \"f\"(C[2]), \"f\"(C[3]));\n#else\n    const Array<half, 4>* _a = (const Array<half, 4>*)&a;\n    const Array<half, 2>* _b = (const Array<half, 2>*)&b;\n    mma_m16n8k8_row_col(d, _a[0], _b[0], c);\n    mma_m16n8k8_row_col(d, _a[1], _b[1], d);\n#endif\n}\n\n__inline__ __device__ void\nmma_m16n8k16_row_col(Array<half, 4>& d, const Array<half, 8>& a, const Array<half, 4>& b, Array<half, 4>& c)\n{\n#if TURBOMIND_ARCH_SM80\n    uint32_t const* A = reinterpret_cast<uint32_t const*>(&a);\n    uint32_t const* B = reinterpret_cast<uint32_t const*>(&b);\n    uint32_t const* C = reinterpret_cast<uint32_t const*>(&c);\n    uint32_t*       D = reinterpret_cast<uint32_t*>(&d);\n    asm volatile(\"mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16  {%0,%1}, \"\n                 \"{%2,%3,%4,%5}, {%6,%7}, {%8,%9};\\n\"\n                 : \"=r\"(D[0]), \"=r\"(D[1])\n                 : \"r\"(A[0]), \"r\"(A[1]), \"r\"(A[2]), \"r\"(A[3]), \"r\"(B[0]), \"r\"(B[1]), \"r\"(C[0]), \"r\"(C[1]));\n#else\n    const Array<half, 4>* _a = (const Array<half, 4>*)&a;\n    const Array<half, 2>* _b = (const Array<half, 2>*)&b;\n    mma_m16n8k8_row_col(d, _a[0], _b[0], c);\n    mma_m16n8k8_row_col(d, _a[1], _b[1], d);\n#endif\n}\n\n__inline__ __device__ void mma_m16n8k16_row_col(Array<float, 4>&             d,\n                                                const Array<nv_bfloat16, 8>& a,\n                                                const Array<nv_bfloat16, 4>& b,\n                                                Array<float, 4>&             c)\n{\n#if TURBOMIND_ARCH_SM80\n    uint32_t const* A = reinterpret_cast<uint32_t const*>(&a);\n    uint32_t const* B = reinterpret_cast<uint32_t const*>(&b);\n    float const*    C = reinterpret_cast<float const*>(&c);\n    float*          D = reinterpret_cast<float*>(&d);\n    asm volatile(\n        \"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32  {%0,%1,%2,%3}, \"\n        \"{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\\n\"\n        : \"=f\"(D[0]), \"=f\"(D[1]), \"=f\"(D[2]), \"=f\"(D[3])\n        : \"r\"(A[0]), \"r\"(A[1]), \"r\"(A[2]), \"r\"(A[3]), \"r\"(B[0]), \"r\"(B[1]), \"f\"(C[0]), \"f\"(C[1]), \"f\"(C[2]), \"f\"(C[3]));\n#else\n    const Array<nv_bfloat16, 4>* _a = (const Array<nv_bfloat16, 4>*)&a;\n    const Array<nv_bfloat16, 2>* _b = (const Array<nv_bfloat16, 2>*)&b;\n    mma_m16n8k8_row_col(d, _a[0], _b[0], c);\n    mma_m16n8k8_row_col(d, _a[1], _b[1], d);\n#endif\n}\n\n__inline__ __device__ void mma_m16n8k16_row_col(Array<nv_bfloat16, 4>&       d,\n                                                const Array<nv_bfloat16, 8>& a,\n                                                const Array<nv_bfloat16, 4>& b,\n                                                Array<nv_bfloat16, 4>&       c)\n{\n#if TURBOMIND_ARCH_SM80\n    uint32_t const* A = reinterpret_cast<uint32_t const*>(&a);\n    uint32_t const* B = reinterpret_cast<uint32_t const*>(&b);\n    uint32_t const* C = reinterpret_cast<uint32_t const*>(&c);\n    uint32_t*       D = reinterpret_cast<uint32_t*>(&d);\n    asm volatile(\"mma.sync.aligned.m16n8k16.row.col.bf16.bf16.bf16.bf16  {%0,%1}, \"\n                 \"{%2,%3,%4,%5}, {%6,%7}, {%8,%9};\\n\"\n                 : \"=r\"(D[0]), \"=r\"(D[1])\n                 : \"r\"(A[0]), \"r\"(A[1]), \"r\"(A[2]), \"r\"(A[3]), \"r\"(B[0]), \"r\"(B[1]), \"r\"(C[0]), \"r\"(C[1]));\n#else\n    const Array<nv_bfloat16, 4>* _a = (const Array<nv_bfloat16, 4>*)&a;\n    const Array<nv_bfloat16, 2>* _b = (const Array<nv_bfloat16, 2>*)&b;\n    mma_m16n8k8_row_col(d, _a[0], _b[0], c);\n    mma_m16n8k8_row_col(d, _a[1], _b[1], d);\n#endif\n}\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/kernels/core/pipe_iter.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\nnamespace turbomind {\n\ntemplate<int Stages, int Step = 1>\nstruct PipeIter {\n    static constexpr int kMaxStep = Stages * Step;\n\n    int r = 0;\n    int w = kMaxStep - Step;\n\n    __inline__ __device__ PipeIter& operator++()\n    {\n        w = r;\n        r += Step;\n        if (r == kMaxStep) {\n            r -= kMaxStep;\n        }\n        return *this;\n    }\n};\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/kernels/core/smem.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include \"src/turbomind/kernels/core/array.h\"\n#include \"src/turbomind/kernels/core/common.h\"\n#include <cassert>\n\nnamespace turbomind {\n\n__inline__ __device__ uint32_t cast_smem_ptr_to_uint(void const* const ptr)\n{\n    return (uint32_t)__cvta_generic_to_shared(ptr);\n}\n\n__inline__ __device__ void ldmatrix_m8n8_x4_b16(uint& d0, uint& d1, uint& d2, uint& d3, uint32_t smem_int_ptr)\n{\n#if TURBOMIND_ARCH_SM75\n    asm volatile(\"ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\\n\"\n                 : \"=r\"(d0), \"=r\"(d1), \"=r\"(d2), \"=r\"(d3)\n                 : \"r\"(smem_int_ptr));\n#else\n    assert(TURBOMIND_ARCH_SM75);\n#endif\n}\n\n__inline__ __device__ void ldsm_x4_trans(uint& d0, uint& d1, uint& d2, uint& d3, uint32_t smem_int_ptr)\n{\n#if TURBOMIND_ARCH_SM75\n    asm volatile(\"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0,%1,%2,%3}, [%4];\\n\"\n                 : \"=r\"(d0), \"=r\"(d1), \"=r\"(d2), \"=r\"(d3)\n                 : \"r\"(smem_int_ptr));\n#else\n    assert(TURBOMIND_ARCH_SM75);\n#endif\n}\n\n__inline__ __device__ void ldmatrix_m8n8_x2_b16(uint& d0, uint& d1, uint32_t smem_int_ptr)\n{\n#if TURBOMIND_ARCH_SM75\n    asm volatile(\"ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\\n\" : \"=r\"(d0), \"=r\"(d1) : \"r\"(smem_int_ptr));\n#else\n    assert(TURBOMIND_ARCH_SM75);\n#endif\n}\n\n__inline__ __device__ void ldsm_x2_trans(uint& d0, uint& d1, uint32_t smem_int_ptr)\n{\n#if TURBOMIND_ARCH_SM75\n    asm volatile(\"ldmatrix.sync.aligned.m8n8.x2.trans.shared.b16 {%0,%1}, [%2];\\n\"\n                 : \"=r\"(d0), \"=r\"(d1)\n                 : \"r\"(smem_int_ptr));\n#else\n    assert(TURBOMIND_ARCH_SM75);\n#endif\n}\n\n__inline__ __device__ void ldmatrix_m8n8_x1_b16(uint& d0, uint32_t smem_int_ptr)\n{\n#if TURBOMIND_ARCH_SM75\n    asm volatile(\"ldmatrix.sync.aligned.m8n8.x1.shared.b16 %0, [%1];\\n\" : \"=r\"(d0) : \"r\"(smem_int_ptr));\n#else\n    assert(TURBOMIND_ARCH_SM75);\n#endif\n}\n\n__inline__ __device__ void ldsm_x1_trans(uint& d0, uint32_t smem_int_ptr)\n{\n#if TURBOMIND_ARCH_SM75\n    asm volatile(\"ldmatrix.sync.aligned.m8n8.x1.trans.shared.b16 %0, [%1];\\n\" : \"=r\"(d0) : \"r\"(smem_int_ptr));\n#else\n    assert(TURBOMIND_ARCH_SM75);\n#endif\n}\n\n__inline__ __device__ void ldsm_x4(Array<uint32_t, 4>& d, uint32_t smem_int_ptr)\n{\n    ldmatrix_m8n8_x4_b16(d[0], d[1], d[2], d[3], smem_int_ptr);\n}\n\n__inline__ __device__ void ldsm_x2(Array<uint32_t, 2>& d, uint32_t smem_int_ptr)\n{\n    ldmatrix_m8n8_x2_b16(d[0], d[1], smem_int_ptr);\n}\n\n__inline__ __device__ void ldsm_x1(Array<uint32_t, 1>& d, uint32_t smem_int_ptr)\n{\n    ldmatrix_m8n8_x1_b16(d[0], smem_int_ptr);\n}\n\n__inline__ __device__ void ldsm_x4_trans(Array<uint32_t, 4>& d, uint32_t smem_int_ptr)\n{\n    ldsm_x4_trans(d[0], d[1], d[2], d[3], smem_int_ptr);\n}\n\n__inline__ __device__ void ldsm_x2_trans(Array<uint32_t, 2>& d, uint32_t smem_int_ptr)\n{\n    ldsm_x2_trans(d[0], d[1], smem_int_ptr);\n}\n\n__inline__ __device__ void ldsm_x1_trans(Array<uint32_t, 1>& d, uint32_t smem_int_ptr)\n{\n    ldsm_x1_trans(d[0], smem_int_ptr);\n}\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/kernels/core/sub_byte_ptr.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include \"src/turbomind/kernels/core/data_type.h\"\n\nnamespace turbomind {\n\ntemplate<class T>\nstruct SubBytePtr {\n\n    constexpr SubBytePtr() = default;\n\n    constexpr __host__ __device__ explicit SubBytePtr(T* ptr): ptr_((char*)ptr) {}\n\n    constexpr __host__ __device__ SubBytePtr(char* ptr): ptr_(ptr) {}\n\n    __host__ __device__ T& operator[](int i)\n    {\n        return *reinterpret_cast<T*>(ptr_ + i * bitsof<T> / bitsof<char>);\n    }\n\n    friend __host__ __device__ SubBytePtr operator+(const SubBytePtr a, int n)\n    {\n        return SubBytePtr{a.ptr_ + n * bitsof<T> / bitsof<char>};\n    }\n\n    friend __host__ __device__ SubBytePtr operator+(int n, const SubBytePtr a)\n    {\n        return a + n;\n    }\n\n    friend __host__ __device__ bool operator==(const SubBytePtr& a, const SubBytePtr& b)\n    {\n        return a.ptr_ == b.ptr_;\n    }\n\n    __host__ __device__ explicit operator T*() const\n    {\n        return (T*)ptr_;\n    }\n\n    char* ptr_;\n};\n\ntemplate<class T>\nstruct get_pointer_type_t<T, std::enable_if_t<bitsof<T> % 8 != 0>> {\n    using type = SubBytePtr<T>;\n};\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/kernels/core/sync.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\nnamespace turbomind {\n\n__inline__ __device__ int sem_fetch(int* lock, bool pred)\n{\n    int state{};\n    if (pred) {\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700\n        asm volatile(\"ld.global.acquire.gpu.b32 %0, [%1];\\n\" : \"=r\"(state) : \"l\"(lock));\n#else\n        asm volatile(\"ld.global.cg.b32 %0, [%1];\\n\" : \"=r\"(state) : \"l\"(lock));\n#endif\n    }\n    return state;\n}\n\n__inline__ __device__ void sem_wait(int* lock, int status, bool pred)\n{\n    int state = 0;\n    while (__syncthreads_and(state != status)) {\n        state = sem_fetch(lock, pred);\n    }\n\n    __syncthreads();  // memory fence\n}\n\n__inline__ __device__ void sem_wait_many(int* lock, int count, bool pred)\n{\n    int state = 0;\n    while (__syncthreads_count(state) != count) {\n        state = sem_fetch(lock, pred);\n    }\n\n    __syncthreads();  // memory fence\n}\n\n__inline__ __device__ void sem_post(int* lock, int status, bool pred)\n{\n    __syncthreads();  // memory fence\n\n    if (pred) {\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700\n        asm volatile(\"st.global.release.gpu.b32 [%0], %1;\\n\" : : \"l\"(lock), \"r\"(status));\n#else\n        asm volatile(\"st.global.cg.b32 [%0], %1;\\n\" : : \"l\"(lock), \"r\"(status));\n#endif\n    }\n}\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/kernels/core/thread_map.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include \"src/turbomind/kernels/core/common.h\"\n#include \"src/turbomind/kernels/core/math.h\"\n\n#include <iostream>\n\nnamespace turbomind {\n\ntemplate<int C, int S, int AccessC, int WarpCount>\nstruct ThreadMapQ {\n    static constexpr int kWarpCount = WarpCount;\n    static constexpr int kAccessC   = AccessC;\n\n    static constexpr int kWarpThreadC = C / kAccessC;\n    static constexpr int kWarpThreadS = WARP_SIZE / kWarpThreadC;\n\n    static_assert(kWarpThreadC <= WARP_SIZE);\n\n    static constexpr int kWarpAccessC = kWarpThreadC * kAccessC;  // C\n    static constexpr int kWarpAccessS = kWarpThreadS;\n\n    static constexpr int kWarpIterC = C / kWarpAccessC;  // 1\n    static constexpr int kWarpIterS = S / kWarpAccessS;\n\n    static constexpr int kWarpC = 1;\n    static constexpr int kWarpS = kWarpCount;\n\n    static constexpr int kIterC = kWarpIterC / kWarpC;  // 1\n    static constexpr int kIterS = std::max(kWarpIterS / kWarpS, 1);\n\n    static constexpr int kFootprintC = kWarpAccessC * kIterC;  // C\n    static constexpr int kFootprintS = kWarpAccessS * kIterS;\n\n    static constexpr int kDeltaC = kWarpAccessC;\n    static constexpr int kDeltaS = kWarpAccessS;\n\n    __device__ static int2 get_offset(int warp_id, int lane_id)\n    {\n        int warp_offset_c = warp_id % kWarpC;\n        int warp_offset_s = warp_id / kWarpC;\n\n        int warp_thread_offset_c = lane_id % kWarpThreadC;\n        int warp_thread_offset_s = lane_id / kWarpThreadC;\n\n        int cta_thread_offset_c = kFootprintC * warp_offset_c + warp_thread_offset_c * kAccessC;\n        int cta_thread_offset_s = kFootprintS * warp_offset_s + warp_thread_offset_s;\n\n        return {cta_thread_offset_c, cta_thread_offset_s};\n    }\n};\n\ntemplate<int DimC, int DimS, int AccessC, int WarpCount, int WarpThreadC = lowbit(DimC) / AccessC, int WarpC = 1>\nstruct RakedThreadMap {\n    static constexpr int kDimC = DimC;\n    static constexpr int kDimS = DimS;\n\n    static constexpr int kWarpCount = WarpCount;\n    static constexpr int kAccessC   = AccessC;\n\n    static constexpr int kWarpThreadC = WarpThreadC;\n    static constexpr int kWarpThreadS = WARP_SIZE / kWarpThreadC;\n\n    static_assert(WARP_SIZE % kWarpThreadC == 0);\n\n    static constexpr int kWarpAccessC = kWarpThreadC * kAccessC;\n    static constexpr int kWarpAccessS = kWarpThreadS;\n\n    static constexpr int kWarpIterC = cdiv(kDimC, kWarpAccessC);\n    static constexpr int kWarpIterS = cdiv(kDimS, kWarpAccessS);\n\n    static constexpr int kWarpC = WarpC;\n    static constexpr int kWarpS = kWarpCount / kWarpC;\n\n    static_assert(kWarpCount % kWarpC == 0);\n\n    static constexpr int kIterC = cdiv(kWarpIterC, kWarpC);\n    static constexpr int kIterS = cdiv(kWarpIterS, kWarpS);\n\n    // Allow partial tile when there is ONLY 1 iteration\n    static_assert(kDimC % kWarpAccessC == 0 || kIterC == 1);\n\n    static constexpr bool kPartialC = kDimC % kWarpAccessC != 0;\n\n    static constexpr int kFootprintC = kWarpAccessC * kIterC;\n    static constexpr int kFootprintS = kWarpAccessS * kIterS;\n\n    static constexpr int kDeltaC = kWarpAccessC;\n    static constexpr int kDeltaS = kWarpAccessS;\n\n    // static constexpr int kDeltaC = kWarpAccessC * kWarpC;\n    // static constexpr int kDeltaS = kWarpAccessS * kWarpS;\n\n    __device__ static int2 get_offset(int warp_id, int lane_id)\n    {\n        int warp_offset_c = warp_id % kWarpC;\n        int warp_offset_s = warp_id / kWarpC;\n\n        int warp_thread_offset_c = lane_id % kWarpThreadC;\n        int warp_thread_offset_s = lane_id / kWarpThreadC;\n\n        int cta_thread_offset_c = kFootprintC * warp_offset_c + warp_thread_offset_c * kAccessC;\n        int cta_thread_offset_s = kFootprintS * warp_offset_s + warp_thread_offset_s;\n\n        // int cta_thread_offset_c = kWarpAccessC * warp_offset_c + warp_thread_offset_c * kAccessC;\n        // int cta_thread_offset_s = kWarpAccessS * warp_offset_s + warp_thread_offset_s;\n\n        return {cta_thread_offset_c, cta_thread_offset_s};\n    }\n};\n\nnamespace {\n\ntemplate<class TMap>\nvoid Print(TMap)\n{\n    std::cout << \"     warps: \" << TMap::kWarpCount << \"\\n\";\n    std::cout << \"     shape: (\" << TMap::kDimC << \", \" << TMap::kDimS << \")\\n\";\n    std::cout << \"    access: (\" << TMap::kAccessC << \", \" << 1 << \")\\n\";\n    std::cout << \"warpThread: (\" << TMap::kWarpThreadC << \", \" << TMap::kWarpThreadS << \")\\n\";\n    std::cout << \"warpAccess: (\" << TMap::kWarpAccessC << \", \" << TMap::kWarpAccessS << \")\\n\";\n    std::cout << \"  warpIter: (\" << TMap::kWarpIterC << \", \" << TMap::kWarpIterS << \")\\n\";\n    std::cout << \"      warp: (\" << TMap::kWarpC << \", \" << TMap::kWarpS << \")\\n\";\n    std::cout << \"      iter: (\" << TMap::kIterC << \", \" << TMap::kIterS << \")\\n\";\n    std::cout << \" footprint: (\" << TMap::kFootprintC << \", \" << TMap::kFootprintS << \")\\n\";\n    std::cout << \"     delta: (\" << TMap::kDeltaC << \", \" << TMap::kDeltaS << \")\\n\";\n    std::cout << \"  partialC: \" << TMap::kPartialC << \"\\n\";\n}\n\n}  // namespace\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/kernels/decoding_kernels.cu",
    "content": "/*\n * Copyright (c) 2020-2023, NVIDIA CORPORATION.  All rights reserved.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include \"src/turbomind/kernels/decoding_kernels.h\"\n#include \"src/turbomind/kernels/reduce_kernel_utils.cuh\"\n#include \"src/turbomind/utils/cuda_type_utils.cuh\"\n#include \"src/turbomind/utils/cuda_utils.h\"\n\nnamespace turbomind {\n\n// PROMPT_SRC: 0 --> no prompts, 1 --> from loaded prompts, 2 --> from request prompts\ntemplate<typename T>\n__global__ void embeddingLookupPosEncoding(T*            from_tensor,\n                                           const T*      embedding_table,\n                                           const T*      position_encoding,\n                                           const int*    all_ids,\n                                           const int*    padding_count,\n                                           const int*    input_lengths,\n                                           const int     local_token_num,\n                                           const int64_t hidden_units,\n                                           const int     step,\n                                           const int     max_input_length,\n                                           const int     token_num,\n                                           const int     ite,\n                                           const T       scale)\n{\n    // 1. lookup from embedding table\n    // 2. multiply scale\n    // 3. add the position encoding\n    const int id_offset = step * token_num + ite * local_token_num;\n\n    const bool use_padding_count = padding_count != nullptr;\n    const bool use_input_len     = input_lengths != nullptr;\n\n    for (int64_t index = blockIdx.x * blockDim.x + threadIdx.x; index < local_token_num * hidden_units;\n         index += blockDim.x * gridDim.x) {\n        const int row_index   = index / hidden_units;\n        const int col_index   = index % hidden_units;\n        int       step_offset = step;\n        if (use_padding_count) {\n            step_offset -= padding_count[row_index];\n        }\n        else if (use_input_len) {\n            step_offset -= max_input_length - input_lengths[row_index];\n        }\n        step_offset *= hidden_units;\n\n        T val = embedding_table[all_ids[id_offset + row_index] * hidden_units + col_index] * scale;\n        val   = val + position_encoding[step_offset + col_index];\n\n        from_tensor[index] = val;\n    }\n}\n\n// No absolute position embedding\n// PROMPT_SRC: 0 --> no prompts, 1 --> from loaded prompts, 2 --> from request prompts\ntemplate<typename T, int PROMPT_SRC>\n__global__ void embeddingLookup(T*                    from_tensor,\n                                const T*              embedding_table,\n                                const int*            all_ids,\n                                pPromptTuningParam<T> prompt_param,\n                                const int             local_token_num,\n                                const int64_t         hidden_units,\n                                const int             step,\n                                const int             token_num,\n                                const int             ite,\n                                const int             seq_len,\n                                const T               scale)\n{\n    // 1. lookup from embedding table\n    // 2. multiply scale\n    const int id_offset = step * token_num + ite * local_token_num;\n\n    for (int64_t index = blockIdx.x * blockDim.x + threadIdx.x; index < local_token_num * hidden_units;\n         index += blockDim.x * gridDim.x) {\n\n        const int word_index     = index / hidden_units;\n        const int word_index_row = word_index / seq_len;  // batch_id\n        const int col_index      = index % hidden_units;\n        const int input_id       = all_ids == nullptr ? word_index : all_ids[id_offset + word_index];\n        const int prompt_id      = input_id - prompt_param.p_prompt_tuning_id_start;\n        T         embedding      = (T)0.0f;\n        if (PROMPT_SRC > 0 && prompt_id >= 0) {\n            if (PROMPT_SRC == 1) {\n                // from loaded prompt embedding tables\n                embedding =\n                    prompt_param.p_prompt_tuning_batch_weights[word_index_row][prompt_id * hidden_units + col_index];\n            }\n            else {\n                // from request prompt embedding\n                embedding =\n                    prompt_param\n                        .request_prompt_embedding[word_index_row * prompt_param.request_prompt_max_length * hidden_units\n                                                  + prompt_id * hidden_units + col_index];\n            }\n        }\n        else {\n            embedding = embedding_table[input_id * hidden_units + col_index];\n        }\n        from_tensor[index] = embedding * scale;\n    }\n}\n\n#define EMBEDDING_LOOKUP(PROMPT_SRC)                                                                                   \\\n    embeddingLookup<T, PROMPT_SRC><<<grid, block, 0, stream>>>(from_tensor,                                            \\\n                                                               embedding_table,                                        \\\n                                                               all_ids,                                                \\\n                                                               prompt_param,                                           \\\n                                                               local_token_num,                                        \\\n                                                               hidden_units,                                           \\\n                                                               step,                                                   \\\n                                                               token_num,                                              \\\n                                                               ite,                                                    \\\n                                                               seq_len,                                                \\\n                                                               scale);\n\n/* Adapter function for invokeEmbeddingLookupPosEncoding{PadCount,InputLen} */\ntemplate<typename T>\nvoid invokeEmbeddingLookupPosEncoding(T*                    from_tensor,\n                                      const T*              embedding_table,\n                                      const T*              position_encoding,\n                                      const int*            all_ids,\n                                      const int*            padding_count,\n                                      const int*            input_lengths,\n                                      pPromptTuningParam<T> prompt_param,\n                                      const int             local_token_num,\n                                      const int             hidden_units,\n                                      const T               scale,\n                                      const int             step,\n                                      const int             max_input_length,\n                                      const int             token_num,\n                                      const int             ite,\n                                      const int             seq_len,\n                                      cudaStream_t          stream)\n{\n    dim3 grid(min(local_token_num, 65536));\n    dim3 block(min(hidden_units, 1024));\n    if (position_encoding != nullptr) {\n        FT_CHECK_WITH_INFO(prompt_param.use_request_p_prompt_embedding == false\n                               && prompt_param.p_prompt_tuning_batch_weights == nullptr,\n                           fmtstr(\"embeddingLookupPosEncoding still not support prompt tuning\"));\n        embeddingLookupPosEncoding<T><<<grid, block, 0, stream>>>(from_tensor,\n                                                                  embedding_table,\n                                                                  position_encoding,\n                                                                  all_ids,\n                                                                  padding_count,\n                                                                  input_lengths,\n                                                                  local_token_num,\n                                                                  hidden_units,\n                                                                  step,\n                                                                  max_input_length,\n                                                                  token_num,\n                                                                  ite,\n                                                                  scale);\n    }\n    else {\n        if (prompt_param.use_request_p_prompt_embedding) {\n            EMBEDDING_LOOKUP(2);\n        }\n        else if (prompt_param.p_prompt_tuning_batch_weights != nullptr) {\n            EMBEDDING_LOOKUP(1);\n        }\n        else {\n            EMBEDDING_LOOKUP(0);\n        }\n    }\n}\n\n#undef EMBEDDING_LOOKUP\n\ntemplate<typename T>\nvoid invokeEmbeddingLookupPosEncodingPadCount(T*                    from_tensor,\n                                              const T*              embedding_table,\n                                              const T*              position_encoding,\n                                              const int*            all_ids,\n                                              const int*            pad_count,\n                                              pPromptTuningParam<T> prompt_param,\n                                              const int             local_token_num,\n                                              const int             hidden_units,\n                                              const T               scale,\n                                              const int             step,\n                                              const int             token_num,\n                                              const int             ite,\n                                              const int             seq_len,\n                                              cudaStream_t          stream)\n{\n    invokeEmbeddingLookupPosEncoding<T>(from_tensor,\n                                        embedding_table,\n                                        position_encoding,\n                                        all_ids,\n                                        pad_count,\n                                        nullptr,\n                                        prompt_param,\n                                        local_token_num,\n                                        hidden_units,\n                                        scale,\n                                        step,\n                                        0,\n                                        token_num,\n                                        ite,\n                                        seq_len,\n                                        stream);\n}\n\n#define INSTANTIATE_LOOKUP_POS_ENCODING_PAD_COUNT(T)                                                                   \\\n    template void invokeEmbeddingLookupPosEncodingPadCount(T*                    from_tensor,                          \\\n                                                           const T*              embedding_table,                      \\\n                                                           const T*              position_encoding,                    \\\n                                                           const int*            all_ids,                              \\\n                                                           const int*            pad_count,                            \\\n                                                           pPromptTuningParam<T> prompt_param,                         \\\n                                                           const int             local_token_num,                      \\\n                                                           const int             hidden_units,                         \\\n                                                           const T               scale,                                \\\n                                                           const int             step,                                 \\\n                                                           const int             token_num,                            \\\n                                                           const int             ite,                                  \\\n                                                           const int             seq_len,                              \\\n                                                           cudaStream_t          stream)\n#ifdef ENABLE_FP32\nINSTANTIATE_LOOKUP_POS_ENCODING_PAD_COUNT(float);\n#endif\nINSTANTIATE_LOOKUP_POS_ENCODING_PAD_COUNT(half);\n#ifdef ENABLE_BF16\nINSTANTIATE_LOOKUP_POS_ENCODING_PAD_COUNT(__nv_bfloat16);\n#endif\n#undef INSTANTIATE_LOOKUP_POS_ENCODING_PAD_COUNT\n\ntemplate<typename T>\n__global__ void paddingEmbedding(T*            padded_embedding_kernel,\n                                 T*            padded_embedding_bias,\n                                 const T*      embedding_kernel,\n                                 const T*      embedding_bias,\n                                 const int64_t hidden_unit,\n                                 const int64_t vocab_size,\n                                 const int64_t vocab_size_padded)\n{\n    for (int64_t id = threadIdx.x + blockIdx.x * blockDim.x; id < hidden_unit * vocab_size_padded;\n         id += blockDim.x * gridDim.x) {\n        int row_id = id / vocab_size_padded;\n        int col_id = id % vocab_size_padded;\n        if (col_id < vocab_size) {\n            padded_embedding_kernel[id] = embedding_kernel[row_id * vocab_size + col_id];\n        }\n        else {\n            padded_embedding_kernel[id] = (T)(0.0f);\n        }\n    }\n\n    for (int id = threadIdx.x + blockIdx.x * blockDim.x; id < vocab_size_padded; id += blockDim.x * gridDim.x) {\n        if (id < vocab_size) {\n            padded_embedding_bias[id] = embedding_bias[id];\n        }\n        else {\n            padded_embedding_bias[id] = (T)(0.0f);\n        }\n    }\n}\n\ntemplate<typename T>\nvoid invokePaddingEmbedding(T*           padded_embedding_kernel,\n                            T*           padded_embedding_bias,\n                            const T*     embedding_kernel,\n                            const T*     embedding_bias,\n                            const int    hidden_unit,\n                            const int    vocab_size,\n                            const int    vocab_size_padded,\n                            cudaStream_t stream)\n{\n    dim3 block(512);\n    dim3 grid((int)(ceil(hidden_unit * vocab_size_padded / 512.)));\n    paddingEmbedding<<<grid, block, 0, stream>>>(padded_embedding_kernel,\n                                                 padded_embedding_bias,\n                                                 embedding_kernel,\n                                                 embedding_bias,\n                                                 hidden_unit,\n                                                 vocab_size,\n                                                 vocab_size_padded);\n}\n\n// template void invokePaddingEmbedding(float*       padded_embedding_kernel,\n//                                      float*       padded_embedding_bias,\n//                                      const float* embedding_kernel,\n//                                      const float* embedding_bias,\n//                                      const int    hidden_unit,\n//                                      const int    vocab_size,\n//                                      const int    vocab_size_padded,\n//                                      cudaStream_t stream);\n\n// template void invokePaddingEmbedding(half*        padded_embedding_kernel,\n//                                      half*        padded_embedding_bias,\n//                                      const half*  embedding_kernel,\n//                                      const half*  embedding_bias,\n//                                      const int    hidden_unit,\n//                                      const int    vocab_size,\n//                                      const int    vocab_size_padded,\n//                                      cudaStream_t stream);\n// #ifdef ENABLE_BF16\n// template void invokePaddingEmbedding(__nv_bfloat16*       padded_embedding_kernel,\n//                                      __nv_bfloat16*       padded_embedding_bias,\n//                                      const __nv_bfloat16* embedding_kernel,\n//                                      const __nv_bfloat16* embedding_bias,\n//                                      const int            hidden_unit,\n//                                      const int            vocab_size,\n//                                      const int            vocab_size_padded,\n//                                      cudaStream_t         stream);\n// #endif\n\ntemplate<typename T>\n__global__ void paddingEmbeddingKernel(T*        padded_embedding_kernel,\n                                       const T*  embedding_kernel,\n                                       const int hidden_unit,\n                                       const int vocab_size,\n                                       const int vocab_size_padded)\n{\n    for (int id = threadIdx.x + blockIdx.x * blockDim.x; id < hidden_unit * vocab_size_padded;\n         id += blockDim.x * gridDim.x) {\n        int row_id = id / hidden_unit;\n        int col_id = id % hidden_unit;\n        if (row_id < vocab_size) {\n            padded_embedding_kernel[id] = embedding_kernel[row_id * hidden_unit + col_id];\n        }\n        else {\n            padded_embedding_kernel[id] = (T)(0.0f);\n        }\n    }\n}\n\ntemplate<typename T>\nvoid invokePaddingEmbeddingKernel(T*           padded_embedding_kernel,\n                                  const T*     embedding_kernel,\n                                  const int    hidden_unit,\n                                  const int    vocab_size,\n                                  const int    vocab_size_padded,\n                                  cudaStream_t stream)\n{\n    dim3 block(512);\n    dim3 grid((int)(ceil(hidden_unit * vocab_size_padded / 512.)));\n    paddingEmbeddingKernel<<<grid, block, 0, stream>>>(\n        padded_embedding_kernel, embedding_kernel, hidden_unit, vocab_size, vocab_size_padded);\n}\n\n// template void invokePaddingEmbeddingKernel(float*       padded_embedding_kernel,\n//                                            const float* embedding_kernel,\n//                                            const int    hidden_unit,\n//                                            const int    vocab_size,\n//                                            const int    vocab_size_padded,\n//                                            cudaStream_t stream);\n\n// template void invokePaddingEmbeddingKernel(half*        padded_embedding_kernel,\n//                                            const half*  embedding_kernel,\n//                                            const int    hidden_unit,\n//                                            const int    vocab_size,\n//                                            const int    vocab_size_padded,\n//                                            cudaStream_t stream);\n\n// #ifdef ENABLE_BF16\n// template void invokePaddingEmbeddingKernel(__nv_bfloat16*       padded_embedding_kernel,\n//                                            const __nv_bfloat16* embedding_kernel,\n//                                            const int            hidden_unit,\n//                                            const int            vocab_size,\n//                                            const int            vocab_size_padded,\n//                                            cudaStream_t         stream);\n// #endif\n\ntemplate<typename T>\n__global__ void plusScalar(T* buf, const T val, const int size)\n{\n    for (int i = threadIdx.x + blockIdx.x * blockDim.x; i < size; i += blockDim.x * gridDim.x) {\n        buf[i] += val;\n    }\n}\n\ntemplate<typename T>\nvoid invokePlusScalar(T* buf, const T val, const int size, cudaStream_t stream)\n{\n    dim3 block(min(256, size));\n    dim3 grid(ceil(size / 256.));\n    plusScalar<<<block, grid, 0, stream>>>(buf, val, size);\n}\n\ntemplate void invokePlusScalar(int* buf, const int val, const int size, cudaStream_t stream);\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/kernels/decoding_kernels.h",
    "content": "/*\n * Copyright (c) 2019-2023, NVIDIA CORPORATION.  All rights reserved.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#pragma once\n\n#include \"gpt_kernels.h\"\n#include <cuda_fp16.h>\n#include <cuda_runtime.h>\n\nnamespace turbomind {\n\n// get token from all_ids at step, then lookup from the embedding table\n// by the token\ntemplate<typename T>\nvoid invokeEmbeddingLookupPosEncodingPadCount(T*                    from_tensor,\n                                              const T*              embedding_table,\n                                              const T*              position_encoding,\n                                              const int*            all_ids,\n                                              const int*            padding_count,\n                                              pPromptTuningParam<T> prompt_param,\n                                              const int             local_token_num,\n                                              const int             hidden_units,\n                                              const T               scale,\n                                              const int             step,\n                                              const int             token_num,\n                                              const int             ite,\n                                              const int             seq_len,\n                                              cudaStream_t          stream);\n\ntemplate<typename T>\nvoid invokeEmbeddingLookupPosEncodingPadCount(T*           from_tensor,\n                                              const T*     embedding_table,\n                                              const T*     position_encoding,\n                                              const int*   all_ids,\n                                              const int*   padding_count,\n                                              const int    local_token_num,\n                                              const int    hidden_units,\n                                              const T      scale,\n                                              const int    step,\n                                              const int    token_num,\n                                              const int    ite,\n                                              cudaStream_t stream)\n{\n    invokeEmbeddingLookupPosEncodingPadCount(from_tensor,\n                                             embedding_table,\n                                             position_encoding,\n                                             all_ids,\n                                             padding_count,\n                                             {(const T**)nullptr, 0, 0, false, nullptr},\n                                             local_token_num,\n                                             hidden_units,\n                                             scale,\n                                             step,\n                                             token_num,\n                                             ite,\n                                             0,\n                                             stream);\n}\n\ntemplate<typename T>\nvoid invokePaddingEmbedding(T*           padded_embedding_kernel,\n                            T*           padded_embedding_bias,\n                            const T*     embedding_kernel,\n                            const T*     embedding_bias,\n                            const int    hidden_unit,\n                            const int    vocab_size,\n                            const int    vocab_size_padded,\n                            cudaStream_t stream);\n\ntemplate<typename T>\nvoid invokePaddingEmbeddingKernel(T*           padded_embedding_kernel,\n                                  const T*     embedding_kernel,\n                                  const int    hidden_unit,\n                                  const int    vocab_size,\n                                  const int    vocab_size_padded,\n                                  cudaStream_t stream);\n\ntemplate<typename T>\nvoid invokePlusScalar(T* buf, const T val, const int size, cudaStream_t stream);\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/CMakeLists.txt",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nadd_library(gemm2\n        gemm.cu\n        kernel.cu\n        registry.cu\n        dispatch_cache.cu\n        gpu_metric.cu\n        convert_v3.cu\n        cast.cu\n        unpack.cu\n        context.cu\n        tma.cu\n        tuner/cache_utils.cu\n        tuner/measurer.cu\n        tuner/sampler.cu\n        tuner/stopping_criterion.cc\n        tuner/params.cc\n        kernel/sm90_16816_4.cu\n        kernel/sm90_16816_8.cu\n        kernel/sm90_16816_16.cu\n        kernel/sm80_16816_4.cu\n        kernel/sm80_16816_8.cu\n        kernel/sm80_16816_16.cu\n        kernel/sm75_16816_4.cu\n        kernel/sm75_16816_8.cu\n        kernel/sm75_16816_16.cu\n        kernel/sm70_884_4.cu\n        kernel/sm70_884_8.cu\n        kernel/sm70_884_16.cu\n        kernel/sm90_64n32_8.cu\n        cublas.cu\n        moe_utils_v2.cu\n        test/test_utils.cu\n)\n\ntarget_link_libraries(gemm2 PRIVATE parser nvidia::cutlass::cutlass CUDA::cuda_driver)\n\n\ntarget_compile_definitions(gemm2 PRIVATE -DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED)\n\ntarget_compile_options(gemm2 PRIVATE\n        $<$<COMPILE_LANGUAGE:CUDA>:\n                -Xptxas=-v\n                --generate-line-info\n                --threads 16>\n)\nset_property(TARGET gemm2 PROPERTY POSITION_INDEPENDENT_CODE ON)\nset_property(TARGET gemm2 PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)\n\n\n\nif (BUILD_TEST)\n        add_executable(test_gemm_v2\n                test/test_gemm_v2.cc\n                ../../models/llama/LlamaLinear.cu\n                ../../models/llama/LlamaDenseWeight.cc\n                test/reference.cu)\n        target_link_libraries(test_gemm_v2 PRIVATE gemm2 core cublas quantization_kernels gpt_kernels)\n\n        add_executable(test_moe_utils test/test_moe_utils.cu test/test_utils.cu)\n        target_link_libraries(test_moe_utils PRIVATE gemm2 core cublas)\n\n        # if (NOT MSVC)\n        #         FetchContent_Declare(\n        #         repo-nvbench\n        #         GIT_REPOSITORY https://github.com/NVIDIA/nvbench.git\n        #         GIT_TAG        d8dced8a64d9ce305add92fa6d274fd49b569b7e\n        #         )\n\n        #         set(NVBench_ENABLE_EXAMPLES OFF)\n        #         set(NVBench_ENABLE_TESTING OFF)\n        #         set(BUILD_SHARED_LIBS OFF)\n\n        #         FetchContent_MakeAvailable(repo-nvbench)\n\n        #         add_executable(gemm_bench\n        #                 test/gemm_bench.cu\n        #                 # test/test_utils.cu\n        #                 test/quantization.cu\n        #                 test/reference.cu)\n        #         target_link_libraries(gemm_bench PRIVATE gemm2 core nvbench::nvbench cublas)\n        # endif ()\nendif ()\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/arch/config_simt.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include \"src/turbomind/kernels/gemm/arch.h\"\n#include \"src/turbomind/kernels/gemm/arch/mma_simt.h\"\n#include \"src/turbomind/kernels/gemm/arch/operand_simt.h\"\n#include \"src/turbomind/kernels/gemm/cta_map.h\"\n#include \"src/turbomind/kernels/gemm/gemm_universal.h\"\n#include \"src/turbomind/kernels/gemm/iterator_sm70.h\"\n#include \"src/turbomind/kernels/gemm/mainloop_sm70.h\"\n#include \"src/turbomind/kernels/gemm/thread_group_map.h\"\n#include \"src/turbomind/kernels/gemm/tiled_mma.h\"\n#include \"src/turbomind/kernels/gemm/types.h\"\n\nnamespace turbomind::gemm {\n\nnamespace simt {\n\ntemplate<class A,\n         class TransformA,\n         class U,\n         class B,\n         class TransformB,\n         class V,\n         Order order_C,\n         class Tc,\n         Striding mode_A,\n         Striding mode_B,\n         Striding mode_C,\n         class CtaMap_>\nstruct Sm75_Simt {\n\n    static_assert(A::SmemCopyAtom::K == B::SmemCopyAtom::K);\n\n    static constexpr int SMEM_M = A::SmemCopyAtom::M / A::SmemCopyAtom::kFragNum;\n    static constexpr int SMEM_N = B::SmemCopyAtom::M / B::SmemCopyAtom::kFragNum;\n    static constexpr int SMEM_K = A::SmemCopyAtom::K;\n\n    template<int CTA_M,\n             int CTA_N,\n             int CTA_K,\n             int TG_M,\n             int TG_N,\n             int TG_K,\n             class PolicyA,\n             class PolicyB,\n             int  Stages,\n             bool SplitK,\n             int  GroupSizeU = 1,\n             int  GroupSizeV = 1,\n             int  TILE_C_M_  = -1,\n             int  TILE_C_N_  = -1>\n    struct Type {\n\n        // (TM, TN, TK) = R(MMA_Atom, SmemCopy_Atom)\n        using MMA_Atom = MMA_SIMT<half>;\n\n        static constexpr int TM = MMA_Atom::M;\n        static constexpr int TN = MMA_Atom::N;\n        static constexpr int TK = MMA_Atom::K;\n\n        using Partition = Blocked<TG_M, TG_N, kColMajor>;\n\n        using MMA_Map = MMA_Map<CTA_M, CTA_N, CTA_K, SMEM_M, SMEM_N, SMEM_K, Partition, TG_K>;\n        using MMA     = Tiled_MMA_v2<MMA_Atom, MMA_Map>;\n\n        // using MMA_Map = RakedThreadGroupMap<CTA_M, CTA_N, CTA_K, TM, TN, TK, WARP_CNT_M, WARP_CNT_N, WARP_CNT_K>;\n\n        using Mainloop = MainloopSm70<MMA,\n                                      A,\n                                      IteratorSm70<mode_A, PolicyA>,\n                                      TransformA,\n                                      U,\n                                      GroupSizeU,\n                                      B,\n                                      IteratorSm70<mode_B, PolicyB>,\n                                      TransformB,\n                                      V,\n                                      GroupSizeV,\n                                      Stages,\n                                      true>;\n\n        static constexpr int TILE_C_M = TILE_C_M_ == -1 ? CTA_M : TILE_C_M_;\n        static constexpr int TILE_C_N = TILE_C_N_ == -1 ? CTA_N : TILE_C_N_;\n\n        using Epilogue = gemm::Epilogue_<Tc,\n                                         CTA_M,\n                                         CTA_N,\n                                         TILE_C_M,\n                                         TILE_C_N,\n                                         MMA::kThreadCount,\n                                         Rearrange<MMA>,\n                                         Operand_C<float, order_C>,\n                                         mode_C,\n                                         SplitK>;\n\n        using Kernel = GemmUniversal<Sm75, Mainloop, Epilogue, CtaMap_>;\n    };\n};\n\n}  // namespace simt\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/arch/config_sm70_s884.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include <numeric>\n\n#include \"src/turbomind/kernels/gemm/arch.h\"\n#include \"src/turbomind/kernels/gemm/arch/mma_sm70.h\"\n#include \"src/turbomind/kernels/gemm/arch/operand_sm70_s884.h\"\n#include \"src/turbomind/kernels/gemm/epilogue.h\"\n#include \"src/turbomind/kernels/gemm/gemm_universal.h\"\n#include \"src/turbomind/kernels/gemm/iterator_sm70.h\"\n#include \"src/turbomind/kernels/gemm/mainloop_sm70.h\"\n#include \"src/turbomind/kernels/gemm/scheduler_sm70.cuh\"\n#include \"src/turbomind/kernels/gemm/thread_group_map.h\"\n#include \"src/turbomind/kernels/gemm/tiled_mma.h\"\n#include \"src/turbomind/kernels/gemm/transform.h\"\n#include \"src/turbomind/kernels/gemm/types.h\"\n\nnamespace turbomind::gemm::sm70_s884 {\n\ntemplate<class A,\n         class TransformA,\n         class U,\n         class B,\n         class TransformB,\n         class V,\n         Order order_C,\n         class Tc,\n         Order raster_order,\n         int   group_axis>\nstruct Sm70_s884 {\n\n    static_assert(A::SmemCopyAtom::K == B::SmemCopyAtom::K);\n\n    static constexpr int SMEM_M = A::SmemCopyAtom::M / A::SmemCopyAtom::kFragNum;\n    static constexpr int SMEM_N = B::SmemCopyAtom::M / B::SmemCopyAtom::kFragNum;\n    static constexpr int SMEM_K = A::SmemCopyAtom::K;\n\n    static constexpr auto MODE_ = group_axis >= 0 ? Striding::kBlocked : Striding::kFlat;\n\n    static constexpr auto MODE_A = group_axis == 0 ? Striding::kIndexed : MODE_;\n    static constexpr auto MODE_B = group_axis == 1 ? Striding::kIndexed : MODE_;\n    static constexpr auto MODE_C = MODE_;\n\n    template<int CTA_M,\n             int CTA_N,\n             int CTA_K,\n             int TG_M,\n             int TG_N,\n             int TG_K,\n             class PolicyA,\n             class PolicyB,\n             int  Stages,\n             bool SplitK,\n             int  GroupSizeU = 1,\n             int  GroupSizeV = 1,\n             int  TILE_C_M_  = -1,\n             int  TILE_C_N_  = -1>\n    struct Type {\n\n        // (TM, TN, TK) = R(MMA_Atom, SmemCopy_Atom)\n        using MMA_Atom = SM70_MMA_884;\n\n        using Partition = Blocked<TG_M, TG_N, kColMajor>;\n        using MMA_Map   = MMA_Map<CTA_M, CTA_N, CTA_K, SMEM_M, SMEM_N, SMEM_K, Partition, TG_K>;\n\n        using MMA = Tiled_MMA_v2<MMA_Atom, MMA_Map>;\n\n        using Mainloop = MainloopSm70<MMA,\n                                      A,\n                                      IteratorSm70<MODE_A, PolicyA>,\n                                      TransformA,\n                                      U,\n                                      GroupSizeU,\n                                      B,\n                                      IteratorSm70<MODE_B, PolicyB>,\n                                      TransformB,\n                                      V,\n                                      GroupSizeV,\n                                      Stages,\n                                      true>;  // FusePrefetch_\n\n        static constexpr int CHUNK_K = std::lcm(std::lcm(GroupSizeU, GroupSizeV), CTA_K);\n\n        using Scheduler = SchedulerSm70<raster_order, CTA_M, CTA_N, CTA_K, CHUNK_K, SplitK, group_axis>;\n\n        static constexpr int TILE_C_M = TILE_C_M_ == -1 ? CTA_M : TILE_C_M_;\n        static constexpr int TILE_C_N = TILE_C_N_ == -1 ? CTA_N : TILE_C_N_;\n\n        using Epilogue = gemm::Epilogue_<Tc,\n                                         CTA_M,\n                                         CTA_N,\n                                         TILE_C_M,\n                                         TILE_C_N,\n                                         MMA::kThreadCount,\n                                         Rearrange<MMA>,\n                                         Operand_C<float, order_C>,\n                                         MODE_C,\n                                         SplitK>;\n\n        using Kernel = GemmUniversal<Sm70, Mainloop, Epilogue, Scheduler>;\n    };\n};\n\ntemplate<Order raster_order>\nusing Config_U4_d = Sm70_s884<typename GetOperand<HMMA_884, OPERAND_A, half, kRowMajor, false>::Operand,\n                              Transform_Default,\n                              VoidOperand,\n                              typename GetOperand<HMMA_884, OPERAND_B, uint4_t, kRowMajor, true>::Operand,\n                              Transform_HMMA_SIMT_B,\n                              typename GetOperand<HMMA_884, OPERAND_V, uint32_t, kColMajor, true>::Operand,\n                              kRowMajor,\n                              half,\n                              raster_order,\n                              -1>;\n\ntemplate<Order raster_order>\nusing Config_U4_g = Sm70_s884<Operand_A<half>,           // A\n                              Transform_Default,         // tarnsform A\n                              VoidOperand,               // U\n                              Operand_B_Pack<uint4_t>,   // B\n                              Transform_HMMA_SIMT_B,     // transform B,\n                              Operand_V_Pack<uint32_t>,  // V\n                              kRowMajor,                 // order_C\n                              half,                      // Tc\n                              raster_order,\n                              0>;\n\ntemplate<Order raster_order, int group_axis = -1>\nusing Config_MXF4 = Sm70_s884<Operand_A<half>,             // A\n                              Transform_Default,           // tarnsform A\n                              VoidOperand,                 // U\n                              Operand_B_Pack<fp4_e2m1_t>,  // B\n                              Transform_HMMA_SIMT_B,       // transform B,\n                              Operand_V_Pack<uint8_t>,     // V\n                              kRowMajor,                   // order_C\n                              half,                        // Tc\n                              raster_order,\n                              group_axis>;\n\ntemplate<Order raster_order, int group_axis = -1>\nusing Config_E4M3 = Sm70_s884<Operand_A<half>,             // A\n                              Transform_Default,           // tarnsform A\n                              VoidOperand,                 // U\n                              Operand_B_Pack<fp8_e4m3_t>,  // B\n                              Transform_HMMA_SIMT_B,       // transform B,\n                              Operand_V_Pack<uint16_t>,    // V\n                              kRowMajor,                   // order_C\n                              half,                        // Tc\n                              raster_order,\n                              group_axis>;\n\ntemplate<Order raster_order, int group_axis = -1>\nusing Config_F16 = Sm70_s884<Operand_A<half>,       // A\n                             Transform_Default,     // tarnsform A\n                             VoidOperand,           // U\n                             Operand_B_Pack<half>,  // B\n                             Transform_Default,     // transform B\n                             VoidOperand,           // V\n                             kRowMajor,             // order_C\n                             half,                  // Tc\n                             raster_order,\n                             group_axis>;\n\n}  // namespace turbomind::gemm::sm70_s884\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/arch/config_sm75_s16816.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include <numeric>\n\n#include \"src/turbomind/kernels/gemm/arch.h\"\n#include \"src/turbomind/kernels/gemm/arch/mma_sm80.h\"\n#include \"src/turbomind/kernels/gemm/arch/operand_sm80_s16816.h\"\n#include \"src/turbomind/kernels/gemm/epilogue.h\"\n#include \"src/turbomind/kernels/gemm/gemm_universal.h\"\n#include \"src/turbomind/kernels/gemm/iterator_sm70.h\"\n#include \"src/turbomind/kernels/gemm/mainloop_sm70.h\"\n#include \"src/turbomind/kernels/gemm/scheduler_sm70.cuh\"\n#include \"src/turbomind/kernels/gemm/thread_group_map.h\"\n#include \"src/turbomind/kernels/gemm/tiled_mma.h\"\n#include \"src/turbomind/kernels/gemm/transform.h\"\n#include \"src/turbomind/kernels/gemm/types.h\"\n\nnamespace turbomind::gemm {\n\nnamespace sm75_s16816 {\n\nusing namespace sm80_s16816;\n\ntemplate<Order mma_iter_order,\n         class A,\n         class TransformA,\n         class U,\n         class B,\n         class TransformB,\n         class V,\n         Order order_C,\n         class Tc,\n         Order raster_order,\n         int   group_axis>\nstruct Sm75_s16816 {\n\n    static_assert(A::SmemCopyAtom::K == B::SmemCopyAtom::K);\n\n    static constexpr int SMEM_M = A::SmemCopyAtom::M / A::SmemCopyAtom::kFragNum;\n    static constexpr int SMEM_N = B::SmemCopyAtom::M / B::SmemCopyAtom::kFragNum;\n    static constexpr int SMEM_K = A::SmemCopyAtom::K;\n\n    static constexpr auto MODE_ = group_axis >= 0 ? Striding::kBlocked : Striding::kFlat;\n\n    static constexpr auto MODE_A = group_axis == 0 ? Striding::kIndexed : MODE_;\n    static constexpr auto MODE_B = group_axis == 1 ? Striding::kIndexed : MODE_;\n    static constexpr auto MODE_C = MODE_;\n\n    template<int CTA_M,\n             int CTA_N,\n             int CTA_K,\n             int TG_M,\n             int TG_N,\n             int TG_K,\n             class PolicyA,\n             class PolicyB,\n             int  Stages,\n             bool SplitK,\n             int  GroupSizeU = 1,\n             int  GroupSizeV = 1,\n             int  TILE_C_M_  = -1,\n             int  TILE_C_N_  = -1>\n    struct Type {\n        // Raked partition dont support `Pack_M > 1`\n        using Partition = Blocked<TG_M, TG_N, kColMajor>;\n        using MMA_Map   = MMA_Map<CTA_M, CTA_N, CTA_K, SMEM_M, SMEM_N, SMEM_K, Partition, TG_K>;\n        using MMA       = Tiled_MMA_v2<SM80_MMA_16x8x16_F32_F16_F16_F32_TN<half>, MMA_Map, mma_iter_order>;\n\n        using Mainloop = MainloopSm70<MMA,\n                                      A,\n                                      IteratorSm70<MODE_A, PolicyA>,\n                                      TransformA,\n                                      U,\n                                      GroupSizeU,\n                                      B,\n                                      IteratorSm70<MODE_B, PolicyB>,\n                                      TransformB,\n                                      V,\n                                      GroupSizeV,\n                                      Stages,\n                                      true>;  // FusePrefetch_\n\n        static constexpr int CHUNK_K = std::lcm(std::lcm(GroupSizeU, GroupSizeV), CTA_K);\n\n        using Scheduler = SchedulerSm70<raster_order, CTA_M, CTA_N, CTA_K, CHUNK_K, SplitK, group_axis>;\n\n        static constexpr int TILE_C_M = TILE_C_M_ == -1 ? CTA_M : TILE_C_M_;\n        static constexpr int TILE_C_N = TILE_C_N_ == -1 ? CTA_N : TILE_C_N_;\n\n        using Epilogue = gemm::Epilogue_<Tc,\n                                         CTA_M,\n                                         CTA_N,\n                                         TILE_C_M,\n                                         TILE_C_N,\n                                         MMA::kThreadCount,\n                                         Rearrange<MMA>,\n                                         Operand_C<float, order_C>,\n                                         MODE_C,\n                                         SplitK>;\n\n        using Kernel = GemmUniversal<Sm75, Mainloop, Epilogue, Scheduler>;\n    };\n};\n\n// mma_iter_order has no effect yet\n\ntemplate<Order raster_order>  // kColMajor\nusing Config_U4_d = Sm75_s16816<kColMajor,\n                                Operand_A<half, kRowMajor>,\n                                Transform_Default,\n                                VoidOperand,\n                                Operand_B_Pack<uint4_t, kColMajor, 2>,\n                                Transform_HMMA_16816<1, 0>,\n                                Operand_UV_Pack<uint32_t, true>,\n                                kRowMajor,\n                                half,\n                                raster_order,\n                                -1>;\n\ntemplate<Order raster_order>  // kColMajor\nusing Config_U4_g = Sm75_s16816<kColMajor,\n                                Operand_A<half, kRowMajor>,             // A\n                                Transform_Default,                      // tarnsform A\n                                VoidOperand,                            // U\n                                Operand_B_Pack<uint4_t, kRowMajor, 2>,  // B\n                                Transform_HMMA_16816<1, 0>,             // transform B,\n                                Operand_UV_Pack<uint32_t, true>,        // V\n                                kRowMajor,                              // order_C\n                                half,                                   // Tc\n                                raster_order,\n                                0>;\n\ntemplate<Order raster_order, int group_axis = -1>\nusing Config_MXF4 = Sm75_s16816<kColMajor,\n                                Operand_A_Pack<fp4_e2m1_t, kColMajor, 1>,  // A\n                                Transform_HMMA_16816<0, 1>,                // tarnsform A\n                                Operand_UV_Pack<uint8_t, false>,           // U\n                                Operand_B<half_t, kRowMajor>,              // B\n                                Transform_Default,                         // transform B\n                                VoidOperand,                               // V\n                                kColMajor,                                 // order_C\n                                half_t,                                    // Tc\n                                raster_order,\n                                group_axis>;\n\ntemplate<Order raster_order, int group_axis = -1>\nusing Config_E4M3 = Sm75_s16816<kColMajor,\n                                Operand_A_Pack<fp8_e4m3_t, kColMajor, 1>,  // A\n                                Transform_HMMA_16816<0, 1>,                // tarnsform A\n                                Operand_UV_Pack<uint16_t, false>,          // U\n                                Operand_B<half_t, kRowMajor>,              // B\n                                Transform_Default,                         // transform B\n                                VoidOperand,                               // V\n                                kColMajor,                                 // order_C\n                                half_t,                                    // Tc\n                                raster_order,\n                                group_axis>;\n\ntemplate<Order raster_order, int group_axis = -1>\nusing Config_F16 = Sm75_s16816<kColMajor,\n                               Operand_A<half, kRowMajor>,          // A\n                               Transform_Default,                   // tarnsform A\n                               VoidOperand,                         // U\n                               Operand_B_Pack<half, kRowMajor, 1>,  // B\n                               Transform_Default,                   // transform B\n                               VoidOperand,                         // V\n                               kRowMajor,                           // order_C\n                               half,                                // Tc\n                               raster_order,\n                               group_axis>;\n\n}  // namespace sm75_s16816\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/arch/config_sm80_s16816.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include <numeric>\n\n#include \"src/turbomind/kernels/gemm/arch.h\"\n#include \"src/turbomind/kernels/gemm/arch/mma_sm80.h\"\n#include \"src/turbomind/kernels/gemm/arch/operand_sm80_s16816.h\"\n#include \"src/turbomind/kernels/gemm/epilogue.h\"\n#include \"src/turbomind/kernels/gemm/gemm_universal.h\"\n#include \"src/turbomind/kernels/gemm/iterator_sm80.h\"\n#include \"src/turbomind/kernels/gemm/mainloop_sm80_v2.h\"\n#include \"src/turbomind/kernels/gemm/scheduler_sm70.cuh\"\n#include \"src/turbomind/kernels/gemm/thread_group_map.h\"\n#include \"src/turbomind/kernels/gemm/tiled_mma.h\"\n#include \"src/turbomind/kernels/gemm/transform.h\"\n#include \"src/turbomind/kernels/gemm/types.h\"\n\nnamespace turbomind::gemm::sm80_s16816 {\n\ntemplate<class Arch,\n         class Dtype,\n         Order mma_iter_order,\n         class A,\n         class TransformA,\n         class U,\n         class B,\n         class TransformB,\n         class V,\n         Order order_C,\n         class Tc,\n         Order raster_order,\n         int   group_axis>\nstruct Sm80_s16816 {\n\n    static_assert(A::SmemCopyAtom::K == B::SmemCopyAtom::K);\n\n    static constexpr int SMEM_M = A::SmemCopyAtom::M / A::SmemCopyAtom::kFragNum;\n    static constexpr int SMEM_N = B::SmemCopyAtom::M / B::SmemCopyAtom::kFragNum;\n    static constexpr int SMEM_K = A::SmemCopyAtom::K;\n\n    static constexpr auto MODE_ = group_axis >= 0 ? Striding::kBlocked : Striding::kFlat;\n\n    static constexpr auto MODE_A = group_axis == 0 ? Striding::kIndexed : MODE_;\n    static constexpr auto MODE_B = group_axis == 1 ? Striding::kIndexed : MODE_;\n    static constexpr auto MODE_C = MODE_;\n\n    template<int CTA_M,\n             int CTA_N,\n             int CTA_K,\n             int TG_M,\n             int TG_N,\n             int TG_K,\n             class PolicyA,\n             class PolicyB,\n             int  Stages,\n             bool SplitK,\n             int  GroupSizeU   = 1,\n             int  GroupSizeV   = 1,\n             int  TILE_C_M_    = -1,\n             int  TILE_C_N_    = -1,\n             bool FusePrefecth = true>\n\n    struct Type {\n\n        // Raked partition dont support `Pack_M > 1`\n        using Partition = Blocked<TG_M, TG_N, kColMajor>;\n        using MMA_Map   = MMA_Map<CTA_M, CTA_N, CTA_K, SMEM_M, SMEM_N, SMEM_K, Partition, TG_K>;\n        using MMA       = Tiled_MMA_v2<SM80_MMA_16x8x16_F32_F16_F16_F32_TN<Dtype>, MMA_Map, mma_iter_order>;\n\n        using Mainloop = MainloopSm80_v2<MMA,\n                                         A,\n                                         IteratorSm80<MODE_A, PolicyA>,\n                                         TransformA,\n                                         U,\n                                         GroupSizeU,\n                                         B,\n                                         IteratorSm80<MODE_B, PolicyB>,\n                                         TransformB,\n                                         V,\n                                         GroupSizeV,\n                                         Stages,\n                                         FusePrefecth>;\n\n        static constexpr int CHUNK_K = std::lcm(std::lcm(GroupSizeU, GroupSizeV), CTA_K);\n\n        using Scheduler = SchedulerSm70<raster_order, CTA_M, CTA_N, CTA_K, CHUNK_K, SplitK, group_axis>;\n\n        static constexpr int TILE_C_M = TILE_C_M_ == -1 ? CTA_M : TILE_C_M_;\n        static constexpr int TILE_C_N = TILE_C_N_ == -1 ? CTA_N : TILE_C_N_;\n\n        using Epilogue = gemm::Epilogue_<Tc,\n                                         CTA_M,\n                                         CTA_N,\n                                         TILE_C_M,\n                                         TILE_C_N,\n                                         MMA::kThreadCount,\n                                         Rearrange<MMA>,\n                                         Operand_C<float, order_C>,\n                                         MODE_C,\n                                         SplitK>;\n\n        using Kernel = GemmUniversal<Arch, Mainloop, Epilogue, Scheduler>;\n    };\n};\n\ntemplate<class Arch, class T, Order raster_order>  // kColMajor\nusing Config_U4_d = Sm80_s16816<Arch,\n                                T,                                      // mma dtype\n                                kColMajor,                              // mma iter order\n                                Operand_A<half, kRowMajor>,             // A\n                                Transform_Default,                      // tarnsform A\n                                VoidOperand,                            // U\n                                Operand_B_Pack<uint4_t, kColMajor, 2>,  // B\n                                Transform_HMMA_16816<1, 0>,             // transform B\n                                Operand_UV_Pack<uint32_t, true>,        // V\n                                kRowMajor,                              // order_C\n                                half,                                   // Tc\n                                raster_order,                           // raster order\n                                -1>;                                    // group axis\n\ntemplate<class Arch, class T, Order raster_order>  // kColMajor\nusing Config_U4_g = Sm80_s16816<Arch,\n                                T,                                      // mma dtype\n                                kColMajor,                              // mma iter order\n                                Operand_A<T, kRowMajor>,                // A\n                                Transform_Default,                      // tarnsform A\n                                VoidOperand,                            // U\n                                Operand_B_Pack<uint4_t, kRowMajor, 2>,  // B\n                                Transform_HMMA_16816<1, 0>,             // transform B,\n                                Operand_UV_Pack<uint32_t, true>,        // V\n                                kRowMajor,                              // order_C\n                                T,                                      // Tc\n                                raster_order,                           // raster order\n                                0>;                                     // group axis\n\ntemplate<class Arch, class T, int N, Order raster_order, int group_axis = -1>\nusing Config_MXF4 = Sm80_s16816<Arch,\n                                T,                                         // mma dtype\n                                kRowMajor,                                 // mma iter order\n                                Operand_A_Pack<fp4_e2m1_t, kColMajor, 1>,  // A\n                                Transform_HMMA_16816<0, 1>,                // tarnsform A\n                                Operand_UV_Pack<uint8_t, false>,           // U\n                                Operand_B<T, kRowMajor, N>,                // B\n                                Transform_Default,                         // transform B\n                                VoidOperand,                               // V\n                                kColMajor,                                 // order_C\n                                T,                                         // Tc\n                                raster_order,                              // raster order\n                                group_axis>;                               // group axis\n\ntemplate<class Arch, class T, int N, Order raster_order, int group_axis = -1>\nusing Config_E4M3 = Sm80_s16816<Arch,\n                                T,                                         // mma dtype\n                                kRowMajor,                                 // mma iter order\n                                Operand_A_Pack<fp8_e4m3_t, kColMajor, 1>,  // A\n                                Transform_HMMA_16816<0, 1>,                // tarnsform A\n                                Operand_UV_Pack<uint16_t, false>,          // U\n                                Operand_B<T, kRowMajor, N>,                // B\n                                Transform_Default,                         // transform B\n                                VoidOperand,                               // V\n                                kColMajor,                                 // order_C\n                                T,                                         // Tc\n                                raster_order,                              // raster order\n                                group_axis>;                               // group axis\n\ntemplate<class Arch, class T, Order raster_order>\nusing Config_F16_g = Sm80_s16816<Arch,\n                                 T,                                // mma dtype\n                                 kColMajor,                        // mma iter order\n                                 Operand_A<T, kRowMajor>,          // A\n                                 Transform_Default,                // tarnsform A\n                                 VoidOperand,                      // U\n                                 Operand_B_Pack<T, kRowMajor, 1>,  // B\n                                 Transform_Default,                // transform B\n                                 VoidOperand,                      // V\n                                 kRowMajor,                        // order_C\n                                 T,                                // Tc\n                                 raster_order,                     // raster order\n                                 0>;                               // group axis\n\n}  // namespace turbomind::gemm::sm80_s16816\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/arch/mma_simt.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include \"src/turbomind/kernels/core/array.h\"\n#include \"src/turbomind/kernels/core/common.h\"\n#include \"src/turbomind/kernels/gemm/desc.h\"\n#include \"src/turbomind/kernels/gemm/simt.h\"\n\nnamespace turbomind::gemm {\n\ntemplate<class T>\nstruct MMA_SIMT {\n    static constexpr int M = simt::OP_M;\n    static constexpr int N = simt::OP_N;\n    static constexpr int K = simt::OP_K;\n\n    static constexpr int kThreadCount = 32;\n\n    static constexpr auto kOpClass = OpClass::kSIMT;\n\n    using FragA = Array<T, K>;\n    using FragB = Array<T, K>;\n    using FragC = Array<float, 1>;\n\n    using OffsetC = Array<int2, 1>;\n    using FragC_  = FragC[1];\n\n    __device__ static void fma(FragC& d, const FragA& a, const FragB& b, const FragC& c)\n    {\n        PRAGMA_UNROLL\n        for (int k = 0; k < K; ++k) {\n            d[0] = c[0] + float(a[k]) * float(b[k]);\n        }\n\n        // PRAGMA_UNROLL\n        // for (int k = 0; k < K; ++k) {\n        //     d[0] = c[0] + float(a[k] * b[k]);\n        // }\n\n        // T acc{};\n        // PRAGMA_UNROLL\n        // for (int k = 0; k < K; ++k) {\n        //     acc += a[k] * b[k];\n        // }\n        // d[0] = c[0] + float(acc);\n    }\n\n    __device__ static constexpr OffsetC static_offset_C()\n    {\n        return {};\n    }\n\n    __device__ static int2 thread_offset_C()  // -> (m,n)\n    {\n        const int lane_id = threadIdx.x % WARP_SIZE;\n        return {lane_id / N, lane_id % N};\n    }\n\n    __device__ static void ReshapeC(const FragC& c, FragC_& c_)\n    {\n        c_[0] = c;\n    }\n\n    __device__ static int get_group_id(int thread_idx)\n    {\n        return thread_idx / WARP_SIZE;\n    }\n};\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/arch/mma_sm70.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include \"src/turbomind/kernels/core/array.h\"\n#include \"src/turbomind/kernels/core/common.h\"\n#include \"src/turbomind/kernels/core/mma.h\"\n#include \"src/turbomind/kernels/gemm/desc.h\"\n\nnamespace turbomind::gemm {\n\nstruct SM70_MMA_884 {\n    // static constexpr int M = 16;\n    // static constexpr int N = 16;\n    static constexpr int M = 8;\n    static constexpr int N = 32;\n    static constexpr int K = 8;\n\n    static constexpr int kThreadCount = 32;\n\n    static constexpr auto kOpClass = OpClass::kMMA_s884;\n\n    using FragA = Array<half, K>;\n    using FragB = Array<half, K>;\n    using FragC = Array<float, 8>;\n\n    using OffsetC = Array<int2, 4>;\n    using FragC_  = Array<float, 2>[4];\n\n    __device__ static void fma(FragC& d, const FragA& a, const FragB& b, const FragC& c)\n    {\n        mma_m8n8k4_row_col(d, (const Array<half, 4>&)a[0], (const Array<half, 4>&)b[0], (FragC&)c);\n        if constexpr (K == 8) {\n            mma_m8n8k4_row_col(d, (const Array<half, 4>&)a[4], (const Array<half, 4>&)b[4], (FragC&)d);\n        }\n    }\n\n    __device__ static constexpr OffsetC static_offset_C()\n    {\n        OffsetC r{};\n        PRAGMA_UNROLL\n        for (int n = 0; n < 2; ++n) {\n            PRAGMA_UNROLL\n            for (int m = 0; m < 2; ++m) {\n                r[n * 2 + m] = int2{m * 2, n * 4};\n            }\n        }\n        return r;\n    }\n\n    __device__ static int2 thread_offset_C()  // -> (m,n)\n    {\n        const int lane_id = threadIdx.x % WARP_SIZE;\n        // return {\n        //     (lane_id & 8) * 1 + (lane_id & 1) + lane_id / 16 * 4,\n        //     (lane_id & 4) * 2 + (lane_id & 2),\n        // };\n        return {(lane_id & 1) + (lane_id / 16) * 4,  //\n                (lane_id & 2) + (lane_id & 12) * 2};\n    }\n\n    __device__ static void ReshapeC(const FragC& c, FragC_& c_)\n    {\n        PRAGMA_UNROLL\n        for (int m = 0; m < 4; ++m) {\n            c_[m] = (Array<float, 2>&)c[m * 2];\n        }\n    }\n\n    __device__ static int get_group_id(int thread_idx)\n    {\n        return thread_idx / WARP_SIZE;\n    }\n};\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/arch/mma_sm80.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include \"src/turbomind/kernels/core/array.h\"\n#include \"src/turbomind/kernels/core/common.h\"\n#include \"src/turbomind/kernels/core/mma.h\"\n#include \"src/turbomind/kernels/gemm/desc.h\"\n\nnamespace turbomind::gemm {\n\ntemplate<class T>\nstruct SM80_MMA_16x8x16_F32_F16_F16_F32_TN {\n    static constexpr int M = 16;\n    static constexpr int N = 8;\n    static constexpr int K = 16;\n\n    static constexpr int kThreadCount = 32;\n\n    static constexpr auto kOpClass = OpClass::kMMA_s16816;\n\n    using FragA = Array<T, 8>;\n    using FragB = Array<T, 4>;\n    using FragC = Array<float, 4>;\n\n    using OffsetC = Array<int2, 2>;  // (m, n)\n    using FragC_  = Array<float, 2>[2];\n\n    __device__ static void fma(FragC& d, const FragA& a, const FragB& b, const FragC& c)\n    {\n        mma_m16n8k16_row_col(d, a, b, (FragC&)c);\n    }\n\n    __device__ static constexpr OffsetC static_offset_C()\n    {\n        return {int2{0, 0}, int2{8, 0}};\n    }\n\n    __device__ static int2 thread_offset_C()  // -> (m,n)\n    {\n        const int lane_id = threadIdx.x % WARP_SIZE;\n        return {lane_id / 4, lane_id % 4 * 2};\n    }\n\n    __device__ static void ReshapeC(const FragC& c, FragC_& c_)\n    {\n        PRAGMA_UNROLL\n        for (int m = 0; m < 2; ++m) {\n            c_[m] = (Array<float, 2>&)c[m * 2];\n        }\n    }\n\n    __device__ static int get_group_id(int thread_idx)\n    {\n        return thread_idx / WARP_SIZE;\n    }\n};\n\n// This is not used yet\ntemplate<class T>\nstruct SM75_MMA_16x8x8_F32_F16_F16_F32_TN: SM80_MMA_16x8x16_F32_F16_F16_F32_TN<T> {\n    static constexpr int M = 16;\n    static constexpr int N = 8;\n    static constexpr int K = 8;\n\n    using FragA = Array<T, 4>;\n    using FragB = Array<T, 2>;\n    using FragC = Array<float, 4>;\n\n    __device__ static void fma(FragC& d, const FragA& a, const FragB& b, const FragC& c)\n    {\n        mma_m16n8k8_row_col(d, a, b, (FragC&)c);\n    }\n};\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/arch/operand_simt.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include \"src/turbomind/kernels/core/layout.h\"\n#include \"src/turbomind/kernels/core/meta.h\"\n#include \"src/turbomind/kernels/gemm/arch/smem_copy_simt.h\"\n#include \"src/turbomind/kernels/gemm/iterator.h\"\n#include \"src/turbomind/kernels/gemm/operand.h\"\n#include \"src/turbomind/kernels/gemm/simt.h\"\n#include \"src/turbomind/kernels/gemm/smem_copy.h\"\n#include \"src/turbomind/kernels/gemm/types.h\"\n\nnamespace turbomind::gemm {\n\nnamespace simt {\n\nstruct GetSmemLayout {\n    template<int M, int K>\n    static constexpr auto apply(pair<M, K>)\n    {\n        return SmemLayoutV2<M, K>{};\n    }\n};\n\ntemplate<class T, int K>\nstruct Operand_A {\n    using Dtype = T;\n\n    static constexpr Pack  kPack  = 0;\n    static constexpr Order kOrder = kRowMajor;\n\n    using SmemCopyAtom = SmemCopy_MMA_SIMT_A<T, K>;\n\n    using GetSmemLayout = GetSmemLayout;\n    using GetGmemIter   = GetGmemIter;\n};\n\ntemplate<class T, int K>\nstruct Operand_B {\n    using Dtype = T;\n\n    static constexpr Pack  kPack  = 0;\n    static constexpr Order kOrder = kRowMajor;\n\n    using SmemCopyAtom = SmemCopy_MMA_SIMT_B<T, K>;\n\n    using GetSmemLayout = GetSmemLayout;\n    using GetGmemIter   = GetGmemIter;\n};\n\ntemplate<Order order>\nstruct _GetSmemLayoutC {\n    template<int M, int N>\n    static constexpr auto apply(pair<M, N>)\n    {\n        constexpr auto cs = mk2cs<order>(M, N);\n        return SmemLayoutV2<cs.y, cs.x, 1, 1>{};\n    }\n};\n\ntemplate<Order order>\nstruct _GetThreadMapC {\n    template<int M, int N, int THREADS>\n    static constexpr auto apply(pair<M, N>, constant<THREADS>)\n    {\n        constexpr auto cs    = mk2cs<order>(M, N);\n        constexpr int  WARPS = THREADS / WARP_SIZE;\n\n        return ThreadMap_V2<cs.x, cs.y, 4, Raked, WARPS>{};\n    }\n};\n\ntemplate<class T, Order order>\nstruct Operand_C {\n    using Dtype = T;\n\n    static constexpr Order kOrder = order;\n\n    using GetSmemLayout = _GetSmemLayoutC<order>;\n    using GetThreadMap  = _GetThreadMapC<order>;\n};\n\ntemplate<class T>\nstruct Operand_V {\n    using Dtype = T;\n\n    static constexpr Pack  kPack  = 0;\n    static constexpr Order kOrder = kColMajor;\n\n    using SmemCopyAtom = SmemCopy_MMA_SIMT_V<T, 1>;\n\n    struct GetSmemLayout {  // m-major\n        template<int M, int K>\n        static constexpr auto apply(pair<M, K>)\n        {\n            return SmemLayoutV2<K, M>{};\n        }\n    };\n\n    using GetGmemIter = GetGmemIter;\n};\n\nstruct GetSmemLayout_Pack {\n    template<int M, int K>\n    static constexpr auto apply(pair<M, K>)\n    {\n        return SmemLayoutV2<M, K>{};\n    }\n};\n\ntemplate<class T, int K>\nstruct Operand_B_Pack {\n    using Dtype = T;\n\n    static constexpr int Pack_M = 1;\n\n    static constexpr Pack  kPack  = HMMA_SIMT | OPERAND_B | Pack_M;\n    static constexpr Order kOrder = kRowMajor;\n\n    using SmemCopyAtom  = SmemCopyAtom_Pack_v3<T, typename Operand_B<T, K>::SmemCopyAtom, kRowMajor, Pack_M>;\n    using GetSmemLayout = GetSmemLayout_Pack;\n    using GetGmemIter   = GetGmemIter;\n};\n\ntemplate<class T>\nstruct Operand_V_Pack {\n    using Dtype = T;\n\n    static constexpr int Pack_M = 1;\n\n    static constexpr Pack  kPack  = HMMA_SIMT | OPERAND_V | Pack_M;\n    static constexpr Order kOrder = kColMajor;\n\n    using SmemCopyAtom = SmemCopyAtom_Pack_v3<T, SmemCopy_MMA_SIMT_V<T, OP_K>, kColMajor, Pack_M>;\n\n    struct GetSmemLayout {  // m-major\n        template<int M, int K>\n        static constexpr auto apply(pair<M, K>)\n        {\n            return SmemLayoutV2<K, M>{};\n        }\n    };\n\n    using GetGmemIter = GetGmemIter;\n};\n\n}  // namespace simt\n\ntemplate<class T>\nstruct GetOperand<HMMA_SIMT, OPERAND_A, T, kRowMajor, false>: std::true_type {\n    using Operand = simt::Operand_A<T, simt::OP_K>;\n};\n\ntemplate<class T>\nstruct GetOperand<HMMA_SIMT, OPERAND_B, T, kRowMajor, false>: std::true_type {\n    using Operand = simt::Operand_B<T, simt::OP_K>;\n};\n\ntemplate<class T>\nstruct GetOperand<HMMA_SIMT, OPERAND_V, T, kColMajor, false>: std::true_type {\n    using Operand = simt::Operand_V<T>;\n};\n\ntemplate<class T>\nstruct GetOperand<HMMA_SIMT, OPERAND_B, T, kRowMajor, true>: std::true_type {\n    using Operand = simt::Operand_B_Pack<T, simt::OP_K>;\n};\n\ntemplate<class T>\nstruct GetOperand<HMMA_SIMT, OPERAND_V, T, kColMajor, true>: std::true_type {\n    using Operand = simt::Operand_V_Pack<T>;\n};\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/arch/operand_sm70_s884.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include \"src/turbomind/kernels/core/layout.h\"\n#include \"src/turbomind/kernels/core/meta.h\"\n#include \"src/turbomind/kernels/gemm/arch/smem_copy_sm70.h\"\n#include \"src/turbomind/kernels/gemm/iterator.h\"\n#include \"src/turbomind/kernels/gemm/operand.h\"\n#include \"src/turbomind/kernels/gemm/smem_copy.h\"\n#include \"src/turbomind/kernels/gemm/types.h\"\n\nnamespace turbomind::gemm {\n\nnamespace sm70_s884 {\n\ntemplate<Order order>\nstruct GetSmemLayout {\n    template<int M, int K>\n    static constexpr auto apply(pair<M, K>)\n    {\n        constexpr int2 cs = mk2cs<order>(M, K);\n        return SmemLayoutV2<cs.y, cs.x, 1, 1>{};\n    }\n};\n\ntemplate<class T>\nstruct Operand_A {\n    using Dtype = T;\n\n    static constexpr Pack  kPack  = 0;\n    static constexpr Order kOrder = kRowMajor;\n\n    using SmemCopyAtom = SmemCopy_MMA_884_A<T>;\n\n    using GetSmemLayout = GetSmemLayout<kOrder>;\n    using GetGmemIter   = GetGmemIter;\n};\n\ntemplate<class T>\nstruct Operand_B {\n    using Dtype = T;\n\n    static constexpr Pack  kPack  = 0;\n    static constexpr Order kOrder = kRowMajor;  // (n,k)\n\n    using SmemCopyAtom = SmemCopy_MMA_884_B<T>;\n\n    using GetSmemLayout = GetSmemLayout<kOrder>;\n    using GetGmemIter   = GetGmemIter;\n};\n\ntemplate<class T>\nstruct Operand_V {\n    using Dtype = T;\n\n    static constexpr Pack  kPack  = 0;\n    static constexpr Order kOrder = kColMajor;  // (n,k)\n\n    using SmemCopyAtom = SmemCopy_MMA_884_V<T, 1>;\n\n    struct GetSmemLayout {  // m-major\n        template<int M, int K>\n        static constexpr auto apply(pair<M, K>)\n        {\n            return SmemLayoutV2<K, M>{};\n        }\n    };\n\n    using GetGmemIter = GetGmemIter;\n};\n\ntemplate<Order order>\nstruct _GetSmemLayoutC {\n    template<int M, int N>\n    static constexpr auto apply(pair<M, N>)\n    {\n        constexpr auto cs = mk2cs<order>(M, N);\n        return SmemLayoutV2<cs.y, cs.x, 1, 1>{};\n    }\n};\n\ntemplate<Order order>\nstruct _GetThreadMapC {\n    template<int M, int N, int THREADS>\n    static constexpr auto apply(pair<M, N>, constant<THREADS>)\n    {\n        constexpr auto cs    = mk2cs<order>(M, N);\n        constexpr int  WARPS = THREADS / WARP_SIZE;\n\n        return ThreadMap_V2<cs.x, cs.y, 4, Raked, WARPS>{};\n    }\n};\n\ntemplate<class T, Order order>\nstruct Operand_C {\n    using Dtype = T;\n\n    static constexpr Order kOrder = order;\n\n    using GetSmemLayout = _GetSmemLayoutC<order>;\n    using GetThreadMap  = _GetThreadMapC<order>;\n};\n\ntemplate<class T>\nstruct Operand_B_Pack {\n    using Dtype = T;\n\n    static constexpr int Pack_M = 1;\n\n    static constexpr Pack  kPack  = HMMA_884 | OPERAND_B | Pack_M;\n    static constexpr Order kOrder = kRowMajor;\n\n    using SmemCopyAtom = SmemCopyAtom_Pack_v3<T, SmemCopy_MMA_884_B<T>, kOrder, Pack_M>;\n\n    using GetSmemLayout = GetSmemLayout<kOrder>;\n    using GetGmemIter   = GetGmemIter;\n};\n\ntemplate<class T>\nstruct Operand_V_Pack {\n    using Dtype = T;\n\n    static constexpr int Pack_M = 1;\n\n    static constexpr Pack  kPack  = HMMA_884 | OPERAND_V | Pack_M;\n    static constexpr Order kOrder = kColMajor;\n\n    using SmemCopyAtom = SmemCopyAtom_Pack_v3<T, SmemCopy_MMA_884_V<T, 8>, kColMajor, Pack_M>;\n\n    struct GetSmemLayout {  // m-major\n        template<int M, int K>\n        static constexpr auto apply(pair<M, K>)\n        {\n            return SmemLayoutV2<K, M>{};\n        }\n    };\n\n    using GetGmemIter = GetGmemIter;\n};\n\n}  // namespace sm70_s884\n\ntemplate<class T>\nstruct GetOperand<HMMA_884, OPERAND_A, T, kRowMajor, false>: std::true_type {\n    using Operand = sm70_s884::Operand_A<T>;\n};\n\ntemplate<class T>\nstruct GetOperand<HMMA_884, OPERAND_B, T, kRowMajor, false>: std::true_type {\n    using Operand = sm70_s884::Operand_B<T>;\n};\n\ntemplate<class T>\nstruct GetOperand<HMMA_884, OPERAND_V, T, kColMajor, false>: std::true_type {\n    using Operand = sm70_s884::Operand_V<T>;\n};\n\ntemplate<class T>\nstruct GetOperand<HMMA_884, OPERAND_B, T, kRowMajor, true>: std::true_type {\n    using Operand = sm70_s884::Operand_B_Pack<T>;\n};\n\ntemplate<class T>\nstruct GetOperand<HMMA_884, OPERAND_V, T, kColMajor, true>: std::true_type {\n    using Operand = sm70_s884::Operand_V_Pack<T>;\n};\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/arch/operand_sm80_s16816.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include \"src/turbomind/kernels/core/common.h\"\n#include \"src/turbomind/kernels/core/layout.h\"\n#include \"src/turbomind/kernels/core/meta.h\"\n#include \"src/turbomind/kernels/gemm/arch/smem_copy_sm80.h\"\n#include \"src/turbomind/kernels/gemm/iterator.h\"\n#include \"src/turbomind/kernels/gemm/operand.h\"\n#include \"src/turbomind/kernels/gemm/smem_copy.h\"\n#include \"src/turbomind/kernels/gemm/types.h\"\n#include <type_traits>\n\nnamespace turbomind::gemm {\n\nnamespace sm80_s16816 {\n\nnamespace detail {\n\nstruct GetSmemLayout {\n    template<int S, int C>\n    static constexpr auto apply(pair<S, C>)\n    {\n        // constexpr int S0 = S >= 16 ? 16 : 8;\n        constexpr int S0 = 8;\n        constexpr int C0 = C >= 64 ? 64 : (C >= 32 ? 32 : 16);\n        using _Small     = std::conditional_t<C0 == 32, Swizzle<2, 3, 3>, Swizzle<1, 3, 3>>;\n        using Swizzle    = std::conditional_t<C0 == 64, Swizzle<3, 3, 3>, _Small>;\n        return SmemLayoutV2<S, C, S0, C0, Swizzle>{};\n    }\n};\n\n}  // namespace detail\n\ntemplate<Order order>\nstruct GetSmemLayoutV2 {\n    template<int M, int K>\n    static constexpr auto apply(pair<M, K>)\n    {\n        constexpr int2 cs = mk2cs<order>(M, K);\n        return detail::GetSmemLayout::apply(pair<cs.y, cs.x>{});\n    }\n};\n\n// (m, k)\ntemplate<class T, Order order>\nstruct Operand_A {\n    using Dtype = T;\n\n    static constexpr Pack  kPack  = 0;\n    static constexpr Order kOrder = order;\n\n    // using SmemCopyAtom =\n    //     std::conditional_t<order == kRowMajor, SmemCopy_MMA_16816_A<T, false>, SmemCopy_MMA_16816_B<T, true>>;\n\n    // using SmemCopyAtom = std::conditional_t<order == kRowMajor,\n    //                                         LDSM_SM75_8x8<T, 16, 16, kColMajor, kRowMajor>,\n    //                                         LDSM_SM75_8x8<T, 16, 16, kRowMajor, kColMajor>>;\n\n    using SmemCopyAtom = LDSM_SM75_8x8<T, 16, 16, ~order, order>;\n\n    using GetSmemLayout = GetSmemLayoutV2<kOrder>;\n    using GetGmemIter   = GetGmemIter;\n};\n\n// (n, k)\ntemplate<class T, Order order, int N = 16>\nstruct Operand_B {\n    using Dtype = T;\n\n    static constexpr Pack  kPack  = 0;\n    static constexpr Order kOrder = order;\n\n    // using SmemCopyAtom =\n    //     std::conditional_t<order == kRowMajor, SmemCopy_MMA_16816_B<T, false>, SmemCopy_MMA_16816_A<T, true>>;\n    // using SmemCopyAtom = std::conditional_t<order == kRowMajor,  //\n    //                                         LDSM_SM75_8x8<T, 16, 16, kRowMajor, kRowMajor>,\n    //                                         LDSM_SM75_8x8<T, 16, 16, kColMajor, kColMajor>>;\n\n    using SmemCopyAtom = LDSM_SM75_8x8<T, N, 16, order, order>;\n\n    using GetSmemLayout = GetSmemLayoutV2<kOrder>;\n    using GetGmemIter   = GetGmemIter;\n};\n\ntemplate<Order order>\nstruct _GetSmemLayoutC {\n    template<int M, int N>\n    static constexpr auto apply(pair<M, N>)\n    {\n        if constexpr (order == kRowMajor) {\n            // x01  23\n            // cccccss\n            //                                    bits base shift\n            return SmemLayoutV2<M, N, 8, 32, Swizzle<2, 3, 2>>{};\n        }\n        else {\n            // 234  x01\n            // 23401x\n            // cccccsss\n            // so that x is not part of swizzling\n            return SmemLayoutV2<N, M, 8, 32, Swizzle<2, 3, 3>>{};\n        }\n    }\n};\n\ntemplate<Order order>\nstruct _GetThreadMapC {\n    template<int M, int N, int THREADS>\n    static constexpr auto apply(pair<M, N>, constant<THREADS>)\n    {\n        constexpr auto cs    = mk2cs<order>(M, N);\n        constexpr int  WARPS = THREADS / WARP_SIZE;\n\n        return ThreadMap_V2<cs.x, cs.y, 4, Raked, WARPS>{};\n    }\n};\n\ntemplate<class T, Order order>\nstruct Operand_C {\n    using Dtype = T;\n\n    static constexpr Order kOrder = order;\n\n    using GetSmemLayout = _GetSmemLayoutC<order>;\n    using GetThreadMap  = _GetThreadMapC<order>;\n};\n\ntemplate<class T>\nstruct Operand_UV {\n    using Dtype = T;\n\n    static constexpr Pack  kPack  = 0;\n    static constexpr Order kOrder = kColMajor;\n\n    using SmemCopyAtom = SmemCopy_MMA_16816_U<T>;\n\n    struct GetSmemLayout {\n        template<int M, int K>\n        static constexpr auto apply(pair<M, K>)\n        {\n            return SmemLayoutV2<K, M>{};\n        }\n    };\n    using GetGmemIter = GetGmemIter;\n};\n\ntemplate<Order order>\nstruct GetSmemLayout_Pack {\n    template<int M, int K>\n    static constexpr auto apply(pair<M, K>)\n    {\n        constexpr int2 CS = mk2cs<order>(M, K);\n        return SmemLayoutV2<CS.y, CS.x, 1, 1>{};\n    }\n};\n\ntemplate<class T, Order order, int Pack_M_>\nstruct Operand_A_Pack {\n    using Dtype = T;\n\n    static constexpr int Pack_M = Pack_M_;\n\n    static constexpr Pack  kPack  = HMMA_16816 | OPERAND_A | Pack_M;\n    static constexpr Order kOrder = order;\n\n    // using SmemCopyAtom = SmemCopyAtom_Pack_v2<T, kOrder, 16 * Pack_M, 16, 8, Pack_M>;\n    using _SCp         = typename Operand_A<T, order>::SmemCopyAtom;\n    using SmemCopyAtom = SmemCopyAtom_Pack_v3<T, _SCp, order, Pack_M>;\n\n    using GetSmemLayout = GetSmemLayout_Pack<kOrder>;\n    using GetGmemIter   = GetGmemIter;\n};\n\ntemplate<class T, Order order, int Pack_M_>\nstruct Operand_B_Pack {\n    using Dtype = T;\n\n    static constexpr int Pack_M = Pack_M_;\n\n    static constexpr Pack  kPack  = HMMA_16816 | OPERAND_B | Pack_M;\n    static constexpr Order kOrder = order;\n\n    using SmemCopyAtom = SmemCopyAtom_Pack_v2<T, kOrder, 16 * Pack_M, 16, 8, Pack_M>;\n\n    using GetSmemLayout = GetSmemLayout_Pack<kOrder>;\n    using GetGmemIter   = GetGmemIter;\n};\n\ntemplate<class T, bool is_V>\nstruct Operand_UV_Pack {\n    using Dtype = T;\n\n    static constexpr int Pack_M = 1;\n\n    static constexpr Pack  kPack  = HMMA_16816 | (is_V ? OPERAND_V : OPERAND_U) | Pack_M;\n    static constexpr Order kOrder = Order::kColMajor;\n\n    using _SCp         = typename Operand_UV<T>::SmemCopyAtom;\n    using SmemCopyAtom = SmemCopyAtom_Pack_v3<T, _SCp, kOrder, Pack_M>;\n\n    using GetSmemLayout = GetSmemLayout_Pack<kOrder>;\n    using GetGmemIter   = GetGmemIter;\n};\n\n}  // namespace sm80_s16816\n\ntemplate<class T, Order order>\nstruct GetOperand<HMMA_16816, OPERAND_A, T, order, false>: std::true_type {\n    using Operand = sm80_s16816::Operand_A<T, order>;\n};\n\ntemplate<class T, Order order>\nstruct GetOperand<HMMA_16816, OPERAND_B, T, order, false>: std::true_type {\n    using Operand = sm80_s16816::Operand_B<T, order, 16>;\n};\n\ntemplate<class T>\nstruct GetOperand<HMMA_16816, OPERAND_U, T, kColMajor, false>: std::true_type {\n    using Operand = sm80_s16816::Operand_UV<T>;\n};\n\ntemplate<class T>\nstruct GetOperand<HMMA_16816, OPERAND_V, T, kColMajor, false>: std::true_type {\n    using Operand = sm80_s16816::Operand_UV<T>;\n};\n\n// template<class T>\n// struct GetOperand<HMMA_16816, OPERAND_A, T, kColMajor, true>: std::true_type {\n//     using Operand = sm80_s16816::Operand_A_Pack<T, kColMajor>;\n// };\n\n// template<class T>\n// struct GetOperand<HMMA_16816, OPERAND_B, T, kColMajor, true>: std::true_type {\n//     using Operand = sm80_s16816::Operand_B_Pack<T, kColMajor>;\n// };\n\n// template<>\n// struct GetOperand<HMMA_16816, OPERAND_U, uint32_t, kColMajor, true>: std::true_type {\n//     using Operand = sm80_s16816::Operand_U_Pack<uint32_t>;\n// };\n\n// template<>\n// struct GetOperand<HMMA_16816, OPERAND_V, uint32_t, kColMajor, true>: std::true_type {\n//     using Operand = sm80_s16816::Operand_U_Pack<uint32_t>;\n// };\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/arch/smem_copy_simt.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include \"src/turbomind/kernels/core/array_ops.h\"\n#include \"src/turbomind/kernels/core/common.h\"\n#include \"src/turbomind/kernels/gemm/simt.h\"\n#include \"src/turbomind/kernels/gemm/smem_copy.h\"\n#include \"src/turbomind/kernels/gemm/types.h\"\n\nnamespace turbomind::gemm {\n\ntemplate<class T, int K_>\nstruct SmemCopy_MMA_SIMT_A {\n    static constexpr int M = simt::OP_M;\n    static constexpr int K = simt::OP_K;\n\n    static constexpr int OP_N = simt::OP_N;\n\n    static constexpr int kFragNum = 1;\n\n    using Frag = Array<T, K>;\n\n    __device__ static int2 get_offset(int thread_idx)\n    {\n        const int lane_id = thread_idx % WARP_SIZE;\n        return {lane_id / OP_N, 0};\n    }\n\n    template<class S, class D>\n    __device__ static void copy(S&& src_ptr, D&& dst_ptr, bool)  // -> (m, k)\n    {\n        Lds(*(Frag*)dst_ptr, (S &&) src_ptr);\n    }\n\n    __device__ static int2 unique(int thread_idx, int pack_idx)  // -> (unique id, repeat id)\n    {\n        const int lane_id = thread_idx % WARP_SIZE;\n        return {pack_idx * M + lane_id / OP_N, lane_id % OP_N};\n    }\n};\n\ntemplate<class T, int K_>\nstruct SmemCopy_MMA_SIMT_B {\n    static constexpr int M = simt::OP_N;\n    static constexpr int K = simt::OP_K;\n\n    static constexpr int OP_N = simt::OP_N;\n\n    static constexpr int kFragNum = 1;\n\n    using Frag = Array<T, K>;\n\n    __device__ static int2 get_offset(int thread_idx)  // -> (m, k)\n    {\n        const int lane_id = thread_idx % WARP_SIZE;\n        return {lane_id % OP_N, 0};\n    }\n\n    template<class S, class D>\n    __device__ static void copy(S&& src_ptr, D&& dst_ptr, bool)\n    {\n        Lds(*(Frag*)dst_ptr, (S &&) src_ptr);\n    }\n\n    __device__ static int2 unique(int thread_idx, int pack_idx)  // -> (unique id, repeat id)\n    {\n        const int lane_id = thread_idx % WARP_SIZE;\n        return {pack_idx * OP_N + lane_id % OP_N, lane_id / OP_N};\n    }\n};\n\ntemplate<class T, int K_>\nstruct SmemCopy_MMA_SIMT_V {\n    static constexpr int M = simt::OP_N;\n    static constexpr int K = K_;\n\n    static constexpr int OP_N = simt::OP_N;\n\n    static constexpr int kFragNum = 1;\n\n    using Frag = Array<T, 1>;\n\n    __device__ static int2 unique(int thread_idx, int pack_idx)\n    {\n        const int lane_id = thread_idx % WARP_SIZE;\n        return {pack_idx * OP_N + lane_id % OP_N, lane_id / OP_N};\n    }\n\n    __device__ static int2 get_offset(int thread_idx)  // -> (m, k)\n    {\n        return {unique(thread_idx, 0).x, 0};\n    }\n\n    template<class S, class D>\n    __device__ static void copy(S&& src_ptr, D&& dst_ptr, bool mask)\n    {\n        Lds(*(Frag*)dst_ptr, src_ptr);\n    }\n};\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/arch/smem_copy_sm70.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include \"src/turbomind/kernels/core/common.h\"\n#include \"src/turbomind/kernels/gemm/smem_copy.h\"\n\nnamespace turbomind::gemm {\n\ntemplate<class T>\nstruct SmemCopy_MMA_884_A {\n    // static constexpr int M = 16;\n    // static constexpr int K = 8;\n    static constexpr int M = 8;\n    static constexpr int K = 8;\n\n    static constexpr int kFragNum = 1;\n\n    using Frag = Array<T, K>;\n\n    __device__ static int2 unique(int thread_idx, int pack_idx)\n    {\n        const int lane_id = thread_idx % WARP_SIZE;\n        //                   4                3               01\n        // const int m = lane_id / 16 * 4 + (lane_id & 8) + lane_id % 4;\n        // return {pack_idx * M + m, (lane_id & 4) >> 2};\n\n        //                   4                01\n        const int m = lane_id / 16 * 4 + lane_id % 4;\n        return {pack_idx * M + m, (lane_id & 12) >> 2};\n    }\n\n    __device__ static int2 get_offset(int thread_idx)\n    {\n        return int2{unique(thread_idx, 0).x, 0};\n    }\n\n    template<class S, class D>\n    __device__ static void copy(S&& src_ptr, D&& dst_ptr, bool)\n    {\n        Lds(*(Frag*)dst_ptr, src_ptr);\n    }\n};\n\ntemplate<class T>\nstruct SmemCopy_MMA_884_B {\n    // static constexpr int M = 16;\n    // static constexpr int K = 8;\n    static constexpr int M = 32;\n    static constexpr int K = 8;\n\n    static constexpr int kFragNum = 1;\n\n    using Frag = Array<T, K>;\n\n    __device__ static int2 unique(int thread_idx, int pack_idx)\n    {\n        const int lane_id = thread_idx % WARP_SIZE;\n        //                4                     2                 01\n        // const int m = lane_id / 16 * 4 + (lane_id & 4) * 2 + lane_id % 4;\n        // return {pack_idx * M + m, (lane_id & 8) >> 3};\n\n        //                  4                  23                  01\n        const int m = lane_id / 16 * 4 + (lane_id & 12) * 2 + lane_id % 4;\n        return {pack_idx * M + m, 0};\n    }\n\n    __device__ static int2 get_offset(int thread_idx)\n    {\n        return int2{unique(thread_idx, 0).x, 0};\n    }\n\n    template<class S, class D>\n    __device__ static void copy(S&& src_ptr, D&& dst_ptr, bool)\n    {\n        Lds(*(Frag*)dst_ptr, src_ptr);\n    }\n};\n\ntemplate<class T, int K_>\nstruct SmemCopy_MMA_884_V {\n    // static constexpr int M = 16;\n    static constexpr int M = 32;\n    static constexpr int K = K_;\n\n    static constexpr int kFragNum = 1;\n\n    using Frag = Array<T, 1>;\n\n    __device__ static int2 unique(int thread_idx, int pack_idx)\n    {\n        const int lane_id = thread_idx % WARP_SIZE;\n        //                4                     2                 01\n        // const int m = lane_id / 16 * 4 + (lane_id & 4) * 2 + lane_id % 4;\n        // return {pack_idx * 16 + m, (lane_id & 8) >> 3};\n\n        const int m = lane_id / 16 * 4 + (lane_id & 12) * 2 + lane_id % 4;\n        return {pack_idx * M + m, 0};\n    }\n\n    __device__ static int2 get_offset(int thread_idx)\n    {\n        return int2{unique(thread_idx, 0).x, 0};\n    }\n\n    template<class S, class D>\n    __device__ static void copy(S&& src_ptr, D&& dst_ptr, bool)\n    {\n        Lds(*(Frag*)dst_ptr, src_ptr);\n    }\n};\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/arch/smem_copy_sm80.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include \"src/turbomind/kernels/core/array.h\"\n#include \"src/turbomind/kernels/core/smem.h\"\n#include \"src/turbomind/kernels/gemm/types.h\"\n#include \"src/turbomind/kernels/gemm/utils.h\"\n\nnamespace turbomind::gemm {\n\ntemplate<bool trans>\nstruct LDSM_x4 {\n    template<class S, class D>\n    __device__ static void apply(S src_ptr, D dst_ptr)\n    {\n        const uint32_t uint_ptr = cast_smem_ptr_to_uint(src_ptr);\n        if constexpr (trans) {\n            ldsm_x4_trans(*(Array<uint32_t, 4>*)dst_ptr, uint_ptr);\n        }\n        else {\n            ldsm_x4(*(Array<uint32_t, 4>*)dst_ptr, uint_ptr);\n        }\n    }\n};\n\ntemplate<bool trans>\nstruct LDSM_x2 {\n    template<class S, class D>\n    __device__ static void apply(S src_ptr, D dst_ptr)\n    {\n        const uint32_t uint_ptr = cast_smem_ptr_to_uint(src_ptr);\n        if constexpr (trans) {\n            ldsm_x2_trans(*(Array<uint32_t, 2>*)dst_ptr, uint_ptr);\n        }\n        else {\n            ldsm_x2(*(Array<uint32_t, 2>*)dst_ptr, uint_ptr);\n        }\n    }\n};\n\ntemplate<bool trans>\nstruct LDSM_x1 {\n    template<class S, class D>\n    __device__ static void apply(S src_ptr, D dst_ptr)\n    {\n        const uint32_t uint_ptr = cast_smem_ptr_to_uint(src_ptr);\n        if constexpr (trans) {\n            ldsm_x1_trans(*(Array<uint32_t, 1>*)dst_ptr, uint_ptr);\n        }\n        else {\n            ldsm_x1(*(Array<uint32_t, 1>*)dst_ptr, uint_ptr);\n        }\n    }\n};\n\ntemplate<class T, bool trans>\nstruct SmemCopy_MMA_16816_A {\n    static constexpr int M = 16;\n    static constexpr int K = 16;\n\n    static constexpr int kFragNum = 1;\n\n    using Frag = Array<T, 8>;\n\n    __device__ static int2 get_offset(int thread_idx)  // -> (m, k)\n    {\n        const int lane_id = thread_idx % WARP_SIZE;\n\n        const int c = lane_id / 16 * 8;\n        const int s = lane_id % 16;\n\n        return trans ? int2{c, s} : int2{s, c};\n    }\n\n    template<class S, class D>\n    __device__ static void copy(S&& src_ptr, D&& dst_ptr, bool)\n    {\n        LDSM_x4<trans>::apply((S &&) src_ptr, (D &&) dst_ptr);\n    }\n\n    __device__ static int2 unique(int thread_idx, int pack_idx)\n    {\n        return {pack_idx * WARP_SIZE + thread_idx % WARP_SIZE, 0};\n    }\n};\n\ntemplate<class T, bool trans>\nstruct SmemCopy_MMA_16816_B {\n    static constexpr int M = 16;\n    static constexpr int K = 16;\n\n    static constexpr int kFragNum = 1;\n\n    using Frag = Array<T, 8>;\n\n    __device__ static int2 get_offset(int thread_idx)\n    {\n        const int lane_id = thread_idx % WARP_SIZE;\n\n        const int c = lane_id / 8 * 8 % 16;\n        const int s = lane_id % 8 + lane_id / 16 * 8;\n\n        return trans ? int2{c, s} : int2{s, c};\n    }\n\n    template<class S, class D>\n    __device__ static void copy(S&& src_ptr, D&& dst_ptr, bool)\n    {\n        LDSM_x4<trans>::apply((S &&) src_ptr, (D &&) dst_ptr);\n    }\n\n    __device__ static int2 unique(int thread_idx, int pack_idx)\n    {\n        return {pack_idx * WARP_SIZE + thread_idx % WARP_SIZE, 0};\n    }\n};\n\ntemplate<class T, int M_, int K_, Order mat_order, Order thr_order>\nstruct LDSM_SM75_8x8 {\n    static constexpr int M = M_;\n    static constexpr int K = K_;\n\n    static constexpr int iM = M / 8;\n    static constexpr int iK = K / 8;\n\n    static constexpr int kFragNum = 1;\n\n    using Frag = Array<T, 2 * iM * iK>;\n\n    __device__ static int2 get_offset(int thread_idx)\n    {\n        const int lane_id = thread_idx % WARP_SIZE;\n        int       c, s;\n        if constexpr (mat_order == kColMajor) {\n            s = lane_id % 16;\n            c = lane_id / 16 * 8;\n        }\n        else {\n            s = lane_id / 16 * 8 + lane_id % 8;\n            c = lane_id & 8;\n        }\n        int2 mk = cs2mk<thr_order>(c, s);\n#if __CUDA_ARCH__ <= 750  // wrap ptrs around for sm_75\n        mk.x %= M;\n        mk.y %= K;\n#endif\n        return mk;\n    }\n\n    template<class S, class D>\n    __device__ static void copy(S&& src_ptr, D&& dst_ptr, bool)\n    {\n        constexpr bool trans = thr_order != kRowMajor;\n        if constexpr (sizeof(Frag) == 16) {\n            LDSM_x4<trans>::apply((S &&) src_ptr, (D &&) dst_ptr);\n        }\n        else if constexpr (sizeof(Frag) == 8) {\n            LDSM_x2<trans>::apply((S &&) src_ptr, (D &&) dst_ptr);\n        }\n        else if constexpr (sizeof(Frag) == 4) {\n            LDSM_x1<trans>::apply((S &&) src_ptr, (D &&) dst_ptr);\n        }\n        else {\n            static_assert(sizeof(S) != sizeof(S), \"not implemented\");\n        }\n    }\n\n    __device__ static int2 unique(int thread_idx, int pack_idx)\n    {\n        return {pack_idx * WARP_SIZE + thread_idx % WARP_SIZE, 0};\n    }\n};\n\ntemplate<class T>\nstruct SmemCopy_MMA_16816_U {  // (M, K)\n    static constexpr int M = 16;\n    static constexpr int K = 1;\n\n    static constexpr int kFragNum = 1;\n\n    using Frag = Array<T, 2>;\n\n    __device__ static int2 get_offset(int thread_idx)\n    {\n        const int lane_id = thread_idx % WARP_SIZE;\n        // Note: this forbids sub-tile group sizes\n        return {lane_id / 4, 0};\n    }\n\n    template<class S, class D>\n    __device__ static void copy(S&& src_ptr, D&& dst_ptr, bool mask)\n    {\n        PRAGMA_UNROLL\n        for (int i = 0; i < 2; ++i) {\n            Lds(*((Array<T, 1>*)dst_ptr + i), src_ptr + i * 8);\n        }\n    }\n\n    __device__ static int2 unique(int thread_idx, int pack_idx)\n    {\n        const int lane_id = thread_idx % WARP_SIZE;\n        return {pack_idx * 8 + lane_id / 4, lane_id % 4};\n    }\n};\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/arch.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\nnamespace turbomind::gemm {\n\n// tags for dispatching & conditional codegen\n\ntemplate<int Begin, int End = -1>\nstruct Arch {\n    static constexpr bool is_compatible(int arch)\n    {\n        return Begin <= arch && (End == -1 || arch < End);\n    }\n};\n\nstruct Sm70: Arch<700, 750> {\n    static constexpr int value = 700;\n};\n\nstruct Sm75: Arch<750, 800> {\n    static constexpr int value = 750;\n};\n\nstruct Sm80: Arch<800, 900> {\n    static constexpr int value = 800;\n};\n\nstruct Sm90: Arch<900> {\n    static constexpr int value = 900;\n};\n\ninline bool is_arch_compatible(int karch, int darch)\n{\n    switch (karch) {\n        case 0:\n            return true;\n        case 700:\n            return Sm70::is_compatible(darch);\n        case 750:\n            return Sm75::is_compatible(darch);\n        case 800:\n            return Sm80::is_compatible(darch);\n        case 900:\n            return Sm90::is_compatible(darch);\n        default:\n            return false;\n    }\n}\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/cast.cu",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include \"src/turbomind/kernels/gemm/cast.h\"\n\n#include \"src/turbomind/core/data_type.h\"\n#include \"src/turbomind/kernels/core/array_ops.h\"\n#include \"src/turbomind/kernels/core/common.h\"\n#include \"src/turbomind/kernels/core/math.h\"\n\nnamespace turbomind {\n\ntemplate<class Ti, class To>\nstruct Cast {\n    template<int N>\n    __device__ static Array<To, N> apply(const Array<Ti, N>& vi)\n    {\n        Array<To, N> vo;\n        PRAGMA_UNROLL\n        for (int i = 0; i < N; ++i) {\n            vo[i] = static_cast<To>(vi[i]);\n        }\n        return vo;\n    }\n};\n\ntemplate<class Ti>\nstruct Cast<Ti, uint4_t> {\n    template<int N>\n    __device__ static Array<uint4_t, N> apply(const Array<Ti, N>& vi)\n    {\n        static_assert(N % 8 == 0);\n        Array<uint4_t, N> vo;\n        PRAGMA_UNROLL\n        for (int i = 0; i < N; i += 8) {\n            uint32_t& v = (uint32_t&)vo[i];\n            v           = 0;\n            PRAGMA_UNROLL\n            for (int j = 7; j >= 0; --j) {\n                v = (v << 4) | vi[i + j];\n            }\n        }\n        return vo;\n    }\n};\n\ntemplate<class To>\nstruct Cast<uint4_t, To> {\n    template<int N>\n    __device__ static Array<To, N> apply(const Array<uint4_t, N>& vi)\n    {\n        static_assert(N % 8 == 0);\n        Array<To, N> vo;\n        PRAGMA_UNROLL\n        for (int i = 0; i < N; i += 8) {\n            uint32_t v = (const uint32_t&)vi[i];\n            PRAGMA_UNROLL\n            for (int j = 0; j < 8; ++j) {\n                vo[i + j] = (v & 0xf);\n                v >>= 4;\n            }\n        }\n        return vo;\n    }\n};\n\ntemplate<>\nstruct Cast<uint4_t, uint4_t> {\n    template<int N>\n    __device__ static Array<uint4_t, N> apply(const Array<uint4_t, N>& vi)\n    {\n        return vi;\n    }\n};\n\ntemplate<int VecSize, class Ti, class To>\n__global__ void cast_kernel(To* dst, const Ti* src, size_t n)\n{\n    n /= VecSize;\n\n    auto p_src = (const Array<Ti, VecSize>*)src;\n    auto p_dst = (Array<To, VecSize>*)dst;\n\n    for (size_t p = threadIdx.x + blockDim.x * blockIdx.x; p < n; p += blockDim.x * gridDim.x) {\n        Array<Ti, VecSize> vi;\n        Ldg(vi, (const Ti*)&p_src[p]);\n\n        Array<To, VecSize> vo = Cast<Ti, To>::apply(vi);\n\n        Store((To*)&p_dst[p], vo);\n    }\n}\n\ntemplate<int VecSize, class Ti, class To>\nvoid invokeCast(To* dst, const Ti* src, size_t n, cudaStream_t st)\n{\n    cast_kernel<VecSize><<<256, 256, 0, st>>>(dst, src, n);\n}\n\nvoid extend_to_u8(uint8_t* dst, const uint4_t* src, size_t n, cudaStream_t st)\n{\n    invokeCast<8>(dst, src, n, st);\n}\n\nvoid compact_to_u4(uint4_t* dst, const uint8_t* src, size_t n, cudaStream_t st)\n{\n    invokeCast<8>(dst, src, n, st);\n}\n\nvoid extend_to_u16(uint16_t* dst, const uint4_t* src, size_t n, cudaStream_t st)\n{\n    invokeCast<8>(dst, src, n, st);\n}\n\nnamespace {\n\n__global__ void extend_u16_u8(uint16_t* dst, const uint8_t* src, size_t n)\n{\n    int64_t idx = threadIdx.x + (int64_t)blockDim.x * blockIdx.x;\n    if (idx < n) {\n        dst[idx] = src[idx];\n    }\n}\n\n}  // namespace\n\nvoid extend_to_u16(uint16_t* dst, const uint8_t* src, size_t n, cudaStream_t st)\n{\n    extend_u16_u8<<<(n + 511) / 512, 512, 0, st>>>(dst, src, n);\n}\n\ntemplate<int VecSize, class T>\n__global__ void fuse_scales_and_zeros_kernel(T* fused, const T* scales, T* zeros, size_t n)\n{\n    n /= VecSize;\n\n    auto p_scales = (const Array<T, VecSize>*)scales;\n    auto p_zeros  = (const Array<T, VecSize>*)zeros;\n\n    auto p_fused = (Array<T, VecSize * 2>*)fused;\n\n    for (size_t p = threadIdx.x + blockDim.x * blockIdx.x; p < n; p += blockDim.x * gridDim.x) {\n        Array<T, VecSize> vs;\n        Ldg(vs, (const T*)&p_scales[p]);\n        Array<T, VecSize> vz{};\n        if (zeros) {\n            Ldg(vz, (const T*)&p_zeros[p]);\n        }\n        Array<T, VecSize * 2> vf;\n        PRAGMA_UNROLL\n        for (int i = 0; i < VecSize; ++i) {\n            vf[i * 2]     = vs[i];\n            vf[i * 2 + 1] = -vz[i] * vs[i];\n        }\n        Store((T*)&p_fused[p], vf);\n    }\n}\n\nvoid fuse_scales_and_zeros(half* fused, const half* scales, half* zeros, size_t n, cudaStream_t st)\n{\n    fuse_scales_and_zeros_kernel<4><<<256, 256, 0, st>>>(fused, scales, zeros, n);\n}\n\ntemplate<int VecSize, class T>\n__global__ void\ninterleave_output_dims_kernel(T* __restrict__ fused, const T* __restrict__ a, const T* __restrict__ b, int m, int k)\n{\n    using Vec1 = Array<T, VecSize>;\n\n    const int ki = blockIdx.y;\n\n    auto p_a = reinterpret_cast<const Vec1*>(a + ki * m);\n    auto p_b = reinterpret_cast<const Vec1*>(b + ki * m);\n\n    using Vec2 = Array<T, VecSize * 2>;\n\n    auto p_f = reinterpret_cast<Vec2*>(fused + ki * m * 2);\n\n    m /= VecSize;\n\n    const int tidx = threadIdx.x + blockIdx.x * blockDim.x;\n\n    for (int64_t mi = tidx; mi < m; mi += blockDim.x * gridDim.x) {\n        Vec1 va;\n        Vec1 vb;\n        Ldg(va, (const T*)&p_a[mi]);\n        Ldg(vb, (const T*)&p_b[mi]);\n        Vec2 vc;\n        PRAGMA_UNROLL\n        for (int i = 0; i < VecSize; ++i) {\n            vc[i * 2]     = va[i];\n            vc[i * 2 + 1] = vb[i];\n        }\n        Store((T*)&p_f[mi], vc);\n    }\n}\n\ntemplate<class T>\nvoid interleave_output_dims_impl(T* fused, const T* a, const T* b, int m, int k, cudaStream_t st)\n{\n    constexpr int kVecSize = std::min(8, 128 / (bitsof<T> * 2));\n\n    constexpr int block = 256;\n    const dim3    grid(1, k);  // x is a grid stride loop\n\n    interleave_output_dims_kernel<kVecSize><<<grid, block, 0, st>>>(fused, a, b, m, k);\n}\n\ntemplate void\ninterleave_output_dims_impl(uint8_t* fused, const uint8_t* a, const uint8_t* b, int m, int k, cudaStream_t st);\ntemplate void\ninterleave_output_dims_impl(uint16_t* fused, const uint16_t* a, const uint16_t* b, int m, int k, cudaStream_t st);\ntemplate void\ninterleave_output_dims_impl(uint32_t* fused, const uint32_t* a, const uint32_t* b, int m, int k, cudaStream_t st);\n\n__global__ void adjust_ue8m0_scale_for_half_kernel(uint8_t* data, int n)\n{\n    int64_t idx = threadIdx.x + (int64_t)blockDim.x * blockIdx.x;\n    if (idx < n) {\n        /// TODO: saturate the quantized values accordingly\n        data[idx] = max(0, min(30, (int)data[idx] + 15 - 127));  // exponent 31 is INF in half\n    }\n}\n\nvoid AdjustUe8m0ScaleForHalf(uint8_t* data, int n, cudaStream_t st)\n{\n    constexpr int block = 512;\n    const int     grid  = cdiv(n, block);\n    adjust_ue8m0_scale_for_half_kernel<<<grid, block, 0, st>>>(data, n);\n}\n\ntemplate<class T0, class T1>\n__global__ void BlockscaleToGroupscale_Kernel(T1* dst, const T0* src, int64_t n, int block_size)\n{\n    int64_t idx = threadIdx.x + (int64_t)blockIdx.x * blockDim.x;\n    if (idx < n) {\n        dst[idx] = (T1)src[idx / block_size];\n    }\n}\n\nTensor BlockscaleToGroupscale(const Tensor& scales, DataType data_type, int block_size)\n{\n    TM_CHECK_EQ(scales.dtype(), kFloat32);\n\n    Tensor ret{{scales.shape(0), scales.shape(1) * block_size}, data_type, kDEVICE};\n\n    auto stream = core::Context::stream().handle();\n\n    auto invoke = [&](auto t) {\n        using T = decltype(t);\n        BlockscaleToGroupscale_Kernel<<<(ret.size() + 511) / 512, 512, 0, stream>>>(\n            ret.data<T>(), scales.data<float>(), ret.size(), block_size);\n    };\n\n    TM_DISPATCH_DTYPES(data_type, invoke, half_t, bfloat16_t);\n\n    return ret;\n}\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/cast.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include <cuda_runtime.h>\n\n#include \"src/turbomind/core/core.h\"\n#include \"src/turbomind/kernels/core/data_type.h\"\n\nnamespace turbomind {\n\nvoid extend_to_u8(uint8_t* dst, const uint4_t* src, size_t n, cudaStream_t st = {});\n\nvoid extend_to_u16(uint16_t* dst, const uint4_t* src, size_t n, cudaStream_t st = {});\n\nvoid extend_to_u16(uint16_t* dst, const uint8_t* src, size_t n, cudaStream_t st);\n\nvoid compact_to_u4(uint4_t* dst, const uint8_t* src, size_t n, cudaStream_t st = {});\n\nvoid transpose_u4(uint4_t* dst, const uint4_t* src, int s, int c, cudaStream_t st = {});\n\nvoid fuse_scales_and_zeros(half* fused, const half* scales, half* zeros, size_t n, cudaStream_t st = {});\n\ntemplate<class T>\nvoid interleave_output_dims_impl(T* fused, const T* a, const T* b, int m, int k, cudaStream_t st);\n\ntemplate<class T>\ninline void interleave_output_dims(T* fused, const T* a, const T* b, int m, int k, cudaStream_t st)\n{\n    auto dispatch = [&](auto u) {\n        using U = decltype(u);\n        return interleave_output_dims_impl((U*)fused, (const U*)a, (const U*)b, m, k, st);\n    };\n    if constexpr (bitsof<T> == 8) {\n        return dispatch(uint8_t{});\n    }\n    else if constexpr (bitsof<T> == 16) {\n        return dispatch(uint16_t{});\n    }\n    else if constexpr (bitsof<T> == 32) {\n        return dispatch(uint32_t{});\n    }\n}\n\nvoid AdjustUe8m0ScaleForHalf(uint8_t* data, int n, cudaStream_t st);\n\nTensor BlockscaleToGroupscale(const Tensor& scales, DataType data_type, int block_size);\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/context.cu",
    "content": "\n#include \"src/turbomind/kernels/core/array.h\"\n#include \"src/turbomind/kernels/core/math.h\"\n#include \"src/turbomind/kernels/gemm/context.h\"\n#include \"src/turbomind/kernels/gemm/cta_map.h\"\n#include \"src/turbomind/kernels/gemm/desc.h\"\n#include \"src/turbomind/kernels/gemm/kernel.h\"\n#include \"src/turbomind/kernels/gemm/moe_utils_v2.h\"\n#include \"src/turbomind/kernels/gemm/types.h\"\n#include \"src/turbomind/kernels/gemm/utils.h\"\n#include \"src/turbomind/utils/monotonic.h\"\n#include <algorithm>\n#include <cub/block/block_reduce.cuh>\n#include <iostream>\n#include <tuple>\n\nnamespace turbomind::gemm {\n\nstatic std::optional<GemmDesc> get_gemm_desc(const Operation&    operation,\n                                             const MatrixLayout& Adesc,\n                                             const MatrixLayout& Udesc,\n                                             const MatrixLayout& Bdesc,\n                                             const MatrixLayout& Vdesc,\n                                             const MatrixLayout& Cdesc,\n                                             const MatrixLayout& Ddesc,\n                                             int                 arch)\n{\n\n    // Constant dimensions are set to the exact value\n    // Variable dimensions are set to sum of the values\n\n    const int m0 = Adesc.rows, k0 = Adesc.cols;\n    const int k1 = Bdesc.rows, n0 = Bdesc.cols;\n    const int m1 = Ddesc.rows, n1 = Ddesc.cols;\n\n    const int l0 = Adesc.num, l1 = Bdesc.num, l2 = Ddesc.num;\n\n    if (m0 != m1 || n0 != n1 || k0 != k1 || l0 != l1 || l0 != l2) {\n        fprintf(stderr, \"%d %d %d %d %d %d %d %d %d\\n\", m0, m1, n0, n1, k0, k1, l0, l1, l2);\n        return {};\n    }\n\n    GemmDesc desc{arch,\n                  Adesc.type,\n                  Bdesc.type,\n                  Ddesc.type,\n                  Adesc.order,\n                  Bdesc.order,\n                  Ddesc.order,\n                  get_mode(Adesc),\n                  get_mode(Bdesc),\n                  get_mode(Ddesc),\n                  Adesc.pack,\n                  Bdesc.pack,\n                  Udesc.pack,\n                  Vdesc.pack,\n                  operation.quant_a,\n                  operation.quant_b,\n                  operation.epilogue,\n                  operation.batch_dim,\n                  -1};\n\n    desc.m   = m0;\n    desc.n   = n0;\n    desc.k   = k0;\n    desc.num = std::max(l0, 1);\n\n    if (desc.num > 1) {\n        desc.group_axis = operation.batch_dim;\n    }\n\n    return desc;\n}\n\nstd::vector<LaunchSpec> get_swizzle(const int4& shape, const LaunchSpec& spec, const std::vector<int>& swizzle)\n{\n    std::vector<int> vec;\n    const int        max_swizzle = spec.kernel->GetMaxSwizzle(shape);\n    for (const auto& s : swizzle) {\n        if (s <= max_swizzle && std::find(vec.begin(), vec.end(), s) == vec.end()) {\n            vec.push_back(s);\n        }\n    }\n    std::vector<LaunchSpec> ret;\n    for (const auto& s : vec) {\n        auto tmp    = spec;\n        tmp.swizzle = s;\n        ret.push_back(tmp);\n    }\n    return ret;\n}\n\nContext::Context(const cudaDeviceProp& prop)\n{\n    arch_     = prop.major * 100 + prop.minor * 10;\n    sm_count_ = prop.multiProcessorCount;\n}\n\nbool Context::Init(const Operation&    operation,\n                   const MatrixLayout& Adesc,\n                   const MatrixLayout& Udesc,\n                   const MatrixLayout& Bdesc,\n                   const MatrixLayout& Vdesc,\n                   const MatrixLayout& Cdesc,\n                   const MatrixLayout& Ddesc)\n{\n    auto desc = get_gemm_desc(operation, Adesc, Udesc, Bdesc, Vdesc, Cdesc, Ddesc, arch_);\n    if (!desc) {\n        return false;\n    }\n\n    desc_       = *desc;\n    desc_trans_ = transpose(desc_);\n\n    return true;\n}\n\nstd::vector<Kernel*> Context::Filter(const std::vector<Kernel*>& kernels) const\n{\n    std::vector<std::pair<Kernel*, int>> feasible;\n    auto get_batch_dim  = [](auto k, auto& g) { return g.batch_dim ? k->desc().cta_tile.y : k->desc().cta_tile.x; };\n    int  max_batch_size = 0;  // max batch size of single CTA tile\n\n    for (auto& k : kernels) {\n        auto& g = get_desc(*k);\n        if (k->is_feasible(g)) {\n            auto bsz = get_batch_dim(k, g);\n            feasible.emplace_back(k, bsz);\n            max_batch_size = std::max(bsz, max_batch_size);\n        }\n    }\n\n    // Batch size of the GEMM problem\n    const int batch_size = desc_.batch_dim ? desc_.n : desc_.m;\n    // std::cout << \"BATCH SIZE: \" << batch_size << \"\\n\";\n\n    // Find smallest kernel the problem can fit into (may not exist)\n    for (const auto& [k, bsz] : feasible) {\n        if (bsz >= batch_size) {\n            max_batch_size = std::min(max_batch_size, bsz);\n        }\n    }\n\n    const auto pred = [&](auto k) {  //\n        return k.second > max_batch_size;\n    };\n    feasible.erase(std::remove_if(feasible.begin(), feasible.end(), pred), feasible.end());\n\n    std::vector<Kernel*> ret;\n    for (auto& [k, bsz] : feasible) {\n        // std::cout << \"KERNEL: \" << k->name() << \", BSZ: \" << bsz << std::endl;\n        ret.push_back(k);\n    }\n\n    return ret;\n}\n\nstd::vector<LaunchSpec> Context::Populate(const Kernel& kernel, const PopulateParam& param) const\n{\n    // early exit for cuBLAS backend\n    if (kernel.desc().backend) {\n        return {LaunchSpec{const_cast<Kernel*>(&kernel), 0, 1}};\n    }\n\n    const auto& gemm = get_desc(kernel);\n\n    const int m = gemm.m, n = gemm.n, k = gemm.k, num = std::max(1, gemm.num);\n\n    const auto& desc = kernel.desc();\n    const auto& info = kernel.info();\n\n    const int64_t tiled_shape_m = cdiv(m, desc.cta_tile.x * (desc.group_axis == 0 ? num : 1));\n    const int64_t tiled_shape_n = cdiv(n, desc.cta_tile.y * (desc.group_axis == 1 ? num : 1));\n    const int     chunk_cnt_k   = cdiv(k, kernel.chunk_size_k());\n\n    // Despite we only have sm_count * constant tensor cores, this is the granularity for scheduling\n    const int   concurrency     = sm_count_ * kernel.info().max_active_ctas;\n    const float waves_per_split = float(tiled_shape_m * tiled_shape_n) / concurrency;\n    const float splits_per_wave = 1.f / waves_per_split;\n\n    // Tile quantization\n    const int64_t ceil_m = tiled_shape_m * desc.cta_tile.x;\n    const int64_t ceil_n = tiled_shape_n * desc.cta_tile.y;\n\n    // int max_splits = kernel.GetMaxSplits(m, n, k, param.barriers_size, param.partials_size);\n    int max_splits = kernel.GetMaxSplits({m, n, k, num}, 0, param.barriers_size, param.partials_size);\n\n    // std::cout << \"max_splits: \" << max_splits << std::endl;\n\n    max_splits = std::min(param.max_splits, max_splits);\n\n    std::vector<LaunchSpec> specs;\n\n    /// TODO: revise this according to the lastest scheduler\n    for (int splits = 1; splits <= max_splits; ++splits) {\n        // Split quantization, penalize uneven splits\n        const int64_t split_ceil_k = cdiv(chunk_cnt_k, splits) * kernel.chunk_size_k();\n        // Footprint for single split\n        const int64_t split_mma_cost = ceil_m * ceil_n * split_ceil_k;\n        // Footprint for single wave\n        const int64_t wave_mma_cost = split_mma_cost * splits_per_wave;\n\n        // Wave quantization\n        // const int waves = (int)std::ceil(wave_per_split * splits);\n\n        // Bold simulation of thread block scheduling\n        const int   grid_size    = tiled_shape_m * tiled_shape_n * splits * num;\n        const int   full_waves   = grid_size / concurrency;\n        const int   residue      = grid_size % concurrency;\n        const float partial_wave = (float)cdiv(residue, sm_count_) / info.max_active_ctas;\n        const float waves        = full_waves + partial_wave;\n\n        if (splits > 1 && waves > param.max_waves) {\n            break;\n        }\n        // ceil(tiled_mn / C * splits) * C / tiled_mn * ceil_m * ceil_n * split_ceil_k\n        const int64_t mma_cost = wave_mma_cost * waves;\n\n        // IO has less severe quantization effect\n        const int64_t mio_cost_a = byte_size(desc.type_a, tiled_shape_n * m * split_ceil_k) * splits * num;\n        const int64_t mio_cost_b = byte_size(desc.type_b, tiled_shape_m * n * split_ceil_k) * splits * num;\n        /// TODO: read type from `desc_.accum` when added\n        const int64_t mio_cost_c = byte_size(desc.type_c, (int64_t)m * n) * (splits - 1) * 2 * num;\n        const int64_t mio_cost   = mio_cost_a + mio_cost_b + mio_cost_c;\n\n        // std::cout << kernel.name() << \" \" << splits << \" \" << waves << \" \" << (float)mio_cost << \" \" <<\n        // (float)mma_cost\n        //           << \"\\n\";\n\n        // metrics.emplace_back(splits, KernelMetric{mio_cost, mma_cost});\n\n        LaunchSpec spec{};\n        spec.kernel    = const_cast<Kernel*>(&kernel);\n        spec.splits    = splits;\n        spec.swizzle   = param.swizzle;\n        spec.estimated = {mio_cost, mma_cost};\n        specs.push_back(spec);\n    }\n\n    return specs;\n}\n\nstd::vector<LaunchSpec> Context::Swizzle(const LaunchSpec& spec, const std::vector<int>& swizzle) const\n{\n    auto& desc = get_desc(*spec.kernel);\n    return get_swizzle({desc.m, desc.n, desc.k, desc.num}, spec, swizzle);\n}\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/context.h",
    "content": "#pragma once\n\n#include \"src/turbomind/kernels/gemm/desc.h\"\n#include \"src/turbomind/kernels/gemm/kernel.h\"\n#include \"src/turbomind/kernels/gemm/types.h\"\n#include <optional>\n\nnamespace turbomind::gemm {\n\nstruct PopulateParam {\n    int    max_splits;\n    int    max_waves;\n    int    swizzle;\n    size_t barriers_size;\n    size_t partials_size;\n};\n\nclass Context {\npublic:\n    explicit Context(const cudaDeviceProp& prop);\n\n    bool Init(const Operation&    operation,\n              const MatrixLayout& Adesc,\n              const MatrixLayout& Udesc,\n              const MatrixLayout& Bdesc,\n              const MatrixLayout& Vdesc,\n              const MatrixLayout& Cdesc,\n              const MatrixLayout& Ddesc);\n\n    std::vector<Kernel*> Filter(const std::vector<Kernel*>& kernels) const;\n\n    std::vector<LaunchSpec> Populate(const Kernel& kernel, const PopulateParam& param) const;\n\n    std::vector<LaunchSpec> Swizzle(const LaunchSpec& spec, const std::vector<int>& swizzle) const;\n\n    const GemmDesc& desc() const\n    {\n        return desc_;\n    }\n\n    const GemmDesc& get_desc(const Kernel& kernel) const\n    {\n        return kernel.desc().transpose ? desc_trans_ : desc_;\n    }\n\n    // Alignment\n    // (align_m, align_n, align_k) -> is_aligned\n    //  gcd_mnk need to be part of gemm desc\n\n    // Max splits\n    // (max_mn_tiles, max_k_tiles) -> max_splits\n\n    // CTA Swizzling\n    // - GemmScheduler: return get_log_tile\n    // - DynamicScheduler: bypass\n\n    // Cost estimation\n    //\n\nprotected:\n    int arch_{};\n    int sm_count_{};\n\n    GemmDesc desc_{};\n    GemmDesc desc_trans_{};\n};\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/convert.cuh",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include <cuda_pipeline_primitives.h>\n\n#include \"src/turbomind/kernels/core/data_type.h\"\n#include \"src/turbomind/kernels/core/math.h\"\n\n#include \"src/turbomind/kernels/attention/quantization.h\"\n\n#include \"src/turbomind/kernels/gemm/cp_async.h\"\n#include \"src/turbomind/kernels/gemm/format.h\"\n#include \"src/turbomind/kernels/gemm/iterator_sm70.h\"\n#include \"src/turbomind/kernels/gemm/matrix_ptr.h\"\n#include \"src/turbomind/kernels/gemm/operand.h\"\n#include \"src/turbomind/kernels/gemm/smem_copy.h\"\n#include \"src/turbomind/kernels/gemm/types.h\"\n#include \"src/turbomind/kernels/gemm/utils.h\"\n\ntemplate<class T>\n__device__ void print_type(T)\n{\n    if (threadIdx.x == 0) {\n        printf(\"%s\\n\", __PRETTY_FUNCTION__);\n    }\n}\n\nnamespace turbomind::gemm {\n\ntemplate<int M_, int K_, int Pack_M, class Operand_, class Td, class Converter>\nstruct ConvertOperand {\n\n    static constexpr int M = M_;\n    static constexpr int K = K_;\n\n    using Operand = MakeOperand<Operand_, IteratorSm70<Striding::kFlat, cache_policy::Default>, M_, K_, 1>;\n\n    using Ts         = typename Operand::Dtype;\n    using SmemLayout = typename Operand::SmemLayout;\n    using GmemIter   = typename Operand::GmemIter;\n\n    using Atom = typename Operand::SmemCopyAtom;\n\n    using SmemCopy = SmemCopy<Operand, M_ / Atom::M, K_ / Atom::K, Atom::M, Atom::K>;\n\n    using Accessor = SmemAccessor<Ts, SmemLayout>;\n\n    static constexpr auto kOrderS = Operand::kOrder;\n\n    static constexpr int ITER_K = ceil_div(K, Atom::K);\n\n    /// TODO: generailize this\n    static constexpr int WARP_CNT = 1;\n\n    using PtrD = get_pointer_type<Td>;\n\n    struct Param {\n        int         m;\n        int         k;\n        MatrixParam src;\n        MatrixParam dst;\n    };\n\n    using SharedStorage = Array<Ts, SmemLayout::kSize>;\n\n    template<class T, int N, int M>\n    static constexpr int get_fragment_size(Array<T, N> (&)[M])\n    {\n        return N;\n    }\n\n    template<class T, int N, int M>\n    static constexpr int get_fragment_num(Array<T, N> (&)[M])\n    {\n        return M;\n    }\n\n    __device__ constexpr int2 _mk2cs(int m, int k)\n    {\n        return mk2cs<kOrderS>(m, k);\n    }\n\n    __device__ void operator()(const Param& param, char* smem_buf)\n    {\n        Ts* smem = (Ts*)smem_buf;\n\n        const int cta_cnt_m = ceil_div(param.m, M);\n        const int cta_cnt_k = ceil_div(param.k, K);\n\n        const int cta_idx_m = blockIdx.x;\n\n        const int cta_offset_m = cta_idx_m * M;\n        const int residue_m    = min(M, param.m - cta_offset_m);\n\n        const int warp_id = threadIdx.x / WARP_SIZE;\n\n        const int warp_offset_m = 0;\n\n        Converter converter{};\n\n        typename SmemCopy::Frag data;\n\n        constexpr int kFragSize = get_fragment_size(data);\n        constexpr int kFragNum  = get_fragment_num(data);\n        constexpr int kPackSize = kFragSize * Pack_M;\n\n        const int pack_cnt_k = ceil_div(param.k, Atom::K);\n        const int pack_cnt_m = ceil_div(param.m, Atom::M * Pack_M);\n\n        if (threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0) {\n            // printf(\"m=%d, k=%d, lds = %d\\n\", param.m, param.k, param.lds);\n            // printf(\n            //     \"CTA_M=%d, CTA_K=%d, cta_cnt_m=%d, cta_cnt_k=%d, cta_idx_m=%d, ITER_K=%d, pack_cnt_m=%d,\n            //     pack_cnt_k=%d\\n\", M_, K_, cta_cnt_m, cta_cnt_k, cta_idx_m, ITER_K, pack_cnt_m, pack_cnt_k);\n            // printf(\"frag_size=%d, frag_num=%d, pack_size=%d\\n\", kFragSize, kFragNum, kPackSize);\n        }\n\n        const int cta_offset_k = (cta_cnt_k - 1) * K;\n        const int residue_k    = min(K, param.k - cta_offset_k);\n\n        const auto mat_S = resolve<Ts, Striding::kFlat>(param.src, 0);\n        const auto mat_D = resolve<Td, Striding::kFlat>(param.dst, 0);\n\n        // Handle residue k first\n        GmemIter gmem{mat_S, {cta_offset_m, cta_offset_k}, {residue_m, residue_k}};\n\n        gmem.smem_data_ = smem;\n        gmem.ClearSmem();\n\n        __syncthreads();\n\n        // gmem.Prefetch(true);\n\n        typename GmemIter::Fragments fragments{};\n        gmem.Fetch(fragments, true);\n        gmem.Store(fragments);\n\n        // Rest full k tiles\n        gmem            = GmemIter{mat_S, {cta_offset_m, 0}, {residue_m, K}};\n        gmem.smem_data_ = smem;\n\n        SmemCopy smem_copy({warp_offset_m, 0});\n\n        // last, 0, 1, 2, 3, ..., last - 1\n        int cta_idx_k = cta_cnt_k - 1;\n\n        get_pointer_type<Td> mat_D_ptr{(Td*)mat_D.ptr.ptr};\n\n        for (int k_stage = 0; k_stage < cta_cnt_k; ++k_stage) {\n            __syncthreads();\n\n            PRAGMA_UNROLL\n            for (int k = 0; k < ITER_K; ++k) {\n                // Assuming `SmemCopy` is a warp-level operation\n                // Load from smem as we are doing GEMMs\n                // SmemCopy::copy(smem, data, int2{warp_offset_m, 0}, k);\n                smem_copy(smem, data, k);\n\n                PRAGMA_UNROLL\n                for (int m = 0; m < kFragNum; m += Pack_M) {\n                    // Convert and pack rmem data\n                    Array<Td, kPackSize> packed = converter((Array<Ts, kPackSize>&)data[m]);\n\n                    // Logical pack coords\n                    const int pack_idx_k = cta_idx_k * ITER_K + k;\n                    const int pack_idx_m = ((cta_idx_m * WARP_CNT + warp_id) * kFragNum + m) / Pack_M;\n\n                    // Linear pack index\n                    const int pack_index = cs2idx(_mk2cs(pack_idx_m, pack_idx_k),  //\n                                                  _mk2cs(pack_cnt_m, pack_cnt_k).x);\n\n                    auto [unique_id, repeat_id] = Atom::unique(threadIdx.x, pack_index);\n\n                    // Store in [pack_id, lane_id], static cast is needed to decay SubBytePtr<T> to T*\n                    auto dst_ptr = static_cast<Td*>(mat_D_ptr + unique_id * kPackSize);\n\n                    if (pack_idx_m < pack_cnt_m && pack_idx_k < pack_cnt_k && repeat_id == 0) {\n                        Store(dst_ptr, packed);\n                    }\n                }\n            }\n\n            __syncthreads();\n\n            if (k_stage == cta_cnt_k - 1) {\n                break;\n            }\n\n            // gmem.Prefetch(true);\n            gmem.Fetch(fragments, true);\n            gmem.Store(fragments);\n            gmem.Advance();\n\n            cta_idx_k = k_stage;\n        }\n    }\n\n    __device__ void print(...) {}\n\n    __device__ void print(Array<uint32_t, 2> _x)\n    {\n        auto& x = (const Array<half, 4>&)_x;\n        printf(\"tidx=%d, %f %f %f %f\\n\", (int)threadIdx.x, (float)x[0], (float)x[1], (float)x[2], (float)x[3]);\n    }\n};\n\nextern __shared__ char smem_buf[];\n\ntemplate<class Kernel>\n__global__ void convert_kernel(typename Kernel::Param param)\n{\n    Kernel kernel;\n    kernel(param, smem_buf);\n}\n\nconstexpr bool is_AB(Op_Tag op)\n{\n    if (op == OPERAND_A || op == OPERAND_B) {\n        return true;\n    }\n    else {\n        return false;\n    }\n}\n\nconstexpr bool is_UV(Op_Tag op)\n{\n    return !is_AB(op);\n}\n\ntemplate<class Dtype>\nconstexpr int unit_size(basic_type<Dtype>)\n{\n    return 1;\n}\n\nconstexpr int unit_size(basic_type<uint8_t>)\n{\n    return 4;\n}\n\nconstexpr int unit_size(basic_type<uint4_t>)\n{\n    return 8;\n}\n\n// MMA     : H_16816, H_1688, H_884, H_SIMT\n// Operand : A, B, U, V\n// Order   : row, col\n// Dtype   : u16, u8, u4 (u6, u3)\n// PackNum : 1, 2, 4\n\ntemplate<class Operand, class Dtype_, int PackNum>\nstruct Config {\n    static constexpr int CTA_M = 64;\n    static constexpr int CTA_K = 32;\n\n    static constexpr int BLOCK_SIZE = 32;\n\n    using Stype = typename Operand::Dtype;\n    using Dtype = Dtype_;\n\n    using Kernel = ConvertOperand<CTA_M, CTA_K, PackNum, Operand, Dtype, Converter<Stype, Dtype>>;\n};\n\ntemplate<class Config>\nvoid Convert_v2_Impl(const void* S, const MatrixLayout& Sdesc, void* D, const MatrixLayout& Ddesc, cudaStream_t stream)\n{\n    using Kernel = typename Config::Kernel;\n    using Stype  = typename Config::Stype;\n    using Dtype  = typename Config::Dtype;\n\n    constexpr int CTA_M = Config::CTA_M;\n\n    static constexpr int kSmemSize = sizeof(typename Kernel::SharedStorage);\n\n    if (kSmemSize > (48 << 10)) {\n        cudaFuncSetAttribute(convert_kernel<Kernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize);\n    }\n\n    typename Kernel::Param param{Sdesc.rows, Sdesc.cols, to_param((void*)S, Sdesc), to_param((void*)D, Ddesc)};\n\n    constexpr int threads = Config::BLOCK_SIZE;\n    const int     blocks  = ceil_div(Sdesc.rows, CTA_M);\n\n    // std::cout << __PRETTY_FUNCTION__ << std::endl;\n    // std::cout << __PRETTY_FUNCTION__ << \"\\nThreadMap:\\n\";\n    // Print(typename Kernel::GmemIter::ThreadMap{});\n\n    convert_kernel<Kernel><<<blocks, threads, kSmemSize, stream>>>(param);\n}\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/convert.h",
    "content": "\n#include <array>\n#include <vector>\n\n#include \"src/turbomind/core/data_type.h\"\n#include \"src/turbomind/kernels/gemm/types.h\"\n\nnamespace turbomind::gemm {\n\nstruct LayoutConverter {\n\n    Order order;\n    Pack  pack;\n\n    virtual int Convert(const void*         S,  //\n                        const MatrixLayout& Sdesc,\n                        void*               D,\n                        MatrixLayout&       Ddesc,\n                        cudaStream_t        stream) const = 0;\n};\n\n// Pointers to singletons\nstd::array<const LayoutConverter*, 2> GetConverters(DataType data_type,\n                                                    DataType weight_type,  //\n                                                    DataType input_type,\n                                                    bool     grouped,\n                                                    int      sm);\n\n// Free with `cudaFree`\nvoid* MakeStridedPtrs(const std::vector<std::pair<void*, int>>& ptrs, cudaStream_t stream);\nvoid* MakeBlockedPtrs(const std::vector<std::pair<void*, int>>& ptrs, cudaStream_t stream);\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/convert_v3.cu",
    "content": "\n#include <array>\n\n#include \"src/turbomind/core/check.h\"\n#include \"src/turbomind/kernels/gemm/arch.h\"\n#include \"src/turbomind/kernels/gemm/convert.cuh\"\n#include \"src/turbomind/kernels/gemm/convert.h\"\n#include \"src/turbomind/kernels/gemm/types.h\"\n\n#include \"src/turbomind/kernels/gemm/arch/operand_simt.h\"\n#include \"src/turbomind/kernels/gemm/arch/operand_sm70_s884.h\"\n#include \"src/turbomind/kernels/gemm/arch/operand_sm80_s16816.h\"\n\nnamespace turbomind::gemm {\n\ntemplate<class Arch, Order order_, MMA_Tag mma_tag, Op_Tag op_tag, int pack_num, class Stype, class Dtype>\nstruct LayoutConverterImpl: public LayoutConverter {\n\n    LayoutConverterImpl(): LayoutConverter{}\n    {\n        this->order = order_;\n        this->pack  = mma_tag | op_tag | pack_num;\n    }\n\n    int Convert(const void*         S,\n                const MatrixLayout& Sdesc_,  // (m,k) / (n,k)\n                void*               D,\n                MatrixLayout&       Ddesc,  // (m,k) / (n,k)\n                cudaStream_t        stream) const override\n    {\n        // TM_CHECK_EQ(Sdesc.pack, 0U) << \"Source must be non-packed format\";\n\n        const bool trans = op_tag == OPERAND_B || op_tag == OPERAND_V;\n        // (k, n) -> (n, k)\n        MatrixLayout Sdesc = trans ? transpose(Sdesc_) : Sdesc_;\n        // MatrixLayout Ddesc = trans ? transpose(Ddesc_) : Ddesc_;\n\n        TM_CHECK_NOTNULL(S);\n        TM_CHECK_NOTNULL(D);\n\n        using Operand = typename GetOperand<mma_tag, op_tag, Stype, order_, false>::Operand;\n\n        Convert_v2_Impl<Config<Operand, Dtype, pack_num>>(S, Sdesc, D, Ddesc, stream);\n\n        constexpr Pack pack = mma_tag | op_tag | pack_num;\n\n        // Update leading dimension\n        Ddesc.ld = mk2cs<order_>(Packing_v2<pack, order_>::apply({Sdesc.rows, Sdesc.cols})).x;\n\n        return 0;\n    }\n};\n\ntemplate<class Arch, Order order, uint32_t pack, class Stype, class Dtype>\nstatic LayoutConverter* GetImpl()\n{\n    constexpr auto mma      = get_mma_tag(pack);\n    constexpr auto operand  = get_operand_tag(pack);\n    constexpr auto pack_num = get_pack_num(pack);\n\n    static LayoutConverterImpl<Arch, order, mma, operand, pack_num, Stype, Dtype> impl{};\n\n    return &impl;\n}\n\ntemplate<class Stype, class Dtype>\nstruct Cvt {\n    template<class Arch, Order order, Pack pack>\n    LayoutConverter* operator()(Arch, constant<order>, constant<pack>) const\n    {\n        return GetImpl<Arch, order, pack, Stype, Dtype>();\n    }\n};\n\nconstexpr constant<(Pack)HMMA_16816> s16816h{};\nconstexpr constant<(Pack)HMMA_884>   s884h{};\n\ntemplate<auto a, auto b>\nconstexpr auto operator|(constant<a>, constant<b>)\n{\n    return constant<a | b>{};\n}\n\nstd::array<const LayoutConverter*, 2> GetConverters(DataType data_type,\n                                                    DataType weight_type,  //\n                                                    DataType input_type,\n                                                    bool     grouped,\n                                                    int      sm)\n{\n    constexpr constant<kRowMajor> kRow{};\n    constexpr constant<kColMajor> kCol{};\n\n    constexpr constant<OPERAND_A> A{};\n    constexpr constant<OPERAND_B> B{};\n    constexpr constant<OPERAND_U> U{};\n    constexpr constant<OPERAND_V> V{};\n\n    constexpr constant<1> _1{};\n    constexpr constant<2> _2{};\n\n    constexpr Arch<80> sm8_{};\n    constexpr Sm75     sm75{};\n    constexpr Sm70     sm70{};\n\n    if (weight_type == kHalf || weight_type == kBfloat16) {\n        constexpr Cvt<uint16_t, uint16_t> W;\n        if (grouped) {\n            // clang-format off\n            if (sm >= 80) return {W(sm8_, kRow, s16816h | B | _1), {}};\n            if (sm == 75) return {W(sm75, kRow, s16816h | B | _1), {}};\n            if (sm >= 70) return {W(sm70, kRow,   s884h | B | _1), {}};\n            // clang-format on\n        }\n        else {\n            return {};  //  trivial case: dense floating point\n        }\n    }\n\n    // For performance reasons, u4 use different layouts for grouped/non-grouped GEMM\n    if (weight_type == kUint4) {\n        constexpr Cvt<uint16_t, uint4_t>  W;  // e4m3     weight\n        constexpr Cvt<uint32_t, uint32_t> S;  // f16/bf16 scales&zeros\n        if (grouped) {\n            // clang-format off\n            if (sm >= 80) return {W(sm8_, kRow, s16816h | B | _2), S(sm8_, kCol, s16816h | V | _1)};\n            if (sm == 75) return {W(sm75, kRow, s16816h | B | _2), S(sm75, kCol, s16816h | V | _1)};\n            if (sm >= 70) return {W(sm70, kRow,   s884h | B | _1), S(sm70, kCol,   s884h | V | _1)};\n            // clang-format on\n        }\n        else {\n            // clang-format off\n            if (sm >= 80) return {W(sm8_, kCol, s16816h | B | _2), S(sm8_, kCol, s16816h | V | _1)};\n            if (sm == 75) return {W(sm75, kCol, s16816h | B | _2), S(sm75, kCol, s16816h | V | _1)};\n            if (sm >= 70) return {W(sm70, kRow,   s884h | B | _1), S(sm70, kCol,   s884h | V | _1)};\n            // clang-format on\n        }\n    }\n\n    if (weight_type == kFloat4_e2m1) {\n        constexpr Cvt<uint16_t, uint4_t> W;  // e2m1  weight\n        constexpr Cvt<uint8_t, uint8_t>  S;  // ue8m0 scales\n        // clang-format off\n        if (sm >= 80) return {W(sm8_, kCol, s16816h | A | _1), S(sm8_, kCol, s16816h | U | _1)};\n        if (sm == 75) return {W(sm75, kCol, s16816h | A | _1), S(sm75, kCol, s16816h | U | _1)};\n        if (sm >= 70) return {W(sm70, kRow,   s884h | B | _1), S(sm70, kCol,   s884h | V | _1)};\n        // clang-format on\n    }\n\n    if (weight_type == kFloat8_e4m3) {\n        constexpr Cvt<uint16_t, uint8_t>  W;  // e4m3     weight\n        constexpr Cvt<uint16_t, uint16_t> S;  // f16/bf16 scales\n        // clang-format off\n        if (sm >= 80) return {W(sm8_, kCol, s16816h | A | _1), S(sm8_, kCol, s16816h | U | _1)};\n        if (sm == 75) return {W(sm75, kCol, s16816h | A | _1), S(sm75, kCol, s16816h | U | _1)};\n        if (sm >= 70) return {W(sm70, kRow,   s884h | B | _1), S(sm70, kCol,   s884h | V | _1)};\n        // clang-format on\n    }\n\n    TM_CHECK(0) << \"Invalid combination: \" << sm << \" \" << data_type << \" \" << weight_type << \" \" << input_type << \" \"\n                << grouped;\n\n    return {};\n}\n\nnamespace {\n\ntemplate<int N>\nstruct Param {\n    StridedPtr  data[N];\n    StridedPtr* ptr;\n    int         n;\n};\n\ntemplate<int N>\n__global__ void fill_strided_ptrs(Param<N> param)\n{\n    const int idx = threadIdx.x + blockIdx.x * blockDim.x;\n    if (idx < param.n) {\n        param.ptr[idx] = param.data[idx];\n    }\n}\n\n}  // namespace\n\nvoid* MakeStridedPtrs(const std::vector<std::pair<void*, int>>& ptrs, cudaStream_t stream)\n{\n    constexpr int N = 64;\n    Param<N>      param{};\n    static_assert(sizeof(param) <= 4096);  // max parameter size for cuda11\n    StridedPtr* ptr{};\n    cudaMallocAsync(&ptr, sizeof(StridedPtr) * ptrs.size(), stream);\n    param.ptr = ptr;\n    for (int i = 0; i < (int)ptrs.size(); i += N) {\n        const int n = std::min<int>(ptrs.size() - i, N);\n        for (int j = 0; j < n; ++j) {\n            auto& [p, s]  = ptrs[i + j];\n            param.data[j] = StridedPtr{p, s};\n        }\n        param.n = n;\n        fill_strided_ptrs<<<1, N, 0, stream>>>(param);\n        param.ptr += N;\n    }\n    return ptr;\n}\n\nnamespace {\n\ntemplate<int N>\n__global__ void fill_blocked_ptrs(Array<void*, N> src, void** dst, int n)\n{\n    const int idx = threadIdx.x + blockIdx.x * blockDim.x;\n    if (idx < n) {\n        dst[idx] = src[idx];\n    }\n}\n\n}  // namespace\n\nvoid* MakeBlockedPtrs(const std::vector<std::pair<void*, int>>& ptrs, cudaStream_t stream)\n{\n    constexpr int   N = 64;\n    Array<void*, N> src{};\n    static_assert(sizeof(src) <= 4096);  // max parameter size for cuda11\n    void** dst{};\n    cudaMallocAsync(&dst, sizeof(void*) * ptrs.size(), stream);\n    for (int i = 0; i < (int)ptrs.size(); i += N) {\n        const int n = std::min<int>(ptrs.size() - i, N);\n        for (int j = 0; j < n; ++j) {\n            auto& [p, s] = ptrs[i + j];\n            src[j]       = p;\n        }\n        fill_blocked_ptrs<<<1, N, 0, stream>>>(src, dst, n);\n        dst += n;\n    }\n    return dst - ptrs.size();\n}\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/cp_async.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include <type_traits>\n\n#if (__CUDACC_VER_MAJOR__ >= 11) && (__CUDACC_VER_MINOR__ >= 4)\n#define L2_CACHEHINT(size) \".L2::\" #size \"B\"\n#else\n#define L2_CACHEHINT(size)\n#endif\n\nnamespace turbomind {\n\nenum class CacheOp\n{\n    kDefault,  // use global when possible\n    kAlways,\n    kGlobal,\n};\n\ntemplate<CacheOp cache_op, int size>\nstruct GetCacheOp {\n    static constexpr auto value = cache_op;\n};\n\ntemplate<>\nstruct GetCacheOp<CacheOp::kDefault, 16> {\n    static constexpr auto value = CacheOp::kGlobal;\n};\n\ntemplate<int size>\nstruct GetCacheOp<CacheOp::kDefault, size> {\n    static constexpr auto value = CacheOp::kAlways;\n};\n\nenum class EvictPolicy\n{\n    kEvictNormal,\n    kEvictFirst,\n    kEvictLast,\n};\n\nnamespace cache_policy {\n\nstruct Default {\n    static constexpr auto kCacheOp     = CacheOp::kDefault;\n    static constexpr auto kEvictPolicy = EvictPolicy::kEvictNormal;\n};\n\nstruct Stream {\n    static constexpr auto kCacheOp     = CacheOp::kDefault;\n    static constexpr auto kEvictPolicy = EvictPolicy::kEvictFirst;\n};\n\nstruct Reuse {\n    static constexpr auto kCacheOp     = CacheOp::kAlways;\n    static constexpr auto kEvictPolicy = EvictPolicy::kEvictNormal;\n};\n\n};  // namespace cache_policy\n\ntemplate<CacheOp, int size, int prefetch_size>\nstruct CP_ASYNC {\n};\n\ntemplate<int prefetch_size>\nstruct CP_ASYNC<CacheOp::kGlobal, 16, prefetch_size> {\n    // clang-format off\n    __device__ static void apply(int smem_ptr, const void* __restrict__ src, bool mask)\n    {\n        asm volatile(\"{\\n  .reg .pred p;\\n  setp.ne.b32 p, %0, 0;\\n\"\n                     \"  @p cp.async.cg.shared.global [%1], [%2], 16;\\n\"\n                     \"}\\n\" ::\"r\"((int)mask), \"r\"(smem_ptr), \"l\"(src));\n    }\n    __device__ static void apply(int smem_ptr, const void* __restrict__ src, uint64_t cache_policy, bool mask)\n    {\n        asm volatile(\"{\\n  .reg .pred p;\\n  setp.ne.b32 p, %0, 0;\\n\"\n                     \"  @p cp.async.cg.shared.global.L2::cache_hint [%1], [%2], 16, %3;\\n\"\n                     \"}\\n\" ::\"r\"((int)mask), \"r\"(smem_ptr), \"l\"(src), \"l\"(cache_policy));\n    }\n    // clang-format on\n};\n\ntemplate<>\nstruct CP_ASYNC<CacheOp::kGlobal, 16, 64> {\n    // clang-format off\n    __device__ static void apply(int smem_ptr, const void* __restrict__ src, bool mask)\n    {\n        asm volatile(\"{\\n  .reg .pred p;\\n  setp.ne.b32 p, %0, 0;\\n\"\n                     \"  @p cp.async.cg.shared.global\" L2_CACHEHINT(64) \" [%1], [%2], 16;\\n\"\n                     \"}\\n\" ::\"r\"((int)mask), \"r\"(smem_ptr), \"l\"(src));\n    }\n    __device__ static void apply(int smem_ptr, const void* __restrict__ src, uint64_t cache_policy, bool mask)\n    {\n        asm volatile(\"{\\n  .reg .pred p;\\n  setp.ne.b32 p, %0, 0;\\n\"\n                     \"  @p cp.async.cg.shared.global.L2::cache_hint\" L2_CACHEHINT(64) \" [%1], [%2], 16, %3;\\n\"\n                     \"}\\n\" ::\"r\"((int)mask), \"r\"(smem_ptr), \"l\"(src), \"l\"(cache_policy));\n    }\n    // clang-format on\n};\n\ntemplate<>\nstruct CP_ASYNC<CacheOp::kGlobal, 16, 128> {\n    // clang-format off\n    __device__ static void apply(int smem_ptr, const void* __restrict__ src, bool mask)\n    {\n        asm volatile(\"{\\n  .reg .pred p;\\n  setp.ne.b32 p, %0, 0;\\n\"\n                     \"  @p cp.async.cg.shared.global\" L2_CACHEHINT(128) \" [%1], [%2], 16;\\n\"\n                     \"}\\n\" ::\"r\"((int)mask), \"r\"(smem_ptr), \"l\"(src));\n    }\n    __device__ static void apply(int smem_ptr, const void* __restrict__ src, uint64_t cache_policy, bool mask)\n    {\n        asm volatile(\"{\\n  .reg .pred p;\\n  setp.ne.b32 p, %0, 0;\\n\"\n                     \"  @p cp.async.cg.shared.global.L2::cache_hint\" L2_CACHEHINT(128) \" [%1], [%2], 16, %3;\\n\"\n                     \"}\\n\" ::\"r\"((int)mask), \"r\"(smem_ptr), \"l\"(src), \"l\"(cache_policy));\n    }\n    // clang-format on\n};\n\ntemplate<>\nstruct CP_ASYNC<CacheOp::kGlobal, 16, 256> {\n    // clang-format off\n    __device__ static void apply(int smem_ptr, const void* __restrict__ src, bool mask)\n    {\n        asm volatile(\"{\\n  .reg .pred p;\\n  setp.ne.b32 p, %0, 0;\\n\"\n                     \"  @p cp.async.cg.shared.global\" L2_CACHEHINT(256) \" [%1], [%2], 16;\\n\"\n                     \"}\\n\" ::\"r\"((int)mask), \"r\"(smem_ptr), \"l\"(src));\n    }\n    __device__ static void apply(int smem_ptr, const void* __restrict__ src, uint64_t cache_policy, bool mask)\n    {\n        asm volatile(\"{\\n  .reg .pred p;\\n  setp.ne.b32 p, %0, 0;\\n\"\n                     \"  @p cp.async.cg.shared.global.L2::cache_hint\" L2_CACHEHINT(256) \" [%1], [%2], 16, %3;\\n\"\n                     \"}\\n\" ::\"r\"((int)mask), \"r\"(smem_ptr), \"l\"(src), \"l\"(cache_policy));\n    }\n    // clang-format on\n};\n\ntemplate<int size, int prefetch_size>\nstruct CP_ASYNC<CacheOp::kAlways, size, prefetch_size> {\n    // clang-format off\n    __device__ static void apply(int smem_ptr, const void* __restrict__ src, bool mask)\n    {\n        asm volatile(\"{\\n  .reg .pred p;\\n  setp.ne.b32 p, %0, 0;\\n\"\n                     \"  @p cp.async.ca.shared.global [%1], [%2], %3;\\n\"\n                     \"}\\n\" ::\"r\"((int)mask), \"r\"(smem_ptr), \"l\"(src), \"n\"(size));\n    }\n    __device__ static void apply(int smem_ptr, const void* __restrict__ src, uint64_t cache_policy, bool mask)\n    {\n        asm volatile(\"{\\n  .reg .pred p;\\n  setp.ne.b32 p, %0, 0;\\n\"\n                     \"  @p cp.async.ca.shared.global.L2::cache_hint [%1], [%2], %3, %4;\\n\"\n                     \"}\\n\" ::\"r\"((int)mask), \"r\"(smem_ptr), \"l\"(src), \"n\"(size), \"l\"(cache_policy));\n    }\n    // clang-format on\n};\n\ntemplate<int size>\nstruct CP_ASYNC<CacheOp::kAlways, size, 64> {\n    // clang-format off\n    __device__ static void apply(int smem_ptr, const void* __restrict__ src, bool mask)\n    {\n        asm volatile(\"{\\n  .reg .pred p;\\n  setp.ne.b32 p, %0, 0;\\n\"\n                     \"  @p cp.async.ca.shared.global\" L2_CACHEHINT(64) \" [%1], [%2], %3;\\n\"\n                     \"}\\n\" ::\"r\"((int)mask), \"r\"(smem_ptr), \"l\"(src), \"n\"(size));\n    }\n    __device__ static void apply(int smem_ptr, const void* __restrict__ src, uint64_t cache_policy, bool mask)\n    {\n        asm volatile(\"{\\n  .reg .pred p;\\n  setp.ne.b32 p, %0, 0;\\n\"\n                     \"  @p cp.async.ca.shared.global.L2::cache_hint\" L2_CACHEHINT(64) \" [%1], [%2], %3, %4;\\n\"\n                     \"}\\n\" ::\"r\"((int)mask), \"r\"(smem_ptr), \"l\"(src), \"n\"(size), \"l\"(cache_policy));\n    }\n    // clang-format on\n};\n\ntemplate<int size>\nstruct CP_ASYNC<CacheOp::kAlways, size, 128> {\n    // clang-format off\n    __device__ static void apply(int smem_ptr, const void* __restrict__ src, bool mask)\n    {\n        asm volatile(\"{\\n  .reg .pred p;\\n  setp.ne.b32 p, %0, 0;\\n\"\n                     \"  @p cp.async.ca.shared.global\" L2_CACHEHINT(128) \" [%1], [%2], %3;\\n\"\n                     \"}\\n\" ::\"r\"((int)mask), \"r\"(smem_ptr), \"l\"(src), \"n\"(size));\n    }\n    __device__ static void apply(int smem_ptr, const void* __restrict__ src, uint64_t cache_policy, bool mask)\n    {\n        asm volatile(\"{\\n  .reg .pred p;\\n  setp.ne.b32 p, %0, 0;\\n\"\n                     \"  @p cp.async.ca.shared.global.L2::cache_hint\" L2_CACHEHINT(128) \" [%1], [%2], %3, %4;\\n\"\n                     \"}\\n\" ::\"r\"((int)mask), \"r\"(smem_ptr), \"l\"(src), \"n\"(size), \"l\"(cache_policy));\n    }\n    // clang-format on\n};\n\ntemplate<int size>\nstruct CP_ASYNC<CacheOp::kAlways, size, 256> {\n    // clang-format off\n    __device__ static void apply(int smem_ptr, const void* __restrict__ src, bool mask)\n    {\n        asm volatile(\"{\\n  .reg .pred p;\\n  setp.ne.b32 p, %0, 0;\\n\"\n                     \"  @p cp.async.ca.shared.global\" L2_CACHEHINT(256) \" [%1], [%2], %3;\\n\"\n                     \"}\\n\" ::\"r\"((int)mask), \"r\"(smem_ptr), \"l\"(src), \"n\"(size));\n    }\n    __device__ static void apply(int smem_ptr, const void* __restrict__ src, uint64_t cache_policy, bool mask)\n    {\n        asm volatile(\"{\\n  .reg .pred p;\\n  setp.ne.b32 p, %0, 0;\\n\"\n                     \"  @p cp.async.ca.shared.global.L2::cache_hint\" L2_CACHEHINT(256) \" [%1], [%2], %3, %4;\\n\"\n                     \"}\\n\" ::\"r\"((int)mask), \"r\"(smem_ptr), \"l\"(src), \"n\"(size), \"l\"(cache_policy));\n    }\n    // clang-format on\n};\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/cta_map.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include \"src/turbomind/kernels/core/common.h\"\n#include \"src/turbomind/kernels/core/math.h\"\n#include \"src/turbomind/kernels/gemm/types.h\"\n\nnamespace turbomind::gemm {\n\nTM_HOST_DEVICE constexpr int get_log_tile(int size, int tile_size)\n{\n    if (tile_size >= 32 && size >= 24)\n        return 5;\n    if (tile_size >= 16 && size >= 12)\n        return 4;\n    if (tile_size >= 8 && size >= 6)\n        return 3;\n    if (tile_size >= 4 && size >= 3)\n        return 2;\n    if (tile_size >= 2 && size >= 2)\n        return 1;\n    return 0;\n}\n\nTM_HOST_DEVICE constexpr int2 get_tiled_shape(int m, int n, int cta_m, int cta_n)\n{\n    return {ceil_div(m, cta_m), ceil_div(n, cta_n)};\n}\n\nstruct CtaMap_ {\n\n    TM_HOST_DEVICE static int3 get_tiled_shape(int m, int n, int k, int cta_m, int cta_n, int split_cnt)\n    {\n        return {(m + cta_m - 1) / cta_m, (n + cta_n - 1) / cta_n, split_cnt};\n    }\n\n    TM_HOST_DEVICE static int get_log_tile(int2 tiled_mn, int N)\n    {\n        return gemm::get_log_tile(tiled_mn.y, N);\n    }\n\n    TM_HOST_DEVICE static dim3 get_grid_shape(int3 tiled_shape, int log_tile)\n    {\n        int tile = 1 << log_tile;\n        return {static_cast<unsigned>(tiled_shape.x * tile),\n                static_cast<unsigned>((tiled_shape.y + tile - 1) / tile),\n                static_cast<unsigned>(tiled_shape.z)};\n    }\n\n    TM_DEVICE static int3 get_tile_offset(int log_tile)\n    {\n        int block_idx_x = blockIdx.x;\n        int block_idx_y = blockIdx.y;\n        int block_idx_z = blockIdx.z;\n        return {(block_idx_x >> log_tile),  //\n                (block_idx_y << log_tile) + (block_idx_x & ((1 << log_tile) - 1)),\n                block_idx_z};\n    }\n};\n\ntemplate<Order order_>\nclass GemmScheduler {\n\n    static constexpr auto order = order_;\n\n    int4 gemm_shape_;\n    int4 tiled_shape_;\n    int  log_tile_;\n\n    int chunk_offset_;\n    int chunks_per_split_;\n    int iter_k_per_chunk_;\n\n    int4 tile_offset_;\n    int2 iter_k_range_;\n\npublic:\n    TM_HOST_DEVICE\n    GemmScheduler(int4 gemm_shape, int2 tiled_mn, int splits, int log_tile, int cta_k, int chunk_size):\n        gemm_shape_{gemm_shape}, tiled_shape_{tiled_mn.x, tiled_mn.y, splits}, log_tile_{log_tile}\n    {\n        const int chunk_cnt = cdiv(gemm_shape_.z, chunk_size);\n\n        iter_k_per_chunk_ = chunk_size / cta_k;\n        chunks_per_split_ = chunk_cnt / splits;\n        chunk_offset_     = splits - chunk_cnt % splits;\n    }\n\n    TM_HOST_DEVICE static int get_log_tile(int2 tiled_mn, int tile_size)\n    {\n        return gemm::get_log_tile(order == kColMajor ? tiled_mn.y : tiled_mn.x, tile_size);\n    }\n\n    TM_HOST_DEVICE static dim3 get_grid_shape(int4 tiled_shape, int log_tile)\n    {\n        const int tile = 1 << log_tile;\n        if constexpr (order == kColMajor) {\n            return {(unsigned)(tiled_shape.x * tile), (unsigned)(cdiv(tiled_shape.y, tile)), (unsigned)(tiled_shape.z)};\n        }\n        else {\n            return {(unsigned)(tiled_shape.y * tile), (unsigned)(cdiv(tiled_shape.x, tile)), (unsigned)(tiled_shape.z)};\n        }\n    }\n\n    TM_HOST_DEVICE dim3 get_grid_shape() const\n    {\n        return get_grid_shape(tiled_shape_, log_tile_);\n    }\n\n    TM_HOST_DEVICE std::true_type init(int block_idx_x, int block_idx_y, int block_idx_z)\n    {\n        if constexpr (order == kColMajor) {\n            tile_offset_ = {(block_idx_x >> log_tile_),\n                            (block_idx_y << log_tile_) + (block_idx_x & ((1 << log_tile_) - 1)),\n                            (block_idx_z)};\n        }\n        else {\n            tile_offset_ = {(block_idx_y << log_tile_) + (block_idx_x & ((1 << log_tile_) - 1)),\n                            (block_idx_x >> log_tile_),\n                            (block_idx_z)};\n        }\n        tile_offset_.w       = 0;\n        const int chunk_id   = tile_offset_.z * chunks_per_split_ + max(tile_offset_.z - chunk_offset_, 0);\n        const int iter_k_beg = chunk_id * iter_k_per_chunk_;\n        const int iter_k_cnt = (chunks_per_split_ + int(tile_offset_.z >= chunk_offset_)) * iter_k_per_chunk_;\n        iter_k_range_        = {iter_k_beg, iter_k_beg + iter_k_cnt};\n\n        return {};\n    }\n\n    TM_DEVICE std::true_type init()\n    {\n        return init(blockIdx.x, blockIdx.y, blockIdx.z);\n    }\n\n    TM_DEVICE int4 gemm_shape() const\n    {\n        return gemm_shape_;\n    }\n\n    TM_DEVICE int4 tiled_shape() const\n    {\n        return tiled_shape_;\n    }\n\n    TM_DEVICE int4 tile_offset() const\n    {\n        return tile_offset_;\n    }\n\n    TM_DEVICE int2 iter_k_range() const\n    {\n        return iter_k_range_;\n    }\n\n    TM_DEVICE int tile_id() const\n    {\n        return tile_offset_.x * tiled_shape_.y + tile_offset_.y;\n    }\n};\n\ntemplate<Order order_>\nclass DynamicScheduler {\n\n    static constexpr auto order = order_;\n\n    int ctas_;\n\n    const int4* __restrict__ gemm_shapes_;    // [group_num]\n    const int4* __restrict__ tiled_shapes_;   // [group_num]\n    const int2* __restrict__ offsets_mn_;     // [group_num]\n    const int4* __restrict__ tile_offsets_;   // [ctas]\n    const int2* __restrict__ iter_k_ranges_;  // [ctas]\n    const int* __restrict__ tile_ids_;        // [ctas]\n\n    int4 gemm_shape_;\n    int4 tiled_shape_;\n    int4 tile_offset_;\n    int2 iter_k_range_;\n    int2 base_mn_;\n\npublic:\n    DynamicScheduler(const Tape& tape):\n        ctas_{tape.ctas},\n        gemm_shapes_{tape.gemm_shapes},\n        tiled_shapes_{tape.tiled_shapes},\n        tile_offsets_{tape.tile_offsets},\n        iter_k_ranges_{tape.iter_k_ranges},\n        tile_ids_{tape.tile_ids}\n    {\n    }\n\n    TM_HOST_DEVICE static int get_log_tile(int2 tiled_mn, int tile_size)\n    {\n        return gemm::get_log_tile(order == kColMajor ? tiled_mn.y : tiled_mn.x, tile_size);\n    }\n\n    TM_HOST_DEVICE dim3 get_grid_shape()\n    {\n        return {(unsigned)ctas_, 1, 1};\n    }\n\n    TM_DEVICE bool init()\n    {\n        const int block_idx = blockIdx.x;\n\n        const auto [cta_m_id, cta_n_id, cta_k_id, group_id] = __ldg(tile_offsets_ + block_idx);\n\n        if (group_id < 0) {\n            return false;\n        }\n\n        gemm_shape_  = __ldg(gemm_shapes_ + group_id);\n        tiled_shape_ = __ldg(tiled_shapes_ + group_id);\n        base_mn_     = __ldg(offsets_mn_ + group_id);\n\n        tile_offset_ = {cta_m_id, cta_n_id, cta_k_id, group_id};\n\n        iter_k_range_ = __ldg(iter_k_ranges_ + block_idx);\n\n        return true;\n    }\n\n    TM_DEVICE int4 gemm_shape() const\n    {\n        return gemm_shape_;\n    }\n\n    TM_DEVICE int4 tiled_shape() const\n    {\n        return tiled_shape_;\n    }\n\n    TM_DEVICE int4 tile_offset() const\n    {\n        return tile_offset_;\n    }\n\n    TM_DEVICE int2 iter_k_range() const\n    {\n        return iter_k_range_;\n    }\n\n    TM_DEVICE int tile_id() const\n    {\n        return tile_ids_[blockIdx.x];\n    }\n};\n\ntemplate<class S>\nstruct is_dynamic_scheduler: std::false_type {\n};\n\ntemplate<Order order>\nstruct is_dynamic_scheduler<DynamicScheduler<order>>: std::true_type {\n};\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/cublas.cu",
    "content": "#include <cublas_v2.h>\n\n#include \"src/turbomind/core/cuda_data_type.h\"\n#include \"src/turbomind/core/data_type.h\"\n\n#include \"src/turbomind/kernels/gemm/kernel.h\"\n#include \"src/turbomind/kernels/gemm/registry.h\"\n#include \"src/turbomind/kernels/gemm/types.h\"\n\nnamespace turbomind::gemm {\n\nclass CublasKernel: public Kernel {\npublic:\n    explicit CublasKernel(): cublas_{}\n    {\n        cublasCreate(&cublas_);\n        if (0) {\n            cublasSetMathMode(cublas_, CUBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION);\n        }\n\n        desc_.backend    = 1;\n        desc_.group_axis = -1;\n\n        info_.chunk_size_k      = 1;\n        info_.dynamic_smem_size = 0;\n\n        info_.name = GetName();\n    }\n\n    ~CublasKernel() override\n    {\n        cublasDestroy(cublas_);\n        cublas_ = {};\n    }\n\n    int Launch(const Operation&    operation,\n               float               alpha,\n               const void*         A,\n               const MatrixLayout& Adesc,\n               const void*         U,\n               const MatrixLayout& Udesc,\n               const void*         B,\n               const MatrixLayout& Bdesc,\n               const void*         V,\n               const MatrixLayout& Vdesc,\n               float               beta,\n               const void*         C,\n               const MatrixLayout& Cdesc,\n               void*               D,\n               const MatrixLayout& Ddesc,\n               int                 swizzle,\n               int                 splits,\n               Workspace&          workspace,\n               cudaStream_t        stream) override\n    {\n        cublasOperation_t transa = Adesc.order == kColMajor ? CUBLAS_OP_N : CUBLAS_OP_T;\n        cublasOperation_t transb = Bdesc.order == kColMajor ? CUBLAS_OP_N : CUBLAS_OP_T;\n\n        const int m = Adesc.rows;\n        const int n = Bdesc.cols;\n        const int k = Adesc.cols;\n\n        TM_CHECK_EQ(Bdesc.rows, k);\n        TM_CHECK_EQ(Ddesc.rows, m);\n        TM_CHECK_EQ(Ddesc.cols, n);\n\n        TM_CHECK(C == nullptr || C == D);\n\n        if (stream_ != stream) {\n            cublasSetStream(cublas_, stream);\n            stream_ = stream;\n        }\n\n        if (workspace_ != workspace.partials || workspace_size_ != workspace.partials_size) {\n            cublasSetWorkspace(cublas_, workspace.partials, workspace.partials_size);\n            workspace_      = workspace.partials;\n            workspace_size_ = workspace.partials_size;\n        }\n\n        auto ec = cublasGemmEx(cublas_,\n                               transa,\n                               transb,\n                               m,\n                               n,\n                               k,\n                               &alpha,\n                               A,\n                               to_cuda_dtype(Adesc.type),\n                               Adesc.ld,\n                               B,\n                               to_cuda_dtype(Bdesc.type),\n                               Bdesc.ld,\n                               &beta,\n                               D,\n                               to_cuda_dtype(Ddesc.type),\n                               Ddesc.ld,\n                               CUDA_R_32F,\n                               CUBLAS_GEMM_DEFAULT_TENSOR_OP);\n\n        return ec == CUBLAS_STATUS_SUCCESS ? 0 : 1;\n    }\n\n    bool is_feasible(const GemmDesc& desc) const noexcept override\n    {\n        constexpr std::tuple flat3{Striding::kFlat, Striding::kFlat, Striding::kFlat};\n\n        if (std::tie(desc.striding_a, desc.striding_b, desc.striding_c) != flat3) {\n            return false;\n        }\n        if (std::tie(desc.pack_a, desc.pack_b, desc.pack_u, desc.pack_v) != std::tuple{0, 0, 0, 0}) {\n            return false;\n        }\n        if (desc.epilogue != Epilogue::kNone) {\n            return false;\n        }\n        if (desc.num > 1) {\n            return false;\n        }\n        if (desc.quant_a || desc.quant_b) {\n            return false;\n        }\n        if (desc.group_axis >= 0) {\n            return false;\n        }\n        if (desc.order_c != kColMajor) {\n            return false;\n        }\n        if (desc.type_a != kHalf && desc.type_a != kBfloat16 && desc.type_a != kFloat) {\n            return false;\n        }\n        if (desc.type_b != desc.type_a) {\n            return false;\n        }\n        if (desc.type_c != desc.type_a && desc.type_c != kFloat) {\n            return false;\n        }\n        return true;\n    }\n\n    int GetMaxSwizzle(const int4&) const override\n    {\n        return 0;\n    }\n\n    int GetMaxSplits(const int4&, int, size_t, size_t) const override\n    {\n        return 1;\n    }\n\nprivate:\n    cublasHandle_t cublas_{};\n    cudaStream_t   stream_{};\n    void*          workspace_{};\n    size_t         workspace_size_{};\n};\n\nvoid Registry::cublas_float()\n{\n    Add(std::make_unique<CublasKernel>());\n}\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/desc.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include <array>\n#include <tuple>\n#include <type_traits>\n\n#include \"src/turbomind/kernels/core/data_type.h\"\n#include \"src/turbomind/kernels/gemm/types.h\"\n\nnamespace turbomind::gemm {\n\n// aggregate that uniquely identifies a GEMM problem\nstruct GemmDesc {\n    int       arch;\n    DataType  type_a;\n    DataType  type_b;\n    DataType  type_c;\n    Order     order_a;\n    Order     order_b;\n    Order     order_c;\n    Striding  striding_a;\n    Striding  striding_b;\n    Striding  striding_c;\n    Pack      pack_a;\n    Pack      pack_b;\n    Pack      pack_u;\n    Pack      pack_v;\n    QuantDesc quant_a;\n    QuantDesc quant_b;\n    Epilogue  epilogue;\n    int       batch_dim;\n    int       group_axis;\n    int       m;\n    int       n;\n    int       k;\n    int       num;\n};\n\nstatic_assert(std::is_trivially_copyable_v<GemmDesc>);\n\ninline GemmDesc transpose(GemmDesc d)\n{\n    std::swap(d.type_a, d.type_b);\n    std::swap(d.order_a, d.order_b);\n    d.order_a = ~d.order_a;\n    d.order_b = ~d.order_b;\n    d.order_c = ~d.order_c;\n    std::swap(d.striding_a, d.striding_b);\n    std::swap(d.pack_a, d.pack_b);\n    std::swap(d.pack_u, d.pack_v);\n    std::swap(d.quant_a, d.quant_b);\n    std::swap(d.m, d.n);\n    d.batch_dim = 1 - d.batch_dim;\n    if (d.group_axis >= 0) {\n        d.group_axis = 1 - d.group_axis;\n    }\n    return d;\n}\n\ninline std::string to_string(const GemmDesc& d)\n{\n    std::stringstream ss;\n    ss << \"sm\" << d.arch / 10;\n    ss << \"_\" << to_string(d.type_a);  //\n    if (d.quant_a) {\n        ss << to_string(d.quant_a);\n    }\n    ss << \"_\" << to_string(d.type_b);  //\n    if (d.quant_b) {\n        ss << to_string(d.quant_b);\n    }\n    ss << \"_\" << to_string(d.type_c);\n    ss << \"_\"                                    //\n       << (d.order_a == kColMajor ? 'n' : 't')   //\n       << (d.order_b == kColMajor ? 'n' : 't')   //\n       << (d.order_c == kColMajor ? 'n' : 't');  //\n    ss << \"_\"                                    //\n       << to_string(d.striding_a)                //\n       << to_string(d.striding_b)                //\n       << to_string(d.striding_c);\n    ss << \"_\" << d.m << \"x\" << d.n << \"x\" << d.k;\n    ss << \"_\" << d.num;\n    return ss.str();\n}\n\nenum class OpClass\n{\n    kSIMT,\n    kMMA_s884,\n    kMMA_s16816,\n    kGMMA_s64n16\n};\n\ninline const char* to_string(OpClass op)\n{\n    switch (op) {\n        case OpClass::kSIMT:\n            return \"simt\";\n        case OpClass::kMMA_s884:\n            return \"s884\";\n        case OpClass::kMMA_s16816:\n            return \"s16816\";\n        default:\n            return \"unknown_op_cls\";\n    }\n}\n\n// aggregate that uniquely identifies a kernel\nstruct KernelDesc {\n    int       arch;\n    OpClass   op_class;\n    DataType  type_a;\n    DataType  type_b;\n    DataType  type_c;\n    Order     order_a;\n    Order     order_b;\n    Order     order_c;\n    Striding  striding_a;\n    Striding  striding_b;\n    Striding  striding_c;\n    Pack      pack_a;\n    Pack      pack_b;\n    Pack      pack_u;\n    Pack      pack_v;\n    QuantDesc quant_a;\n    QuantDesc quant_b;\n    int       policy_a;\n    int       policy_b;\n    int3      cta_tile;\n    int3      mma_tile;\n    int2      cluster_shape;\n    int3      align;\n    int2      c_tile;\n    int       stages;\n    bool      split_k;\n    int       group_axis;\n    int       backend;\n    bool      transpose;\n};\n\nstatic_assert(std::is_trivially_copyable_v<KernelDesc>);\n\nstruct KernelInfo {\n    int dynamic_smem_size;\n    int max_active_ctas;\n    int chunk_size_k;\n\n    std::string name;\n\n    cudaFuncAttributes attr;\n};\n\ninline KernelDesc transpose(const KernelDesc& d)\n{\n    KernelDesc k{d};\n\n    k.arch     = d.arch;\n    k.op_class = d.op_class;\n\n    k.order_a = ~d.order_b;\n    k.order_b = ~d.order_a;\n    k.order_c = ~d.order_c;\n\n    k.type_a = d.type_b;\n    k.type_b = d.type_a;\n\n    k.striding_a = d.striding_b;\n    k.striding_b = d.striding_a;\n\n    k.pack_a = d.pack_b;\n    k.pack_b = d.pack_a;\n    k.pack_u = d.pack_v;\n    k.pack_v = d.pack_u;\n\n    k.quant_a = d.quant_b;\n    k.quant_b = d.quant_a;\n\n    k.policy_a = d.policy_b;\n    k.policy_b = d.policy_a;\n\n    auto swap = [](auto& v) { std::swap(v.x, v.y); };\n\n    swap(k.cta_tile);\n    swap(k.mma_tile);\n    swap(k.cluster_shape);\n    swap(k.align);\n    swap(k.c_tile);\n\n    return k;\n}\n\nclass Kernel;\nstruct LaunchSpec {\n    Kernel* kernel;\n    int     swizzle;\n    int     splits;\n    float   measured;\n\n    std::array<int64_t, 2> estimated;\n};\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/dispatch_cache.cu",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include \"src/turbomind/kernels/gemm/desc.h\"\n#include \"src/turbomind/kernels/gemm/dispatch_cache.h\"\n#include \"src/turbomind/kernels/gemm/kernel.h\"\n#include \"src/turbomind/kernels/gemm/types.h\"\n#include <algorithm>\n#include <iostream>\n#include <map>\n#include <memory>\n#include <ostream>\n#include <sstream>\n#include <vector>\n\nstatic inline bool operator==(const int3& a, const int3& b)\n{\n    return a.x == b.x && a.y == b.y && a.z == b.z;\n}\n\nstatic inline bool operator==(const int2& a, const int2& b)\n{\n    return a.x == b.x && a.y == b.y;\n}\n\nnamespace turbomind::gemm {\n\nstatic inline decltype(auto) as_tuple(const KernelDesc& d)\n{\n    return std::tie(d.arch,\n                    d.op_class,\n                    d.type_a,\n                    d.type_b,\n                    d.type_c,\n                    d.order_a,\n                    d.order_b,\n                    d.order_c,\n                    d.striding_a,\n                    d.striding_b,\n                    d.striding_c,\n                    d.pack_a,\n                    d.pack_b,\n                    d.pack_u,\n                    d.pack_v,\n                    d.quant_a,\n                    d.quant_b,\n                    d.policy_a,\n                    d.policy_b,\n                    d.cta_tile,\n                    d.mma_tile,\n                    d.cluster_shape,\n                    d.align,\n                    d.c_tile,\n                    d.stages,\n                    d.split_k,\n                    d.backend,\n                    d.transpose,\n                    d.group_axis);\n}\n\nstatic inline bool operator==(const QuantDesc& a, const QuantDesc& b)\n{\n    return a.type == b.type && a.group_size == b.group_size;\n}\n\nstatic inline bool operator==(const KernelDesc& a, const KernelDesc& b)\n{\n    return as_tuple(a) == as_tuple(b);\n}\n\nnamespace {\n\nstruct Record {\n    GemmDesc   gemm;\n    KernelDesc kernel;\n\n    int swizzle;\n    int splits;\n};\n\n}  // namespace\n\nvoid ExportDispatchCache(std::ostream& os, const std::vector<std::pair<GemmDesc, LaunchSpec>>& entries)\n{\n\n    for (const auto& [g, spec] : entries) {\n        Record record{};\n        record.gemm    = g;\n        record.kernel  = spec.kernel->desc();\n        record.splits  = spec.splits;\n        record.swizzle = spec.swizzle;\n\n        os.write((const char*)&record, sizeof(record));\n    }\n}\n\nvoid ImportDispatchCache(std::istream&                                 is,\n                         std::vector<std::pair<GemmDesc, LaunchSpec>>& entries,\n                         const std::vector<Kernel*>&                   kernels)\n{\n    is.seekg(0, is.end);\n    const auto size_in_bytes = is.tellg();\n    is.seekg(0, is.beg);\n\n    if (size_in_bytes % sizeof(Record)) {\n        std::cerr << \"File size is not a multiple of record size, faild to import records.\\n\";\n    }\n\n    const int n = size_in_bytes / sizeof(Record);\n\n    for (int i = 0; i < n; ++i) {\n        Record record;\n        is.read((char*)&record, sizeof(Record));\n\n        LaunchSpec spec{};\n        spec.splits  = record.splits;\n        spec.swizzle = record.swizzle;\n\n        for (const auto& p : kernels) {\n            if (p->desc() == record.kernel) {\n                spec.kernel = p;\n                break;\n            }\n        }\n        if (spec.kernel) {\n            entries.emplace_back(record.gemm, spec);\n        }\n        else {\n            std::cerr << \"No kernel found for entry \" << i << \"\\n\";\n        }\n    }\n}\n\nnamespace {\n\ninline decltype(auto) as_tuple(const GemmDesc& d)\n{\n    return std::tie(d.arch,\n                    d.type_a,\n                    d.type_b,\n                    d.type_c,\n                    d.order_a,\n                    d.order_b,\n                    d.order_c,\n                    d.striding_a,\n                    d.striding_b,\n                    d.striding_c,\n                    d.pack_a,\n                    d.pack_b,\n                    d.pack_u,\n                    d.pack_v,\n                    d.quant_a.type,\n                    d.quant_a.group_size,\n                    d.quant_b.type,\n                    d.quant_b.group_size,\n                    d.batch_dim,\n                    d.group_axis,\n                    d.m,\n                    d.n,\n                    d.k,\n                    d.num);\n    // Note: `d.epilogue` is not used yet\n}\n\n}  // namespace\n\ninline bool operator<(const GemmDesc& a, const GemmDesc& b)\n{\n    return as_tuple(a) < as_tuple(b);\n}\n\nint extract_batch_size(GemmDesc& desc)\n{\n    return std::exchange(desc.batch_dim == 0 ? desc.m : desc.n, 0);\n}\n\nvoid set_batch_size(GemmDesc& desc, int batch_size)\n{\n    (desc.batch_dim == 0 ? desc.m : desc.n) = batch_size;\n}\n\nstruct DispatchCache::Impl {\n\n    struct Flat {\n        std::vector<std::pair<int, int>> idxs;\n        std::vector<LaunchSpec>          specs;\n    };\n\n    const std::vector<Kernel*> kernels_;\n    std::map<GemmDesc, Flat>   cache_;\n\n    Impl(std::vector<Kernel*> kernels): kernels_(std::move(kernels)) {}\n\n    std::optional<LaunchSpec> Find(GemmDesc desc, bool exact) const\n    {\n        const int batch_size = extract_batch_size(desc);\n        // std::cerr << batch_size << \" \" << desc.m << \" \" << desc.n << \" \" << desc.k << \" \" << std::boolalpha << exact\n        //           << \"\\n\";\n        const auto it = cache_.find(desc);\n        if (it != cache_.end()) {\n            const auto& [idxs, specs] = it->second;\n            // Find index via key\n            const auto p =\n                std::lower_bound(idxs.begin(), idxs.end(), std::make_pair(batch_size, 0), [](auto& a, auto& b) {  //\n                    return a.first < b.first;\n                });\n            // std::cout << it->second.specs.size() << std::endl;\n            if (p != idxs.end() && (!exact || p->first == batch_size)) {\n                // std::cerr << p->first << \" \" << p->second << \"\\n\";\n                return specs[p->second];\n            }\n        }\n        return {};\n    }\n\n    bool Insert(GemmDesc desc, const LaunchSpec& spec)\n    {\n        const int batch_size = extract_batch_size(desc);\n\n        auto it = cache_.find(desc);\n        if (it == cache_.end()) {\n            it = cache_.emplace_hint(it, desc, Flat{});\n        }\n        auto& [idxs, specs] = it->second;\n        // Find index via key\n        const auto p =\n            std::lower_bound(idxs.begin(), idxs.end(), std::make_pair(batch_size, 0), [](auto& a, auto& b) {  //\n                return a.first < b.first;\n            });\n        // Exact match, skip\n        if (p != idxs.end() && p->first == batch_size) {\n            return false;\n        }\n        // Insert\n        idxs.insert(p, {batch_size, (int)specs.size()});\n        specs.push_back(spec);\n        return true;\n    }\n\n    int Export(std::ostream& os) const\n    {\n        std::vector<std::pair<GemmDesc, LaunchSpec>> entries;\n        for (const auto& [desc, flat] : cache_) {\n            auto tmp = desc;\n            for (const auto& [batch_size, index] : flat.idxs) {\n                set_batch_size(tmp, batch_size);\n                entries.emplace_back(tmp, flat.specs[index]);\n            }\n        }\n        Summary(entries);\n        ExportDispatchCache(os, entries);\n        return entries.size();\n    }\n\n    int Import(std::istream& is)\n    {\n        std::vector<std::pair<GemmDesc, LaunchSpec>> entries;\n        ImportDispatchCache(is, entries, kernels_);\n        Summary(entries);\n        for (auto [desc, spec] : entries) {\n            const int batch_size = extract_batch_size(desc);\n            auto      it         = cache_.find(desc);\n            if (it == cache_.end()) {\n                it = cache_.emplace_hint(it, desc, Flat{});\n            }\n            auto& [idxs, specs] = it->second;\n            // Order is not maintained at this point\n            idxs.emplace_back(batch_size, (int)specs.size());\n            specs.push_back(spec);\n        }\n        // Sort indices and deduplicate\n        for (auto& [desc, flat] : cache_) {\n            auto& [idxs, specs] = flat;\n            std::stable_sort(idxs.begin(), idxs.end(), [](auto a, auto b) { return a.first < b.first; });\n            idxs.erase(std::unique(idxs.begin(), idxs.end(), [](auto a, auto b) { return a.first == b.first; }),\n                       idxs.end());\n            // Remove unreferenced specs and update spec indices\n            std::vector<LaunchSpec> tmp;\n            for (auto& [key, val] : idxs) {\n                int old = std::exchange(val, tmp.size());\n                tmp.push_back(specs[old]);\n            }\n            specs = std::move(tmp);\n        }\n        return entries.size();\n    }\n\n    // Print a summary of how many cases a kernel is used\n    void Summary(const std::vector<std::pair<GemmDesc, LaunchSpec>>& entries) const\n    {\n        std::vector<Kernel*> uses{nullptr};\n        std::copy(kernels_.begin(), kernels_.end(), std::back_inserter(uses));\n\n        for (const auto& [_, s] : entries) {\n            uses.push_back(s.kernel);\n        }\n        std::sort(uses.begin(), uses.end());\n        std::vector<std::pair<int, Kernel*>> count;\n        for (size_t i = 1; i < uses.size(); ++i) {\n            if (uses[i] != uses[i - 1]) {\n                count.emplace_back(-1, uses[i]);\n            }\n            ++count.back().first;\n        }\n        std::sort(count.begin(), count.end(), std::greater<>{});\n        for (const auto& [n, k] : count) {\n            std::cout << k->name() << \": \" << n << \"\\n\";\n        }\n    }\n};\n\nDispatchCache::DispatchCache(std::vector<Kernel*> kernels): impl_(std::make_unique<Impl>(std::move(kernels))) {}\n\nDispatchCache::~DispatchCache() = default;\n\nstd::optional<LaunchSpec> DispatchCache::Find(const GemmDesc& desc) const\n{\n    return impl_->Find(desc, true);\n}\n\nstd::optional<LaunchSpec> DispatchCache::LowerBound(const GemmDesc& desc) const\n{\n    return impl_->Find(desc, false);\n}\n\nbool DispatchCache::Insert(const GemmDesc& desc, const LaunchSpec& spec)\n{\n    return impl_->Insert(desc, spec);\n}\n\nint DispatchCache::Export(std::ostream& os) const\n{\n    return impl_->Export(os);\n}\n\nint DispatchCache::Import(std::istream& is)\n{\n    return impl_->Import(is);\n}\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/dispatch_cache.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include \"src/turbomind/kernels/gemm/desc.h\"\n\n#include <memory>\n#include <optional>\n#include <vector>\n\nnamespace turbomind::gemm {\n\nclass DispatchCache {\npublic:\n    DispatchCache(std::vector<Kernel*> kernels);\n\n    ~DispatchCache();\n\n    std::optional<LaunchSpec> LowerBound(const GemmDesc& desc) const;\n\n    std::optional<LaunchSpec> Find(const GemmDesc& desc) const;\n\n    bool Insert(const GemmDesc& desc, const LaunchSpec& spec);\n\n    int Export(std::ostream& os) const;\n\n    int Import(std::istream& is);\n\nprivate:\n    struct Impl;\n    std::unique_ptr<Impl> impl_;\n};\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/epilogue.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include \"src/turbomind/kernels/core/array_ops.h\"\n#include \"src/turbomind/kernels/core/common.h\"\n#include \"src/turbomind/kernels/core/math.h\"\n#include \"src/turbomind/kernels/core/meta.h\"\n#include \"src/turbomind/kernels/core/sync.h\"\n#include \"src/turbomind/kernels/gemm/matrix_ptr.h\"\n#include \"src/turbomind/kernels/gemm/predicate.h\"\n#include \"src/turbomind/kernels/gemm/smem_copy.h\"\n#include \"src/turbomind/kernels/gemm/types.h\"\n#include \"src/turbomind/kernels/gemm/utils.h\"\n\nnamespace turbomind::gemm {\n\ntemplate<class Tc>\nstruct ChannelCombination_v3 {\n    const Tc* __restrict__ scale_bias_ptr;\n\n    template<class T, int V, int S, int C, int delta_c, int delta_s, class Pred>\n    __device__ void operator()(Array<T, V> (&x)[S][C], int2 cs0, pair<delta_c, delta_s>, Pred& pred) const\n    {\n        __align__(16) Array<Tc, 2> scale_bias[S];\n\n        if (scale_bias_ptr) {\n            constexpr int ds  = sizeof(Tc) * delta_s;\n            auto          ptr = reinterpret_cast<const char*>(scale_bias_ptr + cs0.y);\n            PRAGMA_UNROLL\n            for (int s = 0; s < S; ++s) {\n                if (pred(s, 0)) {\n                    Ldg(scale_bias[s], reinterpret_cast<const Tc*>(ptr));\n                }\n                ptr += ds;\n            }\n            PRAGMA_UNROLL\n            for (int s = 0; s < S; ++s) {\n                auto tmp = cast<T>(scale_bias[s]);\n                PRAGMA_UNROLL\n                for (int c = 0; c < C; ++c) {\n                    using namespace ops;\n                    x[s][c] = x[s][c] * tmp[0] + tmp[1];\n                }\n            }\n        }\n    }\n};\n\ntemplate<bool     scale_S,\n         bool     scale_C,\n         Striding mode_S,\n         Striding mode_C,\n         class T,\n         int N,\n         int S,\n         int C,\n         int delta_C,\n         int delta_S,\n         class Pred>\n__device__ void Scale(pair<scale_S, scale_C>,\n                      pair<mode_S, mode_C>,\n                      pair<delta_C, delta_S>,\n                      Array<T, N> (&x)[S][C],\n                      const MatrixParam& param_S,\n                      const MatrixParam& param_C,\n                      int                gemm_id,\n                      int2               cs0,\n                      Pred&              pred)\n{\n    if (scale_S && param_S.ptr) {\n        const auto mat = resolve<T, mode_S>(param_S, gemm_id);\n        const T*   ptr = (const T*)mat.ptr.ptr;\n        T          param[S];\n        PRAGMA_UNROLL\n        for (int s = 0; s < S; ++s) {\n            const int ss  = cs0.y + s * delta_S;\n            const int idx = mat.idxs ? __ldg(mat.idxs + ss) : ss;\n            if (pred(s, 0)) {\n                param[s] = __ldg((const T*)(ptr + idx));\n            }\n            PRAGMA_UNROLL\n            for (int c = 0; c < C; ++c) {\n                using namespace ops;\n                x[s][c] = x[s][c] * param[s];\n            }\n        }\n    }\n\n    if (scale_C && param_C.ptr) {\n        const T*      ptr = (const T*)resolve<T, mode_C>(param_C, gemm_id).ptr.ptr + cs0.x;\n        constexpr int dc  = sizeof(Array<T, N>) * delta_C;\n        Array<T, N>   param[C];\n        PRAGMA_UNROLL\n        for (int c = 0; c < C; ++c) {\n            if (pred(0, c)) {\n                Ldg(param[c], (const T*)(ptr + dc * c));\n            }\n            PRAGMA_UNROLL\n            for (int s = 0; s < S; ++s) {\n                using namespace ops;\n                x[s][c] = x[s][c] * param[c];\n            }\n        }\n    }\n}\n\nstruct MatrixCombination_v3 {\n\n    MatrixParam param_c;\n    float       alpha;\n    float       beta;\n\n    template<class Tc, Striding mode, class T, int N, int S, int C, int delta_c, int delta_s, class Pred>\n    __device__ void operator()(Tc*,  //\n                               constant<mode>,\n                               Array<T, N> (&x)[S][C],\n                               int2 cs0,\n                               int  gemm_id,\n                               pair<delta_c, delta_s>,\n                               Pred& pred) const\n    {\n        if (beta) {\n            const auto c = resolve<Tc, mode>(param_c, gemm_id);\n\n            Array<Tc, N>  frag[S][C];\n            constexpr int dc  = sizeof(Tc) * delta_c;\n            const int     ds  = sizeof(Tc) * delta_s * c.ptr.stride;\n            const char*   ptr = (const char*)c.ptr.ptr + sizeof(Tc) * dot(cs0, long2{1, c.ptr.stride});\n            PRAGMA_UNROLL\n            for (int s = 0; s < S; ++s) {\n                PRAGMA_UNROLL\n                for (int c = 0; c < C; ++c) {\n                    if (pred(s, c)) {\n                        Load(frag[s][c], reinterpret_cast<const Tc*>(ptr));\n                        using namespace ops;\n                        x[s][c] = x[s][c] * alpha + cast<T>(frag[s][c]) * beta;\n                    }\n                    ptr += dc;\n                }\n                ptr -= dc * C;\n                ptr += ds;\n            }\n        }\n        else if (alpha != 1.f) {\n            PRAGMA_UNROLL\n            for (int s = 0; s < S; ++s) {\n                PRAGMA_UNROLL\n                for (int c = 0; c < C; ++c) {\n                    using namespace ops;\n                    x[s][c] = x[s][c] * alpha;\n                }\n            }\n        }\n    }\n};\n\ntemplate<class Act>\nstruct GatedActivation {\n    template<class T, int N>\n    __device__ static void apply(Array<T, N>& x)\n    {\n        static_assert(N % 2 == 0);\n        PRAGMA_UNROLL\n        for (int i = 0; i < N; i += 2) {\n            x[i / 2] = static_cast<T>(Act::apply(x[i]) * x[i + 1]);\n        }\n    }\n};\n\nstruct Silu {\n    __device__ static float apply(float x)\n    {\n        return fdividef(x, 1.f + expf(-x));\n    }\n};\n\nstruct EpilogueParam {\n    MatrixParam c;\n    MatrixParam partials;\n    int*        locks;\n\n    // MatrixParam scale_S;\n    // MatrixParam scale_C;\n\n    MatrixCombination_v3 combine_mat;\n\n    bool silu_act;\n};\n\ntemplate<class Tc_,\n         int M,\n         int N,\n         int TM_,\n         int TN_,\n         int THREADS,\n         class RearrangeC,\n         class OperandC,\n         Striding mode_C,\n         bool     SplitK_>\nstruct Epilogue_ {\n\n    using Dtype = typename OperandC::Dtype;\n\n    static constexpr auto kOrder = OperandC::kOrder;\n    static constexpr auto kMode  = mode_C;\n    static constexpr bool SplitK = SplitK_;\n\n    using Tc = Tc_;\n\n    static constexpr int TM = TM_;\n    static constexpr int TN = TN_;\n\n    using SmemLayout = decltype(OperandC::GetSmemLayout::apply(pair<TM, TN>{}));\n\n    using SmemAccessorV2 = SmemAccessorV2<Dtype, SmemLayout, kOrder>;\n\n    using SharedStorage = Array<Dtype, SmemLayout::kSize>;\n\n    using Map = decltype(OperandC::GetThreadMap::apply(pair<M, N>{}, constant<THREADS>{}));\n\n    static constexpr int S       = Map::kIterS;\n    static constexpr int C       = Map::kIterC;\n    static constexpr int kAccess = Map::kAccessC;\n\n    template<class T>\n    using OutputC = Array<T, kAccess>;\n\n    template<class FragC>\n    __device__ void Rearrange(FragC& frag_C, SharedStorage& storage, OutputC<Dtype> (&out)[S][C])\n    {\n        SmemAccessorV2 smem_C{storage.data()};\n\n        const int2 thr_cs = Map::get_offset(threadIdx.x / WARP_SIZE, threadIdx.x % WARP_SIZE);\n\n        constexpr int kPeriodC = ceil_div(SmemLayout::C0, Map::kDeltaC);\n        constexpr int kPeriodS = ceil_div(SmemLayout::S0, Map::kDeltaS);\n\n        int phases[kPeriodS][kPeriodC];\n        PRAGMA_UNROLL\n        for (int s = 0; s < kPeriodS; ++s) {\n            PRAGMA_UNROLL\n            for (int c = 0; c < kPeriodC; ++c) {\n                phases[s][c] = SmemLayout::apply(s * Map::kDeltaS + thr_cs.y, c * Map::kDeltaC + thr_cs.x);\n            }\n        }\n\n        constexpr bool kRaked = true;\n\n        PRAGMA_UNROLL\n        for (int m = 0; m < M; m += TM) {\n            PRAGMA_UNROLL\n            for (int n = 0; n < N; n += TN) {\n                // Store to shared memory\n                RearrangeC::apply(frag_C, smem_C, {m, n}, pair<TM, TN>{});\n\n                // Load from shared memory\n                PRAGMA_UNROLL\n                for (int s = 0; s < S; ++s) {\n                    PRAGMA_UNROLL\n                    for (int c = 0; c < C; ++c) {\n                        const int cc = c * Map::kDeltaC + thr_cs.x;\n                        const int ss = s * Map::kDeltaS + thr_cs.y;\n\n                        const int2 mn =\n                            kRaked ? cs2mk<kOrder>(c * Map::kDeltaC, s * Map::kDeltaS) : cs2mk<kOrder>(cc, ss);\n                        const int  mm   = mn.x - m;\n                        const int  nn   = mn.y - n;\n                        const bool mask = (M <= TM || (0 <= mm && mm < TM)) && ((N <= TN) || (0 <= nn && nn < TN));\n\n                        const int2 _cs      = mk2cs<kOrder>(m, n);\n                        const int  offset_0 = SmemLayout::apply(  //\n                            s / kPeriodS * kPeriodS * Map::kDeltaS - _cs.y,\n                            c / kPeriodC * kPeriodC * Map::kDeltaC - _cs.x);\n                        const int  offset_p = phases[s % kPeriodS][c % kPeriodC];\n\n                        if (mask) {\n                            Load(out[s][c], &storage[offset_0 + offset_p]);\n                        }\n                    }\n                }\n                __syncthreads();\n            }\n        }\n    }\n\n    template<class T, class VecC, class Pred>\n    __device__ void StoreC(const VecC& vec_C, const MatrixData& c, int2 cs0, Pred& pred)\n    {\n        constexpr int dc  = sizeof(T) * Map::kDeltaC;\n        const int     ds  = sizeof(T) * Map::kDeltaS * c.ptr.stride;\n        char*         ptr = (char*)c.ptr.ptr + sizeof(T) * dot(cs0, long2{1, c.ptr.stride});\n        PRAGMA_UNROLL\n        for (int s = 0; s < S; ++s) {\n            PRAGMA_UNROLL\n            for (int c = 0; c < C; ++c) {\n                const auto tmp = cast<T>(vec_C[s][c]);\n                if (pred(s, c)) {\n                    Store(reinterpret_cast<T*>(ptr), tmp);\n                }\n                ptr += dc;\n            }\n            ptr -= dc * C;\n            ptr += ds;\n        }\n    }\n\n#if 0\n    template<class FragC, class Pred>\n    __device__ void\n    Reduce(FragC& frag_C, int splits, int64_t split_size, const int2& cta_cs, Pred& pred, const EpilogueParam& param)\n    {\n        using Vec         = OutputC<Dtype>;\n        const int2 thr_cs = Map::get_offset(threadIdx.x / WARP_SIZE, threadIdx.x % WARP_SIZE);\n        for (int k = 0; k < splits; ++k) {\n            PRAGMA_UNROLL\n            for (int s = 0; s < S; ++s) {\n                PRAGMA_UNROLL\n                for (int c = 0; c < C; ++c) {\n                    const int     ss  = thr_cs.y + s * Map::kDeltaS;\n                    const int     cc  = thr_cs.x + c * Map::kDeltaC;\n                    const int64_t idx = k * split_size + (cta_cs.y + ss) * param.partial_C_ld + (cta_cs.x + cc);\n                    if (true) {\n                        Vec tmp;\n                        Load(tmp, &param.partial_C[idx]);\n                        using namespace ops;\n                        frag_C[s][c] = frag_C[s][c] + tmp;\n                    }\n                }\n            }\n        }\n    }\n#endif\n\n    template<class FragC, class Pred>\n    __device__ void Reduce(FragC& frag_C, const MatrixData& p, bool is_first, bool is_last, int2 cs0, Pred& pred)\n    {\n        constexpr int dc = sizeof(Dtype) * Map::kDeltaC;\n        const int     ds = sizeof(Dtype) * Map::kDeltaS * p.ptr.stride;\n\n        char* ptr = (char*)p.ptr.ptr + sizeof(Dtype) * dot(cs0, long2{1, p.ptr.stride});\n\n        Pred ld_mask = is_first ? Pred{} : pred;\n        Pred st_mask = is_last ? Pred{} : pred;\n\n        PRAGMA_UNROLL\n        for (int s = 0; s < S; ++s) {\n            PRAGMA_UNROLL\n            for (int c = 0; c < C; ++c) {\n                OutputC<Dtype> tmp{};  // ! ZERO-filled\n                if (ld_mask(s, c)) {\n                    Load(tmp, reinterpret_cast<Dtype*>(ptr));\n                }\n                if (1) {\n                    using namespace ops;\n                    frag_C[s][c] = frag_C[s][c] + tmp;\n                }\n                if (st_mask(s, c)) {\n                    Store(reinterpret_cast<Dtype*>(ptr), frag_C[s][c]);\n                }\n                ptr += dc;\n            }\n            ptr -= dc * C;\n            ptr += ds;\n        }\n    }\n\n    template<class FragC>\n    __device__ void operator()(FragC&               frag_C,\n                               const int4&          tile_offset,\n                               const int2&          extents,\n                               int                  splits,\n                               int                  tile_id,\n                               bool                 is_last,\n                               const EpilogueParam& param,\n                               SharedStorage&       storage)\n    {\n        const int2 cta_cs = mk2cs<kOrder>(tile_offset.x * M, tile_offset.y * N);\n        const int2 end_cs = mk2cs<kOrder>(extents);\n\n        OutputC<Dtype> tmp_C[S][C];\n\n        Rearrange(frag_C, storage, tmp_C);\n\n        Predicate<S, C, false, false> pred{};  //  1 regs\n\n        const int2 thr_cs = Map::get_offset(threadIdx.x / WARP_SIZE, threadIdx.x % WARP_SIZE);\n        const int2 cs0    = {cta_cs.x + thr_cs.x, cta_cs.y + thr_cs.y};\n\n        PRAGMA_UNROLL\n        for (int s = 0; s < S; ++s) {\n            PRAGMA_UNROLL\n            for (int c = 0; c < C; ++c) {\n                const int ss = thr_cs.y + s * Map::kDeltaS;\n                const int cc = thr_cs.x + c * Map::kDeltaC;\n                if (ss < end_cs.y && cc < end_cs.x) {\n                    pred.set(s, c);\n                }\n            }\n        }\n\n        if (SplitK_ && splits > 1) {\n            int* barrier = &param.locks[tile_id];\n\n            sem_wait(barrier, tile_offset.z, threadIdx.x == 0);\n\n            const MatrixData p = resolve<Dtype, kMode>(param.partials, tile_offset.w);\n\n            Reduce(tmp_C, p, tile_offset.z == 0, is_last, cs0, pred);\n\n            const int post_id = is_last ? 0 : tile_offset.z + 1;\n            sem_post(barrier, post_id, threadIdx.x == 0);\n\n            if (!is_last) {\n                return;\n            }\n        }\n\n        constexpr pair<Map::kDeltaC, Map::kDeltaS> delta_cs{};\n\n        // opt-in scaling\n        // Scale(scale_SC{}, mode_SC{}, delta_cs, tmp_C, param.scale_S, param.scale_C, tile_offset.w, cs0, pred);\n\n        param.combine_mat((Tc*)0, constant<kMode>{}, tmp_C, cs0, tile_offset.w, delta_cs, pred);\n\n        const MatrixData c = resolve<Tc, kMode>(param.c, tile_offset.w);\n\n        if (param.silu_act) {\n            constexpr int dc  = sizeof(Tc) * Map::kDeltaC / 2;\n            const int     ds  = sizeof(Tc) * Map::kDeltaS * c.ptr.stride;\n            auto          ptr = (char*)c.ptr.ptr + sizeof(Tc) * dot({cs0.x / 2, cs0.y}, long2{1, c.ptr.stride});\n            PRAGMA_UNROLL\n            for (int s = 0; s < S; ++s) {\n                PRAGMA_UNROLL\n                for (int c = 0; c < C; ++c) {\n                    GatedActivation<Silu>::apply(tmp_C[s][c]);\n                    if (pred(s, c)) {\n                        const auto tmp = cast<Tc>((Array<Dtype, kAccess / 2>&)tmp_C[s][c]);\n                        Store(reinterpret_cast<Tc*>(ptr), tmp);\n                    }\n                    ptr += dc;\n                }\n                ptr -= dc * C;\n                ptr += ds;\n            }\n        }\n        else {\n            StoreC<Tc>(tmp_C, c, cs0, pred);\n        }\n    }\n};\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/format.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include \"src/turbomind/kernels/core/array.h\"\n\nnamespace turbomind::gemm {\n\ntemplate<class Tin, class Tout>\nstruct Converter {\n};\n\ntemplate<class T>\nstruct Converter<T, T> {\n    template<int N>\n    __device__ Array<T, N> operator()(Array<T, N> x)\n    {\n        return x;\n    }\n};\n\ntemplate<>\nstruct Converter<uint16_t, uint4_t> {\n\n    static __device__ Array<uint4_t, 8> pack(const Array<uint8_t, 8>& vi)\n    {\n        Array<uint32_t, 2> ui = (Array<uint32_t, 2>&)vi;\n\n        ui[0] |= (ui[0] >> 12);\n        ui[1] |= (ui[1] >> 12);\n\n        //  7 6 5 4 3 2 1 0\n        // _7_67564_3_23120\n        uint32_t uo = __byte_perm(ui[0], ui[1], 0x5140);\n\n        return (Array<uint4_t, 8>&)uo;\n    }\n\n    template<class U, int N>\n    __device__ Array<uint4_t, N> operator()(const Array<U, N>& x)\n    {\n        static_assert(sizeof(U) == 2);\n        auto&             vi = (const Array<uint16_t, N>&)x;\n        Array<uint8_t, N> tmp;\n        PRAGMA_UNROLL\n        for (int i = 0; i < N; ++i) {\n            tmp[i] = static_cast<uint8_t>(vi[i]);\n        }\n        Array<uint4_t, N> vo;\n        PRAGMA_UNROLL\n        for (int i = 0; i < N; i += 8) {\n            (Array<uint4_t, 8>&)vo[i] = pack((Array<uint8_t, 8>&)tmp[i]);\n        }\n        return vo;\n    }\n};\n\ntemplate<>\nstruct Converter<uint16_t, uint8_t> {\n    template<int N>\n    __device__ Array<uint8_t, N> operator()(const Array<uint16_t, N>& x)\n    {\n        static_assert(N % 4 == 0);\n        Array<uint8_t, N> vo;\n        PRAGMA_UNROLL\n        for (int i = 0; i < N; i += 4) {\n            // 3120\n            vo[i + 0] = (uint8_t)x[i + 0];\n            vo[i + 1] = (uint8_t)x[i + 2];\n            vo[i + 2] = (uint8_t)x[i + 1];\n            vo[i + 3] = (uint8_t)x[i + 3];\n        }\n        return vo;\n    }\n};\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/gemm.cu",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include \"src/turbomind/core/check.h\"\n#include \"src/turbomind/kernels/gemm/context.h\"\n#include \"src/turbomind/kernels/gemm/desc.h\"\n#include \"src/turbomind/kernels/gemm/dispatch_cache.h\"\n#include \"src/turbomind/kernels/gemm/gemm.h\"\n#include \"src/turbomind/kernels/gemm/kernel.h\"\n#include \"src/turbomind/kernels/gemm/registry.h\"\n#include \"src/turbomind/kernels/gemm/tuner/params.h\"\n#include \"src/turbomind/kernels/gemm/tuner/sampler.h\"\n#include \"src/turbomind/kernels/gemm/types.h\"\n#include <algorithm>\n#include <iterator>\n#include <memory>\n#include <numeric>\n#include <optional>\n#include <vector>\n\nnamespace turbomind::gemm {\n\nvoid ExportDispatchCache(std::ostream& os, const std::vector<std::pair<GemmDesc, LaunchSpec>>& entries);\n\nvoid ImportDispatchCache(std::istream&                                 is,\n                         std::vector<std::pair<GemmDesc, LaunchSpec>>& entries,\n                         const std::vector<std::unique_ptr<Kernel>>&   kernels);\n\nnamespace {\n\ntemplate<class Cmp>\nstd::vector<int> ArgSort(size_t size, const Cmp& cmp)\n{\n    std::vector<int> idxs(size);\n    std::iota(idxs.begin(), idxs.end(), 0);\n    std::stable_sort(idxs.begin(), idxs.end(), cmp);\n    return idxs;\n}\n\n}  // namespace\n\nstruct Gemm::Impl {\n\n    Impl():\n        props_{GetCudaDeviceProps()},\n        arch_{props_->major * 100 + props_->minor * 10},\n        registry_{props_},\n        cache_{registry_.kernels()}\n    {\n        if (auto str = std::getenv(\"TM_GEMM_TUNE\")) {\n            try {\n                ParseTuningParams(tuning_, str);\n            }\n            catch (...) {\n                std::cerr << \"[Gemm2] Failed to parse `TM_GEMM_TUNE`, default value will be used.\\n\";\n                tuning_ = {};\n            }\n        }\n        if (std::getenv(\"TM_GEMM_WARN_CACHE_MISS\")) {\n            warn_cache_miss_ = true;\n        }\n        measurer_.emplace(CreateStoppingCriterion(tuning_.min_iter, tuning_.max_iter, tuning_.max_time));\n    }\n\n    // find launch spec in dispatch cache, dispatch by heuristic on cache miss\n    LaunchSpec Dispatch(Context& ctx, DispatchPolicy policy, size_t barriers_size, size_t partials_size)\n    {\n        const auto& desc = ctx.desc();\n        if (policy & DispatchPolicy::kReuse) {\n            if (auto spec = cache_.LowerBound(desc)) {\n                return *spec;\n            }\n            if (warn_cache_miss_) {\n                std::cerr << \"Failed to find a feasible kernel in the cache, will dispatch by heuristic: \"\n                          << to_string(ctx.desc()) << std::endl;\n            }\n        }\n\n        if (auto spec = cache_.Find(desc)) {\n            return *spec;\n        }\n\n        auto specs = Find(ctx, barriers_size, partials_size, 1);\n        if (!specs.empty()) {\n            cache_.Insert(desc, specs.front());\n            return specs.front();\n        }\n        return {};\n    }\n\n    std::vector<LaunchSpec> Find(Context& ctx, size_t barrier_size, size_t partials_size, int top_k)\n    {\n        std::vector<Kernel*> feasible = ctx.Filter(registry_.kernels());\n\n        std::vector<std::vector<LaunchSpec>> clusters;\n        {\n            std::vector<LaunchSpec> tmp;\n            tmp.reserve(feasible.size());\n            for (const auto& k : feasible) {\n                LaunchSpec spec{k};\n                tmp.push_back(spec);\n            }\n            clusters = Cluster(tmp, ClusteringParam{false, true});\n        }\n        std::vector<Kernel*> proxies;\n        proxies.reserve(clusters.size());\n\n        for (const auto& c : clusters) {\n            proxies.push_back(c.front().kernel);\n        }\n\n        std::vector<std::pair<int, LaunchSpec>> specs;\n\n        PopulateParam param{};\n        param.max_splits    = tuning_.max_splits;\n        param.max_waves     = tuning_.max_waves;\n        param.swizzle       = tuning_.swizzle.at(0);\n        param.barriers_size = barrier_size;\n        param.partials_size = partials_size;\n\n        for (int cluster_id = 0; cluster_id < (int)proxies.size(); ++cluster_id) {\n            auto& kernel = *proxies[cluster_id];\n\n            auto tmp = ctx.Populate(kernel, param);\n            for (const auto& s : tmp) {\n                specs.emplace_back(cluster_id, s);\n            }\n        }\n\n        // std::cerr << \"#kernel: \" << kernels.size() << \", #cluster: \" << clusters.size()\n        //           << \", #metric: \" << metrics.size() << \"\\n\";\n\n        int64_t mio_max = 0;\n        int64_t mma_max = 0;\n        for (const auto& [_, s] : specs) {\n            auto& [mio, mma] = s.estimated;\n            mio_max          = std::max(mio_max, mio);\n            mma_max          = std::max(mma_max, mma);\n        }\n        std::vector<float> mio_ratio;\n        std::vector<float> mma_ratio;\n        std::vector<float> avg_ratio;\n        for (const auto& [_, s] : specs) {\n            auto& [mio, mma] = s.estimated;\n            mio_ratio.push_back((float)mio / mio_max);\n            mma_ratio.push_back((float)mma / mma_max);\n            avg_ratio.push_back(.5 * (mio_ratio.back() + mma_ratio.back()));\n        }\n        auto idxs = ArgSort(specs.size(), [&](int i, int j) {  //\n            return avg_ratio[i] < avg_ratio[j];\n        });\n\n        // for (const auto& i : idxs) {\n        //     auto [cid, s, m] = metrics[i];\n        //     std::cout << clusters[cid].front().kernel->name() << \" s\" << s << \" \" << avg_ratio[i] << \" \" <<\n        //     mio_ratio[i]\n        //               << \" \" << mma_ratio[i] << \" \" << m.mio_cost << \" \" << m.mma_cost << \"\\n\";\n        // }\n\n        top_k = top_k > 0 ? std::min<int>(idxs.size(), top_k) : (int)idxs.size();\n        std::vector<LaunchSpec> ret;\n        ret.reserve(top_k);\n        for (int i = 0; i < top_k; ++i) {\n            const auto& [cluster_id, spec] = specs[idxs[i]];\n            // Apply `splits` to all kernels in the cluster\n            for (const auto& s : clusters[cluster_id]) {\n                auto tmp   = spec;\n                tmp.kernel = s.kernel;\n                ret.push_back(tmp);\n            }\n        }\n\n        return ret;\n    }\n\n    template<class LaunchFunc>\n    int Measure(\n        Context& ctx, size_t barriers_size, size_t partials_size, int top_k, LaunchFunc launch_func, cudaStream_t st)\n    {\n        // Early exit on exact match\n        if (cache_.Find(ctx.desc())) {\n            return 0;\n        }\n        // std::cerr << \"GEMM: \" << desc.m << \"x\" << desc.n << \"x\" << desc.k << \"\\n\";\n\n        const auto tmp = Find(ctx, barriers_size, partials_size, tuning_.top_k);\n\n        std::vector<LaunchSpec> specs;\n        for (const auto& spec : tmp) {\n            // populate swizzle parameters\n            const auto swis = ctx.Swizzle(spec, tuning_.swizzle);\n            specs.insert(specs.end(), swis.begin(), swis.end());\n        }\n\n        specs = Sampler{*measurer_, tuning_.clusters}.Run(specs, launch_func, st);\n\n        // for (const auto& s : specs) {\n        //     std::cout << s.kernel->name()          //\n        //               << \" swizzle=\" << s.swizzle  //\n        //               << \", splits=\" << s.splits   //\n        //               << \", measured=\" << s.measured << \"ms\\n\";\n        //     break;\n        // }\n\n        if (!specs.empty()) {\n            cache_.Insert(ctx.desc(), specs.front());\n        }\n        else {\n            std::cerr << \"No valid kernel found for the problem\\n\";\n            return -1;\n        }\n\n        return 0;\n    }\n\n    /// TODO: move to cuda utils\n    static std::unique_ptr<cudaDeviceProp> GetCudaDeviceProps()\n    {\n        auto props     = std::make_unique<cudaDeviceProp>();\n        int  device_id = -1;\n        cudaGetDevice(&device_id);\n        cudaGetDeviceProperties(props.get(), device_id);\n        return props;\n    }\n\n    std::shared_ptr<cudaDeviceProp> props_;\n\n    int arch_;\n\n    Registry registry_;\n\n    TuningParams tuning_;\n\n    bool warn_cache_miss_{};\n\n    std::optional<Measurer> measurer_;\n\n    DispatchCache cache_;\n};\n\n// implementation of GEMM interfaces\n\nGemm::Gemm(): impl_{new Impl{}} {}\n\nGemm::~Gemm() = default;\n\nint Gemm::Run(const Operation&    operation,\n              float               alpha,\n              const void*         A,\n              const MatrixLayout& Adesc,\n              const void*         U,\n              const MatrixLayout& Udesc,\n              const void*         B,\n              const MatrixLayout& Bdesc,\n              const void*         V,\n              const MatrixLayout& Vdesc,\n              float               beta,\n              const void*         C,\n              const MatrixLayout& Cdesc,\n              void*               D,\n              const MatrixLayout& Ddesc,\n              const Workspace&    workspace,\n              cudaStream_t        stream)\n{\n\n    Context context{*impl_->props_};\n\n    const auto desc = context.Init(operation, Adesc, Udesc, Bdesc, Vdesc, Cdesc, Ddesc);\n\n    if (!desc) {\n        fprintf(stderr, \"invalid argument.\\n\");\n        TM_CHECK(0);\n        return -1;\n    }\n\n    const auto launch = [=](LaunchSpec spec, cudaStream_t st) {\n        auto _workspace = workspace;\n        return spec.kernel->Launch(operation,\n                                   alpha,\n                                   A,\n                                   Adesc,\n                                   U,\n                                   Udesc,\n                                   B,\n                                   Bdesc,\n                                   V,\n                                   Vdesc,\n                                   beta,\n                                   C,\n                                   Cdesc,\n                                   D,\n                                   Ddesc,\n                                   spec.swizzle,\n                                   spec.splits,\n                                   _workspace,\n                                   st);\n    };\n\n#if 0\n    if (operation.reserved) {\n        auto specs = impl_->Find(context, workspace.barriers_size, workspace.partials_size, 0);\n        auto cases = (std::vector<std::function<LaunchSpec()>>*)operation.reserved;\n        for (const auto& spec : specs) {\n            cases->push_back([=] {\n                launch(spec, stream);\n                return spec;\n            });\n        }\n        return -1;\n    }\n#endif\n\n    LaunchSpec spec{};\n\n    if (operation.dispatch & DispatchPolicy::kMeasure) {\n        impl_->Measure(context, workspace.barriers_size, workspace.partials_size, 1, launch, stream);\n    }\n\n    spec = impl_->Dispatch(context, operation.dispatch, workspace.barriers_size, workspace.partials_size);\n\n    if (spec.kernel) {\n        // std::cout << \"[Gemm] dispatch: \" << spec.kernel->name()  //\n        //           << \" split_k=\" << spec.splits                  //\n        //           << \" swizzle=\" << spec.swizzle << std::endl;\n        return launch(spec, stream);\n    }\n\n    TM_CHECK(0) << \"No feasible kernel found for the problem: \" << to_string(context.desc());\n\n    return -1;\n}\n\nint Gemm::Export(std::ostream& os)\n{\n    return impl_->cache_.Export(os);\n}\n\nint Gemm::Import(std::istream& is)\n{\n    return impl_->cache_.Import(is);\n}\n\nstd::vector<int> Gemm::GetTuningSeq() const\n{\n    return impl_->tuning_.seq;\n}\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/gemm.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include <memory>\n#include <vector>\n\n#include <cuda_runtime.h>\n\n#include \"src/turbomind/kernels/gemm/types.h\"\n\nnamespace turbomind::gemm {\n\nclass Gemm {\npublic:\n    static constexpr size_t kBarriersSize = 1 << 20;\n    static constexpr size_t kPartialsSize = 32 << 20;\n\n    Gemm();\n\n    ~Gemm();\n\n    [[nodiscard]] int Run(const Operation&    operation,\n                          float               alpha,\n                          const void*         A,\n                          const MatrixLayout& Adesc,\n                          const void*         U,\n                          const MatrixLayout& Udesc,\n                          const void*         B,\n                          const MatrixLayout& Bdesc,\n                          const void*         V,\n                          const MatrixLayout& Vdesc,\n                          float               beta,\n                          const void*         C,\n                          const MatrixLayout& Cdesc,\n                          void*               D,\n                          const MatrixLayout& Ddesc,\n                          const Workspace&    workspace,\n                          cudaStream_t        stream);\n\n    [[maybe_unused]] int Export(std::ostream& os);\n\n    [[maybe_unused]] int Import(std::istream& is);\n\n    [[nodiscard]] std::vector<int> GetTuningSeq() const;\n\nprivate:\n    struct Impl;\n    std::unique_ptr<Impl> impl_;\n};\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/gemm_universal.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include <climits>\n\n#include \"src/turbomind/kernels/core/array_ops.h\"\n#include \"src/turbomind/kernels/core/common.h\"\n#include \"src/turbomind/kernels/core/data_type.h\"\n#include \"src/turbomind/kernels/core/layout.h\"\n#include \"src/turbomind/kernels/core/math.h\"\n\n#include \"src/turbomind/kernels/gemm/cta_map.h\"\n#include \"src/turbomind/kernels/gemm/desc.h\"\n#include \"src/turbomind/kernels/gemm/epilogue.h\"\n#include \"src/turbomind/kernels/gemm/thread_map.h\"\n#include \"src/turbomind/kernels/gemm/types.h\"\n#include \"src/turbomind/kernels/gemm/utils.h\"\n\nnamespace turbomind::gemm {\n\nstruct GemmParam {\n    MatrixParam a;\n    MatrixParam b;\n    MatrixParam u;\n    MatrixParam v;\n};\n\ntemplate<class Op>\n__inline__ __device__ MatrixData resolve_op(const MatrixParam& param, int gemm_id)\n{\n    return resolve<typename Op::Dtype, Op::GmemIter::kMode>(param, gemm_id);\n}\n\ntemplate<class Arch_, class Mainloop, class Epilogue_, class Scheduler_>\nstruct GemmUniversal {\n\n    // using Impl = typename Mainloop::Impl;\n    using Impl = Mainloop;\n\n    using Ta = typename Impl::Ta;\n    using Tb = typename Impl::Tb;\n    using Tu = typename Impl::Tu;\n    using Tv = typename Impl::Tv;\n\n    using Arch      = Arch_;\n    using Scheduler = Scheduler_;\n    using Epilogue  = Epilogue_;\n\n    using Tc = typename Epilogue::Tc;\n\n    // col major == M-major (A)\n    // row major == N-major (B)\n    static constexpr Order kOrderC = Epilogue::kOrder;\n\n    static constexpr int CTA_M = Impl::CTA_M;\n    static constexpr int CTA_N = Impl::CTA_N;\n    static constexpr int CTA_K = Impl::CTA_K;\n\n    static constexpr bool kDynamicSched = Scheduler::group_axis >= 0;\n    static constexpr bool kSplitK       = Epilogue::SplitK;\n\n    using FragC = typename Impl::FragC;\n\n    static constexpr int WARP_CNT = Impl::WARPS;\n\n    using OperandA = typename Mainloop::OperandA;\n    using OperandB = typename Mainloop::OperandB;\n    using OperandU = typename Mainloop::OperandU;\n    using OperandV = typename Mainloop::OperandV;\n\n    static constexpr int kChunkSizeK = std::max(CTA_K, std::max(OperandU::kGroupSize, OperandV::kGroupSize));\n\n    static constexpr int kGSizeU = OperandU::kGroupSize;\n    static constexpr int kGSizeV = OperandV::kGroupSize;\n\n    struct SharedStorage {\n        union {\n            typename Mainloop::SharedStorage mainloop;\n            typename Epilogue::SharedStorage epilogue;\n        };\n        typename Scheduler::SharedStorage sched;\n    };\n\n    static constexpr Order kOrderA = OperandA::kOrder;\n    static constexpr Order kOrderB = OperandB::kOrder;\n    static constexpr Order kOrderU = OperandU::kOrder;\n    static constexpr Order kOrderV = OperandV::kOrder;\n\n    static constexpr Pack kPackA = OperandA::kPack;\n    static constexpr Pack kPackB = OperandB::kPack;\n\n    using Param = GemmParam;\n\n    __device__ void operator()(const Param& param, const EpilogueParam& epi_param, Scheduler& sched, char* smem_buf)\n    {\n        SharedStorage& storage = *reinterpret_cast<SharedStorage*>(smem_buf);\n\n        typename Scheduler::Tile tile;\n\n        if (!sched.init(tile, storage.sched, std::false_type{})) {\n            return;\n        }\n\n        const auto& [M, N, K] = tile.shape.__a;\n\n        const auto tile_id = tile.tile_id;\n\n        const int offset_m = tile_id[0] * CTA_M;\n        const int offset_n = tile_id[1] * CTA_N;\n\n        const int offset_k = tile.k_iters[0] * CTA_K;\n\n        if (offset_m >= M || offset_n >= N || offset_k >= K) {  // empty tile\n            return;\n        }\n\n        const int extent_m = min(CTA_M, M - offset_m);\n        const int extent_n = min(CTA_N, N - offset_n);\n\n        // Is 8 enough?\n        __align__(8) FragC frag_C{};\n\n        int tile_iter = tile.k_iters[1];\n\n        const int g = tile.group_id;\n\n        const auto mat_A = resolve_op<OperandA>(param.a, g);\n        const auto mat_B = resolve_op<OperandB>(param.b, g);\n        const auto mat_U = resolve_op<OperandU>(param.u, g);\n        const auto mat_V = resolve_op<OperandV>(param.v, g);\n\n        typename OperandA::GmemIter gmem_A{mat_A, {offset_m, offset_k}, {extent_m, CTA_K}};\n        typename OperandB::GmemIter gmem_B{mat_B, {offset_n, offset_k}, {extent_n, CTA_K}};\n\n        const int2 offset_U{offset_m, cdiv(offset_k, kGSizeU)}, extent_U{extent_m, cdiv(CTA_K, kGSizeU)};\n        typename OperandU::GmemIter gmem_U{mat_U, offset_U, extent_U};\n\n        const int2 offset_V{offset_n, cdiv(offset_k, kGSizeV)}, extent_V{extent_n, cdiv(CTA_K, kGSizeV)};\n        typename OperandV::GmemIter gmem_V{mat_V, offset_V, extent_V};\n\n        Mainloop mainloop{};\n        mainloop(gmem_A, gmem_B, gmem_U, gmem_V, frag_C, tile_iter, storage.mainloop);\n\n        {\n            sched.init(tile, storage.sched, std::true_type{});\n\n            const auto [M, N, K] = tile.shape.__a;\n\n            int4 tile_offset{tile.tile_id[0], tile.tile_id[1], tile.tile_id[2], tile.group_id};\n\n            const int2 extents = {min(CTA_M, M - tile_offset.x * CTA_M), min(CTA_N, N - tile_offset.y * CTA_N)};\n\n            const bool is_last = (tile.k_iters[0] + tile.k_iters[1]) * CTA_K == K;\n\n            Epilogue epilogue{};\n            epilogue(frag_C,  //\n                     tile_offset,\n                     extents,\n                     sched.tiles_[2],\n                     tile.linear_tile_id,\n                     is_last,\n                     epi_param,\n                     storage.epilogue);\n        }\n    }\n};\n\nextern __shared__ char smem_buf[];\n\ntemplate<class Kernel, class Param, class EpilogueParam, class Scheduler>\n__global__ void gemm_kernel(Param param, EpilogueParam epi_param, Scheduler sched)\n{\n#if __CUDA_ARCH__\n    if constexpr (Kernel::Arch::is_compatible(__CUDA_ARCH__)) {\n        Kernel kernel;\n        kernel(param, epi_param, sched, smem_buf);\n    }\n#endif\n}\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/gemm_universal_sm90.h",
    "content": "#pragma once\n\n#include <utility>\n\n#include <cuda_fp8.h>\n\n#include \"cute/arch/cluster_sm90.hpp\"\n#include \"cute/arch/copy_sm90.hpp\"\n#include \"cute/arch/copy_sm90_tma.hpp\"\n#include \"cute/arch/mma_sm90_desc.hpp\"\n#include \"cute/arch/mma_sm90_gmma.hpp\"\n#include \"cute/atom/mma_traits.hpp\"\n#include \"cute/atom/mma_traits_sm90_gmma.hpp\"\n\n#include \"cutlass/arch/barrier.h\"\n#include \"cutlass/arch/reg_reconfig.h\"\n#include \"cutlass/cutlass.h\"\n#include \"cutlass/pipeline/sm90_pipeline.hpp\"\n\n#include \"src/turbomind/kernels/core/array_ops.h\"\n#include \"src/turbomind/kernels/core/common.h\"\n#include \"src/turbomind/kernels/core/smem.h\"\n#include \"src/turbomind/kernels/gemm/iterator_sm70.h\"\n#include \"src/turbomind/kernels/gemm/iterator_sm90.h\"\n#include \"src/turbomind/kernels/gemm/scheduler.cuh\"\n#include \"src/turbomind/kernels/gemm/types.h\"\n\nnamespace turbomind::gemm {\n\nnamespace GMMA = cute::SM90::GMMA;\n\ninline __device__ cute::GmmaDescriptor make_smem_desc(void* smem_ptr, int layout_type)\n{\n    auto uint_ptr = cast_smem_ptr_to_uint(smem_ptr);\n\n    cute::GmmaDescriptor desc{};\n    desc.bitfield.start_address_       = uint_ptr >> 4;\n    desc.bitfield.layout_type_         = layout_type;\n    desc.bitfield.leading_byte_offset_ = 0;\n    desc.bitfield.stride_byte_offset_  = 1024 >> 4;\n    desc.bitfield.base_offset_         = 0;\n\n    return desc;\n}\n\ntemplate<int Stages, int Step>\nstruct SmemDescIterV2 {\n    union {\n        uint32_t u32_[2];\n        uint64_t u64_;\n    };\n\n    uint32_t base_;\n\n    __device__ SmemDescIterV2(uint64_t desc): u64_{desc}, base_{u32_[0]} {}\n\n    __device__ void Advance(int stage)\n    {\n        u32_[0] += Step;\n        if (stage == Stages - 1) {\n            u32_[0] = base_;\n        }\n    }\n\n    __device__ SmemDescIterV2& operator+=(int offset)\n    {\n        u32_[0] += offset;\n        return *this;\n    }\n\n    __device__ SmemDescIterV2& operator-=(int offset)\n    {\n        u32_[0] -= offset;\n        return *this;\n    }\n\n    __device__ operator uint64_t()\n    {\n        return u64_;\n    }\n};\n\ntemplate<class MMA_Atom, size_t... Is>\ninline __device__ void\nwgmma_impl(uint64_t desc_a, uint64_t desc_b, float* frag_C, bool clear, std::index_sequence<Is...>)\n{\n    return MMA_Atom::fma(desc_a, desc_b, frag_C[Is]..., clear ? GMMA::ScaleOut::Zero : GMMA::ScaleOut::One);\n}\n\ntemplate<class MMA_Atom, int N>\ninline __device__ void wgmma(uint64_t desc_a, uint64_t desc_b, float (&frag_C)[N], bool clear)\n{\n    return wgmma_impl<MMA_Atom>(desc_a, desc_b, frag_C, clear, std::make_index_sequence<N>{});\n}\n\ninline __device__ void warpgroup_fence_operand(float& reg)\n{\n    asm volatile(\"\" : \"+f\"(reg)::\"memory\");\n}\n\ntemplate<int M, int N, int K>\ninline __device__ void warpgroup_fence_operand(float (&x)[M][N][K])\n{\n    PRAGMA_UNROLL\n    for (int m = 0; m < M; ++m) {\n        PRAGMA_UNROLL\n        for (int n = 0; n < N; ++n) {\n            PRAGMA_UNROLL\n            for (int k = 0; k < K; ++k) {\n                warpgroup_fence_operand(x[m][n][k]);\n            }\n        }\n    }\n}\n\ntemplate<class Arch_>\nstruct GemmUniversalSm90 {\n\n    // using MMA_Atom = GMMA::MMA_64x128x16_F32BF16BF16_SS<GMMA::Major::K, GMMA::Major::K>;\n    using MMA_Atom = GMMA::MMA_64x112x32_F32E4M3E4M3_SS_TN<>;\n    static constexpr typename cute::MMA_Traits<MMA_Atom>::Shape_MNK MMA_Shape{};\n\n    static constexpr int MMA_ATOM_M = cute::get<0>(MMA_Shape);\n    static constexpr int MMA_ATOM_N = cute::get<1>(MMA_Shape);\n    static constexpr int MMA_ATOM_K = cute::get<2>(MMA_Shape);\n\n    static constexpr int kWorkGroupM = 1;\n    static constexpr int kWorkGroupN = 2;\n\n    static constexpr int CTA_M = 128;\n    static constexpr int CTA_N = MMA_ATOM_N * kWorkGroupN;\n    static constexpr int CTA_K = 128;\n\n    static constexpr int WARPGORUPS = kWorkGroupM * kWorkGroupN;\n\n    static constexpr int MMA_M = MMA_ATOM_M * kWorkGroupM;\n    static constexpr int MMA_N = MMA_ATOM_N * kWorkGroupN;\n    static constexpr int MMA_K = MMA_ATOM_K;\n\n    static constexpr int MMA_ITER_M = CTA_M / MMA_M;  // 2\n    static constexpr int MMA_ITER_N = CTA_N / MMA_N;  // 1\n    static constexpr int MMA_ITER_K = CTA_K / MMA_K;  // 4\n\n    static constexpr int kMulticastA = 1;\n    static constexpr int kMulticastB = 2;\n\n    static constexpr int kClusterSize = kMulticastA * kMulticastB;\n\n    static constexpr int Stages = 3;\n\n    static constexpr bool kSplitK     = false;\n    static constexpr int  kChunkSizeK = CTA_K;\n\n    static constexpr int WARPGROUP_SIZE = 128;\n\n    static constexpr int CTA_SIZE = WARPGROUP_SIZE * (WARPGORUPS + 1);\n\n    using Ta = __nv_fp8_e4m3;\n    using Tb = __nv_fp8_e4m3;\n    using Tc = nv_bfloat16;\n\n    using Tu = float;\n    using Tv = float;\n\n    using Arch      = Arch_;\n    using Scheduler = TileScheduler<kRowMajor, kMulticastB, kMulticastA>;\n\n    using ProducerBar = cutlass::arch::ClusterTransactionBarrier;\n    using ConsumerBar = cutlass::arch::ClusterBarrier;\n\n    static constexpr int CTA_M_U = cdiv(CTA_M, 128);\n    static constexpr int CTA_K_U = cdiv(CTA_K, 128);\n    static constexpr int CTA_K_V = cdiv(CTA_K, 128);\n    static constexpr int CTA_N_V = cdiv(CTA_N, 1);\n\n    static constexpr int kTmaTxBytes =\n        sizeof(Ta) * (CTA_M * CTA_K) + sizeof(Tb) * (CTA_K * CTA_N) + sizeof(Tv) * CTA_N_V * CTA_K_V;\n\n    struct SharedStorage {\n        struct Source {\n            __align__(128) Array<Ta, Stages * CTA_M * CTA_K> A;\n            __align__(128) Array<Tb, Stages * CTA_K * CTA_N> B;\n            __align__(128) Tu U[Stages][round_up(CTA_M_U * CTA_K_U, 32)];\n            __align__(128) Tv V[Stages][round_up(CTA_N_V * CTA_K_V, 32)];  // (k1,n256)\n        };\n        Source source;\n        __align__(128) Array<Tc, CTA_M * CTA_N> C;\n        __align__(128) float UV[WARPGORUPS][round_up(CTA_M_U * CTA_N_V, 32)];\n        __align__(128) uint64_t producer_bar[Stages];\n        __align__(128) uint64_t consumer_bar[Stages];\n    };\n\n    __device__ void operator()(const CUtensorMap& tm_a,\n                               const CUtensorMap& tm_b,\n                               const CUtensorMap& tm_c,\n                               const CUtensorMap& tm_u,\n                               const CUtensorMap& tm_v,\n                               const void*        U_,\n                               int                ldU,\n                               const void*        V_,\n                               int                ldV,\n                               Scheduler          sched,\n                               char*              smem_buf)\n    {\n        sched.grid_init();\n\n        SharedStorage& storage = *reinterpret_cast<SharedStorage*>(smem_buf);\n\n        uint64_t* producer_bar = storage.producer_bar;\n        uint64_t* consumer_bar = storage.consumer_bar;\n\n        if (threadIdx.x == 0) {\n            PRAGMA_UNROLL\n            for (int s = 0; s < Stages; ++s) {\n                ProducerBar::init(&producer_bar[s], 1);\n                ConsumerBar::init(&consumer_bar[s], kClusterSize * WARPGORUPS);\n            }\n            cutlass::arch::fence_view_async_shared();\n            if constexpr (kClusterSize > 1) {\n                cutlass::arch::fence_barrier_init();\n            }\n        }\n\n        (kClusterSize > 1) ? cute::cluster_sync() : __syncthreads();\n\n        const int warpgroup_id = cutlass::canonical_warp_group_idx();\n\n        if (warpgroup_id == WARPGORUPS) {\n            cutlass::arch::warpgroup_reg_dealloc<32>();\n\n            static_assert(CTA_M % kMulticastA == 0);\n            static_assert(CTA_N % kMulticastB == 0);\n\n            const int cta_id = cute::block_id_in_cluster().x;\n\n            const int mc_offset_m = kMulticastA > 1 ? cta_id * (CTA_M / kMulticastA) : 0;\n            const int mc_offset_n = kMulticastB > 1 ? cta_id * (CTA_N / kMulticastB) : 0;\n\n            auto  smem_A = storage.source.A.data() + mc_offset_m * CTA_K;\n            auto  smem_B = storage.source.B.data() + mc_offset_n * CTA_K;\n            auto& smem_U = storage.source.U;\n            auto& smem_V = storage.source.V;\n\n            if (threadIdx.x == WARPGORUPS * WARPGROUP_SIZE) {\n                cutlass::PipelineState<Stages> write_state{0, 1, 0};\n                while (sched.next()) {\n                    auto [valid_cta_tile_p, cluster_tile_p] = sched.is_valid_tile();\n\n                    if (!cluster_tile_p) {\n                        // OOB tile caused by swizzle pattern\n                        continue;\n                    }\n\n                    const auto tile_offset              = sched.tile_offset();\n                    const auto [iter_k_beg, iter_k_end] = sched.iter_k_range();\n\n                    const int offset_m = tile_offset.x * CTA_M;\n                    const int offset_n = tile_offset.y * CTA_N;\n                    const int offset_k = 0 * CTA_K;\n\n                    int k_iter = iter_k_end - iter_k_beg;\n\n                    GmemIteratorSm90<kMulticastA> gmem_A{&tm_a, {offset_k, offset_m + mc_offset_m}, {CTA_K, 0}};\n                    GmemIteratorSm90<kMulticastB> gmem_B{&tm_b, {offset_k, offset_n + mc_offset_n}, {CTA_K, 0}};\n\n                    // column-major\n                    GmemIteratorSm90<false> gmem_V{&tm_v, {offset_n, offset_k / 128}, {0, 1}};\n\n                    // auto gmem_U = (const Tu*)U_ + (offset_m / 128) * ldU + (offset_k / 128);\n                    // auto step_U = 1;\n\n                    // auto gmem_V = (const Tv*)V_ + offset_n + (offset_k / 128) * ldV;\n                    // auto step_V = ldV;\n\n                    while (k_iter > 0) {\n                        int pipe = write_state.index();\n                        ConsumerBar::wait(&consumer_bar[pipe], write_state.phase());\n                        ProducerBar::arrive_and_expect_tx(&producer_bar[pipe], kTmaTxBytes);\n                        gmem_A.Load(&producer_bar[pipe], &smem_A[pipe * CTA_M * CTA_K]);\n                        // {\n                        //     // printf(\"%f\\n\", *gmem_U);\n                        // smem_U[pipe][0] = *gmem_U;\n                        // gmem_U += step_U;\n                        // }\n                        gmem_B.Load(&producer_bar[pipe], &smem_B[pipe * CTA_N * CTA_K]);\n                        gmem_V.Load(&producer_bar[pipe], &smem_V[pipe][0]);\n\n                        ++write_state;\n                        --k_iter;\n                    }\n                }\n            }\n        }\n        else {\n            cutlass::arch::warpgroup_reg_alloc<232>();\n\n            auto& smem_A  = storage.source.A;\n            auto& smem_B  = storage.source.B;\n            auto& smem_U  = storage.source.U;\n            auto& smem_V  = storage.source.V;\n            auto& smem_UV = storage.UV[warpgroup_id];\n\n            const int warp_group_id_m = warpgroup_id % kWorkGroupM;\n            const int warp_group_id_n = warpgroup_id / kWorkGroupM;\n\n            auto smem_desc_A = make_smem_desc(&smem_A[warp_group_id_m * MMA_ATOM_M * CTA_K], 1);\n            auto smem_desc_B = make_smem_desc(&smem_B[warp_group_id_n * MMA_ATOM_N * CTA_K], 1);\n\n            SmemDescIterV2<Stages, ((sizeof(Ta) * CTA_M * CTA_K) >> 4)> smem_iter_A{smem_desc_A};\n            SmemDescIterV2<Stages, ((sizeof(Tb) * CTA_N * CTA_K) >> 4)> smem_iter_B{smem_desc_B};\n\n            constexpr int kStepMA = (sizeof(Ta) * MMA_M * CTA_K) >> 4;\n            constexpr int kStepNB = (sizeof(Tb) * MMA_N * CTA_K) >> 4;\n            constexpr int kStepKA = (sizeof(Ta) * MMA_K) >> 4;\n            constexpr int kStepKB = (sizeof(Tb) * MMA_K) >> 4;\n\n            cutlass::PipelineState<Stages> read_state{};\n            cutlass::PipelineState<Stages> release_state{};\n\n            while (sched.next()) {\n                auto [cta_tile_p, cluster_tile_p] = sched.is_valid_tile();\n\n                if (!cluster_tile_p) {\n                    // OOB tile caused by swizzle pattern\n                    continue;\n                }\n\n                MMA_Atom::CRegisters frag_C[MMA_ITER_M][MMA_ITER_N];\n                MMA_Atom::CRegisters accum_C[MMA_ITER_M][MMA_ITER_N]{};  /// TODO: check the z-fill is eliminated\n\n                const auto tile_offset              = sched.tile_offset();\n                const auto [iter_k_beg, iter_k_end] = sched.iter_k_range();\n\n                const int offset_m = tile_offset.x * CTA_M;\n                const int offset_n = tile_offset.y * CTA_N;\n                const int offset_k = 0;\n\n                auto gmem_U = (const Tu*)U_ + (offset_m / 128) * ldU + (offset_k / 128);\n                auto step_U = 1;\n\n                int k_iter = iter_k_end - iter_k_beg;\n\n                auto tile_gemm = [&] {\n                    PRAGMA_UNROLL\n                    for (int k = 0; k < MMA_ITER_K; ++k) {\n                        PRAGMA_UNROLL\n                        for (int m = 0; m < MMA_ITER_M; ++m) {\n                            PRAGMA_UNROLL\n                            for (int n = 0; n < MMA_ITER_N; ++n) {\n                                wgmma<MMA_Atom>(smem_iter_A, smem_iter_B, frag_C[m][n], k == 0);\n                                smem_iter_B += kStepNB;\n                            }\n                            smem_iter_B -= MMA_ITER_N * kStepNB;\n                            smem_iter_A += kStepMA;\n                        }\n                        smem_iter_A += kStepKA - MMA_ITER_M * kStepMA;\n                        smem_iter_B += kStepKB;\n                    }\n                    smem_iter_A -= MMA_ITER_K * kStepKA;\n                    smem_iter_B -= MMA_ITER_K * kStepKB;\n                    cute::warpgroup_commit_batch();\n\n                    smem_iter_A.Advance(read_state.index());\n                    smem_iter_B.Advance(read_state.index());\n                };\n\n                auto consumer_arrive = [&] {\n                    if constexpr (kClusterSize > 1) {\n                        ConsumerBar::arrive(&consumer_bar[release_state.index()],\n                                            threadIdx.x % WARPGROUP_SIZE,\n                                            threadIdx.x % WARPGROUP_SIZE < kClusterSize);\n                    }\n                    else {\n                        if (threadIdx.x % WARPGROUP_SIZE == 0) {\n                            ConsumerBar::arrive(&consumer_bar[release_state.index()]);\n                        }\n                    }\n                };\n\n                if constexpr (kClusterSize > 1) {\n                    if (!cta_tile_p) {\n                        // other CTAs in the cluster are still alive\n                        for (; k_iter > 0; --k_iter) {\n                            ProducerBar::wait(&producer_bar[read_state.index()], read_state.phase());\n                            consumer_arrive();\n                            smem_iter_A.Advance(read_state.index());\n                            smem_iter_B.Advance(read_state.index());\n                            ++read_state;\n                            ++release_state;\n                        }\n                        continue;\n                    }\n                }\n\n                float scale_U{};\n                auto  Load_U = [&] {\n                    scale_U = *gmem_U;\n                    gmem_U += step_U;\n                };\n\n                auto scale_accum = [&]() {  // cta_n = mma_iter_n * wg_n * mma_atom_n\n                    // auto scale_U = smem_U[read_state.index()][0];\n\n                    PRAGMA_UNROLL\n                    for (int i = threadIdx.x % WARPGROUP_SIZE; i < MMA_ATOM_N; i += WARPGROUP_SIZE) {\n                        smem_UV[i] = scale_U * smem_V[read_state.index()][i + warp_group_id_n * MMA_ATOM_N];\n                    }\n                    cute::warpgroup_wait<0>();\n\n                    const int lane_id = threadIdx.x % WARP_SIZE;\n\n                    cutlass::arch::NamedBarrier(WARPGROUP_SIZE, warpgroup_id + 1).sync();\n                    PRAGMA_UNROLL\n                    for (int n = 0; n < MMA_ITER_N; ++n) {\n                        PRAGMA_UNROLL\n                        for (int c = 0; c < MMA_ATOM_N; c += 8) {\n                            Array<float, 2> scale_Vs;\n                            int             idx = n * MMA_N + c + (lane_id & 3) * 2;\n                            Load(scale_Vs, &smem_UV[idx]);\n                            PRAGMA_UNROLL\n                            for (int m = 0; m < MMA_ITER_M; ++m) {\n                                accum_C[m][n][c / 2 + 0] += frag_C[m][n][c / 2 + 0] * scale_Vs[0];\n                                accum_C[m][n][c / 2 + 1] += frag_C[m][n][c / 2 + 1] * scale_Vs[1];\n                                accum_C[m][n][c / 2 + 2] += frag_C[m][n][c / 2 + 2] * scale_Vs[0];\n                                accum_C[m][n][c / 2 + 3] += frag_C[m][n][c / 2 + 3] * scale_Vs[1];\n                            }\n                        }\n                    }\n                };\n\n                Load_U();\n                ProducerBar::wait(&producer_bar[read_state.index()], read_state.phase());\n                cute::warpgroup_arrive();\n                warpgroup_fence_operand(frag_C);\n                tile_gemm();\n                warpgroup_fence_operand(frag_C);\n                scale_accum();\n                consumer_arrive();\n                --k_iter;\n                ++read_state;\n                ++release_state;\n\n                while (k_iter > 0) {\n                    Load_U();\n                    ProducerBar::wait(&producer_bar[read_state.index()], read_state.phase());\n                    cute::warpgroup_arrive();\n                    warpgroup_fence_operand(frag_C);\n                    tile_gemm();\n                    warpgroup_fence_operand(frag_C);\n                    scale_accum();\n                    consumer_arrive();\n                    --k_iter;\n                    ++read_state;\n                    ++release_state;\n                }\n\n                if (threadIdx.x == 0) {\n                    cute::tma_store_wait<0>();\n                }\n\n                cutlass::arch::NamedBarrier(WARPGORUPS * WARPGROUP_SIZE).sync();\n\n                // epilogue\n                const int warp_id = threadIdx.x / WARP_SIZE;\n                const int lane_id = threadIdx.x % WARP_SIZE;\n\n                auto& smem_C = storage.C;\n\n                // (M,N):(1,M)\n                PRAGMA_UNROLL\n                for (int m = 0; m < MMA_ITER_M; ++m) {\n                    PRAGMA_UNROLL\n                    for (int n = 0; n < MMA_ITER_N; ++n) {\n                        PRAGMA_UNROLL\n                        for (int i = 0; i < MMA_ATOM_N; i += 16) {\n                            // clang-format off\n                            // const int mm   = m * MMA_M + warp_id * 16 + (lane_id & 8);\n                            // const int nn   = n * MMA_N +     i        + (lane_id & 7) + (lane_id & 16) / 2;\n                            const int mm   = m * MMA_M + (warp_id & 3) * 16 + (lane_id & 8);\n                            const int nn   = n * MMA_N + warp_group_id_n * MMA_ATOM_N + i + (lane_id & 7) + (lane_id & 16) / 2;\n                            // clang-format on\n                            __align__(16) Array<Tc, 8> tvec = cast<Tc>(*(Array<float, 8>*)&accum_C[m][n][i / 2]);\n                            cute::SM90_U16x8_STSM_T::copy((uint32_t&)tvec[0],\n                                                          (uint32_t&)tvec[2],\n                                                          (uint32_t&)tvec[4],\n                                                          (uint32_t&)tvec[6],\n                                                          (cutlass::uint128_t&)smem_C[nn * CTA_M + mm]);\n                        }\n                    }\n                }\n                cute::tma_store_fence();  // visibility: smem -> async proxy\n                cutlass::arch::NamedBarrier(WARPGORUPS * WARPGROUP_SIZE).sync();\n\n                if (threadIdx.x == 0) {\n                    cute::SM90_TMA_STORE_2D::copy(&tm_c, &smem_C, offset_m, offset_n);\n                    cute::tma_store_arrive();\n                }\n            }  // scheduler loop\n\n            if (threadIdx.x == 0) {\n                cute::tma_store_wait<0>();\n            }\n        }\n\n        cute::cluster_arrive();\n        cute::cluster_wait();\n\n    }  // operator()\n};\n\nextern __shared__ char smem_buf[];\n\ntemplate<class Kernel>\n__global__ void __launch_bounds__(Kernel::CTA_SIZE, 1) gemm_kernel_sm90(const __grid_constant__ CUtensorMap tm_a,\n                                                                        const __grid_constant__ CUtensorMap tm_b,\n                                                                        const __grid_constant__ CUtensorMap tm_c,\n                                                                        const __grid_constant__ CUtensorMap tm_u,\n                                                                        const __grid_constant__ CUtensorMap tm_v,\n                                                                        const void*                         U_,\n                                                                        int                                 ldU,\n                                                                        const void*                         V_,\n                                                                        int                                 ldV,\n                                                                        typename Kernel::Scheduler          sched)\n{\n#if __CUDA_ARCH__\n    if constexpr (Kernel::Arch::is_compatible(__CUDA_ARCH__)) {\n        Kernel kernel;\n        kernel(tm_a, tm_b, tm_c, tm_u, tm_v, U_, ldU, V_, ldV, sched, smem_buf);\n    }\n#endif\n}\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/gemm_universal_sm90_v2.h",
    "content": "#pragma once\n\n#include <numeric>\n#include <utility>\n\n#include <cuda_fp8.h>\n\n#include \"cute/arch/cluster_sm90.hpp\"\n#include \"cute/arch/copy_sm90.hpp\"\n#include \"cute/arch/copy_sm90_tma.hpp\"\n#include \"cute/arch/mma_sm90_desc.hpp\"\n#include \"cute/arch/mma_sm90_gmma.hpp\"\n#include \"cute/atom/mma_traits.hpp\"\n#include \"cute/atom/mma_traits_sm90_gmma.hpp\"\n\n#include \"cutlass/arch/barrier.h\"\n#include \"cutlass/arch/reg_reconfig.h\"\n#include \"cutlass/cutlass.h\"\n#include \"cutlass/pipeline/sm90_pipeline.hpp\"\n\n#include \"src/turbomind/core/data_type.h\"\n\n#include \"src/turbomind/kernels/core/array_ops.h\"\n#include \"src/turbomind/kernels/core/common.h\"\n#include \"src/turbomind/kernels/core/layout.h\"\n#include \"src/turbomind/kernels/core/meta.h\"\n#include \"src/turbomind/kernels/core/smem.h\"\n\n#include \"src/turbomind/kernels/gemm/arch.h\"\n#include \"src/turbomind/kernels/gemm/iterator_sm90.h\"\n#include \"src/turbomind/kernels/gemm/scheduler.cuh\"\n#include \"src/turbomind/kernels/gemm/types.h\"\n#include \"src/turbomind/kernels/gemm/utils.h\"\n\nnamespace turbomind::gemm {\n\nnamespace GMMA = cute::SM90::GMMA;\n\ninline __device__ cute::GmmaDescriptor make_smem_desc(void* smem_ptr, int layout_type)\n{\n    auto uint_ptr = cast_smem_ptr_to_uint(smem_ptr);\n\n    cute::GmmaDescriptor desc{};\n    desc.bitfield.start_address_       = uint_ptr >> 4;\n    desc.bitfield.layout_type_         = layout_type;\n    desc.bitfield.leading_byte_offset_ = 0;\n    desc.bitfield.stride_byte_offset_  = 1024 >> 4;\n    desc.bitfield.base_offset_         = 0;\n\n    return desc;\n}\n\ntemplate<int Stages, int Step>\nstruct SmemDescIterV2 {\n    union {\n        uint32_t u32_[2];\n        uint64_t u64_;\n    };\n\n    uint32_t base_;\n\n    __device__ SmemDescIterV2(uint64_t desc): u64_{desc}, base_{u32_[0]} {}\n\n    __device__ void Advance(int stage)\n    {\n        u32_[0] += Step;\n        if (stage == Stages - 1) {\n            u32_[0] = base_;\n        }\n    }\n\n    __device__ void Reset(int stage)\n    {\n        u32_[0] = base_ + stage * Step;\n    }\n\n    __device__ SmemDescIterV2& operator+=(int offset)\n    {\n        u32_[0] += offset;\n        return *this;\n    }\n\n    __device__ SmemDescIterV2& operator-=(int offset)\n    {\n        u32_[0] -= offset;\n        return *this;\n    }\n\n    __device__ operator uint64_t()\n    {\n        return u64_;\n    }\n};\n\ntemplate<class MMA_Atom, size_t... Is>\ninline __device__ void\nwgmma_impl(uint64_t desc_a, uint64_t desc_b, float* frag_C, bool clear, std::index_sequence<Is...>)\n{\n    return MMA_Atom::fma(desc_a, desc_b, frag_C[Is]..., clear ? GMMA::ScaleOut::Zero : GMMA::ScaleOut::One);\n}\n\ntemplate<class MMA_Atom, int N>\ninline __device__ void wgmma(uint64_t desc_a, uint64_t desc_b, float (&frag_C)[N], bool clear)\n{\n    return wgmma_impl<MMA_Atom>(desc_a, desc_b, frag_C, clear, std::make_index_sequence<N>{});\n}\n\ninline __device__ void warpgroup_fence_operand(float& reg)\n{\n    asm volatile(\"\" : \"+f\"(reg)::\"memory\");\n}\n\ntemplate<int M, int N, int K>\ninline __device__ void warpgroup_fence_operand(float (&x)[M][N][K])\n{\n    PRAGMA_UNROLL\n    for (int m = 0; m < M; ++m) {\n        PRAGMA_UNROLL\n        for (int n = 0; n < N; ++n) {\n            PRAGMA_UNROLL\n            for (int k = 0; k < K; ++k) {\n                warpgroup_fence_operand(x[m][n][k]);\n            }\n        }\n    }\n}\n\ntemplate<int N, int K>\ninline __device__ void warpgroup_fence_operand(float (&x)[N][K])\n{\n    PRAGMA_UNROLL\n    for (int n = 0; n < N; ++n) {\n        PRAGMA_UNROLL\n        for (int k = 0; k < K; ++k) {\n            warpgroup_fence_operand(x[n][k]);\n        }\n    }\n}\n\ntemplate<class Func, size_t... Is>\n__device__ void for_(std::index_sequence<Is...>, Func func)\n{\n    return (func(constant<Is>{}), ...);\n}\n\nnamespace arch {\n\ntemplate<int M_, int N_, Order order>\nstruct Cluster {\n    static constexpr int M = M_;\n    static constexpr int N = N_;\n\n    static constexpr int C = mk2cs<order>(M, N).x;\n    static constexpr int S = mk2cs<order>(M, N).y;\n\n    static constexpr int size = M * N;\n\n    static constexpr uint16_t kMaskC = (1 << C) - 1;\n    static constexpr uint16_t kMaskS = ((1 << size) - 1) / kMaskC;\n\n    __device__ static ushort2 mask_cs(int cta_id)\n    {\n        const auto [c, s] = cta_cs(cta_id);\n        return make_ushort2(kMaskS << c, kMaskC << s * C);\n    }\n\n    __device__ static ushort2 mask_mn(int cta_id)\n    {\n        auto [c, s] = mask_cs(cta_id);\n        return order == kColMajor ? ushort2{c, s} : ushort2{s, c};\n    }\n\n    __device__ static int2 cta_cs(int cta_id)\n    {\n        return {C > 1 ? cta_id % C : 0, S > 1 ? cta_id / C : 0};\n    }\n\n    __device__ static int2 cta_mn(int cta_id)\n    {\n        return cs2mk<order>(cta_cs(cta_id));\n    }\n\n    int2    cta_mn_;\n    ushort2 mask_mn_;\n\n    __device__ explicit Cluster(int cta_id): cta_mn_(cta_mn(cta_id)), mask_mn_(mask_mn(cta_id)) {}\n\n    __device__ int cta_m()\n    {\n        return cta_mn_.x;\n    }\n\n    __device__ int cta_n()\n    {\n        return cta_mn_.y;\n    }\n\n    __device__ uint16_t mask_m()\n    {\n        return mask_mn_.x;\n    }\n\n    __device__ uint16_t mask_n()\n    {\n        return mask_mn_.y;\n    }\n};\n\n}  // namespace arch\n\nstruct GemmUniversalSm90_v2 {\n\n    static constexpr bool kDebug = false;\n\n    using Arch = Sm90;\n\n    // using MMA_Atom = GMMA::MMA_64x128x16_F32BF16BF16_SS<GMMA::Major::K, GMMA::Major::K>;\n    using MMA_Atom = GMMA::MMA_64x96x32_F32E4M3E4M3_SS_TN<>;\n    static constexpr typename cute::MMA_Traits<MMA_Atom>::Shape_MNK MMA_Shape{};\n\n    static constexpr int MMA_ATOM_M = cute::get<0>(MMA_Shape);\n    static constexpr int MMA_ATOM_N = cute::get<1>(MMA_Shape);\n    static constexpr int MMA_ATOM_K = cute::get<2>(MMA_Shape);\n\n    static constexpr int WARPGORUPS = 2;\n\n    static constexpr int TILE_M = 128;\n    static constexpr int TILE_N = MMA_ATOM_N;\n    static constexpr int TILE_K = 128;\n\n    static constexpr int MMA_ITER_M = TILE_M / MMA_ATOM_M;\n    static constexpr int MMA_ITER_N = TILE_N / MMA_ATOM_N;\n    static constexpr int MMA_ITER_K = TILE_K / MMA_ATOM_K;\n\n    static constexpr int kMulticastA = 1;\n    static constexpr int kMulticastB = 2;\n\n    static constexpr int kClusterSize = kMulticastA * kMulticastB;\n\n    static constexpr int Stages = 4;\n\n    static constexpr bool kSplitK     = false;\n    static constexpr int  kChunkSizeK = TILE_K;\n\n    static constexpr int WARPGROUP_SIZE = 128;\n\n    static constexpr int CTA_SIZE = WARPGROUP_SIZE * (WARPGORUPS + 1);\n\n    using Ta = __nv_fp8_e4m3;\n    using Tb = __nv_fp8_e4m3;\n    using Tc = nv_bfloat16;\n\n    using Tu = float;\n    using Tv = float;\n\n    using Cluster = arch::Cluster<kMulticastB, kMulticastA, kRowMajor>;\n\n    using Scheduler = TileScheduler<kRowMajor, Cluster, false, false>;\n\n    using ProducerBar = cutlass::arch::ClusterTransactionBarrier;\n    using ConsumerBar = cutlass::arch::ClusterBarrier;\n\n    static constexpr int MAX_K = 32768;\n\n    static constexpr int TILE_M_U = cdiv(TILE_M, 1);\n    static constexpr int CTA_K_U  = cdiv(TILE_K, 128);\n\n    static constexpr int kTmaTxBytes =\n        sizeof(Ta) * (TILE_M * TILE_K) + sizeof(Tb) * (TILE_N * TILE_K) + sizeof(Tu) * TILE_M_U * CTA_K_U;\n\n    // ! Smem addr must be SBO aligned for TMA load/store\n    struct SharedStorage {\n        struct Source {\n            __align__(1024) Array<Ta, Stages * TILE_M * TILE_K> A;\n            __align__(1024) Array<Tb, Stages * TILE_N * TILE_K> B;\n            __align__(1024) Tu U[Stages][round_up(TILE_M_U * CTA_K_U, 32)];\n            __align__(1024) Tv V[2][WARPGORUPS][cdiv(MAX_K, 128)];\n        };\n        Source source;\n        __align__(1024) Array<Tc, TILE_M * TILE_N> C;\n        __align__(128) uint64_t producer_bar[Stages];\n        __align__(128) uint64_t consumer_bar[Stages];\n        int pipe_count[WARPGORUPS];\n    };\n\n    static constexpr int kSmemSize = sizeof(SharedStorage);\n\n    static constexpr int kSwizzleC = 2 * std::gcd(TILE_N, 128 / sizeof(Tc));\n\n    using LayoutC = std::conditional_t<kSwizzleC >= 32,\n                                       SmemLayoutV2<TILE_M, TILE_N, -1, kSwizzleC / sizeof(Tc)>,\n                                       SmemLayoutV2<TILE_M, TILE_N>>;\n\n    __device__ void operator()(const CUtensorMap& tm_a,\n                               const CUtensorMap& tm_b,\n                               const CUtensorMap& tm_c,\n                               const CUtensorMap& tm_u,\n                               const CUtensorMap& tm_v,\n                               const void*        U_,\n                               int                ldU,\n                               const void*        V_,\n                               int                ldV,\n                               Scheduler          sched,\n                               char*              smem_buf)\n    {\n        SharedStorage& storage = *reinterpret_cast<SharedStorage*>(smem_buf);\n\n        uint64_t* producer_bar = storage.producer_bar;\n        uint64_t* consumer_bar = storage.consumer_bar;\n\n        if (threadIdx.x == 0) {\n            PRAGMA_UNROLL\n            for (int s = 0; s < Stages; ++s) {\n                ProducerBar::init(&producer_bar[s], 1);\n                ConsumerBar::init(&consumer_bar[s], kClusterSize * 4);\n            }\n            cutlass::arch::fence_view_async_shared();\n            if constexpr (kClusterSize > 1) {\n                cutlass::arch::fence_barrier_init();\n            }\n            PRAGMA_UNROLL\n            for (int i = 0; i < WARPGORUPS; ++i) {\n                storage.pipe_count[i] = 0;\n            }\n        }\n\n        (kClusterSize > 1) ? cute::cluster_sync() : __syncthreads();\n\n        const int warpgroup_id = cutlass::canonical_warp_group_idx();\n\n        if (warpgroup_id == WARPGORUPS) {\n            cutlass::arch::warpgroup_reg_dealloc<40>();\n\n            static_assert(TILE_M % kMulticastA == 0);\n            static_assert(TILE_N % kMulticastB == 0);\n\n            if (threadIdx.x == WARPGORUPS * WARPGROUP_SIZE) {\n\n                Cluster cluster(cute::block_id_in_cluster().x);\n\n                const int mc_offset_m = cluster.cta_n() * (TILE_M / kMulticastA);\n                const int mc_offset_n = cluster.cta_m() * (TILE_N / kMulticastB);\n\n                auto  smem_A = storage.source.A.data() + mc_offset_m * TILE_K;\n                auto  smem_B = storage.source.B.data() + mc_offset_n * TILE_K;\n                auto& smem_U = storage.source.U;\n\n                sched.grid_init();\n\n                cutlass::PipelineState<Stages> write_state{0, 1, 0};\n\n                while (sched.next()) {\n                    auto [valid_cta_tile_p, cluster_tile_p] = sched.is_valid_tile();\n\n                    if (!cluster_tile_p) {\n                        // OOB tile caused by swizzle pattern\n                        continue;\n                    }\n\n                    const auto tile_offset              = sched.tile_offset();\n                    const auto [iter_k_beg, iter_k_end] = sched.iter_k_range();\n\n                    const int offset_k = iter_k_beg * TILE_K;\n\n                    const uint16_t mask_A = cluster.mask_m();\n                    const uint16_t mask_B = cluster.mask_n();\n\n                    const int offset_m = tile_offset.x * TILE_M;\n                    const int offset_n = tile_offset.y * TILE_N;\n\n                    int k_iter = iter_k_end - iter_k_beg;\n\n                    GmemIteratorSm90<kMulticastA> gmem_A{&tm_a, {offset_k, offset_m + mc_offset_m}, {TILE_K, 0}};\n                    GmemIteratorSm90<kMulticastB> gmem_B{&tm_b, {offset_k, offset_n + mc_offset_n}, {TILE_K, 0}};\n\n                    // column-major\n                    GmemIteratorSm90<kMulticastA> gmem_U{&tm_u, {offset_m + mc_offset_m, offset_k / 128}, {0, 1}};\n\n                    for (; k_iter > 0; --k_iter) {\n                        int pipe = write_state.index();\n                        ConsumerBar::wait(&consumer_bar[pipe], write_state.phase());\n                        ProducerBar::arrive_and_expect_tx(&producer_bar[pipe], kTmaTxBytes);\n                        gmem_A.Load(&producer_bar[pipe], &smem_A[pipe * TILE_M * TILE_K], mask_A);\n                        gmem_B.Load(&producer_bar[pipe], &smem_B[pipe * TILE_N * TILE_K], mask_B);\n                        gmem_U.Load(&producer_bar[pipe], &smem_U[pipe][0] + mc_offset_m, mask_A);\n                        ++write_state;\n                    }\n                }\n            }\n        }\n        else {\n            cutlass::arch::warpgroup_reg_alloc<232>();\n\n            sched.grid_init(WARPGORUPS);\n\n            auto& smem_A = storage.source.A;\n            auto& smem_B = storage.source.B;\n            auto& smem_U = storage.source.U;\n\n            auto smem_desc_A = make_smem_desc(&smem_A, 1);\n            auto smem_desc_B = make_smem_desc(&smem_B, 1);\n\n            SmemDescIterV2<Stages, ((sizeof(Ta) * TILE_M * TILE_K) >> 4)> smem_iter_A{smem_desc_A};\n            SmemDescIterV2<Stages, ((sizeof(Tb) * TILE_N * TILE_K) >> 4)> smem_iter_B{smem_desc_B};\n\n            constexpr int kStepMA = (sizeof(Ta) * MMA_ATOM_M * TILE_K) >> 4;\n            constexpr int kStepNB = (sizeof(Tb) * MMA_ATOM_N * TILE_K) >> 4;\n            constexpr int kStepKA = (sizeof(Ta) * MMA_ATOM_K) >> 4;\n            constexpr int kStepKB = (sizeof(Tb) * MMA_ATOM_K) >> 4;\n\n            auto math_barrier_sync = [&](int phase, int alive = 1) {\n                constexpr int base    = (int)cutlass::arch::ReservedNamedBarriers::FirstUserBarrier;\n                constexpr int threads = WARPGORUPS * WARPGROUP_SIZE;\n                int           res;\n                asm volatile(\"{\\n\"\n                             \"  .reg.pred p;\\n\"\n                             \"  setp.ne.b32 p, %3, 0;\\n\"\n                             \"  barrier.cta.red.or.pred p, %1, %2, p;\\n\"\n                             \"  selp.s32 %0, 1, 0, p;\\n\"\n                             \"}\\n\"\n                             : \"=r\"(res)\n                             : \"r\"(base + warpgroup_id ^ phase), \"r\"(threads), \"r\"(alive));\n                return res;\n            };\n\n            cutlass::arch::NamedBarrier wg_barrier(WARPGROUP_SIZE, warpgroup_id + 2);  // 2,3\n\n            sched.next(warpgroup_id);\n\n            if (warpgroup_id == 1) {\n                math_barrier_sync(1);\n            }\n\n            while (sched.next(WARPGORUPS)) {\n                auto [cta_tile_p, cluster_tile_p] = sched.is_valid_tile();\n\n                if (!cluster_tile_p) {\n                    // OOB tile caused by swizzle pattern\n                    continue;\n                }\n\n                MMA_Atom::CRegisters frag_C[MMA_ITER_M][MMA_ITER_N];\n                MMA_Atom::CRegisters accum_C[MMA_ITER_M][MMA_ITER_N]{};\n\n                const auto tile_offset              = sched.tile_offset();\n                const auto [iter_k_beg, iter_k_end] = sched.iter_k_range();\n\n                const auto [M, N, K, L] = sched.gemm_shape();\n\n                const int offset_m = tile_offset.x * TILE_M;\n                const int offset_n = tile_offset.y * TILE_N;\n                const int offset_k = 0;\n\n                int k_iter = iter_k_end - iter_k_beg;\n\n                const int warp_id = threadIdx.x / WARP_SIZE;\n                const int lane_id = threadIdx.x % WARP_SIZE;\n\n                const int wg_lane = threadIdx.x % WARPGROUP_SIZE;\n\n                cutlass::PipelineState<Stages> pipe_state{};\n\n                auto consumer_arrive = [&] {\n                    __syncwarp();\n                    if constexpr (kClusterSize > 1) {\n                        ConsumerBar::arrive(&consumer_bar[pipe_state.index()], lane_id, lane_id < kClusterSize);\n                    }\n                    else {\n                        if (lane_id == 0) {\n                            ConsumerBar::arrive(&consumer_bar[pipe_state.index()]);\n                        }\n                    }\n                };\n\n                if constexpr (kClusterSize > 1) {\n                    if (!cta_tile_p) {  // other CTAs in the cluster are still alive\n                        math_barrier_sync(0);\n                        pipe_state.advance(storage.pipe_count[warpgroup_id ^ 1]);\n                        for (; k_iter > 0; --k_iter) {\n                            ProducerBar::wait(&producer_bar[pipe_state.index()], pipe_state.phase());\n                            consumer_arrive();\n                            ++pipe_state;\n                        }\n                        if (wg_lane == 0) {\n                            storage.pipe_count[warpgroup_id] = pipe_state.count();\n                        }\n                        math_barrier_sync(1);\n                        continue;\n                    }\n                }\n\n                auto Copy = [k = cdiv(K, 128)](Tv* dst, const Tv* src) {\n                    for (int i = threadIdx.x % WARPGROUP_SIZE; i < k; i += WARPGROUP_SIZE) {\n                        dst[i] = __ldg(&src[i]);\n                    }\n                };\n                auto gmem_V = (const Tv*)V_ + (offset_n / 128) * ldV + (offset_k / 128);\n                Copy(storage.source.V[0][warpgroup_id], gmem_V);\n\n                uint32_t pred_V{};\n                int      iter_V{};\n\n                constexpr int OUTER_N = std::gcd(MMA_ATOM_N, 128);\n                if constexpr (OUTER_N != 128) {\n\n                    static_assert(MMA_ATOM_N <= 128 + OUTER_N, \"MMA inst is crossing more than 2 scale blocks\");\n\n                    constexpr uint32_t mask = (1UL << (TILE_M / OUTER_N)) - 1;\n\n                    int phase = 128 - offset_n % 128;\n                    pred_V    = (mask << (phase / OUTER_N)) & mask;\n\n                    if (pred_V && offset_n / 128 + 1 < cdiv(N, 128)) {\n                        Copy(storage.source.V[1][warpgroup_id], gmem_V + ldV);\n                    }\n\n                    // if constexpr (kWorkGroupN > 1) {\n                    //     constexpr int tiles = MMA_ATOM_N / OUTER_N;\n                    //     pred_V              = (pred_V >> (warp_group_id_n * tiles)) & ((1 << tiles) - 1);\n                    // }\n                }\n\n                float scale_V[2];\n                auto  Load_V = [&] {\n                    scale_V[0] = storage.source.V[0][warpgroup_id][iter_V];\n                    if (pred_V) {\n                        scale_V[1] = storage.source.V[1][warpgroup_id][iter_V];\n                    }\n                    ++iter_V;\n                };\n\n                float     scale_U[MMA_ITER_M][2];\n                const int offset_U = warp_id % 4 * 16 + lane_id / 4;\n                auto      Load_U   = [&] {\n                    for (int m = 0; m < MMA_ITER_M; ++m) {\n                        scale_U[m][0] = smem_U[pipe_state.index()][offset_U + m * MMA_ATOM_M];\n                        scale_U[m][1] = smem_U[pipe_state.index()][offset_U + m * MMA_ATOM_M + 8];\n                    }\n                };\n\n                auto scale_accum = [&](int m) {  // cta_n = mma_iter_n * wg_n * mma_atom_n\n                    float scales[2][2];\n                    scales[0][0] = scale_U[m][0] * scale_V[0];\n                    scales[1][0] = scale_U[m][1] * scale_V[0];\n                    scales[0][1] = scale_U[m][0] * scale_V[1];\n                    scales[1][1] = scale_U[m][1] * scale_V[1];\n                    PRAGMA_UNROLL\n                    for (int n = 0; n < MMA_ITER_N; ++n) {\n                        PRAGMA_UNROLL\n                        for (int c0 = 0; c0 < MMA_ATOM_N; c0 += OUTER_N) {\n                            bool pred = (pred_V & (1U << (c0 / OUTER_N)));\n                            PRAGMA_UNROLL\n                            for (int cc = 0; cc < OUTER_N; cc += 8) {\n                                int c = c0 + cc;\n                                // clang-format off\n                                accum_C[m][n][c / 2 + 0] += (pred ? scales[0][1] : scales[0][0]) * frag_C[m][n][c / 2 + 0];\n                                accum_C[m][n][c / 2 + 1] += (pred ? scales[0][1] : scales[0][0]) * frag_C[m][n][c / 2 + 1];\n                                accum_C[m][n][c / 2 + 2] += (pred ? scales[1][1] : scales[1][0]) * frag_C[m][n][c / 2 + 2];\n                                accum_C[m][n][c / 2 + 3] += (pred ? scales[1][1] : scales[1][0]) * frag_C[m][n][c / 2 + 3];\n                                // clang-format on\n                            }\n                        }\n                    }\n\n                };\n\n                auto gmma = [&](int m) {\n                    PRAGMA_UNROLL\n                    for (int k = 0; k < MMA_ITER_K; ++k) {\n                        PRAGMA_UNROLL\n                        for (int n = 0; n < MMA_ITER_N; ++n) {\n                            wgmma<MMA_Atom>(smem_iter_A, smem_iter_B, frag_C[m][n], k == 0);\n                            smem_iter_B += kStepNB;\n                        }\n                        smem_iter_B -= MMA_ITER_N * kStepNB;\n                        smem_iter_A += kStepKA;\n                        smem_iter_B += kStepKB;\n                    }\n                    smem_iter_A -= MMA_ITER_K * kStepKA;\n                    smem_iter_B -= MMA_ITER_K * kStepKB;\n                    smem_iter_A += kStepMA;\n                    cute::warpgroup_commit_batch();\n                };\n\n                static_assert(MMA_ITER_N == 1);\n\n                math_barrier_sync(0);\n\n                pipe_state.advance(storage.pipe_count[warpgroup_id ^ 1]);\n\n                smem_iter_A.Reset(pipe_state.index());\n                smem_iter_B.Reset(pipe_state.index());\n                Load_V();\n                ProducerBar::wait(&producer_bar[pipe_state.index()], pipe_state.phase());\n                Load_U();\n                cute::warpgroup_arrive();\n                gmma(0);\n                gmma(1);\n                cute::warpgroup_wait<1>();\n                scale_accum(0);\n                cute::warpgroup_wait<0>();\n                scale_accum(1);\n                consumer_arrive();\n                ++pipe_state;\n                --k_iter;\n\n                Load_V();\n                ProducerBar::wait(&producer_bar[pipe_state.index()], pipe_state.phase());\n                Load_U();\n                smem_iter_A.Reset(pipe_state.index());\n                smem_iter_B.Reset(pipe_state.index());\n\n                for (; k_iter > 1; --k_iter) {\n                    cute::warpgroup_arrive();\n                    gmma(0);\n                    gmma(1);\n                    cute::warpgroup_wait<1>();\n                    scale_accum(0);\n                    cute::warpgroup_wait<0>();\n                    scale_accum(1);\n                    consumer_arrive();\n                    ++pipe_state;\n                    Load_V();\n                    ProducerBar::wait(&producer_bar[pipe_state.index()], pipe_state.phase());\n                    Load_U();\n                    smem_iter_A.Reset(pipe_state.index());\n                    smem_iter_B.Reset(pipe_state.index());\n                }\n\n                cute::warpgroup_arrive();\n                gmma(0);\n                gmma(1);\n                cute::warpgroup_wait<1>();\n                scale_accum(0);\n                cute::warpgroup_wait<0>();\n                scale_accum(1);\n                consumer_arrive();\n                ++pipe_state;\n\n                if (wg_lane == 0) {\n                    storage.pipe_count[warpgroup_id] = pipe_state.count();\n                }\n\n                math_barrier_sync(1);\n\n                // epilogue\n                PRAGMA_UNROLL\n                for (int m = 0; m < MMA_ITER_M; ++m) {\n                    PRAGMA_UNROLL\n                    for (int n = 0; n < MMA_ITER_N; ++n) {\n\n                        constexpr int N       = LayoutC::C0;\n                        constexpr int SW_bits = log2(kSwizzleC / 16);\n\n                        static_assert(!SW_bits || MMA_ATOM_N % LayoutC::C0 == 0);\n\n                        const int m0 = m * MMA_ATOM_M;\n                        const int n0 = n * MMA_ATOM_N;\n\n                        PRAGMA_UNROLL\n                        for (int i = 0; i < MMA_ATOM_N; i += 16) {\n                            __align__(16) Array<Tc, 8> tvec = cast<Tc>(*(Array<float, 8>*)&accum_C[m][n][i / 2]);\n\n                            int mm = m0 + warp_id % 4 * 16 + (lane_id & 8);\n                            int nn = n0 + i / N * N;\n\n                            int addr = ((nn / N) * TILE_M * N) + (mm * N) + (nn % N);\n\n                            int s = lane_id % 8;\n                            int c = (lane_id & 16) / 2 + i % N;\n\n                            addr += Swizzle<SW_bits, 3, 3>::apply(s * N + c);\n\n                            auto& uvec = (Array<uint32_t, 4>&)tvec;\n                            cute::SM90_U32x4_STSM_N::copy(\n                                uvec[0], uvec[1], uvec[2], uvec[3], (cutlass::uint128_t&)storage.C[addr]);\n                        }\n                    }\n                }\n\n                cute::tma_store_fence();  // visibility: smem -> async proxy\n\n                wg_barrier.sync();\n\n                const int wg_thread_id = threadIdx.x % WARPGROUP_SIZE;\n\n                if (wg_thread_id < LayoutC::C1) {\n                    const int tma_n = wg_thread_id * LayoutC::C0;\n                    cute::SM90_TMA_STORE::copy(\n                        &tm_c, &storage.C[wg_thread_id * TILE_M * LayoutC::C0], offset_n + tma_n, offset_m);\n                    cute::tma_store_arrive();\n                    cute::tma_store_wait<0>();\n                }\n\n                wg_barrier.sync();\n\n            }  // scheduler loop\n\n            if (warpgroup_id == 0) {\n                math_barrier_sync(0, 0);\n                while (math_barrier_sync(1, 0)) {\n                    math_barrier_sync(0, 0);\n                }\n            }\n            else {\n                while (math_barrier_sync(0, 0)) {\n                    math_barrier_sync(1, 0);\n                }\n            }\n\n            if (threadIdx.x % WARPGROUP_SIZE < LayoutC::C1) {\n                cute::tma_store_wait<0>();\n            }\n        }\n\n        if constexpr (kClusterSize > 1) {\n            cute::cluster_arrive();\n            cute::cluster_wait();\n        }\n\n    }  // operator()\n};\n\nextern __shared__ char smem_buf[];\n\ntemplate<class Kernel>\n__global__ void __launch_bounds__(Kernel::CTA_SIZE, 1) gemm_kernel_sm90(const __grid_constant__ CUtensorMap tm_a,\n                                                                        const __grid_constant__ CUtensorMap tm_b,\n                                                                        const __grid_constant__ CUtensorMap tm_c,\n                                                                        const __grid_constant__ CUtensorMap tm_u,\n                                                                        const __grid_constant__ CUtensorMap tm_v,\n                                                                        const void*                         U_,\n                                                                        int                                 ldU,\n                                                                        const void*                         V_,\n                                                                        int                                 ldV,\n                                                                        typename Kernel::Scheduler          sched)\n{\n#if __CUDA_ARCH__\n    if constexpr (Kernel::Arch::is_compatible(__CUDA_ARCH__)) {\n        Kernel kernel;\n        kernel(tm_a, tm_b, tm_c, tm_u, tm_v, U_, ldU, V_, ldV, sched, smem_buf);\n    }\n#endif\n}\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/gemm_universal_sm90_v3.h",
    "content": "#pragma once\n\n#include <numeric>\n#include <utility>\n\n#include <cuda_fp8.h>\n#include <cuda_pipeline_primitives.h>\n\n#include \"cute/arch/cluster_sm90.hpp\"\n#include \"cute/arch/copy_sm90.hpp\"\n#include \"cute/arch/copy_sm90_desc.hpp\"\n#include \"cute/arch/copy_sm90_tma.hpp\"\n#include \"cute/arch/mma_sm90_desc.hpp\"\n\n#include \"cutlass/arch/barrier.h\"\n#include \"cutlass/arch/reg_reconfig.h\"\n#include \"cutlass/cutlass.h\"\n#include \"cutlass/pipeline/sm90_pipeline.hpp\"\n\n#include \"src/turbomind/core/data_type.h\"\n\n#include \"src/turbomind/kernels/core/array_ops.h\"\n#include \"src/turbomind/kernels/core/common.h\"\n#include \"src/turbomind/kernels/core/layout.h\"\n\n#include \"src/turbomind/kernels/core/smem.h\"\n\n#include \"src/turbomind/kernels/gemm/arch.h\"\n#include \"src/turbomind/kernels/gemm/cp_async.h\"\n#include \"src/turbomind/kernels/gemm/iterator_sm90.h\"\n#include \"src/turbomind/kernels/gemm/matrix_ptr.h\"\n#include \"src/turbomind/kernels/gemm/scheduler.cuh\"\n#include \"src/turbomind/kernels/gemm/types.h\"\n#include \"src/turbomind/kernels/gemm/utils.h\"\n\n#include \"src/turbomind/kernels/gemm/scaled_gmma_fp8_sm90.h\"\n#include \"src/turbomind/kernels/gemm/sm90_utils.h\"\n\nnamespace turbomind::gemm {\n\ntemplate<Order raster_order, int multicast_a, int multicast_b, bool is_grouped_gemm_>\nstruct GemmUniversalSm90_v3 {\n\n    static constexpr bool kDebug = false;\n\n    using Arch = Sm90;\n\n    static constexpr int TILE_M = 128;\n    static constexpr int TILE_N = 192;\n    static constexpr int TILE_K = 128;\n\n    static constexpr int WG_M = 2;\n    static constexpr int WG_N = 1;\n\n    static constexpr int WG_TILE_M = TILE_M / WG_M;\n    static constexpr int WG_TILE_N = TILE_N / WG_N;\n\n    static constexpr int kSchedWarpGroups = 1;\n\n    static constexpr int WARPGORUPS = WG_M * WG_N;\n\n    using GMMA = ScaledGmmaFP8_TN<WG_TILE_M, WG_TILE_N, TILE_K, 1, 1, 1, 1>;\n\n    static constexpr int kMulticastA = multicast_a;\n    static constexpr int kMulticastB = multicast_b;\n\n    static constexpr int kClusterSize = kMulticastA * kMulticastB;\n\n    static constexpr int Stages = 4;\n\n    static constexpr bool kSplitK     = false;\n    static constexpr int  kChunkSizeK = TILE_K;\n\n    static constexpr int WARPGROUP_SIZE = 128;\n\n    static constexpr int kMathGroupSize = WARPGROUP_SIZE * WARPGORUPS;\n\n    static constexpr int CTA_SIZE = WARPGROUP_SIZE * (WARPGORUPS + 1);\n\n    using Ta = __nv_fp8_e4m3;\n    using Tb = __nv_fp8_e4m3;\n    using Tc = nv_bfloat16;\n\n    using Tu = float;\n    using Tv = float;\n\n    using Cluster = arch::Cluster<kMulticastB, kMulticastA, kRowMajor>;\n\n    static constexpr auto is_grouped_gemm = is_grouped_gemm_;\n\n    using Scheduler = TileScheduler<raster_order, Cluster, true, true, TILE_M, TILE_N, Stages, is_grouped_gemm>;\n\n    static constexpr int kMulticastU = is_grouped_gemm ? 1 : kMulticastA;\n\n    using ProducerBar = cutlass::arch::ClusterTransactionBarrier;\n    using ConsumerBar = cutlass::arch::ClusterBarrier;\n\n    static constexpr int kAlignmentU = 16 / sizeof(Tu);\n    static constexpr int kBoxU       = TILE_M + (is_grouped_gemm ? kAlignmentU : 0);\n\n    // Alignment requirement for SMEM addr. This forbids multicast factor 8.\n    static_assert(kMulticastU == 1 || sizeof(Tu) * kBoxU / kMulticastU % 128 == 0);\n\n    static constexpr int kTmaTxBytes =\n        sizeof(Ta) * (TILE_M * TILE_K) + sizeof(Tb) * (TILE_N * TILE_K) + sizeof(Tu) * kBoxU;\n\n    // ! SMEM addr must be SBO aligned for TMA load/store\n    struct SharedStorage {\n        __align__(1024) Array<Ta, Stages * TILE_M * TILE_K> A;\n        __align__(1024) Array<Tb, Stages * TILE_N * TILE_K> B;\n        __align__(1024) Array<Tc, TILE_M * TILE_N> C;\n        __align__(128) Tu U[Stages][round_up<int>(kBoxU, 128)];  // at least 128 byte alignment\n        __align__(128) Tv V[Stages][2];\n        __align__(128) CUtensorMap tensor_map[5];\n        __align__(8) uint64_t producer_bar[Stages];\n        __align__(8) uint64_t consumer_bar[Stages];\n        typename Scheduler::Storage sched;\n    };\n\n    static constexpr int kSmemSize = sizeof(SharedStorage);\n\n    static constexpr int kSwizzleC = 2 * std::gcd(WG_TILE_N, 128 / sizeof(Tc));\n\n    using LayoutC = std::conditional_t<kSwizzleC >= 32,\n                                       SmemLayoutV2<WG_TILE_M, WG_TILE_N, -1, kSwizzleC / sizeof(Tc)>,\n                                       SmemLayoutV2<WG_TILE_M, WG_TILE_N>>;\n\n    static constexpr int OUTER_N       = GMMA::OUTER_N;\n    static constexpr int MMA_SUBTILE_N = GMMA::OP_N / OUTER_N;\n\n    __device__ void operator()(const CUtensorMap& tm_a,\n                               const CUtensorMap& tm_b,\n                               const CUtensorMap& tm_c,\n                               const CUtensorMap& tm_u,\n                               const CUtensorMap& tm_v,\n                               const MatrixParam& param_A,\n                               const MatrixParam& param_B,\n                               const MatrixParam& param_U,\n                               const MatrixParam& param_V,\n                               const MatrixParam& param_C,\n                               Scheduler          sched,\n                               CUtensorMap*       tensormap_buf,\n                               char*              smem_buf)\n    {\n        SharedStorage& storage = *reinterpret_cast<SharedStorage*>(smem_buf);\n\n        uint64_t* producer_bar = storage.producer_bar;\n        uint64_t* consumer_bar = storage.consumer_bar;\n\n        if (threadIdx.x == 0) {\n            PRAGMA_UNROLL\n            for (int s = 0; s < Stages; ++s) {\n                ProducerBar::init(&producer_bar[s], 1 + 1);\n                ConsumerBar::init(&consumer_bar[s], WARPGORUPS * kClusterSize * 4);\n            }\n            sched.init_dyanmic(storage.sched, kClusterSize * (WARPGORUPS * 4 + 1));\n            cutlass::arch::fence_view_async_shared();\n            if constexpr (kClusterSize > 1) {\n                cutlass::arch::fence_barrier_init();\n            }\n        }\n\n        (kClusterSize > 1) ? cute::cluster_sync() : __syncthreads();\n\n        const int wg_idx = cutlass::canonical_warp_group_idx();\n\n        if (wg_idx == WARPGORUPS) {\n            cutlass::arch::warpgroup_reg_dealloc<40>();\n\n            static_assert(TILE_M % kMulticastA == 0);\n            static_assert(TILE_N % kMulticastB == 0);\n\n            cutlass::arch::NamedBarrier producers_bar(WARP_SIZE * 2, 7);\n\n            const int  warp_id = cutlass::canonical_warp_idx_sync();\n            const bool cta_0   = cute::block_id_in_cluster().x == 0;\n\n            if (warp_id % 4 == 0) {\n\n                Cluster cluster(cute::block_id_in_cluster().x);\n\n                const int mc_offset_m = cluster.cta_n() * (TILE_M / kMulticastA);\n                const int mc_offset_n = cluster.cta_m() * (TILE_N / kMulticastB);\n\n                auto  smem_A = storage.A.data() + mc_offset_m * TILE_K;\n                auto  smem_B = storage.B.data() + mc_offset_n * TILE_K;\n                auto& smem_U = storage.U;\n                auto& smem_V = storage.V;\n\n                if constexpr (is_grouped_gemm) {\n                    init_tma_descs<3>({&tm_a, &tm_b, &tm_u}, storage.tensor_map);\n                }\n\n                cutlass::PipelineState<Stages> write_state{0, 1, 0};\n\n                auto sched_state = sched.init_consumer(storage.sched);\n\n                int lane_predicate = cute::elect_one_sync();\n\n                typename Scheduler::Tile* tile;\n\n                while (sched_state.acquire(tile)) {\n\n                    if (tile->is_valid_cluster) {\n\n                        const CUtensorMap* Adesc = &tm_a;\n                        const CUtensorMap* Bdesc = &tm_b;\n                        const CUtensorMap* Udesc = &tm_u;\n\n                        const Tv* gmem_V0 = (const Tv*)param_V.ptr;\n                        const Tv* gmem_V1;\n\n                        if constexpr (is_grouped_gemm) {\n                            const int g  = tile->group_idx;\n                            const int m0 = tile->m0;\n                            const int m1 = tile->m1;\n                            const int m  = m1 - m0;\n\n                            Array<void*, 3> global_addrs;\n                            global_addrs[0] = (Ta*)param_A.ptr + m0 * (int64_t)param_A.stride;\n                            global_addrs[1] = ((void**)param_B.ptr)[g];\n\n                            const int beg_u = m0 / kAlignmentU * kAlignmentU;\n                            const int end_u = round_up(m1, kAlignmentU);\n                            global_addrs[2] = (Tu*)param_U.ptr + beg_u;\n\n                            Array<int, 3> dims;\n                            dims[0] = m;\n                            dims[1] = sched.gemm_shape().y;\n                            dims[2] = end_u - beg_u;\n\n                            auto descs = update_tma_descs(tensormap_buf, storage.tensor_map, global_addrs, dims);\n                            Adesc      = &descs[0];\n                            Bdesc      = &descs[1];\n                            Udesc      = &descs[2];\n\n                            gmem_V0 = ((Tv**)gmem_V0)[g];\n\n                            PRAGMA_UNROLL\n                            for (int i = 0; i < 3; ++i) {\n                                cute::tma_descriptor_fence_acquire((cute::TmaDescriptor*)&descs[i]);\n                            }\n                        }\n\n                        if (lane_predicate) {\n                            const int offset_k = 0;\n\n                            const uint16_t mask_A = cluster.mask_m();\n                            const uint16_t mask_B = cluster.mask_n();\n\n                            const int offset_m = tile->offset_m;\n                            const int offset_n = tile->offset_n;\n\n                            int k_iter = sched.k_iters_;\n\n                            GmemIteratorSm90<kMulticastA> gmem_A{\n                                Adesc, {offset_k, offset_m + mc_offset_m}, {TILE_K, 0}};\n                            GmemIteratorSm90<kMulticastB> gmem_B{\n                                Bdesc, {offset_k, offset_n + mc_offset_n}, {TILE_K, 0}};\n\n                            const int mc_offset_u = kMulticastU > 1 ? mc_offset_m : 0;\n                            // column-major\n                            GmemIteratorSm90<kMulticastU> gmem_U{\n                                Udesc, {offset_m + mc_offset_u, offset_k / 128}, {0, 1}};\n\n                            gmem_V0 += (offset_n / 128) * param_V.stride + (offset_k / 128);\n                            gmem_V1 = gmem_V0;\n                            if (offset_n / 128 + 1 < cdiv(sched.gemm_shape().y, 128)) {\n                                gmem_V1 += param_V.stride;\n                            }\n\n                            for (; k_iter > 0; --k_iter) {\n                                int pipe = write_state.index();\n                                ConsumerBar::wait(&consumer_bar[pipe], write_state.phase());\n                                ProducerBar::arrive_and_expect_tx(&producer_bar[pipe], kTmaTxBytes);\n                                gmem_A.Step(&producer_bar[pipe], &smem_A[pipe * TILE_M * TILE_K], mask_A);\n                                gmem_B.Step(&producer_bar[pipe], &smem_B[pipe * TILE_N * TILE_K], mask_B);\n                                gmem_U.Step(&producer_bar[pipe], smem_U[pipe] + mc_offset_u, mask_A);\n                                uint32_t uint_ptr_V = cast_smem_ptr_to_uint(smem_V[pipe]);\n                                CP_ASYNC<CacheOp::kAlways, 4, 0>::apply(uint_ptr_V, gmem_V0, true);\n                                CP_ASYNC<CacheOp::kAlways, 4, 0>::apply(uint_ptr_V + sizeof(Tv), gmem_V1, true);\n                                ++gmem_V0;\n                                ++gmem_V1;\n                                cutlass::arch::cpasync_barrier_arrive_noinc(&producer_bar[pipe]);\n                                ++write_state;\n                            }\n                        }\n                    }\n\n                    if constexpr (Scheduler::is_dynamic) {\n                        if (cta_0) {\n                            producers_bar.arrive_unaligned();\n                        }\n                    }\n\n                    sched_state.release();\n\n                }  // scheduler loop\n\n                // release last tile\n                sched_state.release();\n\n                if constexpr (kClusterSize > 1) {\n                    if (lane_predicate) {\n                        for (int i = 0; i < Stages; ++i) {\n                            ConsumerBar::wait(&consumer_bar[write_state.index()], write_state.phase());\n                            ++write_state;\n                        }\n                    }\n                    __syncwarp();\n                }\n            }\n            else if (warp_id % 4 == 1 && cta_0) {\n                auto state = sched.init_producer(storage.sched);\n                while (state.next()) {\n                    if constexpr (Scheduler::is_dynamic) {\n                        producers_bar.arrive_and_wait_unaligned();\n                    }\n                }\n                sched.tail(state);\n            }\n        }\n        else {\n            cutlass::arch::warpgroup_reg_alloc<232>();\n\n            if constexpr (is_grouped_gemm) {\n                if (threadIdx.x % WARPGROUP_SIZE / WARP_SIZE == 0) {\n                    init_tma_descs<1>({&tm_c}, storage.tensor_map + 3 + wg_idx);\n                }\n            }\n\n            auto& smem_A = storage.A;\n            auto& smem_B = storage.B;\n            auto& smem_U = storage.U;\n            auto& smem_V = storage.V;\n\n            const int wg_idx_m = WG_M > 1 ? wg_idx % WG_M : 0;\n            const int wg_idx_n = WG_N > 1 ? wg_idx / WG_M : 0;\n\n            auto smem_desc_A = make_smem_desc(&smem_A[wg_idx_m * WG_TILE_M * TILE_K], 1);\n            auto smem_desc_B = make_smem_desc(&smem_B[wg_idx_n * WG_TILE_N * TILE_K], 1);\n\n            SmemDescIterV2<Stages, ((TILE_M * TILE_K) >> 4)> smem_iter_A{smem_desc_A};\n            SmemDescIterV2<Stages, ((TILE_N * TILE_K) >> 4)> smem_iter_B{smem_desc_B};\n\n            cutlass::arch::NamedBarrier barrier(WARPGROUP_SIZE, 2 + wg_idx);  // 0, 1\n\n            cutlass::PipelineState<Stages> pipe_state{};\n\n            const int warp_id = cutlass::canonical_warp_idx_sync();\n            const int lane_id = cutlass::canonical_lane_idx();\n\n            auto consumer_arrive = [&] {\n                auto bar = &consumer_bar[pipe_state.index()];\n                __syncwarp();\n                if constexpr (kClusterSize > 1) {\n                    ConsumerBar::arrive(bar, lane_id, lane_id < kClusterSize);\n                }\n                else {\n                    if (lane_id == 0) {\n                        ConsumerBar::arrive(bar);\n                    }\n                }\n            };\n\n            auto sched_state = sched.init_consumer(storage.sched);\n\n            typename Scheduler::Tile* tile;\n\n            sched_state.acquire(tile);\n\n            while (tile->alive) {\n\n                if (tile->is_valid_cta) {\n                    GMMA::AccumC accum_C{};\n                    GMMA::FragC  frag_C;\n\n                    auto pred_V = Fetch_V(tile, wg_idx_n);\n\n                    float scale_V[2];\n                    auto  Load_V = [&] {\n                        scale_V[0] = smem_V[pipe_state.index()][0];\n                        scale_V[1] = smem_V[pipe_state.index()][1];\n                    };\n\n                    int offset_U = wg_idx_m * WG_TILE_M + warp_id % 4 * 16 + lane_id / 4;\n                    if constexpr (is_grouped_gemm) {\n                        offset_U += tile->m0 % kAlignmentU;\n                    }\n                    GMMA::FragU frag_U;\n                    auto        Load_U = [&] {\n                        GMMA::foreach_m(frag_U, [&](auto& U, int m) {\n                            U[0] = smem_U[pipe_state.index()][offset_U + m * GMMA::OP_M];\n                            U[1] = smem_U[pipe_state.index()][offset_U + m * GMMA::OP_M + 8];\n                        });\n                    };\n\n                    auto gmma = [&] {  //\n                        GMMA::apply(smem_iter_A, smem_iter_B, frag_C, accum_C, frag_U, scale_V, pred_V);\n                    };\n\n                    if constexpr (is_grouped_gemm) {\n                        if (threadIdx.x % WARPGROUP_SIZE < LayoutC::C1) {\n                            cute::tma_store_wait<0>();\n                        }\n                        // No need to sync here as the update is warp synchronized\n                        if (warp_id % 4 == 0) {\n                            int  m0 = tile->m0, m1 = tile->m1;\n                            auto global_addr = (Tc*)param_C.ptr + m0 * (int64_t)param_C.stride;\n                            int  idx         = 3 + wg_idx;\n                            update_tma_descs<1>(\n                                tensormap_buf + idx, storage.tensor_map + idx, {global_addr}, {m1 - m0});\n                        }\n                        barrier.sync();\n                    }\n\n                    int k_iter = sched.k_iters_;\n\n                    ProducerBar::wait(&producer_bar[pipe_state.index()], pipe_state.phase());\n                    Load_V();\n                    Load_U();\n                    smem_iter_A.Reset(pipe_state.index());\n                    smem_iter_B.Reset(pipe_state.index());\n                    gmma();\n                    consumer_arrive();\n                    ++pipe_state;\n                    --k_iter;\n\n                    ProducerBar::wait(&producer_bar[pipe_state.index()], pipe_state.phase());\n                    Load_V();\n                    Load_U();\n                    smem_iter_A.Reset(pipe_state.index());\n                    smem_iter_B.Reset(pipe_state.index());\n\n                    PRAGMA_NO_UNROLL\n                    for (; k_iter > 1; --k_iter) {\n                        gmma();\n                        consumer_arrive();\n                        ++pipe_state;\n                        ProducerBar::wait(&producer_bar[pipe_state.index()], pipe_state.phase());\n                        Load_V();\n                        Load_U();\n                        smem_iter_A.Reset(pipe_state.index());\n                        smem_iter_B.Reset(pipe_state.index());\n                    }\n\n                    gmma();\n\n                    const int thread_idx = threadIdx.x % WARPGROUP_SIZE;\n                    if constexpr (!is_grouped_gemm) {\n                        if (thread_idx < LayoutC::C1) {\n                            cute::tma_store_wait<0>();\n                        }\n                        barrier.sync();\n                    }\n\n                    consumer_arrive();\n                    ++pipe_state;\n\n                    Tc* smem_C = &storage.C[wg_idx_m * WG_TILE_M * TILE_N + wg_idx_n * WG_TILE_N];\n\n                    // epilogue\n                    GMMA::foreach_C(accum_C, [&](const auto& C, int m, int n) {\n                        constexpr int N       = LayoutC::C0;\n                        constexpr int SW_bits = log2(kSwizzleC / 16);\n\n                        static_assert(!SW_bits || GMMA::OP_N % LayoutC::C0 == 0);\n                        static_assert(GMMA::OP_N % 16 == 0);\n\n                        const int m0 = m * GMMA::OP_M;\n                        const int n0 = n * GMMA::OP_N;\n\n                        PRAGMA_UNROLL\n                        for (int i = 0; i < GMMA::OP_N; i += 16) {\n                            __align__(16) Array<Tc, 8> tvec = cast<Tc>((Array<float, 8>&)C[i / 2]);\n\n                            // fill(tvec, Tc(255));\n\n                            int mm = m0 + warp_id % 4 * 16 + (lane_id & 8);\n                            int nn = n0 + i / N * N;\n\n                            int addr = ((nn / N) * WG_TILE_M * N) + (mm * N) + (nn % N);\n\n                            int s = lane_id % 8;\n                            int c = (lane_id & 16) / 2 + i % N;\n\n                            addr += Swizzle<SW_bits, 3, 3>::apply(s * N + c);\n\n                            auto& uvec = (Array<uint32_t, 4>&)tvec;\n                            cute::SM90_U32x4_STSM_N::copy(uvec[0],  //\n                                                          uvec[1],\n                                                          uvec[2],\n                                                          uvec[3],\n                                                          (cutlass::uint128_t&)smem_C[addr]);\n                        }\n                    });\n\n                    cute::tma_store_fence();  // visibility: smem -> async proxy\n\n                    barrier.sync();\n\n                    const int offset_m = tile->offset_m;\n                    const int offset_n = tile->offset_n;\n\n                    const void* Cdesc = &tm_c;\n\n                    if (thread_idx < LayoutC::C1) {\n                        const int tma_n = thread_idx * LayoutC::C0;\n                        if constexpr (is_grouped_gemm) {\n                            Cdesc = tensormap_buf + blockIdx.x * 5 + 3 + wg_idx;\n                            cute::tma_descriptor_fence_acquire((cute::TmaDescriptor*)Cdesc);\n                        }\n                        cute::SM90_TMA_STORE::copy(Cdesc,\n                                                   &smem_C[thread_idx * WG_TILE_M * LayoutC::C0],\n                                                   offset_n + wg_idx_n * WG_TILE_N + tma_n,\n                                                   offset_m + wg_idx_m * WG_TILE_M);\n                        cute::tma_store_arrive();\n                    }\n                }\n                else if (tile->is_valid_cluster) {\n                    int k_iter = sched.k_iters_;\n                    for (; k_iter > 0; --k_iter) {\n                        ProducerBar::wait(&producer_bar[pipe_state.index()], pipe_state.phase());\n                        consumer_arrive();\n                        ++pipe_state;\n                    }\n                }\n\n                sched_state.release();\n                sched_state.acquire(tile);\n\n            }  // scheduler loop\n\n            // release last tile\n            sched_state.release();\n\n            if (threadIdx.x % WARPGROUP_SIZE < LayoutC::C1) {\n                cute::tma_store_wait<0>();\n            }\n        }\n\n    }  // operator()\n\n    template<int N>\n    __device__ void init_tma_descs(Array<const CUtensorMap*, N> param_desc, CUtensorMap* smem_desc)\n    {\n        const int lane_id = threadIdx.x % WARP_SIZE;\n\n        if (lane_id < sizeof(CUtensorMap) / sizeof(uint2)) {\n            PRAGMA_UNROLL\n            for (int i = 0; i < N; ++i) {\n                ((uint2*)&smem_desc[i])[lane_id] = ((uint2*)param_desc[i])[lane_id];\n            }\n        }\n\n        __syncwarp();\n    }\n\n    template<int N>\n    __device__ CUtensorMap*\n    update_tma_descs(CUtensorMap* gmem_desc, CUtensorMap* smem_desc, Array<void*, N> global_addrs, Array<int, N> dims)\n    {\n        const int lane_id = threadIdx.x % WARP_SIZE;\n        if (lane_id == 0) {\n            PRAGMA_UNROLL\n            for (int i = 0; i < N; ++i) {\n                uint32_t uint_ptr = cast_smem_ptr_to_uint(&smem_desc[i]);\n                // clang-format off\n                asm volatile(\"tensormap.replace.tile.global_address.shared::cta.b1024.b64 [%0], %1;\" ::\"r\"(uint_ptr), \"l\"(global_addrs[i]));\n                if (i != 2) {\n                    asm volatile(\"tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [%0], 1, %1;\" ::\"r\"(uint_ptr), \"r\"(dims[i]));\n                } else { // special case for U\n                    asm volatile(\"tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [%0], 0, %1;\" ::\"r\"(uint_ptr), \"r\"(dims[i]));\n                }\n                // clang-format on\n            }\n        }\n\n        __syncwarp();\n\n        constexpr int kNumPerCta = 5;  // a,b,u,c0,c1\n        auto          gmem_ptr   = &gmem_desc[blockIdx.x * kNumPerCta];\n        PRAGMA_UNROLL\n        for (int i = 0; i < N; ++i) {\n            uint32_t uint_ptr = cast_smem_ptr_to_uint(&smem_desc[i]);\n            // clang-format off\n            asm volatile(\"tensormap.cp_fenceproxy.global.shared::cta.tensormap::generic.release.gpu.sync.aligned [%0], [%1], 128;\" :: \"l\"(gmem_ptr + i), \"r\"(uint_ptr));\n            // clang-format on\n        }\n\n        return gmem_ptr;\n    }\n\n    __device__ auto Fetch_V(typename Scheduler::Tile* tile, int wg_idx_n)\n    {\n        constexpr int BLK_SUBTILE_N = 128 / OUTER_N;\n        static_assert(MMA_SUBTILE_N - 1 < BLK_SUBTILE_N + 1);  // n1 - 1 + n0 - 1 < 2 * n0\n\n        Array<bool, MMA_SUBTILE_N> pred_V{};\n        if constexpr (MMA_SUBTILE_N != 1) {\n            int offset = tile->offset_n % 128 + wg_idx_n * WG_TILE_N;\n            static_assert(WG_N == 1);\n            // Safely skip pred_V_0 when distributing WGs along M\n            PRAGMA_UNROLL\n            for (int i = 1; i < MMA_SUBTILE_N; ++i) {\n                pred_V[i] = (i * OUTER_N + offset) >= 128;\n            }\n        }\n\n        return pred_V;\n    }\n};\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/gemm_universal_sm90_v4.h",
    "content": "#pragma once\n\n#include <numeric>\n#include <utility>\n\n#include <cuda_fp8.h>\n\n#include \"cute/arch/cluster_sm90.hpp\"\n#include \"cute/arch/copy_sm90.hpp\"\n#include \"cute/arch/copy_sm90_desc.hpp\"\n#include \"cute/arch/copy_sm90_tma.hpp\"\n#include \"cute/arch/mma_sm90_desc.hpp\"\n#include \"cute/arch/mma_sm90_gmma.hpp\"\n#include \"cute/atom/mma_traits.hpp\"\n#include \"cute/atom/mma_traits_sm90_gmma.hpp\"\n\n#include \"cutlass/arch/barrier.h\"\n#include \"cutlass/arch/reg_reconfig.h\"\n#include \"cutlass/cutlass.h\"\n#include \"cutlass/pipeline/sm90_pipeline.hpp\"\n\n#include \"src/turbomind/core/data_type.h\"\n\n#include \"src/turbomind/kernels/core/array_ops.h\"\n#include \"src/turbomind/kernels/core/common.h\"\n#include \"src/turbomind/kernels/core/layout.h\"\n#include \"src/turbomind/kernels/core/meta.h\"\n#include \"src/turbomind/kernels/core/smem.h\"\n\n#include \"src/turbomind/kernels/gemm/arch.h\"\n#include \"src/turbomind/kernels/gemm/iterator_sm90.h\"\n#include \"src/turbomind/kernels/gemm/matrix_ptr.h\"\n#include \"src/turbomind/kernels/gemm/scheduler.cuh\"\n#include \"src/turbomind/kernels/gemm/types.h\"\n#include \"src/turbomind/kernels/gemm/utils.h\"\n\nnamespace turbomind::gemm {\n\nnamespace GMMA = cute::SM90::GMMA;\n\ninline __device__ cute::GmmaDescriptor make_smem_desc(void* smem_ptr, int layout_type)\n{\n    auto uint_ptr = cast_smem_ptr_to_uint(smem_ptr);\n\n    cute::GmmaDescriptor desc{};\n    desc.bitfield.start_address_       = uint_ptr >> 4;\n    desc.bitfield.layout_type_         = layout_type;\n    desc.bitfield.leading_byte_offset_ = 0;\n    desc.bitfield.stride_byte_offset_  = 1024 >> 4;\n    desc.bitfield.base_offset_         = 0;\n\n    return desc;\n}\n\ntemplate<int Stages, int Step>\nstruct SmemDescIterV2 {\n    union {\n        uint32_t u32_[2];\n        uint64_t u64_;\n    };\n\n    uint32_t base_;\n\n    __device__ SmemDescIterV2(uint64_t desc): u64_{desc}, base_{u32_[0]} {}\n\n    __device__ void Advance(int stage)\n    {\n        u32_[0] += Step;\n        if (stage == Stages - 1) {\n            u32_[0] = base_;\n        }\n    }\n\n    __device__ void Reset(int stage)\n    {\n        u32_[0] = base_ + stage * Step;\n    }\n\n    __device__ SmemDescIterV2& operator+=(int offset)\n    {\n        u32_[0] += offset;\n        return *this;\n    }\n\n    __device__ SmemDescIterV2& operator-=(int offset)\n    {\n        u32_[0] -= offset;\n        return *this;\n    }\n\n    __device__ operator uint64_t()\n    {\n        return u64_;\n    }\n};\n\ntemplate<class MMA_Atom, size_t... Is>\ninline __device__ void\nwgmma_impl(uint64_t desc_a, uint64_t desc_b, float* frag_C, bool clear, std::index_sequence<Is...>)\n{\n    return MMA_Atom::fma(desc_a, desc_b, frag_C[Is]..., clear ? GMMA::ScaleOut::Zero : GMMA::ScaleOut::One);\n}\n\ntemplate<class MMA_Atom, int N>\ninline __device__ void wgmma(uint64_t desc_a, uint64_t desc_b, float (&frag_C)[N], bool clear)\n{\n    return wgmma_impl<MMA_Atom>(desc_a, desc_b, frag_C, clear, std::make_index_sequence<N>{});\n}\n\ninline __device__ void warpgroup_fence_operand(float& reg)\n{\n    asm volatile(\"\" : \"+f\"(reg)::\"memory\");\n}\n\ntemplate<int M, int N, int K>\ninline __device__ void warpgroup_fence_operand(float (&x)[M][N][K])\n{\n    PRAGMA_UNROLL\n    for (int m = 0; m < M; ++m) {\n        PRAGMA_UNROLL\n        for (int n = 0; n < N; ++n) {\n            PRAGMA_UNROLL\n            for (int k = 0; k < K; ++k) {\n                warpgroup_fence_operand(x[m][n][k]);\n            }\n        }\n    }\n}\n\ntemplate<int N, int K>\ninline __device__ void warpgroup_fence_operand(float (&x)[N][K])\n{\n    PRAGMA_UNROLL\n    for (int n = 0; n < N; ++n) {\n        PRAGMA_UNROLL\n        for (int k = 0; k < K; ++k) {\n            warpgroup_fence_operand(x[n][k]);\n        }\n    }\n}\n\ntemplate<class Func, size_t... Is>\n__device__ void for_(std::index_sequence<Is...>, Func func)\n{\n    return (func(constant<Is>{}), ...);\n}\n\nnamespace arch {\n\ntemplate<int M_, int N_, Order order>\nstruct Cluster {\n    static constexpr int M = M_;\n    static constexpr int N = N_;\n\n    static constexpr int C = mk2cs<order>(M, N).x;\n    static constexpr int S = mk2cs<order>(M, N).y;\n\n    static constexpr int size = M * N;\n\n    static constexpr uint16_t kMaskC = (1 << C) - 1;\n    static constexpr uint16_t kMaskS = ((1 << size) - 1) / kMaskC;\n\n    __device__ static ushort2 mask_cs(int cta_id)\n    {\n        const auto [c, s] = cta_cs(cta_id);\n        return make_ushort2(kMaskS << c, kMaskC << s * C);\n    }\n\n    __device__ static ushort2 mask_mn(int cta_id)\n    {\n        auto [c, s] = mask_cs(cta_id);\n        return order == kColMajor ? ushort2{c, s} : ushort2{s, c};\n    }\n\n    __device__ static int2 cta_cs(int cta_id)\n    {\n        return {C > 1 ? cta_id % C : 0, S > 1 ? cta_id / C : 0};\n    }\n\n    __device__ static int2 cta_mn(int cta_id)\n    {\n        return cs2mk<order>(cta_cs(cta_id));\n    }\n\n    int2    cta_mn_;\n    ushort2 mask_mn_;\n\n    __device__ explicit Cluster(int cta_id): cta_mn_(cta_mn(cta_id)), mask_mn_(mask_mn(cta_id)) {}\n\n    __device__ int cta_m()\n    {\n        return cta_mn_.x;\n    }\n\n    __device__ int cta_n()\n    {\n        return cta_mn_.y;\n    }\n\n    __device__ uint16_t mask_m()\n    {\n        return mask_mn_.x;\n    }\n\n    __device__ uint16_t mask_n()\n    {\n        return mask_mn_.y;\n    }\n};\n\n}  // namespace arch\n\nstruct GemmUniversalSm90_v3 {\n\n    static constexpr bool kDebug = false;\n\n    using Arch = Sm90;\n\n    // using MMA_Atom = GMMA::MMA_64x128x16_F32BF16BF16_SS<GMMA::Major::K, GMMA::Major::K>;\n    using MMA_Atom = GMMA::MMA_64x192x32_F32E4M3E4M3_SS_TN<>;\n    static constexpr typename cute::MMA_Traits<MMA_Atom>::Shape_MNK MMA_Shape{};\n\n    static constexpr int MMA_ATOM_M = cute::get<0>(MMA_Shape);\n    static constexpr int MMA_ATOM_N = cute::get<1>(MMA_Shape);\n    static constexpr int MMA_ATOM_K = cute::get<2>(MMA_Shape);\n\n    static constexpr int TILE_M = 128;\n    static constexpr int TILE_N = 192;\n    static constexpr int TILE_K = 128;\n\n    static constexpr int WG_M = 2;\n    static constexpr int WG_N = 1;\n\n    static constexpr int WG_TILE_M = TILE_M / WG_M;\n    static constexpr int WG_TILE_N = TILE_N / WG_N;\n\n    static constexpr int kSchedWarpGroups = 1;\n\n    static constexpr int WARPGORUPS = WG_M * WG_N;\n\n    static constexpr int MMA_ITER_M = WG_TILE_M / MMA_ATOM_M;\n    static constexpr int MMA_ITER_N = WG_TILE_N / MMA_ATOM_N;\n    static constexpr int MMA_ITER_K = TILE_K / MMA_ATOM_K;\n\n    static constexpr int kMulticastA = 1;\n    static constexpr int kMulticastB = 2;\n\n    static constexpr int kClusterSize = kMulticastA * kMulticastB;\n\n    static constexpr int Stages = 4;\n\n    static constexpr bool kSplitK     = false;\n    static constexpr int  kChunkSizeK = TILE_K;\n\n    static constexpr int WARPGROUP_SIZE = 128;\n\n    static constexpr int CTA_SIZE = WARPGROUP_SIZE * (WARPGORUPS + 1);\n\n    using Ta = __nv_fp8_e4m3;\n    using Tb = __nv_fp8_e4m3;\n    using Tc = nv_bfloat16;\n\n    using Tu = float;\n    using Tv = float;\n\n    using Cluster = arch::Cluster<kMulticastB, kMulticastA, kRowMajor>;\n\n    using Scheduler = TileScheduler<kRowMajor, Cluster, true, true, false>;\n\n    using ProducerBar = cutlass::arch::ClusterTransactionBarrier;\n    using ConsumerBar = cutlass::arch::ClusterBarrier;\n\n    static constexpr int MAX_K = 32768;\n\n    static constexpr int TILE_M_U = cdiv(TILE_M, 1);\n    static constexpr int CTA_K_U  = cdiv(TILE_K, 128);\n\n    static constexpr int kTmaTxBytes =\n        sizeof(Ta) * (TILE_M * TILE_K) + sizeof(Tb) * (TILE_N * TILE_K) + sizeof(Tu) * TILE_M_U * CTA_K_U;\n\n    // ! Smem addr must be SBO aligned for TMA load/store\n    struct SharedStorage {\n        struct Source {\n            __align__(1024) Array<Ta, Stages * TILE_M * TILE_K> A;\n            __align__(1024) Array<Tb, Stages * TILE_N * TILE_K> B;\n            __align__(1024) Tu U[Stages][TILE_M_U * CTA_K_U];\n            // __align__(1024) Tv V[2][WARPGORUPS][cdiv(MAX_K, 128)];\n            __align__(1024) Tv V[Stages][2 * cdiv(MAX_K, 128)];\n        };\n        Source source;\n        __align__(1024) Array<Tc, TILE_M * TILE_N> C;\n        __align__(128) uint64_t producer_bar[Stages];\n        __align__(128) uint64_t consumer_bar[Stages];\n        __align__(128) CUtensorMap tma_desc_C[WARPGORUPS];\n    };\n\n    static constexpr int kSmemSize = sizeof(SharedStorage);\n\n    static constexpr int kSwizzleC = 2 * std::gcd(WG_TILE_N, 128 / sizeof(Tc));\n\n    using LayoutC = std::conditional_t<kSwizzleC >= 32,\n                                       SmemLayoutV2<WG_TILE_M, WG_TILE_N, -1, kSwizzleC / sizeof(Tc)>,\n                                       SmemLayoutV2<WG_TILE_M, WG_TILE_N>>;\n\n    __device__ void operator()(const CUtensorMap& tm_a,\n                               const CUtensorMap& tm_b,\n                               const CUtensorMap& tm_c,\n                               const CUtensorMap& tm_u,\n                               const CUtensorMap& tm_v,\n                               const MatrixParam& param_A,\n                               const MatrixParam& param_B,\n                               const MatrixParam& param_U,\n                               const MatrixParam& param_V,\n                               const MatrixParam& param_C,\n                               uint2              box_V,\n                               Scheduler          sched,\n                               CUtensorMap*       tensormap_buf,\n                               char*              smem_buf)\n    {\n        SharedStorage& storage = *reinterpret_cast<SharedStorage*>(smem_buf);\n\n        uint64_t* producer_bar = storage.producer_bar;\n        uint64_t* consumer_bar = storage.consumer_bar;\n\n        if (threadIdx.x == 0) {\n            PRAGMA_UNROLL\n            for (int s = 0; s < Stages; ++s) {\n                ProducerBar::init(&producer_bar[s], 1);\n                ConsumerBar::init(&consumer_bar[s], WARPGORUPS * kClusterSize * 4);\n            }\n            cutlass::arch::fence_view_async_shared();\n            if constexpr (kClusterSize > 1) {\n                cutlass::arch::fence_barrier_init();\n            }\n        }\n\n        (kClusterSize > 1) ? cute::cluster_sync() : __syncthreads();\n\n        const int wg_idx = cutlass::canonical_warp_group_idx();\n\n        if (wg_idx == WARPGORUPS) {\n            cutlass::arch::warpgroup_reg_dealloc<40>();\n\n            static_assert(TILE_M % kMulticastA == 0);\n            static_assert(TILE_N % kMulticastB == 0);\n\n            // if (threadIdx.x == WARPGORUPS * WARPGROUP_SIZE) {\n            if (threadIdx.x % WARPGROUP_SIZE / WARP_SIZE == 0) {\n\n                Cluster cluster(cute::block_id_in_cluster().x);\n\n                const int mc_offset_m = cluster.cta_n() * (TILE_M / kMulticastA);\n                const int mc_offset_n = cluster.cta_m() * (TILE_N / kMulticastB);\n\n                auto  smem_A = storage.source.A.data() + mc_offset_m * TILE_K;\n                auto  smem_B = storage.source.B.data() + mc_offset_n * TILE_K;\n                auto& smem_U = storage.source.U;\n                auto& smem_V = storage.source.V;\n\n                sched.grid_init();\n\n                cutlass::PipelineState<Stages> write_state{0, 1, 0};\n                cutlass::PipelineState<Stages> v_state{0, 1, 0};\n\n                while (sched.next()) {\n                    if (cute::elect_one_sync()) {\n                        auto [valid_cta_tile_p, cluster_tile_p] = sched.is_valid_tile();\n\n                        if (!cluster_tile_p) {\n                            // OOB tile caused by swizzle pattern\n                            continue;\n                        }\n\n                        const auto tile_offset              = sched.tile_offset();\n                        const auto [iter_k_beg, iter_k_end] = sched.iter_k_range();\n\n                        const int offset_k = iter_k_beg * TILE_K;\n\n                        const uint16_t mask_A = cluster.mask_m();\n                        const uint16_t mask_B = cluster.mask_n();\n\n                        const int offset_m = tile_offset.x * TILE_M;\n                        const int offset_n = tile_offset.y * TILE_N;\n\n                        int k_iter = iter_k_end - iter_k_beg;\n\n                        GmemIteratorSm90<kMulticastA> gmem_A{&tm_a, {offset_k, offset_m + mc_offset_m}, {TILE_K, 0}};\n                        GmemIteratorSm90<kMulticastB> gmem_B{&tm_b, {offset_k, offset_n + mc_offset_n}, {TILE_K, 0}};\n                        GmemIteratorSm90<kMulticastA> gmem_U{&tm_u, {offset_m + mc_offset_m, offset_k / 128}, {0, 1}};\n                        GmemIteratorSm90<1>           gmem_V(&tm_v, {0, offset_n / 128}, {0, 0});\n\n                        {\n                            int pipe = write_state.index();\n                            ConsumerBar::wait(&consumer_bar[pipe], write_state.phase());\n                            const int v_bytes = sizeof(Tv) * box_V.x * box_V.y;\n                            ProducerBar::arrive_and_expect_tx(&producer_bar[pipe], kTmaTxBytes + v_bytes);\n                            gmem_A.Step(&producer_bar[pipe], &smem_A[pipe * TILE_M * TILE_K], mask_A);\n                            gmem_B.Step(&producer_bar[pipe], &smem_B[pipe * TILE_N * TILE_K], mask_B);\n                            gmem_U.Step(&producer_bar[pipe], &smem_U[pipe][0] + mc_offset_m, mask_A);\n                            gmem_V.Step(&producer_bar[pipe], &smem_V[v_state.index()], 0);\n                            ++write_state;\n                            --k_iter;\n                        }\n\n                        for (; k_iter > 0; --k_iter) {\n                            int pipe = write_state.index();\n                            ConsumerBar::wait(&consumer_bar[pipe], write_state.phase());\n                            ProducerBar::arrive_and_expect_tx(&producer_bar[pipe], kTmaTxBytes);\n                            gmem_A.Step(&producer_bar[pipe], &smem_A[pipe * TILE_M * TILE_K], mask_A);\n                            gmem_B.Step(&producer_bar[pipe], &smem_B[pipe * TILE_N * TILE_K], mask_B);\n                            gmem_U.Step(&producer_bar[pipe], &smem_U[pipe][0] + mc_offset_m, mask_A);\n                            ++write_state;\n                        }\n\n                        ++v_state;\n                    }\n                }\n            }\n        }\n        else {\n            cutlass::arch::warpgroup_reg_alloc<232>();\n\n            sched.grid_init(kSchedWarpGroups);\n\n            auto& smem_A = storage.source.A;\n            auto& smem_B = storage.source.B;\n            auto& smem_U = storage.source.U;\n\n            const int wg_idx_m = WG_M > 1 ? wg_idx % WG_M : 0;\n            const int wg_idx_n = WG_N > 1 ? wg_idx / WG_M : 0;\n\n            auto smem_desc_A = make_smem_desc(&smem_A[wg_idx_m * WG_TILE_M * TILE_K], 1);\n            auto smem_desc_B = make_smem_desc(&smem_B[wg_idx_n * WG_TILE_N * TILE_K], 1);\n\n            SmemDescIterV2<Stages, ((sizeof(Ta) * TILE_M * TILE_K) >> 4)> smem_iter_A{smem_desc_A};\n            SmemDescIterV2<Stages, ((sizeof(Tb) * TILE_N * TILE_K) >> 4)> smem_iter_B{smem_desc_B};\n\n            constexpr int kStepMA = (sizeof(Ta) * MMA_ATOM_M * TILE_K) >> 4;\n            constexpr int kStepNB = (sizeof(Tb) * MMA_ATOM_N * TILE_K) >> 4;\n            constexpr int kStepKA = (sizeof(Ta) * MMA_ATOM_K) >> 4;\n            constexpr int kStepKB = (sizeof(Tb) * MMA_ATOM_K) >> 4;\n\n            cutlass::arch::NamedBarrier wg_barrier(WARPGROUP_SIZE, wg_idx + 2);  // 2,3\n\n            auto epi_barrier = [&](int phase) {  // 0, 1\n                return EmptyBarrier{};\n                // return cutlass::arch::NamedBarrier(WARPGORUPS * WARPGROUP_SIZE, wg_idx ^ phase);\n            };\n\n            if (wg_idx == 1) {\n                epi_barrier(1).arrive_unaligned();\n            }\n\n            cutlass::PipelineState<Stages> pipe_state{};\n            cutlass::PipelineState<Stages> v_state{};\n\n            while (sched.next(kSchedWarpGroups)) {\n                auto [cta_tile_p, cluster_tile_p] = sched.is_valid_tile();\n\n                if (!cluster_tile_p) {\n                    // OOB tile caused by swizzle pattern\n                    continue;\n                }\n\n                MMA_Atom::CRegisters frag_C[MMA_ITER_N];\n                MMA_Atom::CRegisters accum_C[MMA_ITER_M][MMA_ITER_N]{};\n\n                const auto tile_offset              = sched.tile_offset();\n                const auto [iter_k_beg, iter_k_end] = sched.iter_k_range();\n\n                // const auto [M, N, K, L] = sched.gemm_shape();\n\n                const int offset_m = tile_offset.x * TILE_M;\n                const int offset_n = tile_offset.y * TILE_N;\n\n                const int wg_offset_n = offset_n + wg_idx_n * WG_TILE_N;\n\n                int k_iter = iter_k_end - iter_k_beg;\n\n                const int warp_id = threadIdx.x / WARP_SIZE;\n                const int lane_id = threadIdx.x % WARP_SIZE;\n\n                auto consumer_arrive = [&] {\n                    __syncwarp();\n                    if constexpr (kClusterSize > 1) {\n                        ConsumerBar::arrive(&consumer_bar[pipe_state.index()], lane_id, lane_id < kClusterSize);\n                    }\n                    else {\n                        if (lane_id == 0) {\n                            ConsumerBar::arrive(&consumer_bar[pipe_state.index()]);\n                        }\n                    }\n                    __syncwarp();\n                };\n\n                if constexpr (kClusterSize > 1) {\n                    if (!cta_tile_p) {  // other CTAs in the cluster are still alive\n                        for (; k_iter > 0; --k_iter) {\n                            ProducerBar::wait(&producer_bar[pipe_state.index()], pipe_state.phase());\n                            consumer_arrive();\n                            ++pipe_state;\n                        }\n                        epi_barrier(0).arrive_and_wait_unaligned();\n                        epi_barrier(1).arrive_unaligned();\n                        continue;\n                    }\n                }\n\n                // auto Copy = [k = cdiv(K, 128)](Tv* dst, const Tv* src) {\n                //     PRAGMA_NO_UNROLL\n                //     for (int i = threadIdx.x % WARPGROUP_SIZE; i < k; i += WARPGROUP_SIZE) {\n                //         dst[i] = __ldg(&src[i]);\n                //     }\n                // };\n                // auto gmem_V = (const Tv*)param_V.ptr + (wg_offset_n / 128) * param_V.stride + (offset_k / 128);\n                // Copy(storage.source.V[0][wg_idx], gmem_V);\n\n                uint32_t pred_V{};\n                int      iter_V{};\n\n                constexpr int OUTER_N = std::gcd(MMA_ATOM_N, 128);\n                if constexpr (OUTER_N != 128) {\n\n                    static_assert(MMA_ATOM_N <= 128 + OUTER_N, \"MMA inst is crossing more than 2 scale blocks\");\n\n                    constexpr uint32_t mask = (1UL << (WG_TILE_N / OUTER_N)) - 1;\n\n                    int phase = 128 - wg_offset_n % 128;\n                    pred_V    = (mask << (phase / OUTER_N)) & mask;\n\n                    // if (pred_V && wg_offset_n / 128 + 1 < cdiv(N, 128)) {\n                    //     Copy(storage.source.V[1][wg_idx], gmem_V + param_V.stride);\n                    // }\n                    // if constexpr (WG_N > 1) {\n                    //     constexpr int tiles = MMA_ATOM_N / OUTER_N;\n                    //     pred_V              = (pred_V >> (wg_idx_n * tiles)) & ((1 << tiles) - 1);\n                    // }\n                }\n\n                __syncwarp();\n\n                float scale_V[2];\n                // auto  Load_V = [&] {\n                //     scale_V[0] = storage.source.V[0][wg_idx][iter_V];\n                //     if (pred_V) {\n                //         scale_V[1] = storage.source.V[1][wg_idx][iter_V];\n                //     }\n                //     ++iter_V;\n                // };\n                auto Load_V = [&] {\n                    // scale_V[0] = scale_V[1] = 1;\n                    scale_V[0] = storage.source.V[v_state.index()][iter_V];\n                    if (pred_V) {\n                        scale_V[1] = storage.source.V[v_state.index()][box_V.x + iter_V];\n                    }\n                    ++iter_V;\n                };\n\n                float     scale_U[MMA_ITER_M][2];\n                const int offset_U = wg_idx_m * WG_TILE_M + warp_id % 4 * 16 + lane_id / 4;\n                auto      Load_U   = [&] {\n                    for (int m = 0; m < MMA_ITER_M; ++m) {\n                        scale_U[m][0] = smem_U[pipe_state.index()][offset_U + m * MMA_ATOM_M];\n                        scale_U[m][1] = smem_U[pipe_state.index()][offset_U + m * MMA_ATOM_M + 8];\n                    }\n                };\n\n                auto scale_accum = [&](int m) {  // cta_n = mma_iter_n * wg_n * mma_atom_n\n                    float scales[2][2];\n                    scales[0][0] = scale_U[m][0] * scale_V[0];\n                    scales[1][0] = scale_U[m][1] * scale_V[0];\n                    scales[0][1] = scale_U[m][0] * scale_V[1];\n                    scales[1][1] = scale_U[m][1] * scale_V[1];\n                    cute::warpgroup_wait<0>();\n                    PRAGMA_UNROLL\n                    for (int n = 0; n < MMA_ITER_N; ++n) {\n                        PRAGMA_UNROLL\n                        for (int c0 = 0; c0 < MMA_ATOM_N; c0 += OUTER_N) {\n                            bool pred = (pred_V & (1U << (c0 / OUTER_N)));\n                            PRAGMA_UNROLL\n                            for (int cc = 0; cc < OUTER_N; cc += 8) {\n                                int c = c0 + cc;\n                                // clang-format off\n                                accum_C[m][n][c / 2 + 0] += (pred ? scales[0][1] : scales[0][0]) * frag_C[n][c / 2 + 0];\n                                accum_C[m][n][c / 2 + 1] += (pred ? scales[0][1] : scales[0][0]) * frag_C[n][c / 2 + 1];\n                                accum_C[m][n][c / 2 + 2] += (pred ? scales[1][1] : scales[1][0]) * frag_C[n][c / 2 + 2];\n                                accum_C[m][n][c / 2 + 3] += (pred ? scales[1][1] : scales[1][0]) * frag_C[n][c / 2 + 3];\n                                // clang-format on\n                            }\n                        }\n                    }\n                };\n\n                auto gmma = [&](int m) {\n                    PRAGMA_UNROLL\n                    for (int k = 0; k < MMA_ITER_K; ++k) {\n                        PRAGMA_UNROLL\n                        for (int n = 0; n < MMA_ITER_N; ++n) {\n                            wgmma<MMA_Atom>(smem_iter_A, smem_iter_B, frag_C[n], k == 0);\n                            smem_iter_B += kStepNB;\n                        }\n                        smem_iter_B -= MMA_ITER_N * kStepNB;\n                        smem_iter_A += kStepKA;\n                        smem_iter_B += kStepKB;\n                    }\n                    smem_iter_A -= MMA_ITER_K * kStepKA;\n                    smem_iter_B -= MMA_ITER_K * kStepKB;\n                    smem_iter_A += kStepMA;\n                    cute::warpgroup_commit_batch();\n                };\n\n                static_assert(MMA_ITER_N == 1);\n\n                wg_barrier.sync();\n\n                ProducerBar::wait(&producer_bar[pipe_state.index()], pipe_state.phase());\n                Load_V();\n                Load_U();\n                smem_iter_A.Reset(pipe_state.index());\n                smem_iter_B.Reset(pipe_state.index());\n                cute::warpgroup_arrive();\n                gmma(0);\n                scale_accum(0);\n                consumer_arrive();\n                ++pipe_state;\n                --k_iter;\n\n                Load_V();\n                ProducerBar::wait(&producer_bar[pipe_state.index()], pipe_state.phase());\n                Load_U();\n                smem_iter_A.Reset(pipe_state.index());\n                smem_iter_B.Reset(pipe_state.index());\n\n                for (; k_iter > 1; --k_iter) {\n                    cute::warpgroup_arrive();\n                    gmma(0);\n                    scale_accum(0);\n                    consumer_arrive();\n                    ++pipe_state;\n                    Load_V();\n                    ProducerBar::wait(&producer_bar[pipe_state.index()], pipe_state.phase());\n                    Load_U();\n                    smem_iter_A.Reset(pipe_state.index());\n                    smem_iter_B.Reset(pipe_state.index());\n                }\n\n                cute::warpgroup_arrive();\n                gmma(0);\n                scale_accum(0);\n                consumer_arrive();\n                ++pipe_state;\n                ++v_state;\n\n                const int wg_lane = threadIdx.x % WARPGROUP_SIZE;\n\n                if (wg_lane < LayoutC::C1) {\n                    cute::tma_store_wait<0>();\n                }\n\n                epi_barrier(0).arrive_and_wait_unaligned();\n\n                wg_barrier.sync();\n\n                // void* Cdesc{};\n                // if (threadIdx.x % WARPGROUP_SIZE / WARP_SIZE == 0) {\n                //     Cdesc = update_tma_desc(tm_c, tensormap_buf, &storage.tma_desc_C[wg_idx], wg_idx, param_C.ptr,\n                //     M);\n                // }\n\n                Tc* smem_C = &storage.C[wg_idx_m * WG_TILE_M * TILE_N + wg_idx_n * WG_TILE_N];\n\n                // epilogue\n                PRAGMA_UNROLL\n                for (int m = 0; m < MMA_ITER_M; ++m) {\n                    PRAGMA_UNROLL\n                    for (int n = 0; n < MMA_ITER_N; ++n) {\n\n                        constexpr int N       = LayoutC::C0;\n                        constexpr int SW_bits = log2(kSwizzleC / 16);\n\n                        static_assert(!SW_bits || MMA_ATOM_N % LayoutC::C0 == 0);\n\n                        const int m0 = m * MMA_ATOM_M;\n                        const int n0 = n * MMA_ATOM_N;\n\n                        PRAGMA_UNROLL\n                        for (int i = 0; i < MMA_ATOM_N; i += 16) {\n                            __align__(16) Array<Tc, 8> tvec = cast<Tc>(*(Array<float, 8>*)&accum_C[m][n][i / 2]);\n\n                            int mm = m0 + warp_id % 4 * 16 + (lane_id & 8);\n                            int nn = n0 + i / N * N;\n\n                            int addr = ((nn / N) * WG_TILE_M * N) + (mm * N) + (nn % N);\n\n                            int s = lane_id % 8;\n                            int c = (lane_id & 16) / 2 + i % N;\n\n                            addr += Swizzle<SW_bits, 3, 3>::apply(s * N + c);\n\n                            auto& uvec = (Array<uint32_t, 4>&)tvec;\n                            cute::SM90_U32x4_STSM_N::copy(uvec[0],  //\n                                                          uvec[1],\n                                                          uvec[2],\n                                                          uvec[3],\n                                                          (cutlass::uint128_t&)smem_C[addr]);\n                        }\n                    }\n                }\n\n                cute::tma_store_fence();  // visibility: smem -> async proxy\n\n                wg_barrier.sync();\n\n                if (wg_lane < LayoutC::C1) {\n                    const int tma_n = wg_lane * LayoutC::C0;\n                    // cute::tma_descriptor_fence_acquire((cute::TmaDescriptor*)Cdesc);\n                    cute::SM90_TMA_STORE::copy(&tm_c,\n                                               &smem_C[wg_lane * WG_TILE_M * LayoutC::C0],\n                                               offset_n + wg_idx_n * WG_TILE_N + tma_n,\n                                               offset_m + wg_idx_m * WG_TILE_M);\n                    cute::tma_store_arrive();\n                }\n\n                epi_barrier(1).arrive_unaligned();\n\n            }  // scheduler loop\n\n            if (threadIdx.x % WARPGROUP_SIZE < LayoutC::C1) {\n                cute::tma_store_wait<0>();\n            }\n\n            if (wg_idx == 0) {\n                epi_barrier(0).arrive_and_wait_unaligned();\n            }\n        }\n\n        if constexpr (kClusterSize > 1) {\n            cute::cluster_arrive();\n            cute::cluster_wait();\n        }\n\n    }  // operator()\n\n    struct EmptyBarrier {\n        __device__      EmptyBarrier(...) {}\n        __device__ void arrive_and_wait_unaligned() {}\n        __device__ void arrive_unaligned() {}\n    };\n\n    __device__ void* update_tma_desc(const CUtensorMap& param_desc,\n                                     CUtensorMap*       gmem_desc,\n                                     CUtensorMap*       smem_desc,\n                                     int                index,\n                                     void*              global_addr,\n                                     int                dim)\n    {\n        uint32_t uint_ptr = cast_smem_ptr_to_uint(smem_desc);\n\n        const int lane_id = threadIdx.x % WARP_SIZE;\n\n        if (lane_id < sizeof(CUtensorMap) / sizeof(uint2)) {\n            ((uint2*)smem_desc)[lane_id] = ((uint2*)&param_desc)[lane_id];\n        }\n\n        __syncwarp();\n\n        if (lane_id == 0) {\n            // clang-format off\n            asm volatile(\"tensormap.replace.tile.global_address.shared::cta.b1024.b64 [%0], %1;\" ::\"r\"(uint_ptr), \"l\"(global_addr));\n            asm volatile(\"tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [%0], 1, %1;\" ::\"r\"(uint_ptr), \"r\"(dim));\n            // clang-format on\n        }\n\n        __syncwarp();\n\n        constexpr int kNumPerCta = 4;\n\n        auto gmem_ptr = (void*)&gmem_desc[blockIdx.x * kNumPerCta + index];\n\n        // clang-format off\n        asm volatile(\"tensormap.cp_fenceproxy.global.shared::cta.tensormap::generic.release.gpu.sync.aligned [%0], [%1], 128;\" :: \"l\"(gmem_ptr), \"r\"(uint_ptr));\n        // clang-format on\n\n        return gmem_ptr;\n    }\n};\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/gemm_universal_sm90_v5.h",
    "content": "#pragma once\n\n#include <numeric>\n#include <utility>\n\n#include <cuda_fp8.h>\n#include <cuda_pipeline_primitives.h>\n\n#include \"cute/arch/cluster_sm90.hpp\"\n#include \"cute/arch/copy_sm90.hpp\"\n#include \"cute/arch/copy_sm90_tma.hpp\"\n#include \"cute/arch/mma_sm90_desc.hpp\"\n#include \"cute/arch/mma_sm90_gmma.hpp\"\n#include \"cute/atom/mma_traits.hpp\"\n#include \"cute/atom/mma_traits_sm90_gmma.hpp\"\n\n#include \"cutlass/arch/barrier.h\"\n#include \"cutlass/arch/reg_reconfig.h\"\n#include \"cutlass/cutlass.h\"\n#include \"cutlass/pipeline/sm90_pipeline.hpp\"\n\n#include \"src/turbomind/core/data_type.h\"\n\n#include \"src/turbomind/kernels/core/array_ops.h\"\n#include \"src/turbomind/kernels/core/common.h\"\n#include \"src/turbomind/kernels/core/layout.h\"\n#include \"src/turbomind/kernels/core/meta.h\"\n#include \"src/turbomind/kernels/core/smem.h\"\n\n#include \"src/turbomind/kernels/gemm/arch.h\"\n#include \"src/turbomind/kernels/gemm/cp_async.h\"\n#include \"src/turbomind/kernels/gemm/iterator_sm90.h\"\n#include \"src/turbomind/kernels/gemm/matrix_ptr.h\"\n#include \"src/turbomind/kernels/gemm/scheduler.cuh\"\n#include \"src/turbomind/kernels/gemm/types.h\"\n#include \"src/turbomind/kernels/gemm/utils.h\"\n\n#include \"src/turbomind/kernels/gemm/scaled_gmma_fp8_sm90.h\"\n#include \"src/turbomind/kernels/gemm/sm90_utils.h\"\n\nnamespace turbomind::gemm {\n\ntemplate<Order raster_order, int multicast_a, int multicast_b, bool is_grouped_gemm_>\nstruct GemmUniversalSm90_v5 {\n\n    static constexpr bool kDebug = false;\n\n    using Arch = Sm90;\n\n    static constexpr int WARPGORUPS = 4;\n\n    static constexpr int TILE_M = 128;\n    static constexpr int TILE_N = 96;\n    static constexpr int TILE_K = 128;\n\n    static constexpr int WG_M = 2;\n    static constexpr int WG_N = 1;\n\n    static constexpr int WG_TILE_M = TILE_M / WG_M;\n    static constexpr int WG_TILE_N = TILE_N / WG_N;\n\n    using GMMA = ScaledGmmaFP8_TN<WG_TILE_M, WG_TILE_N, TILE_K, 1, 1, 1, 1>;\n\n    static constexpr int kMulticastA = multicast_a;\n    static constexpr int kMulticastB = multicast_b;\n\n    static constexpr int kClusterSize = kMulticastA * kMulticastB;\n\n    static constexpr int Stages = 5;\n\n    static constexpr bool kSplitK     = false;\n    static constexpr int  kChunkSizeK = TILE_K;\n\n    static constexpr int WARPGROUP_SIZE = 128;\n    static constexpr int kMathGroupSize = 256;\n\n    static constexpr int CTA_SIZE = WARPGROUP_SIZE * (WARPGORUPS + 1);\n\n    using Ta = __nv_fp8_e4m3;\n    using Tb = __nv_fp8_e4m3;\n    using Tc = nv_bfloat16;\n\n    using Tu = float;\n    using Tv = float;\n\n    using Cluster = arch::Cluster<kMulticastB, kMulticastA, kRowMajor>;\n\n    static constexpr auto is_grouped_gemm = is_grouped_gemm_;\n\n    using Scheduler = TileScheduler<raster_order, Cluster, true, true, TILE_M, TILE_N, Stages, is_grouped_gemm>;\n\n    static constexpr int kMulticastU = is_grouped_gemm ? 1 : kMulticastA;\n\n    using ProducerBar = cutlass::arch::ClusterTransactionBarrier;\n    using ConsumerBar = cutlass::arch::ClusterBarrier;\n\n    static constexpr int MAX_K        = 32768;\n    static constexpr int MAX_K_BLOCKS = cdiv(MAX_K, 128);\n\n    static constexpr int kAlignmentU = 16 / sizeof(Tu);\n    static constexpr int kBoxU       = TILE_M + (is_grouped_gemm ? kAlignmentU : 0);\n\n    // Alignment requirement for SMEM addr. This forbids multicast factor 8.\n    static_assert(kMulticastU == 1 || sizeof(Tu) * kBoxU / kMulticastU % 128 == 0);\n\n    static constexpr int kTmaTxBytes =\n        sizeof(Ta) * (TILE_M * TILE_K) + sizeof(Tb) * (TILE_N * TILE_K) + sizeof(Tu) * kBoxU;\n\n    static constexpr int kTmaDescNum = 7;\n\n    // ! Smem addr must be SBO aligned for TMA load/store\n    struct SharedStorage {\n        struct Source {\n            __align__(1024) Array<Ta, Stages * TILE_M * TILE_K> A;\n            __align__(1024) Array<Tb, Stages * TILE_N * TILE_K> B;\n            __align__(1024) Tu U[Stages][round_up<int>(kBoxU, 128)];\n            __align__(1024) Tv V[Stages][2];\n        };\n        Source source;\n        __align__(1024) Array<Tc, TILE_M * TILE_N> C;\n        __align__(128) uint64_t producer_bar[Stages];\n        __align__(128) uint64_t consumer_bar[Stages];\n        __align__(128) CUtensorMap tma_desc_buf[kTmaDescNum];  //\n        int                         pipe_count[2];\n        typename Scheduler::Storage sched;\n    };\n\n    static constexpr int kSmemSize = sizeof(SharedStorage);\n\n    static constexpr int kSwizzleC = 2 * std::gcd(WG_TILE_N, 128 / sizeof(Tc));\n\n    using LayoutC = std::conditional_t<kSwizzleC >= 32,\n                                       SmemLayoutV2<WG_TILE_M, WG_TILE_N, -1, kSwizzleC / sizeof(Tc)>,\n                                       SmemLayoutV2<WG_TILE_M, WG_TILE_N>>;\n\n    static constexpr int OUTER_N       = GMMA::OUTER_N;\n    static constexpr int MMA_SUBTILE_N = GMMA::OP_N / OUTER_N;\n\n    __device__ void operator()(const CUtensorMap& tm_a,\n                               const CUtensorMap& tm_b,\n                               const CUtensorMap& tm_c,\n                               const CUtensorMap& tm_u,\n                               const CUtensorMap& tm_v,\n                               const MatrixParam& param_A,\n                               const MatrixParam& param_B,\n                               const MatrixParam& param_U,\n                               const MatrixParam& param_V,\n                               const MatrixParam& param_C,\n                               Scheduler          sched,\n                               CUtensorMap*       tensormap_buf,\n                               char*              smem_buf)\n    {\n        SharedStorage& storage = *reinterpret_cast<SharedStorage*>(smem_buf);\n\n        uint64_t* producer_bar = storage.producer_bar;\n        uint64_t* consumer_bar = storage.consumer_bar;\n\n        if (threadIdx.x == 0) {\n            PRAGMA_UNROLL\n            for (int s = 0; s < Stages; ++s) {\n                ProducerBar::init(&producer_bar[s], 1 + 1);\n                ConsumerBar::init(&consumer_bar[s], kClusterSize * (kMathGroupSize / WARP_SIZE));\n            }\n            sched.init_dyanmic(storage.sched, kClusterSize * (kMathGroupSize / WARP_SIZE + 1));\n            cutlass::arch::fence_view_async_shared();\n            if constexpr (kClusterSize > 1) {\n                cutlass::arch::fence_barrier_init();\n            }\n            PRAGMA_UNROLL\n            for (int i = 0; i < 2; ++i) {\n                storage.pipe_count[i] = 0;\n            }\n        }\n\n        (kClusterSize > 1) ? cute::cluster_sync() : __syncthreads();\n\n        const int wg_idx = cutlass::canonical_warp_group_idx();\n\n        if (wg_idx == WARPGORUPS) {\n            cutlass::arch::warpgroup_reg_dealloc<32>();\n\n            static_assert(TILE_M % kMulticastA == 0);\n            static_assert(TILE_N % kMulticastB == 0);\n\n            cutlass::arch::NamedBarrier producers_bar(WARP_SIZE * 2, 6);\n\n            const int  warp_id = cutlass::canonical_warp_idx_sync();\n            const bool cta_0   = cute::block_id_in_cluster().x == 0;\n\n            if (warp_id % 4 == 0) {\n\n                Cluster cluster(cute::block_id_in_cluster().x);\n\n                const int mc_offset_m = cluster.cta_n() * (TILE_M / kMulticastA);\n                const int mc_offset_n = cluster.cta_m() * (TILE_N / kMulticastB);\n\n                auto  smem_A = storage.source.A.data() + mc_offset_m * TILE_K;\n                auto  smem_B = storage.source.B.data() + mc_offset_n * TILE_K;\n                auto& smem_U = storage.source.U;\n                auto& smem_V = storage.source.V;\n\n                if constexpr (is_grouped_gemm) {\n                    init_tma_descs<3>({&tm_a, &tm_b, &tm_u}, storage.tma_desc_buf);\n                }\n\n                cutlass::PipelineState<Stages> write_state{0, 1, 0};\n\n                auto sched_state = sched.init_consumer(storage.sched);\n\n                int lane_predicate = cute::elect_one_sync();\n\n                typename Scheduler::Tile* tile;\n\n                while (sched_state.acquire(tile)) {\n\n                    // if (cute::elect_one_sync()) {\n                    //     printf(\"READ m %d n %d g %d v %s\\n\",\n                    //            tile->offset_m,\n                    //            tile->offset_n,\n                    //            tile->group_idx,\n                    //            tile->is_valid_cluster ? \"true\" : \"false\");\n                    // }\n\n                    if constexpr (Scheduler::is_dynamic) {\n                        if (cta_0) {\n                            producers_bar.arrive_unaligned();\n                        }\n                    }\n\n                    if (tile->is_valid_cluster) {\n\n                        const CUtensorMap* Adesc = &tm_a;\n                        const CUtensorMap* Bdesc = &tm_b;\n                        const CUtensorMap* Udesc = &tm_u;\n\n                        const Tv* gmem_V0 = (const Tv*)param_V.ptr;\n                        const Tv* gmem_V1;\n\n                        if constexpr (is_grouped_gemm) {\n                            const int g  = tile->group_idx;\n                            const int m0 = tile->m0;\n                            const int m1 = tile->m1;\n                            const int m  = m1 - m0;\n\n                            Array<void*, 3> global_addrs;\n                            global_addrs[0] = (Ta*)param_A.ptr + m0 * (int64_t)param_A.stride;\n                            global_addrs[1] = ((void**)param_B.ptr)[g];\n\n                            const int beg_u = m0 / kAlignmentU * kAlignmentU;\n                            const int end_u = round_up(m1, kAlignmentU);\n                            global_addrs[2] = (Tu*)param_U.ptr + beg_u;\n\n                            Array<int, 3> dims;\n                            dims[0] = m;\n                            dims[1] = sched.gemm_shape().y;\n                            dims[2] = end_u - beg_u;\n\n                            auto descs = update_tma_descs(tensormap_buf, storage.tma_desc_buf, global_addrs, dims);\n                            Adesc      = &descs[0];\n                            Bdesc      = &descs[1];\n                            Udesc      = &descs[2];\n\n                            gmem_V0 = ((Tv**)gmem_V0)[g];\n\n                            PRAGMA_UNROLL\n                            for (int i = 0; i < 3; ++i) {\n                                cute::tma_descriptor_fence_acquire((cute::TmaDescriptor*)&descs[i]);\n                            }\n                        }\n\n                        if (lane_predicate) {\n\n                            const int offset_k = 0;\n\n                            const uint16_t mask_A = cluster.mask_m();\n                            const uint16_t mask_B = cluster.mask_n();\n\n                            const int offset_m = tile->offset_m;\n                            const int offset_n = tile->offset_n;\n\n                            int k_iter = sched.k_iters_;\n\n                            GmemIteratorSm90<kMulticastA> gmem_A{\n                                Adesc, {offset_k, offset_m + mc_offset_m}, {TILE_K, 0}};\n                            GmemIteratorSm90<kMulticastB> gmem_B{\n                                Bdesc, {offset_k, offset_n + mc_offset_n}, {TILE_K, 0}};\n\n                            const int mc_offset_u = kMulticastU > 1 ? mc_offset_m : 0;\n                            // column-major\n                            GmemIteratorSm90<kMulticastU> gmem_U{\n                                Udesc, {offset_m + mc_offset_u, offset_k / 128}, {0, 1}};\n\n                            gmem_V0 += (offset_n / 128) * param_V.stride + (offset_k / 128);\n                            gmem_V1 = gmem_V0;\n                            if (offset_n / 128 + 1 < cdiv(sched.gemm_shape().y, 128)) {\n                                gmem_V1 += param_V.stride;\n                            }\n\n                            for (; k_iter > 0; --k_iter) {\n                                int pipe = write_state.index();\n                                ConsumerBar::wait(&consumer_bar[pipe], write_state.phase());\n                                ProducerBar::arrive_and_expect_tx(&producer_bar[pipe], kTmaTxBytes);\n                                gmem_A.Step(&producer_bar[pipe], &smem_A[pipe * TILE_M * TILE_K], mask_A);\n                                gmem_B.Step(&producer_bar[pipe], &smem_B[pipe * TILE_N * TILE_K], mask_B);\n                                gmem_U.Step(&producer_bar[pipe], &smem_U[pipe][0] + mc_offset_u, mask_A);\n                                uint32_t uint_ptr_V = cast_smem_ptr_to_uint(smem_V[pipe]);\n                                CP_ASYNC<CacheOp::kAlways, 4, 0>::apply(uint_ptr_V, gmem_V0, true);\n                                CP_ASYNC<CacheOp::kAlways, 4, 0>::apply(uint_ptr_V + sizeof(Tv), gmem_V1, true);\n                                ++gmem_V0;\n                                ++gmem_V1;\n                                cutlass::arch::cpasync_barrier_arrive_noinc(&producer_bar[pipe]);\n                                ++write_state;\n                            }\n                        }\n                    }\n\n                    sched_state.release();\n\n                }  // scheduler loop\n\n                sched_state.release();\n\n                // pair with the EXTRA tile\n                sched_state.acquire(tile);\n                sched_state.release();\n\n                if constexpr (kClusterSize > 1) {\n                    if (lane_predicate) {\n                        for (int i = 0; i < Stages; ++i) {\n                            ConsumerBar::wait(&consumer_bar[write_state.index()], write_state.phase());\n                            ++write_state;\n                        }\n                    }\n                    __syncwarp();\n                }\n            }\n            else if (warp_id % 4 == 1 && cta_0) {\n                auto sched_state = sched.init_producer(storage.sched);\n                while (sched_state.next()) {\n                    if constexpr (Scheduler::is_dynamic) {\n                        producers_bar.arrive_and_wait_unaligned();\n                    }\n                }\n                // send EXTRA null tile (to math WGs)\n                sched_state.next();\n                sched.tail(sched_state);\n            }\n        }\n        else {\n            cutlass::arch::warpgroup_reg_alloc<112>();\n\n            const int math_group_idx = wg_idx / 2;\n\n            if constexpr (is_grouped_gemm) {\n                if (threadIdx.x % WARPGROUP_SIZE / WARP_SIZE == 0) {\n                    init_tma_descs<1>({&tm_c}, storage.tma_desc_buf + 3 + wg_idx);\n                }\n            }\n\n            auto& smem_A = storage.source.A;\n            auto& smem_B = storage.source.B;\n            auto& smem_U = storage.source.U;\n            auto& smem_V = storage.source.V;\n\n            const int wg_idx_m = WG_M > 1 ? wg_idx % 2 % WG_M : 0;\n            const int wg_idx_n = WG_N > 1 ? wg_idx % 2 / WG_M : 0;\n\n            auto smem_desc_A = make_smem_desc(&smem_A[wg_idx_m * WG_TILE_M * TILE_K], 1);\n            auto smem_desc_B = make_smem_desc(&smem_B[wg_idx_n * WG_TILE_N * TILE_K], 1);\n\n            SmemDescIterV2<Stages, ((TILE_M * TILE_K) >> 4)> smem_iter_A{smem_desc_A};\n            SmemDescIterV2<Stages, ((TILE_N * TILE_K) >> 4)> smem_iter_B{smem_desc_B};\n\n            const int  thread_idx    = threadIdx.x % WARPGROUP_SIZE;\n            const bool math_leader_p = threadIdx.x % kMathGroupSize == 0;\n\n            auto math_barrier_sync = [&](int phase, int alive = 1) {\n                constexpr int base       = (int)cutlass::arch::ReservedNamedBarriers::FirstUserBarrier;\n                const int     barrier_id = base + math_group_idx ^ phase;\n                constexpr int threads    = WARPGORUPS * WARPGROUP_SIZE;\n                int           res        = 0;\n                asm volatile(\"{\\n\"\n                             \"  .reg.pred p;\\n\"\n                             \"  setp.ne.b32 p, %3, 0;\\n\"\n                             \"  barrier.cta.red.or.pred p, %1, %2, p;\\n\"\n                             \"  selp.s32 %0, 1, 0, p;\\n\"\n                             \"}\\n\"\n                             : \"=r\"(res)\n                             : \"r\"(barrier_id), \"r\"(threads), \"r\"(alive));\n                return res;\n            };\n\n            cutlass::arch::NamedBarrier barrier(WARPGROUP_SIZE, 2 + wg_idx);  // 2,3,4,5\n\n            cutlass::PipelineState<Stages> pipe_state{};\n\n            const int warp_id = cutlass::canonical_warp_idx_sync();\n            const int lane_id = cutlass::canonical_lane_idx();\n\n            auto consumer_arrive = [&] {\n                auto bar = &consumer_bar[pipe_state.index()];\n                __syncwarp();\n                if constexpr (kClusterSize > 1) {\n                    ConsumerBar::arrive(bar, lane_id, lane_id < kClusterSize);\n                }\n                else {\n                    if (lane_id == 0) {\n                        ConsumerBar::arrive(bar);\n                    }\n                }\n            };\n\n            auto sched_state = sched.init_consumer(storage.sched);\n\n            if (math_group_idx == 1) {\n                ++sched_state.pipe;\n                math_barrier_sync(1);\n            }\n\n            typename Scheduler::Tile* tile;\n\n            sched_state.acquire(tile);\n\n            while (tile->alive) {\n\n                if (tile->is_valid_cta) {\n\n                    GMMA::AccumC accum_C{};\n                    GMMA::FragC  frag_C;\n\n                    const auto [_, N, K, L] = sched.gemm_shape();\n\n                    const int offset_m = tile->offset_m;\n                    const int offset_n = tile->offset_n;\n\n                    int k_iter = sched.k_iters_;\n\n                    auto pred_V = Fetch_V(param_V, K, N, tile, math_group_idx, wg_idx_n, storage);\n\n                    float scale_V[2];\n                    auto  Load_V = [&] {\n                        scale_V[0] = smem_V[pipe_state.index()][0];\n                        scale_V[1] = smem_V[pipe_state.index()][1];\n                    };\n\n                    int offset_U = wg_idx_m * WG_TILE_M + warp_id % 4 * 16 + lane_id / 4;\n                    if constexpr (is_grouped_gemm) {\n                        offset_U += tile->m0 % kAlignmentU;\n                    }\n                    GMMA::FragU frag_U;\n                    auto        Load_U = [&] {\n                        GMMA::foreach_m(frag_U, [&](auto& U, int m) {\n                            U[0] = smem_U[pipe_state.index()][offset_U + m * GMMA::OP_M];\n                            U[1] = smem_U[pipe_state.index()][offset_U + m * GMMA::OP_M + 8];\n                        });\n                    };\n\n                    auto gmma = [&] {  //\n                        GMMA::apply(smem_iter_A, smem_iter_B, frag_C, accum_C, frag_U, scale_V, pred_V);\n                    };\n\n                    if constexpr (is_grouped_gemm) {\n                        if (warp_id % 4 == 0) {\n                            int  m0 = tile->m0, m1 = tile->m1;\n                            auto addr = (Tc*)param_C.ptr + m0 * (int64_t)param_C.stride;\n                            int  idx  = 3 + wg_idx;\n                            update_tma_descs<1>(tensormap_buf + idx, storage.tma_desc_buf + idx, {addr}, {m1 - m0});\n                        }\n                    }\n\n                    math_barrier_sync(0);\n\n                    pipe_state = {};\n                    pipe_state.advance(storage.pipe_count[math_group_idx ^ 1]);\n\n                    ProducerBar::wait(&producer_bar[pipe_state.index()], pipe_state.phase());\n                    Load_V();\n                    Load_U();\n                    smem_iter_A.Reset(pipe_state.index());\n                    smem_iter_B.Reset(pipe_state.index());\n                    gmma();\n                    consumer_arrive();\n                    ++pipe_state;\n                    --k_iter;\n\n                    ProducerBar::wait(&producer_bar[pipe_state.index()], pipe_state.phase());\n                    Load_V();\n                    Load_U();\n                    smem_iter_A.Reset(pipe_state.index());\n                    smem_iter_B.Reset(pipe_state.index());\n\n                    PRAGMA_NO_UNROLL\n                    for (; k_iter > 1; --k_iter) {\n                        gmma();\n                        consumer_arrive();\n                        ++pipe_state;\n                        ProducerBar::wait(&producer_bar[pipe_state.index()], pipe_state.phase());\n                        Load_V();\n                        Load_U();\n                        smem_iter_A.Reset(pipe_state.index());\n                        smem_iter_B.Reset(pipe_state.index());\n                    }\n\n                    if (math_leader_p) {\n                        storage.pipe_count[math_group_idx] = pipe_state.count() + 1;\n                    }\n                    math_barrier_sync(1);\n\n                    gmma();\n                    consumer_arrive();\n\n                    Tc* smem_C = &storage.C[wg_idx_m * WG_TILE_M * TILE_N + wg_idx_n * WG_TILE_N];\n\n                    GMMA::foreach_C(accum_C, [&](const auto& C, int m, int n) {\n                        constexpr int N       = LayoutC::C0;\n                        constexpr int SW_bits = log2(kSwizzleC / 16);\n\n                        static_assert(!SW_bits || GMMA::OP_N % LayoutC::C0 == 0);\n                        static_assert(GMMA::OP_N % 16 == 0);\n\n                        const int m0 = m * GMMA::OP_M;\n                        const int n0 = n * GMMA::OP_N;\n\n                        PRAGMA_UNROLL\n                        for (int i = 0; i < GMMA::OP_N; i += 16) {\n                            __align__(16) Array<Tc, 8> tvec = cast<Tc>(*(Array<float, 8>*)&C[i / 2]);\n                            // fill(tvec, Tc(255));\n                            int mm = m0 + warp_id % 4 * 16 + (lane_id & 8);\n                            int nn = n0 + i / N * N;\n\n                            int addr = ((nn / N) * WG_TILE_M * N) + (mm * N) + (nn % N);\n\n                            int s = lane_id % 8;\n                            int c = (lane_id & 16) / 2 + i % N;\n\n                            addr += Swizzle<SW_bits, 3, 3>::apply(s * N + c);\n\n                            auto& uvec = (Array<uint32_t, 4>&)tvec;\n                            cute::SM90_U32x4_STSM_N::copy(\n                                uvec[0], uvec[1], uvec[2], uvec[3], (cutlass::uint128_t&)smem_C[addr]);\n                        }\n                    });\n\n                    cute::tma_store_fence();  // visibility: smem -> async proxy\n\n                    barrier.sync();\n\n                    if (thread_idx < LayoutC::C1) {\n                        const void* Cdesc = &tm_c;\n                        const int   tma_n = thread_idx * LayoutC::C0;\n                        if constexpr (is_grouped_gemm) {\n                            Cdesc = &tensormap_buf[blockIdx.x * kTmaDescNum + 3 + wg_idx];\n                            cute::tma_descriptor_fence_acquire((cute::TmaDescriptor*)Cdesc);\n                        }\n                        cute::SM90_TMA_STORE::copy(Cdesc,\n                                                   &smem_C[thread_idx * WG_TILE_M * LayoutC::C0],\n                                                   offset_n + wg_idx_n * WG_TILE_N + tma_n,\n                                                   offset_m + wg_idx_m * WG_TILE_M);\n                        cute::tma_store_arrive();\n                        cute::tma_store_wait<0>();\n                    }\n\n                }  // valid cta tile\n                else {\n                    math_barrier_sync(0);\n\n                    pipe_state = {};\n                    pipe_state.advance(storage.pipe_count[math_group_idx ^ 1]);\n\n                    if (tile->is_valid_cluster) {\n                        // other CTAs in the cluster are still alive\n                        for (int k_iter = sched.k_iters_; k_iter > 0; --k_iter) {\n                            ProducerBar::wait(&producer_bar[pipe_state.index()], pipe_state.phase());\n                            consumer_arrive();\n                            ++pipe_state;\n                        }\n                    }\n\n                    if (math_leader_p) {\n                        storage.pipe_count[math_group_idx] = pipe_state.count();\n                    }\n\n                    math_barrier_sync(1);\n                }\n\n                sched_state.release(2);\n                sched_state.acquire(tile);\n            }  // scheduler loop\n\n            sched_state.release(2);  // release the last tile\n\n            if (math_group_idx == 0) {\n                math_barrier_sync(0, 0);\n                if (math_leader_p) {\n                    storage.pipe_count[0] = storage.pipe_count[1];\n                }\n                while (math_barrier_sync(1, 0)) {\n                    math_barrier_sync(0, 0);\n                    if (math_leader_p) {\n                        storage.pipe_count[0] = storage.pipe_count[1];\n                    }\n                }\n            }\n            else {\n                while (math_barrier_sync(0, 0)) {\n                    if (math_leader_p) {\n                        storage.pipe_count[1] = storage.pipe_count[0];\n                    }\n                    math_barrier_sync(1, 0);\n                }\n            }\n        }\n\n    }  // operator()\n\n    template<int N>\n    __device__ void init_tma_descs(Array<const CUtensorMap*, N> param_desc, CUtensorMap* smem_desc)\n    {\n        const int lane_id = threadIdx.x % WARP_SIZE;\n\n        if (lane_id < sizeof(CUtensorMap) / sizeof(uint2)) {\n            PRAGMA_UNROLL\n            for (int i = 0; i < N; ++i) {\n                ((uint2*)&smem_desc[i])[lane_id] = ((uint2*)param_desc[i])[lane_id];\n            }\n        }\n\n        __syncwarp();\n    }\n\n    template<int N>\n    __device__ CUtensorMap*\n    update_tma_descs(CUtensorMap* gmem_desc, CUtensorMap* smem_desc, Array<void*, N> global_addrs, Array<int, N> dims)\n    {\n        const int lane_id = threadIdx.x % WARP_SIZE;\n        if (lane_id == 0) {\n            PRAGMA_UNROLL\n            for (int i = 0; i < N; ++i) {\n                uint32_t uint_ptr = cast_smem_ptr_to_uint(&smem_desc[i]);\n                // clang-format off\n                asm volatile(\"tensormap.replace.tile.global_address.shared::cta.b1024.b64 [%0], %1;\" ::\"r\"(uint_ptr), \"l\"(global_addrs[i]));\n                if (i != 2) {\n                    asm volatile(\"tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [%0], 1, %1;\" ::\"r\"(uint_ptr), \"r\"(dims[i]));\n                } else { // special case for U\n                    asm volatile(\"tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [%0], 0, %1;\" ::\"r\"(uint_ptr), \"r\"(dims[i]));\n                }\n                // clang-format on\n            }\n        }\n\n        __syncwarp();\n\n        auto gmem_ptr = &gmem_desc[blockIdx.x * kTmaDescNum];\n        PRAGMA_UNROLL\n        for (int i = 0; i < N; ++i) {\n            uint32_t uint_ptr = cast_smem_ptr_to_uint(&smem_desc[i]);\n            // clang-format off\n            asm volatile(\"tensormap.cp_fenceproxy.global.shared::cta.tensormap::generic.release.gpu.sync.aligned [%0], [%1], 128;\" :: \"l\"(gmem_ptr + i), \"r\"(uint_ptr));\n            // clang-format on\n        }\n\n        return gmem_ptr;\n    }\n\n    __device__ auto Fetch_V(const MatrixParam&        param_V,\n                            int                       K,\n                            int                       N,\n                            typename Scheduler::Tile* tile,\n                            int                       math_group_idx,\n                            int                       wg_idx_n,\n                            SharedStorage&            storage)\n    {\n        const int offset_n = tile->offset_n;\n\n        Array<bool, MMA_SUBTILE_N> pred_V{};\n\n        if constexpr (MMA_SUBTILE_N != 1) {\n            int offset = offset_n % 128 + wg_idx_n * WG_TILE_N;\n            static_assert(WG_N == 1);\n            // Safely skip pred_V_0 when distributing WGs along M\n            PRAGMA_UNROLL\n            for (int i = 1; i < MMA_SUBTILE_N; ++i) {\n                pred_V[i] = (i * OUTER_N + offset) >= 128;\n            }\n        }\n\n        return pred_V;\n    }\n};\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/gpu_metric.cu",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include \"src/turbomind/kernels/core/array.h\"\n#include \"src/turbomind/kernels/core/common.h\"\n#include \"src/turbomind/kernels/core/math.h\"\n#include \"src/turbomind/kernels/gemm/gpu_metric.h\"\n#include <thrust/device_vector.h>\n\n#include <cublas_v2.h>\n\nnamespace turbomind::gemm {\n\nusing thrust::device_vector;\n\nnamespace {\n\ntemplate<int BLOCK_NUM, int BLOCK_DIM, int LOG_TILE>\n__global__ void l2_bw(float* dsink, const float* array, int count)\n{\n    int    tid = threadIdx.x + (blockIdx.x >> LOG_TILE) * blockDim.x;\n    float4 sink{};\n\n    constexpr int NUM_THREADS = BLOCK_NUM * BLOCK_DIM;\n\n    for (int i = 0; i < count; i += NUM_THREADS * 4) {\n        const float* ptr    = array + i;\n        const int    offset = tid * 4;\n        float4       data   = __ldcg(reinterpret_cast<const float4*>(ptr + offset));\n        sink.x += data.x;\n        sink.y += data.y;\n        sink.z += data.z;\n        sink.w += data.w;\n    }\n\n    dsink[threadIdx.x] = sink.x + sink.y + sink.z + sink.w;\n}\n\n}  // namespace\n\nfloat MeasureL2CacheThroughput()\n{\n    cudaDeviceProp prop{};\n    int            device{};\n    cudaGetDevice(&device);\n    cudaGetDeviceProperties(&prop, device);\n\n    size_t size = static_cast<size_t>(prop.l2CacheSize) * 64;\n\n    std::cout << size << std::endl;\n\n    constexpr int BLOCK_X  = 128;  // blocks participating single sweep\n    constexpr int BLOCK_Y  = 128;  // full sweep iters\n    constexpr int LOG_TILE = 5;    // swizzling factor to bring up L2 hit rate, set to 0 will minimize hit rate\n\n    constexpr int BLOCK_DIM = 256;\n\n    constexpr int CHUNK_SIZE = BLOCK_X * BLOCK_DIM * 4;  // x4 for float4 load pattern\n\n    device_vector<float> data(ceil_div(size, sizeof(float)) / CHUNK_SIZE * CHUNK_SIZE);\n    device_vector<float> dsink(BLOCK_DIM);\n\n    cudaStream_t stream;\n    cudaStreamCreate(&stream);\n\n    cudaMemsetAsync(data.data().get(), 0, sizeof(float) * data.size(), stream);\n\n    cudaEvent_t ev_start, ev_end;\n\n    cudaEventCreate(&ev_start);\n    cudaEventCreate(&ev_end);\n\n    cudaEventRecord(ev_start, stream);\n\n    l2_bw<BLOCK_X, BLOCK_DIM, LOG_TILE><<<dim3(BLOCK_X << LOG_TILE, BLOCK_Y >> LOG_TILE), BLOCK_DIM, 0, stream>>>(\n        dsink.data().get(), data.data().get(), data.size());\n\n    cudaEventRecord(ev_end, stream);\n\n    cudaEventSynchronize(ev_end);\n\n    float ms{};\n    cudaEventElapsedTime(&ms, ev_start, ev_end);\n\n    size_t bytes = BLOCK_Y * sizeof(float) * data.size();\n\n    const float bytes_per_second = bytes / ms * 1e3;\n    std::cout << bytes_per_second / 1e9 << \" GB/s\" << std::endl;\n\n    cudaEventDestroy(ev_start);\n    cudaEventDestroy(ev_end);\n\n    cudaStreamDestroy(stream);\n\n    return bytes_per_second;\n}\n\nfloat MeasureMmaThroughput(int problem_size)\n{\n    device_vector<half> a(problem_size * problem_size);\n    device_vector<half> b(a.size());\n    device_vector<half> c(a.size());\n\n    cublasHandle_t cublas{};\n    cublasCreate(&cublas);\n\n    cudaStream_t stream;\n    cudaStreamCreate(&stream);\n\n    cublasSetStream(cublas, stream);\n\n    cudaEvent_t ev_start, ev_end;\n\n    cudaEventCreate(&ev_start);\n    cudaEventCreate(&ev_end);\n\n    cudaEventRecord(ev_start, stream);\n\n    float alpha = 1.f;\n    float beta  = 0.f;\n    cublasGemmEx(cublas,\n                 CUBLAS_OP_N,\n                 CUBLAS_OP_N,\n                 problem_size,\n                 problem_size,\n                 problem_size,\n                 &alpha,\n                 a.data().get(),\n                 CUDA_R_16F,\n                 problem_size,\n                 b.data().get(),\n                 CUDA_R_16F,\n                 problem_size,\n                 &beta,\n                 c.data().get(),\n                 CUDA_R_16F,\n                 problem_size,\n                 CUBLAS_COMPUTE_32F,\n                 CUBLAS_GEMM_DEFAULT);\n\n    cudaEventRecord(ev_end, stream);\n\n    cudaEventSynchronize(ev_end);\n\n    float ms{};\n    cudaEventElapsedTime(&ms, ev_start, ev_end);\n\n    cudaEventDestroy(ev_start);\n    cudaEventDestroy(ev_end);\n\n    cudaStreamDestroy(stream);\n\n    cublasDestroy(cublas);\n\n    const size_t ops = (size_t)problem_size * problem_size * problem_size;\n\n    float fma_per_second = ops / ms * 1e3;\n\n    std::cout << 2 * fma_per_second / 1e9 << \" FLOPS/s\" << std::endl;\n\n    return fma_per_second;\n}\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/gpu_metric.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include \"src/turbomind/kernels/gemm/types.h\"\n\nnamespace turbomind::gemm {\n\n// bytes / second\nfloat MeasureL2CacheThroughput();\n\n// fused multiply-add / second\nfloat MeasureMmaThroughput(int proble_size = 16384);\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/iterator.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include \"src/turbomind/kernels/core/array.h\"\n#include \"src/turbomind/kernels/core/data_type.h\"\n#include \"src/turbomind/kernels/core/meta.h\"\n#include \"src/turbomind/kernels/gemm/matrix_ptr.h\"\n#include \"src/turbomind/kernels/gemm/thread_map.h\"\n#include \"src/turbomind/kernels/gemm/types.h\"\n#include \"src/turbomind/kernels/gemm/utils.h\"\n\nnamespace turbomind::gemm {\n\nstruct VoidGmemIter {\n    static constexpr int  ITER_S = 0;\n    static constexpr auto kMode  = Striding::kFlat;\n    using Fragments              = int;\n    __device__      VoidGmemIter(...) {}\n    __device__ void ClearSmem() {}\n    __device__ void Prefetch(int, int, bool) {}\n    __device__ void Prefetch(bool) {}\n    __device__ void Fetch(Fragments&, bool) {}\n    __device__ void Store(const Fragments&) {}\n    __device__ void Advance() {}\n    int*            smem_data_;\n    bool            g_mask{false};\n};\n\nstruct GetGmemIter {\n    template<class Operand, class Iterator, class SmemLayout, int M, int K, int WARPS>\n    static constexpr auto\n        apply(basic_type<Operand>, basic_type<Iterator>, basic_type<SmemLayout>, pair<M, K>, constant<WARPS>)\n    {\n        using Dtype = typename Operand::Dtype;\n\n        constexpr int kAccessSize =\n            std::min<int>(128 / bitsof<Dtype>, std::max<int>(32 / bitsof<Dtype>, M * K / (WARPS * WARP_SIZE)));\n\n        constexpr int2 kAligned = mk2cs<Operand::kOrder>(0, 1);\n        constexpr int2 kCS      = mk2cs<Operand::kOrder>(M, K);\n\n#if 0\n        constexpr int kMaxThrS = std::min(WARP_SIZE, ceil_div(kCS.y, WARPS));\n        constexpr int kMaxThrC = std::min(WARP_SIZE, ceil_div(kCS.x, kAccessSize));\n\n        constexpr int kTgtThrC = ceil_div<int>(256, sizeof(Array<Dtype, kAccessSize>));\n\n        constexpr int kWarpThrC = std::min(kMaxThrC, std::max(WARP_SIZE / kMaxThrS, kTgtThrC));\n#endif\n        using GmemIter = typename Iterator::template Type<Dtype,\n                                                          gemm::ThreadMap_V2<kCS.x, kCS.y, kAccessSize, Blocked, WARPS>,\n                                                          SmemLayout,\n                                                          Operand::kPack,\n                                                          Operand::kOrder,\n                                                          kAligned.x,   // aligned C\n                                                          kAligned.y>;  // aligned S\n        return type_c<GmemIter>;\n    }\n};\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/iterator_sm70.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include \"src/turbomind/kernels/core/array_ops.h\"\n#include \"src/turbomind/kernels/core/common.h\"\n#include \"src/turbomind/kernels/core/data_type.h\"\n#include \"src/turbomind/kernels/core/layout.h\"\n#include \"src/turbomind/kernels/gemm/cp_async.h\"\n#include \"src/turbomind/kernels/gemm/matrix_ptr.h\"\n#include \"src/turbomind/kernels/gemm/predicate.h\"\n#include \"src/turbomind/kernels/gemm/types.h\"\n#include \"src/turbomind/kernels/gemm/utils.h\"\n#include <cassert>\n#include <type_traits>\n\nnamespace turbomind::gemm {\n\ntemplate<typename T, int N>\ninline __device__ void _Ld(Array<T, N>& dst, const T* src)\n{\n    static_assert(sizeof(Array<T, N>) <= sizeof(uint4));\n\n    if constexpr (sizeof(Array<T, N>) == sizeof(uint4)) {\n        (uint4&)dst = __ldcs((const uint4*)src);\n    }\n    else if constexpr (sizeof(Array<T, N>) == sizeof(uint2)) {\n        (uint2&)dst = __ldcs((const uint2*)src);\n    }\n    else if constexpr (sizeof(Array<T, N>) == sizeof(uint)) {\n        (uint&)dst = __ldcs((const uint*)src);\n    }\n    else {\n        static_assert(!std::is_same_v<T, T>);\n    }\n}\n\ntemplate<class T,\n         class Map,\n         class SmemLayout,\n         Pack     kPack,\n         Order    kOrder,\n         bool     AlignedC,\n         bool     AlignedS,\n         Striding mode,\n         class Policy_>\nstruct GmemIteratorSm70 {\n\n    using ThreadMap = Map;\n\n    using AccessType = Array<T, Map::kAccessC>;\n    using Pointer    = get_pointer_type<T>;\n\n    using Policy = Policy_;\n\n    static constexpr int ITER_S = Map::kIterS;\n    static constexpr int ITER_C = Map::kIterC;\n\n    static constexpr Striding kMode      = mode;\n    static constexpr bool     is_indexed = mode == Striding::kIndexed;\n\n    const char* src_data_;\n\n    int src_offset_;\n    int dst_offset_;\n\n    int offset_c_;\n    int offset_s_;\n\n    int src_step_c_;\n    int src_step_s_;\n\n    int src_step_k_;\n\n    Predicate<Map::kIterS, Map::kIterC, (AlignedC && Map::kAlignedC), (AlignedS && Map::kAlignedS)> pred_;\n\n    bool g_mask{true};\n\n    SmemAccessor<T, SmemLayout> smem_data_;\n\n    static constexpr int2 kMK0     = cs2mk<kOrder>(SmemLayout::C0, SmemLayout::S0);\n    static constexpr int  kPeriodC = ceil_div(SmemLayout::C0, Map::kDeltaC);\n    static constexpr int  kPeriodS = ceil_div(SmemLayout::S0, Map::kDeltaS);\n\n    int phases_[kPeriodS][kPeriodC];\n\n    const char* src_data_vec_[ITER_S];\n\n    using Fragments = AccessType[Map::kIterS][Map::kIterC];\n\n    __device__ static constexpr int2 pack(int2 mk)\n    {\n        return Packing_v2<kPack, kOrder>::apply(mk);\n    }\n\n    __device__ static constexpr int2 to_cs(int2 mk)\n    {\n        return mk2cs<kOrder>(mk.x, mk.y);\n    }\n\n    __device__ GmemIteratorSm70(): smem_data_{Pointer{nullptr}} {};\n\n    __device__ GmemIteratorSm70(const MatrixData& mat, int2 offset, int2 extent): smem_data_{Pointer{(T*)nullptr}}\n    {\n        const int warp_id = threadIdx.x / WARP_SIZE;\n        const int lane_id = threadIdx.x % WARP_SIZE;\n\n        const Pointer data{(T*)mat.ptr.ptr};\n        const int     ld = mat.ptr.stride;\n\n        const int2 offsets = Map::get_offset(warp_id, lane_id);\n\n        offset_c_ = offsets.x;\n        offset_s_ = offsets.y;\n\n        // auto src_ptr = reinterpret_cast<const char*>((T*)data);\n\n        if constexpr (pred_.is_active) {\n            extent = to_cs(pack(extent));\n            PRAGMA_UNROLL\n            for (int s = 0; s < Map::kIterS; ++s) {\n                PRAGMA_UNROLL\n                for (int c = 0; c < Map::kIterC; ++c) {\n                    int ss = offset_s_ + s * Map::kDeltaS;\n                    int cc = offset_c_ + c * Map::kDeltaC;\n                    if (ss < extent.y && cc < extent.x) {\n                        pred_.set(s, c);\n                    }\n                }\n            }\n        }\n\n        PRAGMA_UNROLL\n        for (int s = 0; s < kPeriodS; ++s) {\n            PRAGMA_UNROLL\n            for (int c = 0; c < kPeriodC; ++c) {\n                phases_[s][c] = SmemLayout::apply(offset_s_ + s * Map::kDeltaS, offset_c_ + c * Map::kDeltaC);\n            }\n        }\n\n        const int src_offset = is_indexed ? offsets.x : offsets.x + offsets.y * ld;\n\n        src_offset_ = src_offset * bitsof<T> / bitsof<char>;\n\n        src_step_c_ = bitsof<T> * Map::kDeltaC / bitsof<char>;\n        src_step_s_ = bitsof<T> * Map::kDeltaS * ld / bitsof<char>;\n\n        src_step_k_ = bitsof<T> * cs2mk<kOrder>(Map::kDimC, Map::kDimS * ld).y / bitsof<char>;\n\n        // initialize for the first tile\n        if constexpr (is_indexed) {\n            const int2 cta_cs = to_cs(offset);\n            for (int s = 0; s < ITER_S; ++s) {\n                const int  ss    = cta_cs.y + offset_s_ + s * Map::kDeltaS;\n                const int  idx   = (mat.idxs && pred_(s, 0)) ? __ldg(mat.idxs + ss) : ss;\n                const auto tmp   = data + cs2idx({cta_cs.x, idx}, ld);\n                src_data_vec_[s] = reinterpret_cast<const char*>((T*)tmp) + src_offset_;\n            }\n        }\n        else {\n            auto src_data = data + cs2idx(to_cs(pack(offset)), ld);\n            src_data_     = reinterpret_cast<const char*>((T*)src_data) + src_offset_;\n        }\n    }\n\n    __device__ constexpr int _src_step_k() const\n    {\n        return src_step_k_;\n    }\n\n    __device__ void ClearSmem(int pipe_iter = 0)\n    {\n        PRAGMA_UNROLL\n        for (int s = 0; s < Map::kIterS; ++s) {\n            PRAGMA_UNROLL\n            for (int c = 0; c < Map::kIterC; ++c) {\n                const int pred_s = offset_s_ + s * Map::kDeltaS < Map::kDimS;\n                const int pred_c = offset_c_ + c * Map::kDeltaC < Map::kDimC;\n                auto      ptr    = &smem_data_(offset_s_ + s * Map::kDeltaS, offset_c_ + c * Map::kDeltaC);\n                if ((Map::kAlignedC && Map::kAlignedS) || (pred_s && pred_c)) {\n                    turbomind::Store(ptr, Array<T, Map::kAccessC>{});\n                }\n            }\n        }\n    }\n\n    __device__ void Advance()\n    {\n        if constexpr (!is_indexed) {\n            if (!g_mask) {\n                src_data_ -= _src_step_k();\n            }\n        }\n    }\n\n    __device__ void Copy(std::true_type, T* dst, const char* __restrict__ src, bool mask)\n    {\n        if (mask) {\n            AccessType frag;\n            if constexpr (Policy_::kEvictPolicy != EvictPolicy::kEvictNormal) {\n                _Ld(frag, (const T*)src);\n            }\n            else {\n                Ldg(frag, (const T*)src);\n            }\n            turbomind::Store(dst, frag);\n        }\n    }\n\n    __device__ void Fetch(Fragments& frags, bool tile_mask)\n    {\n        PRAGMA_UNROLL\n        for (int s = 0; s < Map::kIterS; ++s) {\n\n            if constexpr (is_indexed) {\n                src_data_ = src_data_vec_[s];\n            }\n\n            PRAGMA_UNROLL\n            for (int c = 0; c < Map::kIterC; ++c) {\n                Copy2(frags[s][c], src_data_ + src_step_c_ * c, tile_mask && g_mask && pred_(s, c));\n            }\n\n            if constexpr (is_indexed) {\n                src_data_vec_[s] += _src_step_k();\n            }\n            else {\n                src_data_ += src_step_s_;\n                if (s == Map::kIterS - 1) {\n                    src_data_ -= src_step_s_ * Map::kIterS;\n                    src_data_ += _src_step_k();\n                }\n            }\n        }\n    }\n\n    __device__ void Store(Fragments& frags)\n    {\n        PRAGMA_UNROLL\n        for (int s = 0; s < Map::kIterS; ++s) {\n            PRAGMA_UNROLL\n            for (int c = 0; c < Map::kIterC; ++c) {\n                // auto dst = &smem_data_(offset_s_ + s * Map::kDeltaS, offset_c_ + c * Map::kDeltaC);\n\n                const int i0  = SmemLayout::apply(  //\n                    s / kPeriodS * kPeriodS * Map::kDeltaS,\n                    c / kPeriodC * kPeriodC * Map::kDeltaC);\n                const int i1  = phases_[s % kPeriodS][c % kPeriodC];\n                auto      dst = &smem_data_.ptr_[i0 + i1];\n\n                if (pred_(s, c)) {\n                    turbomind::Store(dst, frags[s][c]);\n                }\n            }\n        }\n    }\n\n    __device__ void Copy2(AccessType& frag, const char* __restrict__ src, bool mask)\n    {\n        if (mask) {\n            if constexpr (Policy_::kEvictPolicy != EvictPolicy::kEvictNormal) {\n                _Ld(frag, (const T*)src);\n            }\n            else {\n                Ldg(frag, (const T*)src);\n            }\n        }\n    }\n};\n\ntemplate<Striding mode, class Policy>\nstruct IteratorSm70 {\n    template<class T, class Map, class SmemLayout, Pack kPack, Order kOrder, bool AlignedC, bool AlignedS>\n    using Type = GmemIteratorSm70<T, Map, SmemLayout, kPack, kOrder, AlignedC, AlignedS, mode, Policy>;\n};\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/iterator_sm80.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include \"src/turbomind/kernels/core/array.h\"\n#include \"src/turbomind/kernels/core/common.h\"\n#include \"src/turbomind/kernels/core/data_type.h\"\n#include \"src/turbomind/kernels/core/layout.h\"\n#include \"src/turbomind/kernels/core/smem.h\"\n#include \"src/turbomind/kernels/gemm/cp_async.h\"\n#include \"src/turbomind/kernels/gemm/matrix_ptr.h\"\n#include \"src/turbomind/kernels/gemm/predicate.h\"\n#include \"src/turbomind/kernels/gemm/types.h\"\n#include \"src/turbomind/kernels/gemm/utils.h\"\n#include <cassert>\n#include <type_traits>\n\nnamespace turbomind::gemm {\n\ntemplate<class T,\n         class Map,\n         class SmemLayout,\n         Pack     kPack,\n         Order    kOrder,\n         bool     AlignedC,\n         bool     AlignedS,\n         Striding mode,\n         class Policy_>\nstruct GmemIteratorSm80 {\n\n    using ThreadMap = Map;\n\n    using AccessType = Array<T, Map::kAccessC>;\n    using Pointer    = get_pointer_type<T>;\n\n    using Policy = Policy_;\n\n    static constexpr int ITER_S = Map::kIterS;\n    static constexpr int ITER_C = Map::kIterC;\n\n    static constexpr Striding kMode      = mode;\n    static constexpr bool     is_indexed = mode == Striding::kIndexed;\n\n    const char* src_data_;\n\n    int src_offset_;\n    int dst_offset_;\n\n    int offset_c_;\n    int offset_s_;\n\n    int src_step_c_;\n    int src_step_s_;\n\n    int src_step_k_;\n\n    Predicate<Map::kIterS, Map::kIterC, (AlignedC && Map::kAlignedC), (AlignedS && Map::kAlignedS)> pred_;\n\n    bool g_mask{true};\n\n    SmemAccessor<T, SmemLayout> smem_data_;\n\n    static constexpr int2 kMK0     = cs2mk<kOrder>(SmemLayout::C0, SmemLayout::S0);\n    static constexpr int  kPeriodC = ceil_div(SmemLayout::C0, Map::kDeltaC);\n    static constexpr int  kPeriodS = ceil_div(SmemLayout::S0, Map::kDeltaS);\n\n    int phases_[kPeriodS][kPeriodC];\n\n    const char* src_data_vec_[ITER_S];\n\n    uint64_t cache_policy_{};\n\n    __device__ static constexpr int2 pack(int2 mk)\n    {\n        return Packing_v2<kPack, kOrder>::apply(mk);\n    }\n\n    __device__ static constexpr int2 to_cs(int2 mk)\n    {\n        return mk2cs<kOrder>(mk.x, mk.y);\n    }\n\n    __device__ GmemIteratorSm80(): smem_data_{Pointer{nullptr}} {};\n\n    __device__ GmemIteratorSm80(const MatrixData& mat, int2 offset, int2 extent): smem_data_{Pointer{(T*)nullptr}}\n    {\n        const int warp_id = threadIdx.x / WARP_SIZE;\n        const int lane_id = threadIdx.x % WARP_SIZE;\n\n        const Pointer data{(T*)mat.ptr.ptr};\n        const int     ld = mat.ptr.stride;\n\n        const int2 offsets = Map::get_offset(warp_id, lane_id);\n\n        offset_c_ = offsets.x;\n        offset_s_ = offsets.y;\n\n        if constexpr (pred_.is_active) {\n            extent = to_cs(pack(extent));\n            PRAGMA_UNROLL\n            for (int s = 0; s < Map::kIterS; ++s) {\n                PRAGMA_UNROLL\n                for (int c = 0; c < Map::kIterC; ++c) {\n                    int ss = offset_s_ + s * Map::kDeltaS;\n                    int cc = offset_c_ + c * Map::kDeltaC;\n                    if (ss < extent.y && cc < extent.x) {\n                        pred_.set(s, c);\n                    }\n                }\n            }\n        }\n\n        PRAGMA_UNROLL\n        for (int s = 0; s < kPeriodS; ++s) {\n            PRAGMA_UNROLL\n            for (int c = 0; c < kPeriodC; ++c) {\n                phases_[s][c] = SmemLayout::apply(offset_s_ + s * Map::kDeltaS, offset_c_ + c * Map::kDeltaC);\n            }\n        }\n\n        const int src_offset = is_indexed ? offsets.x : offsets.x + offsets.y * ld;\n\n        src_offset_ = src_offset * bitsof<T> / bitsof<char>;\n\n        src_step_c_ = bitsof<T> * Map::kDeltaC / bitsof<char>;\n        src_step_s_ = bitsof<T> * Map::kDeltaS * ld / bitsof<char>;\n\n        src_step_k_ = bitsof<T> * cs2mk<kOrder>(Map::kDimC, Map::kDimS * ld).y / bitsof<char>;\n\n        // Initialize for the first tile\n        if constexpr (is_indexed) {\n            const int2 cta_cs = to_cs(offset);\n            for (int s = 0; s < ITER_S; ++s) {\n                const int  ss    = cta_cs.y + offset_s_ + s * Map::kDeltaS;\n                const int  idx   = (mat.idxs && pred_(s, 0)) ? __ldg(mat.idxs + ss) : ss;\n                const auto tmp   = data + cs2idx({cta_cs.x, idx}, ld);\n                src_data_vec_[s] = reinterpret_cast<const char*>((T*)tmp) + src_offset_;\n            }\n        }\n        else {\n            auto src_data = data + cs2idx(to_cs(pack(offset)), ld);\n            src_data_     = reinterpret_cast<const char*>((T*)src_data) + src_offset_;\n        }\n\n#if TURBOMIND_ARCH_SM80\n        if constexpr (Policy::kEvictPolicy != EvictPolicy::kEvictNormal) {\n            asm volatile(\"createpolicy.fractional.L2::evict_first.b64 %0;\\n\" : \"=l\"(cache_policy_) :);\n        }\n#endif\n    }\n\n    __device__ constexpr int _src_step_k() const\n    {\n        return src_step_k_;\n    }\n\n    __device__ void ClearSmem(int pipe_iter = 0)\n    {\n        PRAGMA_UNROLL\n        for (int s = 0; s < Map::kIterS; ++s) {\n            PRAGMA_UNROLL\n            for (int c = 0; c < Map::kIterC; ++c) {\n                const int pred_s = offset_s_ + s * Map::kDeltaS < Map::kDimS;\n                const int pred_c = offset_c_ + c * Map::kDeltaC < Map::kDimC;\n                auto      ptr    = &smem_data_(offset_s_ + s * Map::kDeltaS, offset_c_ + c * Map::kDeltaC);\n                if ((Map::kAlignedC && Map::kAlignedS) || (pred_s && pred_c)) {\n                    Store(ptr, Array<T, Map::kAccessC>{});\n                }\n            }\n        }\n    }\n\n    __device__ void Prefetch(int begin, int count, bool tile_mask)\n    {\n        PRAGMA_UNROLL\n        for (int s = begin; s < begin + count && s < Map::kIterS; ++s) {\n\n            if constexpr (is_indexed) {\n                src_data_ = src_data_vec_[s];\n            }\n\n            PRAGMA_UNROLL\n            for (int c = 0; c < Map::kIterC; ++c) {\n                // auto dst = &smem_data_(offset_s_ + s * Map::kDeltaS, offset_c_ + c * Map::kDeltaC);\n\n                const int i0  = SmemLayout::apply(  //\n                    s / kPeriodS * kPeriodS * Map::kDeltaS,\n                    c / kPeriodC * kPeriodC * Map::kDeltaC);\n                const int i1  = phases_[s % kPeriodS][c % kPeriodC];\n                auto      dst = &smem_data_.ptr_[i0 + i1];\n\n                CpAsync(std::true_type{}, dst, src_data_ + src_step_c_ * c, tile_mask && g_mask && pred_(s, c));\n            }\n\n            if constexpr (is_indexed) {\n                src_data_vec_[s] += _src_step_k();\n            }\n            else {\n                src_data_ += src_step_s_;\n                if (s == Map::kIterS - 1) {\n                    src_data_ -= src_step_s_ * Map::kIterS;\n                    src_data_ += _src_step_k();\n                }\n            }\n        }\n    }\n\n    __device__ void Prefetch(bool tile_mask)\n    {\n        Prefetch(0, Map::kIterS, tile_mask);\n    }\n\n    __device__ void Advance()\n    {\n        if constexpr (!is_indexed) {\n            if (!g_mask) {\n                src_data_ -= _src_step_k();\n            }\n        }\n    }\n\n    __device__ void CpAsync(std::true_type, T* dst, const char* __restrict__ src, bool mask)\n    {\n#if TURBOMIND_ARCH_SM80\n        constexpr int size = sizeof(AccessType);\n        static_assert(size <= 16);\n\n        constexpr int prefetch_size = std::min(256, size * Map::kWarpThreadC);\n\n        auto ptr = cast_smem_ptr_to_uint(dst);\n\n        static constexpr auto cache_op = GetCacheOp<Policy::kCacheOp, size>::value;\n\n        if constexpr (Policy::kEvictPolicy != EvictPolicy::kEvictNormal) {\n            CP_ASYNC<cache_op, size, prefetch_size>::apply(ptr, src, cache_policy_, mask);\n        }\n        else {\n            CP_ASYNC<cache_op, size, prefetch_size>::apply(ptr, src, mask);\n        }\n#else\n        assert(TURBOMIND_ARCH_SM80);\n#endif\n    }\n};\n\ntemplate<Striding mode, class Policy>\nstruct IteratorSm80 {\n    template<class T, class Map, class SmemLayout, Pack kPack, Order kOrder, bool AlignedC, bool AlignedS>\n    using Type = GmemIteratorSm80<T, Map, SmemLayout, kPack, kOrder, AlignedC, AlignedS, mode, Policy>;\n};\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/iterator_sm90.h",
    "content": "#pragma once\n\n#include <cute/arch/copy_sm90_desc.hpp>\n#include <cute/arch/copy_sm90_tma.hpp>\n\nnamespace turbomind::gemm {\n\ntemplate<int multicast>\nstruct GmemIteratorSm90 {\n\n    const CUtensorMap* desc_ptr_;\n    int2               offset_;\n    int2               step_;\n\n    __device__ GmemIteratorSm90(const CUtensorMap* desc_ptr, int2 offset, int2 step)\n    {\n        desc_ptr_ = desc_ptr;\n        offset_   = offset;\n        step_     = step;\n    }\n\n    __device__ void Step(uint64_t* mbar_ptr, void* smem_ptr, uint16_t mask, uint64_t cache_hint = 0)\n    {\n        if constexpr (multicast > 1) {\n            cute::SM90_TMA_LOAD_MULTICAST_2D::copy(\n                desc_ptr_, mbar_ptr, mask, cache_hint, smem_ptr, offset_.x, offset_.y);\n        }\n        else {\n            cute::SM90_TMA_LOAD_2D::copy(desc_ptr_, mbar_ptr, cache_hint, smem_ptr, offset_.x, offset_.y);\n        }\n        offset_.x += step_.x;\n        offset_.y += step_.y;\n    }\n};\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/kernel/sm70_884_16.cu",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include \"src/turbomind/kernels/gemm/arch/config_sm70_s884.h\"\n#include \"src/turbomind/kernels/gemm/registry.h\"\n#include \"src/turbomind/kernels/gemm/types.h\"\n\nnamespace turbomind::gemm {\n\nusing namespace sm70_s884;\nusing namespace cache_policy;\nusing S = cache_policy::Stream;\nusing D = cache_policy::Default;\n\nvoid Registry::sm70_884_16()\n{\n    if constexpr (1) {\n        // clang-format off\n        using C = Config_F16<kColMajor, 0>;\n        Add<C::Type<256, 128,  16, 4, 2, 1, D, D, 2,   0 , 1, 1, 128, 128>>();\n        Add<C::Type<128, 256,  16, 2, 4, 1, D, D, 2,   0 , 1, 1, 128, 128>>();\n        Add<C::Type<128, 256,  16, 2, 4, 1, D, D, 2,   0 , 1, 1, 128, 128>>();\n        Add<C::Type<128, 128,  16, 2, 2, 1, D, D, 2, true, 1, 1,  64, 128>>();\n        Add<C::Type< 96,  64,  32, 2, 2, 1, D, D, 2, true, 1, 1>>();\n        Add<C::Type< 64, 128,  32, 1, 4, 1, D, S, 2, true, 1, 1>>();\n        Add<C::Type< 64,  64,  64, 2, 2, 1, D, S, 2, true, 1, 1>>();\n        Add<C::Type< 32, 128,  32, 1, 4, 1, D, S, 2, true, 1, 1>>();\n        Add<C::Type< 16, 128,  64, 1, 4, 1, D, S, 2, true, 1, 1>>();\n        Add<C::Type< 16, 128,  32, 1, 4, 1, D, S, 2, true, 1, 1>>();\n        Add<C::Type<  8, 128,  64, 1, 4, 1, D, S, 2, true, 1, 1>>();\n        // clang-format on\n    }\n}\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/kernel/sm70_884_4.cu",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include \"src/turbomind/kernels/gemm/arch/config_sm70_s884.h\"\n#include \"src/turbomind/kernels/gemm/registry.h\"\n#include \"src/turbomind/kernels/gemm/types.h\"\n\nnamespace turbomind::gemm {\n\nusing namespace sm70_s884;\nusing namespace cache_policy;\nusing S = cache_policy::Stream;\nusing D = cache_policy::Default;\n\nvoid Registry::sm70_884_4()\n{\n    if constexpr (1) {\n        // clang-format off\n        using C = Config_U4_d<kColMajor>;\n        Add<C::Type<128, 256, 16, 2, 4, 1, D, D, 2, true, 1, 128, 128, 128>>();\n        Add<C::Type<128, 128, 16, 2, 2, 1, D, D, 2, true, 1, 128, 64, 128>>();\n        Add<C::Type<128, 128, 16, 2, 2, 1, D, S, 2, true, 1, 128, 64, 128>>();\n        Add<C::Type< 96, 128, 32, 2, 2, 1, D, S, 2, true, 1, 128, 48, 128>>();\n        Add<C::Type< 64, 128, 32, 2, 2, 1, D, D, 2, true, 1, 128, 32, 128>>();\n        Add<C::Type< 64, 128, 32, 2, 2, 1, D, S, 2, true, 1, 128, 32, 128>>();\n        Add<C::Type< 64, 128, 16, 1, 4, 1, D, S, 2, true, 1, 128, 32, 128>>();\n        Add<C::Type< 64, 256, 16, 1, 4, 1, D, S, 2, true, 1, 128, 64, 128>>();\n        Add<C::Type< 32, 128, 32, 1, 4, 1, D, S, 2, true, 1, 128>>();\n        Add<C::Type< 32, 256, 32, 1, 4, 1, D, S, 2, true, 1, 128, 32, 128>>();\n        Add<C::Type< 16, 128, 32, 1, 4, 1, D, S, 2, true, 1, 128>>();\n        Add<C::Type< 16, 256, 32, 1, 4, 1, D, S, 2, true, 1, 128>>();\n        Add<C::Type<  8, 128, 64, 1, 4, 1, D, S, 2, true, 1, 128>>();\n        Add<C::Type<  8, 128, 32, 1, 4, 1, D, S, 2, true, 1, 128>>();\n        Add<C::Type<  8, 256, 64, 1, 4, 1, D, S, 2, true, 1, 128>>();\n        // clang-format on\n    }\n\n    if constexpr (1) {\n        // clang-format off\n        using C = Config_U4_g<kColMajor>;\n        Add<C::Type<128, 256,  16, 2, 4, 1, D, D, 2,   0 , 1, 128, 128, 128>>();\n        Add<C::Type<128, 128,  16, 2, 2, 1, D, D, 2, true, 1, 128,  64, 128>>();\n        Add<C::Type< 64, 128,  32, 1, 4, 1, D, S, 2, true, 1, 128,  32, 128>>();\n        Add<C::Type< 64, 256,  16, 1, 4, 1, D, S, 2, true, 1, 128,  64, 128>>();\n        Add<C::Type< 32, 128,  32, 1, 4, 1, D, S, 2, true, 1, 128>>();\n        Add<C::Type< 32, 256,  32, 1, 4, 1, D, S, 2, true, 1, 128>>();\n        Add<C::Type< 16, 256,  64, 1, 4, 1, D, S, 2, true, 1, 128>>();\n        Add<C::Type< 16, 256,  32, 1, 4, 1, D, S, 2, true, 1, 128>>();\n        Add<C::Type< 16, 128,  32, 1, 4, 1, D, S, 2, true, 1, 128>>();\n        Add<C::Type< 16, 256,  32, 1, 4, 1, D, S, 2, true, 1, 128>>();\n        Add<C::Type<  8, 128,  64, 1, 4, 1, D, S, 2, true, 1, 128>>();\n        // clang-format on\n    }\n\n    if constexpr (1) {\n        // clang-format off\n        using C = Config_MXF4<kColMajor, 0>;\n        Add<C::Type<128, 128,  16, 2, 2, 1, D, D, 2, true, 1, 32,  64, 128>>();\n        Add<C::Type< 64, 128,  32, 1, 4, 1, D, S, 2, true, 1, 32,  32, 128>>();\n        Add<C::Type< 32, 128,  32, 1, 4, 1, D, S, 2, true, 1, 32>>();\n        Add<C::Type< 16, 128,  32, 1, 4, 1, D, S, 2, true, 1, 32>>();\n        Add<C::Type<  8, 128,  64, 1, 4, 1, D, S, 2, true, 1, 32>>();\n        // clang-format on\n    }\n}\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/kernel/sm70_884_8.cu",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include \"src/turbomind/kernels/gemm/arch/config_sm70_s884.h\"\n#include \"src/turbomind/kernels/gemm/registry.h\"\n#include \"src/turbomind/kernels/gemm/types.h\"\n\nnamespace turbomind::gemm {\n\nusing namespace sm70_s884;\nusing namespace cache_policy;\nusing S = cache_policy::Stream;\nusing D = cache_policy::Default;\n\nvoid Registry::sm70_884_8()\n{\n    if constexpr (1) {\n        // clang-format off\n        using C = Config_E4M3<kColMajor, 0>;\n        Add<C::Type<128, 128,  16, 2, 2, 1, D, D, 2, true, 1, 128,  64, 128>>();\n        Add<C::Type< 64, 128,  32, 1, 4, 1, D, S, 2, true, 1, 128,  32, 128>>();\n        Add<C::Type< 32, 128,  32, 1, 4, 1, D, S, 2, true, 1, 128>>();\n        Add<C::Type< 16, 128,  32, 1, 4, 1, D, S, 2, true, 1, 128>>();\n        Add<C::Type<  8, 128,  64, 1, 4, 1, D, S, 2, true, 1, 128>>();\n        // clang-format on\n    }\n}\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/kernel/sm75_16816_16.cu",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include \"src/turbomind/kernels/gemm/arch.h\"\n#include \"src/turbomind/kernels/gemm/arch/config_sm75_s16816.h\"\n#include \"src/turbomind/kernels/gemm/registry.h\"\n#include \"src/turbomind/kernels/gemm/types.h\"\n\nnamespace turbomind::gemm {\n\nusing namespace sm75_s16816;\nusing namespace cache_policy;\nusing S = cache_policy::Stream;\nusing D = cache_policy::Default;\n\nvoid Registry::sm75_16816_16()\n{\n    if constexpr (1) {\n        // clang-format off\n        using C = Config_F16<kColMajor, 0>;\n        Add<C::Type<128, 256,  32, 2, 4, 1, D, D, 2,    0, 1, 1, 128, 128>>();\n        Add<C::Type<128, 128,  32, 2, 2, 1, D, D, 2, true, 1, 1,  64, 128>>();\n        Add<C::Type< 96,  64,  64, 2, 2, 1, D, D, 2, true, 1, 1>>();\n        Add<C::Type< 64, 128,  64, 1, 4, 1, D, S, 2, true, 1, 1>>();\n        Add<C::Type< 64,  64,  64, 2, 2, 1, D, S, 2, true, 1, 1>>();\n        Add<C::Type< 64,  64, 128, 1, 2, 2, D, S, 2, true, 1, 1>>();\n        Add<C::Type< 32,  64, 128, 1, 2, 2, D, S, 2, true, 1, 1>>();\n        Add<C::Type< 32, 128,  64, 1, 4, 1, D, S, 2, true, 1, 1>>();\n        Add<C::Type< 16,  64, 128, 1, 2, 2, D, S, 2, true, 1, 1>>();\n        Add<C::Type< 16, 128,  64, 1, 4, 1, D, S, 2, true, 1, 1>>();\n        // clang-format on\n    }\n}\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/kernel/sm75_16816_4.cu",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include \"src/turbomind/kernels/gemm/arch/config_sm75_s16816.h\"\n#include \"src/turbomind/kernels/gemm/registry.h\"\n#include \"src/turbomind/kernels/gemm/types.h\"\n\nnamespace turbomind::gemm {\n\nusing namespace sm75_s16816;\nusing namespace cache_policy;\nusing S = cache_policy::Stream;\nusing D = cache_policy::Default;\n\nvoid Registry::sm75_16816_4()\n{\n    if constexpr (1) {\n        // clang-format off\n        using C = Config_U4_d<kColMajor>;\n        Add<C::Type<128, 256, 32, 1, 8, 1, D, D, 2, true, 1, 128, 128, 128>>();\n        Add<C::Type<128, 128, 32, 1, 4, 1, D, D, 2, true, 1, 128,  64, 128>>();\n        Add<C::Type< 96,  64, 64, 1, 2, 2, D, S, 2, true, 1, 128>>();\n        Add<C::Type< 64, 128, 32, 1, 4, 1, D, D, 2, true, 1, 128,  32, 128>>();\n        Add<C::Type< 64, 128, 32, 1, 4, 1, D, S, 2, true, 1, 128,  32, 128>>();\n        Add<C::Type< 64,  64, 64, 1, 2, 2, D, S, 2, true, 1, 128>>();\n        Add<C::Type< 48, 128, 64, 1, 4, 1, D, S, 2, true, 1, 128>>();\n        Add<C::Type< 48,  64, 64, 1, 2, 2, D, S, 2, true, 1, 128>>();\n        Add<C::Type< 32,  64, 64, 1, 2, 2, D, S, 2, true, 1, 128>>();\n        Add<C::Type< 16, 128, 32, 1, 4, 1, D, S, 2, true, 1, 128>>();\n        Add<C::Type< 16,  64, 64, 1, 2, 2, D, S, 2, true, 1, 128>>();\n        // clang-format on\n    }\n\n    if constexpr (1) {\n        // clang-format off\n        using C = Config_U4_g<kColMajor>;\n        Add<C::Type<128, 256,  32, 2, 4, 1, D, D, 2,    0, 1, 128, 128, 128>>();\n        Add<C::Type<128, 128,  32, 2, 2, 1, D, D, 2, true, 1, 128,  64, 128>>();\n        Add<C::Type< 64, 128,  64, 1, 4, 1, D, S, 2, true, 1, 128,  32, 128>>();\n        Add<C::Type< 64, 256,  32, 1, 4, 1, D, S, 2, true, 1, 128,  32, 256>>();\n        Add<C::Type< 32,  64, 128, 1, 2, 2, D, S, 2, true, 1, 128>>();\n        Add<C::Type< 32, 128,  64, 1, 4, 1, D, S, 2, true, 1, 128>>();\n        Add<C::Type< 16, 128,  32, 1, 4, 1, D, S, 2, true, 1, 128>>();\n        Add<C::Type< 16,  64,  64, 1, 2, 2, D, S, 2, true, 1, 128>>();\n        // clang-format on\n    }\n\n    if constexpr (1) {\n        // clang-format off\n        using C = Config_MXF4<kColMajor, 1>;\n        Add<C::Type<128, 128, 32, 4, 1, 1, D, D, 2, true, 32, 1, 128, 64>>();\n        Add<C::Type<128,  64, 32, 4, 1, 1, D, D, 2, true, 32, 1>>();\n        Add<C::Type<128,  32, 32, 4, 1, 1, S, D, 2, true, 32, 1>>();\n        Add<C::Type<128,  16, 32, 4, 1, 1, S, D, 2, true, 32, 1>>();\n        Add<C::Type<128,  16, 64, 4, 1, 1, S, D, 2, true, 32, 1>>();\n        Add<C::Type< 64,  16, 64, 4, 1, 1, S, D, 2, true, 32, 1>>();\n        // clang-format on\n    }\n}\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/kernel/sm75_16816_8.cu",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include \"src/turbomind/kernels/gemm/arch.h\"\n#include \"src/turbomind/kernels/gemm/arch/config_sm75_s16816.h\"\n#include \"src/turbomind/kernels/gemm/registry.h\"\n#include \"src/turbomind/kernels/gemm/types.h\"\n\nnamespace turbomind::gemm {\n\nusing namespace sm75_s16816;\nusing namespace cache_policy;\nusing S = cache_policy::Stream;\nusing D = cache_policy::Default;\n\nvoid Registry::sm75_16816_8()\n{\n    if constexpr (1) {\n        // clang-format off\n        using Cg = Config_E4M3<kColMajor, 1>;\n        Add<Cg::Type<256, 128,  32, 8, 1, 1, D, D, 3, true, 128, 1, 128, 128>>();\n        Add<Cg::Type<256,  64,  32, 4, 1, 1, D, D, 3, true, 128, 1, 128,  64>>();\n        Add<Cg::Type<128, 128,  32, 4, 1, 1, D, D, 3, true, 128, 1, 128,  64>>();\n        Add<Cg::Type<128,  96,  32, 4, 1, 1, D, D, 3, true, 128, 1>>();\n        Add<Cg::Type<128,  64,  32, 4, 1, 1, D, D, 3, true, 128, 1>>();\n        Add<Cg::Type<128,  32,  32, 4, 1, 1, S, D, 3, true, 128, 1>>();\n        Add<Cg::Type<128,  16,  64, 4, 1, 1, S, D, 3, true, 128, 1>>();\n        Add<Cg::Type<128,  16,  32, 4, 1, 1, S, D, 5, true, 128, 1>>();\n        // clang-format on\n    }\n}\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/kernel/sm80_16816_16.cu",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include \"src/turbomind/kernels/gemm/arch.h\"\n#include \"src/turbomind/kernels/gemm/arch/config_sm80_s16816.h\"\n#include \"src/turbomind/kernels/gemm/registry.h\"\n#include \"src/turbomind/kernels/gemm/types.h\"\n\nnamespace turbomind::gemm {\n\nusing namespace sm80_s16816;\nusing namespace cache_policy;\nusing S = cache_policy::Stream;\nusing D = cache_policy::Default;\n\nvoid Registry::sm80_16816_16()\n{\n    if constexpr (1) {\n        // clang-format off\n        using C = Config_F16_g<Sm80, half, kColMajor>;\n        Add<C::Type<256, 128,  64, 4, 2, 1, D, D, 3,   0 , 1, 1>>();\n        Add<C::Type<128, 256,  64, 2, 4, 1, D, D, 3,   0 , 1, 1>>(); // 10\n        Add<C::Type<128, 256,  32, 2, 4, 1, D, D, 3,   0 , 1, 1>>();\n        Add<C::Type<128, 128,  32, 2, 2, 1, D, D, 3, true, 1, 1>>(); // 6\n        Add<C::Type<128, 128,  64, 2, 2, 1, D, D, 3, true, 1, 1>>();\n        Add<C::Type<128, 128,  32, 2, 2, 1, D, D, 5, true, 1, 1>>();\n        Add<C::Type< 96,  64,  64, 2, 2, 1, D, D, 3, true, 1, 1>>(); // 2\n        Add<C::Type< 64, 128,  64, 1, 4, 1, D, S, 3, true, 1, 1>>();\n        Add<C::Type< 64,  64,  64, 2, 2, 1, D, S, 3, true, 1, 1>>(); // *\n        Add<C::Type< 64,  64,  64, 2, 2, 1, D, S, 5, true, 1, 1>>();\n        Add<C::Type< 64,  64, 128, 1, 2, 2, D, S, 3, true, 1, 1>>(); // 4\n        Add<C::Type< 32,  64, 128, 1, 2, 2, D, S, 3, true, 1, 1>>();\n        Add<C::Type< 32, 128,  64, 1, 4, 1, D, S, 3, true, 1, 1>>();\n        Add<C::Type< 16,  64, 128, 1, 2, 2, D, S, 3, true, 1, 1>>(); // 10\n        Add<C::Type< 16, 128,  64, 1, 4, 1, D, S, 3, true, 1, 1>>();\n        // clang-format on\n    }\n\n    if constexpr (1) {\n        // clang-format off\n        using C = Config_F16_g<Sm80, nv_bfloat16, kColMajor>;\n        Add<C::Type<256, 128,  64, 4, 2, 1, D, D, 3,   0 , 1, 1>>();\n        Add<C::Type<128, 256,  64, 2, 4, 1, D, D, 3,   0 , 1, 1>>(); // 10\n        Add<C::Type<128, 256,  32, 2, 4, 1, D, D, 3,   0 , 1, 1>>();\n        Add<C::Type<128, 128,  32, 2, 2, 1, D, D, 3, true, 1, 1>>(); // 6\n        Add<C::Type<128, 128,  64, 2, 2, 1, D, D, 3, true, 1, 1>>();\n        Add<C::Type<128, 128,  32, 2, 2, 1, D, D, 5, true, 1, 1>>();\n        Add<C::Type< 96,  64,  64, 2, 2, 1, D, D, 3, true, 1, 1>>(); // 2\n        Add<C::Type< 64, 128,  64, 1, 4, 1, D, S, 3, true, 1, 1>>();\n        Add<C::Type< 64,  64,  64, 2, 2, 1, D, S, 3, true, 1, 1>>(); // *\n        Add<C::Type< 64,  64,  64, 2, 2, 1, D, S, 5, true, 1, 1>>();\n        Add<C::Type< 64,  64, 128, 1, 2, 2, D, S, 3, true, 1, 1>>(); // 4\n        Add<C::Type< 32,  64, 128, 1, 2, 2, D, S, 3, true, 1, 1>>();\n        Add<C::Type< 32, 128,  64, 1, 4, 1, D, S, 3, true, 1, 1>>();\n        Add<C::Type< 16,  64, 128, 1, 2, 2, D, S, 3, true, 1, 1>>(); // 10\n        Add<C::Type< 16, 128,  64, 1, 4, 1, D, S, 3, true, 1, 1>>();\n        // clang-format on\n    }\n}\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/kernel/sm80_16816_4.cu",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include \"src/turbomind/kernels/gemm/arch.h\"\n#include \"src/turbomind/kernels/gemm/arch/config_sm80_s16816.h\"\n#include \"src/turbomind/kernels/gemm/registry.h\"\n#include \"src/turbomind/kernels/gemm/types.h\"\n\nnamespace turbomind::gemm {\n\nusing namespace sm80_s16816;\nusing namespace cache_policy;\nusing S = cache_policy::Stream;\nusing D = cache_policy::Default;\n\nvoid Registry::sm80_16816_4()\n{\n    if constexpr (1) {\n        // clang-format off\n        using C = Config_U4_d<Sm80, half, kColMajor>;\n        // Add<C::Type<128, 256,  64, 1, 8, 1, D, S, 3, true, 1, 128>>(); // 0/0\n        Add<C::Type<128, 256,  32, 1, 8, 1, D, D, 3, true, 1, 128, 128, 128>>(); // 30/3\n        Add<C::Type<128, 256,  32, 1, 8, 1, D, D, 4, true, 1, 128, 128, 128>>(); // --/20\n        Add<C::Type<128, 128,  32, 1, 4, 1, D, D, 3, true, 1, 128, 64, 128>>();  // --/13\n        Add<C::Type<128, 128,  32, 1, 4, 1, D, S, 4, true, 1, 128, 64, 128>>();  // 21/13\n        Add<C::Type<128, 128,  64, 1, 4, 2, D, S, 3, true, 1, 128, 64, 128>>();  // 6/6\n\n        Add<C::Type<96, 256,  32, 1, 8, 1, D, D, 4, true, 1, 128>>();  // --/3\n        Add<C::Type<96, 256,  32, 1, 8, 1, D, S, 3, true, 1, 128>>();  // 13/13\n        Add<C::Type<96, 128,  32, 1, 4, 1, D, S, 4, true, 1, 128>>();  // 14/10\n        Add<C::Type<96, 128, 128, 1, 4, 2, D, S, 3, true, 1, 128>>();  // 2/2\n\n        Add<C::Type<64, 256,  32, 1, 4, 1, D, D, 3, true, 1, 128, 64, 128>>(); // --/21\n        Add<C::Type<64, 256,  32, 1, 4, 1, D, S, 4, true, 1, 128, 64, 128>>(); // 27/13\n        Add<C::Type<64, 128,  32, 1, 4, 1, D, S, 4, true, 1, 128>>();  // 8/5\n        Add<C::Type<64, 128,  64, 1, 4, 1, D, S, 3, true, 1, 128>>();  // 7/5\n        Add<C::Type<64, 128, 128, 1, 4, 2, D, S, 3, true, 1, 128>>();  // 6/7\n        Add<C::Type<64,  64,  64, 1, 2, 2, D, S, 6, true, 1, 128>>();\n\n        Add<C::Type<48, 256,  64, 1, 4, 1, D, S, 3, true, 1, 128, 48, 128>>(); // 1/1\n        Add<C::Type<48, 128,  64, 1, 4, 1, D, S, 4, true, 1, 128>>();  // 1/1\n        Add<C::Type<48, 128, 128, 1, 4, 2, D, S, 3, true, 1, 128>>();  // 4/4\n        Add<C::Type<48,  64, 128, 1, 2, 2, D, S, 4, true, 1, 128>>();\n\n        Add<C::Type<32, 256,  64, 1, 4, 1, D, S, 3, true, 1, 128>>();\n        Add<C::Type<32, 128,  64, 1, 4, 1, D, S, 4, true, 1, 128>>();\n        Add<C::Type<32, 128, 128, 1, 4, 2, D, S, 3, true, 1, 128>>();\n        Add<C::Type<32,  64, 128, 1, 2, 2, D, S, 3, true, 1, 128>>();\n        Add<C::Type<32,  64, 128, 1, 2, 2, D, S, 4, true, 1, 128>>();\n\n        Add<C::Type<16, 128,  64, 1, 4, 1, D, S, 4, true, 1, 128>>();\n        Add<C::Type<16, 128, 128, 1, 4, 2, D, S, 3, true, 1, 128>>();\n        Add<C::Type<16, 128, 128, 1, 4, 2, D, S, 4, true, 1, 128>>();\n        Add<C::Type<16,  64, 128, 1, 2, 2, D, S, 3, true, 1, 128>>();\n        Add<C::Type<16,  64, 128, 1, 2, 2, D, S, 4, true, 1, 128>>();\n        // clang-format on\n    }\n\n    if constexpr (1) {\n        // clang-format off\n        using C = Config_U4_g<Sm80, half, kColMajor>;\n        Add<C::Type<128, 256,  32, 2, 4, 1, D, D, 3,   0 , 1, 128>>();  // 10 + 5 + 4 + 10 + 10, 37\n        Add<C::Type<128, 128,  32, 1, 4, 1, D, D, 3, true, 1, 128>>();  // 1 + 6 + 4 + 4 + 2, 3\n        Add<C::Type< 64, 128,  64, 1, 4, 1, D, S, 3, true, 1, 128>>();  // 7 + 4 + 6 + 2 + 4, 26\n        Add<C::Type< 64, 256,  32, 1, 4, 1, D, S, 3, true, 1, 128>>();  // 18\n        Add<C::Type< 32,  64, 128, 1, 2, 2, D, S, 3, true, 1, 128>>();  // 2\n        Add<C::Type< 32, 128,  64, 1, 4, 1, D, S, 5, true, 1, 128>>();  // 1 + 2 + 2 + 2 + 2, 2\n        Add<C::Type< 32, 256,  64, 1, 4, 1, D, S, 3, true, 1, 128>>();  // 9\n        Add<C::Type< 16, 256,  64, 1, 4, 1, D, S, 3, true, 1, 128>>();  // 22\n        Add<C::Type< 16, 256,  32, 1, 4, 1, D, S, 3, true, 1, 128>>();  // 8\n        Add<C::Type< 16, 128,  64, 1, 4, 1, D, S, 3, true, 1, 128>>();  // 1 + 13 + 9 + 13 + 7, 7\n        Add<C::Type< 16,  64, 128, 1, 2, 2, D, S, 3, true, 1, 128>>();  // 12 + 2 + 6 + 2 + 8, 42\n        // clang-format on\n    }\n\n    if constexpr (1) {\n        // clang-format off\n        using Cd = Config_MXF4<Sm80, bfloat16_t, 16, kColMajor>;\n        // Add<Cd::Type<256, 128, 32, 8, 1, 1, D, D, 3, true, 32, 1, 128, 128>>();\n\n        using Cg = Config_MXF4<Sm80, bfloat16_t, 16, kColMajor, 1>;\n        Add<Cg::Type<256, 128, 32, 8, 1, 1, D, D, 3, true, 32, 1, 128, 128>>();\n        Add<Cg::Type<256,  64, 32, 4, 1, 1, D, D, 3, true, 32, 1, 128,  64>>();\n        Add<Cg::Type<256,  32, 32, 4, 1, 1, S, D, 5, true, 32, 1>>();\n        Add<Cg::Type<128, 128, 32, 4, 1, 1, D, D, 3, true, 32, 1, 128,  64>>();\n        Add<Cg::Type<128,  96, 32, 4, 1, 1, D, D, 3, true, 32, 1>>();\n        Add<Cg::Type<128,  64, 32, 4, 1, 1, S, D, 3, true, 32, 1>>();\n        Add<Cg::Type<128,  32, 32, 4, 1, 1, S, D, 3, true, 32, 1>>();\n        Add<Cg::Type<128,  16, 32, 4, 1, 1, S, D, 5, true, 32, 1>>();\n        Add<Cg::Type<128,  16, 64, 4, 1, 1, S, D, 3, true, 32, 1>>();\n\n        using C8 = Config_MXF4<Sm80, bfloat16_t, 8, kColMajor, 1>;\n        Add<C8::Type<256, 8,  32, 4, 1, 1, S, D, 5, true, 32, 1>>();\n        Add<C8::Type<128, 8,  32, 4, 1, 1, S, D, 5, true, 32, 1>>();\n        Add<C8::Type<128, 8,  64, 4, 1, 1, S, D, 3, true, 32, 1>>();\n        Add<C8::Type< 64, 8,  64, 4, 1, 1, S, D, 5, true, 32, 1>>();\n        // clang-format on\n    }\n}\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/kernel/sm80_16816_8.cu",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include \"src/turbomind/kernels/gemm/arch.h\"\n#include \"src/turbomind/kernels/gemm/arch/config_sm80_s16816.h\"\n#include \"src/turbomind/kernels/gemm/registry.h\"\n#include \"src/turbomind/kernels/gemm/types.h\"\n\nnamespace turbomind::gemm {\n\nusing namespace sm80_s16816;\nusing namespace cache_policy;\nusing S = cache_policy::Stream;\nusing D = cache_policy::Default;\n\nvoid Registry::sm80_16816_8()\n{\n    if constexpr (1) {\n        // clang-format off\n        using Cd = Config_E4M3<Sm80, bfloat16_t, 16, kColMajor>;\n        // Add<Cd::Type<256, 128, 32, 8, 1, 1, D, D, 3, true, 128, 1, 128, 128>>();\n\n        using Cg = Config_E4M3<Sm80, bfloat16_t, 16, kColMajor, 1>;\n        Add<Cg::Type<256, 128,  32, 8, 1, 1, D, D, 3, true, 128, 1, 128, 128>>();\n        Add<Cg::Type<256,  64,  32, 4, 1, 1, D, D, 3, true, 128, 1, 128,  64>>();\n        Add<Cg::Type<256,  32,  64, 4, 1, 1, D, D, 3, true, 128, 1>>();\n        Add<Cg::Type<128, 128,  32, 4, 1, 1, D, D, 3, true, 128, 1, 128,  64>>();\n        Add<Cg::Type<128,  96,  32, 4, 1, 1, D, D, 3, true, 128, 1>>();\n        Add<Cg::Type<128,  64,  32, 4, 1, 1, D, D, 3, true, 128, 1>>();\n        Add<Cg::Type<128,  32,  32, 4, 1, 1, S, D, 3, true, 128, 1>>();\n        Add<Cg::Type<128,  16,  64, 4, 1, 1, S, D, 3, true, 128, 1>>();\n        Add<Cg::Type<128,  16,  32, 4, 1, 1, S, D, 5, true, 128, 1>>();\n\n        using C8 = Config_E4M3<Sm80, bfloat16_t, 8, kColMajor, 1>;\n        Add<C8::Type<256, 8,  64, 4, 1, 1, S, D, 3, true, 128, 1>>();\n        Add<C8::Type<128, 8,  64, 4, 1, 1, S, D, 3, true, 128, 1>>();\n        Add<C8::Type< 64, 8, 128, 4, 1, 1, S, D, 3, true, 128, 1>>();\n        // clang-format on\n    }\n}\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/kernel/sm90_16816_16.cu",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include \"src/turbomind/kernels/gemm/arch.h\"\n#include \"src/turbomind/kernels/gemm/arch/config_sm80_s16816.h\"\n#include \"src/turbomind/kernels/gemm/registry.h\"\n#include \"src/turbomind/kernels/gemm/types.h\"\n\nnamespace turbomind::gemm {\n\nusing namespace sm80_s16816;\nusing namespace cache_policy;\nusing S = cache_policy::Stream;\nusing D = cache_policy::Default;\n\nvoid Registry::sm90_16816_16()\n{\n    if constexpr (1) {\n        // clang-format off\n        using C = Config_F16_g<Sm90, half, kColMajor>;\n        Add<C::Type<256, 128,  64, 4, 2, 1, D, D, 3,   0 , 1, 1>>();\n        Add<C::Type<128, 256,  64, 2, 4, 1, D, D, 3,   0 , 1, 1>>();\n        Add<C::Type<128, 256,  32, 2, 4, 1, D, D, 3,   0 , 1, 1>>();\n        Add<C::Type<128, 128,  32, 2, 2, 1, D, D, 3, true, 1, 1>>();\n        Add<C::Type<128, 128,  64, 2, 2, 1, D, D, 3, true, 1, 1>>();\n        Add<C::Type<128, 128,  32, 2, 2, 1, D, D, 5, true, 1, 1>>();\n        Add<C::Type< 96,  64,  64, 2, 2, 1, D, D, 3, true, 1, 1>>();\n        Add<C::Type< 64, 128,  64, 1, 4, 1, D, D, 3, true, 1, 1>>();\n        Add<C::Type< 64,  64,  64, 2, 2, 1, D, D, 3, true, 1, 1>>();\n        Add<C::Type< 64,  64,  64, 2, 2, 1, D, D, 5, true, 1, 1>>();\n        Add<C::Type< 64,  64, 128, 1, 2, 2, D, D, 3, true, 1, 1>>();\n        Add<C::Type< 32,  64, 128, 1, 2, 2, D, D, 3, true, 1, 1>>();\n        Add<C::Type< 32, 128,  64, 1, 4, 1, D, D, 3, true, 1, 1>>();\n        Add<C::Type< 16,  64, 128, 1, 2, 2, D, D, 3, true, 1, 1>>();\n        Add<C::Type< 16, 128,  64, 1, 4, 1, D, D, 3, true, 1, 1>>();\n        // clang-format on\n    }\n\n    if constexpr (1) {\n        // clang-format off\n        using C = Config_F16_g<Sm90, nv_bfloat16, kColMajor>;\n        Add<C::Type<256, 128,  64, 4, 2, 1, D, D, 3,   0 , 1, 1>>();\n        Add<C::Type<128, 256,  64, 2, 4, 1, D, D, 3,   0 , 1, 1>>();\n        Add<C::Type<128, 256,  32, 2, 4, 1, D, D, 3,   0 , 1, 1>>();\n        Add<C::Type<128, 128,  32, 2, 2, 1, D, D, 3, true, 1, 1>>();\n        Add<C::Type<128, 128,  64, 2, 2, 1, D, D, 3, true, 1, 1>>();\n        Add<C::Type<128, 128,  32, 2, 2, 1, D, D, 5, true, 1, 1>>();\n        Add<C::Type< 96,  64,  64, 2, 2, 1, D, D, 3, true, 1, 1>>();\n        Add<C::Type< 64, 128,  64, 1, 4, 1, D, D, 3, true, 1, 1>>();\n        Add<C::Type< 64,  64,  64, 2, 2, 1, D, D, 3, true, 1, 1>>();\n        Add<C::Type< 64,  64,  64, 2, 2, 1, D, D, 5, true, 1, 1>>();\n        Add<C::Type< 64,  64, 128, 1, 2, 2, D, D, 3, true, 1, 1>>();\n        Add<C::Type< 32,  64, 128, 1, 2, 2, D, D, 3, true, 1, 1>>();\n        Add<C::Type< 32, 128,  64, 1, 4, 1, D, D, 3, true, 1, 1>>();\n        Add<C::Type< 16,  64, 128, 1, 2, 2, D, D, 3, true, 1, 1>>();\n        Add<C::Type< 16, 128,  64, 1, 4, 1, D, D, 3, true, 1, 1>>();\n        // clang-format on\n    }\n}\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/kernel/sm90_16816_4.cu",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include \"src/turbomind/kernels/gemm/arch.h\"\n#include \"src/turbomind/kernels/gemm/arch/config_sm80_s16816.h\"\n#include \"src/turbomind/kernels/gemm/registry.h\"\n#include \"src/turbomind/kernels/gemm/types.h\"\n\nnamespace turbomind::gemm {\n\nusing namespace sm80_s16816;\nusing namespace cache_policy;\nusing S = cache_policy::Stream;\nusing D = cache_policy::Default;\n\nvoid Registry::sm90_16816_4()\n{\n    if constexpr (1) {\n        // clang-format off\n        using C = Config_U4_d<Sm90, half, kColMajor>;\n        Add<C::Type<128, 256,  64, 1, 8, 1, D, D, 3, true, 1, 128>>();\n        Add<C::Type<128, 256,  32, 1, 8, 1, D, D, 3, true, 1, 128, 128, 128>>();\n        Add<C::Type<128, 256,  32, 1, 8, 1, D, D, 4, true, 1, 128, 128, 128>>();\n        Add<C::Type<128, 128,  32, 1, 4, 1, D, D, 3, true, 1, 128, 64, 128>>();\n        Add<C::Type<128, 128,  32, 1, 4, 1, D, D, 4, true, 1, 128, 64, 128>>();\n        Add<C::Type<128, 128,  64, 1, 4, 2, D, D, 3, true, 1, 128, 64, 128>>();\n\n        Add<C::Type<96, 256,  32, 1, 8, 1, D, D, 4, true, 1, 128>>();\n        Add<C::Type<96, 256,  32, 1, 8, 1, D, D, 3, true, 1, 128>>();\n        Add<C::Type<96, 128,  32, 1, 4, 1, D, D, 4, true, 1, 128>>();\n        Add<C::Type<96, 128, 128, 1, 4, 2, D, D, 3, true, 1, 128>>();\n\n        Add<C::Type<64, 256,  32, 1, 4, 1, D, D, 3, true, 1, 128, 64, 128>>();\n        Add<C::Type<64, 256,  32, 1, 4, 1, D, D, 4, true, 1, 128, 64, 128>>();\n        Add<C::Type<64, 128,  32, 1, 4, 1, D, D, 4, true, 1, 128>>();\n        Add<C::Type<64, 128,  64, 1, 4, 1, D, D, 3, true, 1, 128>>();\n        Add<C::Type<64, 128, 128, 1, 4, 2, D, D, 3, true, 1, 128>>();\n        Add<C::Type<64,  64,  64, 1, 2, 2, D, D, 6, true, 1, 128>>();\n\n        Add<C::Type<48, 256,  64, 1, 4, 1, D, D, 3, true, 1, 128, 48, 128>>();\n        Add<C::Type<48, 128,  64, 1, 4, 1, D, D, 4, true, 1, 128>>();\n        Add<C::Type<48, 128, 128, 1, 4, 2, D, D, 3, true, 1, 128>>();\n        Add<C::Type<48,  64, 128, 1, 2, 2, D, D, 4, true, 1, 128>>();\n\n        Add<C::Type<32, 256,  64, 1, 4, 1, D, D, 3, true, 1, 128>>();\n        Add<C::Type<32, 128,  64, 1, 4, 1, D, D, 4, true, 1, 128>>();\n        Add<C::Type<32, 128, 128, 1, 4, 2, D, D, 3, true, 1, 128>>();\n        Add<C::Type<32,  64, 128, 1, 2, 2, D, D, 3, true, 1, 128>>();\n        Add<C::Type<32,  64, 128, 1, 2, 2, D, D, 4, true, 1, 128>>();\n\n        Add<C::Type<16, 128,  64, 1, 4, 1, D, D, 4, true, 1, 128>>();\n        Add<C::Type<16, 128, 128, 1, 4, 2, D, D, 3, true, 1, 128>>();\n        Add<C::Type<16, 128, 128, 1, 4, 2, D, D, 4, true, 1, 128>>();\n        Add<C::Type<16,  64, 128, 1, 2, 2, D, D, 3, true, 1, 128>>();\n        Add<C::Type<16,  64, 128, 1, 2, 2, D, D, 4, true, 1, 128>>();\n        // clang-format on\n    }\n\n    if constexpr (1) {\n        // clang-format off\n        using C = Config_U4_g<Sm90, half, kColMajor>;\n        Add<C::Type<128, 256,  32, 2, 4, 1, D, D, 3,   0 , 1, 128>>();\n        Add<C::Type<128, 128,  32, 1, 4, 1, D, D, 3, true, 1, 128>>();\n        Add<C::Type< 64, 128,  64, 1, 4, 1, D, D, 3, true, 1, 128>>();\n        Add<C::Type< 64, 256,  32, 1, 4, 1, D, D, 3, true, 1, 128>>();\n        Add<C::Type< 32,  64, 128, 1, 2, 2, D, D, 3, true, 1, 128>>();\n        Add<C::Type< 32, 128,  64, 1, 4, 1, D, D, 5, true, 1, 128>>();\n        Add<C::Type< 32, 256,  64, 1, 4, 1, D, D, 3, true, 1, 128>>();\n        Add<C::Type< 16, 256,  64, 1, 4, 1, D, D, 3, true, 1, 128>>();\n        Add<C::Type< 16, 256,  32, 1, 4, 1, D, D, 3, true, 1, 128>>();\n        Add<C::Type< 16, 128,  64, 1, 4, 1, D, D, 3, true, 1, 128>>();\n        Add<C::Type< 16,  64, 128, 1, 2, 2, D, D, 3, true, 1, 128>>();\n        // clang-format on\n    }\n\n    if constexpr (1) {\n        // clang-format off\n        using Cd = Config_MXF4<Sm90, bfloat16_t, 16, kColMajor>;\n        // Add<Cd::Type<256, 128, 32, 8, 1, 1, D, D, 3, true, 32, 1, 128, 128>>();\n\n        using Cg = Config_MXF4<Sm90, bfloat16_t, 16, kColMajor, 1>;\n        Add<Cg::Type<256, 128, 32, 8, 1, 1, D, D, 3, true, 32, 1, 128, 128>>();\n        Add<Cg::Type<256,  64, 32, 4, 1, 1, D, D, 3, true, 32, 1, 128,  64>>();\n        Add<Cg::Type<256,  32, 32, 4, 1, 1, D, D, 5, true, 32, 1>>();\n        Add<Cg::Type<128, 128, 32, 4, 1, 1, D, D, 4, true, 32, 1, 128,  64>>();\n        Add<Cg::Type<128,  96, 32, 4, 1, 1, D, D, 3, true, 32, 1>>();\n        Add<Cg::Type<128,  64, 32, 4, 1, 1, D, D, 3, true, 32, 1>>();\n        Add<Cg::Type<128,  32, 32, 4, 1, 1, D, D, 3, true, 32, 1>>();\n        Add<Cg::Type<128,  16, 32, 4, 1, 1, D, D, 5, true, 32, 1>>();\n        Add<Cg::Type<128,  16, 64, 4, 1, 1, D, D, 3, true, 32, 1>>();\n\n        using C8 = Config_MXF4<Sm90, bfloat16_t, 8, kColMajor, 1>;\n        Add<C8::Type<256, 8,  32, 4, 1, 1, D, D, 5, true, 32, 1>>();\n        Add<C8::Type<128, 8,  32, 4, 1, 1, D, D, 5, true, 32, 1>>();\n        Add<C8::Type<128, 8,  64, 4, 1, 1, D, D, 3, true, 32, 1>>();\n        Add<C8::Type< 64, 8,  64, 4, 1, 1, D, D, 5, true, 32, 1>>();\n        // clang-format on\n    }\n}\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/kernel/sm90_16816_8.cu",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include \"src/turbomind/kernels/gemm/arch.h\"\n#include \"src/turbomind/kernels/gemm/arch/config_sm80_s16816.h\"\n#include \"src/turbomind/kernels/gemm/registry.h\"\n#include \"src/turbomind/kernels/gemm/types.h\"\n\nnamespace turbomind::gemm {\n\nusing namespace sm80_s16816;\nusing namespace cache_policy;\nusing S = cache_policy::Stream;\nusing D = cache_policy::Default;\n\nvoid Registry::sm90_16816_8()\n{\n    if constexpr (1) {\n        // clang-format off\n        using Cd = Config_E4M3<Sm90, bfloat16_t, 16, kColMajor>;\n        // Add<Cd::Type<256, 128, 32, 8, 1, 1, D, D, 3, true, 128, 1, 128, 128>>();\n\n        using Cg = Config_E4M3<Sm90, bfloat16_t, 16, kColMajor, 1>;\n        Add<Cg::Type<256, 128,  32, 8, 1, 1, D, D, 3, true, 128, 1, 128, 128>>();\n        Add<Cg::Type<256,  64,  32, 4, 1, 1, D, D, 3, true, 128, 1, 128,  64>>();\n        Add<Cg::Type<256,  32,  64, 4, 1, 1, D, D, 3, true, 128, 1>>();\n        Add<Cg::Type<128, 128,  32, 4, 1, 1, D, D, 3, true, 128, 1, 128,  64>>();\n        Add<Cg::Type<128,  96,  32, 4, 1, 1, D, D, 3, true, 128, 1>>();\n        Add<Cg::Type<128,  64,  32, 4, 1, 1, D, D, 3, true, 128, 1>>();\n        Add<Cg::Type<128,  32,  32, 4, 1, 1, D, D, 3, true, 128, 1>>();\n        Add<Cg::Type<128,  16,  64, 4, 1, 1, D, D, 3, true, 128, 1>>();\n        Add<Cg::Type<128,  16,  32, 4, 1, 1, D, D, 5, true, 128, 1>>();\n\n        using C8 = Config_E4M3<Sm90, bfloat16_t, 8, kColMajor, 1>;\n        Add<C8::Type<256, 8,  64, 4, 1, 1, D, D, 3, true, 128, 1>>();\n        Add<C8::Type<128, 8,  64, 4, 1, 1, D, D, 3, true, 128, 1>>();\n        Add<C8::Type< 64, 8, 128, 4, 1, 1, D, D, 3, true, 128, 1>>();\n        // clang-format on\n    }\n}\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/kernel/sm90_64n32_8.cu",
    "content": "\n#include <cuda.h>\n\n#include \"src/turbomind/kernels/gemm/registry.h\"\n\n// We need modifiable TMA, which is added in 12.3\n#if (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ >= 12 && __CUDACC_VER_MINOR__ >= 3))\n\n#include \"src/turbomind/kernels/gemm/arch.h\"\n#include \"src/turbomind/kernels/gemm/gemm_universal_sm90_v3.h\"\n#include \"src/turbomind/kernels/gemm/gemm_universal_sm90_v5.h\"\n#include \"src/turbomind/kernels/gemm/kernel_impl_sm90.h\"\n#include \"src/turbomind/kernels/gemm/types.h\"\n\nnamespace turbomind::gemm {\n\nvoid Registry::sm90_64n32_8()\n{\n    Add(std::make_unique<KernelImplSm90<GemmUniversalSm90_v5<kRowMajor, 1, 1, false>>>());\n    Add(std::make_unique<KernelImplSm90<GemmUniversalSm90_v5<kRowMajor, 2, 1, false>>>());\n    Add(std::make_unique<KernelImplSm90<GemmUniversalSm90_v5<kRowMajor, 1, 2, false>>>());\n\n    Add(std::make_unique<KernelImplSm90<GemmUniversalSm90_v5<kColMajor, 1, 1, true>>>());\n    Add(std::make_unique<KernelImplSm90<GemmUniversalSm90_v5<kColMajor, 2, 1, true>>>());\n    Add(std::make_unique<KernelImplSm90<GemmUniversalSm90_v5<kColMajor, 1, 2, true>>>());\n\n    Add(std::make_unique<KernelImplSm90<GemmUniversalSm90_v3<kRowMajor, 1, 1, false>>>());\n    Add(std::make_unique<KernelImplSm90<GemmUniversalSm90_v3<kRowMajor, 2, 1, false>>>());\n    Add(std::make_unique<KernelImplSm90<GemmUniversalSm90_v3<kRowMajor, 1, 2, false>>>());\n\n    Add(std::make_unique<KernelImplSm90<GemmUniversalSm90_v3<kColMajor, 1, 1, true>>>());\n    Add(std::make_unique<KernelImplSm90<GemmUniversalSm90_v3<kColMajor, 2, 1, true>>>());\n    Add(std::make_unique<KernelImplSm90<GemmUniversalSm90_v3<kColMajor, 1, 2, true>>>());\n}\n\n}  // namespace turbomind::gemm\n\n#else\n\nnamespace turbomind::gemm {\n\nvoid Registry::sm90_64n32_8() {}\n\n}  // namespace turbomind::gemm\n\n#endif\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/kernel.cu",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include <algorithm>\n#include <iostream>\n#include <numeric>\n#include <sstream>\n\n#include \"src/turbomind/kernels/core/math.h\"\n#include \"src/turbomind/kernels/gemm/arch.h\"\n#include \"src/turbomind/kernels/gemm/desc.h\"\n#include \"src/turbomind/kernels/gemm/kernel.h\"\n#include \"src/turbomind/kernels/gemm/types.h\"\n#include \"src/turbomind/kernels/gemm/utils.h\"\n\nnamespace turbomind::gemm {\n\nbool accept(Striding a, Striding b)\n{\n    if (a == Striding::kBlocked) {\n        switch (b) {\n            case Striding::kBlocked:\n            case Striding::kFlat:\n                return true;\n            default:\n                return false;\n        }\n    }\n    else if (a == Striding::kIndexed) {\n        switch (b) {\n            case Striding::kFlat:\n            case Striding::kBlocked:\n            case Striding::kIndexed:\n                return true;\n            default:\n                return false;\n        }\n    }\n    else {\n        return a == b;\n    }\n}\n\nbool Kernel::is_feasible(const GemmDesc& desc) const noexcept\n{\n    constexpr bool debug = 0;\n\n    if constexpr (debug)\n        printf(\"S\\n\");\n\n    // printf(\"%d %d\\n\", desc.arch, desc_.arch);\n\n    if (!is_arch_compatible(desc_.arch, desc.arch)) {\n        return false;\n    }\n\n    if constexpr (debug)\n        printf(\"S0\\n\");\n\n    if (std::tie(desc.order_a, desc.order_b, desc.order_c) != std::tie(desc_.order_a, desc_.order_b, desc_.order_c)) {\n        return false;\n    }\n\n    if (desc.group_axis >= 0 && desc.group_axis != desc_.group_axis) {\n        return false;\n    }\n\n    if (!(accept(desc_.striding_a, desc.striding_a)     //\n          && accept(desc_.striding_b, desc.striding_b)  //\n          && accept(desc_.striding_c, desc.striding_c))) {\n        return false;\n    }\n\n    if constexpr (debug)\n        printf(\"A\\n\");\n\n    if (std::tie(desc.type_a, desc.type_b, desc.type_c) != std::tie(desc_.type_a, desc_.type_b, desc_.type_c)) {\n        return false;\n    }\n\n    if constexpr (debug) {\n        printf(\"B\\n\");\n        printf(\"%X %X %X %X\\n\", desc.pack_a, desc_.pack_a, desc.pack_u, desc_.pack_u);\n    }\n\n    if (std::tie(desc.pack_a, desc.pack_u) != std::tie(desc_.pack_a, desc_.pack_u)) {\n        return false;\n    }\n\n    if constexpr (debug) {\n        printf(\"C\\n\");\n        printf(\"%X %X %X %X\\n\", desc.pack_b, desc_.pack_b, desc.pack_v, desc_.pack_v);\n    }\n\n    if (std::tie(desc.pack_b, desc.pack_v) != std::tie(desc_.pack_b, desc_.pack_v)) {\n        return false;\n    }\n\n    if constexpr (debug)\n        printf(\"D\\n\");\n\n    if (desc.quant_a.type != desc_.quant_a.type || desc.quant_a.group_size != desc_.quant_a.group_size) {\n        return false;\n    }\n\n    if constexpr (debug)\n        printf(\"E\\n\");\n\n    if (desc.quant_b.type != desc_.quant_b.type || desc.quant_b.group_size != desc_.quant_b.group_size) {\n        return false;\n    }\n\n    if constexpr (debug)\n        printf(\"F\\n\");\n\n    if (desc.m % desc_.align.x || desc.n % desc_.align.y || desc.k % desc_.align.z) {\n        return false;\n    }\n\n    if constexpr (debug)\n        printf(\"success\\n\");\n\n    return true;\n}\n\n//  mm:     m * n * k,     m * k,     n * k,     m * n\n// Bmm: b * m * n * k, b * m * k, b * n * k, b * m * n\n// Gmm: S $ M * n * k, S $ M * k, S $ n * k, S $ M * n\n\nstd::string Kernel::GetName() const\n{\n    std::stringstream ss;\n\n    ss << \"sm\" << desc_.arch / 10;\n    ss << \"_\" << to_string(desc_.type_a);  //\n    if (desc_.quant_a) {\n        ss << to_string(desc_.quant_a);\n    }\n    ss << \"_\" << to_string(desc_.type_b);  //\n    if (desc_.quant_b) {\n        ss << to_string(desc_.quant_b);\n    }\n    ss << \"_\" << to_string(desc_.type_c);\n    ss << \"_\"                                        //\n       << (desc_.order_a == kColMajor ? 'n' : 't')   //\n       << (desc_.order_b == kColMajor ? 'n' : 't')   //\n       << (desc_.order_c == kColMajor ? 'n' : 't');  //\n    ss << \"_\"                                        //\n       << to_string(desc_.striding_a)                //\n       << to_string(desc_.striding_b)                //\n       << to_string(desc_.striding_c);\n    ss << \"_\" << desc_.cta_tile.x << \"x\" << desc_.cta_tile.y << \"x\" << desc_.cta_tile.z  //\n       << \"_\" << desc_.stages                                                            //\n       << \"_\" << desc_.cluster_shape.x << \"x\" << desc_.cluster_shape.y                   //\n       << \"_\" << to_string(desc_.op_class)                                               //\n       << \"_\" << desc_.mma_tile.x << \"x\" << desc_.mma_tile.y << \"x\" << desc_.mma_tile.z;\n    if (desc_.group_axis >= 0) {\n        ss << \"_\"\n           << \"mn\"[desc_.group_axis] << \"group\";\n    }\n    ss << \"_c\" << desc_.c_tile.x << \"x\" << desc_.c_tile.y                        //\n       << \"_a\" << desc_.align.x << \"x\" << desc_.align.y << \"x\" << desc_.align.z  //\n       << \"_\" << desc_.policy_a << desc_.policy_b;\n\n    return ss.str();\n}\n\nclass TransposedKernel: public Kernel {\npublic:\n    explicit TransposedKernel(Kernel& kernel): kernel_(&kernel)\n    {\n        desc_ = kernel.desc();\n        info_ = kernel.info();\n\n        desc_.transpose = !desc_.transpose;\n    }\n\n    int Launch(const Operation&    operation,\n               float               alpha,\n               const void*         A,\n               const MatrixLayout& Adesc,\n               const void*         U,\n               const MatrixLayout& Udesc,\n               const void*         B,\n               const MatrixLayout& Bdesc,\n               const void*         V,\n               const MatrixLayout& Vdesc,\n               float               beta,\n               const void*         C,\n               const MatrixLayout& Cdesc,\n               void*               D,\n               const MatrixLayout& Ddesc,\n               int                 swizzle,\n               int                 splits,\n               Workspace&          workspace,\n               cudaStream_t        stream) override\n    {\n        return kernel_->Launch(transpose(operation),\n                               alpha,\n                               B,\n                               transpose(Bdesc),\n                               V,\n                               transpose(Vdesc),\n                               A,\n                               transpose(Adesc),\n                               U,\n                               transpose(Udesc),\n                               beta,\n                               C,\n                               transpose(Cdesc),\n                               D,\n                               transpose(Ddesc),\n                               swizzle,\n                               splits,\n                               workspace,\n                               stream);\n    }\n\n    bool is_feasible(const GemmDesc& desc) const noexcept override\n    {\n        return kernel_->is_feasible(desc);\n    }\n\n    int GetMaxSwizzle(const int4& shape) const override\n    {\n        return kernel_->GetMaxSwizzle(shape);\n    }\n\n    int GetMaxSplits(const int4& shape, int swizzle, size_t bsize, size_t psize) const override\n    {\n        return kernel_->GetMaxSplits(shape, swizzle, bsize, psize);\n    }\n\nprivate:\n    Kernel* kernel_;\n};\n\nstd::unique_ptr<Kernel> transpose(Kernel& kernel)\n{\n    return std::make_unique<TransposedKernel>(kernel);\n}\n\ntemplate<class Op>\ninline static bool cmp(const int3& a, const int3& b, Op op)\n{\n    return op(std::tie(a.x, a.y, a.z), std::tie(b.x, b.y, b.z));\n}\n\nstd::vector<std::vector<LaunchSpec>> Cluster(const std::vector<LaunchSpec>& specs, const ClusteringParam& param)\n{\n    std::vector<const LaunchSpec*> ptrs;  // pointer into `specs`\n    for (auto& s : specs) {\n        ptrs.push_back(&s);\n    }\n\n    auto less = [&](const LaunchSpec* u, const LaunchSpec* v) {\n        const auto& a = u->kernel->desc();\n        const auto& b = v->kernel->desc();\n        if (!cmp(a.cta_tile, b.cta_tile, std::equal_to<>{})) {\n            return cmp(a.cta_tile, b.cta_tile, std::less<>{});\n        }\n        if (!cmp(a.mma_tile, b.mma_tile, std::equal_to<>{})) {\n            return cmp(a.mma_tile, b.mma_tile, std::less<>{});\n        }\n        if (param.cache_policy) {\n            const auto pa = std::tie(a.policy_a, a.policy_b);\n            const auto pb = std::tie(b.policy_a, b.policy_b);\n            if (pa != pb) {\n                return pa < pb;\n            }\n        }\n        if (param.max_active_ctas) {\n            const auto& a = u->kernel->info();\n            const auto& b = v->kernel->info();\n            if (a.max_active_ctas != b.max_active_ctas) {\n                return a.max_active_ctas < b.max_active_ctas;\n            }\n        }\n        return u->splits < v->splits;\n    };\n\n    std::stable_sort(ptrs.begin(), ptrs.end(), less);\n\n    if (ptrs.empty()) {\n        return {};\n    }\n    std::vector<std::vector<LaunchSpec>> clusters{{*ptrs[0]}};\n\n    auto equal = [&](const LaunchSpec* u, const LaunchSpec* v) {  //\n        return !less(u, v) && !less(v, u);\n    };\n    int p = 0;\n    for (size_t i = 1; i < ptrs.size(); ++i) {\n        if (equal(ptrs[p], ptrs[i])) {\n            clusters.back().push_back(*ptrs[i]);\n        }\n        else {\n            clusters.push_back({*ptrs[i]});\n            p = i;\n        }\n    }\n\n    return clusters;\n}\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/kernel.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include <array>\n#include <memory>\n#include <string>\n#include <utility>\n#include <vector>\n\n#include <cuda_runtime.h>\n\n#include \"src/turbomind/kernels/gemm/desc.h\"\n#include \"src/turbomind/kernels/gemm/types.h\"\n\nnamespace turbomind::gemm {\n\nstruct KernelMetric {\n    int64_t mio_cost;\n    int64_t mma_cost;\n};\n\nclass Kernel {\npublic:\n    Kernel(): desc_{}, info_{} {}\n\n    virtual ~Kernel() = default;\n\n    virtual int Launch(const Operation&    operation,\n                       float               alpha,\n                       const void*         A,\n                       const MatrixLayout& Adesc,\n                       const void*         U,\n                       const MatrixLayout& Udesc,\n                       const void*         B,\n                       const MatrixLayout& Bdesc,\n                       const void*         V,\n                       const MatrixLayout& Vdesc,\n                       float               beta,\n                       const void*         C,\n                       const MatrixLayout& Cdesc,\n                       void*               D,\n                       const MatrixLayout& Ddesc,\n                       int                 swizzle,\n                       int                 splits,\n                       Workspace&          workspace,\n                       cudaStream_t        stream) = 0;\n\n    // true if this kernel can be used to compute the gemm\n    virtual bool is_feasible(const GemmDesc& desc) const noexcept;\n\n    virtual int GetMaxSwizzle(const int4& shape) const = 0;\n\n    virtual int GetMaxSplits(const int4& shape, int swizzle, size_t bsize, size_t psize) const = 0;\n\n    const KernelDesc& desc() const noexcept\n    {\n        return desc_;\n    }\n\n    const KernelInfo& info() const noexcept\n    {\n        return info_;\n    }\n\n    int3 cta_tile_size() const noexcept\n    {\n        return desc_.cta_tile;\n    }\n\n    int3 warp_tile_size() const noexcept\n    {\n        return desc_.mma_tile;\n    }\n\n    int chunk_size_k() const noexcept\n    {\n        return info_.chunk_size_k;\n    }\n\n    int stages() const noexcept\n    {\n        return desc_.stages;\n    }\n\n    bool split_k() const noexcept\n    {\n        return desc_.split_k;\n    }\n\n    int arch() const noexcept\n    {\n        return desc_.arch;\n    }\n\n    int smem_size() const noexcept\n    {\n        return info_.attr.sharedSizeBytes + info_.dynamic_smem_size;\n    }\n\n    std::string name() const\n    {\n        return info_.name;\n    }\n\nprotected:\n    std::string GetName() const;\n\n    KernelDesc desc_;\n    KernelInfo info_;\n};\n\nstruct ClusteringParam {\n    bool cache_policy;\n    bool max_active_ctas;\n};\n\nstd::vector<std::vector<LaunchSpec>> Cluster(const std::vector<LaunchSpec>& specs, const ClusteringParam& param);\n\nstd::unique_ptr<Kernel> transpose(Kernel& kernel);\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/kernel_impl.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include \"src/turbomind/core/data_type.h\"\n\n#include \"src/turbomind/kernels/core/common.h\"\n#include \"src/turbomind/kernels/core/data_type.h\"\n\n#include \"src/turbomind/kernels/gemm/context.h\"\n#include \"src/turbomind/kernels/gemm/cta_map.h\"\n#include \"src/turbomind/kernels/gemm/desc.h\"\n#include \"src/turbomind/kernels/gemm/epilogue.h\"\n#include \"src/turbomind/kernels/gemm/gemm_universal.h\"\n#include \"src/turbomind/kernels/gemm/kernel.h\"\n#include \"src/turbomind/kernels/gemm/matrix_ptr.h\"\n#include \"src/turbomind/kernels/gemm/operand.h\"\n#include \"src/turbomind/kernels/gemm/thread_group_map.h\"\n#include \"src/turbomind/kernels/gemm/types.h\"\n#include \"src/turbomind/kernels/gemm/utils.h\"\n\nnamespace turbomind::gemm {\n\ntemplate<class Gemm>\nclass KernelImpl: public Kernel {\npublic:\n    // import frequently used constants\n    static constexpr int CTA_M = Gemm::CTA_M;\n    static constexpr int CTA_N = Gemm::CTA_N;\n    static constexpr int CTA_K = Gemm::CTA_K;\n\n    using Impl  = typename Gemm::Impl;\n    using Sched = typename Gemm::Scheduler;\n\n    using OpA = typename Gemm::OperandA;\n    using OpB = typename Gemm::OperandB;\n    using OpU = typename Gemm::OperandU;\n    using OpV = typename Gemm::OperandV;\n\n    KernelImpl()\n    {\n        desc_.order_a = OpA::kOrder;\n        desc_.order_b = transpose(OpB::kOrder);\n        desc_.order_c = Gemm::kOrderC;\n\n        desc_.type_a = data_type_v<typename Gemm::Ta>;\n        desc_.type_b = data_type_v<typename Gemm::Tb>;\n        desc_.type_c = data_type_v<typename Gemm::Tc>;\n\n        using IterA = typename OpA::GmemIter;\n        using IterB = typename OpB::GmemIter;\n\n        desc_.striding_a = IterA::kMode;\n        desc_.striding_b = IterB::kMode;\n        desc_.striding_c = Gemm::Epilogue::kMode;\n\n        desc_.pack_a = OpA::kPack;\n        desc_.pack_b = OpB::kPack;\n        desc_.pack_u = OpU::kPack;\n        desc_.pack_v = OpV::kPack;\n\n        desc_.quant_a = QuantDesc{};\n        desc_.quant_b = QuantDesc{};\n\n        if constexpr (OpU::SmemLayout::kSize > 1) {\n            desc_.quant_a = QuantDesc{QuantType::kDefault, OpU::kGroupSize};\n        }\n\n        if constexpr (OpV::SmemLayout::kSize > 1) {\n            desc_.quant_b = QuantDesc{QuantType::kDefault, OpV::kGroupSize};\n        }\n\n        desc_.cta_tile = {Gemm::CTA_M, Gemm::CTA_N, Gemm::CTA_K};\n        desc_.mma_tile = {Impl::MMA_Map::kGroupM, Impl::MMA_Map::kGroupN, Impl::MMA_Map::kGroupK};\n\n        info_.chunk_size_k = Gemm::kChunkSizeK;\n\n        desc_.align.x = OpA::kOrder == kColMajor ? IterA::ThreadMap::kAccessC : 1;\n        desc_.align.y = OpB::kOrder == kColMajor ? IterB::ThreadMap::kAccessC : 1;\n        desc_.align.z = Gemm::CTA_K;\n\n        desc_.policy_a = (int)IterA::Policy::kEvictPolicy;\n        desc_.policy_b = (int)IterB::Policy::kEvictPolicy;\n        desc_.c_tile   = {Gemm::Epilogue::TM, Gemm::Epilogue::TN};\n        desc_.op_class = Impl::kOpClass;\n\n        desc_.cluster_shape = {1, 1};\n\n        auto func = gemm_kernel<Gemm, GemmParam, EpilogueParam, Sched>;\n\n        cudaFuncGetAttributes(&info_.attr, func);\n\n        info_.dynamic_smem_size = sizeof(typename Gemm::SharedStorage);\n\n        if (info_.dynamic_smem_size > (48 << 10)) {\n            cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, info_.dynamic_smem_size);\n        }\n\n        cudaOccupancyMaxActiveBlocksPerMultiprocessor(\n            &info_.max_active_ctas, func, Impl::WARPS * WARP_SIZE, info_.dynamic_smem_size);\n\n        desc_.stages     = Impl::Stages;\n        desc_.split_k    = Gemm::kSplitK;\n        desc_.group_axis = Sched::group_axis;\n\n        desc_.arch = Gemm::Arch::value;\n\n        info_.name = GetName();\n    }\n\n    int Launch(const Operation&    operation,\n               float               alpha,\n               const void*         A,\n               const MatrixLayout& _Adesc,\n               const void*         U,\n               const MatrixLayout& Udesc,\n               const void*         B,\n               const MatrixLayout& _Bdesc,\n               const void*         V,\n               const MatrixLayout& _Vdesc,\n               float               beta,\n               const void*         C,\n               const MatrixLayout& Cdesc,\n               void*               D,\n               const MatrixLayout& Ddesc,\n               int                 swizzle,\n               int                 splits,\n               Workspace&          workspace,\n               cudaStream_t        stream) override\n    {\n        MatrixLayout Adesc = _Adesc;\n\n        const int m = Ddesc.rows;\n        const int n = Ddesc.cols;\n        const int k = Adesc.cols;\n        const int l = std::max(1, Ddesc.num);\n\n        auto transpose = [](MatrixLayout x) {\n            std::swap(x.rows, x.cols);\n            x.order = gemm::transpose(x.order);\n            return x;\n        };\n\n        MatrixLayout Bdesc = transpose(_Bdesc);\n        MatrixLayout Vdesc = transpose(_Vdesc);\n\n        auto max_splits = GetMaxSplits({m, n, k, l}, swizzle, workspace.barriers_size, workspace.partials_size);\n\n        Sched sched{{m, n, k, l}, swizzle, std::min(splits, max_splits)};\n        sched.offsets_ = Ddesc.offsets;\n\n        using Ta = typename Gemm::Ta;\n        using Tb = typename Gemm::Tb;\n        using Tc = typename Gemm::Tc;\n\n        if constexpr (0) {\n            [[maybe_unused]] static const int _ = [] {\n                std::cout << \"A:\\n\";\n                Print(typename Gemm::OperandA::GmemIter::ThreadMap{});\n                std::cout << \"\\nB:\\n\";\n                Print(typename Gemm::OperandB::GmemIter::ThreadMap{});\n                if constexpr (!std::is_same_v<Ta, Tc>) {\n                    std::cout << \"\\nU:\\n\";\n                    Print(typename Gemm::OperandU::GmemIter::ThreadMap{});\n                }\n                if constexpr (!std::is_same_v<Tb, Tc>) {\n                    std::cout << \"\\nV:\\n\";\n                    Print(typename Gemm::OperandV::GmemIter::ThreadMap{});\n                }\n                printf(\"warp count: %d\\n\", Impl::WARPS);\n                Print_(typename Gemm::Impl::MMA_Map{});\n\n                printf(\"C:\\n\");\n                Print(typename Gemm::Epilogue::Map{});\n\n                std::cout << \"Smem for mainloop: \" << sizeof(Gemm::SharedStorage::mainloop) << \"\\n\";\n                std::cout << \"Smem for epilogue: \" << sizeof(Gemm::SharedStorage::epilogue) << \"\\n\";\n\n                return 0;\n            }();\n        }\n\n        const bool silu_act = ((int)operation.epilogue & (int)Epilogue::kGatedSilu);\n\n        MatrixLayout Pdesc = Ddesc;\n        Pdesc.ld           = mk2cs<Gemm::kOrderC>(Pdesc.rows, Pdesc.cols).x;\n\n        MatrixCombination_v3 combin_mat{to_param((void*)C, Cdesc), alpha, beta};\n\n        EpilogueParam epilogue{to_param((void*)D, Ddesc),\n                               to_param((void*)workspace.partials, Pdesc),\n                               (int*)workspace.barriers,\n                               combin_mat,\n                               silu_act};\n\n        // std::cout << Adesc.offsets << \" \" << Adesc.idxs << \"\\n\";\n\n        GemmParam param{\n            to_param((void*)A, Adesc),\n            to_param((void*)B, Bdesc),\n            to_param((void*)U, Udesc),\n            to_param((void*)V, Vdesc),\n        };\n\n        const auto grid  = sched.get_grid_shape();\n        const auto block = Gemm::Impl::WARPS * WARP_SIZE;\n\n        // std::cout << info_.name << \" \" << splits << \" \" << swizzle << \" \" << sched.tiles_[0] << \" \" <<\n        // sched.tiles_[1]\n        //           << std::endl;\n        // std::cout << grid.x << \" \" << grid.y << \" \" << grid.z << \"\\n\";\n\n        gemm_kernel<Gemm><<<grid, block, info_.dynamic_smem_size, stream>>>(param, epilogue, sched);\n\n        return 0;\n    }\n\n    std::array<size_t, 2> GetWorkspaceSize(int tiles, int splits) const\n    {\n        static constexpr bool kSerial = true;\n\n        size_t barriers_size = sizeof(int) * tiles;\n        size_t partials_size = sizeof(float) * CTA_M * CTA_N * tiles;\n\n        if constexpr (!kSerial) {\n            barriers_size *= splits;\n            partials_size *= splits;\n        }\n\n        return {barriers_size, partials_size};\n    }\n\n    int GetMaxSplits(const int4& shape, int swizzle, size_t bsize, size_t psize) const override\n    {\n        if (!Gemm::kSplitK) {\n            return 1;\n        }\n\n        const auto& [m, n, k, l] = shape;\n\n        Sched sched{{m, n, k, l}, swizzle};  // for getting padded tiles\n\n        const auto& [a, b] = GetWorkspaceSize(sched.tiles_[0] * sched.tiles_[1], 1);\n\n        if (bsize >= a && psize >= b) {\n            // Serial split-k requires workspace for 1 split only\n            // But it can't exceed num of k chunks\n            return cdiv(k, Gemm::kChunkSizeK);\n        }\n        else {\n            return 1;\n        }\n    }\n\n    int GetMaxSwizzle(const int4& shape) const override\n    {\n        const auto& [m, n, k, l] = shape;\n\n        auto swizzle = Sched{{m, n, k, l}}.get_max_swizzle();\n        // std::cout << m << \" \" << n << \" \" << k << \" \" << l << \" \" << swizzle << \"\\n\";\n        return swizzle;\n    }\n};\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/kernel_impl_sm90.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include \"cute/util/debug.hpp\"\n#include \"src/turbomind/core/check.h\"\n#include \"src/turbomind/kernels/core/common.h\"\n#include \"src/turbomind/kernels/core/data_type.h\"\n#include \"src/turbomind/kernels/gemm/context.h\"\n#include \"src/turbomind/kernels/gemm/cta_map.h\"\n#include \"src/turbomind/kernels/gemm/desc.h\"\n#include \"src/turbomind/kernels/gemm/epilogue.h\"\n#include \"src/turbomind/kernels/gemm/gemm_universal_sm90_v3.h\"\n#include \"src/turbomind/kernels/gemm/kernel.h\"\n#include \"src/turbomind/kernels/gemm/matrix_ptr.h\"\n#include \"src/turbomind/kernels/gemm/operand.h\"\n#include \"src/turbomind/kernels/gemm/thread_group_map.h\"\n#include \"src/turbomind/kernels/gemm/types.h\"\n#include \"src/turbomind/kernels/gemm/utils.h\"\n\n#include \"src/turbomind/kernels/gemm/tma.h\"\n\n#include \"src/turbomind/utils/cuda_utils.h\"\n\n#define TM_GEMM_CUTLASS_NAME 0\n\n#if TM_GEMM_CUTLASS_NAME\n#define gemm_kernel_name cutlass_gemm_kernel_sm90\n#else\n#define gemm_kernel_name gemm_kernel_sm90\n#endif\n\nnamespace turbomind::gemm {\n\nextern __shared__ char smem_buf[];\n\ntemplate<class Kernel>\n__global__ void __launch_bounds__(Kernel::CTA_SIZE, 1) gemm_kernel_name(const __grid_constant__ CUtensorMap tm_a,\n                                                                        const __grid_constant__ CUtensorMap tm_b,\n                                                                        const __grid_constant__ CUtensorMap tm_c,\n                                                                        const __grid_constant__ CUtensorMap tm_u,\n                                                                        const __grid_constant__ CUtensorMap tm_v,\n                                                                        const MatrixParam                   param_A,\n                                                                        const MatrixParam                   param_B,\n                                                                        const MatrixParam                   param_U,\n                                                                        const MatrixParam                   param_V,\n                                                                        const MatrixParam                   param_C,\n                                                                        typename Kernel::Scheduler          sched,\n                                                                        void* tensormap_buf)\n{\n\n#if __CUDA_ARCH__\n    if constexpr (Kernel::Arch::is_compatible(__CUDA_ARCH__)) {\n        Kernel kernel;\n        kernel(tm_a,\n               tm_b,\n               tm_c,\n               tm_u,\n               tm_v,\n               param_A,\n               param_B,\n               param_U,\n               param_V,\n               param_C,\n               sched,\n               (CUtensorMap*)tensormap_buf,\n               smem_buf);\n    }\n#endif\n}\n\ntemplate<class Gemm>\nclass KernelImplSm90: public Kernel {\npublic:\n    // import frequently used constants\n    static constexpr int TILE_M = Gemm::TILE_M;\n    static constexpr int TILE_N = Gemm::TILE_N;\n    static constexpr int TILE_K = Gemm::TILE_K;\n\n    static constexpr auto is_grouped_gemm = Gemm::is_grouped_gemm;\n\n    KernelImplSm90()\n    {\n        desc_.order_a = kRowMajor;  // m, k\n        desc_.order_b = kColMajor;  // k, n\n        desc_.order_c = kRowMajor;\n\n        desc_.type_a = data_type_v<typename Gemm::Ta>;\n        desc_.type_b = data_type_v<typename Gemm::Tb>;\n        desc_.type_c = data_type_v<typename Gemm::Tc>;\n\n        desc_.striding_a = {is_grouped_gemm ? Striding::kBlocked : Striding::kFlat};  // IterA::kMode;\n        desc_.striding_b = {is_grouped_gemm ? Striding::kBlocked : Striding::kFlat};  // IterB::kMode;\n        desc_.striding_c = {is_grouped_gemm ? Striding::kBlocked : Striding::kFlat};  // Gemm::Epilogue::kMode;\n\n        desc_.pack_a = {};  // OpA::kPack;\n        desc_.pack_b = {};  // OpB::kPack;\n        desc_.pack_u = {};  // OpU::kPack;\n        desc_.pack_v = {};  // OpV::kPack;\n\n        desc_.quant_a = QuantDesc{QuantType::kK, 128};\n        desc_.quant_b = QuantDesc{QuantType::kB, 128};\n\n        desc_.cta_tile = {TILE_M, TILE_N, TILE_K};\n        desc_.mma_tile = {1, 1, 1};\n\n        info_.chunk_size_k = Gemm::TILE_K;\n\n        desc_.align.x = 1;  // OpA::kOrder == kColMajor ? IterA::ThreadMap::kAccessC : 1;\n        desc_.align.y = 1;  // OpB::kOrder == kColMajor ? IterB::ThreadMap::kAccessC : 1;\n        desc_.align.z = 1;  // Gemm::TILE_K;\n\n        desc_.policy_a = 0;                 // (int)IterA::Policy::kEvictPolicy;\n        desc_.policy_b = 0;                 // (int)IterB::Policy::kEvictPolicy;\n        desc_.c_tile   = {TILE_M, TILE_N};  // {Gemm::Epilogue::TM, Gemm::Epilogue::TN};\n        desc_.op_class = OpClass::kGMMA_s64n16;\n\n        desc_.cluster_shape = {Gemm::Cluster::M, Gemm::Cluster::N};\n\n        info_.dynamic_smem_size = Gemm::kSmemSize;\n\n        desc_.stages     = Gemm::Stages;\n        desc_.split_k    = 1;  // Gemm::kSplitK;\n        desc_.group_axis = is_grouped_gemm ? 0 : -1;\n\n        desc_.arch = Gemm::Arch::value;\n\n        auto func = gemm_kernel_name<Gemm>;\n\n        cudaFuncGetAttributes(&info_.attr, func);\n\n        if (info_.dynamic_smem_size > (48 << 10)) {\n            cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, info_.dynamic_smem_size);\n        }\n\n        if (1) {\n            cudaFuncSetAttribute(func, cudaFuncAttributeNonPortableClusterSizeAllowed, 16);\n        }\n\n        cudaOccupancyMaxActiveBlocksPerMultiprocessor(\n            &info_.max_active_ctas, func, Gemm::CTA_SIZE, info_.dynamic_smem_size);\n\n        sm_count_ = getSMCount();\n\n        info_.name = GetName();\n    }\n\n    int Launch(const Operation&    operation,\n               float               alpha,\n               const void*         A,\n               const MatrixLayout& _Adesc,\n               const void*         U,\n               const MatrixLayout& Udesc,\n               const void*         B,\n               const MatrixLayout& _Bdesc,\n               const void*         V,\n               const MatrixLayout& _Vdesc,\n               float               beta,\n               const void*         C,\n               const MatrixLayout& Cdesc,\n               void*               D,\n               const MatrixLayout& Ddesc,\n               int                 swizzle,\n               int                 splits,\n               Workspace&          workspace,\n               cudaStream_t        stream) override\n    {\n        using Sched = typename Gemm::Scheduler;\n\n        MatrixLayout Adesc = _Adesc;\n\n        [[maybe_unused]] const int m = Ddesc.rows;\n        [[maybe_unused]] const int n = Ddesc.cols;\n        [[maybe_unused]] const int k = Adesc.cols;\n\n        // std::cout << \"M: \" << m << \", N: \" << n << \", K: \" << k << \"\\n\";\n\n        auto transpose = [](MatrixLayout x) {\n            std::swap(x.rows, x.cols);\n            x.order = gemm::transpose(x.order);\n            return x;\n        };\n\n        // (K, N) -> (N, K)\n        MatrixLayout Bdesc = transpose(_Bdesc);\n        MatrixLayout Vdesc = transpose(_Vdesc);\n\n        auto sched = [&] {\n            const int2 tiles = get_tiled_shape(m, n, TILE_M, TILE_N);\n            const int4 shape{m, n, k, Adesc.num};\n\n            swizzle = Sched::get_log_tile(tiles, 1 << swizzle);\n\n            Sched sched{};\n            sched.init(shape, swizzle, {TILE_M, TILE_N, TILE_K});\n\n            sched.next_cluster_id_ = TM_CHECK_NOTNULL(workspace.flags);\n\n            sched.offsets_ = Adesc.offsets;\n\n            return sched;\n        }();\n\n        constexpr int kMulticastA = Gemm::kMulticastA;\n        constexpr int kMulticastB = Gemm::kMulticastB;\n        constexpr int kMulticastU = Gemm::kMulticastU;\n\n        constexpr int kTileM = Gemm::TILE_M;\n        constexpr int kTileN = Gemm::TILE_N;\n\n        if (Gemm::Scheduler::is_dynamic) {\n            check_cuda_error(cudaMemsetAsync(workspace.flags, 0, sizeof(int), stream));\n        }\n\n        // std::cout << \"A: \" << Adesc << \"\\n\";\n        auto tm_a = make_2d_tma_desc((void*)A, Adesc, {kTileM / kMulticastA, TILE_K}, CU_TENSOR_MAP_SWIZZLE_128B);\n\n        // std::cout << \"B: \" << Bdesc << \"\\n\";\n        auto tm_b = make_2d_tma_desc(Gemm::is_grouped_gemm ? nullptr : (void*)B,\n                                     Bdesc,\n                                     {kTileN / kMulticastB, TILE_K},\n                                     CU_TENSOR_MAP_SWIZZLE_128B);\n\n        // std::cout << \"C: \" << Cdesc << \"\\n\";\n        using LayoutC = typename Gemm::LayoutC;\n        auto tm_c     = make_2d_tma_desc((void*)C, Cdesc, {LayoutC::S0, LayoutC::C0}, get_tma_swizzle(Gemm::kSwizzleC));\n\n        CUtensorMap tm_u{};\n        if (U) {\n            // std::cout << \"U: \" << Udesc << \"\\n\";\n            tm_u = make_2d_tma_desc((void*)U, Udesc, {Gemm::kBoxU / kMulticastU, 1}, CU_TENSOR_MAP_SWIZZLE_NONE);\n        }\n\n        CUtensorMap            tm_v{};\n        [[maybe_unused]] uint2 box_v{};\n        if (V) {\n            // std::cout << \"V: \" << Vdesc << \"\\n\";\n            // box_v = {(uint32_t)round_up(cdiv(k, 128), 4), 2};\n            // std::cout << \"V: \" << Vdesc << \", box: \" << box_v.x << \",\" << box_v.y << \"\\n\";\n            // tm_v = make_2d_tma_desc((void*)V, Vdesc, {box_v.y, box_v.x}, CU_TENSOR_MAP_SWIZZLE_NONE);\n        }\n\n        const int sm_count = sm_count_;\n\n        static constexpr int cluster_size = Gemm::kClusterSize;\n\n        auto       grid  = sm_count / cluster_size * cluster_size;\n        const auto block = Gemm::CTA_SIZE;\n\n        cudaLaunchConfig_t config{};\n        config.gridDim          = grid;\n        config.blockDim         = block;\n        config.dynamicSmemBytes = info_.dynamic_smem_size;\n        config.stream           = stream;\n\n        auto func = gemm_kernel_name<Gemm>;\n\n        [[maybe_unused]] static bool _ = [&] {\n            int max_cluster_size = 0;\n            cudaOccupancyMaxPotentialClusterSize(&max_cluster_size, func, &config);\n            // std::cout << \"max cluster size: \" << max_cluster_size << \"\\n\";\n            return false;\n        }();\n\n        cudaLaunchAttribute attrs[1];\n\n        attrs[0].id               = cudaLaunchAttributeClusterDimension;\n        attrs[0].val.clusterDim.x = cluster_size;\n        attrs[0].val.clusterDim.y = 1;\n        attrs[0].val.clusterDim.z = 1;\n\n        config.attrs    = attrs;\n        config.numAttrs = std::size(attrs);\n\n        int max_active_cluster{};\n        cudaOccupancyMaxActiveClusters(&max_active_cluster, func, &config);\n        config.gridDim = std::min<int>(config.gridDim.x, max_active_cluster * cluster_size);\n\n        // std::cout << \"max active cluster: \" << max_active_cluster << \"\\n\";\n\n        // std::cout << \"swizzle: \" << swizzle << \", split: \" << splits << \"\\n\";\n\n        auto ec = cudaLaunchKernelEx(&config,\n                                     func,\n                                     tm_a,\n                                     tm_b,\n                                     tm_c,\n                                     tm_u,\n                                     tm_v,\n                                     to_param((void*)A, Adesc),\n                                     to_param((void*)B, Bdesc),\n                                     to_param((void*)U, Udesc),\n                                     to_param((void*)V, Vdesc),\n                                     to_param((void*)D, Ddesc),\n                                     sched,\n                                     workspace.tensormaps);\n        TM_CHECK_EQ(ec, cudaSuccess) << cudaGetErrorString(ec);\n\n        return 0;\n    }\n\n    std::array<size_t, 2> GetWorkspaceSize(int tiles, int splits) const\n    {\n        static constexpr bool kSerial = true;\n\n        size_t barriers_size = sizeof(int) * tiles;\n        size_t partials_size = sizeof(float) * TILE_M * TILE_N * tiles;\n\n        if constexpr (!kSerial) {\n            barriers_size *= splits;\n            partials_size *= splits;\n        }\n\n        return {barriers_size, partials_size};\n    }\n\n    int GetMaxSplits(const int4& shape, int swizzle, size_t bsize, size_t psize) const override\n    {\n        return 1;\n    }\n\n    int GetMaxSwizzle(const int4& shape) const override\n    {\n        using Map = typename Gemm::Scheduler;\n        // TODO: fix tiled shape\n        const auto tiles = get_tiled_shape(shape.x, shape.y, TILE_M, TILE_N);\n        return Map::get_log_tile(tiles, 1 << 10);\n    }\n\n    bool is_feasible(const GemmDesc& desc) const noexcept override\n    {\n        if (desc.striding_a != desc_.striding_a) {\n            return false;\n        }\n        if (desc.striding_b != desc_.striding_b) {\n            return false;\n        }\n        if (desc.striding_c != desc_.striding_c) {\n            return false;\n        }\n        return Kernel::is_feasible(desc);\n    }\n\nprivate:\n    int sm_count_ = 0;\n};\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/mainloop_sm70.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include \"src/turbomind/kernels/core/array_ops.h\"\n#include \"src/turbomind/kernels/core/common.h\"\n#include \"src/turbomind/kernels/core/data_type.h\"\n#include \"src/turbomind/kernels/core/layout.h\"\n#include \"src/turbomind/kernels/core/math.h\"\n#include \"src/turbomind/kernels/core/meta.h\"\n#include \"src/turbomind/kernels/gemm/operand.h\"\n#include \"src/turbomind/kernels/gemm/thread_map.h\"\n#include \"src/turbomind/kernels/gemm/types.h\"\n#include \"src/turbomind/kernels/gemm/utils.h\"\n#include <cuda_pipeline_primitives.h>\n\nnamespace turbomind::gemm {\n\ntemplate<int Stages>\nstruct GroupIter {\n\n    static_assert((Stages & (Stages - 1)) == 0);\n\n    int iter_ = 0;\n\n    __device__ void Advance()\n    {\n        iter_ = (iter_ + 1) % Stages;\n    }\n\n    __device__ constexpr explicit operator bool()\n    {\n        return iter_ == 0;\n    }\n};\n\ntemplate<>\nstruct GroupIter<1> {\n    __device__ void               Advance() {}\n    __device__ constexpr explicit operator bool()\n    {\n        return true;\n    }\n};\n\ntemplate<class Pointer, int Step, int Stages>\nstruct SmemIter {\n    Pointer pointer;\n    Pointer other_;\n\n    __device__ SmemIter(Pointer base): pointer{base}, other_{base + Step} {}\n\n    __device__ void Advance()\n    {\n        auto tmp = pointer;\n        pointer  = other_;\n        other_   = tmp;\n    }\n};\n\ntemplate<class A, class B, class U, class V>\nstruct Binding {\n    A&         a;\n    B&         b;\n    U&         u;\n    V&         v;\n    __device__ Binding(A& a, B& b, U& u, V& v): a{a}, b{b}, u{u}, v{v} {}  // CTAD\n};\n\n// Inspired by\n// https://github.com/NVIDIA/cutlass/blob/f93a69134ec8259fd235f220209d6f8734a5cb06/include/cutlass/gemm/threadblock/mma_pipelined.h\ntemplate<class MMA,\n         class OperandA_,\n         class IteratorA_,\n         class TransformA_,\n         class OperandU_,\n         int GroupSizeU_,\n         class OperandB_,\n         class IteratorB_,\n         class TransformB_,\n         class OperandV_,\n         int  GroupSizeV_,\n         int  Stages_,\n         bool FusePrefetch_>\nstruct MainloopSm70 {\n\n    using MMA_Atom = typename MMA::Atom;\n    using MMA_Map  = typename MMA::Map;\n\n    using FragC = typename MMA_Atom::FragC[MMA::kMmaIterM][MMA::kMmaIterN];\n\n    static constexpr int Stages = Stages_;\n\n    static constexpr int CTA_M = MMA::M;\n    static constexpr int CTA_N = MMA::N;\n    static constexpr int CTA_K = MMA::K;\n\n    static constexpr auto kOpClass = MMA_Atom::kOpClass;\n\n    static constexpr int WARPS = MMA::kThreadCount / WARP_SIZE;\n\n    using OperandA = MakeOperand<OperandA_, IteratorA_, CTA_M, CTA_K, WARPS>;\n    using OperandU = MakeOperand<OperandU_, IteratorA_, CTA_M, CTA_K, WARPS, GroupSizeU_>;\n\n    using OperandB = MakeOperand<OperandB_, IteratorB_, CTA_N, CTA_K, WARPS>;\n    using OperandV = MakeOperand<OperandV_, IteratorB_, CTA_N, CTA_K, WARPS, GroupSizeV_>;\n\n    using TransformA = TransformA_;\n    using TransformB = TransformB_;\n\n    using Ta = typename OperandA::Dtype;\n    using Tb = typename OperandB::Dtype;\n    using Tu = typename OperandU::Dtype;\n    using Tv = typename OperandV::Dtype;\n\n    using SmemLayoutA = typename OperandA::SmemLayout;\n    using SmemLayoutB = typename OperandB::SmemLayout;\n    using SmemLayoutU = typename OperandU::SmemLayout;\n    using SmemLayoutV = typename OperandV::SmemLayout;\n\n    using SmemCopyA = SmemCopy<OperandA, MMA_Map::kIterM, MMA_Map::kIterK, MMA_Map::kDeltaM, MMA_Map::kDeltaK>;\n    using SmemCopyU = SmemCopy<OperandU, MMA_Map::kIterM, MMA_Map::kIterK, MMA_Map::kDeltaM, MMA_Map::kDeltaK>;\n    using SmemCopyB = SmemCopy<OperandB, MMA_Map::kIterN, MMA_Map::kIterK, MMA_Map::kDeltaN, MMA_Map::kDeltaK>;\n    using SmemCopyV = SmemCopy<OperandV, MMA_Map::kIterN, MMA_Map::kIterK, MMA_Map::kDeltaN, MMA_Map::kDeltaK>;\n\n    using SmemAccessorA = SmemAccessor<Ta, SmemLayoutA>;\n    using SmemAccessorB = SmemAccessor<Tb, SmemLayoutB>;\n    using SmemAccessorU = SmemAccessor<Tu, SmemLayoutU>;\n    using SmemAccessorV = SmemAccessor<Tv, SmemLayoutV>;\n\n    using GmemIterA = typename OperandA::GmemIter;\n    using GmemIterB = typename OperandB::GmemIter;\n    using GmemIterU = typename OperandU::GmemIter;\n    using GmemIterV = typename OperandV::GmemIter;\n\n    struct SharedStorage {\n        __align__(16) Array<Ta, Stages * SmemLayoutA::kSize> A;\n        __align__(16) Array<Tb, Stages * SmemLayoutB::kSize> B;\n        __align__(16) Array<Tu, Stages * SmemLayoutU::kSize> U;\n        __align__(16) Array<Tv, Stages * SmemLayoutV::kSize> V;\n    };\n\n    template<class GmemIter, class SmemIter>\n    __device__ void _advance_smem(GmemIter& gmem_iter, SmemIter& smem_iter)\n    {\n        gmem_iter.smem_data_ = smem_iter.pointer;\n        smem_iter.Advance();\n    }\n\n    // zip with\n    template<class BindingG, class BindingS>\n    __device__ void AdvanceSmemStage(BindingG& g, BindingS& s)\n    {\n        _advance_smem(g.a, s.a);\n        _advance_smem(g.b, s.b);\n        _advance_smem(g.u, s.u);\n        _advance_smem(g.v, s.v);\n    }\n\n    template<class Binding>\n    __device__ void ClearSmem(Binding& g)\n    {\n        g.a.ClearSmem();\n        g.b.ClearSmem();\n        g.u.ClearSmem();\n        g.v.ClearSmem();\n    }\n\n    template<class Binding, class Fragments>\n    __device__ void Fetch(Binding& g, Fragments& f, bool mask)\n    {\n        g.a.Fetch(f.a, mask);\n        g.b.Fetch(f.b, mask);\n        g.u.Fetch(f.u, mask);\n        g.v.Fetch(f.v, mask);\n    }\n\n    template<class Binding, class Fragments>\n    __device__ void Store(Binding& g, Fragments& f)\n    {\n        g.a.Store(f.a);\n        g.b.Store(f.b);\n        g.u.Store(f.u);\n        g.v.Store(f.v);\n    }\n\n    template<class Binding>\n    __device__ void AdvanceGmemStage(Binding& g)\n    {\n        g.a.Advance();\n        g.b.Advance();\n        g.u.Advance();\n        g.v.Advance();\n    }\n\n    __device__ void operator()(GmemIterA&     gmem_A,\n                               GmemIterB&     gmem_B,\n                               GmemIterU&     gmem_U,\n                               GmemIterV&     gmem_V,\n                               FragC&         frag_C,\n                               int            tile_iter,\n                               SharedStorage& storage)\n    {\n        static_assert(MMA::kAtomK == 1);\n\n        static constexpr int UU = 1;  // ceil_div(GroupSizeU_, MMA_Map::TileK);\n        static constexpr int VV = 1;  // ceil_div(GroupSizeV_, MMA_Map::TileK);\n\n        // mma_iter_x = tile_iter_x * atom_x\n        typename MMA_Atom::FragA frag_A[MMA::kTileIterK][MMA::kMmaIterM];\n        typename MMA_Atom::FragB frag_B[MMA::kTileIterK][MMA::kMmaIterN];\n\n        typename SmemCopyA::Frag data_A[MMA::kTileIterK];\n        typename SmemCopyB::Frag data_B[MMA::kTileIterK];\n        typename SmemCopyU::Frag data_U[ceil_div(MMA::kTileIterK, UU)];\n        typename SmemCopyV::Frag data_V[ceil_div(MMA::kTileIterK, VV)];\n\n        SmemIter<get_pointer_type<Ta>, SmemLayoutA::kSize, Stages> smem_A{storage.A.data()};\n        SmemIter<get_pointer_type<Tb>, SmemLayoutB::kSize, Stages> smem_B{storage.B.data()};\n        SmemIter<get_pointer_type<Tu>, SmemLayoutU::kSize, Stages> smem_U{storage.U.data()};\n        SmemIter<get_pointer_type<Tv>, SmemLayoutV::kSize, Stages> smem_V{storage.V.data()};\n\n        typename GmemIterA::Fragments rmem_A;\n        typename GmemIterB::Fragments rmem_B;\n        typename GmemIterU::Fragments rmem_U;\n        typename GmemIterV::Fragments rmem_V;\n\n        GroupIter<ceil_div(GroupSizeU_, CTA_K)> gmem_group_iter_U{};\n        GroupIter<ceil_div(GroupSizeV_, CTA_K)> gmem_group_iter_V{};\n\n        auto smem_group_iter_U = gmem_group_iter_U;\n        auto smem_group_iter_V = gmem_group_iter_V;\n\n        // a separate counter tends to generate better code\n        int gmem_iter = tile_iter;\n        int gmem_mask = true;\n\n        Binding gmem_iters{gmem_A, gmem_B, gmem_U, gmem_V};\n        Binding smem_iters{smem_A, smem_B, smem_U, smem_V};\n        Binding rmem{rmem_A, rmem_B, rmem_U, rmem_V};\n\n        // r0,w_\n\n        PRAGMA_UNROLL\n        for (int i = 0; i < Stages; ++i) {\n            AdvanceSmemStage(gmem_iters, smem_iters);\n            ClearSmem(gmem_iters);\n        }\n\n        // r0,w1\n\n        __syncthreads();\n\n        auto fetch_stage = [&](auto& rmem) {\n            Fetch(gmem_iters, rmem, gmem_mask);\n            AdvanceGmemStage(gmem_iters);\n            gmem_group_iter_U.Advance();\n            gmem_group_iter_V.Advance();\n            gmem_U.g_mask = (bool)gmem_group_iter_U;\n            gmem_V.g_mask = (bool)gmem_group_iter_V;\n            if (--gmem_iter == 0) {\n                gmem_mask = false;\n            }\n        };\n\n        auto advance_and_wait_smem_stage = [&] {\n            __syncthreads();\n            AdvanceSmemStage(gmem_iters, smem_iters);\n        };\n\n        const int3 offset_mnk = MMA::get_offset(threadIdx.x);\n        const int  offset_m   = offset_mnk.x;\n        const int  offset_n   = offset_mnk.y;\n        const int  offset_k   = offset_mnk.z;\n\n        SmemCopyA smem_copy_A{{offset_m, offset_k}};\n        SmemCopyU smem_copy_U{{offset_m, offset_k}};\n        SmemCopyB smem_copy_B{{offset_n, offset_k}};\n        SmemCopyV smem_copy_V{{offset_n, offset_k}};\n\n        auto preload = [&](int k) {\n            smem_copy_A(smem_A.pointer, data_A[k], k);\n            smem_copy_U(smem_U.pointer, data_U[k / UU], k, k % UU == 0 && (bool)smem_group_iter_U);\n\n            smem_copy_B(smem_B.pointer, data_B[k], k);\n            smem_copy_V(smem_V.pointer, data_V[k / VV], k, k % VV == 0 && (bool)smem_group_iter_V);\n        };\n\n        AdvanceSmemStage(gmem_iters, smem_iters);\n        // r1,w0\n\n        fetch_stage(rmem);  // gmem -> rmem\n\n        Store(gmem_iters, rmem);  // rmem -> smem\n\n        advance_and_wait_smem_stage();\n        // r0,w1\n\n        preload(0);  // smem -> data_[A,B,U,V]\n\n        TransformA::apply(frag_A, 0, data_A, data_U, UU);\n        TransformB::apply(frag_B, 0, data_B, data_V, VV);\n\n        PRAGMA_NO_UNROLL\n        for (; tile_iter > 0; --tile_iter) {\n            constexpr int ITER_K = MMA::kTileIterK;\n            static_assert(ITER_K > 1);\n\n            PRAGMA_UNROLL\n            for (int k = 0; k < ITER_K; ++k) {\n                // The last iter, store prefetched fragments to smem\n                if (k == ITER_K - 1) {\n                    Store(gmem_iters, rmem);\n                    advance_and_wait_smem_stage();  // swap rw\n                    smem_group_iter_U.Advance();\n                    smem_group_iter_V.Advance();\n                }\n\n                // Preload for next iter, smem -> data_[A,B,U,V]\n                preload((k + 1) % ITER_K);\n\n                // The first iter, issue the prefetching of next stage\n                if (k == 0) {\n                    fetch_stage(rmem);\n                }\n\n                // PRAGMA_UNROLL\n                // for (int n = 0; n < MMA::kMmaIterN; ++n) {\n                //     PRAGMA_UNROLL\n                //     for (int m = 0; m < MMA::kMmaIterM; ++m) {\n                //         int mm = n % 2 ? MMA::kMmaIterM - m - 1 : m;\n                //         MMA_Atom::fma(frag_C[mm][n], frag_A[k][mm], frag_B[k][n], frag_C[mm][n]);\n                //     }\n                // }\n\n                PRAGMA_UNROLL\n                for (int m = 0; m < MMA::kMmaIterM; ++m) {\n                    PRAGMA_UNROLL\n                    for (int n = 0; n < MMA::kMmaIterN; ++n) {\n                        int nn = m % 2 ? MMA::kMmaIterN - n - 1 : n;\n                        MMA_Atom::fma(frag_C[m][nn], frag_A[k][m], frag_B[k][nn], frag_C[m][nn]);\n                    }\n                }\n\n                TransformA::apply(frag_A, (k + 1) % ITER_K, data_A, data_U, UU);\n                TransformB::apply(frag_B, (k + 1) % ITER_K, data_B, data_V, VV);\n            }\n        }\n\n        __syncthreads();\n    }\n};\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/mainloop_sm80_v2.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include \"src/turbomind/kernels/core/array_ops.h\"\n#include \"src/turbomind/kernels/core/common.h\"\n#include \"src/turbomind/kernels/core/data_type.h\"\n#include \"src/turbomind/kernels/core/layout.h\"\n#include \"src/turbomind/kernels/core/math.h\"\n#include \"src/turbomind/kernels/core/meta.h\"\n#include \"src/turbomind/kernels/gemm/operand.h\"\n#include \"src/turbomind/kernels/gemm/thread_map.h\"\n#include \"src/turbomind/kernels/gemm/types.h\"\n#include \"src/turbomind/kernels/gemm/utils.h\"\n#include <cuda_pipeline_primitives.h>\n\nnamespace turbomind::gemm {\n\ntemplate<int Stages>\nstruct GroupIter {\n\n    static_assert((Stages & (Stages - 1)) == 0);\n\n    int iter_ = 0;\n\n    __device__ void Advance()\n    {\n        iter_ = (iter_ + 1) % Stages;\n    }\n\n    __device__ constexpr explicit operator bool()\n    {\n        return iter_ == 0;\n    }\n};\n\ntemplate<>\nstruct GroupIter<1> {\n    __device__ void               Advance() {}\n    __device__ constexpr explicit operator bool()\n    {\n        return true;\n    }\n};\n\ntemplate<class Pointer, int Step, int Stages>\nstruct SmemIter {\n    Pointer base_;\n    Pointer pointer;\n    int     pipe_iter_;\n\n    __device__ SmemIter(Pointer base): base_{base}, pointer{base}, pipe_iter_{} {}\n\n    __device__ void Advance()\n    {\n        pipe_iter_ += 1;\n        pointer = pointer + Step;\n        if (pipe_iter_ == Stages) {\n            pipe_iter_ = 0;\n            pointer    = base_;\n        }\n    }\n};\n\ntemplate<class A, class B, class U, class V>\nstruct Binding {\n    A&         a;\n    B&         b;\n    U&         u;\n    V&         v;\n    __device__ Binding(A& a, B& b, U& u, V& v): a{a}, b{b}, u{u}, v{v} {}  // CTAD\n};\n\n// Inspired by\n// https://github.com/NVIDIA/cutlass/blob/f93a69134ec8259fd235f220209d6f8734a5cb06/include/cutlass/gemm/threadblock/mma_multistage.h\n// https://github.com/NVIDIA/cutlass/blob/f93a69134ec8259fd235f220209d6f8734a5cb06/include/cutlass/gemm/collective/sm80_mma_multistage.hpp\ntemplate<class MMA,\n         class OperandA_,\n         class IteratorA_,\n         class TransformA_,\n         class OperandU_,\n         int GroupSizeU_,\n         class OperandB_,\n         class IteratorB_,\n         class TransformB_,\n         class OperandV_,\n         int  GroupSizeV_,\n         int  Stages_,\n         bool FusePrefetch_>\nstruct MainloopSm80_v2 {\n\n    using MMA_Atom = typename MMA::Atom;\n    using MMA_Map  = typename MMA::Map;\n\n    using FragC = typename MMA_Atom::FragC[MMA::kMmaIterM][MMA::kMmaIterN];\n\n    static constexpr int Stages = Stages_;\n\n    static constexpr int CTA_M = MMA::M;\n    static constexpr int CTA_N = MMA::N;\n    static constexpr int CTA_K = MMA::K;\n\n    static constexpr auto kOpClass = MMA_Atom::kOpClass;\n\n    static constexpr int WARPS = MMA::kThreadCount / WARP_SIZE;\n\n    using OperandA = MakeOperand<OperandA_, IteratorA_, CTA_M, CTA_K, WARPS>;\n    using OperandU = MakeOperand<OperandU_, IteratorA_, CTA_M, CTA_K, WARPS, GroupSizeU_>;\n\n    using OperandB = MakeOperand<OperandB_, IteratorB_, CTA_N, CTA_K, WARPS>;\n    using OperandV = MakeOperand<OperandV_, IteratorB_, CTA_N, CTA_K, WARPS, GroupSizeV_>;\n\n    using TransformA = TransformA_;\n    using TransformB = TransformB_;\n\n    using Ta = typename OperandA::Dtype;\n    using Tb = typename OperandB::Dtype;\n    using Tu = typename OperandU::Dtype;\n    using Tv = typename OperandV::Dtype;\n\n    using SmemLayoutA = typename OperandA::SmemLayout;\n    using SmemLayoutB = typename OperandB::SmemLayout;\n    using SmemLayoutU = typename OperandU::SmemLayout;\n    using SmemLayoutV = typename OperandV::SmemLayout;\n\n    using SmemCopyA = SmemCopy<OperandA, MMA_Map::kIterM, MMA_Map::kIterK, MMA_Map::kDeltaM, MMA_Map::kDeltaK>;\n    using SmemCopyU = SmemCopy<OperandU, MMA_Map::kIterM, MMA_Map::kIterK, MMA_Map::kDeltaM, MMA_Map::kDeltaK>;\n    using SmemCopyB = SmemCopy<OperandB, MMA_Map::kIterN, MMA_Map::kIterK, MMA_Map::kDeltaN, MMA_Map::kDeltaK>;\n    using SmemCopyV = SmemCopy<OperandV, MMA_Map::kIterN, MMA_Map::kIterK, MMA_Map::kDeltaN, MMA_Map::kDeltaK>;\n\n    using SmemAccessorA = SmemAccessor<Ta, SmemLayoutA>;\n    using SmemAccessorB = SmemAccessor<Tb, SmemLayoutB>;\n    using SmemAccessorU = SmemAccessor<Tu, SmemLayoutU>;\n    using SmemAccessorV = SmemAccessor<Tv, SmemLayoutV>;\n\n    using GmemIterA = typename OperandA::GmemIter;\n    using GmemIterB = typename OperandB::GmemIter;\n    using GmemIterU = typename OperandU::GmemIter;\n    using GmemIterV = typename OperandV::GmemIter;\n\n    static constexpr int kFusePrefetch = FusePrefetch_;\n\n    static constexpr int kMaxPrefetchIter = 1;\n    // std::min(ceil_div(std::max(GmemIterA::ITER_S, GmemIterB::ITER_S), 4), MMA::kTileIterK);\n\n    static constexpr int kBatchA = ceil_div(GmemIterA::ITER_S, kMaxPrefetchIter);\n    static constexpr int kBatchB = ceil_div(GmemIterB::ITER_S, kMaxPrefetchIter);\n    static constexpr int kBatchU = ceil_div(GmemIterU::ITER_S, kMaxPrefetchIter);\n    static constexpr int kBatchV = ceil_div(GmemIterV::ITER_S, kMaxPrefetchIter);\n\n    struct SharedStorage {\n        __align__(16) Array<Ta, Stages * SmemLayoutA::kSize> A;\n        __align__(16) Array<Tb, Stages * SmemLayoutB::kSize> B;\n        __align__(16) Array<Tu, Stages * SmemLayoutU::kSize> U;\n        __align__(16) Array<Tv, Stages * SmemLayoutV::kSize> V;\n    };\n\n    __device__ void Wait()\n    {\n        __pipeline_wait_prior(Stages - 2);\n        __syncthreads();\n    }\n\n    template<class GmemIter, class SmemIter>\n    __device__ void _advance_smem(GmemIter& gmem_iter, SmemIter& smem_iter)\n    {\n        gmem_iter.smem_data_ = smem_iter.pointer;\n        smem_iter.Advance();\n    }\n\n    // zip with\n    template<class BindingG, class BindingS>\n    __device__ void AdvanceSmemStage(BindingG& g, BindingS& s)\n    {\n        _advance_smem(g.a, s.a);\n        _advance_smem(g.b, s.b);\n        _advance_smem(g.u, s.u);\n        _advance_smem(g.v, s.v);\n    }\n\n    template<class Binding>\n    __device__ void ClearSmem(Binding& g)\n    {\n        g.a.ClearSmem();\n        g.b.ClearSmem();\n        g.u.ClearSmem();\n        g.v.ClearSmem();\n    }\n\n    template<class Binding>\n    __device__ void Prefetch(Binding& g, bool mask)\n    {\n        g.a.Prefetch(mask);\n        g.b.Prefetch(mask);\n        g.u.Prefetch(mask);\n        g.v.Prefetch(mask);\n    }\n\n    template<class Binding>\n    __device__ void Prefetch(Binding& g, int k, bool mask)\n    {\n        int batch_A = min((k + 1) * kBatchA, GmemIterA::ITER_S) - k * kBatchA;\n        int batch_B = min((k + 1) * kBatchB, GmemIterB::ITER_S) - k * kBatchB;\n        int batch_U = min((k + 1) * kBatchU, GmemIterU::ITER_S) - k * kBatchU;\n        int batch_V = min((k + 1) * kBatchV, GmemIterV::ITER_S) - k * kBatchV;\n        g.a.Prefetch(k * kBatchA, batch_A, mask);\n        g.b.Prefetch(k * kBatchB, batch_B, mask);\n        g.u.Prefetch(k * kBatchU, batch_U, mask);\n        g.v.Prefetch(k * kBatchV, batch_V, mask);\n    }\n\n    template<class Binding>\n    __device__ void AdvanceGmemStage(Binding& g)\n    {\n        g.a.Advance();\n        g.b.Advance();\n        g.u.Advance();\n        g.v.Advance();\n    }\n\n    __device__ void operator()(GmemIterA&     gmem_A,\n                               GmemIterB&     gmem_B,\n                               GmemIterU&     gmem_U,\n                               GmemIterV&     gmem_V,\n                               FragC&         frag_C,\n                               int            tile_iter,\n                               SharedStorage& storage)\n    {\n        static_assert(MMA::kAtomK == 1);\n\n        static constexpr int UU = ceil_div(GroupSizeU_, MMA_Map::TileK);\n        static constexpr int VV = ceil_div(GroupSizeV_, MMA_Map::TileK);\n\n        // mma_iter_x = tile_iter_x * atom_x\n        typename MMA_Atom::FragA frag_A[MMA::kTileIterK][MMA::kMmaIterM];\n        typename MMA_Atom::FragB frag_B[MMA::kTileIterK][MMA::kMmaIterN];\n\n        typename SmemCopyA::Frag data_A[MMA::kTileIterK];\n        typename SmemCopyB::Frag data_B[MMA::kTileIterK];\n        typename SmemCopyU::Frag data_U[ceil_div(MMA::kTileIterK, UU)];\n        typename SmemCopyV::Frag data_V[ceil_div(MMA::kTileIterK, VV)];\n\n        SmemIter<get_pointer_type<Ta>, SmemLayoutA::kSize, Stages> smem_A{storage.A.data()};\n        SmemIter<get_pointer_type<Tb>, SmemLayoutB::kSize, Stages> smem_B{storage.B.data()};\n        SmemIter<get_pointer_type<Tu>, SmemLayoutU::kSize, Stages> smem_U{storage.U.data()};\n        SmemIter<get_pointer_type<Tv>, SmemLayoutV::kSize, Stages> smem_V{storage.V.data()};\n\n        GroupIter<ceil_div(GroupSizeU_, CTA_K)> gmem_group_iter_U{};\n        GroupIter<ceil_div(GroupSizeV_, CTA_K)> gmem_group_iter_V{};\n\n        auto smem_group_iter_U = gmem_group_iter_U;\n        auto smem_group_iter_V = gmem_group_iter_V;\n\n        // a separate counter tends to generate better code\n        int gmem_iter = tile_iter;\n        int gmem_mask = true;\n\n        Binding gmem_iters{gmem_A, gmem_B, gmem_U, gmem_V};\n        Binding smem_iters{smem_A, smem_B, smem_U, smem_V};\n\n        PRAGMA_UNROLL\n        for (int i = 0; i < Stages; ++i) {\n            AdvanceSmemStage(gmem_iters, smem_iters);\n            ClearSmem(gmem_iters);\n        }\n\n        // r: 0, w:s-1\n\n        __syncthreads();\n\n        auto prefetch_stage = [&] {\n            Prefetch(gmem_iters, gmem_mask);\n            __pipeline_commit();\n            AdvanceGmemStage(gmem_iters);\n            gmem_group_iter_U.Advance();\n            gmem_group_iter_V.Advance();\n            gmem_U.g_mask = (bool)gmem_group_iter_U;\n            gmem_V.g_mask = (bool)gmem_group_iter_V;\n            if (--gmem_iter == 0) {\n                gmem_mask = false;\n            }\n        };\n\n        [[maybe_unused]] auto prefetch_batch = [&](int k) {\n            Prefetch(gmem_iters, k, gmem_mask);\n            if (k == MMA::kTileIterK - 1) {\n                __pipeline_commit();\n                AdvanceGmemStage(gmem_iters);\n                gmem_group_iter_U.Advance();\n                gmem_group_iter_V.Advance();\n                gmem_U.g_mask = (bool)gmem_group_iter_U;\n                gmem_V.g_mask = (bool)gmem_group_iter_V;\n                if (--gmem_iter == 0) {\n                    gmem_mask = false;\n                }\n            }\n        };\n\n        auto advance_and_wait_smem_stage = [&] {\n            Wait();\n            AdvanceSmemStage(gmem_iters, smem_iters);\n        };\n\n        const int3 offset_mnk = MMA::get_offset(threadIdx.x);\n        const int  offset_m   = offset_mnk.x;\n        const int  offset_n   = offset_mnk.y;\n        const int  offset_k   = offset_mnk.z;\n\n        SmemCopyA smem_copy_A{{offset_m, offset_k}};\n        SmemCopyU smem_copy_U{{offset_m, offset_k}};\n        SmemCopyB smem_copy_B{{offset_n, offset_k}};\n        SmemCopyV smem_copy_V{{offset_n, offset_k}};\n\n        auto preload = [&](int k) {\n            smem_copy_A(smem_A.pointer, data_A[k], k);\n            smem_copy_U(smem_U.pointer, data_U[k / UU], k, k % UU == 0 && (bool)smem_group_iter_U);\n\n            smem_copy_B(smem_B.pointer, data_B[k], k);\n            smem_copy_V(smem_V.pointer, data_V[k / VV], k, k % VV == 0 && (bool)smem_group_iter_V);\n        };\n\n        PRAGMA_UNROLL\n        for (int stage = 0; stage < Stages - 1; ++stage) {\n            AdvanceSmemStage(gmem_iters, smem_iters);\n            prefetch_stage();\n        }\n        // r:-1, w:-2\n\n        advance_and_wait_smem_stage();\n        // r: 0, w:-1\n\n        preload(0);\n\n        TransformA::apply(frag_A, 0, data_A, data_U, UU);\n        TransformB::apply(frag_B, 0, data_B, data_V, VV);\n\n        if constexpr (kFusePrefetch) {\n            prefetch_batch(0);\n        }\n\n        PRAGMA_NO_UNROLL\n        for (; tile_iter > 0; --tile_iter) {\n            if constexpr (!kFusePrefetch) {\n                prefetch_stage();\n            }\n            constexpr int ITER_K = MMA::kTileIterK;\n            static_assert(ITER_K > 1);\n\n            PRAGMA_UNROLL\n            for (int k = 0; k < ITER_K; ++k) {\n                // preload for next iter\n                preload((k + 1) % ITER_K);\n\n                MMA::mma_k_iter(frag_C, frag_A[k], frag_B[k], frag_C);\n\n                if constexpr (kFusePrefetch) {\n                    prefetch_batch((k + 1) % ITER_K);\n                }\n\n                if (k + 1 == ITER_K - 1) {\n                    advance_and_wait_smem_stage();\n                    smem_group_iter_U.Advance();\n                    smem_group_iter_V.Advance();\n                }\n\n                TransformA::apply(frag_A, (k + 1) % ITER_K, data_A, data_U, UU);\n                TransformB::apply(frag_B, (k + 1) % ITER_K, data_B, data_V, VV);\n            }\n        }\n\n        __pipeline_commit();\n        __pipeline_wait_prior(0);\n\n        __syncthreads();\n    }\n};\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/matrix_ptr.h",
    "content": "#pragma once\n\n#include \"src/turbomind/kernels/core/data_type.h\"\n#include \"src/turbomind/kernels/core/meta.h\"\n#include \"src/turbomind/kernels/gemm/types.h\"\n\nnamespace turbomind::gemm {\n\nstruct __align__(16) StridedPtr\n{\n    void* ptr;\n    int   stride;\n};\n\nstruct MatrixParam {\n    void* ptr;\n    int   stride;\n    int*  offsets;\n    int*  idxs;\n};\n\nstruct MatrixData {\n    StridedPtr ptr;\n    const int* idxs;\n};\n\ninline MatrixParam to_param(void* ptr, MatrixLayout layout)\n{\n    return {ptr, layout.ld, layout.offsets, layout.idxs};\n}\n\n#if 0\ntemplate<Striding mode>\n__inline__ __device__ MatrixData resolve(const MatrixParam& param, int gemm_id)\n{\n    if constexpr (mode == Striding::kFlat) {\n        return {{param.ptr, param.stride}, nullptr};\n    }\n    else if constexpr (mode == Striding::kBlocked) {\n        StridedPtr ptr{param.ptr, param.stride};\n        if (param.stride == 0) {\n            (uint4&)ptr = __ldg((const uint4*)param.ptr + gemm_id);\n        }\n        return {ptr, nullptr};\n    }\n    else if constexpr (mode == Striding::kIndexed) {\n        const uintptr_t idx = param.idxs ? __ldg((uintptr_t*)param.idxs + gemm_id) : 0;\n        StridedPtr      ptr{param.ptr, param.stride};\n        if (param.stride == 0) {\n            (uint4&)ptr = __ldg((const uint4*)param.ptr + gemm_id);\n        }\n        return {ptr, reinterpret_cast<const int*>(idx)};\n    }\n    else {\n        static_assert(mode != mode, \"Not implemented.\");\n        return {};\n    }\n}\n#endif\n\ntemplate<class T, Striding mode>\n__inline__ __device__ MatrixData resolve(const MatrixParam& param, int g)\n{\n    StridedPtr ptr{param.ptr, param.stride};\n    const int* idxs{};\n    if constexpr (mode == Striding::kFlat) {\n        // pass\n    }\n    else if constexpr (mode == Striding::kBlocked) {\n        if (ptr.stride == 0) {\n            (uint4&)ptr = __ldg((const uint4*)param.ptr + g);\n        }  // Post-condition: ptr.stride != 0\n        if (param.offsets) {\n            ptr.ptr = (char*)ptr.ptr + __ldg(param.offsets + g) * (size_t)ptr.stride * bitsof<T> / bitsof<char>;\n        }\n    }\n    else if constexpr (mode == Striding::kIndexed) {\n        idxs = param.idxs;\n        if (ptr.stride == 0) {\n            (uint4&)ptr = __ldg((const uint4*)param.ptr + g);\n            idxs        = idxs ? ((int**)idxs)[g] : nullptr;\n        }  // Post-condition: ptr.stride != 0\n        if (param.offsets) {\n            const int offset = __ldg(param.offsets + g);\n            if (idxs) {\n                idxs += offset;\n            }\n            else {\n                ptr.ptr = (char*)ptr.ptr + offset * (size_t)ptr.stride * bitsof<T> / bitsof<char>;\n            }\n        }\n    }\n    else {\n        static_assert(mode != mode, \"Not implemented.\");\n    }\n    return {ptr, idxs};\n}\n\n// p <- dat_ptrs[g]\n// i <- idx_ptrs[g]\n\n// pitch offset idxs\n//    1     0     0   -> {ptr, pitch}       , 0\n//    1     0     1   -> {ptr, pitch}       , idxs\n//    1     1     0   -> {ptr, pitch} + o[g], 0\n//    1     1     1   -> {ptr, pitch}       , idxs + o[g]\n//    0     0     0   ->       p            , 0\n//    0     0     1   ->       p            , i\n//    0     1     0   ->       p      + o[g], 0\n//    0     1     1   ->       p            , i    + o[g]\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/moe_utils_v2.cu",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include <algorithm>\n#include <cmath>\n#include <cstdio>\n#include <iostream>\n#include <limits>\n#include <numeric>\n#include <random>\n\n#include <cub/block/block_reduce.cuh>\n#include <cub/block/block_scan.cuh>\n#include <cub/warp/warp_scan.cuh>\n\n#include \"src/turbomind/core/allocator.h\"\n#include \"src/turbomind/core/check.h\"\n#include \"src/turbomind/core/data_type.h\"\n#include \"src/turbomind/kernels/core/array_ops.h\"\n#include \"src/turbomind/kernels/core/common.h\"\n#include \"src/turbomind/kernels/core/math.h\"\n#include \"src/turbomind/kernels/gemm/moe_utils_v2.h\"\n#include \"src/turbomind/kernels/reduce_kernel_utils.cuh\"\n\nnamespace turbomind {\n\ntemplate<int top_k, int block_dim>\n__global__ void MoeGateKernel_V2(float*       scales,  // [e,n]\n                                 int8_t*      masks,   // [E,n], padded\n                                 int*         accum,   // [E,tiles]\n                                 const float* logits,  // [E,n]\n                                 int          log_tile,\n                                 int          tiles,\n                                 int          tokens,\n                                 int          tokens_padded,\n                                 int          experts)\n{\n    constexpr int max_tiles = kMoeGateMaxTiles;\n\n    // Brute-force per thread top-k using a flat thread mapping\n    const int ti = threadIdx.x + blockIdx.x * blockDim.x;\n\n    // Clear masks\n    for (int e = 0; e < experts; ++e) {\n        if (ti < tokens_padded) {\n            masks[e * tokens_padded + ti] = -1;\n        }\n    }\n\n    __shared__ int shared_accum[32][max_tiles];\n\n    for (int i = threadIdx.x; i < experts * max_tiles; i += block_dim) {\n        int e = i / max_tiles;\n        int t = i % max_tiles;\n        if (e < experts && t < tiles) {\n            shared_accum[e][t] = 0;\n        }\n    }\n\n    __syncthreads();\n\n    if (ti < tokens) {\n\n        static_assert(top_k <= 32);\n        int mask = -1;\n\n        float max_logit = 0.f;\n\n        // Find top-k\n        PRAGMA_UNROLL\n        for (int k = 0; k < top_k; ++k) {\n            int   max_bit = 0;\n            float max_val = -std::numeric_limits<float>::infinity();\n            int   bit     = 1;\n            for (int e = 0; e < experts; ++e) {\n                const auto val = logits[ti * experts + e];\n                // const auto val = logits[e * tokens + ti];\n                if ((mask & bit) && val > max_val) {\n                    max_bit = bit;\n                    max_val = val;\n                }\n                bit *= 2;\n            }\n            mask -= max_bit;\n            if (k == 0) {\n                max_logit = max_val;\n            }\n        }\n\n        mask = ~mask;\n\n        Array<float, top_k> top_val;\n        PRAGMA_UNROLL\n        for (int i = 0; i < top_k; ++i) {\n            const int lowbit = (mask & -mask);\n            const int e      = 31 - __clz(lowbit);\n\n            // printf(\"e = %d, ti = %d, idx = %d\\n\", e, ti, i);\n\n            masks[e * tokens_padded + ti] = i;\n            atomicAdd(&shared_accum[e][ti >> log_tile], 1);\n            top_val[i] = logits[ti * experts + e];\n            // top_val[i] = logits[e * tokens + ti];\n\n            mask -= lowbit;\n        }\n\n        float prob_sum = 0.f;\n        PRAGMA_UNROLL\n        for (int i = 0; i < top_k; ++i) {\n            top_val[i] = expf(top_val[i] - max_logit);\n            prob_sum += top_val[i];\n        }\n\n        PRAGMA_UNROLL\n        for (int i = 0; i < top_k; ++i) {\n            scales[i * tokens + ti] = fdividef(top_val[i], prob_sum);\n        }\n    }\n\n    __syncthreads();\n\n    for (int i = threadIdx.x; i < experts * max_tiles; i += block_dim) {\n        int e = i / max_tiles;\n        int t = i % max_tiles;\n        if (e < experts && t < tiles) {\n            atomicAdd(accum + e * tiles + t, shared_accum[e][t]);\n        }\n    }\n}\n\ntemplate<int block_dim, class Mask>\n__global__ void MoeScanKernel_v2(int*       f2n,      // [e*n]\n                                 int*       f2E,      // [e*n]\n                                 int*       en2f,     // [e,n]\n                                 int*       offsets,  // [E+1]\n                                 Mask*      masks,    // [E,n], padded\n                                 const int* accum,    // [E,tiles]\n                                 int        log_tile,\n                                 int        tiles,\n                                 int        tokens,\n                                 int        tokens_padded,\n                                 int        experts)\n{\n    using BlockReduce = cub::BlockReduce<int, block_dim>;\n    using BlockScan   = cub::BlockScan<int, block_dim>;\n\n    __shared__ union TempStorage {\n        typename BlockReduce::TempStorage reduce;\n        typename BlockScan::TempStorage   scan;\n    } temp_storage;\n\n    constexpr int vec_size = kMoeGateVecSize;\n\n    using Vec = Array<Mask, vec_size>;\n\n    const int tile_id = blockIdx.x;\n    const int ei      = blockIdx.y;\n\n    const int  global_tile_id = ei * tiles + tile_id;\n    const bool is_valid       = global_tile_id <= experts * tiles;\n\n#if 0\n    int vacc[4]{};\n    {\n        int idx = threadIdx.x;\n        PRAGMA_UNROLL\n        for (int i = 0; i < 4; ++i) {\n            if (idx < global_tile_id) {\n                vacc[i] = accum[idx];\n            }\n            idx += block_dim;\n        }\n    }\n\n    int offset = BlockReduce{temp_storage.reduce}.Sum(vacc);\n#else\n\n    int vacc = 0;\n    for (int i = threadIdx.x; i < global_tile_id; i += block_dim) {\n        if (is_valid && i < global_tile_id) {\n            vacc += accum[i];\n        }\n    }\n\n    int offset = BlockReduce{temp_storage.reduce}.Sum(vacc);\n\n#endif\n\n    __shared__ int shared_offset;\n\n    if (threadIdx.x == 0) {\n        shared_offset = offset;\n        if (tile_id == 0) {\n            offsets[ei] = offset;\n        }\n    }\n\n    if (ei == experts) {\n        return;\n    }\n\n    __syncthreads();\n\n    offset = shared_offset;\n\n    const int token_vecs = tokens_padded / vec_size;\n\n    const int tile_size     = 1 << log_tile;\n    const int tile_vec_size = tile_size / vec_size;\n\n    const int tile_vec_beg    = tile_id * tile_vec_size;\n    const int tile_vec_end    = std::min(tile_vec_beg + tile_vec_size, token_vecs);\n    const int tile_vec_padded = tile_vec_beg + round_up(tile_vec_size, block_dim);\n\n    // if (threadIdx.x == 0) {\n    //     printf(\"%d %d %d\\n\", tile_vec_beg, tile_vec_end, tile_vec_padded);\n    // }\n\n    auto mask_ptr = (Vec*)masks + ei * token_vecs;\n\n    for (int vi = tile_vec_beg + threadIdx.x; vi < tile_vec_padded; vi += block_dim) {\n\n        const bool pred = vi < tile_vec_end;\n\n        Vec data;\n        fill(data, Mask{-1});\n        if (pred) {\n            Ldg(data, mask_ptr[vi].data());\n        }\n\n        int prefix[vec_size];\n        PRAGMA_UNROLL\n        for (int i = 0; i < vec_size; ++i) {\n            prefix[i] = int(data[i] >= 0);\n        }\n\n        int block_sum = 0;\n\n        BlockScan{temp_storage.scan}.ExclusiveSum(prefix, prefix, block_sum);\n        __syncthreads();\n\n        PRAGMA_UNROLL\n        for (int i = 0; i < vec_size; ++i) {\n            if (pred && data[i] >= 0) {\n                const int flat_id = prefix[i] + offset;\n                const int ti      = vi * vec_size + i;\n                f2n[flat_id]      = ti;\n                f2E[flat_id]      = ei;\n                // No ti is generated for padded tokens so we are safe\n                en2f[data[i] * tokens + ti] = flat_id;\n            }\n        }\n\n        offset += block_sum;\n    }\n}\n\ntemplate<int max_expert_num,\n         int max_top_k,\n         int items_per_thread,\n         int block_dim,\n         int access_size,\n         class Mask>\n__global__ void MoeGateKernel_v8(float*       scales,  // [e,n]\n                                 Mask*        masks,   // [E,n], padded\n                                 int*         accum,   // [E,tiles]\n                                 const float* logits,  // [n,E]\n                                 int          log_tile,\n                                 int          tiles,\n                                 int          token_num,\n                                 int          token_num_padded,\n                                 int          expert_num,\n                                 int          top_k,\n                                 bool         softmax,\n                                 bool         norm_topk,\n                                 float        routed_scale)\n{\n    constexpr int max_tiles         = kMoeGateMaxTiles;\n    constexpr int threads_per_token = max_expert_num / items_per_thread;  // 8\n    constexpr int tokens_per_cta    = block_dim / threads_per_token;\n\n    // We use bits in a uint32_t to represent selected experts\n    static_assert(items_per_thread <= 32);\n    // We use warp-level primitives for reduction\n    static_assert(threads_per_token <= 32);\n\n    static_assert((threads_per_token & (threads_per_token - 1)) == 0);\n\n    const int thread_idx = threadIdx.x + blockIdx.x * blockDim.x;\n\n    const int ti = thread_idx / threads_per_token;\n    const int ei = thread_idx % threads_per_token;\n\n    const int bti = threadIdx.x / threads_per_token;\n\n    const int warp_ti = threadIdx.x % WARP_SIZE / threads_per_token;\n\n    // const int warp_offset  = thread_idx / WARP_SIZE * WARP_SIZE / threads_per_token;\n    // const int block_offset = thread_idx / block_dim * block_dim / threads_per_token;\n\n    float data[items_per_thread];\n    int   idxs[items_per_thread];\n\n#if 0\n    PRAGMA_UNROLL\n    for (int i = 0; i < items_per_thread; ++i) {\n        data[i] = -std::numeric_limits<float>::infinity();\n        idxs[i] = threads_per_token * (i / access_size * access_size) + i % access_size + ei * access_size;\n    }\n    if (ti < token_num) {\n        PRAGMA_UNROLL\n        for (int i = 0; i < items_per_thread; i += access_size) {\n            const int e = threads_per_token * i + ei * access_size;\n            if (e < expert_num) {\n                Ldg((Array<float, access_size>&)data[i], &logits[ti * expert_num + e]);\n            }\n        }\n    }\n\n    __shared__ union {\n        struct {\n            // +1 padding greatly reduced (-80%) bank conflicts\n            int   shared_accum[max_tiles][max_expert_num + 1];\n            float shared_scales[max_top_k][tokens_per_cta];\n            int   shared_exp_id[max_top_k][tokens_per_cta];\n        };\n    } smem;\n#elif 1\n    PRAGMA_UNROLL\n    for (int i = 0; i < items_per_thread; ++i) {\n        data[i] = -std::numeric_limits<float>::infinity();\n        // idxs[i] = threads_per_token * (i / access_size * access_size) + i % access_size + ei * access_size;\n        idxs[i] = ei * items_per_thread + i;\n    }\n    if (ti < token_num) {\n        PRAGMA_UNROLL\n        for (int i = 0; i < items_per_thread; i += access_size) {\n            // const int e = threads_per_token * i + ei * access_size;\n            const int e = ei * items_per_thread + i;\n            if (e < expert_num) {\n                Ldg((Array<float, access_size>&)data[i], &logits[ti * expert_num + e]);\n            }\n        }\n    }\n\n    __shared__ union {\n        struct {\n            // +1 padding greatly reduced (-80%) bank conflicts\n            int   shared_accum[max_tiles][max_expert_num + 1];\n            float shared_scales[max_top_k][tokens_per_cta];\n            int   shared_exp_id[max_top_k][tokens_per_cta];\n        };\n    } smem;\n#else\n\n    const int warp_id = threadIdx.x / WARP_SIZE;\n    const int lane_id = threadIdx.x % WARP_SIZE;\n\n    constexpr int vecs_per_thread = items_per_thread / access_size;\n\n    using Vec            = Array<float, access_size>;\n    constexpr int banks  = 128 / sizeof(Vec);\n    constexpr int chunks = 4;  // block_dim / WARP_SIZE;\n\n    __shared__ union {\n        Vec shared_data[chunks][vecs_per_thread * WARP_SIZE / banks][banks + 1];\n        struct {\n            // +1 padding greatly reduced (-80%) bank conflicts\n            int   shared_accum[max_tiles][max_expert_num + 1];\n            float shared_scales[max_top_k][tokens_per_cta];\n            int   shared_exp_id[max_top_k][tokens_per_cta];\n        };\n    } smem;\n\n    __align__(16) Vec vecs[vecs_per_thread];\n\n    {\n        const int warp_end = min(warp_offset + WARP_SIZE / threads_per_token, token_num) * expert_num;\n        int       p        = warp_offset * expert_num + access_size * lane_id;\n        PRAGMA_UNROLL\n        for (int i = 0; i < vecs_per_thread; ++i) {\n            fill(vecs[i], -std::numeric_limits<float>::infinity());\n            // const int p = warp_offset * expert_num + access_size * (lane_id + i * WARP_SIZE);\n            if (p < warp_end) {\n                Ldg(vecs[i], &logits[p]);\n            }\n            p += access_size * WARP_SIZE;\n        }\n    }\n\n    PRAGMA_UNROLL\n    for (int c = 0; c < block_dim / WARP_SIZE; c += chunks) {\n        PRAGMA_UNROLL\n        for (int i = 0; i < vecs_per_thread; ++i) {\n            int p = i * WARP_SIZE + lane_id;\n            if (c <= warp_id && warp_id < c + chunks) {\n                Store(smem.shared_data[warp_id - c][p / banks][p % banks].data(), vecs[i]);\n            }\n        }\n\n        __syncwarp();\n\n        PRAGMA_UNROLL\n        for (int i = 0; i < vecs_per_thread; ++i) {\n            int p = lane_id * vecs_per_thread + i;\n            if (c <= warp_id && warp_id < c + chunks) {\n                Load(vecs[i], smem.shared_data[warp_id - c][p / banks][p % banks].data());\n            }\n        }\n\n        __syncthreads();\n    }\n\n    PRAGMA_UNROLL\n    for (int i = 0; i < items_per_thread; ++i) {\n        idxs[i] = ei * items_per_thread + i;\n    }\n    PRAGMA_UNROLL\n    for (int i = 0; i < vecs_per_thread; ++i) {\n        (Array<float, access_size>&)data[i * access_size] = vecs[i];\n    }\n\n#endif\n\n    // constexpr float kLog2e = 1.4426950408889634074;\n    // if (k == 0) {\n    //     PRAGMA_UNROLL\n    //     for (int i = 0; i < items_per_thread; ++i) {\n    //         data[i] *= kLog2e;\n    //     }\n    // }\n\n    unsigned mask = (unsigned)-1;\n    float    max_logit;\n\n    int count{};\n\n    const int warp_ti_offset = warp_ti * threads_per_token;\n\n    auto run = [&](int k) {\n        unsigned bit     = 1;\n        unsigned max_bit = 0;\n        float    max_val = -std::numeric_limits<float>::infinity();\n        // local maximum\n        PRAGMA_UNROLL\n        for (int i = 0; i < items_per_thread; ++i) {\n            if ((mask & bit) && data[i] > max_val) {\n                max_bit = bit;\n                max_val = data[i];\n            }\n            // weird thing that nvcc tends to use funnel shift for `bit <<= 1`\n            asm(\"shl.b32 %0, %1, 1;\\n\" : \"=r\"(bit) : \"r\"(bit));\n        }\n\n        int   g_max_ei  = ei;\n        float g_max_val = max_val;\n        if constexpr (threads_per_token > 1) {\n            // global maximum\n            PRAGMA_UNROLL\n            for (int m = threads_per_token / 2; m >= 1; m /= 2) {\n                g_max_val = fmaxf(g_max_val, __shfl_xor_sync((uint32_t)-1, g_max_val, m));\n            }\n            // tie breaking\n            const auto active = __ballot_sync((uint32_t)-1, max_val == g_max_val);\n            g_max_ei          = __ffs(active >> (unsigned)warp_ti_offset) - 1;\n        }\n        if (k == 0) {\n            max_logit = g_max_val;\n        }\n        if (ei == g_max_ei) {\n            mask -= max_bit;\n            ++count;\n        }\n    };\n\n    run(0);\n\n    for (int k = 1; k < top_k; ++k) {\n        run(k);\n    }\n\n    mask = ~mask;\n\n    int used[items_per_thread];\n    {\n        unsigned bit = 1;\n        PRAGMA_UNROLL\n        for (int i = 0; i < items_per_thread; ++i) {\n            used[i] = (mask & bit) > 0;\n            asm(\"shl.b32 %0, %1, 1;\\n\" : \"=r\"(bit) : \"r\"(bit));\n        }\n    }\n\n    float sum_prob{};\n\n    if (softmax) {\n        PRAGMA_UNROLL\n        for (int i = 0; i < items_per_thread; ++i) {\n            if (!norm_topk || used[i]) {\n                data[i] = expf(data[i] - max_logit);\n                sum_prob += data[i];\n            }\n        }\n        PRAGMA_UNROLL\n        for (int m = threads_per_token / 2; m >= 1; m /= 2) {\n            sum_prob += __shfl_xor_sync((uint32_t)-1, sum_prob, m);\n        }\n        sum_prob = fdividef(1.f, sum_prob);\n    }\n    else {\n        sum_prob = 1.f;\n    }\n\n    using WarpScan = cub::WarpScan<int, threads_per_token>;\n    __shared__ typename WarpScan::TempStorage temp_storage[tokens_per_cta];\n\n    int idx{};\n    WarpScan{temp_storage[bti]}.ExclusiveSum(count, idx);\n\n    PRAGMA_UNROLL\n    for (int i = 0; i < items_per_thread; ++i) {\n        if (used[i]) {\n            smem.shared_exp_id[idx][bti] = idxs[i];\n            smem.shared_scales[idx][bti] = data[i] * sum_prob;\n            ++idx;\n        }\n    }\n\n    PRAGMA_UNROLL\n    for (int i = 0; i < max_tiles * max_expert_num; i += block_dim) {\n        int e = (i + threadIdx.x) % max_expert_num;\n        int t = (i + threadIdx.x) / max_expert_num;\n        if (t < max_tiles) {\n            smem.shared_accum[t][e] = 0;\n        }\n    }\n\n    __syncthreads();\n\n    constexpr int k_per_thread = cdiv(max_top_k, threads_per_token);\n\n    const int bti2 = threadIdx.x % tokens_per_cta;\n    const int ei2  = threadIdx.x / tokens_per_cta;\n    const int ti2  = blockIdx.x * tokens_per_cta + bti2;\n\n    PRAGMA_UNROLL\n    for (int i = 0; i < k_per_thread; ++i) {\n        const int   idx       = ei2 * k_per_thread + i;\n        const int   expert_id = smem.shared_exp_id[idx][bti2];\n        const float scale     = smem.shared_scales[idx][bti2];\n\n        if (ti2 < token_num && idx < top_k) {\n            masks[expert_id * token_num_padded + ti2] = idx;\n            scales[idx * token_num + ti2]             = scale * routed_scale;\n            atomicAdd(&smem.shared_accum[ti2 >> log_tile][expert_id], 1);\n        }\n    }\n\n    __syncthreads();\n\n    for (int i = 0; i < max_expert_num * max_tiles; i += block_dim) {\n        int t = (threadIdx.x + i) % max_tiles;\n        int e = (threadIdx.x + i) / max_tiles;\n        if (e < expert_num && t < tiles) {\n            atomicAdd(accum + e * tiles + t, smem.shared_accum[t][e]);\n        }\n    }\n}\n\ntemplate<int N>\ninline constexpr std::integral_constant<int, N> _Int{};\n\nvoid invokeMoeGate_V2(int*         f2n,            // [e*n] -> n\n                      int*         f2E,            // [e*n] -> E\n                      int*         en2f,           // [e,n] -> n*e\n                      int*         offsets,        // [E+1]\n                      float*       scales,         // [e,n]\n                      void*        masks,          // [E,n]\n                      int*         accum,          // [E]\n                      const float* logits,         // [e,n]\n                      int          tokens,         //  n\n                      int          tokens_padded,  //  round_up(n, 4)\n                      int          experts,        //  E\n                      int          experts_per_token,\n                      bool         softmax,\n                      bool         norm_topk,\n                      float        routed_scale,\n                      cudaStream_t st)\n{\n    constexpr int base_log_tile = 9;\n\n    int log_tile = base_log_tile;\n    while (((tokens_padded + (1 << log_tile) - 1) >> log_tile) > kMoeGateMaxTiles) {\n        ++log_tile;\n    }\n    const int tiles = ceil_div(tokens_padded, 1 << log_tile);\n\n    // std::cout << log_tile << \" \" << tiles << \"\\n\";\n\n    auto invoke = [&](auto max_expert_num, auto top_k, auto items_per_thread, auto vec_size) {\n        constexpr int thrs_per_tok = max_expert_num.value / items_per_thread.value;\n        constexpr int threads      = 256;\n        const int     blocks       = ceil_div(tokens, threads / thrs_per_tok);\n\n        cudaMemsetAsync(masks, -1, sizeof(int8_t) * experts * tokens_padded, st);\n\n        MoeGateKernel_v8<max_expert_num.value, top_k.value, items_per_thread.value, threads, vec_size.value>\n            <<<blocks, threads, 0, st>>>(  //\n                scales,\n                (int8_t*)masks,\n                accum,\n                logits,\n                log_tile,\n                tiles,\n                tokens,\n                tokens_padded,\n                experts,\n                experts_per_token,\n                softmax,\n                norm_topk,\n                routed_scale);\n\n        return true;\n    };\n\n    if (!softmax && norm_topk) {\n        // norm top-k is part of softmax impl\n        TM_CHECK(0) << softmax << \" \" << norm_topk;\n    }\n\n    auto dispatch = [&] {\n        if (experts <= 8) {\n            if (experts_per_token <= 2) {\n                return invoke(_Int<8>, _Int<2>, _Int<8>, _Int<4>);\n            }\n            else {\n                return invoke(_Int<8>, _Int<8>, _Int<8>, _Int<4>);\n            }\n        }\n        else if (experts <= 64) {\n            if (experts_per_token <= 4) {\n                return invoke(_Int<64>, _Int<4>, _Int<16>, _Int<4>);\n            }\n            else if (experts_per_token <= 8) {\n                return invoke(_Int<64>, _Int<8>, _Int<16>, _Int<4>);\n            }\n        }\n        else if (experts <= 128) {\n            if (experts_per_token <= 8) {\n                return invoke(_Int<128>, _Int<8>, _Int<16>, _Int<4>);\n            }\n        }\n        else if (experts <= 160) {\n            if (experts_per_token <= 8) {\n                return invoke(_Int<160>, _Int<8>, _Int<10>, _Int<2>);\n            }\n        }\n        else if (experts <= 512) {\n            if (experts_per_token <= 8) {\n                return invoke(_Int<512>, _Int<8>, _Int<16>, _Int<4>);\n            }\n        }\n        return false;\n    };\n\n    auto success = dispatch();\n\n    sync_check_cuda_error();\n\n    TM_CHECK(success) << \"unsupported moe config: expert_num=\" << experts << \", top_k=\" << experts_per_token\n                      << \", softmax=\" << softmax << \", norm_topk=\" << norm_topk;\n\n    {\n        constexpr int threads = (1 << base_log_tile) / kMoeGateVecSize;\n        const dim3    blocks(tiles, experts + 1);\n\n        MoeScanKernel_v2<threads><<<blocks, threads, 0, st>>>(f2n,  //\n                                                              f2E,\n                                                              en2f,\n                                                              offsets,\n                                                              (int8_t*)masks,\n                                                              accum,\n                                                              log_tile,\n                                                              tiles,\n                                                              tokens,\n                                                              tokens_padded,\n                                                              experts);\n    }\n}\n\n// noaux_tc: scores = scoring_func(logits), scores_for_choice = scores + correction_bias,\n// top-k on scores_for_choice, weights from scores; renormalize; apply routed_scale.\n// Threading: one token per block, threads cooperate over expert dimension.\n__global__ void MoeGateNoAuxTCKernel(float*       scales,  // [top_k, tokens]\n                                     int8_t*      masks,   // [experts, tokens_padded]\n                                     int*         accum,   // [experts, tiles]\n                                     const float* logits,  // [tokens, experts]\n                                     const float* bias,    // [experts] or nullptr\n                                     int          tokens,\n                                     int          tokens_padded,\n                                     int          experts,\n                                     int          top_k,\n                                     bool         norm_topk,\n                                     float        routed_scale,\n                                     int          log_tile,\n                                     int          tiles,\n                                     bool         use_sigmoid)\n{\n    const int ti = blockIdx.x;  // one token per block\n    if (ti >= tokens) {\n        return;\n    }\n\n    extern __shared__ char smem[];\n    float*                 scores            = (float*)smem;\n    float*                 scores_for_choice = scores + experts;\n\n    const float* row = logits + ti * experts;\n\n    if (use_sigmoid) {\n        // Sigmoid scoring: scores[e] = 1 / (1 + exp(-logit[e]))\n        for (int e = threadIdx.x; e < experts; e += blockDim.x) {\n            float s              = 1.0f / (1.0f + expf(-row[e]));\n            scores[e]            = s;\n            scores_for_choice[e] = s + (bias ? bias[e] : 0.f);\n        }\n    }\n    else {\n        // Softmax scoring: scores[e] = exp(logit[e] - max) / sum(exp)\n        float max_logit = -1e30f;\n        for (int e = threadIdx.x; e < experts; e += blockDim.x) {\n            float v = row[e];\n            if (v > max_logit) {\n                max_logit = v;\n            }\n        }\n        max_logit = blockReduceMax<float>(max_logit);\n        __syncthreads();\n\n        float sum_exp = 0.f;\n        for (int e = threadIdx.x; e < experts; e += blockDim.x) {\n            float s   = expf(row[e] - max_logit);\n            scores[e] = s;\n            sum_exp += s;\n        }\n        sum_exp = blockReduceSum<float>(sum_exp);\n        __syncthreads();\n\n        for (int e = threadIdx.x; e < experts; e += blockDim.x) {\n            float s              = scores[e] / (sum_exp + 1e-20f);\n            scores[e]            = s;\n            scores_for_choice[e] = s + (bias ? bias[e] : 0.f);\n        }\n    }\n    __syncthreads();\n\n    if (threadIdx.x == 0) {\n        // Top-k on scores_for_choice (simple linear scan)\n        int   topk_idx[32];\n        float topk_val[32];\n        for (int k = 0; k < top_k; k++) {\n            int   best_e = -1;\n            float best_v = -INFINITY;\n            for (int e = 0; e < experts; e++) {\n                if (k > 0) {\n                    bool chosen = false;\n                    for (int j = 0; j < k; j++) {\n                        if (topk_idx[j] == e) {\n                            chosen = true;\n                            break;\n                        }\n                    }\n                    if (chosen) {\n                        continue;\n                    }\n                }\n                float v = scores_for_choice[e];\n                if (!isfinite(v)) {\n                    v = -INFINITY;\n                }\n                if (v > best_v) {\n                    best_v = v;\n                    best_e = e;\n                }\n            }\n            if (best_e < 0) {\n                best_e      = 0;\n                topk_val[k] = 0.f;\n            }\n            else {\n                topk_val[k] = scores[best_e];\n            }\n            topk_idx[k] = best_e;\n        }\n\n        float wsum = 0.f;\n        for (int k = 0; k < top_k; k++) {\n            wsum += topk_val[k];\n        }\n        if (norm_topk && wsum > 1e-20f) {\n            for (int k = 0; k < top_k; k++) {\n                topk_val[k] /= wsum;\n            }\n        }\n        for (int k = 0; k < top_k; k++) {\n            scales[k * tokens + ti] = topk_val[k] * routed_scale;\n        }\n\n        for (int k = 0; k < top_k; k++) {\n            masks[topk_idx[k] * tokens_padded + ti] = (int8_t)k;\n        }\n\n        const int tile_id = ti >> log_tile;\n        for (int k = 0; k < top_k; k++) {\n            const int e = topk_idx[k];\n            atomicAdd(&accum[e * tiles + tile_id], 1);\n        }\n    }\n}\n\nvoid invokeMoeGate_NoAuxTC(int*         f2n,\n                           int*         f2E,\n                           int*         en2f,\n                           int*         offsets,\n                           float*       scales,\n                           void*        masks,\n                           int*         accum,\n                           const float* logits,\n                           const float* correction_bias,\n                           int          tokens,\n                           int          tokens_padded,\n                           int          experts,\n                           int          exp_per_tok,\n                           bool         norm_topk_prob,\n                           float        routed_scale,\n                           bool         use_sigmoid,\n                           cudaStream_t st)\n{\n    TM_CHECK(exp_per_tok > 0);\n    TM_CHECK_LE(exp_per_tok, 32);\n    TM_CHECK_LE(exp_per_tok, experts);\n\n    constexpr int base_log_tile = 9;\n    int           log_tile      = base_log_tile;\n    while (((tokens_padded + (1 << log_tile) - 1) >> log_tile) > kMoeGateMaxTiles) {\n        ++log_tile;\n    }\n    const int tiles = ceil_div(tokens_padded, 1 << log_tile);\n\n    cudaMemsetAsync(accum, 0, sizeof(int) * experts * kMoeGateMaxTiles, st);\n    cudaMemsetAsync(masks, -1, sizeof(int8_t) * experts * tokens_padded, st);\n\n    // One token per block: threads cooperate over expert dimension\n    int block_dim = 1;\n    while (block_dim < experts && block_dim < 256) {\n        block_dim *= 2;  // next power of 2\n    }\n    const int    blocks = tokens;\n    const size_t smem   = sizeof(float) * experts * 2;\n\n    MoeGateNoAuxTCKernel<<<blocks, block_dim, smem, st>>>(scales,\n                                                          (int8_t*)masks,\n                                                          accum,\n                                                          logits,\n                                                          correction_bias,\n                                                          tokens,\n                                                          tokens_padded,\n                                                          experts,\n                                                          exp_per_tok,\n                                                          norm_topk_prob,\n                                                          routed_scale,\n                                                          log_tile,\n                                                          tiles,\n                                                          use_sigmoid);\n\n    constexpr int scan_threads = (1 << base_log_tile) / kMoeGateVecSize;\n    const dim3    scan_blocks(tiles, experts + 1);\n    MoeScanKernel_v2<scan_threads><<<scan_blocks, scan_threads, 0, st>>>(\n        f2n, f2E, en2f, offsets, (int8_t*)masks, accum, log_tile, tiles, tokens, tokens_padded, experts);\n}\n\ntemplate<int vec_size, int block_dim, class T>\n__global__ void MoeGatherKernel(T*         dst,  // [e*n, d]\n                                const T*   src,  // [  n, d]\n                                const int* f2n,  // [e*n] :: e*n -> n\n                                int        dims)\n{\n    using Vec        = Array<T, vec_size>;\n    const int64_t bi = blockIdx.x;\n\n    auto src_ptr = (const Vec*)src + dims * f2n[bi];\n    auto dst_ptr = (/* */ Vec*)dst + dims * bi;\n    for (int i = threadIdx.x; i < dims; i += block_dim) {\n        Vec v;\n        Ldg(v, src_ptr[i].data());\n        Store(dst_ptr[i].data(), v);\n    }\n}\n\nvoid invokeMoeDispatch(Ref<Tensor> out_, const Tensor& src, const int* f2n, int expert_per_token, cudaStream_t st)\n{\n    auto& out    = out_.get();\n    auto  invoke = [&](auto t) {\n        using T                = decltype(t);\n        auto [num, dim]        = src.shapes(0, 1);\n        constexpr int threads  = 256;\n        constexpr int vec_size = 16 / sizeof(T);\n        // std::cout << num * expert_per_token << \" \" << dim << \"\\n\";\n        MoeGatherKernel<vec_size, threads><<<num * expert_per_token, threads, 0, st>>>(  //\n            (T*)out.raw_data(),\n            (const T*)src.raw_data(),\n            f2n,\n            dim / vec_size);\n    };\n    TM_CHECK_EQ(src.dtype(), out.dtype());\n    const auto elem_size = byte_size(src.dtype());\n    if (elem_size == sizeof(uint16_t)) {\n        return invoke(uint16_t{});\n    }\n    else if (elem_size == sizeof(uint8_t)) {\n        return invoke(uint8_t{});\n    }\n    TM_CHECK(0) << \"unsupported data type: \" << src.dtype();\n}\n\ntemplate<int alignment, int block_dim, class T>\n__global__ void MoeDispatchScales(\n    T* dst, int* dst_offsets, const T* src, const int* f2n, const int* offsets, int dim, int expert_num, int stride)\n{\n    int bi = blockIdx.x;\n\n    __shared__ int shared_g;\n\n    for (int g = threadIdx.x; g < expert_num; ++g) {\n        if (offsets[g] <= bi && bi < offsets[g + 1]) {\n            shared_g = g;\n        }\n    }\n\n    __syncthreads();\n\n    int g = shared_g;\n\n    const int base = (offsets[g - 1] + alignment * (g - 1)) / alignment * alignment;\n    const int ti   = base + bi - offsets[g];\n\n    bi = f2n[bi];\n\n    // ! strided access\n    for (int di = threadIdx.x; di < dim; di += block_dim) {\n        dst[di * stride + ti] = src[di * stride + bi];\n    }\n}\n\ntemplate<class T>\n__global__ void\nMoeDispatchScalesNonaligned(T* dst, const T* src, int dst_stride, int src_stride, const int* f2n, int dim)\n{\n    const int bi = blockIdx.x;\n    const int ti = f2n[bi];\n\n    if (threadIdx.x < dim) {\n        dst[threadIdx.x * dst_stride + bi] = src[threadIdx.x * src_stride + ti];\n    }\n}\n\nvoid invokeMoeDispatchScales(Ref<Tensor> out_, const Tensor& src, const int* f2n, int expert_per_token, cudaStream_t st)\n{\n    using T                 = float;\n    constexpr int alignment = 16 / sizeof(T);\n\n    auto [dim, num] = src.shapes(0, 1);\n\n    const int size         = num * expert_per_token;\n    const int aligned_size = round_up<int>(size, alignment);\n\n    auto& out = out_.get();\n\n    if (!out) {\n        out = Tensor_<T>{{{dim, size}, {aligned_size, 1}}, kDEVICE};\n    }\n    else {\n        TM_CHECK(std::make_tuple(dim, size) == out.shapes(0, 1));\n        TM_CHECK(out.stride(1) == 1);\n        TM_CHECK(out.stride(0) % alignment == 0);\n    }\n\n    TM_CHECK_LE(dim, 1024);\n    const int threads = round_up<int>(dim, WARP_SIZE);\n    const int blocks  = size;\n\n    // std::cout << src << \" \" << out << \"\\n\";\n\n    MoeDispatchScalesNonaligned<<<blocks, threads, 0, st>>>((T*)out.raw_data(),  //\n                                                            (const T*)src.raw_data(),\n                                                            out.stride(0),\n                                                            src.stride(0),\n                                                            f2n,\n                                                            dim);\n}\n\ntemplate<int vec_size, int exp_k, bool has_bias, int block_dim, class T>\n__global__ void MoeReduceKernel(T*           dst,         // [  n, d]\n                                const T*     src,         // [e*n, d]\n                                const T*     bias,        // [  E, d]\n                                const float* scales,      // [  e, n]\n                                const int*   en2f,        // [  e, n] :: (e,n) -> e*n\n                                const int*   f2E,         // [  e* n]\n                                const float* dst_scales,  // [n]\n                                int          dim,\n                                int          tokens,\n                                T            bscale,\n                                float        dst_scale)\n{\n    if constexpr (TURBOMIND_ARCH_DTYPE_GUARD(data_type_v<T>)) {\n        const int64_t ti = blockIdx.x;\n\n        dst += dim * ti;\n\n        if (dst_scales) {\n            dst_scale = dst_scales[ti];\n            dst_scale = fdividef(1.f, 1.f + expf(-dst_scale));\n        }\n\n        // Should be warp uniforms\n        const T* src_[exp_k];\n        const T* bias_[exp_k];\n\n        float scale[exp_k];\n\n        PRAGMA_UNROLL\n        for (int e = 0; e < exp_k; ++e) {\n            int fid = __ldg(&en2f[e * tokens + ti]);\n            src_[e] = src + dim * fid;\n            if constexpr (has_bias) {\n                bias_[e] = bias + __ldg(&f2E[fid]) * dim;\n            }\n            scale[e] = scales ? __ldg(&scales[e * tokens + ti]) : 1.f;\n        }\n\n        using Vec = Array<T, vec_size>;\n\n        for (int i = threadIdx.x * vec_size; i < dim; i += block_dim * vec_size) {\n            Array<float, vec_size> accum{};\n            if (dst_scale) {\n                Vec v;\n                Load(v, &dst[i]);\n                using namespace ops;\n                accum = cast<float>(v) * dst_scale;\n            }\n            PRAGMA_UNROLL\n            for (int e = 0; e < exp_k; ++e) {\n                Vec v;\n                Load(v, src_[e] + i);\n                using namespace ops;\n                if constexpr (has_bias) {\n                    Vec b;\n                    Load(b, bias_[e] + i);\n                    PRAGMA_UNROLL\n                    for (int i = 0; i < vec_size; ++i) {\n                        v[i] = __hfma(b[i], bscale, v[i]);\n                    }\n                }\n                const auto x = cast<float>(v) * scale[e];\n                accum        = accum + x;\n            }\n            Store(&dst[i], cast<T>(accum));\n        }\n    }\n}\n\ntemplate<bool has_bias, class T>\nvoid invokeMoeReduce(T*           dst,\n                     const T*     src,\n                     const T*     bias,\n                     const float* scales,\n                     const int*   en2f,\n                     const int*   f2E,\n                     const float* dst_scales,\n                     int          tokens,\n                     int          experts_per_token,\n                     int          dim,\n                     T            bscale,\n                     float        dst_scale,\n                     cudaStream_t st)\n{\n    // std::cout << __PRETTY_FUNCTION__ << std::endl;\n\n    const auto invoke = [&](auto e) {\n        constexpr int threads     = 256;\n        constexpr int vec_size    = 16 / sizeof(T);\n        constexpr int exp_per_tok = decltype(e)::value;\n        MoeReduceKernel<vec_size, exp_per_tok, has_bias, threads><<<tokens, threads, 0, st>>>(  //\n            dst,\n            src,\n            bias,\n            scales,\n            en2f,\n            f2E,\n            dst_scales,\n            dim,\n            tokens,\n            bscale,\n            dst_scale);\n    };\n\n    switch (experts_per_token) {\n        case 1:\n            return invoke(std::integral_constant<int, 1>{});\n        case 2:\n            return invoke(std::integral_constant<int, 2>{});\n        case 4:\n            return invoke(std::integral_constant<int, 4>{});\n        case 6:\n            return invoke(std::integral_constant<int, 6>{});\n        case 8:\n            return invoke(std::integral_constant<int, 8>{});\n        default:\n            fprintf(stderr, \"Unsupported experts_per_token %d\\n\", experts_per_token);\n            std::abort();\n    }\n}\n\nvoid invokeMoeCombine(Ref<Tensor>   out_,\n                      const Tensor& src,\n                      const Tensor& bias,\n                      const float*  scales,\n                      const int*    en2f,\n                      const int*    f2E,\n                      const float*  dst_scales,\n                      int           experts_per_token,\n                      float         bscale,\n                      float         dst_scale,\n                      cudaStream_t  st)\n{\n    auto& out = out_.get();\n\n    const int tokens = out.shape(0);\n    TM_CHECK_EQ(src.shape(0), tokens * experts_per_token);\n\n    auto invoke = [&](auto has_bias, auto t) {\n        using T = decltype(t);\n        return invokeMoeReduce<has_bias.value>(out.data<T>(),\n                                               src.data<T>(),\n                                               bias.data_or((T*)nullptr),\n                                               scales,\n                                               en2f,\n                                               f2E,\n                                               dst_scales,\n                                               tokens,\n                                               experts_per_token,\n                                               src.shape(1),\n                                               (T)bscale,\n                                               dst_scale,\n                                               st);\n    };\n\n    auto dispatch_dtype = [&](auto t) {\n        if (bias) {\n            TM_CHECK_NOTNULL(f2E);\n            return invoke(std::true_type{}, t);\n        }\n        else {\n            return invoke(std::false_type{}, t);\n        }\n    };\n\n    TM_DISPATCH_PRIMARY_DTYPES(src.dtype(), dispatch_dtype);\n}\n\nstd::vector<int> SampleUniform(int token_num, int expert_num, int exp_per_tok, std::mt19937& g)\n{\n    std::vector<int> idxs((size_t)token_num * exp_per_tok);\n    std::vector<int> r(expert_num);\n    std::iota(r.begin(), r.end(), 0);\n    auto it = idxs.begin();\n    for (int i = 0; i < token_num; ++i) {\n        it = std::sample(r.cbegin(), r.cend(), it, exp_per_tok, g);\n    }\n    return idxs;\n}\n\nstd::vector<int> SampleBalanced(int token_num, int expert_num, int exp_per_tok, std::mt19937& g)\n{\n    assert(exp_per_tok <= expert_num);\n    std::vector<int> idxs((size_t)token_num * exp_per_tok);\n    std::vector<int> q;\n\n    std::vector<int> r(expert_num);\n    std::iota(r.begin(), r.end(), 0);\n\n    auto it = idxs.begin();\n    for (int i = 0; i < token_num; ++i) {\n        if ((int)q.size() < exp_per_tok) {\n            const int k = q.size();\n            // prepend the experts: [xxx] -> [yyy | xxx]\n            q.insert(q.begin(), r.cbegin(), r.cend());\n            // move duplicated experts to the front: [yyy | xxx] -> [xxx' | yyy' | xxx]\n            int p = 0;\n            std::for_each(q.cend() - k, q.cend(), [&](auto x) { std::swap(q[p++], q[x]); });\n            // shuffle unique experts yyy'\n            std::shuffle(q.begin() + p, q.end() - k, g);\n        }\n        it = std::copy(q.end() - exp_per_tok, q.end(), it);\n        // remove used experts [xxx' | yyy' | xxx ] -> [xxx' | zzz]\n        q.resize(q.size() - exp_per_tok);\n        // alias [xxx] <- [xxx' | zzz]\n    }\n    assert(it == idxs.end());\n\n    // shuffle to decorrelate adjacent tokens\n    r.resize(token_num);\n    std::iota(r.begin(), r.end(), 0);\n    std::shuffle(r.begin(), r.end(), g);\n    std::vector<int> ret(idxs.size());\n    it = ret.begin();\n    for (const auto& i : r) {\n        it = std::copy_n(idxs.begin() + i * exp_per_tok, exp_per_tok, it);\n    }\n    assert(it == ret.end());\n    return ret;\n}\n\ntemplate<int max_expert_num, int items_per_thread, int access_size>\n__global__ void MoeSoftmaxMaskTopKGroups(float* logits, int token_num, int expert_num, int top_k)\n{\n    constexpr int threads_per_token = max_expert_num / items_per_thread;\n\n    static_assert((threads_per_token & (threads_per_token - 1)) == 0);\n    static_assert(items_per_thread % access_size == 0);\n\n    const int thread_idx = threadIdx.x + blockIdx.x * blockDim.x;\n\n    const int ti = thread_idx / threads_per_token;\n    const int ei = thread_idx % threads_per_token;\n\n    float data[items_per_thread];\n    PRAGMA_UNROLL\n    for (int i = 0; i < items_per_thread; ++i) {\n        data[i] = -std::numeric_limits<float>::infinity();\n    }\n    // max logit in the group\n    float max_val = -std::numeric_limits<float>::infinity();\n    if (ti < token_num) {\n        PRAGMA_UNROLL\n        for (int i = 0; i < items_per_thread; i += access_size) {\n            const int e = ei * items_per_thread + i;  // blocked partition\n            if (e < expert_num) {\n                Ldg((Array<float, access_size>&)data[i], &logits[ti * expert_num + e]);\n                PRAGMA_UNROLL\n                for (int c = 0; c < access_size; ++c) {\n                    max_val = fmaxf(max_val, data[i + c]);\n                }\n            }\n        }\n    }\n\n    const int warp_ti        = threadIdx.x % WARP_SIZE / threads_per_token;\n    const int warp_ti_offset = warp_ti * threads_per_token;\n\n    bool  alive     = false;\n    float max_logit = 0;\n\n    for (int k = 0; k < top_k; ++k) {\n        int   g_max_ei  = ei;\n        float g_max_val = max_val;\n        PRAGMA_UNROLL\n        for (int m = threads_per_token / 2; m >= 1; m /= 2) {\n            g_max_val = fmaxf(g_max_val, __shfl_xor_sync((uint32_t)-1, g_max_val, m));\n        }\n        // tie breaking\n        const auto active = __ballot_sync((uint32_t)-1, max_val == g_max_val);\n        g_max_ei          = __ffs(active >> (unsigned)warp_ti_offset) - 1;\n        if (k == 0) {\n            max_logit = g_max_val;\n        }\n        if (ei == g_max_ei) {\n            alive   = true;\n            max_val = -std::numeric_limits<float>::infinity();\n        }\n    }\n\n    float sum_prob{};\n\n    PRAGMA_NO_UNROLL\n    for (int i = 0; i < items_per_thread; ++i) {\n        data[i] = expf(data[i] - max_logit);\n        sum_prob += data[i];\n    }\n\n    PRAGMA_UNROLL\n    for (int m = threads_per_token / 2; m >= 1; m /= 2) {\n        sum_prob += __shfl_xor_sync((uint32_t)-1, sum_prob, m);\n    }\n\n    // mask dead logits\n    sum_prob = alive ? fdividef(1.f, sum_prob) : 0;\n\n    PRAGMA_UNROLL\n    for (int i = 0; i < items_per_thread; ++i) {\n        data[i] *= sum_prob;\n    }\n\n    if (ti < token_num) {\n        PRAGMA_UNROLL\n        for (int i = 0; i < items_per_thread; i += access_size) {\n            const int e = ei * items_per_thread + i;\n            if (e < expert_num) {\n                Store(&logits[ti * expert_num + e], (Array<float, access_size>&)data[i]);\n            }\n        }\n    }\n}\n\nvoid invokeMoeSoftmaxMaskTopKGroups(\n    float* logits, int token_num, int expert_num, int group_size, int top_k, cudaStream_t st)\n{\n    auto invoke = [&](auto max_expert_num, auto items_per_thread, auto vec_size) {\n        constexpr int thrs_per_tok = max_expert_num.value / items_per_thread.value;\n        constexpr int threads      = 256;\n        const int     blocks       = ceil_div(token_num, threads / thrs_per_tok);\n        MoeSoftmaxMaskTopKGroups<max_expert_num.value, items_per_thread.value, vec_size.value>\n            <<<blocks, threads, 0, st>>>(logits, token_num, expert_num, top_k);\n    };\n\n    if (expert_num == 160 && group_size == 20) {\n        return invoke(_Int<160>, _Int<20>, _Int<4>);\n    }\n\n    std::cerr << __FILE__ << \"(\" << __LINE__ << \"): unsupported moe config: expert_num=\" << expert_num\n              << \", group_size=\" << group_size << \"\\n\";\n    std::abort();\n}\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/moe_utils_v2.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include <cstdint>\n#include <cuda_runtime.h>\n#include <random>\n#include <vector>\n\n#include \"src/turbomind/core/core.h\"\n\nnamespace turbomind {\n\nconstexpr int kMoeGateMaxTiles = 16;\nconstexpr int kMoeGateVecSize  = 4;\n\nvoid invokeMoeGate_V2(int*         f2n,\n                      int*         f2E,\n                      int*         en2f,\n                      int*         offsets,\n                      float*       scales,\n                      void*        masks,\n                      int*         accum,\n                      const float* logits,\n                      int          tokens,\n                      int          tokens_padded,\n                      int          experts,\n                      int          exp_per_tok,\n                      bool         softmax,\n                      bool         norm_topk,\n                      float        routed_scale,\n                      cudaStream_t st);\n\nvoid invokeMoeDispatch(Ref<Tensor>   out_,  //\n                       const Tensor& src,\n                       const int*    f2n,\n                       int           expert_per_token,\n                       cudaStream_t  st);\n\nvoid invokeMoeDispatchScales(Ref<Tensor>   out_,  //\n                             const Tensor& src,\n                             const int*    f2n,\n                             int           expert_per_token,\n                             cudaStream_t  st);\n\nvoid invokeMoeCombine(Ref<Tensor>   out_,\n                      const Tensor& src,\n                      const Tensor& bias,\n                      const float*  scales,\n                      const int*    en2f,\n                      const int*    f2E,\n                      const float*  dst_scales,\n                      int           experts_per_token,\n                      float         bscale,\n                      float         dst_scale,\n                      cudaStream_t  st);\n\nvoid invokeMoeSoftmaxMaskTopKGroups(\n    float* logits, int token_num, int expert_num, int group_size, int top_k, cudaStream_t st);\n\n/// noaux_tc routing: scores = scoring_func(logits), scores_for_choice = scores + correction_bias,\n/// top-k on scores_for_choice, weights from scores; renormalize if norm_topk_prob; always apply routed_scale.\n/// correction_bias may be nullptr (then treated as 0).\n/// use_sigmoid: if true, scores = sigmoid(logits); if false, scores = softmax(logits).\nvoid invokeMoeGate_NoAuxTC(int*         f2n,\n                           int*         f2E,\n                           int*         en2f,\n                           int*         offsets,\n                           float*       scales,\n                           void*        masks,\n                           int*         accum,\n                           const float* logits,\n                           const float* correction_bias,\n                           int          tokens,\n                           int          tokens_padded,\n                           int          experts,\n                           int          exp_per_tok,\n                           bool         norm_topk_prob,\n                           float        routed_scale,\n                           bool         use_sigmoid,\n                           cudaStream_t st);\n\n// Sample `e` from `E` experts uniformly for every token\nstd::vector<int> SampleUniform(int token_num, int expert_num, int exp_per_tok, std::mt19937& g);\n\nstd::vector<int> SampleBalanced(int token_num, int expert_num, int exp_per_tok, std::mt19937& g);\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/operand.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include \"src/turbomind/kernels/core/layout.h\"\n#include \"src/turbomind/kernels/core/meta.h\"\n#include \"src/turbomind/kernels/gemm/iterator.h\"\n#include \"src/turbomind/kernels/gemm/smem_copy.h\"\n#include \"src/turbomind/kernels/gemm/types.h\"\n#include \"src/turbomind/kernels/gemm/utils.h\"\n\nnamespace turbomind::gemm {\n\nstruct VoidOperand {\n    using Dtype = int;\n\n    static constexpr Pack  kPack  = 0;\n    static constexpr Order kOrder = Order::kColMajor;\n\n    struct GetSmemLayout {\n        static constexpr SmemLayoutV2<1, 1> apply(...)\n        {\n            return {};\n        }\n    };\n\n    using SmemCopyAtom = VoidSmemCopyAtom;\n\n    struct GetGmemIter {\n        static constexpr auto apply(...)\n        {\n            return type_c<VoidGmemIter>;\n        }\n    };\n};\n\n/// TODO: fix AlignC, AlignS\n/// TODO: fix GroupSize\ntemplate<class Operand, class Iterator, int M, int K, int WARPS, int GroupSize = 1>\nstruct MakeOperand {\n\n    using Dtype = typename Operand::Dtype;\n\n    static constexpr Pack  kPack      = Operand::kPack;\n    static constexpr Order kOrder     = Operand::kOrder;\n    static constexpr int   kGroupSize = GroupSize;\n\n    static constexpr int2 kPackMK = Packing_v2<kPack, kOrder>::apply({M, ceil_div(K, kGroupSize)});\n\n    static constexpr pair<kPackMK.x, kPackMK.y> kShapeMK{};\n\n    using SmemLayout   = decltype(Operand::GetSmemLayout::apply(kShapeMK));\n    using SmemAccessor = SmemAccessorV2<Dtype, SmemLayout, kOrder>;\n\n    using GmemIter = typename decltype(Operand::GetGmemIter::apply(\n        type_c<Operand>, type_c<Iterator>, type_c<SmemLayout>, kShapeMK, constant<WARPS>{}))::type;\n\n    using SmemCopyAtom = typename Operand::SmemCopyAtom;\n};\n\n// CPO for getting specific operand templates\ntemplate<MMA_Tag mma, Op_Tag optag, class T, Order order, bool is_pack, class SFINAE = void>\nstruct GetOperand: std::false_type {\n};\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/predicate.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include <cstdint>\n#include <type_traits>\n\nnamespace turbomind::gemm {\n\ntemplate<int S, int C, bool AlignedS, bool AlignedC>\nstruct Predicate {\n\n    static constexpr int kSizeC = AlignedC ? 1 : C;\n\n    static_assert(S * kSizeC <= 32);\n\n    static constexpr bool is_active = true;\n\n    uint32_t pred_{};\n\n    __device__ int operator()(int s, int c) const\n    {\n        return (pred_ & (1 << (s * kSizeC + c))) != 0;\n    }\n\n    __device__ void set(int s, int c)\n    {\n        pred_ |= (1 << (s * kSizeC + c));\n    }\n\n    __device__ void clear()\n    {\n        pred_ = 0;\n    }\n};\n\ntemplate<int S, int C>\nstruct Predicate<S, C, true, true> {\n\n    static constexpr bool is_active = false;\n\n    __device__ constexpr std::integral_constant<int, 1> operator()(int, int) const\n    {\n        return {};\n    }\n\n    __device__ void set(int, int) {}\n\n    __device__ void clear()\n    {\n        // pred_ = 0;\n    }\n};\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/registry.cu",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include \"src/turbomind/kernels/gemm/arch.h\"\n#include \"src/turbomind/kernels/gemm/registry.h\"\n\nnamespace turbomind::gemm {\n\nRegistry::Registry(std::shared_ptr<cudaDeviceProp> device_prop):\n    device_prop_{std::move(device_prop)}, arch_{device_prop_->major * 100 + device_prop_->minor * 10}\n{\n    sm90_16816_4();\n    sm90_16816_8();\n    sm90_16816_16();\n\n    sm80_16816_4();\n    sm80_16816_8();\n    sm80_16816_16();\n\n    sm75_16816_4();\n    sm75_16816_8();\n    sm75_16816_16();\n\n    sm70_884_4();\n    sm70_884_8();\n    sm70_884_16();\n\n    sm90_64n32_8();\n\n    cublas_float();\n}\n\nbool Registry::Add(std::unique_ptr<Kernel> kernel)\n{\n    bool is_valid = true;\n\n    if (!is_arch_compatible(kernel->arch(), arch_)) {\n        is_valid = false;\n    }\n\n    // if (is_valid) {\n    //     std::cout << \"register: \" << kernel->name()                                        //\n    //               << \", shared: \" << (kernel->smem_size() >> 10) << \" KB\"                  //\n    //               << \", regs: \" << kernel->info().attr.numRegs                             //\n    //               << \", local: \" << (float)kernel->info().attr.localSizeBytes << \" bytes\"  //\n    //               << \", max_active_ctas: \" << kernel->info().max_active_ctas << \" \\n\";\n    // }\n\n    if ((int)device_prop_->sharedMemPerBlockOptin < kernel->smem_size()) {\n        is_valid = false;\n    }\n\n    if (is_valid) {\n        ptrs_.push_back(kernels_.emplace_back(transpose(*kernel)).get());\n        ptrs_.push_back(kernels_.emplace_back(std::move(kernel)).get());\n    }\n\n    return true;\n}\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/registry.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include \"src/turbomind/kernels/gemm/kernel_impl.h\"\n#include <memory>\n\nnamespace turbomind::gemm {\n\nclass Registry {\npublic:\n    explicit Registry(std::shared_ptr<cudaDeviceProp> device_prop);\n\n    /// TODO: remove this\n    template<class Config>\n    [[maybe_unused]] bool Add()\n    {\n        return Add(std::make_unique<KernelImpl<typename Config::Kernel>>());\n    }\n\n    [[nodiscard]] const std::vector<Kernel*>& kernels() const\n    {\n        return ptrs_;\n    }\n\nprivate:\n    bool Add(std::unique_ptr<Kernel> kernel);\n\n    void sm90_16816_4();\n    void sm90_16816_8();\n    void sm90_16816_16();\n\n    void sm80_16816_4();\n    void sm80_16816_8();\n    void sm80_16816_16();\n\n    void sm75_16816_4();\n    void sm75_16816_8();\n    void sm75_16816_16();\n\n    void sm70_884_4();\n    void sm70_884_8();\n    void sm70_884_16();\n\n    void sm90_64n32_8();\n\n    void cublas_float();\n\nprivate:\n    std::shared_ptr<cudaDeviceProp>      device_prop_;\n    int                                  arch_;\n    std::vector<std::unique_ptr<Kernel>> kernels_;\n    std::vector<Kernel*>                 ptrs_;\n};\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/scaled_gmma_fp8_sm90.h",
    "content": "#pragma once\n\n#include <numeric>\n\n#include \"cute/arch/mma_sm90_gmma.hpp\"\n#include \"cute/atom/mma_traits.hpp\"\n#include \"cute/atom/mma_traits_sm90_gmma.hpp\"\n\n#include \"src/turbomind/kernels/core/common.h\"\n#include \"src/turbomind/kernels/core/meta.h\"\n#include \"src/turbomind/kernels/gemm/sm90_utils.h\"\n\nnamespace turbomind::gemm {\n\ntemplate<int TILE_M, int TILE_N, int TILE_K, int BATCH_M, int BATCH_N, int PIPE_M, int PIPE_N>\nstruct ScaledGmmaFP8_TN {\n\n    static constexpr auto select_gmma_operation()\n    {\n        static_assert(TILE_M % (BATCH_M * PIPE_M) == 0);\n        static_assert(TILE_N % (BATCH_N * PIPE_N) == 0);\n\n        constexpr int M = TILE_M / (BATCH_M * PIPE_M);\n        constexpr int N = TILE_N / (BATCH_N * PIPE_N);\n\n        static_assert(M % 64 == 0);\n\n        using namespace cute::SM90::GMMA;\n\n        if constexpr (N % 256 == 0) {\n            return type_c<MMA_64x256x32_F32E4M3E4M3_SS_TN<>>;\n        }\n        else if constexpr (N % 224 == 0) {\n            return type_c<MMA_64x224x32_F32E4M3E4M3_SS_TN<>>;\n        }\n        else if constexpr (N % 192 == 0) {\n            return type_c<MMA_64x192x32_F32E4M3E4M3_SS_TN<>>;\n        }\n        else if constexpr (N % 160 == 0) {\n            return type_c<MMA_64x160x32_F32E4M3E4M3_SS_TN<>>;\n        }\n        else if constexpr (N % 128 == 0) {\n            return type_c<MMA_64x128x32_F32E4M3E4M3_SS_TN<>>;\n        }\n        else if constexpr (N % 96 == 0) {\n            return type_c<MMA_64x96x32_F32E4M3E4M3_SS_TN<>>;\n        }\n        else if constexpr (N % 64 == 0) {\n            return type_c<MMA_64x64x32_F32E4M3E4M3_SS_TN<>>;\n        }\n        else {\n            static_assert(N == 0, \"unsupported configuration\");\n        }\n    }\n\n    using Operation = typename decltype(select_gmma_operation())::type;\n\n    static constexpr typename cute::MMA_Traits<Operation>::Shape_MNK OP_Shape{};\n\n    static constexpr int OP_M = cute::get<0>(OP_Shape);\n    static constexpr int OP_N = cute::get<1>(OP_Shape);\n    static constexpr int OP_K = cute::get<2>(OP_Shape);\n\n    static constexpr int ITER_M = TILE_M / OP_M / BATCH_M / PIPE_M;\n    static constexpr int ITER_N = TILE_N / OP_N / BATCH_N / PIPE_N;\n\n    using FragU = float[ITER_M][PIPE_M][BATCH_M][2];\n    using FragV = float[2];\n\n    using FragC = typename Operation::CRegisters[PIPE_M][PIPE_N][BATCH_M][BATCH_N];\n\n    using AccumC = FragC[ITER_M][ITER_N];\n\n    static constexpr int kStepMA = (OP_M * TILE_K) >> 4;\n    static constexpr int kStepNB = (OP_N * TILE_K) >> 4;\n    static constexpr int kStepKA = (OP_K) >> 4;\n    static constexpr int kStepKB = (OP_K) >> 4;\n\n    static constexpr int OUTER_N = std::gcd(TILE_N, 128);\n\n    template<class FragC, class AccumC, class FragU, class FragV, class PredV>\n    __device__ static void scale_batch_to_accum(AccumC&      accum_C,\n                                                const FragC& frag_C,\n                                                const FragU& frag_U,\n                                                const FragV& frag_V,\n                                                const PredV& pred_V,\n                                                int          offset_V)\n    {\n        PRAGMA_UNROLL\n        for (int m = 0; m < BATCH_M; ++m) {\n            float scales[2][2];\n            // TODO: check the compiler's ability to avoid re-computing this\n            scales[0][0] = frag_U[m][0] * frag_V[0];\n            scales[1][0] = frag_U[m][1] * frag_V[0];\n            scales[0][1] = frag_U[m][0] * frag_V[1];\n            scales[1][1] = frag_U[m][1] * frag_V[1];\n            PRAGMA_UNROLL\n            for (int n = 0; n < BATCH_N; ++n) {\n                PRAGMA_UNROLL\n                for (int c0 = 0; c0 < OP_N; c0 += OUTER_N) {\n                    int  i = (offset_V + c0) / OUTER_N;\n                    bool p = pred_V[i];\n                    PRAGMA_UNROLL\n                    for (int c1 = 0; c1 < OUTER_N; c1 += 8) {\n                        int c = c0 + c1;\n                        accum_C[m][n][c / 2 + 0] += (p ? scales[0][1] : scales[0][0]) * frag_C[m][n][c / 2 + 0];\n                        accum_C[m][n][c / 2 + 1] += (p ? scales[0][1] : scales[0][0]) * frag_C[m][n][c / 2 + 1];\n                        accum_C[m][n][c / 2 + 2] += (p ? scales[1][1] : scales[1][0]) * frag_C[m][n][c / 2 + 2];\n                        accum_C[m][n][c / 2 + 3] += (p ? scales[1][1] : scales[1][0]) * frag_C[m][n][c / 2 + 3];\n                    }\n                }\n            }\n        }\n    }\n\n    __device__ static void warpgroup_wait(int n)\n    {\n        if (n == 0) {\n            cute::warpgroup_wait<0>();\n        }\n        else if (n == 1) {\n            cute::warpgroup_wait<1>();\n        }\n        else if (n == 2) {\n            cute::warpgroup_wait<2>();\n        }\n        else if (n == 3) {\n            cute::warpgroup_wait<3>();\n        }\n        else if (n == 4) {\n            cute::warpgroup_wait<4>();\n        }\n        else if (n == 5) {\n            cute::warpgroup_wait<5>();\n        }\n        else if (n == 6) {\n            cute::warpgroup_wait<6>();\n        }\n        else if (n == 7) {\n            cute::warpgroup_wait<7>();\n        }\n    }\n\n    template<class SmemIterA, class SmemIterB, class FragC>\n    __device__ static void gmma_batch(SmemIterA& iter_A, SmemIterB& iter_B, FragC& frag_C)\n    {\n        constexpr int BATCH_K = TILE_K / OP_K;\n        PRAGMA_UNROLL\n        for (int k = 0; k < BATCH_K; ++k) {\n            PRAGMA_UNROLL\n            for (int m = 0; m < BATCH_M; ++m) {\n                PRAGMA_UNROLL\n                for (int n = 0; n < BATCH_N; ++n) {\n                    wgmma<Operation>(iter_A, iter_B, frag_C[m][n], k == 0);\n                    iter_B += kStepNB;\n                }\n                iter_B -= kStepNB * BATCH_N;\n                iter_A += kStepMA;\n            }\n            iter_A -= kStepMA * BATCH_M;\n            iter_A += kStepKA;\n            iter_B += kStepKB;\n        }\n        iter_A -= kStepKA * BATCH_K;\n        iter_B -= kStepKB * BATCH_K;\n        cute::warpgroup_commit_batch();\n    }\n\n    template<class SmemIterA, class SmemIterB, class FragC, class AccumC, class FragU, class FragV, class PredV>\n    __device__ static void gmma_pipe(AccumC&      accum_C,\n                                     SmemIterA&   iter_A,\n                                     SmemIterB&   iter_B,\n                                     FragC&       frag_C,\n                                     const FragU& frag_U,\n                                     const FragV& frag_V,\n                                     const PredV& pred_V,\n                                     int          offset_V)\n    {\n        cute::warpgroup_arrive();\n        PRAGMA_UNROLL\n        for (int m = 0; m < PIPE_M; ++m) {\n            PRAGMA_UNROLL\n            for (int n = 0; n < PIPE_N; ++n) {\n                gmma_batch(iter_A, iter_B, frag_C[m][n]);\n                iter_B += kStepNB * BATCH_N;\n            }\n            iter_B -= kStepNB * BATCH_N * PIPE_N;\n            iter_A += kStepMA * BATCH_M;\n        }\n        iter_A -= kStepMA * BATCH_M * PIPE_M;\n\n        int i = 0;\n        PRAGMA_UNROLL\n        for (int m = 0; m < PIPE_M; ++m) {\n            PRAGMA_UNROLL\n            for (int n = 0; n < PIPE_N; ++n, ++i) {\n                warpgroup_wait(PIPE_M * PIPE_N - i - 1);\n                int offset = offset_V + n * BATCH_N * OP_N;\n                scale_batch_to_accum(accum_C[m][n], frag_C[m][n], frag_U[m], frag_V, pred_V, offset);\n            }\n        }\n    }\n\n    template<class SmemIterA, class SmemIterB, class FragC, class AccumC, class FragU, class FragV, class PredV>\n    __device__ static void apply(SmemIterA&   iter_A,\n                                 SmemIterB&   iter_B,\n                                 FragC&       frag_C,\n                                 AccumC&      accum_C,\n                                 const FragU& frag_U,\n                                 const FragV& frag_V,\n                                 const PredV& pred_V)\n    {\n        PRAGMA_UNROLL\n        for (int m = 0; m < ITER_M; ++m) {\n            PRAGMA_UNROLL\n            for (int n = 0; n < ITER_N; ++n) {\n                int offset_V = n * PIPE_N * BATCH_N * OP_N;\n                gmma_pipe(accum_C[m][n], iter_A, iter_B, frag_C, frag_U[m], frag_V, pred_V, offset_V);\n                iter_B += kStepNB * BATCH_N * PIPE_N;\n            }\n            iter_B -= kStepNB * BATCH_N * PIPE_N * ITER_N;\n            iter_A += kStepMA * BATCH_M * PIPE_M;\n        }\n        iter_A -= kStepMA * BATCH_M * PIPE_M * ITER_M;\n    }\n\n    template<class Frag, class Func>\n    __device__ static void foreach_C(Frag& frag, Func&& func)\n    {\n        PRAGMA_UNROLL\n        for (int i_m = 0; i_m < ITER_M; ++i_m) {\n            PRAGMA_UNROLL\n            for (int i_n = 0; i_n < ITER_N; ++i_n) {\n                PRAGMA_UNROLL\n                for (int p_m = 0; p_m < PIPE_M; ++p_m) {\n                    PRAGMA_UNROLL\n                    for (int p_n = 0; p_n < PIPE_N; ++p_n) {\n                        PRAGMA_UNROLL\n                        for (int b_m = 0; b_m < BATCH_M; ++b_m) {\n                            PRAGMA_UNROLL\n                            for (int b_n = 0; b_n < BATCH_N; ++b_n) {\n                                int m = ((i_m * PIPE_M) + p_m * BATCH_M) + b_m;\n                                int n = ((i_n * PIPE_N) + p_n * BATCH_N) + b_n;\n                                func(frag[i_m][i_n][p_m][p_n][b_m][b_n], m, n);\n                            }  // BATCH_N\n                        }      // BATCH_M\n                    }          // PIPE_N\n                }              // PIPE_M\n            }                  // ITER_N\n        }                      // ITER_M\n    }\n\n    template<class Frag, class Func>\n    __device__ static void foreach_m(Frag& frag, Func&& func)\n    {\n        PRAGMA_UNROLL\n        for (int i_m = 0; i_m < ITER_M; ++i_m) {\n            PRAGMA_UNROLL\n            for (int p_m = 0; p_m < PIPE_M; ++p_m) {\n                PRAGMA_UNROLL\n                for (int b_m = 0; b_m < BATCH_M; ++b_m) {\n                    int m = ((i_m * PIPE_M) + p_m * BATCH_M) + b_m;\n                    func(frag[i_m][p_m][b_m], m);\n                }\n            }\n        }\n    }\n};\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/scheduler.cuh",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include \"cutlass/fast_math.h\"\n#include \"src/turbomind/kernels/core/common.h\"\n#include \"src/turbomind/kernels/core/math.h\"\n#include \"src/turbomind/kernels/core/smem.h\"\n#include \"src/turbomind/kernels/gemm/cta_map.h\"\n#include \"src/turbomind/kernels/gemm/types.h\"\n\n#include \"cute/arch/cluster_sm90.hpp\"\n#include \"cutlass/arch/barrier.h\"\n#include \"cutlass/cutlass.h\"\n#include \"cutlass/pipeline/sm90_pipeline.hpp\"\n\nnamespace turbomind::gemm {\n\nTM_DEVICE void mbarrier_arrive_cluster(uint64_t* mbar, int cta_id, int pred)\n{\n    uint32_t smem_addr = cast_smem_ptr_to_uint(mbar);\n    if (pred) {\n        asm volatile(\"{\\n\"\n                     \".reg .b32 remAddr32;\\n\"\n                     \"mapa.shared::cluster.u32  remAddr32, %0, %1;\\n\"\n                     \"mbarrier.arrive.release.cluster.shared::cluster.b64  _, [remAddr32];\\n\"\n                     \"}\"\n                     :\n                     : \"r\"(smem_addr), \"r\"(cta_id));\n    }\n}\n\nTM_DEVICE void mbarrier_wait_cluster(uint64_t* mbar, uint32_t phase)\n{\n    uint32_t smem_addr = cast_smem_ptr_to_uint(mbar);\n    uint32_t ticks     = 0x989680;\n    asm volatile(\"{\\n\"\n                 \".reg .pred       P1; \\n\"\n                 \"LAB_WAIT: \\n\"\n                 \"mbarrier.try_wait.parity.acquire.cluster.shared::cta.b64 P1, [%0], %1, %2; \\n\"\n                 \"@P1 bra DONE; \\n\"\n                 \"bra     LAB_WAIT; \\n\"\n                 \"DONE: \\n\"\n                 \"}\"\n                 :\n                 : \"r\"(smem_addr), \"r\"(phase), \"r\"(ticks));\n}\n\nTM_DEVICE void* map_to_cta(void* ptr, int cta_id)\n{\n    void* ret;\n    asm volatile(\"mapa.u64 %0, %1, %2;\\n\" : \"=l\"(ret) : \"l\"(ptr), \"r\"(cta_id));\n    return ret;\n}\n\nTM_DEVICE void st_shared_cluster(uint32_t ptr, int value)\n{\n    asm volatile(\"st.shared::cluster.s32 [%0], %1;\\n\" ::\"r\"(ptr), \"r\"(value));\n}\n\ntemplate<class T, class M>\nconstexpr int member_offset(M T::*member)\n{\n    return reinterpret_cast<std::size_t>(&(reinterpret_cast<T*>(0)->*member));\n}\n\ntemplate<Order order,\n         class Cluster,\n         int  striped_m,\n         bool striped_n,\n         int  tile_m,\n         int  tile_n,\n         int  Stages_,\n         bool is_grouped_gemm>\nstruct TileScheduler {\n\n    static constexpr bool is_dynamic = 1;  // is_grouped_gemm;\n    static constexpr int  Stages     = Stages_;\n\n    static constexpr int2 tile_{tile_m, tile_n};\n    static constexpr int2 cluster_tile_{tile_m * Cluster::M, tile_n* Cluster::N};\n\n    int4 gemm_shape_;\n    int2 tiled_shape_;\n\n    int log_tile_;\n    int k_iters_;\n\n    int2 tile_offset_;\n    int2 iter_k_range_;\n\n    int clusters_;\n\n    //////// v2 /////\n    int2 swizzle_unit_;\n    int2 cluster_tiles_;\n    int2 padded_cluster_tiles_;\n    int2 swizzled_cluster_tiles_;\n\n    cutlass::FastDivmod swizzle_tile_x_;\n    /////////////\n\n    const int* offsets_;\n\n    int* next_cluster_id_;\n\n    using PipelineState = cutlass::PipelineState<Stages>;\n\n    struct Tile0 {\n        int is_valid_cta;\n        int is_valid_cluster;\n        int offset_m;\n        int offset_n;\n        int alive;\n    };\n\n    struct Tile1 {\n        int is_valid_cta;\n        int is_valid_cluster;\n        int offset_m;\n        int offset_n;\n        int alive;\n        int group_idx;\n        int m0;\n        int m1;\n    };\n\n    using Tile = std::conditional_t<is_grouped_gemm, Tile1, Tile0>;\n\n    struct Storage {\n        Tile tile[Stages];\n        __align__(8) uint64_t producer_bar[Stages];\n        __align__(8) uint64_t consumer_bar[Stages];\n    };\n\n    struct ConsumerState {\n        PipelineState  pipe;\n        Storage&       store;\n        TileScheduler& sched;\n\n        TM_DEVICE bool acquire(Tile*& tile)\n        {\n            return sched.acquire(*this, tile);\n        }\n\n        TM_DEVICE void release(int step = 1)\n        {\n            return sched.release(*this, step);\n        }\n    };\n\n    struct ProducerState {\n        PipelineState  pipe;\n        int            group_id_offset;\n        int            cluster_idx;\n        Storage&       store;\n        TileScheduler& sched;\n\n        TM_DEVICE bool next()\n        {\n            return sched.next(*this);\n        }\n    };\n\npublic:\n    TM_DEVICE void init_dyanmic(Storage& store, int consumer_num)\n    {\n        for (int i = 0; i < Stages; ++i) {\n            cutlass::arch::ClusterBarrier::init(&store.producer_bar[i], 1);\n            cutlass::arch::ClusterBarrier::init(&store.consumer_bar[i], consumer_num);\n        }\n        // cutlass::arch::ClusterBarrier::init(&store.sync_bar, 1);\n    }\n\n    TM_HOST_DEVICE void init(int4 gemm_shape, int log_tile, int3 tile_shape)\n    {\n        gemm_shape_ = gemm_shape;\n\n        // printf(\"gemm shape: %d %d %d\\n\", gemm_shape.x, gemm_shape.y, gemm_shape.z);\n\n        log_tile_ = log_tile;\n        k_iters_  = cdiv(gemm_shape_.z, tile_shape.z);\n\n        tiled_shape_.x = cdiv(gemm_shape.x, tile_.x);\n        tiled_shape_.y = cdiv(gemm_shape.y, tile_.y);\n\n        cluster_tiles_.x = cdiv(gemm_shape.x, cluster_tile_.x);  // useless\n        cluster_tiles_.y = cdiv(gemm_shape.y, cluster_tile_.y);\n\n        // printf(\"cluster tiles: %d %d\\n\", cluster_tiles_.x, cluster_tiles_.y);\n\n        if constexpr (is_grouped_gemm) {\n            {\n                int2 unit     = get_swizzled_shape({1, 1}, log_tile);\n                swizzle_unit_ = order == kColMajor ? int2{unit.y, unit.x} : int2{unit.x, unit.y};\n            }\n\n            // col {8, 1}, row {1, 8}\n            // printf(\"swizzle unit: %d %d\\n\", swizzle_unit_.x, swizzle_unit_.y);\n\n            swizzle_tile_x_ = cluster_tile_.x * swizzle_unit_.x;\n\n            int num = gemm_shape_.w;\n\n            // num of tiles won't change after swizzle\n            padded_cluster_tiles_.x = (num + gemm_shape.x / (cluster_tile_.x * swizzle_unit_.x)) * swizzle_unit_.x;\n            padded_cluster_tiles_.y = cdiv(gemm_shape.y, cluster_tile_.y * swizzle_unit_.y) * swizzle_unit_.y;\n\n            // printf(\"padded   cluster tiles: %d %d\\n\", padded_cluster_tiles_.x, padded_cluster_tiles_.y);\n\n            swizzled_cluster_tiles_ = get_swizzled_shape(padded_cluster_tiles_, log_tile);\n\n            // printf(\"swizzled cluster tiles: %d %d\\n\", swizzled_cluster_tiles_.x, swizzled_cluster_tiles_.y);\n\n            clusters_ = padded_cluster_tiles_.x * padded_cluster_tiles_.y;\n\n            // printf(\"clusters = %d\\n\", clusters_);\n            // M is runtime value\n        }\n        else {\n            tiled_shape_.x = cdiv(gemm_shape.x, tile_.x);\n            tiled_shape_.y = cdiv(gemm_shape.y, tile_.y);\n\n            cluster_tiles_.x = cdiv(gemm_shape.x, cluster_tile_.x);\n            cluster_tiles_.y = cdiv(gemm_shape.y, cluster_tile_.y);\n\n            swizzled_cluster_tiles_ = get_swizzled_shape(cluster_tiles_, log_tile);\n\n            swizzle_tile_x_ = swizzled_cluster_tiles_.x;\n\n            clusters_ = swizzled_cluster_tiles_.x * swizzled_cluster_tiles_.y;\n        }\n    }\n\n    TM_HOST_DEVICE static int get_log_tile(int2 tiled_mn, int tile_size)\n    {\n        return gemm::get_log_tile(order == kColMajor ? tiled_mn.y : tiled_mn.x, tile_size);\n    }\n\n    TM_HOST_DEVICE static int2 get_swizzled_shape(int2 tiled_shape, int log_tile)\n    {\n        const int tile = 1 << log_tile;\n\n        if constexpr (order == kColMajor) {\n            return {tiled_shape.x * tile, (tiled_shape.y + tile - 1) >> log_tile};\n        }\n        else {\n            return {tiled_shape.y * tile, (tiled_shape.x + tile - 1) >> log_tile};\n        }\n    }\n\n    TM_DEVICE ProducerState init_producer(Storage& store)\n    {\n        int cluster_id = 0;\n        if constexpr (!is_dynamic) {\n            cluster_id = (int)cute::cluster_id_in_grid().x;\n        }\n        return {\n            PipelineState{0, 1, 0},\n            0,\n            cluster_id,\n            store,\n            *this,\n        };\n    }\n\n    TM_DEVICE ConsumerState init_consumer(Storage& store)\n    {\n        return {\n            PipelineState{},\n            store,\n            *this,\n        };\n    }\n\n    TM_DEVICE void\n    unswizzle(Tile& tile, int cluster_idx, int cta_id, int2 cta_tiles, int2 cluster_tiles, int2 swizzle_tiles) const\n    {\n        int cluster_idx_x, cluster_idx_y;\n\n        if constexpr (is_grouped_gemm) {\n            cluster_idx_x = cluster_idx % swizzle_tiles.x;\n            cluster_idx_y = cluster_idx / swizzle_tiles.x;\n        }\n        else {\n            swizzle_tile_x_(cluster_idx_y, cluster_idx_x, cluster_idx);\n        }\n\n        auto [cluster_cta_m, cluster_cta_n] = Cluster::cta_mn(cta_id);\n\n        const int offset_x = cluster_cta_m * (striped_m ? cluster_tiles.x : 1);\n        const int offset_y = cluster_cta_n * (striped_n ? cluster_tiles.y : 1);\n\n        int2 cluster_tile_offset;\n\n        if constexpr (order == kColMajor) {\n            cluster_tile_offset = {(cluster_idx_x >> log_tile_),\n                                   (cluster_idx_y << log_tile_) + (cluster_idx_x & ((1 << log_tile_) - 1))};\n        }\n        else {\n            cluster_tile_offset = {(cluster_idx_y << log_tile_) + (cluster_idx_x & ((1 << log_tile_) - 1)),\n                                   (cluster_idx_x >> log_tile_)};\n        }\n\n        // `tile` may be on DSMEM\n        int tile_idx_x        = offset_x + cluster_tile_offset.x * (striped_m ? 1 : Cluster::M);\n        int tile_idx_y        = offset_y + cluster_tile_offset.y * (striped_n ? 1 : Cluster::N);\n        tile.offset_m         = tile_idx_x * tile_.x;\n        tile.offset_n         = tile_idx_y * tile_.y;\n        int valid_cluster_p   = cluster_tile_offset.x < cluster_tiles.x && cluster_tile_offset.y < cluster_tiles.y;\n        tile.is_valid_cta     = valid_cluster_p && tile_idx_x < cta_tiles.x && tile_idx_y < cta_tiles.y;\n        tile.is_valid_cluster = valid_cluster_p;\n    }\n\n    TM_DEVICE int get_start_index(int g) const\n    {\n        // return (__ldg(&offsets_[g]) / (cluster_tile_.x * swizzle_unit_.x) + g) * swizzle_unit_.x\n        //        * padded_cluster_tiles_.y;\n        return (swizzle_tile_x_.div(__ldg(&offsets_[g])) + g) * swizzle_unit_.x * padded_cluster_tiles_.y;\n    }\n\n    TM_DEVICE bool update_sync(int   cluster_idx,\n                               int&  group_id_offset,\n                               int&  group_idx,\n                               int&  group_beg,\n                               int&  group_m0,\n                               int&  group_m1,\n                               int2& tiled_shape,\n                               int2& cluster_tiles,\n                               int2& swizzled_tiles) const\n    {\n        const int lane_id = threadIdx.x % WARP_SIZE;\n\n        uint32_t mask;\n        while (true) {\n            int e    = group_id_offset + lane_id;\n            int pred = e > gemm_shape_.w || cluster_idx < get_start_index(e);\n            mask     = __ballot_sync((uint32_t)-1, pred);\n            if (mask) {\n                break;\n            }\n            group_id_offset += WARP_SIZE;\n        }\n\n        // 32 - clz(~mask) - 1\n        group_idx = group_id_offset + 31 - __clz(~mask);\n\n        group_m0 = __ldg(&offsets_[group_idx]);\n        group_m1 = __ldg(&offsets_[group_idx + 1]);\n        int m    = group_m1 - group_m0;\n\n        group_beg = get_start_index(group_idx);\n\n        tiled_shape.x   = cdiv(m, tile_.x);\n        cluster_tiles.x = cdiv(m, cluster_tile_.x);\n\n        swizzled_tiles = get_swizzled_shape(cluster_tiles, log_tile_);\n\n        return true;\n    }\n\n    TM_DEVICE bool next(ProducerState& state)\n    {\n        const int lane_id = cutlass::canonical_lane_idx();\n\n        auto& store = state.store;\n        auto& pipe  = state.pipe;\n\n        int cluster_idx{};\n\n        if constexpr (is_dynamic) {\n            if (lane_id == 0) {\n                cutlass::arch::ClusterBarrier::wait(&store.consumer_bar[pipe.index()], pipe.phase());\n                cluster_idx = atomicAdd(next_cluster_id_, 1);\n            }\n            cluster_idx = __shfl_sync((uint32_t)-1, cluster_idx, 0);\n        }\n        else {\n            cutlass::arch::ClusterBarrier::wait(&store.consumer_bar[pipe.index()], pipe.phase());\n            cluster_idx = state.cluster_idx;\n            state.cluster_idx += (int)cute::cluster_grid_dims().x;\n        }\n\n        Tile* tile{};\n\n        if constexpr (Cluster::size == 1) {\n            tile = &store.tile[pipe.index()];\n        }\n        else {\n            if (lane_id < Cluster::size) {\n                tile = (Tile*)map_to_cta(&store.tile[pipe.index()], lane_id);\n            }\n        }\n\n        const int alive = cluster_idx < clusters_;\n\n        if (alive) {\n            int  group_id      = 0;\n            int  group_beg     = 0;\n            int  group_m0      = 0;\n            int  group_m1      = 0;\n            auto cta_tiles     = tiled_shape_;\n            auto cluster_tiles = cluster_tiles_;\n            auto swizzle_tiles = swizzled_cluster_tiles_;\n            if constexpr (is_grouped_gemm) {\n                update_sync(cluster_idx,  //\n                            state.group_id_offset,\n                            group_id,\n                            group_beg,\n                            group_m0,\n                            group_m1,\n                            cta_tiles,\n                            cluster_tiles,\n                            swizzle_tiles);\n            }\n            if (lane_id < Cluster::size) {\n                unswizzle(*tile,  //\n                          cluster_idx - group_beg,\n                          lane_id,\n                          cta_tiles,\n                          cluster_tiles,\n                          swizzle_tiles);\n                if constexpr (is_grouped_gemm) {\n                    tile->group_idx = group_id;\n                    tile->m0        = group_m0;\n                    tile->m1        = group_m1;\n                }\n            }\n        }\n\n        if (lane_id < Cluster::size) {\n            tile->alive = alive;\n        }\n\n        if constexpr (Cluster::size == 1) {\n            if (lane_id == 0) {\n                cutlass::arch::ClusterBarrier::arrive(&store.producer_bar[pipe.index()]);\n            }\n        }\n        else {\n            mbarrier_arrive_cluster(&store.producer_bar[pipe.index()], lane_id, lane_id < Cluster::size);\n        }\n\n        ++pipe;\n\n        return alive;\n    }\n\n    TM_DEVICE void tail(ProducerState& state)\n    {\n        if constexpr (Cluster::size > 1) {\n            for (int i = 0; i < Stages; ++i) {\n                cutlass::arch::ClusterBarrier::wait(&state.store.consumer_bar[state.pipe.index()], state.pipe.phase());\n                ++state.pipe;\n            }\n        }\n    }\n\n    TM_DEVICE bool acquire(ConsumerState& state, Tile*& tile)\n    {\n        auto& store = state.store;\n        auto& pipe  = state.pipe;\n\n        if constexpr (Cluster::size == 1) {\n            cutlass::arch::ClusterBarrier::wait(&store.producer_bar[pipe.index()], pipe.phase());\n        }\n        else {\n            mbarrier_wait_cluster(&store.producer_bar[pipe.index()], pipe.phase());\n        }\n\n        tile = &store.tile[pipe.index()];\n\n        return tile->alive;\n    }\n\n    TM_DEVICE void release(ConsumerState& state, int step)\n    {\n        auto& store = state.store;\n        auto& pipe  = state.pipe;\n\n        __syncwarp();\n\n        if constexpr (Cluster::size == 1) {\n            if (cutlass::elect_one_sync()) {\n                cutlass::arch::ClusterBarrier::arrive(&store.consumer_bar[pipe.index()]);\n            }\n        }\n        else {\n            cutlass::arch::ClusterBarrier::arrive(&store.consumer_bar[pipe.index()], 0, cutlass::elect_one_sync());\n        }\n\n        pipe.advance(step);\n    }\n\n    TM_DEVICE int4 gemm_shape() const\n    {\n        return gemm_shape_;\n    }\n\n    TM_DEVICE int2 tiled_shape() const\n    {\n        return tiled_shape_;\n    }\n};\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/scheduler_sm70.cuh",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include \"src/turbomind/kernels/core/array.h\"\n#include \"src/turbomind/kernels/core/common.h\"\n#include \"src/turbomind/kernels/core/math.h\"\n#include \"src/turbomind/kernels/core/meta.h\"\n\n#include \"src/turbomind/kernels/gemm/cta_map.h\"\n#include \"src/turbomind/kernels/gemm/types.h\"\n\nnamespace turbomind::gemm {\n\ntemplate<Order order, int tile_m, int tile_n, int tile_k, int chunk_k, int split_k, int group_axis_>\nstruct SchedulerSm70 {\n\n    static constexpr int group_axis = group_axis_;\n\n    static constexpr Array<int, 3> tile_shape{tile_m, tile_n, tile_k};\n\n    static_assert(chunk_k % tile_k == 0);\n    static constexpr int chunk_iters = chunk_k / tile_k;\n\n    Array<int, 4> gemm_shape_;\n    Array<int, 3> tiles_;\n\n    int log_tile_;\n\n    int split_chunks_;\n    int chunk_offset_;\n\n    const int* offsets_;\n\n    struct Tile {\n        Array<int, 3> tile_id;\n        Array<int, 3> shape;\n        Array<int, 2> k_iters;\n\n        int group_id;\n        int linear_tile_id;\n    };\n\n    struct SharedStorage {\n        int group_id;\n        int dynamic_dim;\n        int base_tile_id;\n    };\n\n    __host__ dim3 get_grid_shape()\n    {\n        auto shape = get_swizzled_shape(tiles_, log_tile_);\n        return dim3(shape[0], shape[1], shape[2]);\n    }\n\n    __host__ SchedulerSm70(Array<int, 4> gemm_shape, int log_tile = 0, int splits = 1):\n        gemm_shape_{gemm_shape}, log_tile_{log_tile}\n    {\n        tiles_[0] = cdiv(gemm_shape[0], tile_m);\n        tiles_[1] = cdiv(gemm_shape[1], tile_n);\n        tiles_[2] = splits;\n\n        log_tile_ = log_tile;\n\n        Array<int, 2> log_unit{};\n        log_unit[1 - (int)order] = log_tile;\n\n        tiles_[0] = round_up(tiles_[0], 1 << log_unit[0]);\n        tiles_[1] = round_up(tiles_[1], 1 << log_unit[1]);\n\n        // printf(\"gemm shape: %d %d %d %d\\n\", gemm_shape_[0], gemm_shape_[1], gemm_shape_[2], gemm_shape_[3]);\n        // printf(\"tile shape: %d %d %d\\n\", tile_shape[0], tile_shape[1], tile_shape[2]);\n\n        if constexpr (group_axis >= 0) {\n            constexpr int i = group_axis;\n            // overwrite dynamic axis <- estimated upper bound\n            tiles_[i] = ((gemm_shape_[i] / tile_shape[i] >> log_unit[i]) + gemm_shape_[3]) << log_unit[i];\n        }\n\n        int chunks    = cdiv(gemm_shape[2], chunk_k);\n        split_chunks_ = chunks / splits;\n        chunk_offset_ = splits - chunks % splits;\n    }\n\n    __device__ int2 get_group_offset(int g)\n    {\n        constexpr int i = group_axis;\n\n        Array<int, 2> log_unit{};\n        log_unit[1 - (int)order] = log_tile_;\n\n        int offset      = __ldg(offsets_ + g);\n        int tile_offset = ((offset / tile_shape[i] >> log_unit[i]) + g) << log_unit[i];\n\n        return {offset, tile_offset};\n    }\n\n    __device__ int find_group(Array<int, 3>& tile_id, SharedStorage& storage)\n    {\n        constexpr int axis = group_axis;\n\n        int success = 0;\n\n        const int block_dim = blockDim.x;\n\n        for (int g = threadIdx.x; g < gemm_shape_[3]; g += block_dim) {\n            auto [beg, beg_tile] = get_group_offset(g);\n            auto [end, end_tile] = get_group_offset(g + 1);\n\n            if (beg_tile <= tile_id[axis] && tile_id[axis] < end_tile) {\n                storage.group_id     = g;\n                storage.dynamic_dim  = end - beg;\n                storage.base_tile_id = beg_tile;\n                success              = 1;\n            }\n\n            if (tile_id[axis] < end_tile) {\n                break;\n            }\n        }\n\n        return __syncthreads_or(success);\n    }\n\n    template<class Reinit>\n    __device__ int init(Tile& tile, SharedStorage& storage, Reinit)\n    {\n        Array<int, 3> cta_id{(int)blockIdx.x, (int)blockIdx.y, (int)blockIdx.z};\n        Array<int, 3> tile_id = unswizzle(cta_id);\n        Array<int, 3> shape{gemm_shape_[0], gemm_shape_[1], gemm_shape_[2]};\n\n        tile.group_id       = 0;\n        tile.linear_tile_id = tile_id[1 - (int)order] * tiles_[(int)order] + tile_id[(int)order];\n\n        constexpr int axis = group_axis;\n\n        if constexpr (axis >= 0) {\n            if (offsets_) {\n                if constexpr (!Reinit::value) {\n                    if (!find_group(tile_id, storage)) {\n                        return false;\n                    }\n                }\n                tile_id[axis] -= storage.base_tile_id;\n                shape[axis]   = storage.dynamic_dim;\n                tile.group_id = storage.group_id;\n                // Crucial for the values above to be recognized as warp uniform, `__syncwarp()`\n                // does not prevent modifying CTA scope SMEM from other warps\n                __syncthreads();\n            }\n        }\n\n        if constexpr (split_k) {\n            int split_id    = tile_id[2];\n            int chunk_id    = split_id * split_chunks_ + max(split_id - chunk_offset_, 0);\n            tile.k_iters[0] = chunk_id * chunk_iters;\n            tile.k_iters[1] = (split_chunks_ + int(split_id >= chunk_offset_)) * chunk_iters;\n        }\n        else {\n            tile.k_iters[0] = 0;\n            tile.k_iters[1] = split_chunks_ * chunk_iters;\n        }\n\n        tile.tile_id = tile_id;\n        tile.shape   = shape;\n\n        return true;\n    }\n\n    __device__ Array<int, 3> unswizzle(Array<int, 3> cta_id)\n    {\n        int tile_c = cta_id[0] >> log_tile_;\n        int tile_s = cta_id[1] << log_tile_ | (cta_id[0] & ((1 << log_tile_) - 1));\n\n        Array<int, 3> tile_id;\n\n        tile_id[(int)order]     = tile_c;\n        tile_id[1 - (int)order] = tile_s;\n\n        tile_id[2] = cta_id[2];\n\n        return tile_id;\n    }\n\n    __host__ __device__ static Array<int, 3> get_swizzled_shape(Array<int, 3> tiles, int log_tile)\n    {\n        constexpr int i = (int)order;  // expansion axis\n        return {tiles[i] << log_tile, (tiles[1 - i] + (1 << log_tile) - 1) >> log_tile, tiles[2]};\n    }\n\n    __host__ int get_max_swizzle()\n    {\n        constexpr int axis = 1 - (int)order;\n\n        int n = tiles_[axis];\n\n        if (group_axis == axis) {\n            n = cdiv(n, gemm_shape_[3]);\n        }\n\n        return get_log_tile(n);\n    }\n\n    __host__ __device__ static int get_log_tile(int size)\n    {\n        if (size >= 24)\n            return 5;\n        if (size >= 12)\n            return 4;\n        if (size >= 6)\n            return 3;\n        if (size >= 3)\n            return 2;\n        if (size >= 2)\n            return 1;\n        return 0;\n    }\n};\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/simt.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\nnamespace turbomind::gemm::simt {\n\n// constexpr int OP_M = 2;\n// constexpr int OP_N = 16;\n// constexpr int OP_K = 4;\n\n// constexpr int OP_M = 4;\n// constexpr int OP_N = 8;\n// constexpr int OP_K = 8;\n\nconstexpr int OP_M = 1;\nconstexpr int OP_N = 32;\nconstexpr int OP_K = 8;\n\n}  // namespace turbomind::gemm::simt\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/sm90_utils.h",
    "content": "\n\n#pragma once\n\n#include \"cute/arch/mma_sm90_gmma.hpp\"\n#include \"cute/atom/mma_traits.hpp\"\n#include \"cute/atom/mma_traits_sm90_gmma.hpp\"\n\n#include \"src/turbomind/kernels/core/meta.h\"\n#include \"src/turbomind/kernels/core/smem.h\"\n\n#include \"src/turbomind/kernels/gemm/types.h\"\n#include \"src/turbomind/kernels/gemm/utils.h\"\n\nnamespace turbomind::gemm {\n\nnamespace GMMA = cute::SM90::GMMA;\n\ninline __device__ cute::GmmaDescriptor make_smem_desc(void* smem_ptr, int layout_type)\n{\n    auto uint_ptr = cast_smem_ptr_to_uint(smem_ptr);\n\n    cute::GmmaDescriptor desc{};\n    desc.bitfield.start_address_       = uint_ptr >> 4;\n    desc.bitfield.layout_type_         = layout_type;\n    desc.bitfield.leading_byte_offset_ = 0;\n    desc.bitfield.stride_byte_offset_  = 1024 >> 4;\n    desc.bitfield.base_offset_         = 0;\n\n    return desc;\n}\n\ntemplate<int Stages, int Step>\nstruct SmemDescIterV2 {\n    union {\n        uint32_t u32_[2];\n        uint64_t u64_;\n    };\n\n    uint32_t base_;\n\n    __device__ SmemDescIterV2(uint64_t desc): u64_{desc}, base_{u32_[0]} {}\n\n    __device__ void Advance(int stage)\n    {\n        u32_[0] += Step;\n        if (stage == Stages - 1) {\n            u32_[0] = base_;\n        }\n    }\n\n    __device__ void Reset(int stage)\n    {\n        u32_[0] = base_ + stage * Step;\n    }\n\n    __device__ SmemDescIterV2& operator+=(int offset)\n    {\n        u32_[0] += offset;\n        return *this;\n    }\n\n    __device__ SmemDescIterV2& operator-=(int offset)\n    {\n        u32_[0] -= offset;\n        return *this;\n    }\n\n    __device__ operator uint64_t()\n    {\n        return u64_;\n    }\n};\n\ntemplate<class MMA_Atom, size_t... Is>\ninline __device__ void\nwgmma_impl(uint64_t desc_a, uint64_t desc_b, float* frag_C, bool clear, std::index_sequence<Is...>)\n{\n    return MMA_Atom::fma(desc_a, desc_b, frag_C[Is]..., clear ? GMMA::ScaleOut::Zero : GMMA::ScaleOut::One);\n}\n\ntemplate<class MMA_Atom, int N>\ninline __device__ void wgmma(uint64_t desc_a, uint64_t desc_b, float (&frag_C)[N], bool clear)\n{\n    return wgmma_impl<MMA_Atom>(desc_a, desc_b, frag_C, clear, std::make_index_sequence<N>{});\n}\n\ninline __device__ void warpgroup_fence_operand(float& reg)\n{\n    asm volatile(\"\" : \"+f\"(reg)::\"memory\");\n}\n\ntemplate<int M, int N, int K>\ninline __device__ void warpgroup_fence_operand(float (&x)[M][N][K])\n{\n    PRAGMA_UNROLL\n    for (int m = 0; m < M; ++m) {\n        PRAGMA_UNROLL\n        for (int n = 0; n < N; ++n) {\n            PRAGMA_UNROLL\n            for (int k = 0; k < K; ++k) {\n                warpgroup_fence_operand(x[m][n][k]);\n            }\n        }\n    }\n}\n\ntemplate<int N, int K>\ninline __device__ void warpgroup_fence_operand(float (&x)[N][K])\n{\n    PRAGMA_UNROLL\n    for (int n = 0; n < N; ++n) {\n        PRAGMA_UNROLL\n        for (int k = 0; k < K; ++k) {\n            warpgroup_fence_operand(x[n][k]);\n        }\n    }\n}\n\ntemplate<class Func, size_t... Is>\n__device__ void for_(std::index_sequence<Is...>, Func func)\n{\n    return (func(constant<Is>{}), ...);\n}\n\nnamespace arch {\n\ntemplate<int M_, int N_, Order order>\nstruct Cluster {\n    static constexpr int M = M_;\n    static constexpr int N = N_;\n\n    static constexpr int C = mk2cs<order>(M, N).x;\n    static constexpr int S = mk2cs<order>(M, N).y;\n\n    static constexpr int size = M * N;\n\n    static constexpr uint16_t kMaskC = (1 << C) - 1;\n    static constexpr uint16_t kMaskS = ((1 << size) - 1) / kMaskC;\n\n    __device__ static ushort2 mask_cs(int cta_id)\n    {\n        const auto [c, s] = cta_cs(cta_id);\n        return make_ushort2(kMaskS << c, kMaskC << s * C);\n    }\n\n    __device__ static ushort2 mask_mn(int cta_id)\n    {\n        auto [c, s] = mask_cs(cta_id);\n        return order == kColMajor ? ushort2{c, s} : ushort2{s, c};\n    }\n\n    __device__ static int2 cta_cs(int cta_id)\n    {\n        return {C > 1 ? cta_id % C : 0, S > 1 ? cta_id / C : 0};\n    }\n\n    __device__ static int2 cta_mn(int cta_id)\n    {\n        return cs2mk<order>(cta_cs(cta_id));\n    }\n\n    int2    cta_mn_;\n    ushort2 mask_mn_;\n\n    __device__ explicit Cluster(int cta_id): cta_mn_(cta_mn(cta_id)), mask_mn_(mask_mn(cta_id)) {}\n\n    __device__ int cta_m()\n    {\n        return cta_mn_.x;\n    }\n\n    __device__ int cta_n()\n    {\n        return cta_mn_.y;\n    }\n\n    __device__ uint16_t mask_m()\n    {\n        return mask_mn_.x;\n    }\n\n    __device__ uint16_t mask_n()\n    {\n        return mask_mn_.y;\n    }\n};\n\n}  // namespace arch\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/smem_copy.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include \"src/turbomind/kernels/core/array.h\"\n#include \"src/turbomind/kernels/core/common.h\"\n#include \"src/turbomind/kernels/core/data_type.h\"\n#include \"src/turbomind/kernels/core/layout.h\"\n#include \"src/turbomind/kernels/core/math.h\"\n#include \"src/turbomind/kernels/core/smem.h\"\n#include \"src/turbomind/kernels/gemm/types.h\"\n#include \"src/turbomind/kernels/gemm/utils.h\"\n\nnamespace turbomind::gemm {\n\nstruct VoidSmemCopyAtom {\n\n    static constexpr int M = 1;\n    static constexpr int K = 1;\n\n    static constexpr int kFragNum = 1;\n\n    using Frag = Array<int, 1>;\n\n    template<class S, class D>\n    __device__ static void copy(S, D, bool)\n    {\n    }\n\n    __device__ static int2 get_offset(int)\n    {\n        return {};\n    }\n\n    __device__ static int2 unique(int thread_idx, int pack_idx)\n    {\n        return {};\n    }\n};\n\ntemplate<class T, class Layout, Order order>\nstruct SmemAccessorV2 {\n};\n\ntemplate<class T, class Layout>\nstruct SmemAccessorV2<T, Layout, kRowMajor>: SmemAccessor<T, Layout> {\n    using SmemAccessor<T, Layout>::SmemAccessor;\n};\n\ntemplate<class T, class Layout>\nstruct SmemAccessorV2<T, Layout, kColMajor> {\n    SmemAccessor<T, Layout> base_;\n\n    __device__ SmemAccessorV2(get_pointer_type<T> ptr): base_{ptr} {}\n    __device__ T& operator()(int m, int k)\n    {\n        return base_(k, m);\n    }\n};\n\ntemplate<class T, Order order, int M_, int K_, int FragSize, int FragNum_, int RepeatC = 1>\nstruct SmemCopyAtom_Pack_v2 {\n    static constexpr int M = M_;\n    static constexpr int K = K_;\n\n    static constexpr int kFragNum = FragNum_;\n\n    using Frag = Array<T, FragSize * kFragNum>;\n\n    __device__ static int2 get_offset(int thread_idx)  // -> (m, k)\n    {\n        const int lane_id = thread_idx % WARP_SIZE;\n\n        const int c = lane_id / RepeatC * Frag::size();\n\n        return order == kRowMajor ? int2{0, c} : int2{c, 0};\n    }\n\n    template<class S, class D>\n    __device__ static void copy(S src_ptr, D dst_ptr, bool mask)\n    {\n        auto dst_raw_ptr = (T*)dst_ptr;  // SubBytePtr<T> -> T*\n        if (mask) {\n            Lds(*(Frag*)dst_raw_ptr, src_ptr);\n        }\n    }\n};\n\ntemplate<class T, class CopyAtom, Order order, int FragNum_>\nstruct SmemCopyAtom_Pack_v3 {\n    static constexpr int M = CopyAtom::M * FragNum_;\n    static constexpr int K = CopyAtom::K;\n\n    static constexpr int kFragNum = FragNum_;\n\n    using Frag = Array<T, CopyAtom::Frag::size() * kFragNum>;\n\n    __device__ static int2 get_offset(int thread_idx)  // -> (m, k)\n    {\n        const int c = CopyAtom::unique(thread_idx, 0).x * Frag::size();\n\n        return order == kRowMajor ? int2{0, c} : int2{c, 0};\n    }\n\n    template<class S, class D>\n    __device__ static void copy(S src_ptr, D dst_ptr, bool mask)\n    {\n        if (mask) {\n            auto dst_raw_ptr = (T*)dst_ptr;  // SubBytePtr<T> -> T*\n            Lds(*(Frag*)dst_raw_ptr, src_ptr);\n        }\n    }\n};\n\ntemplate<class Operand, int iM, int iK, int dM, int dK>\nstruct SmemCopy {\n    using Atom = typename Operand::SmemCopyAtom;\n\n    static constexpr int kFragNum = Atom::kFragNum;\n\n    static constexpr int ITER_M = iM / Atom::kFragNum;\n\n    static_assert(ITER_M > 0);\n\n    using Frag = typename Atom::Frag[ITER_M];\n\n    using Pack = Packing_v2<Operand::kPack, Operand::kOrder>;\n\n    static constexpr int2 delta = Pack::apply(int2{dM * kFragNum, dK});\n\n    using Layout = typename Operand::SmemLayout;\n\n    static constexpr int2 kMK0 = cs2mk<Operand::kOrder>(Layout::C0, Layout::S0);\n\n    static constexpr int kPeriodM = ceil_div(kMK0.x, delta.x);\n    static constexpr int kPeriodK = ceil_div(kMK0.y, delta.y);\n\n    const int2 offset_;\n\n    int phases_[kPeriodK][kPeriodM];\n\n    __device__ SmemCopy(int2 offset): offset_{offset}\n    {\n        const int2 thr = Atom::get_offset(threadIdx.x);\n        PRAGMA_UNROLL\n        for (int k = 0; k < kPeriodK; ++k) {\n            PRAGMA_UNROLL\n            for (int m = 0; m < kPeriodM; ++m) {\n                const int2 pack = Pack::apply({offset.x + m * dM * kFragNum, offset.y + k * dK});\n                const int2 cs   = mk2cs<Operand::kOrder>({pack.x + thr.x, pack.y + thr.y});\n                phases_[k][m]   = Layout::apply(cs.y, cs.x);\n            }\n        }\n    }\n\n    template<class Pointer>\n    __device__ void operator()(Pointer src_ptr, Frag& dst, int k, bool mask = true)\n    {\n        using Accessor = typename Operand::SmemAccessor;\n        if constexpr (Operand::kGroupSize == 1) {\n            PRAGMA_UNROLL\n            for (int m = 0; m < ITER_M; ++m) {\n                const int  mm = m / kPeriodM * kPeriodM * dM * kFragNum;\n                const int  kk = k / kPeriodK * kPeriodK * dK;\n                const int2 cs = mk2cs<Operand::kOrder>(Pack::apply(int2{mm, kk}));\n                const int  i0 = Layout::apply(cs.y, cs.x);\n                const int  i1 = phases_[k % kPeriodK][m % kPeriodM];\n                Atom::copy(&src_ptr[i0 + i1], dst[m].data(), mask);\n            }\n        }\n        else {  // generic case\n            Accessor   smem{src_ptr};\n            const int2 thr = Atom::get_offset(threadIdx.x);\n            PRAGMA_UNROLL\n            for (int m = 0; m < ITER_M; ++m) {\n                const int  mm = offset_.x + m * dM * kFragNum;\n                const int  kk = offset_.y + k * dK;  // Note: this forbids sub-tile group sizes\n                const int2 mk = Pack::apply(int2{mm, kk / Operand::kGroupSize});\n                Atom::copy(&smem(mk.x + thr.x, mk.y + thr.y), dst[m].data(), mask);\n            }\n        }\n        // else if constexpr (Operand::kPack != 0 && Operand::kGroupSize != 1) {  // group size = 1, pack != 0\n        //     const int  mask_k = Operand::kGroupSize == 1;\n        //     const int2 pack   = Pack::apply(int2{offset_.x, offset_.y});\n        //     const int2 thr    = Atom::get_offset(threadIdx.x);\n        //     const int2 cs     = mk2cs<Operand::kOrder>({pack.x + thr.x, (pack.y + thr.y) * mask_k});\n        //     auto       smem   = src_ptr + Layout::apply(cs.y, cs.x);\n        //     PRAGMA_UNROLL\n        //     for (int m = 0; m < ITER_M; ++m) {\n        //         const int  mm  = m * dM * kFragNum;\n        //         const int  kk  = k * dK;\n        //         const int2 cs  = mk2cs<Operand::kOrder>(Pack::apply(int2{mm, kk * mask_k}));\n        //         const int  idx = Layout::apply(cs.y, cs.x);\n        //         Atom::copy(&smem[idx], dst[m].data(), mask);\n        //     }\n        // }\n    }\n};\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/test/gemm_bench.cu",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include \"nvbench/main.cuh\"\n#include \"src/turbomind/kernels/gemm/operand.h\"\n#include \"src/turbomind/kernels/gemm/test/models.h\"\n#include \"src/turbomind/kernels/gemm/test/testbed.h\"\n#include <cuda_runtime_api.h>\n#include <map>\n#include <nvbench/nvbench.cuh>\n#include <string>\n\nvoid gemm_bench(nvbench::state& state)\n{\n    const auto idx = state.get_int64(\"idx\");\n\n    const auto bs = state.get_int64(\"bs\");\n    const auto tp = state.get_int64(\"tp\");\n\n    const auto expert_num  = state.get_int64(\"e_num\");\n    const auto exp_per_tok = state.get_int64(\"e_tok\");\n\n    auto [output_dims, input_dims] = config[idx];\n\n    constexpr int group_size = 128;\n\n    if (idx % 4 == 0 || idx % 4 == 2) {\n        if (output_dims % tp)\n            return;\n        output_dims /= tp;\n    }\n    else {\n        if (input_dims % tp)\n            return;\n        input_dims /= tp;\n    }\n\n    if (input_dims % group_size)\n        return;\n\n    using turbomind::gemm::get_test;\n\n    {\n        int m = bs;\n        int n = output_dims;\n        int k = input_dims;\n        if (get_test().kBatchDim == 1) {\n            std::swap(m, n);\n        }\n        std::cerr << \"m\" << m << \"n\" << n << \"k\" << k << \"\\n\";\n\n        get_test().Initialize(m, n, k, group_size, expert_num, exp_per_tok, state.get_cuda_stream());\n    }\n\n    state.add_element_count(get_test().get_element_count());\n\n    // state.collect_dram_throughput();\n    // state.collect_l2_hit_rates();\n\n    if constexpr (1) {\n        state.add_global_memory_reads(get_test().get_global_memory_reads());\n        get_test().Run();\n        state.exec(nvbench::exec_tag::sync, [&](nvbench::launch&) {  //\n            get_test().Run();\n        });\n    }\n    else {\n        state.add_global_memory_reads(get_test().get_ref_global_memory_reads());\n        state.exec(nvbench::exec_tag::sync, [&](nvbench::launch&) {  //\n            get_test().RunCublas();\n        });\n    }\n\n    get_test().ctx_.reset();\n}\n\nNVBENCH_BENCH(gemm_bench)\n    .add_int64_axis(\"idx\", nvbench::range(0, (int)config.size() - 1))\n    .add_int64_power_of_two_axis(\"bs\", nvbench::range(0, 14))\n    .add_int64_axis(\"tp\", {1, 2, 4})\n    .add_int64_axis(\"e_num\", {0})\n    .add_int64_axis(\"e_tok\", {1});\n\nint main(int argc, char* argv[])\n{\n    NVBENCH_MAIN_BODY(argc, argv);\n    return 0;\n}\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/test/models.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include <cstdint>\n#include <map>\n#include <utility>\n#include <vector>\n\nstatic const std::vector<std::pair<int64_t, int64_t>> config{\n    {11008 * 2, 4096}, {4096, 11008}, {12288, 4096}, {4096, 4096},  // llama2-7b\n    {14336 * 2, 4096}, {4096, 14336}, {6144, 4096},  {4096, 4096},  // llama3-8b / internlm2.5-7b\n    {16384 * 2, 6144}, {6144, 16384}, {8192, 6144},  {6144, 6144},  // internlm2-20b\n    {13696 * 2, 4096}, {4096, 13696}, {4608, 4096},  {4096, 4096},  // glm4-9b\n    {18944 * 2, 3584}, {3584, 18944}, {4608, 3584},  {3584, 3584},  // qwen2-7b\n    {20480 * 2, 7168}, {7168, 20480}, {9216, 7168},  {7168, 7168},  // yi-34b\n    {28672 * 2, 8192}, {8192, 28672}, {10240, 8192}, {8192, 8192},  // llama2-70b / llama3-70b\n    {29696 * 2, 8192}, {8192, 29696}, {10240, 8192}, {8192, 8192},  // qwen2-72b-instruct-awq\n    {14336 * 2, 4096}, {4096, 14336}, {6144, 4096},  {4096, 4096},  // mixtral-8x7b, E8e2\n    {16384 * 2, 6144}, {6144, 16384}, {0, 0},        {0, 0},        // mixtral-8x22b, E8e2\n    {1536 * 2, 5120},  {5120, 1536},  {0, 0},        {0, 0},        // deepseek-v2, E160e6\n    {1536 * 2, 2048},  {2048, 1536},  {0, 0},        {0, 0},        // deepseek-v2-lite, E64e6\n    {2560 * 2, 3840},  {3840, 2560},  {0, 0},        {0, 0},        // qwen2-a14b, E64e8\n    {6400 * 2, 4096},  {4096, 6400},  {0, 0},        {0, 0},        // phi-3.5-MoE, E16e2\n};\n\n// static const std::map<int, std::pair<int, int>> moe_config{{32, {8, 2}}, {33, {8, 2}}};\n\n// {29568 * 2, 8192}, {8192, 29568}, {10240, 8192}, {8192, 8192},  // qwen2-72b\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/test/quantization.cu",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include \"src/turbomind/kernels/gemm/test/quantization_impl.h\"\n\nnamespace turbomind::gemm {\n\ntemplate void Quantize<uint4_t>(const thrust::universal_vector<half>& x,\n                                int                                   m,\n                                int                                   k,\n                                Order                                 order,\n                                int                                   group_size,\n                                thrust::universal_vector<half>&       x_p,  // pseudo-quantized\n                                thrust::universal_vector<uint16_t>&   x_q,  // quantized ushort\n                                thrust::universal_vector<half>&       x_u,  // scales & zeros (always m-major)\n                                cudaStream_t                          stream);\n\ntemplate void Quantize<uint8_t>(const thrust::universal_vector<half>& x,\n                                int                                   m,\n                                int                                   k,\n                                Order                                 order,\n                                int                                   group_size,\n                                thrust::universal_vector<half>&       x_p,  // pseudo-quantized\n                                thrust::universal_vector<uint16_t>&   x_q,  // quantized ushort\n                                thrust::universal_vector<half>&       x_u,  // scales & zeros (always m-major)\n                                cudaStream_t                          stream);\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/test/quantization.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include \"src/turbomind/kernels/gemm/types.h\"\n#include <thrust/device_vector.h>\n#include <thrust/universal_vector.h>\n\n#pragma once\n\nnamespace turbomind::gemm {\n\ntemplate<class D, class S>\nvoid Quantize(const thrust::universal_vector<S>&  x,\n              int                                 m,\n              int                                 k,\n              Order                               order,\n              int                                 group_size,\n              thrust::universal_vector<S>&        x_p,  // pseudo-quantized\n              thrust::universal_vector<uint16_t>& x_q,  // quantized ushort\n              thrust::universal_vector<S>&        x_u,  // scales & zeros (always m-major)\n              cudaStream_t                        stream);\n\n}\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/test/quantization_impl.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include \"src/turbomind/kernels/attention/quantization.h\"\n#include \"src/turbomind/kernels/core/array_ops.h\"\n#include \"src/turbomind/kernels/core/common.h\"\n#include \"src/turbomind/kernels/core/math.h\"\n#include \"src/turbomind/kernels/gemm/test/test_utils.h\"\n#include \"src/turbomind/kernels/gemm/types.h\"\n\n#include <thrust/execution_policy.h>\n#include <thrust/universal_vector.h>\n\nnamespace turbomind::gemm {\n\n// quantize using `scale` and `zeros`,\ntemplate<class T>\n__global__ void find_stats(Array<T, 2>* minmax, const T* src, int N, int K, int G)\n{\n    int n_idx = blockIdx.x * blockDim.x + threadIdx.x;\n    int k_idx = blockIdx.y;\n\n    if (n_idx >= N || k_idx * G >= K) {\n        return;\n    }\n\n    float minval = std::numeric_limits<float>::infinity();\n    float maxval = -minval;\n\n    const int L = min(K, G);\n\n    for (int k = 0; k < L; k += 8) {\n        Array<T, 8> vec;\n        Load(vec, &src[n_idx * K + k_idx * G + k]);\n        PRAGMA_UNROLL\n        for (int i = 0; i < vec.size(); ++i) {\n            minval = __hmin(minval, vec[i]);\n            maxval = __hmax(maxval, vec[i]);\n        }\n    }\n\n    // store in n-major\n    Store(minmax[k_idx * N + n_idx].data(), Array<T, 2>{minval, maxval});\n}\n\ntemplate<class Q, bool asym, class T>\n__global__ void find_params(T* param, const Array<T, 2>* minmax, int count)\n{\n    int global_idx = threadIdx.x + blockIdx.x * blockDim.x;\n    if (global_idx >= count) {\n        return;\n    }\n    auto        stats     = minmax[global_idx];\n    const float inv_q_max = fdividef(1.f, (1 << bitsof<Q>)-1);\n\n    static_assert(asym);\n\n    float scale = (T)(((float)stats[1] - (float)stats[0]) * inv_q_max);\n\n    // force trivial scale / zero for debugging\n    if constexpr (0) {\n        stats[0] = 0;\n        scale    = 1.f;\n    }\n\n    Store(param + global_idx * 2, Array<T, 2>{scale, stats[0]});\n}\n\ntemplate<class Q, class T>\n__global__ void quantize(uint16_t* dst, T* pseudo, const T* src, const T* stats, int N, int K, int G)\n{\n    static_assert(bitsof<Q> <= 16);\n    static_assert(bitsof<T> == 16);  // fp16 & bf16\n\n    int n_idx = blockIdx.x * blockDim.x + threadIdx.x;\n    int k_idx = blockIdx.y;\n\n    if (n_idx >= N || k_idx * G >= K) {\n        return;\n    }\n\n    Array<T, 2> param;\n    Load(param, stats + (k_idx * N + n_idx) * 2);\n\n    float inv_scale = fdividef(1.f, param[0]);\n\n    const int L = min(K, G);\n\n    for (int k = 0; k < L; k += 8) {\n        Array<T, 8>        vi;\n        Array<uint16_t, 8> vo;\n        Load(vi, &src[n_idx * K + k_idx * G + k]);\n\n        PRAGMA_UNROLL\n        for (int i = 0; i < 8; ++i) {\n            float u = (static_cast<float>(vi[i] - param[1])) * inv_scale;\n            vo[i]   = quant<uint16_t>(u, bitsof<Q>);\n        }\n        Store(&dst[n_idx * K + k_idx * G + k], vo);\n\n        if (pseudo) {\n            Array<T, 8> vf;\n            PRAGMA_UNROLL\n            for (int i = 0; i < 8; ++i) {\n                vf[i] = __hfma(static_cast<T>(vo[i]), param[0], param[1]);\n            }\n            Store(&pseudo[n_idx * K + k_idx * G + k], vf);\n        }\n    }\n}\n\ntemplate<class T>\n__global__ void transpose(const T* src, T* dst, int s, int c)\n{\n    const int cid = threadIdx.x + blockIdx.x * blockDim.x;\n    const int sid = threadIdx.y + blockIdx.y * blockDim.y;\n    if (sid < s && cid < c) {\n        dst[cid * s + sid] = src[sid * c + cid];\n    }\n}\n\ntemplate<class T>\nvoid invokeTranspose(const T* src, T* dst, int s, int c, cudaStream_t stream)\n{\n    const dim3 block{32, 16};\n    const dim3 grid(ceil_div<int>(c, block.x), ceil_div<int>(s, block.y));\n\n    transpose<<<grid, block, 0, stream>>>(src, dst, s, c);\n}\n\ntemplate<class D, class S>\nvoid Quantize(const thrust::universal_vector<S>&  x,\n              int                                 m,\n              int                                 k,\n              Order                               order,\n              int                                 group_size,\n              thrust::universal_vector<S>&        x_p,  // pseudo-quantized\n              thrust::universal_vector<uint16_t>& x_q,  // quantized ushort\n              thrust::universal_vector<S>&        x_u,  // scales & zeros (always m-major)\n              cudaStream_t                        stream)\n\n{\n    auto policy = thrust::device.on(stream);\n\n    thrust::universal_vector<S>           _x(x.size());\n    thrust::universal_vector<S>           _x_p(x.size());\n    thrust::universal_vector<uint16_t>    _x_q(x.size());\n    thrust::universal_vector<Array<S, 2>> stats(ceil_div(k, group_size) * m);\n\n    x_p.resize(x.size());\n    x_q.resize(x.size());\n    /// FIXME: correct the size\n    x_u.resize(stats.size() * 2);\n\n    if (order == Order::kRowMajor) {\n        thrust::copy(policy, x.begin(), x.end(), _x.begin());\n    }\n    else {\n        invokeTranspose(x.data().get(), _x.data().get(), k, m, stream);\n    }\n\n    const int  block = std::min(256, m);\n    const dim3 grid(ceil_div(m, block), ceil_div(k, group_size));\n\n    find_stats<<<grid, block, 0, stream>>>(stats.data().get(),  //\n                                           _x.data().get(),\n                                           m,\n                                           k,\n                                           group_size);\n\n    find_params<D, true><<<ceil_div<int>(stats.size(), 256), 256, 0, stream>>>(  //\n        x_u.data().get(),\n        stats.data().get(),\n        stats.size());\n\n    quantize<D><<<grid, block, 0, stream>>>(_x_q.data().get(),  //\n                                            _x_p.data().get(),\n                                            _x.data().get(),\n                                            x_u.data().get(),\n                                            m,\n                                            k,\n                                            group_size);\n\n    if (order == Order::kRowMajor) {\n        thrust::copy(policy, _x_p.begin(), _x_p.end(), x_p.begin());\n        thrust::copy(policy, _x_q.begin(), _x_q.end(), x_q.begin());\n    }\n    else {\n        invokeTranspose(_x_p.data().get(), x_p.data().get(), m, k, stream);\n        invokeTranspose(_x_q.data().get(), x_q.data().get(), m, k, stream);\n    }\n\n    cudaStreamSynchronize(stream);\n\n    // Compare(_x_p.data().get(), _x.data().get(), k, k, m);\n\n    const int kg = ceil_div(k, group_size);\n    for (int i = 0; i < m * kg; ++i) {\n        // int mi = i % m;\n        // int ki = i / m;\n\n        // x_u[i * 2]     = i;\n        // x_u[i * 2 + 1] = i;\n\n        // x_u[i * 2]     = i * 2;\n        // x_u[i * 2 + 1] = i * 2 + 1;\n    }\n}\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/test/reference.cu",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include \"src/turbomind/kernels/gemm/test/reference.h\"\n#include <cstdio>\n\nnamespace turbomind::gemm {\n\n#define CHECK(cond)                                                                                                    \\\n    do {                                                                                                               \\\n        if (!(cond)) {                                                                                                 \\\n            fprintf(stderr, \"*** Check failed: (%s) @ %s:%d\\n\", #cond, __FILE__, __LINE__);                            \\\n            std::abort();                                                                                              \\\n        }                                                                                                              \\\n    } while (0)\n\nnamespace {\n\nMatrixLayout transpose(MatrixLayout x)\n{\n    std::swap(x.rows, x.cols);\n    x.order = x.order == Order::kColMajor ? Order::kRowMajor : Order::kColMajor;\n    return x;\n}\n\ncudaDataType to_cuda_dtype(DataType dtype)\n{\n    switch (dtype) {\n        case DataType::kFloat16:\n            return CUDA_R_16F;\n        case DataType::kBfloat16:\n            return CUDA_R_16BF;\n        default:\n            CHECK(\"unsupported data type\" && 0);\n    }\n    return {};\n}\n\n}  // namespace\n\nReference::Reference()\n{\n    cublasCreate(&handle_);\n\n    cublasSetWorkspace(handle_, nullptr, 0);\n    cublasSetMathMode(handle_, CUBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION);\n}\n\nReference::~Reference()\n{\n    if (handle_) {\n        cublasDestroy(handle_);\n        handle_ = {};\n    }\n}\n\nvoid Reference::set_stream(cudaStream_t stream)\n{\n    cublasSetStream(handle_, stream);\n}\n\nvoid Reference::gemm(const void* A, MatrixLayout Adesc, const void* B, MatrixLayout Bdesc, void* C, MatrixLayout Cdesc)\n{\n\n    // Transpose the problem for C to be column major\n    if (Cdesc.order == Order::kRowMajor) {\n        std::swap(A, B);\n        std::swap(Adesc, Bdesc);\n        Adesc = transpose(Adesc);\n        Bdesc = transpose(Bdesc);\n        Cdesc = transpose(Cdesc);\n        // (n, k) (k, m)\n    }\n\n    TM_CHECK_EQ(Adesc.cols, Bdesc.rows);\n\n    // (m, k) (k, n)\n    int m = Cdesc.rows;\n    int n = Cdesc.cols;\n    int k = Adesc.cols;\n\n    TM_CHECK_EQ(Adesc.rows, m);\n    TM_CHECK_EQ(Bdesc.cols, n);\n    TM_CHECK_EQ(Bdesc.rows, k);\n\n    float alpha = 1.f;\n    float beta  = 0.f;\n\n    auto to_cublas_op = [](Order o) { return o == Order::kColMajor ? CUBLAS_OP_N : CUBLAS_OP_T; };\n\n    auto status = cublasGemmEx(handle_,\n                               to_cublas_op(Adesc.order),\n                               to_cublas_op(Bdesc.order),\n                               m,\n                               n,\n                               k,\n                               &alpha,\n                               A,\n                               to_cuda_dtype(Adesc.type),\n                               Adesc.ld,\n                               B,\n                               to_cuda_dtype(Bdesc.type),\n                               Bdesc.ld,\n                               &beta,\n                               C,\n                               to_cuda_dtype(Cdesc.type),\n                               Cdesc.ld,\n                               CUBLAS_COMPUTE_32F,\n                               CUBLAS_GEMM_DEFAULT_TENSOR_OP);\n\n    TM_CHECK_EQ(status, CUBLAS_STATUS_SUCCESS);\n}\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/test/reference.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include \"src/turbomind/kernels/gemm/types.h\"\n\n#include <cublas_v2.h>\n\nnamespace turbomind::gemm {\n\nclass Reference {\npublic:\n    Reference();\n    ~Reference();\n\n    void set_stream(cudaStream_t stream);\n\n    void gemm(const void* A, MatrixLayout Adesc, const void* B, MatrixLayout Bdesc, void* C, MatrixLayout Cdesc);\n\nprivate:\n    cublasHandle_t handle_;\n};\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/test/test_gemm_v2.cc",
    "content": "\n\n#include \"src/turbomind/core/context.h\"\n#include \"src/turbomind/core/data_type.h\"\n\n#include \"testbed_v3.h\"\n\nusing namespace turbomind;\n\nstruct TestParameter: Testbed_v3::Parameter {\n    TestParameter(DataType dtype, DataType wtype, DataType itype, int group_size = 128): Testbed_v3::Parameter{}\n    {\n        data_type   = dtype;\n        weight_type = wtype;\n        input_type  = itype;\n\n        this->group_size = group_size;\n    }\n};\n\nint main()\n{\n    auto stream = core::Stream::create();\n\n    core::ContextGuard ctx{stream, core::Allocator{kCPU}, core::Allocator{stream, false}};\n    // clang-format off\n    // TestParameter p{kHalf, kUint4      , kHalf, 128};\n    // TestParameter p{kHalf, kFloat4_e2m1, kHalf,  32};\n    // TestParameter p{kHalf, kFloat8_e4m3, kHalf, 128};\n    // TestParameter p{kHalf, kHalf       , kHalf};\n\n    // TestParameter p{kBfloat16, kBfloat16   , kBfloat16};\n    // TestParameter p{kBfloat16, kFloat8_e4m3, kFloat8_e4m3, 128};\n    TestParameter p{kBfloat16, kFloat8_e4m3, kBfloat16   , 128};\n    // TestParameter p{kBfloat16, kFloat4_e2m1, kBfloat16   ,  32};\n    // clang-format on\n\n    // p.input_dim      = 512;\n    // p.output_dim     = 1024;\n    // p.max_batch_size = 256;\n\n    // p.input_dim      = 1024;\n    // p.output_dim     = 1024;\n    // p.max_batch_size = 1024;\n\n    // p.input_dim      = 12288;\n    // p.output_dim     = 16384;\n    // p.max_batch_size = 8192;\n\n    // p.expert_num        = 1;\n    // p.experts_per_token = 1;\n\n    // p.input_dim      = 2880;\n    // p.output_dim     = 2880;\n    // p.max_batch_size = 64;\n\n    // p.input_dim         = 7168;\n    // p.output_dim        = 4096;\n    // p.max_batch_size    = 16384;\n    // p.expert_num        = 256;\n    // p.experts_per_token = 8;\n\n    // Qwen3-MoE\n    p.expert_num        = 128;\n    p.experts_per_token = 8;\n    // 30B\n    // p.input_dim  = 2048;\n    // p.output_dim = 768 * 2;\n    // 235B\n    // p.input_dim  = 4096;\n    // p.output_dim = 1536 * 2;\n    // 480B\n    p.input_dim  = 6144;\n    p.output_dim = 2560 * 2;\n\n    p.max_batch_size = 256;\n\n    // p.input_dim         = 16384;\n    // p.output_dim        = 16384;\n    // p.max_batch_size    = 16384;\n\n    // p.input_dim         = 2880;\n    // p.output_dim        = 5760;\n    // p.max_batch_size    = 16384;\n    // p.expert_num        = 32;\n    // p.experts_per_token = 4;\n\n    // p.input_dim      = 128;\n    // p.output_dim     = 32;\n    // p.max_batch_size = 1;\n\n    Testbed_v3 test{p};\n\n    test.GetReference();\n    test.Run();\n    test.Compare();\n\n    cudaDeviceSynchronize();\n\n    return 0;\n}\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/test/test_moe_utils.cu",
    "content": "#include \"src/turbomind/kernels/gemm/moe_utils_v2.h\"\n#include \"src/turbomind/kernels/gemm/test/test_utils.h\"\n#include \"src/turbomind/kernels/gemm/tuner/cache_utils.h\"\n#include \"src/turbomind/kernels/gemm/types.h\"\n#include <algorithm>\n#include <iomanip>\n#include <numeric>\n#include <sstream>\n#include <string>\n#include <thrust/host_vector.h>\n#include <thrust/universal_vector.h>\n\nusing namespace turbomind;\n\ntemplate<class T>\nvoid print_vecs(const T* data, int m, int k, std::string msg, int width = 4)\n{\n    if (!msg.empty()) {\n        std::cout << msg << \":\\n\";\n    }\n    for (int mm = 0; mm < m; ++mm) {\n        for (int kk = 0; kk < k; ++kk) {\n            std::cout << std::setw(width) << data[mm * k + kk];\n        }\n        std::cout << \"\\n\";\n    }\n}\n\ntemplate<class T>\nvoid diff_vecs(const T* data, const T* refs, int m, int k, std::string msg)\n{\n    if (!msg.empty()) {\n        std::cout << msg << \": [\" << m << \", \" << k << \"]\\n\";\n    }\n    for (int mm = 0; mm < m; ++mm) {\n        std::cout << \"m=\" << mm << \": \";\n        for (int kk = 0; kk < k; ++kk) {\n            const auto& x = data[mm * k + kk];\n            const auto& y = refs[mm * k + kk];\n            if (x != y) {\n                std::cout << kk << \"(\" << x << \", \" << y << \") \";\n            }\n        }\n        std::cout << \"\\n\";\n    }\n}\n\nRNG& gRNG()\n{\n    static RNG inst{};\n    return inst;\n}\n\nusing thrust::universal_vector;\n\nvoid moe_gate_ref(int                            tokens,\n                  int                            expert_num,\n                  int                            experts_per_token,\n                  const universal_vector<float>& logits,\n                  universal_vector<int>&         offsets,\n                  universal_vector<int>&         eids,\n                  universal_vector<int>&         f2n,\n                  universal_vector<int>&         en2f,\n                  universal_vector<float>&       scales)\n{\n    std::vector<int> eid_range(expert_num);\n    std::iota(eid_range.begin(), eid_range.end(), 0);\n\n    for (int t = 0; t < tokens; ++t) {\n        const float* logit   = logits.data().get() + expert_num * t;\n        const float  max_val = *std::max_element(logit, logit + expert_num);\n        if constexpr (0) {\n            std::vector<float> probs(logit, logit + expert_num);\n            float              sum = 0;\n            for (auto& p : probs) {\n                p = std::exp(p - max_val);\n                sum += p;\n            }\n            for (auto& p : probs) {\n                p /= sum;\n            }\n            std::vector<int> idxs = eid_range;\n            // Had to use stable sort since there is no `std::stable_nth_element`\n            std::stable_sort(idxs.begin(), idxs.end(), [&](int i, int j) {  //\n                return probs[i] > probs[j];\n            });\n            // Recover natural order in top-k\n            std::sort(idxs.begin(), idxs.begin() + experts_per_token);\n            idxs.resize(experts_per_token);\n            sum = 0;\n            for (int e = 0; e < experts_per_token; ++e) {\n                eids[e * tokens + t] = idxs[e];\n                sum += probs[idxs[e]];\n            }\n            for (int e = 0; e < experts_per_token; ++e) {\n                scales[e * tokens + t] = probs[idxs[e]] / sum;\n            }\n        }\n        else {\n            std::vector<int> idxs = eid_range;\n            // Had to use stable sort since there is no `std::stable_nth_element`\n            std::stable_sort(idxs.begin(), idxs.end(), [&](int i, int j) {  //\n                return logit[i] > logit[j];\n            });\n            // Recover natural order in top-k\n            std::sort(idxs.begin(), idxs.begin() + experts_per_token);\n            idxs.resize(experts_per_token);\n            std::vector<float> probs(experts_per_token);\n            float              sum = 0;\n            for (int e = 0; e < experts_per_token; ++e) {\n                eids[e * tokens + t] = idxs[e];\n                probs[e]             = std::exp(logit[idxs[e]] - max_val);\n                sum += probs[e];\n            }\n            for (int e = 0; e < experts_per_token; ++e) {\n                scales[e * tokens + t] = probs[e] / sum;\n            }\n        }\n    }\n\n    // f2en\n    std::vector<int> f2en(eids.size());\n    std::iota(f2en.begin(), f2en.end(), 0);\n\n    std::stable_sort(f2en.begin(), f2en.end(), [&](int i, int j) {  //\n        if (eids[i] != eids[j]) {\n            return eids[i] < eids[j];\n        }\n        return i % tokens < j % tokens;\n    });\n\n    std::fill_n(offsets.begin(), offsets.size(), 0);\n    std::vector<int> accum(expert_num);\n\n    for (size_t i = 0; i < f2en.size(); ++i) {\n        f2n[i]        = f2en[i] % tokens;\n        en2f[f2en[i]] = i;\n        ++accum[eids[i]];\n    }\n\n    for (size_t i = 1; i < offsets.size(); ++i) {\n        offsets[i] = offsets[i - 1] + accum[i - 1];\n    }\n}\n\nvoid mask2eids(universal_vector<int8_t>& masks, universal_vector<int>& eids, int tokens, int expert_num)\n{\n    const int tokens_padded = masks.size() / expert_num;\n    // std::cout << eids.size() << std::endl;\n    for (int e = 0; e < expert_num; ++e) {\n        for (int t = 0; t < tokens_padded; ++t) {\n            if (auto v = masks[e * tokens_padded + t]; v >= 0) {\n                // if (v >= 2 || t >= 8193) {\n                //     std::cerr << \"FUCK \" << v << \" \" << t << std::endl;\n                // }\n                eids[v * tokens + t] = e;\n            }\n        }\n    }\n}\n\nstruct Tiling {\n    int  output_dims;\n    int  input_dims;\n    int3 cta_tile;\n};\n\nbool test_moe_gate(int                     tokens,  //\n                   int                     expert_num,\n                   int                     experts_per_token,\n                   gemm::Tape&             tape,\n                   const Tiling&           tiling,\n                   universal_vector<float> logits = {})\n{\n    if (logits.empty()) {\n        logits.resize(tokens * expert_num);\n        gRNG().GenerateUniform(logits.data().get(), logits.size());\n    }\n    assert(logits.size() == tokens * expert_num);\n\n    const int tokens_padded = (tokens + kMoeGateVecSize - 1) / kMoeGateVecSize * kMoeGateVecSize;\n    // const int max_coords    = get_max_coords(tokens, expert_num, experts_per_token, tiling);\n\n    universal_vector<int>    offsets(expert_num + 1);\n    universal_vector<int>    accum(expert_num * kMoeGateMaxTiles);\n    universal_vector<int8_t> masks(expert_num * tokens_padded);\n    universal_vector<int>    eids(experts_per_token * tokens);\n    universal_vector<int>    f2n(experts_per_token * tokens);\n    universal_vector<int>    f2E(experts_per_token * tokens);\n    universal_vector<int>    en2f(experts_per_token * tokens);\n    universal_vector<float>  scales(experts_per_token * tokens);\n    // universal_vector<int2>  coords(max_coords);\n    // thrust::fill(coords.begin(), coords.end(), int2{-1, 0});\n\n    auto offsets_ref = offsets;\n    auto eids_ref    = eids;\n    auto f2n_ref     = f2n;\n    auto en2f_ref    = en2f;\n    auto scales_ref  = scales;\n\n    moe_gate_ref(tokens, expert_num, experts_per_token, logits, offsets_ref, eids_ref, f2n_ref, en2f_ref, scales_ref);\n\n    cudaMemPrefetchAsync(f2n.data().get(), sizeof(int) * f2n.size(), 0);\n    cudaMemPrefetchAsync(f2E.data().get(), sizeof(int) * f2E.size(), 0);\n    cudaMemPrefetchAsync(en2f.data().get(), sizeof(int) * en2f.size(), 0);\n    cudaMemPrefetchAsync(offsets.data().get(), sizeof(int) * offsets.size(), 0);\n    cudaMemPrefetchAsync(scales.data().get(), sizeof(float) * scales.size(), 0);\n    cudaMemPrefetchAsync(logits.data().get(), sizeof(float) * logits.size(), 0);\n\n    bool softmax = true;\n\n    if (1) {\n        invokeMoeSoftmaxMaskTopKGroups(logits.data().get(), tokens, expert_num, expert_num / 8, 8, nullptr);\n        softmax = false;\n    }\n\n    for (int i = 0; i < 1; ++i) {\n        gemm::CacheFlushing::flush();\n        cudaMemset(accum.data().get(), 0, sizeof(int) * accum.size());\n        cudaMemset(masks.data().get(), -1, sizeof(int8_t) * masks.size());\n        invokeMoeGate_V2(f2n.data().get(),\n                         f2E.data().get(),\n                         en2f.data().get(),\n                         offsets.data().get(),\n                         scales.data().get(),\n                         masks.data().get(),\n                         accum.data().get(),\n                         logits.data().get(),\n                         tokens,\n                         tokens_padded,\n                         expert_num,\n                         experts_per_token,\n                         softmax,\n                         false,\n                         1.f,\n                         nullptr);\n    }\n\n    // invokeMoeTiling(coords.data().get(), offsets.data().get(), expert_num, coords.size(), &tiling, 1, 0);\n\n    // gemm::scheduleGemmMoe(tape,\n    //                       offsets.data().get(),\n    //                       tokens,\n    //                       experts_per_token,\n    //                       expert_num,\n    //                       tiling.output_dims,\n    //                       tiling.input_dims,\n    //                       tiling.cta_tile,\n    //                       tiling.cta_tile.z,\n    //                       1,\n    //                       0,\n    //                       0);\n\n    if (auto err = cudaDeviceSynchronize(); err != cudaSuccess) {\n        std::cerr << cudaGetErrorString(err) << std::endl;\n        std::abort();\n    }\n\n    // print_vecs(masks.data().get(), expert_num, tokens_padded, \"masks\");\n    mask2eids(masks, eids, tokens, expert_num);\n\n    bool success = true;\n\n    // success = offsets == offsets_ref && eids == eids_ref && f2n == f2n_ref && en2f == en2f_ref;\n\n    if (offsets != offsets_ref) {\n        std::cerr << \"offset\\n\";\n        success = false;\n    }\n    if (eids != eids_ref) {\n        std::cerr << \"eids\\n\";\n        success = false;\n    }\n    if (f2n != f2n_ref) {\n        std::cerr << \"f2n\\n\";\n        success = false;\n    }\n    if (en2f != en2f_ref) {\n        std::cerr << \"en2f\\n\";\n        success = false;\n    }\n\n    // print_vecs(logits.data().get(), tokens, expert_num, \"logits\", 12);\n\n    if (!success && 1) {\n\n        diff_vecs(eids.data().get(), eids_ref.data().get(), experts_per_token, tokens, \"eids\");\n\n        print_vecs(offsets_ref.data().get(), 1, expert_num + 1, \"offsets_ref\");\n        print_vecs(offsets.data().get(), 1, expert_num + 1, \"offsets\");\n\n        print_vecs(eids_ref.data().get(), experts_per_token, tokens, \"eids_ref\");\n        print_vecs(eids.data().get(), experts_per_token, tokens, \"eids\");\n\n        print_vecs(f2n_ref.data().get(), 1, experts_per_token * tokens, \"f2n_ref\");\n        print_vecs(f2n.data().get(), 1, experts_per_token * tokens, \"f2n\");\n\n        print_vecs(en2f_ref.data().get(), experts_per_token, tokens, \"en2f_ref\");\n        print_vecs(en2f.data().get(), experts_per_token, tokens, \"en2f\");\n\n        print_vecs(scales_ref.data().get(), experts_per_token, tokens, \"scales_ref\", 12);\n        print_vecs(scales.data().get(), experts_per_token, tokens, \"scales\", 12);\n\n        for (int i = 0; i < tokens; ++i) {\n            float sum = 0;\n            for (int j = 0; j < experts_per_token; ++j) {\n                sum += scales[j * tokens + i];\n            }\n            std::cout << sum << \" \";\n        }\n        std::cout << \"\\n\";\n\n        // print_vecs(accum.data().get(), expert_num, 1, \"accum\");\n\n        // print_vecs(coords.data().get(), 1, max_coords, \"coords\");\n\n        // thrust::host_vector<int4> tile_offsets(tape.max_ctas);\n        // std::cout << tape.max_ctas << std::endl;\n        // cudaMemcpy(tile_offsets.data(), tape.tile_offsets, sizeof(int4) * tile_offsets.size(),\n        // cudaMemcpyDefault); cudaDeviceSynchronize();\n\n        // std::cout << \"coords:\\n\";\n        // int last = -1;\n        // for (int i = 0; i < tape.max_ctas; ++i) {\n        //     auto& c = tile_offsets[i];\n        //     if (last >= 0 && c.w != last) {\n        //         std::cout << \"\\n\";\n        //     }\n        //     if (c.w == -1) {\n        //         std::cout << i << \"\\n\";\n        //         break;\n        //     }\n        //     last = c.w;\n        //     std::stringstream ss;\n        //     ss << c.x << \",\" << c.y;\n        //     std::cout << std::setw(6) << ss.str();\n        // }\n        // std::cout << \"\\n\";\n    }\n\n    return success;\n}\n\nint main()\n{\n    gemm::Tape       tape{};\n    constexpr Tiling tiling{14336, 128, {128, 128, 32}};\n\n    // test_moe_gate(32768 * 4, 60, 4, tape, tiling);\n    // test_moe_gate(32768, 64, 8, tape, tiling);\n    // test_moe_gate(8, 60, 4, tape, tiling);\n\n    test_moe_gate(16, 160, 6, tape, tiling);\n\n    return 0;\n\n    for (int i = 1; i < 16384; ++i) {\n        // std::cerr << i << std::endl;\n        auto success = test_moe_gate(i, 8, 2, tape, tiling);\n        if (!success) {\n            std::cerr << i << std::endl;\n            // std::abort();\n        }\n        // break;\n    }\n}\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/test/test_utils.cu",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include \"src/turbomind/core/core.h\"\n#include \"src/turbomind/core/data_type.h\"\n#include \"src/turbomind/kernels/gemm/test/test_utils.h\"\n#include <cublas_v2.h>\n#include <cuda_bf16.h>\n#include <curand.h>\n#include <curand_kernel.h>\n#include <fstream>\n#include <iostream>\n\n#define _CG_ABI_EXPERIMENTAL\n#include <cooperative_groups.h>\n#include <cooperative_groups/memcpy_async.h>\n#include <cooperative_groups/reduce.h>\n\n#include <thrust/device_ptr.h>\n#include <thrust/iterator/zip_iterator.h>\n#include <thrust/system/cuda/execution_policy.h>\n#include <thrust/transform_reduce.h>\n\nnamespace turbomind {\n\ncublasHandle_t cublas_handle{};\ncudaStream_t   cublas_stream{};\n\ntemplate<typename T>\nvoid Compare(const T* src, const T* ref, size_t stride, int dims, int bsz, bool show, float rtol, float atol)\n{\n    float asums{};\n    float rsums{};\n    int   outliers{};\n    for (int nn = 0; nn < bsz; ++nn) {\n        float abs_diff_sum{};\n        float rel_diff_sum{};\n        for (int mm = 0; mm < dims; ++mm) {\n            auto x = float(src[nn * stride + mm]);\n            auto y = float(ref[nn * stride + mm]);\n            // if (show) {\n            //     std::cout << x << \"\\t\" << y << std::endl;\n            // }\n            auto abs_diff = std::abs(x - y);\n            auto rel_diff = abs_diff / (std::max(std::abs(y), std::abs(x)) + 1e-8f);\n            if (!(abs_diff <= atol + rtol * std::abs(y))) {\n                ++outliers;\n                if (show) {\n                    std::cout << nn << \",\" << mm << \"\\t\" << x << \"\\t\" << y << std::endl;\n                }\n            }\n            abs_diff_sum += abs_diff;\n            rel_diff_sum += rel_diff;\n        }\n        asums += abs_diff_sum / dims;\n        rsums += rel_diff_sum / dims;\n    }\n    const float abs_diff = asums / bsz;\n    const float rel_diff = rsums / bsz;\n    const float outlier  = outliers / (float)bsz;\n    std::cout << \"abs_diff = \" << abs_diff << \" rel_diff = \" << rel_diff << \" outliers = \" << outlier << std::endl;\n}\n\ntemplate void\nCompare(const half* src, const half* ref, size_t stride, int dims, int bsz, bool show, float rtol, float atol);\ntemplate void\nCompare(const float* src, const float* ref, size_t stride, int dims, int bsz, bool show, float rtol, float atol);\n#if ENABLE_BF16\ntemplate void Compare(const nv_bfloat16* src,\n                      const nv_bfloat16* ref,\n                      size_t             stride,\n                      int                dims,\n                      int                bsz,\n                      bool               show,\n                      float              rtol,\n                      float              atol);\n#endif\n\nvoid Compare(\n    const void* x, const void* r, DataType dtype, size_t stride, int dim, int bsz, bool show, float rtol, float atol)\n{\n    auto invoke = [&](auto t) {\n        using T = decltype(t);\n        Compare((const T*)x, (const T*)r, stride, dim, bsz, show, rtol, atol);\n    };\n    TM_DISPATCH_DTYPES(dtype, invoke, half_t, bfloat16_t);\n}\n\ntemplate<class T>\nstd::vector<float>\nFastCompare(const T* src, const T* ref, int dims, int bsz, cudaStream_t stream, float rtol, float atol)\n{\n    auto       zip_iter = thrust::make_zip_iterator(src, ref);\n    const auto count    = (size_t)dims * bsz;\n    // nvcc-11.8: __host__ __device__ lambda can't be generic\n    using Tuple = thrust::tuple<float, float, float, float, float, float, int64_t>;\n    auto res    = thrust::transform_reduce(\n        thrust::cuda::par.on(stream),\n        zip_iter,\n        zip_iter + count,\n        [=] __host__ __device__(thrust::tuple<float, float> tup) -> Tuple {\n            float   s        = thrust::get<0>(tup);\n            float   r        = thrust::get<1>(tup);\n            float   abs_diff = fabsf(s - r);\n            float   abs_s    = fabsf(s);\n            float   abs_r    = fabsf(r);\n            float   rel_diff = abs_diff / (fmaxf(abs_r, abs_s) + 1e-8f);\n            int64_t outlier  = !(abs_diff <= (atol + rtol * abs_r));\n            return thrust::make_tuple(abs_s, abs_r, abs_diff, abs_diff, rel_diff, rel_diff, outlier);\n        },\n        thrust::make_tuple(0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0LL),\n        [] __host__ __device__(const Tuple& a, const Tuple& b) {  // `__host__`: compiler needs the return type\n            return thrust::make_tuple(thrust::get<0>(a) + thrust::get<0>(b),\n                                      thrust::get<1>(a) + thrust::get<1>(b),\n                                      thrust::get<2>(a) + thrust::get<2>(b),\n                                      fmaxf(thrust::get<3>(a), thrust::get<3>(b)),\n                                      thrust::get<4>(a) + thrust::get<4>(b),\n                                      fmaxf(thrust::get<5>(a), thrust::get<5>(b)),\n                                      thrust::get<6>(a) + thrust::get<6>(b));\n        });\n    return {thrust::get<0>(res) / dims / bsz,   // avg abs src\n            thrust::get<1>(res) / dims / bsz,   // avg abs ref\n            thrust::get<2>(res) / dims / bsz,   // avg abs diff\n            thrust::get<3>(res),                // max abs diff\n            thrust::get<4>(res) / dims / bsz,   // avg rel diff\n            thrust::get<5>(res),                // max rel diff\n            (float)thrust::get<6>(res) / bsz};  // outlier count\n}\n\ntemplate std::vector<float> FastCompare(const half*  src,  //\n                                        const half*  ref,\n                                        int          dims,\n                                        int          bsz,\n                                        cudaStream_t stream,\n                                        float        rtol,\n                                        float        atol);\n\ntemplate std::vector<float> FastCompare(const nv_bfloat16* src,  //\n                                        const nv_bfloat16* ref,\n                                        int                dims,\n                                        int                bsz,\n                                        cudaStream_t       stream,\n                                        float              rtol,\n                                        float              atol);\n\ntemplate std::vector<float> FastCompare(const float* src,  //\n                                        const float* ref,\n                                        int          dims,\n                                        int          bsz,\n                                        cudaStream_t stream,\n                                        float        rtol,\n                                        float        atol);\n\nstd::vector<float> FastCompare(const Tensor& x, const Tensor& r, cudaStream_t stream, float rtol, float atol)\n{\n    TM_CHECK_EQ(x.ndim(), 2);\n    TM_CHECK(x.is_contiguous());\n    TM_CHECK(x.layout() == r.layout());\n    TM_CHECK(x.dtype() == r.dtype());\n\n    auto invoke = [&](auto t) {\n        using T         = decltype(t);\n        auto [dim, bsz] = x.shapes(1, 0);\n        return FastCompare(x.data<T>(), r.data<T>(), dim, bsz, stream, rtol, atol);\n    };\n\n    TM_DISPATCH_DTYPES_RET(x.dtype(), invoke, half_t, bfloat16_t, float);\n}\n\nvoid FC_Header()\n{\n    printf(\"%16s%16s%16s%16s%16s%16s%16s\\n\",\n           \"amean\",\n           \"amean_ref\",\n           \"absdiff\",\n           \"absdiff_max\",\n           \"reldiff\",\n           \"reldiff_max\",\n           \"#outlier\");\n}\n\nvoid FC_Print(const std::vector<float>& d)\n{\n    printf(\"%16f%16f%16f%16f%16f%16f%16f\\n\", d[0], d[1], d[2], d[3], d[4], d[5], d[6]);\n}\n\nvoid LoadBinary(const std::string& path, size_t size, void* dst)\n{\n    std::ifstream ifs(path, std::ios::binary | std::ios::in);\n    if (!ifs.is_open()) {\n        std::cerr << \"failed to open \" << path << \"\\n\";\n        std::abort();\n    }\n    ifs.seekg(0, ifs.end);\n    auto actual_size_in_bytes = ifs.tellg();\n    ifs.seekg(0, ifs.beg);\n    if (size != actual_size_in_bytes) {\n        std::cerr << \"[warning] file \" << path << \" has \" << actual_size_in_bytes << \" bytes, while \" << size\n                  << \" bytes is requested\\n\";\n    }\n    ifs.read((char*)dst, size);\n    std::cerr << \"[info] \" << path << \" \" << size << \"\\n\";\n}\n\nnamespace cg = cooperative_groups;\n\n__global__ void curand_init(curandState* state)\n{\n    auto tid = cg::this_grid().thread_rank();\n    curand_init(0xe4c45822e90461ddULL, tid, 0, state + tid);\n}\n\ntemplate<typename T>\n__global__ void curand_uniform(curandState* state, size_t count, T* result, float scale, float shift)\n{\n    auto grid = cg::this_grid();\n    for (auto i = grid.thread_rank(); i < count; i += grid.size()) {\n        float tmp = curand_uniform(state + grid.thread_rank());\n        result[i] = T(scale * tmp + shift);\n    }\n}\n\ntemplate<typename T>\n__global__ void curand_normal(curandState* state, size_t count, T* result, float scale, float shift)\n{\n    auto grid = cg::this_grid();\n    for (auto i = grid.thread_rank(); i < count; i += grid.size()) {\n        float tmp = curand_normal(state + grid.thread_rank());\n        result[i] = T(scale * tmp + shift);\n    }\n}\n\n__global__ void curand_bytes(curandState* state, size_t count, uint* result)\n{\n    auto grid = cg::this_grid();\n    for (auto i = grid.thread_rank(); i < count; i += grid.size()) {\n        result[i] = curand(state + grid.thread_rank());\n    }\n}\n\nstruct RNG::Impl {\n\n    curandState* states{};\n\n    Impl()\n    {\n        cudaMalloc(&states, sizeof(curandState) * 64 * 64);\n        curand_init<<<64, 64>>>(states);\n    }\n\n    ~Impl()\n    {\n        cudaFree(states);\n    }\n\n    void GenerateUInt(uint* out, size_t count)\n    {\n        curand_bytes<<<64, 64, 0, stream_>>>(states, count, out);\n    }\n\n    template<typename T>\n    void GenerateUniform(T* out, size_t count, float scale, float shift)\n    {\n        curand_uniform<<<64, 64, 0, stream_>>>(states, count, out, scale, shift);\n    }\n\n    template<typename T>\n    void GenerateNormal(T* out, size_t count, float scale, float shift)\n    {\n        curand_normal<<<64, 64, 0, stream_>>>(states, count, out, scale, shift);\n    }\n\n    cudaStream_t stream_{};\n};\n\nRNG::RNG(): impl_(std::make_unique<Impl>()) {}\n\nRNG::~RNG() = default;\n\nvoid RNG::GenerateUInt(uint* out, size_t count)\n{\n    impl_->GenerateUInt(out, count);\n}\n\ntemplate<typename T>\nvoid RNG::GenerateUniform(T* out, size_t count, float scale, float shift)\n{\n    impl_->GenerateUniform(out, count, scale, shift);\n}\n\ntemplate<typename T>\nvoid RNG::GenerateNormal(T* out, size_t count, float scale, float shift)\n{\n    impl_->GenerateNormal(out, count, scale, shift);\n}\n\ncudaStream_t RNG::stream() const\n{\n    return impl_->stream_;\n}\n\nvoid RNG::set_stream(cudaStream_t stream)\n{\n    impl_->stream_ = stream;\n}\n\ntemplate void RNG::GenerateUniform(half* out, size_t count, float scale, float shift);\ntemplate void RNG::GenerateUniform(float* out, size_t count, float scale, float shift);\ntemplate void RNG::GenerateUniform(nv_bfloat16* out, size_t count, float scale, float shift);\ntemplate void RNG::GenerateNormal(half* out, size_t count, float scale, float shift);\ntemplate void RNG::GenerateNormal(float* out, size_t count, float scale, float shift);\ntemplate void RNG::GenerateNormal(nv_bfloat16* out, size_t count, float scale, float shift);\n\nvoid RNG::RandomBytes(Ref<Tensor> out_)\n{\n    auto& out = out_.get();\n    TM_CHECK(out.size() == out.layout().cosize());\n    TM_CHECK(out.byte_size() % sizeof(uint) == 0);\n    GenerateUInt((uint*)out.raw_data(), out.byte_size() / sizeof(uint));\n}\n\nvoid RNG::UniformFloat(Ref<Tensor> out_, float scale, float shift)\n{\n    auto& out = out_.get();\n    TM_CHECK(out.size() == out.layout().cosize());\n    auto invoke = [&](auto t) {\n        using T = decltype(t);\n        GenerateUniform(out.data<T>(), out.size(), scale, shift);\n    };\n    TM_DISPATCH_DTYPES(out.dtype(), invoke, float, half_t, bfloat16_t);\n}\n\nvoid RNG::NormalFloat(Ref<Tensor> out_, float scale, float shift)\n{\n    auto& out = out_.get();\n    TM_CHECK(out.size() == out.layout().cosize());\n    auto invoke = [&](auto t) {\n        using T = decltype(t);\n        GenerateNormal(out.data<T>(), out.size(), scale, shift);\n    };\n    TM_DISPATCH_DTYPES(out.dtype(), invoke, float, half_t, bfloat16_t);\n}\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/test/test_utils.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include \"src/turbomind/macro.h\"\n#include <cuda_fp16.h>\n#include <cuda_runtime.h>\n#include <memory>\n#include <string>\n#include <vector>\n\n#include \"src/turbomind/core/core.h\"\n\nnamespace turbomind {\n\ntemplate<typename T>\nvoid Compare(const T* src,\n             const T* ref,\n             size_t   stride,\n             int      dims,\n             int      bsz,\n             bool     show = false,\n             float    rtol = 1e-2,\n             float    atol = 1e-4);\n\nvoid Compare(const void* x,\n             const void* r,\n             DataType    dtype,\n             size_t      stride,\n             int         dim,\n             int         bsz,\n             bool        show,\n             float       rtol = 1e-2,\n             float       atol = 1e-4);\n\ntemplate<class T>\nstd::vector<float> FastCompare(const T*     src,  //\n                               const T*     ref,\n                               int          dims,\n                               int          bsz,\n                               cudaStream_t stream,\n                               float        rtol = 1e-2,\n                               float        atol = 1e-4);\n\nstd::vector<float> FastCompare(const Tensor& x,  //\n                               const Tensor& r,\n                               cudaStream_t  stream,\n                               float         rtol = 1e-2,\n                               float         atol = 1e-4);\n\nvoid FC_Header();\n\nvoid FC_Print(const std::vector<float>& d);\n\nvoid LoadBinary(const std::string& path, size_t size, void* dst);\n\nclass RNG {\npublic:\n    RNG();\n    ~RNG();\n    void GenerateUInt(uint* out, size_t count);\n\n    template<typename T>\n    void GenerateUniform(T* out, size_t count, float scale = 1.f, float shift = 0.f);\n\n    template<typename T>\n    void GenerateNormal(T* out, size_t count, float scale = 1.f, float shift = 0.f);\n\n    void RandomBytes(Ref<Tensor> out_);\n\n    void UniformFloat(Ref<Tensor> out_, float scale = 1.f, float shift = 0.f);\n\n    void NormalFloat(Ref<Tensor> out_, float scale = 1.f, float shift = 0.f);\n\n    cudaStream_t stream() const;\n\n    void set_stream(cudaStream_t stream);\n\nprivate:\n    struct Impl;\n    std::unique_ptr<Impl> impl_;\n};\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/test/testbed_v3.h",
    "content": "\n#include <memory>\n\n#include \"src/turbomind/core/allocator.h\"\n#include \"src/turbomind/core/core.h\"\n\n#include \"src/turbomind/core/tensor.h\"\n#include \"src/turbomind/kernels/gemm/moe_utils_v2.h\"\n#include \"src/turbomind/kernels/gemm/test/reference.h\"\n#include \"src/turbomind/kernels/gemm/test/test_utils.h\"\n#include \"src/turbomind/kernels/gemm/types.h\"\n#include \"src/turbomind/kernels/quantization.h\"\n\n#include \"src/turbomind/models/llama/LlamaDenseWeight.h\"\n#include \"src/turbomind/models/llama/LlamaLinear.h\"\n\n#include \"src/turbomind/kernels/gpt_kernels.h\"\n\nnamespace turbomind {\n\nusing std::vector;\nusing std::unique_ptr;\n\nusing DenseWeight = LlamaDenseWeight;\nusing Linear      = LlamaLinear;\n\nusing namespace gemm;\n\nstruct Parameter {\n    int input_dim;\n    int output_dim;\n\n    DataType data_type;\n    DataType weight_type;\n    DataType input_type;\n\n    int group_size;\n\n    int max_batch_size;\n\n    int  expert_num;\n    int  experts_per_token;\n    bool combine_experts;\n};\n\n/// TODO: add a generic copy / casting for non-sub-byte Tensor\nstatic Tensor CopyTransposed(const Tensor& src, Tensor out = {})\n{\n    if (out) {\n        TM_CHECK(out.shapes(0, 1) == src.shapes(1, 0)) << src << \" vs \" << out;\n        TM_CHECK_EQ(out.dtype(), src.dtype());\n    }\n    else {\n        out = {{src.shape(1), src.shape(0)}, src.dtype(), src.device()};\n    }\n\n    auto invoke = [&](auto t) {\n        using T = decltype(t);\n        invokeTransposeAxis01(\n            (T*)out.raw_data(), (T*)src.raw_data(), src.shape(0), src.shape(1), 1, core::Context::stream().handle());\n    };\n\n    const int bits = byte_size(src.dtype(), 8);\n    if (bits == 8) {\n        invoke(uint8_t{});\n    }\n    else if (bits == 16) {\n        invoke(uint16_t{});\n    }\n    else if (bits == 32) {\n        invoke(int{});\n    }\n    else {\n        TM_CHECK(0) << \"Not implemented. bits = \" << bits;\n    }\n\n    return out;\n}\n\nstruct Testbed_v3: Parameter {\n\n    Testbed_v3(const Parameter& param): Parameter{param}, stream_{core::Context::stream().handle()}, linear_{}\n    {\n        rng_.set_stream(stream_);\n        ref_.set_stream(stream_);\n\n        if (auto str = std::getenv(\"TM_GEMM_IMPORT\")) {\n            import_file_ = str;\n            std::ifstream ifs(import_file_, std::ios::binary);\n            auto          n = linear_.Import(ifs);\n            std::cout << \"Records imported: \" << n << \"\\n\";\n        }\n        if (auto str = std::getenv(\"TM_GEMM_TUNE\"); str && import_file_.empty()) {\n            tuning_ = true;\n            std::cout << \"Enable tuning\\n\";\n        }\n        if (auto str = std::getenv(\"TM_GEMM_EXPORT\"); str && import_file_.empty()) {\n            export_file_ = str;\n        }\n\n        cudaGetDeviceProperties(&prop_, 0);\n\n        w_original_ = std::make_unique<DenseWeight>();\n        w_quant_    = std::make_unique<DenseWeight>();\n        w_dequant_  = std::make_unique<DenseWeight>();\n\n        for (int i = 0; i < expert_num; ++i) {\n            e_original_.push_back(std::make_unique<DenseWeight>());\n            e_quant_.push_back(std::make_unique<DenseWeight>());\n            e_dequant_.push_back(std::make_unique<DenseWeight>());\n        }\n\n        GenerateWeight();\n        GenerateInput();\n\n        if (expert_num) {\n            LinkExperts([&](int i) { return e_original_[i].get(); }, expert_num, *w_original_);\n            LinkExperts([&](int i) { return e_quant_[i].get(); }, expert_num, *w_quant_);\n            LinkExperts([&](int i) { return e_dequant_[i].get(); }, expert_num, *w_dequant_);\n            Route();\n        }\n    }\n\n    ~Testbed_v3()\n    {\n        if (!export_file_.empty()) {\n            std::cerr << \"export file: \" << export_file_ << \"\\n\";\n            std::ofstream ofs(export_file_, std::ios::binary);\n            if (ofs.is_open()) {\n                auto n = linear_.Export(ofs);\n                std::cout << \"Records exported: \" << n << \"\\n\";\n            }\n        }\n    }\n\n    void GenerateInput()\n    {\n        x_original_ = Tensor{{max_batch_size, input_dim}, data_type, kDEVICE};\n        rng_.NormalFloat(x_original_, 1., 1.);\n\n        if (input_type == data_type) {\n            x_quant_   = empty_like(x_original_);\n            x_dequant_ = empty_like(x_original_);\n            Copy(x_original_, x_quant_);\n            Copy(x_original_, x_dequant_);\n        }\n        else if (input_type == kFloat8_e4m3) {\n            QuantizeSymm(x_quant_, x_scale_, x_original_, stream_);\n            DequantizeSymm(x_dequant_, x_quant_, x_scale_, stream_);\n        }\n        else {\n            TM_CHECK(0) << \"Not implemented for input type \" << to_string(input_type);\n        }\n    }\n\n    void Route()\n    {\n        const int bsz = max_batch_size;\n\n        std::mt19937 g{};\n\n        /// TODO: Control the distribution\n        auto expert_ids = SampleUniform(bsz, expert_num, experts_per_token, g);\n\n        std::uniform_real_distribution<float> dist(1e-3, 1.f);\n\n        Buffer_<float> tmp(experts_per_token, kCPU);\n        Buffer_<float> scales(bsz * experts_per_token, kCPU);\n\n        for (int i = 0; i < bsz; ++i) {\n            float sum{};\n            for (auto& x : tmp) {\n                x = dist(g);\n                sum += x;\n            }\n            for (int e = 0; e < experts_per_token; ++e) {\n                scales[e * bsz + i] = tmp[e] / sum;\n            }\n        }\n\n        vector<int>         count(expert_num);\n        vector<vector<int>> f2i(expert_num);\n        for (int i = 0; i < (int)expert_ids.size(); ++i) {\n            ++count[expert_ids[i]];\n            f2i[expert_ids[i]].push_back(i);\n        }\n\n        Buffer_<int> offsets(expert_num + 1, kCPU);\n        offsets[0] = 0;\n        for (int i = 0; i < expert_num; ++i) {\n            offsets[i + 1] = offsets[i] + count[i];\n        }\n\n        for (const auto& x : count) {\n            std::cout << x << \" \";\n        }\n        std::cout << \"\\n\";\n\n        Buffer_<int> f2n(expert_ids.size(), kCPU);\n        Buffer_<int> en2f(expert_ids.size(), kCPU);\n        for (int e = 0, i = 0; e < expert_num; ++e) {\n            for (auto x : f2i[e]) {\n                f2n[i]   = x / experts_per_token;\n                int en   = x % experts_per_token * bsz + x / experts_per_token;\n                en2f[en] = i;\n                ++i;\n            }\n        }\n\n        f2n_ = {f2n.size(), kDEVICE};\n        Copy(f2n, f2n_);\n\n        en2f_ = {en2f.size(), kDEVICE};\n        Copy(en2f, en2f_);\n\n        scales_ = {scales.size(), kDEVICE};\n        Copy(scales, scales_);\n\n        offsets_ = {offsets.size(), kDEVICE};\n        Copy(offsets, offsets_);\n        h_offsets_ = offsets;\n\n        core::Context::stream().Sync();\n    }\n\n    void GenerateWeight()\n    {\n        if (expert_num) {\n            for (int i = 0; i < expert_num; ++i) {\n                GenerateWeight(*e_original_[i], *e_quant_[i], *e_dequant_[i]);\n            }\n        }\n        else {\n            GenerateWeight(*w_original_, *w_quant_, *w_dequant_);\n        }\n    }\n\n    // - quantize weight\n    // - dequantize weight\n    void GenerateWeight(DenseWeight& original, DenseWeight& quant, DenseWeight& dequant)\n    {\n        original.emplace(input_dim, output_dim, data_type, false, data_type, group_size);\n        rng_.NormalFloat(original.weight, 1., .1);\n\n        quant.emplace(input_dim, output_dim, data_type, false, weight_type, group_size);\n        dequant.emplace(input_dim, output_dim, data_type, false, data_type, group_size);\n\n        Buffer_<unsigned> rbits;\n        // rbits = {original.weight.size(), kDEVICE};\n        // rng_.RandomBytes(Tensor{rbits});\n\n        /// Weights are allocated in MN-major, but some quantization requires K-major tensor\n\n        if (weight_type == data_type) {\n            Copy(original.weight, quant.weight);\n            Copy(original.weight, dequant.weight);\n        }\n        else if (weight_type == kFloat8_e4m3) {\n            QuantizeSymmBlock(quant.weight, quant.scales, original.weight, stream_);\n            DequantizeSymmBlock(dequant.weight, quant.weight, quant.scales, stream_);\n        }\n        else if (weight_type == kUint4) {\n            /// Weights are allocated in (M,N), quantization needs K-major tensor\n            QuantizeGroupwise(quant.weight.t(),\n                              quant.scales.t(),\n                              quant.zeros.t(),\n                              dequant.weight.t(),\n                              original.weight.t(),\n                              {},\n                              group_size);\n        }\n        else if (weight_type == kFloat4_e2m1) {\n            QuantizeGroupwise(quant.weight.t(),  //\n                              quant.scales.t(),\n                              {},\n                              dequant.weight.t(),\n                              original.weight.t(),\n                              rbits,\n                              group_size);\n        }\n        else {\n            TM_CHECK(0);\n        }\n\n        original.prepare(0);\n        quant.prepare(expert_num > 0);\n        dequant.prepare(0);\n    }\n\n    void GetReference()\n    {\n        if (expert_num) {\n            GetReference(x_original_, e_original_, d_original_);\n            GetReference(x_dequant_, e_dequant_, d_dequant_);\n        }\n        else {\n            GetReference(x_original_, w_original_, d_original_);\n            GetReference(x_dequant_, w_dequant_, d_dequant_);\n        }\n    }\n\n    void GetReference(const Tensor& x, const unique_ptr<DenseWeight>& dense, Ref<Tensor> d_)\n    {\n        auto& d = d_.get();\n        if (!d) {\n            d = Tensor{{x.shape(0), dense->output_dim}, x.dtype(), x.device()};\n        }\n        /// TODO: refactor reference API\n        const MatrixLayout desc_A{x.dtype(), kRowMajor, (int)x.shape(0), (int)x.shape(1), (int)x.stride(0)};  // m,k\n        const MatrixLayout desc_D{d.dtype(), kRowMajor, (int)d.shape(0), (int)d.shape(1), (int)d.stride(0)};  // m,n\n        ref_.gemm(x.raw_data(), desc_A, dense->weight.raw_data(), dense->k_desc, d.raw_data(), desc_D);\n    }\n\n    void GetReference(const Tensor& x, const vector<unique_ptr<DenseWeight>>& experts, Ref<Tensor> d_)\n    {\n        Tensor xe{{x.shape(0) * experts_per_token, input_dim}, data_type, kDEVICE};\n        Tensor de{{x.shape(0) * experts_per_token, output_dim}, data_type, kDEVICE};\n\n        invokeMoeDispatch(xe, x, f2n_.data(), experts_per_token, stream_);\n\n        for (int i = 0; i < expert_num; ++i) {\n            const int base = h_offsets_[i], size = h_offsets_[i + 1] - base;\n            GetReference(xe.slice(base, size), experts[i], de.slice(base, size));\n        }\n\n        auto& d = d_.get();\n        if (combine_experts) {\n            d = Tensor{{x.shape(0), output_dim}, data_type, kDEVICE};\n            invokeMoeCombine(d,  //\n                             de,\n                             {},\n                             scales_.data(),\n                             en2f_.data(),\n                             nullptr,\n                             nullptr,\n                             experts_per_token,\n                             1.,\n                             0.,\n                             stream_);\n        }\n        else {\n            d = de;\n        }\n    }\n\n    void Run()\n    {\n        if (tuning_) {\n            linear_.set_measure(true);\n        }\n        if (expert_num) {\n            auto de = linear_.Forward(x_original_, *w_quant_, f2n_, offsets_);\n            if (combine_experts) {\n                d_quant_ = Tensor{{x_original_.shape(0), output_dim}, data_type, kDEVICE};\n                invokeMoeCombine(d_quant_,\n                                 de,\n                                 {},\n                                 scales_.data(),\n                                 en2f_.data(),\n                                 nullptr,\n                                 nullptr,\n                                 experts_per_token,\n                                 1.,\n                                 0.,\n                                 stream_);\n            }\n            else {\n                d_quant_ = de;\n            }\n        }\n        else {\n            d_quant_ = linear_.Forward(x_original_, *w_quant_);\n        }\n        if (tuning_) {\n            linear_.set_measure(false);\n        }\n    }\n\n    void Run(const Tensor& x, const vector<unique_ptr<DenseWeight>>& experts) {}\n\n    void Compare()\n    {\n        // Buffer_<float> h(16 * 16, kCPU);\n        // Buffer_<float> x(linear_.buf, 16 * 16, kDEVICE);\n        // Copy(x, h);\n\n        // auto y = empty_like(w_dequant_->weight, kCPU);\n        // Copy(w_dequant_->weight, y);\n\n        // clang-format off\n        printf(\"%20s\", \"\"); FC_Header();\n        if (!expert_num) {\n            printf(\"%20s\", \"w_dequant v w_origi\"); FC_Print(FastCompare(w_dequant_->weight, w_original_->weight, stream_));\n        }\n        printf(\"%20s\", \"quant   vs  dequant\"); FC_Print(FastCompare(d_quant_, d_dequant_, stream_));\n        printf(\"%20s\", \"quant   vs original\"); FC_Print(FastCompare(d_quant_, d_original_, stream_));\n        printf(\"%20s\", \"dequant vs original\"); FC_Print(FastCompare(d_dequant_, d_original_, stream_));\n        // clang-format on\n\n        // for (int m = 0; m < 16; ++m) {\n        //     for (int k = 0; k < 16; ++k) {\n        //         printf(\"%5.1f\", h[m * 16 + k]);\n        //     }\n        //     printf(\"\\n\");\n        // }\n\n        // printf(\"\\n\");\n\n        // for (int m = 0; m < 16; ++m) {\n        //     for (int k = 0; k < 16; ++k) {\n        //         printf(\"%5.1f\", (float)y.data<bfloat16_t>()[k * output_dim + m]);\n        //     }\n        //     printf(\"\\n\");\n        // }\n    }\n\n    cudaStream_t stream_;\n\n    cudaDeviceProp prop_;\n\n    Linear linear_;\n\n    // ! weights are non-movable\n    unique_ptr<DenseWeight> w_original_;\n    unique_ptr<DenseWeight> w_quant_;\n    unique_ptr<DenseWeight> w_dequant_;\n\n    Tensor x_original_;\n    Tensor x_quant_, x_scale_;\n    Tensor x_dequant_;\n\n    Tensor d_original_;  // x_original * w_original\n    Tensor d_quant_;     // x_original * w_quant, quant for X done by `Linear`\n    Tensor d_dequant_;   // x_dequant  * w_dequant\n\n    vector<unique_ptr<DenseWeight>> e_original_;\n    vector<unique_ptr<DenseWeight>> e_quant_;\n    vector<unique_ptr<DenseWeight>> e_dequant_;\n\n    Buffer_<int> f2n_;\n    Buffer_<int> en2f_;\n\n    Buffer_<int>   offsets_;\n    Buffer_<float> scales_;\n\n    Buffer_<int> h_offsets_;\n\n    bool tuning_{};\n\n    std::string import_file_;\n    std::string export_file_;\n\n    RNG       rng_;\n    Reference ref_;\n};\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/thread_group_map.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include \"src/turbomind/kernels/core/common.h\"\n#include \"src/turbomind/kernels/core/math.h\"\n#include \"src/turbomind/kernels/core/meta.h\"\n#include \"src/turbomind/kernels/gemm/thread_map.h\"\n\n#include <iostream>\n\nnamespace turbomind::gemm {\n\ntemplate<int M_, int N_, int K_, int TM, int TN, int TK, int GM, int GN, int GK>\nstruct RakedThreadGroupMap {\n    static constexpr int M = M_;\n    static constexpr int N = N_;\n    static constexpr int K = K_;\n\n    static constexpr int TileM = TM;\n    static constexpr int TileN = TN;\n    static constexpr int TileK = TK;\n\n    static constexpr int kGroupM = GM;\n    static constexpr int kGroupN = GN;\n    static constexpr int kGroupK = GK;\n\n    static constexpr int kGroupCount = GM * GN * GK;\n\n    static constexpr int M1 = GM * TM;\n    static constexpr int N1 = GN * TN;\n    static constexpr int K1 = GK * TK;\n\n    static constexpr int kIterM = M / M1;\n    static constexpr int kIterN = N / N1;\n    static constexpr int kIterK = K / K1;\n\n    static constexpr int kFootprintM = kIterM * TM;\n    static constexpr int kFootprintN = kIterN * TN;\n    static constexpr int kFootprintK = kIterK * TK;\n\n    static constexpr int kDeltaM = TM;\n    static constexpr int kDeltaN = TN;\n    static constexpr int kDeltaK = TK;\n\n    __device__ static int3 get_offset(int group_id)\n    {\n        const int m = group_id % GM;\n        const int n = group_id / GM % GN;\n        const int k = group_id / GM / GN;\n        return {m * kFootprintM, n * kFootprintN, k * kFootprintK};\n    }\n};\n\ntemplate<int M_, int N_, int K_, int tM_, int tN_, int tK_, class ArrangementMN, int gK, bool rK = 0>\nstruct MMA_Map {\n    static constexpr int M = M_;\n    static constexpr int N = N_;\n    static constexpr int K = K_;\n\n    static constexpr int TileM = tM_;\n    static constexpr int TileN = tN_;\n    static constexpr int TileK = tK_;\n\n    static constexpr int kGroupM = ArrangementMN::gM;\n    static constexpr int kGroupN = ArrangementMN::gN;\n    static constexpr int kGroupK = gK;\n\n    static constexpr int kGroupCount = kGroupM * kGroupN * kGroupK;\n\n    static constexpr int kIterM = M / tM_ / kGroupM;\n    static constexpr int kIterN = N / tN_ / kGroupN;\n    static constexpr int kIterK = K / tK_ / kGroupK;\n\n    static constexpr int kFootprintM = kIterM * tM_;\n    static constexpr int kFootprintN = kIterN * tN_;\n    static constexpr int kFootprintK = kIterK * tK_;\n\n    static constexpr int kDeltaM = tM_ * ArrangementMN::dM;\n    static constexpr int kDeltaN = tN_ * ArrangementMN::dN;\n    static constexpr int kDeltaK = tK_ * (rK ? gK : 1);\n\n    static constexpr auto kPartitionM = ArrangementMN::pM;\n    static constexpr auto kPartitionN = ArrangementMN::pN;\n    static constexpr auto kPartitionK = rK ? Partition::kRaked : Partition::kBlocked;\n\n    __device__ static int3 get_offset(int group_id)\n    {\n        constexpr int kGroupMN = kGroupM * kGroupN;\n\n        const auto mn = ArrangementMN::get_offset(group_id % kGroupMN, pair<M / TileM, N / TileN>{});\n        const int  k  = group_id / kGroupMN;\n\n        return {mn.x * tM_, mn.y * tN_, k * tK_ * (rK ? 1 : kIterK)};\n    }\n};\n\nnamespace {\n\ntemplate<class TMap>\nvoid Print_(TMap)\n{\n    std::cout << \"M, N, K = \" << TMap::M << \" \" << TMap::N << \" \" << TMap::K << \"\\n\";\n    std::cout << \"TM, TN, TK = \" << TMap::TileM << \" \" << TMap::TileN << \" \" << TMap::TileK << \"\\n\";\n    std::cout << \"group count = \" << TMap::kGroupCount << \"\\n\";\n    // std::cout << \"M1, N1, K1 = \" << TMap::M1 << \" \" << TMap::N1 << \" \" << TMap::K1 << \"\\n\";\n    std::cout << \"itM, itN, itK = \" << TMap::kIterM << \" \" << TMap::kIterN << \" \" << TMap::kIterK << \"\\n\";\n    std::cout << \"fpM, fpN, fpK = \" << TMap::kFootprintM << \" \" << TMap::kFootprintN << \" \" << TMap::kFootprintK\n              << \"\\n\";\n    std::cout << \"dM, dN, dK = \" << TMap::kDeltaM << \" \" << TMap::kDeltaN << \" \" << TMap::kDeltaK << \"\\n\";\n}\n\n}  // namespace\n\n/// TODO: Striped partition?\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/thread_map.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include \"src/turbomind/kernels/core/common.h\"\n#include \"src/turbomind/kernels/core/math.h\"\n#include \"src/turbomind/kernels/core/meta.h\"\n\n#include \"src/turbomind/kernels/gemm/types.h\"\n\n#include <iostream>\n\nnamespace turbomind::gemm {\n\ntemplate<int DimC, int DimS, int AccessC, int WarpCount, int WarpThreadC = std::min(WARP_SIZE, DimC / AccessC)>\nstruct ThreadMap {\n    static constexpr int kDimC = DimC;\n    static constexpr int kDimS = DimS;\n\n    static constexpr int kWarpCount = WarpCount;\n    static constexpr int kAccessC   = AccessC;\n\n    static constexpr int kWarpThreadC = WarpThreadC;\n    static constexpr int kWarpThreadS = WARP_SIZE / kWarpThreadC;\n\n    static_assert(kWarpThreadC <= WARP_SIZE);\n\n    static constexpr int kWarpAccessC = kWarpThreadC * kAccessC;\n    static constexpr int kWarpAccessS = kWarpThreadS;\n\n    static constexpr int kWarpIterC = ceil_div(kDimC, kWarpAccessC);\n    static constexpr int kWarpIterS = ceil_div(kDimS, kWarpAccessS);\n\n    // Partition warps along the strided axis first to reduce strided iters\n    static constexpr int kWarpS = kWarpIterS >= kWarpCount ? kWarpCount : kWarpIterS;\n    static constexpr int kWarpC = kWarpCount > kWarpIterS ? kWarpCount / kWarpS : 1;\n\n    static constexpr int kIterC = ceil_div(kWarpIterC, kWarpC);\n    static constexpr int kIterS = ceil_div(kWarpIterS, kWarpS);\n\n    // Allow partial tile when there is ONLY 1 iteration\n    static_assert(kDimC % kWarpAccessC == 0 || kIterC == 1);\n\n    // static_assert(kIterC > 0);\n    // static_assert(kIterS > 0);\n\n    static constexpr bool kAlignedC = (kDimC % kWarpAccessC == 0) && (kWarpIterC % kWarpC == 0);\n    static constexpr bool kAlignedS = (kDimS % kWarpAccessS == 0) && (kWarpIterS % kWarpS == 0);\n\n    static constexpr int kFootprintC = kWarpAccessC * kIterC;\n    static constexpr int kFootprintS = kWarpAccessS * kIterS;\n\n    static constexpr int kDeltaC = kWarpAccessC;\n    static constexpr int kDeltaS = kWarpAccessS;\n\n    // static constexpr int kDeltaC = kWarpAccessC * kWarpC;\n    // static constexpr int kDeltaS = kWarpAccessS * kWarpS;\n\n    __device__ static int2 get_offset(int warp_id, int lane_id)\n    {\n        int warp_offset_c = warp_id % kWarpC;\n        int warp_offset_s = warp_id / kWarpC;\n\n        int warp_thread_offset_c = lane_id % kWarpThreadC;\n        int warp_thread_offset_s = lane_id / kWarpThreadC;\n\n        int cta_thread_offset_c = kFootprintC * warp_offset_c + warp_thread_offset_c * kAccessC;\n        int cta_thread_offset_s = kFootprintS * warp_offset_s + warp_thread_offset_s;\n\n        // int cta_thread_offset_c = kWarpAccessC * warp_offset_c + warp_thread_offset_c * kAccessC;\n        // int cta_thread_offset_s = kWarpAccessS * warp_offset_s + warp_thread_offset_s;\n\n        return {cta_thread_offset_c, cta_thread_offset_s};\n    }\n};\n\ntemplate<Order order, int M, int K>\n__host__ __device__ static constexpr int2 idx2mk(int idx, pair<M, K>)\n{\n    if constexpr (order == kColMajor) {\n        return {idx % M, idx / M};\n    }\n    else {\n        return {idx / K, idx % K};\n    }\n}\n\nenum class Partition\n{\n    kBlocked,\n    kRaked,\n};\n\ntemplate<int gM_, int gN_, Order order>\nstruct Blocked {\n    static constexpr int gM = gM_;\n    static constexpr int gN = gN_;\n\n    // static_assert((gM - 1) * sM + (gN - 1) * sN == gM * gN - 1);\n\n    static constexpr int dM = 1;\n    static constexpr int dN = 1;\n\n    static constexpr Partition pM = Partition::kBlocked;\n    static constexpr Partition pN = Partition::kBlocked;\n\n    template<int M, int N>\n    __device__ static int2 get_offset(int idx, pair<M, N>)\n    {\n        constexpr int iM = ceil_div(M, gM);\n        constexpr int iN = ceil_div(N, gN);\n\n        // const int mi = idx / sM % gM;\n        // const int ni = idx / sN % gN;\n\n        const int2 mn = idx2mk<order>(idx, pair<gM, gN>{});\n        return {mn.x * iM, mn.y * iN};\n    }\n};\n\ntemplate<int gM_, int gN_, Order order>\nstruct Raked {\n    static constexpr int gM = gM_;\n    static constexpr int gN = gN_;\n\n    // static_assert((gM - 1) * sM + (gN - 1) * sN == gM * gN - 1);\n\n    static constexpr int dM = gM;\n    static constexpr int dN = gN;\n\n    static constexpr Partition pM = Partition::kRaked;\n    static constexpr Partition pN = Partition::kRaked;\n\n    template<class Shape>\n    __device__ static int2 get_offset(int idx, Shape)\n    {\n        return idx2mk<order>(idx, pair<gM, gN>{});\n    }\n};\n\ntemplate<int gM_, int gN_, Order order>\nstruct Blocked_C_Raked_S {\n    static constexpr int gM = gM_;\n    static constexpr int gN = gN_;\n\n    static constexpr int dM = 1;\n    static constexpr int dN = gN;\n\n    static constexpr Partition pM = Partition::kBlocked;\n    static constexpr Partition pN = Partition::kRaked;\n\n    template<int M, int N>\n    __device__ static int2 get_offset(int idx, pair<M, N>)\n    {\n        constexpr int iM = ceil_div(M, gM);\n\n        const int2 mn = idx2mk<order>(idx, pair<gM, gN>{});\n        return {mn.x * iM, mn.y};\n    }\n};\n\ntemplate<int C,\n         int S,\n         int AccessC,\n         template<int, int, Order>\n         typename Arrangement_,\n         int WarpCount,\n         int WarpThrC = std::min(WARP_SIZE, C / AccessC)>\nstruct ThreadMap_V2 {\n    static constexpr int kDimC = C;\n    static constexpr int kDimS = S;\n\n    static constexpr int kWarpCount = WarpCount;\n    static constexpr int kAccessC   = AccessC;\n\n    static_assert(WarpThrC <= WARP_SIZE);\n\n    static constexpr int kWarpThreadC = WarpThrC;\n    static constexpr int kWarpThreadS = WARP_SIZE / kWarpThreadC;\n\n    static constexpr int kWarpAccessC = kWarpThreadC * kAccessC;\n    static constexpr int kWarpAccessS = kWarpThreadS;\n\n    static constexpr int kWarpIterC = ceil_div(kDimC, kWarpAccessC);\n    static constexpr int kWarpIterS = ceil_div(kDimS, kWarpAccessS);\n\n    static constexpr int kWarpS = kWarpIterS >= kWarpCount ? kWarpCount : kWarpIterS;\n    static constexpr int kWarpC = kWarpCount > kWarpIterS ? kWarpCount / kWarpS : 1;\n\n    using Arrangement = Arrangement_<kWarpC, kWarpS, kColMajor>;\n\n    static constexpr auto kPartitionM = Arrangement::pM;\n    static constexpr auto kPartitionN = Arrangement::pN;\n\n    static constexpr int kIterC = ceil_div(kWarpIterC, kWarpC);\n    static constexpr int kIterS = ceil_div(kWarpIterS, kWarpS);\n\n    static constexpr bool kAlignedC = (kDimC % kWarpAccessC == 0) && (kWarpIterC % kWarpC == 0);\n    static constexpr bool kAlignedS = (kDimS % kWarpAccessS == 0) && (kWarpIterS % kWarpS == 0);\n\n    static constexpr int kFootprintC = kWarpAccessC * kIterC;\n    static constexpr int kFootprintS = kWarpAccessS * kIterS;\n\n    static constexpr int kDeltaC = kWarpAccessC * Arrangement::dM;\n    static constexpr int kDeltaS = kWarpAccessS * Arrangement::dN;\n\n    __device__ static int2 get_offset(int warp_id, int lane_id)\n    {\n        const int2 warp_offset = Arrangement::get_offset(warp_id, pair<kWarpIterC, kWarpIterS>{});\n\n        int warp_thr_offset_c = lane_id % kWarpThreadC;\n        int warp_thr_offset_s = lane_id / kWarpThreadC;\n\n        if constexpr (kWarpThreadC == WARP_SIZE) {\n            warp_thr_offset_c = lane_id;\n            warp_thr_offset_s = 0;\n        }\n\n        const int offset_c = warp_offset.x * kWarpAccessC + warp_thr_offset_c * kAccessC;\n        const int offset_s = warp_offset.y * kWarpAccessS + warp_thr_offset_s;\n\n        return {offset_c, offset_s};\n    }\n};\n\nnamespace {\n\ntemplate<class TMap>\nvoid Print(TMap)\n{\n    std::cout << \"     warps: \" << TMap::kWarpCount << \"\\n\";\n    std::cout << \"     shape: (\" << TMap::kDimC << \", \" << TMap::kDimS << \")\\n\";\n    std::cout << \"    access: (\" << TMap::kAccessC << \", \" << 1 << \")\\n\";\n    std::cout << \"warpThread: (\" << TMap::kWarpThreadC << \", \" << TMap::kWarpThreadS << \")\\n\";\n    std::cout << \"warpAccess: (\" << TMap::kWarpAccessC << \", \" << TMap::kWarpAccessS << \")\\n\";\n    std::cout << \"  warpIter: (\" << TMap::kWarpIterC << \", \" << TMap::kWarpIterS << \")\\n\";\n    std::cout << \"      warp: (\" << TMap::kWarpC << \", \" << TMap::kWarpS << \")\\n\";\n    std::cout << \"      iter: (\" << TMap::kIterC << \", \" << TMap::kIterS << \")\\n\";\n    std::cout << \" footprint: (\" << TMap::kFootprintC << \", \" << TMap::kFootprintS << \")\\n\";\n    std::cout << \"     delta: (\" << TMap::kDeltaC << \", \" << TMap::kDeltaS << \")\\n\";\n    std::cout << \"   aligned: (\" << TMap::kAlignedC << \",\" << TMap::kAlignedS << \")\\n\";\n}\n\n}  // namespace\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/tiled_mma.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include \"src/turbomind/kernels/core/array_ops.h\"\n#include \"src/turbomind/kernels/core/common.h\"\n#include \"src/turbomind/kernels/core/meta.h\"\n#include \"src/turbomind/kernels/core/mma.h\"\n#include \"src/turbomind/kernels/core/smem.h\"\n#include \"src/turbomind/kernels/gemm/desc.h\"\n#include \"src/turbomind/kernels/gemm/simt.h\"\n#include \"src/turbomind/kernels/gemm/smem_copy.h\"\n#include \"src/turbomind/kernels/gemm/thread_map.h\"\n#include \"src/turbomind/kernels/gemm/types.h\"\n#include \"src/turbomind/kernels/gemm/utils.h\"\n\nnamespace turbomind::gemm {\n\ntemplate<class MMA_Atom_, class MMA_Map_, Order order_ = kColMajor>\nstruct Tiled_MMA_v2 {\n    using Atom = MMA_Atom_;\n    using Map  = MMA_Map_;\n\n    static constexpr int M = Map::M;\n    static constexpr int N = Map::N;\n    static constexpr int K = Map::K;\n\n    static constexpr int kGroupCount  = Map::kGroupCount;\n    static constexpr int kThreadCount = kGroupCount * Atom::kThreadCount;\n\n    static constexpr int kTileIterM = Map::kIterM;\n    static constexpr int kTileIterN = Map::kIterN;\n    static constexpr int kTileIterK = Map::kIterK;\n\n    static constexpr int kDeltaM = Map::kDeltaM;\n    static constexpr int kDeltaN = Map::kDeltaN;\n    static constexpr int kDeltaK = Map::kDeltaK;\n\n    static constexpr int kAtomM = Map::TileM / Atom::M;\n    static constexpr int kAtomN = Map::TileN / Atom::N;\n    static constexpr int kAtomK = Map::TileK / Atom::K;\n\n    static constexpr int kMmaIterM = kTileIterM * kAtomM;\n    static constexpr int kMmaIterN = kTileIterN * kAtomN;\n    static constexpr int kMmaIterK = kTileIterK * kAtomK;\n\n    __device__ static int3 get_offset(int thread_idx)\n    {\n        return Map::get_offset(Atom::get_group_id(thread_idx));\n    }\n\n    // (M,N)\n    template<class FragD, class FragA, class FragB, class FragC>\n    __device__ static void mma_k_iter(FragD& frag_D, const FragA& frag_A, const FragB& frag_B, const FragC& frag_C)\n    {\n        if constexpr (order_ == kColMajor) {\n            PRAGMA_UNROLL\n            for (int n = 0; n < kMmaIterN; ++n) {\n                PRAGMA_UNROLL\n                for (int m = 0; m < kMmaIterM; ++m) {\n                    int mm = n % 2 ? (kMmaIterM - m - 1) : m;\n                    Atom::fma(frag_D[mm][n], frag_A[mm], frag_B[n], frag_C[mm][n]);\n                }\n            }\n        }\n        else {\n            PRAGMA_UNROLL\n            for (int m = 0; m < kMmaIterM; ++m) {\n                PRAGMA_UNROLL\n                for (int n = 0; n < kMmaIterN; ++n) {\n                    int nn = n;\n                    int mm = m;\n                    Atom::fma(frag_D[mm][nn], frag_A[mm], frag_B[nn], frag_C[mm][nn]);\n                }\n            }\n        }\n    }\n};\n\ntemplate<class MMA>\nstruct Rearrange {\n    using Map  = typename MMA::Map;\n    using Atom = typename MMA::Atom;\n\n    template<class T, int V, int M, int N, class Layout, Order order, int TM, int TN>\n    __device__ static void\n    apply(Array<T, V> (&frag_C)[M][N], SmemAccessorV2<T, Layout, order>& smem_C, int2 offset_mn, pair<TM, TN>)\n    {\n        const int3 offset_mnk = MMA::get_offset(threadIdx.x);\n        const int  group_id_k = offset_mnk.z / Map::kFootprintK;\n\n        constexpr bool kRakedM = Map::kPartitionM == Partition::kRaked;\n        constexpr bool kRakedN = Map::kPartitionN == Partition::kRaked;\n\n        static constexpr int2 kMN0 = cs2mk<order>(Layout::C0, Layout::S0);\n\n        constexpr int kPeriodM  = ceil_div(kMN0.x, Map::kDeltaM);\n        constexpr int kPeriodN  = ceil_div(kMN0.y, Map::kDeltaN);\n        constexpr int kPeriodM1 = ceil_div(kMN0.x, Atom::M);\n        constexpr int kPeriodN1 = ceil_div(kMN0.y, Atom::N);\n\n        constexpr auto offset_C = Atom::static_offset_C();\n        const int2     thr      = Atom::thread_offset_C();\n\n        // Contract: All these indices is not a part of swizzling\n        int phases[kPeriodM][kPeriodN][kPeriodM1][kPeriodN1][offset_C.size()];\n        PRAGMA_UNROLL\n        for (int m = 0; m < kPeriodM; ++m) {\n            PRAGMA_UNROLL\n            for (int n = 0; n < kPeriodN; ++n) {\n                PRAGMA_UNROLL\n                for (int m1 = 0; m1 < kPeriodM1; ++m1) {\n                    PRAGMA_UNROLL\n                    for (int n1 = 0; n1 < kPeriodN1; ++n1) {\n                        const int mm = offset_mnk.x + m * Map::kDeltaM + m1 * Atom::M + thr.x;\n                        const int nn = offset_mnk.y + n * Map::kDeltaN + n1 * Atom::N + thr.y;\n                        PRAGMA_UNROLL\n                        for (int i = 0; i < offset_C.size(); ++i) {\n                            const int2 cs           = mk2cs<order>(mm + offset_C[i].x, nn + offset_C[i].y);\n                            phases[m][n][m1][n1][i] = Layout::apply(cs.y, cs.x);\n                        }\n                    }\n                }\n            }\n        }\n\n        constexpr int K = Map::kGroupK;\n        constexpr int C = offset_C.size();\n\n        int offsets[K][M][N][C];\n        int masks[K][M][N][C];\n\n        PRAGMA_UNROLL\n        for (int k = 0; k < K; ++k) {\n            PRAGMA_UNROLL\n            for (int m = 0; m < M; ++m) {\n                PRAGMA_UNROLL\n                for (int n = 0; n < N; ++n) {\n                    int m0 = m / MMA::kAtomM, m1 = m % MMA::kAtomM, n0 = n / MMA::kAtomN, n1 = n % MMA::kAtomN;\n                    int m01 =\n                        m0 / kPeriodM * kPeriodM * Map::kDeltaM + m1 / kPeriodM1 * kPeriodM1 * Atom::M - offset_mn.x;\n                    int n01 =\n                        n0 / kPeriodN * kPeriodN * Map::kDeltaN + n1 / kPeriodN1 * kPeriodN1 * Atom::N - offset_mn.y;\n                    const int2 cs       = mk2cs<order>(m01, n01);\n                    int        offset_0 = Layout::apply(cs.y, cs.x);\n                    PRAGMA_UNROLL\n                    for (int i = 0; i < offset_C.size(); ++i) {\n                        int offset_1        = phases[m0 % kPeriodM][n0 % kPeriodN][m1 % kPeriodM1][n1 % kPeriodN1][i];\n                        offsets[k][m][n][i] = offset_0 + offset_1;\n                        const int bm        = offset_mnk.x - offset_mn.x + m0 * Map::kDeltaM + m1 * Atom::M + thr.x;\n                        const int bn        = offset_mnk.y - offset_mn.y + n0 * Map::kDeltaN + n1 * Atom::N + thr.y;\n                        const int mm        = kRakedM ? m01 : bm;\n                        const int nn        = kRakedN ? n01 : bn;\n                        masks[k][m][n][i]   = (Map::kGroupK == 1 || group_id_k == k)\n                                            && (TM >= Map::M || (0 <= mm && mm < TM))\n                                            && (TN >= Map::N || (0 <= nn && nn < TN));\n                    }\n                }\n            }\n        }\n\n        auto _store = [](auto ptr, auto offset, auto vec) {\n            if constexpr (order == kRowMajor) {\n                Store(&ptr[offset], vec);\n            }\n            else {\n                for (int i = 0; i < vec.size(); ++i) {\n                    ptr[offset + Layout::apply(i, 0)] = vec[i];\n                }\n            }\n        };\n\n        typename Atom::FragC_ reshape_C;\n\n        auto ptr = &smem_C(0, 0);\n\n        PRAGMA_UNROLL\n        for (int m = 0; m < M; ++m) {\n            PRAGMA_UNROLL\n            for (int n = 0; n < N; ++n) {\n                Atom::ReshapeC(frag_C[m][n], reshape_C);\n                PRAGMA_UNROLL\n                for (int c = 0; c < C; ++c) {\n                    auto& vec    = reshape_C[c];\n                    int   offset = offsets[0][m][n][c];\n                    if (masks[0][m][n][c]) {\n                        _store(ptr, offset, vec);\n                    }\n                }\n            }\n        }\n\n        __syncthreads();\n\n#if 1\n        auto _load = [](auto ptr, auto offset, auto& vec) {\n            if constexpr (order == kRowMajor) {\n                Load(vec, &ptr[offset]);\n            }\n            else {\n                for (int i = 0; i < vec.size(); ++i) {\n                    vec[i] = ptr[offset + Layout::apply(i, 0)];\n                }\n            }\n        };\n\n        PRAGMA_UNROLL\n        for (int k = 1; k < K; ++k) {\n            PRAGMA_UNROLL\n            for (int m = 0; m < M; ++m) {\n                PRAGMA_UNROLL\n                for (int n = 0; n < N; ++n) {\n                    Atom::ReshapeC(frag_C[m][n], reshape_C);\n                    PRAGMA_UNROLL\n                    for (int c = 0; c < C; ++c) {\n                        auto& vec    = reshape_C[c];\n                        int   offset = offsets[k][m][n][c];\n                        if (masks[k][m][n][c]) {\n                            std::remove_reference_t<decltype(vec)> tmp;\n                            _load(ptr, offset, tmp);\n                            {\n                                using namespace ops;\n                                vec = vec + tmp;\n                            }\n                            _store(ptr, offset, vec);\n                        }\n                    }\n                }\n            }\n            __syncthreads();\n        }\n#endif\n    }\n};\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/tma.cu",
    "content": "\n#include \"src/turbomind/core/check.h\"\n#include \"src/turbomind/core/cuda_data_type.h\"\n#include \"src/turbomind/kernels/gemm/tma.h\"\n\nnamespace turbomind::gemm {\n\n#if __CUDACC_VER_MAJOR__ >= 12\n\n#if (CUDA_VERSION >= 13000) && (!defined(PFN_cuTensorMapEncodeTiled))\n// PFN_cuTensorMapEncodeTiled not defined in cuda 13 headers.\n#define PFN_cuTensorMapEncodeTiled PFN_cuTensorMapEncodeTiled_v12000\n#endif\n\nnamespace {\n\nPFN_cuTensorMapEncodeTiled get_cuTensorMapEncodeTiled()\n{\n    static const auto ptr = [] {\n        // Get pointer to `cuTensorMapEncodeTiled`\n        cudaDriverEntryPointQueryResult driver_status;\n        void*                           cuTensorMapEncodeTiled_ptr = nullptr;\n\n// https://github.com/NVIDIA/cutlass/pull/2086\n#if CUDA_VERSION >= 13000\n        cudaGetDriverEntryPointByVersion(\n            \"cuTensorMapEncodeTiled\", &cuTensorMapEncodeTiled_ptr, 12000, cudaEnableDefault, &driver_status);\n#else\n        cudaGetDriverEntryPoint(\n            \"cuTensorMapEncodeTiled\", &cuTensorMapEncodeTiled_ptr, cudaEnableDefault, &driver_status);\n#endif\n        TM_CHECK_EQ(driver_status, cudaDriverEntryPointSuccess);\n        return reinterpret_cast<PFN_cuTensorMapEncodeTiled>(cuTensorMapEncodeTiled_ptr);\n    }();\n    return ptr;\n}\n\nCUtensorMap make_2d_tma_desc(void*              global_address,\n                             DataType           data_type,\n                             uint64_t           gmem_dims[2],\n                             uint64_t           stride_in_bytes,\n                             uint32_t           smem_dims[2],\n                             CUtensorMapSwizzle swizzle)\n{\n    uint64_t global_stride[1] = {stride_in_bytes};\n    uint32_t elem_strides[2]  = {1, 1};\n\n    auto encode_func = get_cuTensorMapEncodeTiled();\n\n    CUtensorMap tensor_map = {};\n\n    auto result = encode_func(&tensor_map,\n                              to_CUtensorMap_dtype(data_type),\n                              2,\n                              global_address,\n                              gmem_dims,\n                              global_stride,\n                              smem_dims,\n                              elem_strides,\n                              CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE,\n                              swizzle,\n                              CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_256B,\n                              CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);\n\n    TM_CHECK_EQ(result, CUDA_SUCCESS);\n\n    return tensor_map;\n}\n\n}  // namespace\n\nCUtensorMap make_2d_tma_desc(void*              global_address,\n                             DataType           data_type,\n                             uint32_t           gmem_rows,\n                             uint32_t           gmem_cols,\n                             uint32_t           smem_rows,\n                             uint32_t           smem_cols,\n                             Order              order,\n                             CUtensorMapSwizzle swizzle,\n                             int                ld)\n{\n    if (order == kRowMajor) {\n        uint64_t gmem_dims[] = {gmem_cols, gmem_rows};\n        uint32_t smem_dims[] = {smem_cols, smem_rows};\n        return make_2d_tma_desc(\n            global_address, data_type, gmem_dims, byte_size(data_type, ld ? ld : gmem_cols), smem_dims, swizzle);\n    }\n    else {\n        uint64_t gmem_dims[] = {gmem_rows, gmem_cols};\n        uint32_t smem_dims[] = {smem_rows, smem_cols};\n        return make_2d_tma_desc(\n            global_address, data_type, gmem_dims, byte_size(data_type, ld ? ld : gmem_rows), smem_dims, swizzle);\n    }\n}\n\nCUtensorMap make_2d_tma_desc(void* ptr, const MatrixLayout& desc, uint2 smem_shape, CUtensorMapSwizzle swizzle)\n{\n    return make_2d_tma_desc(\n        ptr, desc.type, desc.rows, desc.cols, smem_shape.x, smem_shape.y, desc.order, swizzle, desc.ld);\n}\n\n#endif\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/tma.h",
    "content": "#include <cuda.h>\n#include <cudaTypedefs.h>\n#include <cuda_runtime.h>\n#include <stdexcept>\n\n#include \"src/turbomind/core/data_type.h\"\n#include \"src/turbomind/kernels/gemm/types.h\"\n\nnamespace turbomind::gemm {\n\n#if __CUDACC_VER_MAJOR__ >= 12\n\nCUtensorMap make_2d_tma_desc(void*              global_address,\n                             DataType           data_type,\n                             uint32_t           gmem_rows,\n                             uint32_t           gmem_cols,\n                             uint32_t           smem_rows,\n                             uint32_t           smem_cols,\n                             Order              order,\n                             CUtensorMapSwizzle swizzle,\n                             int                ld = 0);\n\nCUtensorMap make_2d_tma_desc(void* ptr, const MatrixLayout& desc, uint2 smem_shape, CUtensorMapSwizzle swizzle);\n\nconstexpr CUtensorMapSwizzle get_tma_swizzle(int bytes)\n{\n    switch (bytes) {\n        case 128:\n            return CU_TENSOR_MAP_SWIZZLE_128B;\n        case 64:\n            return CU_TENSOR_MAP_SWIZZLE_64B;\n        case 32:\n            return CU_TENSOR_MAP_SWIZZLE_32B;\n        case 16:  // unit swizzle is equivalent to \"none\"\n        case 0:\n            return CU_TENSOR_MAP_SWIZZLE_NONE;\n        default:\n            throw std::logic_error(\"unsupported swizzle type: \" + std::to_string(bytes));\n    }\n    return {};\n}\n\n#endif\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/transform.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include \"src/turbomind/core/data_type.h\"\n\n#include \"src/turbomind/kernels/attention/quantization.h\"\n#include \"src/turbomind/kernels/core/common.h\"\n#include \"src/turbomind/kernels/core/meta.h\"\n#include \"src/turbomind/kernels/gemm/smem_copy.h\"\n#include \"src/turbomind/kernels/gemm/tiled_mma.h\"\n\nnamespace turbomind::gemm {\n\nstruct Transform_Default {\n    template<class T, int Nf, int Mf, int K, int Nd, int Md, class S>\n    __device__ static void apply(Array<T, Nf> (&frag)[K][Mf], int k, Array<T, Nd> (&data)[K][Md], S&, int div)\n    {\n        static_assert(Nf * Mf == Nd * Md);\n        static_assert(Nd % Nf == 0 && Mf % Md == 0);\n        static_assert(sizeof(frag) == sizeof(data));\n\n        // Alignment must be manually enforced for `reinterpret_cast`\n        auto& frag_k = reinterpret_cast<Array<T, Nd>(&)[Md]>(frag[k]);\n        auto& data_k = data[k];\n\n        PRAGMA_UNROLL\n        for (int i = 0; i < std::size(frag_k); ++i) {\n            frag_k[i] = data_k[i];\n        }\n    }\n};\n\ntemplate<int StatStepS, int StatStepC>\nstruct Transform_HMMA_16816 {\n    template<class F, int Nf, int Mf, int K, class D, int Nd, int Md, class S, int Ns, int Ms, int Ks>\n    __device__ static void\n    apply(Array<F, Nf> (&frag)[K][Mf], int k, Array<D, Nd> (&data)[K][Md], Array<S, Ns> (&stat)[Ks][Ms], int div)\n    {\n        static_assert(Nf * Mf == Nd * Md);\n        static_assert(Nd % Nf == 0 && Mf % Md == 0);\n        static_assert(Nf * Mf == Ns * Ms * 4);\n\n        auto& frag_k = reinterpret_cast<Array<F, Nd>(&)[Md]>(frag[k]);\n        auto& stat_k = reinterpret_cast<Array<S, 1>(&)[Ns * Ms]>(stat[k / div]);\n        auto& data_k = data[k];\n\n        PRAGMA_UNROLL\n        for (int m = 0; m < Md; ++m) {\n            auto tmp = ConvertKvCache<D, F>::convert(data_k[m]);\n            static_assert(Nd % 8 == 0);\n            PRAGMA_UNROLL\n            for (int i = 0; i < Nd; i += 8) {\n                PRAGMA_UNROLL\n                for (int s = 0; s < 2; ++s) {\n                    PRAGMA_UNROLL\n                    for (int c = 0; c < 2; ++c) {\n                        const int idx = (m * Nd + i) / 8 * 2 + s * StatStepS + c * StatStepC;\n                        dequant((Array<F, 2>&)tmp[i + s * 4 + c * 2], stat_k[idx]);\n                    }\n                }\n            }\n\n            frag_k[m] = tmp;\n        }\n    }\n\n    template<class F>\n    __device__ static void dequant(Array<F, 2>& x, Array<uint32_t, 1> s)\n    {\n        Array<F, 2>& _s = (Array<F, 2>&)s;\n        x[0]            = __hfma(x[0], _s[0], _s[1]);\n        x[1]            = __hfma(x[1], _s[0], _s[1]);\n    }\n\n    __device__ static void dequant(Array<bfloat16_t, 2>& x, Array<uint8_t, 1> s)\n    {\n        bfloat16_t s1 = __ushort_as_bfloat16((uint16_t)s[0] << 7);\n        x[0]          = __hmul(x[0], s1);\n        x[1]          = __hmul(x[1], s1);\n    }\n\n    __device__ static void dequant(Array<half_t, 2>& x, Array<uint8_t, 1> s)\n    {\n        // half_t s1 = __ushort_as_half(((uint16_t)s[0] + 15 - 127) << 10);\n        // Adjusted in `AdjustUe8m0ScaleForHalf`\n        half_t s1 = __ushort_as_half((uint16_t)s[0] << 10);\n        x[0]      = __hmul(x[0], s1);\n        x[1]      = __hmul(x[1], s1);\n    }\n\n    __device__ static void dequant(Array<bfloat16_t, 2>& x, Array<uint16_t, 1> s)\n    {\n        auto s1 = __ushort_as_bfloat16(s[0]);\n        x[0]    = __hmul(x[0], s1);\n        x[1]    = __hmul(x[1], s1);\n    }\n\n    __device__ static void dequant(Array<half, 2>& x, Array<uint16_t, 1> s)\n    {\n        auto s1 = __ushort_as_half(s[0]);\n        x[0]    = __hmul(x[0], s1);\n        x[1]    = __hmul(x[1], s1);\n    }\n};\n\n// Used by SM70 MMA\nstruct Transform_HMMA_SIMT_B {\n    template<class F, int Nf, int Mf, int K, class D, int Nd, int Md, class S, int Ns, int Ms, int Ks>\n    __device__ static void\n    apply(Array<F, Nf> (&frag)[K][Mf], int k, Array<D, Nd> (&data)[K][Md], Array<S, Ns> (&stat)[Ks][Ms], int div)\n    {\n        static_assert(Nf * Mf == Nd * Md);\n        static_assert(Nd % Nf == 0 && Mf % Md == 0);\n\n        auto& frag_k = reinterpret_cast<Array<F, Nd>(&)[Md]>(frag[k]);\n        auto& stat_k = reinterpret_cast<Array<S, 1>(&)[Ns * Ms]>(stat[k / div]);\n        auto& data_k = data[k];\n\n        // static_assert(Nf != Nf);\n\n        PRAGMA_UNROLL\n        for (int m = 0; m < Md; ++m) {\n            auto tmp = ConvertKvCache<D, F>::convert(data_k[m]);\n            PRAGMA_UNROLL\n            for (int i = 0; i < Nd; i += 2) {\n                dequant((Array<F, 2>&)tmp[i], stat_k[(m * Nd + i) / Nf]);\n            }\n            frag_k[m] = tmp;\n        }\n    }\n\n    template<class F>\n    __device__ static void dequant(Array<F, 2>& x, Array<uint32_t, 1> s)\n    {\n        Array<F, 2>& _s = (Array<F, 2>&)s;\n\n        x[0] = __hfma(x[0], _s[0], _s[1]);\n        x[1] = __hfma(x[1], _s[0], _s[1]);\n    }\n\n    __device__ static void dequant(Array<half_t, 2>& x, Array<uint8_t, 1> s)\n    {\n        // half_t s1 = __ushort_as_half(((uint16_t)s[0] + 15 - 127) << 10);\n        // Adjusted in `AdjustUe8m0ScaleForHalf`\n        half_t s1 = __ushort_as_half((uint16_t)s[0] << 10);\n        x[0]      = __hmul(x[0], s1);\n        x[1]      = __hmul(x[1], s1);\n    }\n\n    __device__ static void dequant(Array<half, 2>& x, Array<uint16_t, 1> s)\n    {\n        auto s1 = __ushort_as_half(s[0]);\n        x[0]    = __hmul(x[0], s1);\n        x[1]    = __hmul(x[1], s1);\n    }\n};\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/tuner/cache_utils.cu",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include \"src/turbomind/kernels/gemm/tuner/cache_utils.h\"\n\nnamespace turbomind::gemm {\n\nCacheFlushing::CacheFlushing()\n{\n    cudaDeviceProp props{};\n    cudaGetDeviceProperties(&props, 0);\n\n    size_ = props.l2CacheSize;\n\n    cudaMalloc(&buffer_, size_);\n}\n\nvoid CacheFlushing::flush(cudaStream_t stream)\n{\n    thread_local CacheFlushing inst{};\n    inst(stream);\n}\n\nvoid CacheFlushing::operator()(cudaStream_t stream) const\n{\n    cudaMemsetAsync(buffer_, 0, size_, stream);\n}\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/tuner/cache_utils.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include <cstdint>\n\nnamespace turbomind::gemm {\n\nclass CacheFlushing {\npublic:\n    static void flush(cudaStream_t stream = {});\n\nprivate:\n    CacheFlushing();\n    void operator()(cudaStream_t stream) const;\n\n    uint32_t* buffer_;\n    size_t    size_;\n};\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/tuner/measurer.cu",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include \"src/turbomind/kernels/gemm/kernel.h\"\n#include \"src/turbomind/kernels/gemm/tuner/cache_utils.h\"\n#include \"src/turbomind/kernels/gemm/tuner/measurer.h\"\n#include <iostream>\n\nnamespace turbomind::gemm {\n\nMeasurer::Measurer(std::unique_ptr<StoppingCriterion> stop_criterion): stop_criterion_{std::move(stop_criterion)}\n{\n    cudaEventCreate(&ev_beg_);\n    cudaEventCreate(&ev_end_);\n}\n\nMeasurer::~Measurer()\n{\n    cudaEventDestroy(ev_beg_);\n    cudaEventDestroy(ev_end_);\n    ev_beg_ = ev_end_ = {};\n}\n\nstd::vector<Measurement>\nMeasurer::Measure(const std::vector<LaunchSpec>& specs, const Launcher& launcher, cudaStream_t stream)\n{\n    std::vector<Measurement> m;\n    m.reserve(specs.size());\n    for (const auto& spec : specs) {\n        auto measure = MeasureOne(spec, launcher, stream);\n        if (measure.sample_count) {\n            m.push_back(measure);\n        }\n        /// TODO: report error\n    }\n    return m;\n}\n\nMeasurement Measurer::MeasureOne(LaunchSpec spec, const Launcher& launcher, cudaStream_t stream)\n{\n    Stats       stats{};\n    cudaError_t status = cudaSuccess;\n    while (true) {\n        float ms{};\n        std::tie(ms, status) = ColdRun(spec, launcher, stream);\n        if (status != cudaSuccess) {\n            break;\n        }\n        stats.add_sample(ms);\n        // std::cout << spec.kernel->name() << \" \" << spec.swizzle << \" \" << stats.count() << \" \" << stats.mean() << \" \"\n        //           << stats.get_variance() << \"\\n\";\n        if (stop_criterion_->should_stop(stats)) {\n            break;\n        }\n    }\n    return Measurement{\n        status,\n        stats.count(),\n        stats.mean(),\n        stats.get_variance(),\n    };\n}\n\nstd::pair<float, cudaError_t> Measurer::ColdRun(LaunchSpec spec, const Launcher& launcher, cudaStream_t stream)\n{\n    CacheFlushing::flush(stream);\n\n    cudaEventRecord(ev_beg_, stream);\n\n    // std::cout << spec.kernel->name() << \" \" << spec.splits << \" \" << spec.swizzle << std::endl;\n\n    launcher(spec, stream);\n\n    cudaEventRecord(ev_end_, stream);\n    cudaEventSynchronize(ev_end_);\n\n    const auto status = cudaGetLastError();\n    float      ms{};\n\n    if (status == cudaSuccess) {\n        cudaEventElapsedTime(&ms, ev_beg_, ev_end_);\n    }\n    else {\n        TM_CHECK(status == cudaSuccess) << cudaGetErrorString(status);\n    }\n\n    return {ms, status};\n}\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/tuner/measurer.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include \"src/turbomind/kernels/gemm/desc.h\"\n#include \"src/turbomind/kernels/gemm/tuner/stopping_criterion.h\"\n#include <climits>\n#include <functional>\n#include <memory>\n#include <vector>\n\nnamespace turbomind::gemm {\n\nstruct Measurement {\n    cudaError_t status;\n    int         sample_count;\n    float       mean;\n    float       variance;\n};\n\nusing Launcher = std::function<int(LaunchSpec, cudaStream_t)>;\n\nclass Measurer {\npublic:\n    Measurer(std::unique_ptr<StoppingCriterion> stop_criterion);\n\n    ~Measurer();\n\n    std::vector<Measurement>\n    Measure(const std::vector<LaunchSpec>& specs, const Launcher& launcher, cudaStream_t stream);\n\nprivate:\n    Measurement MeasureOne(LaunchSpec spec, const Launcher& launcher, cudaStream_t stream);\n\n    std::pair<float, cudaError_t> ColdRun(LaunchSpec spec, const Launcher& launcher, cudaStream_t stream);\n\nprivate:\n    cudaEvent_t                        ev_beg_;\n    cudaEvent_t                        ev_end_;\n    std::unique_ptr<StoppingCriterion> stop_criterion_;\n};\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/tuner/params.cc",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include \"src/turbomind/kernels/gemm/tuner/params.h\"\n#include \"src/turbomind/utils/parser.h\"\n#include <algorithm>\n#include <iostream>\n#include <regex>\n\nnamespace turbomind::gemm {\n\nvoid ParseTuningParams(TuningParams& params, const std::string& str)\n{\n    const auto list = ParseArgsList(str);\n\n    auto try_parse = [&](auto& value, auto name) {\n        auto it = std::find_if(list.begin(), list.end(), [&](auto a) { return a.first == name; });\n        if (it != list.end()) {\n            std::cout << name << \" \" << it->second << \"\\n\";\n            Parse(value, it->second);\n        }\n    };\n\n    try_parse(params.max_splits, \"max_splits\");\n    try_parse(params.max_waves, \"max_waves\");\n    try_parse(params.swizzle, \"swizzle\");\n    try_parse(params.top_k, \"top_k\");\n    try_parse(params.clusters, \"clusters\");\n    try_parse(params.min_iter, \"min_iter\");\n    try_parse(params.max_iter, \"max_iter\");\n    try_parse(params.max_time, \"max_time\");\n\n    if (auto it = std::find_if(list.begin(), list.end(), [&](auto a) { return a.first == \"seq\"; }); it != list.end()) {\n        params.seq = ParseTuningSequence(it->second);\n    }\n}\n\nstd::vector<int> ParseTuningSequence(const std::string& str)\n{\n    const std::regex triplet(R\"((\\d+)-(\\d+)-(\\d+))\");\n\n    std::vector<std::array<int, 3>> generators;\n\n    const auto tokens = ParseListOrTuple(str);\n\n    for (const auto& token : tokens) {\n        std::smatch match;\n        if (std::regex_match(token, match, triplet)) {\n            generators.push_back({std::stoi(match[1].str()),  //\n                                  std::stoi(match[2].str()),\n                                  std::stoi(match[3].str())});\n        }\n        else {  // must be an integer string\n            generators.push_back({std::stoi(token), 0, 0});\n        }\n    }\n\n    if (generators.size() == 1) {  // Replace sentinel of the default generators\n        auto fallback   = GetDefaultTuningGenerators();\n        fallback.back() = {generators.front().front(), 0, 0};\n        generators      = std::move(fallback);\n    }\n\n    return GenerateTuningSequence(generators);\n}\n\nstd::vector<int> GenerateTuningSequence(const std::vector<std::array<int, 3>>& generators)\n{\n    std::vector<int> ret;\n    if (generators.empty()) {\n        return ret;\n    }\n    const int last = generators.back().front();\n    // The last generator is a sentinel `(max_bs, 0, 0)`\n    for (int i = 0; i < (int)generators.size() - 1; ++i) {\n        auto [curr, next, step] = generators[i];\n        if (curr >= last) {\n            break;\n        }\n        if (next == 0 && step == 0) {  // single value\n            ret.push_back(curr);\n        }\n        else {  // generator\n            const int end = std::min(generators[i + 1][0], last);\n            while (curr < end) {\n                ret.push_back(curr);\n                if (curr == next) {\n                    step *= 2;\n                    next *= 2;\n                }\n                curr += step;\n            }\n        }\n    }\n    ret.push_back(last);\n    return ret;\n}\n\nstd::vector<std::array<int, 3>> GetDefaultTuningGenerators()\n{\n    /// TODO: set generators based on device\n    return {{8, 16, 8}, {16, 64, 16}, {65536}};\n}\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/tuner/params.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include <array>\n#include <string>\n#include <vector>\n\nnamespace turbomind::gemm {\n\nstruct TuningParams {\n    // Split-k params\n    int max_splits = 8;\n    int max_waves  = 10;\n\n    // Swizzling params\n    std::vector<int> swizzle{0, 3};\n\n    // Sampling params for hierarchical kernel selection\n    float top_k    = 0;\n    int   clusters = 5;\n    int   min_iter = 1;\n    int   max_iter = 10;\n    float max_time = 1.f;\n\n    std::vector<int> seq;\n};\n\n// example\n//   max_splits=8,top_splits=5,max_waves=16,top_k=10,swizzle=[2,3,4],clusters=5,max_iter=10,min_iter=1,max_time=10.0\nvoid ParseTuningParams(TuningParams& params, const std::string& str);\n\n// example\n//   16-16-128,256-128-1024,8192\nstd::vector<int> ParseTuningSequence(const std::string& str);\n\nstd::vector<int> GenerateTuningSequence(const std::vector<std::array<int, 3>>& generators);\n\nstd::vector<std::array<int, 3>> GetDefaultTuningGenerators();\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/tuner/sampler.cu",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include \"src/turbomind/kernels/gemm/desc.h\"\n#include \"src/turbomind/kernels/gemm/kernel.h\"\n#include \"src/turbomind/kernels/gemm/tuner/sampler.h\"\n#include <algorithm>\n#include <iostream>\n#include <numeric>\n#include <vector>\n\nnamespace turbomind::gemm {\n\ntemplate<class Cmp>\nstatic std::vector<int> ArgSort(size_t size, const Cmp& cmp)\n{\n    std::vector<int> idxs(size);\n    std::iota(idxs.begin(), idxs.end(), 0);\n    std::stable_sort(idxs.begin(), idxs.end(), cmp);\n    return idxs;\n}\n\nstd::vector<LaunchSpec> Sampler::Run(std::vector<LaunchSpec> specs, const Launcher& launcher, cudaStream_t stream)\n{\n    std::vector<std::vector<LaunchSpec>> clusters;  // ptr into `specs`\n    if (k_clusters_) {\n        clusters = Cluster(specs, ClusteringParam{true, true});\n    }\n    else {\n        for (auto& s : specs) {\n            clusters.push_back({s});\n        }\n    }\n    // std::cout << \"k_clusters=\" << k_clusters_ << \", #specs\" << specs.size() << \", #clusters\" << clusters.size() <<\n    // \"\\n\";\n\n    std::vector<LaunchSpec> s_1;\n    for (const auto& c : clusters) {\n        s_1.push_back(c.front());\n    }\n\n    auto m_1 = measurer_.Measure(s_1, launcher, stream);\n\n    auto idxs = ArgSort(m_1.size(), [&](int i, int j) { return m_1[i].mean < m_1[j].mean; });\n\n    if (k_clusters_) {\n        const auto top_k = std::min(k_clusters_, (int)idxs.size());\n        idxs.resize(top_k);\n\n        std::vector<LaunchSpec> s_2;\n        for (const auto& idx : idxs) {\n            auto& cluster = clusters[idx];\n            // Skip cluster leader\n            for (size_t j = 1; j < cluster.size(); ++j) {\n                s_2.push_back(cluster[j]);\n            }\n        }\n\n        // std::cout << \"#s_2=\" << s_2.size() << \"\\n\";\n\n        auto m_2 = measurer_.Measure(s_2, launcher, stream);\n        // Merge measurements of the 2 runs\n        m_2.insert(m_2.end(), m_1.begin(), m_1.end());\n        s_2.insert(s_2.end(), s_1.begin(), s_1.end());\n        m_1.swap(m_2);\n        s_1.swap(s_2);\n    }\n\n    idxs = ArgSort(m_1.size(), [&](int i, int j) { return m_1[i].mean < m_1[j].mean; });\n\n    std::vector<LaunchSpec> ret;\n    for (const auto& i : idxs) {\n        s_1[i].measured = m_1[i].mean;\n        ret.push_back(s_1[i]);\n    }\n\n    return ret;\n}\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/tuner/sampler.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include \"src/turbomind/kernels/gemm/desc.h\"\n#include \"src/turbomind/kernels/gemm/tuner/measurer.h\"\n\n#include <vector>\n\nnamespace turbomind::gemm {\n\nclass Sampler {\npublic:\n    explicit Sampler(Measurer& measurer, int k_clusters): measurer_{measurer}, k_clusters_{k_clusters} {}\n\n    std::vector<LaunchSpec> Run(std::vector<LaunchSpec> specs, const Launcher& launcher, cudaStream_t stream);\n\nprivate:\n    Measurer& measurer_;\n    int       k_clusters_;\n};\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/tuner/stats.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include <limits>\n\nnamespace turbomind::gemm {\n\nclass Stats {\npublic:\n    Stats(): count_{}, mean_{}, m2_{} {}\n\n    float mean() const noexcept\n    {\n        return mean_;\n    }\n\n    float sum() const noexcept\n    {\n        return mean_ * count_;\n    }\n\n    int count() const noexcept\n    {\n        return count_;\n    }\n\n    float get_variance() const noexcept\n    {\n        return count_ < 2 ? std::numeric_limits<float>::quiet_NaN() : m2_ / count_;\n    }\n\n    void add_sample(float x) noexcept\n    {\n        ++count_;\n        float delta = x - mean_;\n        mean_ += delta / count_;\n        float delta2 = x - mean_;\n        m2_ += delta * delta2;\n    }\n\nprivate:\n    int   count_;\n    float mean_;\n    float m2_;\n};\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/tuner/stopping_criterion.cc",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include \"src/turbomind/kernels/gemm/tuner/stopping_criterion.h\"\n#include <memory>\n\nnamespace turbomind::gemm {\n\nnamespace stopping_criterions {\n\nclass Optimistic: public StoppingCriterion {\npublic:\n    Optimistic(int min_iter, int max_iter, float max_ms)\n    {\n        min_iter_ = std::max(min_iter, 1);\n        max_iter_ = max_iter > 0 ? max_iter : std::numeric_limits<int>::max();\n        max_ms_   = max_ms > 0 ? max_ms : std::numeric_limits<float>::infinity();\n    }\n    bool should_stop(const Stats& stats) override\n    {\n        return stats.count() >= min_iter_ && (stats.count() >= max_iter_ || stats.sum() >= max_ms_);\n    }\n\nprivate:\n    int   min_iter_;\n    int   max_iter_;\n    float max_ms_;\n};\n\n}  // namespace stopping_criterions\n\nstd::unique_ptr<StoppingCriterion> CreateStoppingCriterion(int min_iter, int max_iter, float max_ms)\n{\n    return std::make_unique<stopping_criterions::Optimistic>(min_iter, max_iter, max_ms);\n}\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/tuner/stopping_criterion.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include \"src/turbomind/kernels/gemm/tuner/stats.h\"\n#include <memory>\n\nnamespace turbomind::gemm {\n\nclass StoppingCriterion {\npublic:\n    virtual ~StoppingCriterion()                 = default;\n    virtual bool should_stop(const Stats& stats) = 0;\n};\n\nstd::unique_ptr<StoppingCriterion> CreateStoppingCriterion(int min_iter, int max_iter, float max_ms);\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/types.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include \"src/turbomind/kernels/core/data_type.h\"\n#include <cuda_fp16.h>\n#include <cuda_runtime.h>\n\n#if ENABLE_BF16\n#include <cuda_bf16.h>\n#endif\n\nnamespace turbomind::gemm {\n\nenum class Order : int\n{\n    kColMajor = 0,\n    kRowMajor = 1,\n};\n\ninline constexpr Order kColMajor = Order::kColMajor;\ninline constexpr Order kRowMajor = Order::kRowMajor;\n\nconstexpr Order operator~(Order a)\n{\n    return a == kColMajor ? kRowMajor : kColMajor;\n}\n\nconstexpr const char* to_string(Order order)\n{\n    switch (order) {\n        case kColMajor:\n            return \"Col\";\n        case kRowMajor:\n            return \"Row\";\n    }\n    return \"\";\n}\n\nusing Pack = uint32_t;\n\ntypedef enum MMA_Tag\n{\n    HMMA_16816 = 0x100,  // sm80+\n    HMMA_1688  = 0x200,  // sm75\n    HMMA_884   = 0x300,  // sm70\n    HMMA_SIMT  = 0x400,  // sm75-\n} MMA_Tag;\n\ntypedef enum Op_Tag\n{\n    OPERAND_A = 0x010,\n    OPERAND_B = 0x020,\n    OPERAND_U = 0x030,\n    OPERAND_V = 0x040,\n    OPERAND_C = 0x050,\n    OPERAND_D = 0x060,\n} Op_Tag;\n\nconstexpr MMA_Tag get_mma_tag(Pack pack)\n{\n    return static_cast<MMA_Tag>(pack & 0xf00);\n}\n\nconstexpr Op_Tag get_operand_tag(Pack pack)\n{\n    return static_cast<Op_Tag>(pack & 0x0f0);\n}\n\nconstexpr int get_pack_num(Pack pack)\n{\n    return pack & 0x00f;\n}\n\nenum class Striding : int\n{\n    kFlat,     // [1111,2222,3333]\n    kRagged,   // [11,2222222,333]  [0 , 2      , 9  ]\n    kIndexed,  // [xx xxxxxxx xxx], [01, 2345678, 9ab]\n    kBlocked,  // [11][22222][333]\n};\n\ninline const char* to_string(Striding striding)\n{\n    switch (striding) {\n        case Striding::kFlat:\n            return \"f\";\n        case Striding::kRagged:\n            return \"r\";\n        case Striding::kIndexed:\n            return \"i\";\n        case Striding::kBlocked:\n            return \"b\";\n        default:\n            return \"unknown\";\n    }\n}\n\nenum class QuantType : int\n{\n    kNone    = 0,\n    kK       = 1,\n    kM       = 2,\n    kB       = 3,\n    kDefault = kK,\n};\n\ninline const char* to_string(QuantType q)\n{\n    switch (q) {\n        case QuantType::kNone:\n            return \"none\";\n        case QuantType::kK:\n            return \"k\";\n        case QuantType::kM:\n            return \"m\";\n        case QuantType::kB:\n            return \"b\";\n        default:\n            return \"unknown\";\n    }\n}\n\nenum class Epilogue : int\n{\n    kNone               = 0,\n    kChannelCombination = 0x1,\n    kGatedSilu          = 0x2,\n};\n\nstruct QuantDesc {\n    QuantType type;\n    int       group_size;\n\n    operator bool() const noexcept\n    {\n        return (int)type || group_size;\n    }\n};\n\ninline std::string to_string(QuantDesc desc)\n{\n    if (desc) {\n        return to_string(desc.type) + std::to_string(desc.group_size);\n    }\n    else {\n        return to_string(desc.type);\n    }\n}\n\nenum class DispatchPolicy : int\n{\n    kDefault = 0,\n    kMeasure = 1,\n    kReuse   = 2,\n    kAppend  = 3,\n};\n\nconstexpr bool operator&(const DispatchPolicy& a, const DispatchPolicy& b)\n{\n    return ((int)a & (int)b);\n}\n\nclass Kernel;\nclass Context;\n\nstruct Tape {\n    int   ctas;\n    int   max_num;\n    int   max_ctas;\n    char* buffer;\n    int4* gemm_shapes;\n    int4* tiled_shapes;\n    int4* tile_offsets;\n    int2* iter_k_ranges;\n    int*  tile_ids;\n};\n\nstruct Operation {\n    DispatchPolicy dispatch;\n    Epilogue       epilogue;\n    QuantDesc      quant_a;\n    QuantDesc      quant_b;\n    int            batch_dim;\n    // void*          reserved;\n};\n\ninline Operation transpose(Operation o)\n{\n    std::swap(o.quant_a, o.quant_b);\n    o.batch_dim = 1 - o.batch_dim;\n    return o;\n}\n\nstruct MatrixLayout {\n    DataType type;\n    Order    order;\n    int      rows;\n    int      cols;\n    int      ld;\n    Pack     pack;\n    int      num;\n    int*     offsets;\n    int*     idxs;\n};\n\ninline std::ostream& operator<<(std::ostream& os, const MatrixLayout& x)\n{\n    os << x.type << \" \" << to_string(x.order) << \" \" << x.rows << \" \" << x.cols << \" \" << x.num << \" \" << x.ld;\n    return os;\n}\n\ninline int64_t byte_size(const MatrixLayout& m)\n{\n    return byte_size(m.type, (int64_t)m.rows * m.cols);\n}\n\ninline Striding get_mode(const MatrixLayout& m)\n{\n    if (m.idxs) {\n        return Striding::kIndexed;\n    }\n    else if (m.ld == 0 || m.offsets) {\n        return Striding::kBlocked;\n    }\n    return Striding::kFlat;\n}\n\nstruct Workspace {\n    void*  barriers;\n    size_t barriers_size;\n    void*  partials;\n    size_t partials_size;\n    void*  tensormaps;\n    size_t tensormaps_size;\n    int*   flags;\n};\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/unpack.cu",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include \"src/turbomind/kernels/core/array_ops.h\"\n#include \"src/turbomind/kernels/core/common.h\"\n#include \"src/turbomind/kernels/core/data_type.h\"\n#include <iostream>\n\nnamespace turbomind {\n\nnamespace {\n\n__device__ void atomic_assign_u4(uint32_t* address, uint32_t index, uint32_t value)\n{\n    uint32_t old = *address;\n    uint32_t assumed;\n    do {\n        assumed      = old;\n        uint32_t tmp = (assumed & ~(0xfu << (index * 4u))) | (value << (index * 4u));\n        old          = atomicCAS(address, assumed, tmp);\n    } while (assumed != old);\n}\n\n__device__ uint32_t read_u4(const uint32_t* address, uint32_t index)\n{\n    return (*address >> (index * 4u)) & 0xfu;\n}\n\ntemplate<int... Ds>\n__global__ void permute_u4(uint* dst, const uint* src, Array<int, sizeof...(Ds)> dims)\n{\n    constexpr int N = sizeof...(Ds);\n\n    size_t count = 1;\n    PRAGMA_UNROLL\n    for (int i = 0; i < N; ++i) {\n        count *= dims[i];\n    }\n\n    constexpr int order[] = {Ds...};\n\n    for (int i = threadIdx.x + blockDim.x * blockIdx.x; i < count; i += blockDim.x * gridDim.x) {\n\n        int indices[N]{};\n\n        PRAGMA_UNROLL\n        for (int j = N - 1, ii = i; j >= 0; --j) {\n            indices[j] = ii % dims[j];\n            ii /= dims[j];\n        }\n\n        auto data = read_u4(src + i / 8, i % 8);\n\n        int index = 0;\n\n        PRAGMA_UNROLL\n        for (int j = N - 1, stride = 1; j >= 0; --j) {\n            index += indices[order[j]] * stride;\n            stride *= dims[order[j]];\n        }\n\n        atomic_assign_u4(dst + index / 8, index % 8, data);\n    }\n}\n\n}  // namespace\n\n// col-major interleaved\nvoid unpack_awq_gemm(uint4_t* dst, const uint4_t* src, int rows, int cols, cudaStream_t st)\n{\n    Array<int, 4> shape{cols, rows / 8, 2, 4};\n    permute_u4<0, 1, 3, 2><<<512, 512, 0, st>>>((uint*)dst, (const uint*)src, shape);\n}\n\n__global__ void transpose_u4_kernel(uint4_t* dst, const uint4_t* src, int s, int c)\n{\n    const int idx_c = 8 * (threadIdx.x + blockIdx.x * blockDim.x);\n    const int idx_s = 8 * (threadIdx.y + blockIdx.y * blockDim.y);\n    if (idx_c >= c || idx_s >= s) {\n        return;\n    }\n    uint32_t ivec[8];\n    PRAGMA_UNROLL\n    for (int i = 0; i < 8; ++i) {\n        ivec[i] = ((const uint32_t*)src)[((idx_s + i) * c + idx_c) / 8];\n    }\n    uint32_t ovec[8]{};\n    PRAGMA_UNROLL\n    for (int i = 0; i < 8; ++i) {\n        PRAGMA_UNROLL\n        for (int j = 0; j < 8; ++j) {\n            ovec[i] |= (((ivec[j] >> (i * 4)) & 0xfu) << (j * 4));\n        }\n    }\n    PRAGMA_UNROLL\n    for (int i = 0; i < 8; ++i) {\n        ((uint32_t*)dst)[((idx_c + i) * s + idx_s) / 8] = ovec[i];\n    }\n}\n\nvoid transpose_u4(uint4_t* dst, const uint4_t* src, int s, int c, cudaStream_t st)\n{\n    if (s % 8 || c % 8) {\n        std::cerr << \"transpose_u4: invalid shape (\" << s << \",\" << c << \"), must be multiple of 8\" << std::endl;\n        return;\n    }\n    // Array<int, 2> shape{s, c};\n    // permute_u4<1, 0><<<512, 512, 0, st>>>((uint*)dst, (const uint*)src, shape);\n\n    const dim3 block(16, 16);\n    const dim3 grid((c + 15) / 16, (s + 15) / 16);\n    transpose_u4_kernel<<<grid, block, 0, st>>>(dst, src, s, c);\n}\n\n// load -> unpack -> extend_to_u8 -> manipulation -> compat_to_u4 -> store\n// load -> extend_to_u16 -> convert -> run\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/kernels/gemm/utils.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include \"src/turbomind/kernels/gemm/simt.h\"\n#include \"src/turbomind/kernels/gemm/types.h\"\n\nnamespace turbomind::gemm {\n\n__host__ __device__ constexpr Order transpose(Order order)\n{\n    return order == Order::kColMajor ? Order::kRowMajor : Order::kColMajor;\n}\n\n__host__ __device__ constexpr MatrixLayout transpose(MatrixLayout x)\n{\n    auto tmp = x.cols;  // `std::swap` is not constexpr\n    x.cols   = x.rows;\n    x.rows   = tmp;\n    x.order  = transpose(x.order);\n    return x;\n}\n\ntemplate<Order order>\n__host__ __device__ constexpr int2 mk2cs(int m, int k)\n{\n    if constexpr (order == Order::kRowMajor) {\n        return {k, m};\n    }\n    else {\n        return {m, k};\n    }\n}\n\ntemplate<Order order>\n__host__ __device__ constexpr int2 mk2cs(int2 mk)\n{\n    return mk2cs<order>(mk.x, mk.y);\n}\n\ntemplate<Order order>\n__host__ __device__ constexpr int2 cs2mk(int c, int s)\n{\n    if constexpr (order == Order::kRowMajor) {\n        return {s, c};\n    }\n    else {\n        return {c, s};\n    }\n}\n\ntemplate<Order order>\n__host__ __device__ constexpr int2 cs2mk(int2 cs)\n{\n    return cs2mk<order>(cs.x, cs.y);\n}\n\ntemplate<Order order>\n__host__ __device__ constexpr int2 _kn2cs(int k, int n)\n{\n    if constexpr (order == Order::kColMajor) {\n        return {k, n};\n    }\n    else {\n        return {n, k};\n    }\n}\n\ntemplate<class Index>\n__host__ __device__ constexpr Index cs2idx(int2 cs, Index ld)\n{\n    return ld * cs.y + cs.x;\n}\n\ntemplate<class Index>\n__host__ __device__ constexpr Index cs2idx(int2 cs, Index ld, int s0)\n{\n    return ld * (cs.y + s0) + cs.x;\n}\n\n__host__ __device__ constexpr auto dot(int2 a, int2 b)\n{\n    return a.x * b.x + a.y * b.y;\n}\n\n__host__ __device__ constexpr auto dot(int2 a, long2 b)\n{\n    return a.x * b.x + a.y * b.y;\n}\n\ntemplate<MMA_Tag mma, Op_Tag op, int num, Order order>\nstruct PackingImpl {\n    __host__ __device__ static constexpr int2 apply(int2 mk)\n    {\n        return mk;\n    }\n};\n\ntemplate<Pack pack, Order order>\nstruct Packing_v2: PackingImpl<get_mma_tag(pack), get_operand_tag(pack), get_pack_num(pack), order> {\n};\n\n/// TODO: move packing utility to arch/smem_copy_xxx\n\ntemplate<int num>\nstruct PackingImpl<HMMA_16816, OPERAND_A, num, kRowMajor> {\n    __host__ __device__ static constexpr int2 apply(int2 mk)\n    {\n        return {mk.x / 16 / num, mk.y * 16 * num};\n    }\n};\n\ntemplate<int num>\nstruct PackingImpl<HMMA_16816, OPERAND_A, num, kColMajor> {\n    __host__ __device__ static constexpr int2 apply(int2 mk)\n    {\n        return {mk.x * 16, mk.y / 16};\n    }\n};\n\ntemplate<int num, Order order>\nstruct PackingImpl<HMMA_16816, OPERAND_B, num, order>: PackingImpl<HMMA_16816, OPERAND_A, num, order> {\n};\n\ntemplate<int num>\nstruct PackingImpl<HMMA_SIMT, OPERAND_A, num, kRowMajor> {\n    __host__ __device__ static constexpr int2 apply(int2 mk)\n    {\n        return {mk.x / (simt::OP_M * num), mk.y * simt::OP_M * num};\n    }\n};\n\ntemplate<int num>\nstruct PackingImpl<HMMA_SIMT, OPERAND_B, num, kRowMajor> {\n    __host__ __device__ static constexpr int2 apply(int2 mk)\n    {\n        return {mk.x / (simt::OP_N * num), mk.y * simt::OP_N * num};\n    }\n};\n\ntemplate<int num>\nstruct PackingImpl<HMMA_884, OPERAND_B, num, kRowMajor> {\n    __host__ __device__ static constexpr int2 apply(int2 mk)\n    {\n        // return {mk.x / (16 * num), mk.y * 16 * num};\n        return {mk.x / (32 * num), mk.y * 32 * num};\n    }\n};\n\n}  // namespace turbomind::gemm\n"
  },
  {
    "path": "src/turbomind/kernels/gpt_kernels.cu",
    "content": "/*\n * Copyright (c) 2020-2023, NVIDIA CORPORATION.  All rights reserved.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include <cub/cub.cuh>\n\n#include \"src/turbomind/kernels/core/array_ops.h\"\n#include \"src/turbomind/kernels/gpt_kernels.h\"\n#include \"src/turbomind/utils/memory_utils.h\"\n\nnamespace turbomind {\n\ntemplate<class T, int vec_size>\n__global__ void\nembeddingLookupKernel(T* dst, int dst_stride, const T* src, int src_stride, const int* ids, int num, int dim)\n{\n    const int ti = blockIdx.x;\n\n    const int64_t idx = ids[ti];\n\n    src += idx * src_stride;\n    dst += ti * dst_stride;\n\n    for (int di = threadIdx.x * vec_size; di < dim; di += blockDim.x * vec_size) {\n        Array<T, vec_size> vec;\n        Ldg(vec, &src[di]);\n        Store(&dst[di], vec);\n    }\n}\n\nvoid invokeEmbeddingLookup(Ref<Tensor>         out_,\n                           const Buffer_<int>& token_ids,\n                           const Tensor&       embedding_table,\n                           cudaStream_t        st)\n{\n    auto& out = out_.get();\n\n    TM_CHECK_EQ(out.shape(0), token_ids.size());\n    TM_CHECK_EQ(out.shape(1), embedding_table.shape(1));\n\n    int num, dim;\n    std::tie(num, dim) = out.shapes(0, 1);\n\n    auto invoke = [&](auto t) {\n        using T                = decltype(t);\n        constexpr int vec_size = sizeof(uint4) / sizeof(T);\n        TM_CHECK(dim % vec_size == 0) << dim << \" \" << vec_size;\n        const int threads = std::min(dim / vec_size, 1024);\n        const int blocks  = num;\n        TM_CHECK(out_.get());\n        TM_CHECK(token_ids);\n        TM_CHECK(embedding_table);\n        embeddingLookupKernel<T, vec_size><<<blocks, threads, 0, st>>>((T*)out.raw_data(),\n                                                                       out.stride(0),\n                                                                       (const T*)embedding_table.raw_data(),\n                                                                       embedding_table.stride(0),\n                                                                       token_ids.data(),\n                                                                       num,\n                                                                       dim);\n    };\n\n    if (byte_size(out.dtype()) == byte_size<uint16_t>()) {\n        return invoke(uint16_t{});\n    }\n    TM_CHECK(0) << \"not implemented\";\n}\n\n// TODO Add half2 implementation\ntemplate<typename T>\n__global__ void transposeAxis01(T* out, T* in, const int dim0, const int dim1, const int dim2)\n{\n    int index = threadIdx.x + blockIdx.x * blockDim.x;\n    if (index < dim0 * dim1 * dim2) {\n        const int input_dim2_index = index % dim2;\n        index                      = (index - input_dim2_index) / dim2;\n        const int input_dim1_index = index % dim1;\n        index                      = (index - input_dim1_index) / dim1;\n        const int input_dim0_index = index % dim0;\n\n        out[input_dim1_index * dim0 * dim2 + input_dim0_index * dim2 + input_dim2_index] =\n            in[input_dim0_index * dim1 * dim2 + input_dim1_index * dim2 + input_dim2_index];\n    }\n}\n\ntemplate<typename T>\nvoid invokeTransposeAxis01(T* out, T* in, const int dim0, const int dim1, const int dim2, cudaStream_t stream)\n{\n    dim3 block(512);\n    dim3 grid((int)(ceil(dim0 * dim1 * dim2 / 512.)));\n    transposeAxis01<<<grid, block, 0, stream>>>(out, in, dim0, dim1, dim2);\n}\n\ntemplate void\ninvokeTransposeAxis01(float* out, float* in, const int dim0, const int dim1, const int dim2, cudaStream_t stream);\n\ntemplate void\ninvokeTransposeAxis01(half* out, half* in, const int dim0, const int dim1, const int dim2, cudaStream_t stream);\n\ntemplate void\ninvokeTransposeAxis01(int* out, int* in, const int dim0, const int dim1, const int dim2, cudaStream_t stream);\n\ntemplate void\ninvokeTransposeAxis01(uint16_t* out, uint16_t* in, const int dim0, const int dim1, const int dim2, cudaStream_t stream);\n\ntemplate void\ninvokeTransposeAxis01(uint8_t* out, uint8_t* in, const int dim0, const int dim1, const int dim2, cudaStream_t stream);\n\n#ifdef ENABLE_BF16\ntemplate void invokeTransposeAxis01(\n    __nv_bfloat16* out, __nv_bfloat16* in, const int dim0, const int dim1, const int dim2, cudaStream_t stream);\n#endif\n\ntemplate<typename T>\n__global__ void transposeAxis01(T* out, T* in, const int* in_skipping_dim1, const int dim0, const int dim1)\n{\n    // out: [dim1, dim0]\n    // in: [dim0, dim1]\n    // in_skipping_dim1: [dim1]\n\n    int index = threadIdx.x + blockIdx.x * blockDim.x;\n    if (index < dim0 * dim1) {\n        const int input_dim1_index = index % dim1;\n        index                      = (index - input_dim1_index) / dim1;\n        const int input_dim0_index = index % dim0;\n        const int in_offset        = in_skipping_dim1 == nullptr ? 0 : in_skipping_dim1[input_dim1_index] * dim1;\n\n        out[input_dim1_index * dim0 + input_dim0_index] = in[in_offset + input_dim0_index * dim1 + input_dim1_index];\n    }\n}\n\ntemplate<typename T>\nvoid invokeTransposeAxis01(\n    T* out, T* in, const int* in_skipping_dim1, const int dim0, const int dim1, cudaStream_t stream)\n{\n    dim3 block(512);\n    dim3 grid((int)(ceil(dim0 * dim1 / 512.)));\n    transposeAxis01<<<grid, block, 0, stream>>>(out, in, in_skipping_dim1, dim0, dim1);\n}\n\ntemplate void invokeTransposeAxis01(\n    int* out, int* in, const int* in_skipping_dim1, const int dim0, const int dim1, cudaStream_t stream);\n\ntemplate<int TILE_DIM, int BLOCK_ROWS, class T>\n__global__ void transpose_2d_kernel(T* __restrict__ dst, const T* __restrict__ src, int rows, int cols, bool swap_xy)\n{\n    __shared__ T smem[TILE_DIM][TILE_DIM + 1];\n\n    const int block_idx_x = swap_xy ? blockIdx.y : blockIdx.x;\n    const int block_idx_y = swap_xy ? blockIdx.x : blockIdx.y;\n\n    {\n        const int j = block_idx_x * TILE_DIM + threadIdx.x;\n        const int i = block_idx_y * TILE_DIM + threadIdx.y;\n\n#pragma unroll\n        for (int y = 0; y < TILE_DIM; y += BLOCK_ROWS) {\n            if (i + y < rows && j < cols) {\n                smem[threadIdx.y + y][threadIdx.x] = src[(i + y) * cols + j];\n            }\n        }\n    }\n\n    __syncthreads();\n\n    {\n        const int j = block_idx_y * TILE_DIM + threadIdx.x;\n        const int i = block_idx_x * TILE_DIM + threadIdx.y;\n\n#pragma unroll\n        for (int y = 0; y < TILE_DIM; y += BLOCK_ROWS) {\n            if (i + y < cols && j < rows) {\n                dst[(i + y) * rows + j] = smem[threadIdx.x][threadIdx.y + y];\n            }\n        }\n    }\n}\n\ntemplate<class T>\nvoid invokeTranspose2D_(T* dst, const T* src, int rows, int cols, cudaStream_t st)\n{\n    constexpr int TILE_DIM   = 32;  // warp size\n    constexpr int BLOCK_ROWS = 8;\n\n    const dim3 block(TILE_DIM, BLOCK_ROWS);\n\n    dim3 grid((cols + TILE_DIM - 1) / TILE_DIM,  //\n              (rows + TILE_DIM - 1) / TILE_DIM);\n    bool swap_xy = false;\n\n    if (grid.y > 65535) {  // max dim for grid.y\n        std::swap(grid.x, grid.y);\n        swap_xy = true;\n    }\n\n    transpose_2d_kernel<TILE_DIM, BLOCK_ROWS><<<grid, block, 0, st>>>(dst, src, rows, cols, swap_xy);\n}\n\ntemplate void invokeTranspose2D_(uint32_t*, const uint32_t*, int, int, cudaStream_t);\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/kernels/gpt_kernels.h",
    "content": "/*\n * Copyright (c) 2019-2023, NVIDIA CORPORATION.  All rights reserved.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#pragma once\n\n#include <cuda_fp16.h>\n#include <cuda_runtime.h>\n#include <unordered_map>\n\n#include \"src/turbomind/core/core.h\"\n#include \"src/turbomind/utils/memory_utils.h\"\n\nnamespace turbomind {\n\ntemplate<typename T>\nstruct inputIdsEmbeddingLookupPosEncodingSoftPromptParam {\n    T*           from_tensor;\n    int*         output_ids;\n    int*         input_lengths;\n    const T*     embedding_table;\n    const T*     pos_table;\n    const float* prefix_soft_prompt_embedding;\n    const int*   prefix_soft_prompt_lengths;\n    int*         input_ids;\n    int          start_step;\n    int          max_input_length;\n    int          max_prefix_soft_prompt_length;\n    int          batch_size;\n    int          beam_width;\n    int          hidden_units;\n    cudaStream_t stream;\n};\n\ntemplate<typename T>\nstruct pPromptTuningParam {\n    // Batch number of ptrs, each ptr is the ptr of the specific p/prompt tuning weights for this sequence\n    const T** p_prompt_tuning_batch_weights = nullptr;\n    // The start id of p_prompt_tuning token ids (based on the tokenizer)\n    // PROMPT_0 --> p_prompt_tuning_id_start; PROMPT_1 --> p_prompt_tuning_id_start + 1; ...\n    const int p_prompt_tuning_id_start = 0;\n    // Request prompt embeddding's max length\n    const int request_prompt_max_length = 0;\n    // Whether or not use the request prompt embeddings\n    const bool use_request_p_prompt_embedding = false;\n    // Request prompt embeddings\n    const T* request_prompt_embedding = nullptr;\n};\n\ntemplate<typename T>\nvoid invokeInputIdsEmbeddingLookupPosEncoding(T*                    from_tensor,\n                                              int*                  output_ids,\n                                              const T*              embedding_table,\n                                              const T*              pos_table,\n                                              pPromptTuningParam<T> prompt_param,\n                                              const int*            input_ids,\n                                              const int             start_step,\n                                              const int             length,\n                                              const int             max_length,\n                                              const int             batch_size,\n                                              const int             hidden_units,\n                                              cudaStream_t          stream);\n\ntemplate<typename T>\nvoid invokeInputIdsEmbeddingLookupPosEncodingSoftPrompt(inputIdsEmbeddingLookupPosEncodingSoftPromptParam<T> param);\n\ntemplate<typename T>\nvoid invokeTransposeAxis01(T* out, T* in, const int dim0, const int dim1, const int dim2, cudaStream_t stream);\n\ntemplate<typename T>\nvoid invokeTransposeAxis01(\n    T* out, T* in, const int* in_skipping_dim1, const int dim0, const int dim1, cudaStream_t stream);\n\ntemplate<typename T>\nvoid invokeBuildDecoderAttentionMask(T*           attention_mask,\n                                     const int*   sequence_lengths,\n                                     const int*   prefix_prompt_lengths,\n                                     const int    batch_size,\n                                     const int    max_seq_len,\n                                     const int    max_prompt_length,\n                                     cudaStream_t stream);\n\ntemplate<typename T>\nvoid invokeLookupHiddenStateOfLastToken(T*           from_tensor,\n                                        const T*     hidden_state,\n                                        const int*   input_lengths,\n                                        const int    max_input_length,\n                                        const int    batch_size,\n                                        const int    hidden_units,\n                                        cudaStream_t stream);\n\nvoid invokeTileGptPromptInputs(int*         tiled_input_ids,\n                               int*         tiled_input_lengths,\n                               int*         tiled_prompt_lengths,\n                               const int*   input_ids,\n                               const int*   input_lengths,\n                               const int*   prefix_prompt_lengths,\n                               const int    batch_size,\n                               const int    beam_width,\n                               const int    max_input_length,\n                               cudaStream_t stream);\n\nvoid invokeTileGptInputs(int*         tiled_input_ids,\n                         int*         tiled_input_lengths,\n                         const int*   input_ids,\n                         const int*   input_lengths,\n                         const int    batch_size,\n                         const int    beam_width,\n                         const int    max_input_length,\n                         cudaStream_t stream);\n\nvoid invokeFindContextDups(int*         shared_contexts,\n                           int*         batch_to_compact,\n                           int*         compact_to_batch,\n                           int*         compact_size,\n                           const int*   input_ids,\n                           const size_t batch_size,\n                           const size_t input_seq_len,\n                           cudaStream_t stream = 0);\n\ntemplate<typename T>\nvoid invokeCompactInputs(T*           compact_input,\n                         T*           compact_attention_mask,\n                         int*         compact_input_lengths,\n                         const T*     decoder_input,\n                         const T*     decoder_mask,\n                         const int*   input_lengths,\n                         const int*   compact_idx,\n                         size_t       compact_size,\n                         size_t       seq_len,\n                         size_t       hidden_dimension,\n                         cudaStream_t stream = 0);\n\ntemplate<typename T>\nvoid invokeUnCompactOutputs(T*           uncompact_buffer,\n                            const T*     compact_buffer,\n                            const int*   batch_to_compact_idx,\n                            size_t       batch_size,\n                            size_t       buffer_stride,\n                            cudaStream_t stream = 0);\n\ntemplate<typename T>\nvoid invokeUnCompactCaches(T*           uncompact_k_cache,\n                           T*           uncompact_v_cache,\n                           const T*     compact_k_cache,\n                           const T*     compact_v_cache,\n                           const int*   batch_to_compact_idx,\n                           size_t       batch_size,\n                           size_t       num_heads,\n                           size_t       max_seq_len,\n                           size_t       seq_len,\n                           size_t       size_per_head,\n                           size_t       local_batch_size,\n                           size_t       ite,\n                           cudaStream_t stream = 0);\n\nvoid invokeUpdatePaddingCount(int*         total_padding_count,\n                              const int*   input_lengths,\n                              const int*   tiled_prompt_lengths,\n                              size_t       max_input_length,\n                              size_t       max_prompt_length,\n                              size_t       batch_size,\n                              size_t       beam_width,\n                              cudaStream_t stream = 0);\n\ninline void invokeUpdatePaddingCount(int*         total_padding_count,\n                                     const int*   input_lengths,\n                                     size_t       max_input_length,\n                                     size_t       batch_size,\n                                     size_t       beam_width,\n                                     cudaStream_t stream = 0)\n{\n    invokeUpdatePaddingCount(\n        total_padding_count, input_lengths, (const int*)nullptr, max_input_length, 0, batch_size, beam_width, stream);\n}\n\nvoid invokeMaskPaddingTokens(bool*        masked_tokens,\n                             const int*   input_lengths,\n                             const int*   tiled_prefix_prompt_lengths,\n                             const size_t memory_len,\n                             const size_t max_input_length,\n                             const size_t initial_step,\n                             size_t       batch_size,\n                             size_t       beam_width,\n                             cudaStream_t stream = 0);\n\ninline void invokeMaskPaddingTokens(bool*        masked_tokens,\n                                    const int*   input_lengths,\n                                    const size_t memory_len,\n                                    const size_t max_input_length,\n                                    const size_t initial_step,\n                                    size_t       batch_size,\n                                    size_t       beam_width,\n                                    cudaStream_t stream = 0)\n{\n    invokeMaskPaddingTokens(masked_tokens,\n                            input_lengths,\n                            (const int*)nullptr,\n                            memory_len,\n                            max_input_length,\n                            initial_step,\n                            batch_size,\n                            beam_width,\n                            stream);\n}\n\ntemplate<typename T>\nvoid invokeSumLengthDimension(float*       out_buf,\n                              const T*     in_buf,\n                              const size_t batch_size,\n                              const size_t input_length,\n                              const size_t hidden_dim,\n                              cudaStream_t stream = 0);\n\ntemplate<class T>\nvoid invokeTranspose2D_(T* dst, const T* src, int rows, int cols, cudaStream_t st);\n\ntemplate<class T>\nvoid invokeTranspose2D(T* dst, const T* src, int rows, int cols, cudaStream_t st)\n{\n    if constexpr (sizeof(T) == 4) {\n        // FT_CHECK(0);\n        invokeTranspose2D_((uint32_t*)dst, (const uint32_t*)src, rows, cols, st);\n    }\n    else {\n        FT_CHECK(0);\n    }\n}\n\nvoid invokeEmbeddingLookup(Ref<Tensor>         out_,\n                           const Buffer_<int>& token_ids,\n                           const Tensor&       embedding_table,\n                           cudaStream_t        st);\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/kernels/logprob_kernels.cu",
    "content": "/*\n * Copyright (c) 2022-2023, NVIDIA CORPORATION.  All rights reserved.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include <assert.h>\n#include <cuda_fp16.h>\n#include <cuda_runtime.h>\n\n#ifndef CUDART_VERSION\n#error CUDART_VERSION Undefined!\n#elif (CUDART_VERSION >= 11000)\n#include <cub/cub.cuh>\n#else\n#include \"3rdparty/cub/cub.cuh\"\n#endif\n\n#include \"src/turbomind/kernels/logprob_kernels.h\"\n#include \"src/turbomind/kernels/reduce_kernel_utils.cuh\"\n#include \"src/turbomind/macro.h\"\n#include \"src/turbomind/utils/logger.h\"\n\nnamespace turbomind {\n\ntemplate<typename T>\n__global__ void log_probs_kernel(float*       log_probs,\n                                 const T*     logits,\n                                 const int*   ids,\n                                 const int*   lengths,\n                                 const size_t max_input_length,\n                                 const size_t batch_size,\n                                 const size_t vocab_size,\n                                 const size_t vocab_size_padded,\n                                 bool         batch_first)\n{\n    // Calculate the log probability from logits.\n    //   log_probs[t, :] = log(softmax(logits))[ids[t + 1, :]]\n    //\n    // log_probs: [max_length - 1, batch_size] or [batch_size, max_length -1],\n    //     log probabilities of each token.\n    // logits: [max_length, batch_size, vocab_size_padded] or [batch_size, max_length, vocab_size_padded]\n    // lengths: [batch_size], sequence lengths\n    // ids: [max_length, batch_size], token ids.\n    // batch_size: [1], batch_size. in case of beam > 1, batch x beam.\n    // vocab_size: [1], vocab_size,\n    // vocab_size: [1], vocab_size_padded, padded vocab size.\n\n    const bool IS_FP16   = std::is_same<T, half>::value;\n    const T    MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;\n\n    int tidx = threadIdx.x;                            // vocab dim\n    int bidx = batch_first ? blockIdx.x : blockIdx.y;  // batch dim\n    int step = batch_first ? blockIdx.y : blockIdx.x;  // step dim\n\n    __shared__ float s_max_logit;\n\n    if (bidx < batch_size && step < lengths[bidx] - 1) {\n        // reposition logits to data for the current batch.\n        int step_offset  = batch_first ? step * vocab_size_padded : step * batch_size * vocab_size_padded;\n        int batch_offset = batch_first ? bidx * max_input_length * vocab_size_padded : bidx * vocab_size_padded;\n        logits += step_offset + batch_offset;\n\n        // Find max(logits).\n        float local_max = -MAX_T_VAL;\n        float val       = -MAX_T_VAL;\n        for (int i = tidx; i < vocab_size; i += blockDim.x) {\n            val       = static_cast<float>(logits[i]);\n            local_max = fmax(local_max, val);\n        }\n\n        float max_val = blockDim.x <= 32 ? warpReduceMax(local_max) : blockReduceMax<float>(local_max);\n        if (tidx == 0) {\n            s_max_logit = max_val;\n        }\n        __syncthreads();\n\n        // Calculate the denominator: sum_i exp(logits[i])\n        float local_sum_exp = 0.0f;\n        for (int i = tidx; i < vocab_size; i += blockDim.x) {\n            val = __expf(static_cast<float>(logits[i]) - s_max_logit);\n            local_sum_exp += val;\n        }\n\n        float sum_exp = blockDim.x <= 32 ? warpReduceSum(local_sum_exp) : blockReduceSum<float>(local_sum_exp);\n        if (tidx == 0) {\n            int idx = batch_first ? step + bidx * (max_input_length - 1) : step * batch_size + bidx;\n            // log_probs[step, ...] is the log probability of a token at step t + 1.\n            int token_idx  = batch_first ? step + 1 + bidx * max_input_length : (step + 1) * batch_size + bidx;\n            log_probs[idx] = static_cast<float>(logits[ids[token_idx]]) - s_max_logit - __logf(sum_exp + 1e-9f);\n        }\n    }\n}\n\n__global__ void accumulate_log_probs(float*       cum_log_probs,\n                                     const float* log_probs,\n                                     const int*   lengths,\n                                     const size_t max_input_length,\n                                     const size_t batch_size,\n                                     const bool   batch_first)\n{\n    // Accumulate the log probability along with the sequence dimension.\n    //   cum_log_probs[j] = sum_i log(softmax(logits))[ids[i,j]]\n    //\n    // cum_log_probs: [batch_size], cumulative log probability\n    // log_probs: [max_length - 1, batch_size] or [batch_size, max_length - 1],\n    //   log probability of each token\n    // lengths: [batch_size], sequence lengths\n    // batch_size: [1], batch_size. in case of beam > 1, batch x beam.\n\n    int bidx = blockIdx.x;   // batch dim\n    int tidx = threadIdx.x;  // step dim\n\n    if (bidx < batch_size) {\n        int length = lengths[bidx];\n        // reposition logits to data for the current batch.\n        log_probs += batch_first ? bidx * (max_input_length - 1) : bidx;\n        int   stride      = batch_first ? 1 : batch_size;  // stride along with seq dim.\n        float local_accum = 0.0f;\n        for (int step = tidx; step < length - 1; step += blockDim.x) {\n            local_accum += static_cast<float>(log_probs[step * stride]);\n        }\n        float accum = blockDim.x <= 32 ? warpReduceSum(local_accum) : blockReduceSum<float>(local_accum);\n        if (tidx == 0) {\n            cum_log_probs[bidx] = accum;\n        }\n    }\n}\n\ntemplate<typename T>\nvoid invokeLogProbFromLogits(float*       cum_log_probs,\n                             const T*     logits,\n                             const int*   input_ids,\n                             const int*   input_lengths,\n                             const size_t max_input_length,\n                             const size_t batch_size,\n                             const size_t vocab_size,\n                             const size_t vocab_size_padded,\n                             void*        workspace,\n                             const size_t workspace_size,\n                             cudaStream_t stream,\n                             const bool   batch_first)\n{\n    // A batched version of log prob computation.\n    //\n    // cum_log_probs: [batch_size]\n    // logits: [max_input_length, batch_size, vocab_size] or [batch_size, max_input_length, vocab_size]\n    // input_ids: [max_input_length, batch_size] or [max_input_length, batch_size]\n    // input_lengths: [batch_size]\n    // workspace: workspace buffer of size at least sizeof(float) * max_input_length * batch_size.\n\n    TM_LOG_DEBUG(__PRETTY_FUNCTION__);\n    // block_size should be multiple of 32 to use warpReduceMax.\n    const int block_size = vocab_size < 1024 ? (vocab_size + 31) / 32 * 32 : 1024;\n    assert(block_size % 32 == 0);\n    assert(workspace != nullptr && workspace_size >= sizeof(float) * max_input_length * batch_size);\n    assert(vocab_size <= vocab_size_padded);\n\n    float* log_probs = reinterpret_cast<float*>(workspace);\n    int    gx        = batch_first ? batch_size : max_input_length - 1;\n    int    gy        = batch_first ? max_input_length - 1 : batch_size;\n    dim3   grid(gx, gy);\n    log_probs_kernel<T><<<grid, block_size, 0, stream>>>(log_probs,\n                                                         logits,\n                                                         input_ids,\n                                                         input_lengths,\n                                                         max_input_length,\n                                                         batch_size,\n                                                         vocab_size,\n                                                         vocab_size_padded,\n                                                         batch_first);\n    accumulate_log_probs<<<batch_size, block_size, 0, stream>>>(\n        cum_log_probs, log_probs, input_lengths, max_input_length, batch_size, batch_first);\n}\n\ntemplate void invokeLogProbFromLogits(float*       cum_log_probs,\n                                      const float* logits,\n                                      const int*   input_ids,\n                                      const int*   input_lengths,\n                                      const size_t max_input_length,\n                                      const size_t batch_size,\n                                      const size_t vocab_size,\n                                      const size_t vocab_size_padded,\n                                      void*        workspace,\n                                      const size_t workspace_size,\n                                      cudaStream_t stream,\n                                      const bool   batch_first);\n\ntemplate void invokeLogProbFromLogits(float*       cum_log_probs,\n                                      const half*  logits,\n                                      const int*   input_ids,\n                                      const int*   input_lengths,\n                                      const size_t max_input_length,\n                                      const size_t batch_size,\n                                      const size_t vocab_size,\n                                      const size_t vocab_size_padded,\n                                      void*        workspace,\n                                      const size_t workspace_size,\n                                      cudaStream_t stream,\n                                      const bool   batch_first);\n}  // end of namespace turbomind\n"
  },
  {
    "path": "src/turbomind/kernels/logprob_kernels.h",
    "content": "/*\n * Copyright (c) 2022-2023, NVIDIA CORPORATION.  All rights reserved.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#pragma once\n\nnamespace turbomind {\n\ntemplate<typename T>\nvoid invokeLogProbFromLogits(float*       cum_log_probs,\n                             const T*     logits,\n                             const int*   input_ids,\n                             const int*   input_lengths,\n                             const size_t max_input_length,\n                             const size_t batch_size,\n                             const size_t vocab_size,\n                             const size_t vocab_size_padded,\n                             void*        workspace,\n                             const size_t workspace_size,\n                             cudaStream_t stream,\n                             const bool   batch_first = false);\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/kernels/norm/CMakeLists.txt",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\nadd_library(rms_norm rms_norm.cu)\nset_property(TARGET rms_norm PROPERTY POSITION_INDEPENDENT_CODE ON)\nset_property(TARGET rms_norm PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)\n"
  },
  {
    "path": "src/turbomind/kernels/norm/rms_norm.cu",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include <stdexcept>\n\n#include \"cub/block/block_reduce.cuh\"\n\n#include \"src/turbomind/core/data_type.h\"\n#include \"src/turbomind/kernels/core/array_ops.h\"\n#include \"src/turbomind/kernels/core/common.h\"\n#include \"src/turbomind/kernels/core/math.h\"\n#include \"src/turbomind/kernels/core/meta.h\"\n\n#include \"src/turbomind/kernels/norm/rms_norm.h\"\n\nnamespace turbomind {\n\nnamespace kernel {\n\ntemplate<class T, class Accum, int block_dim, int vec_size>\n__global__ void RMSNorm(T*       dst,\n                        int      dst_ld,\n                        const T* src,\n                        int      src_ld,\n                        const T* __restrict__ weights,\n                        int   dims,\n                        int   num,\n                        float eps,\n                        float inv_dims)\n{\n    const int ti = blockIdx.x;\n    const int di = threadIdx.x * vec_size;\n\n    if (ti >= num) {\n        return;\n    }\n\n    src += src_ld * ti;\n\n    Array<Accum, vec_size> accum{};\n    Array<T, vec_size>     vec;\n\n    for (int i = di; i < dims; i += block_dim * vec_size) {\n        Load(vec, &src[i]);\n        Array<Accum, vec_size> tmp = cast<Accum>(vec);\n        using namespace ops;\n        accum = accum + tmp * tmp;\n    }\n\n    float sum{};\n    PRAGMA_UNROLL\n    for (int i = 0; i < vec_size; ++i) {\n        sum += accum[i];\n    }\n\n    using BlockReduce = cub::BlockReduce<Accum, block_dim>;\n    __shared__ typename BlockReduce::TempStorage temp_storage;\n\n    sum = BlockReduce{temp_storage}.Sum(sum);\n\n    __shared__ float shared_sum;\n\n    if (threadIdx.x == 0) {\n        shared_sum = rsqrtf(sum * inv_dims + eps);\n    }\n\n    __syncthreads();\n\n    sum = shared_sum;\n\n    dst += dst_ld * ti;\n\n    Array<T, vec_size> sv;\n    for (int i = di; i < dims; i += block_dim * vec_size) {\n        Load(vec, &src[i]);\n        Ldg(sv, &weights[i]);\n        PRAGMA_UNROLL\n        for (int c = 0; c < vec_size; ++c) {\n            vec[c] = (T)((float)vec[c] * sum) * sv[c];\n            // vec[c] = (T)((float)vec[c] * sum * (float)sv[c]);\n        }\n        Store(&dst[i], vec);\n    }\n}\n\n}  // namespace kernel\n\nvoid invokeRMSNorm(Tensor& out, const Tensor& x, const Tensor& w, float eps, cudaStream_t st)\n{\n    if (x.size() == 0) {\n        return;\n    }\n\n    TM_CHECK(x.ndim() == 2);\n    TM_CHECK(out.shape() == x.shape());\n    TM_CHECK(out.dtype() == x.dtype());\n    TM_CHECK(w.dtype() == x.dtype() && w.shape(-1) == x.shape(-1));\n\n    auto invoke = [&](auto t) {\n        using T = decltype(t);\n\n        const auto [num, dim] = x.shapes(0, 1);\n\n        constexpr int vec_size = 16 / sizeof(T);\n\n        constexpr int threads = 512;\n        const int     blocks  = num;\n\n        kernel::RMSNorm<T, float, threads, vec_size><<<blocks, threads, 0, st>>>((T*)out.raw_data(),  //\n                                                                                 out.stride(0),\n                                                                                 (const T*)x.raw_data(),\n                                                                                 x.stride(0),\n                                                                                 (const T*)w.raw_data(),\n                                                                                 dim,\n                                                                                 num,\n                                                                                 eps,\n                                                                                 1.f / dim);\n    };\n\n    TM_DISPATCH_PRIMARY_DTYPES(x.dtype(), invoke);\n}\n\nnamespace kernel {\n\ntemplate<class T, class A, int vec_size, int max_dim>\n__global__ void RMSNormQK(T*       data,  //\n                          int      ld,\n                          const T* weight,\n                          int      dim,\n                          int      n,\n                          int      token_num,\n                          float    eps,\n                          float    inv_dim)\n{\n    static_assert((max_dim & (max_dim - 1)) == 0);\n\n    constexpr int thr_per_qk = max_dim / vec_size;\n\n    const int bi = (threadIdx.x + blockIdx.x * blockDim.x) / thr_per_qk;\n    const int di = threadIdx.x % thr_per_qk * vec_size;\n    const int ti = bi / n;\n    const int hi = bi % n;\n\n    if (bi >= token_num * n) {\n        return;\n    }\n\n    data += ti * ld + hi * dim;\n\n    Array<T, vec_size> vec{};\n    if (di < dim) {\n        Load(vec, &data[di]);\n    }\n\n    using namespace ops;\n    auto acc = cast<A>(vec);\n    acc      = acc * acc;\n\n    float sum{};\n    PRAGMA_UNROLL\n    for (int i = 0; i < vec_size; ++i) {\n        sum += acc[i];\n    }\n\n    PRAGMA_UNROLL\n    for (int mask = thr_per_qk / 2; mask >= 1; mask /= 2) {\n        sum += __shfl_xor_sync((uint32_t)-1, sum, mask);\n    }\n\n    sum = rsqrtf(sum * inv_dim + eps);\n\n    Array<T, vec_size> w;\n    if (di < dim) {\n        Ldg(w, &weight[di]);\n        PRAGMA_UNROLL\n        for (int i = 0; i < vec_size; ++i) {\n            vec[i] = (T)((float)vec[i] * sum) * w[i];\n        }\n        Store(&data[di], vec);\n    }\n}\n\n}  // namespace kernel\n\nvoid invokeQkRMSNorm(void*        data,\n                     int          ld,\n                     const void*  weight,\n                     DataType     dtype,\n                     int          head_dim,\n                     int          n,\n                     int          token_num,\n                     float        eps,\n                     cudaStream_t stream)\n{\n\n    auto invoke = [&](auto t) {\n        using T = decltype(t);\n\n        auto launch = [&](auto max_dim_c) {\n            constexpr int kMaxDim = std::decay_t<decltype(max_dim_c)>::value;\n            TM_CHECK_LE(head_dim, kMaxDim);\n\n            constexpr int vec_size   = sizeof(uint4) / sizeof(T);\n            constexpr int thr_per_qk = kMaxDim / vec_size;\n\n            FT_CHECK(head_dim % vec_size == 0);\n\n            const int threads   = thr_per_qk * n * (int64_t)token_num;\n            const int block_dim = 512;\n            const int grid_dim  = cdiv(threads, block_dim);\n\n            kernel::RMSNormQK<T, float, vec_size, kMaxDim><<<grid_dim, block_dim, 0, stream>>>(\n                (T*)data, ld, (const T*)weight, head_dim, n, token_num, eps, 1.f / head_dim);\n        };\n\n        if (head_dim <= 128) {\n            launch(constant<128>{});\n        }\n        else {\n            launch(constant<256>{});\n        }\n    };\n\n    TM_DISPATCH_PRIMARY_DTYPES(dtype, invoke);\n}\n\nvoid invokeRMSNormQK(Tensor& x, const Tensor& w, float eps, cudaStream_t st)\n{\n    TM_CHECK(x.ndim() == 3);\n\n    int token_num, head_num, head_dim;\n    std::tie(token_num, head_num, head_dim) = x.shapes(0, 1, 2);\n\n    TM_CHECK(x.stride(1) == head_dim);\n\n    auto data   = x.raw_data();\n    auto stride = x.stride(0);\n\n    auto invoke = [&](auto t) {\n        using T = decltype(t);\n\n        auto launch = [&](auto max_dim_c) {\n            constexpr int kMaxDim = std::decay_t<decltype(max_dim_c)>::value;\n            TM_CHECK_LE(head_dim, kMaxDim);\n\n            constexpr int vec_size   = sizeof(uint4) / sizeof(T);\n            constexpr int thr_per_qk = kMaxDim / vec_size;\n\n            TM_CHECK(head_dim % vec_size == 0);\n\n            const int threads   = token_num * head_num * thr_per_qk;\n            const int block_dim = 512;\n            const int grid_dim  = cdiv(threads, block_dim);\n\n            kernel::RMSNormQK<T, float, vec_size, kMaxDim><<<grid_dim, block_dim, 0, st>>>(\n                (T*)data, stride, (const T*)w.raw_data(), head_dim, head_num, token_num, eps, 1.f / head_dim);\n        };\n\n        if (head_dim <= 128) {\n            launch(constant<128>{});\n        }\n        else {\n            launch(constant<256>{});\n        }\n    };\n\n    TM_DISPATCH_PRIMARY_DTYPES(x.dtype(), invoke);\n}\n\n// r' <- r + (h + b)\n// h' <- norm(r') * w\ntemplate<class T, class Tacc, int block_dim, int vec_size>\n__global__ void BiasResidualRMSNormKernel(T* __restrict__ residual,\n                                          T* __restrict__ hidden_states,\n                                          const T* __restrict__ weights,\n                                          const T* __restrict__ bias,\n                                          int   dims,\n                                          int   num,\n                                          float eps,\n                                          float inv_dims)\n{\n    const int ti = blockIdx.x;\n    const int di = threadIdx.x * vec_size;\n\n    if (ti >= num) {\n        return;\n    }\n\n    residual += dims * ti;\n    hidden_states += dims * ti;\n\n    Array<Tacc, vec_size> accum{};\n\n    Array<T, vec_size> r_vec;\n    Array<T, vec_size> h_vec;\n    Array<T, vec_size> b_vec;\n\n    for (int i = di; i < dims; i += block_dim * vec_size) {\n        Load(r_vec, &residual[i]);\n        Load(h_vec, &hidden_states[i]);\n\n        using namespace ops;\n        r_vec = r_vec + h_vec;\n\n        if (bias) {\n            Ldg(b_vec, &bias[i]);\n            r_vec = r_vec + b_vec;\n        }\n\n        Store(&residual[i], r_vec);\n\n        Array<Tacc, vec_size> tmp = cast<Tacc>(r_vec);\n\n        accum = accum + tmp * tmp;\n    }\n\n    float sum{};\n    PRAGMA_UNROLL\n    for (int i = 0; i < vec_size; ++i) {\n        sum += accum[i];\n    }\n\n    using BlockReduce = cub::BlockReduce<Tacc, block_dim>;\n    __shared__ typename BlockReduce::TempStorage temp_storage;\n\n    sum = BlockReduce{temp_storage}.Sum(sum);\n\n    __shared__ float shared_sum;\n\n    if (threadIdx.x == 0) {\n        shared_sum = rsqrtf(sum * inv_dims + eps);\n    }\n\n    __syncthreads();\n\n    sum = shared_sum;\n\n    Array<T, vec_size> w_vec;\n    for (int i = di; i < dims; i += block_dim * vec_size) {\n        Load(r_vec, &residual[i]);\n        Ldg(w_vec, &weights[i]);\n        PRAGMA_UNROLL\n        for (int c = 0; c < vec_size; ++c) {\n            r_vec[c] = (T)((float)r_vec[c] * sum) * w_vec[c];\n            // r_vec[c] = (T)((float)r_vec[c] * sum * (float)w_vec[c]);\n        }\n        Store(&hidden_states[i], r_vec);\n    }\n}\n\ntemplate<class T>\nvoid invokeBiasResidualRMSNorm(\n    T* residual, T* hidden_states, const T* weights, const T* bias, int dims, int num, float eps, cudaStream_t st)\n{\n    constexpr int vec_size = 16 / sizeof(T);\n    constexpr int threads  = 512;\n    const int     blocks   = num;\n\n    BiasResidualRMSNormKernel<T, float, threads, vec_size><<<blocks, threads, 0, st>>>(residual,  //\n                                                                                       hidden_states,\n                                                                                       weights,\n                                                                                       bias,\n                                                                                       dims,\n                                                                                       num,\n                                                                                       eps,\n                                                                                       1.f / dims);\n}\n\ntemplate void invokeBiasResidualRMSNorm(half*        residual,\n                                        half*        hidden_states,\n                                        const half*  weights,\n                                        const half*  bias,\n                                        int          dims,\n                                        int          num,\n                                        float        eps,\n                                        cudaStream_t st);\n\n#if ENABLE_BF16\ntemplate void invokeBiasResidualRMSNorm(nv_bfloat16*       residual,\n                                        nv_bfloat16*       hidden_states,\n                                        const nv_bfloat16* weights,\n                                        const nv_bfloat16* bias,\n                                        int                dims,\n                                        int                num,\n                                        float              eps,\n                                        cudaStream_t       st);\n#endif\n\nvoid invokeResidualBiasRMSNorm(void*        hidden_states,\n                               void*        residual,\n                               const void*  weights,\n                               const void*  bias,\n                               DataType     dtype,\n                               int          dims,\n                               int          num,\n                               float        eps,\n                               cudaStream_t st)\n{\n    if (num == 0) {\n        return;\n    }\n    auto invoke = [&](auto t) {\n        using T                = decltype(t);\n        constexpr int vec_size = sizeof(uint4) / sizeof(T);\n        constexpr int threads  = 512;\n        const int     blocks   = num;\n        BiasResidualRMSNormKernel<T, float, threads, vec_size><<<blocks, threads, 0, st>>>((T*)residual,  //\n                                                                                           (T*)hidden_states,\n                                                                                           (const T*)weights,\n                                                                                           (const T*)bias,\n                                                                                           dims,\n                                                                                           num,\n                                                                                           eps,\n                                                                                           1.f / dims);\n    };\n\n    TM_DISPATCH_PRIMARY_DTYPES(dtype, invoke);\n}\n\ntemplate<class T, class B, int vec_size>\n__global__ void biasKernel(T* data, const B* bias, int num, int dim)\n{\n    int ti = blockIdx.x;\n    int di = threadIdx.x * vec_size;\n\n    Array<B, vec_size> b;\n    Ldg(b, bias + di);\n\n    Array<T, vec_size> x;\n    Load(x, data + ti * dim + di);\n    using namespace ops;\n    x = x + cast<T>(b);\n    Store(data + ti * dim + di, x);\n}\n\nvoid ApplyBias(Tensor& data, const Tensor& bias, cudaStream_t st)\n{\n    if (!bias) {\n        return;\n    }\n\n    const int num = data.shape(0);\n    const int dim = data.shape(1);\n\n    TM_CHECK_EQ(dim, bias.shape(-1));\n\n    auto invoke0 = [&](auto t) {\n        using T      = decltype(t);\n        auto invoke1 = [&](auto b) {\n            using B                = decltype(b);\n            constexpr int vec_size = sizeof(uint4) / std::max(sizeof(T), sizeof(B));\n            TM_CHECK(dim % vec_size == 0);\n            const int blocks  = num;\n            const int threads = dim / vec_size;\n            TM_CHECK_LE(threads, 1024);\n            biasKernel<T, B, vec_size><<<blocks, threads, 0, st>>>(data.data<T>(),  //\n                                                                   bias.data<B>(),\n                                                                   num,\n                                                                   dim);\n        };\n        if constexpr (data_type_v<T> == kFloat) {\n            TM_DISPATCH_PRIMARY_DTYPES(bias.dtype(), invoke1);\n        }\n        else {  // skip mixing half and bf16\n            invoke1(t);\n        }\n    };\n    TM_DISPATCH_DTYPES(data.dtype(), invoke0, float, half, nv_bfloat16);\n}\n\ntemplate<class T, int vec_size>\n__global__ void biasKernel(T* data, const T* bias, const int* offsets, int num, int dim, int groups, float scale)\n{\n    int ti = blockIdx.x;\n    int di = threadIdx.x * vec_size;\n\n    __shared__ int s_idx;\n\n    if (int tid = threadIdx.x; tid < groups) {\n        int b = __ldg(&offsets[tid]);\n        int e = __ldg(&offsets[tid + 1]);\n        if (b <= ti && ti < e) {\n            s_idx = tid;\n        }\n    }\n\n    data += ti * dim;\n\n    __syncthreads();\n\n    bias += s_idx * dim;\n\n    if (di >= dim) {\n        return;\n    }\n\n    Array<T, vec_size> b;\n    Ldg(b, bias + di);\n\n    PRAGMA_UNROLL\n    for (int i = 0; i < vec_size; ++i) {\n        b[i] = (T)((float)b[i] * scale);\n    }\n\n    Array<T, vec_size> x;\n    Load(x, data + di);\n\n    using namespace ops;\n    x = x + b;\n\n    Store(data + di, x);\n}\n\nvoid ApplyBias(Tensor& data, const Tensor& bias, const Buffer_<int>& offsets, float scale, cudaStream_t st)\n{\n    if (!bias) {\n        return;\n    }\n\n    const int num    = data.shape(0);\n    const int dim    = data.shape(1);\n    const int groups = offsets.size() - 1;\n\n    TM_CHECK_EQ(dim, bias.shape(-1));\n\n    // std::cout << data << \" \" << bias << \" \" << offsets << \"\\n\";\n\n    auto invoke = [&](auto t) {\n        using T = decltype(t);\n\n        constexpr int vec_size = sizeof(uint4) / sizeof(T);\n        TM_CHECK(dim % vec_size == 0);\n\n        const int blocks  = num;\n        const int threads = std::max(dim / vec_size, groups);\n\n        TM_CHECK_LE(threads, 1024);\n\n        biasKernel<T, vec_size><<<blocks, threads, 0, st>>>(data.data<T>(),  //\n                                                            bias.data<T>(),\n                                                            offsets.data(),\n                                                            num,\n                                                            dim,\n                                                            offsets.size() - 1,\n                                                            scale);\n    };\n\n    TM_DISPATCH_PRIMARY_DTYPES(data.dtype(), invoke);\n}\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/kernels/norm/rms_norm.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include <cuda_runtime.h>\n\n#include \"src/turbomind/core/core.h\"\n\nnamespace turbomind {\n\nvoid invokeRMSNorm(Tensor& out, const Tensor& x, const Tensor& w, float eps, cudaStream_t st);\n\nvoid invokeRMSNormQK(Tensor& x, const Tensor& w, float eps, cudaStream_t st);\n\ntemplate<class T>\nvoid invokeBiasResidualRMSNorm(\n    T* residual, T* hidden_states, const T* weights, const T* bias, int dims, int num, float eps, cudaStream_t st);\n\nvoid invokeResidualBiasRMSNorm(void*        hidden_states,\n                               void*        residual,\n                               const void*  weights,\n                               const void*  bias,\n                               DataType     dtype,\n                               int          dims,\n                               int          num,\n                               float        eps,\n                               cudaStream_t st);\n\nvoid ApplyBias(Tensor& x, const Tensor& bias, const Buffer_<int>& offsets, float scale, cudaStream_t st);\n\nvoid ApplyBias(Tensor& x, const Tensor& bias, cudaStream_t st);\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/kernels/penalty_types.h",
    "content": "/*\n * Copyright (c) 2022-2023, NVIDIA CORPORATION.  All rights reserved.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#pragma once\n\n#include <string>\n#include <unordered_map>\n\n#include \"src/turbomind/utils/string_utils.h\"\n\nnamespace turbomind {\n\nenum class RepetitionPenaltyType\n{\n    Additive,        // the presence penalty\n    Multiplicative,  // the repetition penalty\n    None             // No repetition penalty.\n};\n\ninline float getDefaultPenaltyValue(RepetitionPenaltyType penalty_type)\n{\n    switch (penalty_type) {\n        case RepetitionPenaltyType::Additive:\n            return 0.0f;\n        case RepetitionPenaltyType::Multiplicative:\n            return 1.0f;\n        default:\n            break;\n    }\n    return 0.0f;\n}\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/kernels/quantization.cu",
    "content": "\n\n#include <limits>\n\n#include <cuda_bf16.h>\n#include <cuda_fp16.h>\n#include <cuda_fp8.h>\n\n#include <cuda_runtime.h>\n\n#include <cub/block/block_reduce.cuh>\n#include <type_traits>\n\n#include \"src/turbomind/core/allocator.h\"\n#include \"src/turbomind/core/context.h\"\n#include \"src/turbomind/core/data_type.h\"\n\n#include \"src/turbomind/kernels/core/array_ops.h\"\n#include \"src/turbomind/kernels/core/common.h\"\n#include \"src/turbomind/kernels/core/floating_point.h\"\n#include \"src/turbomind/kernels/core/math.h\"\n\n#include \"src/turbomind/kernels/quantization.cuh\"\n#include \"src/turbomind/kernels/quantization.h\"\n\n#include \"src/turbomind/kernels/attention/quantization.h\"\n\nnamespace turbomind {\n\ntemplate<int vec_size, int group_size, class Tout, class Tscale, class T>\n__global__ void quant_symm_row(\n    Tout* out, int out_ld, Tscale* scales, int scales_ld, const T* src, int src_ld, int num, int dim, Tscale qmax)\n{\n#if TURBOMIND_ARCH_SM90\n    static_assert(group_size % vec_size == 0);\n    constexpr int threads = group_size / vec_size;\n    const int     dim1    = round_up(dim, WARP_SIZE * vec_size);\n    for (int ti = blockIdx.x; ti < num; ti += gridDim.x) {\n        for (int di = threadIdx.x * vec_size; di < dim1; di += blockDim.x * vec_size) {\n            Array<T, vec_size> vec{};\n            if (di < dim) {\n                Ldg(vec, src + ti * src_ld + di);\n            }\n            auto         absmax    = fmaxf(static_cast<Tscale>(find_absmax<threads>(vec)), 1e-8f);\n            const Tscale scale     = absmax / qmax;\n            const Tscale inv_scale = qmax / absmax;\n            if (threadIdx.x % threads == 0 && di < dim) {\n                // column-major\n                scales[(di / group_size) * scales_ld + ti] = scale;\n            }\n            Array<Tout, vec_size> tmp;\n            PRAGMA_UNROLL\n            for (int c = 0; c < vec_size; ++c) {\n                tmp[c] = Tout(static_cast<Tscale>(vec[c]) * inv_scale);\n            }\n            if (di < dim) {\n                Store(out + ti * out_ld + di, tmp);\n            }\n        }\n    }\n#endif\n}\n\nvoid QuantizeSymm(Tensor& out, Tensor& scale, const Tensor& src, cudaStream_t st)\n{\n    TM_CHECK_EQ(src.ndim(), 2);\n    TM_CHECK_EQ(src.stride(1), 1);  // row-major\n\n    const auto [num, dim] = src.shapes(0, 1);\n\n    using T      = bfloat16_t;\n    using Tout   = fp8_e4m3_t;\n    using Tscale = float;\n\n    constexpr int group_size = 128;\n    constexpr int vec_size   = 8;\n\n    constexpr int alignment = 16 / sizeof(Tscale);\n\n    if (!out) {\n        out = Tensor_<Tout>{src.shape(), kDEVICE};\n    }\n    else {\n        TM_CHECK(out.shape() == src.shape());\n    }\n\n    const int aligned_num = round_up<int>(num, alignment);\n\n    const int s_dim = cdiv<ssize_t>(dim, group_size);\n\n    if (!scale) {\n        scale = Tensor_<Tscale>({{s_dim, num}, {aligned_num, 1}}, kDEVICE);\n    }\n    else {\n        TM_CHECK(std::make_tuple(s_dim, num) == scale.shapes(0, 1));\n        TM_CHECK(scale.stride(1) == 1);\n        TM_CHECK(scale.stride(0) % alignment == 0);\n    }\n\n    constexpr int block_dim = 512;\n\n    quant_symm_row<vec_size, group_size><<<num, block_dim, 0, st>>>(out.data<Tout>(),  //\n                                                                    out.stride(0),\n                                                                    scale.data<Tscale>(),\n                                                                    scale.stride(0),\n                                                                    src.data<T>(),\n                                                                    src.stride(0),\n                                                                    num,\n                                                                    dim,\n                                                                    448.f);\n}\n\ntemplate<int vec_size, int group_size, class Tout, class Tscale, class T>\n__global__ void\ndequant_symm_row(Tout* out, int out_ld, const T* src, int src_ld, const Tscale* scales, int scales_ld, int num, int dim)\n{\n#if TURBOMIND_ARCH_SM90\n    static_assert(group_size % vec_size == 0);\n    for (int ti = blockIdx.x; ti < num; ti += gridDim.x) {\n        for (int di = threadIdx.x * vec_size; di < dim; di += blockDim.x * vec_size) {\n            Array<T, vec_size> vec;\n            Ldg(vec, src + ti * src_ld + di);\n            const auto            scale = __ldg(&scales[(di / group_size) * scales_ld + ti]);\n            Array<Tout, vec_size> tmp;\n            PRAGMA_UNROLL\n            for (int c = 0; c < vec_size; ++c) {\n                tmp[c] = Tout(static_cast<Tscale>(vec[c]) * scale);\n            }\n            Store(out + ti * out_ld + di, tmp);\n        }\n    }\n#endif\n}\n\nvoid DequantizeSymm(Tensor& out, const Tensor& src, const Tensor& scale, cudaStream_t st)\n{\n    using T      = fp8_e4m3_t;\n    using Tout   = bfloat16_t;\n    using Tscale = float;\n\n    if (!out) {\n        out = Tensor_<Tout>{src.layout(), kDEVICE};\n    }\n    else {\n        TM_CHECK(out.layout() == src.layout());\n    }\n\n    auto [num, dim] = src.shapes(0, 1);\n\n    constexpr int group_size = 128;\n    constexpr int vec_size   = 8;\n\n    constexpr int block_dim = 512;\n\n    dequant_symm_row<vec_size, group_size, Tout, Tscale, T><<<num, block_dim, 0, st>>>(out.data<Tout>(),  //\n                                                                                       out.stride(0),\n                                                                                       src.data<T>(),\n                                                                                       src.stride(0),\n                                                                                       scale.data<Tscale>(),\n                                                                                       scale.stride(0),\n                                                                                       num,\n                                                                                       dim);\n}\n\ntemplate<int vec_size, int cta_size, int block_size, class Tout, class Tscale, class T>\n__global__ void quant_symm_block(Tout* out, Tscale* scales, const T* src, Tscale qmax, int num, int dim)\n{\n    if constexpr (TURBOMIND_ARCH_BF16_GUARD(data_type_v<T>)) {\n        static_assert(block_size % vec_size == 0);\n        constexpr int threads = block_size / vec_size;\n\n        static_assert(cta_size % threads == 0);\n        constexpr int rows = cta_size / threads;\n\n        constexpr int S = cdiv(block_size, rows);\n\n        using BlockReduce = cub::BlockReduce<T, cta_size>;\n        __shared__ typename BlockReduce::TempStorage temp_storage;\n        __shared__ T                                 shared_inv_scale;\n\n        const int row = threadIdx.x / threads;\n        const int col = threadIdx.x % threads;\n        const int ti  = blockIdx.x * block_size;\n        const int di  = blockIdx.y * block_size + col * vec_size;\n\n        T                  absmax{};\n        Array<T, vec_size> xs[S]{};\n        PRAGMA_UNROLL\n        for (int s = 0; s < S; ++s) {\n            if (auto r = ti + s * rows + row; r < num && di < dim) {\n                Ldg(xs[s], src + (int64_t)r * dim + di);\n            }\n            PRAGMA_UNROLL\n            for (int i = 0; i < vec_size; ++i) {\n                absmax = __hmax(absmax, __habs(xs[s][i]));\n            }\n        }\n\n        absmax = BlockReduce{temp_storage}.Reduce(absmax, [](auto a, auto b) { return __hmax(a, b); });\n        if (threadIdx.x == 0) {\n            auto maxval                                 = fmaxf(static_cast<Tscale>(absmax), 1e-8f);\n            scales[blockIdx.x * gridDim.y + blockIdx.y] = maxval / qmax;\n            shared_inv_scale                            = qmax / maxval;\n        }\n        __syncthreads();\n        const Tscale inv_scale = shared_inv_scale;\n\n        Array<Tout, vec_size> ys[S];\n        PRAGMA_UNROLL\n        for (int s = 0; s < S; ++s) {\n            PRAGMA_UNROLL\n            for (int i = 0; i < vec_size; ++i) {\n                ys[s][i] = Tout(static_cast<Tscale>(xs[s][i]) * inv_scale);\n            }\n            if (auto r = ti + s * rows + row; r < num && di < dim) {\n                Store(out + (int64_t)r * dim + di, ys[s]);\n            }\n        }\n    }\n}\n\nvoid QuantizeSymmBlock(Ref<Tensor> out_, Ref<Tensor> scale_, const Tensor& src, cudaStream_t st)\n{\n    TM_CHECK(src.is_contiguous());\n    TM_CHECK_EQ(src.ndim(), 2);\n\n    auto invoke = [&](auto t) {\n        using T      = decltype(t);\n        using Tout   = fp8_e4m3_t;\n        using Tscale = float;\n\n        constexpr int block_size = 128;\n        constexpr int vec_size   = 8;\n\n        const auto [num, dim] = src.shapes(0, 1);\n\n        const int bnum = cdiv<int>(num, block_size);\n        const int bdim = cdiv<int>(dim, block_size);\n\n        constexpr int cta_size = 1024;\n        const dim3    grid(bnum, bdim);\n\n        auto& out   = out_.get();\n        auto& scale = scale_.get();\n\n        if (!out) {\n            out = Tensor_<Tout>{src.layout(), kDEVICE};\n        }\n        else {\n            TM_CHECK(out.layout() == src.layout());\n        }\n\n        if (!scale) {\n            scale = Tensor_<Tscale>({bnum, bdim}, kDEVICE);\n        }\n        else {\n            TM_CHECK(std::make_tuple(bnum, bdim) == scale.shapes(0, 1));\n        }\n\n        quant_symm_block<vec_size, cta_size, block_size><<<grid, cta_size, 0, st>>>(  //\n            out.data<Tout>(),\n            scale.data<Tscale>(),\n            src.data<T>(),\n            448.f,\n            num,\n            dim);\n    };\n\n    TM_DISPATCH_PRIMARY_DTYPES(src.dtype(), invoke);\n}\n\ntemplate<int vec_size, int cta_size, int block_size, class Tout, class Tscale, class T>\n__global__ void dequant_symm_block(Tout* out, const T* src, const Tscale* scales, int num, int dim)\n{\n    if constexpr (TURBOMIND_ARCH_BF16_GUARD(data_type_v<T>)) {\n        static_assert(block_size % vec_size == 0);\n        constexpr int threads = block_size / vec_size;\n        static_assert(cta_size % threads == 0);\n        constexpr int rows  = cta_size / threads;\n        constexpr int S     = cdiv(block_size, rows);\n        const int     col   = threadIdx.x % threads;\n        const int     row   = threadIdx.x / threads;\n        const auto    scale = __ldg(&scales[blockIdx.x * gridDim.y + blockIdx.y]);\n        const auto    di    = blockIdx.y * block_size + col * vec_size;\n        PRAGMA_UNROLL\n        for (int s = 0; s < S; ++s) {\n            const auto ti = blockIdx.x * block_size + s * rows + row;\n            if (ti < num && di < dim) {\n                Array<T, vec_size> x;\n                Ldg(x, src + (int64_t)ti * dim + di);\n                Array<Tout, vec_size> y;\n                PRAGMA_UNROLL\n                for (int i = 0; i < vec_size; ++i) {\n                    y[i] = Tout(static_cast<Tscale>(x[i]) * scale);\n                }\n                Store(out + (int64_t)ti * dim + di, y);\n            }\n        }\n    }\n}\n\nvoid DequantizeSymmBlock(Ref<Tensor> out_, Ref<Tensor> src_, const Tensor& scale, cudaStream_t st)\n{\n    auto invoke = [&](auto tout) {\n        using T      = fp8_e4m3_t;\n        using Tout   = decltype(tout);\n        using Tscale = float;\n\n        constexpr int block_size = 128;\n        constexpr int vec_size   = 8;\n\n        auto& out = out_.get();\n        auto& src = src_.get();\n\n        if (!out) {\n            out = Tensor_<Tout>{src.layout(), kDEVICE};\n        }\n        else {\n            TM_CHECK(out.layout() == src.layout());\n        }\n\n        const auto [num, dim] = src.shapes(0, 1);\n\n        const int bnum = cdiv<int>(num, block_size);\n        const int bdim = cdiv<int>(dim, block_size);\n\n        constexpr int cta_size = 1024;\n        const dim3    grid(bnum, bdim);\n\n        dequant_symm_block<vec_size, cta_size, block_size><<<grid, cta_size, 0, st>>>(  //\n            out.data<Tout>(),\n            src.data<T>(),\n            scale.data<Tscale>(),\n            num,\n            dim);\n    };\n\n    if (!out_.get()) {\n        return invoke(nv_bfloat16{});\n    }\n\n    TM_DISPATCH_PRIMARY_DTYPES(out_.get().dtype(), invoke);\n}\n\ntemplate<int start_bit, int end_bit, class D, class T>\n__global__ void Compact1D_Kernel(D* d, const T* s, int n)\n{\n    constexpr int bits     = end_bit - start_bit;\n    constexpr int vec_size = bitsof<D> / bits;\n\n    const auto idx = threadIdx.x + (int64_t)blockIdx.x * blockDim.x;\n\n    if (idx * vec_size >= n) {\n        return;\n    }\n\n    Array<T, vec_size> s_vec;\n\n    Load(s_vec, &s[idx * vec_size]);\n\n    constexpr T mask = ((1 << bits) - 1) << start_bit;\n\n    D pack{};\n\n    PRAGMA_UNROLL\n    for (int i = 0; i < vec_size; ++i) {\n        pack |= ((s_vec[i] & mask) >> start_bit) << (i * bits);\n    }\n\n    d[idx] = pack;\n}\n\ntemplate<class T_, int bits_, class Q_>\nstruct IntegralQuantizer {\n\n    using T = T_;\n    using Q = Q_;\n\n    using Scale = T;\n    using Zero  = T;\n\n    static constexpr int bits  = bits_;\n    static constexpr int max_q = (1 << bits) - 1;\n\n    template<class T, int N, class R>\n    __device__ void operator()(const Array<T, N>&    x,  //\n                               const Array<bool, N>& pred,\n                               const R&              rbits,\n                               Array<Q, N>&          q,\n                               Array<T, N>&          d,\n                               T&                    scale,\n                               T&                    zero,\n                               int                   threads) const\n    {\n        auto f = cast<float>(x);\n\n        float minval = std::numeric_limits<float>::infinity();\n        float maxval = -minval;\n\n        PRAGMA_UNROLL\n        for (int i = 0; i < N; ++i) {\n            if (pred[i]) {\n                minval = fminf(minval, f[i]);\n                maxval = fmaxf(maxval, f[i]);\n            }\n        }\n\n        for (int offset = threads / 2; offset >= 1; offset /= 2) {\n            minval = fminf(minval, __shfl_xor_sync((uint32_t)-1, minval, offset));\n            maxval = fmaxf(maxval, __shfl_xor_sync((uint32_t)-1, maxval, offset));\n        }\n\n        auto clamp = [](int x, int a, int b) { return max(a, min(b, x)); };\n\n        float scale_ = fmaxf(maxval - minval, 1e-5f) / (float)max_q;\n        int   zero_  = clamp(-round<int32_t>(minval / scale_), 0, max_q);\n\n        scale = (T)scale_;\n        zero  = (T)zero_;\n\n        // T sz = zero_ * scale_;\n\n        PRAGMA_UNROLL\n        for (int i = 0; i < N; ++i) {\n            q[i] = clamp(round<int32_t>(f[i] / scale_) + zero_, 0, max_q);\n            d[i] = (T)((int)q[i] - zero_) * (T)scale_;\n            // d[i] = __hfma((T)q[i], (T)scale_, -sz);\n        }\n    }\n};\n\ntemplate<class T_, int E, int M, class Q_>\nstruct FloatingPointQuantizer {\n\n    using T = T_;\n    using Q = Q_;\n\n    using Scale = uint8_t;\n    using Zero  = void;\n\n    using traits = FloatingPoint<E, M>;\n\n    static constexpr int bits = traits::bits;\n\n    float pre_rounding_scale_;\n\n    __host__ __device__ FloatingPointQuantizer(float pre_rounding_scale = 1.f): pre_rounding_scale_{pre_rounding_scale}\n    {\n    }\n\n    template<int N, class Z, class R>\n    __device__ void operator()(const Array<T, N>&    x,  //\n                               const Array<bool, N>& pred,\n                               const R&              rbits,\n                               Array<Q, N>&          q,\n                               Array<T, N>&          d,\n                               Scale&                scale,\n                               Z                     ignore,\n                               int                   threads) const\n    {\n        auto f = cast<float>(x);\n\n        float absmax = 0.f;\n\n        PRAGMA_UNROLL\n        for (int i = 0; i < N; ++i) {\n            if (pred[i]) {\n                absmax = fmaxf(absmax, fabsf(f[i]));\n            }\n        }\n\n        for (int offset = threads / 2; offset >= 1; offset /= 2) {\n            absmax = fmaxf(absmax, __shfl_xor_sync((uint32_t)-1, absmax, offset));\n        }\n\n        auto get_exponent = [](float x) -> int { return (__float_as_uint(x) >> 23U) & 0xFFU; };\n\n        int scale_i32 = get_exponent(absmax) - (traits::exponent_bias + 1);\n\n        // int scale_i32 = 127;\n\n        if (scale_i32 < 0) {  // absmax(group) < 2*2^-125, flush to zero\n            scale_i32 = 0;\n            f         = {};\n        }\n\n        scale = scale_i32;\n\n        float scale_f32 = __uint_as_float((uint32_t)scale_i32 << 23U);\n\n        PRAGMA_UNROLL\n        for (int i = 0; i < N; ++i) {\n            q[i] = traits::from_f32((f[i] * pre_rounding_scale_) / scale_f32, rbits[i]);\n            d[i] = (traits::to_f32(q[i]) * scale_f32) / pre_rounding_scale_;\n        }\n    }\n};\n\ntemplate<int vec_size,\n         class Quantizer,\n         class T = typename Quantizer::T,\n         class Q = typename Quantizer::Q,\n         class S = typename Quantizer::Scale,\n         class Z = typename Quantizer::Zero>\n__global__ void QuantizeGroupwise_Kernel(Quantizer       quantizer,\n                                         Q*              q,\n                                         S*              s,\n                                         Z*              z,\n                                         T*              d,\n                                         const T*        x,\n                                         const unsigned* r,\n                                         Array<int, 2>   stride_q,\n                                         Array<int, 2>   stride_s,\n                                         Array<int, 2>   stride_d,\n                                         Array<int, 2>   stride_x,\n                                         int             M,\n                                         int             K,\n                                         int             G)\n{\n    if constexpr (TURBOMIND_ARCH_BF16_GUARD(data_type_v<T>)) {\n        static constexpr bool has_zero = !std::is_void_v<Z>;\n\n        int m = blockIdx.x;\n        int k = threadIdx.x + blockIdx.y * blockDim.x;\n\n        const int threads_per_group = G / vec_size;\n        const int warp_k            = WARP_SIZE * vec_size;\n\n        k *= vec_size;\n\n        for (; k < round_up(K, warp_k); k += gridDim.y * blockDim.x * vec_size) {\n\n            Array<T, vec_size>    x_vec;\n            Array<bool, vec_size> p_vec;\n\n            Array<unsigned, vec_size> r_vec;\n\n            PRAGMA_UNROLL\n            for (int i = 0; i < vec_size; ++i) {\n                p_vec[i] = k + i < K;\n                x_vec[i] = p_vec[i] ? x[stride_x[0] * m + stride_x[1] * (k + i)] : T{0};\n                if (r) {\n                    r_vec[i] = p_vec[i] ? r[m * K + k] : 0;\n                }\n            }\n\n            Array<Q, vec_size> q_vec;\n            Array<T, vec_size> d_vec;\n\n            S                                    scale;\n            std::conditional_t<has_zero, Z, int> zero{};\n\n            auto invoke = [&](auto rbits) {\n                quantizer(x_vec, p_vec, rbits, q_vec, d_vec, scale, zero, threads_per_group);\n            };\n\n            r ? invoke(r_vec) : invoke(Array<char, vec_size>{});\n\n            PRAGMA_UNROLL\n            for (int i = 0; i < vec_size; ++i) {\n                const auto idx = stride_q[0] * m + stride_q[1] * (k + i);\n                if (p_vec[i]) {\n                    q[idx] = q_vec[i];\n                    d[idx] = d_vec[i];\n                }\n            }\n            if (threadIdx.x % threads_per_group == 0) {\n                const auto idx = stride_s[0] * m + stride_s[1] * (k / G);\n                if (p_vec[0]) {\n                    s[idx] = (S)scale;\n                    if constexpr (has_zero) {\n                        z[idx] = (S)zero;\n                    }\n                }\n            }\n        }\n    }\n}\n\nvoid QuantizeGroupwise(Tensor            quant,    // (m,k)\n                       Tensor            scales,   // (m,k/g)\n                       Tensor            zeros,    // (m,k/g)\n                       Tensor            dequant,  // (m,k)\n                       Tensor            src,      // (m,k)\n                       Buffer_<unsigned> rbits,    // (m*k)\n                       int               group_size)\n{\n    // std::cout << quant << std::endl;\n    // std::cout << scales << std::endl;\n    // std::cout << zeros << std::endl;\n    // std::cout << dequant << std::endl;\n    // std::cout << src << std::endl;\n\n    if (zeros) {\n        TM_CHECK(scales.layout() == zeros.layout());\n    }\n    TM_CHECK(quant.shape() == dequant.shape());\n    TM_CHECK(quant.size() == quant.layout().cosize());\n\n    auto stream = core::Context::stream().handle();\n\n    auto stride_2d = [](const Tensor& t) {\n        TM_CHECK_EQ(t.ndim(), 2);\n        auto [a, b] = t.strides(0, 1);\n        return Array<int, 2>{(int)a, (int)b};\n    };\n\n    const int m = src.shape(0);\n    const int k = src.shape(1);\n\n    // std::cout << \"m\" << m << \"k\" << k << \"\\n\";\n\n    auto invoke = [&](auto quantizer) {\n        using Quantizer = decltype(quantizer);\n\n        using T = typename Quantizer::T;\n        using Q = typename Quantizer::Q;\n        using S = typename Quantizer::Scale;\n        using Z = typename Quantizer::Zero;\n\n        constexpr int bits = Quantizer::bits;\n\n        Tensor_<Q> proxy = empty_like(quant, data_type_v<Q>);\n\n        constexpr int vec = 8;\n\n        TM_CHECK((group_size & (group_size - 1)) == 0);\n        TM_CHECK_GE(group_size, vec);\n        TM_CHECK_LE(group_size, WARP_SIZE * vec);\n\n        const int threads = round_up(std::min(cdiv(k, vec), 1024), WARP_SIZE);\n\n        QuantizeGroupwise_Kernel<vec><<<m, threads, 0, stream>>>(quantizer,\n                                                                 proxy.data(),\n                                                                 scales.data<S>(),\n                                                                 zeros.data_or((Z*)nullptr),\n                                                                 dequant.data<T>(),\n                                                                 src.data<T>(),\n                                                                 rbits.data_or(nullptr),\n                                                                 stride_2d(proxy),\n                                                                 stride_2d(scales),\n                                                                 stride_2d(dequant),\n                                                                 stride_2d(src),\n                                                                 m,\n                                                                 k,\n                                                                 group_size);\n\n        Compact1D_Kernel<0, bits><<<cdiv((int)quant.size(), 512), 512, 0, stream>>>(\n            (uint32_t*)quant.raw_data(), (Q*)proxy.raw_data(), quant.size());\n    };\n\n    if (0) {}\n    else if (src.dtype() == kHalf && quant.dtype() == kUint4) {\n        invoke(IntegralQuantizer<half_t, 4, uint16_t>{});\n    }\n    else if (src.dtype() == kBfloat16 && quant.dtype() == kFloat4_e2m1) {\n        invoke(FloatingPointQuantizer<bfloat16_t, 2, 1, uint16_t>{});\n    }\n    else if (src.dtype() == kHalf && quant.dtype() == kFloat4_e2m1) {\n        invoke(FloatingPointQuantizer<half_t, 2, 1, uint16_t>{});\n    }\n    else {\n        TM_CHECK(0) << \"Unsupported types: \" << to_string(src.dtype()) << \", \" << to_string(quant.dtype());\n    }\n}\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/kernels/quantization.cuh",
    "content": "\n#include \"src/turbomind/kernels/core/array.h\"\n#include \"src/turbomind/kernels/core/common.h\"\n\nnamespace turbomind {\n\n#if 0\ntemplate<int threads, class T, int N>\n__device__ Array<T, 2> find_minmax(const Array<T, N>& a)\n{\n    static_assert((threads & (threads - 1)) == 0);\n    static_assert(sizeof(Array<T, 2>) == sizeof(uint32_t));\n    uint32_t data;\n    auto&    minmax = reinterpret_cast<Array<T, 2>&>(data);\n    minmax          = {a[0], a[0]};\n    PRAGMA_UNROLL\n    for (int i = 1; i < N; ++i) {\n        minmax = hmin(minmax[0], a[i]);\n        minmax = hmax(minmax[1], a[i]);\n    }\n    PRAGMA_UNROLL\n    for (int mask = threads / 2; mask > 0; mask /= 2) {\n        uint32_t tmp = __shfl_xor_sync(uint32_t(-1), data, mask);\n        auto&    vec = reinterpret_cast<Array<T, 2>&>(tmp);\n        minmax[0]    = hmin(minmax[0], vec[0]);\n        minmax[1]    = hmax(minmax[1], vec[1]);\n    }\n    return minmax;\n}\n#endif\n\ntemplate<int threads, class T, int N>\n__device__ T find_absmax(const Array<T, N>& a)\n{\n    static_assert((threads & (threads - 1)) == 0);\n    static_assert(sizeof(Array<T, 2>) == sizeof(uint32_t));\n    uint32_t data;\n    auto&    val = *reinterpret_cast<T*>(&data);\n    val          = __habs(a[0]);\n    PRAGMA_UNROLL\n    for (int i = 1; i < N; ++i) {\n        val = __hmax(val, __habs(a[i]));\n    }\n    PRAGMA_UNROLL\n    for (int mask = threads / 2; mask > 0; mask /= 2) {\n        uint32_t tmp = __shfl_xor_sync(uint32_t(-1), data, mask);\n        auto&    x   = *reinterpret_cast<T*>(&tmp);\n        val          = __hmax(val, x);\n    }\n    return val;\n}\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/kernels/quantization.h",
    "content": "#include \"src/turbomind/core/core.h\"\n\nnamespace turbomind {\n\nvoid QuantizeSymm(Tensor& out, Tensor& scale, const Tensor& src, cudaStream_t st);\n\nvoid DequantizeSymm(Tensor& out, const Tensor& src, const Tensor& scale, cudaStream_t st);\n\nvoid QuantizeSymmBlock(Ref<Tensor> out_, Ref<Tensor> scale_, const Tensor& src, cudaStream_t st);\n\nvoid DequantizeSymmBlock(Ref<Tensor> out_, Ref<Tensor> src_, const Tensor& scale, cudaStream_t st);\n\nvoid QuantizeGroupwise(Tensor            quant,    // (m,k)\n                       Tensor            scales,   // (m,k/g)\n                       Tensor            zeros,    // (m,k/g)\n                       Tensor            dequant,  // (m,k)\n                       Tensor            src,      // (m,k)\n                       Buffer_<unsigned> rbits,    // (m*k)\n                       int               group_size);\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/kernels/reduce_kernel_utils.cuh",
    "content": "/*\n * Copyright (c) 2020-2023, NVIDIA CORPORATION.  All rights reserved.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n#pragma once\n#include <array>\n#include <assert.h>\n#if ((__CUDACC_VER_MAJOR__ > 11) || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))\n#include <cooperative_groups/reduce.h>\n#else\n#include <cooperative_groups.h>\n#endif\n#include \"src/turbomind/utils/cuda_bf16_wrapper.h\"\n#include \"src/turbomind/utils/cuda_type_utils.cuh\"\n#include <cuda_fp16.h>\n#include <cuda_runtime.h>\n#include <curand_kernel.h>\n#include <float.h>\n#include <type_traits>\n\nnamespace cg = cooperative_groups;\n\nnamespace turbomind {\n\ntemplate<int VPT>\nstruct BytesToType;\n\ntemplate<>\nstruct BytesToType<2> {\n    using type = uint16_t;\n};\ntemplate<>\nstruct BytesToType<4> {\n    using type = uint32_t;\n};\ntemplate<>\nstruct BytesToType<8> {\n    using type = uint64_t;\n};\ntemplate<>\nstruct BytesToType<16> {\n    using type = float4;\n};\n\ntemplate<typename T>\n__device__ inline T getMaxValue();\n\ntemplate<>\n__device__ inline float getMaxValue<float>()\n{\n    return FLT_MAX;\n}\n\ntemplate<>\n__device__ inline half getMaxValue<half>()\n{\n    return __ushort_as_half((unsigned short)0x7BFFU);\n}\n\n#ifdef ENABLE_BF16\ntemplate<>\n__device__ inline __nv_bfloat16 getMaxValue<__nv_bfloat16>()\n{\n#if __CUDA_ARCH__ >= 800\n    return __ushort_as_bfloat16((unsigned short)0x7F7FU);\n#endif\n    return {};\n}\n#endif\n\ntemplate<typename T>\n__device__ inline T getInfValue();\n\ntemplate<>\n__device__ inline float getInfValue<float>()\n{\n    return INFINITY;\n}\n\ntemplate<>\n__device__ inline half getInfValue<half>()\n{\n    return __ushort_as_half((unsigned short)0x7C00U);\n}\n\n#ifdef ENABLE_BF16\ntemplate<>\n__device__ inline __nv_bfloat16 getInfValue<__nv_bfloat16>()\n{\n#if __CUDA_ARCH__ >= 800\n    return __ushort_as_bfloat16((unsigned short)0x7F80U);\n#endif\n    return {};\n}\n#endif\n\ntemplate<int Bytes>\n__device__ inline void copy(const void* local, void* data)\n{\n    using T = typename BytesToType<Bytes>::type;\n\n    const T* in  = static_cast<const T*>(local);\n    T*       out = static_cast<T*>(data);\n    *out         = *in;\n}\n\n#define HALF_FLT_MAX 65504.F\n#define FINAL_MASK 0xffffffff\n\ntemplate<typename T>\n__inline__ __device__ T warpReduceSum(T val)\n{\n#pragma unroll\n    for (int mask = 16; mask > 0; mask >>= 1)\n        val = add(val, __shfl_xor_sync(FINAL_MASK, val, mask, 32));  //__shfl_sync bf16 return float when sm < 80\n    return val;\n}\n\n/* Calculate the sum of all elements in a block */\ntemplate<typename T>\n__inline__ __device__ T blockReduceSum(T val)\n{\n    static __shared__ T shared[32];\n    int                 lane = threadIdx.x & 0x1f;\n    int                 wid  = threadIdx.x >> 5;\n\n    val = warpReduceSum<T>(val);\n\n    if (lane == 0)\n        shared[wid] = val;\n\n    __syncthreads();\n\n    // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent\n    // blockDim.x is not divided by 32\n    val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : (T)(0.0f);\n    val = warpReduceSum<T>(val);\n\n    return val;\n}\n\ntemplate<typename T>\n__inline__ __device__ T warpReduceMax(T val)\n{\n#pragma unroll\n    for (int mask = 16; mask > 0; mask >>= 1)\n        val = max(val, __shfl_xor_sync(FINAL_MASK, val, mask, 32));\n    return val;\n}\n\n/* Calculate the maximum of all elements in a block */\ntemplate<typename T>\n__inline__ __device__ T blockReduceMax(T val)\n{\n    static __shared__ T shared[32];\n    int                 lane = threadIdx.x & 0x1f;  // in-warp idx\n    int                 wid  = threadIdx.x >> 5;    // warp idx\n\n    val = warpReduceMax(val);  // get maxx in each warp\n\n    if (lane == 0)  // record in-warp maxx by warp Idx\n        shared[wid] = val;\n\n    __syncthreads();\n\n    // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent\n    // blockDim.x is not divided by 32\n    val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : -1e20f;\n    val = warpReduceMax(val);\n\n    return val;\n}\n\n/* Calculate the maximum of all elements in a block */\ntemplate<typename T>\n__inline__ __device__ T blockAllReduceMax(T val)\n{\n    static __shared__ T shared[32];\n    int                 lane = threadIdx.x & 0x1f;  // in-warp idx\n    int                 wid  = threadIdx.x >> 5;    // warp idx\n\n    val = warpReduceMax(val);  // get maxx in each warp\n\n    if (lane == 0)  // record in-warp maxx by warp Idx\n        shared[wid] = val;\n\n    __syncthreads();\n\n    // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent\n    // blockDim.x is not divided by 32\n    val = (lane < (blockDim.x / 32.f)) ? shared[lane] : -1e20f;\n    val = warpReduceMax(val);\n\n    return val;\n}\n\ntemplate<typename T, int NUM>\n__inline__ __device__ T warpReduceSumV2(T* val)\n{\n#pragma unroll\n    for (int i = 0; i < NUM; i++) {\n#pragma unroll\n        for (int mask = 16; mask > 0; mask >>= 1)\n            val[i] += __shfl_xor_sync(FINAL_MASK, val[i], mask, 32);\n    }\n    return (T)(0.0f);\n}\n\ntemplate<typename T, int NUM>\n__inline__ __device__ T blockReduceSumV2(T* val)\n{\n    static __shared__ T shared[NUM][33];\n    int                 lane = threadIdx.x & 0x1f;\n    int                 wid  = threadIdx.x >> 5;\n\n    warpReduceSumV2<T, NUM>(val);\n\n    if (lane == 0) {\n#pragma unroll\n        for (int i = 0; i < NUM; i++) {\n            shared[i][wid] = val[i];\n        }\n    }\n\n    __syncthreads();\n\n    bool is_mask = threadIdx.x < (blockDim.x / 32.f);\n#pragma unroll\n    for (int i = 0; i < NUM; i++) {\n        val[i] = is_mask ? shared[i][lane] : (T)(0.0f);\n    }\n    warpReduceSumV2<T, NUM>(val);\n    return (T)0.0f;\n}\n\ntemplate<typename T, int NUM>\n__inline__ __device__ T warpReduceMaxV2(T* val)\n{\n#pragma unroll\n    for (int i = 0; i < NUM; i++) {\n#pragma unroll\n        for (int mask = 16; mask > 0; mask >>= 1)\n            val[i] = max(val[i], __shfl_xor_sync(FINAL_MASK, val[i], mask, 32));\n    }\n    return (T)(0.0f);\n}\n\ntemplate<typename T, int NUM>\n__inline__ __device__ T blockReduceMaxV2(T* val)\n{\n    static __shared__ T shared[32][NUM];\n    int                 lane = threadIdx.x & 0x1f;  // in-warp idx\n    int                 wid  = threadIdx.x >> 5;    // warp idx\n\n    warpReduceMaxV2<T, NUM>(val);  // get maxx in each warp\n\n    if (lane == 0)  // record in-warp maxx by warp Idx\n    {\n#pragma unroll\n        for (int i = 0; i < NUM; i++) {\n            shared[wid][i] = val[i];\n        }\n    }\n\n    __syncthreads();\n\n    // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent\n    // blockDim.x is not divided by 32\n    bool is_mask = threadIdx.x < (blockDim.x / 32.f);\n#pragma unroll\n    for (int i = 0; i < NUM; i++) {\n        val[i] = is_mask ? shared[lane][i] : (T)-1e20f;\n    }\n    warpReduceMaxV2<T, NUM>(val);\n\n    return (T)0.0f;\n}\n\ntemplate<int NUM>\n__inline__ __device__ void cgBlockReduceSumElements(float* element_list, float* cgBlockReduceSumElements_shm)\n{\n    cg::thread_block          cta  = cg::this_thread_block();\n    cg::thread_block_tile<32> tile = cg::tiled_partition<32>(cta);\n\n    const int tid    = cta.thread_rank();\n    const int blockz = blockDim.x;\n    for (int i = 0; i < NUM; i++) {\n#if ((__CUDACC_VER_MAJOR__ > 11) || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))\n        cgBlockReduceSumElements_shm[i * blockz + tid] = cg::reduce(tile, element_list[i], cg::plus<float>());\n#else\n        // TODO Add implementation here\n        if (threadIdx.x == 0 && blockIdx.x == 0) {\n            printf(\"[ERROR] Not support cgBlockReduceSumElements when CUDA < 11 \\n\");\n            assert(false);\n        }\n#endif\n    }\n    cg::sync(cta);\n    if (tid == 0) {\n#pragma unroll\n        for (int i = 0; i < NUM; i++) {\n            float beta = 0.0f;\n            for (int j = 0; j < blockz; j += 32) {\n                beta += cgBlockReduceSumElements_shm[i * blockz + j];\n            }\n            element_list[i] = beta;\n        }\n    }\n}\n\ntemplate<typename T, int MAX_K>\nstruct TopK {\n    int p[MAX_K];\n    T   u[MAX_K];\n\n    __device__ __forceinline__ void insert(T elem, int elem_id)\n    {\n        if (elem > u[MAX_K - 1] || (p[MAX_K - 1] == -1) || ((elem == u[MAX_K - 1]) && (elem_id < p[MAX_K - 1])))\n        // if (elem > u[MAX_K-1] || ((elem == u[MAX_K-1]) && (elem_id < p[MAX_K-1])))\n        {\n            u[MAX_K - 1] = elem;\n            p[MAX_K - 1] = elem_id;\n        }\n\n        for (int k = MAX_K - 2; k >= 0; --k) {\n            if ((u[k + 1] > u[k]) || (p[k] == -1) || ((u[k + 1] == u[k]) && (p[k + 1] < p[k])))\n            // if ((u[k+1] > u[k]) || ((u[k+1] == u[k])&&(p[k+1] < p[k])))\n            {\n                T   u2   = u[k];\n                int p2   = p[k];\n                u[k]     = u[k + 1];\n                p[k]     = p[k + 1];\n                u[k + 1] = u2;\n                p[k + 1] = p2;\n            }\n        }\n    }\n\n    __device__ __forceinline__ void init()\n    {\n        const bool IS_FP16   = std::is_same<T, half>::value;\n        const T    MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;\n\n        for (int i = 0; i < MAX_K; i++) {\n            p[i] = -1;\n            u[i] = -MAX_T_VAL;\n        }\n    }\n};\n\ntemplate<typename T, int MAX_K>\n__device__ __forceinline__ TopK<T, MAX_K> reduce_topk_op(const TopK<T, MAX_K>& a, const TopK<T, MAX_K>& b)\n{\n    TopK<T, MAX_K> res = a;\n    for (int i = 0; i < MAX_K; ++i)\n        res.insert(b.u[i], b.p[i]);\n    return res;\n}\n\ntemplate<typename T>\nstruct TopK_2 {\n    int p = 0;\n    T   u = -getInfValue<T>();\n\n    __device__ __forceinline__ void insert(T elem, int elem_id)\n    {\n        if (elem > u) {\n            u = elem;\n            p = elem_id;\n        }\n    }\n\n    __device__ __forceinline__ void init()\n    {\n        u = -getInfValue<T>();\n        p = 0;\n    }\n};\n\ntemplate<typename T>\n__device__ __forceinline__ TopK_2<T> reduce_topk_op_2(const TopK_2<T>& a, const TopK_2<T>& b)\n{\n    return a.u > b.u ? a : b;\n}\n\ntemplate<typename T>\n__device__ __forceinline__ T clamp_inf_for_half(const float input)\n{\n    return input;\n}\n\ntemplate<>\n__device__ __forceinline__ half clamp_inf_for_half(const float input)\n{\n    // clamp inf values to enable fp16 training\n    return input > 0.0f ? (half)min(input, HALF_FLT_MAX - 1000) : (half)max(input, -HALF_FLT_MAX + 1000);\n}\n\n#ifdef ENABLE_BF16\ntemplate<>\n__device__ __forceinline__ __nv_bfloat16 clamp_inf_for_half(const float input)\n{\n    return __float2bfloat16(input);\n}\n#endif\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/kernels/sampling_kernels.cu",
    "content": "#ifndef CUDART_VERSION\n#error CUDART_VERSION Undefined!\n#elif (CUDART_VERSION >= 11000)\n#include <cub/cub.cuh>\n#else\n#include \"3rdparty/cub/cub.cuh\"\n#endif\n#include \"src/turbomind/kernels/sampling_kernels.h\"\n#include \"src/turbomind/kernels/sampling_topp_kernels.h\"\n#include \"src/turbomind/utils/constant.h\"\n\nnamespace turbomind {\n\ntemplate<typename T, int BLOCK_SIZE>\n__global__ void sampling(const T*       logits,\n                         const int      stride,\n                         const int*     indices,\n                         const int*     kept,\n                         curandState_t* curandstate,\n                         int*           output_ids,\n                         int*           sequence_length,\n                         T*             sampled_logprobs,\n                         int*           sampled_indexes,\n                         int*           sampled_nums)\n{\n    int tid      = threadIdx.x;\n    int batch_id = blockIdx.x;\n    int n        = kept[batch_id];\n\n    logits += stride * batch_id;\n    indices += stride * batch_id;\n\n    __shared__ float rand_num_s;\n    __shared__ int   selected;\n    if (tid == 0) {\n        rand_num_s = curand_uniform(curandstate + batch_id);\n    }\n    __syncthreads();\n\n    typedef cub::BlockScan<float, BLOCK_SIZE>  BlockScan;\n    __shared__ typename BlockScan::TempStorage temp_storage;\n\n    float                 local_rand = rand_num_s;\n    float                 prefix_sum = 0.f;\n    BlockPrefixCallbackOp prefix_op{0};\n    int                   end = (n + BLOCK_SIZE - 1) / BLOCK_SIZE * BLOCK_SIZE;\n    for (int i = tid; i < end; i += BLOCK_SIZE) {\n        float thread_logit = (i < n) ? static_cast<float>(logits[i]) : 0.f;\n        BlockScan(temp_storage).InclusiveSum(thread_logit, prefix_sum, prefix_op);\n        auto count = __syncthreads_count(prefix_sum > local_rand);\n        if (count != 0 || (i + BLOCK_SIZE) >= end) {\n            if (tid == min(BLOCK_SIZE - count, BLOCK_SIZE - 1)) {\n                selected             = min(i, n - 1);\n                output_ids[batch_id] = indices[selected];\n            }\n            break;\n        }\n    }\n\n    if (tid == 0) {\n        sequence_length[batch_id] += 1;\n    }\n\n    if (sampled_logprobs != nullptr && sampled_indexes != nullptr && sampled_nums != nullptr) {\n        __syncthreads();\n        sampled_logprobs += batch_id * kMaxLogProb;\n        sampled_indexes += batch_id * kMaxLogProb;\n        int end = min(n, kMaxLogProb);\n        for (int i = tid; i < end; i += BLOCK_SIZE) {\n            sampled_logprobs[i] = logf(logits[i]);\n            sampled_indexes[i]  = indices[i];\n        }\n        if (n > kMaxLogProb && selected >= kMaxLogProb) {\n            if ((kMaxLogProb - 1 + BLOCK_SIZE - tid) % BLOCK_SIZE == 0) {\n                sampled_logprobs[kMaxLogProb - 1] = logf(logits[selected]);\n                sampled_indexes[kMaxLogProb - 1]  = indices[selected];\n            }\n        }\n        sampled_nums[batch_id] = min(n, kMaxLogProb);\n    }\n}\n\ntemplate<typename T>\nvoid invokeSampling(SamplingParams& params, cudaStream_t stream)\n{\n    const int grid  = params.batch_size;\n    const int block = 256;\n    sampling<T, block><<<grid, block, 0, stream>>>((T*)params.logits,\n                                                   params.stride,\n                                                   params.indices,\n                                                   params.kept,\n                                                   params.curandstate,\n                                                   params.output_ids,\n                                                   params.sequence_length,\n                                                   (T*)params.sampled_logprobs,\n                                                   params.sampled_indexes,\n                                                   params.sampled_nums);\n}\n\ntemplate void invokeSampling<float>(SamplingParams& params, cudaStream_t stream);\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/kernels/sampling_kernels.h",
    "content": "/*\n * Copyright (c) 2019-2023, NVIDIA CORPORATION.  All rights reserved.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#pragma once\n\n#include <cstdint>\n\n#include <cuda_runtime.h>\n#include <curand_kernel.h>\n\nnamespace turbomind {\n\nstruct SamplingParams {\n    void*          logits;\n    int            stride;\n    int*           indices;\n    int*           kept;\n    curandState_t* curandstate;\n    size_t         batch_size;\n    int*           output_ids;\n    int*           sequence_length;\n    void*          sampled_logprobs;\n    int*           sampled_indexes;\n    int*           sampled_nums;\n};\n\ntemplate<typename T>\nvoid invokeSampling(SamplingParams& params, cudaStream_t stream);\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/kernels/sampling_penalty_kernels.cu",
    "content": "/*\n * Copyright (c) 2020-2023, NVIDIA CORPORATION.  All rights reserved.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include <assert.h>\n#include <float.h>\n\n#include \"src/turbomind/kernels/core/array_ops.h\"\n#include \"src/turbomind/kernels/core/common.h\"\n#include \"src/turbomind/kernels/reduce_kernel_utils.cuh\"\n#include \"src/turbomind/kernels/sampling_penalty_kernels.h\"\n\nnamespace turbomind {\n\ntemplate<typename T, int vec_size>\n__global__ void batchApplyTemperaturePenalty_v2(T*           logits,\n                                                const T*     bias,\n                                                const float* temperatures,\n                                                const int    batch_size,\n                                                const int    vocab_size,\n                                                const int    vocab_size_padded)\n{\n    const int vi = blockIdx.x * blockDim.x + threadIdx.x;\n    const int bi = blockIdx.y;\n\n    __shared__ float shared_scale;\n\n    if (threadIdx.x == 0) {\n        shared_scale = fdividef(1.f, temperatures[bi] + 1e-6f);\n    }\n\n    __syncthreads();\n\n    const float scale = shared_scale;\n\n    logits += (size_t)bi * vocab_size_padded;\n\n    const int step = gridDim.x * blockDim.x * vec_size;\n\n    for (int i = vi * vec_size; i < vocab_size_padded; i += step) {\n        Array<T, vec_size> vec;\n        // load\n        if constexpr (sizeof(vec) >= sizeof(uint)) {\n            Load(vec, logits + i);\n        }\n        else {\n            PRAGMA_UNROLL\n            for (int j = 0; j < vec_size; ++j) {\n                vec[j] = logits[i + j];\n            }\n        }\n\n        // process\n        PRAGMA_UNROLL\n        for (int c = 0; c < vec_size; ++c) {\n            if (i + c < vocab_size) {\n                vec[c] = (float)vec[c] * scale;\n            }\n            else {\n                vec[c] = -getInfValue<T>();\n            }\n        }\n\n        // store\n        if constexpr (sizeof(vec) >= sizeof(uint)) {\n            Store(logits + i, vec);\n        }\n        else {\n            PRAGMA_UNROLL\n            for (int j = 0; j < vec_size; ++j) {\n                logits[i + j] = vec[j];\n            }\n        }\n    }\n}\n\ntemplate<typename T>\nvoid invokeBatchApplyTemperaturePenalty_v2(T*           logits,\n                                           const T*     bias,\n                                           const float* temperatures,\n                                           const int    batch_size,\n                                           const int    vocab_size,\n                                           const int    vocab_size_padded,\n                                           cudaStream_t stream)\n{\n\n    auto invoke = [&](auto vec_size) {\n        constexpr int threads        = 256;\n        const int     blocks_per_tok = (vocab_size_padded + threads * vec_size - 1) / (threads * vec_size);\n        const dim3    blocks(blocks_per_tok, batch_size);\n        batchApplyTemperaturePenalty_v2<T, vec_size.value><<<blocks, threads, 0, stream>>>(  //\n            logits,\n            bias,\n            temperatures,\n            batch_size,\n            vocab_size,\n            vocab_size_padded);\n    };\n\n    if (vocab_size_padded % 4 == 0) {\n        invoke(std::integral_constant<int, 4>{});\n    }\n    else if (vocab_size_padded % 2 == 0) {\n        invoke(std::integral_constant<int, 2>{});\n    }\n    else {\n        invoke(std::integral_constant<int, 1>{});\n    }\n}\n\n#define INSTANTIATE_INVOKE_BATCH_APPLY_TEMPERATURE_PENALTY_V2(T)                                                       \\\n    template void invokeBatchApplyTemperaturePenalty_v2(T*           logits,                                           \\\n                                                        const T*     bias,                                             \\\n                                                        const float* temperatures,                                     \\\n                                                        const int    batch_size,                                       \\\n                                                        const int    vocab_size,                                       \\\n                                                        const int    vocab_size_padded,                                \\\n                                                        cudaStream_t stream);\n\nINSTANTIATE_INVOKE_BATCH_APPLY_TEMPERATURE_PENALTY_V2(float);\n\ntemplate<class T>\n__global__ void RepetitionPenaltyKernel(T*                logits,\n                                        const float*      penalties,\n                                        const int* const* token_ids_ptrs,\n                                        const int*        sequence_length,\n                                        int               vocab_size,\n                                        int               mask_size)\n{\n    const int bi = blockIdx.x;\n\n    const int  seq_len   = sequence_length[bi];\n    const int* token_ids = token_ids_ptrs[bi];\n\n    extern __shared__ uint32_t masks[];  // up to 512k vocab size on 64k smem devices\n\n    for (int i = threadIdx.x; i < mask_size; i += blockDim.x) {\n        masks[i] = 0;\n    }\n\n    __syncthreads();\n\n    for (int ti = threadIdx.x; ti < seq_len; ti += blockDim.x) {\n        const int token_id = token_ids[ti];\n        atomicOr(&masks[token_id / 32], 1U << (token_id % 32));\n    }\n\n    __syncthreads();\n\n    logits += bi * (int64_t)vocab_size;\n\n    const float penalty = penalties[bi];\n\n    for (int di = threadIdx.x; di < vocab_size; di += blockDim.x) {\n        if (masks[di / 32] & (1U << (di % 32))) {\n            const float logit = logits[di];\n            logits[di]        = logit < 0.f ? logit * penalty : logit / penalty;\n        }\n    }\n}\n\nvoid ApplyRepetitionPenalty(Tensor&               logits,\n                            const Buffer_<float>& penalties,\n                            const Buffer_<int*>&  token_ids_ptrs,\n                            const Buffer_<int>&   sequence_length,\n                            cudaStream_t          stream)\n{\n    TM_CHECK_EQ(logits.ndim(), 2);\n    auto invoke = [&](auto dtype) {\n        using T                      = decltype(dtype);\n        const auto [bsz, vocab_size] = logits.shapes(0, 1);\n        const int mask_size          = cdiv((int)vocab_size, 32);\n        const int smem_size          = sizeof(uint32_t) * mask_size;\n        auto      func               = RepetitionPenaltyKernel<T>;\n        if (smem_size > (48 << 10)) {\n            TM_CHECK_EQ(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size), 0);\n        }\n        TM_LOG_DEBUG(\"smem_size = %d\", smem_size);\n        func<<<bsz, 1024, smem_size, stream>>>(\n            logits.data<T>(), penalties.data(), token_ids_ptrs.data(), sequence_length.data(), vocab_size, mask_size);\n    };\n    invoke(float{});\n}\n\ntemplate<typename T>\n__global__ void batchApplyMinLengthPenalty(T* __restrict__ logits,\n                                           const int* __restrict__ min_lengths,\n                                           const int* __restrict__ sequence_lengths,\n                                           const int vocab_size_padded,\n                                           const int batch_size,\n                                           const int* __restrict__ end_ids,\n                                           const int end_ids_size)\n{\n    int tid = threadIdx.x + blockIdx.x * blockDim.x;\n    int bid = tid / end_ids_size;\n    int eid = tid % end_ids_size;\n    if (bid < batch_size) {\n        int end_id = end_ids[bid * end_ids_size + eid];\n        if (end_id > 0 && sequence_lengths[bid] + 1 < min_lengths[bid]) {\n            T mask_val                               = -getMaxValue<T>();\n            logits[bid * vocab_size_padded + end_id] = mask_val;\n        }\n    }\n}\n\ntemplate<typename T>\nvoid invokeMinLengthPenalty(T*           logits,\n                            const int*   min_lengths,\n                            const int*   sequnece_lengths,\n                            const int    vocab_size_padded,\n                            const int    batch_size,\n                            const int*   end_ids,\n                            const int    end_ids_size,\n                            cudaStream_t stream)\n{\n    const dim3 block(std::min(batch_size * end_ids_size, 1024));\n    const dim3 grid((batch_size * end_ids_size + block.x - 1) / block.x);\n    batchApplyMinLengthPenalty<<<block, grid, 0, stream>>>(\n        logits, min_lengths, sequnece_lengths, vocab_size_padded, batch_size, end_ids, end_ids_size);\n}\n\n#define INSTANTIATE_INVOKE_MIN_LENGTH_PENALTY(T)                                                                       \\\n    template void invokeMinLengthPenalty(T*           logits,                                                          \\\n                                         const int*   min_lengths,                                                     \\\n                                         const int*   sequnece_lengths,                                                \\\n                                         const int    vocab_size_padded,                                               \\\n                                         const int    batch_size,                                                      \\\n                                         const int*   end_ids,                                                         \\\n                                         const int    end_ids_size,                                                    \\\n                                         cudaStream_t stream);\n\nINSTANTIATE_INVOKE_MIN_LENGTH_PENALTY(float);\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/kernels/sampling_penalty_kernels.h",
    "content": "/*\n * Copyright (c) 2020-2023, NVIDIA CORPORATION.  All rights reserved.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n#pragma once\n\n#include <cuda_fp16.h>\n\n#include \"src/turbomind/utils/cuda_utils.h\"\n\n#include \"src/turbomind/core/core.h\"\n\nnamespace turbomind {\n\nvoid ApplyRepetitionPenalty(Tensor&               logits,\n                            const Buffer_<float>& penalties,\n                            const Buffer_<int*>&  token_ids_ptrs,\n                            const Buffer_<int>&   sequence_length,\n                            cudaStream_t          stream);\n\ntemplate<typename T>\nvoid invokeBatchApplyTemperaturePenalty_v2(T*           logits,\n                                           const T*     bias,\n                                           const float* temperatures,\n                                           const int    batch_size,\n                                           const int    vocab_size,\n                                           const int    vocab_size_padd,\n                                           cudaStream_t stream);\n\ntemplate<typename T>\nvoid invokeMinLengthPenalty(T*           logits,\n                            const int*   min_lengths,\n                            const int*   sequnece_lengths,\n                            const int    vocab_size_padded,\n                            const int    batch_size,\n                            const int*   end_ids,\n                            const int    end_ids_size,\n                            cudaStream_t stream);\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/kernels/sampling_topk_kernels.cu",
    "content": "/*\n * Copyright (c) 2019-2023, NVIDIA CORPORATION.  All rights reserved.\n * Copyright (c) 2021, NAVER Corp.  Authored by CLOVA.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include <stdexcept>\n#ifndef CUDART_VERSION\n#error CUDART_VERSION Undefined!\n#elif (CUDART_VERSION >= 11000)\n#include <cub/cub.cuh>\n#else\n#include \"3rdparty/cub/cub.cuh\"\n#endif\n\n#include \"src/turbomind/core/core.h\"\n\n#include \"src/turbomind/kernels/reduce_kernel_utils.cuh\"\n#include \"src/turbomind/kernels/sampling_topk_kernels.h\"\n\n#include \"src/turbomind/utils/constant.h\"\n\nnamespace turbomind {\n\n// __global__ void curandInitialize(curandState_t* state, const int size, const unsigned long long random_seed)\n// {\n//     if (threadIdx.x + blockIdx.x * blockDim.x < size) {\n//         curand_init(random_seed, 0, 0, &state[blockIdx.x * blockDim.x + threadIdx.x]);\n//     }\n// }\n\n// void invokeCurandInitialize(curandState_t*           state,\n//                             const size_t             batch_size,\n//                             const unsigned long long random_seed,\n//                             cudaStream_t             stream)\n// {\n//     dim3 block(256);\n//     dim3 grid((int)(ceil(batch_size * 1.0 / 256)));\n//     curandInitialize<<<grid, block, 0, stream>>>(state, batch_size, random_seed);\n// }\n\n// __global__ void curandBatchInitialize(curandState_t* states, const int size, const unsigned long long* random_seeds)\n// {\n//     int idx = threadIdx.x + blockIdx.x * blockDim.x;\n//     if (idx < size) {\n//         curand_init(random_seeds[idx], 0, 0, &states[idx]);\n//     }\n// }\n\n// void invokeCurandBatchInitialize(curandState_t*  states,\n//                                  const size_t    batch_size,\n//                                  const uint64_t* random_seeds,\n//                                  cudaStream_t    stream)\n// {\n//     dim3 block(256);\n//     dim3 grid((int)(ceil(batch_size * 1.0 / 256)));\n//     static_assert(sizeof(uint64_t) == sizeof(unsigned long long));\n//     curandBatchInitialize<<<grid, block, 0, stream>>>(states, batch_size, (unsigned long long*)random_seeds);\n// }\n\n__global__ void InitializeRandomStates_Kernel(curandState_t*            states,\n                                              const unsigned long long* random_seeds,\n                                              const bool*               mask,\n                                              const size_t              size)\n{\n    if (auto idx = threadIdx.x + blockIdx.x * (size_t)blockDim.x; idx < size && mask[idx]) {\n        curand_init(random_seeds[idx], 0, 0, &states[idx]);\n    }\n}\n\nvoid InitializeRandomStates(\n    curandState_t* states, const uint64_t* random_seeds, const bool* mask, size_t batch_size, cudaStream_t stream)\n{\n    constexpr int threads = 128;\n    const int     blocks  = (batch_size + threads - 1) / threads;\n\n    static_assert(sizeof(uint64_t) == sizeof(unsigned long long));\n\n    InitializeRandomStates_Kernel<<<blocks, threads, 0, stream>>>(\n        (curandState_t*)states, (const unsigned long long*)random_seeds, mask, batch_size);\n}\n\ntemplate<typename T, int BLOCK_SIZE, int BLOCKS_PER_BEAM>\n__global__ void topKSortStage1(T*         logits,\n                               int*       topk_tmp_id_buf,\n                               T*         topk_tmp_val_buf,\n                               const int  max_top_k,\n                               const int* top_ks,\n                               const int  vocab_size,\n                               const int  vocab_size_padded)\n{\n    typedef cub::BlockReduce<TopK_2<T>, BLOCK_SIZE> BlockReduce;\n    __shared__ typename BlockReduce::TempStorage    temp_storage;\n\n    const int tid = threadIdx.x;\n    const int bid = blockIdx.x;\n\n    const int block_lane = bid % BLOCKS_PER_BEAM;  // block id for a beam\n    const int batch_id   = bid / BLOCKS_PER_BEAM;  // row id for log_probs\n    const int k          = top_ks[batch_id];\n    if (k == 0) {\n        return;\n    }\n\n    logits += batch_id * vocab_size_padded;\n    topk_tmp_id_buf += batch_id * BLOCKS_PER_BEAM * max_top_k + block_lane * k;\n    topk_tmp_val_buf += batch_id * BLOCKS_PER_BEAM * max_top_k + block_lane * k;\n\n    TopK_2<T> partial;\n    const T   MAX_T_VAL = getMaxValue<T>();\n\n    for (int ite = 0; ite < k; ite++) {\n        partial.init();\n#pragma unroll\n        for (int elem_id = tid + block_lane * BLOCK_SIZE; elem_id < vocab_size;\n             elem_id += BLOCK_SIZE * BLOCKS_PER_BEAM) {\n            partial.insert(logits[elem_id], elem_id);\n        }\n\n        TopK_2<T> total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_op_2<T>);\n\n        if (tid == 0) {\n            topk_tmp_id_buf[ite]  = total.p;\n            topk_tmp_val_buf[ite] = total.u;\n            if (total.u != -getInfValue<T>()) {\n                logits[total.p] = -MAX_T_VAL;\n            }\n        }\n        __syncthreads();\n    }\n}\n\ntemplate<typename T, int BLOCK_SIZE, int BLOCKS_PER_BEAM>\n__global__ void topKSortStage2(const int* top_ks,\n                               const int  max_top_k,\n                               const int* topk_tmp_id_buf,\n                               T*         topk_tmp_val_buf,\n                               const int  vocab_size_padded,\n                               T*         sorted_logits,\n                               int*       sorted_indices,\n                               int*       kept)\n{\n    const T MAX_T_VAL = getMaxValue<T>();\n\n    const int tid      = threadIdx.x;\n    const int batch_id = blockIdx.x;\n    const int k        = top_ks[batch_id];\n\n    if (k == 0) {\n        return;\n    }\n\n    sorted_indices += batch_id * vocab_size_padded;\n    sorted_logits += batch_id * vocab_size_padded;\n    const int size   = k * BLOCKS_PER_BEAM;\n    const int stride = max_top_k * BLOCKS_PER_BEAM;\n\n    typedef cub::BlockReduce<TopK_2<float>, BLOCK_SIZE> BlockReduce;\n    __shared__ typename BlockReduce::TempStorage        temp_storage;\n    extern __shared__ char                              array[];\n    __shared__ float                                    s_sum;\n    __shared__ float                                    s_max;\n    T*                                                  s_val  = topk_tmp_val_buf + batch_id * stride;\n    int*                                                s_id   = reinterpret_cast<int*>(array);\n    float*                                              s_val2 = reinterpret_cast<float*>(s_id + k);\n\n    if (tid == 0) {\n        kept[batch_id] = min(kept[batch_id], k);\n        s_sum          = 0.0f;\n    }\n\n    TopK_2<float> partial;\n    for (int ite = 0; ite < k; ite++) {\n        partial.init();\n#pragma unroll\n        for (int i = tid; i < size; i += BLOCK_SIZE) {\n            partial.insert((float)s_val[i], i);\n        }\n\n        TopK_2<float> total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_op_2<float>);\n\n        if (tid == 0) {\n            if (ite == 0) {\n                s_max = total.u;\n            }\n            s_id[ite]      = total.p;\n            s_val[total.p] = -MAX_T_VAL;\n            total.u        = __expf(total.u - s_max);\n            s_val2[ite]    = total.u;\n            s_sum += total.u;\n        }\n        __syncthreads();\n    }\n\n    // norm selected\n    float thread_sum = s_sum;\n    topk_tmp_id_buf += batch_id * stride;\n    for (int i = tid; i < k; i += BLOCK_SIZE) {\n        sorted_logits[i]  = (T)(s_val2[i] / thread_sum);\n        sorted_indices[i] = topk_tmp_id_buf[s_id[i]];\n    }\n}\n\n#define CASE_K(K_MAX, BLOCK_SIZE_1, BLOCK_SIZE_2, BLOCKS_PER_BEAM)                                                     \\\n    topKSortStage1<T, BLOCK_SIZE_1, BLOCKS_PER_BEAM>                                                                   \\\n        <<<batch_size * BLOCKS_PER_BEAM, BLOCK_SIZE_1, 0, stream>>>((T*)params.logits,                                 \\\n                                                                    topk_tmp_ids_buf,                                  \\\n                                                                    topk_tmp_val_buf,                                  \\\n                                                                    max_top_k,                                         \\\n                                                                    params.top_ks,                                     \\\n                                                                    params.vocab_size,                                 \\\n                                                                    params.vocab_size_padded);                         \\\n    topKSortStage2<T, BLOCK_SIZE_2, BLOCKS_PER_BEAM>                                                                   \\\n        <<<batch_size, BLOCK_SIZE_2, K_MAX * sizeof(int) + K_MAX * sizeof(float), stream>>>(params.top_ks,             \\\n                                                                                            params.max_top_k,          \\\n                                                                                            topk_tmp_ids_buf,          \\\n                                                                                            topk_tmp_val_buf,          \\\n                                                                                            params.vocab_size_padded,  \\\n                                                                                            (T*)params.sorted_logits,  \\\n                                                                                            params.sorted_indices,     \\\n                                                                                            params.kept);\n\ntemplate<typename T>\nvoid invokeTopKSortFilter(TopKSortFilterParams& params, cudaStream_t stream)\n{\n    const int max_top_k             = params.max_top_k;\n    const int batch_size            = params.batch_size;\n    const int max_block_per_beam    = 8;\n    int       topk_tmp_ids_buf_size = batch_size * max_top_k * max_block_per_beam;  // type int\n    int       topk_tmp_val_buf_size = batch_size * max_top_k * max_block_per_beam;  // type T\n\n    TM_CHECK(core::Context::stream().handle() == stream);\n\n    Buffer_<int> topk_tmp_ids(round_up(topk_tmp_ids_buf_size, 32), kDEVICE);\n    Buffer_<T>   topk_tmp_val(round_up(topk_tmp_val_buf_size, 32), kDEVICE);\n\n    auto topk_tmp_ids_buf = topk_tmp_ids.data();\n    auto topk_tmp_val_buf = topk_tmp_val.data();\n\n    if (max_top_k <= 16) {\n        CASE_K(16, 128, 128, 8);\n    }\n    else if (max_top_k <= 32) {\n        CASE_K(32, 256, 128, 8);\n    }\n    else if (max_top_k <= 64) {\n        CASE_K(64, 256, 256, 8);\n    }\n    else if (max_top_k <= 1024) {\n        CASE_K(1024, 256, 256, 8);\n    }\n    else {\n        throw std::domain_error(fmtstr(\"top-k kernel supports 1<=k<=1024 but got k=%d\", max_top_k));\n    }\n}\n\ntemplate void invokeTopKSortFilter<float>(TopKSortFilterParams& params, cudaStream_t stream);\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/kernels/sampling_topk_kernels.h",
    "content": "/*\n * Copyright (c) 2020-2023, NVIDIA CORPORATION.  All rights reserved.\n * Copyright (c) 2021, NAVER Corp.  Authored by CLOVA.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n#pragma once\n\n#include \"src/turbomind/utils/logger.h\"\n#include <curand_kernel.h>\nnamespace turbomind {\n\ntemplate<typename T>\nvoid invokeBatchTopKSampling(void*          workspace,\n                             size_t&        workspace_size,\n                             const T*       log_probs,\n                             int*           ids,\n                             int*           sequence_length,\n                             bool*          finished,\n                             float*         cum_log_probs,\n                             float*         output_log_probs,\n                             float*         sampled_logprobs,\n                             uint32_t*      sampled_indexes,\n                             uint32_t*      sampled_nums,\n                             curandState_t* curandstate,\n                             const int      max_top_k,\n                             const int*     top_ks,\n                             const float    top_p,\n                             const float*   top_ps,\n                             const int      vocab_size_padded,\n                             const int*     end_ids,\n                             cudaStream_t   stream,\n                             const int      batch_size,\n                             const bool*    skip_decode);\n\n// void invokeCurandInitialize(curandState_t*     state,\n//                             const size_t       batch_size,\n//                             unsigned long long random_seed,\n//                             cudaStream_t       stream);\n\n// void invokeCurandBatchInitialize(curandState_t*  states,\n//                                  const size_t    batch_size,\n//                                  const uint64_t* random_seeds,\n//                                  cudaStream_t    stream);\n\nvoid InitializeRandomStates(curandState_t*  states,  //\n                            const uint64_t* random_seeds,\n                            const bool*     mask,\n                            size_t          batch_size,\n                            cudaStream_t    stream);\n\nstruct TopKSortFilterParams {\n    void* logits;\n    void* sorted_logits;\n    int*  sorted_indices;\n    int*  kept;\n    int*  top_ks;\n    int   max_top_k;\n    int   batch_size;\n    int   vocab_size;\n    int   vocab_size_padded;\n};\n\ntemplate<typename T>\nvoid invokeTopKSortFilter(TopKSortFilterParams& params, cudaStream_t stream);\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/kernels/sampling_topp_kernels.cu",
    "content": "/*\n * Copyright (c) 2019-2023, NVIDIA CORPORATION.  All rights reserved.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#ifndef CUDART_VERSION\n#error CUDART_VERSION Undefined!\n#elif (CUDART_VERSION >= 11000)\n#include <cub/cub.cuh>\n#else\n#include \"3rdparty/cub/cub.cuh\"\n#endif\n\n#include \"src/turbomind/core/core.h\"\n\n#include \"src/turbomind/kernels/core/math.h\"\n#include \"src/turbomind/kernels/reduce_kernel_utils.cuh\"\n#include \"src/turbomind/kernels/sampling_topp_kernels.h\"\n\n#include \"src/turbomind/utils/constant.h\"\n#include \"src/turbomind/utils/cuda_utils.h\"\n\nnamespace turbomind {\n\n__global__ void topPSortInitialize(const int    vocab_size_padded,\n                                   const int    vocab_size,\n                                   const size_t batch_size,\n                                   const int*   top_ks,\n                                   int*         topp_id_val_buf,\n                                   int*         begin_offset_buf,\n                                   int*         end_offset_buf)\n{\n    int tid = threadIdx.x;\n    int bid = blockIdx.x;\n\n    // According to https://nvidia.github.io/cccl/cub/api/structcub_1_1DeviceSegmentedRadixSort.html\n    // `num_items` should match the largest element within the range `[d_end_offsets, d_end_offsets + num_segments)`\n    // We need to move `begin_offset` (instead of `end_offset`) to make empty intervals\n    if (bid == 0) {\n        for (int i = tid; i < batch_size; i += blockDim.x) {\n            int beg = i * vocab_size_padded;\n            int end = i * vocab_size_padded + vocab_size;\n            if (top_ks[i] > 0) {  // already sorted by topk, make it an empty interval\n                beg = end;\n            }\n            begin_offset_buf[i] = beg;\n            end_offset_buf[i]   = end;\n        }\n    }\n\n    int index = tid + bid * blockDim.x;\n    while (index < batch_size * vocab_size_padded) {\n        int batch_id = index / vocab_size_padded;\n        if (top_ks[batch_id] == 0) {\n            // sort by topp\n            topp_id_val_buf[index] = index % vocab_size_padded;\n        }\n        index += blockDim.x * gridDim.x;\n    }\n}\n\nvoid invokeTopPSortInitialize(const int    vocab_size_padded,\n                              const int    vocab_size,\n                              const size_t batch_size,\n                              const int*   top_ks,\n                              int*         topp_id_val_buf,\n                              int*         begin_offset_buf,\n                              int*         end_offset_buf,\n                              cudaStream_t stream)\n{\n    const size_t block_size = 512;\n    const size_t grid_size  = (batch_size * vocab_size_padded + block_size - 1) / block_size;\n    topPSortInitialize<<<grid_size, block_size, 0, stream>>>(\n        vocab_size_padded, vocab_size, batch_size, top_ks, topp_id_val_buf, begin_offset_buf, end_offset_buf);\n}\n\ntemplate<typename T>\nstatic __global__ void softmax(T* logits, const int vocab_size_padded, const int vocab_size, const int* kept)\n{\n    int bid = blockIdx.x;\n    int n   = kept[bid];\n    // skip softmax as it was already done by topk\n    if (n != vocab_size) {\n        return;\n    }\n    logits += bid * vocab_size_padded;\n\n    float            max_val = -1 * FLT_MAX;\n    __shared__ float s_max_val;\n    __shared__ float s_sum_val;\n\n    for (int tid = threadIdx.x; tid < vocab_size; tid += blockDim.x) {\n        max_val = max(max_val, (float)logits[tid]);\n    }\n\n    max_val = blockReduceMax<float>((float)max_val);\n    if (threadIdx.x == 0) {\n        s_max_val = max_val;\n    }\n    __syncthreads();\n\n    max_val       = s_max_val;\n    float sum_val = 0.0f;\n    for (int tid = threadIdx.x; tid < vocab_size; tid += blockDim.x) {\n        logits[tid] = __expf((float)logits[tid] - max_val);\n        sum_val += (float)logits[tid];\n    }\n\n    sum_val = blockReduceSum<float>(sum_val);\n    if (threadIdx.x == 0) {\n        s_sum_val = sum_val;\n    }\n    __syncthreads();\n\n    sum_val = s_sum_val;\n    for (int tid = threadIdx.x; tid < vocab_size; tid += blockDim.x) {\n        logits[tid] = ((float)logits[tid] / sum_val);\n    }\n}\n\ntemplate<typename T>\nvoid invokeSoftmax(T*           logits,\n                   const int    vocab_size_padded,\n                   const int    vocab_size,\n                   const int    batch_size,\n                   const int*   kept,\n                   cudaStream_t stream)\n{\n    dim3 grid(batch_size);\n    dim3 block(std::min(vocab_size_padded, 1024));\n    softmax<<<grid, block, 0, stream>>>(logits, vocab_size_padded, vocab_size, kept);\n}\n\n#define INSTANTIATE_INVOKE_SOFTMAX(T)                                                                                  \\\n    template void invokeSoftmax<T>(T * logits,                                                                         \\\n                                   const int    vocab_size_padded,                                                     \\\n                                   const int    vocab_size,                                                            \\\n                                   const int    batch_size,                                                            \\\n                                   const int*   kept,                                                                  \\\n                                   cudaStream_t stream);\n\nINSTANTIATE_INVOKE_SOFTMAX(float);\n\ntemplate<typename T, int MAX_K, int THREADBLOCK_SIZE>\n__launch_bounds__(THREADBLOCK_SIZE) __global__ void topp_beam_topk_kernel(const T*     logits,\n                                                                          T*           sorted_logits,\n                                                                          int*         sorted_indices,\n                                                                          int*         kept,\n                                                                          const int    vocab_size,\n                                                                          const int    vocab_size_padded,\n                                                                          int*         begin_offset_buf,\n                                                                          int*         end_offset_buf,\n                                                                          const float* top_ps,\n                                                                          const int*   top_ks)\n{\n    int thread_id = threadIdx.x;\n    int batch_id  = blockIdx.x;\n    if (top_ks[batch_id] > 0) {\n        return;\n    }\n\n    logits += batch_id * vocab_size_padded;\n    sorted_logits += batch_id * vocab_size_padded;\n    sorted_indices += batch_id * vocab_size_padded;\n    float p_threshold = top_ps[batch_id];\n\n    typedef cub::BlockReduce<TopK<T, MAX_K>, THREADBLOCK_SIZE> BlockReduce;\n    __shared__ typename BlockReduce::TempStorage               temp_storage;\n    TopK<T, MAX_K>                                             partial;\n\n    const T MAX_T_VAL = getMaxValue<T>();\n\n#pragma unroll\n    for (int i = 0; i < MAX_K; ++i) {\n        partial.p[i] = -1;\n        partial.u[i] = -MAX_T_VAL;\n    }\n\n#pragma unroll\n    for (int elem_id = thread_id; elem_id < vocab_size; elem_id += THREADBLOCK_SIZE) {\n        partial.insert(logits[elem_id], elem_id);\n    }\n\n    TopK<T, MAX_K> total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_op<T, MAX_K>);\n\n    if (thread_id == 0) {\n        float sum_prob = 0.f;\n\n#pragma unroll\n        for (int i = 0; i < MAX_K; i++) {\n            sum_prob += (float)total.u[i];\n        }\n\n        if (sum_prob >= p_threshold) {\n            begin_offset_buf[batch_id] = end_offset_buf[batch_id];\n            kept[batch_id]             = MAX_K;\n\n#pragma unroll\n            for (int i = 0; i < MAX_K; ++i) {\n                sorted_logits[i]  = (float)total.u[i] / sum_prob;\n                sorted_indices[i] = total.p[i];\n            }\n        }\n    }\n}\n\ntemplate<typename T>\nvoid invokeTopPSort(TopPSortParams& params, cudaStream_t stream)\n{\n    const int num_items = params.vocab_size_padded * (params.batch_size - 1) + params.vocab_size;\n\n    size_t cub_temp_storage_size{};\n    check_cuda_error(cub::DeviceSegmentedRadixSort::SortPairsDescending(nullptr,\n                                                                        cub_temp_storage_size,\n                                                                        (T*)nullptr,\n                                                                        (T*)nullptr,\n                                                                        (int*)nullptr,\n                                                                        (int*)nullptr,\n                                                                        num_items,\n                                                                        params.batch_size,\n                                                                        (int*)nullptr,\n                                                                        (int*)nullptr,\n                                                                        0,              // begin_bit\n                                                                        sizeof(T) * 8,  // end_bit = sizeof(KeyT) * 8\n                                                                        stream));       // cudaStream_t\n\n    TM_CHECK(core::Context::stream().handle() == stream);\n\n    Buffer_<uint8_t> cub_temp_storage(cub_temp_storage_size, kDEVICE);\n\n    Buffer_<int> topp_ids(params.batch_size * params.vocab_size_padded, kDEVICE);\n    Buffer_<int> beg_offset(params.batch_size, kDEVICE);\n    Buffer_<int> end_offset(params.batch_size, kDEVICE);\n\n    auto topp_ids_buf   = topp_ids.data();\n    auto beg_offset_buf = beg_offset.data();\n    auto end_offset_buf = end_offset.data();\n\n    invokeTopPSortInitialize(params.vocab_size_padded,\n                             params.vocab_size,\n                             params.batch_size,\n                             params.top_ks,\n                             topp_ids_buf,\n                             beg_offset_buf,\n                             end_offset_buf,\n                             stream);\n\n    topp_beam_topk_kernel<T, 1, 256><<<params.batch_size, 256, 0, stream>>>((T*)params.logits,\n                                                                            (T*)params.sorted_logits,\n                                                                            params.sorted_indices,\n                                                                            params.kept,\n                                                                            params.vocab_size,\n                                                                            params.vocab_size_padded,\n                                                                            beg_offset_buf,\n                                                                            end_offset_buf,\n                                                                            params.top_ps,\n                                                                            params.top_ks);\n\n    check_cuda_error(cub::DeviceSegmentedRadixSort::SortPairsDescending(cub_temp_storage.data(),\n                                                                        cub_temp_storage_size,\n                                                                        (T*)params.logits,\n                                                                        (T*)params.sorted_logits,\n                                                                        topp_ids_buf,\n                                                                        params.sorted_indices,\n                                                                        num_items,\n                                                                        params.batch_size,\n                                                                        beg_offset_buf,\n                                                                        end_offset_buf,\n                                                                        0,              // begin_bit\n                                                                        sizeof(T) * 8,  // end_bit = sizeof(KeyT) * 8\n                                                                        stream));       // cudaStream_t\n}\n\ntemplate void invokeTopPSort<float>(TopPSortParams& params, cudaStream_t stream);\n\ntemplate<typename T, int BLOCK_SIZE>\n__global__ void topPMinPFilter(T*           sorted_logits,\n                               int*         sorted_indices,\n                               int*         kept,\n                               const int    vocab_size_padded,\n                               const float* top_ps,\n                               const float* min_ps)\n{\n    int   tid        = threadIdx.x;\n    int   bid        = blockIdx.x;\n    int   n          = kept[bid];\n    float sum_logits = 1.f;\n    float top_p      = top_ps[bid];\n    float min_p      = min_ps[bid];\n    sorted_logits += bid * vocab_size_padded;\n    sorted_indices += bid * vocab_size_padded;\n\n    const float kEps = 1e-6f;\n\n    __shared__ int   s_kept;\n    __shared__ float s_sum;\n\n    if (tid == 0) {\n        s_kept = n;\n        s_sum  = 1.f;\n    }\n    __syncthreads();\n\n    if (top_p != 1.0f) {\n        typedef cub::BlockScan<float, BLOCK_SIZE>  BlockScan;\n        __shared__ typename BlockScan::TempStorage temp_storage;\n        // Initialize running total\n        BlockPrefixCallbackOp prefix_op(0);\n        // topp\n        int   end        = ((n + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE;\n        float prefix_sum = 0.f;\n        for (int i = tid; i < end; i += BLOCK_SIZE) {\n            float thread_count = (i < n) ? (float)sorted_logits[i] : 0.f;\n            BlockScan(temp_storage).InclusiveSum(thread_count, prefix_sum, prefix_op);\n            auto count = __syncthreads_count(prefix_sum > top_p);\n            if (count != 0 || (i + BLOCK_SIZE >= end)) {\n                if (tid == min(BLOCK_SIZE - count, BLOCK_SIZE - 1)) {\n                    s_kept = min(i + 1, n);\n                    s_sum  = prefix_sum;\n                }\n                break;\n            }\n        };\n        __syncthreads();\n    }\n\n    if (min_p != 0.f) {\n        n          = s_kept;\n        sum_logits = s_sum;\n\n        typedef cub::BlockScan<float, BLOCK_SIZE>  BlockScan;\n        __shared__ typename BlockScan::TempStorage temp_storage;\n        // Initialize running total\n        BlockPrefixCallbackOp prefix_op(0);\n        // minp\n        float scaled_min_p = (float)sorted_logits[0] / (sum_logits + kEps) * min_p;\n        int   end          = ((n + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE;\n        float prefix_sum   = 0.f;\n        for (int i = tid; i < end; i += BLOCK_SIZE) {\n            float thread_count = (i < n) ? (float)sorted_logits[i] / (sum_logits + kEps) : 0.f;\n            BlockScan(temp_storage).ExclusiveSum(thread_count, prefix_sum, prefix_op);\n            auto count = __syncthreads_count(thread_count < scaled_min_p);\n            if (count != 0 || (i + BLOCK_SIZE >= end)) {\n                if (tid == min(BLOCK_SIZE - count, BLOCK_SIZE - 1)) {\n                    if (count == 0) {\n                        ++i;\n                        prefix_sum += thread_count;\n                    }\n                    s_kept = min(i, n);\n                    s_sum *= prefix_sum;\n                }\n                break;\n            }\n        };\n        __syncthreads();\n    }\n\n    if (top_p != 1.f || min_p != 0.f) {\n        n          = s_kept;\n        sum_logits = s_sum;\n        if (tid == 0) {\n            kept[bid] = n;\n        }\n        // norm\n        for (int i = tid; i < n; i += BLOCK_SIZE) {\n            sorted_logits[i] = (float)sorted_logits[i] / sum_logits;\n        }\n    }\n}\n\ntemplate<typename T>\nvoid invokeTopPMinPFilter(TopPMinPFilterParams& params, cudaStream_t stream)\n{\n    topPMinPFilter<T, 256><<<params.batch_size, 256, 0, stream>>>((T*)params.sorted_logits,\n                                                                  params.sorted_indices,\n                                                                  params.kept,\n                                                                  params.vocab_size_padded,\n                                                                  params.top_ps,\n                                                                  params.min_ps);\n}\n\ntemplate void invokeTopPMinPFilter<float>(TopPMinPFilterParams& params, cudaStream_t stream);\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/kernels/sampling_topp_kernels.h",
    "content": "/*\n * Copyright (c) 2020-2023, NVIDIA CORPORATION.  All rights reserved.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n#pragma once\n\n#include <curand_kernel.h>\n\nnamespace turbomind {\n\nvoid invokeTopPSortInitialize(const int    vocab_size_padded,\n                              const int    vocab_size,\n                              const size_t batch_size,\n                              const int*   top_ks,\n                              int*         topp_id_val_buf,\n                              int*         begin_offet_buf,\n                              int*         end_offset_buf,\n                              cudaStream_t stream);\n\ntemplate<typename T>\nvoid invokeSoftmax(T*           logits,\n                   const int    vocab_size_padded,\n                   const int    vocab_size,\n                   const int    batch_size,\n                   const int*   kept,\n                   cudaStream_t stream);\n\nstruct BlockPrefixCallbackOp {\n    // Running prefix\n    float running_total;\n    // Constructor\n    __device__ BlockPrefixCallbackOp(float running_total): running_total(running_total) {}\n    // Callback operator to be entered by the first warp of threads in the block.\n    // Thread-0 is responsible for returning a value for seeding the block-wide scan.\n    __device__ float operator()(float block_aggregate)\n    {\n        float old_prefix = running_total;\n        running_total += block_aggregate;\n        return old_prefix;\n    }\n};\n\nstruct TopPSortParams {\n    void*  logits;\n    void*  sorted_logits;\n    int*   sorted_indices;\n    int*   kept;\n    int*   top_ks;\n    float* top_ps;\n    int    batch_size;\n    int    vocab_size;\n    int    vocab_size_padded;\n};\n\ntemplate<typename T>\nvoid invokeTopPSort(TopPSortParams& params, cudaStream_t stream);\n\nstruct TopPMinPFilterParams {\n    void*  sorted_logits;\n    int*   sorted_indices;\n    int*   kept;\n    float* top_ps;\n    float* min_ps;\n    int    batch_size;\n    int    vocab_size;\n    int    vocab_size_padded;\n};\n\ntemplate<typename T>\nvoid invokeTopPMinPFilter(TopPMinPFilterParams& params, cudaStream_t stream);\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/kernels/stop_criteria_kernels.cu",
    "content": "/*\n * Copyright (c) 2022-2023, NVIDIA CORPORATION.  All rights reserved.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include \"src/turbomind/kernels/core/math.h\"\n#include \"src/turbomind/kernels/stop_criteria_kernels.h\"\n\nnamespace turbomind {\n\n__global__ void stop_words_criterion_v2(const int** token_ids_ptrs,\n                                        const int*  sequence_length,\n                                        const int*  stop_words,\n                                        bool*       finished,\n                                        int         stop_words_len,\n                                        int         batch_size)\n{\n    const int id        = blockIdx.x * blockDim.x + threadIdx.x;\n    const int batch_idx = blockIdx.y;\n\n    const int* base_stop_words = stop_words + batch_idx * 2 * stop_words_len;\n    const int* base_offsets    = base_stop_words + stop_words_len;\n\n    if (id >= stop_words_len || base_offsets[id] < 0) {\n        return;\n    }\n\n    const int item_end   = base_offsets[id];\n    const int item_start = (id > 0) ? base_offsets[id - 1] : 0;\n    const int item_size  = item_end - item_start;\n\n    const int  seq_len   = sequence_length[batch_idx];\n    const int* token_ids = token_ids_ptrs[batch_idx];\n\n    /* Enough previously generated tokens to look for a match */\n    if (seq_len >= item_size) {\n        // token_ids[seq_len - 1] is the last token\n        for (int token_idx = item_size - 1, offset = seq_len - 1; token_idx >= 0; token_idx--, offset--) {\n            if (token_ids[offset] != base_stop_words[item_start + token_idx]) {\n                return;\n            }\n        }\n        finished[batch_idx] = true;\n    }\n}\n\nvoid invokeStopWordsCriterion_v2(const int**  token_ids_ptrs,\n                                 const int*   sequence_length,\n                                 const int*   stop_words,\n                                 bool*        finished,\n                                 int          stop_words_len,\n                                 int          batch_size,\n                                 cudaStream_t stream)\n{\n    // Check if we have sampled a word from the stop_words list. If so, stop the sequence.\n\n    const int  block = std::min(round_up(stop_words_len, 32), 256);\n    const dim3 grid(cdiv(stop_words_len, block), batch_size);\n\n    stop_words_criterion_v2<<<grid, block, 0, stream>>>(\n        token_ids_ptrs, sequence_length, stop_words, finished, stop_words_len, batch_size);\n}\n\n__global__ void length_criterion_v2(bool*      finished,  //\n                                    const int* sequence_length,\n                                    const int* sequence_length_limit,\n                                    int        batch_size)\n{\n    const int idx = threadIdx.x + blockDim.x * blockIdx.x;\n    if (idx >= batch_size) {\n        return;\n    }\n    if (sequence_length[idx] >= sequence_length_limit[idx]) {\n        finished[idx] = true;\n    }\n}\n\nvoid invokeLengthCriterion_v2(bool*        finished,  //\n                              const int*   sequence_length,\n                              const int*   sequence_length_limit,\n                              int          batch_size,\n                              cudaStream_t stream)\n{\n    // Check if we have attained the sequence length limit. If so, stop the sequence.\n\n    constexpr int block = 256;\n    const int     grid  = cdiv(batch_size, block);\n\n    length_criterion_v2<<<grid, block, 0, stream>>>(finished, sequence_length, sequence_length_limit, batch_size);\n}\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/kernels/stop_criteria_kernels.h",
    "content": "/*\n * Copyright (c) 2022-2023, NVIDIA CORPORATION.  All rights reserved.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n#pragma once\n\n#include <cstdint>\n\n#include <cuda_runtime.h>\n\nnamespace turbomind {\n\nvoid invokeStopWordsCriterion_v2(const int**  token_ids_ptrs,\n                                 const int*   sequence_length,\n                                 const int*   stop_words,\n                                 bool*        finished,\n                                 int          stop_words_len,\n                                 int          batch_size,\n                                 cudaStream_t stream);\n\nvoid invokeLengthCriterion_v2(bool*        finished,  //\n                              const int*   sequence_length,\n                              const int*   sequence_length_limit,\n                              int          batch_size,\n                              cudaStream_t stream);\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/kernels/test_quantization.cc",
    "content": "\n\n#include \"src/turbomind/core/allocator.h\"\n#include \"src/turbomind/core/context.h\"\n#include \"src/turbomind/core/data_type.h\"\n#include \"src/turbomind/core/stream.h\"\n\n#include \"src/turbomind/kernels/gemm/test/test_utils.h\"\n#include \"src/turbomind/kernels/quantization.h\"\n\nusing namespace turbomind;\n\nint main()\n{\n    core::ContextGuard ctx{core::Stream::create(), core::Allocator{kCPU}, core::Allocator{kDEVICE}};\n\n    auto stream = core::Context::stream().handle();\n\n    const int m = 1024, n = 2048, gs = 128;\n\n    Tensor_<bfloat16_t> h_x{{m, n}, kCPU};\n    Tensor_<bfloat16_t> h_x_f{{m, n}, kCPU};\n\n    Tensor_<bfloat16_t> x{{m, n}, kDEVICE};\n    Tensor_<bfloat16_t> x_f{{m, n}, kDEVICE};\n    Tensor_<fp8_e4m3_t> x_q{{m, n}, kDEVICE};\n\n    // Tensor_<float> x_s{{{m, n / gs}, {1, round_up(m, 4)}}, kDEVICE};\n    Tensor_<float> x_s;\n\n    RNG r;\n    r.set_stream(stream);\n\n    /////////////////////////////////////////////////////////////////////////////////////\n    // round trip of dequant(quant(x))\n    r.UniformFloat(x, 2.f, 2.f);  // [-1, +1]\n    Copy(x, h_x);\n    QuantizeSymm(x_q, x_s, x, stream);\n    DequantizeSymm(x_f, x_q, x_s, stream);\n    Copy(x_f, h_x_f);\n    FC_Header();\n    FC_Print(FastCompare(x_f, x, stream));\n\n    /////////////////////////////////////////////////////////////////////////////////////\n    // round trip of dequant(quant(dequant(quant(x)))), aligned representable values\n    Copy(x_f, x);\n    Clear(x_f);\n    QuantizeSymm(x_q, x_s, x, stream);\n    DequantizeSymm(x_f, x_q, x_s, stream);\n    FC_Print(FastCompare(x_f, x, stream));\n\n    /////////////////////////////////////////////////////////////////////////////////////\n    // round trip of dequant(quant(x))\n    // x_s = {{cdiv(m, gs), cdiv(n, gs)}, kDEVICE};\n    x_s = {};\n    r.UniformFloat(x, 2.f, 2.f);  // [-1, +1]\n    Copy(x, h_x);\n    QuantizeSymmBlock(x_q, x_s, x, stream);\n    DequantizeSymmBlock(x_f, x_q, x_s, stream);\n    FC_Print(FastCompare(x_f, x, stream));\n\n    /////////////////////////////////////////////////////////////////////////////////////\n    // round trip of dequant(quant(dequant(quant(x)))), aligned representable values\n    Copy(x_f, x);\n    Clear(x_f);\n    QuantizeSymmBlock(x_q, x_s, x, stream);\n    DequantizeSymmBlock(x_f, x_q, x_s, stream);\n    FC_Print(FastCompare(x_f, x, stream));\n\n    return 0;\n}\n"
  },
  {
    "path": "src/turbomind/kernels/unfused_attention_kernels.cu",
    "content": "/*\n * Copyright (c) 2019-2023, NVIDIA CORPORATION.  All rights reserved.\n * Copyright (c) 2021, NAVER Corp.  Authored by CLOVA.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include <limits>\n#include <type_traits>\n\n#include \"src/turbomind/kernels/core/common.h\"\n#include \"src/turbomind/kernels/core/math.h\"\n#include \"src/turbomind/kernels/reduce_kernel_utils.cuh\"\n#include \"src/turbomind/kernels/unfused_attention_kernels.h\"\n\n#include \"src/turbomind/utils/cuda_type_utils.cuh\"\n#include \"src/turbomind/utils/cuda_utils.h\"\n\nnamespace turbomind {\n\ntemplate<typename T, int ITEMS_PER_THREAD>\n__global__ void __launch_bounds__(1024) softmax_kernel(T*           attn_score,\n                                                       const float* qk,\n                                                       const T*     attn_mask,\n                                                       const T*     sinks,\n                                                       const int    batch_size,\n                                                       const int    head_num,\n                                                       const int    q_length,\n                                                       const int    k_length)\n{\n    // attn_score [batch_size, num_heads, q_length, k_length]\n    // qk         [batch_size, num_heads, q_length, k_length]\n    // attn_mask  [batch_size,            q_length, k_length]\n\n    const long bi = blockIdx.y;  // Batch index.\n    const int  hi = blockIdx.z;  // Head index.\n\n    __shared__ float s_mean, s_max;\n\n    float sink = -std::numeric_limits<float>::infinity();\n    if (sinks) {\n        sink = sinks[hi];\n    }\n\n    // Loop along with Q dimension.\n    for (int qi = blockIdx.x; qi < q_length; qi += gridDim.x) {\n\n        float data[ITEMS_PER_THREAD];\n        long  qk_offset;\n        float local_max = -std::numeric_limits<float>::infinity();\n\n        // Loop along with K dimension.\n        for (int i = 0; i < ITEMS_PER_THREAD; i++) {\n            if (int ki = blockDim.x * i + threadIdx.x; ki < k_length) {  // Index of K dimension.\n\n                qk_offset = ((bi * head_num + hi) * q_length + qi) * k_length + ki;\n\n                float qk_val  = static_cast<float>(qk[qk_offset]);\n                float qk_bias = 0.0f;\n\n                long  mask_offset = (bi * q_length + qi) * k_length + ki;\n                float mask_val    = static_cast<float>(ldg(&attn_mask[mask_offset]));\n\n                if (!mask_val) {\n                    qk_bias -= std::numeric_limits<float>::infinity();\n                }\n\n                data[i]   = qk_val + qk_bias;\n                local_max = fmaxf(local_max, data[i]);\n            }\n        }\n\n        float max_val = blockDim.x <= 32 ? warpReduceMax(local_max) : blockReduceMax<float>(local_max);\n\n        if (threadIdx.x == 0) {\n            s_max = fmaxf(max_val, sink);\n        }\n\n        __syncthreads();\n\n        float local_sum = 0;\n\n        for (int i = 0; i < ITEMS_PER_THREAD; i++) {\n            if (blockDim.x * i + threadIdx.x < k_length) {\n                data[i] = expf(data[i] - s_max);\n                local_sum += data[i];\n            }\n        }\n\n        float sum_val = blockDim.x <= 32 ? warpReduceSum(local_sum) : blockReduceSum<float>(local_sum);\n\n        if (threadIdx.x == 0) {\n            sum_val += expf(sink - s_max);\n            s_mean = sum_val;\n            s_mean = fdividef(1.f, s_mean);\n        }\n        __syncthreads();\n\n        for (int i = 0; i < ITEMS_PER_THREAD; i++) {\n            if (blockDim.x * i + threadIdx.x < k_length) {\n                qk_offset = ((bi * head_num + hi) * q_length + qi) * k_length + blockDim.x * i + threadIdx.x;\n                attn_score[qk_offset] = (T)(data[i] * s_mean);\n            }\n        }\n    }\n}\n\ntemplate<typename T>\nvoid invokeMaskedSoftmax(MaskedSoftmaxParam<T>& param, cudaStream_t stream)\n{\n    // attention_score,    (batch_size, head_num, q_length, k_length), softmax output.\n    // qk,                 (batch_size, head_num, q_length, k_length), QK^T.\n    // attention_mask,     (batch_size, q_length, k_length), attention mask.\n\n    dim3 grid(param.q_length, param.batch_size, param.num_heads);\n\n    auto invoke = [&](auto items_per_thread) {\n        const int block = round_up(cdiv(param.k_length, items_per_thread.value), WARP_SIZE);\n        FT_CHECK(block <= 1024);\n        softmax_kernel<T, items_per_thread.value><<<grid, block, 0, stream>>>(param.attention_score,\n                                                                              param.qk,\n                                                                              param.attention_mask,\n                                                                              param.sinks,\n                                                                              param.batch_size,\n                                                                              param.num_heads,\n                                                                              param.q_length,\n                                                                              param.k_length);\n    };\n\n    const auto k = param.k_length;\n\n    if (k <= 1024) {\n        invoke(std::integral_constant<int, 1>{});\n    }\n    else if (k <= 2048) {\n        invoke(std::integral_constant<int, 2>{});\n    }\n    else if (k <= 4096) {\n        invoke(std::integral_constant<int, 4>{});\n    }\n    else if (k <= 8192) {\n        invoke(std::integral_constant<int, 8>{});\n    }\n    else if (k <= 16384) {\n        invoke(std::integral_constant<int, 16>{});\n    }\n    else if (k <= 32768) {\n        invoke(std::integral_constant<int, 32>{});\n    }\n    else if (k <= 65536) {\n        invoke(std::integral_constant<int, 64>{});\n    }\n    else if (k <= 131072) {\n        invoke(std::integral_constant<int, 128>{});\n    }\n    else {\n        throw std::runtime_error(\"not impelmented\");\n    }\n}\n\ntemplate void invokeMaskedSoftmax(MaskedSoftmaxParam<half>& param, cudaStream_t stream);\n#ifdef ENABLE_BF16\ntemplate void invokeMaskedSoftmax(MaskedSoftmaxParam<nv_bfloat16>& param, cudaStream_t stream);\n#endif\n#if ENABLE_FP32\ntemplate void invokeMaskedSoftmax(MaskedSoftmaxParam<float>& param, cudaStream_t stream);\n#endif\n\n// clang-format off\ntemplate<typename T> struct packed_type;\ntemplate <>          struct packed_type<float>         { using type = float; }; // we don't need to pack float by default\ntemplate <>          struct packed_type<half>          { using type = half2; };\n\n#ifdef ENABLE_BF16\ntemplate<>\nstruct packed_type<__nv_bfloat16> {\n    using type = __nv_bfloat162;\n};\n#endif\n\ntemplate<typename T> struct num_elems;\ntemplate <>          struct num_elems<float>           { static constexpr int value = 1; };\ntemplate <>          struct num_elems<float2>          { static constexpr int value = 2; };\ntemplate <>          struct num_elems<float4>          { static constexpr int value = 4; };\ntemplate <>          struct num_elems<half>            { static constexpr int value = 1; };\ntemplate <>          struct num_elems<half2>           { static constexpr int value = 2; };\n#ifdef ENABLE_BF16\ntemplate <>          struct num_elems<__nv_bfloat16>   { static constexpr int value = 1; };\ntemplate <>          struct num_elems<__nv_bfloat162>  { static constexpr int value = 2; };\n#endif\n\ntemplate<typename T, int num> struct packed_as;\ntemplate<typename T>          struct packed_as<T, 1>              { using type = T; };\ntemplate<>                    struct packed_as<half,  2>          { using type = half2; };\ntemplate<>                    struct packed_as<float,  2>         { using type = float2; };\ntemplate<>                    struct packed_as<int8_t, 2>         { using type = int16_t; };\ntemplate<>                    struct packed_as<int32_t, 2>        { using type = int2; };\ntemplate<>                    struct packed_as<half2, 1>          { using type = half; };\n#ifdef ENABLE_BF16\ntemplate<> struct packed_as<__nv_bfloat16,  2> { using type = __nv_bfloat162; };\ntemplate<> struct packed_as<__nv_bfloat162, 1> { using type = __nv_bfloat16;  };\n#endif\n\ninline __device__ float2 operator*(float2 a, float2 b) { return make_float2(a.x * b.x, a.y * b.y); }\ninline __device__ float2 operator*(float2 a, float  b) { return make_float2(a.x * b, a.y * b); }\n// clang-format on\n\ntemplate<typename T>\n__global__ void transpose_remove_padding(const T*     src,\n                                         T*           dst,\n                                         const int    batch_size,\n                                         const int    seq_len,\n                                         const int    head_num,\n                                         const int    size_per_head,\n                                         const int*   mask_offset,\n                                         const float* scale,\n                                         const int    int8_mode)\n{\n    // TODO: optimize this kernel?\n    // do remove_sequence_length_padding\n    const int bid = blockIdx.x;  // batch * seq_len or valid_word_num\n\n    const int token_offset = mask_offset ? mask_offset[bid] : 0;\n\n    const int src_batch_id = (bid + token_offset) / seq_len;\n    const int src_seq_id   = (bid + token_offset) % seq_len;\n\n    const int dst_seq_id = bid;\n\n    const int src_offset_base = src_batch_id * seq_len * head_num * size_per_head + src_seq_id * size_per_head;\n    const int dst_offset_base = dst_seq_id * head_num * size_per_head;\n\n    using Int8_Packed_T  = typename packed_as<int8_t, num_elems<T>::value>::type;\n    using Float_Packed_T = typename packed_as<float, num_elems<T>::value>::type;\n    const Float_Packed_T scale_val =\n        int8_mode == 2 ? cuda_cast<Float_Packed_T>(*scale) : cuda_cast<Float_Packed_T>(0.0f);\n\n    for (int idx = threadIdx.x; idx < head_num * size_per_head; idx += blockDim.x) {\n        const int head_id   = idx / size_per_head;\n        const int hidden_id = idx % size_per_head;\n        const T   src_elem  = ldg(&src[src_offset_base + head_id * seq_len * size_per_head + hidden_id]);\n        if (int8_mode == 2) {\n            reinterpret_cast<Int8_Packed_T*>(dst)[dst_offset_base + idx] =\n                cuda_cast<Int8_Packed_T>(cuda_cast<Float_Packed_T>(src_elem) * scale_val);\n        }\n        else {\n            dst[dst_offset_base + idx] = src_elem;\n        }\n    }\n}\n\n// clang-format off\ntemplate<typename T>\nvoid invokeTransposeAttentionOutRemovePadding(T*           src,\n                                              T*           dst,\n                                              const int    valid_word_num,\n                                              const int    batch_size,\n                                              const int    seq_len,\n                                              const int    head_num,\n                                              const int    size_per_head,\n                                              const int*   mask_offset,\n                                              const float* scale,\n                                              const int    int8_mode,\n                                              cudaStream_t stream)\n{\n#ifdef ENABLE_BF16\n    bool is_half2 = (std::is_same<T, half>::value || std::is_same<T, __nv_bfloat16>::value) && (size_per_head % 2 == 0);\n#else\n    bool is_half2 = (std::is_same<T, half>::value) && (size_per_head % 2 == 0);\n#endif\n    using T2       = typename TypeConverter<T>::Type;  // fp16 to half2, bf16 to bf162\n    int block_size = head_num * size_per_head;\n    if (is_half2) {\n        while (block_size > 512) {\n            if (block_size % 2 == 0) {\n                block_size /= 2;\n            }\n            else {\n                is_half2   = false;\n                block_size = std::min(block_size, 1024);\n                break;\n            }\n        }\n    }\n    else {\n        block_size = std::min(block_size, 1024);\n    }\n\n    if (is_half2) {\n        transpose_remove_padding<T2><<<valid_word_num, block_size, 0, stream>>>(\n            (T2*)src, (T2*)dst, batch_size, seq_len, head_num, size_per_head / 2, mask_offset, scale, int8_mode);\n    }\n    else {\n        transpose_remove_padding<<<valid_word_num, block_size, 0, stream>>>(\n            src, dst, batch_size, seq_len, head_num, size_per_head, mask_offset, scale, int8_mode);\n    }\n}\n// clang-format on\n\n#define INSTANTIATETRANSPOSEATTENTIONOUTREMOVEPADDING(T)                                                               \\\n    template void invokeTransposeAttentionOutRemovePadding(T*           src,                                           \\\n                                                           T*           dst,                                           \\\n                                                           const int    valid_word_num,                                \\\n                                                           const int    batch_size,                                    \\\n                                                           const int    seq_len,                                       \\\n                                                           const int    head_num,                                      \\\n                                                           const int    size_per_head,                                 \\\n                                                           const int*   mask_offset,                                   \\\n                                                           const float* scale,                                         \\\n                                                           const int    int8_mode,                                     \\\n                                                           cudaStream_t stream)\n#ifdef ENABLE_FP32\nINSTANTIATETRANSPOSEATTENTIONOUTREMOVEPADDING(float);\n#endif\nINSTANTIATETRANSPOSEATTENTIONOUTREMOVEPADDING(half);\n#ifdef ENABLE_BF16\nINSTANTIATETRANSPOSEATTENTIONOUTREMOVEPADDING(__nv_bfloat16);\n#endif\n#undef INSTANTIATETRANSPOSEATTENTIONOUTREMOVEPADDING\n\ntemplate<typename T>\n__global__ void addRelativeAttentionBias(\n    T* qk_buf, const T* relative_attention_bias, const int batch_size, const int head_num, const int seq_len)\n{\n    for (int i = threadIdx.x; i < batch_size * seq_len; i += blockDim.x) {\n        int batch_id = i / seq_len;\n        int seq_id   = i % seq_len;\n\n        const int bias_index = blockIdx.x * seq_len + seq_id;\n        const int qk_index   = batch_id * gridDim.x * seq_len + bias_index;\n        qk_buf[qk_index]     = add(qk_buf[qk_index], relative_attention_bias[bias_index]);\n    }\n}\n\ntemplate<typename T>\nvoid invokeAddRelativeAttentionBias(T*           qk_buf,\n                                    const T*     relative_attention_bias,\n                                    const int    batch_size,\n                                    const int    head_num,\n                                    const int    seq_len,\n                                    cudaStream_t stream)\n{\n    // qk_buf: [batch_size, head_num, seq_len, seq_len]\n    // relative_attention_bias: [1, head_num, seq_len, seq_len]\n    dim3 grid(head_num * seq_len);\n    dim3 block(512);\n    using T2 = typename TypeConverter<T>::Type;\n#ifdef ENABLE_BF16\n    const bool is_half2 = (std::is_same<T, half>::value || std::is_same<T, __nv_bfloat16>::value) && (seq_len % 2 == 0);\n#else\n    const bool is_half2 = (std::is_same<T, half>::value) && (seq_len % 2 == 0);\n#endif\n    if (is_half2) {\n        addRelativeAttentionBias<T2><<<grid, block, 0, stream>>>(\n            (T2*)qk_buf, (const T2*)relative_attention_bias, batch_size, head_num, seq_len / 2);\n    }\n    else {\n        addRelativeAttentionBias<<<grid, block, 0, stream>>>(\n            qk_buf, relative_attention_bias, batch_size, head_num, seq_len);\n    }\n}\n\n#define INSTANTIATEADDRELATIVEATTENTIONBIAS(T)                                                                         \\\n    template void invokeAddRelativeAttentionBias(T*           qk_buf,                                                  \\\n                                                 const T*     relative_attention_bias,                                 \\\n                                                 const int    batch_size,                                              \\\n                                                 const int    head_num,                                                \\\n                                                 const int    seq_len,                                                 \\\n                                                 cudaStream_t stream)\n#if 0\n#ifdef ENABLE_FP32\nINSTANTIATEADDRELATIVEATTENTIONBIAS(float);\n#endif\nINSTANTIATEADDRELATIVEATTENTIONBIAS(half);\n#ifdef ENABLE_BF16\nINSTANTIATEADDRELATIVEATTENTIONBIAS(__nv_bfloat16);\n#endif\n#undef INSTANTIATEADDRELATIVEATTENTIONBIAS\n#endif\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/kernels/unfused_attention_kernels.h",
    "content": "/*\n * Copyright (c) 2020-2023, NVIDIA CORPORATION.  All rights reserved.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n#pragma once\n\nnamespace turbomind {\n\ntemplate<typename T>\nstruct MaskedSoftmaxParam {\n    // Common parameters.\n    T*           attention_score = nullptr;  // (batch_size, head_num, q_length, k_length)\n    const float* qk              = nullptr;  // (batch_size, head_num, q_length, k_length)\n    const T*     attention_mask  = nullptr;  // (batch_size, q_length, k_length)\n    int          batch_size      = 0;\n    int          q_length        = 0;\n    int          k_length        = 0;\n    int          num_heads       = 0;\n    const T*     sinks           = nullptr;\n};\n\ntemplate<typename T>\nvoid invokeMaskedSoftmax(MaskedSoftmaxParam<T>& param, cudaStream_t stream);\n\ntemplate<typename T>\nvoid invokeTransposeQKV(T*           dst,\n                        T*           src,\n                        const int    batch_size,\n                        const int    seq_len,\n                        const int    head_num,\n                        const int    size_per_head,\n                        const float* scale,\n                        const int    int8_mode,\n                        cudaStream_t stream);\n\ntemplate<typename T>\nvoid invokeTransposeAttentionOutRemovePadding(T*           src,\n                                              T*           dst,\n                                              const int    valid_word_num,\n                                              const int    batch_size,\n                                              const int    seq_len,\n                                              const int    head_num,\n                                              const int    size_per_head,\n                                              const int*   mask_offset,\n                                              const float* scale,\n                                              const int    int8_mode,\n                                              cudaStream_t stream);\n\ntemplate<typename T>\nvoid invokeAddFusedQKVBiasTranspose(T*           q_buf,\n                                    T*           k_buf,\n                                    T*           v_buf,\n                                    T*           QKV,\n                                    const T*     qkv_bias,\n                                    const int*   padding_offset,\n                                    const int*   context_length,\n                                    const int*   input_length,\n                                    const float* rope_theta,\n                                    const int    batch_size,\n                                    const int    seq_len,\n                                    const int    token_num,\n                                    const int    head_num,\n                                    const int    kv_head_num,\n                                    const int    size_per_head,\n                                    const int    rotary_embedding_dim,\n                                    float        rotary_embedding_base,\n                                    int          max_position_embeddings,\n                                    bool         use_dynamic_ntk,\n                                    bool         use_logn_attn,\n                                    cudaStream_t stream);\n\ntemplate<typename T>\nvoid invokeTranspose4d(T*           dst,\n                       T*           src,\n                       const int    local_batch_size,\n                       const int    seq_len,\n                       const int    size_per_head,\n                       const int    local_hidden_units,\n                       const int    local_head_num,\n                       const int    batch_size,\n                       const int    ite,\n                       cudaStream_t stream);\n\ntemplate<typename T>\nvoid invokeTranspose4dBatchMajor(T*           k_dst,\n                                 T*           v_dst,\n                                 const T*     k_src,\n                                 const T*     v_src,\n                                 const int    local_batch_size,\n                                 const int    seq_len,\n                                 const int    max_seq_len,\n                                 const int    size_per_head,\n                                 const int    local_head_num,\n                                 cudaStream_t stream);\n\ntemplate<typename T>\nvoid invokeAddRelativeAttentionBias(T*           qk_buf,\n                                    const T*     relative_attention_bias,\n                                    const int    batch_size,\n                                    const int    head_num,\n                                    const int    seq_len,\n                                    cudaStream_t stream);\n\ntemplate<typename T>\nvoid invokeAddHead3SizeQKVBias(const T*     mm_qkv,\n                               const T*     bias_qkv,\n                               T*           q_buf_,\n                               T*           k_buf_,\n                               T*           v_buf_,\n                               const int    batch,\n                               const int    window_num,\n                               const int    window_len,\n                               const int    head_num,\n                               const int    size_per_head,\n                               cudaStream_t stream);\n\ntemplate<typename T>\nvoid invokeMaskedSoftMaxWithRelPosBias(T*           qk_buf,\n                                       const T*     attn_mask,\n                                       const T*     relative_pos_bias,\n                                       const int    batch_size,\n                                       const int    num_head,\n                                       const int    window_num,\n                                       const int    window_len,\n                                       const float  qk_scale,\n                                       cudaStream_t stream);\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/macro.h",
    "content": "#pragma once\n\n#if !defined(__PRETTY_FUNCTION__) && !defined(__GNUC__)\n\n#define __PRETTY_FUNCTION__ __FUNCSIG__\n\n#endif\n\ntypedef unsigned int uint;\n"
  },
  {
    "path": "src/turbomind/models/CMakeLists.txt",
    "content": "cmake_minimum_required(VERSION 3.8)\n\nadd_library(models STATIC\n        language_model.cc\n        input_processor.cc\n        output_processor.cc\n        llama/LlamaLinear.cu\n        llama/BlockManager.cc\n        llama/BlockTrie.cc\n        llama/SequenceManager.cc\n        llama/LlamaWeight.cc\n        llama/LlamaDenseWeight.cc\n        llama/LlamaDecoderLayerWeight.cc\n        llama/LlamaFfnLayer.cc\n        llama/moe_ffn_layer.cc\n        llama/unified_decoder.cc\n        llama/unified_attention_layer.cc\n        llama/llama_kernels.cu\n        llama/llama_utils.cu\n        llama/mla_utils.cu\n        llama/GatedDeltaNetWeight.cc\n        llama/GatedDeltaNetLayer.cc\n        llama/gated_delta_net_kernels.cu)\nset_property(TARGET models PROPERTY POSITION_INDEPENDENT_CODE ON)\nset_property(TARGET models PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)\ntarget_link_libraries(models PUBLIC\n        generation\n        core\n        gemm2\n        rms_norm\n        CUDA::cublas\n        CUDA::cudart\n        nvidia::cutlass::cutlass\n        activation_kernels\n        activation\n        attention\n        decoding_kernels\n        quantization_kernels\n        unfused_attention_kernels\n        gpt_kernels\n        memory_utils\n        cuda_utils\n        logger\n        anomaly_handler)\ntarget_compile_options(models PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xptxas=-v --generate-line-info --threads 8>)\n\nif(BUILD_TEST)\n    add_executable(bench_gated_delta_net\n            llama/bench_gated_delta_net.cc)\n    target_link_libraries(bench_gated_delta_net PRIVATE\n            models\n            CUDA::cudart)\n\n    add_executable(bench_conv1d_silu\n            llama/bench_conv1d_silu.cc)\n    target_link_libraries(bench_conv1d_silu PRIVATE\n            models\n            CUDA::cudart)\nendif()\n"
  },
  {
    "path": "src/turbomind/models/input_processor.cc",
    "content": "\n#include \"src/turbomind/models/input_processor.h\"\n\n#include \"src/turbomind/core/check.h\"\n#include \"src/turbomind/core/core.h\"\n\n#include \"src/turbomind/engine/request.h\"\n\n#include \"src/turbomind/models/llama/SequenceManager.h\"\n\nnamespace turbomind {\n\nusing std::vector;\n\nstruct InputProcessor::Impl {\npublic:\n    Impl(const EngineParam& engine, const ModelParam& model, int phases):\n        max_batch_size_{engine.max_batch_size}, max_forward_token_num_{engine.max_forward_token_num}\n    {\n        input_ids_buf_         = {max_forward_token_num_, kCPUpinned};\n        input_ids_offsets_buf_ = {max_batch_size_ + 1, kCPUpinned};\n        decode_token_pos_buf_  = {max_batch_size_, kCPUpinned};\n\n        data_.reserve(phases);\n        for (int i = 0; i < phases; ++i) {\n            auto& d              = data_.emplace_back();\n            d.input_ids          = empty_like(input_ids_buf_, kDEVICE);\n            d.input_ids_offsets  = empty_like(input_ids_offsets_buf_, kDEVICE);\n            d.selected_token_pos = empty_like(decode_token_pos_buf_, kDEVICE);\n\n            d.autoreg_ids_pos = {max_batch_size_, kCPU};  // ! CPU buffer\n\n            /// TODO: initialize only when required\n            d.input_embeds_buf = {{max_forward_token_num_, (int)model.hidden_units}, model.data_type, kCPUpinned};\n        }\n    }\n\n    int Add(RequestCache& c)\n    {\n        const auto& [r, s] = std::tie(*c.req, *c.seq);\n\n        // trim input embeds\n        if (!s.input_embeds_offsets.empty()) {\n            Interval l{0, (int)s.tokens.size()};\n            using Size    = Interval::Size;\n            auto& embeds  = s.input_embeds;\n            auto& offsets = s.input_embeds_offsets;\n            int   i       = embeds.size() - 1;\n            for (; i >= 0; --i) {\n                Interval r{offsets[i], Size{(int)embeds[i].shape(0)}};\n                if (auto o = r & l) {\n                    if (o.end() < r.end()) {\n                        embeds[i] = embeds[i].slice(0, o.end() - r.begin());\n                    }\n                    break;\n                }\n            }\n            embeds.resize(i + 1);\n            offsets.resize(i + 1);\n        }\n\n        if (auto ranges_ptr = r.inputs.try_(\"input_embedding_ranges\")) {  // [n, 2]\n            auto embeds = r.inputs.at(\"input_embeddings\");                // [k, d]\n            if (ranges_ptr->ndim() != 2 || embeds.ndim() != 2 || ranges_ptr->shape(1) != 2) {\n                /// TODO: reject for invalid shapes\n                return Request::kInvalid;\n            }\n\n            // clone the embeds if the request persists\n            if (!r.session.end_flag) {\n                auto tmp = std::exchange(embeds, empty_like(embeds));\n                std::copy_n((const uint8_t*)tmp.raw_data(), tmp.byte_size(), (uint8_t*)embeds.raw_data());\n            }\n\n            const auto [sum, dim] = embeds.shapes(0, 1);\n            const auto n          = ranges_ptr->shape(0);\n            const auto ranges     = ranges_ptr->data<int>();\n\n            int offset = 0;\n            int last   = c.step0;\n            for (int i = 0; i < n; ++i) {\n                Interval range{c.step0 + ranges[i * 2], c.step0 + ranges[i * 2 + 1]};\n                auto     size = (int)range.size();\n                if (range.begin() < last) {\n                    /// TODO: reject for non-sorted ranges\n                    return Request::kInvalid;\n                }\n                if (range.end() > c.seq_len) {\n                    /// TODO: reject for dst range OOB\n                    return Request::kInvalid;\n                }\n                if (offset + size > sum) {\n                    /// TODO: reject for src range OOB\n                    return Request::kInvalid;\n                }\n                s.input_embeds_offsets.push_back(range.begin());\n                s.input_embeds.push_back(embeds.slice(offset, size));  // reference into `embeds`\n                offset += size;\n                last = range.end();\n            }\n        }\n\n        return 0;\n    }\n\n    void Add(int phase, TensorMap& env)\n    {\n        const Buffer_<RequestCache*> rc = env.at(\"requests\").buffer();\n        for (int i = 0; i < rc.size(); ++i) {\n            auto& c = *TM_CHECK_NOTNULL(rc[i]);\n            if (c.status == 0) {\n                c.status = Add(c);\n            }\n        }\n    }\n\n    void Setup(int phase, TensorMap& env)\n    {\n        auto& d    = data_.at(phase);\n        auto& b    = *env.at(\"batch\").data<BatchData*>()[0];\n        auto& copy = *env.at(\"copy\").data<BatchCopy*>()[0];\n\n        const auto& rc = b.rc;\n\n        input_ids_offsets_buf_[0] = 0;\n        for (int i = 0; i < rc.size(); ++i) {\n            input_ids_offsets_buf_[i + 1] = input_ids_offsets_buf_[i];\n            if (const auto& c = *rc[i]; TM_UNLIKELY(!c.autoregres)) {\n                const auto src = c.token_ids + c.history_len + c.alpha;\n                std::copy_n(src, c.input_len, input_ids_buf_.data() + input_ids_offsets_buf_[i]);\n                // dbg(std::vector<int>(src, src + c.input_len));\n                d.autoreg_ids_pos[i] = -1;\n                input_ids_offsets_buf_[i + 1] += c.input_len;\n            }\n            else {\n                d.autoreg_ids_pos[i] = input_ids_offsets_buf_[i];\n                input_ids_offsets_buf_[i + 1] += 1;\n            }\n            decode_token_pos_buf_[i] = input_ids_offsets_buf_[i + 1] - 1;\n        }\n\n        // dbg(core::to_vector<int>(input_ids_offsets_buf_.slice(0, bsz + 1)));\n        // dbg(core::to_vector<int>(decode_token_pos_buf_.slice(0, bsz)));\n\n        copy(input_ids_buf_, input_ids_offsets_buf_[b.bsz], d.input_ids);\n        copy(decode_token_pos_buf_, b.bsz, d.selected_token_pos);\n        copy(input_ids_offsets_buf_, b.bsz + 1, d.input_ids_offsets);\n\n        // dbg(decode_token_pos_buf_[0]);\n\n        d.input_token_num = input_ids_offsets_buf_[b.bsz];\n        // dbg(d.input_token_num);\n\n        env.produce(\"token_num\", Buffer{&d.input_token_num, 1, kCPU});\n\n        ////////////////////////////////////////////////////////////////\n        /// input embeddings\n        d.input_embeds_coords.clear();\n        auto embed_ptr = (uint8_t*)d.input_embeds_buf.raw_data();\n        for (int k = 0; k < rc.size(); ++k) {\n            if (auto& c = *rc[k]; !c.autoregres) {\n                const auto& embeds  = c.seq->input_embeds;\n                const auto& offsets = c.seq->input_embeds_offsets;\n                Interval    p{input_ids_offsets_buf_[k], input_ids_offsets_buf_[k + 1]};\n                Interval    s{c.history_len + c.alpha, p.size()};\n                for (int i = (int)offsets.size() - 1; i >= 0; --i) {\n                    Interval r{offsets[i], Interval::Size{(int)embeds[i].shape(0)}};\n                    auto     o = r & s;\n                    if (auto size = (int)o.size()) {\n                        auto src  = embeds[i].slice(o.begin() - r.begin(), size);\n                        embed_ptr = std::copy_n((const uint8_t*)src.raw_data(), src.byte_size(), embed_ptr);\n                        d.input_embeds_coords.emplace_back(size, p.begin() + (o.begin() - s.begin()));\n                    }\n                }\n            }\n        }\n    }\n\n    void Prepare(int phase, TensorMap& env)\n    {\n        auto& d    = data_.at(phase);\n        auto& b    = *env.at(\"batch\").data<BatchData*>()[0];\n        auto& copy = *env.at(\"copy\").data<BatchCopy*>()[0];\n\n        // last output token + draft tokens\n        const Buffer_<int> autoreg_ids = env.at(\"autoreg_ids\").buffer();\n\n        // core::CopyT copy{};\n\n        if (auto g = copy.group()) {\n            for (int i = 0; i < b.bsz; ++i) {\n                if (auto pos = d.autoreg_ids_pos[i]; pos >= 0) {\n                    TM_CHECK_LT(b.perm[i], b.bs0);\n                    copy(autoreg_ids.data() + b.perm[i], 1, &d.input_ids[pos]);\n                }\n            }\n        }\n\n        env.produce(\"input_ids\", d.input_ids.slice(0, d.input_token_num));\n        env.produce(\"q_offsets\", d.input_ids_offsets.slice(0, b.bsz + 1));\n        env.produce(\"selected_token_pos\", d.selected_token_pos.slice(0, b.bsz));\n    }\n\n    void PatchEmbedding(int phase, Tensor& embeds, BatchCopy& copy)\n    {\n        auto&      d           = data_.at(phase);\n        const auto byte_stride = byte_size(embeds.dtype(), embeds.stride(0));\n        int        offset      = 0;\n        for (const auto& [size, pos] : d.input_embeds_coords) {\n            auto src = d.input_embeds_buf.slice(offset, size);\n            copy((uint8_t*)src.raw_data(), src.byte_size(), (uint8_t*)embeds.raw_data() + byte_stride * pos);\n            offset += size;\n        }\n    }\n\nprivate:\n    struct Data {\n        Buffer_<int> input_ids;\n        Buffer_<int> input_ids_offsets;\n        int          input_token_num;\n\n        Buffer_<int> selected_token_pos;\n\n        Buffer_<int> autoreg_ids_pos;\n\n        Tensor                      input_embeds_buf;\n        vector<std::pair<int, int>> input_embeds_coords;  // (size, pos)\n    };\n\nprivate:\n    const int max_batch_size_;\n    const int max_forward_token_num_;\n\n    vector<Data> data_;\n\n    Buffer_<int> input_ids_buf_;\n    Buffer_<int> input_ids_offsets_buf_;\n\n    Buffer_<int> decode_token_pos_buf_;\n};\n\nInputProcessor::~InputProcessor() = default;\n\nInputProcessor::InputProcessor(const EngineParam& engine, const ModelParam& model, int phases):\n    impl_{std::make_unique<Impl>(engine, model, phases)}\n{\n}\n\nvoid InputProcessor::Run(BatchOp op, int phase, TensorMap& env)\n{\n    switch (op) {\n        case BatchOp::kAdd:\n            return impl_->Add(phase, env);\n        case BatchOp::kSetup:\n            return impl_->Setup(phase, env);\n        case BatchOp::kPrepare:\n            return impl_->Prepare(phase, env);\n        default:\n            return;\n    }\n}\n\nvoid InputProcessor::PatchEmbedding(int phase, Tensor& embeds, BatchCopy& copy)\n{\n    impl_->PatchEmbedding(phase, embeds, copy);\n}\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/models/input_processor.h",
    "content": "#pragma once\n\n#include \"src/turbomind/engine/batch.h\"\n#include \"src/turbomind/models/llama/llama_params.h\"\n\nnamespace turbomind {\n\nclass InputProcessor {\npublic:\n    ~InputProcessor();\n\n    InputProcessor(const EngineParam& engine, const ModelParam& model, int phases);\n\n    void Run(BatchOp op, int phase, TensorMap& env);\n\n    void PatchEmbedding(int phase, Tensor& embeds, BatchCopy& copy);\n\nprivate:\n    struct Impl;\n    std::unique_ptr<Impl> impl_;\n};\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/models/language_model.cc",
    "content": "\n#include \"src/turbomind/models/language_model.h\"\n\n#include <memory>\n\n#include \"src/turbomind/comm/device_comm.h\"\n#include \"src/turbomind/core/allocator.h\"\n#include \"src/turbomind/core/check.h\"\n#include \"src/turbomind/core/context.h\"\n#include \"src/turbomind/core/copy.h\"\n#include \"src/turbomind/core/interval.h\"\n#include \"src/turbomind/core/state.h\"\n#include \"src/turbomind/engine/batch.h\"\n#include \"src/turbomind/engine/request.h\"\n#include \"src/turbomind/generation/generation.h\"\n#include \"src/turbomind/kernels/gpt_kernels.h\"\n#include \"src/turbomind/models/input_processor.h\"\n#include \"src/turbomind/models/llama/LlamaWeight.h\"\n#include \"src/turbomind/models/llama/llama_kernels.h\"\n#include \"src/turbomind/models/llama/llama_params.h\"\n#include \"src/turbomind/models/llama/llama_utils.h\"\n#include \"src/turbomind/models/llama/unified_decoder.h\"\n#include \"src/turbomind/models/output_processor.h\"\n#include \"src/turbomind/utils/anomaly_handler.h\"\n#include \"src/turbomind/utils/cuda_utils.h\"\n\n// #include \"dbg.h\"\n\nnamespace turbomind {\n\nusing std::vector;\nusing std::unique_ptr;\nusing std::shared_ptr;\n\nstruct LanguageModel::Impl {\n    const DataType       dtype_;\n    const ModelParam     param_;\n    const AttentionParam attn_param_;\n    const Communicators& comm_;\n    const LlamaWeight&   weights_;\n    LlamaLinear&         linear_;\n\n    const int  tp_size_;\n    const int  tp_rank_;\n    const bool use_ag2d_;\n\n    const bool debug_;\n\n    Buffer_<bool> false_;\n\n    // mutable state\n    State finished_;\n    State sequence_length_;  // length of known tokens\n    // immutable state\n    Buffer_<int> autoreg_ids_;\n    // Buffer_<int> autoreg_ids_offsets_;\n\n    // Symmetric buffer for holding global hidden states or logits\n    Buffer_<uint8_t> symm_buf_;\n\n    // Max chunk size for compute / output full logits\n    int max_logits_len_ = 0;\n\n    Buffer_<int>  sequence_length_buf_;\n    Buffer_<bool> finished_buf_;\n\n    struct Data {\n        Buffer_<int>  sequence_length;\n        Buffer_<bool> finished;\n\n        Buffer_<bool> autoregres;\n        Buffer_<bool> generating;\n\n        int n_generating;\n    };\n\n    vector<Data> data_;\n\n    std::optional<InputProcessor>   input_processor_;\n    std::unique_ptr<UnifiedDecoder> unified_decoder_;\n    std::optional<OutputProcessor>  output_processor_;\n    std::unique_ptr<Generation>     generation_;  // token generator\n\n    void Run(BatchOp op, int phase, TensorMap& env)\n    {\n        switch (op) {\n            case BatchOp::kSetup:\n                return Setup(phase, env);\n            case BatchOp::kPrepare:\n                return Prepare(phase, env);\n            case BatchOp::kForward:\n                return Forward(phase, env);\n            case BatchOp::kUnprep:\n                return Unprep(phase, env);\n            case BatchOp::kFetch:\n                return Fetch(phase, env);\n            default:\n                input_processor_->Run(op, phase, env);\n                unified_decoder_->Run(op, phase, env);\n                generation_->Run(op, phase, env);\n                output_processor_->Run(op, phase, env);\n        }\n    }\n\n    Impl(DataType              dtype,\n         const ModelParam&     model,\n         const EngineParam&    engine,\n         const AttentionParam& attn,\n         const MoeParam&       moe,\n         const Context&        ctx,\n         const LlamaWeight&    weights,\n         int                   phases);\n\n    Tensor LookupEmbedding(const Buffer_<int>& input_ids, Buffer symm_buf);\n    Tensor PostEmbedding(const Tensor& features, Buffer symm_buf);\n\n    void Setup(int phase, TensorMap& env);\n    void Prepare(int phase, TensorMap& env);\n    void Forward(int phase, TensorMap& env);\n    void Unprep(int phase, TensorMap& env);\n    void Fetch(int phase, TensorMap& env);\n};\n\nLanguageModel::Impl::Impl(DataType              dtype,\n                          const ModelParam&     model,\n                          const EngineParam&    engine,\n                          const AttentionParam& attn,\n                          const MoeParam&       moe,\n                          const Context&        ctx,\n                          const LlamaWeight&    weights,\n                          int                   phases):\n    dtype_{dtype},\n    param_{model},\n    attn_param_{attn},\n    comm_{ctx.comm},\n    weights_{weights},\n    linear_{*ctx.linear},\n    tp_size_{comm_.h_tp_group->n_ranks()},\n    tp_rank_{comm_.h_tp_group->rank()},\n    use_ag2d_{comm_.d_comm && comm_.d_comm->Query(comm::kHasAllGather2D)},\n    debug_{isDebug()}\n{\n\n    false_ = {engine.max_batch_size, kDEVICE};\n    Clear(false_);\n\n    finished_buf_ = {engine.max_batch_size, kCPUpinned};\n    finished_     = {{engine.max_batch_size}, kBool, kDEVICE};\n\n    autoreg_ids_ = {engine.max_batch_size, kDEVICE};\n    // autoreg_ids_offsets_ = {engine.max_batch_size + 1, kCPU};\n    // std::fill_n(autoreg_ids_offsets_.data(), autoreg_ids_offsets_.size(), 0);\n\n    sequence_length_buf_ = {engine.max_batch_size, kCPUpinned};\n    sequence_length_     = {{engine.max_batch_size}, kInt, kDEVICE};\n    for (int i = 0; i < phases; ++i) {\n        auto& d           = data_.emplace_back();\n        d.sequence_length = empty_like(sequence_length_buf_, kDEVICE);\n        d.finished        = empty_like(finished_buf_, kDEVICE);\n        d.autoregres      = {engine.max_batch_size, kCPU};\n        d.generating      = {engine.max_batch_size, kCPU};\n    }\n\n    input_processor_.emplace(engine, param_, phases);\n\n    unified_decoder_ = std::make_unique<UnifiedDecoder>(model, engine, attn, moe, ctx, phases);\n\n    generation_ = std::make_unique<Generation>(kFloat32,\n                                               engine.max_batch_size,\n                                               engine.session_len,\n                                               model.vocab_size,\n                                               weights.post_decoder_embedding.output_dim * tp_size_,\n                                               comm_.h_tp_group,\n                                               phases);\n\n    const int     vocab_size     = weights_.post_decoder_embedding.output_dim * tp_size_;\n    const ssize_t max_fwd_tokens = engine.max_forward_token_num;\n\n    if (ctx.comm.d_comm) {\n        auto symm_alloc = GetSymmAllocator(ctx.comm.d_comm);\n        // Native comm fuses allreduce & rmsnorm in token granularity\n        TM_CHECK(engine.max_forward_token_num % tp_size_ == 0);\n\n        ssize_t bytes{};\n        bytes = std::max(bytes, byte_size(dtype_, max_fwd_tokens * engine.attn_dp_size * model.hidden_units));\n        bytes = std::max(bytes, byte_size(dtype_, engine.max_batch_size * vocab_size));\n\n        symm_buf_ = {bytes, symm_alloc};\n        // Compute max logits length based on symm buffer size\n        max_logits_len_ = symm_buf_.view(dtype_).size() / vocab_size;\n    }\n    else {\n        max_logits_len_ = std::max<int>(max_fwd_tokens * model.hidden_units / vocab_size, engine.max_batch_size);\n    }\n\n    output_processor_.emplace(param_, max_logits_len_, tp_rank_, phases, [this](const Tensor& hstate) {\n        return PostEmbedding(hstate, symm_buf_);\n    });\n}\n\nTensor LanguageModel::Impl::LookupEmbedding(const Buffer_<int>& input_ids, Buffer symm_buf)\n{\n    const auto st = core::Context::stream().handle();\n\n    const int hidden_units = param_.hidden_units;\n\n    const auto& embedding_table = weights_.pre_decoder_embedding.weight;\n    TM_CHECK_EQ(embedding_table.shape(1) * tp_size_, hidden_units);\n\n    const int token_num = input_ids.size();\n\n    Tensor input_embeds{{token_num, hidden_units}, dtype_, kDEVICE};\n\n    if (token_num == 0) {\n        return input_embeds;\n    }\n\n    if (tp_size_ == 1) {\n        invokeEmbeddingLookup(input_embeds, input_ids, embedding_table, st);\n        sync_check_cuda_error();\n    }\n    else if (use_ag2d_) {\n        const auto local_hidden_units = embedding_table.shape(1);\n\n        Tensor temp{symm_buf.view(dtype_), {token_num, tp_size_, local_hidden_units}};\n        Tensor local{temp.slice({0, tp_rank_, 0}, {-1, 1, -1}).squeeze(1)};\n\n        invokeEmbeddingLookup(local, input_ids, embedding_table, st);\n        sync_check_cuda_error();\n\n        comm_.d_comm->AllGather2D(local.raw_data(),\n                                  temp.raw_data(),\n                                  hidden_units,\n                                  local_hidden_units,\n                                  local_hidden_units,\n                                  token_num,\n                                  local.dtype(),\n                                  {true, true},\n                                  comm_.d_tp_group,\n                                  st);\n        sync_check_cuda_error();\n\n        Copy(temp.buffer(), input_embeds.buffer());\n    }\n    else {\n        const auto local_hidden_units = embedding_table.shape(1);\n\n        Tensor temp{symm_buf.view(dtype_), {tp_size_, token_num, local_hidden_units}};\n        Tensor local{temp.slice(tp_rank_).squeeze(0)};\n\n        invokeEmbeddingLookup(local, input_ids, embedding_table, st);\n        sync_check_cuda_error();\n\n        comm_.d_comm->AllGather(local.raw_data(), temp.raw_data(), local.size(), dtype_, comm_.d_tp_group, st);\n        sync_check_cuda_error();\n\n        invokeInPlaceTranspose102((uint16_t*)input_embeds.raw_data(),\n                                  (uint16_t*)temp.raw_data(),\n                                  tp_size_,\n                                  token_num,\n                                  local_hidden_units,\n                                  false,\n                                  st);\n        sync_check_cuda_error();\n    }\n\n    return input_embeds;\n}\n\nTensor LanguageModel::Impl::PostEmbedding(const Tensor& features, Buffer symm_buf)\n{\n    NvtxScope scope(\"postDecodeEmbedding\");\n\n    const auto st = core::Context::stream().handle();\n\n    const int bsz              = features.shape(0);\n    const int local_vocab_size = weights_.post_decoder_embedding.output_dim;\n    const int vocab_size       = local_vocab_size * tp_size_;\n\n    if (bsz == 0) {\n        return Tensor{{0, vocab_size}, dtype_, kDEVICE};\n    }\n\n    if (tp_size_ == 1) {\n        Tensor logits{{bsz, vocab_size}, dtype_, kDEVICE};\n        linear_.Forward(features, weights_.post_decoder_embedding, logits);\n        sync_check_cuda_error();\n        TM_DEBUG_TENSOR(logits, \"logits\", 1);\n        return logits;\n    }\n    else if (use_ag2d_) {\n        Tensor logits{symm_buf.view(dtype_), {bsz, tp_size_, local_vocab_size}};\n        Tensor local = logits.slice({0, tp_rank_, 0}, {-1, 1, -1});\n        linear_.Forward(features, weights_.post_decoder_embedding, local.squeeze(1));\n        sync_check_cuda_error();\n        comm_.d_comm->AllGather2D(local.raw_data(),\n                                  logits.raw_data(),\n                                  vocab_size,\n                                  local_vocab_size,\n                                  local_vocab_size,\n                                  bsz,\n                                  logits.dtype(),\n                                  {true, true},\n                                  comm_.d_tp_group,\n                                  st);\n        sync_check_cuda_error();\n        return logits.view({bsz, -1});\n    }\n    else {\n        Tensor logits{symm_buf.view(dtype_), {tp_size_, bsz, local_vocab_size}};\n        Tensor local = logits.slice({tp_rank_, 0, 0}, {1, -1, -1});\n        linear_.Forward(features, weights_.post_decoder_embedding, local.squeeze(0));\n        sync_check_cuda_error();\n        comm_.d_comm->AllGather(local.raw_data(), logits.raw_data(), local.size(), local.dtype(), comm_.d_tp_group, st);\n        sync_check_cuda_error();\n        Tensor out{{bsz, vocab_size}, features.dtype(), features.device()};\n        invokeTransposeAxis01(\n            (uint16_t*)out.raw_data(), (uint16_t*)logits.raw_data(), tp_size_, bsz, local_vocab_size, st);\n        sync_check_cuda_error();\n        return out;\n    }\n}\n\nvoid LanguageModel::Impl::Setup(int phase, TensorMap& env)\n{\n    input_processor_->Run(BatchOp::kSetup, phase, env);\n\n    auto& d    = data_.at(phase);\n    auto& copy = *env.at(\"copy\").data<BatchCopy*>()[0];\n\n    const auto& rc = env.at(\"batch\").data<BatchData*>()[0]->rc;\n\n    d.n_generating = 0;\n\n    for (int i = 0; i < rc.size(); ++i) {\n        auto& c         = *rc[i];\n        d.autoregres[i] = c.autoregres;\n        d.generating[i] = c.generating;\n        d.n_generating += c.generating;\n        if (TM_UNLIKELY(!c.autoregres)) {\n            sequence_length_buf_[i] = c.history_len + c.alpha + c.input_len;\n        }\n    }\n\n    copy(sequence_length_buf_, rc.size(), d.sequence_length);\n\n    unified_decoder_->Run(BatchOp::kSetup, phase, env);\n    generation_->Run(BatchOp::kSetup, phase, env);\n    output_processor_->Run(BatchOp::kSetup, phase, env);\n}\n\nvoid LanguageModel::Impl::Prepare(int phase, TensorMap& env)\n{\n    env.emplace(\"autoreg_ids\", autoreg_ids_);\n\n    input_processor_->Run(BatchOp::kPrepare, phase, env);\n\n    auto& d = data_.at(phase);\n\n    auto& b    = *env.at(\"batch\").data<BatchData*>()[0];\n    auto& copy = *env.at(\"copy\").data<BatchCopy*>()[0];\n\n    // core::CopyT copy{};\n\n    if (auto group = copy.group()) {\n        for (int i = 0; i < b.bsz; ++i) {\n            if (const int j = b.perm[i]; j < b.bs0) {\n                copy(finished_.front().data<bool>() + j, 1, finished_.back().data<bool>() + i);\n            }\n            else {\n                copy(false_.data() + i, 1, finished_.back().data<bool>() + i);\n            }\n        }\n        finished_.Swap();\n    }\n\n    if (auto group = copy.group()) {\n        // sequence_length = history_len + input_len\n        for (int i = 0; i < b.bsz; ++i) {\n            if (const int j = b.perm[i]; j < b.bs0 && d.autoregres[i]) {\n                copy(sequence_length_.front().data<int>() + j, 1, sequence_length_.back().data<int>() + i);\n            }\n            else {\n                copy(d.sequence_length.data() + i, 1, sequence_length_.back().data<int>() + i);\n            }\n        }\n        sequence_length_.Swap();\n    }\n\n    Buffer_<int> k_offsets{b.bsz + 1, kDEVICE};\n    // PrefixSum(sequence_length_.front().data<int>(), bsz, k_offsets.data(), core::Context::stream().handle());\n\n    // Buffer_<int> k_offsets_tmp{k_offsets.size(), kCPU};\n    // Buffer_<int> sequence_length_tmp{sequence_length_.front().size(), kCPU};\n\n    // Copy(k_offsets, k_offsets_tmp);\n    // Copy(sequence_length_.front().buffer(), sequence_length_tmp);\n\n    // core::Context::stream().Sync();\n\n    // dbg(core::to_vector<int>(sequence_length_tmp.slice(0, bsz)));\n    // dbg(core::to_vector<int>(k_offsets_tmp.slice(0, bsz + 1)));\n\n    env.produce(\"finished\", finished_.front());\n    env.produce(\"sequence_length\", sequence_length_.front());\n    env.produce(\"k_offsets\", k_offsets);\n\n    unified_decoder_->Run(BatchOp::kPrepare, phase, env);\n    generation_->Run(BatchOp::kPrepare, phase, env);\n    output_processor_->Run(BatchOp::kPrepare, phase, env);\n}\n\nvoid LanguageModel::Impl::Forward(int phase, TensorMap& env)\n{\n\n    auto& d = data_.at(phase);\n    auto& b = *env.at(\"batch\").data<BatchData*>()[0];\n\n    {\n        Buffer_<int> k_offsets = env.at(\"k_offsets\").buffer();\n        PrefixSum(sequence_length_.front().data<int>(), b.bsz, k_offsets.data(), core::Context::stream().handle());\n    }\n\n    {  // compute input embeddings\n        auto input_ids = env.at(\"input_ids\").buffer();\n\n        Tensor input_embeds = LookupEmbedding(input_ids, symm_buf_);\n        TM_DEBUG_TENSOR(input_embeds, \"embeddings\", 1);\n\n        auto& copy = *env.at(\"copy\").data<BatchCopy*>()[0];\n        input_processor_->PatchEmbedding(phase, input_embeds, copy);\n        copy.Run();\n\n        env.produce(\"input_embeds\", std::move(input_embeds));\n        // dbg(env);\n    }\n\n    if (symm_buf_) {\n        env.produce(\"symm_buf\", symm_buf_);\n    }\n\n    env.produce(\"output_norm_weight\", weights_.output_norm_weight);\n\n    unified_decoder_->Forward(phase, env, weights_.decoder_layer_weights);\n\n    // env.at(\"batch\").data<BatchData*>()[0]->Notify();\n\n    output_processor_->OutputHiddenStatesAndLogits(phase, env, 2);\n\n    auto& hidden_states = env.at(\"hidden_states\");\n\n    env.produce(\"logits\", PostEmbedding(hidden_states, symm_buf_));\n\n    output_processor_->OutputHiddenStatesAndLogits(phase, env, 1);\n\n    if (d.n_generating) {\n        generation_->Run(BatchOp::kForward, phase, env);\n        Copy(env.at(\"output_ids\").buffer(), autoreg_ids_);\n    }\n}\n\nvoid LanguageModel::Impl::Unprep(int phase, TensorMap& env)\n{\n    auto& d    = data_.at(phase);\n    auto& copy = *env.at(\"copy\").data<BatchCopy*>()[0];\n\n    copy(sequence_length_.front().buffer(), d.sequence_length.size(), d.sequence_length);\n\n    copy(finished_.front().buffer(), d.finished.size(), d.finished);\n\n    generation_->Run(BatchOp::kUnprep, phase, env);\n}\n\nvoid LanguageModel::Impl::Fetch(int phase, TensorMap& env)\n{\n    auto& d    = data_.at(phase);\n    auto& copy = *env.at(\"copy\").data<BatchCopy*>()[0];\n\n    copy(d.sequence_length, d.sequence_length.size(), sequence_length_buf_);\n    env.produce(\"sequence_length\", sequence_length_buf_);\n\n    copy(d.finished, d.finished.size(), finished_buf_);\n    env.produce(\"finished\", finished_buf_);\n\n    env.produce(\"generating\", d.generating);\n\n    generation_->Run(BatchOp::kFetch, phase, env);\n}\n\nLanguageModel::~LanguageModel() = default;\n\nLanguageModel::LanguageModel(LanguageModel&&) noexcept = default;\n\nLanguageModel::LanguageModel(DataType              dtype,\n                             const ModelParam&     model,\n                             const EngineParam&    engine,\n                             const AttentionParam& attn,\n                             const MoeParam&       moe,\n                             const Context&        ctx,\n                             const LlamaWeight&    weights,\n                             int                   phases)\n{\n    impl_ = std::make_unique<Impl>(dtype, model, engine, attn, moe, ctx, weights, phases);\n}\n\nvoid LanguageModel::Run(BatchOp op, int phase, TensorMap& env)\n{\n    return TM_CHECK_NOTNULL(impl_)->Run(op, phase, env);\n}\n\nconst ModelParam& LanguageModel::model_param() const noexcept\n{\n    return TM_CHECK_NOTNULL(impl_)->param_;\n}\n\nconst AttentionParam& LanguageModel::attn_param() const noexcept\n{\n    return TM_CHECK_NOTNULL(impl_)->attn_param_;\n}\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/models/language_model.h",
    "content": "#pragma once\n\n#include <memory>\n\n#include \"src/turbomind/core/core.h\"\n#include \"src/turbomind/engine/batch.h\"\n#include \"src/turbomind/models/llama/context.h\"\n#include \"src/turbomind/models/llama/llama_params.h\"\n\nnamespace turbomind {\n\nclass LlamaWeight;\n\nclass LanguageModel {\npublic:\n    ~LanguageModel();\n\n    LanguageModel() = default;\n\n    LanguageModel(LanguageModel&&) noexcept;\n\n    explicit operator bool() const noexcept\n    {\n        return static_cast<bool>(impl_);\n    }\n\n    LanguageModel(DataType              dtype,\n                  const ModelParam&     model,\n                  const EngineParam&    engine,\n                  const AttentionParam& attn,\n                  const MoeParam&       moe,\n                  const Context&        ctx,\n                  const LlamaWeight&    weights,\n                  int                   phases);\n\n    void Run(BatchOp op, int phase, TensorMap& env);\n\n    const ModelParam&     model_param() const noexcept;\n    const AttentionParam& attn_param() const noexcept;\n\nprivate:\n    struct Impl;\n    std::unique_ptr<Impl> impl_;\n};\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/models/llama/Barrier.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include \"src/turbomind/utils/cuda_utils.h\"\n#include \"src/turbomind/utils/logger.h\"\n#ifndef _MSC_VER\n#include <pthread.h>\n#endif\n\nnamespace turbomind {\n\n#ifdef _MSC_VER\n\nclass Barrier {\npublic:\n    Barrier(unsigned count)\n    {\n        TM_LOG_INFO(\"Barrier(%d)\", (int)count);\n        FT_CHECK(count == 1);\n    }\n\n    Barrier(const Barrier&) = delete;\n    Barrier& operator=(const Barrier&) = delete;\n    Barrier(Barrier&&) noexcept        = delete;\n    Barrier& operator=(Barrier&&) noexcept = delete;\n\n    void wait() {}\n\n    ~Barrier() {}\n};\n\n#else\n\nclass Barrier {\npublic:\n    Barrier(unsigned count): count_(count)\n    {\n        if (count_ > 1) {\n            pthread_barrier_init(&barrier_, nullptr, count);\n        }\n    }\n\n    Barrier(const Barrier&) = delete;\n    Barrier& operator=(const Barrier&) = delete;\n    Barrier(Barrier&&) noexcept        = delete;\n    Barrier& operator=(Barrier&&) noexcept = delete;\n\n    void wait()\n    {\n        if (count_ > 1) {\n            pthread_barrier_wait(&barrier_);\n        }\n    }\n\n    ~Barrier()\n    {\n        if (count_ > 1) {\n            pthread_barrier_destroy(&barrier_);\n        }\n    }\n\nprivate:\n    const int         count_;\n    pthread_barrier_t barrier_{};\n};\n\n#endif\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/models/llama/BlockManager.cc",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include <algorithm>\n\n#include \"src/turbomind/models/llama/BlockManager.h\"\n#include \"src/turbomind/utils/debug_utils.h\"\n#include \"src/turbomind/utils/logger.h\"\n#include \"src/turbomind/utils/string_utils.h\"\n\nnamespace turbomind {\n\nBlockManager::BlockManager(\n    size_t block_size, double block_count, int chunk_size, core::Allocator allocator, GetFreeMemSize get_free_size):\n    block_size_(block_size), allocator_(allocator)\n{\n    if (block_count < 1.) {\n        max_block_count_ = GetBlockCount(block_size, block_count, get_free_size);\n    }\n    else {\n        max_block_count_ = block_count;\n    }\n\n    if (chunk_size == 0) {\n        chunk_size_ = static_cast<int>(std::sqrt(max_block_count_));\n    }\n    else if (chunk_size < 0) {\n        chunk_size_ = max_block_count_;\n    }\n    else {\n        chunk_size_ = chunk_size;\n    }\n\n    TM_LOG_INFO(\"[BlockManager] block_size = %.3f MB\", (float)block_size_ / (1 << 20));\n    TM_LOG_INFO(\"[BlockManager] max_block_count = %d\", max_block_count_);\n    TM_LOG_INFO(\"[BlockManager] chunk_size = %d\", chunk_size_);\n\n    blocks_.reserve(max_block_count_);\n\n    active_ids_.reserve(max_block_count_);\n    cached_ids_.reserve(max_block_count_);\n    free_ids_.reserve(max_block_count_);\n\n    // pre-allocate first chunk\n    Malloc();\n    dbg(free_ids_);\n}\n\nBlockManager::~BlockManager()\n{\n    for (auto& chunk : chunks_) {\n        allocator_->deallocate(chunk, block_size_);\n    }\n}\n\nbool BlockManager::Malloc()\n{\n    auto chunk_size = std::min<int>(chunk_size_, max_block_count_ - blocks_.size());\n\n    if (!chunk_size) {\n        return false;\n    }\n\n    auto ptr = (std::byte*)allocator_->allocate(block_size_ * chunk_size);\n    if (!ptr) {\n        return false;\n    }\n\n    chunks_.push_back(ptr);\n\n    for (int i = 0; i < chunk_size; ++i, ptr += block_size_) {\n        auto& block     = blocks_.emplace_back();\n        block.use_count = 0;\n        block.id        = (int)blocks_.size() - 1;\n        block.timestamp = 0;\n        block.data      = ptr;\n\n        free_ids_.push_back(block.id);\n    }\n\n    return true;\n}\n\nsize_t BlockManager::GetBlockCount(size_t block_size, double ratio, GetFreeMemSize get_free_size)\n{\n    size_t free = get_free_size();\n    return static_cast<size_t>(free * ratio) / block_size;\n}\n\nvoid BlockManager::Move(std::vector<int>& src, const std::vector<int>& delta, std::vector<int>& dst)\n{\n    TM_CHECK_GE(src.size(), delta.size());\n    std::vector<int> src1(src.size() - delta.size());\n    {\n        auto end = std::set_difference(src.begin(), src.end(), delta.begin(), delta.end(), src1.begin());\n        TM_CHECK(end == src1.end());\n    }\n    src.swap(src1);\n\n    std::vector<int> dst1(dst.size() + delta.size());\n    {\n        auto end = std::set_union(dst.begin(), dst.end(), delta.begin(), delta.end(), dst1.begin());\n        TM_CHECK(end == dst1.end());\n    }\n    dst.swap(dst1);\n}\n\nauto BlockManager::Allocate(int count) -> std::pair<BlockIds, UniqueIds>\n{\n    while (free_ids_.size() < count) {\n        if (!Malloc()) {\n            throw std::runtime_error(\"out of memory\");\n        }\n    }\n\n    BlockIds  block_ids(count);\n    UniqueIds unique_ids(count);\n\n    for (int i = 0; i < count; ++i) {\n        int   idx = free_ids_[i];\n        auto& b   = blocks_[idx];\n        TM_CHECK(is_free(b));  // pre-condition: uc == 0 && ts == 0\n        b.use_count = 1;\n        b.unique_id = unique_id_++;\n        b.timestamp = timestamp_++;\n        TM_CHECK(is_active(b));  // post-condition\n        block_ids[i]  = idx;\n        unique_ids[i] = b.unique_id;\n    }\n\n    Move(free_ids_, block_ids, active_ids_);\n\n    dbg(free_ids_, active_ids_);\n\n    return {block_ids, unique_ids};\n}\n\nvoid BlockManager::Evict(int count)\n{\n    TM_CHECK_LE(count, cached_ids_.size());\n    std::vector<int> idxs(cached_ids_);\n    // get first `count` cached ids according to timestamp\n    std::nth_element(idxs.begin(), idxs.begin() + count, idxs.end(), [&](int i, int j) {\n        return blocks_[i].timestamp < blocks_[j].timestamp;\n    });\n    idxs.resize(count);\n\n    // sort the retrieved ids\n    std::sort(idxs.begin(), idxs.end());\n\n    // set as free\n    for (const auto& idx : idxs) {\n        auto& b = blocks_[idx];\n        TM_CHECK(is_cached(b));  // pre-condition\n        b.unique_id = 0;\n        b.timestamp = 0;\n        TM_CHECK(is_free(b));  // post-condition\n    }\n\n    Move(cached_ids_, idxs, free_ids_);\n\n    dbg(cached_ids_, free_ids_);\n}\n\nvoid BlockManager::Free(BlockIds ids)\n{\n    std::sort(ids.begin(), ids.end());\n\n    for (const auto& i : ids) {\n        auto& b = blocks_[i];\n        TM_CHECK(is_cached(b));  // pre-condition\n        b.unique_id = 0;\n        b.timestamp = 0;\n        TM_CHECK(is_free(b));  // post-condition\n    }\n\n    Move(cached_ids_, ids, free_ids_);\n}\n\nint BlockManager::Unlock(const BlockIds& ids)\n{\n    BlockIds unlock;\n    unlock.reserve(ids.size());\n\n    for (const auto& i : ids) {\n        auto& b = blocks_[i];\n        TM_CHECK(is_active(b));  // pre-condition\n        if (--b.use_count == 0) {\n            unlock.push_back(b.id);\n            TM_CHECK(is_cached(b));  // post-condition\n        }\n    }\n\n    std::sort(unlock.begin(), unlock.end());\n\n    Move(active_ids_, unlock, cached_ids_);\n\n    dbg(active_ids_, cached_ids_);\n    return unlock.size();\n}\n\nint BlockManager::Lock(const BlockIds& ids)\n{\n    BlockIds lock;\n    lock.reserve(ids.size());\n\n    for (const auto& i : ids) {\n        auto& b = blocks_[i];\n        if (++b.use_count == 1) {\n            lock.push_back(i);\n            TM_CHECK(is_active(b));  // post-condition\n        }\n    }\n\n    std::sort(lock.begin(), lock.end());\n\n    Move(cached_ids_, lock, active_ids_);\n\n    // dbg(cached_ids_, active_ids_);\n\n    return lock.size();\n}\n\nvoid BlockManager::Touch(const BlockIds& ids)\n{\n    std::for_each(ids.crbegin(), ids.crend(), [this](int i) {\n        TM_CHECK(is_active(blocks_[i]));\n        blocks_[i].timestamp = timestamp_++;\n    });\n}\n\nint BlockManager::Verify(const std::vector<int>& block_ids, const std::vector<uint64_t>& unique_ids)\n{\n    TM_CHECK_EQ(block_ids.size(), unique_ids.size());\n    int valid = block_ids.size();\n    for (int i = 0; i < block_ids.size(); ++i) {\n        if (unique_id(block_ids[i]) != unique_ids[i]) {\n            valid = i;\n            break;\n        }\n    }\n    int miss = 0;\n    for (int i = valid; i < block_ids.size(); ++i) {\n        miss += (unique_id(block_ids[i]) != unique_ids[i]);\n    }\n    // All later blocks should have been invalidated\n    TM_CHECK_EQ(miss, (int)block_ids.size() - valid)\n        << fmtstr(\"count = %d, valid = %d, miss = %d\", (int)block_ids.size(), valid, miss);\n    return valid;\n}\n\nSnapshot BlockManager::TakeSnapshot()\n{\n    std::vector<int> use_count(blocks_.size());\n    for (const auto& idx : active_ids_) {\n        use_count[idx] = blocks_[idx].use_count;\n    }\n    return {active_count(), cached_count(), free_count(), std::move(use_count)};\n}\n\nstd::ostream& operator<<(std::ostream& os, const BlockManager& manager)\n{\n    os << \"block_size: \" << manager.block_size_ << \", \";\n    os << \"max_block_count: \" << manager.max_block_count_ << \", \";\n    os << \"chunk_size: \" << manager.chunk_size_ << \", \";\n    os << \"chunks: \" << manager.chunks_.size() << \", \";\n    os << \"active_ids: \" << manager.active_ids_.size() << \", \";\n    os << \"cached_ids: \" << manager.cached_ids_.size() << \", \";\n    os << \"free_ids: \" << manager.free_ids_.size() << \", \";\n    os << \"blocks: \" << manager.blocks_.size() << \", \";\n    os << \"unique_id: \" << manager.unique_id_ << \", \";\n    os << \"timestamp: \" << manager.timestamp_;\n    return os;\n}\n\nstd::ostream& operator<<(std::ostream& os, const Block& block)\n{\n    os << \"id=\" << block.id << \", use_count=\" << block.use_count << \", unique_id=\" << block.unique_id\n       << \", timestamp=\" << block.timestamp << \", data=\" << block.data;\n    return os;\n}\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/models/llama/BlockManager.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include \"src/turbomind/core/allocator.h\"\n#include \"src/turbomind/models/llama/Barrier.h\"\n#include \"src/turbomind/utils/cuda_utils.h\"\n#include \"src/turbomind/utils/logger.h\"\n#include <algorithm>\n#include <atomic>\n#include <cstdint>\n#include <cuda_runtime.h>\n#include <functional>\n#include <iterator>\n#include <numeric>\n#include <queue>\n#include <sstream>\n#include <unordered_map>\n#include <vector>\n\nnamespace turbomind {\n\n// [L, H, S, D]\n\n// [L, S/x, H, x, D]\n\nstruct Block {\n    int      id;         // fixed linear id in the pool\n    int      use_count;  // active sequences using the block\n    uint64_t unique_id;  // unique for every block allocation\n    uint64_t timestamp;\n    void*    data;\n\n    friend std::ostream& operator<<(std::ostream& os, const Block& block);\n    friend std::string   to_string(const Block& b)\n    {\n        std::stringstream ss;\n        ss << b;\n        return ss.str();\n    }\n};\n\nusing BlockIds  = std::vector<int>;\nusing UniqueIds = std::vector<uint64_t>;\n\ninline bool is_active(const Block& block)\n{\n    // timestamp may be 0 for newly allocated block that has not been written\n    return block.use_count > 0;\n}\n\ninline bool is_cached(const Block& block)\n{\n    return block.use_count == 0 && block.timestamp != 0;\n}\n\ninline bool is_free(const Block& block)\n{\n    return block.use_count == 0 && block.timestamp == 0;\n}\n\nstruct Snapshot {\n    int              active;\n    int              cached;\n    int              free;\n    std::vector<int> use_count;\n};\n\nusing GetFreeMemSize = std::function<size_t()>;\n\nclass BlockManager {\npublic:\n    explicit BlockManager(\n        size_t block_size, double block_count, int chunk_size, core::Allocator allocator, GetFreeMemSize get_free_size);\n\n    ~BlockManager();\n\n    // free -> active (use_count = 1, ref_count = 1)\n    [[nodiscard]] std::pair<BlockIds, UniqueIds> Allocate(int count);\n\n    // cached -> active (use_count += 1)\n    [[maybe_unused]] int Lock(const BlockIds& ids);\n\n    // active -> cached (use_count -= 1)\n    [[maybe_unused]] int Unlock(const BlockIds& ids);\n\n    // cached -> free (ref_count = 0)\n    void Evict(int count);\n\n    // cached -> free (ref_count -= 1)\n    void Free(BlockIds bs);\n\n    // increase timestamp in reversed order\n    void Touch(const BlockIds& bs);\n\n    [[nodiscard]] int Verify(const BlockIds& block_ids, const UniqueIds& unique_ids);\n\n    Snapshot TakeSnapshot();\n\n    int max_block_count() const noexcept\n    {\n        return max_block_count_;\n    }\n\n    int total_count() const noexcept\n    {\n        return blocks_.size();\n    }\n\n    int active_count() const noexcept\n    {\n        return active_ids_.size();\n    }\n\n    int cached_count() const noexcept\n    {\n        return cached_ids_.size();\n    }\n\n    int free_count() const noexcept\n    {\n        return free_ids_.size();\n    }\n\n    Block& block(int idx)\n    {\n        return blocks_[idx];\n    }\n\n    int unique_id(int idx)\n    {\n        return blocks_[idx].unique_id;\n    }\n\n    friend std::ostream& operator<<(std::ostream& os, const BlockManager&);\n\nprivate:\n    static size_t GetBlockCount(size_t block_size, double ratio, GetFreeMemSize get_free_size);\n\n    // move indices between sets\n    static void Move(BlockIds& src, const BlockIds& delta, BlockIds& dst);\n\n    // allocate a chunk of blocks\n    bool Malloc();\n\nprivate:\n    size_t block_size_;\n    int    max_block_count_{};\n    int    chunk_size_{};\n\n    core::Allocator allocator_;\n\n    std::vector<void*> chunks_;\n\n    BlockIds active_ids_;\n    BlockIds cached_ids_;\n    BlockIds free_ids_;\n\n    std::vector<Block> blocks_;  // < 100k\n\n    uint64_t unique_id_{1};\n    uint64_t timestamp_{1};\n};\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/models/llama/BlockTrie.cc",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include \"src/turbomind/models/llama/BlockTrie.h\"\n#include \"src/turbomind/models/llama/SequenceManager.h\"\n\nnamespace turbomind {\n\nsize_t hash(const std::vector<int>& vec)\n{\n    size_t seed = vec.size();\n    for (const auto& i : vec) {\n        seed ^= std::hash<int>{}(i) + 0x9e3779b9 + (seed << 6) + (seed >> 2);\n    }\n    return seed;\n}\n\nBlockTrie::BlockTrie(size_t block_len, std::shared_ptr<BlockManager> block_manager):\n    block_seq_len_(block_len), block_manager_(block_manager)\n{\n    root_ = std::make_shared<TrieNode>();\n}\n\nstd::tuple<BlockIds, UniqueIds> BlockTrie::Match(const Sequence& seq)\n{\n    BlockIds  block_ids;\n    UniqueIds unique_ids;\n\n    auto node  = root_;\n    auto first = seq.prompt.begin();\n\n    // Warning: Do not use \"<=\" operator even when seq.prompt length is evenly\n    // divisible by block_seq_len_. The model needs at least one input token to generate output.\n    while (first + block_seq_len_ < seq.prompt.end()) {\n        const std::vector<int> segment{first, first + block_seq_len_};\n        const size_t           hash_key = hash(segment);\n        if (const auto it = node->children.find(hash_key); it != node->children.end()) {\n            if (segment == it->second->tokens) {\n                block_ids.push_back(it->second->block_id);\n                unique_ids.push_back(it->second->block_unique_id);\n                node = it->second;\n                first += block_seq_len_;\n            }\n            else {\n                TM_LOG_WARNING(\"hash collision detected\");\n                break;\n            }\n        }\n        else {\n            break;\n        }\n    }\n\n    return std::make_tuple(block_ids, unique_ids);\n}\n\nstd::tuple<BlockIds, UniqueIds> BlockTrie::Cache(const Sequence& seq, const std::vector<int>& tokens)\n{\n    // Ensure the seq is active or locked so that all cache blocks must be valid\n    TM_CHECK_NE(seq.status, Sequence::kCached);\n    TM_CHECK_LE(seq.cache_len, seq.blocks.size() * block_seq_len_);\n\n    auto node = root_;\n\n    BlockIds  cache_block_ids;\n    UniqueIds cache_block_unique_ids;\n\n    const int n_blocks = std::min(seq.cache_len, (int)tokens.size()) / block_seq_len_;\n\n    int new_cached = 0;\n\n    for (int idx = 0; idx < n_blocks; ++idx) {\n        auto start = tokens.begin() + idx * block_seq_len_;\n        auto end   = start + block_seq_len_;\n\n        const std::vector<int> segment(start, end);\n        const size_t           hash_key = hash(segment);  // TODO(lvhan): add salt to ensure the hash security\n\n        int      block_id        = seq.blocks[idx];\n        uint64_t block_unique_id = seq.block_unique_ids[idx];\n\n        if (auto it = node->children.find(hash_key); it != node->children.end()) {\n            if (segment == it->second->tokens) {  // fast-forward\n                node                  = it->second;\n                node->block_id        = block_id;\n                node->block_unique_id = block_unique_id;\n            }\n            else {\n                TM_LOG_WARNING(\"[BlockTrie][cache] Hash collision detected\");\n                break;\n            }\n        }\n        else {\n            // insert new node\n            node                  = node->children.emplace_hint(it, hash_key, std::make_shared<TrieNode>())->second;\n            node->hash_key        = hash_key;\n            node->tokens          = segment;\n            node->block_id        = block_id;\n            node->block_unique_id = block_unique_id;\n            new_cached += block_seq_len_;\n        }\n        cache_block_ids.emplace_back(block_id);\n        cache_block_unique_ids.emplace_back(block_unique_id);\n    }\n\n    TM_LOG_INFO(\"[BlockTrie][cache] %d new tokens cached\", new_cached);\n\n    return std::make_tuple(cache_block_ids, cache_block_unique_ids);\n}\n\nvoid BlockTrie::Verify()\n{\n    DFS(root_);\n}\n\nvoid BlockTrie::DFS(std::shared_ptr<TrieNode>& node)\n{\n    for (auto it = node->children.begin(); it != node->children.end();) {\n        if (block_manager_->unique_id(it->second->block_id) != it->second->block_unique_id) {\n            // child invalid\n            it = node->children.erase(it);\n        }\n        else {\n            DFS(it->second);\n            it++;\n        }\n    }\n}\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/models/llama/BlockTrie.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include \"src/turbomind/models/llama/BlockManager.h\"\n#include <memory>\n#include <unordered_map>\n#include <vector>\n\nnamespace turbomind {\n\nstruct Sequence;\n\nstruct TrieNode {\n    std::unordered_map<size_t, std::shared_ptr<TrieNode>> children;\n    size_t                                                hash_key;\n    std::vector<int>                                      tokens;\n    int                                                   block_id;\n    uint64_t                                              block_unique_id;\n    int                                                   num_matched;\n};\n\nclass BlockTrie {\npublic:\n    explicit BlockTrie(size_t block_len, std::shared_ptr<BlockManager> block_manager);\n\n    /**\n     * @brief Attempt to match cached key-value (KV) blocks for a given sequence.\n     *\n     * This function iterates the tokens of the sequence and attempts\n     * to match them with the cached KV blocks. If the max prefix match is found,\n     * it returns the IDs, unique IDs of the matched blocks.\n     *\n     * @param seq The sequence whose tokens are to be matched against the cached KV blocks.\n     * @return A tuple containing the following:\n     *         - BlockIds: A list of IDs of the matched blocks.\n     *         - UniqueIds: A list of unique IDs of the matched blocks.\n     *\n     * @note If no blocks are matched, all containers in the returned tuple will be empty.\n     */\n    std::tuple<BlockIds, UniqueIds> Match(const Sequence& seq);\n\n    /**\n     * @brief Cache the key-value (KV) blocks of a given sequence.\n     *\n     * This function caches the KV blocks of the specified sequence. Only valid blocks\n     * of a sequence whose status is NOT `Sequence::kCached` are considered\n     * to be cached\n     *\n     * @param seq The sequence whose KV blocks are to be cached.\n     * @param tokens The token list corresponding to the KV blocks\n     * @return A tuple containing the following:\n     *         - BlockIds: A list of IDs of the cached blocks.\n     *         - UniqueIds: A list of unique IDs of the cached blocks.\n     */\n    std::tuple<BlockIds, UniqueIds> Cache(const Sequence& seq, const std::vector<int>& tokens);\n\n    /**\n     * @brief remove invalid nodes\n     */\n    void Verify();\n\nprivate:\n    void DFS(std::shared_ptr<TrieNode>& node);\n\nprivate:\n    size_t block_seq_len_;\n\n    std::shared_ptr<BlockManager> block_manager_;\n\n    std::shared_ptr<TrieNode> root_;\n};\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/models/llama/CMakeLists.txt",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\ncmake_minimum_required(VERSION 3.11)\n\n\nfind_package(CUDAToolkit REQUIRED)\n\nadd_library(Llama STATIC\n        LlamaV2.cc\n        LlamaBatch.cc\n        LlamaLinear.cu\n        BlockManager.cc\n        BlockTrie.cc\n        SequenceManager.cc\n        LlamaWeight.cc\n        LlamaDenseWeight.cc\n        LlamaDecoderLayerWeight.cc\n        LlamaFfnLayer.cc\n        moe_ffn_layer.cc\n        unified_decoder.cc\n        unified_attention_layer.cc\n        llama_kernels.cu\n        llama_utils.cu\n        mla_utils.cu\n        GatedDeltaNetWeight.cc\n        GatedDeltaNetLayer.cc\n        gated_delta_net_kernels.cu\n)\nset_property(TARGET Llama PROPERTY POSITION_INDEPENDENT_CODE  ON)\nset_property(TARGET Llama PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS  ON)\ntarget_link_libraries(Llama PUBLIC CUDA::cudart\n        engine\n        core\n        gemm2\n        CUDA::cublas\n        nvidia::cutlass::cutlass\n        rms_norm\n        DynamicDecodeLayer\n        activation_kernels\n        activation\n        attention\n        decoding_kernels\n        quantization_kernels\n        unfused_attention_kernels\n        gpt_kernels\n        memory_utils\n        cuda_utils\n        logger\n        anomaly_handler)\n"
  },
  {
    "path": "src/turbomind/models/llama/GatedDeltaNetLayer.cc",
    "content": "#include \"src/turbomind/models/llama/GatedDeltaNetLayer.h\"\n#include \"src/turbomind/core/allocator.h\"\n#include \"src/turbomind/core/check.h\"\n#include \"src/turbomind/core/data_type.h\"\n#include \"src/turbomind/models/llama/SequenceManager.h\"\n#include \"src/turbomind/models/llama/gated_delta_net_kernels.h\"\n#include \"src/turbomind/utils/cuda_utils.h\"\n#include \"src/turbomind/utils/logger.h\"\n\nnamespace turbomind {\n\nGatedDeltaNetLayer::GatedDeltaNetLayer(const ModelParam&     model,\n                                       const AttentionParam& attn,\n                                       const EngineParam&    engine,\n                                       int                   tp_size,\n                                       const Context&        ctx,\n                                       int                   phases):\n    hidden_units_(model.hidden_units),\n    num_k_heads_(model.linear_num_key_heads / tp_size),\n    num_v_heads_(model.linear_num_value_heads / tp_size),\n    key_head_dim_(model.linear_key_head_dim > 0 ? model.linear_key_head_dim : model.head_dim),\n    value_head_dim_(model.linear_value_head_dim > 0 ? model.linear_value_head_dim : model.head_dim),\n    d_conv_(model.linear_conv_kernel_dim > 0 ? model.linear_conv_kernel_dim : 4),\n    key_dim_(num_k_heads_ * key_head_dim_),\n    value_dim_(num_v_heads_ * value_head_dim_),\n    conv_dim_(key_dim_ * 2 + value_dim_),\n    norm_eps_(model.norm_eps),\n    dtype_(model.data_type),\n    state_dtype_(model.linear_state_dtype),\n    linear_(*ctx.linear)\n{\n    layer_types_       = model.layer_types;\n    num_linear_layers_ = 0;\n    for (auto t : layer_types_) {\n        if (t == 1)\n            ++num_linear_layers_;\n    }\n\n    TM_LOG_INFO(\"GatedDeltaNetLayer: num_k=%d num_v=%d k_dim=%d v_dim=%d \"\n                \"conv_dim=%d d_conv=%d num_linear_layers=%d\",\n                num_k_heads_,\n                num_v_heads_,\n                key_dim_,\n                value_dim_,\n                conv_dim_,\n                d_conv_,\n                num_linear_layers_);\n\n    if (num_linear_layers_ > 0) {\n        conv_state_ptrs_buf_      = {engine.max_batch_size, kCPUpinned};\n        recurrent_state_ptrs_buf_ = {engine.max_batch_size, kCPUpinned};\n    }\n\n    for (int i = 0; i < phases; ++i) {\n        data_.emplace_back();\n        if (num_linear_layers_ > 0) {\n            data_.at(i).conv_state_ptrs      = empty_like(conv_state_ptrs_buf_, kDEVICE);\n            data_.at(i).recurrent_state_ptrs = empty_like(recurrent_state_ptrs_buf_, kDEVICE);\n        }\n    }\n\n    int device = 0;\n    cudaGetDevice(&device);\n    cudaDeviceGetAttribute(&sm_count_, cudaDevAttrMultiProcessorCount, device);\n    work_counter_ = {1, kDEVICE};\n\n    check_cuda_error(cudaStreamCreateWithPriority(&aux_stream_, cudaStreamNonBlocking, -1));\n    check_cuda_error(cudaEventCreateWithFlags(&ev_before_, cudaEventDisableTiming));\n    check_cuda_error(cudaEventCreateWithFlags(&ev_after_, cudaEventDisableTiming));\n}\n\nGatedDeltaNetLayer::~GatedDeltaNetLayer()\n{\n    cudaStreamDestroy(aux_stream_);\n    cudaEventDestroy(ev_before_);\n    cudaEventDestroy(ev_after_);\n}\n\nvoid GatedDeltaNetLayer::Run(BatchOp op, int phase, TensorMap& env)\n{\n    if (op == BatchOp::kAdd) {\n        Buffer_<RequestCache*> rc    = env.at(\"requests\").buffer();\n        const auto             dtype = dtype_;\n        for (int i = 0; i < rc.size(); ++i) {}\n    }\n    else if (op == BatchOp::kSetup) {\n        Setup(phase, env);\n    }\n    else if (op == BatchOp::kPrepare) {\n        auto& d     = data_.at(phase);\n        d.q_offsets = env.at(\"q_offsets\").buffer().borrow();\n        d.k_offsets = env.at(\"k_offsets\").buffer().borrow();\n    }\n}\n\nvoid GatedDeltaNetLayer::Setup(int phase, TensorMap& env)\n{\n    auto&       d = data_.at(phase);\n    const auto& b = *env.at(\"batch\").data<BatchData*>()[0];\n\n    d.batch_size = b.rc.size();\n    d.rc.resize(d.batch_size);\n    d.input_lens.resize(d.batch_size);\n\n    d.conv_states.resize(d.batch_size);\n    d.recurrent_states.resize(d.batch_size);\n\n    for (int i = 0; i < d.batch_size; ++i) {\n        d.rc[i]         = b.rc[i].get();\n        d.input_lens[i] = b.rc[i]->input_len;\n\n        auto& s = *b.rc[i]->seq;\n        TM_CHECK(s.conv_states && s.recurrent_states)\n            << \"Linear-attention state slot is not bound for sequence \" << s.id;\n        if (s.linear_states_need_reset) {\n            // Reset newly assigned pooled slot state on first use. Keep GPU-side\n            // state initialization out of SequenceManager.\n            Clear(s.conv_states);\n            Clear(s.recurrent_states);\n            s.linear_states_need_reset = false;\n        }\n\n        // Linear-attention requests are restricted to stateless execution, so\n        // the sequence-owned states can be passed directly here.\n        d.conv_states[i]      = s.conv_states;\n        d.recurrent_states[i] = s.recurrent_states;\n\n        conv_state_ptrs_buf_[i]      = d.conv_states[i].raw_data();\n        recurrent_state_ptrs_buf_[i] = d.recurrent_states[i].raw_data();\n    }\n\n    Copy(conv_state_ptrs_buf_, d.batch_size, d.conv_state_ptrs);\n    Copy(recurrent_state_ptrs_buf_, d.batch_size, d.recurrent_state_ptrs);\n}\n\nstatic int linear_layer_index(int layer_id, const std::vector<int>& layer_types)\n{\n    int idx = 0;\n    for (int i = 0; i < layer_id && i < (int)layer_types.size(); ++i) {\n        if (layer_types[i] == 1)\n            ++idx;\n    }\n    return idx;\n}\n\nvoid GatedDeltaNetLayer::Forward(ForwardParam p)\n{\n    TM_LOG_DEBUG(__PRETTY_FUNCTION__);\n\n    const int token_num = p.input.shape(0);\n    if (token_num == 0)\n        return;\n\n    const auto  dtype   = p.input.dtype();\n    const auto  device  = p.input.device();\n    const auto  stream  = core::Context::stream().handle();\n    const auto& weights = *p.weights;\n\n    auto& pd = data_.at(p.phase);\n\n    auto dispatch = [&](auto t) {\n        using T = decltype(t);\n\n        // =================================================================\n        // 1. Single fused input projection: reads p.input once from HBM.\n        //    Output columns are ordered: [qkv | z | b | a]\n        //    where the split dims are: conv_dim_, value_dim_, v_heads_tp_, v_heads_tp_\n        // =================================================================\n        const int v_heads_tp = num_v_heads_;  // already TP-sharded\n        Tensor    all_proj   = linear_.Forward(p.input, weights.in_proj_all);\n        sync_check_cuda_error();\n\n        // Column offsets per token (all_proj is token-major, row-major):\n        //   [0, conv_dim_)           -> mixed_qkv\n        //   [conv_dim_, +value_dim_) -> z\n        //   [conv_dim_+value_dim_, +v_heads_tp) -> b (beta logit)\n        //   [conv_dim_+value_dim_+v_heads_tp, +v_heads_tp) -> a (alpha/dt)\n        const int all_col = conv_dim_ + value_dim_ + v_heads_tp * 2;\n        // const T* sub-pointers are derived per-request below; stride = all_col.\n\n        // =================================================================\n        // 2. Compute beta and g for all tokens\n        //    b_raw and a_raw are sliced from the fused projection output.\n        //    Stride between tokens is all_col elements.\n        // =================================================================\n        const int bg_total = token_num * num_v_heads_;\n\n        const int b_offset = conv_dim_ + value_dim_;  // column offset to b logits\n        const int a_offset = b_offset + v_heads_tp;   // column offset to a logits\n\n        Tensor beta{{token_num, num_v_heads_}, dtype, device};\n        Tensor g{{token_num, num_v_heads_}, dtype, device};\n\n        auto b = all_proj.slice({0, b_offset}, {-1, v_heads_tp});\n        auto a = all_proj.slice({0, a_offset}, {-1, v_heads_tp});\n\n        ComputeBetaG_v2(beta, g, b, a, weights.A_log, weights.dt_bias, stream);\n\n        // =================================================================\n        // 3. Process all requests at once via batched kernel launches\n        // =================================================================\n        Tensor attn_out{{token_num, value_dim_}, dtype, device};\n        Tensor conv_out{{token_num, conv_dim_}, dtype, device};\n\n        const int state_layer_idx              = linear_layer_index(p.layer_id, layer_types_);\n        const int conv_state_layer_offset      = state_layer_idx * (conv_dim_ * d_conv_);\n        const int recurrent_state_layer_offset = state_layer_idx * (num_v_heads_ * key_head_dim_ * value_head_dim_);\n\n        // ----- 3a. Fused Causal Conv1d + SiLU (all requests) -----\n        // all_proj carries the non-contiguous qkv slice (stride = all_col);\n        // in_stride is derived from all_proj.stride(0) inside the launcher.\n        invokeFusedConv1dSiLU(conv_out,\n                              all_proj,\n                              weights.conv1d,\n                              Tensor{},\n                              pd.conv_state_ptrs,\n                              pd.q_offsets,\n                              pd.k_offsets,\n                              pd.batch_size,\n                              conv_state_layer_offset,\n                              sm_count_,\n                              work_counter_.data(),\n                              stream);\n        sync_check_cuda_error();\n\n        // ----- 3b. Gated Delta Rule -----\n        // Requests are sorted by input_len: decode (seq_len==1) first, prefill last.\n        // Find the split point and dispatch each half to its optimal kernel.\n        // When both are present, run them concurrently on separate streams.\n        {\n            int decode_count = 0;\n            for (int i = 0; i < pd.batch_size; ++i) {\n                if (pd.input_lens[i] <= 1)\n                    ++decode_count;\n                else\n                    break;\n            }\n            const int prefill_count = pd.batch_size - decode_count;\n\n            if (decode_count > 0 && prefill_count > 0) {\n                // Fork: aux_stream (high priority) waits for prior work on main stream\n                check_cuda_error(cudaEventRecord(ev_before_, stream));\n                check_cuda_error(cudaStreamWaitEvent(aux_stream_, ev_before_));\n\n                // Decode on main stream\n                auto dc_state = pd.recurrent_state_ptrs.slice(0, decode_count);\n                auto dc_q     = pd.q_offsets.slice(0, decode_count + 1);\n                invokeGatedDeltaRuleBatched_v3(attn_out,\n                                               conv_out,\n                                               beta,\n                                               g,\n                                               dc_state,\n                                               dc_q,\n                                               decode_count,\n                                               num_k_heads_,\n                                               recurrent_state_layer_offset,\n                                               state_dtype_,\n                                               sm_count_,\n                                               work_counter_.data(),\n                                               stream);\n\n                // Prefill on aux stream (higher priority)\n                auto pf_state = pd.recurrent_state_ptrs.slice(decode_count, prefill_count);\n                auto pf_q     = pd.q_offsets.slice(decode_count, prefill_count + 1);\n                invokeChunkedGatedDeltaRuleBatched(attn_out,\n                                                   conv_out,\n                                                   beta,\n                                                   g,\n                                                   pf_state,\n                                                   pf_q,\n                                                   prefill_count,\n                                                   num_k_heads_,\n                                                   recurrent_state_layer_offset,\n                                                   state_dtype_,\n                                                   sm_count_,\n                                                   work_counter_.data(),\n                                                   aux_stream_);\n\n                // Join: main stream waits for prefill to finish\n                check_cuda_error(cudaEventRecord(ev_after_, aux_stream_));\n                check_cuda_error(cudaStreamWaitEvent(stream, ev_after_));\n            }\n            else if (decode_count > 0) {\n                auto state_slice = pd.recurrent_state_ptrs.slice(0, decode_count);\n                auto q_slice     = pd.q_offsets.slice(0, decode_count + 1);\n                invokeGatedDeltaRuleBatched_v3(attn_out,\n                                               conv_out,\n                                               beta,\n                                               g,\n                                               state_slice,\n                                               q_slice,\n                                               decode_count,\n                                               num_k_heads_,\n                                               recurrent_state_layer_offset,\n                                               state_dtype_,\n                                               sm_count_,\n                                               work_counter_.data(),\n                                               stream);\n            }\n            else if (prefill_count > 0) {\n                auto state_slice = pd.recurrent_state_ptrs.slice(decode_count, prefill_count);\n                auto q_slice     = pd.q_offsets.slice(decode_count, prefill_count + 1);\n                invokeChunkedGatedDeltaRuleBatched(attn_out,\n                                                   conv_out,\n                                                   beta,\n                                                   g,\n                                                   state_slice,\n                                                   q_slice,\n                                                   prefill_count,\n                                                   num_k_heads_,\n                                                   recurrent_state_layer_offset,\n                                                   state_dtype_,\n                                                   sm_count_,\n                                                   work_counter_.data(),\n                                                   stream);\n                // invokeChunkedGatedDeltaRuleBatched\n            }\n        }\n        sync_check_cuda_error();\n\n        // ----- 3c. RMSNormGated (all tokens at once) -----\n        // Gate (z) lives at column conv_dim_ of all_proj with row-stride all_col.\n        Tensor gate        = all_proj.slice({0, conv_dim_}, {-1, value_dim_});\n        Tensor hidden_view = attn_out.view({token_num * num_v_heads_, value_head_dim_});\n        invokeRMSNormGated(hidden_view, gate, weights.norm, norm_eps_, stream);\n        sync_check_cuda_error();\n\n        // =================================================================\n        // 4. Output projection (all tokens at once)\n        // =================================================================\n        (void)linear_.Forward(attn_out, weights.out_proj, p.output);\n        sync_check_cuda_error();\n    };\n\n    if (dtype == kHalf) {\n        dispatch(half{});\n    }\n    else if (dtype == kBfloat16) {\n        dispatch(nv_bfloat16{});\n    }\n    else {\n        TM_CHECK(0) << \"Unsupported dtype for GatedDeltaNetLayer\";\n    }\n}\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/models/llama/GatedDeltaNetLayer.h",
    "content": "#pragma once\n\n#include \"src/turbomind/core/tensor.h\"\n#include \"src/turbomind/engine/batch.h\"\n#include \"src/turbomind/models/llama/GatedDeltaNetWeight.h\"\n#include \"src/turbomind/models/llama/LlamaLinear.h\"\n#include \"src/turbomind/models/llama/context.h\"\n#include \"src/turbomind/models/llama/llama_params.h\"\n\nnamespace turbomind {\n\nclass GatedDeltaNetLayer {\npublic:\n    struct ForwardParam {\n        int                        phase;\n        Tensor                     input;\n        Tensor                     output;\n        const GatedDeltaNetWeight* weights;\n        int                        layer_id;\n    };\n\n    GatedDeltaNetLayer(const ModelParam&     model,\n                       const AttentionParam& attn,\n                       const EngineParam&    engine,\n                       int                   tp_size,\n                       const Context&        ctx,\n                       int                   phases);\n\n    ~GatedDeltaNetLayer();\n\n    void Run(BatchOp op, int phase, TensorMap& env);\n\n    void Forward(ForwardParam p);\n\nprivate:\n    void Setup(int phase, TensorMap& env);\n\n    // Model dimensions\n    int              hidden_units_;\n    int              num_k_heads_;\n    int              num_v_heads_;\n    int              key_head_dim_;\n    int              value_head_dim_;\n    int              d_conv_;\n    int              key_dim_;            // num_k_heads * key_head_dim\n    int              value_dim_;          // num_v_heads * value_head_dim\n    int              conv_dim_;           // key_dim * 2 + value_dim\n    int              num_linear_layers_;  // count of linear attention layers for state sizing\n    std::vector<int> layer_types_;        // model layer types for index mapping\n\n    float    norm_eps_;\n    DataType dtype_;\n    DataType state_dtype_;  // recurrent state dtype (may differ from dtype_ for float32 state)\n\n    LlamaLinear& linear_;\n\n    // Per-phase batch data (mirrors UnifiedAttentionLayer pattern)\n    struct Data {\n        std::vector<RequestCache*> rc;          // borrowed batch RequestCache pointers\n        std::vector<int>           input_lens;  // snapshot of input_len per request (captured at Setup time)\n        int                        batch_size = 0;\n        Buffer_<int>               q_offsets;  // cumulative input-token offsets, device buffer\n        Buffer_<int>               k_offsets;  // cumulative key (history+input) offsets, device buffer\n        std::vector<Tensor>        conv_states;\n        std::vector<Tensor>        recurrent_states;\n        Buffer_<void*>             conv_state_ptrs;\n        Buffer_<void*>             recurrent_state_ptrs;\n    };\n    std::vector<Data> data_;\n\n    // staging buffers\n    Buffer_<void*> conv_state_ptrs_buf_;\n    Buffer_<void*> recurrent_state_ptrs_buf_;\n\n    // Queried once at construction; passed to all three kernel launchers.\n    int          sm_count_{1};\n    Buffer_<int> work_counter_;  // 1-element device int for v3 atomic claiming\n\n    // Dual-stream dispatch: prefill on high-priority aux stream, decode on main\n    cudaStream_t aux_stream_{};\n    cudaEvent_t  ev_before_{};  // main→aux: prior work done\n    cudaEvent_t  ev_after_{};   // aux→main: prefill done\n};\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/models/llama/GatedDeltaNetWeight.cc",
    "content": "#include \"src/turbomind/models/llama/GatedDeltaNetWeight.h\"\n#include \"src/turbomind/kernels/gpt_kernels.h\"\n#include \"src/turbomind/utils/cuda_utils.h\"\n\nnamespace turbomind {\n\nGatedDeltaNetWeight::GatedDeltaNetWeight(int      hidden_dim,\n                                         int      num_k_heads,\n                                         int      num_v_heads,\n                                         int      key_head_dim,\n                                         int      value_head_dim,\n                                         int      d_conv,\n                                         bool     bias,\n                                         int      tp_size,\n                                         int      tp_rank,\n                                         DataType data_type,\n                                         DataType weight_type,\n                                         int      group_size):\n    tp_rank_(tp_rank), tp_size_(tp_size)\n{\n    const int key_dim    = num_k_heads * key_head_dim / tp_size;\n    const int value_dim  = num_v_heads * value_head_dim / tp_size;\n    const int v_heads_tp = num_v_heads / tp_size;\n    const int conv_dim   = key_dim * 2 + value_dim;\n\n    // GatedDeltaNet projections are stored as plain dense weights in the checkpoint\n    // (dense_wtype = data_type avoids quantization path for these projections).\n    const DataType dense_wtype = data_type;\n    const int      dense_gsz   = 0;\n\n    // Individual projections registered for checkpoint loading\n    in_proj_qkv.emplace(hidden_dim, conv_dim, data_type, bias, dense_wtype, dense_gsz);\n    in_proj_z.emplace(hidden_dim, value_dim, data_type, bias, dense_wtype, dense_gsz);\n    in_proj_b.emplace(hidden_dim, v_heads_tp, data_type, bias, dense_wtype, dense_gsz);\n    in_proj_a.emplace(hidden_dim, v_heads_tp, data_type, bias, dense_wtype, dense_gsz);\n    out_proj.emplace(value_dim, hidden_dim, data_type, bias, dense_wtype, dense_gsz);\n\n    register_module(\"in_proj_qkv\", in_proj_qkv, tp_rank_);\n    register_module(\"in_proj_z\", in_proj_z, tp_rank_);\n    register_module(\"in_proj_b\", in_proj_b, tp_rank_);\n    register_module(\"in_proj_a\", in_proj_a, tp_rank_);\n    register_module(\"out_proj\", out_proj, tp_rank_);\n\n    // conv1d: depthwise weights, shape (conv_dim, d_conv)\n    conv1d = Tensor{{conv_dim, d_conv}, data_type, kDEVICE};\n    register_parameter(\"conv1d.\" + std::to_string(tp_rank_) + \".weight\", conv1d);\n\n    // A_log: log-space decay per head, shape (num_v_heads/tp,)\n    A_log = Tensor{{v_heads_tp}, data_type, kDEVICE};\n    register_parameter(\"A_log.\" + std::to_string(tp_rank_) + \".weight\", A_log);\n\n    // dt_bias: per head, shape (num_v_heads/tp,)\n    dt_bias = Tensor{{v_heads_tp}, data_type, kDEVICE};\n    register_parameter(\"dt_bias.\" + std::to_string(tp_rank_) + \".weight\", dt_bias);\n\n    // norm: RMSNormGated weight, shape (value_head_dim,)\n    norm = Tensor{{value_head_dim}, data_type, kDEVICE};\n    register_parameter(\"norm.weight\", norm);\n}\n\n// ---------------------------------------------------------------------------\n// Row-wise concatenation of 4 weight matrices into a single pre-allocated\n// destination tensor.\n//\n// Each source weight has shape (input_dim, out_dim_i) in row-major storage.\n// The destination has shape (input_dim, sum_i out_dim_i) and rows are filled\n// by concatenating the corresponding source rows in order.\n//\n// Implemented with cudaMemcpy2DAsync so that no extra temporary is needed:\n// each source \"column block\" is scattered into the correct column range of\n// the destination in one pass per source.\n// ---------------------------------------------------------------------------\nstatic void\nconcat_weights_4(const Tensor& a, const Tensor& b, const Tensor& c, const Tensor& d, Tensor& dst, cudaStream_t st)\n{\n    // Tensors are (K=input_dim, M=output_dim) in row-major order.\n    // Each row of `dst` is [a_row | b_row | c_row | d_row].\n    const int K       = dst.shape(0);\n    const int M_a     = a.shape(1);\n    const int M_b     = b.shape(1);\n    const int M_c     = c.shape(1);\n    const int M_d     = d.shape(1);\n    const int M_dst   = dst.shape(1);  // M_a + M_b + M_c + M_d\n    const int elem_sz = byte_size(dst.dtype(), 1);\n\n    // Pitch of the destination row in bytes\n    const size_t dst_pitch   = (size_t)M_dst * elem_sz;\n    const size_t src_pitch_a = (size_t)M_a * elem_sz;\n    const size_t src_pitch_b = (size_t)M_b * elem_sz;\n    const size_t src_pitch_c = (size_t)M_c * elem_sz;\n    const size_t src_pitch_d = (size_t)M_d * elem_sz;\n\n    char* dst_ptr = reinterpret_cast<char*>(dst.raw_data());\n\n    // Columns [0, M_a)\n    check_cuda_error(\n        cudaMemcpy2DAsync(dst_ptr, dst_pitch, a.raw_data(), src_pitch_a, src_pitch_a, K, cudaMemcpyDefault, st));\n\n    // Columns [M_a, M_a+M_b)\n    check_cuda_error(cudaMemcpy2DAsync(\n        dst_ptr + src_pitch_a, dst_pitch, b.raw_data(), src_pitch_b, src_pitch_b, K, cudaMemcpyDefault, st));\n\n    // Columns [M_a+M_b, M_a+M_b+M_c)\n    check_cuda_error(cudaMemcpy2DAsync(dst_ptr + src_pitch_a + src_pitch_b,\n                                       dst_pitch,\n                                       c.raw_data(),\n                                       src_pitch_c,\n                                       src_pitch_c,\n                                       K,\n                                       cudaMemcpyDefault,\n                                       st));\n\n    // Columns [M_a+M_b+M_c, M_dst)\n    check_cuda_error(cudaMemcpy2DAsync(dst_ptr + src_pitch_a + src_pitch_b + src_pitch_c,\n                                       dst_pitch,\n                                       d.raw_data(),\n                                       src_pitch_d,\n                                       src_pitch_d,\n                                       K,\n                                       cudaMemcpyDefault,\n                                       st));\n    sync_check_cuda_error();\n}\n\nvoid GatedDeltaNetWeight::prepare()\n{\n    auto stream = core::Context::stream().handle();\n\n    // Preprocess individual weights (converts blockscale FP8, etc.)\n    in_proj_qkv.preprocess();\n    in_proj_z.preprocess();\n    in_proj_b.preprocess();\n    in_proj_a.preprocess();\n    out_proj.preprocess();\n    out_proj.prepare();\n\n    // Build the fused input projection weight:\n    //   shape (hidden_dim,  conv_dim + value_dim + 2*v_heads_tp)\n    //   = [in_proj_qkv | in_proj_z | in_proj_b | in_proj_a]  (column-wise)\n    const int out_all = in_proj_qkv.output_dim  //\n                        + in_proj_z.output_dim  //\n                        + in_proj_b.output_dim  //\n                        + in_proj_a.output_dim;\n\n    in_proj_all.emplace(in_proj_qkv.input_dim,\n                        out_all,\n                        in_proj_qkv.data_type,\n                        /*bias=*/false,\n                        in_proj_qkv.weight_type,\n                        in_proj_qkv.group_size);\n\n    concat_weights_4(\n        in_proj_qkv.weight, in_proj_z.weight, in_proj_b.weight, in_proj_a.weight, in_proj_all.weight, stream);\n\n    // Prepare (convert/repack) the fused weight for GEMM\n    in_proj_all.prepare();\n\n    // Release the now-redundant individual weight tensors to free HBM\n    in_proj_qkv = {};\n    in_proj_z   = {};\n    in_proj_b   = {};\n    in_proj_a   = {};\n\n    // Transpose conv1d from checkpoint layout [conv_dim, d_conv] to kernel layout [d_conv, conv_dim]\n    {\n        const int rows = conv1d.shape(0);  // conv_dim\n        const int cols = conv1d.shape(1);  // d_conv\n\n        Tensor conv1d_t{{cols, rows}, conv1d.dtype(), kDEVICE};\n        invokeTransposeAxis01((uint16_t*)conv1d_t.raw_data(), (uint16_t*)conv1d.raw_data(), rows, cols, 1, stream);\n        sync_check_cuda_error();\n        conv1d = std::move(conv1d_t);\n    }\n}\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/models/llama/GatedDeltaNetWeight.h",
    "content": "#pragma once\n\n#include \"src/turbomind/core/core.h\"\n#include \"src/turbomind/core/module.h\"\n#include \"src/turbomind/models/llama/LlamaDenseWeight.h\"\n\nnamespace turbomind {\n\nstruct GatedDeltaNetWeight: public core::Module {\n\n    GatedDeltaNetWeight() = default;\n\n    GatedDeltaNetWeight(int      hidden_dim,\n                        int      num_k_heads,\n                        int      num_v_heads,\n                        int      key_head_dim,\n                        int      value_head_dim,\n                        int      d_conv,\n                        bool     bias,\n                        int      tp_size,\n                        int      tp_rank,\n                        DataType data_type,\n                        DataType weight_type,\n                        int      group_size);\n\n    void prepare();\n\n    // Individual projections – populated at load time from the checkpoint.\n    // After prepare() completes they are released (null-ed) to free HBM.\n    LlamaDenseWeight in_proj_qkv;  // hidden -> key_dim*2 + value_dim\n    LlamaDenseWeight in_proj_z;    // hidden -> value_dim (output gate)\n    LlamaDenseWeight in_proj_b;    // hidden -> num_v_heads (beta, per-head scalar)\n    LlamaDenseWeight in_proj_a;    // hidden -> num_v_heads (alpha/dt, per-head scalar)\n\n    // Fused projection: hidden -> (conv_dim + value_dim + 2*v_heads_tp).\n    // Built from the four above in prepare(); used for all inference GEMMs.\n    // Reduces p.input HBM reads from 4× to 1× per forward pass.\n    LlamaDenseWeight in_proj_all;\n\n    LlamaDenseWeight out_proj;  // value_dim -> hidden\n\n    // Non-dense parameters\n    Tensor conv1d;   // depthwise conv weights: (d_conv, conv_dim)\n    Tensor A_log;    // log-space decay: (num_v_heads,)\n    Tensor dt_bias;  // dt bias: (num_v_heads,)\n    Tensor norm;     // RMSNormGated weight: (value_head_dim,)\n\n    int tp_rank_;\n    int tp_size_;\n};\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/models/llama/LlamaDecoderLayerWeight.cc",
    "content": "/*\n * Copyright (c) OpenMMLab. All rights reserved.\n * Copyright (c) 2019-2023, NVIDIA CORPORATION.  All rights reserved.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n// Modified from\n// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/models/multi_gpu_gpt/ParallelGptDecoderLayerWeight.cc\n\n#include <cstdlib>\n\n#include <cuda_bf16.h>\n#include <cuda_runtime.h>\n\n#include \"src/turbomind/models/llama/LlamaDecoderLayerWeight.h\"\n\n#include \"src/turbomind/core/data_type.h\"\n#include \"src/turbomind/models/llama/LlamaDenseWeight.h\"\n#include \"src/turbomind/models/llama/llama_params.h\"\n#include \"src/turbomind/utils/cuda_utils.h\"\n#include \"src/turbomind/utils/logger.h\"\n\nnamespace turbomind {\n\nstatic bool is_fuse_silu_act()\n{\n    static const bool value = [] {\n        const auto str = std::getenv(\"TM_FUSE_SILU_ACT\");\n        if (str) {\n            try {\n                auto v = std::stoi(str) != 0;\n                TM_LOG_INFO(\"TM_FUSE_SILU_ACT=%d\", (int)v);\n                return v;\n            }\n            catch (...) {\n            }\n        }\n        // TM_LOG_INFO(\"TM_FUSE_SILU_ACT=1\");\n        return true;\n    }();\n    return value;\n}\n\nLlamaDecoderLayerWeight::LlamaDecoderLayerWeight(\n    DataType data_type, int layer_id, const ModelParam& model, const EngineParam& engine, const MoeParam& moe_param):\n    head_num_(model.head_num),\n    kv_head_num_(model.kv_head_num),\n    size_per_head_(model.head_dim),\n    hidden_units_(model.hidden_units),\n    inter_size_(model.inter_size.at(layer_id)),\n    data_type_{data_type},\n    weight_type_(model.weight_type),\n    expert_weight_type_(model.expert_weight_type),\n    attn_bias_(model.attn_bias),\n    attn_tp_size_(engine.attn_tp_size),\n    attn_tp_rank_(engine.attn_tp_rank),\n    mlp_tp_size_(engine.mlp_tp_size),\n    mlp_tp_rank_(engine.mlp_tp_rank)\n{\n    bool is_linear_attention = false;\n    if (layer_id < (int)model.layer_types.size() && model.layer_types[layer_id] == 1) {\n        is_linear_attention = true;\n    }\n\n    if (is_linear_attention) {\n        linear_attn_weights.reset(\n            new GatedDeltaNetWeight{hidden_units_,\n                                    model.linear_num_key_heads,\n                                    model.linear_num_value_heads,\n                                    model.linear_key_head_dim,\n                                    model.linear_value_head_dim,\n                                    model.linear_conv_kernel_dim > 0 ? model.linear_conv_kernel_dim : 4,\n                                    attn_bias_,\n                                    attn_tp_size_,\n                                    attn_tp_rank_,\n                                    data_type_,\n                                    weight_type_,\n                                    model.group_size});\n        register_module(\"linear_attn\", *linear_attn_weights);\n    }\n    else {\n        // Attention uses weight_type (fp16 in mixed quant scenarios)\n        self_attn_weights.reset(new LlamaAttentionWeight{hidden_units_,\n                                                         size_per_head_,\n                                                         head_num_,\n                                                         kv_head_num_,\n                                                         model.mla,\n                                                         attn_bias_,\n                                                         model.qk_norm,\n                                                         attn_tp_size_,\n                                                         attn_tp_rank_,\n                                                         data_type_,\n                                                         weight_type_,\n                                                         model.group_size,\n                                                         model.window_size.empty() ? 0 : model.window_size.at(layer_id),\n                                                         model.attn_sink,\n                                                         model.attn_output_gate});\n        register_module(\"attention\", *self_attn_weights);\n    }\n\n    // FFN uses ffn_weight_type, except for layers fully excluded from\n    // quantization (e.g. 'model.layers.0.' in modules_to_not_convert)\n    // where all weights—including FFN—are in data_type (fp16).\n    if (inter_size_) {\n        const DataType ffn_wtype = model.unquantized_expert_layers.count(layer_id) ? data_type_ : model.ffn_weight_type;\n        const bool     is_cublas_gemm = byte_size(ffn_wtype, 8) == 16;\n        ffn_weights.reset(new LlamaFfnWeight{\n            hidden_units_,\n            inter_size_,\n            model.mlp_bias,\n            mlp_tp_size_,\n            mlp_tp_rank_,\n            data_type_,\n            ffn_wtype,\n            model.group_size,\n            model.act_type,\n            is_fuse_silu_act() && !is_cublas_gemm,\n        });\n        register_module(\"feed_forward\", *ffn_weights);\n    }\n\n    // MoE routed experts use expert_weight_type (int4 for AWQ, e2m1 for mxfp4)\n    // unless the layer is in unquantized_expert_layers (e.g. layer 0 excluded\n    // from quantization via modules_to_not_convert).\n    if (layer_id < moe_param.expert_num.size() && moe_param.expert_num[layer_id]) {\n        const DataType moe_wtype = model.unquantized_expert_layers.count(layer_id) ? data_type_ : expert_weight_type_;\n        moe_weights.reset(new MoeFfnWeight{layer_id,\n                                           moe_param,\n                                           hidden_units_,\n                                           model.mlp_bias,\n                                           data_type_,\n                                           moe_wtype,\n                                           model.group_size,\n                                           mlp_tp_size_,\n                                           mlp_tp_rank_,\n                                           model.act_type,\n                                           is_fuse_silu_act()});\n        register_module(\"moe_ffn\", *moe_weights);\n    }\n\n    self_attn_norm = Tensor{{hidden_units_}, data_type_, kDEVICE};\n    ffn_norm       = Tensor{{hidden_units_}, data_type_, kDEVICE};\n    register_parameter(\"attention_norm.weight\", self_attn_norm);\n    register_parameter(\"ffn_norm.weight\", ffn_norm);\n}\n\nLlamaDecoderLayerWeight::~LlamaDecoderLayerWeight() = default;\n\nvoid LlamaDecoderLayerWeight::prepare(const cudaDeviceProp& prop, cudaStream_t st)\n{\n    if (self_attn_weights) {\n        self_attn_weights->prepare();\n    }\n\n    if (linear_attn_weights) {\n        linear_attn_weights->prepare();\n    }\n\n    if (ffn_weights) {\n        ffn_weights->prepare(false);\n    }\n\n    if (moe_weights) {\n        moe_weights->prepare();\n    }\n}\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/models/llama/LlamaDecoderLayerWeight.h",
    "content": "/*\n * Copyright (c) OpenMMLab. All rights reserved.\n * Copyright (c) 2019-2023, NVIDIA CORPORATION.  All rights reserved.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n// Modified from\n// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/models/multi_gpu_gpt/ParallelGptDecoderLayerWeight.h\n\n#pragma once\n\n#include \"src/turbomind/core/core.h\"\n\n#include \"src/turbomind/models/llama/GatedDeltaNetWeight.h\"\n#include \"src/turbomind/models/llama/LlamaDenseWeight.h\"\n#include \"src/turbomind/models/llama/llama_params.h\"\n\nnamespace turbomind {\n\nstruct LlamaDecoderLayerWeight: core::Module {\npublic:\n    LlamaDecoderLayerWeight() = delete;\n\n    LlamaDecoderLayerWeight(DataType           data_type,\n                            int                layer_id,\n                            const ModelParam&  model,\n                            const EngineParam& engine,\n                            const MoeParam&    moe_param);\n\n    ~LlamaDecoderLayerWeight();\n    LlamaDecoderLayerWeight(const LlamaDecoderLayerWeight&) = delete;\n    LlamaDecoderLayerWeight& operator=(const LlamaDecoderLayerWeight&) = delete;\n\n    void prepare(const cudaDeviceProp& prop, cudaStream_t st);\n\n    Tensor self_attn_norm;\n    Tensor ffn_norm;\n\n    std::unique_ptr<LlamaAttentionWeight> self_attn_weights;\n    std::unique_ptr<GatedDeltaNetWeight>  linear_attn_weights;\n\n    std::unique_ptr<LlamaFfnWeight> ffn_weights;\n    std::unique_ptr<MoeFfnWeight>   moe_weights;\n\nprivate:\n    int head_num_;\n    int kv_head_num_;\n    int size_per_head_;\n    int hidden_units_;\n    int inter_size_;\n\n    DataType data_type_;\n    DataType weight_type_;\n    DataType expert_weight_type_;\n\n    int  bit_size_;\n    bool attn_bias_;\n    int  attn_tp_size_;\n    int  attn_tp_rank_;\n    int  mlp_tp_size_;\n    int  mlp_tp_rank_;\n};\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/models/llama/LlamaDenseWeight.cc",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include <utility>\n\n#include \"src/turbomind/models/llama/LlamaDenseWeight.h\"\n\n#include \"src/turbomind/core/allocator.h\"\n#include \"src/turbomind/core/data_type.h\"\n\n#include \"src/turbomind/kernels/activation.h\"\n#include \"src/turbomind/kernels/gemm/cast.h\"\n#include \"src/turbomind/kernels/gemm/convert.h\"\n#include \"src/turbomind/kernels/gemm/gemm.h\"\n#include \"src/turbomind/kernels/gemm/types.h\"\n#include \"src/turbomind/kernels/gemm/utils.h\"\n#include \"src/turbomind/kernels/gpt_kernels.h\"\n#include \"src/turbomind/utils/cuda_utils.h\"\n\nnamespace turbomind {\n\nvoid LlamaDenseWeight::emplace(\n    int input_dim, int output_dim, DataType data_type, bool bias, DataType weight_type, int group_size)\n{\n    this->data_type   = data_type;\n    this->input_type  = data_type;\n    this->weight_type = weight_type;\n    this->input_dim   = input_dim;\n    this->output_dim  = output_dim;\n    this->group_size  = group_size;\n\n    const bool is_qweight = weight_type == kUint4 || weight_type == kUint8;\n\n    weight = Tensor({input_dim, output_dim}, weight_type, kDEVICE);\n    register_parameter(is_qweight ? \"qweight\" : \"weight\", weight);\n\n    if (bias) {\n        this->bias = Tensor{{output_dim}, data_type, kDEVICE};\n        register_parameter(\"bias\", this->bias);\n    }\n\n    if (weight_type == kFloat8_e4m3) {\n        TM_CHECK_EQ(group_size, 128);\n        scales       = Tensor{{cdiv(input_dim, group_size), cdiv(output_dim, group_size)}, kFloat, kDEVICE};\n        weight_quant = QuantDesc{gemm::QuantType::kB, group_size};\n        if (getSMVersion() == 90) {\n            input_type  = kFloat8_e4m3;\n            input_quant = QuantDesc{gemm::QuantType::kK, group_size};\n        }\n        register_parameter(\"scales\", scales);\n    }\n    else if (weight_type == kFloat4_e2m1) {\n        scales       = Tensor{{cdiv(input_dim, group_size), output_dim}, kUint8, kDEVICE};\n        input_type   = data_type;\n        weight_quant = QuantDesc{gemm::QuantType::kK, group_size};\n        register_parameter(\"scales\", scales);\n    }\n    else if (is_qweight) {\n        TM_CHECK(input_dim % group_size == 0) << input_dim << \" \" << group_size;\n        scales       = Tensor{{input_dim / group_size, output_dim}, data_type, kDEVICE};\n        zeros        = Tensor{{input_dim / group_size, output_dim}, data_type, kDEVICE};\n        weight_quant = QuantDesc{gemm::QuantType::kK, group_size};\n        register_parameter(\"scales\", scales);\n        register_parameter(\"zeros\", zeros);\n    }\n\n    k_desc = {};\n    q_desc = {};\n\n    // default case: floating point, N-major\n    k_desc.type  = weight.dtype();\n    k_desc.order = gemm::kRowMajor;\n    k_desc.rows  = input_dim;\n    k_desc.cols  = output_dim;\n    k_desc.ld    = output_dim;\n}\n\nvoid LlamaDenseWeight::preprocess()\n{\n    if (!weight) {\n        return;\n    }\n    if (weight_quant.type == gemm::QuantType::kB && input_quant.type == gemm::QuantType::kNone) {\n        // Convert blockwise scales to groupwise scales\n        weight_quant.type = gemm::QuantType::kK;\n        scales            = BlockscaleToGroupscale(scales, data_type, weight_quant.group_size);\n    }\n}\n\nstatic void Convert(LlamaDenseWeight& dense, bool is_grouped, cudaStream_t st)\n{\n    using namespace gemm;\n\n    auto [conv_w, conv_s] =\n        GetConverters(dense.data_type, dense.weight_type, dense.input_type, is_grouped, getSMVersion());\n\n    if (conv_w) {\n        const auto order_w = conv_w->order;\n        const bool is_A    = get_operand_tag(conv_w->pack) == OPERAND_A;\n        const bool is_B    = !is_A;\n\n        const int bits = byte_size(dense.weight_type, 8);\n\n        Tensor_<uint16_t> tmp{{dense.input_dim, dense.output_dim}, kDEVICE};\n\n        if (bits == 4) {  // u4 -> u16\n            extend_to_u16(tmp.data(), (const uint4_t*)dense.weight.raw_data(), tmp.size(), st);\n            sync_check_cuda_error();\n        }\n        else if (bits == 8) {  // u8 -> u16\n            extend_to_u16(tmp.data(), (const uint8_t*)dense.weight.raw_data(), tmp.size(), st);\n            sync_check_cuda_error();\n        }\n        else if (bits == 16) {\n            check_cuda_error(\n                cudaMemcpyAsync(tmp.raw_data(), dense.weight.raw_data(), tmp.byte_size(), cudaMemcpyDefault, st));\n        }\n\n        if (order_w == kRowMajor) {  // (k,m) -> (m,k)\n            Tensor_<uint16_t> trans{{dense.output_dim, dense.input_dim}, kDEVICE};\n            invokeTransposeAxis01(trans.data(), tmp.data(), dense.input_dim, dense.output_dim, 1, st);\n            tmp = trans;\n        }\n\n        MatrixLayout w_desc{\n            dense.data_type,\n            order_w,\n            (int)dense.output_dim,  // M\n            (int)dense.input_dim,   // K\n            order_w == kRowMajor ? (int)dense.input_dim : (int)dense.output_dim,\n        };\n\n        if (is_B) {\n            std::swap(w_desc.rows, w_desc.cols);\n            w_desc.order = ~w_desc.order;\n        }\n\n        MatrixLayout k_desc = w_desc;\n        k_desc.type         = dense.weight_type;\n        // Converter does not recognize e2m1 / e4m3\n        if (bits == 4) {\n            k_desc.type = data_type_v<uint4_t>;\n        }\n        else if (bits == 8) {\n            k_desc.type = data_type_v<uint8_t>;\n        }\n        k_desc.pack = conv_w->pack;\n\n        check_cuda_error(cudaMemsetAsync(dense.weight.raw_data(), 0, dense.weight.byte_size(), st));\n\n        TM_CHECK(conv_w->Convert(tmp.data(), w_desc, dense.weight.raw_data(), k_desc, st) == 0);\n\n        sync_check_cuda_error();\n\n        k_desc.type = dense.weight_type;\n        if (is_A) {\n            k_desc = transpose(k_desc);\n        }\n        dense.k_desc = k_desc;\n    }\n\n    if (conv_s) {\n        const auto order_s = conv_s->order;\n        const auto pack_s  = conv_s->pack;\n        const bool is_A    = get_operand_tag(conv_s->pack) == OPERAND_U;\n        const bool is_B    = !is_A;\n\n        Tensor   tmp_q;\n        DataType scale_type;\n\n        if (dense.zeros) {  // AWQ/GPTQ fuse scales and zeros\n            tmp_q = {{dense.scales.size(), 2}, kHalf, kDEVICE};\n            fuse_scales_and_zeros(\n                tmp_q.data<half>(), dense.scales.data<half>(), dense.zeros.data<half>(), dense.scales.size(), st);\n            scale_type   = kUint32;  // half2\n            dense.zeros  = {};\n            dense.scales = empty_like(tmp_q);\n        }\n        else if (dense.weight_type == kFloat8_e4m3) {  // e4m3\n            tmp_q = empty_like(dense.scales);\n            Copy(dense.scales, tmp_q);\n            scale_type = kUint16;  // bf16\n        }\n        else {  // mxfp4\n            tmp_q = empty_like(dense.scales);\n            Copy(dense.scales, tmp_q);\n            scale_type = kUint8;  // ue8m0\n        }\n\n        if (dense.data_type == kHalf && dense.weight_type == kFloat4_e2m1) {  // mxfp4\n            AdjustUe8m0ScaleForHalf(tmp_q.data<uint8_t>(), tmp_q.size(), st);\n            sync_check_cuda_error();\n        }\n\n        MatrixLayout s_desc{\n            scale_type,\n            order_s,\n            (int)dense.output_dim,                    // M\n            (int)dense.input_dim / dense.group_size,  // K\n            (int)dense.output_dim,                    // always MN-major\n        };\n\n        if (is_B) {\n            std::swap(s_desc.rows, s_desc.cols);\n            s_desc.order = ~s_desc.order;\n        }\n\n        MatrixLayout q_desc = s_desc;\n        q_desc.pack         = pack_s;\n\n        TM_CHECK(conv_s->Convert(tmp_q.raw_data(), s_desc, dense.scales.raw_data(), q_desc, st) == 0);\n        sync_check_cuda_error();\n\n        // weight is placed at B in `Linear`\n        if (is_A) {\n            q_desc = transpose(q_desc);\n        }\n        dense.q_desc = q_desc;\n    }\n}\n\nstatic void ConvertBlockscaleFP8Native(LlamaDenseWeight& dense, cudaStream_t stream)\n{\n    using namespace gemm;\n\n    TM_CHECK_GE(getSMVersion(), 90);\n    TM_CHECK_EQ(dense.data_type, data_type_v<bfloat16_t>);\n\n    auto process = [&](Tensor& x, MatrixLayout& d, auto dtype) {\n        using T = decltype(dtype);\n        Tensor trans{{x.shape(1), x.shape(0)}, x.dtype(), kDEVICE};\n        invokeTransposeAxis01((T*)trans.raw_data(), (T*)x.raw_data(), x.shape(0), x.shape(1), 1, stream);\n        x = std::move(trans);\n        d = MatrixLayout{x.dtype(),  //\n                         kColMajor,\n                         (int)x.shape(1),\n                         (int)x.shape(0),\n                         (int)x.stride(0)};\n    };\n\n    TM_CHECK_EQ(dense.weight.dtype(), kFloat8_e4m3);\n    process(dense.weight, dense.k_desc, uint8_t{});\n\n    TM_CHECK_EQ(dense.scales.dtype(), kFloat);\n    process(dense.scales, dense.q_desc, float{});\n}\n\nvoid LlamaDenseWeight::prepare(bool fused_moe)\n{\n    if (!weight) {\n        return;\n    }\n\n    auto stream = core::Context::stream().handle();\n\n    if (weight_type == kFloat8_e4m3 && input_type == kFloat8_e4m3) {\n        ConvertBlockscaleFP8Native(*this, stream);\n    }\n    else {\n        Convert(*this, fused_moe, stream);\n    }\n}\n\nLlamaAttentionWeight::LlamaAttentionWeight(int      hidden_dim,\n                                           int      head_dim,\n                                           int      head_num,\n                                           int      kv_head_num,\n                                           MLAParam mla,\n                                           bool     bias,\n                                           bool     qk_norm,\n                                           int      tp_size,\n                                           int      tp_rank,\n                                           DataType data_type,\n                                           DataType weight_type,\n                                           int      group_size,\n                                           int      window_size,\n                                           bool     sink,\n                                           bool     attn_output_gate)\n{\n    this->window_size = window_size;\n\n    // attn_output_gate doubles Q dimension (extra gate projection fused into Q)\n    const int q_factor = attn_output_gate ? 2 : 1;\n\n    if (mla.kv_lora_rank == 0) {\n        qkv.emplace(hidden_dim,\n                    (head_num * q_factor + 2 * kv_head_num) * head_dim / tp_size,\n                    data_type,\n                    bias,\n                    weight_type,\n                    group_size);\n        register_module(\"w_qkv\", qkv, tp_rank);\n        if (qk_norm) {\n            q_a_layernorm  = Tensor{{head_dim}, data_type, kDEVICE};\n            kv_a_layernorm = Tensor{{head_dim}, data_type, kDEVICE};\n            register_parameter(\"q_norm\", q_a_layernorm);\n            register_parameter(\"k_norm\", kv_a_layernorm);\n        }\n    }\n    else {\n        const int qk_nope_dim = head_dim - mla.qk_rope_dim;\n        if (mla.q_lora_rank) {\n            q_a_proj.emplace(hidden_dim, mla.q_lora_rank, data_type, false, weight_type, group_size);\n            q_b_proj.emplace(mla.q_lora_rank, head_num * head_dim / tp_size, data_type, false, weight_type, group_size);\n            q_a_layernorm = Tensor{{q_b_proj.input_dim}, data_type, kDEVICE};\n            register_module(\"q_a_proj\", q_a_proj);\n            register_module(\"q_b_proj\", q_b_proj, tp_rank);\n            register_parameter(\"q_a_layernorm\", q_a_layernorm);\n        }\n        else {\n            q_proj.emplace(hidden_dim, head_num * head_dim / tp_size, data_type, false, weight_type, group_size);\n            register_module(\"q_proj\", q_proj, tp_rank);\n        }\n        kv_a_proj.emplace(hidden_dim, mla.kv_lora_rank + mla.qk_rope_dim, data_type, false, weight_type, group_size);\n        // kv_b_proj.emplace(mla.kv_lora_rank,\n        //                   head_num * (qk_nope_dim + mla.v_head_dim) / tp_size,\n        //                   data_type,\n        //                   false,\n        //                   weight_type,\n        //                   group_size);\n\n        kv_a_layernorm = Tensor{{mla.kv_lora_rank}, data_type, kDEVICE};\n        register_module(\"kv_a_proj\", kv_a_proj);\n        // register_module(\"kv_b_proj\", kv_b_proj, tp_rank);\n        register_parameter(\"kv_a_layernorm\", kv_a_layernorm);\n    }\n    output.emplace((head_num * head_dim) / tp_size, hidden_dim, data_type, bias, weight_type, group_size);\n    register_module(\"wo\", output, tp_rank);\n\n    if (sink) {\n        sinks = Tensor{{head_num / tp_size}, data_type, kDEVICE};\n        register_parameter(std::to_string(tp_rank) + \".sinks\", sinks);\n    }\n}\n\nvoid LlamaAttentionWeight::prepare()\n{\n    std::vector weights{\n        &qkv, &output, &q_a_proj, &q_a_proj, &q_b_proj, &kv_a_proj  // &kv_b_proj,\n    };\n    for (auto& w : weights) {\n        w->preprocess();\n        w->prepare();\n    }\n}\n\nLlamaFfnWeight::LlamaFfnWeight(int            hidden_dim,\n                               int            inter_size,\n                               bool           bias,\n                               int            tp_size,\n                               int            tp_rank,\n                               DataType       data_type,\n                               DataType       weight_type,\n                               int            group_size,\n                               ActivationType act_type,\n                               bool           fuse_silu_act)\n{\n    TM_CHECK(inter_size % tp_size == 0) << inter_size << \" \" << tp_size;\n\n    inter_size /= tp_size;\n\n    this->inter_size    = inter_size;\n    this->tp_rank       = tp_rank;\n    this->act_type      = act_type;\n    this->is_fused_silu = fuse_silu_act && this->act_type == ActivationType::kSilu;\n\n    gating.emplace(hidden_dim, inter_size, data_type, bias, weight_type, group_size);\n\n    intermediate.emplace(hidden_dim, inter_size, data_type, bias, weight_type, group_size);\n\n    output.emplace(inter_size, hidden_dim, data_type, bias, weight_type, group_size);\n\n    if (gating.input_type == kFloat8_e4m3) {  // SM90 FP8*FP8 GEMM, can't fuse\n        this->is_fused_silu = false;\n    }\n\n    register_module(\"w1\", gating, tp_rank);\n    register_module(\"w3\", intermediate, tp_rank);\n    register_module(\"w2\", output, tp_rank);\n}\n\nstatic void Interleave(const Tensor& a, const Tensor& b, Tensor& c, cudaStream_t st)\n{\n    TM_CHECK(a.layout() == b.layout());\n    int M, K;\n    if (a.ndim() == 2) {\n        std::tie(K, M) = a.shapes(0, 1);\n    }\n    else {\n        M = a.shape(0);\n        K = 1;\n    }\n    auto a_ = a.raw_data();\n    auto b_ = b.raw_data();\n    auto c_ = c.raw_data();\n\n    const int bits = byte_size(a.dtype(), 8);\n    if (bits == 4) {\n        Buffer_<uint8_t> ta{a.size(), kDEVICE};\n        Buffer_<uint8_t> tb{b.size(), kDEVICE};\n        Buffer_<uint8_t> tc{c.size(), kDEVICE};\n        extend_to_u8(ta.data(), (uint4_t*)a_, a.size(), st);\n        extend_to_u8(tb.data(), (uint4_t*)b_, b.size(), st);\n        interleave_output_dims(tc.data(), ta.data(), tb.data(), M, K, st);\n        compact_to_u4((uint4_t*)c_, tc.data(), c.size(), st);\n    }\n    else if (bits == 8) {\n        interleave_output_dims((uint8_t*)c_, (uint8_t*)a_, (uint8_t*)b_, M, K, st);\n    }\n    else if (bits == 16) {\n        interleave_output_dims((uint16_t*)c_, (uint16_t*)a_, (uint16_t*)b_, M, K, st);\n    }\n    else if (bits == 32) {\n        interleave_output_dims((uint32_t*)c_, (uint32_t*)a_, (uint32_t*)b_, M, K, st);\n    }\n    else {\n        TM_CHECK(0);\n    }\n}\n\nvoid interleave(LlamaDenseWeight& c, LlamaDenseWeight& a, LlamaDenseWeight& b, DataType data_type, cudaStream_t st)\n{\n    TM_CHECK_EQ(c.input_dim, a.input_dim);\n    TM_CHECK_EQ(c.input_dim, b.input_dim);\n    TM_CHECK_EQ(c.output_dim, a.output_dim * 2);\n    TM_CHECK_EQ(c.output_dim, b.output_dim * 2);\n    TM_CHECK_EQ(c.group_size, a.group_size);\n    TM_CHECK_EQ(c.group_size, b.group_size);\n\n    Interleave(a.weight, b.weight, c.weight, st);\n    sync_check_cuda_error();\n\n    if (a.scales) {\n        Interleave(a.scales, b.scales, c.scales, st);\n        sync_check_cuda_error();\n    }\n    if (a.zeros) {\n        Interleave(a.zeros, b.zeros, c.zeros, st);\n        sync_check_cuda_error();\n    }\n    if (a.bias) {\n        Interleave(a.bias, b.bias, c.bias, st);\n        sync_check_cuda_error();\n    }\n}\n\nstatic void Chunk(const Tensor& a, const Tensor& b, Tensor& c, cudaStream_t st)\n{\n    TM_CHECK(a.layout() == b.layout());\n    int M, K, spitch, dpitch;\n    if (a.ndim() == 2) {\n        std::tie(K, M) = a.shapes(0, 1);\n        spitch         = byte_size(a.dtype(), a.stride(0));\n        dpitch         = byte_size(c.dtype(), c.stride(0));\n    }\n    else {\n        M      = a.shape(0);\n        K      = 1;\n        spitch = byte_size(a.dtype(), M);\n        dpitch = byte_size(c.dtype(), c.shape(0));\n    }\n    int height = K;\n    int width  = byte_size(a.dtype(), M);\n    check_cuda_error(cudaMemcpy2DAsync((char*)c.raw_data(),  //\n                                       dpitch,\n                                       (const char*)a.raw_data(),\n                                       spitch,\n                                       width,\n                                       height,\n                                       cudaMemcpyDefault,\n                                       st));\n    check_cuda_error(cudaMemcpy2DAsync((char*)c.raw_data() + width,  //\n                                       dpitch,\n                                       (const char*)b.raw_data(),\n                                       spitch,\n                                       width,\n                                       height,\n                                       cudaMemcpyDefault,\n                                       st));\n}\n\nvoid chunk(LlamaDenseWeight& c, LlamaDenseWeight& a, LlamaDenseWeight& b, DataType data_type, cudaStream_t st)\n{\n    TM_CHECK_EQ(c.input_dim, a.input_dim);\n    TM_CHECK_EQ(c.input_dim, b.input_dim);\n    TM_CHECK_EQ(c.output_dim, a.output_dim * 2);\n    TM_CHECK_EQ(c.output_dim, b.output_dim * 2);\n    TM_CHECK_EQ(c.group_size, a.group_size);\n    TM_CHECK_EQ(c.group_size, b.group_size);\n\n    Chunk(a.weight, b.weight, c.weight, st);\n    sync_check_cuda_error();\n\n    if (a.scales) {\n        Chunk(a.scales, b.scales, c.scales, st);\n        sync_check_cuda_error();\n    }\n    if (a.zeros) {\n        Chunk(a.zeros, b.zeros, c.zeros, st);\n        sync_check_cuda_error();\n    }\n    if (a.bias) {\n        Chunk(a.bias, b.bias, c.bias, st);\n        sync_check_cuda_error();\n    }\n}\n\nvoid LlamaFfnWeight::prepare(bool fused_moe)\n{\n    const auto data_type = gating.data_type;\n\n    auto stream = core::Context().stream().handle();\n\n    gating.preprocess();\n    intermediate.preprocess();\n\n    if (fuse_up_and_gate) {\n        auto& gate_and_up = fused_gating_intermediate;\n\n        gate_and_up.emplace(gating.input_dim,  //\n                            gating.output_dim * 2,\n                            gating.data_type,\n                            (bool)gating.bias,\n                            gating.weight_type,\n                            gating.group_size);\n        gate_and_up.preprocess();\n        register_module(\"w1w3\", gate_and_up, this->tp_rank);\n\n        if (is_fused_silu) {\n            interleave(gate_and_up, gating, intermediate, data_type, stream);\n            gate_and_up.epilogue = gemm::Epilogue::kGatedSilu;\n        }\n        else {\n            chunk(gate_and_up, gating, intermediate, data_type, stream);\n        }\n\n        fused_gating_intermediate.prepare(fused_moe);\n\n        gating       = {};\n        intermediate = {};\n    }\n    else {\n        gating.prepare(fused_moe);\n        intermediate.prepare(fused_moe);\n    }\n\n    output.preprocess();\n    output.prepare(fused_moe);\n}\n\nMoeFfnWeight::MoeFfnWeight(int             layer_id,\n                           const MoeParam& param,\n                           int             hidden_dim,\n                           bool            mlp_bias,\n                           DataType        data_type,\n                           DataType        weight_type,\n                           int             group_size,\n                           int             tp_size,\n                           int             tp_rank,\n                           ActivationType  act_type,\n                           bool            fuse_silu_act)\n{\n    if ((int)param.expert_num.size() <= layer_id) {\n        return;\n    }\n\n    const int expert_num = param.expert_num[layer_id];\n\n    if (expert_num == 0) {\n        return;\n    }\n\n    gate.emplace(hidden_dim, expert_num, data_type, param.router_bias, data_type, 1);\n    register_module(\"gate\", gate);\n\n    if (param.topk_method == \"noaux_tc\") {\n        score_correction_bias = Tensor{{expert_num}, kFloat, kDEVICE};\n        register_parameter(\"gate.score_correction_bias\", score_correction_bias);\n    }\n\n    method = param.method;\n\n    const bool is_cublas_gemm = method == MoeParam::kNaive && byte_size(weight_type, 8) == 16;\n    if (is_cublas_gemm || mlp_bias) {\n        fuse_silu_act = false;\n    }\n\n    experts.reserve(expert_num);\n    for (int i = 0; i < expert_num; ++i) {\n        experts.emplace_back(new LlamaFfnWeight{hidden_dim,\n                                                param.inter_size,\n                                                mlp_bias,\n                                                tp_size,\n                                                tp_rank,\n                                                data_type,\n                                                weight_type,\n                                                group_size,\n                                                act_type,\n                                                fuse_silu_act});\n        register_module(\"experts\", *experts.back(), i);\n    }\n\n    if (param.shared_gate) {\n        shared_gate.emplace(hidden_dim, 1, data_type, false, data_type, 1);\n        register_module(\"shared_gate\", shared_gate);\n    }\n}\n\nvoid MoeFfnWeight::prepare()\n{\n    const auto fused_moe = method == MoeParam::kFused;\n\n    gate.prepare();\n    shared_gate.prepare();\n\n    for (auto& e : experts) {\n        e->prepare(fused_moe);\n    }\n\n    const int n = experts.size();\n    LinkExperts([&](int i) { return &experts[i]->fused_gating_intermediate; }, n, block.fused_gating_intermediate);\n    LinkExperts([&](int i) { return &experts[i]->output; }, n, block.output);\n\n    auto& e = *experts.at(0);\n    // Copy MLP properties\n    block.inter_size    = e.inter_size;\n    block.is_fused_silu = e.is_fused_silu;\n    block.act_type      = e.act_type;\n}\n\nvoid LinkExperts(std::function<LlamaDenseWeight*(int)> experts, int n, LlamaDenseWeight& d)\n{\n    const auto& e = *experts(0);\n\n    d.input_dim    = e.input_dim;\n    d.output_dim   = e.output_dim;\n    d.group_size   = e.group_size;\n    d.data_type    = e.data_type;\n    d.input_type   = e.input_type;\n    d.weight_type  = e.weight_type;\n    d.input_quant  = e.input_quant;\n    d.weight_quant = e.weight_quant;\n    d.k_desc       = e.k_desc;\n    d.q_desc       = e.q_desc;\n    d.epilogue     = e.epilogue;\n\n    d.k_desc.num = d.q_desc.num = n;\n\n    if (e.bias) {\n        d.bias = Tensor{{n, e.output_dim}, e.bias.dtype(), kDEVICE};\n    }\n\n    std::vector<std::pair<void*, int>> weights;\n    std::vector<std::pair<void*, int>> scales;\n\n    for (int i = 0; i < n; ++i) {\n        auto& e = *experts(i);\n        weights.emplace_back(e.weight.raw_data(), e.k_desc.ld);\n        if (e.scales) {\n            scales.emplace_back(e.scales.raw_data(), e.q_desc.ld);\n        }\n        if (e.bias) {\n            Copy(e.bias, d.bias.slice(i, 1).squeeze(0));\n        }\n    }\n\n    auto stream = core::Context::stream().handle();\n\n    if (d.weight_type == kFloat8_e4m3 && d.input_type == kFloat8_e4m3) {\n        auto make_blocked_ptr = [&](const auto& ptrs) {\n            return std::shared_ptr<void>{gemm::MakeBlockedPtrs(ptrs, stream), [](auto p) { cudaFree(p); }};\n        };\n        d.weight = Tensor{make_blocked_ptr(weights), {n}, e.weight.dtype(), kDEVICE};\n        d.scales = Tensor{make_blocked_ptr(scales), {n}, e.scales.dtype(), kDEVICE};\n        // This is needed to be recognized as blocked striding mode\n        d.k_desc.offsets = d.q_desc.offsets = (int*)1;\n    }\n    else {\n        auto make_strided_ptr = [&](const auto& ptrs) {\n            return std::shared_ptr<void>{gemm::MakeStridedPtrs(ptrs, stream), [](auto p) { cudaFree(p); }};\n        };\n        d.weight = Tensor{make_strided_ptr(weights), {n}, d.weight_type, kDEVICE};\n        if (e.scales) {\n            d.scales = Tensor{make_strided_ptr(scales), {n}, e.scales.dtype(), kDEVICE};\n        }\n        // pre-sm90 grouped GEMM need `ld == 0 to resolve strided_ptr\n        d.k_desc.ld = d.q_desc.ld = 0;\n    }\n}\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/models/llama/LlamaDenseWeight.h",
    "content": "/*\n * Copyright (c) OpenMMLab. All rights reserved.\n * Copyright (c) 2019-2023, NVIDIA CORPORATION.  All rights reserved.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n// Modified from https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/layers/DenseWeight.h\n\n#pragma once\n\n#include \"src/turbomind/core/core.h\"\n#include \"src/turbomind/core/module.h\"\n\n#include \"src/turbomind/kernels/activation.h\"\n#include \"src/turbomind/kernels/gemm/types.h\"\n\n#include \"src/turbomind/models/llama/llama_params.h\"\n\nnamespace turbomind {\n\nusing gemm::QuantDesc;\nusing gemm::MatrixLayout;\nusing gemm::Epilogue;\n\nstruct LlamaDenseWeight: public core::Module {\n\n    LlamaDenseWeight():\n        data_type{}, weight_type{}, input_type{}, weight_quant{}, input_quant{}, epilogue{}, k_desc{}, q_desc{}\n    {\n    }\n\n    void emplace(int input_dim, int output_dim, DataType data_type, bool bias, DataType weight_type, int group_size);\n\n    void preprocess();\n\n    void prepare(bool fused_moe = 0);\n\n    LlamaDenseWeight& operator=(std::nullptr_t)\n    {\n        this->~LlamaDenseWeight();\n        new (this) LlamaDenseWeight{};\n        return *this;\n    }\n\n    operator bool() const noexcept\n    {\n        return static_cast<bool>(weight);\n    }\n\n    int input_dim  = 0;\n    int output_dim = 0;\n    int group_size = 1;\n\n    Tensor weight;\n    Tensor bias;\n\n    Tensor scales;\n    Tensor zeros;\n\n    DataType data_type;\n\n    DataType weight_type;\n    DataType input_type;\n\n    QuantDesc weight_quant;\n    QuantDesc input_quant;\n\n    Epilogue epilogue;\n\n    MatrixLayout k_desc;\n    MatrixLayout q_desc;\n};\n\nstruct LlamaAttentionWeight: public core::Module {\n\n    LlamaAttentionWeight() = default;\n\n    LlamaAttentionWeight(int      hidden_dim,\n                         int      head_dim,\n                         int      head_num,\n                         int      kv_head_num,\n                         MLAParam mla,\n                         bool     bias,\n                         bool     qk_norm,\n                         int      tp_size,\n                         int      tp_rank,\n                         DataType data_type,\n                         DataType weight_type,\n                         int      group_size,\n                         int      window_size,\n                         bool     sink,\n                         bool     attn_output_gate = false);\n\n    void prepare();\n\n    LlamaDenseWeight qkv;\n    LlamaDenseWeight output;\n\n    Tensor sinks;\n\n    LlamaDenseWeight q_proj;\n    LlamaDenseWeight q_a_proj;\n    LlamaDenseWeight q_b_proj;\n    LlamaDenseWeight kv_a_proj;\n    // LlamaDenseWeight kv_b_proj;\n\n    Tensor q_a_layernorm;\n    Tensor kv_a_layernorm;\n\n    int window_size{};\n};\n\nstruct LlamaFfnWeight: core::Module {\n\n    LlamaFfnWeight() = default;\n\n    LlamaFfnWeight(int            hidden_dim,\n                   int            inter_size,\n                   bool           bias,\n                   int            tp_size,\n                   int            tp_rank,\n                   DataType       data_type,\n                   DataType       weight_type,\n                   int            group_size,\n                   ActivationType act_type,\n                   bool           fuse_silu_act);\n\n    static constexpr bool fuse_up_and_gate = true;\n\n    void prepare(bool fused_moe);\n\n    LlamaDenseWeight gating;\n    LlamaDenseWeight intermediate;\n    LlamaDenseWeight output;\n    LlamaDenseWeight fused_gating_intermediate;\n\n    ActivationType act_type;\n\n    int  inter_size{};\n    bool is_fused_silu{};\n\n    int tp_rank{};\n};\n\nstruct MoeFfnWeight: core::Module {\n\n    MoeFfnWeight() = default;\n\n    MoeFfnWeight(int             layer_id,\n                 const MoeParam& param,\n                 int             hidden_dim,\n                 bool            mlp_bias,\n                 DataType        data_type,\n                 DataType        weight_type,\n                 int             group_size,\n                 int             tp_size,\n                 int             tp_rank,\n                 ActivationType  act_type,\n                 bool            fuse_silu_act);\n\n    void prepare();\n\n    LlamaDenseWeight gate;\n    LlamaDenseWeight shared_gate;\n\n    /// Per-expert score correction bias for noaux_tc routing (optional; used when topk_method == \"noaux_tc\")\n    Tensor score_correction_bias;\n\n    std::vector<std::unique_ptr<LlamaFfnWeight>> experts;\n\n    // reference into `experts`\n    LlamaFfnWeight block;\n\n    MoeParam::Method method{};\n};\n\nvoid LinkExperts(std::function<LlamaDenseWeight*(int)> experts, int n, LlamaDenseWeight& d);\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/models/llama/LlamaFfnLayer.cc",
    "content": "/*\n * Copyright (c) OpenMMLab. All rights reserved.\n * Copyright (c) 2022-2023, NVIDIA CORPORATION.  All rights reserved.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n// Modified from https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/layers/FfnLayer.h\n\n#include \"src/turbomind/models/llama/LlamaFfnLayer.h\"\n#include \"src/turbomind/kernels/activation.h\"\n#include \"src/turbomind/models/llama/llama_utils.h\"\n#include \"src/turbomind/utils/anomaly_handler.h\"\n\nnamespace turbomind {\n\nvoid LlamaFfnLayer::forward(ForwardParam param)\n{\n    NvtxScope scope(\"ffn\");\n\n    const auto& mlp = *param.weights;\n\n    const int token_num  = param.input.shape(0);\n    const int inter_size = mlp.inter_size;\n    const int layer_id   = param.layer_id;\n\n    const auto stream = core::Context::stream().handle();\n\n    Tensor gating;\n    Tensor inter;\n\n    if (mlp.fused_gating_intermediate.weight) {\n        auto mix = linear_.Forward(param.input, mlp.fused_gating_intermediate);\n        sync_check_cuda_error();\n\n        gating = mix.slice({0, 0}, {(int)token_num, inter_size});\n        if (!mlp.is_fused_silu) {\n            inter = mix.slice({0, inter_size}, {(ssize_t)token_num, inter_size});\n        }\n    }\n    else {\n        gating = linear_.Forward(param.input, mlp.gating);\n        sync_check_cuda_error();\n        TM_DEBUG_TENSOR(gating, Concat(\"w1\", layer_id), 3);\n\n        inter = linear_.Forward(param.input, mlp.intermediate);\n        sync_check_cuda_error();\n        TM_DEBUG_TENSOR(inter, Concat(\"w3\", layer_id), 3);\n    }\n\n    if (!mlp.is_fused_silu) {\n        // gate' = silu(gate) * up\n        Activation(gating, inter, mlp.act_type, stream);\n        sync_check_cuda_error();\n        TM_DEBUG_TENSOR(gating, Concat(\"act\", layer_id), 3);\n    }\n\n    {  // w2(x)\n        NvtxScope scope(\"w2\");\n        linear_.Forward(gating, mlp.output, param.output);\n        sync_check_cuda_error();\n    }\n}\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/models/llama/LlamaFfnLayer.h",
    "content": "/*\n * Copyright (c) OpenMMLab. All rights reserved.\n * Copyright (c) 2022-2023, NVIDIA CORPORATION.  All rights reserved.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n// Modified from https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/layers/FfnLayer.cc\n\n#pragma once\n\n#include \"src/turbomind/core/core.h\"\n#include \"src/turbomind/models/llama/LlamaDenseWeight.h\"\n#include \"src/turbomind/models/llama/LlamaLinear.h\"\n#include \"src/turbomind/models/llama/context.h\"\n#include \"src/turbomind/models/llama/llama_params.h\"\n\nnamespace turbomind {\n\nclass LlamaFfnLayer {\npublic:\n    LlamaFfnLayer(const ModelParam& model, const Context& ctx): hidden_units_(model.hidden_units), linear_(*ctx.linear)\n    {\n    }\n\n    struct ForwardParam {\n        Tensor                input;\n        Tensor                output;\n        const LlamaFfnWeight* weights;\n        int                   layer_id;\n    };\n\n    void forward(ForwardParam param);\n\nprivate:\n    const size_t hidden_units_;\n    LlamaLinear& linear_;\n};\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/models/llama/LlamaLinear.cu",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include \"src/turbomind/core/allocator.h\"\n#include \"src/turbomind/core/context.h\"\n#include \"src/turbomind/core/core.h\"\n#include \"src/turbomind/core/cuda_data_type.h\"\n#include \"src/turbomind/core/data_type.h\"\n\n#include \"src/turbomind/kernels/gemm/gemm.h\"\n#include \"src/turbomind/kernels/gemm/moe_utils_v2.h\"\n#include \"src/turbomind/kernels/gemm/types.h\"\n\n#include \"src/turbomind/kernels/quantization.h\"\n\n#include \"src/turbomind/models/llama/LlamaDenseWeight.h\"\n#include \"src/turbomind/models/llama/LlamaLinear.h\"\n\n#include \"src/turbomind/utils/cuda_utils.h\"\n\nnamespace turbomind {\n\nusing namespace gemm;\n\nstruct LlamaLinear::Impl {\n\n    explicit Impl()\n    {\n        workspace_ = {};\n\n        workspace_.barriers_size   = gemm::Gemm::kBarriersSize;\n        workspace_.partials_size   = gemm::Gemm::kPartialsSize;\n        workspace_.tensormaps_size = 8192 * 128;  // maximum 4096 tensor maps\n\n        auto st = core::Context::stream().handle();\n\n        check_cuda_error(cudaMallocAsync(&workspace_.barriers, workspace_.barriers_size, st));\n        check_cuda_error(cudaMallocAsync(&workspace_.partials, workspace_.partials_size, st));\n        check_cuda_error(cudaMallocAsync(&workspace_.tensormaps, workspace_.partials_size, st));\n        check_cuda_error(cudaMemsetAsync(workspace_.barriers, 0, workspace_.barriers_size, st));\n        check_cuda_error(cudaMallocAsync(&workspace_.flags, sizeof(int), st));\n\n        core::Context::stream().Sync();\n    }\n\n    ~Impl()\n    {\n        auto st = core::Context::stream().handle();\n\n        cudaFreeAsync(workspace_.barriers, st);\n        cudaFreeAsync(workspace_.partials, st);\n        cudaFreeAsync(workspace_.tensormaps, st);\n        cudaFreeAsync(workspace_.flags, st);\n        workspace_ = {};\n    }\n\n    std::tuple<Tensor, MatrixLayout, Tensor, MatrixLayout> GetOperandB(const LlamaDenseWeight& dense)\n    {\n        const Tensor& B      = dense.weight;\n        const Tensor& V      = dense.scales;\n        MatrixLayout  desc_B = dense.k_desc;\n        MatrixLayout  desc_V = dense.q_desc;\n        return {B, desc_B, V, desc_V};\n    }\n\n    std::tuple<Tensor, MatrixLayout, Tensor, MatrixLayout>\n    GetOperandA(const LlamaDenseWeight& dense, const Tensor& input, Buffer_<int> indices, const Buffer_<int>& offsets)\n    {\n        auto st = core::Context::stream().handle();\n\n        Tensor A;\n        Tensor U;\n\n        const int m = indices ? indices.size() : input.shape(0);\n\n        // Currently, FP8 only; INT8 may be added later\n        if (input.dtype() != dense.input_type) {\n            QuantizeSymm(A, U, input, st);\n            sync_check_cuda_error();\n        }\n        else {\n            A = input;\n        }\n\n        if (indices && A.dtype() == kFloat8_e4m3) {\n            const auto [bsz, k] = A.shapes(0, 1);\n            const int e         = indices.size() / bsz;\n            Tensor    A_e       = {{m, k}, A.dtype(), kDEVICE};\n            invokeMoeDispatch(A_e, A, indices.data(), e, st);\n            sync_check_cuda_error();\n            Tensor U_e;\n            invokeMoeDispatchScales(U_e, U, indices.data(), e, st);\n            sync_check_cuda_error();\n            A       = A_e;\n            U       = U_e;\n            indices = {};  // indices already applied\n        }\n\n        MatrixLayout desc_A{A.dtype(), gemm::Order::kRowMajor, m, (int)A.shape(1), (int)A.stride(0)};\n        MatrixLayout desc_U{};\n        if (U) {\n            desc_U = {U.dtype(), kColMajor, (int)U.shape(1), (int)U.shape(0), (int)U.stride(0)};\n        }\n        if (offsets) {\n            desc_A.num = desc_U.num = dense.k_desc.num;\n            desc_A.offsets = desc_U.offsets = const_cast<int*>(offsets.data());\n        }\n        if (indices) {\n            desc_A.idxs = desc_U.idxs = const_cast<int*>(indices.data());\n        }\n\n        return {A, desc_A, U, desc_U};\n    }\n\n    void Forward(Tensor&                 output,\n                 const Tensor&           input,  //\n                 const LlamaDenseWeight& dense,\n                 const Buffer_<int>&     indices,\n                 const Buffer_<int>&     offsets)\n    {\n        using namespace gemm;\n\n        Operation op{};\n        op.dispatch  = dispatch_policy_;\n        op.epilogue  = dense.epilogue;\n        op.quant_a   = dense.input_quant;\n        op.quant_b   = dense.weight_quant;\n        op.batch_dim = 0;\n\n        auto&& [A, desc_A, U, desc_U] = GetOperandA(dense, input, indices, offsets);\n        auto&& [B, desc_B, V, desc_V] = GetOperandB(dense);\n\n        Tensor& D = output;\n        if (!D) {\n            int dim = dense.epilogue == Epilogue::kGatedSilu ? dense.output_dim / 2 : dense.output_dim;\n            D       = Tensor{{desc_A.rows, dim}, dense.data_type, kDEVICE};\n        }\n\n        // std::cout << \"D: \" << D << \" \" << desc_B.num << \"\\n\";\n\n        MatrixLayout desc_D{\n            output.dtype(),\n            kRowMajor,\n            (int)output.shape(0),\n            dense.output_dim,\n            (int)output.stride(0),\n        };\n\n        if (offsets) {\n            desc_D.num     = desc_B.num;\n            desc_D.offsets = const_cast<int*>(offsets.data());\n        }\n\n        auto ec = gemm_.Run(op,\n                            1.f,\n                            A.raw_data(),\n                            desc_A,\n                            U.data_or((void*)nullptr),\n                            desc_U,\n                            B.raw_data(),\n                            desc_B,\n                            V.data_or((void*)nullptr),\n                            desc_V,\n                            0.f,\n                            D.raw_data(),\n                            desc_D,\n                            D.raw_data(),\n                            desc_D,\n                            workspace_,\n                            core::Context::stream().handle());\n\n        if (ec) {\n            TM_LOG_ERROR(\"%s: %d\", __PRETTY_FUNCTION__, ec);\n        }\n    }\n\n    gemm::Gemm           gemm_;\n    gemm::DispatchPolicy dispatch_policy_{gemm::DispatchPolicy::kDefault};\n\n    gemm::Workspace workspace_;\n};\n\nLlamaLinear::LlamaLinear(): impl_{std::make_shared<Impl>()} {}\n\nTensor LlamaLinear::Forward(const Tensor&           input,  //\n                            const LlamaDenseWeight& weight,\n                            std::optional<Tensor>   output)\n{\n    return Forward(input, weight, {}, {}, output);\n}\n\nTensor LlamaLinear::Forward(const Tensor&           input,  //\n                            const LlamaDenseWeight& weight,\n                            const Buffer_<int>&     indices,\n                            const Buffer_<int>&     offsets,\n                            std::optional<Tensor>   output)\n{\n    Tensor in = input.view({-1, input.shape(-1)});\n    Tensor out;\n\n    if (output) {\n        out = output->view({-1, output->shape(-1)});\n    }\n\n    impl_->Forward(out, in, weight, indices, offsets);\n\n    return out;\n}\n\nvoid LlamaLinear::set_measure(bool measure)\n{\n    impl_->dispatch_policy_ = measure ? gemm::DispatchPolicy::kMeasure : gemm::DispatchPolicy::kReuse;\n}\n\nint LlamaLinear::Export(std::ostream& os)\n{\n    if (os) {\n        return impl_->gemm_.Export(os);\n    }\n    return 0;\n}\n\nint LlamaLinear::Import(std::istream& is)\n{\n    auto n_records = 0;\n    if (is) {\n        n_records = impl_->gemm_.Import(is);\n    }\n    if (n_records) {\n        impl_->dispatch_policy_ = gemm::DispatchPolicy::kReuse;\n    };\n    return n_records;\n}\n\nstd::vector<int> LlamaLinear::GetTuningSeq() const\n{\n    return impl_->gemm_.GetTuningSeq();\n}\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/models/llama/LlamaLinear.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include <istream>\n#include <ostream>\n\n#include \"src/turbomind/core/core.h\"\n#include \"src/turbomind/models/llama/LlamaDenseWeight.h\"\n\nnamespace turbomind {\n\nclass LlamaLinear {\npublic:\n    explicit LlamaLinear();\n\n    Tensor Forward(const Tensor&           input,  //\n                   const LlamaDenseWeight& weight,\n                   std::optional<Tensor>   output = {});\n\n    Tensor Forward(const Tensor&           input,\n                   const LlamaDenseWeight& weight,\n                   const Buffer_<int>&     indices,\n                   const Buffer_<int>&     offsets,\n                   std::optional<Tensor>   output = {});\n\n    void set_measure(bool measure);\n\n    [[maybe_unused]] int Export(std::ostream& os);\n\n    [[maybe_unused]] int Import(std::istream& is);\n\n    std::vector<int> GetTuningSeq() const;\n\nprivate:\n    struct Impl;\n    std::shared_ptr<Impl> impl_;\n};\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/models/llama/LlamaWeight.cc",
    "content": "/*\n * Copyright (c) OpenMMLab. All rights reserved.\n * Copyright (c) 2019-2023, NVIDIA CORPORATION.  All rights reserved.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n// Modified from\n// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/models/multi_gpu_gpt/ParallelGptWeight.cc\n\n#include <cuda_runtime.h>\n\n#include \"src/turbomind/core/allocator.h\"\n#include \"src/turbomind/core/context.h\"\n#include \"src/turbomind/models/llama/LlamaDenseWeight.h\"\n#include \"src/turbomind/models/llama/LlamaWeight.h\"\n#include \"src/turbomind/models/llama/llama_params.h\"\n#include \"src/turbomind/utils/cuda_utils.h\"\n\nnamespace turbomind {\n\nLlamaWeight::LlamaWeight(DataType           data_type,\n                         const ModelParam&  model,\n                         const EngineParam& engine_param,\n                         const MoeParam&    moe_param):\n    model_param_{model},\n    engine_param_{engine_param},\n    moe_param_{moe_param},\n    hidden_units_(model.hidden_units),\n    inter_size_(model.inter_size),\n    vocab_size_(model.vocab_size),\n    vocab_size_padded_(model.vocab_size),\n    embedding_size_(model.embedding_size),\n    num_layer_(model.layer_num),\n    data_type_{data_type},\n    weight_type_{model.weight_type},\n    tp_size_(engine_param.attn_tp_size * engine_param.attn_cp_size),\n    tp_rank_(engine_param.attn_tp_rank * engine_param.attn_cp_size + engine_param.attn_cp_rank)\n{\n    if (vocab_size_padded_ % tp_size_ != 0) {\n        vocab_size_padded_ = (vocab_size_ + tp_size_ - 1) / tp_size_ * tp_size_;\n        TM_LOG_WARNING(\"pad vocab size from %d to %d\", vocab_size_, vocab_size_padded_);\n    }\n    if (embedding_size_ % tp_size_ != 0) {\n        embedding_size_ = (embedding_size_ + tp_size_ - 1) / tp_size_ * tp_size_;\n        TM_LOG_WARNING(\"pad embed size from %d to %d\", embedding_size_, embedding_size_);\n    }\n    FT_CHECK(hidden_units_ % tp_size_ == 0);\n    TM_CHECK_EQ(vocab_size_padded_ % tp_size_, 0);\n    TM_CHECK_EQ(hidden_units_ % tp_size_, 0);\n\n    stream_ = core::Stream::create();\n    alloca_ = core::Allocator{stream_, false};\n\n    initialize();\n}\n\nLlamaWeight::~LlamaWeight()\n{\n    release();\n}\n\nbool LlamaWeight::is_initialized() const\n{\n    return initialized_;\n}\n\nvoid LlamaWeight::initialize()\n{\n    core::ContextGuard guard = context();\n\n    pre_decoder_embedding.emplace(embedding_size_, hidden_units_ / tp_size_, data_type_, false, data_type_, 1);\n    post_decoder_embedding.emplace(hidden_units_, vocab_size_padded_ / tp_size_, data_type_, false, data_type_, 1);\n    register_module(\"tok_embeddings\", pre_decoder_embedding, tp_rank_);\n    register_module(\"output\", post_decoder_embedding, tp_rank_);\n\n    /// Lower VRAM pressure on consumer grade GPUs\n    /// TODO: Support token embeds on pinned host memory\n    pre_decoder_embedding.weight  = empty_like(pre_decoder_embedding.weight, kCPU);\n    post_decoder_embedding.weight = empty_like(post_decoder_embedding.weight, kCPU);\n\n    decoder_layer_weights.reserve(num_layer_);\n    for (int i = 0; i < num_layer_; ++i) {\n        decoder_layer_weights.emplace_back(\n            new LlamaDecoderLayerWeight(data_type_, i, model_param_, engine_param_, moe_param_));\n        register_module(\"layers\", *decoder_layer_weights.back(), i);\n    }\n\n    output_norm_weight = Tensor{{hidden_units_}, data_type_, kDEVICE};\n    register_parameter(\"norm.weight\", output_norm_weight);\n    initialized_ = true;\n}\n\nvoid LlamaWeight::release()\n{\n    core::ContextGuard guard = context();\n\n    pre_decoder_embedding  = {};\n    post_decoder_embedding = {};\n    output_norm_weight     = {};\n\n    for (auto& p : decoder_layer_weights) {\n        delete p;\n    }\n\n    decoder_layer_weights.clear();\n    pinned_weights_.clear();\n\n    // Wait for deallocations\n    core::Context::stream().Sync();\n\n    // release memory back to os\n    core::Context::device_alloc()->trim(0);\n    initialized_ = false;\n}\n\nvoid LlamaWeight::to_device(const core::Device& device)\n{\n    TM_CHECK(device.type == kCPU || device.type == kDEVICE);\n    core::ContextGuard guard{stream_, alloca_, Allocator{kCPUpinned}};\n\n    auto tensor_ptr_map = get_parameters();\n    for (auto& [name, tensor_ptr] : tensor_ptr_map) {\n        if (device.type == kCPU) {\n            if (pinned_weights_.find(name) == pinned_weights_.end()) {\n                pinned_weights_[name] = empty_like(*tensor_ptr, kCPUpinned);\n                Copy(*tensor_ptr, pinned_weights_[name]);\n            }\n            *tensor_ptr = {};\n        }\n        else {\n            TM_CHECK(pinned_weights_.find(name) != pinned_weights_.end());\n            *tensor_ptr = empty_like(pinned_weights_[name], kDEVICE);\n            Copy(pinned_weights_[name], *tensor_ptr);\n        }\n    }\n    core::Context::stream().Sync();\n    if (device.type == kCPU) {\n        core::Context::device_alloc()->trim(0);\n    }\n}\n\ncore::ContextGuard LlamaWeight::context() const\n{\n    return core::ContextGuard{stream_, alloca_};\n}\n\nvoid LlamaWeight::prepare(const cudaDeviceProp& prop)\n{\n    core::ContextGuard guard = context();\n\n    // Wait for the weights to be filled externally\n    check_cuda_error(cudaDeviceSynchronize());\n\n    auto stream = core::Context::stream().handle();\n\n    for (auto& layer : decoder_layer_weights) {\n        layer->prepare(prop, stream);\n    }\n\n    auto to_device = [](Tensor& x) {\n        auto tmp = std::exchange(x, empty_like(x, kDEVICE));\n        Copy(tmp, x);\n        return tmp;\n    };\n\n    // Keep the host tensor until stream synchronization\n    auto tmp_token_embeds = to_device(pre_decoder_embedding.weight);\n    auto tmp_lm_head      = to_device(post_decoder_embedding.weight);\n\n    post_decoder_embedding.prepare();\n\n    // Block until processing is done\n    check_cuda_error(cudaStreamSynchronize(stream));\n}\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/models/llama/LlamaWeight.h",
    "content": "/*\n * Copyright (c) OpenMMLab. All rights reserved.\n * Copyright (c) 2019-2023, NVIDIA CORPORATION.  All rights reserved.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n// Modified from\n// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/models/multi_gpu_gpt/ParallelGptWeight.h\n\n#pragma once\n\n#include <unordered_map>\n\n#include \"src/turbomind/core/context.h\"\n#include \"src/turbomind/models/llama/LlamaDecoderLayerWeight.h\"\n#include \"src/turbomind/models/llama/LlamaDenseWeight.h\"\n#include \"src/turbomind/models/llama/llama_params.h\"\n\nnamespace turbomind {\n\nstruct LlamaWeight: core::Module {\n    LlamaWeight() = default;\n\n    LlamaWeight(DataType           data_type,\n                const ModelParam&  model_param,\n                const EngineParam& engine_param,\n                const MoeParam&    moe_param);\n\n    ~LlamaWeight();\n\n    LlamaWeight(const LlamaWeight&) = delete;\n    LlamaWeight& operator=(const LlamaWeight&) = delete;\n\n    void prepare(const cudaDeviceProp& prop);\n\n    bool is_initialized() const;\n\n    void initialize();\n\n    void release();\n\n    void to_device(const core::Device& device);\n\n    core::ContextGuard context() const;\n\n    std::vector<LlamaDecoderLayerWeight*> decoder_layer_weights;\n\n    LlamaDenseWeight pre_decoder_embedding;\n    LlamaDenseWeight post_decoder_embedding;\n\n    Tensor output_norm_weight;\n\nprivate:\n    const ModelParam  model_param_;\n    const EngineParam engine_param_;\n    const MoeParam    moe_param_;\n\n    int hidden_units_;\n    int vocab_size_;\n    int vocab_size_padded_;\n    int embedding_size_;\n    int num_layer_;\n\n    DataType data_type_;\n    DataType weight_type_;\n\n    std::unordered_map<std::string, Tensor> pinned_weights_;\n\n    int tp_size_;  // this will follow attn tp param\n    int tp_rank_;\n\n    std::vector<int> inter_size_;\n\n    core::Stream    stream_;\n    core::Allocator alloca_;\n    bool            initialized_{false};\n};\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/models/llama/SequenceManager.cc",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include <cstddef>\n#include <cstdlib>\n#include <ctime>\n#include <numeric>\n\n#include \"src/turbomind/kernels/attention/block.h\"\n#include \"src/turbomind/models/llama/BlockManager.h\"\n#include \"src/turbomind/models/llama/SequenceManager.h\"\n#include \"src/turbomind/utils/logger.h\"\n\n// #include \"dbg.h\"\n\nnamespace turbomind {\n\ntemplate<typename T>\nstd::string vector2string(const std::vector<T>& data)\n{\n    if (data.empty()) {\n        return \"nil\";\n    }\n    std::stringstream ss;\n\n    auto it = data.begin();\n    ss << *it;\n\n    for (++it; it != data.end(); ++it) {\n        ss << \", \" << *it;\n    }\n    return ss.str();\n}\n\nSequenceManager::SequenceManager(const ModelParam& model_param,\n                                 DataType          runtime_dtype,\n                                 int               cache_block_seq_len,\n                                 int               attn_tp_size,\n                                 int               max_batch_size,\n                                 double            block_count,\n                                 int               chunk_size,\n                                 bool              enable_prefix_caching,\n                                 int               rank,\n                                 int               attn_cp_size,\n                                 core::Allocator   allocator,\n                                 GetFreeMemSize    get_free_size):\n    block_seq_len_(cache_block_seq_len), rank_(rank), attn_cp_size_(attn_cp_size)\n{\n    TM_CHECK_GT(attn_tp_size, 0);\n    TM_CHECK_GT(cache_block_seq_len, 0);\n\n    int cache_layer_num   = model_param.layer_num;\n    int num_linear_layers = 0;\n    for (const auto& type : model_param.layer_types) {\n        if (type == 1) {\n            --cache_layer_num;\n            ++num_linear_layers;\n        }\n    }\n\n    const size_t free_before = (block_count < 1. && num_linear_layers > 0) ? get_free_size() : 0;\n\n    if (num_linear_layers > 0) {\n\n        const int key_head_dim =\n            model_param.linear_key_head_dim > 0 ? model_param.linear_key_head_dim : model_param.head_dim;\n        const int value_head_dim =\n            model_param.linear_value_head_dim > 0 ? model_param.linear_value_head_dim : model_param.head_dim;\n        const int d_conv      = model_param.linear_conv_kernel_dim > 0 ? model_param.linear_conv_kernel_dim : 4;\n        const int num_k_heads = model_param.linear_num_key_heads / attn_tp_size;\n        const int num_v_heads = model_param.linear_num_value_heads / attn_tp_size;\n        const int key_dim     = num_k_heads * key_head_dim;\n        const int value_dim   = num_v_heads * value_head_dim;\n        const int conv_dim    = key_dim * 2 + value_dim;\n\n        TM_CHECK_GT(max_batch_size, 0);\n        pooled_conv_states_ = {{max_batch_size, num_linear_layers, d_conv, conv_dim}, model_param.data_type, kDEVICE};\n        pooled_recurrent_states_ = {{max_batch_size, num_linear_layers, num_v_heads, key_head_dim, value_head_dim},\n                                    model_param.linear_state_dtype,\n                                    kDEVICE};\n\n        free_linear_state_slots_.reserve(max_batch_size);\n        for (int slot = max_batch_size - 1; slot >= 0; --slot) {\n            free_linear_state_slots_.push_back(slot);\n        }\n        TM_LOG_INFO(\"[SeqMgr] linear-state slot pool initialized: %d slots\", max_batch_size);\n        const auto   conv_one      = pooled_conv_states_.slice(0, 1).squeeze(0);\n        const auto   recurrent_one = pooled_recurrent_states_.slice(0, 1).squeeze(0);\n        const double mb            = 1.0 / (1024.0 * 1024.0);\n        TM_LOG_INFO(\"[SeqMgr] linear-state per slot: conv %.2f MB + recurrent %.2f MB = %.2f MB\",\n                    conv_one.byte_size() * mb,\n                    recurrent_one.byte_size() * mb,\n                    (conv_one.byte_size() + recurrent_one.byte_size()) * mb);\n        TM_LOG_INFO(\"[SeqMgr] linear-state combined total: %.2f MB\",\n                    (pooled_conv_states_.byte_size() + pooled_recurrent_states_.byte_size()) * mb);\n    }\n\n    const int  dbits        = byte_size(runtime_dtype, 8);\n    const auto quant_policy = model_param.quant_policy;\n    const int  elem_bits    = quant_policy ? quant_policy : dbits;\n\n    BlockConfig block_config{\n        (int)model_param.head_dim,\n        (int)model_param.kv_head_num / attn_tp_size,\n        cache_block_seq_len,\n        elem_bits == dbits ? 0 : dbits,\n        elem_bits,\n        model_param.head_dim == 576,  // share kv\n    };\n\n    block::Layout layout{block_config};\n    // dump(layout);\n\n    size_t block_size = layout.block_size(cache_layer_num);\n\n    if (num_linear_layers > 0 && block_count < 1.) {\n        const size_t linear_bytes = pooled_conv_states_.byte_size() + pooled_recurrent_states_.byte_size();\n        const size_t target_bytes = static_cast<size_t>(free_before * block_count);\n        TM_LOG_INFO(\"[SeqMgr] Adjusting block_count: free_before %.2f MB, linear %.2f MB, target %.2f MB\",\n                    free_before / (1024. * 1024.),\n                    linear_bytes / (1024. * 1024.),\n                    target_bytes / (1024. * 1024.));\n        if (target_bytes <= linear_bytes) {\n            TM_LOG_ERROR(\"[SeqMgr] Linear-state memory (%.2f MB) >= cache budget (%.2f MB). \",\n                         linear_bytes / (1024. * 1024.),\n                         target_bytes / (1024. * 1024.));\n            TM_CHECK(0)\n                << \"Please decrease max_batch_size to reduce total linear state size or increase cache_max_entry_count.\";\n        }\n        const size_t cache_bytes = target_bytes - linear_bytes;\n        block_count              = static_cast<double>(cache_bytes) / static_cast<double>(block_size);\n        TM_LOG_INFO(\"[SeqMgr] Adjusted block_count to %.0f\", block_count);\n    }\n\n    block_manager_ = std::make_shared<BlockManager>(block_size, block_count, chunk_size, allocator, get_free_size);\n\n    if (enable_prefix_caching) {\n        block_trie_ = std::make_shared<BlockTrie>(block_config.block_len_, block_manager_);\n    }\n    TM_LOG_WARNING(\"[SegMgr] prefix caching is %s\", enable_prefix_caching ? \"enabled\" : \"disabled\");\n}\n\nconst Sequence* SequenceManager::Create(uint64_t id)\n{\n    Sequence sequence{id};\n    auto     it = sequences_.find(id);\n    if (it != sequences_.end()) {\n        if (rank_ == 0) {\n            TM_LOG_WARNING(\"[SeqMgr][Create] Removing conflicting ID %llu\", id);\n        }\n        Erase(it);\n    }\n    it = sequences_.emplace_hint(it, id, std::move(sequence));\n    if (rank_ == 0) {\n        TM_LOG_INFO(\"[SeqMgr][Create] ID %llu\", id);\n    }\n    return &it->second;\n}\n\nconst Sequence* SequenceManager::Get(uint64_t id)\n{\n    if (auto it = sequences_.find(id); it != sequences_.end()) {\n        return &it->second;\n    }\n    return nullptr;\n}\n\nbool SequenceManager::Contains(uint64_t id)\n{\n    return sequences_.find(id) != sequences_.end();\n}\n\nvoid SequenceManager::Erase(std::map<uint64_t, Sequence>::iterator& it)\n{\n    auto& seq = it->second;\n    if (seq.status == Sequence::kCached) {\n        const int count = block_manager_->Verify(seq.blocks, seq.block_unique_ids);\n        seq.blocks.resize(count);\n    }\n    else {\n        UpdateAndSetUnlock(seq);\n    }\n    // if prefix cache enabled, blocks will be shared by sequences, cannot be freed immediately\n    if (!block_trie_) {\n        freed_.insert(freed_.end(), seq.blocks.begin(), seq.blocks.end());\n    }\n    ReleaseLinearStateSlot(seq);\n    it = sequences_.erase(it);\n}\n\nbool SequenceManager::Erase(uint64_t id)\n{\n    if (auto it = sequences_.find(id); it != sequences_.end()) {\n        Erase(it);\n        return true;\n    }\n    return false;\n}\n\nvoid SequenceManager::AcquireLinearStateSlot(const Sequence& sequence)\n{\n    if (!pooled_recurrent_states_) {\n        return;\n    }\n\n    auto& seq = const_cast<Sequence&>(sequence);\n\n    auto slot_it = seq_to_linear_state_slot_.find(seq.id);\n    if (slot_it != seq_to_linear_state_slot_.end()) {\n        const int slot       = slot_it->second;\n        seq.conv_states      = pooled_conv_states_.slice(slot).squeeze(0);\n        seq.recurrent_states = pooled_recurrent_states_.slice(slot).squeeze(0);\n        return;\n    }\n\n    TM_CHECK(!free_linear_state_slots_.empty()) << \"No free linear-state slot for sequence \" << seq.id\n                                                << \", max_batch_size=\" << pooled_recurrent_states_.shape(0);\n\n    const int slot = free_linear_state_slots_.back();\n    free_linear_state_slots_.pop_back();\n    seq_to_linear_state_slot_.emplace(seq.id, slot);\n\n    seq.conv_states              = pooled_conv_states_.slice(slot).squeeze(0);\n    seq.recurrent_states         = pooled_recurrent_states_.slice(slot).squeeze(0);\n    seq.linear_states_need_reset = true;\n}\n\nvoid SequenceManager::ReleaseLinearStateSlot(const Sequence& sequence)\n{\n    if (!pooled_recurrent_states_) {\n        return;\n    }\n\n    auto& seq = const_cast<Sequence&>(sequence);\n\n    if (auto slot_it = seq_to_linear_state_slot_.find(seq.id); slot_it != seq_to_linear_state_slot_.end()) {\n        free_linear_state_slots_.push_back(slot_it->second);\n        seq_to_linear_state_slot_.erase(slot_it);\n    }\n    seq.conv_states              = {};\n    seq.recurrent_states         = {};\n    seq.linear_states_need_reset = false;\n}\n\nvoid SequenceManager::InvalidateStatesAndCache(const Sequence& sequence)\n{\n    InvalidateStatesAndCache(sequence, freed_);\n}\n\nvoid SequenceManager::InvalidateStatesAndCache(const Sequence& sequence, BlockIds& freed_blocks)\n{\n    auto& seq = const_cast<Sequence&>(sequence);\n    if (seq.status != Sequence::kCached) {\n        UpdateAndSetUnlock(seq);\n    }\n    freed_blocks.insert(freed_blocks.end(), seq.blocks.begin(), seq.blocks.end());\n\n    seq.blocks.clear();\n    seq.block_unique_ids.clear();\n    seq.input_length = 0;\n    seq.cache_len    = 0;\n    ReleaseLinearStateSlot(seq);\n}\n\nvoid SequenceManager::CachePrompt(const Sequences& sequences, int active_size)\n{\n    if (!block_trie_) {\n        return;\n    }\n\n    for (int i = 0; i < active_size; ++i) {\n        if (auto& seq = *sequences[i]; !seq.prompt.empty()) {\n            const auto& [block_ids, unique_ids] = block_trie_->Cache(seq, seq.prompt);\n            if (rank_ == 0) {\n                // clang-format off\n                TM_LOG_INFO(\"[SeqMgr][CachePrompt] ID %llu, cached blocks %d, tokens %d\", seq.id,\n                            (int)block_ids.size(), (int)seq.prompt.size());\n                TM_LOG_DEBUG(\"[SeqMgr][CachePrompt] ID %llu, cached block_ids %s, unique_ids %s\", seq.id,\n                             vector2string(block_ids).c_str(), vector2string(unique_ids).c_str());\n                // clang-format on\n            }\n            if (seq.cache_len >= seq.prompt.size()) {\n                seq.prompt.clear();\n            }\n        }\n    }\n}\n\nvoid SequenceManager::CacheGeneration(const Sequence& seq)\n{\n    if (!block_trie_) {\n        return;\n    }\n\n    const auto& [block_ids, unique_ids] = block_trie_->Cache(seq, seq.tokens);\n\n    if (rank_ == 0) {\n        // clang-format off\n        TM_LOG_INFO(\"[SeqMgr][CacheGeneration] ID %llu, cached blocks %d, tokens %d\",\n                    seq.id, (int)block_ids.size(), (int)seq.tokens.size());\n        TM_LOG_DEBUG(\"[SeqMgr][CacheGeneration] ID %llu, cached block_ids %s, unique_ids %s\", seq.id,\n                     vector2string(block_ids).c_str(), vector2string(unique_ids).c_str());\n        // clang-format on\n    }\n}\n\nvoid SequenceManager::VerifyAndLockCached(const Sequences& sequences)\n{\n    BlockIds valid_blocks;\n    BlockIds freed_blocks;\n    for (const auto& p : sequences) {\n        auto& seq = const_cast<Sequence&>(*p);\n        if (seq.status != Sequence::kCached) {\n            continue;\n        }\n        TM_CHECK_EQ(seq.blocks.size(), seq.block_unique_ids.size());\n        // Verify cache blocks that may be invalidated\n        const int original_count = seq.blocks.size();\n        const int count          = block_manager_->Verify(seq.blocks, seq.block_unique_ids);\n        seq.blocks.resize(count);\n        seq.block_unique_ids.resize(count);\n\n        const bool has_linear_states = static_cast<bool>(seq.recurrent_states);\n        if (has_linear_states && count < original_count) {\n            InvalidateStatesAndCache(seq, freed_blocks);\n            // This request can still continue in the current scheduling round.\n            // Rebind a slot immediately so GatedDeltaNetLayer::Setup always sees\n            // valid linear-state views.\n            AcquireLinearStateSlot(seq);\n            continue;\n        }\n\n        valid_blocks.insert(valid_blocks.end(), seq.blocks.begin(), seq.blocks.end());\n        seq.cache_len = std::min<int>(seq.cache_len, seq.blocks.size() * block_seq_len_);\n        seq.status    = Sequence::kLocked;\n    }\n    if (!freed_blocks.empty()) {\n        block_manager_->Free(freed_blocks);\n    }\n    block_manager_->Lock(valid_blocks);\n}\n\nvoid SequenceManager::CommitUnlockAndFree()\n{\n    if (!unlocked_.empty()) {\n        block_manager_->Unlock(unlocked_);\n        unlocked_.clear();\n    }\n\n    if (!freed_.empty()) {\n        block_manager_->Free(freed_);\n        freed_.clear();\n    }\n}\n\nvoid SequenceManager::UpdateAndSetUnlock(const Sequence& sequence)\n{\n    TM_CHECK_NE(sequence.status, Sequence::kCached);\n    auto& seq = const_cast<Sequence&>(sequence);\n    block_manager_->Touch(seq.blocks);\n    unlocked_.insert(unlocked_.end(), seq.blocks.begin(), seq.blocks.end());\n    seq.status = Sequence::kCached;\n}\n\nnamespace {\n\nstruct Schedule {\n    int free;\n    int cached;\n\n    int allocate{};\n    int evict{};\n    int preempt{};\n\n    int last;\n\n    int max_fwd_tokens;\n    int max_tmp_tokens;\n\n    Sequences        active;\n    std::vector<int> block_counts;\n    Sequences        inactive;\n    Sequences        victims;\n\n    Schedule(Snapshot snapshot, int size, int max_fwd_tokens, int max_tmp_tokens):\n        free{snapshot.free},\n        cached{snapshot.cached},\n        last{size},\n        max_fwd_tokens{max_fwd_tokens},\n        max_tmp_tokens{max_tmp_tokens},\n        use_count_{std::move(snapshot.use_count)},\n        unlocked_(size),  // ! This is a vector, DO NOT brace initialize it\n        it_{size}\n    {\n    }\n\n    int Unlock(const Sequences& seqs, int vidx)\n    {\n        while (vidx < it_) {\n            const auto& blocks = seqs[--it_]->blocks;\n            int         count  = 0;\n            for (const auto& bid : blocks) {\n                count += static_cast<int>(--use_count_[bid] == 0);\n            }\n            unlocked_[it_] = count;\n        }\n        return unlocked_[vidx];\n    }\n\nprivate:\n    std::vector<int> use_count_;\n    std::vector<int> unlocked_;\n    int              it_;\n};\n\ntemplate<typename T>\nstd::ostream& operator<<(std::ostream& os, const std::vector<T>& v)\n{\n    os << \"[\";\n    for (int i = 0; i < v.size(); ++i) {\n        os << (i ? \",\" : \"\") << v[i];\n    }\n    os << \"]\";\n    return os;\n}\n\nstd::ostream& operator<<(std::ostream& os, const Schedule& s)\n{\n    os << \"free=\" << s.free << \", cached=\" << s.cached << \", allocate=\" << s.allocate << \", evict=\" << s.evict\n       << \", preempt=\" << s.preempt << \", active=\" << s.active << \", victims=\" << s.victims\n       << \", block_counts=\" << s.block_counts << \", inactive=\" << s.inactive;\n    return os;\n}\n\nstruct Transaction {\n    int index_;\n    int block_count_;\n    int input_len_;\n    int temp_len_;\n\n    int allocate_{};\n    int evict_{};\n    int preempt_{};\n\n    Sequences victims_;\n\n    const Sequences& sequences_;\n    Schedule&        schedule_;\n\n    explicit Transaction(\n        const Sequences& sequences, int index, int block_count, int input_len, int temp_len, Schedule& sched):\n        index_{index},\n        block_count_{block_count},\n        input_len_{input_len},\n        temp_len_{temp_len},\n        sequences_{sequences},\n        schedule_{sched}\n    {\n    }\n\n    void Process()\n    {\n        if (schedule_.max_fwd_tokens > 0 && schedule_.max_tmp_tokens >= temp_len_) {\n            int count = block_count_;\n\n            int tmp = std::min(schedule_.free, count);\n            count -= tmp;\n            allocate_ += tmp;\n\n            tmp = std::min(schedule_.cached, count);\n            count -= tmp;\n            evict_ += tmp;\n\n            for (int vidx = schedule_.last - 1; count && vidx > index_; --vidx) {\n                if (sequences_[vidx]->status == Sequence::kCached) {\n                    continue;\n                }\n                victims_.push_back(sequences_[vidx]);\n                preempt_ += schedule_.Unlock(sequences_, vidx);\n\n                if (count <= preempt_) {\n                    evict_ += count;\n                    count -= count;\n                    schedule_.last = vidx;  // ! modifiying `sched_.last` is part of commit\n                    break;\n                }\n            }\n            if (count == 0) {\n                return Commit();\n            }\n        }\n\n        const_cast<Sequence*>(sequences_[index_])->input_length = 0;\n        schedule_.inactive.push_back(sequences_[index_]);\n    }\n\n    void Commit()\n    {\n        // update available resources\n        schedule_.free -= allocate_;\n        TM_CHECK_GE(schedule_.free, 0);\n        schedule_.cached += preempt_;\n        schedule_.cached -= evict_;\n        TM_CHECK_GE(schedule_.cached, 0);\n\n        // update scheduled operations\n        schedule_.allocate += allocate_;\n        schedule_.evict += evict_;\n        schedule_.preempt += preempt_;\n        schedule_.victims.insert(schedule_.victims.end(), victims_.begin(), victims_.end());\n\n        // update active sequences\n        schedule_.active.push_back(sequences_[index_]);\n        schedule_.block_counts.push_back(block_count_);\n\n        input_len_ = std::min(input_len_, schedule_.max_fwd_tokens);\n        schedule_.max_fwd_tokens -= input_len_;\n        const_cast<Sequence*>(sequences_[index_])->input_length = input_len_;\n\n        schedule_.max_tmp_tokens -= temp_len_;\n    }\n};\n\nstd::ostream& operator<<(std::ostream& os, const Transaction& trans)\n{\n    os << \"index=\" << trans.index_ << \", block_count=\" << trans.block_count_ << \", allocate=\" << trans.allocate_\n       << \", evict=\" << trans.evict_ << \", preempt=\" << trans.preempt_ << \", victims=\" << trans.victims_;\n    return os;\n}\n\n}  // namespace\n\ntemplate<class Key, class... Ts>\nstatic void SortByKey(const std::vector<Key>& keys, std::vector<Ts>&... vals)\n{\n    std::vector<int> idxs(keys.size());\n    std::iota(idxs.begin(), idxs.end(), 0);\n    std::sort(idxs.begin(), idxs.end(), [&](int i, int j) { return keys[i] < keys[j]; });\n    auto reorder = [&](auto& xs) {\n        std::remove_reference_t<decltype(xs)> ys(xs.size());\n        for (size_t i = 0; i < xs.size(); ++i) {\n            ys[i] = xs[idxs[i]];\n        }\n        xs.swap(ys);\n    };\n    (reorder(vals), ...);\n}\n\nstd::vector<int> SequenceManager::CountRequiredBlocks(const Sequences&        sequences,\n                                                      const std::vector<int>& context_length)\n{\n    std::vector<int> required(sequences.size());\n    for (int i = 0; i < sequences.size(); ++i) {\n        int length  = (context_length[i] + attn_cp_size_ - 1) / attn_cp_size_;\n        int count   = (length + block_seq_len_ - 1) / block_seq_len_ - static_cast<int>(sequences[i]->blocks.size());\n        required[i] = std::max(0, count);\n    }\n    return required;\n}\n\nvoid SequenceManager::AssignAndActivate(const Sequences&        sequences,  //\n                                        const std::vector<int>& counts,\n                                        const BlockIds&         blocks,\n                                        const UniqueIds&        unique_ids)\n{\n    TM_CHECK_EQ(sequences.size(), counts.size());\n    int first = 0;\n    for (int i = 0; i < sequences.size(); ++i) {\n        auto& s     = const_cast<Sequence&>(*sequences[i]);\n        auto  count = counts[i];\n        int   last  = first + count;\n        TM_CHECK_LE(last, blocks.size());\n        s.blocks.insert(s.blocks.end(), blocks.begin() + first, blocks.begin() + last);\n        s.block_unique_ids.insert(s.block_unique_ids.end(), unique_ids.begin() + first, unique_ids.begin() + last);\n        s.status = Sequence::kActive;\n        first    = last;\n    }\n}\n\nvoid SequenceManager::PrefixMatch(Sequences& sequences, const std::vector<int>& alpha)\n{\n    if (!block_trie_) {\n        return;\n    }\n\n    for (int i = 0; i < sequences.size(); i++) {\n\n        auto& seq = const_cast<Sequence&>(*sequences[i]);\n\n        /// TODO: Is there a way to exploit the alpha[i] != 0 case?\n        if (alpha[i] != 0 || seq.cache_len >= seq.prompt.size()) {\n            continue;\n        }\n\n        const auto& [block_ids, unique_ids] = block_trie_->Match(seq);\n\n        if (rank_ == 0) {\n            // clang-format off\n            TM_LOG_INFO(\"[SeqMgr][match] ID %llu, hit blocks %d, cache_len %d\", seq.id, (int)block_ids.size(), seq.cache_len);\n            TM_LOG_DEBUG(\"[SeqMgr][match] ID %llu, hit block_ids %s, unique_ids %s\", seq.id,\n                         vector2string(block_ids).c_str(), vector2string(unique_ids).c_str());\n            // clang-format on\n        }\n\n        /// TODO: `Unlock` and `Lock` can't be batched because there may be repeated blocks between sequences\n        if (const int offset = seq.cache_len / block_seq_len_; offset < block_ids.size()) {\n            if (BlockIds tail{seq.blocks.begin() + offset, seq.blocks.end()}; !tail.empty()) {\n                block_manager_->Unlock(tail);\n                seq.blocks.resize(offset);\n                seq.block_unique_ids.resize(offset);\n            }\n            seq.blocks.insert(seq.blocks.end(), block_ids.begin() + offset, block_ids.end());\n            seq.block_unique_ids.insert(seq.block_unique_ids.end(), unique_ids.begin() + offset, unique_ids.end());\n            seq.cache_len = seq.blocks.size() * block_seq_len_;\n            block_manager_->Lock({block_ids.begin() + offset, block_ids.end()});\n        }\n\n        if (rank_ == 0) {\n            // clang-format off\n            TM_LOG_INFO(\"[SeqMgr][match] ID %llu, after matching, blocks %d, cache_len %d\",\n                        seq.id, seq.blocks.size(), seq.cache_len);\n            TM_LOG_DEBUG(\"[SeqMgr][match] ID %llu, after matching, block_ids %s, unique_ids %s\", seq.id,\n                         vector2string(seq.blocks).c_str(), vector2string(seq.block_unique_ids).c_str());\n            // clang-format on\n        }\n    }\n}\n\nauto SequenceManager::Materialize(Sequences             sequences,\n                                  std::vector<int>      context_length,\n                                  std::vector<int>      alpha,\n                                  std::vector<uint64_t> priorities,\n                                  int                   max_fwd_tokens,\n                                  int                   max_tmp_tokens) -> Outcome\n{\n    ////////////////////////////////////////////////////////////////////////////////\n    /// Schedule the assignment of blocks to sequences\n\n    // process deferred unlock and free operations\n    CommitUnlockAndFree();\n\n    SortByKey(priorities, sequences, context_length, alpha);\n\n    // Verify and lock cache sequences to avoid their blocks being evicted unnoticed\n    // the blocks can still be preempted later\n    VerifyAndLockCached(sequences);\n\n    PrefixMatch(sequences, alpha);\n\n    std::vector required = CountRequiredBlocks(sequences, context_length);\n\n    Schedule schedule(block_manager_->TakeSnapshot(), sequences.size(), max_fwd_tokens, max_tmp_tokens);\n\n    // `schedule.last` is decreasing in the loop\n    for (int i = 0; i < schedule.last; ++i) {\n        auto&     s         = *sequences[i];\n        const int input_len = context_length[i] - alpha[i] - s.cache_len;\n        // sanity check\n        TM_CHECK_GT(input_len, 0) << \"Logical error: \" << context_length[i] << \" \" << alpha[i] << \" \" << s.cache_len\n                                  << \" \" << s.status;\n        // temp buffer for flatten KV cache\n        const int temp_len = (input_len > 1 || s.status != Sequence::kActive) ? context_length[i] : 0;\n        Transaction{sequences, i, required[i], input_len, temp_len, schedule}.Process();\n    }\n\n    // mark remaining sequences invalid\n    for (int i = schedule.last; i < sequences.size(); ++i) {\n        schedule.inactive.push_back(sequences[i]);\n    }\n\n    ////////////////////////////////////////////////////////////////////////////////\n    /// Schedule is ready, time to execute it. (locked -> cached -> free -> locked)\n\n    // combine allocate and evict since evicted blocks are reused by allocation\n    schedule.allocate += schedule.evict;\n\n    // if (schedule.allocate) {\n    //     dbg(*block_manager_);\n    // }\n\n    Outcome outcome{};\n    outcome.allocation = schedule.allocate;\n    outcome.swap_in    = std::count_if(schedule.active.begin(), schedule.active.end(), [](auto p) {\n        // if (p->status != Sequence::kActive) {\n        //     dbg(*p);\n        // }\n        return p->status != Sequence::kActive;\n    });\n    outcome.swap_out = std::count_if(schedule.inactive.begin(), schedule.inactive.end(), [](auto p) {\n        // if (p->status == Sequence::kActive) {\n        //     dbg(*p);\n        // }\n        return p->status == Sequence::kActive;\n    });\n\n    // release preempted blocks -> cached\n    if (!schedule.victims.empty()) {\n        TM_LOG_INFO(\"[SeqMgr] #victim: %d\", (int)schedule.victims.size());\n        for (const auto& p : schedule.victims) {\n            UpdateAndSetUnlock(*p);\n        }\n        CommitUnlockAndFree();\n    }\n\n    // evict cached blocks -> free\n    if (schedule.evict) {\n        block_manager_->Evict(schedule.evict);\n    }\n\n    // allocate & assign blocks\n    {\n        BlockIds  block_ids;\n        UniqueIds unique_ids;\n        if (schedule.allocate) {\n            std::tie(block_ids, unique_ids) = block_manager_->Allocate(schedule.allocate);\n        }\n        AssignAndActivate(schedule.active, schedule.block_counts, block_ids, unique_ids);\n    }\n\n    // active -> locked\n    for (const auto& p : schedule.inactive) {\n        if (p->status == Sequence::kActive) {\n            const_cast<Sequence*>(p)->status = Sequence::kLocked;\n        }\n    }\n\n    // TM_LOG_ERROR(\"active: %4d, cached: %4d, free: %4d\",\n    //              block_manager_->active_count(),\n    //              block_manager_->cached_count(),\n    //              block_manager_->free_count());\n    if (block_trie_) {\n        block_trie_->Verify();\n    }\n\n    return outcome;\n}\n\nstd::tuple<int, int, int> SequenceManager::seq_stats() const noexcept\n{\n    int total  = static_cast<int>(sequences_.size());\n    int active = 0;\n    int cached = 0;\n    for (const auto& p : sequences_) {\n        if (p.second.status == Sequence::kActive) {\n            ++active;\n        }\n        else if (p.second.status == Sequence::kCached) {\n            ++cached;\n        }\n    }\n    return std::make_tuple(total, active, cached);\n}\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/models/llama/SequenceManager.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include <functional>\n#include <unordered_map>\n\n#include \"src/turbomind/core/allocator.h\"\n#include \"src/turbomind/core/core.h\"\n\n#include \"src/turbomind/models/llama/BlockManager.h\"\n#include \"src/turbomind/models/llama/BlockTrie.h\"\n#include \"src/turbomind/models/llama/llama_params.h\"\n\nnamespace turbomind {\n\nstruct Sequence {\n\n    enum Status\n    {\n        kCached = 0,\n        kLocked,\n        kActive\n    };\n\n    uint64_t id;\n    Status   status = kCached;\n\n    BlockIds  blocks;\n    UniqueIds block_unique_ids;\n\n    int input_length = 0;  // the number of tokens to be processed in each forward iter\n\n    mutable std::vector<int> prompt;\n\n    mutable std::vector<int> tokens;  // update by user or when the sequence is finished\n\n    mutable int cache_len = 0;\n\n    // additional data kept round-to-round\n    mutable std::vector<std::byte> random_state;  // update by user\n\n    mutable float rope_theta = 0.f;\n\n    // embedding data\n    mutable std::vector<Tensor> input_embeds;\n    mutable std::vector<int>    input_embeds_offsets;\n\n    // Gated DeltaNet linear attention persistent states (e.g. Qwen3.5-MoE).\n    // Allocated on first request, preserved across requests for the same session,\n    // and freed automatically when the sequence is erased from the SequenceManager.\n    //   conv_states:      (num_linear_layers, conv_dim, d_conv) — per-channel rolling conv history\n    //   recurrent_states: (num_linear_layers, num_v_heads, key_head_dim, value_head_dim) — SSM state\n    mutable Tensor conv_states;\n    mutable Tensor recurrent_states;\n    mutable bool   linear_states_need_reset = false;\n\n    explicit Sequence(uint64_t _id): id(_id) {}\n\n    friend std::ostream& operator<<(std::ostream& os, const Sequence& seq);\n};\n\nusing Sequences = std::vector<const Sequence*>;\n\ninline std::ostream& operator<<(std::ostream& os, const Sequence& seq)\n{\n    os << \"id=\" << seq.id << \", status=\" << seq.status << \", token_count=\" << seq.tokens.size()\n       << \", block_count=\" << seq.blocks.size() << \", cache_len=\" << seq.cache_len\n       << \", random_state_size=\" << seq.random_state.size() << \", input_length=\" << seq.input_length;\n    return os;\n}\n\nclass SequenceManager {\npublic:\n    // clang-format off\n    struct BlockConfig {\n        int head_dim_;\n        int head_num_;\n        int block_len_;\n        int t_bits_;\n        int q_bits_;\n        bool share_kv_;\n        int t_bits() const { return t_bits_; }\n        int q_bits() const { return q_bits_; }\n        int head_dim() const { return head_dim_; }\n        int head_num() const { return head_num_; }\n        int block_len() const { return block_len_; }\n        bool is_share_kv() const { return share_kv_; }\n    };\n    // clang-format on\n\n    explicit SequenceManager(const ModelParam& model_param,\n                             DataType          runtime_dtype,\n                             int               cache_block_seq_len,\n                             int               attn_tp_size,\n                             int               max_batch_size,\n                             double            block_count,\n                             int               chunk_size,\n                             bool              enable_prefix_caching,\n                             int               rank,\n                             int               attn_cp_size,\n                             core::Allocator   allocator,\n                             GetFreeMemSize    get_free_size);\n\n    SequenceManager(const SequenceManager&)     = delete;\n    SequenceManager(SequenceManager&&) noexcept = default;\n\n    [[nodiscard]] const Sequence* Create(uint64_t id);\n\n    [[nodiscard]] const Sequence* Get(uint64_t id);\n\n    [[nodiscard]] bool Contains(uint64_t id);\n\n    [[nodiscard]] bool Erase(uint64_t id);\n\n    void AcquireLinearStateSlot(const Sequence& seq);\n\n    void ReleaseLinearStateSlot(const Sequence& seq);\n\n    void InvalidateStatesAndCache(const Sequence& seq);\n\n    void UpdateAndSetUnlock(const Sequence& seq);\n\n    struct Outcome {\n        int allocation;\n        int swap_in;\n        int swap_out;\n    };\n\n    using AdjustInputCount = std::function<int(const Sequences&, const std::vector<int>&)>;\n\n    //                50       1       0       50\n    //    context = seq_len + beta = cache + alpha + input\n    //     alpha' = input\n    //      beta' = int(is_gen)\n    //  -----------------------------------\n    //   seq_len += output\n    //     cache += input + output - 1  or  cache = seq_len - 1\n\n    [[maybe_unused]] Outcome Materialize(Sequences             sequences,\n                                         std::vector<int>      context_length,\n                                         std::vector<int>      alpha,\n                                         std::vector<uint64_t> priorities,\n                                         int                   max_fwd_tokens,\n                                         int                   max_tmp_tokens);\n\n    /** @brief cache the input prompt tokens of each seq in sequences[0:active_size-1]\n     *\n     * @param sequences The sequence list\n     * @param active_size the number of active sequences in the list\n     */\n    void CachePrompt(const Sequences& sequences, int active_size);\n\n    /** @brief cache the generated tokens of a given sequence\n     *\n     * @param sequence the given sequence\n     *\n     * @note This function can only be called after the sequence finish generation\n     * and all tokens including the prompt tokens and generated tokens have been put to\n     * `seq.tokens`\n     */\n    void CacheGeneration(const Sequence& sequence);\n\n    [[nodiscard]] void* GetBlockPtr(int block_id)\n    {\n        return block_manager_->block(block_id).data;\n    }\n\n    int max_block_count() const noexcept\n    {\n        return block_manager_->max_block_count();\n    }\n\n    int total_count() const noexcept\n    {\n        return block_manager_->total_count();\n    }\n\n    int active_count() const noexcept\n    {\n        return block_manager_->active_count();\n    }\n\n    int free_count() const noexcept\n    {\n        return block_manager_->free_count();\n    }\n\n    int cached_count() const noexcept\n    {\n        return block_manager_->cached_count();\n    }\n\n    // return #total_seq, #active_seq, #cached_seq\n    std::tuple<int, int, int> seq_stats() const noexcept;\n\nprivate:\n    void Erase(std::map<uint64_t, Sequence>::iterator& it);\n\n    void CommitUnlockAndFree();\n\n    void InvalidateStatesAndCache(const Sequence& seq, BlockIds& freed_blocks);\n\n    void VerifyAndLockCached(const Sequences& sequences);\n\n    std::vector<int> CountRequiredBlocks(const Sequences&        sequences,  //\n                                         const std::vector<int>& context_length);\n\n    static void AssignAndActivate(const Sequences&        sequences,  //\n                                  const std::vector<int>& counts,\n                                  const BlockIds&         blocks,\n                                  const UniqueIds&        unique_ids);\n\n    void PrefixMatch(Sequences& sequences, const std::vector<int>& alpha);\n\nprivate:\n    int block_seq_len_;\n    int rank_;\n    int attn_cp_size_;\n\n    // Use `std::map` to avoid reference invalidation\n    std::map<uint64_t, Sequence> sequences_;\n\n    std::shared_ptr<BlockManager> block_manager_;\n    std::shared_ptr<BlockTrie>    block_trie_;\n\n    Tensor                            pooled_conv_states_;\n    Tensor                            pooled_recurrent_states_;\n    std::vector<int>                  free_linear_state_slots_;\n    std::unordered_map<uint64_t, int> seq_to_linear_state_slot_;\n\n    BlockIds unlocked_;\n    BlockIds freed_;\n};\n\ninline std::ostream& operator<<(std::ostream& os, const SequenceManager::Outcome& oc)\n{\n    os << \"allocation: \" << oc.allocation << \", swap-in: \" << oc.swap_in << \", swap-out: \" << oc.swap_out;\n    return os;\n}\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/models/llama/bench_conv1d_silu.cc",
    "content": "\n#include <cmath>\n#include <cstdio>\n#include <cstdlib>\n#include <cstring>\n#include <functional>\n#include <vector>\n\n#include <cuda_runtime.h>\n\n#include \"src/turbomind/core/core.h\"\n#include \"src/turbomind/kernels/gemm/test/test_utils.h\"\n#include \"src/turbomind/models/llama/gated_delta_net_kernels.h\"\n\nusing namespace turbomind;\nusing namespace turbomind::core;\n\nstruct Args {\n    int      batch_size  = 32;\n    int      seq_len     = 1;\n    int      num_v_heads = 64;\n    int      num_k_heads = 16;\n    int      d_conv      = 4;\n    int      warmup      = 10;\n    int      iters       = 100;\n    DataType dtype       = kFloat16;\n\n    static DataType ParseDtype(const char* s)\n    {\n        if (strcmp(s, \"half\") == 0 || strcmp(s, \"fp16\") == 0)\n            return kFloat16;\n        if (strcmp(s, \"bf16\") == 0)\n            return kBfloat16;\n        fprintf(stderr, \"Unknown dtype: %s (expected half/fp16/bf16)\\n\", s);\n        exit(1);\n    }\n\n    static Args Parse(int argc, char** argv)\n    {\n        Args a;\n        for (int i = 1; i < argc; i += 2) {\n            if (i + 1 >= argc) {\n                fprintf(stderr, \"Missing value for %s\\n\", argv[i]);\n                exit(1);\n            }\n            if (strcmp(argv[i], \"--batch_size\") == 0)\n                a.batch_size = atoi(argv[i + 1]);\n            else if (strcmp(argv[i], \"--seq_len\") == 0)\n                a.seq_len = atoi(argv[i + 1]);\n            else if (strcmp(argv[i], \"--num_v_heads\") == 0)\n                a.num_v_heads = atoi(argv[i + 1]);\n            else if (strcmp(argv[i], \"--num_k_heads\") == 0)\n                a.num_k_heads = atoi(argv[i + 1]);\n            else if (strcmp(argv[i], \"--d_conv\") == 0)\n                a.d_conv = atoi(argv[i + 1]);\n            else if (strcmp(argv[i], \"--warmup\") == 0)\n                a.warmup = atoi(argv[i + 1]);\n            else if (strcmp(argv[i], \"--iters\") == 0)\n                a.iters = atoi(argv[i + 1]);\n            else if (strcmp(argv[i], \"--dtype\") == 0)\n                a.dtype = ParseDtype(argv[i + 1]);\n            else {\n                fprintf(stderr, \"Unknown arg: %s\\n\", argv[i]);\n                exit(1);\n            }\n        }\n        return a;\n    }\n\n    void Print() const\n    {\n        printf(\"batch_size=%d  seq_len=%d  num_v_heads=%d  num_k_heads=%d  d_conv=%d  \"\n               \"warmup=%d  iters=%d  dtype=%s\\n\",\n               batch_size,\n               seq_len,\n               num_v_heads,\n               num_k_heads,\n               d_conv,\n               warmup,\n               iters,\n               to_string(dtype));\n    }\n};\n\nstatic float\nbenchmark_kernel(const char* name, std::function<void()> launch, cudaStream_t stream, int warmup, int iters)\n{\n    for (int i = 0; i < warmup; ++i)\n        launch();\n    cudaStreamSynchronize(stream);\n\n    cudaEvent_t start, stop;\n    cudaEventCreate(&start);\n    cudaEventCreate(&stop);\n\n    cudaEventRecord(start, stream);\n    for (int i = 0; i < iters; ++i)\n        launch();\n    cudaEventRecord(stop, stream);\n    cudaEventSynchronize(stop);\n\n    float ms = 0;\n    cudaEventElapsedTime(&ms, start, stop);\n    float avg_ms = ms / iters;\n\n    printf(\"  %-45s  %8.3f ms (avg over %d iters)\\n\", name, avg_ms, iters);\n\n    cudaEventDestroy(start);\n    cudaEventDestroy(stop);\n    return avg_ms;\n}\n\n// CPU reference for depthwise causal conv1d + SiLU.\n//\n//   y(t, c) = SiLU( sum_{d=0}^{D-1} w(d, c) * x(t - D + 1 + d, c) )\n//\n// where x(i, c) falls back to the conv state for i < 0 (history from the\n// previous inference step).  After the sequence, the state is updated to the\n// last D-1 inputs.\n//\n// State is a ring buffer: slot j holds the input written at absolute time t\n// where t % d_conv == j.  history_len = k_offsets[b+1] - k_offsets[b] - seq_len.\n//\n// Weight layout: [d_conv, conv_dim].  State layout: [d_conv, conv_dim] per batch.\ntemplate<typename T>\nstatic void cpu_conv1d_silu(T*         h_out,\n                            const T*   h_in,\n                            const T*   h_weight,\n                            T*         h_state,\n                            const int* h_q_offsets,\n                            const int* h_k_offsets,\n                            int        batch_size,\n                            int        conv_dim,\n                            int        d_conv,\n                            int        in_stride)\n{\n    for (int b = 0; b < batch_size; ++b) {\n        const int seq_off     = h_q_offsets[b];\n        const int seq_len     = h_q_offsets[b + 1] - seq_off;\n        const int history_len = (h_k_offsets[b + 1] - h_k_offsets[b]) - seq_len;\n        T*        state       = h_state + b * d_conv * conv_dim;\n\n        auto x = [&](int i, int c) -> float {\n            if (i >= 0)\n                return static_cast<float>(h_in[(seq_off + i) * in_stride + c]);\n            int ring_idx = ((history_len + i) % d_conv + d_conv) % d_conv;\n            return static_cast<float>(state[ring_idx * conv_dim + c]);\n        };\n\n        for (int t = 0; t < seq_len; ++t) {\n            for (int c = 0; c < conv_dim; ++c) {\n                float acc = 0.f;\n                for (int d = 0; d < d_conv; ++d)\n                    acc += static_cast<float>(h_weight[d * conv_dim + c]) * x(t - d_conv + 1 + d, c);\n                h_out[(seq_off + t) * conv_dim + c] = static_cast<T>(acc / (1.f + std::exp(-acc)));\n            }\n        }\n\n        for (int d = 0; d < d_conv; ++d) {\n            int src = seq_len - d_conv + d;\n            if (src >= 0) {\n                int ring_d = (history_len + src) % d_conv;\n                for (int c = 0; c < conv_dim; ++c)\n                    state[ring_d * conv_dim + c] = h_in[(seq_off + src) * in_stride + c];\n            }\n        }\n    }\n}\n\nint main(int argc, char** argv)\n{\n    auto args = Args::Parse(argc, argv);\n    args.Print();\n\n    constexpr int kHeadDim = 128;\n\n    const int num_v_heads = args.num_v_heads;\n    const int num_k_heads = args.num_k_heads;\n    const int batch_size  = args.batch_size;\n    const int seq_len     = args.seq_len;\n    const int d_conv      = args.d_conv;\n\n    const int k_dim     = num_k_heads * kHeadDim;\n    const int v_dim     = num_v_heads * kHeadDim;\n    const int conv_dim  = 2 * k_dim + v_dim;\n    const int in_stride = conv_dim + v_dim + 2 * num_v_heads;\n    const int total_tok = batch_size * seq_len;\n\n    const int      conv_state_size = conv_dim * d_conv;\n    const DataType dtype           = args.dtype;\n    const auto     elem_bytes      = byte_size(dtype);\n\n    auto         stream = Stream::create();\n    ContextGuard ctx{stream, Allocator{kCPU}, Allocator{kCPUpinned}, Allocator{stream, false}};\n    cudaStream_t cu_stream = stream.handle();\n\n    int sm_count = 1;\n    {\n        int device = 0;\n        cudaGetDevice(&device);\n        cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, device);\n    }\n    Buffer_<int> work_counter{1, kDEVICE};\n\n    printf(\"\\nconv_dim=%d  d_conv=%d  in_stride=%d  total_tokens=%d\\n\", conv_dim, d_conv, in_stride, total_tok);\n\n    Tensor all_proj{Layout{{total_tok, in_stride}}, dtype, kDEVICE};\n    Tensor weight{Layout{{d_conv, conv_dim}}, dtype, kDEVICE};\n\n    Tensor out_ref{Layout{{total_tok, conv_dim}}, dtype, kDEVICE};\n    Tensor out_v2{Layout{{total_tok, conv_dim}}, dtype, kDEVICE};\n\n    Tensor state_ref{Layout{{batch_size, conv_state_size}}, dtype, kDEVICE};\n    Tensor state_v2{Layout{{batch_size, conv_state_size}}, dtype, kDEVICE};\n\n    Buffer_<void*> state_ptrs_v2_host{batch_size, kCPUpinned};\n    Buffer_<void*> state_ptrs_v2_dev{batch_size, kDEVICE};\n\n    Buffer_<int> q_offsets_host{batch_size + 1, kCPUpinned};\n    Buffer_<int> q_offsets_dev{batch_size + 1, kDEVICE};\n    Buffer_<int> k_offsets_dev{batch_size + 1, kDEVICE};\n\n    RNG rng;\n    rng.UniformFloat(all_proj, 0.1f);\n    rng.UniformFloat(weight, 0.1f);\n\n    for (int i = 0; i <= batch_size; ++i)\n        q_offsets_host.data()[i] = i * seq_len;\n    Copy(q_offsets_host, batch_size + 1, q_offsets_dev);\n    Copy(q_offsets_host, batch_size + 1, k_offsets_dev);  // no history in bench\n\n    for (int i = 0; i < batch_size; ++i) {\n        state_ptrs_v2_host.data()[i] = (char*)state_v2.raw_data() + i * conv_state_size * elem_bytes;\n    }\n    Copy(state_ptrs_v2_host, batch_size, state_ptrs_v2_dev);\n    stream.Sync();\n\n    auto launch_v2 = [&] {\n        invokeFusedConv1dSiLU(out_v2,\n                              all_proj,\n                              weight,\n                              Tensor{},\n                              state_ptrs_v2_dev,\n                              q_offsets_dev,\n                              k_offsets_dev,\n                              batch_size,\n                              0,\n                              sm_count,\n                              work_counter.data(),\n                              cu_stream);\n    };\n\n    // === Benchmark ===\n    printf(\"\\n=== Benchmark ===\\n\");\n    float v2_ms = benchmark_kernel(\"v2   (templated + vectorized)\", launch_v2, cu_stream, args.warmup, args.iters);\n\n    // === Bandwidth ===\n    {\n        double in_bytes         = (double)total_tok * conv_dim * elem_bytes;\n        double out_bytes        = (double)total_tok * conv_dim * elem_bytes;\n        double wt_bytes         = (double)conv_dim * d_conv * elem_bytes;\n        int    state_write_rows = std::min(seq_len, d_conv);\n        double state_rd_bytes   = (double)batch_size * d_conv * conv_dim * elem_bytes;\n        double state_wr_bytes   = (double)batch_size * state_write_rows * conv_dim * elem_bytes;\n        double state_bytes      = state_rd_bytes + state_wr_bytes;\n        double total_bytes      = in_bytes + out_bytes + wt_bytes + state_bytes;\n\n        printf(\"\\n=== Bandwidth ===\\n\");\n        printf(\"  in:     %.1f MB\\n\", in_bytes / 1e6);\n        printf(\"  out:    %.1f MB\\n\", out_bytes / 1e6);\n        printf(\"  weight: %.3f MB\\n\", wt_bytes / 1e6);\n        printf(\"  state:  %.1f MB  (R %.1f + W %.1f)\\n\", state_bytes / 1e6, state_rd_bytes / 1e6, state_wr_bytes / 1e6);\n        printf(\"  total:  %.1f MB\\n\", total_bytes / 1e6);\n        printf(\"  v2  BW: %.1f GB/s\\n\", total_bytes / (v2_ms * 1e6));\n    }\n\n    // === Cross-comparison (correctness): CPU ref vs GPU v2 ===\n    printf(\"\\n=== Cross-comparison (CPU ref vs GPU v2) ===\\n\");\n\n    Clear(state_ref);\n    Clear(state_v2);\n    Clear(out_ref);\n    Clear(out_v2);\n    stream.Sync();\n\n    // Run GPU kernel\n    launch_v2();\n    stream.Sync();\n\n    // Run CPU reference\n    {\n        const size_t in_bytes    = (size_t)total_tok * in_stride * elem_bytes;\n        const size_t wt_bytes    = (size_t)d_conv * conv_dim * elem_bytes;\n        const size_t state_bytes = (size_t)batch_size * conv_state_size * elem_bytes;\n        const size_t out_bytes   = (size_t)total_tok * conv_dim * elem_bytes;\n\n        std::vector<char> h_in(in_bytes), h_wt(wt_bytes), h_state(state_bytes), h_out(out_bytes);\n\n        cudaMemcpy(h_in.data(), all_proj.raw_data(), in_bytes, cudaMemcpyDeviceToHost);\n        cudaMemcpy(h_wt.data(), weight.raw_data(), wt_bytes, cudaMemcpyDeviceToHost);\n        std::memset(h_state.data(), 0, state_bytes);\n        std::memset(h_out.data(), 0, out_bytes);\n\n        auto run_cpu = [&](auto t) {\n            using T = decltype(t);\n            cpu_conv1d_silu((T*)h_out.data(),\n                            (const T*)h_in.data(),\n                            (const T*)h_wt.data(),\n                            (T*)h_state.data(),\n                            q_offsets_host.data(),\n                            q_offsets_host.data(),  // k_offsets == q_offsets (no history in bench)\n                            batch_size,\n                            conv_dim,\n                            d_conv,\n                            in_stride);\n        };\n\n        if (dtype == kFloat16)\n            run_cpu(half{});\n        else\n            run_cpu(nv_bfloat16{});\n\n        cudaMemcpy(out_ref.raw_data(), h_out.data(), out_bytes, cudaMemcpyHostToDevice);\n        cudaMemcpy(state_ref.raw_data(), h_state.data(), state_bytes, cudaMemcpyHostToDevice);\n    }\n\n    printf(\"  output comparison:\\n\");\n    FC_Header();\n    FC_Print(FastCompare(out_ref, out_v2, cu_stream));\n\n    printf(\"  state comparison:\\n\");\n    FC_Header();\n    FC_Print(FastCompare(state_ref, state_v2, cu_stream));\n\n    printf(\"\\nDone.\\n\");\n    return 0;\n}\n"
  },
  {
    "path": "src/turbomind/models/llama/bench_gated_delta_net.cc",
    "content": "\n#include <cstdio>\n#include <cstdlib>\n#include <cstring>\n\n#include <cuda_runtime.h>\n\n#include \"src/turbomind/core/core.h\"\n#include \"src/turbomind/kernels/gemm/test/test_utils.h\"\n#include \"src/turbomind/models/llama/gated_delta_net_kernels.h\"\n\nusing namespace turbomind;\nusing namespace turbomind::core;\n\nstruct Args {\n    int      batch_size  = 32;\n    int      seq_len     = 64;\n    int      num_v_heads = 16;\n    int      num_k_heads = 4;\n    int      warmup      = 10;\n    int      iters       = 100;\n    DataType dtype       = kFloat16;\n    DataType state_dtype = kFloat32;\n\n    static DataType ParseDtype(const char* s)\n    {\n        if (strcmp(s, \"half\") == 0 || strcmp(s, \"fp16\") == 0)\n            return kFloat16;\n        if (strcmp(s, \"bf16\") == 0)\n            return kBfloat16;\n        if (strcmp(s, \"fp32\") == 0 || strcmp(s, \"float\") == 0)\n            return kFloat32;\n        fprintf(stderr, \"Unknown dtype: %s (expected half/fp16/bf16/fp32/float)\\n\", s);\n        exit(1);\n    }\n\n    static Args Parse(int argc, char** argv)\n    {\n        Args a;\n        for (int i = 1; i < argc; i += 2) {\n            if (i + 1 >= argc) {\n                fprintf(stderr, \"Missing value for %s\\n\", argv[i]);\n                exit(1);\n            }\n            if (strcmp(argv[i], \"--batch_size\") == 0)\n                a.batch_size = atoi(argv[i + 1]);\n            else if (strcmp(argv[i], \"--seq_len\") == 0)\n                a.seq_len = atoi(argv[i + 1]);\n            else if (strcmp(argv[i], \"--num_v_heads\") == 0)\n                a.num_v_heads = atoi(argv[i + 1]);\n            else if (strcmp(argv[i], \"--num_k_heads\") == 0)\n                a.num_k_heads = atoi(argv[i + 1]);\n            else if (strcmp(argv[i], \"--warmup\") == 0)\n                a.warmup = atoi(argv[i + 1]);\n            else if (strcmp(argv[i], \"--iters\") == 0)\n                a.iters = atoi(argv[i + 1]);\n            else if (strcmp(argv[i], \"--dtype\") == 0)\n                a.dtype = ParseDtype(argv[i + 1]);\n            else if (strcmp(argv[i], \"--state_dtype\") == 0)\n                a.state_dtype = ParseDtype(argv[i + 1]);\n            else {\n                fprintf(stderr, \"Unknown arg: %s\\n\", argv[i]);\n                exit(1);\n            }\n        }\n        return a;\n    }\n\n    void Print() const\n    {\n        printf(\n            \"batch_size=%d  seq_len=%d  num_v_heads=%d  num_k_heads=%d  warmup=%d  iters=%d  dtype=%s  state_dtype=%s\\n\",\n            batch_size,\n            seq_len,\n            num_v_heads,\n            num_k_heads,\n            warmup,\n            iters,\n            to_string(dtype),\n            to_string(state_dtype));\n    }\n};\n\nstatic float\nbenchmark_kernel(const char* name, std::function<void()> launch, cudaStream_t stream, int warmup, int iters)\n{\n    for (int i = 0; i < warmup; ++i)\n        launch();\n    cudaStreamSynchronize(stream);\n\n    cudaEvent_t start, stop;\n    cudaEventCreate(&start);\n    cudaEventCreate(&stop);\n\n    cudaEventRecord(start, stream);\n    for (int i = 0; i < iters; ++i)\n        launch();\n    cudaEventRecord(stop, stream);\n    cudaEventSynchronize(stop);\n\n    float ms = 0;\n    cudaEventElapsedTime(&ms, start, stop);\n    float avg_ms = ms / iters;\n\n    printf(\"  %-45s  %8.3f ms (avg over %d iters)\\n\", name, avg_ms, iters);\n\n    cudaEventDestroy(start);\n    cudaEventDestroy(stop);\n    return avg_ms;\n}\n\nint main(int argc, char** argv)\n{\n    auto args = Args::Parse(argc, argv);\n    args.Print();\n\n    constexpr int kHeadDim = 128;\n\n    const int num_v_heads = args.num_v_heads;\n    const int num_k_heads = args.num_k_heads;\n    const int batch_size  = args.batch_size;\n    const int seq_len     = args.seq_len;\n\n    const int k_dim     = num_k_heads * kHeadDim;\n    const int v_dim     = num_v_heads * kHeadDim;\n    const int conv_dim  = 2 * k_dim + v_dim;\n    const int total_tok = batch_size * seq_len;\n\n    const int state_size = num_v_heads * kHeadDim * kHeadDim;  // per request\n\n    const DataType dtype       = args.dtype;\n    const DataType state_dtype = args.state_dtype;\n\n    // --- Context setup ---\n    auto         stream = Stream::create();\n    ContextGuard ctx{stream, Allocator{kCPU}, Allocator{kCPUpinned}, Allocator{stream, false}};\n    cudaStream_t cu_stream = stream.handle();\n\n    const bool is_decode = (seq_len == 1);\n\n    // --- Allocate tensors ---\n    Tensor qkv_in{Layout{{total_tok, conv_dim}}, dtype, kDEVICE};\n    Tensor v_out_v2{Layout{{total_tok, v_dim}}, dtype, kDEVICE};\n    Tensor v_out_chunked{Layout{{total_tok, v_dim}}, dtype, kDEVICE};\n    Tensor v_out_v3{Layout{{total_tok, v_dim}}, dtype, kDEVICE};\n    Tensor beta{Layout{{total_tok, num_v_heads}}, dtype, kDEVICE};\n    Tensor g{Layout{{total_tok, num_v_heads}}, dtype, kDEVICE};\n\n    // State buffers — all three kernels use state_dtype\n    Tensor state_v2{Layout{{batch_size, state_size}}, state_dtype, kDEVICE};\n    Tensor state_chunked{Layout{{batch_size, state_size}}, state_dtype, kDEVICE};\n    Tensor state_v3{Layout{{batch_size, state_size}}, state_dtype, kDEVICE};\n\n    // State pointer arrays: host pinned + device\n    Buffer_<void*> state_ptrs_v2_host{batch_size, kCPUpinned};\n    Buffer_<void*> state_ptrs_v2_dev{batch_size, kDEVICE};\n    Buffer_<void*> state_ptrs_chunked_host{batch_size, kCPUpinned};\n    Buffer_<void*> state_ptrs_chunked_dev{batch_size, kDEVICE};\n    Buffer_<void*> state_ptrs_v3_host{batch_size, kCPUpinned};\n    Buffer_<void*> state_ptrs_v3_dev{batch_size, kDEVICE};\n\n    // q_offsets: host + device\n    Buffer_<int> q_offsets_host{batch_size + 1, kCPUpinned};\n    Buffer_<int> q_offsets_dev{batch_size + 1, kDEVICE};\n\n    // --- Fill random data ---\n    RNG rng;\n    rng.UniformFloat(qkv_in, 0.1f);\n    rng.UniformFloat(beta, 1.0f);        // will be passed through sigmoid inside kernel\n    rng.UniformFloat(g, 0.02f, -0.01f);  // small values around 0\n    Clear(state_v2);\n    Clear(state_chunked);\n    Clear(state_v3);\n\n    // --- Build q_offsets ---\n    for (int i = 0; i <= batch_size; ++i)\n        q_offsets_host.data()[i] = i * seq_len;\n    Copy(q_offsets_host, batch_size + 1, q_offsets_dev);\n\n    // --- Build state_ptrs ---\n    const auto state_elem_bytes    = byte_size(state_dtype);\n    const auto state_elem_bytes_v3 = byte_size(state_dtype);\n    for (int i = 0; i < batch_size; ++i) {\n        state_ptrs_v2_host.data()[i]      = (char*)state_v2.raw_data() + i * state_size * state_elem_bytes;\n        state_ptrs_chunked_host.data()[i] = (char*)state_chunked.raw_data() + i * state_size * state_elem_bytes;\n        state_ptrs_v3_host.data()[i]      = (char*)state_v3.raw_data() + i * state_size * state_elem_bytes_v3;\n    }\n    Copy(state_ptrs_v2_host, batch_size, state_ptrs_v2_dev);\n    Copy(state_ptrs_chunked_host, batch_size, state_ptrs_chunked_dev);\n    Copy(state_ptrs_v3_host, batch_size, state_ptrs_v3_dev);\n    stream.Sync();\n\n    // Shared resources for all three kernel launchers\n    int sm_count = 1;\n    {\n        int device = 0;\n        cudaGetDevice(&device);\n        cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, device);\n    }\n    Buffer_<int> work_counter_buf{1, kDEVICE};\n    int*         work_counter = work_counter_buf.data();\n\n    // --- Benchmark recurrent (v2) kernel ---\n    printf(\"\\n=== Benchmarks ===\\n\");\n    auto launch_v2 = [&] {\n        invokeGatedDeltaRuleBatched_v2(v_out_v2,\n                                       qkv_in,\n                                       beta,\n                                       g,\n                                       state_ptrs_v2_dev,\n                                       q_offsets_dev,\n                                       batch_size,\n                                       num_k_heads,\n                                       0,\n                                       state_dtype,\n                                       sm_count,\n                                       work_counter,\n                                       cu_stream);\n    };\n    float v2_ms = benchmark_kernel(\"invokeGatedDeltaRuleBatched_v2\", launch_v2, cu_stream, args.warmup, args.iters);\n\n    // --- Benchmark chunked kernel ---\n    auto launch_chunked = [&] {\n        invokeChunkedGatedDeltaRuleBatched(v_out_chunked,\n                                           qkv_in,\n                                           beta,\n                                           g,\n                                           state_ptrs_chunked_dev,\n                                           q_offsets_dev,\n                                           batch_size,\n                                           num_k_heads,\n                                           0,\n                                           state_dtype,\n                                           sm_count,\n                                           work_counter,\n                                           cu_stream);\n    };\n    float chunked_ms =\n        benchmark_kernel(\"invokeChunkedGatedDeltaRuleBatched\", launch_chunked, cu_stream, args.warmup, args.iters);\n\n    // --- Benchmark v3 persistent decode kernel (seq_len == 1 only) ---\n    float v3_ms     = -1.f;\n    auto  launch_v3 = [&] {\n        invokeGatedDeltaRuleBatched_v3(v_out_v3,\n                                       qkv_in,\n                                       beta,\n                                       g,\n                                       state_ptrs_v3_dev,\n                                       q_offsets_dev,\n                                       batch_size,\n                                       num_k_heads,\n                                       0,\n                                       state_dtype,\n                                       sm_count,\n                                       work_counter,\n                                       cu_stream);\n    };\n    if (is_decode) {\n        v3_ms = benchmark_kernel(\n            \"invokeGatedDeltaRuleBatched_v3 (persistent)\", launch_v3, cu_stream, args.warmup, args.iters);\n    }\n    else {\n        printf(\"  %-45s  (skipped — seq_len > 1)\\n\", \"invokeGatedDeltaRuleBatched_v3 (persistent)\");\n    }\n\n    printf(\"\\n  Speedup v2 / chunked:  %.2fx\\n\", v2_ms / chunked_ms);\n    if (is_decode)\n        printf(\"  Speedup v2 / v3:       %.2fx\\n\", v2_ms / v3_ms);\n\n    // --- Bandwidth stats ---\n    {\n        double state_bytes    = (double)batch_size * state_size * state_elem_bytes * 2.0;\n        double state_bytes_v3 = (double)batch_size * state_size * state_elem_bytes_v3 * 2.0;\n        printf(\"\\n=== Bandwidth ===\\n\");\n        printf(\"  v2:      state BW = %.1f GB/s\\n\", state_bytes / (v2_ms * 1e6));\n        printf(\"  chunked: state BW = %.1f GB/s\\n\", state_bytes / (chunked_ms * 1e6));\n        if (is_decode)\n            printf(\"  v3:      state BW = %.1f GB/s\\n\", state_bytes_v3 / (v3_ms * 1e6));\n        printf(\"  total_tokens = %d\\n\", total_tok);\n    }\n\n    // === Cross-comparison: run both kernels on identical input, compare outputs ===\n    printf(\"\\n=== Cross-comparison (v2 vs chunked) ===\\n\");\n\n    // Reset states to identical initial values (zero)\n    Clear(state_v2);\n    Clear(state_chunked);\n    Clear(v_out_v2);\n    Clear(v_out_chunked);\n    stream.Sync();\n\n    // Single invocation of each kernel\n    launch_v2();\n    launch_chunked();\n    stream.Sync();\n\n    // Compare v_out\n    printf(\"  v_out comparison:\\n\");\n    FC_Header();\n    auto v_out_stats = FastCompare(v_out_v2, v_out_chunked, cu_stream);\n    FC_Print(v_out_stats);\n\n    // Compare final states\n    printf(\"  state comparison:\\n\");\n    FC_Header();\n    auto state_stats = FastCompare(state_v2, state_chunked, cu_stream);\n    FC_Print(state_stats);\n\n    // === Cross-comparison: v2 vs v3 (decode only) ===\n    if (is_decode) {\n        printf(\"\\n=== Cross-comparison (v2 vs v3, state_dtype=%s) ===\\n\", to_string(state_dtype));\n\n        Clear(state_v2);\n        Clear(state_v3);\n        Clear(v_out_v2);\n        Clear(v_out_v3);\n        stream.Sync();\n\n        launch_v2();\n        launch_v3();\n        stream.Sync();\n\n        printf(\"  v_out comparison:\\n\");\n        FC_Header();\n        FC_Print(FastCompare(v_out_v2, v_out_v3, cu_stream));\n\n        printf(\"  state comparison:\\n\");\n        FC_Header();\n        FC_Print(FastCompare(state_v2, state_v3, cu_stream));\n    }\n\n    printf(\"\\nDone.\\n\");\n    return 0;\n}\n"
  },
  {
    "path": "src/turbomind/models/llama/context.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include <memory>\n\n#include <cuda_runtime.h>\n\n#include <cublasLt.h>\n#include <cublas_v2.h>\n\n#include \"src/turbomind/comm/device_comm.h\"\n#include \"src/turbomind/core/context.h\"\n#include \"src/turbomind/core/core.h\"\n#include \"src/turbomind/models/llama/LlamaLinear.h\"\n\nnamespace turbomind {\n\nstruct Communicators {\n    comm::HostComm h_global;\n    comm::HostComm h_comm;\n    comm::HostComm h_tp_group;\n    comm::HostComm h_dp_group;\n\n    comm::DeviceComm d_comm;\n    int              d_tp_group;\n    int              d_cp_group;\n};\n\n// Execution context for the model\nstruct Context {\n    core::Stream                 core_stream;\n    core::Allocator              allocator;\n    cudaStream_t                 stream;\n    std::unique_ptr<LlamaLinear> linear;\n    cudaDeviceProp               device_prop;\n    Communicators                comm;  // initialize later\n    std::unique_ptr<int>         is_warm_up;\n\n    Context(int device_id):\n        core_stream{core::Stream::create()},\n        allocator{core::Allocator(core_stream, false)},\n        stream{core_stream.handle()},\n        comm{},  // value initialize\n        is_warm_up{std::make_unique<int>()}\n    {\n        core::ContextGuard guard{core_stream};\n        linear = std::make_unique<LlamaLinear>();\n        check_cuda_error(cudaGetDeviceProperties(&device_prop, device_id));\n    }\n};\n\ninline Allocator GetSymmAllocator(const comm::DeviceComm& comm)\n{\n    TM_CHECK(comm);\n    return core::SimpleAllocator::Create(\n        [&comm](auto size) {\n            auto p = comm->Allocate(size);\n            comm->Register(p, size);\n            return p;\n        },\n        [&comm](void* p, auto size) {\n            comm->Deregister(p);\n            comm->Free(p);\n        },\n        kDEVICE);\n}\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/models/llama/gated_delta_net_kernels.cu",
    "content": "\n#include \"src/turbomind/core/data_type.h\"\n#include \"src/turbomind/kernels/core/common.h\"\n#include \"src/turbomind/models/llama/gated_delta_net_kernels.h\"\n\n#include <algorithm>\n#include <cmath>\n#include <cuda_bf16.h>\n#include <type_traits>\n\n#include \"src/turbomind/utils/cuda_utils.h\"\n\n#include \"src/turbomind/kernels/core/array_ops.h\"\n#include \"src/turbomind/kernels/core/layout.h\"\n#include \"src/turbomind/kernels/gemm/thread_map.h\"\n\nnamespace turbomind {\n\nusing namespace gemm;\n\ntemplate<int k_head_dim, int v_head_dim, int block_dim, class T, class S>\n__global__ void recurrent_gated_delta_rule_kernel_v2(T*         v_out,\n                                                     const T*   qkv_in,\n                                                     const T*   beta_in,\n                                                     const T*   g_in,\n                                                     S* const*  state_ptrs,\n                                                     const int* q_offsets,\n                                                     int        num_v_heads,\n                                                     int        num_k_heads,\n                                                     int        k_dim_total,\n                                                     int        state_layer_offset)\n{\n    const int bh    = blockIdx.x;\n    const int b     = bh / num_v_heads;\n    const int h     = bh % num_v_heads;\n    const int ratio = num_v_heads / num_k_heads;\n    const int kh    = h / ratio;\n\n    const int tok_off    = q_offsets[b];\n    const int seq_len    = q_offsets[b + 1] - tok_off;\n    const int state_size = k_head_dim * v_head_dim;\n    const int conv_dim   = 2 * k_dim_total + num_v_heads * v_head_dim;\n    const int v_dim      = num_v_heads * v_head_dim;\n\n    S* s_ptr = state_ptrs[b] + state_layer_offset + h * state_size;\n\n    const float scale = rsqrtf((float)k_head_dim);\n\n    // DimC = v_head_dim (memory-contiguous), DimS = k_head_dim (strided)\n    using Map_S = ThreadMap_V2<v_head_dim, k_head_dim, sizeof(uint4) / sizeof(S), Raked, block_dim / WARP_SIZE>;\n\n    extern __shared__ __align__(16) char smem_buf[];\n\n    // XOR swizzle: bits [10,13] (offset_k) XOR into column access-group index\n    constexpr int kBase  = (sizeof(S) == 4) ? 2 : 3;  // log2(kAccessC)\n    constexpr int kShift = 10 - kBase;\n    using Layout         = SmemLayoutV2<k_head_dim, v_head_dim, -1, -1, Swizzle<4, kBase, kShift>>;\n    SmemAccessor<S, Layout> smem_S{(S*)smem_buf};\n\n    const int warp_id = threadIdx.x / WARP_SIZE;\n    const int lane_id = threadIdx.x % WARP_SIZE;\n\n    constexpr int tile_k = 16;\n    constexpr int tile_v = 4;\n\n    constexpr int k_tiles = k_head_dim / tile_k;  // 8\n    constexpr int v_tiles = v_head_dim / tile_v;  // 32\n\n    constexpr int k_threads = k_tiles;\n    constexpr int v_threads = block_dim / k_threads;\n\n    constexpr int v_iters = cdiv(v_tiles, v_threads);\n\n    Array<float, tile_v> vec_S[v_iters][tile_k];\n\n    const int offset_k = threadIdx.x % k_tiles;\n    const int offset_v = threadIdx.x / k_tiles;\n\n    constexpr int kAccessC = Map_S::kAccessC;\n\n    PRAGMA_UNROLL\n    for (int s = 0; s < Map_S::kIterS; ++s) {\n        Array<S, kAccessC> vec;\n        PRAGMA_UNROLL\n        for (int c = 0; c < Map_S::kIterC; ++c) {\n            const auto [vd, kd] = Map_S::get_offset(warp_id, lane_id);\n            const int final_vd  = vd + c * Map_S::kDeltaC;\n            const int final_kd  = kd + s * Map_S::kDeltaS;\n            Load(vec, s_ptr + final_kd * v_head_dim + final_vd);\n            Store(&smem_S(final_kd, final_vd), vec);\n        }\n    }\n\n    __syncthreads();\n\n    PRAGMA_UNROLL\n    for (int v_iter = 0; v_iter < v_iters; ++v_iter) {\n        PRAGMA_UNROLL\n        for (int k = 0; k < tile_k; ++k) {\n            constexpr int kTileAccessC = (tile_v >= kAccessC) ? kAccessC : tile_v;\n            static_assert(tile_v % kTileAccessC == 0);\n            PRAGMA_UNROLL\n            for (int c = 0; c < tile_v / kTileAccessC; ++c) {\n                Array<S, kTileAccessC> tmp;\n                Load(tmp, &smem_S(offset_k * tile_k + k, (offset_v + v_iter * v_threads) * tile_v + c * kTileAccessC));\n                (Array<float, kTileAccessC>&)vec_S[v_iter][k][c * kTileAccessC] = cast<float>(tmp);\n            }\n        }\n    }\n\n    for (int t = 0; t < seq_len; ++t) {\n        const int global_t = tok_off + t;\n\n        const T* q_ptr = qkv_in + global_t * conv_dim + kh * k_head_dim;\n        const T* k_ptr = qkv_in + global_t * conv_dim + k_dim_total + kh * k_head_dim;\n        const T* v_ptr = qkv_in + global_t * conv_dim + 2 * k_dim_total + h * v_head_dim;\n        T*       o_ptr = v_out + global_t * v_dim + h * v_head_dim;\n\n        const float beta_val = (float)beta_in[global_t * num_v_heads + h];\n        const float decay    = expf((float)g_in[global_t * num_v_heads + h]);\n\n        Array<float, tile_k> vec_K;\n        Array<float, tile_k> vec_Q;\n\n        // --- In-kernel L2-normalize K/Q (Vectorized) ---\n        {\n            {\n                Array<T, tile_k> tmp_K;\n                Array<T, tile_k> tmp_Q;\n                Load(tmp_K, &k_ptr[offset_k * tile_k]);\n                Load(tmp_Q, &q_ptr[offset_k * tile_k]);\n                vec_K = cast<float>(tmp_K);\n                vec_Q = cast<float>(tmp_Q);\n            }\n\n            float k_sum = 0.f;\n            float q_sum = 0.f;\n\n            PRAGMA_UNROLL\n            for (int k = 0; k < tile_k; ++k) {\n                k_sum += vec_K[k] * vec_K[k];\n                q_sum += vec_Q[k] * vec_Q[k];\n            }\n\n            PRAGMA_UNROLL\n            for (int mask = k_threads / 2; mask > 0; mask /= 2) {\n                k_sum += __shfl_xor_sync(0xffffffff, k_sum, mask);\n                q_sum += __shfl_xor_sync(0xffffffff, q_sum, mask);\n            }\n\n            const float k_inv_norm = rsqrtf(k_sum + 1e-6f);\n            const float q_inv_norm = rsqrtf(q_sum + 1e-6f);\n\n            PRAGMA_UNROLL\n            for (int i = 0; i < tile_k; ++i) {\n                vec_K[i] = vec_K[i] * k_inv_norm;\n                vec_Q[i] = vec_Q[i] * q_inv_norm;\n            }\n        }\n\n        // Precompute KQ = dot(K, Q) — invariant across all v elements\n        float KQ = 0.f;\n        PRAGMA_UNROLL\n        for (int k = 0; k < tile_k; ++k)\n            KQ += vec_K[k] * vec_Q[k];\n        PRAGMA_UNROLL\n        for (int mask = k_threads / 2; mask > 0; mask /= 2)\n            KQ += __shfl_xor_sync(0xffffffff, KQ, mask);\n\n        Array<T, tile_v> vec_V[v_iters];\n\n        PRAGMA_UNROLL\n        for (int v_iter = 0; v_iter < v_iters; ++v_iter) {\n            Load(vec_V[v_iter], &v_ptr[(offset_v + v_iter * v_threads) * tile_v]);\n        }\n\n        PRAGMA_UNROLL\n        for (int v_iter = 0; v_iter < v_iters; ++v_iter) {\n            Array<T, tile_v> vec_O;\n            PRAGMA_UNROLL\n            for (int v = 0; v < tile_v; ++v) {\n                // Fused: decay + dual dot product (kv_mem and SQ simultaneously)\n                float kv_mem = 0.f, SQ = 0.f;\n                PRAGMA_UNROLL\n                for (int k = 0; k < tile_k; ++k) {\n                    float s_decayed     = vec_S[v_iter][k][v] * decay;\n                    vec_S[v_iter][k][v] = s_decayed;\n                    kv_mem += s_decayed * vec_K[k];\n                    SQ += s_decayed * vec_Q[k];\n                }\n\n                // Single interleaved reduction (2 independent values -> good ILP)\n                PRAGMA_UNROLL\n                for (int mask = k_threads / 2; mask > 0; mask /= 2) {\n                    kv_mem += __shfl_xor_sync(0xffffffff, kv_mem, mask);\n                    SQ += __shfl_xor_sync(0xffffffff, SQ, mask);\n                }\n\n                const float delta = ((float)vec_V[v_iter][v] - kv_mem) * beta_val;\n\n                // State update\n                PRAGMA_UNROLL\n                for (int k = 0; k < tile_k; ++k) {\n                    vec_S[v_iter][k][v] += vec_K[k] * delta;\n                }\n\n                // Output: algebraic computation, NO reduction needed\n                vec_O[v] = static_cast<T>((SQ + delta * KQ) * scale);\n            }\n            if (offset_k == 0)\n                Store(&o_ptr[(offset_v + v_iter * v_threads) * tile_v], vec_O);\n        }\n    }\n\n    __syncthreads();\n\n    PRAGMA_UNROLL\n    for (int v_iter = 0; v_iter < v_iters; ++v_iter) {\n        PRAGMA_UNROLL\n        for (int k = 0; k < tile_k; ++k) {\n            constexpr int kTileAccessC = (tile_v >= kAccessC) ? kAccessC : tile_v;\n            PRAGMA_UNROLL\n            for (int c = 0; c < tile_v / kTileAccessC; ++c) {\n                auto tmp = cast<S>((Array<float, kTileAccessC>&)vec_S[v_iter][k][c * kTileAccessC]);\n                Store(&smem_S(offset_k * tile_k + k, (offset_v + v_iter * v_threads) * tile_v + c * kTileAccessC), tmp);\n            }\n        }\n    }\n\n    __syncthreads();\n\n    PRAGMA_UNROLL\n    for (int s = 0; s < Map_S::kIterS; ++s) {\n        Array<S, Map_S::kAccessC> vec;\n        PRAGMA_UNROLL\n        for (int c = 0; c < Map_S::kIterC; ++c) {\n            const auto [vd, kd] = Map_S::get_offset(warp_id, lane_id);\n            const int final_vd  = vd + c * Map_S::kDeltaC;\n            const int final_kd  = kd + s * Map_S::kDeltaS;\n            Load(vec, &smem_S(final_kd, final_vd));\n            Store(s_ptr + final_kd * v_head_dim + final_vd, vec);\n        }\n    }\n}\n\nvoid invokeGatedDeltaRuleBatched_v2(Ref<Tensor>           v_out_,\n                                    const Tensor&         qkv_in,\n                                    const Tensor&         beta,\n                                    const Tensor&         g,\n                                    const Buffer_<void*>& state_ptrs,\n                                    const Buffer_<int>&   q_offsets,\n                                    int                   batch_size,\n                                    int                   num_k_heads,\n                                    int                   state_layer_offset,\n                                    DataType              state_dtype,\n                                    int /*sm_count*/,\n                                    int* /*work_counter*/,\n                                    cudaStream_t stream)\n{\n    auto& v_out = v_out_.get();\n\n    const int num_v_heads    = beta.shape(1);\n    const int v_dim          = v_out.shape(1);\n    const int value_head_dim = v_dim / num_v_heads;\n    const int k_dim_total    = (qkv_in.shape(1) - v_dim) / 2;\n\n    if (batch_size == 0 || num_v_heads == 0)\n        return;\n\n    constexpr int kHeadDim  = 128;\n    constexpr int kBlockDim = 256;\n\n    TM_CHECK_EQ(value_head_dim, kHeadDim);\n    TM_CHECK_EQ(k_dim_total / num_k_heads, kHeadDim);\n\n    const int num_blocks = batch_size * num_v_heads;\n\n    auto invoke = [&](auto t) {\n        using T     = decltype(t);\n        auto launch = [&](auto s) {\n            using S = decltype(s);\n\n            auto kernel = recurrent_gated_delta_rule_kernel_v2<kHeadDim, kHeadDim, kBlockDim, T, S>;\n\n            const size_t smem_sz = kHeadDim * kHeadDim * sizeof(S);\n            if (smem_sz > 48 << 10) {\n                cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_sz);\n            }\n\n            kernel<<<num_blocks, kBlockDim, smem_sz, stream>>>(v_out.data<T>(),\n                                                               qkv_in.data<T>(),\n                                                               beta.data<T>(),\n                                                               g.data<T>(),\n                                                               (S* const*)state_ptrs.data(),\n                                                               q_offsets.data(),\n                                                               num_v_heads,\n                                                               num_k_heads,\n                                                               k_dim_total,\n                                                               state_layer_offset);\n        };\n        if (state_dtype == kFloat32) {\n            launch(float{});\n        }\n        else {\n            launch(T{});\n        }\n    };\n    TM_DISPATCH_PRIMARY_DTYPES(v_out.dtype(), invoke);\n}\n\n// =============================================================================\n// Recurrent Gated Delta Rule — Persistent decode kernel (seq_len == 1 only).\n//\n// Designed for large-batch decode (e.g., bs=1024, 64 heads = 65536 work-items).\n// Instead of launching one block per (b, h) pair, we launch only as many blocks\n// as can be simultaneously resident (determined via the CUDA occupancy API), and\n// each block iterates over multiple (b, h) work-items in a persistent loop.\n//\n// State is loaded/stored directly between global memory and registers (no smem\n// staging), eliminating all __syncthreads() from the loop body. Each thread\n// owns a [tile_k, tile_v] register tile and issues strided 8-byte tile loads\n// directly from global memory. smem_sz = 0 in the host launcher.\n// =============================================================================\ntemplate<int k_head_dim, int v_head_dim, int block_dim, class T, class S>\n__global__ __launch_bounds__(block_dim, 2) void recurrent_gated_delta_rule_kernel_v3(T*         v_out,\n                                                                                     const T*   qkv_in,\n                                                                                     const T*   beta_in,\n                                                                                     const T*   g_in,\n                                                                                     S* const*  state_ptrs,\n                                                                                     const int* q_offsets,\n                                                                                     int*       work_counter,\n                                                                                     int        total_work,\n                                                                                     int        num_v_heads,\n                                                                                     int        num_k_heads,\n                                                                                     int        k_dim_total,\n                                                                                     int        state_layer_offset)\n{\n    constexpr int state_size = k_head_dim * v_head_dim;\n    const int     conv_dim   = 2 * k_dim_total + num_v_heads * v_head_dim;\n    const int     v_dim      = num_v_heads * v_head_dim;\n    const float   scale      = rsqrtf((float)k_head_dim);\n\n    // Compile-time thread partition (identical to v2)\n    constexpr int tile_k    = 16;\n    constexpr int tile_v    = 4;\n    constexpr int k_tiles   = k_head_dim / tile_k;\n    constexpr int v_tiles   = v_head_dim / tile_v;\n    constexpr int k_threads = k_tiles;\n    constexpr int v_threads = block_dim / k_threads;\n    constexpr int v_iters   = cdiv(v_tiles, v_threads);\n\n    const int offset_k = threadIdx.x % k_tiles;\n    const int offset_v = threadIdx.x / k_tiles;\n\n    // Persistent loop: each block atomically claims the next (b, h) work-item.\n    // Thread 0 issues the atomic; result is broadcast to all threads via smem.\n    __shared__ int s_work_idx;\n    while (true) {\n        if (threadIdx.x == 0)\n            s_work_idx = atomicAdd(work_counter, 1);\n        __syncthreads();\n        const int work_idx = s_work_idx;\n        if (work_idx >= total_work)\n            break;\n        const int b     = work_idx / num_v_heads;\n        const int h     = work_idx % num_v_heads;\n        const int ratio = num_v_heads / num_k_heads;\n        const int kh    = h / ratio;\n\n        const int global_t = q_offsets[b];  // seq_len == 1 guaranteed\n\n        S* s_ptr = state_ptrs[b] + state_layer_offset + h * state_size;\n\n        // --- Load state: global → registers (direct strided tile loads, tile_v contiguous) ---\n        Array<float, tile_v> vec_S[v_iters][tile_k];\n\n        PRAGMA_UNROLL\n        for (int v_iter = 0; v_iter < v_iters; ++v_iter) {\n            PRAGMA_UNROLL\n            for (int k = 0; k < tile_k; ++k) {\n                Array<S, tile_v> tmp;\n                Load(tmp, &s_ptr[(offset_k * tile_k + k) * v_head_dim + (offset_v + v_iter * v_threads) * tile_v]);\n                vec_S[v_iter][k] = cast<float>(tmp);\n            }\n        }\n\n        // --- Process single token (seq_len == 1) ---\n        {\n            const T* q_ptr = qkv_in + global_t * conv_dim + kh * k_head_dim;\n            const T* k_ptr = qkv_in + global_t * conv_dim + k_dim_total + kh * k_head_dim;\n            const T* v_ptr = qkv_in + global_t * conv_dim + 2 * k_dim_total + h * v_head_dim;\n            T*       o_ptr = v_out + global_t * v_dim + h * v_head_dim;\n\n            const float beta_val = (float)beta_in[global_t * num_v_heads + h];\n            const float decay    = expf((float)g_in[global_t * num_v_heads + h]);\n\n            Array<float, tile_k> vec_K;\n            Array<float, tile_k> vec_Q;\n\n            // L2-normalize K and Q in registers\n            {\n                Array<T, tile_k> tmp_K;\n                Array<T, tile_k> tmp_Q;\n                Load(tmp_K, &k_ptr[offset_k * tile_k]);\n                Load(tmp_Q, &q_ptr[offset_k * tile_k]);\n                vec_K = cast<float>(tmp_K);\n                vec_Q = cast<float>(tmp_Q);\n            }\n\n            float k_sum = 0.f, q_sum = 0.f;\n            PRAGMA_UNROLL\n            for (int k = 0; k < tile_k; ++k) {\n                k_sum += vec_K[k] * vec_K[k];\n                q_sum += vec_Q[k] * vec_Q[k];\n            }\n            PRAGMA_UNROLL\n            for (int mask = k_threads / 2; mask > 0; mask /= 2) {\n                k_sum += __shfl_xor_sync(0xffffffff, k_sum, mask);\n                q_sum += __shfl_xor_sync(0xffffffff, q_sum, mask);\n            }\n            const float k_inv_norm = rsqrtf(k_sum + 1e-6f);\n            const float q_inv_norm = rsqrtf(q_sum + 1e-6f);\n            PRAGMA_UNROLL\n            for (int i = 0; i < tile_k; ++i) {\n                vec_K[i] *= k_inv_norm;\n                vec_Q[i] *= q_inv_norm;\n            }\n\n            // KQ dot product (invariant across v elements)\n            float KQ = 0.f;\n            PRAGMA_UNROLL\n            for (int k = 0; k < tile_k; ++k)\n                KQ += vec_K[k] * vec_Q[k];\n            PRAGMA_UNROLL\n            for (int mask = k_threads / 2; mask > 0; mask /= 2)\n                KQ += __shfl_xor_sync(0xffffffff, KQ, mask);\n\n            Array<T, tile_v> vec_V[v_iters];\n            PRAGMA_UNROLL\n            for (int v_iter = 0; v_iter < v_iters; ++v_iter)\n                Load(vec_V[v_iter], &v_ptr[(offset_v + v_iter * v_threads) * tile_v]);\n\n            PRAGMA_UNROLL\n            for (int v_iter = 0; v_iter < v_iters; ++v_iter) {\n                Array<T, tile_v> vec_O;\n                PRAGMA_UNROLL\n                for (int v = 0; v < tile_v; ++v) {\n                    // Fused: decay + dual dot product (kv_mem and SQ simultaneously)\n                    float kv_mem = 0.f, SQ = 0.f;\n                    PRAGMA_UNROLL\n                    for (int k = 0; k < tile_k; ++k) {\n                        float s_decayed     = vec_S[v_iter][k][v] * decay;\n                        vec_S[v_iter][k][v] = s_decayed;\n                        kv_mem += s_decayed * vec_K[k];\n                        SQ += s_decayed * vec_Q[k];\n                    }\n                    PRAGMA_UNROLL\n                    for (int mask = k_threads / 2; mask > 0; mask /= 2) {\n                        kv_mem += __shfl_xor_sync(0xffffffff, kv_mem, mask);\n                        SQ += __shfl_xor_sync(0xffffffff, SQ, mask);\n                    }\n                    const float delta = ((float)vec_V[v_iter][v] - kv_mem) * beta_val;\n                    PRAGMA_UNROLL\n                    for (int k = 0; k < tile_k; ++k)\n                        vec_S[v_iter][k][v] += vec_K[k] * delta;\n                    vec_O[v] = static_cast<T>((SQ + delta * KQ) * scale);\n                }\n                if (offset_k == 0)\n                    Store(&o_ptr[(offset_v + v_iter * v_threads) * tile_v], vec_O);\n            }\n        }\n\n        // --- Store state: registers → global (direct strided tile stores, tile_v contiguous) ---\n        PRAGMA_UNROLL\n        for (int v_iter = 0; v_iter < v_iters; ++v_iter) {\n            PRAGMA_UNROLL\n            for (int k = 0; k < tile_k; ++k) {\n                auto tmp = cast<S>(vec_S[v_iter][k]);\n                Store(&s_ptr[(offset_k * tile_k + k) * v_head_dim + (offset_v + v_iter * v_threads) * tile_v], tmp);\n            }\n        }\n    }\n}\n\nvoid invokeGatedDeltaRuleBatched_v3(Ref<Tensor>           v_out_,\n                                    const Tensor&         qkv_in,\n                                    const Tensor&         beta,\n                                    const Tensor&         g,\n                                    const Buffer_<void*>& state_ptrs,\n                                    const Buffer_<int>&   q_offsets,\n                                    int                   batch_size,\n                                    int                   num_k_heads,\n                                    int                   state_layer_offset,\n                                    DataType              state_dtype,\n                                    int                   sm_count,\n                                    int*                  work_counter,\n                                    cudaStream_t          stream)\n{\n    auto& v_out = v_out_.get();\n\n    const int num_v_heads = beta.shape(1);\n    const int v_dim       = v_out.shape(1);\n    const int k_dim_total = (qkv_in.shape(1) - v_dim) / 2;\n\n    if (batch_size == 0 || num_v_heads == 0)\n        return;\n\n    constexpr int kHeadDim  = 128;\n    constexpr int kBlockDim = 256;\n\n    TM_CHECK_EQ(v_dim / num_v_heads, kHeadDim);\n    TM_CHECK_EQ(k_dim_total / num_k_heads, kHeadDim);\n\n    const int total_work = batch_size * num_v_heads;\n\n    auto invoke = [&](auto t) {\n        using T     = decltype(t);\n        auto launch = [&](auto s) {\n            using S = decltype(s);\n\n            auto         kernel        = recurrent_gated_delta_rule_kernel_v3<kHeadDim, kHeadDim, kBlockDim, T, S>;\n            const size_t smem_sz       = sizeof(int);  // s_work_idx\n            int          blocks_per_sm = 1;\n            cudaOccupancyMaxActiveBlocksPerMultiprocessor(&blocks_per_sm, kernel, kBlockDim, smem_sz);\n            const int grid_blocks = min(total_work, blocks_per_sm * sm_count);\n\n            cudaMemsetAsync(work_counter, 0, sizeof(int), stream);\n            kernel<<<grid_blocks, kBlockDim, smem_sz, stream>>>(v_out.data<T>(),\n                                                                qkv_in.data<T>(),\n                                                                beta.data<T>(),\n                                                                g.data<T>(),\n                                                                (S* const*)state_ptrs.data(),\n                                                                q_offsets.data(),\n                                                                work_counter,\n                                                                total_work,\n                                                                num_v_heads,\n                                                                num_k_heads,\n                                                                k_dim_total,\n                                                                state_layer_offset);\n        };\n        if (state_dtype == kFloat32)\n            launch(float{});\n        else\n            launch(T{});\n    };\n    TM_DISPATCH_PRIMARY_DTYPES(v_out.dtype(), invoke);\n}\n\n// =============================================================================\n// Chunked Gated Delta Rule kernel — register-centric, small chunk size.\n//\n// Grid = batch_size * num_v_heads blocks, one block per (b, h) pair.\n// Cooperative QKV load to smem per chunk, then sequential per-token\n// processing (same recurrence as v2) reading from smem.\n// State load/store uses the full swizzled smem buffer (same as v2).\n// =============================================================================\ntemplate<int kHeadDim, int kChunkSize, int kBlockDim, class T, class S>\n__global__ void chunked_gated_delta_rule_kernel(T*         v_out,\n                                                const T*   qkv_in,\n                                                const T*   beta_in,\n                                                const T*   g_in,\n                                                S* const*  state_ptrs,\n                                                const int* q_offsets,\n                                                int        num_v_heads,\n                                                int        num_k_heads,\n                                                int        k_dim_total,\n                                                int        state_layer_offset)\n{\n    constexpr int C = kChunkSize;\n    constexpr int D = kHeadDim;\n\n    const int bh    = blockIdx.x;\n    const int b     = bh / num_v_heads;\n    const int h     = bh % num_v_heads;\n    const int ratio = num_v_heads / num_k_heads;\n    const int kh    = h / ratio;\n\n    const int tok_off    = q_offsets[b];\n    const int seq_len    = q_offsets[b + 1] - tok_off;\n    const int state_size = D * D;\n    const int conv_dim   = 2 * k_dim_total + num_v_heads * D;\n    const int v_dim      = num_v_heads * D;\n\n    if (seq_len == 0)\n        return;\n\n    S*          s_ptr = state_ptrs[b] + state_layer_offset + h * state_size;\n    const float scale = rsqrtf((float)D);\n\n    // ── State tiling (same as v2) ──\n    constexpr int tile_k    = 8;\n    constexpr int tile_v    = 8;\n    constexpr int k_tiles   = D / tile_k;                // 16\n    constexpr int k_threads = k_tiles;                   // 16\n    constexpr int v_threads = kBlockDim / k_threads;     // 16\n    constexpr int v_tiles   = D / tile_v;                // 16\n    constexpr int v_iters   = cdiv(v_tiles, v_threads);  // 1\n\n    const int offset_k = threadIdx.x % k_threads;\n    const int offset_v = threadIdx.x / k_threads;\n\n    Array<float, tile_v> vec_S[v_iters][tile_k];\n\n    extern __shared__ __align__(16) char smem_buf[];\n\n    // ================================================================\n    //  LOAD STATE  global → smem (swizzled) → registers   (same as v2)\n    // ================================================================\n    {\n        using Map_S          = ThreadMap_V2<D, D, sizeof(uint4) / sizeof(S), Raked, kBlockDim / WARP_SIZE>;\n        constexpr int kBase  = (sizeof(S) == 4) ? 2 : 3;\n        constexpr int kShift = 10 - kBase;\n        using Layout         = SmemLayoutV2<D, D, -1, -1, Swizzle<4, kBase, kShift>>;\n        SmemAccessor<S, Layout> smem_S{(S*)smem_buf};\n\n        const int     warp_id  = threadIdx.x / WARP_SIZE;\n        const int     lane_id  = threadIdx.x % WARP_SIZE;\n        constexpr int kAccessC = Map_S::kAccessC;\n\n        PRAGMA_UNROLL\n        for (int s = 0; s < Map_S::kIterS; ++s) {\n            Array<S, kAccessC> vec;\n            PRAGMA_UNROLL\n            for (int c = 0; c < Map_S::kIterC; ++c) {\n                const auto [vd, kd] = Map_S::get_offset(warp_id, lane_id);\n                const int fvd       = vd + c * Map_S::kDeltaC;\n                const int fkd       = kd + s * Map_S::kDeltaS;\n                Load(vec, s_ptr + fkd * D + fvd);\n                Store(&smem_S(fkd, fvd), vec);\n            }\n        }\n        __syncthreads();\n\n        PRAGMA_UNROLL\n        for (int vi = 0; vi < v_iters; ++vi) {\n            PRAGMA_UNROLL\n            for (int k = 0; k < tile_k; ++k) {\n                static_assert(tile_v % Map_S::kAccessC == 0);\n                PRAGMA_UNROLL\n                for (int c = 0; c < tile_v / Map_S::kAccessC; ++c) {\n                    Array<S, Map_S::kAccessC> tmp;\n                    Load(tmp,\n                         &smem_S(offset_k * tile_k + k, (offset_v + vi * v_threads) * tile_v + c * Map_S::kAccessC));\n                    (Array<float, Map_S::kAccessC>&)vec_S[vi][k][c * Map_S::kAccessC] = cast<float>(tmp);\n                }\n            }\n        }\n    }\n    __syncthreads();\n\n    // ================================================================\n    //  CHUNK PROCESSING  — sequential per-token (same as v2) with\n    //  smem-cached QKV.  Eliminates resolvent/intra-attention overhead.\n    // ================================================================\n    // Shared memory layout for chunk processing (overlaps state staging buffer):\n    //   k_norm_smem[C][kSmemStride]  — pre-normalized K\n    //   q_norm_smem[C][kSmemStride]  — pre-normalized Q\n    //   v_smem[C][kSmemStride]       — raw V (as float)\n    //   scalars[3*C]                 — beta[C], g[C], scratch[C]\n    constexpr int kSmemStride = D + 4;  // pad rows by 4 to avoid 4-way bank conflicts\n\n    float* k_norm_smem = (float*)smem_buf;\n    float* q_norm_smem = k_norm_smem + C * kSmemStride;\n    float* v_smem      = q_norm_smem + C * kSmemStride;\n    float* beta_vals   = v_smem + C * kSmemStride;\n    float* g_vals      = beta_vals + C;\n\n    // Thread-to-token mapping for cooperative loads: 1 warp per token\n    constexpr int kThreadsPerTok = kBlockDim / C;                 // 256/8 = 32\n    constexpr int kElemsPerThr   = D / kThreadsPerTok;            // 128/32 = 4\n    const int     load_tok       = threadIdx.x / kThreadsPerTok;  // which token (0..C-1)\n    const int     load_lane      = threadIdx.x % kThreadsPerTok;  // lane within token's warp\n\n    const int num_chunks = (seq_len + C - 1) / C;\n\n    for (int ci = 0; ci < num_chunks; ++ci) {\n        const int chunk_start = tok_off + ci * C;\n        const int valid_len   = min(C, seq_len - ci * C);\n\n        // ────────────────────────────────────────────────────\n        //  Phase 0: Cooperative load K, Q, V → smem (pre-normalized)\n        //  32 threads (1 warp) per token, 4 elements per thread.\n        //  Norms computed via warp shuffle, K/Q normalized in registers\n        //  before writing to smem → eliminates one __syncthreads.\n        // ────────────────────────────────────────────────────\n        {\n            float K_reg[kElemsPerThr], Q_reg[kElemsPerThr];\n            float k_sq = 0.f, q_sq = 0.f;\n            if (load_tok < valid_len) {\n                const int gt    = chunk_start + load_tok;\n                const T*  k_ptr = qkv_in + gt * conv_dim + k_dim_total + kh * D;\n                const T*  q_ptr = qkv_in + gt * conv_dim + kh * D;\n                const T*  v_ptr = qkv_in + gt * conv_dim + 2 * k_dim_total + h * D;\n                PRAGMA_UNROLL\n                for (int e = 0; e < kElemsPerThr; ++e) {\n                    const int d = load_lane * kElemsPerThr + e;\n                    K_reg[e]    = (float)k_ptr[d];\n                    Q_reg[e]    = (float)q_ptr[d];\n                    k_sq += K_reg[e] * K_reg[e];\n                    q_sq += Q_reg[e] * Q_reg[e];\n                    v_smem[load_tok * kSmemStride + d] = (float)v_ptr[d];\n                }\n                if (load_lane == 0) {\n                    beta_vals[load_tok] = (float)beta_in[gt * num_v_heads + h];\n                    g_vals[load_tok]    = (float)g_in[gt * num_v_heads + h];\n                }\n            }\n            else {\n                PRAGMA_UNROLL\n                for (int e = 0; e < kElemsPerThr; ++e) {\n                    K_reg[e]                                                      = 0.f;\n                    Q_reg[e]                                                      = 0.f;\n                    v_smem[load_tok * kSmemStride + load_lane * kElemsPerThr + e] = 0.f;\n                }\n            }\n            // Warp-reduce norms (32-thread warp per token)\n            PRAGMA_UNROLL\n            for (int mask = kThreadsPerTok / 2; mask > 0; mask >>= 1) {\n                k_sq += __shfl_xor_sync(0xffffffff, k_sq, mask);\n                q_sq += __shfl_xor_sync(0xffffffff, q_sq, mask);\n            }\n            const float k_inv = (load_tok < valid_len) ? rsqrtf(k_sq + 1e-6f) : 0.f;\n            const float q_inv = (load_tok < valid_len) ? rsqrtf(q_sq + 1e-6f) : 0.f;\n            // Write normalized K, Q to smem\n            PRAGMA_UNROLL\n            for (int e = 0; e < kElemsPerThr; ++e) {\n                const int d                             = load_lane * kElemsPerThr + e;\n                k_norm_smem[load_tok * kSmemStride + d] = K_reg[e] * k_inv;\n                q_norm_smem[load_tok * kSmemStride + d] = Q_reg[e] * q_inv;\n            }\n        }\n        __syncthreads();  // [sync 1] all smem data ready\n\n        // ────────────────────────────────────────────────────\n        //  Sequential per-token loop (same computation as v2)\n        //  Reads K, Q, V from smem instead of global memory.\n        // ────────────────────────────────────────────────────\n        PRAGMA_UNROLL\n        for (int t = 0; t < C; ++t) {\n            if (t >= valid_len)\n                break;\n\n            const int   gt       = chunk_start + t;\n            const float beta_val = beta_vals[t];\n            const float decay    = expf(g_vals[t]);\n\n            float vec_K[tile_k];\n            float vec_Q[tile_k];\n            PRAGMA_UNROLL\n            for (int k = 0; k < tile_k; ++k) {\n                vec_K[k] = k_norm_smem[t * kSmemStride + offset_k * tile_k + k];\n                vec_Q[k] = q_norm_smem[t * kSmemStride + offset_k * tile_k + k];\n            }\n\n            PRAGMA_UNROLL\n            for (int vi = 0; vi < v_iters; ++vi) {\n                const int v_base = (offset_v + vi * v_threads) * tile_v;\n\n                float vec_V[tile_v];\n                PRAGMA_UNROLL\n                for (int v = 0; v < tile_v; ++v)\n                    vec_V[v] = v_smem[t * kSmemStride + v_base + v];\n\n                Array<T, tile_v> vec_O;\n                PRAGMA_UNROLL\n                for (int v = 0; v < tile_v; ++v) {\n                    // Step 1: state *= decay\n                    PRAGMA_UNROLL\n                    for (int k = 0; k < tile_k; ++k)\n                        vec_S[vi][k][v] *= decay;\n\n                    // Step 2: delta rule update\n                    float kv_mem = 0.f;\n                    PRAGMA_UNROLL\n                    for (int k = 0; k < tile_k; ++k)\n                        kv_mem += vec_S[vi][k][v] * vec_K[k];\n                    PRAGMA_UNROLL\n                    for (int mask = k_threads / 2; mask > 0; mask /= 2)\n                        kv_mem += __shfl_xor_sync(0xffffffff, kv_mem, mask);\n                    const float delta = (vec_V[v] - kv_mem) * beta_val;\n                    PRAGMA_UNROLL\n                    for (int k = 0; k < tile_k; ++k)\n                        vec_S[vi][k][v] += vec_K[k] * delta;\n\n                    // Step 3: output = (S^T @ q) * scale\n                    float O = 0.f;\n                    PRAGMA_UNROLL\n                    for (int k = 0; k < tile_k; ++k)\n                        O += vec_S[vi][k][v] * vec_Q[k];\n                    PRAGMA_UNROLL\n                    for (int mask = k_threads / 2; mask > 0; mask /= 2)\n                        O += __shfl_xor_sync(0xffffffff, O, mask);\n                    vec_O[v] = static_cast<T>(O * scale);\n                }\n                if (offset_k == 0)\n                    Store(&v_out[gt * v_dim + h * D + v_base], vec_O);\n            }\n        }\n        __syncthreads();  // [sync 2] ensure all reads done before next chunk overwrites smem\n    }                     // chunk loop\n\n    // ================================================================\n    //  STORE STATE  registers → smem (swizzled) → global   (same as v2)\n    // ================================================================\n    {\n        using Map_S          = ThreadMap_V2<D, D, sizeof(uint4) / sizeof(S), Raked, kBlockDim / WARP_SIZE>;\n        constexpr int kBase  = (sizeof(S) == 4) ? 2 : 3;\n        constexpr int kShift = 10 - kBase;\n        using Layout         = SmemLayoutV2<D, D, -1, -1, Swizzle<4, kBase, kShift>>;\n        SmemAccessor<S, Layout> smem_S{(S*)smem_buf};\n        constexpr int           kAccessC = Map_S::kAccessC;\n\n        PRAGMA_UNROLL\n        for (int vi = 0; vi < v_iters; ++vi) {\n            PRAGMA_UNROLL\n            for (int k = 0; k < tile_k; ++k) {\n                PRAGMA_UNROLL\n                for (int c = 0; c < tile_v / kAccessC; ++c) {\n                    auto tmp = cast<S>((Array<float, kAccessC>&)vec_S[vi][k][c * kAccessC]);\n                    Store(&smem_S(offset_k * tile_k + k, (offset_v + vi * v_threads) * tile_v + c * kAccessC), tmp);\n                }\n            }\n        }\n        __syncthreads();\n\n        const int warp_id = threadIdx.x / WARP_SIZE;\n        const int lane_id = threadIdx.x % WARP_SIZE;\n\n        PRAGMA_UNROLL\n        for (int s = 0; s < Map_S::kIterS; ++s) {\n            Array<S, Map_S::kAccessC> vec;\n            PRAGMA_UNROLL\n            for (int c = 0; c < Map_S::kIterC; ++c) {\n                const auto [vd, kd] = Map_S::get_offset(warp_id, lane_id);\n                const int fvd       = vd + c * Map_S::kDeltaC;\n                const int fkd       = kd + s * Map_S::kDeltaS;\n                Load(vec, &smem_S(fkd, fvd));\n                Store(s_ptr + fkd * D + fvd, vec);\n            }\n        }\n    }\n}\n\n// Host-side launcher\nvoid invokeChunkedGatedDeltaRuleBatched(Ref<Tensor>           v_out_,\n                                        const Tensor&         qkv_in,\n                                        const Tensor&         beta,\n                                        const Tensor&         g,\n                                        const Buffer_<void*>& state_ptrs,\n                                        const Buffer_<int>&   q_offsets,\n                                        int                   batch_size,\n                                        int                   num_k_heads,\n                                        int                   state_layer_offset,\n                                        DataType              state_dtype,\n                                        int /*sm_count*/,\n                                        int* /*work_counter*/,\n                                        cudaStream_t stream)\n{\n    auto& v_out = v_out_.get();\n\n    const int num_v_heads    = beta.shape(1);\n    const int v_dim          = v_out.shape(1);\n    const int value_head_dim = v_dim / num_v_heads;\n    const int k_dim_total    = (qkv_in.shape(1) - v_dim) / 2;\n\n    if (batch_size == 0 || num_v_heads == 0)\n        return;\n\n    constexpr int kHeadDim   = 128;\n    constexpr int kChunkSize = 16;\n    constexpr int kBlockDim  = 256;\n\n    TM_CHECK_EQ(value_head_dim, kHeadDim);\n    TM_CHECK_EQ(k_dim_total / num_k_heads, kHeadDim);\n\n    const int num_blocks = batch_size * num_v_heads;\n\n    auto invoke = [&](auto t) {\n        using T     = decltype(t);\n        auto launch = [&](auto s) {\n            using S = decltype(s);\n\n            auto kernel = chunked_gated_delta_rule_kernel<kHeadDim, kChunkSize, kBlockDim, T, S>;\n\n            // smem = max(state staging, chunk working buffers)\n            // State staging: D*D*sizeof(S) (64KB for fp32)\n            // Chunk buffers: QKV cache [3*C*(D+4)] + scalars[2*C]\n            const size_t state_smem  = kHeadDim * kHeadDim * sizeof(S);\n            const int    kSmemStride = kHeadDim + 4;\n            const size_t chunk_smem  = 3 * kChunkSize * kSmemStride * sizeof(float)  // k_norm, q_norm, v\n                                      + 2 * kChunkSize * sizeof(float);              // beta, g\n            const size_t smem_sz = state_smem > chunk_smem ? state_smem : chunk_smem;\n\n            if (smem_sz > 48 << 10) {\n                cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_sz);\n            }\n\n            kernel<<<num_blocks, kBlockDim, smem_sz, stream>>>(v_out.data<T>(),\n                                                               qkv_in.data<T>(),\n                                                               beta.data<T>(),\n                                                               g.data<T>(),\n                                                               (S* const*)state_ptrs.data(),\n                                                               q_offsets.data(),\n                                                               num_v_heads,\n                                                               num_k_heads,\n                                                               k_dim_total,\n                                                               state_layer_offset);\n        };\n        if (state_dtype == kFloat32) {\n            launch(float{});\n        }\n        else {\n            launch(T{});\n        }\n    };\n    TM_DISPATCH_PRIMARY_DTYPES(v_out.dtype(), invoke);\n}\n\ntemplate<class T>\n__global__ void compute_beta_g_kernel_v2(T*       beta_out,\n                                         T*       g_out,\n                                         const T* b_in,\n                                         int      b_stride,\n                                         const T* a_in,\n                                         int      a_stride,\n                                         const T* A_log,\n                                         const T* dt_bias,\n                                         int      total,\n                                         int      num_v_heads)\n{\n    const int idx = blockIdx.x * blockDim.x + threadIdx.x;\n\n    if (idx >= total)\n        return;\n\n    const int hi = idx % num_v_heads;\n    const int ti = idx / num_v_heads;\n\n    float b_val       = static_cast<float>(b_in[ti * b_stride + hi]);\n    float a_val       = static_cast<float>(a_in[ti * a_stride + hi]);\n    float A_log_val   = static_cast<float>(A_log[hi]);\n    float dt_bias_val = static_cast<float>(dt_bias[hi]);\n\n    float beta  = 1.0f / (1.0f + expf(-b_val));\n    float sum   = a_val + dt_bias_val;\n    float sp    = sum > 20.0f ? sum : logf(1.0f + expf(sum));\n    float g_val = -expf(A_log_val) * sp;\n\n    beta_out[idx] = static_cast<T>(beta);\n    g_out[idx]    = static_cast<T>(g_val);\n}\n\nvoid ComputeBetaG_v2(Ref<Tensor>   beta_out_,\n                     Ref<Tensor>   g_out_,\n                     const Tensor& b_in,\n                     const Tensor& a_in,\n                     const Tensor& A_log,\n                     const Tensor& dt_bias,\n                     cudaStream_t  stream)\n{\n\n    auto& beta_out = beta_out_.get();\n    auto& g_out    = g_out_.get();\n\n    const int threads = 256;\n    const int blocks  = cdiv<ssize_t>(beta_out.size(), threads);\n\n    auto invoke = [&](auto t) {\n        using T = decltype(t);\n        compute_beta_g_kernel_v2<<<blocks, threads, 0, stream>>>(beta_out.data<T>(),\n                                                                 g_out.data<T>(),\n                                                                 b_in.data<T>(),\n                                                                 b_in.stride(0),\n                                                                 a_in.data<T>(),\n                                                                 a_in.stride(0),\n                                                                 A_log.data<T>(),\n                                                                 dt_bias.data<T>(),\n                                                                 beta_out.size(),\n                                                                 A_log.size());\n    };\n\n    TM_DISPATCH_PRIMARY_DTYPES(beta_out.dtype(), invoke);\n}\n\n// =============================================================================\n// RMSNorm * SiLU-Gate (fused output normalization)\n// =============================================================================\ntemplate<typename T>\n__global__ void rms_norm_gated_kernel(\n    T* hidden, const T* gate, const T* weight, float eps, int N, int head_dim, int gate_stride, int num_heads)\n{\n    const int row = blockIdx.x;\n    if (row >= N)\n        return;\n\n    T*        h         = hidden + row * head_dim;\n    const int token_idx = row / num_heads;\n    const int head_idx  = row % num_heads;\n    const T*  g         = gate + token_idx * gate_stride + head_idx * head_dim;\n\n    __shared__ float smem[32];\n    float            sum_sq = 0.0f;\n    for (int d = threadIdx.x; d < head_dim; d += blockDim.x) {\n        float val = static_cast<float>(h[d]);\n        sum_sq += val * val;\n    }\n    for (int mask = 16; mask > 0; mask >>= 1)\n        sum_sq += __shfl_xor_sync(0xffffffff, sum_sq, mask);\n    if ((threadIdx.x & 31) == 0)\n        smem[threadIdx.x >> 5] = sum_sq;\n    __syncthreads();\n    if (threadIdx.x >> 5 == 0) {\n        sum_sq = (threadIdx.x < (blockDim.x + 31) / 32) ? smem[threadIdx.x] : 0.0f;\n        for (int mask = 16; mask > 0; mask >>= 1)\n            sum_sq += __shfl_xor_sync(0xffffffff, sum_sq, mask);\n        if (threadIdx.x == 0)\n            smem[0] = sum_sq;\n    }\n    __syncthreads();\n    sum_sq = smem[0];\n\n    float inv_rms = rsqrtf(sum_sq / (float)head_dim + eps);\n    for (int d = threadIdx.x; d < head_dim; d += blockDim.x) {\n        float h_val  = static_cast<float>(h[d]) * inv_rms * static_cast<float>(weight[d]);\n        float g_val  = static_cast<float>(g[d]);\n        float silu_g = g_val / (1.0f + expf(-g_val));\n        h[d]         = static_cast<T>(h_val * silu_g);\n    }\n}\n\nvoid invokeRMSNormGated(Ref<Tensor> hidden_, const Tensor& gate, const Tensor& weight, float eps, cudaStream_t stream)\n{\n    auto& hidden = hidden_.get();\n\n    const int N           = hidden.shape(0);\n    const int head_dim    = hidden.shape(1);\n    const int token_num   = gate.shape(0);\n    const int gate_stride = gate.stride(0);\n    const int num_heads   = N / token_num;\n\n    if (N == 0)\n        return;\n\n    const int threads = std::min(256, head_dim);\n\n    auto invoke = [&](auto t) {\n        using T = decltype(t);\n        rms_norm_gated_kernel<<<N, threads, 0, stream>>>(\n            hidden.data<T>(), gate.data<T>(), weight.data<T>(), eps, N, head_dim, gate_stride, num_heads);\n    };\n    TM_DISPATCH_PRIMARY_DTYPES(hidden.dtype(), invoke);\n}\n\n// =============================================================================\n// Fused Conv1d + SiLU — persistent batched kernel\n//\n// Weight layout: [d_conv, conv_dim], State layout: [d_conv, conv_dim] per batch.\n//\n// Persistent 1D grid. Each block has a fixed channel tile\n// (blockIdx.x % num_ch_tiles) and atomically claims single-token work items\n// via a global counter. Token-major work ordering with grid size a multiple of\n// num_ch_tiles guarantees monotonically increasing tokens and a fixed channel\n// tile per block.\n// =============================================================================\ntemplate<int D_CONV, int CHANNELS_PER_THREAD, int BLOCK_DIM, int NUM_TOKENS, typename T>\n__global__ void __launch_bounds__(BLOCK_DIM) fused_conv1d_batched_kernel_v2(T*           out,\n                                                                            const T*     in,\n                                                                            const T*     weight,\n                                                                            const T*     bias,\n                                                                            void* const* conv_state_ptrs,\n                                                                            const int*   q_offsets,\n                                                                            const int*   k_offsets,\n                                                                            int*         work_counter,\n                                                                            int          batch_size,\n                                                                            int          conv_dim,\n                                                                            int          in_stride,\n                                                                            int          num_token_tiles,\n                                                                            int          state_layer_offset,\n                                                                            int          total_work,\n                                                                            int          num_ch_tiles)\n{\n    static_assert(BLOCK_DIM * CHANNELS_PER_THREAD > 0);\n\n    int prev_ch_tile = -1;\n    int c_base       = 0;\n\n    Array<T, CHANNELS_PER_THREAD> w_tap[D_CONV];\n    Array<T, CHANNELS_PER_THREAD> bias_vals;\n\n    __shared__ int  s_work_id;\n    __shared__ int4 s_batch_info;\n    int             b_start = 0;\n\n    while (true) {\n        if (threadIdx.x == 0)\n            s_work_id = atomicAdd(work_counter, 1);\n        __syncthreads();\n\n        if (s_work_id >= total_work)\n            break;\n\n        const int t_tile  = s_work_id % num_token_tiles;\n        const int ch_tile = s_work_id / num_token_tiles;\n\n        if (ch_tile != prev_ch_tile) {\n            prev_ch_tile = ch_tile;\n            b_start      = 0;\n        }\n\n        c_base = (ch_tile * BLOCK_DIM + threadIdx.x) * CHANNELS_PER_THREAD;\n\n        const bool ch_active = (c_base < conv_dim);\n\n        if (ch_active) {\n            PRAGMA_UNROLL\n            for (int d = 0; d < D_CONV; ++d) {\n                Load(w_tap[d], weight + d * conv_dim + c_base);\n            }\n            if (bias)\n                Load(bias_vals, bias + c_base);\n        }\n\n        if constexpr (NUM_TOKENS == 1) {\n            for (int b = b_start + threadIdx.x; b < batch_size; b += BLOCK_DIM) {\n                int lo = __ldg(&q_offsets[b]);\n                if (lo > t_tile)\n                    break;\n                int hi = __ldg(&q_offsets[b + 1]);\n                if (t_tile < hi) {\n                    int seq      = hi - lo;\n                    int hist     = (__ldg(&k_offsets[b + 1]) - __ldg(&k_offsets[b])) - seq;\n                    s_batch_info = make_int4(b, lo, seq, hist);\n                }\n            }\n        }\n        else {\n            for (int b = b_start + threadIdx.x; b < batch_size; b += BLOCK_DIM) {\n                int tile_off = __ldg(&q_offsets[b]) / NUM_TOKENS + b;\n                if (tile_off > t_tile)\n                    break;\n                int tile_off_next = __ldg(&q_offsets[b + 1]) / NUM_TOKENS + b + 1;\n                if (t_tile < tile_off_next) {\n                    int lo       = __ldg(&q_offsets[b]);\n                    int seq      = __ldg(&q_offsets[b + 1]) - lo;\n                    int hist     = (__ldg(&k_offsets[b + 1]) - __ldg(&k_offsets[b])) - seq;\n                    s_batch_info = make_int4(b, lo, seq, hist);\n                }\n            }\n        }\n        __syncthreads();\n\n        b_start = s_batch_info.x;\n\n        const int4 bi          = s_batch_info;\n        const int  b           = bi.x;\n        const int  seq_off     = bi.y;\n        const int  seq_len     = bi.z;\n        const int  history_len = bi.w;\n\n        int t_local_start;\n        int n_tokens;\n        if constexpr (NUM_TOKENS == 1) {\n            t_local_start = t_tile - seq_off;\n            n_tokens      = 1;\n        }\n        else {\n            const int tile_off_b = seq_off / NUM_TOKENS + b;\n            t_local_start        = (t_tile - tile_off_b) * NUM_TOKENS;\n            if (t_local_start >= seq_len)\n                continue;\n            n_tokens = min(NUM_TOKENS, seq_len - t_local_start);\n        }\n\n        const int ring_start = (history_len + t_local_start + 1) % D_CONV;\n        T*        state_base = (T*)conv_state_ptrs[b] + state_layer_offset;\n\n        if (ch_active) {\n            constexpr int                 VALS_SIZE = NUM_TOKENS + D_CONV - 1;\n            Array<T, CHANNELS_PER_THREAD> vals[VALS_SIZE];\n            const int                     n_vals = n_tokens + D_CONV - 1;\n\n            PRAGMA_UNROLL\n            for (int i = 0; i < VALS_SIZE; ++i) {\n                if (i < n_vals) {\n                    int pos = t_local_start - (D_CONV - 1) + i;\n                    if (pos >= 0) {\n                        Load(vals[i], in + (seq_off + pos) * in_stride + c_base);\n                    }\n                    else {\n                        int ring_d = (ring_start + i) % D_CONV;\n                        Load(vals[i], state_base + ring_d * conv_dim + c_base);\n                    }\n                }\n            }\n\n            PRAGMA_UNROLL\n            for (int tok = 0; tok < NUM_TOKENS; ++tok) {\n                if (tok < n_tokens) {\n                    float acc[CHANNELS_PER_THREAD] = {};\n                    PRAGMA_UNROLL\n                    for (int d = 0; d < D_CONV; ++d) {\n                        PRAGMA_UNROLL\n                        for (int ch = 0; ch < CHANNELS_PER_THREAD; ++ch) {\n                            acc[ch] += static_cast<float>(vals[tok + d][ch]) * static_cast<float>(w_tap[d][ch]);\n                        }\n                    }\n\n                    Array<T, CHANNELS_PER_THREAD> out_vals;\n                    PRAGMA_UNROLL\n                    for (int ch = 0; ch < CHANNELS_PER_THREAD; ++ch) {\n                        if (bias)\n                            acc[ch] += static_cast<float>(bias_vals[ch]);\n                        out_vals[ch] = static_cast<T>(acc[ch] / (1.0f + expf(-acc[ch])));\n                    }\n\n                    Store(out + (seq_off + t_local_start + tok) * conv_dim + c_base, out_vals);\n                }\n            }\n\n            if (t_local_start + n_tokens >= seq_len) {\n                PRAGMA_UNROLL\n                for (int i = 0; i < VALS_SIZE; ++i) {\n                    int pos = t_local_start - (D_CONV - 1) + i;\n                    if (pos >= 0 && pos >= seq_len - D_CONV && pos < seq_len) {\n                        int ring_d = (ring_start + i) % D_CONV;\n                        Store(state_base + ring_d * conv_dim + c_base, vals[i]);\n                    }\n                }\n            }\n        }\n    }\n}\n\nvoid invokeFusedConv1dSiLU(Ref<Tensor>           out_,\n                           const Tensor&         in,\n                           const Tensor&         weight,\n                           const Tensor&         bias,\n                           const Buffer_<void*>& conv_state_ptrs,\n                           const Buffer_<int>&   q_offsets,\n                           const Buffer_<int>&   k_offsets,\n                           int                   batch_size,\n                           int                   state_layer_offset,\n                           int                   sm_count,\n                           int*                  work_counter,\n                           cudaStream_t          stream)\n{\n    auto& out = out_.get();\n\n    const int total_tokens = in.shape(0);\n    const int d_conv       = weight.shape(0);\n    const int conv_dim     = weight.shape(1);\n    const int in_stride    = in.stride(0);\n\n    constexpr int threads = 128;\n\n    auto invoke = [&](auto t) {\n        using T = decltype(t);\n        if (d_conv == 4) {\n            constexpr int kDConv     = 4;\n            constexpr int kChPerT    = 8;\n            const int     ch_per_blk = threads * kChPerT;\n            TM_CHECK(conv_dim % kChPerT == 0);\n            const int num_ch_tiles = cdiv(conv_dim, ch_per_blk);\n\n            auto launch = [&](auto num_tok_tag) {\n                constexpr int kNumTok         = decltype(num_tok_tag)::value;\n                const int     num_token_tiles = (kNumTok == 1) ? total_tokens : total_tokens / kNumTok + batch_size;\n                const int     total_work      = num_token_tiles * num_ch_tiles;\n\n                auto kernel        = fused_conv1d_batched_kernel_v2<kDConv, kChPerT, threads, kNumTok, T>;\n                int  blocks_per_sm = 1;\n                cudaOccupancyMaxActiveBlocksPerMultiprocessor(&blocks_per_sm, kernel, threads, 0);\n                int grid = min(total_work, blocks_per_sm * sm_count);\n\n                cudaMemsetAsync(work_counter, 0, sizeof(int), stream);\n                kernel<<<grid, threads, 0, stream>>>(out.data<T>(),\n                                                     in.data<T>(),\n                                                     weight.data<T>(),\n                                                     bias ? bias.data<T>() : (T*)nullptr,\n                                                     conv_state_ptrs.data(),\n                                                     q_offsets.data(),\n                                                     k_offsets.data(),\n                                                     work_counter,\n                                                     batch_size,\n                                                     conv_dim,\n                                                     in_stride,\n                                                     num_token_tiles,\n                                                     state_layer_offset,\n                                                     total_work,\n                                                     num_ch_tiles);\n            };\n\n            int avg_seq = total_tokens / batch_size;\n            if (avg_seq >= 4)\n                launch(std::integral_constant<int, 5>{});\n            else\n                launch(std::integral_constant<int, 1>{});\n        }\n        else {\n            TM_CHECK(0) << \"Only d_conv == 4 is supported by fused_conv1d_batched_kernel_v2\";\n        }\n    };\n    TM_DISPATCH_PRIMARY_DTYPES(out.dtype(), invoke);\n}\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/models/llama/gated_delta_net_kernels.h",
    "content": "#pragma once\n\n#include <cuda_bf16.h>\n#include <cuda_fp16.h>\n#include <cuda_runtime.h>\n\n#include \"src/turbomind/core/core.h\"\n\nnamespace turbomind {\n\n// Fused Conv1d + SiLU — unified batched launcher (row-major layout).\n//\n// Processes all requests in a single kernel launch.  Decode (seq_len == 1)\n// and prefill (seq_len > 1) requests may be mixed freely within the batch.\n//\n// out:             (total_tokens, conv_dim)       row-major output\n// in:              (total_tokens, in_stride)      non-contiguous slice of all_proj\n// weight:          (d_conv, conv_dim)\n// bias:            (conv_dim) or empty Tensor\n// conv_state_ptrs: device array[batch_size] of per-request state pointers\n// q_offsets:       device int[batch_size+1] cumulative token offsets\n// k_offsets:       device int[batch_size+1] cumulative key (history+input) offsets\nvoid invokeFusedConv1dSiLU(Ref<Tensor>           out,\n                           const Tensor&         in,\n                           const Tensor&         weight,\n                           const Tensor&         bias,\n                           const Buffer_<void*>& conv_state_ptrs,\n                           const Buffer_<int>&   q_offsets,\n                           const Buffer_<int>&   k_offsets,\n                           int                   batch_size,\n                           int                   state_layer_offset,\n                           int                   sm_count,\n                           int*                  work_counter,\n                           cudaStream_t          stream);\n\n// All three recurrent-rule launchers share the same trailing parameters for\n// interface consistency:\n//   sm_count      — multiprocessor count, queried once by the caller at init\n//   work_counter  — device int* (1 element), owned by caller; v3 uses it for\n//                   atomic workload claiming, v2/chunked ignore it\n//   stream        — CUDA stream\n//\n// v2: standard one-block-per-(b,h) grid launch; sm_count and work_counter ignored.\nvoid invokeGatedDeltaRuleBatched_v2(Ref<Tensor>           v_out,\n                                    const Tensor&         qkv_in,\n                                    const Tensor&         beta,\n                                    const Tensor&         g,\n                                    const Buffer_<void*>& state_ptrs,\n                                    const Buffer_<int>&   q_offsets,\n                                    int                   batch_size,\n                                    int                   num_k_heads,\n                                    int                   state_layer_offset,\n                                    DataType              state_dtype,\n                                    int                   sm_count,\n                                    int*                  work_counter,\n                                    cudaStream_t          stream);\n\n// v3: persistent decode kernel, seq_len == 1 only.\n// Launches min(total_work, blocks_per_sm * sm_count) blocks; each block claims\n// work items atomically via work_counter (zeroed via cudaMemsetAsync per launch).\n// state_dtype controls state precision: kFloat32 → S=float, otherwise S=T.\nvoid invokeGatedDeltaRuleBatched_v3(Ref<Tensor>           v_out,\n                                    const Tensor&         qkv_in,\n                                    const Tensor&         beta,\n                                    const Tensor&         g,\n                                    const Buffer_<void*>& state_ptrs,\n                                    const Buffer_<int>&   q_offsets,\n                                    int                   batch_size,\n                                    int                   num_k_heads,\n                                    int                   state_layer_offset,\n                                    DataType              state_dtype,\n                                    int                   sm_count,\n                                    int*                  work_counter,\n                                    cudaStream_t          stream);\n\n// =============================================================================\n// Chunked Gated Delta Rule — for accelerating prefill\n//\n// Processes sequences in chunks of size C (default 64), parallelizing\n// intra-chunk computation while maintaining sequential inter-chunk state\n// updates. Reduces sequential depth from L to L/C.\n//\n// Same tensor layouts as invokeGatedDeltaRuleBatched_v2.\n// sm_count and work_counter accepted for interface parity; ignored internally.\nvoid invokeChunkedGatedDeltaRuleBatched(Ref<Tensor>           v_out,\n                                        const Tensor&         qkv_in,\n                                        const Tensor&         beta,\n                                        const Tensor&         g,\n                                        const Buffer_<void*>& state_ptrs,\n                                        const Buffer_<int>&   q_offsets,\n                                        int                   batch_size,\n                                        int                   num_k_heads,\n                                        int                   state_layer_offset,\n                                        DataType              state_dtype,\n                                        int                   sm_count,\n                                        int*                  work_counter,\n                                        cudaStream_t          stream);\n\n// =============================================================================\n// Helper kernels\n// =============================================================================\n\nvoid ComputeBetaG_v2(Ref<Tensor>   beta_out_,\n                     Ref<Tensor>   g_out_,\n                     const Tensor& b_in,\n                     const Tensor& a_in,\n                     const Tensor& A_log,\n                     const Tensor& dt_bias,\n                     cudaStream_t  stream);\n\n// RMSNorm * SiLU-gate (fused output normalization)\nvoid invokeRMSNormGated(Ref<Tensor> hidden, const Tensor& gate, const Tensor& weight, float eps, cudaStream_t stream);\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/models/llama/llama_kernels.cu",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include <algorithm>\n#include <cstdint>\n#include <numeric>\n#include <type_traits>\n#include <utility>\n\n#include <cub/block/block_reduce.cuh>\n#include <cub/block/block_scan.cuh>\n\n#include \"src/turbomind/kernels/core/array.h\"\n#include \"src/turbomind/kernels/core/array_ops.h\"\n#include \"src/turbomind/kernels/core/common.h\"\n#include \"src/turbomind/macro.h\"\n#include \"src/turbomind/models/llama/llama_kernels.h\"\n#include \"src/turbomind/utils/cuda_utils.h\"\n#include \"src/turbomind/utils/dispatch.h\"\n\nnamespace turbomind {\n\n__global__ void gatherOutput(int*       output_ids,\n                             const int* ids,\n                             const int* context_length,\n                             int        max_context_len,\n                             int        max_gen_step,\n                             int        max_output_len,\n                             int        batch_size)\n{\n    const int batch_id    = blockIdx.x;\n    const int context_len = context_length[batch_id];\n    output_ids += batch_id * max_output_len;\n    for (int src_idx = threadIdx.x; src_idx < max_gen_step; src_idx += blockDim.x) {\n        // skip padding for src\n        if (context_len <= src_idx && src_idx < max_context_len) {\n            continue;\n        }\n        // skip padding for dst\n        const int dst_idx = src_idx < context_len ? src_idx : src_idx - (max_context_len - context_len);\n        if (dst_idx < max_output_len) {\n            output_ids[dst_idx] = ids[src_idx * batch_size + batch_id];\n        }\n    }\n}\n\nvoid invokeGatherOutput(int*         output_ids,\n                        const int*   ids,\n                        const int*   context_length,\n                        int          max_context_len,\n                        int          max_gen_step,\n                        int          max_output_len,\n                        int          batch_size,\n                        cudaStream_t stream)\n{\n    int block_size = 128;\n    int grid_size  = batch_size;\n    gatherOutput<<<grid_size, block_size, 0, stream>>>(\n        output_ids, ids, context_length, max_context_len, max_gen_step, max_output_len, batch_size);\n}\n\n__global__ void updateOutput(int**      request_output_ids_ptrs,\n                             int**      request_seqlen_ptrs,\n                             const int* output_ids,\n                             const int* sequence_lengths,\n                             const int* request_output_ids_lens,\n                             int        max_session_len,\n                             bool       token_generated)\n{\n    const int batch_id = blockIdx.x;\n\n    auto request_output_ids = request_output_ids_ptrs[batch_id];\n    auto request_seqlen     = request_seqlen_ptrs[batch_id];\n\n    output_ids += max_session_len * batch_id;\n\n    const int seqlen     = sequence_lengths[batch_id] + (int)token_generated;\n    const int output_len = min(seqlen, request_output_ids_lens[batch_id]);\n\n    for (int i = threadIdx.x; i < output_len; i += blockDim.x) {\n        request_output_ids[i] = output_ids[i];\n    }\n\n    *request_seqlen = seqlen;\n}\n\nvoid invokeUpdateOutput(int**        request_output_ids_ptrs,\n                        int**        request_seqlen_ptrs,\n                        const int*   output_ids,\n                        const int*   sequence_lengths,\n                        const int*   request_output_ids_lens,\n                        int          max_session_len,\n                        bool         token_generated,\n                        int          batch_size,\n                        cudaStream_t stream)\n{\n    constexpr int block_size = 128;\n    const int     grid_size  = batch_size;\n\n    updateOutput<<<grid_size, block_size, 0, stream>>>(request_output_ids_ptrs,\n                                                       request_seqlen_ptrs,\n                                                       output_ids,\n                                                       sequence_lengths,\n                                                       request_output_ids_lens,\n                                                       max_session_len,\n                                                       token_generated);\n}\n\ntemplate<int BLOCK_DIM>\n__global__ void compactOutputIds(\n    int* cu_output_ids, const int* output_ids, const int* sequence_lengths, int session_len, bool token_generated)\n{\n    typedef cub::BlockReduce<int, BLOCK_DIM>     BlockReduce;\n    __shared__ typename BlockReduce::TempStorage temp_storage;\n\n    const int batch_idx = blockIdx.x;\n\n    int end   = (batch_idx + BLOCK_DIM - 1) / BLOCK_DIM * BLOCK_DIM;  // align to BLOCK_DIM boundary\n    int count = 0;\n    for (int i = threadIdx.x; i < end; i += blockDim.x) {\n        int x = threadIdx.x < batch_idx ? sequence_lengths[threadIdx.x] : 0;\n        count += BlockReduce(temp_storage).Sum(x);\n        // https://nvlabs.github.io/cub/classcub_1_1_block_reduce.html\n        __syncthreads();\n    }\n\n    __shared__ int offset;\n\n    if (threadIdx.x == 0) {\n        offset = count;\n    }\n\n    __syncthreads();\n\n    auto dst = cu_output_ids + offset;\n\n    const int seq_len = sequence_lengths[batch_idx];\n\n    for (int i = threadIdx.x; i < seq_len; i += blockDim.x) {\n        dst[i] = output_ids[batch_idx * session_len + i];\n    }\n}\n\nvoid invokeCompactOutputIds(int*         cu_output_ids,\n                            const int*   output_ids,\n                            const int*   sequence_lengths,\n                            int          max_session_len,\n                            bool         token_generated,\n                            int          batch_size,\n                            cudaStream_t stream)\n{\n    constexpr int BLOCK_DIM = 128;\n    compactOutputIds<BLOCK_DIM><<<batch_size, BLOCK_DIM, 0, stream>>>(\n        cu_output_ids, output_ids, sequence_lengths, max_session_len, token_generated);\n}\n\ntemplate<int N, int C>\nstruct IndexedCopyParam {\n    Array<void*, N> src_ptr;\n    Array<void*, N> dst_ptr;\n    Array<int, N>   stride;\n    Array<int, C>   src_idx;\n    Array<int, C>   dst_idx;\n    int             max_stride;\n};\n\ntemplate<class T, int N, int C>\n__global__ void indexedCopy(IndexedCopyParam<N, C> param)\n{\n    const int bi = blockIdx.x;\n    const int si = param.src_idx[bi];\n    const int di = param.dst_idx[bi];\n    for (int i = threadIdx.x; i < param.max_stride; i += blockDim.x) {\n        PRAGMA_UNROLL\n        for (int k = 0; k < N; ++k) {\n            if (i < param.stride[k]) {\n                *((T*)param.dst_ptr[k] + param.stride[k] * di + i) =\n                    *((const T*)param.src_ptr[k] + param.stride[k] * si + i);\n            }\n        }\n    }\n}\n\ntemplate<class T, int N>\nvoid invokeIndexedCopyImpl(void**       h_src_ptr,\n                           void**       h_dst_ptr,\n                           const int*   h_elem_sz,\n                           const int*   h_src_idx,\n                           const int*   h_dst_idx,\n                           int          count,\n                           cudaStream_t st)\n{\n    dispatch(  // dispatch for num of copy operations\n        std::integer_sequence<int, 4, 8, 16, 32, 64, 128, 256>{},\n        [&](auto C) { return count <= C; },\n        [&](auto C) {\n            // maximum parameter size: sm<70: 4kB, sm>=70: 32kB\n            static_assert(sizeof(IndexedCopyParam<N, C>) <= 4096);\n            IndexedCopyParam<N, C> param{};\n            std::copy_n(h_src_ptr, N, param.src_ptr.data());\n            std::copy_n(h_dst_ptr, N, param.dst_ptr.data());\n            std::transform(h_elem_sz, h_elem_sz + N, param.stride.data(), [](int size) {\n                // Basic alignment check\n                FT_CHECK_WITH_INFO(size % sizeof(T) == 0, fmtstr(\"misalignment: %d %% %d\", size, (int)sizeof(T)));\n                return size / sizeof(T);\n            });\n            param.max_stride = *std::max_element(param.stride.begin(), param.stride.end());\n            auto copy_idx    = [](const int* src, int offset, int n, auto dst) {\n                return src ? (void)std::copy_n(src + offset, n, dst) : std::iota(dst, dst + n, offset);\n            };\n            for (int c = 0; c < count; c += C) {\n                int batch_size = std::min(count - c, (int)C);\n                copy_idx(h_src_idx, c, batch_size, param.src_idx.data());\n                copy_idx(h_dst_idx, c, batch_size, param.dst_idx.data());\n                indexedCopy<T><<<batch_size, 128, 0, st>>>(param);\n            }\n        });\n}\n\nvoid invokeIndexedCopy(void**       h_src_ptr,\n                       void**       h_dst_ptr,\n                       const int*   h_elem_sz,\n                       const int*   h_src_idx,\n                       const int*   h_dst_idx,\n                       int          count,\n                       int          n_copys,\n                       cudaStream_t st)\n{\n    auto success = dispatch(std::integer_sequence<int, 1, 2, 3, 4>{}, [&](auto N) {\n        if (N == n_copys) {\n            invokeIndexedCopyImpl<uint32_t, N>(h_src_ptr, h_dst_ptr, h_elem_sz, h_src_idx, h_dst_idx, count, st);\n            return true;\n        }\n        return false;\n    });\n    FT_CHECK(success);\n}\n\n__global__ void padLastTokenIds(int* token_ids, const int* context_length, int max_context_len, int batch_size)\n{\n    for (int bi = threadIdx.x; bi < batch_size; bi += blockDim.x) {\n        token_ids[(max_context_len - 1) * batch_size + bi] = token_ids[(context_length[bi] - 1) * batch_size + bi];\n    }\n}\n\nvoid invokePadLastTokenIds(\n    int* token_ids, const int* context_length, int max_context_len, int batch_size, cudaStream_t stream)\n{\n    padLastTokenIds<<<1, 512, 0, stream>>>(token_ids, context_length, max_context_len, batch_size);\n}\n\ntemplate<typename T>\n__global__ void getFeatureOfLastToken(T* output, const T* input, const int* cu_seqlens, int dims)\n{\n    int bi = blockIdx.x;\n    int ti = cu_seqlens[bi + 1] - 1;\n    for (int i = threadIdx.x; i < dims; i += blockDim.x) {\n        output[dims * bi + i] = input[dims * ti + i];\n    }\n}\n\nvoid invokeGetFeatureOfLastToken(\n    uint16_t* output, const uint16_t* input, const int* cu_seqlens, int dims, int batch_size, cudaStream_t stream)\n{\n    getFeatureOfLastToken<<<batch_size, 256, 0, stream>>>(output, input, cu_seqlens, dims);\n}\n\ntemplate<class T, int C>\nstruct BatchedCopyParam {\n    Array<T*, C>  src_ptr;\n    Array<T*, C>  dst_ptr;\n    Array<int, C> size;\n    int           count;\n};\n\ntemplate<int kThrPerCpy, class T, int C>\n__global__ void batchedCopy(BatchedCopyParam<T, C> param)\n{\n    const int ti = threadIdx.x + blockIdx.x * blockDim.x;\n    const int bi = ti / kThrPerCpy;\n    if (bi >= param.count) {\n        return;\n    }\n    const T* __restrict__ src = param.src_ptr[bi];\n    T* __restrict__ dst       = param.dst_ptr[bi];\n    int size                  = param.size[bi];\n    for (int i = ti % kThrPerCpy; i < size; i += kThrPerCpy) {\n        dst[i] = src[i];\n    }\n}\n\n// MSVC does not like CUDA kernel launch inside nested lambdas\ntemplate<class P>\nstruct BatchedCopyLauncher {\n    int          max_size;\n    int          count;\n    const P*     params;\n    cudaStream_t st;\n\n    template<int S>\n    void operator()(std::integral_constant<int, S>) const\n    {\n        constexpr int threads         = 128;\n        constexpr int items_per_block = threads / S;\n        const int     blocks          = (count + items_per_block - 1) / items_per_block;\n        batchedCopy<S><<<blocks, threads, 0, st>>>(*params);\n    }\n};\n\nvoid invokeBatchedCopy(void** src_ptr, void** dst_ptr, int* size, int count, cudaStream_t st)\n{\n    dispatch(\n        std::integer_sequence<int, 1, 8, 32, 128>{},\n        [&](auto C) { return count <= C; },\n        [&](auto C) {\n            using T = uint32_t;\n            BatchedCopyParam<T, C> params{};\n            // TODO: on CUDA 12.1 and sm_70+ this can be 32K\n            static_assert(sizeof(params) <= 4096);\n            for (int c = 0; c < count; c += C) {\n                const int bsz = std::min<int>(count - c, C);\n                params.count  = bsz;\n                for (int i = 0; i < bsz; ++i) {\n                    params.src_ptr[i] = (T*)src_ptr[c + i];\n                    params.dst_ptr[i] = (T*)dst_ptr[c + i];\n                    FT_CHECK(size[c + i] % sizeof(T) == 0);\n                    params.size[i] = size[c + i] / sizeof(T);\n                }\n                const int max_size = *std::max_element(params.size.begin(), params.size.end());\n                dispatch(\n                    std::integer_sequence<int, 1, 2, 4, 8, 16, 32, 64, 128>{},\n                    [&](auto S) { return max_size <= S; },\n                    BatchedCopyLauncher<BatchedCopyParam<T, C>>{max_size, count, &params, st});\n            }\n        });\n}\n\ntemplate<typename T>\n__global__ void maskOutput(T* output, const int* mask, int dim)\n{\n    int batch_idx = blockIdx.x;\n    output += dim * batch_idx;\n    int masked = mask[batch_idx];\n    for (int i = threadIdx.x; i < dim; i += blockDim.x) {\n        output[i] = (masked) ? output[i] : T();\n    }\n}\n\ntemplate<typename T>\nvoid invokeMask(T* output, const int* mask, int batch_size, int dim, cudaStream_t stream)\n{\n    maskOutput<<<batch_size, 1024, 0, stream>>>(output, mask, dim);\n}\n\n#ifdef ENABLE_FP32\ntemplate void invokeMask(float* output, const int* mask, int batch_size, int dim, cudaStream_t stream);\n#endif\ntemplate void invokeMask(half* output, const int* mask, int batch_size, int dim, cudaStream_t stream);\n#ifdef ENABLE_BF16\ntemplate void invokeMask(__nv_bfloat16* output, const int* mask, int batch_size, int dim, cudaStream_t stream);\n#endif\n\ntemplate<typename T, int vec_size>\n__global__ void castFloat2D(const T* input, float* output, int channels)\n{\n    const int vi = blockIdx.x * blockDim.x + threadIdx.x;\n    const int bi = blockIdx.y;\n    input += (size_t)bi * channels;\n    output += (size_t)bi * channels;\n\n    const int step = gridDim.x * blockDim.x * vec_size;\n\n    for (int i = vi * vec_size; i < channels; i += step) {\n        Array<T, vec_size> src;\n\n        if constexpr (sizeof(src) >= sizeof(uint)) {\n            Load(src, input + i);\n        }\n        else {\n            PRAGMA_UNROLL\n            for (int j = 0; j < vec_size; ++j) {\n                src[j] = input[i + j];\n            }\n        }\n\n        auto dst = cast<float>(src);\n\n        // store\n        Store(output + i, dst);\n    }\n}\n\nvoid invokeCastFloat2D(const core::Tensor& src, core::Tensor& dst, cudaStream_t stream)\n{\n    TM_CHECK(src.is_contiguous());\n    TM_CHECK(dst.is_contiguous());\n    TM_CHECK(src.shape() == dst.shape());\n\n    auto batch_size = src.shape(0);\n    auto channels   = src.shape(1);\n\n    auto invoke = [&](auto t, auto vec_size) {\n        using T                      = decltype(t);\n        constexpr int threads        = 256;\n        const int     blocks_per_tok = (channels + threads * vec_size - 1) / (threads * vec_size);\n        const dim3    blocks(blocks_per_tok, batch_size);\n        castFloat2D<T, vec_size.value><<<blocks, threads, 0, stream>>>(  //\n            src.data<T>(),\n            dst.data<float>(),\n            channels);\n    };\n\n    auto dispatch_t = [&](auto vec_size) {\n        switch (src.dtype()) {\n            case kFloat32:\n                return invoke(float{}, vec_size);\n                break;\n            case kFloat16:\n                return invoke(half{}, vec_size);\n                break;\n#ifdef ENABLE_BF16\n            case kBfloat16:\n                return invoke(__nv_bfloat16{}, vec_size);\n                break;\n#endif\n            default:\n                TM_UNREACHABLE;\n        }\n    };\n\n    if (channels % 4 == 0) {\n        return dispatch_t(std::integral_constant<int, 4>{});\n    }\n    else if (channels % 2 == 0) {\n        return dispatch_t(std::integral_constant<int, 2>{});\n    }\n    else {\n        return dispatch_t(std::integral_constant<int, 1>{});\n    }\n}\n\ntemplate<class T>\n__global__ void CollectHiddenStates_Kernel(const T* src, const int* idxs, T* dst, int dim)\n{\n    const int bi = blockIdx.x;\n    const int ti = idxs[bi];\n\n    if (ti < 0) {\n        return;\n    }\n\n    src += ti * dim;\n    dst += bi * dim;\n\n    for (int di = threadIdx.x; di < dim; di += blockDim.x) {\n        dst[di] = src[di];\n    }\n}\n\nvoid CollectHiddenStates(const Tensor& src, const Buffer_<int>& idxs, Ref<Tensor> dst, cudaStream_t st)\n{\n    const auto stride = byte_size(src.dtype(), src.stride(0));\n\n    auto invoke = [&](auto t) {\n        using T           = decltype(t);\n        const int dim     = stride / sizeof(T);\n        const int threads = round_up(min(dim, 1024), WARP_SIZE);\n        const int blocks  = idxs.size();\n        CollectHiddenStates_Kernel<<<blocks, threads, 0, st>>>(\n            (const T*)src.raw_data(), idxs.data(), (T*)dst.get().raw_data(), dim);\n    };\n\n    if (stride % sizeof(uint4) == 0) {\n        invoke(uint4{});\n    }\n    else if (stride % sizeof(uint2) == 0) {\n        invoke(uint2{});\n    }\n    else if (stride % sizeof(uint1) == 0) {\n        invoke(uint1{});\n    }\n    else if (stride % sizeof(ushort) == 0) {\n        invoke(ushort{});\n    }\n    else {\n        TM_CHECK(0) << \"unsupported byte stride: \" << stride;\n    }\n}\n\ntemplate<int BLOCK_DIM, int MAX_COUNT>\n__global__ void\nBatchPrefixSumKernel(Array<const int*, MAX_COUNT> srcs, Array<int, MAX_COUNT> ns, Array<int*, MAX_COUNT> dsts)\n{\n    const int  bi  = blockIdx.x;\n    const int* src = srcs[bi];\n    int*       dst = dsts[bi];\n    const int  n   = ns[bi];\n\n    using BlockScan = cub::BlockScan<int, BLOCK_DIM>;\n\n    __shared__ typename BlockScan::TempStorage temp_storage;\n\n    int prefix{};\n    for (int i = threadIdx.x; i < round_up(n, BLOCK_DIM); i += BLOCK_DIM) {\n        if (i >= BLOCK_DIM) {\n            __syncthreads();\n        }\n        int data = i < n ? src[i] : 0;\n        int sum{};\n        BlockScan{temp_storage}.ExclusiveSum(data, data, sum);\n        if (i < n) {\n            dst[i] = prefix + data;\n        }\n        prefix += sum;\n    }\n\n    if (threadIdx.x == 0) {\n        dst[n] = prefix;\n    }\n}\n\nvoid BatchPrefixSum(const int** srcs, const int* ns, int** dsts, int count, cudaStream_t st)\n{\n    constexpr int max_count = 1;\n\n    Array<const int*, max_count> p_srcs{};\n    Array<int*, max_count>       p_dsts{};\n    Array<int, max_count>        p_ns{};\n\n    for (int i = 0; i < count; ++i) {\n        p_srcs[i] = srcs[i];\n        p_dsts[i] = dsts[i];\n        p_ns[i]   = ns[i];\n    }\n\n    TM_CHECK_LE(count, max_count);\n\n    constexpr int block = 256;\n    const int     grid  = count;\n\n    BatchPrefixSumKernel<block><<<grid, block, 0, st>>>(p_srcs, p_ns, p_dsts);\n}\n\n__global__ void AppendTokenIdsKernel(int** token_ids_ptrs, const int* output_ids, const int* positions, int batch_size)\n{\n    int i = threadIdx.x + blockIdx.x * blockDim.x;\n    if (i < batch_size) {\n        int* token_ids = token_ids_ptrs[i];\n        int  pos       = positions[i];\n        token_ids[pos] = output_ids[i];\n    }\n}\n\nvoid AppendTokenIds(\n    int** token_ids_ptrs, const int* output_ids, const int* positions, int batch_size, cudaStream_t stream)\n{\n    constexpr int block = 128;\n    const int     grid  = cdiv(batch_size, block);\n    AppendTokenIdsKernel<<<grid, block, 0, stream>>>(token_ids_ptrs, output_ids, positions, batch_size);\n}\n\ntemplate<typename T>\n__global__ void SigmoidGateMultiplyKernel(T* attn, const T* gate_base, int dim, int gate_stride, int num_tokens)\n{\n    const int ti = blockIdx.x;\n    const int di = threadIdx.x + blockIdx.y * blockDim.x;\n    if (ti >= num_tokens || di >= dim) {\n        return;\n    }\n    float g             = (float)gate_base[ti * gate_stride + di];\n    float s             = 1.0f / (1.0f + __expf(-g));\n    float a             = (float)attn[ti * dim + di];\n    attn[ti * dim + di] = (T)(a * s);\n}\n\nvoid invokeSigmoidGateMultiply(\n    void* attn, const void* gate_base, int dim, int gate_stride, int num_tokens, DataType dtype, cudaStream_t stream)\n{\n    constexpr int block = 256;\n    const dim3    grid(num_tokens, cdiv(dim, block));\n\n    auto invoke = [&](auto t) {\n        using T = decltype(t);\n        SigmoidGateMultiplyKernel<<<grid, block, 0, stream>>>(\n            (T*)attn, (const T*)gate_base, dim, gate_stride, num_tokens);\n    };\n\n    TM_DISPATCH_PRIMARY_DTYPES(dtype, invoke);\n}\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/models/llama/llama_kernels.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include \"src/turbomind/core/core.h\"\n\n#include <cstdint>\n\n#include <cuda_runtime.h>\nnamespace turbomind {\n\nvoid invokeGatherOutput(int*         output_ids,\n                        const int*   ids,\n                        const int*   context_length,\n                        int          max_context_len,\n                        int          max_gen_step,\n                        int          max_output_len,\n                        int          batch_size,\n                        cudaStream_t stream);\n\nvoid invokeUpdateOutput(int**        request_output_ids_ptrs,\n                        int**        request_seqlen_ptrs,\n                        const int*   output_ids,\n                        const int*   sequence_lengths,\n                        const int*   request_output_ids_lens,\n                        int          max_session_len,\n                        bool         token_generated,\n                        int          batch_size,\n                        cudaStream_t stream);\n\n// [aaa, bbbb, cc, ddd] -> [aaabbbbccddd]\nvoid invokeCompactOutputIds(int*         cu_output_ids,\n                            const int*   output_ids,\n                            const int*   sequence_lengths,\n                            int          max_session_len,\n                            bool         token_generated,\n                            int          batch_size,\n                            cudaStream_t stream);\n\nvoid invokeIndexedCopy(void**       h_src_ptr,\n                       void**       h_dst_ptr,\n                       const int*   h_elem_sz,\n                       const int*   h_src_idx,\n                       const int*   h_dst_idx,\n                       int          count,\n                       int          n_copys,\n                       cudaStream_t st);\n\nvoid invokeBatchedCopy(void** src_ptr, void** dst_ptr, int* size, int count, cudaStream_t st);\n\n// ABCDe            ABCDe     e\n// ABCDEFGHIJk      ABCDEFGHIJk\n// ABCDEFGHi    ->  ABCDEFGHi i\n// ABCDEFGh         ABCDEFGh  h\n// ABCd             ABCd      d\nvoid invokePadLastTokenIds(\n    int* token_ids, const int* context_length, int max_context_len, int batch_size, cudaStream_t stream);\n\nvoid invokeGetFeatureOfLastToken(\n    uint16_t* output, const uint16_t* input, const int* cu_seqlens, int dims, int batch_size, cudaStream_t stream);\n\ntemplate<typename T>\nvoid invokeMask(T* output, const int* mask, int batch_size, int dim, cudaStream_t stream);\n\nvoid invokeCastFloat2D(const core::Tensor& src, core::Tensor& dst, cudaStream_t stream);\n\nvoid CollectHiddenStates(const Tensor& src, const Buffer_<int>& idxs, Ref<Tensor> dst, cudaStream_t st);\n\nvoid BatchPrefixSum(const int** srcs, const int* ns, int** dsts, int count, cudaStream_t st);\n\ninline void PrefixSum(const int* src, int n, int* dst, cudaStream_t st)\n{\n    return BatchPrefixSum(&src, &n, &dst, 1, st);\n}\n\nvoid AppendTokenIds(int**        token_ids_ptrs,  //\n                    const int*   output_ids,\n                    const int*   positions,\n                    int          batch_size,\n                    cudaStream_t stream);\n\n// Apply sigmoid gating: attn[i] *= sigmoid(gate[i])\n// attn:        [num_tokens, dim], contiguous\n// gate_base:   pointer to first gate element in QKV buffer\n// gate_stride: stride between tokens in QKV buffer (elements)\nvoid invokeSigmoidGateMultiply(\n    void* attn, const void* gate_base, int dim, int gate_stride, int num_tokens, DataType dtype, cudaStream_t stream);\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/models/llama/llama_params.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include <cstddef>\n#include <map>\n#include <regex>\n#include <set>\n#include <string>\n\n#include \"src/turbomind/core/data_type.h\"\n#include \"src/turbomind/kernels/activation.h\"\n#include \"src/turbomind/models/llama/llama_rope.h\"\n\nnamespace turbomind {\n\nstruct MLAParam {\n    int q_lora_rank;\n    int kv_lora_rank;\n    int qk_rope_dim;\n    int v_head_dim;\n};\n\nstruct ModelParam {\n    size_t   head_num;\n    size_t   head_dim;\n    size_t   kv_head_num;\n    size_t   hidden_units;\n    size_t   layer_num;\n    size_t   vocab_size;\n    size_t   embedding_size;\n    float    norm_eps;\n    int      quant_policy;\n    bool     attn_bias;\n    bool     attn_sink;\n    bool     mlp_bias;\n    DataType data_type;\n\n    // Weight types for mixed quantization support.\n    // Models like mixed AWQ (e.g. QuantTrio GLM-4.7-Flash) quantize FFN/expert\n    // weights to int4 but keep attention weights as fp16. GptOss mxfp4 quantizes\n    // only MoE experts to e2m1 while keeping attention and shared experts as fp16.\n    //\n    //                  weight_type   ffn_weight_type   expert_weight_type\n    //  Pure fp16       float16       float16           float16\n    //  Full AWQ        int4          int4              int4\n    //  Mixed AWQ       float16       int4              int4\n    //  GptOss mxfp4    bfloat16      bfloat16          e2m1\n    DataType weight_type;         // attention weights\n    DataType expert_weight_type;  // MoE routed expert weights\n    DataType ffn_weight_type;     // dense FFN / shared expert weights\n\n    int      group_size;\n    MLAParam mla;\n    bool     qk_norm;\n    int      tune_layer_num;\n\n    ActivationType act_type;\n\n    std::vector<int> window_size;\n    std::vector<int> inter_size;\n    std::vector<int> layer_types;\n\n    // Qwen3.5 Gated DeltaNet linear attention params\n    int linear_key_head_dim    = 0;\n    int linear_value_head_dim  = 0;\n    int linear_conv_kernel_dim = 0;\n    int linear_num_key_heads   = 0;\n    int linear_num_value_heads = 0;\n\n    DataType linear_state_dtype = {};\n\n    bool attn_output_gate = false;  // Qwen3.5: doubles Q projection in full-attention layers\n\n    // Layer indices whose MoE experts use data_type (fp16) instead of\n    // expert_weight_type (e.g. int4).  Populated from modules_to_not_convert\n    // patterns like 'model.layers.0.'.\n    std::set<int> unquantized_expert_layers;\n};\n\ninline bool HasLinearAttention(const ModelParam& model_param)\n{\n    for (int type : model_param.layer_types) {\n        if (type == 1) {\n            return true;\n        }\n    }\n    return false;\n}\n\n/// TODO: rename all `gate` in the context of MoE router to `router`\nstruct MoeParam {\n    enum Method\n    {\n        kNaive,\n        kFused\n    } method;\n\n    int   experts_per_token;\n    int   inter_size;\n    bool  norm_topk_prob;\n    bool  shared_gate;\n    float routed_scale;\n\n    bool router_bias;\n\n    int         topk_group;\n    std::string topk_method;\n    int         n_group;\n    std::string scoring_func;\n    int         router_n_groups;\n\n    std::vector<int> expert_num;\n};\n\nstruct AttentionParam {\n    float softmax_scale;\n    int   cache_block_seq_len;\n    // logn attention\n    bool use_logn_attn;\n    int  max_position_embeddings;\n    // rotary embedding\n    RopeParam rope;\n};\n\nstruct EngineParam {\n    // batch params\n    int max_batch_size;\n    int session_len;\n    int step_length;\n\n    // cache params\n    float cache_max_block_count;\n    int   cache_chunk_size;\n    bool  enable_prefix_caching;\n    bool  enable_metrics;\n\n    // chunking params\n    int max_forward_token_num;\n    int max_context_token_num;\n    int num_tokens_per_iter;\n    int max_prefill_iters;\n\n    // parallel params\n    int outer_dp_size;\n    int outer_dp_rank;\n    int attn_dp_size;\n    int attn_dp_rank;\n    int attn_tp_size;\n    int attn_tp_rank;\n    int attn_cp_size;\n    int attn_cp_rank;\n    int mlp_tp_size;\n    int mlp_tp_rank;\n\n    // multi-node\n    int nnodes;\n    int node_rank;\n\n    std::vector<int> devices;\n};\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/models/llama/llama_rope.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include <cmath>\n#include <map>\n#include <string>\n\n#include <cuda_runtime.h>\n\nnamespace turbomind {\n\nenum class RopeType\n{\n    kNull,\n    kDefault,\n    kLinear,\n    kDynamic,\n    kYarn,\n    kLlama3,\n    kMrope,\n};\n\ninline RopeType GetRoPEType(const std::string& type)\n{\n    std::map<std::string, RopeType> lookup = {{\"default\", RopeType::kDefault},\n                                              {\"linear\", RopeType::kLinear},\n                                              {\"dynamic\", RopeType::kDynamic},\n                                              {\"yarn\", RopeType::kYarn},\n                                              {\"llama3\", RopeType::kLlama3},\n                                              {\"mrope\", RopeType::kMrope}};\n    return lookup.at(type);\n}\n\nstruct YarnRopeParam {\n    float attention_factor;\n    float beta_fast;\n    float beta_slow;\n};\n\nstruct Llama3RopeParam {\n    float low_freq_factor;\n    float high_freq_factor;\n    int   original_max_position_embeddings;\n};\n\nstruct MropeRopeParam {\n    int3 section;\n};\n\nstruct RopeParam {\n    RopeType type;\n    // common\n    float base;\n    int   dim;\n    float factor;\n    int   max_position_embeddings;\n    // unique\n    union {\n        YarnRopeParam   yarn;\n        Llama3RopeParam llama3;\n        MropeRopeParam  mrope;\n    };\n};\n\nstruct YarnRopeKernelParam {\n    float scale_factor;\n    float attention_factor;\n    float ramp_inv_factor_div_2;\n    float ramp_inv_factor_mul_min;\n};\n\nstruct Llama3RopeKernelParam {\n    float scale_factor;\n    float alpha;\n    float beta;\n};\n\nstruct MropeRopeKernelParam {\n    int3 section;\n\n    int  stride{};\n    int* position_ids{};\n    int* position_delta{};\n    int* length{};\n};\n\nstruct RopeKernelParam {\n    RopeType type;\n\n    float* base{};  // for dynamic ntk\n    int    dim;\n    float  scale_factor;\n    float  inv_factor;\n\n    YarnRopeKernelParam   yarn;\n    Llama3RopeKernelParam llama3;\n    MropeRopeKernelParam  mrope;\n};\n\ninline void init_rope_kernel_param(const RopeParam& rope, RopeKernelParam& rope_kernel)\n{\n    rope_kernel.type         = rope.type;\n    rope_kernel.dim          = rope.dim;\n    rope_kernel.scale_factor = -std::log2(rope.base) / rope.dim;\n    if (rope.type == RopeType::kDynamic) {\n        rope_kernel.inv_factor = 1.f;\n    }\n    else {\n        rope_kernel.inv_factor = (rope.factor != 0.f) ? 1.0 / rope.factor : 1.f;\n    }\n\n    if (rope.type == RopeType::kYarn) {\n        auto&        src = rope.yarn;\n        auto&        dst = rope_kernel.yarn;\n        const double PI  = 3.14159265358979323846;\n\n        auto find_correction_dim = [&](float num_rotations) {\n            return (rope.dim * std::log(rope.max_position_embeddings / (num_rotations * 2 * PI)))\n                   / (2 * std::log(rope.base));\n        };\n\n        auto find_correction_range = [&](float low_rot, float high_rot, float& low, float& high) {\n            low  = std::floor(find_correction_dim(low_rot));\n            high = std::ceil(find_correction_dim(high_rot));\n            low  = std::max(low, 0.f);\n            high = std::min(high, rope.dim - 1.f);\n        };\n\n        float low, high;\n        find_correction_range(src.beta_fast, src.beta_slow, low, high);\n        // https://github.com/huggingface/transformers/blob/6c3f168b36882f0beebaa9121eafa1928ba29633/src/transformers/modeling_rope_utils.py#L216\n        if (low == high) {\n            high += 0.001f;\n        }\n        dst.ramp_inv_factor_div_2   = 1.0 / (high - low) / 2.0;\n        dst.ramp_inv_factor_mul_min = 1.0 / (high - low) * low;\n        dst.attention_factor        = src.attention_factor;\n    }\n    else if (rope.type == RopeType::kLlama3) {\n        auto& src = rope.llama3;\n        auto& dst = rope_kernel.llama3;\n\n        const double PI                   = 3.14159265358979323846;\n        float        inv_diff_freq_factor = 1.0 / (src.high_freq_factor - src.low_freq_factor);\n        dst.alpha                         = src.original_max_position_embeddings / (2 * PI) * inv_diff_freq_factor;\n        dst.beta                          = src.low_freq_factor * inv_diff_freq_factor;\n    }\n\n    else if (rope.type == RopeType::kMrope) {\n        auto& src     = rope.mrope;\n        auto& dst     = rope_kernel.mrope;\n        dst.section.x = src.section.x * 2;\n        dst.section.y = src.section.y * 2 + dst.section.x;\n        dst.section.z = src.section.z * 2 + dst.section.y;\n    }\n}\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/models/llama/llama_utils.cu",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include <cmath>\n#include <cstdio>\n#include <cstdlib>\n#include <cstring>\n#include <type_traits>\n#include <vector>\n\n#include <cuda_fp16.h>\n#include <curand_kernel.h>\n#include <thrust/device_vector.h>\n#include <thrust/execution_policy.h>\n#include <thrust/host_vector.h>\n\n#include \"src/turbomind/models/llama/llama_utils.h\"\n#include \"src/turbomind/utils/cuda_utils.h\"\n\nnamespace turbomind {\n\nCmpMode compare_mode = kCmpRead;\n// CmpMode compare_mode = kCmpWrite;\n\ntemplate<typename T>\nvoid CheckNan(const T* ptr, size_t size, std::string key, cudaStream_t stream)\n{\n    std::vector<T> h_data(size);\n    cudaMemcpyAsync(h_data.data(), ptr, sizeof(T) * size, cudaMemcpyDefault, stream);\n\n    check_cuda_error(cudaStreamSynchronize(stream));\n\n    size_t nan_cnt = 0;\n    for (const auto& x : h_data) {\n        nan_cnt += std::isnan(static_cast<float>(x));\n    }\n    if (nan_cnt) {\n        std::cerr << key << \": NaN count \" << nan_cnt << \"\\n\";\n    }\n}\n\ntemplate<typename T>\nvoid CmpRead(T* ptr, size_t size, std::string key, cudaStream_t stream)\n{\n    // read a from file\n    std::vector<T> h_a(size);\n    {\n        const auto    filename = \"tmp/\" + key + \".cmp\";\n        std::ifstream ifs(filename, std::ios::binary);\n        if (!ifs.is_open()) {\n            std::cerr << key << \": failed to open \" + filename << \"\\n\";\n            return;\n        }\n        ifs.seekg(0, ifs.end);\n        const auto actual_size_in_bytes = ifs.tellg();\n        ifs.seekg(0, ifs.beg);\n        const auto expect_size_in_bytes = sizeof(T) * size;\n        if (actual_size_in_bytes != expect_size_in_bytes) {\n            std::cerr << key << \": file size in bytes mismatch, expect \" << expect_size_in_bytes << \", got \"\n                      << actual_size_in_bytes << \"\\n\";\n            return;\n        }\n        ifs.read((char*)h_a.data(), sizeof(T) * h_a.size());\n    }\n    std::vector<T> h_b(size);\n    check_cuda_error(cudaMemcpyAsync(h_b.data(), ptr, sizeof(T) * size, cudaMemcpyDefault, stream));\n    check_cuda_error(cudaStreamSynchronize(stream));\n\n    using Tacc         = std::conditional_t<std::is_integral_v<T>, int64_t, float>;\n    constexpr Tacc eps = std::is_integral_v<T> ? 1 : 1e-8f;\n\n    Tacc asum{};\n    Tacc rsum{};\n    Tacc amean_r{};\n    Tacc amean_x{};\n    for (size_t i = 0; i < size; ++i) {\n        Tacc x        = (Tacc)h_b[i];\n        Tacc r        = (Tacc)h_a[i];\n        Tacc abs_diff = std::abs(x - r);\n        Tacc rel_diff = abs_diff / std::max(std::max(std::abs(r), std::abs(x)), eps);\n        asum += abs_diff;\n        rsum += rel_diff;\n        amean_x += std::abs(x);\n        amean_r += std::abs(r);\n    }\n\n    fprintf(stderr,\n            \"%15s%15f%15f%15f%15f%15f\\n\",\n            key.c_str(),\n            (float)amean_x / (float)size,\n            (float)amean_r / (float)size,\n            (float)asum,\n            (float)asum / (float)size,\n            (float)rsum / (float)size);\n\n    check_cuda_error(cudaMemcpyAsync(ptr, h_a.data(), sizeof(T) * h_a.size(), cudaMemcpyDefault, stream));\n    check_cuda_error(cudaStreamSynchronize(stream));\n}\n\ntemplate<typename T>\nvoid CmpWrite(T* ptr, size_t size, std::string key, cudaStream_t stream)\n{\n    std::vector<T> a(size);\n    // copy a to host\n    check_cuda_error(cudaMemcpyAsync(a.data(), ptr, sizeof(T) * size, cudaMemcpyDefault, stream));\n    check_cuda_error(cudaStreamSynchronize(stream));\n    // write to file\n    {\n        std::ofstream ofs(\"tmp/\" + key + \".cmp\", std::ios::binary);\n        ofs.write((char*)a.data(), sizeof(T) * a.size());\n    }\n}\n\ntemplate<typename T>\nvoid Compare(T* ptr, size_t size, std::string key, CmpMode mode, cudaStream_t stream)\n{\n    // std::cerr << \"Comparing \" << key << \"\\n\";\n    if (mode == kCmpRead) {\n        CmpRead(ptr, size, key, stream);\n    }\n    else if (mode == kCmpWrite) {\n        CmpWrite(ptr, size, key, stream);\n    }\n    else {\n        // kCmpNone\n    }\n}\n\ntemplate void Compare(int* ptr, size_t size, std::string key, CmpMode mode, cudaStream_t stream);\ntemplate void Compare(float* ptr, size_t size, std::string key, CmpMode mode, cudaStream_t stream);\ntemplate void Compare(half* ptr, size_t size, std::string key, CmpMode mode, cudaStream_t stream);\ntemplate void Compare(__nv_bfloat16* ptr, size_t size, std::string key, CmpMode mode, cudaStream_t stream);\n\ntemplate void CheckNan(const float* ptr, size_t size, std::string key, cudaStream_t stream);\ntemplate void CheckNan(const half* ptr, size_t size, std::string key, cudaStream_t stream);\n\nsize_t curandStateGetSize()\n{\n    return sizeof(curandState_t);\n}\n\nbool isDebug()\n{\n    static const bool is_debug = [] {\n        const auto level = std::getenv(\"TM_DEBUG_LEVEL\");\n        if (level && level == std::string(\"DEBUG\")) {\n            return true;\n        }\n        return false;\n    }();\n    return is_debug;\n}\n\nint64_t& gSequenceIds(int batch_idx)\n{\n    thread_local std::vector<int64_t> ids{};\n    if (batch_idx >= ids.size()) {\n        ids.resize(batch_idx + 1, -1);\n    }\n    return ids.at(batch_idx);\n}\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/models/llama/llama_utils.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n#include \"src/turbomind/utils/nvtx_utils.h\"\n#include <cuda_runtime.h>\n#include <sstream>\n#include <string>\n#include <vector>\n\nnamespace turbomind {\n\nenum QuantPolicy\n{\n    kNone = 0x00,\n    // reserve 0x01 and 0x02 for backward compatibility\n    kReserve1 = 0x01,\n    kReserve2 = 0x02,\n    // quantize cache kv\n    kCacheKVInt8 = 0x08,\n    kCacheKVInt4 = 0x04,\n};\n\nenum CmpMode\n{\n    kCmpNone,\n    kCmpRead,\n    kCmpWrite,\n};\n\nextern CmpMode compare_mode;\n\ntemplate<typename T>\nvoid Compare(T* ptr, size_t size, std::string key, CmpMode mode, cudaStream_t stream);\n\ntemplate<typename T>\nvoid CheckNan(const T* ptr, size_t size, std::string key, cudaStream_t stream);\n\nnamespace detail {\n\ntemplate<typename T>\nstd::string to_string(T x)\n{\n    return std::to_string(x);\n}\n\ninline std::string to_string(std::string x)\n{\n    return x;\n}\n\n}  // namespace detail\n\ntemplate<typename... Args>\nstd::string Concat(std::string key, Args&&... args)\n{\n    std::vector<std::string> args_str{detail::to_string((Args &&) args)...};\n    for (const auto& s : args_str) {\n        key.append(\"_\");\n        key.append(s);\n    }\n    return key;\n}\n\nsize_t curandStateGetSize();\n\nbool isDebug();\n\nstruct NvtxScope {\n    explicit NvtxScope(const std::string& name)\n    {\n        PUSH_RANGE(name.c_str());\n    }\n\n    ~NvtxScope()\n    {\n        POP_RANGE;\n    }\n};\n\nint64_t& gSequenceIds(int batch_idx);\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/models/llama/mla_utils.cu",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include <cuda_bf16.h>\n\n#include \"src/turbomind/core/check.h\"\n#include \"src/turbomind/kernels/core/array_ops.h\"\n#include \"src/turbomind/kernels/core/common.h\"\n#include \"src/turbomind/kernels/core/math.h\"\n\nnamespace turbomind {\n\ntemplate<class T, int vec_size>\n__global__ void mla_copy_qkv_kernel(T*       qkv,        // [s, head_num + 2, kv_lora_rank + rope_dim]\n                                    const T* q,          // [s, head_num,     kv_lora_rank + rope_dim]\n                                    const T* kv_a_k_pe,  // [s, kv_lora_rank + rope_dim]\n                                    int      head_num,   // q head num\n                                    int      head_dim,   // kv_lora_rank + rope_dim\n                                    int      kv_lora_rank,\n                                    int      rope_dim)\n{\n    const int type = blockIdx.y;\n\n    const int64_t ti = blockIdx.x;\n    const int     di = threadIdx.x;\n\n    const int offset = di * vec_size < rope_dim ? kv_lora_rank : -rope_dim;\n\n    Array<T, vec_size> data;\n\n    if (type == 0) {  // Q\n        for (int hi = threadIdx.y; hi < head_num; hi += blockDim.y) {\n            if (di * vec_size < head_dim) {\n                Load(data, &q[ti * head_num * head_dim + hi * head_dim + di * vec_size + offset]);\n                Store(&qkv[ti * (head_num + 1) * head_dim + hi * head_dim + di * vec_size], data);\n            }\n        }\n    }\n    else if (type == 1) {  // K/V\n        if (threadIdx.y == 0) {\n            if (di * vec_size < head_dim) {\n                Ldg(data, &kv_a_k_pe[ti * head_dim + di * vec_size + offset]);\n                Store(&qkv[ti * (head_num + 1) * head_dim + (head_num + 0) * head_dim + di * vec_size], data);\n            }\n        }\n    }\n}\n\ntemplate<class T>\nvoid invokeMLACopyQKV(T*           qkv,\n                      const T*     q,\n                      const T*     kv_a_k_pe,\n                      int          token_num,\n                      int          head_num,\n                      int          kv_lora_rank,\n                      int          rope_dim,\n                      cudaStream_t stream)\n{\n    constexpr int vec_size = 16 / sizeof(T);\n\n    const int head_dim = kv_lora_rank + rope_dim;  // 512 + 64 = 576\n\n    dim3 block(round_up(head_dim / vec_size, WARP_SIZE), head_num);\n\n    // make sure block size <= 1024\n    while (block.x * block.y > 1024) {\n        block.y /= 2;\n    }\n\n    const dim3 grid(token_num, 2);\n\n    mla_copy_qkv_kernel<T, vec_size>\n        <<<grid, block, 0, stream>>>(qkv, q, kv_a_k_pe, head_num, head_dim, kv_lora_rank, rope_dim);\n}\n\nvoid MLACopyQKV(DataType     dtype,\n                void*        qkv,\n                const void*  q,\n                const void*  kv_a_k_pe,\n                int          token_num,\n                int          head_num,\n                int          kv_lora_rank,\n                int          rope_dim,\n                cudaStream_t stream)\n{\n    auto invoke = [&](auto t) {\n        using T = decltype(t);\n        invokeMLACopyQKV(\n            (T*)qkv, (const T*)q, (const T*)kv_a_k_pe, token_num, head_num, kv_lora_rank, rope_dim, stream);\n    };\n\n    TM_CHECK_EQ(byte_size(dtype, 1), 2) << \"unsupported data type: \" << dtype;\n\n    return invoke(uint16_t{});\n}\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/models/llama/mla_utils.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n#pragma once\n\n#include <cuda_runtime.h>\n\n#include \"src/turbomind/core/data_type.h\"\n\nnamespace turbomind {\n\nvoid MLACopyQKV(DataType     dtype,\n                void*        qkv,\n                const void*  q,\n                const void*  kv_a,\n                int          token_num,\n                int          head_num,\n                int          kv_lora_rank,\n                int          rope_dim,\n                cudaStream_t stream);\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/models/llama/moe_ffn_layer.cc",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include <cuda_runtime.h>\n\n#include \"src/turbomind/core/context.h\"\n#include \"src/turbomind/kernels/activation.h\"\n#include \"src/turbomind/kernels/norm/rms_norm.h\"\n\n#include \"src/turbomind/models/llama/LlamaDenseWeight.h\"\n#include \"src/turbomind/models/llama/LlamaLinear.h\"\n#include \"src/turbomind/models/llama/llama_params.h\"\n#include \"src/turbomind/models/llama/llama_utils.h\"\n#include \"src/turbomind/models/llama/moe_ffn_layer.h\"\n\n#include \"src/turbomind/utils/anomaly_handler.h\"\n#include \"src/turbomind/utils/cuda_utils.h\"\n\n// #include \"dbg.h\"\n\nnamespace turbomind {\n\nMoeFfnLayer::MoeFfnLayer(const ModelParam& model, const MoeParam& param, const EngineParam& engine, const Context& ctx):\n    inter_size_(param.inter_size / engine.mlp_tp_size),\n    hidden_dim_(model.hidden_units),\n    tp_size_(engine.mlp_tp_size),\n    param_(param),\n    is_warm_up_{*ctx.is_warm_up},\n    linear_(*ctx.linear)\n{\n    TM_CHECK(!param.expert_num.empty());\n\n    const int max_expert_num = *std::max_element(param.expert_num.begin(), param.expert_num.end());\n\n    if (param_.method == MoeParam::kFused) {\n        // pass\n    }\n    else {\n        expert_ffn_ = std::make_unique<LlamaFfnLayer>(model, ctx);\n    }\n\n    h_offsets_ = {max_expert_num + 1, kCPUpinned};\n\n    const int max_token_num = engine.max_forward_token_num * engine.attn_dp_size;\n    const int pad_token_num = (max_token_num + kMoeGateVecSize - 1) / kMoeGateVecSize * kMoeGateVecSize;\n\n    // dbg(inter_size_,\n    //     hidden_dim_,\n    //     tp_size_,\n    //     param_.method,\n    //     param.expert_num,\n    //     max_expert_num,\n    //     max_token_num,\n    //     pad_token_num,\n    //     param_.experts_per_token);\n\n    masks_   = {max_expert_num * pad_token_num, kDEVICE};\n    f2n_     = {param_.experts_per_token * max_token_num, kDEVICE};\n    f2E_     = {param_.experts_per_token * max_token_num, kDEVICE};\n    en2f_    = {param_.experts_per_token * max_token_num, kDEVICE};\n    scales_  = {param_.experts_per_token * max_token_num, kDEVICE};\n    offsets_ = {max_expert_num + 1, kDEVICE};\n    accum_   = {max_expert_num * kMoeGateMaxTiles, kDEVICE};\n}\n\nTensor_<float> MoeFfnLayer::Gate(const Tensor& input, const LlamaDenseWeight& gate)\n{\n    auto& weight = gate.weight;\n    TM_CHECK_EQ(input.shape(1), weight.shape(0));\n    Tensor_<float> logits{{input.shape(0), weight.shape(1)}, kDEVICE};\n    linear_.Forward(input, gate, logits);\n    sync_check_cuda_error();\n    ApplyBias(logits, gate.bias, core::Context::stream().handle());\n    sync_check_cuda_error();\n    return logits;\n}\n\nvoid MoeFfnLayer::Forward(ForwardParam& p)\n{\n    const int   tokens = p.input.shape(0);\n    const auto& moe    = *p.weights;\n\n    const size_t padded     = (tokens + kMoeGateVecSize - 1) / kMoeGateVecSize * kMoeGateVecSize;\n    const int    expert_num = moe.experts.size();\n\n    FT_CHECK(expert_num);\n\n    auto logits = Gate(p.input, moe.gate);\n\n    TM_DEBUG_TENSOR(logits, \"logits\", 2);\n\n    const auto st = core::Context::stream().handle();\n\n    // dump_logits(tokens, layer_id);\n\n    if (param_.topk_method == \"noaux_tc\") {\n        // invokeMoeGate_NoAuxTC clears accum and masks internally\n        TM_CHECK_EQ(param_.n_group, 1);\n        TM_CHECK_EQ(param_.topk_group, 1);\n        const float* correction_bias =\n            (moe.score_correction_bias.size() > 0) ? moe.score_correction_bias.data<float>() : nullptr;\n        invokeMoeGate_NoAuxTC(f2n_.data(),\n                              f2E_.data(),\n                              en2f_.data(),\n                              offsets_.data(),\n                              scales_.data(),\n                              masks_.data(),\n                              accum_.data(),\n                              logits.data(),\n                              correction_bias,\n                              tokens,\n                              padded,\n                              expert_num,\n                              param_.experts_per_token,\n                              param_.norm_topk_prob,\n                              param_.routed_scale,\n                              param_.scoring_func == \"sigmoid\",\n                              st);\n    }\n    else {\n        // V2: accum must be cleared by caller; masks cleared internally\n        check_cuda_error(cudaMemsetAsync(accum_.data(), 0, sizeof(int) * expert_num * kMoeGateMaxTiles, st));\n\n        bool softmax = true;\n        if (param_.topk_method == \"group_limited_greedy\") {\n            invokeMoeSoftmaxMaskTopKGroups(\n                logits.data(), tokens, expert_num, expert_num / param_.n_group, param_.topk_group, st);\n            sync_check_cuda_error();\n            softmax = false;\n        }\n\n        /// TODO: fix illegal memory access even if NaN are present in logits\n        invokeMoeGate_V2(f2n_.data(),\n                         f2E_.data(),\n                         en2f_.data(),\n                         offsets_.data(),\n                         scales_.data(),\n                         masks_.data(),\n                         accum_.data(),\n                         logits.data(),\n                         tokens,\n                         padded,\n                         expert_num,\n                         param_.experts_per_token,\n                         softmax,\n                         param_.norm_topk_prob,\n                         param_.routed_scale,\n                         st);\n    }\n    sync_check_cuda_error();\n\n    if (is_warm_up_) {\n        std::mt19937     g;\n        const auto       expert_ids = SampleUniform(tokens, expert_num, param_.experts_per_token, g);\n        std::vector<int> cnt(expert_num);\n        for (const auto& x : expert_ids) {\n            ++cnt[x];\n        }\n        h_offsets_[0] = 0;\n        for (int i = 0; i < expert_num; ++i) {\n            h_offsets_[i + 1] = h_offsets_[i] + cnt[i];\n        }\n        check_cuda_error(\n            cudaMemcpyAsync(offsets_.data(), h_offsets_.data(), sizeof(int) * (expert_num + 1), cudaMemcpyDefault, st));\n    }\n\n    temp_ = Tensor{{param_.experts_per_token * tokens, hidden_dim_}, p.input.dtype(), p.input.device()};\n\n    if (param_.method == MoeParam::kNaive) {\n\n        invokeMoeDispatch(temp_, p.input, f2n_.data(), param_.experts_per_token, st);\n        sync_check_cuda_error();\n\n        check_cuda_error(\n            cudaMemcpyAsync(h_offsets_.data(), offsets_.data(), sizeof(int) * (expert_num + 1), cudaMemcpyDefault, st));\n\n        check_cuda_error(cudaStreamSynchronize(st));\n\n        TM_CHECK_EQ(h_offsets_[expert_num], tokens * param_.experts_per_token);\n\n        for (int i = 0; i < expert_num; ++i) {\n            if (int count = h_offsets_[i + 1] - h_offsets_[i]) {\n                auto io = temp_.slice({h_offsets_[i], 0}, {count, -1});\n                expert_ffn_->forward({io, io, moe.experts.at(i).get(), p.layer_id});\n            }\n        }\n    }\n    else {\n\n        auto& block = moe.block;\n\n        auto indices = f2n_.slice(0, tokens * param_.experts_per_token);\n        auto offsets = offsets_.slice(0, expert_num + 1);\n\n        Tensor inter = linear_.Forward(p.input, block.fused_gating_intermediate, indices, offsets_);\n        sync_check_cuda_error();\n\n        if (!block.is_fused_silu) {\n            Activation(inter, block.fused_gating_intermediate.bias, f2E_, moe.block.act_type, st);\n            sync_check_cuda_error();\n        }\n\n        linear_.Forward(inter.slice({0, 0}, {-1, inter_size_}), block.output, {}, offsets, temp_);\n        sync_check_cuda_error();\n    }\n\n    if (moe.shared_gate.weight) {\n        shared_scales_ = Gate(p.input, moe.shared_gate);\n    }\n}\n\nvoid MoeFfnLayer::Combine(ForwardParam& p)\n{\n    auto& moe = *p.weights;\n\n    invokeMoeCombine(p.output,\n                     temp_,\n                     p.weights->block.output.bias,\n                     scales_.data(),\n                     en2f_.data(),\n                     f2E_.data(),\n                     shared_scales_.data_or((float*)nullptr),\n                     param_.experts_per_token,\n                     1.f / tp_size_,\n                     p.scale,\n                     core::Context::stream().handle());\n    sync_check_cuda_error();\n\n    temp_          = {};\n    shared_scales_ = {};\n}\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/models/llama/moe_ffn_layer.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include \"src/turbomind/kernels/gemm/context.h\"\n#include \"src/turbomind/kernels/gemm/moe_utils_v2.h\"\n#include \"src/turbomind/models/llama/LlamaDenseWeight.h\"\n#include \"src/turbomind/models/llama/LlamaFfnLayer.h\"\n#include \"src/turbomind/models/llama/llama_params.h\"\n\nnamespace turbomind {\n\nclass MoeFfnLayer {\npublic:\n    MoeFfnLayer(const ModelParam& model, const MoeParam& param, const EngineParam& engine, const Context& ctx);\n\n    struct ForwardParam {\n        Tensor              input;\n        Tensor              output;\n        const MoeFfnWeight* weights;\n        float               scale;\n        int                 layer_id;\n    };\n\n    void Forward(ForwardParam& p);\n\n    void Combine(ForwardParam& p);\n\nprivate:\n    Tensor_<float> Gate(const Tensor& input, const LlamaDenseWeight& gate);\n\n    void dump_logits(int token_num, int layer_id, int expert_num);\n\n    const int inter_size_;\n    const int hidden_dim_;\n    const int tp_size_;\n\n    const MoeParam param_;\n\n    int& is_warm_up_;\n\n    LlamaLinear& linear_;\n\n    std::unique_ptr<LlamaFfnLayer> expert_ffn_;\n\n    ///////////////////////////////////////////////////////\n    /// runtime states\n    Buffer_<int> h_offsets_;\n\n    Buffer_<int>   masks_;\n    Buffer_<int>   f2n_;\n    Buffer_<int>   f2E_;\n    Buffer_<int>   en2f_;\n    Buffer_<float> scales_;\n    Buffer_<int>   accum_;\n    Buffer_<int>   offsets_;\n\n    Tensor         temp_;\n    Tensor_<float> shared_scales_;\n    ///////////////////////////////////////////////////////\n};\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/models/llama/test_cache_manager.cc",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include \"BlockManager.h\"\n#include \"SequenceManager.h\"\n\n#include \"src/turbomind/utils/allocator.h\"\n\n#include \"src/turbomind/utils/debug_utils.h\"\n#include <catch2/catch_test_macros.hpp>\n#include <iterator>\n\nusing namespace turbomind;\n\nstd::ostream& operator<<(std::ostream& os, const Block* b)\n{\n    os << \"(\" << b->id << \",\" << b->timestamp << \")\";\n    return os;\n}\n\nTEST_CASE(\"BlockManager\")\n{\n    Allocator<AllocatorType::CUDA> allocator(0);\n\n    BlockManager m(1024, 32, 8, &allocator);\n    REQUIRE(m.max_block_count() == 32);\n    REQUIRE(m.free_count() == 32);\n\n    auto blocks1 = m.Allocate(10);\n\n    dbg(blocks1);\n\n    REQUIRE(blocks1.size() == 10);\n    REQUIRE(m.active_count() == blocks1.size());\n    REQUIRE(m.free_count() == 22);\n\n    auto blocks2 = m.Allocate(6);\n    REQUIRE(blocks2.size() == 6);\n    REQUIRE(m.active_count() == blocks1.size() + blocks2.size());\n    REQUIRE(m.free_count() == 16);\n\n    auto blocks3 = m.Allocate(16);\n    REQUIRE(blocks3.size() == 16);\n    REQUIRE(m.active_count() == 32);\n    REQUIRE(m.free_count() == 0);\n\n    std::copy(blocks3.begin(), blocks3.end(), std::back_inserter(blocks1));\n    std::copy(blocks2.begin(), blocks2.end(), std::back_inserter(blocks1));\n\n    m.Touch(blocks1);\n\n    REQUIRE(m.Unlock(blocks1) == 32);\n    REQUIRE(m.active_count() == 0);\n    REQUIRE(m.free_count() == 0);\n    REQUIRE(m.cached_count() == 32);\n\n    m.Evict(16);\n    REQUIRE(m.active_count() == 0);\n    REQUIRE(m.free_count() == 16);\n    REQUIRE(m.cached_count() == 16);\n\n    auto blocks4 = m.Allocate(14);\n    REQUIRE(m.active_count() == 14);\n    REQUIRE(m.free_count() == 2);\n    REQUIRE(m.cached_count() == 16);\n}\n\nTEST_CASE(\"SequenceManager basic test\")\n{\n    Allocator<AllocatorType::CUDA> allocator(0);\n\n    SequenceManager manager(32, 32, 128, 128, 20, 4, 16, 0, &allocator);\n\n    REQUIRE(manager.max_block_count() == 20);\n    REQUIRE(manager.Contains(1) == false);\n\n    auto s1 = manager.Create(1);\n    dbg(*s1);\n    REQUIRE(manager.Contains(1) == true);\n\n    manager.Erase(1);\n    REQUIRE(manager.Contains(1) == false);\n\n    s1 = manager.Create(1);\n    REQUIRE(manager.Contains(1) == true);\n\n    auto outcome = manager.Materialize({s1}, {128}, {100}, 1);\n    dbg(s1->blocks);\n    REQUIRE(s1->blocks.size() == 2);\n\n    auto s2 = manager.Create(2);\n    REQUIRE(manager.Contains(2));\n\n    outcome = manager.Materialize({s1, s2}, {128, 2559}, {2, 1}, 1);\n    dbg(outcome);\n    REQUIRE(outcome.allocation == 20);\n    REQUIRE(outcome.swap_in == 1);\n    REQUIRE(outcome.swap_out == 1);\n\n    auto s3 = manager.Create(3);\n    outcome = manager.Materialize({s1, s2, s3}, {127, 2559, 255}, {1, 100, 2}, 1);\n    dbg(outcome);\n}\n\nTEST_CASE(\"SequenceManager functional test\")\n{\n    Allocator<AllocatorType::CUDA> allocator(0);\n    SequenceManager                manager(32, 32, 128, 128, 20, 4, 16, 0, &allocator);\n\n    auto seq = manager.Create(1);\n    for (int i = 0; i < 1024; ++i) {\n        auto outcome = manager.Materialize({seq}, {i}, {0}, 1);\n        if (outcome.allocation) {\n            dbg(i, outcome);\n        }\n    }\n}\n"
  },
  {
    "path": "src/turbomind/models/llama/unified_attention_layer.cc",
    "content": "/*\n * Copyright (c) OpenMMLab. All rights reserved.\n * Copyright (c) 2021-2023, NVIDIA CORPORATION.  All rights reserved.\n * Copyright (c) 2021, NAVER Corp.  Authored by CLOVA.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n// Modified from\n// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/layers/attention_layers/GptContextAttentionLayer.cc\n\n#include <algorithm>\n#include <functional>\n#include <math.h>\n#include <numeric>\n\n#include \"src/turbomind/core/allocator.h\"\n#include \"src/turbomind/core/check.h\"\n#include \"src/turbomind/core/context.h\"\n#include \"src/turbomind/core/core.h\"\n#include \"src/turbomind/core/data_type.h\"\n#include \"src/turbomind/core/tensor.h\"\n#include \"src/turbomind/engine/request.h\"\n\n#include \"src/turbomind/kernels/attention/attention.h\"\n#include \"src/turbomind/kernels/attention/decoding.h\"\n#include \"src/turbomind/kernels/attention/kv_cache_utils_v2.h\"\n#include \"src/turbomind/kernels/norm/rms_norm.h\"\n\n#include \"src/turbomind/macro.h\"\n\n#include \"src/turbomind/models/llama/llama_kernels.h\"\n#include \"src/turbomind/models/llama/llama_rope.h\"\n#include \"src/turbomind/models/llama/llama_utils.h\"\n#include \"src/turbomind/models/llama/mla_utils.h\"\n#include \"src/turbomind/models/llama/unified_attention_layer.h\"\n\n#include \"src/turbomind/utils/anomaly_handler.h\"\n#include \"src/turbomind/utils/cuda_utils.h\"\n#include \"src/turbomind/utils/logger.h\"\n\n// #include \"dbg.h\"\n\nnamespace turbomind {\n\nstruct AttentionData {\n    struct Stat {\n        int n;\n        int q_sum;\n        int q_max;\n        int k_sum;\n        int k_max;\n    } decode, prefill;\n\n    Buffer_<void*> block_ptrs;\n    Buffer_<int>   block_ptrs_offsets;\n\n    Buffer_<float> rope_base;\n\n    Tensor_<int> mrope_position_ids;\n    Buffer_<int> mrope_position_delta;\n    Buffer_<int> mrope_length;\n\n    // borrowed from env\n    Buffer_<bool> finished;\n    Buffer_<int>  q_offsets;\n    Buffer_<int>  k_offsets;\n\n    // int dbg_offset;\n    // int dbg_size;\n};\n\nUnifiedAttentionLayer::~UnifiedAttentionLayer()\n{\n\n    check_cuda_error(cudaEventDestroy(aux_event_));\n    check_cuda_error(cudaEventDestroy(qkv_event_));\n    check_cuda_error(cudaStreamDestroy(aux_stream_));\n\n    aux_event_ = qkv_event_ = {};\n    aux_stream_             = {};\n}\n\nUnifiedAttentionLayer::UnifiedAttentionLayer(const ModelParam&     model,\n                                             const AttentionParam& attn,\n                                             const EngineParam&    engine,\n                                             int                   tp_size,\n                                             const Context&        ctx,\n                                             int                   phases,\n                                             bool                  init):\n    head_num_(model.head_num),\n    kv_head_num_(model.kv_head_num),\n    size_per_head_(model.head_dim),\n    hidden_units_(model.hidden_units),\n    local_head_num_(head_num_ / tp_size),\n    local_kv_head_num_(model.kv_head_num / tp_size),\n    param_(attn),\n    model_param_(model),\n    engine_param_(engine),\n    cp_fn_ctx_(ctx.comm.d_comm, ctx.comm.d_cp_group),\n    is_warm_up_{*ctx.is_warm_up},\n    context_(ctx),\n    linear_(*ctx.linear),\n    arch_(getSMVersion())\n{\n    TM_CHECK_EQ(head_num_ % tp_size, 0) << head_num_ << \" \" << tp_size;\n    TM_CHECK_EQ(head_num_ % kv_head_num_, 0) << head_num_ << \" \" << kv_head_num_;\n\n    check_cuda_error(cudaStreamCreateWithFlags(&aux_stream_, cudaStreamNonBlocking));\n    check_cuda_error(cudaEventCreateWithFlags(&qkv_event_, cudaEventDisableTiming));\n    check_cuda_error(cudaEventCreateWithFlags(&aux_event_, cudaEventDisableTiming));\n\n    init_rope_kernel_param(param_.rope, rope_param_);\n\n    // Skip other attention layer types\n    std::vector<int> layer_types = model_param_.layer_types;\n    layer_types.resize(model_param_.layer_num);\n    cache_layer_ids_.resize(layer_types.size(), -1);\n    int next_cache_id = 0;\n    for (size_t i = 0; i < layer_types.size(); ++i) {\n        if (layer_types[i] == 0) {\n            cache_layer_ids_[i] = next_cache_id++;\n        }\n    }\n\n    Allocator alloc            = core::Context::device_alloc();\n    ssize_t   workspace_tokens = kMaxWorkspaceTokens;\n    if (engine_param_.attn_cp_size > 1) {\n        alloc = GetSymmAllocator(ctx.comm.d_comm);\n        workspace_tokens += engine_param_.max_forward_token_num;\n    }\n    // partial_O layout:\n    //   w/  cp, decode(q, h, k, 2) + prefill(q, h, 1, 2)\n    //   w/o cp, decode(q, h, k, 2)\n    partial_O_  = Tensor_<float>({workspace_tokens, local_head_num_, size_per_head_}, kDEVICE);\n    partial_ML_ = Tensor_<float>({engine_param_.attn_cp_size, workspace_tokens, local_head_num_, 2}, alloc);\n    split_cnt_  = Tensor_<int>({workspace_tokens}, kDEVICE);\n    if (init) {\n        const int dim = (int)local_head_num_ * (int)size_per_head_;\n        tmp_attn_     = Tensor{{engine_param_.max_forward_token_num, dim}, model.data_type, kDEVICE};\n    }\n\n    Clear(split_cnt_.buffer());\n\n    const int bsz = engine.max_batch_size;\n\n    if (rope_param_.type == RopeType::kDynamic) {\n        rope_base_buf_ = {bsz + 1, kCPUpinned};\n    }\n    else if (rope_param_.type == RopeType::kMrope) {\n        // `mrope_position_ids` is not buffered\n        mrope_position_delta_buf_ = {bsz, kCPUpinned};\n        mrope_length_buf_         = {bsz, kCPUpinned};\n    }\n    const int max_blocks = bsz * cdiv(engine.session_len, param_.cache_block_seq_len);\n    for (int i = 0; i < phases; ++i) {\n        auto& d               = data_.emplace_back(std::make_shared<AttentionData>());\n        d->block_ptrs         = {max_blocks + 16, kDEVICE};\n        d->block_ptrs_offsets = {bsz + 1, kDEVICE};\n        if (rope_param_.type == RopeType::kDynamic) {\n            d->rope_base = empty_like(rope_base_buf_, kDEVICE);\n        }\n        else if (rope_param_.type == RopeType::kMrope) {\n            /// TODO: total space for `mrope_position_ids` can be reduced to (max_fwd_tokens, 3)\n            d->mrope_position_ids    = {{bsz, engine.session_len, 3}, kDEVICE};\n            d->mrope_position_delta  = empty_like(mrope_position_delta_buf_, kDEVICE);\n            d->mrope_length          = empty_like(mrope_length_buf_, kDEVICE);\n            rope_param_.mrope.stride = d->mrope_position_ids.stride(0);\n        }\n    }\n}\n\nstatic void init_dynamic_ntk(RequestCache& cache, const RopeParam& rope)\n{\n    cache.rope_base = rope.base;\n    if (auto scaling_factor = rope.factor; scaling_factor > 1.f) {\n        const auto max_seq_len = cache.prompt_len;\n        const auto max_pos_emb = rope.max_position_embeddings;\n        if (max_seq_len > max_pos_emb) {\n            scaling_factor = scaling_factor * max_seq_len / max_pos_emb - (scaling_factor - 1);\n            cache.rope_base *= powf(scaling_factor, rope.dim / (rope.dim - 2.f));\n            // clang-format off\n            TM_LOG_INFO(\"[ProcessInferRequests] %ld rope_scaling_factor: %f, rope_theta = %f\",\n                        (long)cache.req->id, scaling_factor, cache.rope_base);\n            // clang-format on\n        }\n    }\n}\n\nvoid UnifiedAttentionLayer::Run(BatchOp op, int phase, TensorMap& env)\n{\n    if (op == BatchOp::kAdd) {\n        Buffer_<RequestCache*> rc = env.at(\"requests\").buffer();\n        if (rope_param_.type == RopeType::kDynamic) {\n            for (int i = 0; i < rc.size(); ++i) {\n                init_dynamic_ntk(*rc[i], param_.rope);\n            }\n        }\n    }\n    else if (op == BatchOp::kSetup) {\n        Setup(phase, env);\n    }\n    else if (op == BatchOp::kPrepare) {\n        data_.at(phase)->finished  = env.at(\"finished\").buffer().borrow();\n        data_.at(phase)->q_offsets = env.at(\"q_offsets\").buffer().borrow();\n        data_.at(phase)->k_offsets = env.at(\"k_offsets\").buffer().borrow();\n\n        // This is needed in async mode to clear the `attn` buffer for the finished sequences. Ohterwise random NaNs\n        // will crash the MoE router later\n        /// TODO: use better solution, this increase memory usage and heterogenous attention layers may still break it\n        if (tmp_attn_) {\n            auto& d = data_.at(phase);\n            Clear(tmp_attn_.slice(0, d->decode.n + d->prefill.q_sum));\n            Clear(split_cnt_);\n        }\n    }\n}\n\nvoid UnifiedAttentionLayer::Setup(int phase, TensorMap& env)\n{\n    const auto& rc  = env.at(\"batch\").data<BatchData*>()[0]->rc;\n    const int   bsz = rc.size();\n\n    auto& d    = *data_.at(phase);\n    auto& copy = *env.at(\"copy\").data<BatchCopy*>()[0];\n\n    {  /// Upload KV cache ptrs\n        const Buffer_<int> offsets = env.at(\"block_ptrs_offsets\").buffer();\n        copy(env.at(\"block_ptrs\").buffer(), offsets[bsz], d.block_ptrs);\n        copy(offsets, bsz + 1, d.block_ptrs_offsets);\n    }\n\n    /// prepare Q/K stats for decode/prefill\n    d.decode = d.prefill = {};\n\n    d.decode.n  = std::find_if(rc.begin(), rc.end(), [](auto r) { return r->input_len > 1; }) - rc.begin();\n    d.prefill.n = bsz - d.decode.n;\n\n    // d.dbg_offset = d.dbg_size = 0;\n\n    for (int i = 0; i < bsz; ++i) {\n        const auto& c = *rc[i];\n\n        // if (c.request->id == 4 && c.input_len > 1) {\n        //     d.dbg_offset = d.decode.q_sum + d.prefill.q_sum;\n        //     d.dbg_size   = c.input_len;\n        // }\n\n        auto& s = i < d.decode.n ? d.decode : d.prefill;\n        s.q_sum += c.input_len;\n        s.k_sum += c.history_len + c.alpha + c.input_len;\n        s.q_max = std::max(s.q_max, c.input_len);\n        s.k_max = std::max(s.k_max, c.history_len + c.alpha + c.input_len);\n    }\n\n    // auto &D = d.decode, &P = d.prefill;\n    // dbg(D.n, D.k_sum, D.k_max, P.n, P.q_sum, P.q_max, P.k_sum, P.k_max);\n\n    /// handling different RoPE types\n    if (rope_param_.type == RopeType::kDynamic) {\n        for (int i = 0; i < bsz; ++i) {\n            rope_base_buf_[i] = rc[i]->rope_base;\n        }\n        copy(rope_base_buf_, bsz, d.rope_base);\n    }\n    else if (rope_param_.type == RopeType::kMrope) {\n        const auto stride = d.mrope_position_ids.stride(0);\n        for (int i = 0; i < rc.size(); ++i) {\n            auto& c = *rc[i];\n            auto& r = *c.req;\n            if (auto pos_ids = r.inputs.try_(\"mrope_position_ids\")) {\n                int length                   = pos_ids->shape(0);\n                mrope_length_buf_[i]         = length;\n                mrope_position_delta_buf_[i] = *r.inputs.at(\"mrope_position_delta\").data<int>();\n                if (auto o = Interval{0, length} & Interval{c.history_len + c.alpha, Interval::Size{c.input_len}}) {\n                    copy(pos_ids->data<int>() + o.begin() * 3,\n                         (int)o.size() * 3,\n                         d.mrope_position_ids.data() + i * stride + o.begin() * 3);\n                }\n            }\n            else {\n                mrope_length_buf_[i] = mrope_position_delta_buf_[i] = 0;\n            }\n        }\n        copy(mrope_length_buf_, rc.size(), d.mrope_length);\n        copy(mrope_position_delta_buf_, rc.size(), d.mrope_position_delta);\n    }\n}\n\nvoid UnifiedAttentionLayer::Forward(ForwardParam p)\n{\n    TM_LOG_DEBUG(__PRETTY_FUNCTION__);\n\n    /////////////////////////////////////////////\n    /// parse inputs\n    const int token_num = p.input.shape(0);\n\n    if (token_num == 0) {\n        return;\n    }\n\n    const int layer_id = p.layer_id;\n\n    const auto& weights = *p.weights;\n\n    Tensor qkv;\n\n    auto& d = *data_.at(p.phase);\n\n    // if (d.dbg_size) {\n    //     DebugTensor(p.input.slice(d.dbg_offset, d.dbg_size), Concat(\"attn_in\", p.layer_id), 0);\n    // }\n\n    if (weights.qkv.output_dim) {\n        // [token_num, hidden_dim] -> [token_num, local_q_kv_head_num, head_dim]\n        qkv = linear_.Forward(p.input, weights.qkv);\n        sync_check_cuda_error();\n\n        if (model_param_.qk_norm) {\n            qk_norm(qkv, weights);\n        }\n    }\n    else {\n        qkv = forward_mla(p.input, weights);\n    }\n\n    TM_DEBUG_TENSOR(qkv, Concat(\"qkv\", layer_id), 3);\n\n    auto invoke = [&](auto t) -> Tensor {\n        using T = decltype(t);\n        return core_attention<T>(qkv, p, weights);\n    };\n\n    Tensor attn = [&]() -> Tensor { TM_DISPATCH_PRIMARY_DTYPES_RET(qkv.dtype(), invoke); }();\n\n    // Apply sigmoid gating: attn *= sigmoid(gate)\n    // Gate is stored at the end of each token's QKV: [Q|K|V|Gate]\n    if (model_param_.attn_output_gate) {\n        const int  q_count     = qkv.shape(0);\n        const int  attn_dim    = local_head_num_ * size_per_head_;\n        const int  gate_offset = (local_head_num_ + 2 * local_kv_head_num_) * size_per_head_;\n        const int  qkv_stride  = (2 * local_head_num_ + 2 * local_kv_head_num_) * size_per_head_;\n        const auto stream      = core::Context::stream().handle();\n        invokeSigmoidGateMultiply(attn.raw_data(),\n                                  (const char*)qkv.raw_data() + gate_offset * byte_size(qkv.dtype(), 1),\n                                  attn_dim,\n                                  qkv_stride,\n                                  q_count,\n                                  qkv.dtype(),\n                                  stream);\n        sync_check_cuda_error();\n    }\n\n    TM_DEBUG_TENSOR(attn, Concat(\"attn\", layer_id), 3);\n\n    // if (d.dbg_size) {\n    //     DebugTensor(attn.slice(d.dbg_offset, d.dbg_size), Concat(\"attn_out\", p.layer_id), 0);\n    // }\n\n    //////////////////////////////////////////////\n    /// output gemm <Bs,HD> -> <Bs,HD>\n    (void)linear_.Forward(attn, weights.output, p.output);\n    sync_check_cuda_error();\n}\n\ntemplate<class T>\nTensor UnifiedAttentionLayer::core_attention(Tensor& qkv, const ForwardParam& p, const WeightType& weights)\n{\n    const auto device = qkv.device();\n    const auto dtype  = qkv.dtype();\n\n    auto& d = *data_.at(p.phase);\n\n    const int batch_size = d.decode.n + d.prefill.n;\n    const int q_count    = qkv.shape(0);\n\n    TM_CHECK_EQ(d.prefill.q_sum + d.decode.n, q_count);\n\n    const int local_q_kv_head_num = local_head_num_ + 2 * local_kv_head_num_;\n\n    Tensor attn;\n    if (tmp_attn_) {\n        attn = tmp_attn_.slice(0, q_count);\n    }\n    else {\n        attn = {{q_count, (int)local_head_num_ * (int)size_per_head_}, dtype, device};\n    }\n\n    const bool is_mla = model_param_.mla.kv_lora_rank > 0;\n\n    Tensor tmp_kv{\n        {(int)local_kv_head_num_, is_mla ? 1 : 2, d.prefill.k_sum + MAX_CTA_S, (int)size_per_head_}, dtype, device};\n\n    const int cache_layer_id = cache_layer_ids_[p.layer_id];\n\n    auto CreateParams = [&](int offset, AttentionData::Stat stat, int max_kv_splits, cudaStream_t stream) {\n        AttentionParams<T> params{};\n\n        // Batch offset for `out` and `q` are computed inside the kernel\n        params.out = (T*)attn.raw_data();\n\n        params.q = (T*)qkv.raw_data();\n        params.k = params.q + local_head_num_ * size_per_head_;\n        if (is_mla) {\n            params.v      = params.k;\n            params.stride = (local_head_num_ + 1 * local_kv_head_num_) * size_per_head_;\n        }\n        else {\n            params.v = params.k + local_kv_head_num_ * size_per_head_;\n            // When attn_output_gate, QKV layout is [Q|K|V|Gate] per token\n            // stride must account for the extra gate portion at the end\n            if (model_param_.attn_output_gate) {\n                params.stride = (2 * local_head_num_ + 2 * local_kv_head_num_) * size_per_head_;\n            }\n            else {\n                params.stride = (local_head_num_ + 2 * local_kv_head_num_) * size_per_head_;\n            }\n        }\n\n        if (weights.qkv.bias) {\n            params.q_bias = (T*)weights.qkv.bias.data_or<T>(nullptr);\n            params.k_bias = params.q_bias + local_head_num_ * size_per_head_;\n            params.v_bias = params.k_bias + local_kv_head_num_ * size_per_head_;\n        }\n\n        params.batch_size = stat.n;\n\n        params.token_num = stat.q_sum;\n        params.max_q_len = stat.q_max;\n        params.max_k_len = stat.k_max;\n\n        // decode only\n        params.block_iter_params = BlockIteratorParams{(char**)d.block_ptrs.data(),  //\n                                                       d.block_ptrs_offsets.data() + offset,\n                                                       cache_layer_id,\n                                                       (int)param_.cache_block_seq_len};\n\n        // prefill only\n        if (is_mla) {\n            params.linear_iter_params = LinearIteratorParams{\n                tmp_kv.raw_data(),            // flattened KV\n                stat.k_sum * size_per_head_,  // stride to next head\n                0                             // stride from K to V\n            };\n        }\n        else {\n            params.linear_iter_params = LinearIteratorParams{\n                tmp_kv.raw_data(),                // flattened KV\n                stat.k_sum * size_per_head_ * 2,  // stride to next head\n                stat.k_sum * size_per_head_       // stride from K to V\n            };\n        }\n\n        params.finished = d.finished.data() + offset;\n        params.cu_q_len = d.q_offsets.data() + offset;\n        params.cu_k_len = d.k_offsets.data() + offset;\n\n        params.num_heads     = local_head_num_;\n        params.num_kv_heads  = local_kv_head_num_;\n        params.size_per_head = size_per_head_;\n        params.layer_id      = cache_layer_id;\n\n        double scaling = 1.;\n        if (param_.softmax_scale) {  // model predefined softmax scale\n            scaling *= param_.softmax_scale;\n        }\n        else {  // default value\n            scaling /= std::sqrt((float)params.size_per_head);\n        }\n        params.inv_sqrt_dh = scaling * std::log2(std::exp(1.));\n\n        params.sinks       = weights.sinks.data_or((T*)nullptr);\n        params.scale_sinks = scaling;\n\n        params.window_size = weights.window_size;\n        if (!params.window_size) {\n            params.window_size = 256 << 20;  // 256 M\n        }\n\n        params.rope_param = rope_param_;\n        if (rope_param_.type == RopeType::kDynamic) {\n            params.rope_param.base = d.rope_base.data() + offset;\n        }\n        else if (rope_param_.type == RopeType::kMrope) {\n            params.rope_param.mrope.position_ids   = d.mrope_position_ids.data() + offset * rope_param_.mrope.stride;\n            params.rope_param.mrope.position_delta = d.mrope_position_delta.data() + offset;\n            params.rope_param.mrope.length         = d.mrope_length.data() + offset;\n        }\n\n        // logn attn\n        params.use_logn_attn           = param_.use_logn_attn;\n        params.max_position_embeddings = param_.max_position_embeddings;\n\n        // Decoding use only for now\n        params.split_cnt   = split_cnt_.data();\n        params.partial_ML  = partial_ML_.data();\n        params.partial_O   = partial_O_.data();\n        params.max_split_k = std::min(std::max(1, kMaxWorkspaceTokens / params.token_num), max_kv_splits);\n\n        // context parallel\n        params.cp_rank = engine_param_.attn_cp_rank;\n        params.cp_size = engine_param_.attn_cp_size;\n        if (params.cp_size > 1) {\n            params.cp_size = cutlass::FastDivmod(params.cp_size);\n\n            // update ML,O offset if both prefill and decode present\n            const int offset_ML_stage =\n                engine_param_.attn_cp_size * (offset ? kMaxWorkspaceTokens * local_head_num_ * 2 : 0);\n            const int offset_ML_rank = params.cp_rank * params.token_num * local_head_num_ * params.max_split_k * 2;\n            const int offset_O       = offset ? kMaxWorkspaceTokens * local_head_num_ * size_per_head_ : 0;\n\n            params.partial_ML = partial_ML_.data() + offset_ML_stage + offset_ML_rank;\n            params.partial_O  = partial_O_.data() + offset_O;\n            params.offset_q   = offset;\n\n            // postprocess func\n            params.cp_fn          = CpPost;\n            params.cp_fn_ctx      = (void*)&cp_fn_ctx_;\n            cp_fn_ctx_.cp_rank    = params.cp_rank;\n            cp_fn_ctx_.count      = params.token_num * local_head_num_ * params.max_split_k * 2;\n            cp_fn_ctx_.partial_ML = partial_ML_.data() + offset_ML_stage;\n            cp_fn_ctx_.stream     = stream;\n        }\n\n        params.arch   = arch_;\n        params.stream = stream;\n\n        params.quant_policy = model_param_.quant_policy;\n        return params;\n    };\n\n    const cudaStream_t stream = core::Context::stream().handle();\n\n    cudaStream_t pf_stream = stream;\n    cudaStream_t dc_stream = pf_stream;\n\n    if (d.decode.n && d.prefill.n) {\n        pf_stream = aux_stream_;\n        check_cuda_error(cudaEventRecord(qkv_event_, stream));\n        check_cuda_error(cudaStreamWaitEvent(aux_stream_, qkv_event_));\n    }\n\n    if (d.prefill.n && !is_warm_up_) {\n        const int offset = d.decode.n;\n        // We are executing prefill & decoding kernels concurrently, but only have 1 workspace\n        // disable split kv for prefill for now\n        auto params = CreateParams(offset, d.prefill, 1, pf_stream);\n        if constexpr (sizeof(T) == 2) {\n            invokeProcessKV_v2_(params);\n            sync_check_cuda_error();\n\n            /// TODO: skip flattening for `sm_80`\n            invokeFlattenKV_v2_(params, d.prefill.k_sum);\n            sync_check_cuda_error();\n\n            dispatchAttention(params);\n            sync_check_cuda_error();\n        }\n    }\n\n    if (d.decode.n && !is_warm_up_) {\n        auto params = CreateParams(0, d.decode, kMaxKVSplits, dc_stream);\n        if constexpr (sizeof(T) == 2) {\n            dispatchDecoding<T>(params);\n            sync_check_cuda_error();\n        }\n    }\n\n    if (d.decode.n && d.prefill.n) {\n        check_cuda_error(cudaEventRecord(aux_event_, aux_stream_));\n        check_cuda_error(cudaStreamWaitEvent(stream, aux_event_));\n    }\n\n    if (is_warm_up_) {\n        rng_.set_stream(stream);\n        rng_.GenerateUniform(attn.data<T>(), attn.size(), .02f, -.01f);\n    }\n\n    return attn;\n}\n\nTensor UnifiedAttentionLayer::forward_mla(const Tensor& hidden_state, const WeightType& w)\n{\n\n    const auto token_num = hidden_state.shape(0);\n    const auto dtype     = hidden_state.dtype();\n\n    const int q_lora_rank  = w.q_a_proj.output_dim;\n    const int kv_lora_rank = w.kv_a_layernorm.size();\n    const int qk_rope_dim  = w.kv_a_proj.output_dim - kv_lora_rank;\n\n    Tensor q;\n\n    const auto stream = core::Context::stream().handle();\n\n    if (w.q_proj.weight) {\n        q = linear_.Forward(hidden_state, w.q_proj);\n        sync_check_cuda_error();\n    }\n    else {\n        Tensor q_a = linear_.Forward(hidden_state, w.q_a_proj);\n        sync_check_cuda_error();\n\n        invokeRMSNorm(q_a, q_a, w.q_a_layernorm, model_param_.norm_eps, stream);\n        sync_check_cuda_error();\n\n        q = linear_.Forward(q_a, w.q_b_proj);\n        sync_check_cuda_error();\n    }\n\n    Tensor kv_a_k_pe = linear_.Forward(hidden_state, w.kv_a_proj);\n    sync_check_cuda_error();\n\n    auto kv_a = kv_a_k_pe.slice({0, 0}, {-1, kv_lora_rank});\n    invokeRMSNorm(kv_a, kv_a, w.kv_a_layernorm, model_param_.norm_eps, stream);\n    sync_check_cuda_error();\n\n    const int local_q_kv_head_num = local_head_num_ + 1 * local_kv_head_num_;\n\n    Tensor qkv{{token_num, local_q_kv_head_num, size_per_head_}, dtype, hidden_state.device()};\n    MLACopyQKV(dtype,\n               qkv.raw_data(),\n               q.raw_data(),\n               kv_a_k_pe.raw_data(),\n               token_num,\n               local_head_num_,\n               kv_lora_rank,\n               qk_rope_dim,\n               stream);\n    sync_check_cuda_error();\n\n    return qkv;\n}\n\nvoid UnifiedAttentionLayer::qk_norm(Tensor& qkv, const WeightType& weights)\n{\n    const auto stream = core::Context::stream().handle();\n\n    check_cuda_error(cudaEventRecord(qkv_event_, stream));\n    check_cuda_error(cudaStreamWaitEvent(aux_stream_, qkv_event_));\n\n    TM_CHECK(model_param_.attn_bias == false) << \"not implemented\";\n\n    const auto token_num = qkv.shape(0);\n\n    auto qkv3 = qkv.view({token_num, -1, (int)size_per_head_});\n\n    auto q = qkv3.slice({0, 0, 0}, {-1, (int)local_head_num_, -1});\n    invokeRMSNormQK(q, weights.q_a_layernorm, model_param_.norm_eps, stream);\n    sync_check_cuda_error();\n\n    auto k = qkv3.slice({0, (int)local_head_num_, 0}, {-1, (int)local_kv_head_num_, -1});\n    invokeRMSNormQK(k, weights.kv_a_layernorm, model_param_.norm_eps, aux_stream_);\n    sync_check_cuda_error();\n\n    check_cuda_error(cudaEventRecord(aux_event_, aux_stream_));\n    check_cuda_error(cudaStreamWaitEvent(stream, aux_event_));\n}\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/models/llama/unified_attention_layer.h",
    "content": "/*\n * Copyright (c) OpenMMLab. All rights reserved.\n * Copyright (c) 2021-2023, NVIDIA CORPORATION.  All rights reserved.\n * Copyright (c) 2021, NAVER Corp.  Authored by CLOVA.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n// Modified from\n// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/layers/attention_layers/GptContextAttentionLayer.h\n\n#pragma once\n\n#include <cuda_runtime.h>\n\n#include \"src/turbomind/core/core.h\"\n#include \"src/turbomind/engine/batch.h\"\n#include \"src/turbomind/kernels/attention/cp_utils.h\"\n#include \"src/turbomind/kernels/gemm/test/test_utils.h\"\n#include \"src/turbomind/models/llama/LlamaDenseWeight.h\"\n#include \"src/turbomind/models/llama/LlamaLinear.h\"\n#include \"src/turbomind/models/llama/context.h\"\n#include \"src/turbomind/models/llama/llama_params.h\"\n\nnamespace turbomind {\n\nstruct AttentionData;\n\nclass UnifiedAttentionLayer {\npublic:\n    using WeightType = LlamaAttentionWeight;\n\n    static constexpr int kMaxKVSplits        = 128;\n    static constexpr int kMaxWorkspaceTokens = 4096;\n\n    struct ForwardParam {\n        int               phase;\n        Tensor            input;\n        Tensor            output;\n        const WeightType* weights;\n        int               layer_id;\n    };\n\n    ~UnifiedAttentionLayer();\n\n    UnifiedAttentionLayer(const ModelParam&     model,\n                          const AttentionParam& attn,\n                          const EngineParam&    engine,\n                          int                   tp_size,\n                          const Context&        context,\n                          int                   phases,\n                          bool                  init);\n\n    void Run(BatchOp op, int phase, TensorMap& env);\n\n    void Forward(ForwardParam p);\n\nprivate:\n    void Setup(int phase, TensorMap& env);\n\n    Tensor forward_mla(const Tensor& hidden_state, const WeightType& weights);\n\n    /// TODO: dropping the `T` here requires deep refactor of attention dispatch\n    template<class T>\n    Tensor core_attention(Tensor& qkv, const ForwardParam& p, const WeightType& weights);\n\n    void qk_norm(Tensor& qkv, const WeightType& weights);\n\nprivate:\n    const int head_num_;\n    const int kv_head_num_;\n    const int size_per_head_;\n    const int hidden_units_;\n    const int local_head_num_;\n    const int local_kv_head_num_;\n\n    const AttentionParam param_;\n    const EngineParam    engine_param_;\n    const ModelParam     model_param_;\n    const Context&       context_;\n\n    int& is_warm_up_;\n\n    LlamaLinear& linear_;\n    const int    arch_{};\n\n    cudaStream_t aux_stream_;\n    cudaEvent_t  qkv_event_;\n    cudaEvent_t  aux_event_;\n\n    RNG rng_;\n\n    RopeKernelParam rope_param_{};\n\n    std::vector<std::shared_ptr<AttentionData>> data_;\n\n    std::vector<int> cache_layer_ids_;\n\n    ///////////////////////////////////////////////////////\n    /// temp runtime buffers\n    Tensor_<float> partial_O_;\n    Tensor_<float> partial_ML_;\n    Tensor_<int>   split_cnt_;\n    Tensor         tmp_attn_;\n\n    Buffer_<float> rope_base_buf_;\n    Buffer_<int>   mrope_position_delta_buf_;\n    Buffer_<int>   mrope_length_buf_;\n\n    CpPostContext cp_fn_ctx_;  // context parallel\n};\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/models/llama/unified_decoder.cc",
    "content": "\n\n#include <numeric>\n#include <optional>\n\n#include <cuda_runtime.h>\n\n#include \"src/turbomind/core/allocator.h\"\n#include \"src/turbomind/kernels/core/math.h\"\n#include \"src/turbomind/kernels/norm/rms_norm.h\"\n#include \"src/turbomind/models/llama/llama_kernels.h\"\n#include \"src/turbomind/models/llama/llama_utils.h\"\n#include \"src/turbomind/models/llama/moe_ffn_layer.h\"\n#include \"src/turbomind/models/llama/unified_attention_layer.h\"\n#include \"src/turbomind/models/llama/unified_decoder.h\"\n#include \"src/turbomind/utils/anomaly_handler.h\"\n#include \"src/turbomind/utils/cuda_utils.h\"\n\n#include \"src/turbomind/engine/request.h\"\n\n// #include \"dbg.h\"\n\nnamespace turbomind {\n\nvoid UnifiedDecoder::Run(BatchOp op, int phase, TensorMap& env)\n{\n    attn_layer_->Run(op, phase, env);\n    if (linear_attn_layer_) {\n        linear_attn_layer_->Run(op, phase, env);\n    }\n}\n\nUnifiedDecoder::UnifiedDecoder(const ModelParam&     model,\n                               const EngineParam&    engine,\n                               const AttentionParam& attn,\n                               const MoeParam&       moe,\n                               const Context&        ctx,\n                               int                   phases):\n    layer_num_(model.layer_num),\n    hidden_units_(model.hidden_units),\n    attn_tp_size_(engine.attn_tp_size),\n    attn_dp_size_(engine.attn_dp_size),\n    attn_dp_rank_(engine.attn_dp_rank),\n    mlp_tp_size_(engine.mlp_tp_size),\n    attn_tp_group_(ctx.comm.d_tp_group),\n    rmsnorm_eps_(model.norm_eps),\n    d_comm_(ctx.comm.d_comm),\n    tune_layer_num_(model.tune_layer_num),\n    is_warm_up_{*ctx.is_warm_up}\n{\n    if (std::accumulate(moe.expert_num.begin(), moe.expert_num.end(), 0LL)) {\n        moe_ffn_layer_ = std::make_unique<MoeFfnLayer>(model, moe, engine, ctx);\n    }\n\n    attn_layer_ =\n        std::make_unique<UnifiedAttentionLayer>(model, attn, engine, attn_tp_size_, ctx, phases, (bool)moe_ffn_layer_);\n\n    if (std::find(model.layer_types.begin(), model.layer_types.end(), 1) != model.layer_types.end()) {\n        linear_attn_layer_ = std::make_unique<GatedDeltaNetLayer>(model, attn, engine, attn_tp_size_, ctx, phases);\n    }\n\n    if (std::accumulate(model.inter_size.begin(), model.inter_size.end(), 0LL)) {\n        ffn_layer_ = std::make_unique<LlamaFfnLayer>(model, ctx);\n    }\n}\n\nvoid UnifiedDecoder::AllreduceResidualRMSnorm(Tensor&       hidden_states,\n                                              Tensor&       residual,\n                                              const Tensor& bias,\n                                              const Tensor& weight,\n                                              int           token_num,\n                                              int           group0,\n                                              int           group1,\n                                              const int*    local_token_nums)\n{\n    const auto dtype = hidden_states.dtype();\n\n    const auto stream = core::Context::stream().handle();\n\n    if (0) {}\n    else if (group0 || group1) {\n        d_comm_->AllreduceResidualBiasRMSnormEx(hidden_states.raw_data(),\n                                                residual.data_or((void*)nullptr),\n                                                bias.data_or((void*)nullptr),\n                                                weight.raw_data(),\n                                                rmsnorm_eps_,\n                                                hidden_units_,\n                                                dtype,\n                                                group0,\n                                                group1,\n                                                local_token_nums,\n                                                stream);\n        sync_check_cuda_error();\n    }\n    else if (d_comm_) {\n        d_comm_->AllreduceResidualBiasRMSnorm(hidden_states.raw_data(),\n                                              residual.data_or((void*)nullptr),\n                                              bias.data_or((void*)nullptr),\n                                              weight.raw_data(),\n                                              rmsnorm_eps_,\n                                              hidden_units_,\n                                              token_num,\n                                              dtype,\n                                              0,\n                                              stream);\n        sync_check_cuda_error();\n    }\n    else {\n        invokeResidualBiasRMSNorm(hidden_states.raw_data(),\n                                  residual.data_or((void*)nullptr),\n                                  weight.raw_data(),\n                                  bias.data_or((void*)nullptr),\n                                  dtype,\n                                  hidden_units_,\n                                  token_num,\n                                  rmsnorm_eps_,\n                                  stream);\n        sync_check_cuda_error();\n    }\n}\n\nvoid UnifiedDecoder::Forward(int phase, TensorMap& args, const std::vector<WeightType*>& weights)\n{\n    /**\n     * input tensors:\n     *   \\param decoder_input [token_num, hidden_units], float\n     *   \\param output_norm_weight [hidden_dims], float\n     *   \\param cu_block_counts [batch_size+1], int\n     *   \\param finished [batch_size], bool\n     *   \\param rope_theta [batch_size], float\n     *   \\param h_q_len [batch_size], int on cpu\n     *   \\param h_k_len [batch_size], int on cpu\n     *   \\param pf_batch_size [1], int on cpu\n     *   \\param dc_batch_size [1], int on cpu\n     *\n     * output tensors:\n     *   \\param decoder_output [num_token, hidden_units],\n     *   \\param last_token_hidden_units [batch_size, hidden_units]\n     *   \\param block_ptrs [total_block_counts], void*\n     */\n\n    constexpr auto device = kDEVICE;\n\n    Tensor      local_residual   = args.try_consume(\"input_embeds\");\n    const auto& local_token_nums = args.at(\"batch\").data<BatchData*>()[0]->local_token_num;\n\n    const auto local_token_num  = local_residual.shape(0);\n    const auto global_token_num = std::accumulate(local_token_nums.begin(), local_token_nums.end(), ssize_t{});\n\n    TM_CHECK_EQ(local_token_num, local_token_nums[attn_dp_rank_]);\n\n    const DataType dtype = local_residual.dtype();\n\n    Tensor global_hidden_states;\n    if (d_comm_) {\n        Buffer symm_buf      = args.at(\"symm_buf\").buffer();\n        global_hidden_states = {symm_buf.view(dtype), {global_token_num, (int)hidden_units_}};\n    }\n    else {\n        global_hidden_states = {{global_token_num, (int)hidden_units_}, local_residual.dtype(), kDEVICE};\n    }\n\n    Tensor local_hidden_states;\n    if (attn_dp_size_ > 1) {  // Offset hidden states buffer for mixed DP\n        TM_CHECK_EQ(local_token_nums.size(), attn_dp_size_);\n        std::vector offsets(attn_dp_size_ + 1, 0);\n        std::inclusive_scan(local_token_nums.data(), local_token_nums.data() + attn_dp_size_, offsets.begin() + 1);\n        const int offset    = offsets[attn_dp_rank_];\n        local_hidden_states = global_hidden_states.slice({offset, 0}, {local_token_num, -1});\n\n        // dbg(attn_dp_size_, attn_dp_rank_, local_token_nums, local_token_num, global_token_num);\n    }\n    else {\n        local_hidden_states = global_hidden_states;\n    }\n\n    TM_DEBUG_TENSOR(local_residual, \"res\", 1);\n    TM_DEBUG_TENSOR(weights.at(0)->self_attn_norm, \"norm_weight\", 2);\n\n    const auto stream = core::Context::stream().handle();\n\n    invokeRMSNorm(local_hidden_states, local_residual, weights.at(0)->self_attn_norm, rmsnorm_eps_, stream);\n    sync_check_cuda_error();\n\n    TM_DEBUG_TENSOR(local_hidden_states, Concat(\"norm0\", 0), 2);\n\n    // auto stack_alloc{core::Context::device_alloc().adapt<core::StackAllocatorImpl>()};\n    // core::ContextGuard ctx{Allocator{stack_alloc}};\n\n    for (int layer = 0; layer < layer_num_; ++layer) {\n\n        // stack_alloc->iter();\n\n        if (global_token_num == 0) {\n            break;\n        }\n\n        if (is_warm_up_ && layer >= tune_layer_num_) {\n            continue;\n        }\n\n        /////////////////////////////////////////////\n        /// self-attention or linear-attention\n        if (weights.at(layer)->linear_attn_weights) {\n            linear_attn_layer_->Forward(\n                {phase, local_hidden_states, local_hidden_states, weights.at(layer)->linear_attn_weights.get(), layer});\n        }\n        else {\n            attn_layer_->Forward(\n                {phase, local_hidden_states, local_hidden_states, weights.at(layer)->self_attn_weights.get(), layer});\n        }\n\n        TM_DEBUG_TENSOR(local_hidden_states, Concat(\"attn_block\", layer), 2);\n\n        // For gated delta networks, we may need a different output.bias name or it doesn't have it.\n        // We will just use `output.bias` from either layer.\n        Tensor out_bias;\n        if (weights.at(layer)->linear_attn_weights) {\n            out_bias = weights.at(layer)->linear_attn_weights->out_proj.bias;\n        }\n        else {\n            out_bias = weights.at(layer)->self_attn_weights->output.bias;\n        }\n\n        AllreduceResidualRMSnorm(global_hidden_states,\n                                 local_residual,\n                                 out_bias,\n                                 weights.at(layer)->ffn_norm,\n                                 local_token_num,\n                                 attn_tp_group_,\n                                 0,\n                                 local_token_nums.data());\n\n        TM_DEBUG_TENSOR(local_residual, Concat(\"residual0\", layer), 2);\n        TM_DEBUG_TENSOR(local_hidden_states, Concat(\"norm1\", layer), 2);\n\n        ////////////////////////////////////////////\n        /// feed-forward network\n\n        std::optional<MoeFfnLayer::ForwardParam> moe_fwd_param;\n\n        if (weights.at(layer)->moe_weights) {\n            moe_fwd_param = MoeFfnLayer::ForwardParam{global_hidden_states,\n                                                      global_hidden_states,\n                                                      weights.at(layer)->moe_weights.get(),\n                                                      ffn_layer_ ? 1.f : 0.f,\n                                                      layer};\n            moe_ffn_layer_->Forward(*moe_fwd_param);\n        }\n\n        if (weights.at(layer)->ffn_weights) {\n            ffn_layer_->forward(\n                {global_hidden_states, global_hidden_states, weights.at(layer)->ffn_weights.get(), (int)layer});\n        }\n\n        if (moe_fwd_param) {\n            moe_ffn_layer_->Combine(*moe_fwd_param);\n        }\n\n        TM_DEBUG_TENSOR(global_hidden_states, Concat(\"ffn_block\", layer), 2);\n\n        const bool last = layer == layer_num_ - 1;\n\n        auto& scale_weight = !last ? weights.at(layer + 1)->self_attn_norm : args.at(\"output_norm_weight\");\n\n        AllreduceResidualRMSnorm(global_hidden_states,\n                                 local_residual,\n                                 {},\n                                 scale_weight,\n                                 local_token_num,\n                                 0,\n                                 attn_tp_group_,\n                                 local_token_nums.data());\n        sync_check_cuda_error();\n\n        TM_DEBUG_TENSOR(local_residual, Concat(\"residual1\", layer), 2);\n        TM_DEBUG_TENSOR(local_hidden_states, Concat(\"norm0\", layer + 1), 2);\n\n        // if (layer == layer_num_ - 1) {\n        //     args.at(\"batch\").data<BatchData*>()[0]->Notify();\n        // }\n    }\n\n    // Token indices selected for decoding\n    const Buffer selected_pos = args.consume(\"selected_token_pos\").buffer();\n    // dbg(selected_pos);\n    // When there are no prefill sequences, token selection is not needed\n    const bool reuse_hidden_states = selected_pos.size() == local_token_num;\n\n    const bool output_hidden_states = args.try_(\"output_hidden_states\");\n\n    Tensor hidden_states{local_hidden_states};\n\n    if (d_comm_ && (output_hidden_states || reuse_hidden_states)) {\n        // The full `hidden_states` buffer is needed for output but it's a ref into `symm_buf` atm.\n        // Copy to residual buf so that `symm_buf` may be reused safely later\n        Copy(hidden_states, local_residual);\n        hidden_states = local_residual;\n    }\n\n    Tensor selected_states;\n    if (reuse_hidden_states) {\n        selected_states = hidden_states;\n    }\n    else {\n        selected_states = {{selected_pos.size(), (int)hidden_units_}, dtype, kDEVICE};\n        CollectHiddenStates(hidden_states, selected_pos, selected_states, stream);\n    }\n    args.produce(\"hidden_states\", selected_states);\n\n    // TM_DEBUG_TENSOR(selected_states.slice(0, selected_pos.size()), \"out\", 1);\n\n    if (output_hidden_states) {\n        args.produce(\"full_hidden_states\", hidden_states);\n    }\n}\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/models/llama/unified_decoder.h",
    "content": "#pragma once\n\n#include \"src/turbomind/comm/device_comm.h\"\n#include \"src/turbomind/models/llama/GatedDeltaNetLayer.h\"\n#include \"src/turbomind/models/llama/LlamaDecoderLayerWeight.h\"\n#include \"src/turbomind/models/llama/LlamaFfnLayer.h\"\n#include \"src/turbomind/models/llama/context.h\"\n#include \"src/turbomind/models/llama/llama_params.h\"\n#include \"src/turbomind/models/llama/moe_ffn_layer.h\"\n#include \"src/turbomind/models/llama/unified_attention_layer.h\"\n\nnamespace turbomind {\n\nclass UnifiedDecoder {\npublic:\n    using WeightType = LlamaDecoderLayerWeight;\n\n    UnifiedDecoder(const ModelParam&     model,\n                   const EngineParam&    engine,\n                   const AttentionParam& attn,\n                   const MoeParam&       moe,\n                   const Context&        ctx,\n                   int                   phases);\n\n    void Run(BatchOp op, int phase, TensorMap& env);\n\n    void Forward(int phase, TensorMap& env, const std::vector<WeightType*>& weights);\n\nprivate:\n    const size_t layer_num_;\n    const size_t hidden_units_;\n\n    const int attn_tp_size_;\n    const int attn_dp_size_;\n    const int attn_dp_rank_;\n    const int mlp_tp_size_;\n\n    const int attn_tp_group_;\n\n    const float rmsnorm_eps_;\n\n    comm::DeviceCommImpl* const d_comm_;\n\n    const int tune_layer_num_;\n\n    int& is_warm_up_;\n\n    std::unique_ptr<UnifiedAttentionLayer> attn_layer_;\n    std::unique_ptr<GatedDeltaNetLayer>    linear_attn_layer_;\n    std::unique_ptr<LlamaFfnLayer>         ffn_layer_;\n    std::unique_ptr<MoeFfnLayer>           moe_ffn_layer_;\n\n    void AllreduceResidualRMSnorm(Tensor&       hidden_states,\n                                  Tensor&       residual,\n                                  const Tensor& bias,\n                                  const Tensor& weight,\n                                  int           token_num,\n                                  int           t0,\n                                  int           t1,\n                                  const int*    local_token_nums);\n};\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/models/output_processor.cc",
    "content": "\n#include \"src/turbomind/models/output_processor.h\"\n\n#include <functional>\n\n#include \"src/turbomind/engine/request.h\"\n\n// #include \"dbg.h\"\n\nnamespace turbomind {\n\nusing std::vector;\nusing std::shared_ptr;\n\nstruct OutputProcessor::Impl {\n\n    static constexpr auto kAll = GenerationConfig::kAll;\n\n    const int vocab_size_;\n    const int max_logits_len_;\n    const int tp_rank_;\n\n    std::function<Tensor(const Tensor&)> lm_head_;\n\n    Impl(const ModelParam&                    model,\n         int                                  max_logits_len,\n         int                                  tp_rank,\n         int                                  phases,\n         std::function<Tensor(const Tensor&)> lm_head):\n        vocab_size_{(int)model.vocab_size},\n        max_logits_len_{max_logits_len},\n        tp_rank_{tp_rank},\n        lm_head_{std::move(lm_head)}\n    {\n        for (int i = 0; i < phases; ++i) {\n            data_.emplace_back();\n        }\n    }\n\n    struct Data {\n        Interval full_states;  // requested range for full hidden states\n        Interval full_logits;  // requested range for full logits\n\n        vector<std::tuple<int, int, Interval, Interval>> output_states;\n        vector<std::tuple<int, int, Interval, Interval>> output_logits;\n    };\n\n    vector<Data> data_;\n\n    struct Matching {\n        Interval& target;\n        const int offset_d;\n        Interval  src;\n        Interval  dst;\n\n        bool operator()(const Interval& x, int offset_s, Interval& merged)\n        {\n            if (auto y = target & x; y && y.begin() == target.begin()) {\n                dst    = {y.begin() - offset_d, y.size()};\n                src    = {offset_s + (y.begin() - x.begin()), y.size()};\n                merged = merged | src;\n                target = -(int)y.size() | target;\n                return true;\n            }\n            return false;\n        }\n    };\n\n    void Add(int phase, TensorMap& env)\n    {\n        const Buffer_<RequestCache*> rc = env.at(\"requests\").buffer();\n\n        for (int i = 0; i < rc.size(); ++i) {\n            auto& c = *rc[i];\n            auto& r = *c.req;\n            auto& g = r.gen_cfg;\n            if (g.output_logits) {\n                c.output_logits = g.output_logits == kAll ? Interval{c.step0} : Interval{c.prompt_len - 1};\n                c.logits_offset = c.output_logits.begin();\n            }\n            if (g.output_last_hidden_state) {\n                c.output_hidden_states =\n                    g.output_last_hidden_state == kAll ? Interval{c.step0} : Interval{c.prompt_len - 1};\n                c.hidden_states_offset = c.output_hidden_states.begin();\n                // dbg(&c.output_hidden_states, c.hidden_states_offset);\n            }\n        }\n    }\n\n    void Setup(int phase, TensorMap& env)\n    {\n        auto& d = data_.at(phase);\n\n        const auto& rc = env.at(\"batch\").data<BatchData*>()[0]->rc;\n\n        vector<Interval> all_tokens;\n        vector<Interval> sel_tokens;\n        for (int i = 0; i < rc.size(); ++i) {\n            using Size = Interval::Size;\n            auto& c    = *rc[i];\n            all_tokens.emplace_back(c.history_len + c.alpha, Size{c.input_len});\n            sel_tokens.emplace_back(c.history_len + c.alpha + c.input_len - 1, Size{1});\n            if (!c.generating) {\n                sel_tokens.back() = {};\n            }\n            // dbg(&all_tokens.back(), &sel_tokens.back());\n        }\n\n        const int token_num = *env.at(\"token_num\").data<int>();\n\n        d.full_logits = {INT_MAX, 0};\n        d.full_states = {INT_MAX, 0};\n\n        Interval select_states{INT_MAX, 0};\n        Interval select_logits{INT_MAX, 0};\n\n        d.output_logits = {};\n        d.output_states = {};\n\n        int offset = 0;\n\n        for (int i = 0; i < rc.size(); ++i) {\n            auto& c = *rc[i];\n            auto& g = c.req->gen_cfg;\n            if (c.output_hidden_states) {\n                Matching m{c.output_hidden_states, c.hidden_states_offset};\n                int      type = 0;\n                if (m(sel_tokens[i], i, select_states)) {\n                    type = 1;\n                }\n                else if (m(all_tokens[i], offset, d.full_states)) {\n                    type = 2;\n                }\n                if (type) {\n                    d.output_states.emplace_back(i, type, m.src, m.dst);\n                    // dbg(type, &m.src, &m.dst);\n                }\n            }\n            if (c.output_logits) {\n                Matching m{c.output_logits, c.logits_offset};\n                int      type = 0;\n                if (m(sel_tokens[i], i, select_logits)) {\n                    type = 1;\n                }\n                else if (m(all_tokens[i], offset, d.full_logits)) {\n                    type = 2;\n                }\n                if (type) {\n                    d.output_logits.emplace_back(i, type, m.src, m.dst);\n                }\n            }\n            offset += c.input_len;\n        }\n\n        // logits depends on hidden states\n        d.full_states = d.full_states | d.full_logits;\n    }\n\n    void Prepare(int phase, TensorMap& env)\n    {\n        auto& d = data_.at(phase);\n        if (d.full_states) {\n            env.produce(\"output_hidden_states\", Tensor{});\n        }\n    }\n\n    template<class Ranges>\n    void OutputHiddenStates(const Ranges& ranges, const Tensor& h, int type, const vector<shared_ptr<RequestCache>>& rs)\n    {\n        for (const auto& [i, t, src, dst] : ranges) {\n            if (t == type) {\n                auto& out = rs[i]->req->outputs.at(\"last_hidden_state\");\n                if (tp_rank_ == 0) {\n                    // dbg(&src, &dst);\n                    Copy(h.slice(src.begin(), (int)src.size()), out.slice(dst.begin(), (int)dst.size()));\n                }\n            }\n        }\n    }\n\n    void ComputeAndOutputLogits(const Data& data, const Tensor& h, const vector<shared_ptr<RequestCache>>& rs)\n    {\n        const int step_size = max_logits_len_;\n\n        // Coroutine frame\n        int  p      = 0;\n        auto ranges = data.output_logits;\n\n        using Size = Interval::Size;\n\n        bool success = false;\n        // Erode the range iteratively until empty\n        for (auto r = data.full_logits; r; r = -step_size | r) {\n            // dbg(&r);\n            if (auto chunk = r & Interval{r.begin(), Size{step_size}}) {\n                // dbg(&chunk);\n                // Compute & output full logits by chunks\n                auto logits = lm_head_(h.slice(chunk.begin(), (int)chunk.size()));\n                success     = OutputLogitsImpl(ranges, p, logits, chunk.begin(), 2, rs);\n                if (success) {  // all requests satisfied, exit early\n                    break;\n                }\n            }\n        }\n\n        TM_CHECK(success);  // all requests must be satisfied at the end\n    }\n\n    template<class Ranges>\n    void OutputLogits(Ranges& ranges_, const Tensor& l, int type, const vector<shared_ptr<RequestCache>>& rs)\n    {\n        // Coroutine frame\n        int  p      = 0;\n        auto ranges = ranges_;\n\n        TM_CHECK(OutputLogitsImpl(ranges, p, l, /* base */ 0, type, rs));\n    }\n\n    template<class Ranges>\n    bool OutputLogitsImpl(\n        Ranges& ranges, int& p, const Tensor& l, int base, int type, const vector<shared_ptr<RequestCache>>& rs)\n    {\n        // dbg(\"OutputLogitsImpl\");\n        const auto stream = core::Context::stream().handle();\n        for (; p < ranges.size(); ++p) {\n            if (auto& [i, t, src, dst] = ranges[p]; t == type) {\n                Tensor&        out   = rs[i]->req->outputs.at(\"logits\");\n                const DataType dtype = out.dtype();\n                TM_CHECK_LE(base, src.begin());  // logical error\n                if (Interval msrc = src & Interval{base, Interval::Size{(int)l.shape(0)}}) {\n                    const int tokens = (int)msrc.size();\n                    Interval  mdst{dst.begin(), msrc.size()};\n                    // TODO: support strides in `DLTensor`, so that batched 1D copy can be used\n                    if (tp_rank_ == 0) {\n                        // dbg(&mdst, &msrc, tokens, out, base, l);\n                        TM_CHECK_EQ(cudaMemcpy2DAsync(out.slice(mdst.begin(), tokens).raw_data(),\n                                                      byte_size(dtype, out.stride(0)),\n                                                      l.slice(msrc.begin() - base, tokens).raw_data(),\n                                                      byte_size(dtype, l.stride(0)),\n                                                      byte_size(dtype, vocab_size_),\n                                                      tokens,\n                                                      cudaMemcpyDefault,\n                                                      stream),\n                                    0);\n                    }\n                    // move to next request if they are empty after the erosion\n                    src = -(int)msrc.size() | src;\n                    dst = -(int)mdst.size() | dst;\n                }\n                // dbg(&src, (int)src.size(), &dst, (int)dst.size());\n                if (src) {\n                    // request not compeleted, suspend and wait for next chunk\n                    return false;\n                }\n            }\n        }\n        return true;\n    }\n\n    void OutputHiddenStatesAndLogits(int phase, TensorMap& env, int type)\n    {\n        auto& d = data_.at(phase);\n        auto& b = *env.at(\"batch\").data<BatchData*>()[0];\n\n        if (type == 2 && d.full_states) {\n            auto hidden_states = env.consume(\"full_hidden_states\");\n            if (!d.output_states.empty()) {\n                OutputHiddenStates(d.output_states, hidden_states, 2, b.rc);\n            }\n            if (!d.output_logits.empty() && d.full_logits) {\n                ComputeAndOutputLogits(d, hidden_states, b.rc);\n            }\n        }\n\n        if (type == 1) {\n            if (!d.output_states.empty()) {\n                OutputHiddenStates(d.output_states, env.at(\"hidden_states\"), 1, b.rc);\n            }\n            if (!d.output_logits.empty()) {\n                OutputLogits(d.output_logits, env.at(\"logits\"), 1, b.rc);\n            }\n        }\n    }\n};\n\nOutputProcessor::~OutputProcessor() = default;\n\nOutputProcessor::OutputProcessor(\n    const ModelParam& model, int max_logits_len, int tp_rank, int phases, std::function<Tensor(const Tensor&)> lm_head):\n    impl_{std::make_unique<Impl>(model, max_logits_len, tp_rank, phases, std::move(lm_head))}\n{\n}\n\nvoid OutputProcessor::Run(BatchOp op, int phase, TensorMap& env)\n{\n    switch (op) {\n        case BatchOp::kAdd:\n            return impl_->Add(phase, env);\n        case BatchOp::kSetup:\n            return impl_->Setup(phase, env);\n        case BatchOp::kPrepare:\n            return impl_->Prepare(phase, env);\n        default:\n            return;\n    }\n}\n\nvoid OutputProcessor::OutputHiddenStatesAndLogits(int phase, TensorMap& env, int type)\n{\n    return impl_->OutputHiddenStatesAndLogits(phase, env, type);\n}\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/models/output_processor.h",
    "content": "#pragma once\n\n#include \"src/turbomind/engine/batch.h\"\n#include \"src/turbomind/models/llama/llama_params.h\"\n\nnamespace turbomind {\n\nclass OutputProcessor {\npublic:\n    ~OutputProcessor();\n\n    OutputProcessor(const ModelParam&                    model,  //\n                    int                                  max_logits_len,\n                    int                                  tp_rank,\n                    int                                  phases,\n                    std::function<Tensor(const Tensor&)> lm_head);\n\n    void Run(BatchOp op, int phase, TensorMap& env);\n\n    void OutputHiddenStatesAndLogits(int phase, TensorMap& env, int type);\n\nprivate:\n    struct Impl;\n    std::unique_ptr<Impl> impl_;\n};\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/python/CMakeLists.txt",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n\ncmake_minimum_required(VERSION 3.11)\nproject(_turbomind LANGUAGES CXX CUDA)\n\nfind_package(pybind11 CONFIG)\nif(NOT pybind11_FOUND)\n    execute_process(COMMAND \"pybind11-config\" \"--cmakedir\"\n                    RESULT_VARIABLE _COMMAND_SUCCESS\n                    OUTPUT_VARIABLE pybind11_DIR\n                    OUTPUT_STRIP_TRAILING_WHITESPACE)\n    find_package(pybind11 CONFIG)\nendif()\n\npybind11_add_module(${PROJECT_NAME} bind.cpp)\ntarget_link_libraries(${PROJECT_NAME} PRIVATE turbomind xgrammar)\n\npybind11_add_module(_xgrammar xgrammar_bind.cpp)\ntarget_link_libraries(_xgrammar PRIVATE core xgrammar)\ntarget_compile_features(_xgrammar PRIVATE cxx_std_14)\n\nif (CALL_FROM_SETUP_PY)\n  string(REPLACE \".\" \";\" _ver ${CMAKE_CUDA_COMPILER_VERSION})\n  list(GET _ver 0 CUDA_MAJOR)\n\n  if(CUDA_MAJOR GREATER_EQUAL \"13\")\n    set(_INSTALL_CUDA_RPATH\n        \"\\$ORIGIN\"\n        \"\\$ORIGIN/../../nvidia/nccl/lib/\"\n        \"\\$ORIGIN/../../nvidia/cu${CUDA_MAJOR}/lib/\"\n    )\n  else()\n    set(_INSTALL_CUDA_RPATH\n        \"\\$ORIGIN\"\n        \"\\$ORIGIN/../../nvidia/nccl/lib/\"\n        \"\\$ORIGIN/../../nvidia/cuda_runtime/lib/\"\n        \"\\$ORIGIN/../../nvidia/cublas/lib/\"\n        \"\\$ORIGIN/../../nvidia/curand/lib/\"\n    )\n  endif()\n  set_target_properties(${PROJECT_NAME} PROPERTIES\n      BUILD_RPATH \"\\$ORIGIN\"\n      INSTALL_RPATH \"${_INSTALL_CUDA_RPATH}\"\n  )\nendif ()\n"
  },
  {
    "path": "src/turbomind/python/bind.cpp",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include <memory>\n#include <sstream>\n#include <stdexcept>\n\n#include <cuda_runtime.h>\n\n#include <pybind11/functional.h>\n#include <pybind11/pybind11.h>\n#include <pybind11/pytypes.h>\n#include <pybind11/stl.h>\n#include <pybind11/stl_bind.h>\n\n#include \"xgrammar/compiler.h\"\n\n#include \"src/turbomind/core/data_type.h\"\n#include \"src/turbomind/core/tensor.h\"\n#include \"src/turbomind/engine/model_request.h\"\n#include \"src/turbomind/python/dlpack.h\"\n#include \"src/turbomind/turbomind.h\"\n#include \"src/turbomind/utils/cuda_utils.h\"\n#include \"src/turbomind/utils/metrics.h\"\n\nnamespace py = pybind11;\nnamespace ft = turbomind;\nusing namespace pybind11::literals;\n\nusing ft::core::Tensor;\n\n// prepare to bind container\nusing TensorMap = ft::core::TensorMap;\nPYBIND11_MAKE_OPAQUE(TensorMap);\nstatic const char kDlTensorCapsuleName[] = \"dltensor\";\n\nDLDevice getDLDevice(const Tensor& tensor)\n{\n    int device_id = 0;\n    if (tensor.device().type == ft::kDEVICE) {\n        cudaPointerAttributes ptr_attr{};\n        cudaPointerGetAttributes(&ptr_attr, tensor.raw_data());\n        device_id = ptr_attr.device;\n    }\n\n    DLDevice device{kDLCPU, device_id};\n\n    switch (tensor.device().type) {\n        case ft::kCPU:\n            device.device_type = DLDeviceType::kDLCPU;\n            break;\n        case ft::kCPUpinned:\n            device.device_type = DLDeviceType::kDLCUDAHost;\n            break;\n        case ft::kDEVICE:\n            device.device_type = DLDeviceType::kDLCUDA;\n            break;\n        default:\n            break;\n    }\n\n    return device;\n}\n\nDLManagedTensor* TritonTensorToDLManagedTensor(Tensor& tensor)\n{\n    DLDevice   device = getDLDevice(tensor);\n    DLDataType data_type{0, 0, 1};\n    using ft::data_type_v;\n    switch (tensor.dtype()) {\n        case data_type_v<bool>:\n            data_type.code = DLDataTypeCode::kDLBool;\n            data_type.bits = 8;\n            break;\n        case data_type_v<uint8_t>:\n            data_type.code = DLDataTypeCode::kDLUInt;\n            data_type.bits = 8;\n            break;\n        case data_type_v<uint16_t>:\n            data_type.code = DLDataTypeCode::kDLUInt;\n            data_type.bits = 16;\n            break;\n        case data_type_v<uint32_t>:\n            data_type.code = DLDataTypeCode::kDLUInt;\n            data_type.bits = 32;\n            break;\n        case data_type_v<uint64_t>:\n            data_type.code = DLDataTypeCode::kDLUInt;\n            data_type.bits = 64;\n            break;\n        case data_type_v<int8_t>:\n            data_type.code = DLDataTypeCode::kDLInt;\n            data_type.bits = 8;\n            break;\n        case data_type_v<int16_t>:\n            data_type.code = DLDataTypeCode::kDLInt;\n            data_type.bits = 16;\n            break;\n        case data_type_v<int32_t>:\n            data_type.code = DLDataTypeCode::kDLInt;\n            data_type.bits = 32;\n            break;\n        case data_type_v<int64_t>:\n            data_type.code = DLDataTypeCode::kDLInt;\n            data_type.bits = 64;\n            break;\n        case data_type_v<turbomind::half_t>:\n            data_type.code = DLDataTypeCode::kDLFloat;\n            data_type.bits = 16;\n            break;\n        case data_type_v<float>:\n            data_type.code = DLDataTypeCode::kDLFloat;\n            data_type.bits = 32;\n            break;\n        case data_type_v<double>:\n            data_type.code = DLDataTypeCode::kDLFloat;\n            data_type.bits = 64;\n            break;\n        case data_type_v<turbomind::bfloat16_t>:\n            data_type.code = DLDataTypeCode::kDLBfloat;\n            data_type.bits = 16;\n            break;\n        default:\n            break;\n    }\n\n    static_assert(sizeof(int64_t) == sizeof(tensor.shape(0)));\n\n    Tensor*  ctx = new Tensor(tensor);\n    DLTensor dl_tensor{const_cast<void*>(ctx->raw_data()),\n                       device,\n                       (int32_t)(ctx->ndim()),\n                       data_type,\n                       (int64_t*)ctx->shape().data(),\n                       (int64_t*)(nullptr),\n                       0};\n    return new DLManagedTensor{dl_tensor, ctx, [](DLManagedTensor* dlmt) {  //\n                                   delete (Tensor*)dlmt->manager_ctx;\n                                   delete dlmt;\n                               }};\n}\n\nft::DeviceType getMemoryType(DLDevice device)\n{\n    switch (device.device_type) {\n        case DLDeviceType::kDLCUDAHost:\n            return ft::DeviceType::kCPUpinned;\n        case DLDeviceType::kDLCUDA:\n            return ft::DeviceType::kDEVICE;\n        case DLDeviceType::kDLCPU:\n        default:\n            return ft::DeviceType::kCPU;\n    }\n}\n\nft::DataType getDataType(DLDataType data_type)\n{\n    using ft::data_type_v;\n    switch (data_type.code) {\n        case DLDataTypeCode::kDLUInt:\n            switch (data_type.bits) {\n                case 8:\n                    return data_type_v<uint8_t>;\n                case 16:\n                    return data_type_v<uint16_t>;\n                case 32:\n                    return data_type_v<uint32_t>;\n                case 64:\n                    return data_type_v<uint64_t>;\n                default:\n                    return data_type_v<void>;\n            }\n            break;\n        case DLDataTypeCode::kDLInt:\n            switch (data_type.bits) {\n                case 8:\n                    return data_type_v<int8_t>;\n                case 16:\n                    return data_type_v<int16_t>;\n                case 32:\n                    return data_type_v<int32_t>;\n                case 64:\n                    return data_type_v<int64_t>;\n                default:\n                    return data_type_v<void>;\n            }\n            break;\n        case DLDataTypeCode::kDLFloat:\n            switch (data_type.bits) {\n                case 16:\n                    return data_type_v<turbomind::half_t>;\n                case 32:\n                    return data_type_v<float>;\n                case 64:\n                    return data_type_v<double>;\n                default:\n                    return data_type_v<void>;\n            }\n            break;\n        case DLDataTypeCode::kDLBfloat:\n            switch (data_type.bits) {\n                case 16:\n                    return data_type_v<turbomind::bfloat16_t>;\n                default:\n                    return data_type_v<void>;\n            }\n            break;\n        case DLDataTypeCode::kDLBool:\n            return data_type_v<bool>;\n        default:\n            return data_type_v<void>;\n    }\n}\n\nstd::shared_ptr<Tensor> DLManagedTensorToTritonTensor(DLManagedTensor* tensor)\n{\n    auto& dl_tensor = tensor->dl_tensor;\n    auto  where     = getMemoryType(dl_tensor.device);\n    auto  dtype     = getDataType(dl_tensor.dtype);\n    assert(dl_tensor.ndim > 0);\n    std::vector<ft::core::ssize_t> shape(dl_tensor.shape, dl_tensor.shape + dl_tensor.ndim);\n\n    std::shared_ptr<void> ptr{dl_tensor.data, [tensor](void* p) {\n                                  if (tensor->deleter) {\n                                      tensor->deleter(tensor);\n                                  }\n                              }};\n    return std::make_shared<Tensor>(ptr, std::move(shape), dtype, where);\n}\n\nstatic void safe_memcpy(void* dst, const void* src, size_t size)\n{\n    cudaPointerAttributes dat{};\n    cudaPointerAttributes sat{};\n    ft::check_cuda_error(cudaPointerGetAttributes(&dat, dst));\n    ft::check_cuda_error(cudaPointerGetAttributes(&sat, src));\n    try {\n        if (dat.devicePointer && sat.devicePointer) {\n            // Both can be accessed from current context\n            ft::check_cuda_error(cudaMemcpy(dst, src, size, cudaMemcpyDefault));\n        }\n        else if (dat.type == cudaMemoryTypeDevice && sat.type == cudaMemoryTypeDevice) {\n            if (dat.device != sat.device) {\n                // On different devices, try peer memcpy\n                ft::check_cuda_error(cudaMemcpyPeer(dst, dat.device, src, sat.device, size));\n            }\n            else {\n                // Same device, switch to the device first (this is unlikely)\n                ft::CudaDeviceGuard guard(dat.device);\n                ft::check_cuda_error(cudaMemcpy(dst, src, size, cudaMemcpyDefault));\n            }\n        }\n        else {\n            // Unknown case, give it a try anyway\n            ft::check_cuda_error(cudaMemcpy(dst, src, size, cudaMemcpyDefault));\n        }\n    }\n    catch (...) {\n        int device_id{-1};\n        cudaGetDevice(&device_id);\n        TM_LOG_ERROR(\"cudaMemcpy failed: dst=(%d, %d, %p, %p), src=(%d, %d, %p, %p), size=%s, device=%d\",\n                     (int)dat.type,\n                     dat.device,\n                     dat.devicePointer,\n                     dat.hostPointer,\n                     (int)sat.type,\n                     sat.device,\n                     sat.devicePointer,\n                     sat.hostPointer,\n                     std::to_string(size).c_str(),\n                     device_id);\n        throw;\n    }\n}\n\nnamespace {\n\nstruct ScopedGIL {\n    ScopedGIL(const ScopedGIL&) = delete;\n    ScopedGIL& operator=(const ScopedGIL&) = delete;\n    ScopedGIL(ScopedGIL&&)                 = delete;\n    ScopedGIL& operator=(ScopedGIL&&) = delete;\n    ScopedGIL()\n    {\n        state = PyGILState_Ensure();\n    }\n    ~ScopedGIL()\n    {\n        PyGILState_Release(state);\n    }\n    PyGILState_STATE state;\n};\n\n}  // namespace\n\nPYBIND11_MODULE(_turbomind, m)\n{\n    py::class_<ft::RequestMetrics, std::shared_ptr<ft::RequestMetrics>>(m, \"RequestMetrics\")\n        .def(py::init())\n        .def_property_readonly(\"enqueue_time\",\n                               [](ft::RequestMetrics& m) { return m.enqueue_time.load(std::memory_order_relaxed); })\n        .def_property_readonly(\"scheduled_time\",\n                               [](ft::RequestMetrics& m) { return m.scheduled_time.load(std::memory_order_relaxed); });\n\n    py::class_<ft::ScheduleMetrics, std::shared_ptr<ft::ScheduleMetrics>>(m, \"ScheduleMetrics\")\n        .def(py::init())\n        .def_readonly(\"total_seqs\", &ft::ScheduleMetrics::total_seqs)\n        .def_readonly(\"active_seqs\", &ft::ScheduleMetrics::active_seqs)\n        .def_readonly(\"waiting_seqs\", &ft::ScheduleMetrics::waiting_seqs)\n        .def_readonly(\"total_blocks\", &ft::ScheduleMetrics::total_blocks)\n        .def_readonly(\"active_blocks\", &ft::ScheduleMetrics::active_blocks)\n        .def_readonly(\"cached_blocks\", &ft::ScheduleMetrics::cached_blocks)\n        .def_readonly(\"free_blocks\", &ft::ScheduleMetrics::free_blocks);\n\n    py::class_<ft::SessionParam>(m, \"SessionParam\")\n        .def(py::init([](uint64_t id, int step, bool start, bool end) {\n                 if (!start && end) {\n                     throw std::logic_error(\"unsupported arguments: start=false, end=true\");\n                 }\n                 ft::SessionParam param{};\n                 param.id         = id;\n                 param.step       = step;\n                 param.start_flag = start;\n                 param.end_flag   = end;\n                 return param;\n             }),\n             \"id\"_a,\n             \"step\"_a,\n             \"start\"_a,\n             \"end\"_a)\n        .def_readwrite(\"id\", &ft::SessionParam::id)\n        .def_readwrite(\"step\", &ft::SessionParam::step)\n        .def_readwrite(\"start\", &ft::SessionParam::start_flag)\n        .def_readwrite(\"end\", &ft::SessionParam::end_flag);\n\n    py::class_<ft::GenerationConfig>(m, \"GenerationConfig\")\n        .def(py::init())\n        .def_readwrite(\"max_new_tokens\", &ft::GenerationConfig::max_new_tokens)\n        .def_readwrite(\"min_new_tokens\", &ft::GenerationConfig::min_new_tokens)\n        .def_readwrite(\"eos_ids\", &ft::GenerationConfig::eos_ids)\n        .def_readwrite(\"stop_ids\", &ft::GenerationConfig::stop_ids)\n        .def_readwrite(\"bad_ids\", &ft::GenerationConfig::bad_ids)\n        .def_readwrite(\"top_p\", &ft::GenerationConfig::top_p)\n        .def_readwrite(\"top_k\", &ft::GenerationConfig::top_k)\n        .def_readwrite(\"min_p\", &ft::GenerationConfig::min_p)\n        .def_readwrite(\"temperature\", &ft::GenerationConfig::temperature)\n        .def_readwrite(\"repetition_penalty\", &ft::GenerationConfig::repetition_penalty)\n        .def_readwrite(\"random_seed\", &ft::GenerationConfig::random_seed)\n        .def_readwrite(\"output_logprobs\", &ft::GenerationConfig::output_logprobs)\n        .def_readwrite(\"output_last_hidden_state\", &ft::GenerationConfig::output_last_hidden_state)\n        .def_readwrite(\"output_logits\", &ft::GenerationConfig::output_logits)\n        .def(\"__repr__\", [](const ft::GenerationConfig& c) {\n            std::ostringstream oss;\n            oss << c;\n            return oss.str();\n        });\n\n    py::class_<ft::RequestState, std::unique_ptr<ft::RequestState>>(m, \"RequestState\")\n        .def_readonly(\"status\", &ft::RequestState::status)\n        .def_readonly(\"seq_len\", &ft::RequestState::seq_len);\n\n    py::class_<ft::AtomicRequestState, std::shared_ptr<ft::AtomicRequestState>>(m, \"AtomicRequestState\")\n        .def(\"consume\", [](ft::AtomicRequestState& s) { return s.exchange(nullptr); });\n\n    // data type\n    {\n        using namespace turbomind;\n        py::enum_<ft::DataType>(m, \"DataType\")\n            .value(\"TYPE_INVALID\", kNull)\n            .value(\"TYPE_BOOL\", kBool)\n            .value(\"TYPE_UINT8\", kUint8)\n            .value(\"TYPE_UINT16\", kUint16)\n            .value(\"TYPE_UINT32\", kUint32)\n            .value(\"TYPE_UINT64\", kUint64)\n            .value(\"TYPE_INT8\", kInt8)\n            .value(\"TYPE_INT16\", kInt16)\n            .value(\"TYPE_INT32\", kInt32)\n            .value(\"TYPE_INT64\", kInt64)\n            .value(\"TYPE_FP16\", kFloat16)\n            .value(\"TYPE_FP32\", kFloat32)\n            .value(\"TYPE_FP64\", kFloat64)\n            .value(\"TYPE_BF16\", kBfloat16);\n\n        // memory type\n        py::enum_<ft::DeviceType>(m, \"MemoryType\")\n            .value(\"MEMORY_CPU\", ft::DeviceType::kCPU)\n            .value(\"MEMORY_CPU_PINNED\", ft::DeviceType::kCPUpinned)\n            .value(\"MEMORY_GPU\", ft::DeviceType::kDEVICE);\n    }\n\n    // tensor\n    py::class_<Tensor, std::shared_ptr<Tensor>>(m, \"Tensor\")\n        .def_property_readonly(\"where\", [](const Tensor& t) { return t.device().type; })\n        .def_property_readonly(\"type\", [](const Tensor& t) { return t.dtype(); })\n        .def_property_readonly(\"shape\", [](const Tensor& t) { return t.shape(); })\n        .def_property_readonly(\"data\", [](const Tensor& t) { return t.raw_data(); })\n        .def(\n            \"copy_from\",\n            [](Tensor& self, py::object obj) {\n                py::capsule      cap = obj.attr(\"__dlpack__\")();\n                DLManagedTensor* dlmt =\n                    static_cast<DLManagedTensor*>(PyCapsule_GetPointer(cap.ptr(), kDlTensorCapsuleName));\n                auto src = DLManagedTensorToTritonTensor(dlmt);\n                // take ownership of capsule's payload\n                cap.set_name(\"used_dltensor\");\n\n                TM_CHECK_EQ(self.byte_size(), src->byte_size()) << self << \" \" << *src;\n                safe_memcpy(self.raw_data(), src->raw_data(), self.byte_size());\n            },\n            \"tensor\"_a)\n        .def(\n            \"__dlpack__\",\n            [](Tensor& self, long stream) {\n                DLManagedTensor* dlmt = TritonTensorToDLManagedTensor(self);\n                return py::capsule(dlmt, kDlTensorCapsuleName, [](PyObject* obj) {\n                    DLManagedTensor* dlmt =\n                        static_cast<DLManagedTensor*>(PyCapsule_GetPointer(obj, kDlTensorCapsuleName));\n                    if (dlmt) {\n                        dlmt->deleter(dlmt);\n                    }\n                    else {\n                        // The tensor has been deleted. Clear any error from\n                        // PyCapsule_GetPointer.\n                        PyErr_Clear();\n                    }\n                });\n            },\n            \"stream\"_a = 0)\n        .def(\"__dlpack_device__\", [](const Tensor& self) {\n            auto device = getDLDevice(self);\n            return std::tuple<int, int>(int(device.device_type), device.device_id);\n        });\n    m.def(\n        \"from_dlpack\",\n        [](py::object obj) {\n            py::capsule      cap = obj.attr(\"__dlpack__\")();\n            DLManagedTensor* dlmt =\n                static_cast<DLManagedTensor*>(PyCapsule_GetPointer(cap.ptr(), kDlTensorCapsuleName));\n            auto ret = DLManagedTensorToTritonTensor(dlmt);\n            // take ownership of capsule's payload\n            cap.set_name(\"used_dltensor\");\n            return ret;\n        },\n        \"dl_managed_tensor\"_a);\n\n    py::bind_map<TensorMap, std::shared_ptr<TensorMap>>(m, \"TensorMap\");\n\n    using ft::ModelRequest;\n    py::class_<ModelRequest>(m, \"ModelRequest\")\n        .def(\n            \"forward\",\n            [](ModelRequest*               model_request,\n               std::shared_ptr<TensorMap>  input_tensors,\n               const ft::SessionParam&     session,\n               const ft::GenerationConfig& gen_cfg,\n               bool                        stream_output,\n               bool                        enable_metrics,\n               std::function<void()>       cb) {\n                ModelRequest::InputParam param{};\n                param.tensors        = std::move(input_tensors);\n                param.session        = session;\n                param.gen_cfg        = gen_cfg;\n                param.stream_output  = stream_output;\n                param.enable_metrics = enable_metrics;\n\n                auto ret = model_request->Forward(std::move(param), [cb = std::move(cb)]() {\n                    try {\n                        cb();\n                    }\n                    catch (const py::error_already_set& e) {\n                        std::cerr << e.what() << std::endl;\n                    }\n                });\n                return std::make_tuple(std::move(ret.tensors), std::move(ret.state), std::move(ret.metrics));\n            },\n            py::call_guard<py::gil_scoped_release>(),\n            \"input_tensors\"_a,\n            \"session\"_a,\n            \"gen_cfg\"_a,\n            \"stream_output\"_a,\n            \"enable_metrics\"_a,\n            \"cb\"_a)\n        .def(\n            \"cancel\",\n            [](ModelRequest* model_request) {\n                model_request->Cancel();  //\n            },\n            py::call_guard<py::gil_scoped_release>())\n        .def(\n            \"end\",\n            [](ModelRequest* model_request, std::function<void(int)> cb, uint64_t session_id) {\n                model_request->End(std::move(cb), session_id);  //\n            },\n            py::call_guard<py::gil_scoped_release>(),\n            \"cb\"_a,\n            \"session_id\"_a)\n        .def(\n            \"set_grammar\",\n            [](ModelRequest* model_request, const xgrammar::CompiledGrammar& grammar) {\n                TM_LOG_INFO(\"Set grammar for model_request\");\n                model_request->setGrammar(grammar);\n            },\n            py::call_guard<py::gil_scoped_release>(),\n            \"grammar\"_a);\n\n    // transformer model\n    using ft::TurboMind;\n    py::class_<TurboMind, std::shared_ptr<TurboMind>>(m, \"TurboMind\")\n        .def_static(\n            \"create\",\n            [](std::string model_dir, std::string config, std::string weight_type) -> std::shared_ptr<TurboMind> {\n                auto gil_factory = [] {  //\n                    // erase the type\n                    return std::static_pointer_cast<void>(std::make_shared<ScopedGIL>());\n                };\n                auto no_gil_deleter = [](TurboMind* ptr) {\n                    pybind11::gil_scoped_release release;\n                    delete ptr;\n                };\n\n                std::shared_ptr<TurboMind> model(new TurboMind(model_dir, config, gil_factory), no_gil_deleter);\n                return model;\n            },\n            \"model_dir\"_a,\n            \"config\"_a      = \"\",\n            \"weight_type\"_a = \"half\")\n        .def(\n            \"create_request\",\n            [](TurboMind* model) { return model->CreateRequest(); },\n            py::call_guard<py::gil_scoped_release>())\n        .def(\"create_weights\", &TurboMind::CreateWeights, py::call_guard<py::gil_scoped_release>(), \"index\"_a)\n        .def(\n            \"get_weights\",\n            [](TurboMind* model, int index) { return model->GetWeights(index); },\n            py::call_guard<py::gil_scoped_release>(),\n            \"index\"_a)\n        .def(\n            \"process_weight\",\n            [](TurboMind* model, int index) { model->ProcessWeights(index); },\n            py::call_guard<py::gil_scoped_release>(),\n            \"index\"_a)\n        .def(\n            \"create_engine\",\n            [](TurboMind* model, int index) { model->CreateEngine(index); },\n            py::call_guard<py::gil_scoped_release>(),\n            \"index\"_a)\n        .def(\n            \"get_schedule_metrics\",\n            [](TurboMind* model, int index) { return model->GetScheduleMetrics(index); },\n            py::call_guard<py::gil_scoped_release>(),\n            \"index\"_a)\n        .def(\n            \"sleep\",\n            [](TurboMind* model, int index, int level) { model->Sleep(index, level); },\n            py::call_guard<py::gil_scoped_release>(),\n            \"index\"_a,\n            \"level\"_a)\n        .def(\n            \"wakeup\",\n            [](TurboMind* model, int index, const std::vector<std::string>& tags) { model->WakeUp(index, tags); },\n            py::call_guard<py::gil_scoped_release>(),\n            \"index\"_a,\n            \"tags\"_a)\n        .def(\"is_dummy_node\", [](TurboMind* model) { return model->is_dummy_node(); });\n}\n"
  },
  {
    "path": "src/turbomind/python/dlpack.h",
    "content": "/*!\n *  Copyright (c) 2017 by Contributors\n * \\file dlpack.h\n * \\brief The common header of DLPack.\n */\n#ifndef DLPACK_DLPACK_H_\n#define DLPACK_DLPACK_H_\n\n/**\n * \\brief Compatibility with C++\n */\n#ifdef __cplusplus\n#define DLPACK_EXTERN_C extern \"C\"\n#else\n#define DLPACK_EXTERN_C\n#endif\n\n/*! \\brief The current major version of dlpack */\n#define DLPACK_MAJOR_VERSION 1\n\n/*! \\brief The current minor version of dlpack */\n#define DLPACK_MINOR_VERSION 0\n\n/*! \\brief DLPACK_DLL prefix for windows */\n#ifdef _WIN32\n#ifdef DLPACK_EXPORTS\n#define DLPACK_DLL __declspec(dllexport)\n#else\n#define DLPACK_DLL __declspec(dllimport)\n#endif\n#else\n#define DLPACK_DLL\n#endif\n\n#include <stddef.h>\n#include <stdint.h>\n\n#ifdef __cplusplus\nextern \"C\" {\n#endif\n\n/*!\n * \\brief The DLPack version.\n *\n * A change in major version indicates that we have changed the\n * data layout of the ABI - DLManagedTensorVersioned.\n *\n * A change in minor version indicates that we have added new\n * code, such as a new device type, but the ABI is kept the same.\n *\n * If an obtained DLPack tensor has a major version that disagrees\n * with the version number specified in this header file\n * (i.e. major != DLPACK_MAJOR_VERSION), the consumer must call the deleter\n * (and it is safe to do so). It is not safe to access any other fields\n * as the memory layout will have changed.\n *\n * In the case of a minor version mismatch, the tensor can be safely used as\n * long as the consumer knows how to interpret all fields. Minor version\n * updates indicate the addition of enumeration values.\n */\ntypedef struct {\n    /*! \\brief DLPack major version. */\n    uint32_t major;\n    /*! \\brief DLPack minor version. */\n    uint32_t minor;\n} DLPackVersion;\n\n/*!\n * \\brief The device type in DLDevice.\n */\n#ifdef __cplusplus\ntypedef enum: int32_t\n{\n#else\ntypedef enum\n{\n#endif\n    /*! \\brief CPU device */\n    kDLCPU = 1,\n    /*! \\brief CUDA GPU device */\n    kDLCUDA = 2,\n    /*!\n     * \\brief Pinned CUDA CPU memory by cudaMallocHost\n     */\n    kDLCUDAHost = 3,\n    /*! \\brief OpenCL devices. */\n    kDLOpenCL = 4,\n    /*! \\brief Vulkan buffer for next generation graphics. */\n    kDLVulkan = 7,\n    /*! \\brief Metal for Apple GPU. */\n    kDLMetal = 8,\n    /*! \\brief Verilog simulator buffer */\n    kDLVPI = 9,\n    /*! \\brief ROCm GPUs for AMD GPUs */\n    kDLROCM = 10,\n    /*!\n     * \\brief Pinned ROCm CPU memory allocated by hipMallocHost\n     */\n    kDLROCMHost = 11,\n    /*!\n     * \\brief Reserved extension device type,\n     * used for quickly test extension device\n     * The semantics can differ depending on the implementation.\n     */\n    kDLExtDev = 12,\n    /*!\n     * \\brief CUDA managed/unified memory allocated by cudaMallocManaged\n     */\n    kDLCUDAManaged = 13,\n    /*!\n     * \\brief Unified shared memory allocated on a oneAPI non-partititioned\n     * device. Call to oneAPI runtime is required to determine the device\n     * type, the USM allocation type and the sycl context it is bound to.\n     *\n     */\n    kDLOneAPI = 14,\n    /*! \\brief GPU support for next generation WebGPU standard. */\n    kDLWebGPU = 15,\n    /*! \\brief Qualcomm Hexagon DSP */\n    kDLHexagon = 16,\n} DLDeviceType;\n\n/*!\n * \\brief A Device for Tensor and operator.\n */\ntypedef struct {\n    /*! \\brief The device type used in the device. */\n    DLDeviceType device_type;\n    /*!\n     * \\brief The device index.\n     * For vanilla CPU memory, pinned memory, or managed memory, this is set to 0.\n     */\n    int32_t device_id;\n} DLDevice;\n\n/*!\n * \\brief The type code options DLDataType.\n */\ntypedef enum\n{\n    /*! \\brief signed integer */\n    kDLInt = 0U,\n    /*! \\brief unsigned integer */\n    kDLUInt = 1U,\n    /*! \\brief IEEE floating point */\n    kDLFloat = 2U,\n    /*!\n     * \\brief Opaque handle type, reserved for testing purposes.\n     * Frameworks need to agree on the handle data type for the exchange to be well-defined.\n     */\n    kDLOpaqueHandle = 3U,\n    /*! \\brief bfloat16 */\n    kDLBfloat = 4U,\n    /*!\n     * \\brief complex number\n     * (C/C++/Python layout: compact struct per complex number)\n     */\n    kDLComplex = 5U,\n    /*! \\brief boolean */\n    kDLBool = 6U,\n} DLDataTypeCode;\n\n/*!\n * \\brief The data type the tensor can hold. The data type is assumed to follow the\n * native endian-ness. An explicit error message should be raised when attempting to\n * export an array with non-native endianness\n *\n *  Examples\n *   - float: type_code = 2, bits = 32, lanes = 1\n *   - float4(vectorized 4 float): type_code = 2, bits = 32, lanes = 4\n *   - int8: type_code = 0, bits = 8, lanes = 1\n *   - std::complex<float>: type_code = 5, bits = 64, lanes = 1\n *   - bool: type_code = 6, bits = 8, lanes = 1 (as per common array library convention, the underlying storage size of\n * bool is 8 bits)\n */\ntypedef struct {\n    /*!\n     * \\brief Type code of base types.\n     * We keep it uint8_t instead of DLDataTypeCode for minimal memory\n     * footprint, but the value should be one of DLDataTypeCode enum values.\n     * */\n    uint8_t code;\n    /*!\n     * \\brief Number of bits, common choices are 8, 16, 32.\n     */\n    uint8_t bits;\n    /*! \\brief Number of lanes in the type, used for vector types. */\n    uint16_t lanes;\n} DLDataType;\n\n/*!\n * \\brief Plain C Tensor object, does not manage memory.\n */\ntypedef struct {\n    /*!\n     * \\brief The data pointer points to the allocated data. This will be CUDA\n     * device pointer or cl_mem handle in OpenCL. It may be opaque on some device\n     * types. This pointer is always aligned to 256 bytes as in CUDA. The\n     * `byte_offset` field should be used to point to the beginning of the data.\n     *\n     * Note that as of Nov 2021, multiply libraries (CuPy, PyTorch, TensorFlow,\n     * TVM, perhaps others) do not adhere to this 256 byte alignment requirement\n     * on CPU/CUDA/ROCm, and always use `byte_offset=0`.  This must be fixed\n     * (after which this note will be updated); at the moment it is recommended\n     * to not rely on the data pointer being correctly aligned.\n     *\n     * For given DLTensor, the size of memory required to store the contents of\n     * data is calculated as follows:\n     *\n     * \\code{.c}\n     * static inline size_t GetDataSize(const DLTensor* t) {\n     *   size_t size = 1;\n     *   for (tvm_index_t i = 0; i < t->ndim; ++i) {\n     *     size *= t->shape[i];\n     *   }\n     *   size *= (t->dtype.bits * t->dtype.lanes + 7) / 8;\n     *   return size;\n     * }\n     * \\endcode\n     */\n    void* data;\n    /*! \\brief The device of the tensor */\n    DLDevice device;\n    /*! \\brief Number of dimensions */\n    int32_t ndim;\n    /*! \\brief The data type of the pointer*/\n    DLDataType dtype;\n    /*! \\brief The shape of the tensor */\n    int64_t* shape;\n    /*!\n     * \\brief strides of the tensor (in number of elements, not bytes)\n     *  can be NULL, indicating tensor is compact and row-majored.\n     */\n    int64_t* strides;\n    /*! \\brief The offset in bytes to the beginning pointer to data */\n    uint64_t byte_offset;\n} DLTensor;\n\n/*!\n * \\brief C Tensor object, manage memory of DLTensor. This data structure is\n *  intended to facilitate the borrowing of DLTensor by another framework. It is\n *  not meant to transfer the tensor. When the borrowing framework doesn't need\n *  the tensor, it should call the deleter to notify the host that the resource\n *  is no longer needed.\n *\n * \\note This data structure is used as Legacy DLManagedTensor\n *       in DLPack exchange and is deprecated after DLPack v0.8\n *       Use DLManagedTensorVersioned instead.\n *       This data structure may get renamed or deleted in future versions.\n *\n * \\sa DLManagedTensorVersioned\n */\ntypedef struct DLManagedTensor {\n    /*! \\brief DLTensor which is being memory managed */\n    DLTensor dl_tensor;\n    /*! \\brief the context of the original host framework of DLManagedTensor in\n     *   which DLManagedTensor is used in the framework. It can also be NULL.\n     */\n    void* manager_ctx;\n    /*!\n     * \\brief Destructor - this should be called\n     * to destruct the manager_ctx  which backs the DLManagedTensor. It can be\n     * NULL if there is no way for the caller to provide a reasonable destructor.\n     * The destructors deletes the argument self as well.\n     */\n    void (*deleter)(struct DLManagedTensor* self);\n} DLManagedTensor;\n\n// bit masks used in in the DLManagedTensorVersioned\n\n/*! \\brief bit mask to indicate that the tensor is read only. */\n#define DLPACK_FLAG_BITMASK_READ_ONLY (1UL << 0UL)\n\n/*!\n * \\brief A versioned and managed C Tensor object, manage memory of DLTensor.\n *\n * This data structure is intended to facilitate the borrowing of DLTensor by\n * another framework. It is not meant to transfer the tensor. When the borrowing\n * framework doesn't need the tensor, it should call the deleter to notify the\n * host that the resource is no longer needed.\n *\n * \\note This is the current standard DLPack exchange data structure.\n */\nstruct DLManagedTensorVersioned {\n    /*!\n     * \\brief The API and ABI version of the current managed Tensor\n     */\n    DLPackVersion version;\n    /*!\n     * \\brief the context of the original host framework.\n     *\n     * Stores DLManagedTensorVersioned is used in the\n     * framework. It can also be NULL.\n     */\n    void* manager_ctx;\n    /*!\n     * \\brief Destructor.\n     *\n     * This should be called to destruct manager_ctx which holds the DLManagedTensorVersioned.\n     * It can be NULL if there is no way for the caller to provide a reasonable\n     * destructor. The destructors deletes the argument self as well.\n     */\n    void (*deleter)(struct DLManagedTensorVersioned* self);\n    /*!\n     * \\brief Additional bitmask flags information about the tensor.\n     *\n     * By default the flags should be set to 0.\n     *\n     * \\note Future ABI changes should keep everything until this field\n     *       stable, to ensure that deleter can be correctly called.\n     *\n     * \\sa DLPACK_FLAG_BITMASK_READ_ONLY\n     */\n    uint64_t flags;\n    /*! \\brief DLTensor which is being memory managed */\n    DLTensor dl_tensor;\n};\n\n#ifdef __cplusplus\n}  // DLPACK_EXTERN_C\n#endif\n#endif  // DLPACK_DLPACK_H_\n"
  },
  {
    "path": "src/turbomind/python/xgrammar_bind.cpp",
    "content": "// Modified from xgrammar/nanobind/nanobind.cc from xgrammar project.\n/*!\n *  Copyright (c) 2024 by Contributors\n * \\file xgrammar/nanobind/nanobind.cc\n */\n\n#include <memory>\n#include <sstream>\n#include <stdexcept>\n\n#include <pybind11/functional.h>\n#include <pybind11/pybind11.h>\n#include <pybind11/pytypes.h>\n#include <pybind11/stl.h>\n#include <pybind11/stl_bind.h>\n\n#include <xgrammar/xgrammar.h>\n\n#include \"src/turbomind/core/check.h\"\n\nnamespace py = pybind11;\nusing namespace xgrammar;\nusing namespace pybind11::literals;\n\nnamespace {\n\nstatic const std::vector<std::string>\nCommonEncodedVocabType(const py::typing::List<std::variant<std::string, py::bytes>>& lst)\n{\n    std::vector<std::string> out;\n    out.reserve(lst.size());\n    for (const auto& h : lst) {\n        if (py::isinstance<py::str>(h)) {\n            out.emplace_back(h.cast<std::string>());\n        }\n        else if (py::isinstance<py::bytes>(h)) {\n            out.emplace_back(h.cast<py::bytes>());\n        }\n        else {\n            throw std::invalid_argument(\"encoded_vocab items must be str or bytes\");\n        }\n    }\n    return out;\n}\n\nTokenizerInfo TokenizerInfo_Init(const std::vector<std::string>&     encoded_vocab,\n                                 int                                 vocab_type,\n                                 std::optional<int>                  vocab_size,\n                                 std::optional<std::vector<int32_t>> stop_token_ids,\n                                 bool                                add_prefix_space)\n{\n    TM_CHECK(vocab_type == 0 || vocab_type == 1 || vocab_type == 2) << \"Invalid vocab type: \" << vocab_type;\n    return TokenizerInfo(\n        encoded_vocab, static_cast<VocabType>(vocab_type), vocab_size, stop_token_ids, add_prefix_space);\n}\n\nint TokenizerInfo_GetVocabType(const TokenizerInfo& tokenizer)\n{\n    return static_cast<int>(tokenizer.GetVocabType());\n}\n\nstd::vector<py::bytes> TokenizerInfo_GetDecodedVocab(const TokenizerInfo& tokenizer)\n{\n    const auto&            decoded_vocab = tokenizer.GetDecodedVocab();\n    std::vector<py::bytes> py_result;\n    py_result.reserve(decoded_vocab.size());\n    for (const auto& item : decoded_vocab) {\n        py_result.emplace_back(py::bytes(item.c_str()));\n    }\n    return py_result;\n}\n\n}  // namespace\n\nPYBIND11_MODULE(_xgrammar, m)\n{\n    py::class_<TokenizerInfo, std::shared_ptr<TokenizerInfo>>(m, \"TokenizerInfo\")\n        .def(py::init([](const py::typing::List<std::variant<std::string, py::bytes>>& encoded_vocab,\n                         int                                                           vocab_type,\n                         std::optional<int>                                            vocab_size,\n                         std::optional<std::vector<int32_t>>                           stop_token_ids,\n                         bool                                                          add_prefix_space) {\n                 return TokenizerInfo{TokenizerInfo_Init(CommonEncodedVocabType(encoded_vocab),\n                                                         vocab_type,\n                                                         vocab_size,\n                                                         std::move(stop_token_ids),\n                                                         add_prefix_space)};\n             }),\n             py::arg(\"encoded_vocab\"),\n             py::arg(\"vocab_type\"),\n             py::arg(\"vocab_size\")     = py::none(),\n             py::arg(\"stop_token_ids\") = py::none(),\n             py::arg(\"add_prefix_space\"))\n\n        .def_property_readonly(\"vocab_type\", &TokenizerInfo_GetVocabType)\n        .def_property_readonly(\"vocab_size\", &TokenizerInfo::GetVocabSize)\n        .def_property_readonly(\"add_prefix_space\", &TokenizerInfo::GetAddPrefixSpace)\n        .def_property_readonly(\"decoded_vocab\", &TokenizerInfo_GetDecodedVocab)\n        .def_property_readonly(\"stop_token_ids\", &TokenizerInfo::GetStopTokenIds)\n        .def_property_readonly(\"special_token_ids\", &TokenizerInfo::GetSpecialTokenIds)\n\n        .def(\"dump_metadata\", &TokenizerInfo::DumpMetadata)\n\n        .def_static(\"from_vocab_and_metadata\",\n                    [](const py::typing::List<std::variant<std::string, py::bytes>>& encoded_vocab,\n                       const std::string&                                            metadata) {\n                        return TokenizerInfo::FromVocabAndMetadata(CommonEncodedVocabType(encoded_vocab), metadata);\n                    })\n\n        .def_static(\"_detect_metadata_from_hf\", &TokenizerInfo::DetectMetadataFromHF);\n\n    py::class_<CompiledGrammar>(m, \"CompiledGrammar\");\n\n    py::class_<GrammarCompiler> pyGrammarCompiler(m, \"GrammarCompiler\");\n    pyGrammarCompiler\n        .def(py::init<const TokenizerInfo&, int, bool, int64_t>(),\n             py::arg(\"tokenizer_info\"),\n             py::arg(\"max_threads\")      = 8,\n             py::arg(\"cache_enabled\")    = true,\n             py::arg(\"max_memory_bytes\") = -1)\n        .def(\"compile_json_schema\",\n             &GrammarCompiler::CompileJSONSchema,\n             py::call_guard<py::gil_scoped_release>(),\n             py::arg(\"schema\"),\n             py::arg(\"any_whitespace\")     = false,\n             py::arg(\"indent\")             = py::none(),\n             py::arg(\"separators\")         = py::none(),\n             py::arg(\"strict_mode\")        = true,\n             py::arg(\"max_whitespace_cnt\") = py::none())\n        .def(\"compile_regex\",\n             &GrammarCompiler::CompileRegex,\n             py::call_guard<py::gil_scoped_release>(),\n             py::arg(\"schema\"));\n}\n"
  },
  {
    "path": "src/turbomind/turbomind.cc",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include <filesystem>\n#include <future>\n#include <random>\n\n#include \"src/turbomind/turbomind.h\"\n\n#include \"src/turbomind/comm/host_comm.h\"\n#include \"src/turbomind/core/check.h\"\n#include \"src/turbomind/core/context.h\"\n#include \"src/turbomind/core/core.h\"\n\n#include \"src/turbomind/core/data_type.h\"\n#include \"src/turbomind/engine/engine.h\"\n#include \"src/turbomind/engine/gateway.h\"\n#include \"src/turbomind/engine/model_executor.h\"\n#include \"src/turbomind/engine/model_request.h\"\n\n#include \"src/turbomind/models/language_model.h\"\n#include \"src/turbomind/models/llama/LlamaWeight.h\"\n#include \"src/turbomind/models/llama/context.h\"\n#include \"src/turbomind/models/llama/llama_params.h\"\n#include \"src/turbomind/models/llama/llama_utils.h\"\n\n#include \"src/turbomind/kernels/gemm/tuner/params.h\"\n\n#include \"src/turbomind/utils/cuda_utils.h\"\n#include \"src/turbomind/utils/metrics.h\"\n\n#include <yaml-cpp/yaml.h>\n\n// #include \"dbg.h\"\n\nnamespace turbomind {\n\nusing std::vector;\nusing std::string;\nusing std::shared_ptr;\nusing std::unique_ptr;\n\nstatic std::optional<MoeParam::Method> get_moe_method()\n{\n    static const auto value = []() -> std::optional<MoeParam::Method> {\n        const auto p = std::getenv(\"TM_MOE_METHOD\");\n        if (p) {\n            std::string str(p);\n            for (auto& x : str) {\n                x = std::tolower(x);\n            }\n            if (str == \"naive\") {\n                return MoeParam::kNaive;\n            }\n            else if (str == \"fused\") {\n                return MoeParam::kFused;\n            }\n            else {\n                std::cerr << \"[WARNING] unrecognised MoE method: \" << str << \"\\n\";\n            }\n        }\n        return {};\n    }();\n    return value;\n}\n\n/// TODO: move config parsing to suitable place\nstatic void parse_default_rope_param(const YAML::Node& node, RopeParam& param)\n{\n    param.base = node[\"base\"].as<float>();\n    param.dim  = node[\"dim\"].as<int>();\n    if (param.base == 0.f || param.dim == 0) {\n        TM_LOG_ERROR(\"invalid rope param: base = %f, dim = %d\", param.base, param.dim);\n        FT_CHECK(0);\n    }\n}\n\nstatic void parse_linear_rope_param(const YAML::Node& node, RopeParam& param)\n{\n    parse_default_rope_param(node, param);\n    param.factor = node[\"factor\"].as<float>();\n}\n\nstatic void parse_dynamic_rope_param(const YAML::Node& node, RopeParam& param)\n{\n    parse_linear_rope_param(node, param);\n    param.max_position_embeddings = node[\"max_position_embeddings\"].as<int>();\n}\n\nstatic void parse_yarn_rope_param(const YAML::Node& node, RopeParam& param)\n{\n    parse_dynamic_rope_param(node, param);\n    param.yarn.attention_factor = node[\"attention_factor\"].as<float>();\n    param.yarn.beta_fast        = node[\"beta_fast\"].as<float>();\n    param.yarn.beta_slow        = node[\"beta_slow\"].as<float>();\n}\n\nstatic void parse_llama3_rope_param(const YAML::Node& node, RopeParam& param)\n{\n    parse_linear_rope_param(node, param);\n    param.llama3.low_freq_factor                  = node[\"low_freq_factor\"].as<float>();\n    param.llama3.high_freq_factor                 = node[\"high_freq_factor\"].as<float>();\n    param.llama3.original_max_position_embeddings = node[\"original_max_position_embeddings\"].as<int>();\n}\n\nstatic void parse_mrope_rope_param(const YAML::Node& node, RopeParam& param)\n{\n    parse_default_rope_param(node, param);\n    auto mrope_section = node[\"mrope_section\"].as<std::vector<int>>();\n    FT_CHECK(mrope_section.size() == 3);\n    param.mrope.section = {mrope_section[0], mrope_section[1], mrope_section[2]};\n}\n\nstatic void parse_rope_param(const YAML::Node& node, RopeParam& rope)\n{\n    rope.type = GetRoPEType(node[\"type\"].as<std::string>());\n\n    switch (rope.type) {\n        case RopeType::kDefault:\n            parse_default_rope_param(node, rope);\n            break;\n        case RopeType::kLinear:\n            parse_linear_rope_param(node, rope);\n            break;\n        case RopeType::kDynamic:\n            parse_dynamic_rope_param(node, rope);\n            break;\n        case RopeType::kYarn:\n            parse_yarn_rope_param(node, rope);\n            break;\n        case RopeType::kLlama3:\n            parse_llama3_rope_param(node, rope);\n            break;\n        case RopeType::kMrope:\n            parse_mrope_rope_param(node, rope);\n            break;\n        default:\n            FT_CHECK(0);\n            break;\n    }\n}\n\nstatic DataType data_type_from_string(std::string str)\n{\n    if (str == \"fp16\" || str == \"float16\") {\n        return kFloat16;\n    }\n    else if (str == \"bf16\" || str == \"bfloat16\") {\n        return kBfloat16;\n    }\n    else if (str == \"fp32\") {\n        return kFloat32;\n    }\n    else if (str == \"int8\") {\n        return kUint8;\n    }\n    else if (str == \"int4\") {\n        return kUint4;\n    }\n    else if (str == \"fp8\") {\n        return kFloat8_e4m3;\n    }\n    else if (str == \"e2m1\") {\n        return kFloat4_e2m1;\n    }\n    TM_CHECK(0) << \"unsupported weight type: \" << str;\n    return {};\n}\n\nstruct TurboMind::Impl {\n    DataType       data_type_;\n    ModelParam     model_param_;\n    AttentionParam attn_param_;\n    MoeParam       moe_param_;\n    EngineParam    engine_param_;\n    size_t         comm_size_;\n\n    vector<EngineParam> engine_params_;\n\n    string communicator_type_;  // communicator backend\n\n    unique_ptr<comm::HostGroupId> group_id_;\n\n    shared_ptr<Gateway> gateway_;\n\n    FFICtxFactory ffi_ctx_factory_;\n\n    vector<int> global_rank_;\n\n    // Weights & engine instances for the ranks\n    vector<shared_ptr<LlamaWeight>> weights_;\n    vector<shared_ptr<Context>>     contexts_;\n    vector<Engine>                  engines_;\n\n    string model_name_;\n    string model_dir_;\n\n    vector<int> queue_id_;\n    int         n_queues_{0};\n\n    int need_warm_up_{1};\n    int phases_{1};\n\n    ~Impl();\n\n    Impl(string model_dir, string config, FFICtxFactory ffi_ctx_factory);\n\n    unique_ptr<ModelRequest> CreateRequest()\n    {\n        return std::make_unique<ModelRequest>(gateway_.get(),  //\n                                              data_type_,\n                                              engine_param_.session_len,\n                                              model_param_.vocab_size,\n                                              model_param_.hidden_units);\n    }\n\n    void CreateWeights(int index)\n    {\n        CudaDeviceGuard dev_guard(engine_param_.devices[index]);\n\n        CreateContext(index);\n\n        weights_[index] = std::make_shared<LlamaWeight>(data_type_,  //\n                                                        model_param_,\n                                                        engine_params_.at(index),\n                                                        moe_param_);\n    }\n\n    TensorMap GetWeights(int index)\n    {\n        const auto& tensor_ptr_map = TM_CHECK_NOTNULL(weights_[index])->get_parameters();\n        TensorMap   params;\n        for (const auto& [name, tensor_ptr] : tensor_ptr_map) {\n            params[name] = *tensor_ptr;\n        }\n        return params;\n    }\n\n    void ProcessWeights(int index)\n    {\n        CudaDeviceGuard dev_guard(engine_param_.devices[index]);\n        FT_CHECK(weights_[index] != nullptr);\n\n        cudaDeviceProp props{};\n        check_cuda_error(cudaGetDeviceProperties(&props, engine_param_.devices[index]));\n\n        weights_[index]->prepare(props);\n        sync_check_cuda_error();\n    }\n\n    void CreateEngine(int index);\n\n    void CreateContext(int index);\n\n    void WarmUp(int index);\n\n    void Sleep(int index, int level)\n    {\n        CudaDeviceGuard dev_guard(engine_param_.devices[index]);\n\n        if (level == 2) {\n            // free weights\n            weights_[index]->release();\n        }\n        else {\n            // offload weights to CPU\n            TM_CHECK(moe_param_.experts_per_token == 0) << \"level 1 sleep not supported for MoE model\";\n            weights_[index]->to_device(kCPU);\n        }\n\n        // free model (kv cache and buffer)\n        if (index == 0) {\n            gateway_->shutdown();\n            gateway_.reset();\n        }\n\n        engines_[index] = {};\n        contexts_[index]->allocator->trim(0);\n\n        trim_default_mempool(engine_param_.devices[index]);\n    }\n\n    void WakeUp(int index, const std::vector<std::string>& tags)\n    {\n        CudaDeviceGuard dev_guard(engine_param_.devices[index]);\n\n        std::set<std::string> keys(tags.begin(), tags.end());\n\n        auto& ctx = *TM_CHECK_NOTNULL(contexts_[index]);\n\n        if (keys.find(\"weights\") != keys.end()) {\n            TM_CHECK(weights_[index] != nullptr);\n            if (weights_[index]->is_initialized()) {\n                weights_[index]->to_device(kDEVICE);\n            }\n            else {\n                weights_[index]->initialize();\n            }\n        }\n\n        if (keys.find(\"kv_cache\") != keys.end()) {\n            if (index == 0) {\n                gateway_ = std::make_shared<Gateway>(n_queues_, ffi_ctx_factory_);\n            }\n            CreateEngine(index);\n        }\n    }\n\n    void HandleMissingParams()\n    {\n        if (!engine_param_.max_context_token_num) {\n            engine_param_.max_context_token_num = engine_param_.session_len;\n            TM_LOG_WARNING(\"[TM] `max_context_token_num` is not set, default to %d.\",\n                           (int)engine_param_.max_context_token_num);\n        }\n\n        if (engine_param_.max_context_token_num <= engine_param_.max_batch_size) {\n            engine_param_.max_context_token_num *= engine_param_.session_len;\n            TM_LOG_WARNING(\"[TM] `max_context_token_num` = %d.\", (int)engine_param_.max_context_token_num);\n        }\n    }\n};\n\nTurboMind::Impl::~Impl()\n{\n    TM_LOG_INFO(__PRETTY_FUNCTION__);\n    if (gateway_) {\n        gateway_->shutdown();\n    }\n    for (int i = 0; i < (int)engines_.size(); ++i) {\n        /// TODO: make device part of core::Context\n        CudaDeviceGuard device(engine_param_.devices[i]);\n        {\n            core::ContextGuard context{contexts_[i]->core_stream};\n            engines_[i]  = {};\n            contexts_[i] = {};\n        }\n        weights_[i] = {};\n    }\n}\n\nTurboMind::Impl::Impl(string model_dir, string config, FFICtxFactory ffi_ctx_factory):\n    data_type_{}, model_param_{}, attn_param_{}, moe_param_{}, engine_param_{}, ffi_ctx_factory_{ffi_ctx_factory}\n{\n    TM_CHECK(!config.empty());\n\n    YAML::Node node;\n    try {\n        node = YAML::Load(config);\n    }\n    catch (const YAML::Exception& e) {\n        TM_CHECK(0) << \"Error loading YAML config: \" << e.what() << \"\\nconfig:\\n\" << config;\n    }\n\n    /// TODO: move config parsing to suitable place\n    const auto model     = node[\"model_config\"];\n    const auto attention = node[\"attention_config\"];\n    const auto engine    = node[\"engine_config\"];\n\n    data_type_ = model_param_.data_type = data_type_from_string(model[\"data_type\"].as<std::string>());\n    TM_CHECK(data_type_ == kBfloat16 || data_type_ == kHalf);\n\n    model_name_                     = model[\"model_name\"].as<std::string>();\n    model_param_.head_num           = model[\"head_num\"].as<int>();\n    model_param_.head_dim           = model[\"size_per_head\"].as<int>();\n    model_param_.kv_head_num        = model[\"kv_head_num\"].as<int>(0);\n    model_param_.hidden_units       = model[\"hidden_units\"].as<int>();\n    model_param_.layer_num          = model[\"num_layer\"].as<int>();\n    model_param_.vocab_size         = model[\"vocab_size\"].as<int>();\n    model_param_.embedding_size     = model[\"embedding_size\"].as<int>();\n    model_param_.norm_eps           = model[\"norm_eps\"].as<float>();\n    model_param_.tune_layer_num     = model[\"tune_layer_num\"].as<int>(1);\n    model_param_.mla.q_lora_rank    = model[\"q_lora_rank\"].as<int>();\n    model_param_.mla.kv_lora_rank   = model[\"kv_lora_rank\"].as<int>();\n    model_param_.mla.qk_rope_dim    = model[\"qk_rope_dim\"].as<int>();\n    model_param_.mla.v_head_dim     = model[\"v_head_dim\"].as<int>();\n    attn_param_.cache_block_seq_len = attention[\"cache_block_seq_len\"].as<int>(0);\n    model_param_.quant_policy       = engine[\"quant_policy\"].as<int>(0);\n\n    auto inter_size = model[\"inter_size\"];\n    for (auto it = inter_size.begin(); it != inter_size.end(); ++it) {\n        model_param_.inter_size.push_back(it->as<int>());\n    }\n\n    if (auto layer_types = model[\"layer_types\"]) {\n        for (auto it = layer_types.begin(); it != layer_types.end(); ++it) {\n            auto type_str = it->as<std::string>(\"\");\n            if (type_str == \"linear_attention\") {\n                model_param_.layer_types.push_back(1);\n            }\n            else if (type_str == \"full_attention\" || type_str.empty()) {\n                model_param_.layer_types.push_back(0);\n            }\n            else {\n                TM_LOG_WARNING(\"[TM] Unknown layer_type '%s', treating as full_attention.\", type_str.c_str());\n                model_param_.layer_types.push_back(0);\n            }\n        }\n    }\n\n    // Qwen3.5 Gated DeltaNet linear attention parameters\n    model_param_.linear_key_head_dim    = model[\"linear_key_head_dim\"].as<int>(0);\n    model_param_.linear_value_head_dim  = model[\"linear_value_head_dim\"].as<int>(0);\n    model_param_.linear_conv_kernel_dim = model[\"linear_conv_kernel_dim\"].as<int>(0);\n    model_param_.linear_num_key_heads   = model[\"linear_num_key_heads\"].as<int>(0);\n    model_param_.linear_num_value_heads = model[\"linear_num_value_heads\"].as<int>(0);\n    model_param_.attn_output_gate       = model[\"attn_output_gate\"].as<bool>(false);\n    model_param_.linear_state_dtype     = data_type_;\n\n    if (auto uqel = model[\"unquantized_expert_layers\"]) {\n        for (auto it = uqel.begin(); it != uqel.end(); ++it) {\n            model_param_.unquantized_expert_layers.insert(it->as<int>());\n        }\n    }\n    model_param_.attn_sink = model[\"attn_sink\"].as<bool>();\n    model_param_.mlp_bias  = model[\"mlp_bias\"].as<bool>();\n    if (model[\"activation_type\"].as<std::string>(\"\") == \"gpt-oss\") {\n        model_param_.act_type = ActivationType::kSiluGptOss;\n    }\n\n    auto window_size = model[\"window_size\"];\n    for (auto it = window_size.begin(); it != window_size.end(); ++it) {\n        model_param_.window_size.push_back(it->as<int>());\n    }\n\n    model_param_.attn_bias  = model[\"attn_bias\"].as<int>(0);\n    model_param_.qk_norm    = model[\"qk_norm\"].as<bool>();\n    model_param_.group_size = model[\"group_size\"].as<int>(0);\n\n    attn_param_.softmax_scale = attention[\"softmax_scale\"].as<float>(0);\n    // logn attn for qwen model\n    attn_param_.use_logn_attn           = attention[\"use_logn_attn\"].as<int>(0);\n    attn_param_.max_position_embeddings = attention[\"max_position_embeddings\"].as<int>(0);\n    // rotary embedding parameters\n    parse_rope_param(attention[\"rope_param\"], attn_param_.rope);\n\n    engine_param_.max_batch_size = engine[\"max_batch_size\"].as<int>(0);\n    auto max_forward_token_num   = engine[\"max_prefill_token_num\"].as<int>(0);\n    max_forward_token_num += engine_param_.max_batch_size;\n\n    engine_param_.max_context_token_num = engine[\"max_context_token_num\"].as<int>(0);\n    engine_param_.session_len           = model[\"session_len\"].as<int>(0);\n\n    engine_param_.cache_max_block_count = engine[\"cache_max_entry_count\"].as<float>(0);\n    engine_param_.cache_chunk_size      = engine[\"cache_chunk_size\"].as<int>(0);\n    engine_param_.enable_prefix_caching = engine[\"enable_prefix_caching\"].as<bool>(false);\n    engine_param_.enable_metrics        = engine[\"enable_metrics\"].as<bool>(false);\n\n    if (engine_param_.enable_prefix_caching && HasLinearAttention(model_param_)) {\n        TM_CHECK(0) << \"Prefix caching is unsupported when linear attention is present\";\n    }\n\n    engine_param_.num_tokens_per_iter = engine[\"num_tokens_per_iter\"].as<int>(0);\n    engine_param_.max_prefill_iters   = engine[\"max_prefill_iters\"].as<int>(1);\n\n    phases_ = engine[\"async_\"].as<int>() ? 2 : 1;\n\n    engine_param_.outer_dp_size = engine[\"outer_dp_size\"].as<int>();\n\n    engine_param_.attn_dp_size = engine[\"attn_dp_size\"].as<int>();\n    engine_param_.attn_tp_size = engine[\"attn_tp_size\"].as<int>();\n    engine_param_.attn_cp_size = engine[\"attn_cp_size\"].as<int>();\n\n    engine_param_.mlp_tp_size = engine[\"mlp_tp_size\"].as<int>();\n\n    engine_param_.devices = engine[\"devices\"].as<std::vector<int>>();\n\n    // multi-node information\n    engine_param_.nnodes    = engine[\"nnodes\"].as<int>();\n    engine_param_.node_rank = engine[\"node_rank\"].as<int>();\n\n    {\n        auto sp                             = engine_param_.attn_tp_size * engine_param_.attn_cp_size;\n        engine_param_.max_forward_token_num = ((size_t)max_forward_token_num + sp - 1) / sp * sp;\n    }\n\n    comm_size_ = engine_param_.attn_dp_size * engine_param_.attn_tp_size * engine_param_.attn_cp_size;\n    FT_CHECK(engine_param_.mlp_tp_size == comm_size_);\n\n    communicator_type_ = engine[\"communicator\"].as<std::string>();\n\n    moe_param_.experts_per_token = model[\"experts_per_token\"].as<int>(0);\n    moe_param_.inter_size        = model[\"expert_inter_size\"].as<int>(0);\n    moe_param_.shared_gate       = model[\"moe_shared_gate\"].as<bool>();\n    moe_param_.norm_topk_prob    = model[\"norm_topk_prob\"].as<bool>();\n    moe_param_.routed_scale      = model[\"routed_scale\"].as<float>(1.f);\n    moe_param_.topk_group        = model[\"topk_group\"].as<int>(1);\n    moe_param_.topk_method       = model[\"topk_method\"].as<std::string>(\"greedy\");\n    moe_param_.n_group           = model[\"moe_group_num\"].as<int>(1);\n    moe_param_.scoring_func      = model[\"scoring_func\"].as<std::string>(\"softmax\");\n    moe_param_.router_n_groups   = model[\"router_n_groups\"].as<int>(-1);\n    moe_param_.router_bias       = model[\"expert_router_bias\"].as<bool>();\n    YAML::Node expert_num        = model[\"expert_num\"];\n    for (auto it = expert_num.begin(); it != expert_num.end(); ++it) {\n        moe_param_.expert_num.push_back(it->as<int>());\n    }\n\n    HandleMissingParams();\n\n    weights_.resize(engine_param_.devices.size());\n    engines_.resize(engine_param_.devices.size());\n    contexts_.resize(engine_param_.devices.size());\n\n    model_param_.weight_type        = data_type_from_string(model[\"weight_type\"].as<std::string>());\n    model_param_.expert_weight_type = data_type_from_string(model[\"expert_weight_type\"].as<std::string>());\n    model_param_.ffn_weight_type =\n        data_type_from_string(model[\"ffn_weight_type\"].as<std::string>(model[\"weight_type\"].as<std::string>()));\n\n    if (auto method = get_moe_method()) {\n        moe_param_.method = *method;\n    }\n    else {\n        moe_param_.method = MoeParam::kFused;\n    }\n\n    // NOTE: This runs on Python main thread\n    group_id_ = comm::CreateHostGroupId((engine_param_.nnodes == 1) ? \"\" : \"hybrid\");\n    group_id_->Initialize();\n\n    const int devices = engine_param_.devices.size();\n\n    for (int i = 0; i < devices; ++i) {\n        global_rank_.push_back(engine_param_.node_rank * devices + i);\n    }\n\n    queue_id_.resize(devices);\n    engine_params_.resize(devices, engine_param_);\n}\n\nvoid TurboMind::Impl::CreateContext(int index)\n{\n    auto& p = engine_params_[index];\n\n    CudaDeviceGuard dev_guard(p.devices[index]);\n\n    TM_CHECK(contexts_[index] == nullptr);\n\n    auto& ctx = contexts_[index] = std::make_shared<Context>(p.devices[index]);\n\n    // Layout: (outer, dp, tp, cp)\n\n    const int global_rank = global_rank_[index];\n\n    const int outer_rank = global_rank / comm_size_;\n    const int inner_rank = global_rank % comm_size_;\n\n    p.outer_dp_rank = outer_rank;\n\n    const int tp_cp_size = p.attn_tp_size * p.attn_cp_size;\n\n    const int tp_color = inner_rank / tp_cp_size;\n    const int dp_color = inner_rank % tp_cp_size;\n    const int cp_color = inner_rank / p.attn_cp_size;\n\n    auto& c = ctx->comm;\n\n    c.h_global = group_id_->CreateCommunicator(comm_size_, global_rank, p.node_rank);\n\n    c.h_comm = c.h_global->Split(outer_rank, 0);\n\n    c.h_tp_group = c.h_comm->Split(tp_color, 0);\n    c.h_dp_group = c.h_comm->Split(dp_color, 0);\n\n    if (comm_size_ > 1) {\n        c.d_comm = CreateDeviceCommunicator(communicator_type_, comm_size_, inner_rank, c.h_comm);\n\n        c.d_tp_group = 0;\n        c.d_cp_group = 0;\n\n        if (p.attn_dp_size > 1) {  // has attn_dp\n            c.d_tp_group   = c.d_comm->Split(tp_color, 0, 0);\n            p.attn_dp_rank = c.h_dp_group->rank();\n        }\n\n        if (p.attn_cp_size > 1) {  // has attn_cp\n            c.d_cp_group   = c.d_comm->Split(cp_color, 0, 0);\n            p.attn_cp_rank = c.d_comm->rank(c.d_cp_group);\n        }\n\n        p.attn_tp_rank = c.d_comm->rank(c.d_tp_group) / p.attn_cp_size;\n        p.mlp_tp_rank  = c.d_comm->rank(0);\n    }\n\n    if (c.h_tp_group->rank() == 0) {\n        queue_id_[index] = 1;\n    }\n\n    c.h_global->Sync();\n\n    if (index == 0) {\n        n_queues_ = 0;\n        for (size_t i = 0; i < queue_id_.size(); ++i) {\n            queue_id_[i] = queue_id_[i] ? n_queues_++ : -1;\n        }\n        gateway_ = std::make_shared<Gateway>(n_queues_, ffi_ctx_factory_);\n    }\n\n    c.h_global->Sync();\n}\n\nvoid TurboMind::Impl::CreateEngine(int index)\n{\n    CudaDeviceGuard dev_guard(engine_param_.devices[index]);\n\n    auto& ctx = *TM_CHECK_NOTNULL(contexts_[index]);\n\n    core::ContextGuard guard{ctx.core_stream, ctx.allocator, Allocator{kCPUpinned}};\n\n    const auto& param = engine_params_.at(index);\n\n    ctx.comm.h_comm->Sync();\n\n    // create model\n    LanguageModel model{data_type_,  //\n                        model_param_,\n                        param,\n                        attn_param_,\n                        moe_param_,\n                        ctx,\n                        *weights_[index],\n                        phases_};\n\n    // create engine\n    engines_[index] = Engine{data_type_,  //\n                             param,\n                             std::move(model),\n                             ctx,\n                             *gateway_,\n                             engine_param_.devices[index],\n                             queue_id_[index],\n                             phases_};\n\n    core::Context::stream().Sync();\n\n    ctx.comm.h_comm->Sync();\n\n    engines_[index].Start();\n\n    if (need_warm_up_) {\n        WarmUp(index);\n    }\n}\n\ntemplate<class Iter>\nstatic std::string Join(Iter first, Iter last, const std::string& delim)\n{\n    if (first == last) {\n        return {};\n    }\n    std::ostringstream oss;\n    oss << *first++;\n    while (first != last) {\n        oss << delim << *first++;\n    }\n    return oss.str();\n}\n\nvoid TurboMind::Impl::WarmUp(int index)\n{\n    auto& ctx = *TM_CHECK_NOTNULL(contexts_[index]);\n\n    auto& global = ctx.comm.h_global;\n    auto& linear = *ctx.linear;\n\n    if (auto str = std::getenv(\"TM_GEMM_IMPORT\")) {\n        std::ifstream ifs(str);\n        const int     n_imported = linear.Import(ifs);\n        if (index == 0) {\n            TM_LOG_INFO(\"[GEMM] %d records imported\", n_imported);\n        }\n        return;\n    }\n\n    global->Sync();\n\n    *ctx.is_warm_up = 1;\n    linear.set_measure(true);\n\n    if (index == 0) {\n        gateway_->set_threshold(engine_param_.attn_dp_size);\n    }\n\n    global->Sync();\n\n    if (ctx.comm.h_tp_group->rank() == 0) {\n\n        std::vector<int> bss = linear.GetTuningSeq();\n        if (bss.empty()) {\n            bss = gemm::GenerateTuningSequence(gemm::GetDefaultTuningGenerators());\n        }\n\n        const int max_fwd_token_num = engine_param_.max_forward_token_num;\n\n        // remove bs that is too large\n        bss.erase(std::remove_if(bss.begin(), bss.end(), [&](auto x) { return x > max_fwd_token_num; }), bss.end());\n\n        if (bss.empty() || bss.back() < max_fwd_token_num) {\n            bss.push_back(max_fwd_token_num);\n        }\n\n        auto str = Join(bss.begin(), bss.end(), \", \");\n        TM_LOG_INFO(\"[Engine] Warm-up lengths: %s\", str.c_str());\n\n        if (!bss.empty()) {\n            const auto                         max_bs = *std::max_element(bss.begin(), bss.end());\n            Buffer_<int>                       input_ids(max_bs, kCPU);\n            std::mt19937                       g{};\n            std::uniform_int_distribution<int> d{0, (int)model_param_.vocab_size - 1};\n            for (auto& x : input_ids) {\n                x = d(g);\n            }\n\n            auto tick = std::chrono::steady_clock::now();\n\n            for (auto token_num : bss) {\n\n                TM_LOG_INFO(\"[WarmUp] %d\", token_num);\n\n                auto r = CreateRequest();\n\n                TensorMap inputs{{\"input_ids\", input_ids.slice(0, token_num)}};\n\n                ModelRequest::InputParam param{};\n                param.session.start_flag     = true;\n                param.session.end_flag       = true;\n                param.gen_cfg.max_new_tokens = 1;\n                param.tensors                = std::make_shared<TensorMap>(inputs);\n\n                struct Channel {\n                    int                flag = 1;\n                    std::promise<void> promise;\n                };\n                auto c = std::make_shared<Channel>();\n\n                ModelRequest::OutputParam out = r->Forward(std::move(param), [c] {\n                    /// NOTE: It's risky to set `out.state` here, `out` may not be initialized at this point\n                    if (std::exchange(c->flag, 0)) {\n                        c->promise.set_value();\n                    }\n                });\n\n                c->promise.get_future().get();\n\n                int status = -1;\n                if (auto state = out.state->exchange(nullptr)) {\n                    status = state->status;\n                }\n\n                if (status != Request::kFinish) {\n                    TM_LOG_ERROR(\"[Engine] Warm-up for %d tokens failed with status %d\", (int)token_num, (int)status);\n                }\n            }\n\n            auto tock = std::chrono::steady_clock::now();\n\n            TM_LOG_INFO(\"[WarmUp] Warm-up finished in %.2f seconds.\",\n                        std::chrono::duration<float, std::ratio<1, 1>>(tock - tick).count());\n        }\n    }\n\n    global->Sync();\n\n    linear.set_measure(false);\n    *ctx.is_warm_up = 0;\n\n    if (index == 0) {\n        if (auto path = std::getenv(\"TM_GEMM_EXPORT\")) {\n            std::ofstream ofs(path);\n            const auto    n_records = linear.Export(ofs);\n            TM_LOG_INFO(\"[GEMM] %d records exported.\", n_records);\n        }\n\n        gateway_->set_threshold(1);\n        need_warm_up_ = 0;\n    }\n\n    global->Sync();\n}\n\nTurboMind::~TurboMind() = default;\n\nTurboMind::TurboMind(string model_dir, string config, FFICtxFactory ffi_ctx_factory):\n    impl_{std::make_unique<Impl>(model_dir, config, ffi_ctx_factory)}\n{\n}\n\nvoid TurboMind::CreateWeights(int index)\n{\n    return impl_->CreateWeights(index);\n}\n\nTensorMap TurboMind::GetWeights(int index)\n{\n    return impl_->GetWeights(index);\n}\n\nvoid TurboMind::ProcessWeights(int index)\n{\n    return impl_->ProcessWeights(index);\n}\n\nvoid TurboMind::CreateEngine(int index)\n{\n    return impl_->CreateEngine(index);\n}\n\nvoid TurboMind::Sleep(int index, int level)\n{\n    return impl_->Sleep(index, level);\n}\n\nvoid TurboMind::WakeUp(int index, const vector<string>& tags)\n{\n    return impl_->WakeUp(index, tags);\n}\n\nshared_ptr<ScheduleMetrics> TurboMind::GetScheduleMetrics(int index)\n{\n    return impl_->engines_[index].GetScheduleMetrics();\n}\n\nunique_ptr<ModelRequest> TurboMind::CreateRequest()\n{\n    return impl_->CreateRequest();\n}\n\nbool TurboMind::is_dummy_node() const noexcept\n{\n    return impl_->n_queues_ == 0;\n}\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/turbomind.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include <functional>\n#include <memory>\n#include <string>\n\n#include \"src/turbomind/core/core.h\"\n#include \"src/turbomind/engine/model_request.h\"\n#include \"src/turbomind/utils/metrics.h\"\n\nnamespace turbomind {\n\nclass TurboMind {\npublic:\n    using FFICtxFactory = std::function<std::shared_ptr<void>()>;\n\n    ~TurboMind();\n\n    TurboMind(std::string model_dir, std::string config, FFICtxFactory ffi_ctx_factory);\n\n    void CreateWeights(int index);\n\n    TensorMap GetWeights(int index);\n\n    void ProcessWeights(int index);\n\n    void CreateEngine(int index);\n\n    void Sleep(int index, int level);\n\n    void WakeUp(int index, const std::vector<std::string>& tags);\n\n    bool is_dummy_node() const noexcept;\n\n    std::shared_ptr<ScheduleMetrics> GetScheduleMetrics(int index);\n\n    std::unique_ptr<ModelRequest> CreateRequest();\n\nprivate:\n    struct Impl;\n    std::unique_ptr<Impl> impl_;\n};\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/utils/CMakeLists.txt",
    "content": "# Copyright (c) 2019-2023, NVIDIA CORPORATION.  All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\ncmake_minimum_required(VERSION 3.11)\n\nfind_package(CUDAToolkit REQUIRED)\n\nadd_library(logger STATIC logger.cc)\nset_property(TARGET logger PROPERTY POSITION_INDEPENDENT_CODE  ON)\nset_property(TARGET logger PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS  ON)\ntarget_link_libraries(logger PUBLIC CUDA::cudart)\n\n\nadd_library(cuda_utils STATIC cuda_utils.cc)\nset_property(TARGET cuda_utils PROPERTY POSITION_INDEPENDENT_CODE  ON)\nset_property(TARGET cuda_utils PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS  ON)\ntarget_link_libraries(cuda_utils PUBLIC logger CUDA::cudart CUDA::cuda_driver)\n\n\nadd_library(nvtx_utils STATIC nvtx_utils.cc)\nset_property(TARGET nvtx_utils PROPERTY POSITION_INDEPENDENT_CODE  ON)\nset_property(TARGET nvtx_utils PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS  ON)\nif(${CMAKE_VERSION} VERSION_LESS \"3.25\")\n    target_link_libraries(nvtx_utils PUBLIC CUDA::nvToolsExt -ldl)\nelse()\n    target_link_libraries(nvtx_utils PUBLIC CUDA::nvtx3 -ldl)\nendif()\n\nadd_library(memory_utils STATIC memory_utils.cu)\nset_property(TARGET memory_utils PROPERTY POSITION_INDEPENDENT_CODE  ON)\nset_property(TARGET memory_utils PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS  ON)\ntarget_link_libraries(memory_utils PUBLIC cuda_utils logger)\n\nadd_library(anomaly_handler STATIC anomaly_handler.cu)\nset_property(TARGET anomaly_handler PROPERTY POSITION_INDEPENDENT_CODE  ON)\nset_property(TARGET anomaly_handler PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS  ON)\ntarget_link_libraries(anomaly_handler PUBLIC cuda_utils logger)\n\nadd_library(parser STATIC parser.cc)\nset_property(TARGET parser PROPERTY POSITION_INDEPENDENT_CODE  ON)\n"
  },
  {
    "path": "src/turbomind/utils/anomaly_handler.cu",
    "content": "\n\n#include <cmath>\n#include <cub/block/block_reduce.cuh>\n#include <optional>\n#include <string>\n#include <thrust/device_vector.h>\n#include <thrust/host_vector.h>\n\n#include \"src/turbomind/core/data_type.h\"\n#include \"src/turbomind/models/llama/llama_utils.h\"\n#include \"src/turbomind/utils/anomaly_handler.h\"\n#include \"src/turbomind/utils/cuda_utils.h\"\n#include \"src/turbomind/utils/logger.h\"\n#include \"src/turbomind/utils/memory_utils.h\"\n\nnamespace turbomind {\n\nstatic std::optional<float> parse_float(const std::string& s, const std::string& key)\n{\n    if (auto pos = s.find(key); pos != std::string::npos) {\n        float value{};\n        if (sscanf(s.c_str() + pos + key.size(), \"%f\", &value) != EOF) {\n            return value;\n        }\n    }\n    return {};\n}\n\ntemplate<class T, int BLOCK_SIZE>\n__global__ void CountAndFixAnormaly(\n    T* data, int64_t size, unsigned long long* n_inf, unsigned long long* n_nan, T pinf_val, T ninf_val, T nan_val)\n{\n    int inf_count{};\n    int nan_count{};\n\n    for (size_t i = threadIdx.x + blockIdx.x * blockDim.x; i < size; i += gridDim.x * blockDim.x) {\n        auto x = static_cast<float>(data[i]);\n        if (isinf(x)) {\n            ++inf_count;\n            data[i] = x > 0.f ? pinf_val : ninf_val;\n        }\n        else if (isnan(x)) {\n            ++nan_count;\n            data[i] = nan_val;\n        }\n    }\n\n    typedef cub::BlockReduce<int, BLOCK_SIZE> BlockReduce;\n\n    __shared__ typename BlockReduce::TempStorage temp_storage;\n\n    if (n_inf) {\n        inf_count = BlockReduce(temp_storage).Sum(inf_count);\n        if (threadIdx.x == 0) {\n            atomicAdd(n_inf, inf_count);\n        }\n    }\n\n    // Wait for last use of `temp_storage`\n    __syncthreads();\n\n    if (n_nan) {\n        nan_count = BlockReduce(temp_storage).Sum(nan_count);\n        if (threadIdx.x == 0) {\n            atomicAdd(n_nan, nan_count);\n        }\n    }\n}\n\ntemplate<class T, int BLOCK_SIZE>\n__global__ void FixLogitsAnomaly(T*   logits,  //\n                                 int* is_anomaly,\n                                 int  vocab_size,\n                                 int  batch_size,\n                                 int  fallback)\n{\n    const int bi = blockIdx.x;\n\n    T* ptr = logits + vocab_size * bi;\n\n    int count = 0;\n\n    // Accumulate per thread anomaly count\n    for (int i = threadIdx.x; i < vocab_size; i += BLOCK_SIZE) {\n        const float val = static_cast<float>(ptr[i]);\n        count += static_cast<int>(isnan(val) || isinf(val));\n    }\n\n    // If anything goes wrong\n    int error = __syncthreads_or(count);\n\n    if (!error) {\n        return;\n    }\n\n    // Clear all logits\n    for (int i = threadIdx.x; i < vocab_size; i += BLOCK_SIZE) {\n        ptr[i] = T(0.f);\n    }\n\n    // Set the fallback token\n    if (fallback % BLOCK_SIZE == threadIdx.x) {\n        // Ideally we want INF here, but it leads to `INF - INF -> NaN` in the sampling kernels\n        // Setting other logits to -INF has similar problem when banning bad words (same -INF)\n        ptr[fallback] = T(65504.f);  // Maximum finite value of half\n    }\n\n    if (threadIdx.x == 0 && is_anomaly) {\n        is_anomaly[bi] = 1;\n    }\n}\n\nstruct AnomalyHandler::Impl {\n\n    Impl()\n    {\n        GlobalInit();\n\n        if (g_level) {\n            d_count_.resize(max_entries * 2);\n            h_count_.resize(d_count_.size());\n        }\n    }\n\n    // Process level initialization from environment variable\n    static void GlobalInit()\n    {\n        [[maybe_unused]] static const auto _ = []() -> bool {\n            const auto var = std::getenv(\"TM_ANOMALY_HANDLER\");\n            if (!var) {\n                return false;\n            }\n            const std::string str{var};\n\n            const auto level = parse_float(str, \"level=\");\n            if (level) {\n                g_level = static_cast<int>(*level);\n            }\n\n            TM_LOG_WARNING(\"[AnomalyHandler] level: %d\", g_level);\n\n            if (!g_level) {\n                return {};\n            }\n\n            const auto pos_inf = parse_float(str, \"pinf=\");\n            if (pos_inf) {\n                g_pinf_val_ = *pos_inf;\n                TM_LOG_WARNING(\"[AnomalyHandler] +INF -> %f\", g_pinf_val_);\n            }\n\n            const auto neg_inf = parse_float(str, \"ninf=\");\n            if (neg_inf) {\n                g_ninf_val_ = *neg_inf;\n                TM_LOG_WARNING(\"[AnomalyHandler] -INF -> %f\", g_ninf_val_);\n            }\n\n            if (!pos_inf && !neg_inf) {\n                if (const auto flush_inf = parse_float(str, \"inf=\")) {\n                    g_pinf_val_ = *flush_inf;\n                    g_ninf_val_ = -g_pinf_val_;\n                    TM_LOG_WARNING(\"[AnomalyHandler] +INF -> %f\", g_pinf_val_);\n                    TM_LOG_WARNING(\"[AnomalyHandler] -INF -> %f\", g_ninf_val_);\n                }\n            }\n\n            if (const auto nan = parse_float(str, \"nan=\")) {\n                g_nan_val_ = *nan;\n                TM_LOG_WARNING(\"[AnomalyHandler] NaN -> %f\", g_nan_val_);\n            }\n\n            const auto fallback = parse_float(str, \"fallback=\");\n            if (fallback) {\n                g_fallback = *fallback;\n                TM_LOG_WARNING(\"[AnomalyHandler] fallback -> %d\", g_fallback);\n            }\n\n            return {};\n        }();\n    }\n\n    void Init(int rank, int vocab_size, int fallback, int max_batch_size, cudaStream_t stream)\n    {\n        if (g_level) {\n            rank_       = rank;\n            stream_     = stream;\n            vocab_size_ = vocab_size;\n\n            max_batch_size_ = max_batch_size;\n\n            d_is_anomaly_.resize(max_batch_size);\n            h_is_anomaly_.resize(max_batch_size);\n\n            fallback_ = g_fallback;\n\n            // When fallback is not set from env\n            if (fallback_ == -1) {\n                fallback_ = fallback;\n                TM_LOG_WARNING(\"[AnomalyHandler] fallback: %d\", fallback_);\n            }\n\n            FT_CHECK(0 <= fallback_);\n            FT_CHECK(fallback_ < vocab_size);\n\n            TM_LOG_WARNING(\"[AnomalyHandler] max_batch_size: %d\", max_batch_size);\n            TM_LOG_WARNING(\"[AnomalyHandler] vocab_size: %d\", vocab_size);\n        }\n    }\n\n    void Summarize(std::function<void(const int*, int)> handler)\n    {\n        if (g_level) {\n            check_cuda_error(cudaMemcpyAsync(h_count_.data(),\n                                             d_count_.data().get(),\n                                             sizeof(size_type) * info_.size() * 2,\n                                             cudaMemcpyDefault,\n                                             stream_));\n\n            check_cuda_error(cudaMemcpyAsync(h_is_anomaly_.data(),\n                                             d_is_anomaly_.data().get(),\n                                             sizeof(int) * batch_size_,\n                                             cudaMemcpyDefault,\n                                             stream_));\n\n            check_cuda_error(cudaStreamSynchronize(stream_));\n\n#if 0\n            int die = 0;\n            for (size_t i = 0; i < info_.size(); ++i) {\n                const auto& n_inf = h_count_[i * 2];\n                const auto& n_nan = h_count_[i * 2 + 1];\n                if (n_inf || n_nan) {\n                    TM_LOG_WARNING(\"[AnomalyHandler][rank=%d] (%s) INF: %s, NaN: %s\",\n                                   rank_,\n                                   info_[i].c_str(),\n                                   std::to_string(n_inf).c_str(),\n                                   std::to_string(n_nan).c_str());\n                    ++die;\n                }\n            }\n            TM_CHECK_EQ(die, 0);\n#endif\n\n            handler(h_is_anomaly_.data(), batch_size_);\n        }\n    }\n\n    void Reset()\n    {\n        if (g_level) {\n            if (!info_.empty()) {\n                std::fill_n(h_count_.data(), info_.size() * 2, 0);\n                check_cuda_error(\n                    cudaMemsetAsync(d_count_.data().get(), 0, sizeof(size_type) * info_.size() * 2, stream_));\n                info_.clear();\n            }\n\n            if (batch_size_) {\n                std::fill_n(h_is_anomaly_.data(), batch_size_, 0);\n                check_cuda_error(cudaMemsetAsync(d_is_anomaly_.data().get(), 0, sizeof(int) * batch_size_, stream_));\n                batch_size_ = 0;\n            }\n        }\n    }\n\n    template<class T>\n    void invokeCountAndFixAnomaly(T* data, int64_t size, const std::string& key, int level)\n    {\n        if (g_level && level <= g_level) {\n            FT_CHECK(size >= 0);\n\n            constexpr int block = 512;\n            const int     grid  = (size + block - 1) / block;\n\n            auto idx = info_.size();\n            auto ptr = d_count_.data().get() + idx * 2;\n\n            info_.push_back(key);\n\n            FT_CHECK(info_.size() <= max_entries);\n\n            CountAndFixAnormaly<T, block><<<grid, block, 0, stream_>>>(data,  //\n                                                                       size,\n                                                                       ptr,\n                                                                       ptr + 1,\n                                                                       g_pinf_val_,\n                                                                       g_ninf_val_,\n                                                                       g_nan_val_);\n\n            sync_check_cuda_error();\n        }\n    }\n\n    template<class T>\n    void invokeFixLogitsAnomaly(T* logits, int batch_size, int level)\n    {\n        if (g_level && level <= g_level) {\n            FT_CHECK(batch_size <= max_batch_size_);\n\n            batch_size_ = batch_size;\n\n            constexpr int block = 256;\n\n            FixLogitsAnomaly<T, block><<<batch_size, block, 0, stream_>>>(logits,  //\n                                                                          d_is_anomaly_.data().get(),\n                                                                          vocab_size_,\n                                                                          batch_size,\n                                                                          fallback_);\n\n            sync_check_cuda_error();\n        }\n    }\n\n    static int   g_level;\n    static int   g_fallback;\n    static float g_pinf_val_;\n    static float g_ninf_val_;\n    static float g_nan_val_;\n\n    cudaStream_t stream_{};\n    int          rank_{};\n    int          vocab_size_{};\n    int          fallback_{};\n    int          max_batch_size_{};\n\n    ////////////////////////////////////////////////////////////////////////////////\n    /// Members below has SINGLE iteration validity and must be cleared in `Reset`\n\n    // Datum for tracing anomalies\n    thrust::device_vector<size_type> d_count_;\n    thrust::host_vector<size_type>   h_count_;\n    std::vector<std::string>         info_;\n\n    // Datum for fixing logits\n    thrust::device_vector<int> d_is_anomaly_;\n    thrust::host_vector<int>   h_is_anomaly_;\n    int                        batch_size_{};\n};\n\nint   AnomalyHandler::Impl::g_level     = 0;\nint   AnomalyHandler::Impl::g_fallback  = -1;\nfloat AnomalyHandler::Impl::g_pinf_val_ = INFINITY;\nfloat AnomalyHandler::Impl::g_ninf_val_ = -INFINITY;\nfloat AnomalyHandler::Impl::g_nan_val_  = NAN;\n\nAnomalyHandler::AnomalyHandler(): impl_{new Impl{}} {}\n\nAnomalyHandler::~AnomalyHandler() = default;\n\nAnomalyHandler& AnomalyHandler::instance()\n{\n    thread_local AnomalyHandler inst{};\n    return inst;\n}\n\nvoid AnomalyHandler::Init(int rank, int vocab_size, int fallback, int max_batch_size, cudaStream_t stream) noexcept\n{\n    impl_->Init(rank, vocab_size, fallback, max_batch_size, stream);\n}\n\nvoid AnomalyHandler::Summarize(std::function<void(const int*, int)> handler)\n{\n    impl_->Summarize(handler);\n}\n\nvoid AnomalyHandler::Reset()\n{\n    impl_->Reset();\n}\n\ntemplate<class T>\nvoid AnomalyHandler::CountAndFix(T* data, int64_t size, std::string key, int level)\n{\n    return impl_->invokeCountAndFixAnomaly(data, size, key, level);\n}\n\ntemplate void AnomalyHandler::CountAndFix(float*, int64_t, std::string, int);\ntemplate void AnomalyHandler::CountAndFix(half*, int64_t, std::string, int);\n#ifdef ENABLE_BF16\ntemplate void AnomalyHandler::CountAndFix(__nv_bfloat16*, int64_t, std::string, int);\n#endif\n\ntemplate<class T>\nvoid AnomalyHandler::FixLogits(T* logits, int batch_size, int level)\n{\n    impl_->invokeFixLogitsAnomaly(logits, batch_size, level);\n}\n\nint AnomalyHandler::level() noexcept\n{\n    return Impl::g_level;\n}\n\ntemplate void AnomalyHandler::FixLogits(float*, int, int);\ntemplate void AnomalyHandler::FixLogits(half*, int, int);\n#ifdef ENABLE_BF16\ntemplate void AnomalyHandler::FixLogits(__nv_bfloat16*, int, int);\n#endif\n\nvoid DebugTensor(Tensor& tensor, const std::string& key, int level)\n{\n    auto invoke = [&](auto t) {\n        using T = decltype(t);\n        AnomalyHandler::instance().CountAndFix((T*)tensor.raw_data(), tensor.size(), key, level);\n        // Compare((T*)tensor.raw_data(), tensor.size(), key, kCmpRead, core::Context::stream().handle());\n    };\n    if (tensor.size() == 0) {\n        return;\n    }\n    TM_DISPATCH_DTYPES(tensor.dtype(), invoke, float, half_t, bfloat16_t);\n}\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/utils/anomaly_handler.h",
    "content": "\n// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include <cstdint>\n#include <cuda_bf16.h>\n#include <cuda_runtime.h>\n#include <functional>\n#include <memory>\n#include <string>\n\n#include \"src/turbomind/core/core.h\"\n\nnamespace turbomind {\n\nclass AnomalyHandler {\npublic:\n    static constexpr size_t max_entries = 65536;\n\n    using size_type = unsigned long long;\n\n    ~AnomalyHandler();\n\n    static AnomalyHandler& instance();\n\n    static int level() noexcept;\n\n    void Init(int rank, int vocab_size, int fallback, int max_batch_size, cudaStream_t stream) noexcept;\n\n    template<class T>\n    void CountAndFix(T* data, int64_t size, std::string key, int level);\n\n    template<class T>\n    void FixLogits(T* logits, int batch_size, int level);\n\n    void Summarize(std::function<void(const int*, int)> handler);\n\n    void Reset();\n\nprivate:\n    AnomalyHandler();\n\nprivate:\n    struct Impl;\n    std::unique_ptr<Impl> impl_;\n};\n\ntemplate<class T>\nvoid count_and_fix(T* data, size_t size, std::string key, int level)\n{\n    AnomalyHandler::instance().CountAndFix(data, size, key, level);\n}\n\nvoid DebugTensor(Tensor& tensor, const std::string& key, int level);\n\ninline void DebugTensor(Tensor&& tensor, const std::string& key, int level)\n{\n    DebugTensor(tensor, key, level);\n}\n\n#define TM_DEBUG_RAW(ptr, size, key, __level)                                                                          \\\n    if (::turbomind::AnomalyHandler::level() >= __level) {                                                             \\\n        ::turbomind::count_and_fix(ptr, size, key, __level);                                                           \\\n    }\n\n#define TM_DEBUG_TENSOR(tensor, key, __level)                                                                          \\\n    if (::turbomind::AnomalyHandler::level() >= __level) {                                                             \\\n        ::turbomind::DebugTensor(tensor, key, __level);                                                                \\\n    }\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/utils/constant.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\nnamespace turbomind {\n\nconst int kMaxLogProb = 1024;\n\n}\n"
  },
  {
    "path": "src/turbomind/utils/cuda_bf16_fallbacks.cuh",
    "content": "/*\n * Copyright (c) 2019-2023, NVIDIA CORPORATION.  All rights reserved.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#pragma once\n\n#include \"src/turbomind/utils/cuda_bf16_wrapper.h\"\n#include <cuda_fp16.h>\n\nnamespace turbomind {\n\n#ifdef ENABLE_BF16\ninline __device__ float2 bf1622float2(const __nv_bfloat162 val)\n{\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800\n    float2 f_val;\n    f_val.x = __low2float(val);\n    f_val.y = __high2float(val);\n    return f_val;\n#else\n    return __bfloat1622float2(val);\n#endif\n}\n\ninline __device__ int16_t bf1622int16(__nv_bfloat162 val)\n{\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800\n    float2 f_val;\n    f_val.x = max(min(__low2float(val), 127.f), -128.f);\n    f_val.y = max(min(__high2float(val), 127.f), -128.f);\n    union {\n        int8_t  int8[2];\n        int16_t int16;\n    };\n    int8[0] = static_cast<int8_t>(static_cast<short>(f_val.x));\n    int8[1] = static_cast<int8_t>(static_cast<short>(f_val.y));\n    return int16;\n#else\n    val = __hmin2(val, make_bfloat162(127., 127.));\n    val = __hmax2(val, make_bfloat162(-128., -128.));\n    union {\n        int8_t  int8[2];\n        int16_t int16;\n    };\n    int8[0] = static_cast<int8_t>(static_cast<short>(val.x));\n    int8[1] = static_cast<int8_t>(static_cast<short>(val.y));\n    return int16;\n#endif\n}\n\ninline __device__ __nv_bfloat162 float22bf162(const float2 val)\n{\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800\n    return __floats2bfloat162_rn(val.x, val.y);\n#else\n    return __float22bfloat162_rn(val);\n#endif\n}\n\ninline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val)\n{\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800\n    __nv_bfloat162 val2;\n    val2.x = val;\n    val2.y = val;\n    return val2;\n#else\n    return __bfloat162bfloat162(val);\n#endif\n}\n\ninline __device__ __nv_bfloat162 bf16hadd2(const __nv_bfloat162 x, const __nv_bfloat162 y)\n{\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800\n    float fxl, fxh, fyl, fyh;\n    fxl = __low2float(x);\n    fxh = __high2float(x);\n    fyl = __low2float(y);\n    fyh = __high2float(y);\n    return __floats2bfloat162_rn(fxl + fyl, fxh + fyh);\n#else\n    return __hadd2(x, y);\n#endif\n}\n\ninline __device__ __nv_bfloat16 bf16hadd(const __nv_bfloat16 x, const __nv_bfloat16 y)\n{\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800\n    return __float2bfloat16(__bfloat162float(x) + __bfloat162float(y));\n#else\n    return __hadd(x, y);\n#endif\n}\n\ninline __device__ __nv_bfloat162 bf16hsub2(const __nv_bfloat162 x, const __nv_bfloat162 y)\n{\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800\n    float fxl, fxh, fyl, fyh;\n    fxl = __low2float(x);\n    fxh = __high2float(x);\n    fyl = __low2float(y);\n    fyh = __high2float(y);\n    return __floats2bfloat162_rn(fxl - fyl, fxh - fyh);\n#else\n    return __hsub2(x, y);\n#endif\n}\n\ninline __device__ __nv_bfloat16 bf16hsub(const __nv_bfloat16 x, const __nv_bfloat16 y)\n{\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800\n    return __float2bfloat16(__bfloat162float(x) - __bfloat162float(y));\n#else\n    return __hsub(x, y);\n#endif\n}\n\ninline __device__ __nv_bfloat162 bf16hmul2(const __nv_bfloat162 x, const __nv_bfloat162 y)\n{\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800\n    float fxl, fxh, fyl, fyh;\n    fxl = __low2float(x);\n    fxh = __high2float(x);\n    fyl = __low2float(y);\n    fyh = __high2float(y);\n    return __floats2bfloat162_rn(fxl * fyl, fxh * fyh);\n#else\n    return __hmul2(x, y);\n#endif\n}\n\ninline __device__ __nv_bfloat16 bf16hmul(const __nv_bfloat16 x, const __nv_bfloat16 y)\n{\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800\n    return __float2bfloat16(__bfloat162float(x) * __bfloat162float(y));\n#else\n    return __hmul(x, y);\n#endif\n}\n\ninline __device__ __nv_bfloat162 bf16hfma2(const __nv_bfloat162 x, const __nv_bfloat162 y, const __nv_bfloat162 z)\n{\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800\n    float fxl, fxh, fyl, fyh, fzl, fzh;\n    fxl = __low2float(x);\n    fxh = __high2float(x);\n    fyl = __low2float(y);\n    fyh = __high2float(y);\n    fzl = __low2float(z);\n    fzh = __high2float(z);\n    return __floats2bfloat162_rn(fxl * fyl + fzl, fxh * fyh + fzh);\n#else\n    return __hfma2(x, y, z);\n#endif\n}\n\ninline __device__ __nv_bfloat16 bf16hfma(const __nv_bfloat16 x, const __nv_bfloat16 y, const __nv_bfloat16 z)\n{\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800\n    return __float2bfloat16(__bfloat162float(x) * __bfloat162float(y) + __bfloat162float(z));\n#else\n    return __hfma(x, y, z);\n#endif\n}\n\ninline __device__ __nv_bfloat162 bf16exp2(const __nv_bfloat162 x)\n{\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800\n    float fxl, fxh;\n    fxl = __low2float(x);\n    fxh = __high2float(x);\n    ;\n    return __floats2bfloat162_rn(expf(fxl), expf(fxh));\n#else\n    return h2exp(x);\n#endif\n}\n\n#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)\ninline __device__ __nv_bfloat162 operator*(const __nv_bfloat162 x, const __nv_bfloat162 y)\n{\n    return bf16hmul2(x, y);\n};\ninline __device__ __nv_bfloat162 operator+(const __nv_bfloat162 x, const __nv_bfloat162 y)\n{\n    return bf16hadd2(x, y);\n};\n\ninline __device__ __nv_bfloat162 make_bfloat162(const __nv_bfloat16 x, const __nv_bfloat16 y)\n{\n    __nv_bfloat162 t;\n    t.x = x;\n    t.y = y;\n    return t;\n}\n\n#endif\n\ninline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c)\n{\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800\n    return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b) + __bfloat162float(c));\n#else\n    return a + b + c;\n#endif\n}\n\ninline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c, __nv_bfloat16 d)\n{\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800\n    return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b) + __bfloat162float(c) + __bfloat162float(d));\n#else\n    return (__nv_bfloat16)((float)a + (float)b + (float)c + (float)d);\n#endif\n}\n\ninline __device__ __nv_bfloat162 bf16hadd2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c)\n{\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800\n    float fal, fah, fbl, fbh, fcl, fch;\n    fal = __low2float(a);\n    fah = __high2float(a);\n    fbl = __low2float(b);\n    fbh = __high2float(b);\n    fcl = __low2float(c);\n    fch = __high2float(c);\n    return __floats2bfloat162_rn(fal + fbl + fcl, fah + fbh + fch);\n#else\n    return a + b + c;\n#endif\n}\n\ninline __device__ __nv_bfloat16 bf16hmul(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c)\n{\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800\n    return __float2bfloat16(__bfloat162float(a) * __bfloat162float(b) * __bfloat162float(c));\n#else\n    return a * b * c;\n#endif\n}\n\ninline __device__ __nv_bfloat162 bf16hmul2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c)\n{\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800\n    float fal, fah, fbl, fbh, fcl, fch;\n    fal = __low2float(a);\n    fah = __high2float(a);\n    fbl = __low2float(b);\n    fbh = __high2float(b);\n    fcl = __low2float(c);\n    fch = __high2float(c);\n    return __floats2bfloat162_rn(fal * fbl * fcl, fah * fbh * fch);\n#else\n    return a * b * c;\n#endif\n}\n\ninline __device__ __nv_bfloat162 bf16hfma2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c, __nv_bfloat162 d)\n{\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800\n    float fal, fah, fbl, fbh, fcl, fch, fdl, fdh;\n    fal = __low2float(a);\n    fah = __high2float(a);\n    fbl = __low2float(b);\n    fbh = __high2float(b);\n    fcl = __low2float(c);\n    fch = __high2float(c);\n    fdl = __low2float(d);\n    fdh = __high2float(d);\n    return __floats2bfloat162_rn(fal * fbl * fcl + fdl, fah * fbh * fch + fdh);\n#else\n    return a * b * c + d;\n#endif\n}\n\n#endif  // ENABLE_BF16\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/utils/cuda_bf16_wrapper.h",
    "content": "/*\n * Copyright (c) 2019-2023, NVIDIA CORPORATION.  All rights reserved.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#pragma once\n\n#ifdef ENABLE_BF16\n#include <cuda_bf16.h>\n#endif\n"
  },
  {
    "path": "src/turbomind/utils/cuda_type_utils.cuh",
    "content": "/*\n * Copyright (c) 2022-2023, NVIDIA CORPORATION.  All rights reserved.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#pragma once\n\n#include \"src/turbomind/utils/cuda_bf16_fallbacks.cuh\"\n#include \"src/turbomind/utils/cuda_bf16_wrapper.h\"\n#include <cuda.h>\n#include <cuda_fp16.h>\n\nnamespace turbomind {\n\ntemplate<typename T>\ninline __device__ T ldg(const T* val)\n{\n    return __ldg(val);\n}\n\n#if ENABLE_BF16\ntemplate<>\ninline __device__ __nv_bfloat162 ldg(const __nv_bfloat162* val)\n{\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800\n    return val[0];\n#else\n    return __ldg(val);\n#endif\n}\n\ntemplate<>\ninline __device__ __nv_bfloat16 ldg(const __nv_bfloat16* val)\n{\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800\n    return val[0];\n#else\n    return __ldg(val);\n#endif\n}\n#endif  // ENABLE_BF16\n\n// Get type2 from type or vice versa (applied to half and bfloat16)\ntemplate<typename T>\nstruct TypeConverter {\n    using Type = half2;\n};  // keep for generality\n\ntemplate<>\nstruct TypeConverter<half2> {\n    using Type = half;\n};\n\ntemplate<>\nstruct TypeConverter<half> {\n    using Type = half2;\n};\n\n#if ENABLE_BF16\ntemplate<>\nstruct TypeConverter<__nv_bfloat162> {\n    using Type = __nv_bfloat16;\n};\n\ntemplate<>\nstruct TypeConverter<__nv_bfloat16> {\n    using Type = __nv_bfloat162;\n};\n#endif  // ENABLE_BF16\n\n// Defined math operations (bfloat16 fallback to fp32 when it is not supported)\ntemplate<typename T>\ninline __device__ T hadd2(T a, T b)\n{\n    return __hadd2(a, b);\n}\n\n#if ENABLE_BF16\ntemplate<>\ninline __device__ __nv_bfloat162 hadd2(__nv_bfloat162 a, __nv_bfloat162 b)\n{\n    return bf16hadd2(a, b);\n}\n#endif  // ENABLE_BF16\n\ntemplate<typename T>\ninline __device__ T add(T a, T b)\n{\n    return a + b;\n}\n\ntemplate<>\ninline __device__ half2 add(half2 a, half2 b)\n{\n    return __hadd2(a, b);\n}\n\ntemplate<>\ninline __device__ half add(half a, half b)\n{\n    return __hadd(a, b);\n}\n\n#if ENABLE_BF16\ntemplate<>\ninline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b)\n{\n    return bf16hadd2(a, b);\n}\n\ntemplate<>\ninline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b)\n{\n    return bf16hadd(a, b);\n}\n\ninline __device__ __nv_bfloat16 add(__nv_bfloat16 a, float b)\n{\n    return bf16hadd(a, __float2bfloat16(b));\n}\n#endif  // ENABLE_BF16\n\n// applies to all 4 values addition\ntemplate<typename T>\ninline __device__ T add(T a, T b, T c)\n{\n    return a + b + c;\n}\n\n#if ENABLE_BF16\ninline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c)\n{\n    return bf16hadd(a, b, c);\n}\n\ninline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c)\n{\n    return bf16hadd2(a, b, c);\n}\n#endif  // ENABLE_BF16\n\n// applies to all 4 values addition\ntemplate<typename T>\ninline __device__ T add(T a, T b, T c, T d)\n{\n    return (T)((float)a + (float)b + (float)c + (float)d);\n}\n\n#if ENABLE_BF16\ninline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c, __nv_bfloat16 d)\n{\n    return bf16hadd(a, b, c, d);\n}\n#endif  // ENABLE_BF16\n\ntemplate<typename T>\ninline __device__ T hsub2(T a, T b)\n{\n    return __hsub2(a, b);\n}\n\n#if ENABLE_BF16\ntemplate<>\ninline __device__ __nv_bfloat162 hsub2(__nv_bfloat162 a, __nv_bfloat162 b)\n{\n    return bf16hsub2(a, b);\n}\n#endif  // ENABLE_BF16\n\ntemplate<typename T>\ninline __device__ T hmul2(T a, T b)\n{\n    return __hmul2(a, b);\n}\n\n#if ENABLE_BF16\ntemplate<>\ninline __device__ __nv_bfloat162 hmul2(__nv_bfloat162 a, __nv_bfloat162 b)\n{\n    return bf16hmul2(a, b);\n}\n#endif  // ENABLE_BF16\n\ntemplate<typename T>\ninline __device__ T hmul2(T a, T b, T c)\n{\n    return a * b * c;\n}\n\n#if ENABLE_BF16\ntemplate<>\ninline __device__ __nv_bfloat162 hmul2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c)\n{\n    return bf16hmul2(a, b, c);\n}\n#endif  // ENABLE_BF16\n\ntemplate<typename T>\ninline __device__ T mul(T a, T b, T c)\n{\n    return a * b * c;\n}\n\n#if ENABLE_BF16\ntemplate<>\ninline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c)\n{\n    return bf16hmul(a, b, c);\n}\n\ninline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c)\n{\n    return bf16hmul2(a, b, c);\n}\n#endif  // ENABLE_BF16\n\ntemplate<typename T>\ninline __device__ T fma(T a, T b, T c, T d)\n{\n    return a * b * c + d;\n}\n\n#if ENABLE_BF16\ninline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c, __nv_bfloat162 d)\n{\n    return bf16hfma2(a, b, c, d);\n}\n#endif  // ENABLE_BF16\n\ntemplate<typename T>\ninline __device__ T fma(T a, T b, T c)\n{\n    return a * b + c;\n}\n\n#if ENABLE_BF16\ntemplate<>\ninline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c)\n{\n    return bf16hfma2(a, b, c);\n}\n\ntemplate<>\ninline __device__ __nv_bfloat16 fma(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c)\n{\n    return bf16hfma(a, b, c);\n}\n#endif  // ENABLE_BF16\n\ntemplate<typename T>\ninline __device__ T hexp2(T a)\n{\n    return h2exp(a);\n}\n\n#if ENABLE_BF16\ntemplate<>\ninline __device__ __nv_bfloat162 hexp2(__nv_bfloat162 a)\n{\n    return bf16exp2(a);\n}\n#endif  // ENABLE_BF16\n\ntemplate<typename T_OUT, typename T_IN>\n__device__ inline T_OUT cuda_cast(T_IN val)\n{\n    return val;\n}\n\ntemplate<>\n__device__ inline float2 cuda_cast<float2, int2>(int2 val)\n{\n    return make_float2(val.x, val.y);\n}\ntemplate<>\n__device__ inline float2 cuda_cast<float2, float>(float val)\n{\n    return make_float2(val, val);\n}\ntemplate<>\n__device__ inline float2 cuda_cast<float2, half2>(half2 val)\n{\n    return __half22float2(val);\n}\ntemplate<>\n__device__ inline half2 cuda_cast<half2, float2>(float2 val)\n{\n    return __float22half2_rn(val);\n}\ntemplate<>\n__device__ inline half2 cuda_cast<half2, float>(float val)\n{\n    return __float2half2_rn(val);\n}\ntemplate<>\n__device__ inline half2 cuda_cast<half2, half>(half val)\n{\n    return __half2half2(val);\n}\n\ntemplate<>\n__device__ inline int8_t cuda_cast<int8_t, half>(half val)\n{\n    union {\n        int8_t  int8[2];\n        int16_t int16;\n    };\n    union {\n        half    fp16;\n        int16_t int16_in;\n    };\n    fp16 = val;\n    asm volatile(\"cvt.rni.sat.s8.f16 %0, %1;\" : \"=h\"(int16) : \"h\"(int16_in));\n    return int8[0];\n}\n\ntemplate<>\n__device__ inline int16_t cuda_cast<int16_t, half2>(half2 val)\n{\n    union {\n        int8_t  int8[2];\n        int16_t int16;\n    };\n    int8[0] = cuda_cast<int8_t>(val.x);\n    int8[1] = cuda_cast<int8_t>(val.y);\n    return int16;\n}\n\ntemplate<>\n__device__ inline int8_t cuda_cast<int8_t, float>(float val)\n{\n    union {\n        int8_t  int8[2];\n        int16_t int16;\n    };\n    asm volatile(\"cvt.rni.sat.s8.f32 %0, %1;\" : \"=h\"(int16) : \"f\"(val));\n    return int8[0];\n}\n\ntemplate<>\n__device__ inline int16_t cuda_cast<int16_t, float2>(float2 val)\n{\n    union {\n        int8_t  int8[2];\n        int16_t int16;\n    };\n    int8[0] = cuda_cast<int8_t>(val.x);\n    int8[1] = cuda_cast<int8_t>(val.y);\n    return int16;\n}\n\ntemplate<>\n__device__ inline half2 cuda_cast<half2, int16_t>(int16_t val)\n{\n    union {\n        int8_t  int8[2];\n        int16_t int16;\n    };\n    int16 = val;\n    return make_half2(int8[0], int8[1]);\n}\n\ntemplate<>\n__device__ inline float2 cuda_cast<float2, int16_t>(int16_t val)\n{\n    union {\n        int8_t  int8[2];\n        int16_t int16;\n    };\n    int16 = val;\n    return make_float2(int8[0], int8[1]);\n}\n\n#ifdef ENABLE_BF16\ntemplate<>\n__device__ inline __nv_bfloat16 cuda_cast(int32_t val)\n{\n    return static_cast<float>(val);\n}\ntemplate<>\n__device__ inline __nv_bfloat16 cuda_cast(int8_t val)\n{\n    return static_cast<float>(val);\n}\ntemplate<>\n__device__ inline int8_t cuda_cast(__nv_bfloat16 val)\n{\n    return static_cast<float>(val);\n}\n\ntemplate<>\n__device__ inline float cuda_cast<float, __nv_bfloat16>(__nv_bfloat16 val)\n{\n    return __bfloat162float(val);\n}\n\ntemplate<>\n__device__ inline float2 cuda_cast<float2, __nv_bfloat162>(__nv_bfloat162 val)\n{\n    return bf1622float2(val);\n}\n\ntemplate<>\n__device__ inline half cuda_cast<half, __nv_bfloat16>(__nv_bfloat16 val)\n{\n    return __float2half(__bfloat162float(val));\n}\n\ntemplate<>\n__device__ inline int16_t cuda_cast<int16_t, __nv_bfloat162>(__nv_bfloat162 val)\n{\n    return bf1622int16(val);\n}\n\ntemplate<>\n__device__ inline __nv_bfloat16 cuda_cast<__nv_bfloat16, float>(float val)\n{\n    return __float2bfloat16(val);\n}\ntemplate<>\n__device__ inline __nv_bfloat16 cuda_cast<__nv_bfloat16, half>(half val)\n{\n    return __float2bfloat16(__half2float(val));\n}\n\ntemplate<>\n__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, __nv_bfloat16>(__nv_bfloat16 val)\n{\n    return bf162bf162(val);\n}\ntemplate<>\n__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, float>(float val)\n{\n    return __float2bfloat162_rn(val);\n}\ntemplate<>\n__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, float2>(float2 val)\n{\n    return float22bf162(val);\n}\ntemplate<>\n__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, int16_t>(int16_t val)\n{\n    union {\n        int8_t  int8[2];\n        int16_t int16;\n    };\n    int16 = val;\n    __nv_bfloat162 res;\n    res.x = cuda_cast<__nv_bfloat16>(int8[0]);\n    res.y = cuda_cast<__nv_bfloat16>(int8[1]);\n    return res;\n}\n\ntemplate<>\n__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, half2>(half2 val)\n{\n    return float22bf162(__half22float2(val));\n}\n\n#endif  // ENABLE BF16\n\ntemplate<typename T>\n__device__ inline T cuda_abs(T val);\ntemplate<>\n__device__ inline float cuda_abs(float val)\n{\n    return fabs(val);\n}\ntemplate<>\n__device__ inline half cuda_abs(half val)\n{\n    return __habs(val);\n}\ntemplate<>\n__device__ inline half2 cuda_abs(half2 val)\n{\n    return __habs2(val);\n}\n\n#ifdef ENABLE_BF16\n\n#if __CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)\ntemplate<>\n__device__ inline __nv_bfloat16 cuda_abs(__nv_bfloat16 val)\n{\n    return __habs(val);\n}\ntemplate<>\n__device__ inline __nv_bfloat162 cuda_abs(__nv_bfloat162 val)\n{\n    return __habs2(val);\n}\n#else\ntemplate<>\n__device__ inline __nv_bfloat16 cuda_abs(__nv_bfloat16 val)\n{\n    return fabs(cuda_cast<float>(val));\n}\ntemplate<>\n__device__ inline __nv_bfloat162 cuda_abs(__nv_bfloat162 val)\n{\n    return make_bfloat162(fabs(cuda_cast<float>(val.x)), fabs(cuda_cast<float>(val.y)));\n}\n#endif\n\n#endif  // ENABLE_FP16\n\n// Unary maximum: compute the max of a vector type\ntemplate<typename To, typename Ti>\n__device__ inline To cuda_max(Ti val)\n{\n    return cuda_cast<To>(val);\n};\n\ntemplate<>\n__device__ inline half cuda_max(half2 val)\n{\n    return (val.x > val.y) ? val.x : val.y;\n}\n#ifdef ENABLE_BF16\ntemplate<>\n__device__ inline __nv_bfloat16 cuda_max(__nv_bfloat162 val)\n{\n    return (val.x > val.y) ? val.x : val.y;\n}\n#endif\n\n// Binary maximum: compute the max of two scalar types\ntemplate<typename T>\n__device__ inline T cuda_max(T val1, T val2)\n{\n    return (val1 > val2) ? val1 : val2;\n}\n\n#ifdef ENABLE_FP8\ntemplate<>\n__device__ inline float2 cuda_cast<float2, __nv_fp8x2_e4m3>(__nv_fp8x2_e4m3 val)\n{\n    return bf1622float2(fp8x2_e4m3_to_bfloat2(&val));\n}\ntemplate<>\n__device__ inline __nv_fp8x2_e4m3 cuda_cast<__nv_fp8x2_e4m3, float2>(float2 val)\n{\n    return __nv_fp8x2_e4m3(bf1622float2(float22bf162(val)));\n}\n\ntemplate<>\n__device__ inline __nv_fp8_e4m3 cuda_cast<__nv_fp8_e4m3, half>(half val)\n{\n    return __nv_fp8_e4m3(val);\n}\ntemplate<>\n__device__ inline __nv_fp8_e4m3 cuda_cast<__nv_fp8_e4m3, __nv_bfloat16>(__nv_bfloat16 val)\n{\n    return __nv_fp8_e4m3(val);\n}\ntemplate<>\n__device__ inline __nv_fp8_e4m3 cuda_cast<__nv_fp8_e4m3, float>(float val)\n{\n    return __nv_fp8_e4m3(val);\n}\ntemplate<>\n__device__ inline float cuda_cast<float, __nv_fp8_e4m3>(__nv_fp8_e4m3 val)\n{\n    return (float)val;\n}\ntemplate<>\n__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, __nv_fp8x2_e4m3>(__nv_fp8x2_e4m3 val)\n{\n    return fp8x2_e4m3_to_bfloat2(&val);\n}\n\ntemplate<>\n__device__ inline int8_t cuda_cast<int8_t, __nv_fp8_e4m3>(__nv_fp8_e4m3 val)\n{\n    // no impl\n    return 0;\n}\n\ntemplate<>\n__device__ inline __nv_fp8_e4m3 cuda_cast<__nv_fp8_e4m3, int8_t>(int8_t val)\n{\n    return cuda_cast<__nv_fp8_e4m3>(cuda_cast<__nv_bfloat16>(cuda_cast<float>(val)));\n}\n\n#endif  // ENABLE_FP8\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/utils/cuda_utils.cc",
    "content": "/*\n * Copyright (c) 2019-2023, NVIDIA CORPORATION.  All rights reserved.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include \"src/turbomind/utils/cuda_utils.h\"\n#include \"src/turbomind/macro.h\"\n#include <driver_types.h>\n#include <regex>\n\nnamespace turbomind {\n\nvoid syncAndCheck(const char* const file, int const line)\n{\n    // When FT_DEBUG_LEVEL=DEBUG, must check error\n    static char* level_name = std::getenv(\"TM_DEBUG_LEVEL\");\n    if (level_name != nullptr) {\n        static std::string level = std::string(level_name);\n        if (level == \"DEBUG\") {\n            cudaDeviceSynchronize();\n            cudaError_t result = cudaGetLastError();\n            if (result) {\n                TM_LOG_ERROR((std::string(\"CUDA runtime error: \") + (_cudaGetErrorEnum(result)) + \" \" + file + \":\"\n                              + std::to_string(line))\n                                 .c_str());\n                std::abort();\n            }\n            TM_LOG_DEBUG(fmtstr(\"run syncAndCheck at %s:%d\", file, line));\n        }\n    }\n}\n\n/* **************************** debug tools ********************************* */\n\ntemplate<typename T>\nvoid printMatrix(T* ptr, int m, int k, int stride, bool is_device_ptr)\n{\n    T* tmp;\n    if (is_device_ptr) {\n        // k < stride ; stride = col-dimension.\n        tmp = reinterpret_cast<T*>(malloc(m * stride * sizeof(T)));\n        check_cuda_error(cudaMemcpy(tmp, ptr, sizeof(T) * m * stride, cudaMemcpyDeviceToHost));\n        cudaDeviceSynchronize();\n    }\n    else {\n        tmp = ptr;\n    }\n\n    for (int ii = -1; ii < m; ++ii) {\n        if (ii >= 0) {\n            printf(\"%02d \", ii);\n        }\n        else {\n            printf(\"   \");\n        }\n\n        for (int jj = 0; jj < k; jj += 1) {\n            if (ii >= 0) {\n                printf(\"%7.3f \", (float)tmp[ii * stride + jj]);\n            }\n            else {\n                printf(\"%7d \", jj);\n            }\n        }\n        printf(\"\\n\");\n    }\n    if (is_device_ptr) {\n        free(tmp);\n    }\n}\n\ntemplate void printMatrix(float* ptr, int m, int k, int stride, bool is_device_ptr);\ntemplate void printMatrix(half* ptr, int m, int k, int stride, bool is_device_ptr);\n#ifdef ENABLE_BF16\ntemplate void printMatrix(__nv_bfloat16* ptr, int m, int k, int stride, bool is_device_ptr);\n#endif\n\nvoid printMatrix(unsigned long long* ptr, int m, int k, int stride, bool is_device_ptr)\n{\n    typedef unsigned long long T;\n    T*                         tmp;\n    if (is_device_ptr) {\n        // k < stride ; stride = col-dimension.\n        tmp = reinterpret_cast<T*>(malloc(m * stride * sizeof(T)));\n        check_cuda_error(cudaMemcpy(tmp, ptr, sizeof(T) * m * stride, cudaMemcpyDeviceToHost));\n        cudaDeviceSynchronize();\n    }\n    else {\n        tmp = ptr;\n    }\n\n    for (int ii = -1; ii < m; ++ii) {\n        if (ii >= 0) {\n            printf(\"%02d \", ii);\n        }\n        else {\n            printf(\"   \");\n        }\n\n        for (int jj = 0; jj < k; jj += 1) {\n            if (ii >= 0) {\n                printf(\"%4llu \", tmp[ii * stride + jj]);\n            }\n            else {\n                printf(\"%4d \", jj);\n            }\n        }\n        printf(\"\\n\");\n    }\n    if (is_device_ptr) {\n        free(tmp);\n    }\n}\n\nvoid printMatrix(int* ptr, int m, int k, int stride, bool is_device_ptr)\n{\n    typedef int T;\n    T*          tmp;\n    if (is_device_ptr) {\n        // k < stride ; stride = col-dimension.\n        tmp = reinterpret_cast<T*>(malloc(m * stride * sizeof(T)));\n        check_cuda_error(cudaMemcpy(tmp, ptr, sizeof(T) * m * stride, cudaMemcpyDeviceToHost));\n        cudaDeviceSynchronize();\n    }\n    else {\n        tmp = ptr;\n    }\n\n    for (int ii = -1; ii < m; ++ii) {\n        if (ii >= 0) {\n            printf(\"%02d \", ii);\n        }\n        else {\n            printf(\"   \");\n        }\n\n        for (int jj = 0; jj < k; jj += 1) {\n            if (ii >= 0) {\n                printf(\"%4d \", tmp[ii * stride + jj]);\n            }\n            else {\n                printf(\"%4d \", jj);\n            }\n        }\n        printf(\"\\n\");\n    }\n    if (is_device_ptr) {\n        free(tmp);\n    }\n}\n\n// multiple definitions for msvc\n#ifndef _MSC_VER\nvoid printMatrix(size_t* ptr, int m, int k, int stride, bool is_device_ptr)\n{\n    typedef size_t T;\n    T*             tmp;\n    if (is_device_ptr) {\n        // k < stride ; stride = col-dimension.\n        tmp = reinterpret_cast<T*>(malloc(m * stride * sizeof(T)));\n        check_cuda_error(cudaMemcpy(tmp, ptr, sizeof(T) * m * stride, cudaMemcpyDeviceToHost));\n        cudaDeviceSynchronize();\n    }\n    else {\n        tmp = ptr;\n    }\n\n    for (int ii = -1; ii < m; ++ii) {\n        if (ii >= 0) {\n            printf(\"%02d \", ii);\n        }\n        else {\n            printf(\"   \");\n        }\n\n        for (int jj = 0; jj < k; jj += 1) {\n            if (ii >= 0) {\n                printf(\"%4ld \", tmp[ii * stride + jj]);\n            }\n            else {\n                printf(\"%4d \", jj);\n            }\n        }\n        printf(\"\\n\");\n    }\n    if (is_device_ptr) {\n        free(tmp);\n    }\n}\n#endif\n\ntemplate<typename T>\nvoid check_max_val(const T* result, const int size)\n{\n    T* tmp = new T[size];\n    cudaMemcpy(tmp, result, sizeof(T) * size, cudaMemcpyDeviceToHost);\n    float max_val = -100000;\n    for (int i = 0; i < size; i++) {\n        float val = static_cast<float>(tmp[i]);\n        if (val > max_val) {\n            max_val = val;\n        }\n    }\n    delete tmp;\n    printf(\"[INFO][CUDA] addr %p max val: %f \\n\", result, max_val);\n}\n\ntemplate void check_max_val(const float* result, const int size);\ntemplate void check_max_val(const half* result, const int size);\n#ifdef ENABLE_BF16\ntemplate void check_max_val(const __nv_bfloat16* result, const int size);\n#endif\n\ntemplate<typename T>\nvoid check_abs_mean_val(const T* result, const int size)\n{\n    T* tmp = new T[size];\n    cudaMemcpy(tmp, result, sizeof(T) * size, cudaMemcpyDeviceToHost);\n    float sum = 0.0f;\n    for (int i = 0; i < size; i++) {\n        sum += abs(static_cast<float>(tmp[i]));\n    }\n    delete tmp;\n    printf(\"[INFO][CUDA] addr %p abs mean val: %f \\n\", result, sum / size);\n}\n\ntemplate void check_abs_mean_val(const float* result, const int size);\ntemplate void check_abs_mean_val(const half* result, const int size);\n#ifdef ENABLE_BF16\ntemplate void check_abs_mean_val(const __nv_bfloat16* result, const int size);\n#endif\n\n/* ***************************** common utils ****************************** */\n\nint getSMVersion()\n{\n    int device{-1};\n    check_cuda_error(cudaGetDevice(&device));\n    int sm_major = 0;\n    int sm_minor = 0;\n    check_cuda_error(cudaDeviceGetAttribute(&sm_major, cudaDevAttrComputeCapabilityMajor, device));\n    check_cuda_error(cudaDeviceGetAttribute(&sm_minor, cudaDevAttrComputeCapabilityMinor, device));\n    return sm_major * 10 + sm_minor;\n}\n\nint getSMCount()\n{\n    int device{-1};\n    check_cuda_error(cudaGetDevice(&device));\n    int sm_count{};\n    check_cuda_error(cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, device));\n    return sm_count;\n}\n\nstd::string getDeviceName()\n{\n    int device{-1};\n    check_cuda_error(cudaGetDevice(&device));\n    cudaDeviceProp props;\n    check_cuda_error(cudaGetDeviceProperties(&props, device));\n    return std::string(props.name);\n}\n\nint getDevice()\n{\n    int current_dev_id = 0;\n    check_cuda_error(cudaGetDevice(&current_dev_id));\n    return current_dev_id;\n}\n\nint getDeviceCount()\n{\n    int count = 0;\n    check_cuda_error(cudaGetDeviceCount(&count));\n    return count;\n}\n\nvoid trim_default_mempool(int device_id)\n{\n    cudaMemPool_t mempool;\n    check_cuda_error(cudaDeviceGetDefaultMemPool(&mempool, device_id));\n    check_cuda_error(cudaMemPoolTrimTo(mempool, 0));\n}\n\n/* ************************** end of common utils ************************** */\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/utils/cuda_utils.h",
    "content": "/*\n * Copyright (c) 2019-2023, NVIDIA CORPORATION.  All rights reserved.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#pragma once\n\n#include <algorithm>\n#include <fstream>\n#include <iostream>\n#include <string>\n#include <vector>\n\n#include <cuda.h>\n#include <cuda_runtime.h>\n\n#include <cublasLt.h>\n#include <cublas_v2.h>\n#ifdef SPARSITY_ENABLED\n#include <cusparseLt.h>\n#endif\n\n#include \"src/turbomind/core/check.h\"\n#include \"src/turbomind/macro.h\"\n#include \"src/turbomind/utils/cuda_bf16_wrapper.h\"\n#include \"src/turbomind/utils/logger.h\"\n\nnamespace turbomind {\n\n/* **************************** debug tools ********************************* */\nstatic const char* _cudaGetErrorEnum(cudaError_t error)\n{\n    return cudaGetErrorString(error);\n}\n\nstatic const char* _cudaGetErrorEnum(cublasStatus_t error)\n{\n    switch (error) {\n        case CUBLAS_STATUS_SUCCESS:\n            return \"CUBLAS_STATUS_SUCCESS\";\n\n        case CUBLAS_STATUS_NOT_INITIALIZED:\n            return \"CUBLAS_STATUS_NOT_INITIALIZED\";\n\n        case CUBLAS_STATUS_ALLOC_FAILED:\n            return \"CUBLAS_STATUS_ALLOC_FAILED\";\n\n        case CUBLAS_STATUS_INVALID_VALUE:\n            return \"CUBLAS_STATUS_INVALID_VALUE\";\n\n        case CUBLAS_STATUS_ARCH_MISMATCH:\n            return \"CUBLAS_STATUS_ARCH_MISMATCH\";\n\n        case CUBLAS_STATUS_MAPPING_ERROR:\n            return \"CUBLAS_STATUS_MAPPING_ERROR\";\n\n        case CUBLAS_STATUS_EXECUTION_FAILED:\n            return \"CUBLAS_STATUS_EXECUTION_FAILED\";\n\n        case CUBLAS_STATUS_INTERNAL_ERROR:\n            return \"CUBLAS_STATUS_INTERNAL_ERROR\";\n\n        case CUBLAS_STATUS_NOT_SUPPORTED:\n            return \"CUBLAS_STATUS_NOT_SUPPORTED\";\n\n        case CUBLAS_STATUS_LICENSE_ERROR:\n            return \"CUBLAS_STATUS_LICENSE_ERROR\";\n    }\n    return \"<unknown>\";\n}\n\ntemplate<typename T>\nvoid check(T result, char const* const func, const char* const file, int const line)\n{\n    if (result) {\n        TM_LOG_ERROR((std::string(\"CUDA runtime error: \") + (_cudaGetErrorEnum(result)) + \" \" + file + \":\"\n                      + std::to_string(line))\n                         .c_str());\n        std::abort();\n    }\n}\n\n#define check_cuda_error(val) check((val), #val, __FILE__, __LINE__)\n#define check_cuda_error_2(val, file, line) check((val), #val, file, line)\n\nvoid syncAndCheck(const char* const file, int const line);\n\n#define sync_check_cuda_error() syncAndCheck(__FILE__, __LINE__)\n\n#define CUDRVCHECK(expr)                                                                                               \\\n    if (auto ec = expr; ec != CUDA_SUCCESS) {                                                                          \\\n        const char* p_str{};                                                                                           \\\n        cuGetErrorString(ec, &p_str);                                                                                  \\\n        p_str    = p_str ? p_str : \"Unknown error\";                                                                    \\\n        auto msg = fmtstr(\"[TM][ERROR] CUDA driver error: %s:%d '%s'\", __FILE__, __LINE__, p_str);                     \\\n        throw std::runtime_error(msg.c_str());                                                                         \\\n    }\n\ntemplate<typename T>\nvoid printMatrix(T* ptr, int m, int k, int stride, bool is_device_ptr);\n\nvoid printMatrix(unsigned long long* ptr, int m, int k, int stride, bool is_device_ptr);\nvoid printMatrix(int* ptr, int m, int k, int stride, bool is_device_ptr);\nvoid printMatrix(size_t* ptr, int m, int k, int stride, bool is_device_ptr);\n\ntemplate<typename T>\nvoid check_max_val(const T* result, const int size);\n\ntemplate<typename T>\nvoid check_abs_mean_val(const T* result, const int size);\n\n#define PRINT_FUNC_NAME_()                                                                                             \\\n    do {                                                                                                               \\\n        std::cout << \"[TM][CALL] \" << __FUNCTION__ << \" \" << std::endl;                                                \\\n    } while (0)\n\n[[noreturn]] inline void throwRuntimeError(const char* const file, int const line, std::string const& info = \"\")\n{\n    throw std::runtime_error(std::string(\"[TM][ERROR] \") + info + \" Assertion fail: \" + file + \":\"\n                             + std::to_string(line) + \" \\n\");\n}\n\ninline void myAssert(bool result, const char* const file, int const line, std::string const& info = \"\")\n{\n    if (!result) {\n        throwRuntimeError(file, line, info);\n    }\n}\n\n#define FT_CHECK(val) myAssert(bool(val), __FILE__, __LINE__)\n#define FT_CHECK_WITH_INFO(val, info)                                                                                  \\\n    do {                                                                                                               \\\n        bool is_valid_val = bool(val);                                                                                 \\\n        if (!is_valid_val) {                                                                                           \\\n            turbomind::myAssert(is_valid_val, __FILE__, __LINE__, (info));                                             \\\n        }                                                                                                              \\\n    } while (0)\n\n#define FT_THROW(info) throwRuntimeError(__FILE__, __LINE__, info)\n\n/* ***************************** common utils ****************************** */\n\nint getSMVersion();\n\nint getSMCount();\n\nstd::string getDeviceName();\n\ntemplate<class T>\ninline T div_up(T a, T n)\n{\n    return (a + n - 1) / n;\n}\n\nint getDevice();\n\nint getDeviceCount();\n\nclass CudaDeviceGuard {\npublic:\n    CudaDeviceGuard(int device)\n    {\n        check_cuda_error(cudaGetDevice(&last_device_id_));\n        if (device != last_device_id_) {\n            check_cuda_error(cudaSetDevice(device));\n        }\n    }\n\n    ~CudaDeviceGuard()\n    {\n        TM_CHECK_EQ(cudaSetDevice(last_device_id_), cudaSuccess);\n    }\n\nprivate:\n    int last_device_id_{-1};\n};\n\nvoid trim_default_mempool(int device_id);\n\n/* ************************** end of common utils ************************** */\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/utils/debug_utils.h",
    "content": "#pragma once\n\n#if __has_include(\"3rdparty/dbg.h\")\n#include \"3rdparty/dbg.h\"\n#else\n#define dbg(...)\n#endif\n"
  },
  {
    "path": "src/turbomind/utils/dispatch.h",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#pragma once\n\n#include <utility>\n\nnamespace turbomind {\n\nnamespace detail {\n\ntemplate<int X>\ninline constexpr std::integral_constant<int, X> _Int{};\n\ntemplate<class F, class P, class G, int... Xs, std::size_t... Is>\nbool dispatch_impl(F&& f, P&& p, G g, std::integer_sequence<int, Xs...>, std::index_sequence<Is...>)\n{\n    constexpr int N = sizeof...(Xs);\n    return (((((P &&) p)(_Int<Xs>) || (g && Is == N - 1)) && (((F &&) f)(_Int<Xs>), 1)) || ...);\n}\n\n}  // namespace detail\n\ntemplate<class F, class P, int... Is, class G = std::true_type>\nbool dispatch(std::integer_sequence<int, Is...> seq, P&& p, F&& f, G g = {})\n{\n    return detail::dispatch_impl((F &&) f, (P &&) p, g, seq, std::make_index_sequence<sizeof...(Is)>{});\n}\n\ntemplate<class F, int... Is, class G = std::true_type>\nbool dispatch(std::integer_sequence<int, Is...> seq, F&& f)\n{\n    return (((F &&) f)(detail::_Int<Is>) || ...);\n}\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/utils/logger.cc",
    "content": "/*\n * Copyright (c) 2022-2023, NVIDIA CORPORATION.  All rights reserved.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include \"src/turbomind/utils/logger.h\"\n#include <cuda_runtime.h>\n\nnamespace turbomind {\n\nLogger& Logger::getLogger()\n{\n    thread_local Logger instance;\n    return instance;\n}\n\nLogger::Logger()\n{\n    char* is_first_rank_only_char = std::getenv(\"TM_LOG_FIRST_RANK_ONLY\");\n    bool  is_first_rank_only =\n        (is_first_rank_only_char != nullptr && std::string(is_first_rank_only_char) == \"ON\") ? true : false;\n\n    int device_id;\n    cudaGetDevice(&device_id);\n\n    char* level_name = std::getenv(\"TM_LOG_LEVEL\");\n    if (level_name != nullptr) {\n        std::map<std::string, Level> name_to_level = {\n            {\"TRACE\", TRACE},\n            {\"DEBUG\", DEBUG},\n            {\"INFO\", INFO},\n            {\"WARNING\", WARNING},\n            {\"ERROR\", ERROR},\n        };\n        auto level = name_to_level.find(level_name);\n        // If TM_LOG_FIRST_RANK_ONLY=ON, set LOG LEVEL of other device to ERROR\n        if (is_first_rank_only && device_id != 0) {\n            level = name_to_level.find(\"ERROR\");\n        }\n        if (level != name_to_level.end()) {\n            setLevel(level->second);\n        }\n        else {\n            fprintf(stderr,\n                    \"[TM][WARNING] Invalid logger level TM_LOG_LEVEL=%s. \"\n                    \"Ignore the environment variable and use a default \"\n                    \"logging level.\\n\",\n                    level_name);\n            level_name = nullptr;\n        }\n    }\n}\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/utils/logger.h",
    "content": "/*\n * Copyright (c) 2022-2023, NVIDIA CORPORATION.  All rights reserved.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#pragma once\n\n#include <cstdlib>\n#include <map>\n#include <string>\n\n#include \"src/turbomind/utils/string_utils.h\"\n\nnamespace turbomind {\n\n// cub.cuh brings windows.h\n// should be included after cub.cuh\n#ifdef ERROR\n#undef ERROR\n#endif\n\nclass Logger {\n\npublic:\n    enum Level\n    {\n        TRACE   = 0,\n        DEBUG   = 10,\n        INFO    = 20,\n        WARNING = 30,\n        ERROR   = 40\n    };\n\n    static Logger& getLogger();\n    Logger(Logger const&) = delete;\n    void operator=(Logger const&) = delete;\n\n    template<typename... Args>\n    void log(const Level level, const std::string format, const Args&... args)\n    {\n        if (level_ <= level) {\n            std::string fmt = getPrefix(level) + format + \"\\n\";\n            // FILE*       out    = level_ < WARNING ? stdout : stderr;\n            std::string logstr = fmtstr(fmt, args...);\n            fprintf(stderr, \"%s\", logstr.c_str());\n        }\n    }\n\n    template<typename... Args>\n    void log(const Level level, const int rank, const std::string format, const Args&... args)\n    {\n        if (level_ <= level) {\n            std::string fmt = getPrefix(level, rank) + format + \"\\n\";\n            // FILE*       out    = level_ < WARNING ? stdout : stderr;\n            std::string logstr = fmtstr(fmt, args...);\n            fprintf(stderr, \"%s\", logstr.c_str());\n        }\n    }\n\n    void setLevel(const Level level)\n    {\n        level_ = level;\n        log(DEBUG, \"Set logger level by %s\", getLevelName(level).c_str());\n    }\n\n    int getLevel() const\n    {\n        return level_;\n    }\n\nprivate:\n    const std::string                              PREFIX      = \"[TM]\";\n    const std::map<const Level, const std::string> level_name_ = {\n        {TRACE, \"TRACE\"}, {DEBUG, \"DEBUG\"}, {INFO, \"INFO\"}, {WARNING, \"WARNING\"}, {ERROR, \"ERROR\"}};\n\n#ifndef NDEBUG\n    const Level DEFAULT_LOG_LEVEL = DEBUG;\n#else\n    const Level DEFAULT_LOG_LEVEL = INFO;\n#endif\n    Level level_ = DEFAULT_LOG_LEVEL;\n\n    Logger();\n\n    inline const std::string getLevelName(const Level level)\n    {\n        return level_name_.at(level);\n    }\n\n    inline const std::string getPrefix(const Level level)\n    {\n        return PREFIX + \"[\" + getLevelName(level) + \"] \";\n    }\n\n    inline const std::string getPrefix(const Level level, const int rank)\n    {\n        return PREFIX + \"[\" + getLevelName(level) + \"][\" + std::to_string(rank) + \"] \";\n    }\n};\n\n#define TM_LOG(level, ...)                                                                                             \\\n    do {                                                                                                               \\\n        if (turbomind::Logger::getLogger().getLevel() <= level) {                                                      \\\n            turbomind::Logger::getLogger().log(level, __VA_ARGS__);                                                    \\\n        }                                                                                                              \\\n    } while (0)\n\n#define TM_LOG_TRACE(...) TM_LOG(turbomind::Logger::TRACE, __VA_ARGS__)\n#define TM_LOG_DEBUG(...) TM_LOG(turbomind::Logger::DEBUG, __VA_ARGS__)\n#define TM_LOG_INFO(...) TM_LOG(turbomind::Logger::INFO, __VA_ARGS__)\n#define TM_LOG_WARNING(...) TM_LOG(turbomind::Logger::WARNING, __VA_ARGS__)\n#define TM_LOG_ERROR(...) TM_LOG(turbomind::Logger::ERROR, __VA_ARGS__)\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/utils/memory_utils.cu",
    "content": "/*\n * Copyright (c) 2019-2023, NVIDIA CORPORATION.  All rights reserved.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include \"src/turbomind/macro.h\"\n#include \"src/turbomind/utils/cuda_utils.h\"\n#include \"src/turbomind/utils/memory_utils.h\"\n\nnamespace turbomind {\n\ntemplate<typename T_OUT, typename T_IN>\n__global__ void transpose102(T_OUT* dst, T_IN* src, const int dim0, const int dim1, const int dim2)\n{\n    // src permutation: [0, 1, 2]\n    // dst permutation: [1, 0, 2]\n    for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < dim0 * dim1 * dim2; tid += blockDim.x * gridDim.x) {\n        int       tmp_idx                                           = tid;\n        const int dim_2_idx                                         = tmp_idx % dim2;\n        tmp_idx                                                     = (tmp_idx - dim_2_idx) / dim2;\n        const int dim_1_idx                                         = tmp_idx % dim1;\n        tmp_idx                                                     = (tmp_idx - dim_1_idx) / dim1;\n        const int dim_0_idx                                         = tmp_idx % dim0;\n        dst[dim_1_idx * dim0 * dim2 + dim_0_idx * dim2 + dim_2_idx] = src[tid];\n    }\n}\n\ntemplate<typename T>\nvoid invokeInPlaceTranspose102(\n    T* data, T* workspace, const int dim0, const int dim1, const int dim2, bool copy, cudaStream_t stream)\n{\n    // copy data to workspace, and then transpose from workspace to data\n    // Note that this kernel is used for pre-processing and not very efficient.\n    const size_t count = dim0 * dim1 * dim2;\n    if (copy) {\n        check_cuda_error(cudaMemcpyAsync(workspace, data, sizeof(T) * count, cudaMemcpyDefault, stream));\n    }\n    const int block = 512;\n    const int grid  = std::min((count + block - 1) / block, (size_t)8192);\n    transpose102<<<grid, block, 0, stream>>>(data, workspace, dim0, dim1, dim2);\n}\n\ntemplate void invokeInPlaceTranspose102(uint16_t*    data,\n                                        uint16_t*    workspace,\n                                        const int    dim0,\n                                        const int    dim1,\n                                        const int    dim2,\n                                        bool         copy,\n                                        cudaStream_t stream);\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/utils/memory_utils.h",
    "content": "/*\n * Copyright (c) 2019-2023, NVIDIA CORPORATION.  All rights reserved.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#pragma once\n\n#include <cuda_runtime.h>\n\nnamespace turbomind {\n\ntemplate<typename T>\nvoid invokeInPlaceTranspose102(\n    T* data, T* workspace, const int dim0, const int dim1, const int dim2, bool copy = true, cudaStream_t stream = 0);\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/utils/metrics.h",
    "content": "#pragma once\n\n#include <atomic>\n#include <chrono>\n#include <cstdint>\n#include <ostream>\n\nnamespace turbomind {\n\nstruct ScheduleMetrics {\n    // sequences\n    int total_seqs;    // the number of received sequence\n    int active_seqs;   // the number of active sequence\n    int waiting_seqs;  // the number of waiting sequence\n\n    // kv block usage\n    int total_blocks;   // the number of kv blocks\n    int active_blocks;  // the number of active kv blocks\n    int cached_blocks;  // the number of cached kv blocks\n    int free_blocks;    // the number of free kv blocks\n};\n\nstruct RequestMetrics {\n    std::atomic<int64_t> enqueue_time{};    // when a request is enqued\n    std::atomic<int64_t> scheduled_time{};  // when a request is scheduled for inference\n\n    static int64_t timestamp()\n    {\n        // Get current timestamp in microseconds since Unix epoch\n        // system_clock uses wall-clock time (matches Python's time.time())\n        return std::chrono::duration_cast<std::chrono::microseconds>(\n                   std::chrono::system_clock::now().time_since_epoch())\n            .count();\n    }\n};\n\ninline std::ostream& operator<<(std::ostream& os, const ScheduleMetrics& m)\n{\n    os << \"ScheduleMetrics { \";\n    os << \"total_seqs=\" << m.total_seqs;\n    os << \", active_seqs=\" << m.active_seqs;\n    os << \", waiting_seqs=\" << m.waiting_seqs;\n    os << \", total_blocks=\" << m.total_blocks;\n    os << \", cached_blocks=\" << m.cached_blocks;\n    os << \", free_blocks=\" << m.free_blocks;\n    os << \" }\";\n    return os;\n}\n\ninline std::ostream& operator<<(std::ostream& os, const RequestMetrics& m)\n{\n    os << \"RequestMetrics { \";\n    os << \"enqueue_time=\" << m.enqueue_time.load(std::memory_order_relaxed);\n    os << \", scheduled_time=\" << m.scheduled_time.load(std::memory_order_relaxed);\n    os << \" }\";\n    return os;\n}\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/utils/monotonic.h",
    "content": "#pragma once\n\n#include <cstdint>\n#include <cstdlib>\n#include <utility>\n\nnamespace turbomind {\n\nclass Monotonic {\npublic:\n    Monotonic(void* base, size_t alignment = 256): ptr_{base}, alignment_{alignment}\n    {\n        ptr_ = align(ptr_);\n    }\n\n    template<class T>\n    void operator()(T** ptr, size_t numel) noexcept\n    {\n        *ptr = (T*)std::exchange(ptr_, align((T*)ptr_ + numel));\n    }\n\n    void* ptr() const noexcept\n    {\n        return ptr_;\n    }\n\nprivate:\n    template<class T>\n    void* align(T* p)\n    {\n        static_assert(sizeof(T*) == sizeof(uintptr_t));\n        auto x = reinterpret_cast<uintptr_t>(p);\n        if (auto remainder = x % alignment_) {\n            x += alignment_ - remainder;\n        }\n        return reinterpret_cast<void*>(x);\n    }\n\n    void*  ptr_;\n    size_t alignment_;\n};\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/utils/nvtx_utils.cc",
    "content": "/*\n * Copyright (c) 2021-2023, NVIDIA CORPORATION.  All rights reserved.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include <iostream>\n\n#include \"nvtx_utils.h\"\n#ifdef USE_NVTX\n#include \"nvtx3/nvToolsExt.h\"\n#endif\n\nnamespace ft_nvtx {\nstd::string getScope()\n{\n    return scope;\n}\nvoid addScope(std::string name)\n{\n    scope = scope + name + \"/\";\n    return;\n}\nvoid setScope(std::string name)\n{\n    scope = name + \"/\";\n    return;\n}\nvoid resetScope()\n{\n    scope = \"\";\n    return;\n}\nvoid setDeviceDomain(int deviceId)\n{\n    domain = deviceId;\n    return;\n}\nvoid resetDeviceDomain()\n{\n    domain = 0;\n    return;\n}\nint getDeviceDomain()\n{\n    return domain;\n}\n\nbool isEnableNvtx()\n{\n    if (!has_read_nvtx_env) {\n        static char* ft_nvtx_env_char = std::getenv(\"FT_NVTX\");\n        is_enable_ft_nvtx = (ft_nvtx_env_char != nullptr && std::string(ft_nvtx_env_char) == \"ON\") ? true : false;\n        has_read_nvtx_env = true;\n    }\n    return is_enable_ft_nvtx;\n}\n\nvoid ftNvtxRangePush(std::string name)\n{\n#ifdef USE_NVTX\n    nvtxStringHandle_t    nameId      = nvtxDomainRegisterStringA(NULL, (getScope() + name).c_str());\n    nvtxEventAttributes_t eventAttrib = {0};\n    eventAttrib.messageType           = NVTX_MESSAGE_TYPE_REGISTERED;\n    eventAttrib.message.registered    = nameId;\n    eventAttrib.payloadType           = NVTX_PAYLOAD_TYPE_INT32;\n    eventAttrib.payload.iValue        = getDeviceDomain();\n    nvtxRangePushEx(&eventAttrib);\n#endif\n}\n\nvoid ftNvtxRangePop()\n{\n#ifdef USE_NVTX\n    nvtxRangePop();\n#endif\n}\n\n}  // namespace ft_nvtx\n"
  },
  {
    "path": "src/turbomind/utils/nvtx_utils.h",
    "content": "/*\n * Copyright (c) 2021-2023, NVIDIA CORPORATION.  All rights reserved.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#pragma once\n\nnamespace ft_nvtx {\nstatic std::string scope;\nstd::string        getScope();\nvoid               addScope(std::string name);\nvoid               setScope(std::string name);\nvoid               resetScope();\nstatic int         domain = 0;\nvoid               setDeviceDomain(int deviceId);\nint                getDeviceDomain();\nvoid               resetDeviceDomain();\nbool               isEnableNvtx();\n\nstatic bool has_read_nvtx_env = false;\nstatic bool is_enable_ft_nvtx = false;\nvoid        ftNvtxRangePush(std::string name);\nvoid        ftNvtxRangePop();\n}  // namespace ft_nvtx\n\n#define PUSH_RANGE(name)                                                                                               \\\n    {                                                                                                                  \\\n        if (ft_nvtx::isEnableNvtx()) {                                                                                 \\\n            ft_nvtx::ftNvtxRangePush(name);                                                                            \\\n        }                                                                                                              \\\n    }\n\n#define POP_RANGE                                                                                                      \\\n    {                                                                                                                  \\\n        if (ft_nvtx::isEnableNvtx()) {                                                                                 \\\n            ft_nvtx::ftNvtxRangePop();                                                                                 \\\n        }                                                                                                              \\\n    }\n"
  },
  {
    "path": "src/turbomind/utils/parser.cc",
    "content": "// Copyright (c) OpenMMLab. All rights reserved.\n\n#include <iostream>\n#include <regex>\n#include <string>\n#include <vector>\n\nnamespace turbomind {\n\nstd::vector<std::pair<std::string, std::string>> ParseArgsList(const std::string& str)\n{\n    const std::regex regex(R\"((\\w+)=([^,\\[\\(]+|\\[.*\\]|\\(.*\\)))\");\n\n    std::sregex_iterator beg(str.begin(), str.end(), regex);\n    std::sregex_iterator end{};\n\n    std::vector<std::pair<std::string, std::string>> ret;\n    for (auto it = beg; it != end; ++it) {\n        std::smatch match = *it;\n        ret.emplace_back(match[1], match[2]);\n    }\n\n    return ret;\n}\n\nstd::vector<std::string> ParseListOrTuple(const std::string& str)\n{\n    const std::regex regex(R\"([,\\[\\]\\(\\)]+)\");\n\n    std::vector<std::string> ret;\n    std::copy_if(std::sregex_token_iterator(str.begin(), str.end(), regex, -1),\n                 std::sregex_token_iterator{},\n                 std::back_inserter(ret),\n                 [](const std::string& s) { return !s.empty(); });\n\n    return ret;\n}\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/utils/parser.h",
    "content": "#include <string>\n#include <vector>\n\nnamespace turbomind {\n\nstd::vector<std::pair<std::string, std::string>> ParseArgsList(const std::string& str);\n\nstd::vector<std::string> ParseListOrTuple(const std::string& str);\n\ninline void Parse(int& value, const std::string& str)\n{\n    value = std::stoi(str);\n}\n\ninline void Parse(float& value, const std::string& str)\n{\n    value = std::stof(str);\n}\n\ntemplate<class T>\nvoid Parse(std::vector<T>& xs, const std::string& str)\n{\n    const auto ss = ParseListOrTuple(str);\n    for (const auto& s : ss) {\n        xs.emplace_back();\n        Parse(xs.back(), s);\n    }\n}\n\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/utils/string_utils.h",
    "content": "/*\n * Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#pragma once\n\n#include <memory>   // std::make_unique\n#include <sstream>  // std::stringstream\n#include <string>\n#include <vector>\n\nnamespace turbomind {\n\ntemplate<typename... Args>\ninline std::string fmtstr(const std::string& format, Args... args)\n{\n    // This function came from a code snippet in stackoverflow under cc-by-1.0\n    //   https://stackoverflow.com/questions/2342162/stdstring-formatting-like-sprintf\n\n    // Disable format-security warning in this function.\n#if defined(_MSC_VER)  // for visual studio\n#pragma warning(push)\n#pragma warning(warning(disable : 4996))\n#elif defined(__GNUC__) || defined(__clang__)  // for gcc or clang\n#pragma GCC diagnostic push\n#pragma GCC diagnostic ignored \"-Wformat-security\"\n#endif\n    int size_s = std::snprintf(nullptr, 0, format.c_str(), args...) + 1;  // Extra space for '\\0'\n    if (size_s <= 0) {\n        throw std::runtime_error(\"Error during formatting.\");\n    }\n    auto size = static_cast<size_t>(size_s);\n    auto buf  = std::make_unique<char[]>(size);\n    std::snprintf(buf.get(), size, format.c_str(), args...);\n#if defined(_MSC_VER)\n#pragma warning(pop)\n#elif defined(__GNUC__) || defined(__clang__)\n#pragma GCC diagnostic pop\n#endif\n    return std::string(buf.get(), buf.get() + size - 1);  // We don't want the '\\0' inside\n}\n\ntemplate<typename T>\ninline std::string vec2str(std::vector<T> vec)\n{\n    std::stringstream ss;\n    ss << \"(\";\n    if (!vec.empty()) {\n        for (size_t i = 0; i < vec.size() - 1; ++i) {\n            ss << vec[i] << \", \";\n        }\n        ss << vec.back();\n    }\n    ss << \")\";\n    return ss.str();\n}\n\ntemplate<typename T>\ninline std::string arr2str(T* arr, size_t size)\n{\n    std::stringstream ss;\n    ss << \"(\";\n    for (size_t i = 0; i < size - 1; ++i) {\n        ss << arr[i] << \", \";\n    }\n    if (size > 0) {\n        ss << arr[size - 1];\n    }\n    ss << \")\";\n    return ss.str();\n}\n}  // namespace turbomind\n"
  },
  {
    "path": "src/turbomind/utils/test_utils.h",
    "content": "/*\n * Copyright (c) 2022 NVIDIA CORPORATION.  All rights reserved.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#pragma once\n\n#include <cmath>\n#include <cuda.h>\n#include <cuda_runtime_api.h>\n\nnamespace turbomind {\n\n#define TIMEIT(print, n, stream, fn, ...)                                                                              \\\n    ({                                                                                                                 \\\n        cudaEvent_t _macro_event_start, _macro_event_stop;                                                             \\\n        cudaEventCreate(&_macro_event_start);                                                                          \\\n        cudaEventCreate(&_macro_event_stop);                                                                           \\\n        cudaEventRecord(_macro_event_start, stream);                                                                   \\\n        for (int i = 0; i < n; i++) {                                                                                  \\\n            fn(__VA_ARGS__);                                                                                           \\\n        }                                                                                                              \\\n        cudaEventRecord(_macro_event_stop, stream);                                                                    \\\n        cudaStreamSynchronize(stream);                                                                                 \\\n        float ms = 0.0f;                                                                                               \\\n        cudaEventElapsedTime(&ms, _macro_event_start, _macro_event_stop);                                              \\\n        ms /= n;                                                                                                       \\\n        if (print)                                                                                                     \\\n            printf(\"[TIMEIT] \" #fn \": %.2fµs\\n\", ms * 1000);                                                           \\\n        ms;                                                                                                            \\\n    })\n\ntemplate<typename T>\nstruct rel_abs_diff {\n    T operator()(const T& lhs, const T& rhs) const\n    {\n        return lhs == 0 ? 0 : static_cast<T>(fabs(lhs - rhs) / fabs(lhs));\n    }\n};\n\ntemplate<typename T>\nstruct abs_diff {\n    T operator()(const T& lhs, const T& rhs) const\n    {\n        return static_cast<T>(fabs(lhs - rhs));\n    }\n};\n\n}  // namespace turbomind\n"
  },
  {
    "path": "tests/csrc/CMakeLists.txt",
    "content": "# Copyright (c) 2021-2023, NVIDIA CORPORATION.  All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nadd_subdirectory(unittests)\n"
  },
  {
    "path": "tests/csrc/unittests/CMakeLists.txt",
    "content": "# Copyright (c) 2021-2023, NVIDIA CORPORATION.  All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# GoogleTest Preparation - Code block copied from\n#   https://google.github.io/googletest/quickstart-cmake.html\ninclude(FetchContent)\nFetchContent_Declare(\n  googletest\n  GIT_REPOSITORY https://github.com/google/googletest.git\n  GIT_TAG release-1.12.1\n)\n\nfind_package(CUDAToolkit REQUIRED)\n\nif (NOT MSVC)\n  add_definitions(-DTORCH_CUDA=1)\nendif()\n\n# For Windows: Prevent overriding the parent project's compiler/linker settings\nset(gtest_force_shared_crt ON CACHE BOOL \"\" FORCE)\nFetchContent_MakeAvailable(googletest)\n\nadd_executable(unittest\n    test_logprob_kernels.cu\n    test_penalty_kernels.cu\n    test_sampling_kernels.cu\n    test_sampling_layer.cu\n)\n\n# automatic discovery of unit tests\ntarget_link_libraries(unittest PUBLIC \"${TORCH_LIBRARIES}\" gtest_main)\ntarget_compile_features(unittest PRIVATE cxx_std_14)\n\n# Sorted by alphabetical order of test name.\ntarget_link_libraries(  # Libs for test_attention_kernels\n  unittest PUBLIC\n    CUDA::cudart CUDA::curand\n    gpt_kernels gtest memory_utils tensor unfused_attention_kernels cuda_utils logger)\ntarget_link_libraries(  # Libs for test_logprob_kernels\n  unittest PUBLIC\n    CUDA::cudart\n    logprob_kernels memory_utils cuda_utils logger)\ntarget_link_libraries(  # Libs for test_penalty_kernels\n  unittest PUBLIC\n    CUDA::cublas CUDA::cublasLt CUDA::cudart\n    sampling_penalty_kernels memory_utils cuda_utils logger)\ntarget_link_libraries(  # Libs for test_sampling_kernel\n  unittest PUBLIC\n    CUDA::cudart\n    sampling_topk_kernels sampling_topp_kernels memory_utils tensor cuda_utils logger)\ntarget_link_libraries(  # Libs for test_sampling_layer\n  unittest PUBLIC\n    CUDA::cublas CUDA::cublasLt CUDA::cudart\n    cublasMMWrapper memory_utils\n    DynamicDecodeLayer cuda_utils logger\n)\ntarget_link_libraries(  # Libs for test_tensor\n  unittest PUBLIC cuda_utils logger)\n"
  },
  {
    "path": "tests/csrc/unittests/gtest_utils.h",
    "content": "#include <algorithm>   // std::fill_n\n#include <iostream>    // snprintf\n#include <math.h>      // expf, log\n#include <stdlib.h>    // rand\n#include <string>      // std::string\n#include <vector>      // std::vector\n\n#include <cuda_runtime.h>\n#include <gtest/gtest.h>\n\n#include \"src/turbomind/utils/allocator.h\"\n#include \"src/turbomind/utils/memory_utils.h\"\n#include \"src/turbomind/utils/Tensor.h\"\n#include \"src/turbomind/utils/logger.h\"\n\nnamespace ft = turbomind;\n\nnamespace {\n\n#define EPSILON (1e-20)\n\nbool almostEqual(float a, float b, float atol = 1e-5, float rtol = 1e-8)\n{\n    // Params: a = value to compare and b = reference\n    // This function follows implementation of numpy.isclose(), which checks\n    //   abs(a - b) <= (atol + rtol * abs(b)).\n    // Note that the inequality above is asymmetric where b is considered as\n    // a reference value. To account into both absolute/relative errors, it\n    // uses absolute tolerance and relative tolerance at the same time. The\n    // default values of atol and rtol borrowed from numpy.isclose(). For the\n    // case of nan value, the result will be true.\n    if (isnan(a) && isnan(b)) {\n        return true;\n    }\n    if (isinf(a) && isinf(b) && (a > 0 && b > 0 || a < 0 && b < 0)) {\n        return true;\n    }\n    return fabs(a - b) <= (atol + rtol * fabs(b));\n}\n\ntemplate<typename T>\nbool checkResult(std::string name, T* out, T*ref, size_t size, float atol, float rtol) {\n    size_t failures = 0;\n    float relative_gap = 0.0f;;\n\n    for (size_t i = 0; i < size; ++i) {\n        // The values for the output and the reference.\n        float a = (float)out[i];\n        float b = (float)ref[i];\n\n        bool ok = almostEqual(a, b, atol, rtol);\n        // Print the error.\n        if (!ok && failures < 4) {\n            TM_LOG_ERROR(\">> invalid result for i=%lu:\", i);\n            TM_LOG_ERROR(\">>    found......: %10.6f\", a);\n            TM_LOG_ERROR(\">>    expected...: %10.6f\", b);\n            TM_LOG_ERROR(\">>    error......: %.6f\", fabsf(a - b));\n            TM_LOG_ERROR(\">>    tol........: %.6f\", atol + rtol * fabs(b));\n        }\n        // Update the number of failures.\n        failures += ok ? 0 : 1;\n        // Update the relative gap.\n        relative_gap += fabsf(a - b) / (fabsf(b) + EPSILON);\n    }\n\n    relative_gap /= size;\n\n    // Allow not matched up to 1% elements.\n    size_t tol_failures = (size_t)(0.01 * size);\n    if (failures > tol_failures) {\n        TM_LOG_ERROR(\"%s (failures: %.2f%% atol: %.2e rtol: %.2e rel_gap: %.2e%%)\",\n                     name.c_str(), 100. * failures / size, atol, rtol, 100. * relative_gap);\n    }\n    return failures <= tol_failures;\n}\n\ntemplate<typename T>\nbool checkResult(std::string name, T* out, T* ref, size_t size,\n                 bool device_out = true, bool device_ref = false)\n{\n    bool is_fp32 = sizeof(T) == 4;\n    float atol = is_fp32 ? 1e-4f : 1e-3f;\n    float rtol = is_fp32 ? 1e-2f : 1e-1f;\n\n    T* h_out = nullptr;\n    if (device_out) {\n        h_out = new T[size];\n        cudaMemcpy(h_out, out, sizeof(T) * size, cudaMemcpyDeviceToHost);\n        out = h_out;\n    }\n    T* h_ref = nullptr;\n    if (device_ref) {\n        h_ref = new T[size];\n        cudaMemcpy(h_ref, ref, sizeof(T) * size, cudaMemcpyDeviceToHost);\n        ref = h_ref;\n    }\n    bool is_ok = checkResult(name, out, ref, size, atol, rtol);\n    if (h_out != nullptr){\n        delete[] h_out;\n    }\n    if (h_ref != nullptr) {\n        delete[] h_ref;\n    }\n    return is_ok;\n}\n\ntemplate<typename T>\nvoid initRandom(T* ptr, size_t size, float minval, float maxval) {\n    for (size_t i = 0; i < size; ++i) {\n        float val = static_cast<float>(rand()) / static_cast<float>(RAND_MAX);\n        val *= (maxval - minval);\n        ptr[i] = static_cast<T>(minval + val);\n    }\n}\n\nvoid initRandomInt(int* ptr, size_t size, int minval, int maxval) {\n    assert(minval < maxval);\n    int mod = maxval - minval;\n    for (size_t i = 0; i < size; ++i) {\n        ptr[i] = minval + rand() % mod;\n    }\n}\n\ntemplate<typename T>\nvoid tile(T* x, int m, int n) {\n    for (int i = 1; i < m; ++i) {\n        for (int j = 0; j < n; ++j) {\n            x[i * n + j] = x[j];\n        }\n    }\n}\n\ntemplate<typename T>\nvoid tile(T* dst, T* src, int m, int n) {\n    for (int i = 1; i < m; ++i) {\n        for (int j = 0; j < n; ++j) {\n            dst[i * n + j] = src[j];\n        }\n    }\n}\n\n// for the safe arithmetic functions in host.\nnamespace math {\ntemplate<typename T>\ninline T add(T a, T b)\n{\n    return static_cast<T>((float)a + (float)b);\n}\n\ntemplate<typename T>\ninline T mul(T a, T b)\n{\n    return static_cast<T>((float)a * (float)b);\n}\n\ntemplate<typename T>\ninline T fma(T a, T b, T c)\n{\n    return static_cast<T>((float)a * (float)b + (float)c);\n}\n}\n\n#ifdef ENABLE_FP32\n#ifdef ENABLE_BF16\ntypedef testing::Types<float, half, __nv_bfloat16> SamplingTypes;\n#else\ntypedef testing::Types<float, half> SamplingTypes;\n#endif\n#else\n#ifdef ENABLE_BF16\ntypedef testing::Types<half, __nv_bfloat16>        SamplingTypes;\n#else\ntypedef testing::Types<half> SamplingTypes;\n#endif\n#endif\n\ntypedef testing::Types<float> FloatType;\ntypedef testing::Types<float, half> FloatAndHalfTypes;\n#ifndef ENABLE_BF16\ntypedef FloatAndHalfTypes SupportTypes;\n#else\ntypedef testing::Types<float, half, __nv_bfloat16> FloatHalfBf16Types;\ntypedef FloatHalfBf16Types SupportTypes;\n#endif\n\nclass FtTestBase: public testing::Test {\npublic:\n    void SetUp() override\n    {\n        int device = 0;\n        cudaGetDevice(&device);\n        cudaStreamCreate(&stream);\n        allocator = new ft::Allocator<ft::AllocatorType::CUDA>(device);\n        allocator->setStream(stream);\n    }\n\n    void TearDown() override\n    {\n        // Automatically allocated CPU buffers should be released at the end of a test.\n        // We don't need to care GPU buffers allocated by Allocator because they are\n        // managed by the allocator.\n        for (auto& buffer : allocated_cpu_buffers) {\n            free(buffer);\n        }\n        allocated_cpu_buffers.clear();\n        delete allocator;\n        cudaStreamDestroy(stream);\n    }\n\nprotected:\n    cudaStream_t                            stream;\n    ft::Allocator<ft::AllocatorType::CUDA>* allocator;\n    std::vector<void*>                      allocated_cpu_buffers;\n\n    // Utilities to easily handle tensor instances in test cases.\n\n    ft::Tensor createTensor(const ft::MemoryType mtype,\n                            const ft::DataType dtype,\n                            const std::vector<size_t> shape)\n    {\n        size_t n_elmts  = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<size_t>());\n        size_t buf_size = ft::Tensor::getTypeSize(dtype) * n_elmts;\n\n        void* data = nullptr;\n        if (mtype == ft::MEMORY_CPU || mtype == ft::MEMORY_CPU_PINNED) {\n            data = malloc(buf_size);\n            allocated_cpu_buffers.push_back(data);\n        }\n        else {\n            data = allocator->malloc(buf_size);\n        }\n        return ft::Tensor(mtype, dtype, shape, data);\n    };\n\n    template<typename T>\n    ft::Tensor toHost(ft::Tensor& device_tensor)\n    {\n        if (device_tensor.data == nullptr) {\n            return ft::Tensor();\n        }\n        ft::Tensor host_tensor = createTensor(ft::MEMORY_CPU, device_tensor.type, device_tensor.shape);\n        ft::cudaAutoCpy(host_tensor.getPtr<T>(), device_tensor.getPtr<T>(), host_tensor.size(), stream);\n        cudaStreamSynchronize(stream);\n        return host_tensor;\n    };\n\n    template<typename T>\n    ft::Tensor toDevice(ft::Tensor& host_tensor)\n    {\n        if (host_tensor.data == nullptr) {\n            return ft::Tensor();\n        }\n        ft::Tensor device_tensor = createTensor(ft::MEMORY_GPU, host_tensor.type, host_tensor.shape);\n        ft::cudaAutoCpy(device_tensor.getPtr<T>(), host_tensor.getPtr<T>(), host_tensor.size(), stream);\n        return device_tensor;\n    };\n\n    void copyTensor(ft::Tensor& dst, ft::Tensor& src)\n    {\n        FT_CHECK_WITH_INFO(\n            src.sizeBytes() == dst.sizeBytes(),\n            ft::fmtstr(\"src and dst has different size (%ld != %ld)\", src.sizeBytes(), dst.sizeBytes()));\n        ft::cudaAutoCpy(dst.getPtr<char>(), src.getPtr<char>(), src.sizeBytes(), stream);\n        cudaStreamSynchronize(stream);\n    }\n\n};\n\n}\n"
  },
  {
    "path": "tests/csrc/unittests/test_logprob_kernels.cu",
    "content": "#include <assert.h>\n#include <float.h>\n#include <math.h>\n#include <stdexcept>\n#include <tuple>\n#include <vector>\n#ifdef __linux__\n#include <sys/time.h>\n#endif\n#include \"src/turbomind/kernels/logprob_kernels.h\"\n#include \"src/turbomind/utils/allocator.h\"\n#include \"src/turbomind/utils/cuda_utils.h\"\n#include \"src/turbomind/utils/logger.h\"\n#include \"src/turbomind/utils/memory_utils.h\"\n\n#include \"gtest_utils.h\"\n\nusing namespace turbomind;\n\n////////////////////////////////////////////////////////////////////////////////////\n\nstruct LogProbKernelTestParam {\n    size_t max_input_length;\n    size_t batch_size;\n    size_t vocab_size;\n    size_t beam_width;\n\n    std::string toString()\n    {\n        return fmtstr(\"LogProbKernelTestParam[max_input_length=%ld, batch=%ld, vocab=%ld, beam_width=%ld]\",\n                      max_input_length,\n                      batch_size,\n                      vocab_size,\n                      beam_width);\n    }\n};\n\n/////////////////////////////////// Unittests //////////////////////////////////////////\ntemplate<typename T>\nclass LogProbKernelTest: public FtTestBase {\n\nprotected:\n    void computeCumLogProbs(float*       cum_log_probs,\n                            float*       log_probs,\n                            const T*     logits,\n                            const int*   input_ids,\n                            const int*   input_lengths,\n                            const size_t max_input_length,\n                            const size_t batch_size,\n                            const size_t vocab_size,\n                            const size_t vocab_size_padded)\n    {\n        for (size_t step = 0; step < max_input_length; ++step) {\n            for (size_t i = 0; i < batch_size; ++i) {\n                if ((int)step == 0) {\n                    if (log_probs != nullptr) {\n                        log_probs[i] = 0.0f;\n                    }\n                    cum_log_probs[i] = 0.0f;\n                }\n                else if ((int)step < input_lengths[i]) {\n                    size_t   step_offset = (step - 1) * batch_size * vocab_size_padded;\n                    const T* vec         = logits + step_offset + i * vocab_size_padded;\n                    float    max_logits  = -FLT_MAX;\n                    for (size_t v = 0; v < vocab_size; ++v) {\n                        float val = static_cast<float>(vec[v]);\n                        if (val > max_logits) {\n                            max_logits = val;\n                        }\n                    }\n                    float sum = 0.0f;\n                    for (size_t v = 0; v < vocab_size; ++v) {\n                        sum += expf(static_cast<float>(vec[v]) - max_logits);\n                    }\n                    int   token_id = input_ids[step * batch_size + i];\n                    float log_prob = static_cast<float>(vec[token_id]) - max_logits - log(sum);\n                    if (log_probs != nullptr) {\n                        log_probs[step * batch_size + i] = log_prob;\n                    }\n                    cum_log_probs[i] += log_prob;\n                }\n            }\n        }\n    }\n\n    void computeCumLogProbsBatchFirst(float*       cum_log_probs,\n                                      float*       log_probs,\n                                      const T*     logits,\n                                      const int*   input_ids,\n                                      const int*   input_lengths,\n                                      const size_t max_input_length,\n                                      const size_t batch_size,\n                                      const size_t vocab_size,\n                                      const size_t vocab_size_padded)\n    {\n        for (size_t i = 0; i < batch_size; ++i) {\n            size_t batch_offset = i * max_input_length * vocab_size_padded;\n            for (size_t step = 0; step < max_input_length; ++step) {\n                if ((int)step == 0) {\n                    if (log_probs != nullptr) {\n                        log_probs[i * max_input_length] = 0.0f;\n                    }\n                    cum_log_probs[i] = 0.0f;\n                }\n                else if ((int)step < input_lengths[i]) {\n                    const T* vec        = logits + batch_offset + (step - 1) * vocab_size_padded;\n                    float    max_logits = -FLT_MAX;\n                    for (size_t v = 0; v < vocab_size; ++v) {\n                        float val = static_cast<float>(vec[v]);\n                        if (val > max_logits) {\n                            max_logits = val;\n                        }\n                    }\n                    float sum = 0.0f;\n                    for (size_t v = 0; v < vocab_size; ++v) {\n                        sum += expf(static_cast<float>(vec[v]) - max_logits);\n                    }\n                    int   token_id = input_ids[i * max_input_length + step];\n                    float log_prob = static_cast<float>(vec[token_id]) - max_logits - log(sum);\n                    if (log_probs != nullptr) {\n                        log_probs[i * max_input_length + step] = log_prob;\n                    }\n                    cum_log_probs[i] += log_prob;\n                }\n            }\n        }\n    }\n\npublic:\n    void runTest(LogProbKernelTestParam param)\n    {\n        size_t max_input_length = param.max_input_length;\n        size_t batchxbeam       = param.batch_size * param.beam_width;\n        size_t vocab_size       = param.vocab_size;\n        // Make multiple of 8 as GPT does.\n        size_t vocab_size_padded = static_cast<size_t>(ceil(vocab_size / 8.f) * 8);\n\n        // input values\n        T*   h_logits        = new T[max_input_length * batchxbeam * vocab_size];\n        int* h_input_ids     = new int[max_input_length * batchxbeam];\n        int* h_input_lengths = new int[batchxbeam];\n\n        // output buffers\n        float* expected_cum_log_probs = new float[batchxbeam];\n\n        // initialize host buffers\n        initRandom(h_logits, max_input_length * batchxbeam * vocab_size, -10.0f / vocab_size, -1.0f);\n        initRandomInt(h_input_ids, max_input_length * batchxbeam, 0, vocab_size);\n        initRandomInt(h_input_lengths, batchxbeam, 1, max_input_length + 1);\n        memset(expected_cum_log_probs, 0, sizeof(float) * batchxbeam);\n\n        // device buffers\n        T*   d_logits = reinterpret_cast<T*>(allocator->malloc(sizeof(T) * max_input_length * batchxbeam * vocab_size));\n        int* d_input_ids       = reinterpret_cast<int*>(allocator->malloc(sizeof(int) * max_input_length * batchxbeam));\n        int* d_input_lengths   = reinterpret_cast<int*>(allocator->malloc(sizeof(int) * batchxbeam));\n        float* d_cum_log_probs = reinterpret_cast<float*>(allocator->malloc(sizeof(float) * batchxbeam));\n\n        // initialize device buffers\n        cudaH2Dcpy(d_logits, h_logits, max_input_length * batchxbeam * vocab_size);\n        cudaH2Dcpy(d_input_ids, h_input_ids, max_input_length * batchxbeam);\n        cudaH2Dcpy(d_input_lengths, h_input_lengths, batchxbeam);\n        deviceFill(d_cum_log_probs, batchxbeam, 0.0f);\n\n        size_t workspace_size = sizeof(float) * max_input_length * batchxbeam;\n        void*  workspace      = allocator->malloc(workspace_size);\n        invokeLogProbFromLogits(d_cum_log_probs,\n                                d_logits,\n                                d_input_ids,\n                                d_input_lengths,\n                                max_input_length,\n                                batchxbeam,\n                                vocab_size,\n                                vocab_size_padded,\n                                workspace,\n                                workspace_size,\n                                stream,\n                                false);\n        computeCumLogProbs(expected_cum_log_probs,\n                           nullptr,\n                           h_logits,\n                           h_input_ids,\n                           h_input_lengths,\n                           max_input_length,\n                           batchxbeam,\n                           vocab_size,\n                           vocab_size_padded);\n        bool passed = checkResult(param.toString(), d_cum_log_probs, expected_cum_log_probs, batchxbeam);\n        EXPECT_TRUE(passed);\n\n        TM_LOG_DEBUG(\"free host buffers\");\n        delete[] expected_cum_log_probs;\n        delete[] h_input_lengths;\n        delete[] h_input_ids;\n        delete[] h_logits;\n    }\n\n    void runBatchFirstTest(LogProbKernelTestParam param)\n    {\n        size_t max_input_length = param.max_input_length;\n        size_t batchxbeam       = param.batch_size * param.beam_width;\n        size_t vocab_size       = param.vocab_size;\n        // Make multiple of 8 as GPT does.\n        size_t vocab_size_padded = static_cast<size_t>(ceil(vocab_size / 8.f) * 8);\n\n        // input values\n        T*   h_logits        = new T[max_input_length * batchxbeam * vocab_size_padded];\n        int* h_input_ids     = new int[max_input_length * batchxbeam];\n        int* h_input_lengths = new int[batchxbeam];\n\n        // output buffers\n        float* expected_cum_log_probs = new float[batchxbeam];\n\n        // initialize host buffers\n        initRandom(h_logits, max_input_length * batchxbeam * vocab_size_padded, -10.0f / vocab_size, -1.0f);\n        initRandomInt(h_input_ids, max_input_length * batchxbeam, 0, vocab_size);\n        initRandomInt(h_input_lengths, batchxbeam, 1, max_input_length + 1);\n        memset(expected_cum_log_probs, 0, sizeof(float) * batchxbeam);\n\n        // device buffers\n        T* d_logits =\n            reinterpret_cast<T*>(allocator->malloc(sizeof(T) * max_input_length * batchxbeam * vocab_size_padded));\n        int*   d_input_ids     = reinterpret_cast<int*>(allocator->malloc(sizeof(int) * max_input_length * batchxbeam));\n        int*   d_input_lengths = reinterpret_cast<int*>(allocator->malloc(sizeof(int) * batchxbeam));\n        float* d_cum_log_probs = reinterpret_cast<float*>(allocator->malloc(sizeof(float) * batchxbeam));\n\n        // initialize device buffers\n        cudaH2Dcpy(d_logits, h_logits, max_input_length * batchxbeam * vocab_size_padded);\n        cudaH2Dcpy(d_input_ids, h_input_ids, max_input_length * batchxbeam);\n        cudaH2Dcpy(d_input_lengths, h_input_lengths, batchxbeam);\n        check_cuda_error(cudaMemset(d_cum_log_probs, 0, sizeof(float) * batchxbeam));\n\n        size_t workspace_size = sizeof(float) * max_input_length * batchxbeam;\n        void*  workspace      = allocator->malloc(workspace_size);\n        invokeLogProbFromLogits(d_cum_log_probs,\n                                d_logits,\n                                d_input_ids,\n                                d_input_lengths,\n                                max_input_length,\n                                batchxbeam,\n                                vocab_size,\n                                vocab_size_padded,\n                                workspace,\n                                workspace_size,\n                                stream,\n                                true);\n\n        computeCumLogProbsBatchFirst(expected_cum_log_probs,\n                                     nullptr,\n                                     h_logits,\n                                     h_input_ids,\n                                     h_input_lengths,\n                                     max_input_length,\n                                     batchxbeam,\n                                     vocab_size,\n                                     vocab_size_padded);\n        std::string tag    = param.toString() + (std::is_same<T, float>::value ? \" (fp32)\" : \" (fp16)\");\n        bool        passed = checkResult(tag.c_str(), d_cum_log_probs, expected_cum_log_probs, batchxbeam);\n        EXPECT_TRUE(passed);\n\n        delete[] expected_cum_log_probs;\n        delete[] h_input_lengths;\n        delete[] h_input_ids;\n        delete[] h_logits;\n    }\n};\n\nTYPED_TEST_SUITE(LogProbKernelTest, FloatAndHalfTypes);\n\nTYPED_TEST(LogProbKernelTest, SingleStep)\n{\n    this->runTest({1, 32, 16, 1});\n}\n\nTYPED_TEST(LogProbKernelTest, AccumLongStep129)\n{\n    this->runTest({129, 8, 50211, 1});\n}\n\nTYPED_TEST(LogProbKernelTest, AccumLongStep1023)\n{\n    this->runTest({1023, 8, 5001, 1});\n}\n\nTYPED_TEST(LogProbKernelTest, AccumLongStep4096)\n{\n    this->runTest({4096, 8, 5001, 1});\n}\n\nTYPED_TEST(LogProbKernelTest, BatchFirstSingleStep)\n{\n    this->runBatchFirstTest({1, 32, 16, 1});\n}\n\nTYPED_TEST(LogProbKernelTest, BatchFirstAccumLongStep129)\n{\n    this->runBatchFirstTest({129, 8, 50211, 1});\n}\n\nTYPED_TEST(LogProbKernelTest, BatchFirstAccumLongStep1023)\n{\n    this->runBatchFirstTest({1023, 8, 5001, 1});\n}\n\nTYPED_TEST(LogProbKernelTest, BatchFirstAccumLongStep4096)\n{\n    this->runBatchFirstTest({4096, 8, 5001, 1});\n}\n"
  },
  {
    "path": "tests/csrc/unittests/test_penalty_kernels.cu",
    "content": "/*\n * Copyright (c) 2022-2023, NVIDIA CORPORATION.  All rights reserved.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#include <algorithm>  // std::min, std::max\n#include <iostream>   // snprintf\n#include <math.h>     // expf, log\n#include <stdexcept>\n#include <stdlib.h>  // rand\n#include <string>    // std::string\n#include <unordered_map>\n#include <vector>  // std::vector\n\n#include <cublasLt.h>\n#include <cublas_v2.h>\n#include <cuda_runtime.h>\n\n#include \"gtest_utils.h\"\n#include \"src/turbomind/kernels/penalty_types.h\"\n#include \"src/turbomind/kernels/sampling_penalty_kernels.h\"\n#include \"src/turbomind/utils/cuda_utils.h\"\n#include \"src/turbomind/utils/memory_utils.h\"\n\nusing namespace turbomind;\n\nstruct TemperatureTestParam {\n    size_t batch_size;\n    size_t vocab_size;\n    float* temperatures;\n    size_t temperatures_size;\n\n    std::string toString()\n    {\n        return fmtstr(\"TemperatureTestParam[batch=%ld, vocab=%ld, temperatures=%s]\",\n                      batch_size,\n                      vocab_size,\n                      arr2str(temperatures, temperatures_size).c_str());\n    }\n};\n\nsize_t pad_vocab_size(size_t vocab_size, size_t pad = 8)\n{\n    return (vocab_size + pad - 1) / pad * pad;\n}\n\ntemplate<typename T>\nvoid applyRepetitonPenalty(T*           logits,\n                           const int*   output_ids,\n                           const int*   input_lengths,\n                           const float  repetition_penalty,\n                           const size_t step,\n                           const size_t max_input_length,\n                           const size_t batch_size,\n                           const size_t vocab_size,\n                           const size_t vocab_size_padded)\n{\n    bool* penalized = new bool[vocab_size];\n    for (size_t i = 0; i < batch_size; ++i) {\n        std::fill_n(penalized, vocab_size, false);\n        size_t length = std::min<int>(step, input_lengths[i]);\n        size_t offset = i * vocab_size_padded;\n        for (size_t t = 0; t < step; ++t) {\n            if (t >= (size_t)input_lengths[i] && t < max_input_length) {\n                continue;\n            }\n            int token_id = output_ids[i + t * batch_size];\n            if (!penalized[token_id]) {\n                float logit = static_cast<float>(logits[offset + token_id]);\n                logits[offset + token_id] =\n                    static_cast<T>(logit < 0.0f ? logit * repetition_penalty : logit / repetition_penalty);\n                penalized[token_id] = true;\n            }\n        }\n    }\n    delete[] penalized;\n}\n\ntemplate<typename T>\nvoid batchApplyRepetitonPenalty(T*           logits,\n                                const int*   output_ids,\n                                const int*   input_lengths,\n                                const float* repetition_penalties,\n                                const size_t step,\n                                const size_t max_input_length,\n                                const size_t batch_size,\n                                const size_t vocab_size,\n                                const size_t vocab_size_padded)\n{\n    bool* penalized = new bool[vocab_size];\n    for (size_t i = 0; i < batch_size; ++i) {\n        float repetition_penalty = repetition_penalties[i];\n        std::fill_n(penalized, vocab_size, false);\n        size_t offset = i * vocab_size_padded;\n        for (size_t t = 0; t < step; ++t) {\n            if (t >= (size_t)input_lengths[i] && t < max_input_length) {\n                continue;\n            }\n            int token_id = output_ids[i + t * batch_size];\n            if (!penalized[token_id]) {\n                float logit = static_cast<float>(logits[offset + token_id]);\n                logits[offset + token_id] =\n                    static_cast<T>(logit < 0.0f ? logit * repetition_penalty : logit / repetition_penalty);\n                penalized[token_id] = true;\n            }\n        }\n    }\n    delete[] penalized;\n}\n\ntemplate<typename T>\nvoid initLogitsAndBias(\n    T* logits, T* bias, const size_t batch_size, const size_t vocab_size, const size_t vocab_size_padded)\n{\n    initRandom(logits, batch_size * vocab_size_padded, -5.0f, 5.0f);\n    if (bias != nullptr) {\n        initRandom(bias, vocab_size, -5.0f, 5.0f);\n    }\n    bool is_half = std::is_same<T, half>::value;\n    for (size_t i = 0; i < batch_size; ++i) {\n        for (size_t j = 0; j < vocab_size_padded; ++j) {\n            if (j >= vocab_size) {\n                logits[i * vocab_size_padded + j] = static_cast<T>(is_half ? -65504.f : -FLT_MAX);\n                if (bias != nullptr && i == 0) {\n                    bias[j] = (T)0.0f;\n                }\n            }\n        }\n    }\n}\n\n/////////////////////////////////// Tests //////////////////////////////////////////\n\ntemplate<typename T>\nclass TemperaturePenaltyTest: public FtTestBase {\nprotected:\n    // Set up test\n    size_t batch_size_;\n    size_t vocab_size_;\n    size_t vocab_size_padded_;\n\n    T* h_logits_;\n    T* h_bias_;\n    T* d_logits_;\n    T* d_bias_;\n\n    float* d_temperatures_;\n\n    void subsetup(TemperatureTestParam param)\n    {\n        batch_size_        = param.batch_size;\n        vocab_size_        = param.vocab_size;\n        vocab_size_padded_ = pad_vocab_size(vocab_size_);\n\n        h_logits_ = new T[batch_size_ * vocab_size_padded_];\n        h_bias_   = new T[vocab_size_padded_];\n        initLogitsAndBias(h_logits_, h_bias_, batch_size_, vocab_size_, vocab_size_padded_);\n\n        d_logits_ = reinterpret_cast<T*>(allocator->malloc(sizeof(T) * batch_size_ * vocab_size_padded_));\n        d_bias_   = reinterpret_cast<T*>(allocator->malloc(sizeof(T) * vocab_size_padded_));\n        cudaAutoCpy(d_logits_, h_logits_, batch_size_ * vocab_size_padded_, stream);\n        cudaAutoCpy(d_bias_, h_bias_, vocab_size_padded_, stream);\n        if (param.temperatures_size > 1) {\n            ASSERT_EQ(param.temperatures_size, param.batch_size) << \"Invalid test configuration.\";\n            d_temperatures_ = reinterpret_cast<float*>(allocator->malloc(sizeof(T) * param.temperatures_size));\n            cudaAutoCpy(d_temperatures_, param.temperatures, batch_size_, stream);\n        }\n    }\n\n    void subteardown()\n    {\n        delete[] h_logits_;\n        delete[] h_bias_;\n    }\n\n    void computeReference(T*           logits,\n                          const T*     bias,\n                          const float* temperatures,\n                          const size_t temperatures_size,\n                          const size_t batch_size,\n                          const size_t vocab_size,\n                          const size_t vocab_size_padded)\n    {\n        for (size_t i = 0; i < batch_size; ++i) {\n            float temperature = temperatures_size > 1 ? temperatures[i] : temperatures[0];\n            ASSERT_GT(temperature, 0.0f) << \"temperature should be positive but got \" << temperature;\n            for (size_t j = 0; j < vocab_size; ++j) {\n                size_t index = i * vocab_size_padded + j;\n                float  logit = static_cast<float>(logits[index]);\n                if (bias != nullptr) {\n                    logit += static_cast<float>(bias[j]);\n                }\n                logits[index] = static_cast<T>(logit / temperature);\n            }\n        }\n    }\n\npublic:\n    void runTest(TemperatureTestParam param)\n    {\n        subsetup(param);\n        // Do test\n        if (param.temperatures_size == 1) {\n            invokeApplyTemperaturePenalty(\n                d_logits_, d_bias_, param.temperatures[0], batch_size_, vocab_size_, vocab_size_padded_, stream);\n        }\n        else {\n            invokeBatchApplyTemperaturePenalty(\n                d_logits_, d_bias_, d_temperatures_, batch_size_, vocab_size_, vocab_size_padded_, stream);\n        }\n        computeReference(h_logits_,\n                         h_bias_,\n                         param.temperatures,\n                         param.temperatures_size,\n                         batch_size_,\n                         vocab_size_,\n                         vocab_size_padded_);\n        bool passed = checkResult(param.toString(), d_logits_, h_logits_, batch_size_ * vocab_size_padded_);\n        EXPECT_TRUE(passed);\n        subteardown();\n    }\n\n    void runConsistencyTest(TemperatureTestParam param)\n    {\n        // Set up test\n        ASSERT_EQ(param.temperatures_size, 1) << \"A consistency test assumes temperatures_size=1\";\n        subsetup(param);\n\n        // Run a single runtime value case.\n        invokeApplyTemperaturePenalty(\n            d_logits_, d_bias_, param.temperatures[0], batch_size_, vocab_size_, vocab_size_padded_, stream);\n\n        float  temperature    = param.temperatures[0];\n        float* h_temperatures = new float[batch_size_];\n        for (size_t i = 0; i < batch_size_; ++i) {\n            h_temperatures[i] = temperature;\n        }\n        d_temperatures_ = reinterpret_cast<float*>(allocator->malloc(sizeof(T) * batch_size_));\n        cudaAutoCpy(d_temperatures_, h_temperatures, batch_size_, stream);\n\n        T* d_logits_batch = reinterpret_cast<T*>(allocator->malloc(sizeof(T) * batch_size_ * vocab_size_padded_));\n        T* d_bias_batch   = reinterpret_cast<T*>(allocator->malloc(sizeof(T) * vocab_size_padded_));\n        cudaAutoCpy(d_logits_batch, h_logits_, batch_size_ * vocab_size_padded_, stream);\n        cudaAutoCpy(d_bias_batch, h_bias_, vocab_size_padded_, stream);\n\n        invokeBatchApplyTemperaturePenalty(\n            d_logits_batch, d_bias_batch, d_temperatures_, batch_size_, vocab_size_, vocab_size_padded_, stream);\n        bool passed =\n            checkResult(param.toString(), d_logits_, d_logits_batch, batch_size_ * vocab_size_padded_, true, true);\n        EXPECT_TRUE(passed);\n\n        // Tear down test\n        delete[] h_temperatures;\n        subteardown();\n    }\n};\n\n// Since a compiler doesn't correctly catch the use of a variable inside gtest,\n// we carefully suppress a compile warning message.\n#pragma nv_diag_suppress 177\n\nTYPED_TEST_SUITE(TemperaturePenaltyTest, testing::Types<__nv_bfloat16>);\n\nTYPED_TEST(TemperaturePenaltyTest, NoPenalty)\n{\n    float temperature = 1.0f;\n    this->runTest({6, 4, &temperature, 1});\n}\n\nTYPED_TEST(TemperaturePenaltyTest, LessThanOne)\n{\n    float temperature = 0.53f;\n    this->runTest({6, 4, &temperature, 1});\n}\n\nTYPED_TEST(TemperaturePenaltyTest, GreaterThaneOne)\n{\n    float temperature = 2.01f;\n    this->runTest({6, 4, &temperature, 1});\n}\n\nTYPED_TEST(TemperaturePenaltyTest, LargeVocab)\n{\n    float temperature = 2.01f;\n    this->runTest({6, 50001, &temperature, 1});\n}\n\nTYPED_TEST(TemperaturePenaltyTest, BatchNoPenalty)\n{\n    size_t batch_size   = 6;\n    float* temperatures = new float[batch_size];\n    for (size_t i = 0; i < batch_size; ++i) {\n        temperatures[i] = 1.0f;\n    }\n    this->runTest({batch_size, 4, temperatures, batch_size});\n}\n\nTYPED_TEST(TemperaturePenaltyTest, BatchLessThanOne)\n{\n    size_t batch_size   = 6;\n    float* temperatures = new float[batch_size];\n    for (size_t i = 0; i < batch_size; ++i) {\n        temperatures[i] = 0.53f;\n    }\n    this->runTest({batch_size, 4, temperatures, batch_size});\n}\n\nTYPED_TEST(TemperaturePenaltyTest, BatchGreaterThaneOne)\n{\n    size_t batch_size   = 6;\n    float* temperatures = new float[batch_size];\n    for (size_t i = 0; i < batch_size; ++i) {\n        temperatures[i] = 2.01f;\n    }\n    this->runTest({batch_size, 4, temperatures, batch_size});\n}\n\nTYPED_TEST(TemperaturePenaltyTest, BatchMixed)\n{\n    size_t batch_size   = 6;\n    float* temperatures = new float[batch_size];\n    for (size_t i = 0; i < batch_size; ++i) {\n        temperatures[i] = i % 2 == 0 ? 2.01f : 0.53f;\n    }\n    this->runTest({batch_size, 4, temperatures, batch_size});\n}\n\nTYPED_TEST(TemperaturePenaltyTest, Consistency)\n{\n    float temperature = 2.01f;\n    this->runConsistencyTest({6, 4, &temperature, 1});\n}\n\nstruct RepetitionPenaltyTestCase {\n    size_t                batch_size;\n    size_t                vocab_size;\n    size_t                max_input_length;\n    float*                repetition_penalties;\n    size_t                repetition_penalties_size;\n    RepetitionPenaltyType repetition_penalty_type;\n\n    std::string toString()\n    {\n        static const std::unordered_map<RepetitionPenaltyType, std::string> typestr_map{\n            {RepetitionPenaltyType::Additive, \"additive\"},\n            {RepetitionPenaltyType::Multiplicative, \"multiplicative\"},\n            {RepetitionPenaltyType::None, \"none\"}};\n        return fmtstr(\"RepetitionPenaltyTestCase[batch=%ld, vocab=%ld, max_input_length=%ld, \"\n                      \"repetition_penalties=%s, repetition_penalty_type=%s]\",\n                      batch_size,\n                      vocab_size,\n                      max_input_length,\n                      arr2str(repetition_penalties, repetition_penalties_size).c_str(),\n                      typestr_map.at(repetition_penalty_type).c_str());\n    }\n};\n\ntemplate<typename T>\nclass RepetitionPenaltyTest: public FtTestBase {\nprotected:\n    // Set up test\n    size_t batch_size_;\n    size_t vocab_size_;\n    size_t vocab_size_padded_;\n    size_t max_input_length_;\n    size_t sequence_length_;\n    size_t step_;\n\n    T*   h_logits_;\n    T*   h_bias_;\n    int* h_output_ids_;\n    int* h_input_lengths_;\n\n    T*   d_logits_;\n    T*   d_bias_;\n    int* d_output_ids_;\n    int* d_input_lengths_;\n    int* d_penalty_workspace_;\n\n    float* d_repetition_penalties_;\n\n    void subsetup(RepetitionPenaltyTestCase param)\n    {\n        batch_size_        = param.batch_size;\n        vocab_size_        = param.vocab_size;\n        vocab_size_padded_ = pad_vocab_size(vocab_size_);\n        max_input_length_  = param.max_input_length;\n        sequence_length_   = 2 * max_input_length_;  // input + output\n        step_              = sequence_length_ * 0.7;\n\n        h_logits_        = new T[batch_size_ * vocab_size_padded_];\n        h_bias_          = new T[vocab_size_padded_];\n        h_output_ids_    = new int[sequence_length_ * batch_size_];\n        h_input_lengths_ = new int[batch_size_];\n        initLogitsAndBias(h_logits_, h_bias_, batch_size_, vocab_size_, vocab_size_padded_);\n        initRandomInt(h_output_ids_, sequence_length_ * batch_size_, 0, vocab_size_);\n        initRandomInt(h_input_lengths_, batch_size_, 1, max_input_length_);\n\n        d_logits_        = reinterpret_cast<T*>(allocator->malloc(sizeof(T) * batch_size_ * vocab_size_padded_));\n        d_bias_          = reinterpret_cast<T*>(allocator->malloc(sizeof(T) * vocab_size_padded_));\n        d_output_ids_    = reinterpret_cast<int*>(allocator->malloc(sizeof(int) * sequence_length_ * batch_size_));\n        d_input_lengths_ = reinterpret_cast<int*>(allocator->malloc(sizeof(int) * batch_size_));\n        d_penalty_workspace_ =\n            reinterpret_cast<int*>(allocator->malloc((sizeof(int) + sizeof(float)) * batch_size_ * step_));\n\n        cudaAutoCpy(d_logits_, h_logits_, batch_size_ * vocab_size_padded_, stream);\n        cudaAutoCpy(d_bias_, h_bias_, vocab_size_padded_, stream);\n        cudaAutoCpy(d_output_ids_, h_output_ids_, sequence_length_ * batch_size_, stream);\n        cudaAutoCpy(d_input_lengths_, h_input_lengths_, batch_size_, stream);\n        if (param.repetition_penalties_size > 1) {\n            ASSERT_EQ(param.repetition_penalties_size, param.batch_size) << \"Invalid test configuration.\";\n            d_repetition_penalties_ =\n                reinterpret_cast<float*>(allocator->malloc(sizeof(T) * param.repetition_penalties_size));\n            cudaAutoCpy(d_repetition_penalties_, param.repetition_penalties, batch_size_, stream);\n        }\n    }\n\n    void subteardown()\n    {\n        delete[] h_logits_;\n        delete[] h_bias_;\n        delete[] h_output_ids_;\n        delete[] h_input_lengths_;\n    }\n\n    void computeReference(T*                          logits,\n                          const int*                  output_ids,\n                          const int*                  input_lengths,\n                          const float*                repetition_penalties,\n                          const size_t                repetition_penalties_size,\n                          const RepetitionPenaltyType repetition_penalty_type,\n                          const size_t                step,\n                          const size_t                max_input_length,\n                          const size_t                batch_size,\n                          const size_t                vocab_size,\n                          const size_t                vocab_size_padded)\n    {\n        bool* penalized = new bool[vocab_size];\n        for (size_t i = 0; i < batch_size; ++i) {\n            float repetition_penalty =\n                repetition_penalties_size > 1 ? repetition_penalties[i] : repetition_penalties[0];\n\n            std::fill_n(penalized, vocab_size, false);\n            size_t offset = i * vocab_size_padded;\n            for (size_t t = 0; t < step; ++t) {\n                if (t >= (size_t)input_lengths[i] && t < max_input_length) {\n                    continue;\n                }\n                int token_id = output_ids[i + t * batch_size];\n                if (!penalized[token_id]) {\n                    float logit = static_cast<float>(logits[offset + token_id]);\n                    switch (repetition_penalty_type) {\n                        case RepetitionPenaltyType::Additive:\n                            logits[offset + token_id] = static_cast<T>(logit - repetition_penalty);\n                            break;\n                        case RepetitionPenaltyType::Multiplicative:\n                            logits[offset + token_id] =\n                                static_cast<T>(logit < 0.0f ? logit * repetition_penalty : logit / repetition_penalty);\n                            break;\n                        case RepetitionPenaltyType::None:\n                            // None. do nothing.\n                            break;\n                        default:\n                            throw std::domain_error(\"Invalid repetition penalty type.\");\n                    }\n                    penalized[token_id] = true;\n                }\n            }\n        }\n        delete[] penalized;\n    }\n\npublic:\n    void runTest(RepetitionPenaltyTestCase param)\n    {\n        subsetup(param);\n        // Do test\n        if (param.repetition_penalties_size == 1) {\n            invokeApplyRepetitionPenalty(d_logits_,\n                                         param.repetition_penalties[0],\n                                         nullptr,\n                                         d_output_ids_,\n                                         batch_size_,\n                                         batch_size_,\n                                         vocab_size_,\n                                         vocab_size_padded_,\n                                         d_input_lengths_,\n                                         max_input_length_,\n                                         step_,\n                                         param.repetition_penalty_type,\n                                         stream);\n        }\n        else {\n            invokeBatchApplyRepetitionPenalty(d_logits_,\n                                              d_repetition_penalties_,\n                                              d_penalty_workspace_,\n                                              d_output_ids_,\n                                              batch_size_,\n                                              batch_size_,\n                                              vocab_size_padded_,\n                                              d_input_lengths_,\n                                              max_input_length_,\n                                              step_,\n                                              param.repetition_penalty_type,\n                                              stream);\n        }\n        computeReference(h_logits_,\n                         h_output_ids_,\n                         h_input_lengths_,\n                         param.repetition_penalties,\n                         param.repetition_penalties_size,\n                         param.repetition_penalty_type,\n                         step_,\n                         max_input_length_,\n                         batch_size_,\n                         vocab_size_,\n                         vocab_size_padded_);\n        bool passed = checkResult(param.toString(), d_logits_, h_logits_, batch_size_ * vocab_size_padded_);\n        EXPECT_TRUE(passed);\n        subteardown();\n    }\n\n    void runConsistencyTest(RepetitionPenaltyTestCase param)\n    {\n        // Set up test\n        ASSERT_EQ(param.repetition_penalties_size, 1) << \"A consistency test assumes repetition_penalties_size=1\";\n        subsetup(param);\n\n        // Run a single runtime value case.\n        invokeApplyRepetitionPenalty(d_logits_,\n                                     param.repetition_penalties[0],\n                                     nullptr,\n                                     d_output_ids_,\n                                     batch_size_,\n                                     batch_size_,\n                                     vocab_size_,\n                                     vocab_size_padded_,\n                                     d_input_lengths_,\n                                     max_input_length_,\n                                     step_,\n                                     param.repetition_penalty_type,\n                                     stream);\n\n        float* h_repetition_penalties = new float[batch_size_];\n        for (size_t i = 0; i < batch_size_; ++i) {\n            h_repetition_penalties[i] = param.repetition_penalties[0];\n        }\n        d_repetition_penalties_ = reinterpret_cast<float*>(allocator->malloc(sizeof(T) * batch_size_));\n        cudaAutoCpy(d_repetition_penalties_, h_repetition_penalties, batch_size_, stream);\n\n        T* d_logits_batch = reinterpret_cast<T*>(allocator->malloc(sizeof(T) * batch_size_ * vocab_size_padded_));\n        cudaAutoCpy(d_logits_batch, h_logits_, batch_size_ * vocab_size_padded_, stream);\n        invokeBatchApplyRepetitionPenalty(d_logits_batch,\n                                          d_repetition_penalties_,\n                                          d_penalty_workspace_,\n                                          d_output_ids_,\n                                          batch_size_,\n                                          batch_size_,\n                                          vocab_size_padded_,\n                                          d_input_lengths_,\n                                          max_input_length_,\n                                          step_,\n                                          param.repetition_penalty_type,\n                                          stream);\n        bool passed =\n            checkResult(param.toString(), d_logits_, d_logits_batch, batch_size_ * vocab_size_padded_, true, true);\n        EXPECT_TRUE(passed);\n\n        // Tear down test\n        delete[] h_repetition_penalties;\n        subteardown();\n    }\n};\n\nTYPED_TEST_SUITE(RepetitionPenaltyTest, SamplingTypes);\n\nTYPED_TEST(RepetitionPenaltyTest, NoPenalty)\n{\n    float repetition_penalty = 1.0f;\n    this->runTest({6, 4, 5, &repetition_penalty, 1, RepetitionPenaltyType::Multiplicative});\n}\n\nTYPED_TEST(RepetitionPenaltyTest, LessThanOne)\n{\n    float repetition_penalty = 0.53f;\n    this->runTest({6, 4, 5, &repetition_penalty, 1, RepetitionPenaltyType::Multiplicative});\n}\n\nTYPED_TEST(RepetitionPenaltyTest, GreaterThaneOne)\n{\n    float repetition_penalty = 2.01f;\n    this->runTest({6, 4, 5, &repetition_penalty, 1, RepetitionPenaltyType::Multiplicative});\n}\n\nTYPED_TEST(RepetitionPenaltyTest, LargeVocab)\n{\n    float repetition_penalty = 2.01f;\n    this->runTest({6, 50001, 1003, &repetition_penalty, 1, RepetitionPenaltyType::Multiplicative});\n}\n\nTYPED_TEST(RepetitionPenaltyTest, BatchNoPenalty)\n{\n    size_t batch_size           = 6;\n    float* repetition_penalties = new float[batch_size];\n    for (size_t i = 0; i < batch_size; ++i) {\n        repetition_penalties[i] = 1.0f;\n    }\n    this->runTest({batch_size, 4, 5, repetition_penalties, batch_size, RepetitionPenaltyType::Multiplicative});\n}\n\nTYPED_TEST(RepetitionPenaltyTest, BatchLessThanOne)\n{\n    size_t batch_size           = 6;\n    float* repetition_penalties = new float[batch_size];\n    for (size_t i = 0; i < batch_size; ++i) {\n        repetition_penalties[i] = 0.53f;\n    }\n    this->runTest({batch_size, 4, 5, repetition_penalties, batch_size, RepetitionPenaltyType::Multiplicative});\n}\n\nTYPED_TEST(RepetitionPenaltyTest, BatchGreaterThaneOne)\n{\n    size_t batch_size   = 6;\n    float* temperatures = new float[batch_size];\n    for (size_t i = 0; i < batch_size; ++i) {\n        temperatures[i] = 2.01f;\n    }\n    this->runTest({batch_size, 4, 5, temperatures, batch_size, RepetitionPenaltyType::Multiplicative});\n}\n\nTYPED_TEST(RepetitionPenaltyTest, BatchMixed)\n{\n    size_t batch_size           = 6;\n    float* repetition_penalties = new float[batch_size];\n    for (size_t i = 0; i < batch_size; ++i) {\n        repetition_penalties[i] = i % 2 == 0 ? 2.01f : 0.53f;\n    }\n    this->runTest({batch_size, 4, 5, repetition_penalties, batch_size, RepetitionPenaltyType::Multiplicative});\n}\n\nTYPED_TEST(RepetitionPenaltyTest, Consistency)\n{\n    float repetition_penalty = 2.01f;\n    this->runConsistencyTest({6, 4, 5, &repetition_penalty, 1, RepetitionPenaltyType::Multiplicative});\n}\n\nTYPED_TEST(RepetitionPenaltyTest, PenaltyTypeAdditive)\n{\n    size_t batch_size           = 6;\n    float* repetition_penalties = new float[batch_size];\n    for (size_t i = 0; i < batch_size; ++i) {\n        repetition_penalties[i] = i % 2 == 0 ? 2.01f : 0.53f;\n    }\n    this->runTest({batch_size, 4, 5, repetition_penalties, batch_size, RepetitionPenaltyType::Additive});\n}\n\nTYPED_TEST(RepetitionPenaltyTest, PenaltyTypeAdditiveHasDefaultValueZero)\n{\n    float repetition_penalty = 1.0f;\n    this->runTest({6, 4, 5, &repetition_penalty, 1, RepetitionPenaltyType::Additive});\n}\n\nTYPED_TEST(RepetitionPenaltyTest, PenaltyTypeAdditiveHasDefaultValueZero2)\n{\n    size_t batch_size           = 6;\n    float* repetition_penalties = new float[batch_size];\n    for (size_t i = 0; i < batch_size; ++i) {\n        repetition_penalties[i] = i % 2 == 0 ? 1.0f : 0.0f;\n    }\n    this->runTest({batch_size, 4, 5, repetition_penalties, batch_size, RepetitionPenaltyType::Additive});\n}\n\n// Turn on the warning message.\n#pragma nv_diag_suppress 177\n"
  },
  {
    "path": "tests/csrc/unittests/test_sampling_kernels.cu",
    "content": "#include <algorithm>  // std::fill_n\n#include <iostream>   // snprintf\n#include <math.h>     // expf, log\n#include <stdlib.h>   // rand\n#include <string>     // std::string\n#include <vector>     // std::vector\n\n#include <cublasLt.h>\n#include <cublas_v2.h>\n#include <cuda_runtime.h>\n#include <gtest/gtest.h>\n\n#include \"src/turbomind/kernels/sampling_kernels.h\"\n#include \"src/turbomind/kernels/sampling_topk_kernels.h\"\n#include \"src/turbomind/kernels/sampling_topp_kernels.h\"\n#include \"src/turbomind/layers/DynamicDecodeLayer.h\"\n#include \"src/turbomind/macro.h\"\n#include \"src/turbomind/utils/Tensor.h\"\n#include \"src/turbomind/utils/constant.h\"\n#include \"src/turbomind/utils/cublasMMWrapper.h\"\n#include \"src/turbomind/utils/cuda_utils.h\"\n#include \"src/turbomind/utils/memory_utils.h\"\n\n#include \"gtest_utils.h\"\n\nusing namespace turbomind;\n\nnamespace {\n\n__global__ void get_curand_uniform(curandState_t* curandstate, float* output, int n)\n{\n    int   batch_id   = blockIdx.x;\n    float rand_num   = (float)curand_uniform(curandstate + batch_id);\n    output[batch_id] = rand_num;\n}\n\ntemplate<typename T>\nbool checkSorted(int  batch_size,\n                 T*   expected_logits,\n                 T*   output_logits,\n                 int* expected_indices,\n                 int* output_indices,\n                 int* expected_kept,\n                 int* output_kept,\n                 int  vocab_size)\n{\n    for (int i = 0; i < batch_size; i++) {\n        if (expected_kept[i] != output_kept[i]) {\n            printf(\"batch=%d, expected_kept[i]=%d, output_kept[i]=%d\\n\", i, expected_kept[i], output_kept[i]);\n            return false;\n        }\n\n        for (int j = 0; j < expected_kept[i]; j++) {\n            int index = i * vocab_size + j;\n            // soft check\n            if (std::abs((float)expected_logits[index] - (float)output_logits[index]) > 1e-6\n                && expected_indices[index] != output_indices[index]) {\n                printf(\"batch=%d, ith=%d, expected=(%d, %.5f), output=(%d, %.5f)\\n\",\n                       i,\n                       j,\n                       expected_indices[index],\n                       (float)expected_logits[index],\n                       output_indices[index],\n                       (float)output_logits[index]);\n                return false;\n            }\n        }\n    }\n    return true;\n}\n\ntemplate<typename T>\nbool checkSample(int* expected_output_ids,\n                 int* output_ids,\n                 int  batch_size,\n                 T*   expected_sampled_logprobs,\n                 int* expected_sampled_indices,\n                 int* expected_sampled_nums,\n                 T*   output_sampled_logprobs,\n                 int* output_sampled_indices,\n                 int* output_sampled_nums)\n{\n    for (int i = 0; i < batch_size; i++) {\n        if (expected_sampled_nums[i] != output_sampled_nums[i]) {\n            printf(\"batch=%d, sampled_nums, cpu=%d, gpu=%d\\n\", i, expected_sampled_nums[i], output_sampled_nums[i]);\n            return false;\n        }\n        if (expected_output_ids[i] != output_ids[i]) {\n            printf(\"batch=%d, expected_output_ids=%d, output_ids=%d\\n\", i, expected_output_ids[i], output_ids[i]);\n            return false;\n        }\n        for (int j = 0; j < expected_sampled_nums[i]; j++) {\n            int   offset  = i * kMaxLogProb + j;\n            float gpu_val = output_sampled_logprobs[offset];\n            float cpu_val = expected_sampled_logprobs[offset];\n            int   gpu_idx = output_sampled_indices[offset];\n            int   cpu_idx = expected_sampled_indices[offset];\n            if (std::abs(gpu_val - cpu_val) > 1e-5) {\n                if (gpu_idx != cpu_idx) {\n                    printf(\"%d %d\\n\", expected_output_ids[i], output_ids[i]);\n                    printf(\"batch=%d, ith=%d, idx cpu=%d, gpu=%d, val cpu=%.5f, gpu=%.5f\\n\",\n                           i,\n                           j,\n                           cpu_idx,\n                           gpu_idx,\n                           cpu_val,\n                           gpu_val);\n                    return false;\n                }\n            }\n        }\n    }\n    return true;\n}\n\ntemplate<typename T>\nvoid sampleCpu(int    batch_size,\n               int    vocab_size,\n               T*     logits,\n               int*   indices,\n               int*   kept,\n               float* uniforms,\n               int*   output_ids,\n               T*     sampled_logprobs,\n               int*   sampled_indices,\n               int*   sampled_nums)\n{\n\n    for (int i = 0; i < batch_size; i++) {\n        int   selected = -1;\n        float sum_val  = 0.f;\n        for (int j = 0; j < kept[i]; j++) {\n            sum_val += (float)logits[i * vocab_size + j];\n            if (sum_val > uniforms[i]) {\n                selected      = j;\n                output_ids[i] = indices[i * vocab_size + j];\n                break;\n            }\n        }\n\n        if (sampled_logprobs && sampled_indices && sampled_nums) {\n            for (int j = 0; j < min(kept[i], kMaxLogProb); ++j) {\n                sampled_logprobs[i * kMaxLogProb + j] = std::log((float)logits[i * vocab_size + j]);\n                sampled_indices[i * kMaxLogProb + j]  = indices[i * vocab_size + j];\n            }\n            if (kept[i] > kMaxLogProb && selected >= kMaxLogProb) {\n                sampled_logprobs[i * kMaxLogProb + kMaxLogProb - 1] =\n                    std::log((float)logits[i * vocab_size + selected]);\n                sampled_indices[i * kMaxLogProb + kMaxLogProb - 1] = indices[i * vocab_size + selected];\n            }\n            sampled_nums[i] = min(kept[i], kMaxLogProb);\n        }\n    }\n}\n\ntemplate<typename T>\nvoid softmax(T* input, int batch_size, int vocab_size, int* kept, T* output)\n{\n    for (int i = 0; i < batch_size; i++) {\n        int   offset  = i * vocab_size;\n        float max_val = input[offset];\n        for (int j = 0; j < kept[i]; j++) {\n            max_val = std::max((float)input[offset + j], max_val);\n        }\n        float sum_val{};\n        for (int j = 0; j < kept[i]; j++) {\n            output[offset + j] = std::exp((float)input[offset + j] - max_val);\n            sum_val += (float)output[offset + j];\n        }\n        for (int j = 0; j < kept[i]; j++) {\n            output[offset + j] = (float)output[offset + j] / sum_val;\n        }\n    }\n}\n\ntemplate<typename T>\nvoid filterCpu(int    batch_size,\n               int*   top_ks,\n               float* top_ps,\n               float* min_ps,\n               T*     logits,\n               T*     sorted_logits,\n               int*   sorted_indices,\n               int*   kept,\n               int    vocab_size,\n               bool   filter_topp = false,\n               bool   filter_minp = false)\n{\n    for (int i = 0; i < batch_size; i++) {\n        // fill\n        std::vector<std::pair<float, int>> work(vocab_size);\n        for (int j = 0; j < vocab_size; j++) {\n            work[j] = {logits[i * vocab_size + j], j};\n        }\n\n        // sort\n        if (top_ks && top_ks[i] != 0) {\n            std::partial_sort(work.begin(), work.begin() + top_ks[i], work.end(), std::greater{});\n            kept[i] = top_ks[i];\n        }\n        else {\n            std::sort(work.begin(), work.end(), std::greater{});\n            kept[i] = vocab_size;\n        }\n        for (int j = 0; j < kept[i]; j++) {\n            sorted_logits[i * vocab_size + j]  = work[j].first;\n            sorted_indices[i * vocab_size + j] = work[j].second;\n        }\n        // softmax\n        softmax(sorted_logits + i * vocab_size, 1, vocab_size, kept + i, sorted_logits + i * vocab_size);\n        if (top_ks && top_ks[i] == 0) {\n            if (top_ps && (float)sorted_logits[i * vocab_size] > top_ps[i]) {\n                sorted_logits[i * vocab_size] = 1.f;\n                kept[i]                       = 1;\n            }\n        }\n\n        // topp filter\n        if (filter_topp && top_ps[i] != 1.f) {\n            float topp    = top_ps[i];\n            float sum_val = 0;\n            int   n       = kept[i];\n            for (int j = 0; j < kept[i]; j++) {\n                sum_val += (float)sorted_logits[i * vocab_size + j];\n                if (sum_val > topp) {\n                    n = j + 1;\n                    break;\n                }\n            }\n            if (n != kept[i]) {\n                kept[i] = n;\n                for (int j = 0; j < n; j++) {\n                    sorted_logits[i * vocab_size + j] = (float)sorted_logits[i * vocab_size + j] / (sum_val + 1e-6f);\n                }\n            }\n        }\n\n        // minp filter\n        if (filter_minp && min_ps[i] != 0.f) {\n            float minp      = min_ps[i];\n            float threshold = (float)sorted_logits[i * vocab_size] * minp;\n            float sum_val   = 0;\n            int   n         = kept[i];\n            for (int j = 0; j < kept[i]; j++) {\n                if ((float)sorted_logits[i * vocab_size + j] < threshold) {\n                    n = j;\n                    break;\n                }\n                sum_val += (float)sorted_logits[i * vocab_size + j];\n            }\n            if (n != kept[i]) {\n                kept[i] = n;\n                for (int j = 0; j < n; j++) {\n                    sorted_logits[i * vocab_size + j] = (float)sorted_logits[i * vocab_size + j] / (sum_val + 1e-6f);\n                }\n            }\n        }\n    }\n}\n\ntemplate<typename T>\nclass SamplingKernelTest: public testing::Test {\npublic:\n    void SetUp() override\n    {\n        check_cuda_error(cudaStreamCreate(&stream));\n        allocator = new Allocator<AllocatorType::CUDA>(getDevice());\n        allocator->setStream(stream);\n    }\n    void TearDown() override\n    {\n        delete allocator;\n        check_cuda_error(cudaStreamDestroy(stream));\n    }\n\nprotected:\n    cudaStream_t                    stream;\n    Allocator<AllocatorType::CUDA>* allocator;\n    curandState_t*                  curand_states;\n};\n\ntemplate<typename T>\nclass TopKTopPSortTest: public SamplingKernelTest<T> {\nprotected:\n    using SamplingKernelTest<T>::stream;\n    using SamplingKernelTest<T>::allocator;\n\npublic:\n    void runTest(int batch_size, int* top_ks, float* top_ps, int vocab_size)\n    {\n\n        TopKSortFilterParams params1{};\n        params1.batch_size = batch_size;\n        int max_top_k      = *std::max_element(top_ks, top_ks + batch_size);\n        params1.max_top_k  = std::min(1024, std::max(0, max_top_k));\n        invokeTopKSortFilter<T>(params1, stream);\n\n        TopPSortParams params2{};\n        params2.batch_size        = batch_size;\n        params2.vocab_size        = vocab_size;\n        params2.vocab_size_padded = vocab_size;\n        invokeTopPSort<T>(params2, stream);\n\n        // host buffer\n        std::vector<T>   logits(batch_size * vocab_size);\n        std::vector<T>   expected_logits(batch_size * vocab_size);\n        std::vector<int> expected_indices(batch_size * vocab_size);\n        std::vector<int> expected_kept(batch_size);\n\n        std::vector<T>   output_logits(batch_size * vocab_size);\n        std::vector<int> output_indices(batch_size * vocab_size);\n        std::vector<int> output_kept(batch_size);\n\n        // device buffer\n        void*  d_ws_topk        = allocator->malloc(params1.workspace_size);\n        void*  d_ws_topp        = allocator->malloc(params2.workspace_size);\n        T*     d_logits         = (T*)allocator->malloc(sizeof(T) * batch_size * vocab_size);\n        T*     d_sorted_logits  = (T*)allocator->malloc(sizeof(T) * batch_size * vocab_size);\n        int*   d_sorted_indices = (int*)allocator->malloc(sizeof(int) * batch_size * vocab_size);\n        int*   d_kept           = (int*)allocator->malloc(sizeof(int) * batch_size);\n        int*   d_top_ks         = (int*)allocator->malloc(sizeof(int) * batch_size);\n        float* d_top_ps         = (float*)allocator->malloc(sizeof(float) * batch_size);\n\n        float boundary = 1.f;\n        for (int x = vocab_size; x >= 10; x /= 10) {\n            boundary *= 10;\n        }\n        initRandom(logits.data(), batch_size * vocab_size, -boundary, boundary);\n\n        std::fill_n(expected_kept.data(), batch_size, vocab_size);\n\n        cudaAutoCpy(d_logits, logits.data(), batch_size * vocab_size, stream);\n        cudaAutoCpy(d_top_ps, top_ps, batch_size, stream);\n        cudaAutoCpy(d_top_ks, top_ks, batch_size, stream);\n        cudaAutoCpy(d_kept, expected_kept.data(), batch_size, stream);\n\n        // gpu\n        params1.workspace         = d_ws_topk;\n        params1.logits            = d_logits;\n        params1.sorted_logits     = d_sorted_logits;\n        params1.sorted_indices    = d_sorted_indices;\n        params1.kept              = d_kept;\n        params1.top_ks            = d_top_ks;\n        params1.vocab_size        = vocab_size;\n        params1.vocab_size_padded = vocab_size;\n        invokeTopKSortFilter<T>(params1, stream);\n\n        invokeSoftmax<T>(d_logits, vocab_size, vocab_size, batch_size, d_kept, stream);\n        params2.workspace      = d_ws_topp;\n        params2.logits         = d_logits;\n        params2.sorted_logits  = d_sorted_logits;\n        params2.sorted_indices = d_sorted_indices;\n        params2.kept           = d_kept;\n        params2.top_ks         = d_top_ks;\n        params2.top_ps         = d_top_ps;\n        invokeTopPSort<T>(params2, stream);\n\n        // outputs\n        cudaAutoCpy(output_logits.data(), d_sorted_logits, batch_size * vocab_size);\n        cudaAutoCpy(output_indices.data(), d_sorted_indices, batch_size * vocab_size);\n        cudaAutoCpy(output_kept.data(), d_kept, batch_size, stream);\n        cudaStreamSynchronize(stream);\n\n        // cpu\n        filterCpu(batch_size,\n                  top_ks,\n                  top_ps,\n                  nullptr,\n                  logits.data(),\n                  expected_logits.data(),\n                  expected_indices.data(),\n                  expected_kept.data(),\n                  vocab_size);\n\n        EXPECT_TRUE(checkSorted(batch_size,\n                                expected_logits.data(),\n                                output_logits.data(),\n                                expected_indices.data(),\n                                output_indices.data(),\n                                expected_kept.data(),\n                                output_kept.data(),\n                                vocab_size));\n    }\n};\n\nTYPED_TEST_SUITE(TopKTopPSortTest, SamplingTypes);\n\nTYPED_TEST(TopKTopPSortTest, OnlyTopKBatch)\n{\n    int   top_ks[] = {1, 2, 3, 4, 5, 6, 7, 8};\n    float top_ps[] = {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f};\n    this->runTest(sizeof(top_ks) / sizeof(int), top_ks, top_ps, 20);\n};\n\nTYPED_TEST(TopKTopPSortTest, OnlyTopKLargeVocab)\n{\n    int   top_ks[] = {1, 2, 4, 8, 16, 32, 64, 1024};\n    float top_ps[] = {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f};\n    this->runTest(sizeof(top_ks) / sizeof(int), top_ks, top_ps, 32000);\n};\n\nTYPED_TEST(TopKTopPSortTest, OnlyTopPBatch)\n{\n    int   top_ks[] = {0, 0, 0, 0, 0, 0, 0, 0};\n    float top_ps[] = {0.0f, 0.1f, 0.3f, 0.4f, 0.5f, 0.7f, 0.9f, 1.0f};\n    this->runTest(sizeof(top_ks) / sizeof(int), top_ks, top_ps, 20);\n};\n\nTYPED_TEST(TopKTopPSortTest, OnlyTopPLargeVocab)\n{\n    int   top_ks[] = {0, 0, 0, 0, 0, 0, 0, 0};\n    float top_ps[] = {0.0f, 0.1f, 0.3f, 0.4f, 0.5f, 0.7f, 0.9f, 1.0f};\n    this->runTest(sizeof(top_ks) / sizeof(int), top_ks, top_ps, 32000);\n};\n\nTYPED_TEST(TopKTopPSortTest, MixedTopKTopP)\n{\n    int   top_ks[] = {1, 0, 16, 0, 32, 0, 64, 1024};\n    float top_ps[] = {0.0f, 0.1f, 0.0f, 0.4f, 0.5f, 0.7f, 0.9f, 1.0f};\n    this->runTest(sizeof(top_ks) / sizeof(int), top_ks, top_ps, 32000);\n};\n\ntemplate<typename T>\nclass TopPMinPFilterTest: public SamplingKernelTest<T> {\nprotected:\n    using SamplingKernelTest<T>::stream;\n    using SamplingKernelTest<T>::allocator;\n\npublic:\n    void runTest(int batch_size, float* top_ps, float* min_ps, int vocab_size)\n    {\n\n        // host buffer\n        std::vector<T>   logits(batch_size * vocab_size);\n        std::vector<T>   expected_logits(batch_size * vocab_size);\n        std::vector<int> expected_indices(batch_size * vocab_size);\n        std::vector<int> expected_kept(batch_size);\n\n        std::vector<T>   output_logits(batch_size * vocab_size);\n        std::vector<int> output_indices(batch_size * vocab_size);\n        std::vector<int> output_kept(batch_size);\n\n        // device buffer\n        T*     d_sorted_logits  = (T*)allocator->malloc(sizeof(T) * batch_size * vocab_size);\n        int*   d_sorted_indices = (int*)allocator->malloc(sizeof(int) * batch_size * vocab_size);\n        int*   d_kept           = (int*)allocator->malloc(sizeof(int) * batch_size);\n        float* d_top_ps         = (float*)allocator->malloc(sizeof(float) * batch_size);\n        float* d_min_ps         = (float*)allocator->malloc(sizeof(float) * batch_size);\n\n        float boundary = 1.f;\n        for (int x = vocab_size; x >= 10; x /= 10) {\n            boundary *= 10;\n        }\n        initRandom(logits.data(), batch_size * vocab_size, -boundary, boundary);\n        std::fill_n(expected_kept.data(), batch_size, vocab_size);\n\n        filterCpu(batch_size,\n                  nullptr,\n                  top_ps,\n                  min_ps,\n                  logits.data(),\n                  expected_logits.data(),\n                  expected_indices.data(),\n                  expected_kept.data(),\n                  vocab_size);\n\n        cudaAutoCpy(d_sorted_logits, expected_logits.data(), batch_size * vocab_size);\n        cudaAutoCpy(d_sorted_indices, expected_indices.data(), batch_size * vocab_size);\n        cudaAutoCpy(d_kept, expected_kept.data(), batch_size, stream);\n        cudaAutoCpy(d_top_ps, top_ps, batch_size, stream);\n        cudaAutoCpy(d_min_ps, min_ps, batch_size, stream);\n\n        TopPMinPFilterParams params{};\n        params.sorted_logits     = d_sorted_logits;\n        params.sorted_indices    = d_sorted_indices;\n        params.kept              = d_kept;\n        params.top_ps            = d_top_ps;\n        params.min_ps            = d_min_ps;\n        params.batch_size        = batch_size;\n        params.vocab_size        = vocab_size;\n        params.vocab_size_padded = vocab_size;\n        invokeTopPMinPFilter<T>(params, stream);\n        cudaStreamSynchronize(stream);\n\n        // outputs\n        cudaAutoCpy(output_logits.data(), d_sorted_logits, batch_size * vocab_size);\n        cudaAutoCpy(output_indices.data(), d_sorted_indices, batch_size * vocab_size);\n        cudaAutoCpy(output_kept.data(), d_kept, batch_size, stream);\n        cudaStreamSynchronize(stream);\n\n        // cpu\n        filterCpu(batch_size,\n                  nullptr,\n                  top_ps,\n                  min_ps,\n                  logits.data(),\n                  expected_logits.data(),\n                  expected_indices.data(),\n                  expected_kept.data(),\n                  vocab_size,\n                  true,\n                  true);\n\n        EXPECT_TRUE(checkSorted(batch_size,\n                                expected_logits.data(),\n                                output_logits.data(),\n                                expected_indices.data(),\n                                output_indices.data(),\n                                expected_kept.data(),\n                                output_kept.data(),\n                                vocab_size));\n    }\n};\n\nTYPED_TEST_SUITE(TopPMinPFilterTest, SamplingTypes);\n\nTYPED_TEST(TopPMinPFilterTest, OnlyTopP)\n{\n    float top_ps[] = {0.8f, 0.82f, 0.84f, 0.86f, 0.88f, 0.90f, 0.92f, 0.94f, 0.96f, 0.98f, 1.0f};\n    float min_ps[] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f};\n    this->runTest(sizeof(top_ps) / sizeof(float), top_ps, min_ps, 200);\n};\n\nTYPED_TEST(TopPMinPFilterTest, OnlyMinP)\n{\n    float min_ps[] = {0.0f, 0.002f, 0.004f, 0.006f, 0.008f, 0.01f, 0.012f, 0.014f, 0.016f, 0.018f, 0.02f};\n    float top_ps[] = {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f};\n    this->runTest(sizeof(top_ps) / sizeof(float), top_ps, min_ps, 200);\n};\n\nTYPED_TEST(TopPMinPFilterTest, MixedTopPMinP)\n{\n    float min_ps[] = {0.0f, 0.002f, 0.004f, 0.006f, 0.008f, 0.01f, 0.012f, 0.014f, 0.016f, 0.018f, 0.02f};\n    float top_ps[] = {0.8f, 0.82f, 0.84f, 0.86f, 0.88f, 0.90f, 0.92f, 0.94f, 0.96f, 0.98f, 1.0f};\n    this->runTest(sizeof(top_ps) / sizeof(float), top_ps, min_ps, 200);\n};\n\ntemplate<typename T>\nclass SamplingTest: public SamplingKernelTest<T> {\nprotected:\n    using SamplingKernelTest<T>::stream;\n    using SamplingKernelTest<T>::allocator;\n\npublic:\n    void runTest(int batch_size, int vocab_size, int top_logprobs)\n    {\n\n        // host buffer\n        std::vector<T>     logits(batch_size * vocab_size);\n        std::vector<T>     expected_logits(batch_size * vocab_size);\n        std::vector<int>   expected_indices(batch_size * vocab_size);\n        std::vector<int>   expected_kept(batch_size);\n        std::vector<int>   expected_output_ids(batch_size);\n        std::vector<float> uniforms(batch_size);\n\n        std::vector<T>   sampled_logprobs(batch_size * kMaxLogProb);\n        std::vector<int> sampled_indexes(batch_size * kMaxLogProb);\n        std::vector<int> sampled_nums(batch_size);\n\n        // std::vector<T>     output_logits(batch_size * vocab_size);\n        // std::vector<int>   output_indices(batch_size * vocab_size);\n        // std::vector<int>   output_kept(batch_size);\n        std::vector<int> output_ids(batch_size);\n        std::vector<T>   output_sampled_logprobs(batch_size * kMaxLogProb);\n        std::vector<int> output_sampled_indexes(batch_size * kMaxLogProb);\n        std::vector<int> output_sampled_nums(batch_size);\n\n        // device buffer\n        T*             d_sorted_logits    = (T*)allocator->malloc(sizeof(T) * batch_size * vocab_size);\n        int*           d_sorted_indices   = (int*)allocator->malloc(sizeof(int) * batch_size * vocab_size);\n        int*           d_kept             = (int*)allocator->malloc(sizeof(int) * batch_size);\n        float*         d_top_ps           = (float*)allocator->malloc(sizeof(float) * batch_size);\n        float*         d_min_ps           = (float*)allocator->malloc(sizeof(float) * batch_size);\n        float*         d_uniforms         = (float*)(allocator->malloc(sizeof(float) * batch_size));\n        int*           d_output_ids       = (int*)(allocator->malloc(sizeof(int) * batch_size));\n        T*             d_sampled_logprobs = (T*)(allocator->malloc(sizeof(T) * batch_size * kMaxLogProb));\n        int*           d_sampled_indexes  = (int*)(allocator->malloc(sizeof(int) * batch_size * kMaxLogProb));\n        int*           d_sampled_nums     = (int*)(allocator->malloc(sizeof(int) * batch_size));\n        curandState_t* curand_states =\n            reinterpret_cast<curandState_t*>(allocator->malloc(sizeof(curandState_t) * batch_size, false));\n\n        float boundary = 1.f;\n        for (int x = vocab_size; x >= 10; x /= 10) {\n            boundary *= 10;\n        }\n        initRandom(logits.data(), batch_size * vocab_size, -boundary, boundary);\n        std::fill_n(expected_kept.data(), batch_size, vocab_size);\n\n        // sort & softmax\n        filterCpu(batch_size,\n                  nullptr,\n                  nullptr,\n                  nullptr,\n                  logits.data(),\n                  expected_logits.data(),\n                  expected_indices.data(),\n                  expected_kept.data(),\n                  vocab_size);\n\n        cudaAutoCpy(d_sorted_logits, expected_logits.data(), batch_size * vocab_size);\n        cudaAutoCpy(d_sorted_indices, expected_indices.data(), batch_size * vocab_size);\n        cudaAutoCpy(d_kept, expected_kept.data(), batch_size, stream);\n\n        // uniforms\n        for (int i = 0; i < batch_size; i++) {\n            invokeCurandInitialize(curand_states + i, 1, i, stream);\n        }\n        get_curand_uniform<<<batch_size, 1, 0, stream>>>(curand_states, d_uniforms, batch_size);\n        cudaAutoCpy(uniforms.data(), d_uniforms, batch_size, stream);\n        for (int i = 0; i < batch_size; i++) {\n            invokeCurandInitialize(curand_states + i, 1, i, stream);\n        }\n\n        // sample\n        SamplingParams params{};\n        params.logits           = d_sorted_logits;\n        params.stride           = vocab_size;\n        params.indices          = d_sorted_indices;\n        params.kept             = d_kept;\n        params.curandstate      = curand_states;\n        params.batch_size       = batch_size;\n        params.output_ids       = d_output_ids;\n        params.sequence_length  = nullptr;\n        params.sampled_logprobs = d_sampled_logprobs;\n        params.sampled_indexes  = (uint32_t*)d_sampled_indexes;\n        params.sampled_nums     = (uint32_t*)d_sampled_nums;\n        invokeSampling<T>(params, stream);\n\n        // outputs\n        cudaAutoCpy(output_ids.data(), d_output_ids, batch_size, stream);\n        cudaAutoCpy(output_sampled_logprobs.data(), d_sampled_logprobs, batch_size * kMaxLogProb, stream);\n        cudaAutoCpy(output_sampled_indexes.data(), d_sampled_indexes, batch_size * kMaxLogProb, stream);\n        cudaAutoCpy(output_sampled_nums.data(), d_sampled_nums, batch_size, stream);\n        cudaStreamSynchronize(stream);\n\n        sampleCpu(batch_size,\n                  vocab_size,\n                  expected_logits.data(),\n                  expected_indices.data(),\n                  expected_kept.data(),\n                  uniforms.data(),\n                  expected_output_ids.data(),\n                  sampled_logprobs.data(),\n                  sampled_indexes.data(),\n                  sampled_nums.data());\n\n        EXPECT_TRUE(checkSample(expected_output_ids.data(),\n                                output_ids.data(),\n                                batch_size,\n                                sampled_logprobs.data(),\n                                sampled_indexes.data(),\n                                sampled_nums.data(),\n                                output_sampled_logprobs.data(),\n                                output_sampled_indexes.data(),\n                                output_sampled_nums.data()));\n    }\n};\n\nTYPED_TEST_SUITE(SamplingTest, SamplingTypes);\n\nTYPED_TEST(SamplingTest, Single)\n{\n    this->runTest(1, 20, 5);\n};\n\nTYPED_TEST(SamplingTest, Batch)\n{\n    this->runTest(32, 9700, 1024);\n};\n\n}  // end of namespace\n"
  },
  {
    "path": "tests/csrc/unittests/test_sampling_layer.cu",
    "content": "#include <algorithm>  // std::min, std::max\n#include <iostream>   // snprintf\n#include <math.h>     // expf, log\n#include <stdlib.h>   // rand\n#include <string>     // std::string\n#include <vector>     // std::vector\n\n#include <cublasLt.h>\n#include <cublas_v2.h>\n#include <cuda_runtime.h>\n\n#include \"src/turbomind/kernels/sampling_topk_kernels.h\"\n#include \"src/turbomind/layers/DynamicDecodeLayer.h\"\n#include \"src/turbomind/macro.h\"\n#include \"src/turbomind/utils/Tensor.h\"\n#include \"src/turbomind/utils/cublasMMWrapper.h\"\n#include \"src/turbomind/utils/cuda_utils.h\"\n#include \"src/turbomind/utils/memory_utils.h\"\n\n#include \"gtest_utils.h\"\n\nusing namespace turbomind;\n\nstruct SamplingLayerTestParam {\n    size_t batch_size;\n    size_t vocab_size;\n    size_t beam_width;\n    size_t top_k;\n    float  top_p;\n    size_t output_len;\n\n    std::string toString()\n    {\n        return fmtstr(\"SamplingLayerTestParam[batch=%ld, vocab=%ld, beam=%ld, k=%ld, p=%3.1f, output_len=%ld]\",\n                      batch_size,\n                      vocab_size,\n                      beam_width,\n                      top_k,\n                      top_p,\n                      output_len);\n    }\n};\n\ntemplate<typename T>\nvoid computeProb(T* probs, T* logits, int batch_size, int vocab_size)\n{\n    // Compute the log probability from logits.\n    //   logits = batch_size x vocab_size vector.\n    //   logprobs = log(softmax(logits)) (softmax along with vocab dimension)\n    for (int bidx = 0; bidx < batch_size; ++bidx) {\n        float sum = 0.0f;\n        for (int i = 0; i < vocab_size; ++i) {\n            sum += expf((float)logits[bidx * vocab_size + i]);\n        }\n        for (int i = 0; i < vocab_size; ++i) {\n            int idx    = bidx * vocab_size + i;\n            probs[idx] = static_cast<T>(expf((float)logits[idx]) / (sum + EPSILON));\n        }\n    }\n}\n\ntemplate<typename T>\nvoid computeLogProb(T* logprobs, T* logits, int batch_size, int vocab_size)\n{\n    // Compute the log probability from logits.\n    //   logits = batch_size x vocab_size vector.\n    //   logprobs = log(softmax(logits)) (softmax along with vocab dimension)\n    for (int bidx = 0; bidx < batch_size; ++bidx) {\n        float sum = 0.0f;\n        for (int i = 0; i < vocab_size; ++i) {\n            sum += expf(logits[bidx * vocab_size + i]);\n        }\n        for (int i = 0; i < vocab_size; ++i) {\n            int idx       = bidx * vocab_size + i;\n            logprobs[idx] = static_cast<T>(logf(expf(logits[idx]) / (sum + EPSILON) + EPSILON));\n        }\n    }\n}\n\ntemplate<typename T>\nclass SamplingDecodeTest: public testing::Test {\nprotected:\n    unsigned long long              seed           = 0;\n    const static unsigned long long max_seed       = 30;\n    const size_t                    batch_size     = 6;\n    const size_t                    beam_width     = 1;\n    const size_t                    batchxbeam     = batch_size * beam_width;\n    const size_t                    vocab_size     = 8;\n    const size_t                    max_input_len  = 0;  // has no effect.\n    const size_t                    max_output_len = 3;\n    const size_t                    max_seq_len    = max_input_len + max_output_len;\n    const int                       end_id         = vocab_size - 1;\n    const DataType                  data_type      = getTensorType<T>();\n\n    // vocab size 8 & length 3\n    T* test_input_logits;\n\n    cudaStream_t                            stream;\n    ft::Allocator<ft::AllocatorType::CUDA>* allocator;\n    cublasHandle_t                          cublas_handle;\n    cublasLtHandle_t                        cublaslt_handle;\n    std::mutex*                             cublas_wrapper_mutex;\n    cublasMMWrapper*                        cublas_wrapper;\n    DynamicDecodeLayer<T>*                  dynamic_decode_layer;\n\n    int*   h_output_ids;\n    T*     h_logits;\n    T*     h_probs;\n    T*     h_log_probs;\n    float* h_cum_log_probs;\n    float* h_output_log_probs;\n\n    T*                  d_logits;\n    int*                d_input_lengths;\n    float*              d_cum_log_probs;\n    float*              d_output_log_probs;\n    int*                d_output_ids;\n    int*                d_end_ids;\n    curandState_t*      d_curand_state;\n    unsigned long long* d_random_seed;\n\n    void setup(unsigned long long seed = 0)\n    {\n        this->seed = seed;\n\n        check_cuda_error(cudaStreamCreate(&stream));\n        allocator = new Allocator<AllocatorType::CUDA>(getDevice());\n        allocator->setStream(stream);\n\n        struct cudaDeviceProp prop;\n        check_cuda_error(cudaGetDeviceProperties(&prop, 0));\n        check_cuda_error(cublasCreate(&cublas_handle));\n        check_cuda_error(cublasLtCreate(&cublaslt_handle));\n        check_cuda_error(cublasSetStream(cublas_handle, stream));\n        cublasAlgoMap cublas_algo_map(GEMM_CONFIG);\n        cublas_wrapper_mutex = new std::mutex();\n\n        cublas_wrapper = new cublasMMWrapper(\n            cublas_handle, cublaslt_handle, stream, &cublas_algo_map, cublas_wrapper_mutex, allocator);\n\n        dynamic_decode_layer = new DynamicDecodeLayer<T>(vocab_size,\n                                                         vocab_size,\n                                                         stream,\n                                                         cublas_wrapper,\n                                                         allocator,\n                                                         false,   // is_free_buffer_after_forward\n                                                         &prop);  // cuda_device_prop\n\n        h_output_ids       = new int[batchxbeam];\n        h_logits           = new T[batchxbeam * vocab_size];\n        h_probs            = new T[batchxbeam * vocab_size];\n        h_log_probs        = new T[batchxbeam * vocab_size];\n        h_cum_log_probs    = new float[batchxbeam];\n        h_output_log_probs = new float[max_output_len * batchxbeam];\n\n        // prob = (0.4, 0.3, 0.2, 0.1, ...)\n        test_input_logits = new T[24]{\n            -0.9163,  -1.2040,  -1.6094,  -2.3026,  -FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX,  // step 0\n            -FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX, -0.9163,  -1.2040,  -1.6094,  -2.3026,   // step 1\n            -FLT_MAX, -FLT_MAX, -0.9163,  -1.2040,  -1.6094,  -2.3026,  -FLT_MAX, -FLT_MAX   // step 2\n        };\n\n        d_logits           = reinterpret_cast<T*>(allocator->malloc(sizeof(T) * batchxbeam * vocab_size, true));\n        d_input_lengths    = reinterpret_cast<int*>(allocator->malloc(sizeof(int) * batchxbeam));\n        d_cum_log_probs    = reinterpret_cast<float*>(allocator->malloc(sizeof(float) * batchxbeam));\n        d_output_log_probs = reinterpret_cast<float*>(allocator->malloc(sizeof(float) * max_output_len * batchxbeam));\n        d_output_ids       = reinterpret_cast<int*>(allocator->malloc(sizeof(int) * max_seq_len * batchxbeam));\n        d_end_ids          = reinterpret_cast<int*>(allocator->malloc(sizeof(int) * batchxbeam));\n        d_curand_state     = reinterpret_cast<curandState_t*>(allocator->malloc(sizeof(curandState_t) * batch_size));\n        d_random_seed =\n            reinterpret_cast<unsigned long long*>(allocator->malloc(sizeof(unsigned long long) * batch_size));\n\n        // Init by zero.\n        cudaMemset(d_cum_log_probs, 0, sizeof(float) * batchxbeam);\n        cudaMemset(d_output_log_probs, 0, sizeof(float) * max_output_len * batchxbeam);\n        cudaMemset(d_output_ids, 0, sizeof(int) * max_seq_len * batchxbeam);\n        cudaMemset(d_random_seed, 0, sizeof(unsigned long long) * batch_size);\n        invokeCurandBatchInitialize(d_curand_state, batch_size, d_random_seed, stream);\n        deviceFill(d_end_ids, batchxbeam, end_id, stream);\n    }\n\n    void teardown()\n    {\n        delete[] test_input_logits;\n        delete[] h_output_ids;\n        delete[] h_logits;\n        delete[] h_probs;\n        delete[] h_log_probs;\n        delete[] h_cum_log_probs;\n        delete[] h_output_log_probs;\n        delete dynamic_decode_layer;\n        delete cublas_wrapper;\n        delete cublas_wrapper_mutex;\n        delete allocator;\n        check_cuda_error(cublasDestroy(cublas_handle));\n        check_cuda_error(cublasLtDestroy(cublaslt_handle));\n        check_cuda_error(cudaStreamDestroy(stream));\n    }\n\n    TensorMap* createInputTensors(\n        int* topk, size_t topk_size, float* topp, size_t topp_size, float* temperature, float* repetition_penalty)\n    {\n        // construct common input tensors\n        TensorMap* input_tensors = new TensorMap();\n        if (topk != nullptr) {\n            input_tensors->insert({\"runtime_top_k\", {MEMORY_CPU, TYPE_INT32, {topk_size}, topk}});\n        }\n        if (topp != nullptr) {\n            input_tensors->insert({\"runtime_top_p\", {MEMORY_CPU, TYPE_FP32, {topp_size}, topp}});\n        }\n        if (temperature != nullptr) {\n            input_tensors->insert({\"temperature\", Tensor{MEMORY_CPU, TYPE_FP32, {1}, temperature}});\n        }\n        if (repetition_penalty != nullptr) {\n            input_tensors->insert({\"repetition_penalty\", Tensor{MEMORY_CPU, TYPE_FP32, {1}, repetition_penalty}});\n        }\n        input_tensors->insert(\n            {\"logits\", Tensor{MEMORY_GPU, TYPE_FP32, {batch_size, beam_width, vocab_size}, d_logits}});\n        input_tensors->insert({\"embedding_bias\", Tensor{MEMORY_GPU, data_type, {vocab_size}, nullptr}});\n        input_tensors->insert({\"max_input_length\", Tensor{MEMORY_CPU, TYPE_INT32, {1}, &max_input_len}});\n        input_tensors->insert(\n            {\"input_lengths\", Tensor{MEMORY_GPU, TYPE_INT32, {batch_size, beam_width}, d_input_lengths}});\n        input_tensors->insert({\"end_id\", Tensor{MEMORY_CPU, TYPE_INT32, {batchxbeam}, &d_end_ids}});\n        input_tensors->insert({\"random_seed\", Tensor{MEMORY_CPU, TYPE_UINT64, {1}, &seed}});\n        return input_tensors;\n    }\n\n    TensorMap* createOutputTensors()\n    {\n        // construct common output tensors\n        TensorMap* output_tensors = new TensorMap();\n        output_tensors->insert(\n            {\"output_ids\", Tensor{MEMORY_GPU, TYPE_INT32, {max_seq_len, batch_size, beam_width}, d_output_ids}});\n        output_tensors->insert({\"finished\", Tensor{MEMORY_GPU, TYPE_BOOL, {batch_size * beam_width}, nullptr}});\n        output_tensors->insert(\n            {\"cum_log_probs\", Tensor{MEMORY_GPU, TYPE_FP32, {batch_size * beam_width}, d_cum_log_probs}});\n        output_tensors->insert(\n            {\"output_log_probs\",\n             Tensor{MEMORY_GPU, TYPE_FP32, {max_seq_len, batch_size, beam_width}, d_output_log_probs}});\n        output_tensors->insert({\"sequence_length\", Tensor{MEMORY_GPU, TYPE_INT32, {batch_size * beam_width}, nullptr}});\n        output_tensors->insert({\"curand_state\"}, {MEMORY_GPU, TYPE_VOID, {batch_size}, d_curand_state});\n        return output_tensors;\n    }\n\n    void batchH2Dcpy(T* dst, T* src, size_t m, size_t n)\n    {\n        for (size_t i = 0; i < m; ++i) {\n            cudaH2Dcpy(dst + i * n, src, n);\n        }\n    }\n\n    bool checkResult(int* d_output_ids, std::vector<std::set<int>>& expected_ids)\n    {\n        assert(expected_ids.size() == max_seq_len * batchxbeam);\n        int* h_output_ids = new int[max_seq_len * batchxbeam];\n        cudaD2Hcpy(h_output_ids, d_output_ids, max_seq_len * batchxbeam);\n        int failures = 0;\n        for (size_t i = 0; i < max_seq_len * batchxbeam; ++i) {\n            size_t        s     = i / batchxbeam;\n            size_t        b     = i % batchxbeam;\n            std::set<int> expts = expected_ids.at(i);\n            if (expts.count(h_output_ids[i]) == 0) {\n                if (failures < 10) {\n                    std::stringstream ss;\n                    ss << \" - Fail \"\n                       << \" (step=\" << s << \", batch=\" << b << \") \"\n                       << \"actual=\" << h_output_ids[i] << \", expected\";\n                    for (auto& expt : expts) {\n                        ss << \" \" << expt;\n                    }\n                    TM_LOG_DEBUG(\"%s\", ss.str().c_str());\n                }\n                ++failures;\n            }\n        }\n        TM_LOG_DEBUG(\n            \"check...%6s : failures: %d / %d\", failures == 0 ? \"....OK\" : \"FAILED\", failures, max_seq_len * batchxbeam);\n        delete[] h_output_ids;\n        return failures == 0;\n    }\n\npublic:\n    void runTest(std::vector<std::set<int>> expected_output_ids,\n                 int*                       top_ks,\n                 size_t                     top_k_size,\n                 float*                     top_ps,\n                 size_t                     top_p_size,\n                 float*                     temperature,\n                 float*                     repetition_penalty,\n                 bool                       use_local_batch = false)\n    {\n        size_t local_batch_size = use_local_batch ? batch_size / 3 : batch_size;\n        uint   ite              = use_local_batch ? 1 : 0;\n        for (unsigned long long seed = 0; seed < max_seed; ++seed) {\n            this->setup(seed);\n            size_t     step = max_input_len;\n            TensorMap* input_tensors =\n                createInputTensors(top_ks, top_k_size, top_ps, top_p_size, temperature, repetition_penalty);\n            input_tensors->insert({\"step\", Tensor{MEMORY_CPU, TYPE_INT32, {1}, &step}});\n            input_tensors->insert({\"ite\", Tensor{MEMORY_CPU, TYPE_UINT32, {1}, &ite}});\n            input_tensors->insert({\"local_batch_size\", Tensor{MEMORY_CPU, TYPE_INT32, {1}, &local_batch_size}});\n            TensorMap* output_tensors = createOutputTensors();\n\n            dynamic_decode_layer->setup(batch_size, beam_width, input_tensors);\n            for (step = max_input_len; step < max_output_len; ++step) {\n                // Reset by the test value since the sampling layer internally update the logit buffer.\n                batchH2Dcpy(input_tensors->at(\"logits\").getPtr<T>(),\n                            test_input_logits + step * vocab_size,\n                            batchxbeam,\n                            vocab_size);\n                dynamic_decode_layer->forward(output_tensors, input_tensors);\n            }\n            bool passed = checkResult(d_output_ids, expected_output_ids);\n            EXPECT_TRUE(passed) << \"Failed at seed \" << seed;\n#ifndef NDEBUG\n            if (!passed) {\n                TM_LOG_ERROR(\"actual output ids\");\n                printMatrix(d_output_ids, max_seq_len, batch_size, batch_size, true);\n            }\n#endif\n            delete output_tensors;\n            delete input_tensors;\n            this->teardown();\n        }\n    }\n};\n\nTYPED_TEST_SUITE(SamplingDecodeTest, SamplingTypes);\n\nTYPED_TEST(SamplingDecodeTest, TopK)\n{\n    int                        top_k = 2;\n    std::vector<std::set<int>> expected_output_ids{\n        // batch\n        //  0       1       2       3       4       5\n        {0, 1},\n        {0, 1},\n        {0, 1},\n        {0, 1},\n        {0, 1},\n        {0, 1},  // step 0\n        {4, 5},\n        {4, 5},\n        {4, 5},\n        {4, 5},\n        {4, 5},\n        {4, 5},  // step 1\n        {2, 3},\n        {2, 3},\n        {2, 3},\n        {2, 3},\n        {2, 3},\n        {2, 3}  // step 2\n    };\n    this->runTest(expected_output_ids, &top_k, 1, nullptr, 0, nullptr, nullptr);\n}\n\nTYPED_TEST(SamplingDecodeTest, BatchTopK)\n{\n    size_t                     batch_size = this->batch_size;\n    int*                       top_ks     = new int[batch_size]{2, 1, 1, 2, 1, 1};\n    std::vector<std::set<int>> expected_output_ids{\n        // batch\n        //  0    1    2       3    4    5\n        {0, 1},\n        {0},\n        {0},\n        {0, 1},\n        {0},\n        {0},  // step 0\n        {4, 5},\n        {4},\n        {4},\n        {4, 5},\n        {4},\n        {4},  // step 1\n        {2, 3},\n        {2},\n        {2},\n        {2, 3},\n        {2},\n        {2}  // step 2\n    };\n    this->runTest(expected_output_ids, top_ks, batch_size, nullptr, 0, nullptr, nullptr);\n    delete[] top_ks;\n}\n\nTYPED_TEST(SamplingDecodeTest, TopP)\n{\n    float                      top_p = 0.3;\n    std::vector<std::set<int>> expected_output_ids{\n        // batch\n        {0},\n        {0},\n        {0},\n        {0},\n        {0},\n        {0},  // step 0\n        {4},\n        {4},\n        {4},\n        {4},\n        {4},\n        {4},  // step 1\n        {2},\n        {2},\n        {2},\n        {2},\n        {2},\n        {2}  // step 2\n    };\n    this->runTest(expected_output_ids, nullptr, 0, &top_p, 1, nullptr, nullptr);\n}\n\nTYPED_TEST(SamplingDecodeTest, BatchTopP)\n{\n    size_t                     batch_size = this->batch_size;\n    float*                     top_ps     = new float[batch_size]{0.3f, 0.5f, 0.5f, 0.3f, 0.5f, 0.5f};\n    std::vector<std::set<int>> expected_output_ids{\n        {0},\n        {0, 1},\n        {0, 1},\n        {0},\n        {0, 1},\n        {0, 1},  // step 0\n        {4},\n        {4, 5},\n        {4, 5},\n        {4},\n        {4, 5},\n        {4, 5},  // step 1\n        {2},\n        {2, 3},\n        {2, 3},\n        {2},\n        {2, 3},\n        {2, 3}  // step 2\n    };\n    this->runTest(expected_output_ids, nullptr, 0, top_ps, batch_size, nullptr, nullptr);\n    delete[] top_ps;\n}\n\nTYPED_TEST(SamplingDecodeTest, TopKTopP)\n{\n    int                        top_k = 2;\n    float                      top_p = 0.3;\n    std::vector<std::set<int>> expected_output_ids{\n        // batch\n        {0},\n        {0},\n        {0},\n        {0},\n        {0},\n        {0},  // step 0\n        {4},\n        {4},\n        {4},\n        {4},\n        {4},\n        {4},  // step 1\n        {2},\n        {2},\n        {2},\n        {2},\n        {2},\n        {2}  // step 2\n    };\n    this->runTest(expected_output_ids, &top_k, 1, &top_p, 1, nullptr, nullptr);\n}\n\nTYPED_TEST(SamplingDecodeTest, BatchTopKTopP)\n{\n    size_t                     batch_size = this->batch_size;\n    int*                       top_ks     = new int[batch_size]{2, 2, 1, 2, 2, 1};\n    float                      top_p      = 0.3;\n    std::vector<std::set<int>> expected_output_ids{\n        // batch\n        {0},\n        {0},\n        {0},\n        {0},\n        {0},\n        {0},  // step 0\n        {4},\n        {4},\n        {4},\n        {4},\n        {4},\n        {4},  // step 1\n        {2},\n        {2},\n        {2},\n        {2},\n        {2},\n        {2}  // step 2\n    };\n    this->runTest(expected_output_ids, top_ks, batch_size, &top_p, 1, nullptr, nullptr);\n    delete[] top_ks;\n}\n\nTYPED_TEST(SamplingDecodeTest, TopKBatchTopP)\n{\n    size_t                     batch_size = this->batch_size;\n    int                        top_k      = 2;\n    float*                     top_ps     = new float[batch_size]{0.5, 0.3, 0.5, 0.5, 0.3, 0.5};\n    std::vector<std::set<int>> expected_output_ids{\n        // batch\n        {0, 1},\n        {0},\n        {0, 1},\n        {0, 1},\n        {0},\n        {0, 1},  // step 0\n        {4, 5},\n        {4},\n        {4, 5},\n        {4, 5},\n        {4},\n        {4, 5},  // step 1\n        {2, 3},\n        {2},\n        {2, 3},\n        {2, 3},\n        {2},\n        {2, 3}  // step 2\n    };\n    this->runTest(expected_output_ids, &top_k, 1, top_ps, batch_size, nullptr, nullptr);\n    delete[] top_ps;\n}\n\nTYPED_TEST(SamplingDecodeTest, BatchTopKBatchTopP)\n{\n    size_t                     batch_size = this->batch_size;\n    int*                       top_ks     = new int[batch_size]{2, 2, 0, 2, 2, 0};\n    float*                     top_ps     = new float[batch_size]{0.0, 0.3, 0.5, 0.0, 0.3, 0.5};\n    std::vector<std::set<int>> expected_output_ids{\n        // batch\n        {0, 1},\n        {0},\n        {0, 1},\n        {0, 1},\n        {0},\n        {0, 1},  // step 0\n        {4, 5},\n        {4},\n        {4, 5},\n        {4, 5},\n        {4},\n        {4, 5},  // step 1\n        {2, 3},\n        {2},\n        {2, 3},\n        {2, 3},\n        {2},\n        {2, 3}  // step 2\n    };\n    this->runTest(expected_output_ids, top_ks, batch_size, top_ps, batch_size, nullptr, nullptr);\n    delete[] top_ks;\n    delete[] top_ps;\n}\n\nTYPED_TEST(SamplingDecodeTest, InvalidArgsZeroTopK)\n{\n    size_t                     batch_size = this->batch_size;\n    int                        top_k      = 0;\n    std::vector<std::set<int>> expected_output_ids{\n        // batch\n        {0},\n        {0},\n        {0},\n        {0},\n        {0},\n        {0},  // step 0\n        {4},\n        {4},\n        {4},\n        {4},\n        {4},\n        {4},  // step 1\n        {2},\n        {2},\n        {2},\n        {2},\n        {2},\n        {2}  // step 2\n    };\n    this->runTest(expected_output_ids, &top_k, 1, nullptr, 0, nullptr, nullptr);\n}\n\nTYPED_TEST(SamplingDecodeTest, InvalidArgsZeroTopP)\n{\n    size_t                     batch_size = this->batch_size;\n    float                      top_p      = 0;\n    std::vector<std::set<int>> expected_output_ids{\n        // batch\n        {0},\n        {0},\n        {0},\n        {0},\n        {0},\n        {0},  // step 0\n        {4},\n        {4},\n        {4},\n        {4},\n        {4},\n        {4},  // step 1\n        {2},\n        {2},\n        {2},\n        {2},\n        {2},\n        {2}  // step 2\n    };\n    this->runTest(expected_output_ids, nullptr, 0, &top_p, 1, nullptr, nullptr);\n}\n\nTYPED_TEST(SamplingDecodeTest, InvalidArgsZeroTopKTopP)\n{\n    size_t                     batch_size = this->batch_size;\n    int                        top_k      = 0;\n    float                      top_p      = 0;\n    std::vector<std::set<int>> expected_output_ids{\n        // batch\n        {0},\n        {0},\n        {0},\n        {0},\n        {0},\n        {0},  // step 0\n        {4},\n        {4},\n        {4},\n        {4},\n        {4},\n        {4},  // step 1\n        {2},\n        {2},\n        {2},\n        {2},\n        {2},\n        {2}  // step 2\n    };\n    this->runTest(expected_output_ids, &top_k, 1, &top_p, 1, nullptr, nullptr);\n}\n\nTYPED_TEST(SamplingDecodeTest, InvalidArgsZeroBatchTopKTopP)\n{\n    size_t                     batch_size = this->batch_size;\n    int*                       top_ks     = new int[batch_size]{0, 0, 0, 0, 0, 0};\n    float                      top_p      = 0;\n    std::vector<std::set<int>> expected_output_ids{\n        // batch\n        {0},\n        {0},\n        {0},\n        {0},\n        {0},\n        {0},  // step 0\n        {4},\n        {4},\n        {4},\n        {4},\n        {4},\n        {4},  // step 1\n        {2},\n        {2},\n        {2},\n        {2},\n        {2},\n        {2}  // step 2\n    };\n    this->runTest(expected_output_ids, top_ks, batch_size, &top_p, 1, nullptr, nullptr);\n    delete[] top_ks;\n}\n\nTYPED_TEST(SamplingDecodeTest, InvalidArgsZeroTopKBatchTopP)\n{\n    size_t                     batch_size = this->batch_size;\n    int                        top_k      = 0;\n    float*                     top_ps     = new float[batch_size]{0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f};\n    std::vector<std::set<int>> expected_output_ids{\n        // batch\n        {0},\n        {0},\n        {0},\n        {0},\n        {0},\n        {0},  // step 0\n        {4},\n        {4},\n        {4},\n        {4},\n        {4},\n        {4},  // step 1\n        {2},\n        {2},\n        {2},\n        {2},\n        {2},\n        {2}  // step 2\n    };\n    this->runTest(expected_output_ids, &top_k, 1, top_ps, batch_size, nullptr, nullptr);\n    delete[] top_ps;\n}\n\nTYPED_TEST(SamplingDecodeTest, InvalidArgsBatchTopKContainZero)\n{\n    size_t                     batch_size = this->batch_size;\n    int*                       top_ks     = new int[batch_size]{2, 1, 0, 0, 2, 1};\n    std::vector<std::set<int>> expected_output_ids{\n        // batch\n        {0, 1},\n        {0},\n        {0},\n        {0},\n        {0, 1},\n        {0},  // step 0\n        {4, 5},\n        {4},\n        {4},\n        {4},\n        {4, 5},\n        {4},  // step 1\n        {2, 3},\n        {2},\n        {2},\n        {2},\n        {2, 3},\n        {2}  // step 2\n    };\n    this->runTest(expected_output_ids, top_ks, batch_size, nullptr, 0, nullptr, nullptr);\n    delete[] top_ks;\n}\n\nTYPED_TEST(SamplingDecodeTest, InvalidArgsBatchTopPContainZero)\n{\n    size_t                     batch_size = this->batch_size;\n    float*                     top_ps     = new float[batch_size]{0.5f, 0.5f, 0.0f, 0.5f, 0.0f, 0.3f};\n    std::vector<std::set<int>> expected_output_ids{\n        // batch\n        {0, 1},\n        {0, 1},\n        {0},\n        {0, 1},\n        {0},\n        {0},  // step 0\n        {4, 5},\n        {4, 5},\n        {4},\n        {4, 5},\n        {4},\n        {4},  // step 1\n        {2, 3},\n        {2, 3},\n        {2},\n        {2, 3},\n        {2},\n        {2}  // step 2\n    };\n    this->runTest(expected_output_ids, nullptr, 0, top_ps, batch_size, nullptr, nullptr);\n    delete[] top_ps;\n}\n\nTYPED_TEST(SamplingDecodeTest, InvalidArgsBatchTopKTopPContainZero)\n{\n    size_t                     batch_size = this->batch_size;\n    int*                       top_ks     = new int[batch_size]{2, 2, 1, 0, 2, 0};\n    float                      top_p      = 0.0;\n    std::vector<std::set<int>> expected_output_ids{\n        // batch\n        {0, 1},\n        {0, 1},\n        {0},\n        {0},\n        {0, 1},\n        {0},  // step 0\n        {4, 5},\n        {4, 5},\n        {4},\n        {4},\n        {4, 5},\n        {4},  // step 1\n        {2, 3},\n        {2, 3},\n        {2},\n        {2},\n        {2, 3},\n        {2}  // step 2\n    };\n    this->runTest(expected_output_ids, top_ks, batch_size, &top_p, 1, nullptr, nullptr);\n    delete[] top_ks;\n}\n\nTYPED_TEST(SamplingDecodeTest, InvalidArgsTopKBatchTopPContainZero)\n{\n    size_t                     batch_size = this->batch_size;\n    int                        top_k      = 0;\n    float*                     top_ps     = new float[batch_size]{0.0, 0.3, 0.5, 0.0, 0.3, 0.5};\n    std::vector<std::set<int>> expected_output_ids{\n        // batch\n        {0},\n        {0},\n        {0, 1},\n        {0},\n        {0},\n        {0, 1},  // step 0\n        {4},\n        {4},\n        {4, 5},\n        {4},\n        {4},\n        {4, 5},  // step 1\n        {2},\n        {2},\n        {2, 3},\n        {2},\n        {2},\n        {2, 3}  // step 2\n    };\n    this->runTest(expected_output_ids, &top_k, 1, top_ps, batch_size, nullptr, nullptr);\n    delete[] top_ps;\n}\n\nTYPED_TEST(SamplingDecodeTest, InvalidArgsBatchTopKBatchTopPContainZero)\n{\n    size_t                     batch_size = this->batch_size;\n    int*                       top_ks     = new int[batch_size]{0, 2, 1, 2, 2, 0};\n    float*                     top_ps     = new float[batch_size]{0.0, 0.3, 0.9, 0.0, 0.3, 0.5};\n    std::vector<std::set<int>> expected_output_ids{\n        // batch\n        {0},\n        {0},\n        {0},\n        {0, 1},\n        {0},\n        {0, 1},  // step 0\n        {4},\n        {4},\n        {4},\n        {4, 5},\n        {4},\n        {4, 5},  // step 1\n        {2},\n        {2},\n        {2},\n        {2, 3},\n        {2},\n        {2, 3}  // step 2\n    };\n    this->runTest(expected_output_ids, top_ks, batch_size, top_ps, batch_size, nullptr, nullptr);\n    delete[] top_ks;\n    delete[] top_ps;\n}\n\ntemplate<typename T>\nclass SamplingDecodeTest2: public FtTestBase {\n\npublic:\n    void SetUp() override\n    {\n        FtTestBase::SetUp();\n        check_cuda_error(cudaGetDeviceProperties(&prop, 0));\n        check_cuda_error(cublasCreate(&cublas_handle));\n        check_cuda_error(cublasLtCreate(&cublaslt_handle));\n        check_cuda_error(cublasSetStream(cublas_handle, stream));\n        cublas_algo_map      = new cublasAlgoMap(\"\");\n        cublas_wrapper_mutex = new std::mutex();\n        cublas_wrapper       = new cublasMMWrapper(\n            cublas_handle, cublaslt_handle, stream, cublas_algo_map, cublas_wrapper_mutex, allocator);\n    }\n    void TearDown() override\n    {\n        delete cublas_wrapper;\n        delete cublas_wrapper_mutex;\n        delete cublas_algo_map;\n        check_cuda_error(cublasLtDestroy(cublaslt_handle));\n        check_cuda_error(cublasDestroy(cublas_handle));\n        FtTestBase::TearDown();\n    }\n\nprotected:\n    using FtTestBase::stream;\n    using FtTestBase::allocator;\n\n    struct cudaDeviceProp prop;\n    cublasHandle_t        cublas_handle;\n    cublasLtHandle_t      cublaslt_handle;\n    cublasAlgoMap*        cublas_algo_map;\n    std::mutex*           cublas_wrapper_mutex;\n    cublasMMWrapper*      cublas_wrapper;\n\n    DataType data_type = getTensorType<T>();\n\n    size_t batch_size;\n    size_t beam_width;\n    size_t batchxbeam;\n    size_t vocab_size;\n    size_t max_input_len;\n    size_t max_output_len;\n    size_t max_seq_len;\n\n    uint  top_k;\n    float top_p;\n    float temperature;\n    float repetition_penalty;\n    int   end_id;\n\n    T*     h_logits;\n    T*     h_probs;\n    T*     h_log_probs;\n    float* h_cum_log_probs;\n    float* h_output_log_probs;\n    int*   h_output_ids;\n\n    T*                  d_logits;\n    int*                d_input_lengths;\n    float*              d_cum_log_probs;\n    float*              d_output_log_probs;\n    int*                d_output_ids;\n    int*                d_end_ids;\n    curandState_t*      d_curand_state;\n    unsigned long long* d_random_seed;\n\n    void setup(SamplingLayerTestParam param)\n    {\n        batch_size     = param.batch_size;\n        beam_width     = param.beam_width;\n        batchxbeam     = batch_size * param.beam_width;\n        vocab_size     = param.vocab_size;\n        max_input_len  = 0;\n        max_output_len = param.output_len;\n        max_seq_len    = max_input_len + max_output_len;\n\n        top_k = param.top_k;\n        top_p = param.top_p;\n        // use default values having no effect.\n        temperature        = 1.0f;\n        repetition_penalty = 1.0f;\n        end_id             = 0;\n\n        h_logits     = new T[batchxbeam * vocab_size];\n        h_output_ids = new int[batchxbeam];\n\n        d_logits        = reinterpret_cast<T*>(allocator->malloc(sizeof(T) * batchxbeam * vocab_size));\n        d_input_lengths = reinterpret_cast<int*>(allocator->malloc(sizeof(int) * batchxbeam));\n        d_output_ids    = reinterpret_cast<int*>(allocator->malloc(sizeof(int) * max_seq_len * batchxbeam));\n        d_end_ids       = reinterpret_cast<int*>(allocator->malloc(sizeof(int) * batch_size));\n        d_curand_state  = reinterpret_cast<curandState_t*>(allocator->malloc(sizeof(curandState_t) * batch_size));\n        d_random_seed =\n            reinterpret_cast<unsigned long long*>(allocator->malloc(sizeof(unsigned long long) * batch_size));\n\n        // Init by zero.\n        deviceFill(d_input_lengths, batchxbeam, 0, stream);\n        deviceFill(d_output_ids, max_seq_len * batchxbeam, 0, stream);\n        deviceFill(d_end_ids, batch_size, end_id);\n        cudaMemset(d_random_seed, 0, sizeof(unsigned long long) * batch_size);\n    }\n\n    void teardown()\n    {\n        delete[] h_logits;\n        delete[] h_output_ids;\n    }\n\n    void runCurandTest(SamplingLayerTestParam param, bool use_local_batch, bool use_single_random_seed)\n    {\n        setup(param);\n        const DataType data_type = getTensorType<T>();\n\n        const size_t local_batch_size = use_local_batch ? 3 : batch_size;\n        assert(batch_size % local_batch_size == 0);\n\n        DynamicDecodeLayer<T>* dynamic_decode_layer = new DynamicDecodeLayer<T>(vocab_size,\n                                                                                vocab_size,\n                                                                                stream,\n                                                                                cublas_wrapper,\n                                                                                allocator,\n                                                                                false,   // is_free_buffer_after_forward\n                                                                                &prop);  // cuda_device_prop\n\n        // Prepare decoding arguments\n        const size_t        random_seed_size = use_single_random_seed ? 1 : batch_size;\n        const size_t        period_size      = 3;\n        unsigned long long* random_seed      = new unsigned long long[random_seed_size];\n        for (size_t i = 0; i < random_seed_size; ++i) {\n            random_seed[i] = i / period_size;\n        }\n        cudaH2Dcpy(d_random_seed, random_seed, random_seed_size);\n        if (use_single_random_seed) {\n            invokeCurandInitialize(d_curand_state, batch_size, random_seed[0], stream);\n        }\n        else {\n            invokeCurandBatchInitialize(d_curand_state, batch_size, d_random_seed, stream);\n        }\n        sync_check_cuda_error();\n\n        TensorMap runtime_args;\n        runtime_args.insert({\"random_seed\", Tensor(MEMORY_CPU, TYPE_UINT64, {random_seed_size}, random_seed)});\n        runtime_args.insert({\"runtime_top_k\", Tensor(MEMORY_CPU, TYPE_UINT32, {1}, &top_k)});\n        runtime_args.insert({\"runtime_top_p\", Tensor(MEMORY_CPU, TYPE_FP32, {1}, &top_p)});\n        dynamic_decode_layer->setup(batch_size, beam_width, &runtime_args);\n\n        for (size_t step = max_input_len; step < max_output_len; ++step) {\n            const size_t iteration_num = batch_size / local_batch_size;\n            initRandom(h_logits, beam_width * vocab_size, -3.0f, 3.0f);\n            tile(h_logits, batch_size, beam_width * vocab_size);\n            cudaH2Dcpy(d_logits, h_logits, batchxbeam * vocab_size);\n\n            for (uint ite = 0; ite < iteration_num; ++ite) {\n                TensorMap dynamic_decode_input_tensors(\n                    {{\"logits\", Tensor{MEMORY_GPU, data_type, {batch_size, beam_width, vocab_size}, d_logits}},\n                     {\"embedding_bias\", Tensor{MEMORY_GPU, data_type, {vocab_size}, nullptr}},\n                     {\"step\", Tensor{MEMORY_CPU, TYPE_INT32, {1}, &step}},\n                     {\"max_input_length\", Tensor{MEMORY_CPU, TYPE_INT32, {1}, &max_input_len}},\n                     {\"input_lengths\", Tensor{MEMORY_GPU, TYPE_INT32, {batch_size, beam_width}, d_input_lengths}},\n                     {\"ite\", Tensor{MEMORY_CPU, TYPE_UINT32, {1}, &ite}},\n                     {\"local_batch_size\", Tensor{MEMORY_CPU, TYPE_INT32, {1}, &local_batch_size}},\n                     {\"end_id\", Tensor{MEMORY_GPU, TYPE_INT32, {batch_size}, d_end_ids}},\n                     {\"random_seed\", {MEMORY_CPU, TYPE_UINT64, {random_seed_size}, random_seed}},\n                     {\"runtime_top_k\", {MEMORY_CPU, TYPE_UINT32, {1}, &top_k}},\n                     {\"runtime_top_p\", {MEMORY_CPU, TYPE_FP32, {1}, &top_p}}});\n\n                // common outputs\n                TensorMap dynamic_decode_output_tensors(\n                    {{\"output_ids\",\n                      Tensor{MEMORY_GPU, TYPE_INT32, {max_seq_len, batch_size, beam_width}, d_output_ids}},\n                     {\"finished\", Tensor{MEMORY_GPU, TYPE_BOOL, {batch_size * beam_width}, nullptr}},\n                     {\"sequence_length\", Tensor{MEMORY_GPU, TYPE_INT32, {batch_size * beam_width}, nullptr}},\n                     {\"curand_state\", {MEMORY_GPU, TYPE_VOID, {batch_size}, d_curand_state}}});\n\n                dynamic_decode_layer->forward(&dynamic_decode_output_tensors, &dynamic_decode_input_tensors);\n                sync_check_cuda_error();\n\n                // check results.\n                cudaD2Hcpy(h_output_ids,\n                           dynamic_decode_output_tensors.at(\"output_ids\").getPtrWithOffset<int>(step * batchxbeam),\n                           batchxbeam);\n            }\n            // The same seed produces the same random number.\n            for (size_t i = 0; i + period_size - 1 < batchxbeam; i += period_size) {\n                for (size_t j = 1; j < period_size; ++j) {\n                    EXPECT_TRUE(h_output_ids[i] == h_output_ids[i + j])\n                        << fmtstr(\"Fail at step %u val[%d]=%d <> val[%d]=%d\",\n                                  step,\n                                  i,\n                                  h_output_ids[i],\n                                  i + j,\n                                  h_output_ids[i + j]);\n                }\n            }\n        }\n        delete dynamic_decode_layer;\n        delete[] random_seed;\n        teardown();\n    }\n\n    void runCumLogProbTest(SamplingLayerTestParam param)\n    {\n        setup(param);\n        unsigned long long     seed                 = 43;\n        const DataType         data_type            = getTensorType<T>();\n        DynamicDecodeLayer<T>* dynamic_decode_layer = new DynamicDecodeLayer<T>(vocab_size,\n                                                                                vocab_size,\n                                                                                stream,\n                                                                                cublas_wrapper,\n                                                                                allocator,\n                                                                                false,   // is_free_buffer_after_forward\n                                                                                &prop);  // cuda_device_prop\n\n        // Logit values in the host of shape ((batch_size x beam) x vocab_size) where beam = 1.\n        // T* h_logits = new T[batch_size * beam_width * vocab_size];\n        T*     h_probs                = new T[batch_size * beam_width * vocab_size];\n        T*     h_log_probs            = new T[batch_size * beam_width * vocab_size];\n        float* h_cum_log_probs        = new float[batch_size * beam_width];\n        float* h_output_log_probs     = new float[max_output_len * batch_size * beam_width];\n        float* expected_cum_log_probs = new float[batch_size * beam_width];\n        initRandom(h_logits, batch_size * beam_width * vocab_size, -3.0f, 3.0f);\n        computeProb(h_probs, h_logits, batch_size * beam_width, vocab_size);\n        computeLogProb(h_log_probs, h_logits, batch_size * beam_width, vocab_size);\n        std::fill_n(expected_cum_log_probs, batch_size * beam_width, 0);\n\n        int* tiled_input_lengths_buf = reinterpret_cast<int*>(allocator->malloc(sizeof(int) * batch_size * beam_width));\n        float* cum_log_probs = reinterpret_cast<float*>(allocator->malloc(sizeof(float) * batch_size * beam_width));\n        float* output_log_probs =\n            reinterpret_cast<float*>(allocator->malloc(sizeof(float) * max_output_len * batch_size * beam_width));\n\n        int* output_ids =\n            reinterpret_cast<int*>(allocator->malloc(sizeof(int) * max_seq_len * batch_size * beam_width));\n        int* h_output_ids = new int[batch_size * beam_width];\n\n        int* end_ids = reinterpret_cast<int*>(allocator->malloc(sizeof(int) * batch_size));\n        deviceFill(end_ids, batch_size, end_id);\n\n        // Init by zero.\n        cudaMemset(cum_log_probs, 0, sizeof(float) * batch_size * beam_width);\n        cudaMemset(output_log_probs, 0, sizeof(float) * max_output_len * batch_size * beam_width);\n        cudaMemset(output_ids, 0, sizeof(int) * max_seq_len * batch_size * beam_width);\n\n        TensorMap input_tensors({{\"random_seed\", {MEMORY_CPU, TYPE_INT32, {1}, &seed}},\n                                 {\"runtime_top_k\", {MEMORY_CPU, TYPE_UINT32, {1}, &top_k}},\n                                 {\"runtime_top_p\", {MEMORY_CPU, TYPE_FP32, {1}, &top_p}},\n                                 {\"temperature\", Tensor{MEMORY_CPU, TYPE_FP32, {1}, &temperature}},\n                                 {\"repetition_penalty\", Tensor{MEMORY_CPU, TYPE_FP32, {1}, &repetition_penalty}}});\n        dynamic_decode_layer->setup(batch_size, beam_width, &input_tensors);\n\n        for (size_t step = max_input_len; step < max_output_len; ++step) {\n            uint ite = 0;\n            // Reset by the test value since the sampling layer internally update the logit buffer (making it log-prob).\n            cudaH2Dcpy(d_logits, h_logits, batch_size * beam_width * vocab_size);\n            TensorMap dynamic_decode_input_tensors(\n                {{\"logits\", Tensor{MEMORY_GPU, TYPE_FP32, {batch_size, beam_width, vocab_size}, d_logits}},\n                 {\"embedding_bias\", Tensor{MEMORY_GPU, data_type, {vocab_size}, nullptr}},\n                 {\"step\", Tensor{MEMORY_CPU, TYPE_INT32, {1}, &step}},\n                 {\"max_input_length\", Tensor{MEMORY_CPU, TYPE_INT32, {1}, &max_input_len}},\n                 {\"input_lengths\", Tensor{MEMORY_GPU, TYPE_INT32, {batch_size, beam_width}, tiled_input_lengths_buf}},\n                 {\"ite\", Tensor{MEMORY_CPU, TYPE_UINT32, {1}, &ite}},\n                 {\"local_batch_size\", Tensor{MEMORY_CPU, TYPE_INT32, {1}, &batch_size}},\n                 {\"end_id\", Tensor{MEMORY_GPU, TYPE_INT32, {batch_size}, end_ids}},\n                 {\"random_seed\", {MEMORY_CPU, TYPE_UINT64, {1}, &seed}},\n                 {\"runtime_top_k\", {MEMORY_CPU, TYPE_UINT32, {1}, &top_k}},\n                 {\"runtime_top_p\", {MEMORY_CPU, TYPE_FP32, {1}, &top_p}},\n                 {\"temperature\", Tensor{MEMORY_CPU, TYPE_FP32, {1}, &temperature}},\n                 {\"repetition_penalty\", Tensor{MEMORY_CPU, TYPE_FP32, {1}, &repetition_penalty}}});\n\n            // common outputs\n            TensorMap dynamic_decode_output_tensors(\n                {{\"output_ids\", Tensor{MEMORY_GPU, TYPE_INT32, {max_seq_len, batch_size, beam_width}, output_ids}},\n                 {\"finished\", Tensor{MEMORY_GPU, TYPE_BOOL, {batch_size * beam_width}, nullptr}},\n                 {\"cum_log_probs\", Tensor{MEMORY_GPU, TYPE_FP32, {batch_size * beam_width}, cum_log_probs}},\n                 {\"output_log_probs\",\n                  Tensor{MEMORY_GPU, TYPE_FP32, {max_seq_len, batch_size, beam_width}, output_log_probs}},\n                 {\"sequence_length\", Tensor{MEMORY_GPU, TYPE_INT32, {batch_size * beam_width}, nullptr}},\n                 {\"curand_state\", {MEMORY_GPU, TYPE_VOID, {batch_size}, d_curand_state}}});\n\n            dynamic_decode_layer->forward(&dynamic_decode_output_tensors, &dynamic_decode_input_tensors);\n\n            TM_LOG_DEBUG(\"Step %2d generated ids\", step);\n            cudaD2Hcpy(\n                h_output_ids,\n                dynamic_decode_output_tensors.at(\"output_ids\").getPtrWithOffset<int>(step * (batch_size * beam_width)),\n                batch_size * beam_width);\n            cudaD2Hcpy(h_cum_log_probs, cum_log_probs, batch_size * beam_width);\n            cudaD2Hcpy(h_output_log_probs, output_log_probs, max_output_len * batch_size * beam_width);\n            for (size_t i = 0; i < batch_size * beam_width; ++i) {\n                int idx = i * vocab_size + h_output_ids[i];\n                expected_cum_log_probs[i] += (float)h_log_probs[idx];\n                TM_LOG_DEBUG(\"| step %2d batch %2d idx %7d id %6d | log-prob %9.4f (expt: %9.4f) \"\n                             \"| cum-log-prob %9.4f (expt: %9.4f) | prob %9.4e\",\n                             (int)step,\n                             (int)i,\n                             (int)idx,\n                             (int)h_output_ids[i],\n                             h_output_log_probs[step * batch_size * beam_width + i],\n                             (float)h_log_probs[idx],\n                             h_cum_log_probs[i],\n                             expected_cum_log_probs[i],\n                             (float)h_probs[idx]);\n            }\n            TM_LOG_DEBUG(\"\");\n        }\n\n        bool passed = checkResult(param.toString(), cum_log_probs, expected_cum_log_probs, batch_size * beam_width);\n        EXPECT_TRUE(passed);\n\n        delete[] expected_cum_log_probs;\n        delete[] h_output_log_probs;\n        delete[] h_cum_log_probs;\n        delete[] h_log_probs;\n        delete[] h_probs;\n\n        delete dynamic_decode_layer;\n    }\n};\n\nTYPED_TEST_SUITE(SamplingDecodeTest2, SamplingTypes);\n\nTYPED_TEST(SamplingDecodeTest2, CorrectnessSingleRandTopK)\n{\n    // test TopKSampling\n    this->runCurandTest({113, 1201, 1, 3, 1.0f, 5}, false, true);\n}\n\nTYPED_TEST(SamplingDecodeTest2, CorrectnessSingleRandTopP)\n{\n    this->runCurandTest({113, 1201, 1, 0, 1.0f, 5}, false, true);\n}\n\nTYPED_TEST(SamplingDecodeTest2, CorrectnessBatchRandTopK)\n{\n    // test TopKSampling\n    this->runCurandTest({113, 1201, 1, 3, 1.0f, 5}, false, false);\n}\n\nTYPED_TEST(SamplingDecodeTest2, CorrectnessBatchRandTopP)\n{\n    this->runCurandTest({113, 1201, 1, 0, 1.0f, 5}, false, false);\n}\n"
  },
  {
    "path": "tests/csrc/unittests/unittest_utils.h",
    "content": "/*\n * Copyright (c) 2022-2023, NVIDIA CORPORATION.  All rights reserved.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n#pragma once\n\n#include <algorithm>  // min, max\n#include <assert.h>   // assert\n#include <float.h>    // FLT_MAX\n#include <iostream>   // snprintf\n#include <limits>     // numeric_limits\n#include <math.h>     // expf, log\n#include <stdlib.h>   // rand\n#include <string>     // string\n#include <vector>     // vector\n\n#include \"src/turbomind/utils/cuda_utils.h\"\n#include \"src/turbomind/utils/memory_utils.h\"\n#include \"src/turbomind/utils/string_utils.h\"\n\n#define PRINT_LIMIT 16\n#define EPSILON (1e-20)\n#define EPSILON_FP16 (1e-10)\n\nusing namespace turbomind;\n\nclass TestFailureError: public std::exception {\nprivate:\n    std::string msg_;\n\npublic:\n    explicit TestFailureError() = default;\n    explicit TestFailureError(std::string name, std::string msg = \"\")\n    {\n        msg_ = fmtstr(\"TEST FAIL [%s] %s\", name.c_str(), msg.c_str());\n    }\n    const char* what() const throw()\n    {\n        return msg_.c_str();\n    }\n};\n\n#define EXPECT_TRUE(cond)                                                                                              \\\n    do {                                                                                                               \\\n        if (!(cond)) {                                                                                                 \\\n            TM_LOG_ERROR(\"TEST FAIL [%s]: %s at %s:%d\", __func__, #cond, __FILE__, __LINE__);                          \\\n            throw TestFailureError(__func__);                                                                          \\\n        }                                                                                                              \\\n    } while (false)\n\n#define EXPECT_FALSE(cond)                                                                                             \\\n    do {                                                                                                               \\\n        if (cond) {                                                                                                    \\\n            TM_LOG_ERROR(\"TEST FAIL [%s]: %s at %s:%d\", __func__, #cond, __FILE__, __LINE__);                          \\\n            throw TestFailureError(__func__);                                                                          \\\n        }                                                                                                              \\\n    } while (false)\n\nbool almostEqual(float a, float b, float atol = 1e-5, float rtol = 1e-8)\n{\n    // Params: a = value to compare and b = reference\n    // This function follows implementation of numpy.isclose(), which checks\n    //   abs(a - b) <= (atol + rtol * abs(b)).\n    // Note that the inequality above is asymmetric where b is considered as\n    // a reference value. To account into both absolute/relative errors, it\n    // uses absolute tolerance and relative tolerance at the same time. The\n    // default values of atol and rtol borrowed from numpy.isclose(). For the\n    // case of nan value, the result will be true.\n    if (isnan(a) && isnan(b)) {\n        return true;\n    }\n    return fabs(a - b) <= (atol + rtol * fabs(b));\n}\n\ntemplate<typename T>\nbool checkResult(std::string name, T* out, T* ref, size_t size, float atol, float rtol)\n{\n    size_t failures     = 0;\n    float  relative_gap = 0.0f;\n    ;\n\n    for (size_t i = 0; i < size; ++i) {\n        // The values for the output and the reference.\n        float a = (float)out[i];\n        float b = (float)ref[i];\n\n        bool ok = almostEqual(a, b, atol, rtol);\n        // Print the error.\n        if (!ok && failures < 4) {\n            TM_LOG_ERROR(\">> invalid result for i=%lu:\", i);\n            TM_LOG_ERROR(\">>    found......: %10.6f\", a);\n            TM_LOG_ERROR(\">>    expected...: %10.6f\", b);\n            TM_LOG_ERROR(\">>    error......: %.6f\", fabsf(a - b));\n            TM_LOG_ERROR(\">>    tol........: %.6f\", atol + rtol * fabs(b));\n        }\n        // Update the number of failures.\n        failures += ok ? 0 : 1;\n        // Update the relative gap.\n        relative_gap += fabsf(a - b) / (fabsf(b) + EPSILON);\n    }\n\n    relative_gap /= size;\n\n    // Allow not matched up to 1% elements.\n    size_t tol_failures = (size_t)(0.01 * size);\n    TM_LOG_INFO(\"check...%6s : %-50s (failures: %.2f%% atol: %.2e rtol: %.2e rel_gap: %.2e%%)\",\n                failures <= tol_failures ? \"....OK\" : \"FAILED\",\n                name.c_str(),\n                100. * failures / size,\n                atol,\n                rtol,\n                100. * relative_gap);\n    return failures <= tol_failures;\n}\n\ntemplate<typename T>\nbool checkResult(std::string name, T* out, T* ref, size_t size, bool device_out = true, bool device_ref = false)\n{\n    bool  is_fp32 = sizeof(T) == 4;\n    float atol    = is_fp32 ? 1e-4f : 1e-3f;\n    float rtol    = is_fp32 ? 1e-2f : 1e-1f;\n\n    T* h_out = nullptr;\n    if (device_out) {\n        h_out = new T[size];\n        cudaMemcpy(h_out, out, sizeof(T) * size, cudaMemcpyDeviceToHost);\n        out = h_out;\n    }\n    T* h_ref = nullptr;\n    if (device_ref) {\n        h_ref = new T[size];\n        cudaMemcpy(h_ref, ref, sizeof(T) * size, cudaMemcpyDeviceToHost);\n        ref = h_ref;\n    }\n    bool is_ok = checkResult(name, out, ref, size, atol, rtol);\n    if (h_out != nullptr) {\n        delete[] h_out;\n    }\n    if (h_ref != nullptr) {\n        delete[] h_ref;\n    }\n    return is_ok;\n}\n\ntemplate<typename T>\nvoid initRandom(T* ptr, size_t size, float minval, float maxval)\n{\n    for (size_t i = 0; i < size; ++i) {\n        float val = static_cast<float>(rand()) / static_cast<float>(RAND_MAX);\n        val *= (maxval - minval);\n        ptr[i] = static_cast<T>(minval + val);\n    }\n}\n\nvoid initRandomInt(int* ptr, size_t size, int minval, int maxval)\n{\n    assert(minval < maxval);\n    int mod = maxval - minval;\n    for (size_t i = 0; i < size; ++i) {\n        ptr[i] = minval + rand() % mod;\n    }\n}\n\ntemplate<typename T>\nvoid tile(T* x, int m, int n)\n{\n    for (int i = 1; i < m; ++i) {\n        for (int j = 0; j < n; ++j) {\n            x[i * n + j] = x[j];\n        }\n    }\n}\n\ntemplate<typename T>\nvoid tile(T* dst, T* src, int m, int n)\n{\n    for (int i = 1; i < m; ++i) {\n        for (int j = 0; j < n; ++j) {\n            dst[i * n + j] = src[j];\n        }\n    }\n}\n\n#define HALF_FLT_MAX 65504.0f\n\ntemplate<typename T>\nbool isHalf()\n{\n    return std::is_same<T, half>::value;\n}\n\ntemplate<typename T>\nstatic inline void printMatrixWithLimit(T* ptr, int m, int k, int stride, bool is_device_ptr)\n{\n    printMatrix(ptr, std::min(PRINT_LIMIT, m), std::min(PRINT_LIMIT, k), stride, is_device_ptr);\n}\n"
  },
  {
    "path": "tests/pytorch/config/test_hf_overrides.py",
    "content": "import pytest\n\n\nclass TestHFOverrides:\n\n    @pytest.fixture\n    def hf_config(self):\n        from transformers.models.llava import LlavaConfig\n        yield LlavaConfig()\n\n    def test_hf_overrides(self, hf_config):\n        from lmdeploy.pytorch.config import override_hf_config\n\n        # update root\n        assert hf_config.model_type == 'llava'\n        overrides_dict = dict(model_type='llava_custom', )\n        override_hf_config(hf_config, overrides_dict)\n        assert hf_config.model_type == 'llava_custom'\n\n        # update rope_parameters (renamed from rope_scaling in newer transformers)\n        assert hf_config.text_config.model_type == 'llama'\n        assert hf_config.text_config.rope_parameters['rope_type'] == 'default'\n        overrides_dict = dict(text_config=dict(rope_parameters=dict(rope_type='yarn', )))\n        override_hf_config(hf_config, overrides_dict)\n        assert hf_config.text_config.model_type == 'llama'\n        assert hf_config.text_config.rope_parameters['rope_type'] == 'yarn'\n\n        # update both\n        overrides_dict = dict(model_type='llava_custom2', text_config=dict(rope_parameters=dict(rope_type='yarn2', )))\n        override_hf_config(hf_config, overrides_dict)\n        assert hf_config.model_type == 'llava_custom2'\n        assert hf_config.text_config.model_type == 'llama'\n        assert hf_config.text_config.rope_parameters['rope_type'] == 'yarn2'\n"
  },
  {
    "path": "tests/pytorch/engine/test_logits_process.py",
    "content": "# yapf: disable\nimport torch\nfrom transformers.generation.logits_process import (MinPLogitsWarper, RepetitionPenaltyLogitsProcessor,\n                                                    TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper)\n\n# yapf: enable\n\n\ndef test_process_temperature():\n    from lmdeploy.pytorch.engine.logits_process import _process_temperature_\n\n    batch_size = 4\n    num_tokens = 16\n    scores = torch.rand(batch_size, num_tokens)\n    temperatures = torch.rand(batch_size)\n\n    gt = []\n    for score, temperature in zip(scores, temperatures):\n        warper = TemperatureLogitsWarper(temperature.item())\n        gt.append(warper(None, score[None]))\n    gt = torch.cat(gt)\n\n    out = _process_temperature_(scores, temperatures)\n    torch.testing.assert_close(out, gt)\n\n\ndef test_process_bad_words():\n    from lmdeploy.pytorch.engine.logits_process import _process_bad_words_\n\n    filter_value: float = -float('inf')\n    batch_size = 4\n    num_tokens = 16\n    scores = torch.rand(batch_size, num_tokens)\n    bad_words = torch.tensor([\n        [0, 1],\n        [3, -1],\n        [4, 4],\n        [-1, -1],\n    ])\n    mask = bad_words >= 0\n\n    out_scores = _process_bad_words_(scores, bad_words, mask)\n\n    for score, bw in zip(out_scores, bad_words):\n        bw = bw.tolist()\n\n        for w in bw:\n            if w >= 0:\n                assert score[w] == filter_value\n\n\ndef test_processrepetition_penalty():\n    from lmdeploy.pytorch.engine.logits_process import _process_repetition_penalty_\n    batch_size = 4\n    num_tokens = 16\n    scores = torch.rand(batch_size, num_tokens)\n    input_ids = torch.tensor([\n        [0, 1],\n        [3, 6],\n        [4, 4],\n        [0, 0],\n    ])\n    penalties = 1 + torch.rand(batch_size)\n\n    gt = []\n    for score, ids, penalty in zip(scores, input_ids, penalties):\n        warper = RepetitionPenaltyLogitsProcessor(penalty.item())\n        gt.append(warper(ids[None], score[None].clone()))\n    gt = torch.cat(gt)\n\n    out = _process_repetition_penalty_(scores, input_ids, penalties)\n    torch.testing.assert_close(out, gt)\n\n\ndef test_filter_topk_sorted():\n    from lmdeploy.pytorch.engine.logits_process import _filter_topk_sorted_\n\n    batch_size = 4\n    num_tokens = 16\n    scores = torch.rand(batch_size, num_tokens).sort(1, descending=True)[0]\n    top_k = torch.randint(4, num_tokens - 4, (batch_size, ))\n\n    gt = []\n    for score, k in zip(scores, top_k):\n        warper = TopKLogitsWarper(k.item())\n        gt.append(warper(None, score[None].clone()))\n    gt = torch.cat(gt)\n\n    out = _filter_topk_sorted_(scores, top_k)\n    torch.testing.assert_close(out, gt)\n\n\ndef test_filter_topp_sorted():\n    from lmdeploy.pytorch.engine.logits_process import _filter_topp_sorted_\n\n    batch_size = 4\n    num_tokens = 16\n    scores = torch.rand(batch_size, num_tokens).sort(1, descending=True)[0]\n    top_p = torch.rand(batch_size)\n\n    gt = []\n    for score, p in zip(scores, top_p):\n        warper = TopPLogitsWarper(p.item())\n        gt.append(warper(None, score[None].clone()))\n    gt = torch.cat(gt)\n\n    out = _filter_topp_sorted_(scores, top_p)\n    torch.testing.assert_close(out, gt)\n\n\ndef test_filter_minp_sorted():\n    from lmdeploy.pytorch.engine.logits_process import _filter_minp_sorted_\n\n    batch_size = 4\n    num_tokens = 16\n    scores = torch.rand(batch_size, num_tokens).sort(1, descending=True)[0]\n    min_p = torch.rand(batch_size)\n\n    gt = []\n    for score, p in zip(scores, min_p):\n        warper = MinPLogitsWarper(p.item())\n        gt.append(warper(None, score[None].clone()))\n    gt = torch.cat(gt)\n\n    out = _filter_minp_sorted_(scores, min_p)\n    torch.testing.assert_close(out, gt)\n\n\ndef test_filter_ngram():\n    from lmdeploy.pytorch.engine.logits_process import _filter_repetition_ngram_\n    vocab_size = 100\n\n    def _get_emtas(n, window_size):\n        batch_size = generated_ids.size(0)\n        max_n = int(n.max().item())\n        same_n = n.eq(max_n).all().item()\n        max_window_size = window_size\n        if same_n:\n            n = None\n        return batch_size, max_n, max_window_size, n\n\n    # base test\n    generated_ids = torch.tensor([\n        [2, 3, 4, 1, 2, 3, 4, 2, 3, 4],\n        [9, 8, 7, 3, 8, 7, 5, 9, 8, 7],\n        [9, 8, 7, 3, 8, 7, 5, 9, 8, 7],\n    ],\n                                 dtype=torch.int64)\n    n = torch.tensor([3, 3, 2], dtype=torch.int64)\n    threshold = torch.tensor([3, 3, 3], dtype=torch.int64)\n\n    batch_size, max_n, max_window_size, n = _get_emtas(n, 10)\n    scores = torch.rand(batch_size, vocab_size)\n    stop_words = torch.randint(0, vocab_size, (batch_size, 3), dtype=torch.int64)\n    _filter_repetition_ngram_(scores, stop_words, generated_ids, n, threshold, max_n, max_window_size)\n\n    assert not scores[1].isinf().any().item()\n    assert scores[0].isinf().sum().item() == vocab_size - 1\n    assert scores[2].isinf().sum().item() == vocab_size - 1\n    assert scores[0, stop_words[0, 0]] == 0\n    assert scores[2, stop_words[2, 0]] == 0\n\n    # test no ngram\n    generated_ids = torch.tensor([\n        [2, 3, 4, 1, 2, 3, 4, 2, 3, 4],\n        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n    ])\n    n = torch.tensor([3, 0], dtype=torch.int64)\n    threshold = torch.tensor([3, 0], dtype=torch.int64)\n    batch_size, max_n, max_window_size, n = _get_emtas(n, 10)\n\n    scores = torch.rand(batch_size, vocab_size)\n    stop_words = torch.randint(0, vocab_size, (batch_size, 3), dtype=torch.int64)\n    _filter_repetition_ngram_(scores, stop_words, generated_ids, n, threshold, max_n, max_window_size)\n    assert not scores[1].isinf().any().item()\n    assert scores[0].isinf().sum().item() == vocab_size - 1\n\n    # test ids all 0\n    generated_ids = torch.tensor([\n        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n    ])\n    n = torch.tensor([3], dtype=torch.int64)\n    threshold = torch.tensor([3], dtype=torch.int64)\n    batch_size, max_n, max_window_size, n = _get_emtas(n, 10)\n\n    scores = torch.rand(batch_size, vocab_size)\n    stop_words = torch.randint(0, vocab_size, (batch_size, 3), dtype=torch.int64)\n    _filter_repetition_ngram_(scores, stop_words, generated_ids, n, threshold, max_n, max_window_size)\n    assert scores[0].isinf().sum().item() == vocab_size - 1\n"
  },
  {
    "path": "tests/pytorch/engine/test_request.py",
    "content": "# yapf: disable\nimport asyncio\n\nimport pytest\n\nfrom lmdeploy.pytorch.engine.request import RequestManager, RequestType, ResponseType\n\n# yapf: enable\n\n\nclass TestRequestHander:\n\n    @pytest.fixture\n    def event_loop(self):\n        old_loop = asyncio.get_event_loop()\n        new_loop = asyncio.new_event_loop()\n        try:\n            asyncio.set_event_loop(new_loop)\n            yield new_loop\n        finally:\n            new_loop.stop()\n            asyncio.set_event_loop(old_loop)\n\n    @pytest.fixture\n    def manager(self):\n        yield RequestManager()\n\n    def test_bind(self, manager, event_loop):\n\n        def __stop_engine_callback(reqs, **kwargs):\n            for req in reqs:\n                resp = req.resp\n                resp.type = ResponseType.SUCCESS\n                resp.data = f'{req.data} success'\n                manager.response(resp)\n\n        async def __dummy_loop():\n            while True:\n                try:\n                    await manager.step()\n                except Exception:\n                    return\n\n        sender = manager.build_sender()\n        manager.set_main_loop_func(__dummy_loop)\n\n        # test not bind\n        resp = sender.send_async(RequestType.STOP_ENGINE, None)\n        resp = sender.recv(resp)\n        assert resp.type == ResponseType.HANDLER_NOT_EXIST\n\n        assert manager.is_loop_alive()\n\n        # test bind success\n        sender.send_async(RequestType.STOP_ENGINE, None)\n        manager.bind_func(RequestType.STOP_ENGINE, __stop_engine_callback)\n        resp = sender.send_async(RequestType.STOP_ENGINE, 'test')\n        resp = sender.recv(resp)\n        assert resp.data == 'test success'\n\n        # cleanup, cancel main task\n        task_to_cancel = manager._loop_task\n        manager.stop_loop()\n        asyncio.run\n        event_loop.run_until_complete(asyncio.gather(task_to_cancel, return_exceptions=True))\n"
  },
  {
    "path": "tests/pytorch/engine/test_zmq_rpc.py",
    "content": "import asyncio\nimport multiprocessing as mp\n\n\nclass TestZMQRPC:\n\n    def sub_proc(self, shared_dict=None, condition=None):\n        from lmdeploy.pytorch.engine.mp_engine.zmq_rpc import AsyncRPCServer\n        server = AsyncRPCServer()\n        with condition:\n            shared_dict['rpc_server_port'] = server.port\n            condition.notify()\n\n        async def streaming_method(name):\n            for i in range(3):\n                yield f'{name}: streaming method {i}'\n\n        def method(name):\n            return f'{name}: method'\n\n        async def async_method(name):\n            return f'{name}: async method'\n\n        def close():\n            print('close server...')\n            server.stop()\n\n        server.register_method('method', method)\n        server.register_method('async_method', async_method)\n        server.register_method('streaming_method', streaming_method)\n        server.register_method('close', close)\n\n        loop = asyncio.new_event_loop()\n        asyncio.set_event_loop(loop)\n\n        asyncio.run(server.run())\n\n    async def async_main(self, port):\n        from lmdeploy.pytorch.engine.mp_engine.zmq_rpc import AsyncRPCClient\n        client = AsyncRPCClient(port=port)\n\n        loop = asyncio.get_event_loop()\n        _ = loop.create_task(client.listen())\n\n        # Example usage\n        result = client.call('async_method', 'test2')\n        assert result == 'test2: async method'\n        result = await client.async_call('method', 'test1')\n        assert result == 'test1: method'\n\n        async for result in client.async_stream_call('streaming_method', 'test3'):\n            pass\n        assert result == 'test3: streaming method 2'\n\n        await client.async_call('close')\n        client.stop()\n\n    def test_zmq_rpc(self):\n        with mp.Manager() as manager:\n            shared_dict = manager.dict()\n            condition = manager.Condition()\n            ctx = mp.get_context('spawn')\n            proc = ctx.Process(target=self.sub_proc, args=(shared_dict, condition), daemon=True)\n            proc.start()\n\n            with condition:\n                if 'rpc_server_port' not in shared_dict:\n                    condition.wait()\n            port = shared_dict['rpc_server_port']\n\n        asyncio.run(self.async_main(port))\n\n        proc.join()\n"
  },
  {
    "path": "tests/pytorch/kernel/test_activation.py",
    "content": "import pytest\nimport torch\n\n\nclass TestSiluAndMul:\n\n    @pytest.fixture\n    def seqlen(self, request):\n        yield request.param\n\n    @pytest.fixture\n    def feat_size(self, request):\n        yield request.param\n\n    @pytest.fixture\n    def x(self, seqlen, feat_size):\n        yield torch.rand(seqlen, feat_size, dtype=torch.float16, device='cuda')\n\n    @pytest.fixture\n    def gt(self, x):\n        gate, up = x.chunk(2, -1)\n        gate = torch.nn.functional.silu(gate)\n        yield gate * up\n\n    @pytest.mark.parametrize('seqlen', [65536, 256], indirect=True)\n    @pytest.mark.parametrize('feat_size', [4096, 768], indirect=True)\n    def test_silu_and_mul(self, x, gt):\n        from lmdeploy.pytorch.kernels.cuda.activation import silu_and_mul\n\n        out = silu_and_mul(x)\n        torch.testing.assert_close(out, gt)\n\n\nclass TestSiluAndMulMoEEP:\n\n    @pytest.fixture\n    def num_experts(self, request):\n        yield request.param\n\n    @pytest.fixture\n    def seqlen(self, request):\n        yield request.param\n\n    @pytest.fixture\n    def feat_size(self, request):\n        yield request.param\n\n    @pytest.fixture\n    def dtype(self):\n        yield torch.float16\n\n    @pytest.fixture\n    def x(self, num_experts, seqlen, feat_size, dtype):\n        yield torch.rand(num_experts, seqlen, feat_size, dtype=dtype, device='cuda')\n\n    @pytest.fixture\n    def mask_m(self, num_experts, seqlen):\n        mask_m = torch.randint(0, seqlen, (num_experts, ), device='cuda')\n        yield mask_m\n\n    @pytest.fixture\n    def elem_mask(self, mask_m, seqlen):\n        elem_mask = torch.arange(seqlen, device='cuda').unsqueeze(0) < mask_m.unsqueeze(1)\n        yield elem_mask[..., None]\n\n    @pytest.fixture\n    def gt(self, x):\n        gate, up = x.chunk(2, -1)\n        gate = torch.nn.functional.silu(gate)\n        yield gate * up\n\n    @pytest.mark.parametrize('num_experts', [4], indirect=True)\n    @pytest.mark.parametrize('seqlen', [1024], indirect=True)\n    @pytest.mark.parametrize('feat_size', [4096, 768], indirect=True)\n    def test_silu_and_mul(self, x, mask_m, elem_mask, gt):\n        from lmdeploy.pytorch.kernels.cuda.activation import silu_and_mul_moe_ep\n\n        out = silu_and_mul_moe_ep(x, mask_m)\n        out.masked_fill_(~elem_mask, 0.0)\n        gt.masked_fill_(~elem_mask, 0.0)\n        torch.testing.assert_close(out, gt)\n"
  },
  {
    "path": "tests/pytorch/kernel/test_apply_rotary.py",
    "content": "import pytest\nimport torch\n\nfrom lmdeploy.utils import is_bf16_supported\n\n\ndef _rotate_half(x):\n    \"\"\"Rotates half the hidden dims of the input.\"\"\"\n    x1 = x[..., :x.shape[-1] // 2]\n    x2 = x[..., x.shape[-1] // 2:]\n    return torch.cat((-x2, x1), dim=-1)\n\n\ndef _bf16_mark():\n    return pytest.mark.skipif(not is_bf16_supported(), reason='bf16 not supported.')\n\n\nclass TestApplyRotary:\n\n    @pytest.fixture\n    def dtype(self, request):\n        yield request.param\n\n    @pytest.fixture\n    def batch_size(self):\n        yield 4\n\n    @pytest.fixture\n    def num_heads_q(self, request):\n        yield request.param\n\n    @pytest.fixture\n    def num_heads_k(self, request):\n        yield request.param\n\n    @pytest.fixture\n    def feature_dim(self):\n        yield 128\n\n    @pytest.fixture\n    def seq_length(self, batch_size):\n        yield torch.randint(8, 16, (batch_size, ), device='cuda')\n\n    @pytest.fixture\n    def max_seqlen(self, seq_length):\n        yield seq_length.max()\n\n    @pytest.fixture\n    def q_states(self, seq_length, num_heads_q, feature_dim, dtype):\n        yield torch.randn(seq_length.sum(), num_heads_q, feature_dim, dtype=dtype, device='cuda')\n\n    @pytest.fixture\n    def k_states(self, seq_length, num_heads_k, feature_dim, dtype):\n        yield torch.randn(seq_length.sum(), num_heads_k, feature_dim, dtype=dtype, device='cuda')\n\n    @pytest.fixture\n    def position_ids_1d(self, seq_length, max_seqlen):\n        yield torch.randint(0, max_seqlen.item(), (seq_length.sum().item(), ), device='cuda')\n\n    @pytest.fixture\n    def cached_cos(self, max_seqlen, feature_dim, dtype):\n        yield torch.randn(max_seqlen, feature_dim, dtype=dtype, device='cuda')\n\n    @pytest.fixture\n    def cached_sin(self, max_seqlen, feature_dim, dtype):\n        yield torch.randn(max_seqlen, feature_dim, dtype=dtype, device='cuda')\n\n    @pytest.fixture\n    def cos(self, cached_cos, position_ids_1d):\n        yield cached_cos[position_ids_1d, None, :]\n\n    @pytest.fixture\n    def sin(self, cached_sin, position_ids_1d):\n        yield cached_sin[position_ids_1d, None, :]\n\n    @pytest.fixture\n    def gt(self, q_states, k_states, cos, sin, position_ids_1d):\n\n        q_embed = q_states * cos + _rotate_half(q_states) * sin\n        k_embed = k_states * cos + _rotate_half(k_states) * sin\n\n        yield q_embed, k_embed\n\n    @pytest.mark.parametrize('dtype', [pytest.param(torch.bfloat16, marks=_bf16_mark()), torch.float16, torch.float32],\n                             indirect=True)\n    @pytest.mark.parametrize(('num_heads_q', 'num_heads_k'), [(8, 8), (8, 4)], indirect=True)\n    def test_apply_rotary(self, q_states, k_states, cos, sin, gt):\n        from lmdeploy.pytorch.kernels.cuda import apply_rotary_pos_emb\n        q_embed, k_embed = apply_rotary_pos_emb(q_states, k_states, cos, sin)\n        q_gt, k_gt = gt\n\n        rtol = None\n        atol = None\n        torch.testing.assert_close(q_embed, q_gt, rtol=rtol, atol=atol)\n        torch.testing.assert_close(k_embed, k_gt, rtol=rtol, atol=atol)\n"
  },
  {
    "path": "tests/pytorch/kernel/test_bitonic_topk.py",
    "content": "import pytest\nimport torch\n\n\nclass TestBitonicTopk:\n\n    @pytest.fixture\n    def device(self):\n        yield 'cuda'\n\n    @pytest.fixture\n    def k(self):\n        yield 2048\n\n    @pytest.fixture\n    def q_seqlens(self, device):\n        ret = [4, 16, 1, 32]\n        ret = torch.tensor(ret, dtype=torch.int32, device=device)\n        yield ret\n\n    @pytest.fixture\n    def kv_seqlens(self, device):\n        ret = [1024, 2048, 4096, 4096 + 133]\n        ret = torch.tensor(ret, dtype=torch.int32, device=device)\n        yield ret\n\n    @pytest.fixture\n    def batch_size(self, kv_seqlens):\n        return kv_seqlens.numel()\n\n    @pytest.fixture\n    def max_kv_len(self, kv_seqlens):\n        return kv_seqlens.max().item()\n\n    @pytest.fixture\n    def scores(self, q_seqlens, max_kv_len, device):\n        num_tokens = q_seqlens.sum().item()\n        yield torch.randn((num_tokens, max_kv_len), device=device)\n\n    @pytest.fixture\n    def gt(self, scores, q_seqlens, kv_seqlens, k):\n        batch_size = kv_seqlens.numel()\n        num_tokens, _ = scores.shape\n        topk_indices = torch.empty((num_tokens, k), dtype=torch.int32, device=scores.device)\n        topk_indices.fill_(-1)\n\n        start = 0\n        for i in range(batch_size):\n            q_seqlen = q_seqlens[i].item()\n            seqlen = kv_seqlens[i].item()\n            tmp_k = min(seqlen, k)\n            end = start + q_seqlen\n            _, topk_indices[start:end, :seqlen] = torch.topk(scores[start:end, :seqlen],\n                                                             tmp_k,\n                                                             largest=True,\n                                                             sorted=True)\n            start = end\n        return topk_indices\n\n    def test_bitonic_topk(self, scores, q_seqlens, kv_seqlens, k, gt):\n        from lmdeploy.pytorch.kernels.cuda.bitonic_topk import bitonic_topk\n        out = bitonic_topk(scores, q_seqlens=q_seqlens, kv_seqlens=kv_seqlens, k=k, fill=-1, sorted=True)\n        gt[gt < 0] = 0\n        out[out < 0] = 0\n        gt_score = torch.gather(scores, 1, gt.to(torch.int64))\n        out_score = torch.gather(scores, 1, out.to(torch.int64))\n        torch.testing.assert_close(gt_score, out_score)\n"
  },
  {
    "path": "tests/pytorch/kernel/test_causal_conv1d.py",
    "content": "import pytest\nimport torch\n\n\ndef do_test():\n    try:\n        import causal_conv1d  # noqa: F401\n        import tilelang  # noqa: F401\n        causal_conv1d_fn = causal_conv1d.causal_conv1d_fn  # noqa: F841\n        causal_conv1d_update = causal_conv1d.causal_conv1d_update  # noqa: F841\n        return True\n    except Exception:\n        return False\n\n\n@pytest.mark.skipif(not do_test(), reason='tilelang or causal_conv1d is not available')\nclass TestCausalConv1dUpdate:\n\n    @pytest.fixture\n    def device(self):\n        yield 'cuda'\n\n    @pytest.fixture\n    def batch(self):\n        yield 512\n\n    @pytest.fixture\n    def hidden_size(self):\n        yield 2048\n\n    @pytest.fixture\n    def width(self):\n        yield 4\n\n    @pytest.fixture\n    def x(self, batch, hidden_size, device):\n        yield torch.randn(batch, hidden_size, 1, device=device)\n\n    @pytest.fixture\n    def weight(self, hidden_size, width, device):\n        yield torch.randn(hidden_size, width, device=device)\n\n    @pytest.fixture\n    def conv_state(self, batch, hidden_size, width, device):\n        conv_state = torch.randn(batch * 4, hidden_size, width, device=device)\n        conv_state = conv_state[::2]\n        yield conv_state\n\n    @pytest.fixture\n    def bias(self, hidden_size, device):\n        yield torch.randn(hidden_size, device=device)\n\n    @pytest.fixture\n    def conv_state_indices(self, batch, device):\n        conv_state_indices = batch * 2 - 1 - torch.arange(0, batch * 2, 2, device=device)\n        yield conv_state_indices.to(torch.int32)\n\n    @pytest.fixture(params=[None, 'silu'])\n    def activation(self, request):\n        yield request.param\n\n    def test_causal_conv1d_update(self, x, conv_state, weight, bias, activation, conv_state_indices):\n        from causal_conv1d import causal_conv1d_update as causal_conv1d_update_gt\n\n        from lmdeploy.pytorch.kernels.cuda.causal_conv1d import causal_conv1d_update\n\n        conv_state_clone = conv_state.clone()\n        out = causal_conv1d_update(x=x,\n                                   conv_state=conv_state_clone,\n                                   weight=weight,\n                                   bias=bias,\n                                   activation=activation,\n                                   conv_state_indices=conv_state_indices)\n        out_gt = causal_conv1d_update_gt(x=x,\n                                         conv_state=conv_state,\n                                         weight=weight,\n                                         bias=bias,\n                                         activation=activation,\n                                         conv_state_indices=conv_state_indices)\n        torch.testing.assert_close(out, out_gt, rtol=1e-3, atol=1e-3)\n        torch.testing.assert_close(conv_state_clone, conv_state, rtol=1e-3, atol=1e-3)\n\n\n@pytest.mark.skipif(not do_test(), reason='tilelang or causal_conv1d is not available')\nclass TestCausalConv1dFn:\n\n    @pytest.fixture\n    def device(self):\n        yield 'cuda'\n\n    @pytest.fixture\n    def hidden_size(self):\n        yield 2048\n\n    @pytest.fixture\n    def seqlen(self):\n        yield 4096\n\n    @pytest.fixture\n    def seq_idx(self, seqlen, device):\n        seq_idx = torch.zeros(seqlen, dtype=torch.int32, device=device)\n        seq_idx[seqlen // 4 * 3:] = 1\n        seq_idx = seq_idx.view(1, -1)\n        yield seq_idx\n\n    @pytest.fixture\n    def x(self, hidden_size, seqlen, device):\n        yield torch.randn(1, hidden_size, seqlen, device=device).transpose(1, 2).contiguous().transpose(1, 2)\n\n    @pytest.fixture\n    def weight(self, hidden_size, device):\n        yield torch.randn(hidden_size, 4, device=device)\n\n    @pytest.fixture\n    def bias(self, hidden_size, device):\n        yield torch.randn(hidden_size, device=device)\n\n    @pytest.fixture(params=[None, 'silu'])\n    def activation(self, request):\n        yield request.param\n\n    def test_causal_conv1d_fn(self, x, weight, bias, activation, seq_idx):\n        from causal_conv1d import causal_conv1d_fn as causal_conv1d_fn_gt\n\n        from lmdeploy.pytorch.kernels.cuda.causal_conv1d import causal_conv1d_fn\n\n        out = causal_conv1d_fn(x=x,\n                               weight=weight,\n                               bias=bias,\n                               activation=activation,\n                               return_final_states=False,\n                               seq_idx=seq_idx)\n        out_gt = causal_conv1d_fn_gt(x=x,\n                                     weight=weight,\n                                     bias=bias,\n                                     activation=activation,\n                                     return_final_states=False,\n                                     seq_idx=seq_idx)\n        torch.testing.assert_close(out, out_gt, rtol=1e-3, atol=1e-3)\n"
  },
  {
    "path": "tests/pytorch/kernel/test_ds_index.py",
    "content": "import pytest\nimport torch\n\n\ndef _make_A(M, K, group_size, out_dtype, device):\n    quant_A = torch.randn(M, K // group_size, group_size, dtype=torch.float32, device=device)\n    # -1 ~ 1\n    quant_A = quant_A * 2 - 1\n    # scaling abs max to fmax\n    finfo = torch.finfo(out_dtype)\n    fmax = finfo.max\n    scaling = fmax / quant_A.abs().amax(-1, keepdim=True)\n    quant_A *= scaling\n    quant_A = quant_A.to(out_dtype).to(torch.float32)\n\n    # create scale and A\n    scale = torch.randn(M, K // group_size, dtype=torch.float32, device=device)\n    scale /= fmax\n    A = quant_A * scale[..., None]\n\n    A = A.reshape(M, K)\n    quant_A = quant_A.reshape(M, K).to(out_dtype)\n    scale = scale.T.contiguous().T\n    return A, quant_A, scale\n\n\n@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason='require device with cc>=9.0')\nclass TestDSIndex:\n\n    @pytest.fixture\n    def num_heads(self):\n        yield 64\n\n    @pytest.fixture\n    def head_dim(self):\n        yield 128\n\n    @pytest.fixture\n    def block_size(self):\n        yield 64\n\n    @pytest.fixture\n    def device(self):\n        yield 'cuda'\n\n    @pytest.fixture\n    def q_seqlens(self, request):\n        yield request.param\n\n    @pytest.fixture\n    def kv_seqlens(self, request):\n        yield request.param\n\n    @pytest.fixture\n    def k_seqlens(self, kv_seqlens, device):\n        yield torch.tensor(kv_seqlens, dtype=torch.int32, device=device)\n\n    @pytest.fixture\n    def cu_seqlen_q(self, q_seqlens, device):\n        yield torch.tensor([0] + list(q_seqlens), dtype=torch.int32, device=device).cumsum(0)\n\n    @pytest.fixture\n    def cu_seqlen_kv(self, kv_seqlens, device):\n        yield torch.tensor([0] + list(kv_seqlens), dtype=torch.int32, device=device).cumsum(0)\n\n    @pytest.fixture\n    def query(self, q_seqlens, num_heads, head_dim, device):\n        total_len = sum(q_seqlens)\n        fp_q, q, q_s = _make_A(total_len * num_heads, head_dim, head_dim, out_dtype=torch.float8_e4m3fn, device=device)\n        fp_q = fp_q.view(total_len, num_heads, head_dim)\n        q = q.view(total_len, num_heads, head_dim)\n        q_s = q_s.view(total_len, num_heads)\n        yield fp_q, q, q_s\n\n    @pytest.fixture\n    def q(self, query):\n        yield query[1]\n\n    @pytest.fixture\n    def q_s(self, query):\n        yield query[2]\n\n    @pytest.fixture\n    def key(self, kv_seqlens, head_dim):\n        total_len = sum(kv_seqlens)\n        fp_k, k, k_s = _make_A(total_len, head_dim, head_dim, out_dtype=torch.float8_e4m3fn, device='cuda')\n        fp_k = fp_k.view(total_len, head_dim)\n        k = k.view(total_len, head_dim)\n        k_s = k_s.view(total_len)\n        yield fp_k, k, k_s\n\n    @pytest.fixture\n    def k(self, key):\n        yield key[1]\n\n    @pytest.fixture\n    def k_s(self, key):\n        yield key[2]\n\n    @pytest.fixture\n    def cache_key(self, k, k_s, kv_seqlens, block_size, head_dim):\n        batch_size = len(kv_seqlens)\n        max_num_blocks = (max(kv_seqlens) + block_size - 1) // block_size\n\n        # get block offsets\n        batch_ids = torch.arange(batch_size, device='cuda') * max_num_blocks\n        block_ids = torch.arange(max_num_blocks, device='cuda')\n        block_offsets = (batch_ids[:, None] + block_ids[None, :])\n\n        k_cache = torch.zeros((max_num_blocks * batch_size * block_size, head_dim),\n                              dtype=torch.float8_e4m3fn,\n                              device='cuda')\n        k_s_cache = torch.zeros((max_num_blocks * batch_size * block_size), dtype=torch.float32, device='cuda')\n\n        k = k.split(kv_seqlens, dim=0)\n        k_s = k_s.split(kv_seqlens, dim=0)\n        for i in range(batch_size):\n            size = k[i].size(0)\n            start = i * max_num_blocks * block_size\n            end = start + size\n            k_cache[start:end] = k[i]\n            k_s_cache[start:end] = k_s[i]\n\n        k_cache = k_cache.view(batch_size * max_num_blocks, block_size, head_dim)\n        k_s_cache = k_s_cache.view(batch_size * max_num_blocks, block_size)\n\n        yield k_cache, k_s_cache, block_offsets\n\n    @pytest.fixture\n    def k_cache(self, cache_key):\n        yield cache_key[0]\n\n    @pytest.fixture\n    def k_s_cache(self, cache_key):\n        yield cache_key[1]\n\n    @pytest.fixture\n    def block_offset(self, cache_key):\n        yield cache_key[2]\n\n    @pytest.mark.parametrize('q_seqlens', [(1, 1, 1, 1), (1024, 2048, 1024, 1)], indirect=True)\n    @pytest.mark.parametrize('kv_seqlens', [(2048, 4096, 1024, 128)], indirect=True)\n    def test_fp8_index(self, q, q_s, k_cache, k_s_cache, cu_seqlen_q, k_seqlens, block_offset):\n        # gt requires tilelang, so this test just ensure the kernel works\n        from lmdeploy.pytorch.kernels.cuda.ds_index import fp8_index\n        fp8_index(q, q_s, k_cache, k_s_cache, cu_seqlen_q, k_seqlens, block_offset)\n"
  },
  {
    "path": "tests/pytorch/kernel/test_fill_kv_cache.py",
    "content": "import pytest\nimport torch\n\n\ndef _div_up(a, b):\n    return (a + b - 1) // b\n\n\ndef quant(kv: torch.Tensor, nbits: int = 8):\n    \"\"\"Quant kv on the head_dim.\"\"\"\n    amax = kv.amax(dim=-1, keepdim=True)\n    amin = kv.amin(dim=-1, keepdim=True)\n    scales = (amax - amin) / (2**nbits - 1)\n    zeros = -amin / scales\n    q_kv = (kv / scales + zeros + 0.5).to(torch.uint8)\n    if nbits == 4:\n        q_kv1, q_kv2 = q_kv.split(q_kv.shape[-1] // 2, -1)\n        q_kv = q_kv1 + q_kv2 * 16\n    return q_kv, torch.cat([scales, zeros], dim=-1)\n\n\nclass TestFillKVCache:\n\n    @pytest.fixture\n    def num_heads(self):\n        yield 4\n\n    @pytest.fixture\n    def head_dim(self):\n        yield 32\n\n    @pytest.fixture\n    def block_size(self):\n        yield 16\n\n    @pytest.fixture\n    def seq_lens(self, request):\n        yield request.param\n\n    @pytest.fixture\n    def history_lens(self, request):\n        yield request.param\n\n    @pytest.fixture\n    def batch_size(self, seq_lens):\n        yield len(seq_lens)\n\n    @pytest.fixture\n    def kv_lens(self, seq_lens, history_lens):\n        yield [s + h for s, h in zip(seq_lens, history_lens)]\n\n    @pytest.fixture\n    def max_q_seq_length(self, seq_lens):\n        yield max(seq_lens)\n\n    @pytest.fixture\n    def num_tokens(self, seq_lens):\n        yield sum(seq_lens)\n\n    @pytest.fixture\n    def num_blocks_per_input(self, kv_lens, block_size):\n        yield [_div_up(kv_len, block_size) for kv_len in kv_lens]\n\n    @pytest.fixture\n    def max_num_blocks(self, num_blocks_per_input):\n        yield max(num_blocks_per_input)\n\n    @pytest.fixture\n    def q_seq_length(self, seq_lens):\n        yield torch.tensor(seq_lens).cuda()\n\n    @pytest.fixture\n    def q_start_loc(self, q_seq_length):\n        cum_seq_length = q_seq_length.cumsum(0)\n        yield cum_seq_length - q_seq_length\n\n    @pytest.fixture\n    def kv_seq_length(self, kv_lens):\n        yield torch.tensor(kv_lens).cuda()\n\n    @pytest.fixture\n    def k_states(self, num_tokens, num_heads, head_dim):\n        yield torch.randn(num_tokens, num_heads, head_dim).cuda()\n\n    @pytest.fixture\n    def v_states(self, k_states):\n        yield torch.randn_like(k_states)\n\n    @pytest.fixture\n    def k_caches(self, batch_size, max_num_blocks, block_size, num_heads, head_dim):\n        shape = (batch_size * max_num_blocks, block_size, num_heads, head_dim)\n        yield torch.full(shape, 0.0).cuda()\n\n    @pytest.fixture\n    def v_caches(self, k_caches):\n        yield torch.rand_like(k_caches)\n\n    @pytest.fixture\n    def block_offsets(self, num_blocks_per_input):\n        batch_size = len(num_blocks_per_input)\n        max_num_blocks = max(num_blocks_per_input)\n        batch_ids = torch.arange(batch_size)\n        ret = torch.arange(max_num_blocks)\n        ret = batch_ids[:, None] + ret[None, :] * batch_size\n        yield ret.cuda()\n\n    @pytest.fixture\n    def gt(self, k_states, v_states, k_caches, v_caches, seq_lens, history_lens, block_offsets, block_size):\n        batch_size = len(seq_lens)\n        k_caches = k_caches.clone()\n        v_caches = v_caches.clone()\n        splited_k_states = k_states.split(seq_lens)\n        splited_v_states = v_states.split(seq_lens)\n        for bidx in range(batch_size):\n            k_state = splited_k_states[bidx]\n            v_state = splited_v_states[bidx]\n            h_len = history_lens[bidx]\n            b_offs = block_offsets[bidx]\n            block_id = _div_up(h_len + 1, block_size) - 1\n            fill_start = h_len % block_size\n            fill_size = min(block_size - fill_start, k_state.size(0))\n            while True:\n                boff = b_offs[block_id]\n                tmp_ks = k_state[:fill_size]\n                tmp_vs = v_state[:fill_size]\n                fill_end = fill_start + fill_size\n                k_caches[boff, fill_start:fill_end] = tmp_ks\n                v_caches[boff, fill_start:fill_end] = tmp_vs\n                k_state = k_state[fill_size:]\n                v_state = v_state[fill_size:]\n                block_id += 1\n                fill_start = 0\n                fill_size = min(block_size, k_state.size(0))\n                if fill_size == 0:\n                    break\n\n        yield k_caches, v_caches\n\n    @pytest.mark.parametrize(['seq_lens', 'history_lens'], [\n        ((1, 1, 1, 1), (1, 16, 31, 24)),\n        ((1, 8, 16, 24), (1, 16, 31, 24)),\n    ],\n                             indirect=True)\n    def test_fill_kv_cache(self, k_states, v_states, k_caches, v_caches, block_offsets, q_start_loc, q_seq_length,\n                           kv_seq_length, max_q_seq_length, gt):\n        from lmdeploy.pytorch.kernels.cuda.fill_kv_cache import fill_kv_cache\n        fill_kv_cache(k_states, v_states, k_caches, v_caches, q_start_loc, q_seq_length, kv_seq_length,\n                      max_q_seq_length, block_offsets)\n\n        torch.testing.assert_close(k_caches, gt[0])\n        torch.testing.assert_close(v_caches, gt[1])\n\n\nclass TestFillKVCacheInt8(TestFillKVCache):\n\n    @pytest.fixture\n    def head_dim(self, request):\n        yield request.param\n\n    @pytest.fixture\n    def k_caches(self, batch_size, max_num_blocks, block_size, num_heads, head_dim):\n        shape = (batch_size * max_num_blocks, block_size, num_heads, head_dim)\n        yield torch.full(shape, 0, dtype=torch.uint8).cuda()\n\n    @pytest.fixture\n    def v_caches(self, k_caches):\n        yield torch.full_like(k_caches.to(torch.float32), 0).to(torch.uint8)\n\n    @pytest.fixture\n    def k_scales_zeros(self, batch_size, max_num_blocks, block_size, num_heads):\n        shape = (batch_size * max_num_blocks, block_size, num_heads, 2)\n        yield torch.full(shape, 0.0).cuda()\n\n    @pytest.fixture\n    def v_scales_zeros(self, k_scales_zeros):\n        yield torch.zeros_like(k_scales_zeros)\n\n    @pytest.fixture\n    def nbits(self):\n        yield 8\n\n    @pytest.fixture\n    def gt(self, k_states, v_states, k_caches, v_caches, seq_lens, history_lens, block_offsets, block_size,\n           k_scales_zeros, v_scales_zeros, nbits):\n        k_states, k_states_sz = quant(k_states, nbits)\n        v_states, v_states_sz = quant(v_states, nbits)\n        batch_size = len(seq_lens)\n        k_caches = k_caches.clone()\n        v_caches = v_caches.clone()\n        splited_k_states = k_states.split(seq_lens)\n        splited_v_states = v_states.split(seq_lens)\n        splited_k_states_sz = k_states_sz.split(seq_lens)\n        splited_v_states_sz = v_states_sz.split(seq_lens)\n        for bidx in range(batch_size):\n            k_state = splited_k_states[bidx]\n            v_state = splited_v_states[bidx]\n            k_state_sz = splited_k_states_sz[bidx]\n            v_state_sz = splited_v_states_sz[bidx]\n            h_len = history_lens[bidx]\n            b_offs = block_offsets[bidx]\n            block_id = _div_up(h_len + 1, block_size) - 1\n            fill_start = h_len % block_size\n            fill_size = min(block_size - fill_start, k_state.size(0))\n            while True:\n                boff = b_offs[block_id]\n                tmp_ks = k_state[:fill_size]\n                tmp_vs = v_state[:fill_size]\n                tmp_ks_sz = k_state_sz[:fill_size]\n                tmp_vs_sz = v_state_sz[:fill_size]\n                fill_end = fill_start + fill_size\n                k_caches[boff, fill_start:fill_end] = tmp_ks\n                v_caches[boff, fill_start:fill_end] = tmp_vs\n                k_scales_zeros[boff, fill_start:fill_end] = tmp_ks_sz\n                v_scales_zeros[boff, fill_start:fill_end] = tmp_vs_sz\n                k_state = k_state[fill_size:]\n                v_state = v_state[fill_size:]\n                k_state_sz = k_state_sz[fill_size:]\n                v_state_sz = v_state_sz[fill_size:]\n                block_id += 1\n                fill_start = 0\n                fill_size = min(block_size, k_state.size(0))\n                if fill_size == 0:\n                    break\n\n        yield k_caches, v_caches, k_scales_zeros, v_scales_zeros\n\n    @pytest.mark.parametrize('head_dim', [128, 96], indirect=True)\n    @pytest.mark.parametrize(['seq_lens', 'history_lens'], [\n        ((1, 1, 1, 1), (1, 16, 31, 24)),\n        ((1, 8, 16, 24), (1, 16, 31, 24)),\n    ],\n                             indirect=True)\n    def test_fill_kv_cache(self, k_states, v_states, k_caches, v_caches, k_scales_zeros, v_scales_zeros, block_offsets,\n                           q_start_loc, q_seq_length, kv_seq_length, max_q_seq_length, gt):\n        from lmdeploy.pytorch.kernels.cuda.fill_kv_cache import fill_kv_cache\n        fill_kv_cache(k_states, v_states, k_caches, v_caches, q_start_loc, q_seq_length, kv_seq_length,\n                      max_q_seq_length, block_offsets, k_scales_zeros, v_scales_zeros, 8)\n\n        torch.testing.assert_close(k_caches / 256, gt[0] / 256, atol=1e-2, rtol=1e-2)\n        torch.testing.assert_close(v_caches / 256, gt[1] / 256, atol=1e-2, rtol=1e-2)\n        torch.testing.assert_close(k_scales_zeros, gt[2])\n        torch.testing.assert_close(v_scales_zeros, gt[3])\n\n\nclass TestFillKVCacheInt4(TestFillKVCacheInt8):\n\n    @pytest.fixture\n    def k_caches(self, batch_size, max_num_blocks, block_size, num_heads, head_dim):\n        shape = (batch_size * max_num_blocks, block_size, num_heads, head_dim // 2)\n        yield torch.full(shape, 0, dtype=torch.uint8).cuda()\n\n    @pytest.fixture\n    def nbits(self):\n        yield 4\n\n    @pytest.mark.parametrize('head_dim', [128], indirect=True)\n    @pytest.mark.parametrize(['seq_lens', 'history_lens'], [\n        ((1, 1, 1, 1), (1, 16, 31, 24)),\n        ((1, 8, 16, 24), (1, 16, 31, 24)),\n    ],\n                             indirect=True)\n    def test_fill_kv_cache(self, k_states, v_states, k_caches, v_caches, k_scales_zeros, v_scales_zeros, block_offsets,\n                           q_start_loc, q_seq_length, kv_seq_length, max_q_seq_length, gt, nbits):\n        from lmdeploy.pytorch.kernels.cuda.fill_kv_cache import fill_kv_cache\n        k_scales_zeros = torch.zeros_like(k_scales_zeros)\n        v_scales_zeros = torch.zeros_like(v_scales_zeros)\n        fill_kv_cache(k_states, v_states, k_caches, v_caches, q_start_loc, q_seq_length, kv_seq_length,\n                      max_q_seq_length, block_offsets, k_scales_zeros, v_scales_zeros, nbits)\n\n        torch.testing.assert_close(k_scales_zeros, gt[2])\n        torch.testing.assert_close(v_scales_zeros, gt[3])\n        torch.testing.assert_close(k_caches, gt[0])\n        torch.testing.assert_close(v_caches, gt[1])\n\n\n@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason='require device with cc>=9.0')\nclass TestFillKVCacheBlockedFP8(TestFillKVCache):\n\n    @pytest.fixture(autouse=True, scope='class')\n    def initialize(self):\n        seed = 42\n        torch.manual_seed(seed)\n        torch.cuda.manual_seed(seed)\n        yield\n\n    @pytest.fixture\n    def scale_fmt(self, request):\n        yield request.param\n\n    @pytest.fixture\n    def quant_dtype(self):\n        yield torch.float8_e4m3fn\n\n    @pytest.fixture\n    def num_heads(self):\n        yield 4\n\n    @pytest.fixture\n    def head_dim(self):\n        yield 128\n\n    @pytest.fixture\n    def block_size(self):\n        yield 64\n\n    @pytest.fixture\n    def group_size(self):\n        yield 128\n\n    @pytest.fixture\n    def cu_seqlen_q(self, q_start_loc, q_seq_length):\n        batch_size = q_start_loc.size(0)\n        cu_seqlen = torch.zeros(batch_size + 1, dtype=torch.int32).cuda()\n        cu_seqlen[1:] = q_start_loc + q_seq_length\n        return cu_seqlen\n\n    @pytest.fixture\n    def k_caches(self, batch_size, max_num_blocks, block_size, num_heads, head_dim, quant_dtype):\n        shape = (batch_size * max_num_blocks, block_size, num_heads, head_dim)\n        yield torch.full(shape, 0, dtype=quant_dtype).cuda()\n\n    @pytest.fixture\n    def v_caches(self, k_caches):\n        yield torch.zeros_like(k_caches)\n\n    @pytest.fixture\n    def ks_caches(self, batch_size, max_num_blocks, block_size, num_heads, head_dim, group_size):\n        shape = (batch_size * max_num_blocks, block_size, num_heads, head_dim // group_size)\n        yield torch.full(shape, 0.0).cuda()\n\n    @pytest.fixture\n    def vs_caches(self, ks_caches):\n        yield torch.ones_like(ks_caches)\n\n    @pytest.fixture\n    def gt(self, k_states, v_states, group_size, quant_dtype, scale_fmt):\n        from lmdeploy.pytorch.kernels.cuda.blocked_gemm_fp8 import quant_fp8\n        batch_size = k_states.size(0)\n        num_heads = k_states.size(1)\n        head_dim = k_states.size(2)\n\n        k_states = k_states.flatten(0, -2)\n        v_states = v_states.flatten(0, -2)\n        quant_k, quant_ks = quant_fp8(k_states, group_size=group_size, dtype=quant_dtype, scale_fmt=scale_fmt)\n        quant_v, quant_vs = quant_fp8(v_states, group_size=group_size, dtype=quant_dtype, scale_fmt=scale_fmt)\n\n        quant_k = quant_k.view(batch_size, num_heads, head_dim)\n        quant_ks = quant_ks.view(batch_size, num_heads, head_dim // group_size)\n        quant_v = quant_v.view(batch_size, num_heads, head_dim)\n        quant_vs = quant_vs.view(batch_size, num_heads, head_dim // group_size)\n\n        yield quant_k, quant_ks, quant_v, quant_vs\n\n    def uncache(self, k_caches, ks_caches, v_caches, vs_caches, cu_seqlen_q, kv_seqlens, block_offsets):\n        batch_size = block_offsets.size(0)\n        out_k = []\n        out_ks = []\n        out_v = []\n        out_vs = []\n        q_seqlens = cu_seqlen_q[1:] - cu_seqlen_q[:-1]\n        for bidx in range(batch_size):\n            seqlen = q_seqlens[bidx].item()\n            kv_len = kv_seqlens[bidx].item()\n            start = kv_len - seqlen\n            end = kv_len\n            k = k_caches[block_offsets[bidx]].reshape(-1, k_caches.size(-2), k_caches.size(-1))\n            ks = ks_caches[block_offsets[bidx]].reshape(-1, ks_caches.size(-2), ks_caches.size(-1))\n            v = v_caches[block_offsets[bidx]].reshape(-1, v_caches.size(-2), v_caches.size(-1))\n            vs = vs_caches[block_offsets[bidx]].reshape(-1, vs_caches.size(-2), vs_caches.size(-1))\n            out_k.append(k[start:end])\n            out_ks.append(ks[start:end])\n            out_v.append(v[start:end])\n            out_vs.append(vs[start:end])\n        out_k = torch.cat(out_k, dim=0)\n        out_ks = torch.cat(out_ks, dim=0)\n        out_v = torch.cat(out_v, dim=0)\n        out_vs = torch.cat(out_vs, dim=0)\n        return out_k, out_ks, out_v, out_vs\n\n    @pytest.mark.parametrize('scale_fmt', [None, 'ue8m0'], indirect=True)\n    @pytest.mark.parametrize(['seq_lens', 'history_lens'], [\n        ((1, 1, 1, 1), (1, 128, 256, 200)),\n        ((1, 64, 128, 50), (1, 128, 256, 200)),\n    ],\n                             indirect=True)\n    def test_fill_kv_cache(self, k_states, v_states, k_caches, v_caches, ks_caches, vs_caches, block_offsets,\n                           cu_seqlen_q, kv_seq_length, max_q_seq_length, gt, group_size, scale_fmt):\n        from lmdeploy.pytorch.kernels.cuda.fill_kv_cache import fill_kv_cache_blocked_fp8\n        fill_kv_cache_blocked_fp8(k_states,\n                                  v_states,\n                                  k_caches,\n                                  v_caches,\n                                  ks_caches,\n                                  vs_caches,\n                                  cu_seqlen_q,\n                                  kv_seq_length,\n                                  max_q_seq_length,\n                                  block_offsets=block_offsets,\n                                  group_size=group_size,\n                                  scale_fmt=scale_fmt)\n\n        gt_k, gt_ks, gt_v, gt_vs = gt\n\n        # uncache\n        out_k, out_ks, out_v, out_vs = self.uncache(k_caches, ks_caches, v_caches, vs_caches, cu_seqlen_q,\n                                                    kv_seq_length, block_offsets)\n\n        out_k = out_k.float()\n        out_k = out_k / out_k.max()\n        gt_k = gt_k.float()\n        gt_k = gt_k / gt_k.max()\n        out_v = out_v.float()\n        out_v = out_v / out_v.max()\n        gt_v = gt_v.float()\n        gt_v = gt_v / gt_v.max()\n        torch.testing.assert_close(out_k, gt_k)\n        torch.testing.assert_close(out_ks, gt_ks)\n        torch.testing.assert_close(out_v, gt_v)\n        torch.testing.assert_close(out_vs, gt_vs)\n"
  },
  {
    "path": "tests/pytorch/kernel/test_flash_attention.py",
    "content": "import math\n\nimport pytest\nimport torch\n\n\ndef _conti_input(data, q_seqlens):\n    data = [x[:l] for x, l in zip(data, q_seqlens)]\n    data = torch.cat(data, dim=0)\n    return data\n\n\ndef _make_bias(q_seqlens, history_lens, neg_val, causal):\n    batch_size = q_seqlens.shape[0]\n    kv_seqlens = q_seqlens + history_lens\n    max_seq_len = q_seqlens.max().item()\n    max_kv_len = kv_seqlens.max().item()\n    if causal:\n        seq_ranges = torch.arange(max_seq_len).cuda()\n        seq_ranges = seq_ranges.repeat(batch_size, 1)\n        seq_ranges = torch.where(seq_ranges < q_seqlens[:, None], seq_ranges, -max_kv_len)\n\n        kv_ranges = torch.arange(max_kv_len).cuda()\n        kv_ranges = kv_ranges.repeat(batch_size, 1)\n\n        mask = (kv_ranges[:, None, :] - seq_ranges[:, :, None] > history_lens[:, None, None])\n        return mask.float() * neg_val\n    else:\n        q_mask = torch.arange(max_seq_len)[None].cuda() < q_seqlens[:, None]\n        k_mask = torch.arange(max_kv_len)[None].cuda() < kv_seqlens[:, None]\n        mask = q_mask[:, :, None] & k_mask[:, None, :]\n\n        return (~mask).float() * neg_val\n\n\ndef _make_bias_alibi(q_seqlens, history_lens, neg_val, causal, alibi_slopes):\n\n    batch_size = q_seqlens.shape[0]\n    kv_seqlens = q_seqlens + history_lens\n    max_q_len = q_seqlens.max().item()\n    max_kv_len = kv_seqlens.max().item()\n\n    device = 'cuda'\n    q_ranges = torch.arange(max_q_len, device=device)\n    seq_ranges = q_ranges.repeat(batch_size, 1) + history_lens[:, None]\n\n    kv_ranges = torch.arange(max_kv_len, device=device)\n    kv_ranges = kv_ranges.repeat(batch_size, 1)\n\n    diff = (seq_ranges[:, :, None] - kv_ranges[:, None, :]).abs()\n    slope_diff = -diff[:, None] * alibi_slopes[None, :, None, None]\n\n    # add bias\n    bias = _make_bias(q_seqlens, history_lens, neg_val, causal)\n    bias = bias[:, None] + slope_diff\n    return bias\n\n\ndef _make_block_sparse_bias(q_seqlens: torch.Tensor, history_lens: torch.Tensor, neg_val: float,\n                            block_sparse_size: int):\n    \"\"\"Make block sparse bias.\"\"\"\n    batch_size = q_seqlens.shape[0]\n    kv_seqlens = q_seqlens + history_lens\n    max_seq_len = q_seqlens.max().item()\n    max_kv_len = kv_seqlens.max().item()\n\n    seq_ranges = torch.arange(max_seq_len).cuda()\n    seq_ranges = seq_ranges // block_sparse_size * block_sparse_size\n    seq_ranges = seq_ranges.repeat(batch_size, 1)\n    seq_ranges = torch.where(seq_ranges < q_seqlens[:, None], seq_ranges, -max_kv_len)\n\n    kv_ranges = torch.arange(max_kv_len).cuda()\n    kv_ranges = kv_ranges // block_sparse_size * block_sparse_size\n    kv_ranges = kv_ranges.repeat(batch_size, 1)\n\n    mask = (kv_ranges[:, None, :] - seq_ranges[:, :, None] > history_lens[:, None, None])\n    return mask.float() * neg_val\n\n\ndef _naive_attention(batched_q, batched_kv, bias, sinks=None):\n    batched_k, batched_v = batched_kv\n\n    num_heads_q = batched_q.shape[2]\n    num_heads_k = batched_k.shape[2]\n    head_dim = batched_q.shape[-1]\n    group = num_heads_q // num_heads_k\n\n    q = batched_q.transpose(1, 2)\n    k = batched_k.permute(0, 2, 3, 1)\n    v = batched_v.transpose(1, 2)\n\n    # expand group\n    k = k.unsqueeze(2).expand(-1, -1, group, -1, -1).flatten(1, 2)\n    v = v.unsqueeze(2).expand(-1, -1, group, -1, -1).flatten(1, 2)\n\n    qk = torch.matmul(q, k) / math.sqrt(head_dim)\n    if bias.dim() == 3:\n        bias = bias[:, None]\n    attn_weight = qk + bias\n    if sinks is None:\n        attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32)\n    else:\n        sinks = sinks[None, :, None, None].to(torch.float32)\n        sinks = sinks.expand(attn_weight.shape[0], -1, attn_weight.shape[2], -1)\n        attn_weight = attn_weight.to(torch.float32)\n        combined_logits = torch.cat([attn_weight, sinks], dim=-1)\n        combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values\n        attn_weight = torch.softmax(combined_logits, dim=-1, dtype=torch.float32)\n        attn_weight = attn_weight[..., :-1]\n    attn_weight = attn_weight.to(q.dtype)\n    attn_output = torch.matmul(attn_weight, v)\n    attn_output = attn_output.transpose(1, 2).contiguous()\n\n    return attn_output\n\n\ndef _naive_window_attention(q, k, v, seqlens_q, seqlens_k, window_size):\n    try:\n        from lmdeploy.pytorch.third_party.flash_attn_interface import flash_attn_varlen_func\n    except Exception:\n        try:\n            from flash_attn import flash_attn_varlen_func\n        except Exception:\n            pytest.skip('Skip window attention test since flash attention is not available.')\n\n    def _make_cu_seqlens(seqlens):\n        cu_seqlens = seqlens.cumsum(0)\n        cu_zero = cu_seqlens.new_zeros(1)\n        cu_seqlens = torch.cat([cu_zero, cu_seqlens])\n        return cu_seqlens\n\n    max_seqlen_q = seqlens_q.max().item()\n    max_seqlen_k = seqlens_k.max().item()\n    cu_seqlens_q = _make_cu_seqlens(seqlens_q).int()\n    cu_seqlens_k = _make_cu_seqlens(seqlens_k).int()\n\n    output = flash_attn_varlen_func(q,\n                                    k,\n                                    v,\n                                    cu_seqlens_q,\n                                    cu_seqlens_k,\n                                    max_seqlen_q=max_seqlen_q,\n                                    max_seqlen_k=max_seqlen_k,\n                                    causal=True,\n                                    window_size=window_size)\n    return output\n\n\nclass TestFlashAttention:\n\n    @pytest.fixture\n    def dtype(self):\n        yield torch.float16\n\n    @pytest.fixture\n    def head_dim_k(self, request):\n        yield request.param\n\n    @pytest.fixture\n    def head_dim_v(self, request):\n        yield request.param\n\n    @pytest.fixture\n    def num_heads_q(self, request):\n        yield request.param\n\n    @pytest.fixture\n    def num_heads_k(self, request):\n        yield request.param\n\n    @pytest.fixture\n    def causal(self, request):\n        yield request.param\n\n    @pytest.fixture\n    def q_seqlens(self, request):\n        yield torch.tensor(request.param, device='cuda')\n\n    @pytest.fixture\n    def cu_seqlens_q(self, q_seqlens):\n        cu_seqlens = q_seqlens.cumsum(0)\n        cu_zero = cu_seqlens.new_zeros(1)\n        yield torch.cat([cu_zero, cu_seqlens]).int()\n\n    @pytest.fixture\n    def history_lens(self, request):\n        yield torch.tensor(request.param, device='cuda')\n\n    @pytest.fixture\n    def kv_seqlens(self, q_seqlens, history_lens):\n        yield q_seqlens + history_lens\n\n    @pytest.fixture\n    def cu_seqlens_k(self, kv_seqlens):\n        cu_seqlens = kv_seqlens.cumsum(0)\n        cu_zero = cu_seqlens.new_zeros(1)\n        yield torch.cat([cu_zero, cu_seqlens]).int()\n\n    @pytest.fixture\n    def batched_q(self, q_seqlens, num_heads_q, head_dim_k, dtype):\n        torch.manual_seed(123)\n        batch_size = len(q_seqlens)\n        max_seq_len = q_seqlens.max().item()\n        inputs = torch.rand(batch_size, max_seq_len, num_heads_q, head_dim_k, dtype=dtype, device='cuda')\n        yield inputs\n\n    @pytest.fixture\n    def batched_kv(self, q_seqlens, history_lens, num_heads_k, head_dim_k, head_dim_v, dtype):\n        torch.manual_seed(123)\n        batch_size = len(q_seqlens)\n        kv_seqlens = q_seqlens + history_lens\n        max_seq_len = kv_seqlens.max().item()\n        k = torch.rand(batch_size, max_seq_len, num_heads_k, head_dim_k, dtype=dtype, device='cuda')\n        v = torch.rand(batch_size, max_seq_len, num_heads_k, head_dim_v, dtype=dtype, device='cuda')\n        yield k, v\n\n    @pytest.fixture\n    def conti_q(self, q_seqlens, batched_q):\n        yield _conti_input(batched_q, q_seqlens)\n\n    @pytest.fixture\n    def conti_kv(self, kv_seqlens, batched_kv):\n        conti_k = _conti_input(batched_kv[0], kv_seqlens)\n        conti_k = conti_k.transpose(0, 1).contiguous()\n        conti_v = _conti_input(batched_kv[1], kv_seqlens)\n        conti_v = conti_v.transpose(0, 1).contiguous()\n        yield (conti_k, conti_v)\n\n    @pytest.fixture\n    def mask(self, q_seqlens, history_lens, causal):\n        neg_val = -1e30\n        yield _make_bias(q_seqlens, history_lens, neg_val, causal)\n\n    @pytest.fixture\n    def gt(self, batched_q, batched_kv, mask):\n        yield _naive_attention(batched_q, batched_kv, mask)\n\n    @pytest.fixture\n    def conti_gt(self, gt, q_seqlens):\n        yield _conti_input(gt, q_seqlens)\n\n    @pytest.mark.parametrize('head_dim_k', [32, 48], indirect=True)\n    @pytest.mark.parametrize('head_dim_v', [32], indirect=True)\n    @pytest.mark.parametrize('num_heads_q', [8, 2], indirect=True)\n    @pytest.mark.parametrize('num_heads_k', [2], indirect=True)\n    @pytest.mark.parametrize('causal', [True, False], indirect=True)\n    @pytest.mark.parametrize(['q_seqlens', 'history_lens'], [([30, 50, 70, 90], [50, 40, 30, 20])], indirect=True)\n    def test_flash_attention(self, conti_q, conti_kv, q_seqlens, cu_seqlens_q, cu_seqlens_k, causal, conti_gt):\n        from lmdeploy.pytorch.kernels.cuda.flashattention import flash_attn_varlen_func\n        max_seq_len = q_seqlens.max().item()\n\n        conti_k, conti_v = conti_kv\n        out = flash_attn_varlen_func(conti_q,\n                                     conti_k,\n                                     conti_v,\n                                     cu_seqlens_q,\n                                     cu_seqlens_k,\n                                     max_seqlen_q=max_seq_len,\n                                     causal=causal)\n        torch.testing.assert_close(out, conti_gt, atol=1e-3, rtol=1e-5)\n\n    @pytest.fixture\n    def win_size(self, request):\n        yield request.param\n\n    @pytest.fixture\n    def window_gt(self, conti_q, conti_kv, q_seqlens, kv_seqlens, win_size):\n        conti_k, conti_v = conti_kv\n        yield _naive_window_attention(conti_q,\n                                      conti_k.transpose(0, 1),\n                                      conti_v.transpose(0, 1),\n                                      q_seqlens,\n                                      kv_seqlens,\n                                      window_size=(win_size, win_size))\n\n    @pytest.mark.parametrize('head_dim_k', [16], indirect=True)\n    @pytest.mark.parametrize('head_dim_v', [16], indirect=True)\n    @pytest.mark.parametrize(['num_heads_q', 'num_heads_k'], [(4, 2)], indirect=True)\n    @pytest.mark.parametrize(['q_seqlens', 'history_lens'], [\n        ([30, 50, 70, 90], [50, 40, 30, 90]),\n    ], indirect=True)\n    @pytest.mark.parametrize('win_size', (32, ), indirect=True)\n    def test_window_attention(self, conti_q, conti_kv, q_seqlens, cu_seqlens_q, cu_seqlens_k, win_size, window_gt):\n        from lmdeploy.pytorch.kernels.cuda.flashattention import flash_attn_varlen_func\n        max_seq_len = q_seqlens.max().item()\n\n        conti_k, conti_v = conti_kv\n        out = flash_attn_varlen_func(conti_q,\n                                     conti_k,\n                                     conti_v,\n                                     cu_seqlens_q,\n                                     cu_seqlens_k,\n                                     max_seqlen_q=max_seq_len,\n                                     window_size=win_size,\n                                     causal=True)\n        torch.testing.assert_close(out, window_gt, atol=1e-3, rtol=1e-5)\n\n    @pytest.fixture\n    def sinks(self, num_heads_q, dtype):\n        yield torch.rand(num_heads_q, dtype=dtype, device='cuda')\n\n    @pytest.fixture\n    def sink_gt(self, batched_q, batched_kv, mask, sinks):\n        yield _naive_attention(batched_q, batched_kv, mask, sinks)\n\n    @pytest.fixture\n    def conti_sink_gt(self, sink_gt, q_seqlens):\n        yield _conti_input(sink_gt, q_seqlens)\n\n    @pytest.mark.parametrize('head_dim_k', [32], indirect=True)\n    @pytest.mark.parametrize('head_dim_v', [32], indirect=True)\n    @pytest.mark.parametrize('num_heads_q', [8], indirect=True)\n    @pytest.mark.parametrize('num_heads_k', [2], indirect=True)\n    @pytest.mark.parametrize('causal', [True], indirect=True)\n    @pytest.mark.parametrize(['q_seqlens', 'history_lens'], [([30, 50, 70, 90], [50, 40, 30, 20])], indirect=True)\n    def test_sinks(self, conti_q, conti_kv, q_seqlens, cu_seqlens_q, cu_seqlens_k, causal, sinks, conti_sink_gt):\n        from lmdeploy.pytorch.kernels.cuda.flashattention import flash_attn_varlen_func\n        max_seq_len = q_seqlens.max().item()\n\n        conti_k, conti_v = conti_kv\n        out = flash_attn_varlen_func(conti_q,\n                                     conti_k,\n                                     conti_v,\n                                     cu_seqlens_q,\n                                     cu_seqlens_k,\n                                     max_seqlen_q=max_seq_len,\n                                     sinks=sinks,\n                                     causal=causal)\n        torch.testing.assert_close(out, conti_sink_gt, atol=1e-3, rtol=1e-5)\n\n    # block sparse attention\n    @pytest.fixture\n    def block_sparse_size(self):\n        yield 4\n\n    @pytest.fixture\n    def block_sparse_mask(self, q_seqlens, history_lens, block_sparse_size):\n        neg_val = -1e30\n        yield _make_block_sparse_bias(q_seqlens, history_lens, neg_val, block_sparse_size)\n\n    @pytest.fixture\n    def block_sparse_gt(self, batched_q, batched_kv, block_sparse_mask):\n        yield _naive_attention(batched_q, batched_kv, block_sparse_mask)\n\n    @pytest.mark.parametrize('head_dim_k', [32], indirect=True)\n    @pytest.mark.parametrize('head_dim_v', [32], indirect=True)\n    @pytest.mark.parametrize('num_heads_q', [8], indirect=True)\n    @pytest.mark.parametrize('num_heads_k', [2], indirect=True)\n    @pytest.mark.parametrize(['q_seqlens', 'history_lens'], [([16, 32], [64, 8])], indirect=True)\n    def test_block_sparse_attention(self, conti_q, conti_kv, q_seqlens, cu_seqlens_q, cu_seqlens_k, block_sparse_size,\n                                    block_sparse_gt):\n        from lmdeploy.pytorch.kernels.cuda.flashattention import flash_attn_varlen_func\n        max_seq_len = q_seqlens.max().item()\n\n        conti_k, conti_v = conti_kv\n        out = flash_attn_varlen_func(conti_q,\n                                     conti_k,\n                                     conti_v,\n                                     cu_seqlens_q,\n                                     cu_seqlens_k,\n                                     max_seqlen_q=max_seq_len,\n                                     block_sparse_size=block_sparse_size,\n                                     causal=True)\n        gt = _conti_input(block_sparse_gt, q_seqlens)\n        torch.testing.assert_close(out, gt, atol=1e-3, rtol=1e-5)\n\n    @pytest.fixture\n    def alibi_slopes(self, num_heads_q):\n        yield torch.rand(num_heads_q, dtype=torch.float32, device='cuda')\n\n    @pytest.fixture\n    def alibi_bias(self, q_seqlens, history_lens, causal, alibi_slopes):\n        neg_val = -1e30\n        yield _make_bias_alibi(q_seqlens, history_lens, neg_val, causal, alibi_slopes)\n\n    @pytest.fixture\n    def alibi_gt(self, batched_q, batched_kv, alibi_bias):\n        yield _naive_attention(batched_q, batched_kv, alibi_bias)\n\n    @pytest.fixture\n    def conti_alibi_gt(self, alibi_gt, q_seqlens):\n        yield _conti_input(alibi_gt, q_seqlens)\n\n    @pytest.mark.parametrize('head_dim_k', [128], indirect=True)\n    @pytest.mark.parametrize('head_dim_v', [128], indirect=True)\n    @pytest.mark.parametrize('num_heads_q', [40], indirect=True)\n    @pytest.mark.parametrize('num_heads_k', [8], indirect=True)\n    @pytest.mark.parametrize('causal', [True], indirect=True)\n    @pytest.mark.parametrize(['q_seqlens', 'history_lens'], [\n        ([30, 50, 70, 90], [50, 40, 30, 20]),\n    ], indirect=True)\n    def test_alibi(self, conti_q, conti_kv, q_seqlens, cu_seqlens_q, cu_seqlens_k, causal, alibi_slopes,\n                   conti_alibi_gt):\n        from lmdeploy.pytorch.kernels.cuda.flashattention import flash_attn_varlen_func\n        max_seq_len = q_seqlens.max().item()\n\n        conti_k, conti_v = conti_kv\n        out = flash_attn_varlen_func(conti_q,\n                                     conti_k,\n                                     conti_v,\n                                     cu_seqlens_q,\n                                     cu_seqlens_k,\n                                     max_seqlen_q=max_seq_len,\n                                     alibi_slopes=alibi_slopes,\n                                     causal=causal)\n        torch.testing.assert_close(out, conti_alibi_gt, atol=1e-3, rtol=1e-5)\n"
  },
  {
    "path": "tests/pytorch/kernel/test_flatten_kv_cache.py",
    "content": "import pytest\nimport torch\n\n\ndef _div_up(a, b):\n    return (a + b - 1) // b\n\n\nclass TestFlattenKVCache:\n\n    @pytest.fixture\n    def out_dtype(self):\n        yield torch.float16\n\n    @pytest.fixture\n    def num_heads(self):\n        yield 4\n\n    @pytest.fixture\n    def head_dim(self):\n        yield 32\n\n    @pytest.fixture\n    def block_size(self):\n        yield 16\n\n    @pytest.fixture\n    def kv_lens(self):\n        yield [2, 24, 47, 48]\n\n    @pytest.fixture\n    def batch_size(self, kv_lens):\n        yield len(kv_lens)\n\n    @pytest.fixture\n    def num_blocks_per_input(self, kv_lens, block_size):\n        yield [_div_up(kv_len, block_size) for kv_len in kv_lens]\n\n    @pytest.fixture\n    def max_num_blocks(self, num_blocks_per_input):\n        yield max(num_blocks_per_input)\n\n    @pytest.fixture\n    def out_size(self, kv_lens):\n        yield sum(kv_lens)\n\n    @pytest.fixture\n    def kv_seqlens(self, kv_lens):\n        yield torch.tensor(kv_lens).cuda()\n\n    @pytest.fixture\n    def k_caches(self, batch_size, max_num_blocks, block_size, num_heads, head_dim, out_dtype):\n        shape = (batch_size * max_num_blocks, block_size, num_heads, head_dim)\n        yield torch.rand(shape, dtype=out_dtype, device='cuda')\n\n    @pytest.fixture\n    def v_caches(self, k_caches):\n        yield torch.rand_like(k_caches)\n\n    @pytest.fixture\n    def block_offsets(self, num_blocks_per_input):\n        batch_size = len(num_blocks_per_input)\n        max_num_blocks = max(num_blocks_per_input)\n        batch_ids = torch.arange(batch_size)\n        ret = torch.arange(max_num_blocks)\n        ret = batch_ids[:, None] + ret[None, :] * batch_size\n        yield ret.cuda()\n\n    @pytest.fixture\n    def gt(self, k_caches, v_caches, kv_lens, block_offsets, block_size, num_heads, out_size, head_dim):\n        k_states = k_caches.new_empty(num_heads, out_size, head_dim)\n        v_states = v_caches.new_empty(num_heads, out_size, head_dim)\n        start_loc = 0\n        for kv_len, block_offs in zip(kv_lens, block_offsets):\n            remain_len = kv_len\n            for idx, _ in enumerate(range(0, kv_len, block_size)):\n                b_off = block_offs[idx]\n                block_len = min(block_size, remain_len)\n                end_loc = start_loc + block_len\n                k_block = k_caches[b_off, :block_len]\n                v_block = v_caches[b_off, :block_len]\n                k_states[:, start_loc:end_loc] = k_block.transpose(0, 1)\n                v_states[:, start_loc:end_loc] = v_block.transpose(0, 1)\n                start_loc = end_loc\n                remain_len -= block_len\n\n        yield k_states, v_states\n\n    def test_flatten_kv_cache(self, k_caches, v_caches, kv_seqlens, block_offsets, out_size, gt):\n        from lmdeploy.pytorch.kernels.cuda.flatten_kv_cache import flatten_kv_cache\n\n        k_states, v_states = flatten_kv_cache(k_caches, v_caches, kv_seqlens, block_offsets, out_size=out_size)\n        torch.testing.assert_close(k_states, gt[0])\n        torch.testing.assert_close(v_states, gt[1])\n\n\ndef precise_round(x: torch.Tensor):\n    return x.sign() * (x.abs() + 0.5).floor()\n\n\ndef quant(kv: torch.Tensor, nbits: int = 8):\n    \"\"\"Quant kv on the head_dim.\"\"\"\n    amax = kv.amax(dim=-1, keepdim=True)\n    amin = kv.amin(dim=-1, keepdim=True)\n    scales = (amax - amin) / (2**nbits - 1)\n    zeros = -amin / scales\n    q_kv = (kv / scales + zeros + 0.5).to(torch.uint8)\n    if nbits == 4:\n        q_kv1, q_kv2 = q_kv.split(q_kv.shape[-1] // 2, -1)\n        q_kv = q_kv1 + q_kv2 * 16\n    return q_kv, torch.cat([scales, zeros], dim=-1)\n\n\nclass TestFlattenKVCacheQuant8(TestFlattenKVCache):\n\n    @pytest.fixture\n    def nbits(self):\n        yield 8\n\n    @pytest.fixture\n    def atol(self):\n        yield 4e-3\n\n    @pytest.fixture\n    def rtol(self):\n        yield 1e-5\n\n    @pytest.fixture\n    def k_quant(self, k_caches, nbits):\n        yield quant(k_caches, nbits)\n\n    @pytest.fixture\n    def v_quant(self, v_caches, nbits):\n        yield quant(v_caches, nbits)\n\n    def test_flatten_kv_cache(self, k_quant, v_quant, kv_seqlens, block_offsets, out_size, out_dtype, nbits, gt, atol,\n                              rtol):\n        from lmdeploy.pytorch.kernels.cuda.flatten_kv_cache import flatten_kv_cache\n\n        k_caches, k_sz = k_quant\n        v_caches, v_sz = v_quant\n\n        k_sz = k_sz.to(out_dtype)\n        v_sz = v_sz.to(out_dtype)\n\n        k_states, v_states = flatten_kv_cache(k_caches,\n                                              v_caches,\n                                              kv_seqlens,\n                                              block_offsets,\n                                              out_size=out_size,\n                                              out_dtype=out_dtype,\n                                              k_scales_zeros=k_sz,\n                                              v_scales_zeros=v_sz,\n                                              quant_policy=nbits)\n\n        torch.testing.assert_close(k_states, gt[0], atol=atol, rtol=rtol)\n        torch.testing.assert_close(v_states, gt[1], atol=atol, rtol=rtol)\n\n\nclass TestFlattenKVCacheQuant4(TestFlattenKVCacheQuant8):\n\n    @pytest.fixture\n    def nbits(self):\n        yield 4\n\n    @pytest.fixture\n    def atol(self):\n        yield 0.05\n\n    @pytest.fixture\n    def rtol(self):\n        yield 1e-3\n\n\n@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason='require device with cc>=9.0')\nclass TestFlattenKVCacheMLAFP8(TestFlattenKVCache):\n\n    @pytest.fixture\n    def out_dtype(self):\n        yield torch.bfloat16\n\n    @pytest.fixture\n    def num_heads(self):\n        yield 1\n\n    @pytest.fixture\n    def head_dim(self):\n        yield 576\n\n    @pytest.fixture\n    def block_size(self):\n        yield 64\n\n    @pytest.fixture\n    def k_cache_mla(self, k_caches):\n        from lmdeploy.pytorch.kernels.cuda.blocked_gemm_fp8 import quant_fp8\n        num_blocks, block_size, num_heads, _ = k_caches.shape\n        k_cache_pe = k_caches[:, :, :, 512:]\n        k_cache_nope = k_caches[:, :, :, :512].flatten(0, -2)\n        k_cache_nope, k_cache_scale = quant_fp8(k_cache_nope, group_size=128)\n        k_cache_nope = k_cache_nope.view(num_blocks, block_size, num_heads, -1)\n        k_cache_scale = k_cache_scale.reshape(num_blocks, block_size, num_heads, -1).to(torch.float32)\n        dtype = k_cache_nope.dtype\n        out = torch.cat([k_cache_nope, k_cache_scale.view(dtype), k_cache_pe.view(dtype)], dim=-1)\n        yield out\n\n    def _dequant(self, k_cache_mla):\n        k_cache_nope = k_cache_mla[..., :512].to(torch.float32)\n        k_cache_scale = k_cache_mla[..., 512:512 + 16].view(torch.float32)\n        k_cache_pe = k_cache_mla[..., 512 + 16:].view(torch.bfloat16)\n        k_cache_nope = k_cache_nope.unflatten(-1, (-1, 128))\n        k_cache_scale = k_cache_scale[..., None]\n        k_cache_nope *= k_cache_scale\n        k_cache_nope = k_cache_nope.flatten(-2, -1).to(k_cache_pe.dtype)\n        k_cache = torch.cat([k_cache_nope, k_cache_pe], dim=-1)\n        return k_cache\n\n    @pytest.fixture\n    def gt(self, k_cache_mla, kv_lens, block_offsets, block_size, num_heads, out_size, head_dim):\n        k_caches = self._dequant(k_cache_mla)\n        k_states = k_caches.new_empty(num_heads, out_size, head_dim)\n        start_loc = 0\n        for kv_len, block_offs in zip(kv_lens, block_offsets):\n            remain_len = kv_len\n            for idx, _ in enumerate(range(0, kv_len, block_size)):\n                b_off = block_offs[idx]\n                block_len = min(block_size, remain_len)\n                end_loc = start_loc + block_len\n                k_block = k_caches[b_off, :block_len]\n                k_states[:, start_loc:end_loc] = k_block.transpose(0, 1)\n                start_loc = end_loc\n                remain_len -= block_len\n\n        yield k_states\n\n    def test_flatten_kv_cache(self, k_cache_mla, kv_seqlens, block_offsets, out_size, out_dtype, gt):\n        from lmdeploy.pytorch.kernels.cuda.flatten_kv_cache import flatten_kv_cache_mla_fp8\n\n        k_states = flatten_kv_cache_mla_fp8(k_cache_mla,\n                                            kv_seqlens,\n                                            block_offsets,\n                                            out_size=out_size,\n                                            out_dtype=out_dtype)\n        torch.testing.assert_close(k_states, gt)\n"
  },
  {
    "path": "tests/pytorch/kernel/test_fuse_moe_blocked_fp8.py",
    "content": "import pytest\nimport torch\n\n\ndef _make_A(M, K, group_size, out_dtype, device='cuda'):\n    quant_A = torch.rand(M, K // group_size, group_size, dtype=torch.float32, device=device)\n    # -1 ~ 1\n    quant_A = quant_A * 2 - 1\n    # scaling abs max to fmax\n    finfo = torch.finfo(out_dtype)\n    fmax = finfo.max\n    scaling = fmax / quant_A.abs().amax(-1, keepdim=True)\n    quant_A *= scaling\n    quant_A = quant_A.to(out_dtype).to(torch.float32)\n\n    # create scale and A\n    scale = torch.rand(M, K // group_size, dtype=torch.float32, device=device)\n    scale /= fmax\n    A = quant_A * scale[..., None]\n\n    A = A.reshape(M, K)\n    quant_A = quant_A.reshape(M, K).to(out_dtype)\n    return A, quant_A, scale\n\n\ndef _make_B(E, K, N, group_size, out_dtype, device='cuda'):\n    quant_B = torch.rand(E,\n                         N // group_size,\n                         group_size,\n                         K // group_size,\n                         group_size,\n                         dtype=torch.float32,\n                         device=device)\n    quant_B = quant_B * 2 - 1\n\n    # scaling abs max to fmax\n    finfo = torch.finfo(out_dtype)\n    fmax = finfo.max\n    scaling = fmax / quant_B.abs().amax((2, 4), keepdim=True)\n    quant_B *= scaling\n    quant_B = quant_B.to(out_dtype).to(torch.float32)\n\n    scale = torch.rand(E, N // group_size, 1, K // group_size, 1, dtype=torch.float32, device=device)\n    scale /= fmax\n\n    B = quant_B * scale\n\n    B = B.reshape(E, N, K)\n    quant_B = quant_B.reshape(E, N, K).to(out_dtype)\n    scale = scale.reshape(E, N // group_size, K // group_size)\n    bias = torch.rand(E, N, dtype=torch.float32, device=device) - 0.5\n    return B, quant_B, scale, bias\n\n\ndef _get_sorted_idx(topk_idx: torch.Tensor, num_experts: int):\n    flatten_topk_idx = topk_idx.flatten()\n    sorted_ids = flatten_topk_idx.argsort()\n    exp_range = torch.arange(0, num_experts, device=topk_idx.device)\n    exp_tok_cnt = (flatten_topk_idx[None, :] == exp_range[:, None]).sum(1)\n    return sorted_ids, exp_tok_cnt\n\n\n@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason='require device with cc>=9.0')\nclass TestFusedMoEFP8KernelLauncher:\n\n    @pytest.fixture\n    def dtype(self):\n        yield torch.float16\n\n    @pytest.fixture\n    def quant_dtype(self):\n        yield torch.float8_e4m3fn\n\n    @pytest.fixture\n    def device(self):\n        yield torch.device('cuda')\n\n    @pytest.fixture\n    def N(self):\n        yield 512\n\n    @pytest.fixture\n    def K(self):\n        yield 1024\n\n    @pytest.fixture\n    def M(self):\n        yield 256\n\n    @pytest.fixture\n    def num_experts(self):\n        yield 64\n\n    @pytest.fixture\n    def top_k(self):\n        yield 6\n\n    @pytest.fixture\n    def group_size(self):\n        yield 128\n\n    @pytest.fixture\n    def build_A(self, M, K, group_size, quant_dtype, device):\n        yield _make_A(M, K, group_size=group_size, out_dtype=quant_dtype, device=device)\n\n    @pytest.fixture\n    def A(self, build_A, dtype):\n        yield build_A[0].to(dtype)\n\n    @pytest.fixture\n    def A_quant(self, build_A):\n        yield build_A[1]\n\n    @pytest.fixture\n    def A_scale(self, build_A):\n        yield build_A[2]\n\n    @pytest.fixture\n    def build_B(self, num_experts, N, K, group_size, quant_dtype, device):\n        yield _make_B(num_experts, K, N, group_size=group_size, out_dtype=quant_dtype, device=device)\n\n    @pytest.fixture\n    def B(self, build_B, dtype):\n        yield build_B[0].to(dtype)\n\n    @pytest.fixture\n    def B_quant(self, build_B):\n        yield build_B[1]\n\n    @pytest.fixture\n    def B_scale(self, build_B):\n        yield build_B[2]\n\n    @pytest.fixture\n    def bias(self, build_B, dtype):\n        yield build_B[3].to(dtype)\n        # yield None\n\n    @pytest.fixture\n    def router_weights(self, M, num_experts, device, dtype):\n        yield torch.rand(M, num_experts, device=device, dtype=dtype)\n\n    @pytest.fixture\n    def topk_weights(self, router_weights, top_k):\n        yield router_weights.topk(top_k, dim=-1)\n\n    @pytest.fixture\n    def topk_idx(self, topk_weights):\n        yield topk_weights[1]\n\n    @pytest.fixture\n    def sort_and_cnt(self, topk_idx, num_experts):\n        yield _get_sorted_idx(topk_idx, num_experts)\n\n    @pytest.fixture\n    def sorted_idx(self, sort_and_cnt):\n        yield sort_and_cnt[0]\n\n    @pytest.fixture\n    def exp_tok_cnt(self, sort_and_cnt):\n        yield sort_and_cnt[1]\n\n    @pytest.fixture\n    def exp_end(self, exp_tok_cnt):\n        yield exp_tok_cnt.cumsum(0)\n\n    @pytest.fixture\n    def exp_start(self, exp_end, exp_tok_cnt):\n        yield exp_end - exp_tok_cnt\n\n    @pytest.fixture\n    def gt(self, A, B, bias, top_k, sorted_idx, exp_start, exp_end, M):\n        from lmdeploy.pytorch.kernels.cuda.fused_moe import fused_moe_kernel_launcher\n        N = B.size(1)\n        C = B.new_empty(M * top_k, N)\n        fused_moe_kernel_launcher(\n            A,\n            B,\n            C,\n            sorted_idx,\n            exp_start,\n            exp_end,\n            bias=bias,\n            top_k=top_k,\n            num_tokens=M,\n        )\n\n        yield C\n\n    @torch.inference_mode()\n    def test_launcher(self, A_quant, A_scale, B, B_quant, B_scale, bias, sorted_idx, exp_start, exp_end, top_k, M, gt):\n        from lmdeploy.pytorch.kernels.cuda.blocked_fp8_fused_moe import fused_moe_blocked_fp8_kernel_launcher\n        N = B.size(1)\n        C = B.new_empty(M * top_k, N)\n        fused_moe_blocked_fp8_kernel_launcher(\n            A=A_quant,\n            A_scale=A_scale,\n            B=B_quant,\n            B_scale=B_scale,\n            C=C,\n            sorted_idx=sorted_idx,\n            exp_start=exp_start,\n            exp_end=exp_end,\n            bias=bias,\n            top_k=top_k,\n            num_tokens=M,\n        )\n\n        gt_max = gt.abs().max()\n        C = C / gt_max\n        gt = gt / gt_max\n        torch.testing.assert_close(C, gt, atol=4e-3, rtol=1e-3)\n\n\n@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason='require device with cc>=9.0')\nclass TestFusedMoeBlockedFP8:\n\n    @pytest.fixture\n    def dtype(self):\n        yield torch.float16\n\n    @pytest.fixture\n    def quant_dtype(self):\n        yield torch.float8_e4m3fn\n\n    @pytest.fixture\n    def device(self):\n        yield torch.device('cuda')\n\n    @pytest.fixture\n    def in_size(self):\n        yield 512\n\n    @pytest.fixture\n    def seq_len(seq_len):\n        yield 128\n\n    @pytest.fixture\n    def hidden_size(self):\n        yield 2048\n\n    @pytest.fixture\n    def out_size(self):\n        yield 1024\n\n    @pytest.fixture\n    def num_experts(self):\n        yield 4\n\n    @pytest.fixture\n    def top_k(self):\n        yield 2\n\n    @pytest.fixture\n    def group_size(self):\n        yield 128\n\n    @pytest.fixture\n    def renormalize(self):\n        yield True\n\n    @pytest.fixture\n    def build_hidden_states(self, seq_len, in_size, group_size, quant_dtype, device):\n        yield _make_A(seq_len, in_size, group_size=group_size, out_dtype=quant_dtype, device=device)\n\n    @pytest.fixture\n    def hidden_states(self, build_hidden_states, dtype):\n        yield build_hidden_states[0].to(dtype)\n\n    @pytest.fixture\n    def states_quanted(self, build_hidden_states):\n        yield build_hidden_states[1]\n\n    @pytest.fixture\n    def states_scale(self, build_hidden_states):\n        yield build_hidden_states[2]\n\n    @pytest.fixture\n    def build_w1(self, num_experts, hidden_size, in_size, group_size, quant_dtype, device):\n        yield _make_B(num_experts, in_size, hidden_size, group_size=group_size, out_dtype=quant_dtype, device=device)\n\n    @pytest.fixture\n    def w1(self, build_w1, dtype):\n        yield build_w1[0].to(dtype)\n\n    @pytest.fixture\n    def w1_quant(self, build_w1):\n        yield build_w1[1]\n\n    @pytest.fixture\n    def w1_scale(self, build_w1):\n        yield build_w1[2]\n\n    @pytest.fixture\n    def build_w2(self, num_experts, out_size, hidden_size, group_size, quant_dtype, device):\n        yield _make_B(num_experts,\n                      hidden_size // 2,\n                      out_size,\n                      group_size=group_size,\n                      out_dtype=quant_dtype,\n                      device=device)\n\n    @pytest.fixture\n    def w2(self, build_w2, dtype):\n        yield build_w2[0].to(dtype)\n\n    @pytest.fixture\n    def w2_quant(self, build_w2):\n        yield build_w2[1]\n\n    @pytest.fixture\n    def w2_scale(self, build_w2):\n        yield build_w2[2]\n\n    @pytest.fixture\n    def router_logits(self, seq_len, num_experts, dtype, device):\n        yield torch.rand(seq_len, num_experts, dtype=dtype, device=device)\n\n    @pytest.fixture\n    def topk_logits(self, router_logits, top_k):\n        routing_weights = torch.softmax(router_logits, dim=-1, dtype=torch.float32)\n        yield torch.topk(routing_weights, top_k, dim=-1)\n\n    @pytest.fixture\n    def topk_weights(self, topk_logits):\n        yield topk_logits[0]\n\n    @pytest.fixture\n    def topk_idx(self, topk_logits):\n        yield topk_logits[1]\n\n    @pytest.fixture\n    def gt(self, hidden_states, w1, w2, topk_weights, topk_idx, top_k, renormalize):\n        from lmdeploy.pytorch.kernels.cuda.fused_moe import fused_moe\n        output = fused_moe(hidden_states, w1, w2, topk_weights, topk_idx, topk=top_k, renormalize=renormalize)\n        yield output\n\n    @torch.inference_mode()\n    def test_fused_moe(self, states_quanted, states_scale, w1_quant, w1_scale, w2_quant, w2_scale, topk_weights,\n                       topk_idx, top_k, renormalize, gt):\n        from lmdeploy.pytorch.kernels.cuda.blocked_fp8_fused_moe import fused_moe_blocked_fp8\n        output = fused_moe_blocked_fp8(states_quanted,\n                                       states_scale,\n                                       w1_quant,\n                                       w1_scale,\n                                       w2_quant,\n                                       w2_scale,\n                                       topk_weights,\n                                       topk_idx,\n                                       topk=top_k,\n                                       renormalize=renormalize)\n        out_max = output.abs().max()\n        gt_max = gt.abs().max()\n        assert (out_max - gt_max).abs() / out_max < 0.05\n\n        norm_out = output / out_max\n        norm_gt = gt / gt_max\n        torch.testing.assert_close(norm_out, norm_gt, atol=0.05, rtol=1e-3)\n"
  },
  {
    "path": "tests/pytorch/kernel/test_fused_lora.py",
    "content": "import pytest\nimport torch\n\nfrom lmdeploy.pytorch.kernels.cuda.fused_lora import fused_lora\n\n\nclass TestFusedLoRA:\n\n    @pytest.fixture\n    def dtype(self):\n        yield torch.float16\n\n    @pytest.fixture\n    def head_size(self):\n        yield 32\n\n    @pytest.fixture\n    def out_head_size(self):\n        yield 16\n\n    @pytest.fixture\n    def seq_lens(self, request):\n        yield torch.tensor(request.param).cuda()\n\n    @pytest.fixture\n    def ranks(self):\n        yield torch.tensor([2, 4]).cuda()\n\n    @pytest.fixture\n    def start_loc(self, seq_lens):\n        yield seq_lens.cumsum(0) - seq_lens\n\n    @pytest.fixture\n    def input(self, seq_lens, head_size, dtype):\n        total_len = seq_lens.sum()\n        yield torch.rand(total_len, head_size, dtype=dtype).cuda()\n\n    @pytest.fixture\n    def adapter_ids(self, seq_lens, ranks):\n        num_ranks = len(ranks)\n        num_seqs = len(seq_lens)\n        ret = torch.arange(0, num_seqs) % num_ranks\n        ret = ret.cuda()\n        yield ret\n\n    @pytest.fixture\n    def scaling(self, ranks):\n        yield torch.arange(ranks.size(0)).cuda() + 1\n\n    @pytest.fixture\n    def lora_a(self, ranks, head_size, dtype):\n        out = []\n        for rank in ranks:\n            w = torch.rand(head_size, rank, dtype=dtype).cuda()\n            out.append(w)\n        yield out\n\n    @pytest.fixture\n    def lora_b(self, ranks, out_head_size, dtype):\n        out = []\n        for rank in ranks:\n            w = torch.rand(rank, out_head_size, dtype=dtype).cuda()\n            out.append(w)\n        yield out\n\n    @pytest.fixture\n    def fused_lora_a(self, lora_a):\n        yield torch.cat(lora_a, dim=1).t().contiguous()\n\n    @pytest.fixture\n    def fused_lora_b(self, lora_b):\n        yield torch.cat(lora_b, dim=0).contiguous()\n\n    @pytest.fixture\n    def gt(self, input, start_loc, seq_lens, adapter_ids, lora_a, lora_b, scaling):\n        out = []\n        for loc, s_len, r_id in zip(start_loc, seq_lens, adapter_ids):\n            inp = input[loc:loc + s_len]\n            l_a = lora_a[r_id]\n            l_b = lora_b[r_id]\n            s = scaling[r_id]\n            out.append(inp @ l_a @ l_b * s)\n\n        yield torch.cat(out)\n\n    @pytest.mark.parametrize('seq_lens', [\n        (2, 4, 6, 8),\n        (1, 1, 1, 1),\n    ], indirect=True)\n    def test_fused_lora(self, input, fused_lora_a, fused_lora_b, start_loc, seq_lens, adapter_ids, scaling, ranks, gt):\n        max_seq_len = max(seq_lens).item()\n        max_rank = max(ranks).item()\n        rank_offset = ranks.cumsum(0) - ranks\n\n        output = fused_lora(\n            input,\n            fused_lora_a,\n            fused_lora_b,\n            scaling=scaling,\n            rank_start=rank_offset,\n            ranks=ranks,\n            seq_start=start_loc,\n            seq_lens=seq_lens,\n            adapter_ids=adapter_ids,\n            max_rank=max_rank,\n            max_seqlen=max_seq_len,\n        )\n\n        torch.testing.assert_close(gt, output)\n"
  },
  {
    "path": "tests/pytorch/kernel/test_fused_moe.py",
    "content": "import pytest\nimport torch\nimport torch.nn.functional as F\n\n\ndef _get_sorted_idx(topk_idx: torch.Tensor, num_experts: int):\n    flatten_topk_idx = topk_idx.flatten()\n    sorted_ids = flatten_topk_idx.argsort()\n    exp_range = torch.arange(0, num_experts, device=topk_idx.device)\n    exp_tok_cnt = (flatten_topk_idx[None, :] == exp_range[:, None]).sum(1)\n    return sorted_ids, exp_tok_cnt\n\n\nclass TestFusedMoEKernelLauncher:\n\n    @pytest.fixture\n    def dtype(self):\n        yield torch.float16\n\n    @pytest.fixture\n    def device(self):\n        yield torch.device('cuda')\n\n    @pytest.fixture\n    def N(self):\n        yield 128\n\n    @pytest.fixture\n    def K(self):\n        yield 64\n\n    @pytest.fixture\n    def M(self):\n        yield 256\n\n    @pytest.fixture\n    def num_experts(self):\n        yield 64\n\n    @pytest.fixture\n    def top_k(self):\n        yield 6\n\n    @pytest.fixture\n    def A(self, M, K, device, dtype):\n        ret = torch.rand(M, K, device=device, dtype=dtype)\n        yield (ret - 0.5) / 2\n\n    @pytest.fixture\n    def B(self, num_experts, N, K, device, dtype):\n        ret = torch.rand(num_experts, N, K, device=device, dtype=dtype)\n        yield (ret - 0.5) / 2\n\n    @pytest.fixture\n    def bias(self, num_experts, N, device, dtype):\n        yield torch.rand(num_experts, N, device=device, dtype=dtype) - 0.5\n\n    @pytest.fixture\n    def router_weights(self, M, num_experts, device, dtype):\n        yield torch.rand(M, num_experts, device=device, dtype=dtype)\n\n    @pytest.fixture\n    def topk_weights(self, router_weights, top_k):\n        yield router_weights.topk(top_k, dim=-1)\n\n    @pytest.fixture\n    def topk_idx(self, topk_weights):\n        yield topk_weights[1]\n\n    @pytest.fixture\n    def sort_and_cnt(self, topk_idx, num_experts):\n        yield _get_sorted_idx(topk_idx, num_experts)\n\n    @pytest.fixture\n    def sorted_idx(self, sort_and_cnt):\n        yield sort_and_cnt[0]\n\n    @pytest.fixture\n    def exp_tok_cnt(self, sort_and_cnt):\n        yield sort_and_cnt[1]\n\n    @pytest.fixture\n    def exp_end(self, exp_tok_cnt):\n        yield exp_tok_cnt.cumsum(0)\n\n    @pytest.fixture\n    def exp_start(self, exp_end, exp_tok_cnt):\n        yield exp_end - exp_tok_cnt\n\n    @pytest.fixture\n    def gt(self, A, B, bias, top_k, topk_idx):\n        M = A.size(0)\n        N = B.size(1)\n        E = B.size(0)\n        C = B.new_empty(M, top_k, N)\n        for eid in range(E):\n            EB = B[eid].t()\n            Ebias = bias[eid]\n            token_idx, k_idx = torch.where(topk_idx == eid)\n            if len(token_idx) == 0:\n                continue\n            EC = A[token_idx] @ EB + Ebias\n            C[token_idx, k_idx] = EC\n        yield C.flatten(0, 1)\n\n    @torch.inference_mode()\n    def test_launcher(self, A, B, bias, sorted_idx, exp_start, exp_end, top_k, M, gt):\n        from lmdeploy.pytorch.kernels.cuda.fused_moe import fused_moe_kernel_launcher\n        N = B.size(1)\n        C = B.new_empty(M * top_k, N)\n\n        fused_moe_kernel_launcher(\n            A,\n            B,\n            C,\n            sorted_idx,\n            exp_start,\n            exp_end,\n            bias=bias,\n            top_k=top_k,\n            num_tokens=M,\n        )\n        torch.testing.assert_close(C, gt, atol=1e-3, rtol=1e-3)\n\n\ndef _mlp_forward(hidden_states, gate_proj, up_proj, down_proj):\n    gate = F.linear(hidden_states, gate_proj)\n    up = F.linear(hidden_states, up_proj)\n    return F.linear(F.silu(gate) * up, down_proj)\n\n\nclass TestFusedMoe:\n\n    @pytest.fixture\n    def dtype(self):\n        yield torch.float16\n\n    @pytest.fixture\n    def device(self):\n        yield torch.device('cuda')\n\n    @pytest.fixture\n    def in_size(self):\n        yield 128\n\n    @pytest.fixture\n    def seq_len(seq_len):\n        yield 128\n\n    @pytest.fixture\n    def hidden_size(self):\n        yield 256\n\n    @pytest.fixture\n    def out_size(self):\n        yield 128\n\n    @pytest.fixture\n    def num_experts(self):\n        yield 64\n\n    @pytest.fixture\n    def top_k(self):\n        yield 6\n\n    @pytest.fixture\n    def renormalize(self):\n        yield True\n\n    @pytest.fixture\n    def hidden_states(self, seq_len, in_size, dtype, device):\n        ret = torch.rand(seq_len, in_size, dtype=dtype, device=device)\n        yield (ret - 0.5) / 2\n\n    @pytest.fixture\n    def w1(self, num_experts, hidden_size, in_size, dtype, device):\n        ret = torch.rand(num_experts, hidden_size, in_size, dtype=dtype, device=device)\n        yield (ret - 0.5) / 2\n\n    @pytest.fixture\n    def w2(self, num_experts, out_size, hidden_size, dtype, device):\n        ret = torch.rand(num_experts, out_size, hidden_size // 2, dtype=dtype, device=device)\n        yield (ret - 0.5) / 2\n\n    @pytest.fixture\n    def router_logits(self, seq_len, num_experts, dtype, device):\n        yield torch.rand(seq_len, num_experts, dtype=dtype, device=device)\n\n    @pytest.fixture\n    def topk_logits(self, router_logits, top_k):\n        routing_weights = torch.softmax(router_logits, dim=-1, dtype=torch.float32)\n        yield torch.topk(routing_weights, top_k, dim=-1)\n\n    @pytest.fixture\n    def topk_weights(self, topk_logits):\n        yield topk_logits[0]\n\n    @pytest.fixture\n    def topk_idx(self, topk_logits):\n        yield topk_logits[1]\n\n    @pytest.fixture\n    def gt(self, hidden_states, w1, w2, topk_weights, topk_idx, renormalize):\n        if renormalize:\n            topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)\n\n        seq_len = hidden_states.size(0)\n        out_size = w2.size(1)\n        output = hidden_states.new_zeros(seq_len, out_size)\n        num_experts = w1.size(0)\n        for eid in range(num_experts):\n            token_idx, k_idx = torch.where(topk_idx == eid)\n            gate_proj, up_proj = w1[eid].chunk(2, dim=0)\n            down_proj = w2[eid]\n            tmp_out = _mlp_forward(hidden_states[token_idx], gate_proj, up_proj, down_proj)\n            tmp_out = tmp_out * topk_weights[token_idx, k_idx, None]\n            output.index_add_(0, token_idx, tmp_out.to(output.dtype))\n        yield output\n\n    @torch.inference_mode()\n    def test_fused_moe(self, hidden_states, w1, w2, topk_weights, topk_idx, top_k, renormalize, gt):\n        from lmdeploy.pytorch.kernels.cuda.fused_moe import fused_moe\n        output = fused_moe(hidden_states, w1, w2, topk_weights, topk_idx, topk=top_k, renormalize=renormalize)\n        torch.testing.assert_close(output, gt, atol=1e-3, rtol=1e-3)\n\n\nclass TestFusedMoeW8A8(TestFusedMoe):\n\n    @pytest.fixture\n    def quant_states(self, hidden_states):\n        from lmdeploy.pytorch.kernels.cuda.w8a8_triton_kernels import per_token_quant_int8\n        states_i8, states_scale = per_token_quant_int8(hidden_states, 1e-7)\n        yield states_i8, states_scale\n\n    def quant_weight(self, w):\n        from lmdeploy.pytorch.kernels.cuda.w8a8_triton_kernels import per_channel_quant\n        num_experts, num_outs, _ = w.shape\n        w = w.flatten(0, -2)\n        w_i8, w_scale = per_channel_quant(w, torch.int8)\n        w_i8 = w_i8.view(num_experts, num_outs, -1)\n        w_scale = w_scale.view(num_experts, num_outs, -1)\n        return w_i8, w_scale\n\n    @pytest.fixture\n    def quant_w1(self, w1):\n        w_i8, w_scale = self.quant_weight(w1)\n        yield w_i8, w_scale\n\n    @pytest.fixture\n    def quant_w2(self, w2):\n        w_i8, w_scale = self.quant_weight(w2)\n        yield w_i8, w_scale\n\n    @torch.inference_mode()\n    def test_fused_moe(self, quant_states, quant_w1, quant_w2, topk_weights, topk_idx, top_k, renormalize, gt):\n        from lmdeploy.pytorch.kernels.cuda.w8a8_fused_moe import fused_moe_w8a8\n        state_i8, state_scale = quant_states\n        w1_i8, w1_scale = quant_w1\n        w2_i8, w2_scale = quant_w2\n\n        output = fused_moe_w8a8(state_i8,\n                                state_scale,\n                                w1_i8,\n                                w1_scale,\n                                w2_i8,\n                                w2_scale,\n                                topk_weights=topk_weights,\n                                topk_ids=topk_idx,\n                                topk=top_k,\n                                out_dtype=torch.float16,\n                                renormalize=renormalize)\n        torch.testing.assert_close(output, gt, atol=5e-3, rtol=1e-3)\n"
  },
  {
    "path": "tests/pytorch/kernel/test_gated_delta_rule.py",
    "content": "import pytest\nimport torch\n\n\ndef do_test():\n    try:\n        import tilelang  # noqa: F401\n        return torch.cuda.is_available()\n    except Exception:\n        return False\n\n\ndef naive_recurrent_gdr(\n    q: torch.Tensor,\n    k: torch.Tensor,\n    v: torch.Tensor,\n    beta: torch.Tensor,\n    g: torch.Tensor,\n    scale: float = None,\n    initial_state: torch.Tensor = None,\n    output_final_state: bool = False,\n    use_qk_l2norm_in_kernel: bool = False,\n):\n    dtype = q.dtype\n    if use_qk_l2norm_in_kernel:\n        q = torch.nn.functional.normalize(q, p=2, dim=-1)\n        k = torch.nn.functional.normalize(k, p=2, dim=-1)\n    q, k, v, beta, g = map(lambda x: x.transpose(1, 2).contiguous().to(torch.float32), [q, k, v, beta, g])\n    B, H, T, K, V = *k.shape, v.shape[-1]\n    o = torch.zeros(B, H, T, V).to(v)\n    h = torch.zeros(B, H, K, V).to(v)\n    if initial_state is not None:\n        h = initial_state.to(torch.float32)\n    if scale is None:\n        scale = 1 / (q.shape[-1]**0.5)\n    q = q * scale\n\n    for i in range(T):\n        b_q = q[:, :, i]\n        b_k = k[:, :, i]\n        b_v = v[:, :, i].clone()\n        h = h.clone() * g[:, :, i].exp()[..., None, None]\n        b_beta = beta[:, :, i]\n        b_v = b_v - (h.clone() * b_k[..., None]).sum(-2)\n        b_v = b_v * b_beta[..., None]\n        h = h.clone() + b_k.unsqueeze(-1) * b_v.unsqueeze(-2)\n        o[:, :, i] = torch.einsum('bhd,bhdm->bhm', b_q, h)\n\n    if not output_final_state:\n        h = None\n    o = o.transpose(1, 2).contiguous()\n    o = o.to(dtype)\n    if output_final_state:\n        h = h.to(dtype)\n    return o, h\n\n\n@pytest.mark.skipif(not do_test(), reason='tilelang is not available')\nclass TestRecurrentGatedDeltaRule:\n\n    @pytest.fixture(autouse=True)\n    def auto_context(self):\n        origin_dtype = torch.get_default_dtype()\n        origin_device = torch.get_default_device()\n        with torch.inference_mode():\n            torch.set_default_dtype(torch.bfloat16)\n            torch.set_default_device('cuda')\n            try:\n                yield\n            finally:\n                torch.set_default_dtype(origin_dtype)\n                torch.set_default_device(origin_device)\n\n    @pytest.fixture\n    def batch(self):\n        yield 512\n\n    @pytest.fixture\n    def num_heads(self):\n        yield 16\n\n    @pytest.fixture\n    def seqlen(self):\n        yield 1\n\n    @pytest.fixture\n    def head_dim(self):\n        yield 128\n\n    @pytest.fixture(params=[True, False])\n    def use_qk_l2norm_in_kernel(self, request):\n        yield request.param\n\n    @pytest.fixture\n    def q(self, batch, seqlen, num_heads, head_dim):\n        yield torch.rand(batch, seqlen, num_heads, head_dim) - 0.5\n\n    @pytest.fixture\n    def k(self, batch, seqlen, num_heads, head_dim):\n        yield torch.rand(batch, seqlen, num_heads, head_dim) - 0.5\n\n    @pytest.fixture\n    def v(self, batch, seqlen, num_heads, head_dim):\n        yield torch.rand(batch, seqlen, num_heads, head_dim) - 0.5\n\n    @pytest.fixture\n    def g(self, batch, seqlen, num_heads):\n        yield -2 * torch.rand(batch, seqlen, num_heads)\n\n    @pytest.fixture\n    def beta(self, batch, seqlen, num_heads):\n        yield torch.rand(batch, seqlen, num_heads)\n\n    @pytest.fixture\n    def initial_state(self, batch, num_heads, head_dim):\n        yield torch.rand(batch, num_heads, head_dim, head_dim) - 0.5\n\n    @pytest.fixture\n    def gt(self, q, k, v, g, beta, initial_state, use_qk_l2norm_in_kernel):\n        state_copy = initial_state.clone()\n        yield naive_recurrent_gdr(q,\n                                  k,\n                                  v,\n                                  beta,\n                                  g,\n                                  initial_state=state_copy,\n                                  output_final_state=True,\n                                  use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel)\n\n    def test_fused_gated_delta_rule(self, q, k, v, g, beta, initial_state, use_qk_l2norm_in_kernel, gt):\n        from lmdeploy.pytorch.kernels.cuda.gated_delta_rule import fused_recurrent_gated_delta_rule\n        state_copy = initial_state.clone()\n        out, out_h = fused_recurrent_gated_delta_rule(\n            q=q,\n            k=k,\n            v=v,\n            g=g,\n            beta=beta,\n            initial_state=state_copy,\n            output_final_state=True,\n            use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,\n        )\n        gt_o, gt_h = gt\n        torch.testing.assert_close(out, gt_o, atol=1e-3, rtol=1e-4)\n        torch.testing.assert_close(out_h, gt_h, atol=1e-2, rtol=1e-3)\n"
  },
  {
    "path": "tests/pytorch/kernel/test_gemm_fp8.py",
    "content": "import pytest\nimport torch\n\n\ndef _make_quant_val(shape, out_dtype):\n    x = torch.rand(shape, dtype=torch.float32, device='cuda')\n    # -1 ~ 1\n    x = x * 2 - 1\n    # scaling abs max to fmax\n    finfo = torch.finfo(out_dtype)\n    fmax = finfo.max\n    scaling = fmax / x.abs().amax(-1, keepdim=True)\n    x *= scaling\n    return x.to(out_dtype).to(torch.float32)\n\n\ndef fast_log2_ceil_torch(x: torch.Tensor) -> torch.Tensor:\n    bits_x = x.view(torch.int32)\n    exp_x = (bits_x >> 23) & 0xFF\n    man_bits = bits_x & ((1 << 23) - 1)\n    result = (exp_x - 127).to(torch.int32)\n    result = result + torch.where(man_bits != 0, 1, 0)\n\n    return result.to(torch.int32)\n\n\ndef fast_pow2_torch(x: torch.Tensor) -> torch.Tensor:\n    bits_x = (x + 127) << 23\n    return bits_x.view(torch.float32)\n\n\ndef fast_round_scale_torch(amax: torch.Tensor, fp8_max_inv: torch.Tensor) -> torch.Tensor:\n    return fast_pow2_torch(fast_log2_ceil_torch(amax * fp8_max_inv))\n\n\ndef _make_quant_scale_ue8m0(shape, out_dtype):\n    scale = torch.randn(shape, dtype=torch.float32, device='cuda')\n    finfo = torch.finfo(out_dtype)\n    fmax = finfo.max\n    scale = fast_round_scale_torch(scale, 1 / fmax)\n    return scale\n\n\ndef _make_quant_scale(shape, out_dtype, scale_fmt: str = None):\n    if scale_fmt == 'ue8m0':\n        return _make_quant_scale_ue8m0(shape, out_dtype)\n\n    # default\n    scale = torch.rand(shape, dtype=torch.float32, device='cuda')\n    finfo = torch.finfo(out_dtype)\n    fmax = finfo.max\n    scale /= fmax\n    return scale\n\n\ndef _make_A(M, K, group_size, out_dtype, scale_fmt: str = None):\n    quant_A = _make_quant_val((M, K // group_size, group_size), out_dtype)\n\n    # create scale and A\n    scale = _make_quant_scale((M, K // group_size), out_dtype, scale_fmt)\n    A = quant_A * scale[..., None]\n\n    A = A.reshape(M, K)\n    quant_A = quant_A.reshape(M, K).to(out_dtype)\n    scale = scale.T.contiguous().T\n    return A, quant_A, scale\n\n\ndef _aligned_size(a, b):\n    return (a + b - 1) // b * b\n\n\ndef _make_B(K, N, group_size, out_dtype, scale_fmt: str = None):\n    K_aligned = _aligned_size(K, group_size)\n    N_aligned = _aligned_size(N, group_size)\n\n    quant_B = _make_quant_val((K_aligned // group_size, group_size, N_aligned // group_size, group_size), out_dtype)\n\n    scale = _make_quant_scale((K_aligned // group_size, 1, N_aligned // group_size, 1), out_dtype, scale_fmt)\n\n    B = quant_B * scale\n\n    B = B.reshape(K_aligned, N_aligned)[:K, :N]\n    quant_B = quant_B.reshape(K_aligned, N_aligned).to(out_dtype)[:K, :N]\n    scale = scale.reshape(K_aligned // group_size, N_aligned // group_size)\n    quant_B = quant_B.transpose(0, 1).contiguous().transpose(0, 1)\n    return B, quant_B, scale\n\n\n@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason='require device with cc>=9.0')\nclass TestQuantFP8:\n\n    @pytest.fixture\n    def M(self, request):\n        yield request.param\n\n    @pytest.fixture\n    def K(self):\n        yield 512\n\n    @pytest.fixture\n    def group_size(self):\n        yield 128\n\n    @pytest.fixture\n    def out_dtype(self):\n        yield torch.float8_e4m3fn\n\n    @pytest.fixture\n    def scale_fmt(self, request):\n        yield request.param\n\n    @pytest.fixture\n    def build_A(self, M, K, group_size, out_dtype, scale_fmt):\n        return _make_A(M, K, group_size, out_dtype, scale_fmt)\n\n    @pytest.fixture\n    def A(self, build_A):\n        return build_A[0]\n\n    @pytest.fixture\n    def quant_A(self, build_A):\n        return build_A[1]\n\n    @pytest.fixture\n    def scale(self, build_A):\n        return build_A[2]\n\n    @pytest.fixture\n    def gt(self, quant_A, scale):\n        yield quant_A, scale\n\n    @pytest.mark.parametrize('scale_fmt', [None, 'ue8m0'], indirect=True)\n    @pytest.mark.parametrize('M', [65536, 256], indirect=True)\n    def test_quant_fp8(self, A, group_size, out_dtype, scale_fmt, gt):\n        from lmdeploy.pytorch.kernels.cuda.blocked_gemm_fp8 import quant_fp8\n        quant_A_gt, scale_gt = gt\n\n        quant_A, scale = quant_fp8(A, group_size=group_size, dtype=out_dtype, scale_fmt=scale_fmt)\n        torch.testing.assert_close(scale, scale_gt)\n        diff = (quant_A.to(torch.float16) - quant_A_gt.to(torch.float16)).abs()\n        diff_count = (diff > 1e-5).count_nonzero()\n        assert diff_count / diff.numel() < 1e-4\n\n\n@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason='require device with cc>=9.0')\nclass TestGemmFP8:\n\n    @pytest.fixture\n    def M(self):\n        yield 256\n\n    @pytest.fixture\n    def N(self):\n        # test non-aligned\n        yield 1024 + 64\n\n    @pytest.fixture\n    def K(self):\n        yield 512\n\n    @pytest.fixture\n    def group_size(self):\n        yield 128\n\n    @pytest.fixture\n    def quant_dtype(self):\n        yield torch.float8_e4m3fn\n\n    @pytest.fixture\n    def out_dtype(self):\n        yield torch.float16\n\n    @pytest.fixture\n    def build_A(self, M, K, group_size, quant_dtype):\n        return _make_A(M, K, group_size, quant_dtype)\n\n    @pytest.fixture\n    def A(self, build_A, out_dtype):\n        return build_A[0].to(out_dtype)\n\n    @pytest.fixture\n    def quant_A(self, build_A):\n        return build_A[1]\n\n    @pytest.fixture\n    def scale_A(self, build_A):\n        return build_A[2]\n\n    @pytest.fixture\n    def build_B(self, K, N, group_size, quant_dtype):\n        return _make_B(K, N, group_size, quant_dtype)\n\n    @pytest.fixture\n    def B(self, build_B, out_dtype):\n        return build_B[0].to(out_dtype)\n\n    @pytest.fixture\n    def quant_B(self, build_B):\n        return build_B[1]\n\n    @pytest.fixture\n    def scale_B(self, build_B):\n        return build_B[2]\n\n    @pytest.fixture\n    def gt(self, A, B):\n        yield A @ B\n\n    def test_gemm_fp8(self, quant_A, scale_A, quant_B, scale_B, out_dtype, gt):\n        from lmdeploy.pytorch.kernels.cuda.blocked_gemm_fp8 import blocked_gemm_fp8\n        C = blocked_gemm_fp8(quant_A, scale_A, quant_B, scale_B, out_dtype=out_dtype)\n        torch.testing.assert_close(C, gt, atol=0.5, rtol=1e-4)\n"
  },
  {
    "path": "tests/pytorch/kernel/test_moe_route.py",
    "content": "import pytest\nimport torch\n\n\ndef reference_noaux_tc_routing(\n    logits: torch.Tensor,\n    bias: torch.Tensor,\n    num_experts: int = 256,\n    n_group: int = 8,\n    topk_group: int = 4,\n    top_k: int = 8,\n    renormalize: bool = True,\n    routed_scaling_factor: float = 2.5,\n) -> tuple[torch.Tensor, torch.Tensor]:\n    batch_size = logits.shape[0]\n    scores = torch.sigmoid(logits.float())\n    scores_for_choice = scores + bias[None, :]\n\n    group_size = num_experts // n_group\n    grouped_scores = scores_for_choice.view(batch_size, n_group, group_size)\n    group_scores = grouped_scores.topk(2, dim=-1)[0].sum(dim=-1)\n\n    group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[1]\n    group_mask = torch.zeros_like(group_scores).scatter_(1, group_idx, 1)\n\n    score_mask = group_mask.unsqueeze(-1).expand(batch_size, n_group, group_size).reshape(batch_size, -1)\n    # Note: Using 0.0 matches the actual inference code in deepseek_v2.py\n    # Works correctly because sigmoid scores are always in (0, 1)\n    tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)\n\n    _, topk_idx = torch.topk(tmp_scores, k=top_k, dim=-1, sorted=False)\n    topk_weight = scores.gather(1, topk_idx)\n\n    if renormalize:\n        topk_weight = topk_weight / (topk_weight.sum(dim=-1, keepdim=True) + 1e-20)\n\n    return topk_weight * routed_scaling_factor, topk_idx\n\n\nclass TestNoauxTC:\n\n    @pytest.fixture(autouse=True)\n    def auto_context(self):\n        origin_dtype = torch.get_default_dtype()\n        origin_device = torch.get_default_device()\n        with torch.inference_mode():\n            torch.set_default_dtype(torch.float32)\n            torch.set_default_device('cuda')\n            try:\n                yield\n            finally:\n                torch.set_default_dtype(origin_dtype)\n                torch.set_default_device(origin_device)\n\n    @pytest.fixture\n    def batch_size(self):\n        yield 32\n\n    @pytest.fixture\n    def num_experts(self):\n        yield 256\n\n    @pytest.fixture\n    def logits(self, batch_size, num_experts):\n        yield torch.randn(batch_size, num_experts)\n\n    @pytest.fixture\n    def bias(self, num_experts):\n        yield torch.randn(num_experts)\n\n    @pytest.fixture\n    def kwargs(self):\n        yield {\n            'num_experts': 256,\n            'n_group': 8,\n            'topk_group': 4,\n            'top_k': 8,\n            'renormalize': True,\n            'routed_scaling_factor': 2.5,\n        }\n\n    @pytest.fixture\n    def gt(self, logits, bias, kwargs):\n        yield reference_noaux_tc_routing(logits, bias, **kwargs)\n\n    def test_noaux_tc_router(self, logits, bias, kwargs, gt):\n        from lmdeploy.pytorch.kernels.cuda.fused_noaux_tc import fused_noaux_tc_routing\n\n        out_weights, out_ids = fused_noaux_tc_routing(logits, bias, **kwargs)\n        gt_weights, gt_ids = gt\n\n        torch.testing.assert_close(out_weights, gt_weights, rtol=1e-4, atol=1e-5)\n        # topk in torch is not stable, so we won't assert ids\n"
  },
  {
    "path": "tests/pytorch/kernel/test_multinomial_sampling.py",
    "content": "import pytest\nimport torch\n\nfrom lmdeploy.utils import is_bf16_supported\n\n\ndef _bf16_mark():\n    return pytest.mark.skipif(not is_bf16_supported(), reason='bf16 not supported.')\n\n\nclass TestMultinomialSampling:\n\n    @pytest.fixture\n    def num_tokens(self, request):\n        yield request.param\n\n    @pytest.fixture\n    def select_ids(self, request):\n        yield request.param\n\n    @pytest.fixture\n    def batch_size(self, select_ids):\n        yield len(select_ids)\n\n    @pytest.fixture\n    def dtype(self, request):\n        yield request.param\n\n    @pytest.fixture\n    def scores(self, num_tokens, batch_size, select_ids, dtype):\n        ret = torch.zeros(batch_size, num_tokens).cuda()\n        batch_ids = torch.arange(batch_size).cuda()\n        ret[batch_ids, select_ids] = 1\n        ret = ret.to(dtype)\n        yield ret\n\n    @pytest.fixture\n    def seeds(self, batch_size):\n        yield torch.randint(1000, 2000, (batch_size, )).cuda()\n\n    @pytest.fixture\n    def offsets(self, batch_size):\n        yield torch.randint(1000, 2000, (batch_size, )).cuda()\n\n    @pytest.fixture\n    def indices(self, scores):\n        num_tokens = scores.size(1)\n        ret = [torch.randperm(num_tokens) for _ in scores]\n        ret = torch.stack(ret, 0).cuda()\n        yield ret\n\n    @pytest.fixture\n    def gt(self, batch_size, select_ids, indices):\n        batch_ids = torch.arange(batch_size).cuda()\n        yield indices[batch_ids, select_ids]\n\n    @pytest.mark.parametrize('dtype', [torch.float32, torch.half, pytest.param(torch.bfloat16, marks=_bf16_mark())])\n    @pytest.mark.parametrize(['num_tokens', 'select_ids'], [\n        (8, (4, 2) * 30),\n        (2000, (500, 1500)),\n    ], indirect=True)\n    def test_multinomial_sampling(self, scores, seeds, offsets, indices, gt):\n        from lmdeploy.pytorch.kernels.cuda import multinomial_sampling\n        output = multinomial_sampling(scores, seeds, offsets, indices)\n        torch.testing.assert_close(output, gt)\n"
  },
  {
    "path": "tests/pytorch/kernel/test_paged_attention.py",
    "content": "import math\n\nimport pytest\nimport torch\n\n\ndef _conti_input(data, seq_lens):\n    data = [x[:l] for x, l in zip(data, seq_lens)]\n    data = torch.cat(data, dim=0)\n    return data\n\n\ndef _make_bias(q_seqlens, history_lens, neg_val):\n    batch_size = q_seqlens.shape[0]\n    full_seq_lens = q_seqlens + history_lens\n    max_seq_len = q_seqlens.max().item()\n    max_kv_len = full_seq_lens.max().item()\n    seq_ranges = torch.arange(max_seq_len).cuda()\n    seq_ranges = seq_ranges.repeat(batch_size, 1)\n    seq_ranges = torch.where(seq_ranges < q_seqlens[:, None], seq_ranges, -max_kv_len)\n\n    kv_ranges = torch.arange(max_kv_len).cuda()\n    kv_ranges = kv_ranges.repeat(batch_size, 1)\n    mask = kv_ranges[:, None, :] - seq_ranges[:, :, None] > history_lens[:, None, None]\n    return mask.float() * neg_val\n\n\ndef _make_alibi_bias(q_seqlens, history_lens, neg_val, alibi_slopes):\n    batch_size = q_seqlens.shape[0]\n    kv_seqlens = q_seqlens + history_lens\n    max_seq_len = q_seqlens.max().item()\n    max_kv_len = kv_seqlens.max().item()\n\n    seq_ranges = torch.arange(max_seq_len).cuda()\n    seq_ranges = seq_ranges.repeat(batch_size, 1) + history_lens[:, None]\n\n    kv_ranges = torch.arange(max_kv_len).cuda()\n    kv_ranges = kv_ranges.repeat(batch_size, 1)\n\n    diff = (seq_ranges[:, :, None] - kv_ranges[:, None, :]).abs()\n    slope_diff = -diff[:, None] * alibi_slopes[None, :, None, None]\n\n    # add bias\n    bias = _make_bias(q_seqlens, history_lens, neg_val)\n    bias = bias[:, None] + slope_diff\n    return bias\n\n\ndef _make_block_sparse_bias(q_seqlens: torch.Tensor, history_lens: torch.Tensor, neg_val: float,\n                            block_sparse_size: int):\n    \"\"\"Make block sparse bias.\"\"\"\n    batch_size = q_seqlens.shape[0]\n    kv_seqlens = q_seqlens + history_lens\n    max_seq_len = q_seqlens.max().item()\n    max_kv_len = kv_seqlens.max().item()\n\n    seq_ranges = torch.arange(max_seq_len).cuda()\n    seq_ranges = seq_ranges // block_sparse_size * block_sparse_size\n    seq_ranges = seq_ranges.repeat(batch_size, 1)\n    seq_ranges = torch.where(seq_ranges < q_seqlens[:, None], seq_ranges, -max_kv_len)\n\n    kv_ranges = torch.arange(max_kv_len).cuda()\n    kv_ranges = kv_ranges // block_sparse_size * block_sparse_size\n    kv_ranges = kv_ranges.repeat(batch_size, 1)\n\n    mask = (kv_ranges[:, None, :] - seq_ranges[:, :, None] > history_lens[:, None, None])\n    return mask.float() * neg_val\n\n\ndef _make_blocked_cache(batched_k,\n                        batched_v,\n                        seq_lens,\n                        history_lens,\n                        block_offsets,\n                        block_size,\n                        num_heads_k,\n                        feat_dim,\n                        feat_dim_v,\n                        layout: str = 'bshd'):\n    max_blocks_nums = block_offsets.max() + 1\n    full_seq_lens = seq_lens + history_lens\n    blocked_k = batched_k.new_zeros(max_blocks_nums, block_size, num_heads_k, feat_dim)\n    blocked_v = batched_v.new_zeros(max_blocks_nums, block_size, num_heads_k, feat_dim_v)\n\n    for batch_id, offset in enumerate(block_offsets):\n        ori_k = batched_k[batch_id]\n        ori_v = batched_v[batch_id]\n        seq_len = full_seq_lens[batch_id]\n        for block_id, block_start in enumerate(range(0, seq_len, block_size)):\n            block_off = offset[block_id]\n            tmp_k = ori_k[block_start:block_start + block_size]\n            tmp_v = ori_v[block_start:block_start + block_size]\n            size = tmp_k.size(0)\n            blocked_k[block_off, :size] = tmp_k\n            blocked_v[block_off, :size] = tmp_v\n\n    if layout == 'bhsd':\n        blocked_k = blocked_k.transpose(1, 2).contiguous()\n        blocked_v = blocked_v.transpose(1, 2).contiguous()\n\n    return blocked_k, blocked_v\n\n\ndef _naive_attention(batched_q, batched_kv, bias, sinks=None):\n    batched_k, batched_v = batched_kv\n\n    num_heads_q = batched_q.shape[2]\n    num_heads_k = batched_k.shape[2]\n    head_dim = batched_q.shape[-1]\n    group = num_heads_q // num_heads_k\n\n    q = batched_q.transpose(1, 2)\n    k = batched_k.permute(0, 2, 3, 1)\n    v = batched_v.transpose(1, 2)\n\n    # expand group\n    k = k.unsqueeze(2).expand(-1, -1, group, -1, -1).flatten(1, 2)\n    v = v.unsqueeze(2).expand(-1, -1, group, -1, -1).flatten(1, 2)\n\n    qk = torch.matmul(q, k) / math.sqrt(head_dim)\n    if bias.dim() == 3:\n        bias = bias[:, None]\n    attn_weight = qk + bias\n    if sinks is None:\n        attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32)\n    else:\n        sinks = sinks[None, :, None, None].to(torch.float32)\n        sinks = sinks.expand(attn_weight.shape[0], -1, attn_weight.shape[2], -1)\n        attn_weight = attn_weight.to(torch.float32)\n        combined_logits = torch.cat([attn_weight, sinks], dim=-1)\n        combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values\n        attn_weight = torch.softmax(combined_logits, dim=-1, dtype=torch.float32)\n        attn_weight = attn_weight[..., :-1]\n    attn_weight = attn_weight.to(q.dtype)\n    attn_output = torch.matmul(attn_weight, v)\n    attn_output = attn_output.transpose(1, 2).contiguous()\n\n    return attn_output\n\n\ndef _naive_window_attention(q, k, v, seqlens_q, seqlens_k, window_size):\n    try:\n        from lmdeploy.pytorch.third_party.flash_attn_interface import flash_attn_varlen_func\n    except Exception:\n        try:\n            from flash_attn import flash_attn_varlen_func\n        except Exception:\n            pytest.skip('Skip window attention test since flash attention is not available.')\n\n    def _make_cu_seqlens(seqlens):\n        cu_seqlens = seqlens.cumsum(0)\n        cu_zero = cu_seqlens.new_zeros(1)\n        cu_seqlens = torch.cat([cu_zero, cu_seqlens])\n        return cu_seqlens\n\n    max_seqlen_q = seqlens_q.max().item()\n    max_seqlen_k = seqlens_k.max().item()\n    cu_seqlens_q = _make_cu_seqlens(seqlens_q).int()\n    cu_seqlens_k = _make_cu_seqlens(seqlens_k).int()\n\n    output = flash_attn_varlen_func(q,\n                                    k,\n                                    v,\n                                    cu_seqlens_q,\n                                    cu_seqlens_k,\n                                    max_seqlen_q=max_seqlen_q,\n                                    max_seqlen_k=max_seqlen_k,\n                                    causal=True,\n                                    window_size=window_size)\n    return output\n\n\nclass TestPagedAttentionBase:\n\n    @pytest.fixture\n    def dtype(self):\n        yield torch.float16\n\n    @pytest.fixture\n    def feat_dim(self, request):\n        yield request.param\n\n    @pytest.fixture\n    def feat_dim_v(self, request):\n        yield request.param\n\n    @pytest.fixture\n    def num_heads_q(self, request):\n        yield request.param\n\n    @pytest.fixture\n    def num_heads_k(self, request):\n        yield request.param\n\n    @pytest.fixture\n    def block_size(self, request):\n        yield request.param\n\n    @pytest.fixture\n    def layout(self, request):\n        yield request.param\n\n    @pytest.fixture\n    def history_lens(self, request):\n        yield torch.tensor(request.param, device='cuda')\n\n    @pytest.fixture\n    def seq_len(self):\n        yield 1\n\n    @pytest.fixture\n    def seq_lens(self, seq_len, history_lens):\n        yield torch.ones_like(history_lens) * seq_len\n\n    @pytest.fixture\n    def kv_seqlens(self, seq_lens, history_lens):\n        yield seq_lens + history_lens\n\n    @pytest.fixture\n    def batched_q(self, seq_len, kv_seqlens, num_heads_q, feat_dim, dtype):\n        torch.manual_seed(123)\n        batch_size = len(kv_seqlens)\n        inputs = torch.rand(batch_size, seq_len, num_heads_q, feat_dim, dtype=dtype, device='cuda')\n        yield inputs\n\n    @pytest.fixture\n    def batched_kv(self, kv_seqlens, num_heads_k, feat_dim, feat_dim_v, dtype):\n        torch.manual_seed(123)\n        batch_size = len(kv_seqlens)\n        max_seq_len = kv_seqlens.max().item()\n        k = torch.rand(batch_size, max_seq_len, num_heads_k, feat_dim, dtype=dtype, device='cuda')\n        v = torch.rand(batch_size, max_seq_len, num_heads_k, feat_dim_v, dtype=dtype, device='cuda')\n        yield k, v\n\n    @pytest.fixture\n    def conti_q(self, seq_lens, batched_q):\n        yield _conti_input(batched_q, seq_lens)\n\n    @pytest.fixture\n    def block_offsets(self, kv_seqlens, block_size):\n        batch_size = kv_seqlens.size(0)\n        num_blocks = (kv_seqlens + block_size - 1) // block_size\n\n        offset = [torch.arange(size) * batch_size + idx for idx, size in enumerate(num_blocks)]\n        max_len = max(len(o) for o in offset)\n        new_offset = offset[0].new_zeros(batch_size, max_len)\n        for o, no in zip(offset, new_offset):\n            len_o = o.size(0)\n            no[:len_o] = o\n\n        yield new_offset.cuda()\n\n    @pytest.fixture\n    def conti_kv(self, batched_kv, history_lens):\n        full_seq_lens = 1 + history_lens\n        conti_k = _conti_input(batched_kv[0], full_seq_lens)\n        conti_v = _conti_input(batched_kv[1], full_seq_lens)\n        yield (conti_k, conti_v)\n\n    @pytest.fixture\n    def blocked_kv(self, batched_kv, kv_seqlens, history_lens, block_offsets, block_size, num_heads_k, feat_dim,\n                   feat_dim_v, layout):\n        batched_k, batched_v = batched_kv\n        seq_lens = torch.ones_like(kv_seqlens)\n        yield _make_blocked_cache(batched_k, batched_v, seq_lens, history_lens, block_offsets, block_size, num_heads_k,\n                                  feat_dim, feat_dim_v, layout)\n\n    @pytest.fixture\n    def mask(self, history_lens):\n        neg_val = -1e30\n        seq_lens = torch.ones_like(history_lens)\n        yield _make_bias(seq_lens, history_lens, neg_val)\n\n    @pytest.fixture\n    def gt(self, batched_q, batched_kv, mask):\n        yield _naive_attention(batched_q, batched_kv, mask)\n\n    @pytest.fixture\n    def conti_gt(self, gt, seq_lens):\n        yield _conti_input(gt, seq_lens)\n\n\nclass TestPagedAttention(TestPagedAttentionBase):\n\n    @pytest.mark.parametrize('feat_dim', [32, 32], indirect=True)\n    @pytest.mark.parametrize('feat_dim_v', [32], indirect=True)\n    @pytest.mark.parametrize(['num_heads_q', 'num_heads_k'], [(128, 2), (8, 2), (2, 2)], indirect=True)\n    @pytest.mark.parametrize('history_lens', [(50, 40, 30, 20)], indirect=True)\n    @pytest.mark.parametrize('block_size', [16], indirect=True)\n    @pytest.mark.parametrize('layout', ['bshd', 'bhsd'], indirect=True)\n    def test_paged_attention(self, conti_q, blocked_kv, block_offsets, kv_seqlens, layout, conti_gt):\n        from lmdeploy.pytorch.kernels.cuda import flash_attn_with_kvcache\n\n        blocked_k, blocked_v = blocked_kv\n        out = flash_attn_with_kvcache(conti_q,\n                                      blocked_k,\n                                      blocked_v,\n                                      page_table=block_offsets,\n                                      cache_seqlens=kv_seqlens,\n                                      kv_layout=layout)\n        torch.testing.assert_close(out, conti_gt, atol=1e-3, rtol=1e-5)\n\n    @pytest.fixture\n    def win_size(self, request):\n        yield request.param\n\n    @pytest.fixture\n    def window_gt(self, conti_q, conti_kv, seq_lens, history_lens, win_size):\n        kv_lens = seq_lens + history_lens\n        yield _naive_window_attention(conti_q,\n                                      conti_kv[0],\n                                      conti_kv[1],\n                                      seq_lens,\n                                      kv_lens,\n                                      window_size=(win_size, win_size))\n\n    @pytest.mark.parametrize('feat_dim', [16], indirect=True)\n    @pytest.mark.parametrize('feat_dim_v', [16], indirect=True)\n    @pytest.mark.parametrize(['num_heads_q', 'num_heads_k'], [(4, 2)], indirect=True)\n    @pytest.mark.parametrize('history_lens', [\n        (50, 40, 30, 20),\n    ], indirect=True)\n    @pytest.mark.parametrize('win_size', (32, ), indirect=True)\n    @pytest.mark.parametrize('block_size', [16], indirect=True)\n    @pytest.mark.parametrize('layout', ['bshd'], indirect=True)\n    def test_window_attention(self, conti_q, blocked_kv, block_offsets, kv_seqlens, win_size, layout, window_gt):\n        from lmdeploy.pytorch.kernels.cuda import flash_attn_with_kvcache\n\n        blocked_k, blocked_v = blocked_kv\n        out = flash_attn_with_kvcache(conti_q,\n                                      blocked_k,\n                                      blocked_v,\n                                      page_table=block_offsets,\n                                      cache_seqlens=kv_seqlens,\n                                      window_size=win_size,\n                                      kv_layout=layout)\n        torch.testing.assert_close(out, window_gt, atol=1e-3, rtol=1e-5)\n\n\nclass TestPagedAttentionSink(TestPagedAttentionBase):\n\n    @pytest.fixture\n    def sinks(self, num_heads_q, dtype):\n        yield torch.rand(num_heads_q, dtype=dtype, device='cuda')\n\n    @pytest.fixture\n    def sink_gt(self, batched_q, batched_kv, mask, sinks):\n        yield _naive_attention(batched_q, batched_kv, mask, sinks)\n\n    @pytest.fixture\n    def conti_sink_gt(self, sink_gt, seq_lens):\n        yield _conti_input(sink_gt, seq_lens)\n\n    @pytest.mark.parametrize('feat_dim', [32], indirect=True)\n    @pytest.mark.parametrize('feat_dim_v', [32], indirect=True)\n    @pytest.mark.parametrize(['num_heads_q', 'num_heads_k'], [(8, 2)], indirect=True)\n    @pytest.mark.parametrize('history_lens', [(50, 40, 30, 20)], indirect=True)\n    @pytest.mark.parametrize('block_size', [16], indirect=True)\n    @pytest.mark.parametrize('layout', ['bshd'], indirect=True)\n    def test_paged_attention(self, conti_q, blocked_kv, block_offsets, kv_seqlens, layout, sinks, conti_sink_gt):\n        from lmdeploy.pytorch.kernels.cuda import flash_attn_with_kvcache\n\n        blocked_k, blocked_v = blocked_kv\n\n        out = flash_attn_with_kvcache(conti_q,\n                                      blocked_k,\n                                      blocked_v,\n                                      page_table=block_offsets,\n                                      cache_seqlens=kv_seqlens,\n                                      sinks=sinks,\n                                      kv_layout=layout)\n        torch.testing.assert_close(out, conti_sink_gt, atol=1e-3, rtol=1e-5)\n\n\ndef quant(kv: torch.Tensor, nbits: int = 8):\n    \"\"\"Quant kv on the head_dim.\"\"\"\n    amax = kv.amax(dim=-1, keepdim=True)\n    amin = kv.amin(dim=-1, keepdim=True)\n    scales = (amax - amin) / (2**nbits - 1)\n    zeros = -amin / scales\n    q_kv = (kv / scales + zeros + 0.5).to(torch.uint8)\n    if nbits == 4:\n        q_kv1, q_kv2 = q_kv.split(q_kv.shape[-1] // 2, -1)\n        q_kv = q_kv1 + q_kv2 * 16\n    return q_kv, torch.cat([scales, zeros], dim=-1)\n\n\ndef _make_blocked_cache_quant(batched_k, batched_v, seq_lens, history_lens, block_offsets, block_size, num_heads_k,\n                              feat_dim, feat_dim_v, nbits):\n    max_blocks_nums = block_offsets.max() + 1\n    full_seq_lens = seq_lens + history_lens\n    batched_k, k_scales_zeros = quant(batched_k, nbits)\n    batched_v, v_scales_zeros = quant(batched_v, nbits)\n    if nbits == 4:\n        feat_dim //= 2\n        feat_dim_v //= 2\n    blocked_k = batched_k.new_zeros(max_blocks_nums, block_size, num_heads_k, feat_dim)\n    blocked_v = batched_v.new_zeros(max_blocks_nums, block_size, num_heads_k, feat_dim_v)\n    blocked_ksz = k_scales_zeros.new_zeros(max_blocks_nums, block_size, num_heads_k, 2)\n    blocked_vsz = v_scales_zeros.new_zeros(max_blocks_nums, block_size, num_heads_k, 2)\n\n    for batch_id, offset in enumerate(block_offsets):\n        ori_k = batched_k[batch_id]\n        ori_v = batched_v[batch_id]\n        ori_ksz = k_scales_zeros[batch_id]\n        ori_vsz = v_scales_zeros[batch_id]\n        seq_len = full_seq_lens[batch_id]\n        for block_id, block_start in enumerate(range(0, seq_len, block_size)):\n            block_off = offset[block_id]\n            tmp_k = ori_k[block_start:block_start + block_size]\n            tmp_v = ori_v[block_start:block_start + block_size]\n            tmp_ksz = ori_ksz[block_start:block_start + block_size]\n            tmp_vsz = ori_vsz[block_start:block_start + block_size]\n            size = tmp_k.size(0)\n            blocked_k[block_off, :size] = tmp_k\n            blocked_v[block_off, :size] = tmp_v\n            blocked_ksz[block_off, :size] = tmp_ksz\n            blocked_vsz[block_off, :size] = tmp_vsz\n\n    return blocked_k, blocked_v, blocked_ksz, blocked_vsz\n\n\nclass TestPagedAttentionInt8(TestPagedAttention):\n\n    @pytest.fixture\n    def nbits(self):\n        yield 8\n\n    @pytest.fixture\n    def blocked_kv(self, batched_kv, seq_lens, history_lens, block_offsets, block_size, num_heads_k, feat_dim,\n                   feat_dim_v, nbits):\n        batched_k, batched_v = batched_kv\n        yield _make_blocked_cache_quant(batched_k, batched_v, seq_lens, history_lens, block_offsets, block_size,\n                                        num_heads_k, feat_dim, feat_dim_v, nbits)\n\n    @pytest.mark.parametrize('feat_dim', [48, 32], indirect=True)\n    @pytest.mark.parametrize('feat_dim_v', [32], indirect=True)\n    @pytest.mark.parametrize(['num_heads_q', 'num_heads_k'], [(8, 2), (2, 2)], indirect=True)\n    @pytest.mark.parametrize('history_lens', [(50, 40, 30, 20)], indirect=True)\n    @pytest.mark.parametrize('block_size', [16], indirect=True)\n    def test_paged_attention(self, conti_q, blocked_kv, block_offsets, kv_seqlens, conti_gt, nbits):\n        from lmdeploy.pytorch.kernels.cuda import flash_attn_with_kvcache\n\n        blocked_k, blocked_v, blocked_ksz, blocked_vsz = blocked_kv\n\n        out = flash_attn_with_kvcache(conti_q,\n                                      blocked_k,\n                                      blocked_v,\n                                      k_scales_zeros=blocked_ksz,\n                                      v_scales_zeros=blocked_vsz,\n                                      quant_policy=nbits,\n                                      page_table=block_offsets,\n                                      cache_seqlens=kv_seqlens)\n        if nbits == 4:\n            torch.testing.assert_close(out, conti_gt, atol=0.05, rtol=0.01)\n        else:\n            torch.testing.assert_close(out, conti_gt, atol=1e-3, rtol=1e-5)\n\n    @pytest.mark.parametrize('feat_dim', [16], indirect=True)\n    @pytest.mark.parametrize('feat_dim_v', [16], indirect=True)\n    @pytest.mark.parametrize(['num_heads_q', 'num_heads_k'], [(4, 2)], indirect=True)\n    @pytest.mark.parametrize('history_lens', [\n        (50, 40, 30, 20),\n    ], indirect=True)\n    @pytest.mark.parametrize('win_size', (32, ), indirect=True)\n    @pytest.mark.parametrize('block_size', [16], indirect=True)\n    def test_window_attention(self, conti_q, blocked_kv, block_offsets, kv_seqlens, win_size, window_gt, nbits):\n        from lmdeploy.pytorch.kernels.cuda import flash_attn_with_kvcache\n\n        blocked_k, blocked_v, blocked_ksz, blocked_vsz = blocked_kv\n        out = flash_attn_with_kvcache(conti_q,\n                                      blocked_k,\n                                      blocked_v,\n                                      k_scales_zeros=blocked_ksz,\n                                      v_scales_zeros=blocked_vsz,\n                                      quant_policy=nbits,\n                                      page_table=block_offsets,\n                                      cache_seqlens=kv_seqlens,\n                                      window_size=win_size)\n        if nbits == 4:\n            torch.testing.assert_close(out, window_gt, atol=0.05, rtol=0.01)\n        else:\n            torch.testing.assert_close(out, window_gt, atol=1e-3, rtol=1e-5)\n\n\nclass TestPagedAttentionInt4(TestPagedAttentionInt8):\n\n    @pytest.fixture\n    def nbits(self):\n        yield 4\n\n\nclass TestPagedAttentionBlockDecoding(TestPagedAttentionBase):\n\n    @pytest.fixture\n    def seq_len(self):\n        yield 4\n\n    @pytest.fixture\n    def mask(self, seq_lens, history_lens, seq_len):\n        neg_val = -1e30\n        yield _make_block_sparse_bias(seq_lens, history_lens, neg_val, seq_len)\n\n    @pytest.fixture\n    def gt(self, batched_q, batched_kv, mask):\n        yield _naive_attention(batched_q, batched_kv, mask)\n\n    @pytest.fixture\n    def conti_gt(self, gt, seq_lens):\n        yield _conti_input(gt, seq_lens)\n\n    @pytest.mark.parametrize('feat_dim', [48, 32], indirect=True)\n    @pytest.mark.parametrize('feat_dim_v', [32], indirect=True)\n    @pytest.mark.parametrize(['num_heads_q', 'num_heads_k'], [(128, 2), (8, 2), (2, 2)], indirect=True)\n    @pytest.mark.parametrize('history_lens', [(52, 40, 32, 20)], indirect=True)\n    @pytest.mark.parametrize('block_size', [16], indirect=True)\n    @pytest.mark.parametrize('layout', ['bshd'], indirect=True)\n    def test_paged_attention(self, conti_q, blocked_kv, block_offsets, kv_seqlens, layout, conti_gt):\n        from lmdeploy.pytorch.kernels.cuda import flash_attn_with_kvcache\n\n        blocked_k, blocked_v = blocked_kv\n\n        out = flash_attn_with_kvcache(conti_q,\n                                      blocked_k,\n                                      blocked_v,\n                                      page_table=block_offsets,\n                                      cache_seqlens=kv_seqlens,\n                                      kv_layout=layout)\n        torch.testing.assert_close(out, conti_gt, atol=1e-3, rtol=1e-5)\n\n\nclass TestPagedAttentionAlibi(TestPagedAttentionBase):\n\n    @pytest.fixture\n    def alibi_slopes(self, num_heads_q):\n        yield torch.rand(num_heads_q, dtype=torch.float32, device='cuda')\n\n    @pytest.fixture\n    def mask(self, seq_lens, history_lens, alibi_slopes):\n        neg_val = -1e30\n        yield _make_alibi_bias(seq_lens, history_lens, neg_val, alibi_slopes)\n\n    @pytest.mark.parametrize('feat_dim', [128], indirect=True)\n    @pytest.mark.parametrize('feat_dim_v', [128], indirect=True)\n    @pytest.mark.parametrize(['num_heads_q', 'num_heads_k'], [(40, 8)], indirect=True)\n    @pytest.mark.parametrize('history_lens', [(52, 40, 32, 20)], indirect=True)\n    @pytest.mark.parametrize('layout', ['bshd'], indirect=True)\n    @pytest.mark.parametrize('block_size', [16], indirect=True)\n    def test_paged_attention(self, conti_q, blocked_kv, block_offsets, kv_seqlens, layout, alibi_slopes, conti_gt):\n        from lmdeploy.pytorch.kernels.cuda import flash_attn_with_kvcache\n\n        blocked_k, blocked_v = blocked_kv\n\n        out = flash_attn_with_kvcache(conti_q,\n                                      blocked_k,\n                                      blocked_v,\n                                      page_table=block_offsets,\n                                      cache_seqlens=kv_seqlens,\n                                      alibi_slopes=alibi_slopes,\n                                      kv_layout=layout)\n        torch.testing.assert_close(out, conti_gt, atol=1e-3, rtol=1e-5)\n"
  },
  {
    "path": "tests/pytorch/kernel/test_rms_norm.py",
    "content": "import pytest\nimport torch\n\nfrom lmdeploy.utils import is_bf16_supported\n\n\ndef _bf16_mark():\n    return pytest.mark.skipif(not is_bf16_supported(), reason='bf16 not supported.')\n\n\nclass TestRMSNorm:\n\n    @pytest.fixture(autouse=True, scope='class')\n    def initialize(self):\n        seed = 42\n        torch.manual_seed(seed)\n        torch.cuda.manual_seed(seed)\n        yield\n\n    @pytest.fixture(scope='class')\n    def dtype(self, request):\n        yield request.param\n\n    @pytest.fixture(scope='class')\n    def input_shape(self, request):\n        yield request.param\n\n    @pytest.fixture(scope='class')\n    def hidden_size(self, input_shape):\n        yield input_shape[-1]\n\n    @pytest.fixture(scope='class')\n    def input(self, dtype, input_shape):\n        yield torch.randn(input_shape, dtype=dtype, device='cuda')\n\n    @pytest.fixture(scope='class')\n    def weight(self, dtype, hidden_size):\n        yield torch.randn(hidden_size, dtype=dtype, device='cuda')\n\n    @pytest.fixture(scope='class')\n    def eps(self):\n        yield 1e-6\n\n    @pytest.fixture(scope='class')\n    def gt(self, input, weight, eps):\n        input_dtype = input.dtype\n        input = input.to(torch.float32)\n        variance = (input * input).mean(-1, keepdim=True)\n        input = input * torch.rsqrt(variance + eps)\n        return weight * input.to(input_dtype)\n\n    @pytest.mark.parametrize('input_shape', [(2, 4, 4096), (4, 4096), (4096, )], indirect=True)\n    @pytest.mark.parametrize('dtype', [pytest.param(torch.bfloat16, marks=_bf16_mark()), torch.float16], indirect=True)\n    def test_rms_norm(self, input, weight, eps, gt):\n        from lmdeploy.pytorch.kernels.cuda import rms_norm\n\n        out = rms_norm(input, weight, eps)\n        torch.testing.assert_close(out, gt)\n\n    @pytest.fixture(scope='class')\n    def residual(self, dtype, input_shape):\n        yield torch.randn(input_shape, dtype=dtype, device='cuda')\n\n    @pytest.fixture(scope='class')\n    def gt_residual(self, input, residual, weight, eps):\n\n        input = input + residual\n        out_res = input\n        input_dtype = input.dtype\n        input = input.to(torch.float32)\n        variance = (input * input).mean(-1, keepdim=True)\n        input = input * torch.rsqrt(variance + eps)\n        return weight * input.to(input_dtype), out_res\n\n    @pytest.mark.parametrize('input_shape', [(2, 4, 4096), (4, 4096), (4096, )], indirect=True)\n    @pytest.mark.parametrize('dtype', [pytest.param(torch.bfloat16, marks=_bf16_mark()), torch.float16], indirect=True)\n    def test_rms_norm_residual(self, input, residual, weight, eps, gt_residual):\n        from lmdeploy.pytorch.kernels.cuda import rms_norm\n\n        out, out_res = rms_norm(input, weight, eps, residual=residual)\n        gt, gt_res = gt_residual\n        torch.testing.assert_close(out, gt)\n        torch.testing.assert_close(out_res, gt_res)\n"
  },
  {
    "path": "tests/pytorch/nn/test_embedding.py",
    "content": "import os\nimport time\n\nimport pytest\nimport torch\nimport torch.distributed as dist\nimport torch.multiprocessing as mp\nfrom torch import nn\n\nfrom lmdeploy.pytorch.distributed import DefaultContext\nfrom lmdeploy.pytorch.nn import ParallelEmbedding\n\n\ndef parallel_emb(rank: int, world_size: int, vocab_size: int, feat_size: int, padding_idx: int, dtype: torch.dtype,\n                 x: torch.Tensor, weight: torch.Tensor, result_queue: mp.Queue):\n    dist.init_process_group('nccl', rank=rank, world_size=world_size)\n    gpu_group = dist.new_group(ranks=list(range(world_size)), backend='nccl')\n\n    DefaultContext.attn_tp_group.rank = rank\n    DefaultContext.dist_config.attn_tp = world_size\n    DefaultContext.attn_tp_group.gpu_group = gpu_group\n\n    model = ParallelEmbedding(vocab_size=vocab_size,\n                              hidden_size=feat_size,\n                              padding_idx=padding_idx,\n                              dtype=dtype,\n                              is_tp=True,\n                              device=torch.device(type='cuda', index=rank))\n\n    weight = weight.to(torch.device(type='cuda', index=rank))\n    model.weight_loader(model.weight, weight)\n\n    input = x.to(torch.device(type='cuda', index=rank))\n\n    with torch.inference_mode():\n        out = model(input)\n\n    if rank == 0:\n        result_queue.put(mp.reductions.reduce_tensor(out))\n\n    if dist.is_initialized():\n        dist.destroy_process_group()\n\n\nclass TestEmbedding:\n\n    @pytest.fixture\n    def vocab_size(self, request):\n        yield request.param\n\n    @pytest.fixture\n    def feat_size(self, request):\n        yield request.param\n\n    @pytest.fixture\n    def padding_idx(self, request):\n        yield request.param\n\n    @pytest.fixture\n    def dtype(self, request):\n        yield request.param\n\n    @pytest.fixture\n    def tp(self, request):\n        yield request.param\n\n    @pytest.fixture\n    def seqlen(self, request):\n        yield request.param\n\n    @pytest.fixture\n    def weight(self, vocab_size, feat_size, dtype):\n        yield torch.rand(vocab_size, feat_size, dtype=dtype)\n\n    @pytest.fixture\n    def x(self, seqlen, vocab_size):\n        yield torch.randint(low=0, high=vocab_size, size=(seqlen, ), dtype=torch.int32)\n\n    @pytest.fixture\n    def gt(self, x, vocab_size, feat_size, padding_idx, dtype, weight):\n        token_emb = nn.Embedding(vocab_size,\n                                 feat_size,\n                                 padding_idx=padding_idx,\n                                 dtype=dtype,\n                                 device=torch.device(type='cuda', index=0))\n        token_emb.weight.data.copy_(weight)\n        token_emb._fill_padding_idx_with_zero()\n        input = x.to(torch.device(type='cuda', index=0))\n        yield token_emb(input)\n\n    @pytest.mark.parametrize('vocab_size', [65576, 65533, 3333], indirect=True)\n    @pytest.mark.parametrize('feat_size', [4096, 768], indirect=True)\n    @pytest.mark.parametrize('padding_idx', [None], indirect=True)\n    @pytest.mark.parametrize('seqlen', [1024, 1011, 128], indirect=True)\n    @pytest.mark.parametrize('tp', [2], indirect=True)\n    @pytest.mark.parametrize('dtype', [torch.bfloat16], indirect=True)\n    def test_embedding(self, vocab_size, feat_size, padding_idx, seqlen, tp, dtype, x, weight, gt):\n        os.environ['MASTER_ADDR'] = 'localhost'\n        os.environ['MASTER_PORT'] = '29500'\n        os.environ['NCCL_SOCKET_IFNAME'] = 'lo'\n\n        world_size = tp\n        processes = []\n        mp.set_start_method('spawn', force=True)\n        result_queue = mp.Queue()\n\n        for rank in range(world_size):\n            p = mp.Process(target=parallel_emb,\n                           args=(rank, world_size, vocab_size, feat_size, padding_idx, dtype, x, weight, result_queue))\n            processes.append(p)\n            p.start()\n            time.sleep(0.5)\n\n        func, args = result_queue.get()\n        out = func(*args)\n\n        for p in processes:\n            p.join(timeout=10)\n            if p.is_alive():\n                p.terminate()\n                p.join(timeout=5)\n                if p.is_alive():\n                    p.kill()\n\n        torch.testing.assert_close(out, gt)\n"
  },
  {
    "path": "tests/pytorch/paging/test_block_manager.py",
    "content": "# yapf: disable\nimport pytest\nimport torch\n\nfrom lmdeploy.pytorch.config import CacheConfig, SchedulerConfig\nfrom lmdeploy.pytorch.messages import SequenceMeta\nfrom lmdeploy.pytorch.paging.block_manager.base_block_manager import LogicalAllocator\nfrom lmdeploy.pytorch.paging.scheduler import Scheduler\n\n# yapf: enable\n\n\nclass TestAllocator:\n\n    @pytest.fixture\n    def num_gpu_blocks(self):\n        yield 16\n\n    @pytest.fixture\n    def num_cpu_blocks(self):\n        yield 4\n\n    @pytest.fixture\n    def allocator(self, num_cpu_blocks, num_gpu_blocks):\n        yield LogicalAllocator(num_cpu_blocks, num_gpu_blocks)\n\n    def test_alloc(self, allocator, num_cpu_blocks, num_gpu_blocks):\n\n        # initialize\n        num_blocks = num_cpu_blocks + num_gpu_blocks\n        gpu_allocator = allocator.get_phy_allocator('gpu')\n        cpu_allocator = allocator.get_phy_allocator('cpu')\n        assert allocator.get_num_free_blocks() == num_blocks\n        assert cpu_allocator.get_num_free_blocks() == num_cpu_blocks\n        assert gpu_allocator.get_num_free_blocks() == num_gpu_blocks\n\n        # test allocate\n        block_size = 4\n        blocks = allocator.allocate(block_size, 'gpu')\n        assert len(blocks) == block_size\n        assert allocator.get_num_free_blocks() == num_blocks - block_size\n        assert gpu_allocator.get_num_free_blocks() == num_gpu_blocks - block_size\n\n        # test free\n        allocator.add_ref_count(blocks, 1)\n        allocator.free(blocks)\n        assert allocator.get_num_free_blocks() == num_blocks - block_size\n        allocator.free(blocks)\n        assert allocator.get_num_free_blocks() == num_blocks\n        assert gpu_allocator.get_num_free_blocks() == num_gpu_blocks\n        assert cpu_allocator.get_num_free_blocks() == num_cpu_blocks\n\n    def test_full(self, allocator, num_cpu_blocks, num_gpu_blocks):\n\n        num_blocks = num_cpu_blocks + num_gpu_blocks\n        gpu_allocator = allocator.get_phy_allocator('gpu')\n        cpu_allocator = allocator.get_phy_allocator('cpu')\n\n        # no free blocks\n        gpu_block_size = num_gpu_blocks\n        gpu_blocks = allocator.allocate(gpu_block_size, 'gpu')\n        cpu_block_size = num_cpu_blocks\n        cpu_blocks = allocator.allocate(cpu_block_size, 'cpu')\n        assert cpu_allocator.get_num_free_blocks() == 0\n        assert gpu_allocator.get_num_free_blocks() == 0\n        with pytest.raises(MemoryError):\n            allocator.allocate(1, 'gpu')\n        allocator.free(gpu_blocks)\n        allocator.free(cpu_blocks)\n        assert allocator.get_num_free_blocks() == num_blocks\n        assert gpu_allocator.get_num_free_blocks() == num_gpu_blocks\n        assert cpu_allocator.get_num_free_blocks() == num_cpu_blocks\n\n\nclass TestDefaultBlockManager:\n\n    @pytest.fixture\n    def block_size(self):\n        yield 16\n\n    @pytest.fixture\n    def num_cpu_blocks(self):\n        yield 4\n\n    @pytest.fixture\n    def num_gpu_blocks(self):\n        yield 4\n\n    @pytest.fixture\n    def max_batch_size(self):\n        yield 4\n\n    @pytest.fixture\n    def cache_config(self, block_size, num_cpu_blocks, num_gpu_blocks, max_batch_size):\n        yield CacheConfig(max_batches=max_batch_size,\n                          block_size=block_size,\n                          num_cpu_blocks=num_cpu_blocks,\n                          num_gpu_blocks=num_gpu_blocks)\n\n    @pytest.fixture\n    def scheduler_config(self, max_batch_size):\n        yield SchedulerConfig(max_batches=max_batch_size,\n                              max_session_len=128,\n                              max_request_output_len=64,\n                              eviction_type='recompute')\n\n    @pytest.fixture\n    def seq_meta(self, block_size):\n        from lmdeploy.pytorch.strategies.ar.sequence import ARSequenceStrategy\n        strategy = ARSequenceStrategy()\n        yield SequenceMeta(block_size, strategy=strategy)\n\n    @pytest.fixture\n    def scheduler(self, cache_config, scheduler_config, seq_meta):\n        yield Scheduler(scheduler_config=scheduler_config, cache_config=cache_config, seq_meta=seq_meta)\n\n    @pytest.fixture\n    def block_mgr(self, scheduler):\n        yield scheduler.block_manager\n\n    def test_alloc(self, scheduler, block_mgr, num_gpu_blocks):\n        sess = scheduler.add_session(0)\n        block_size = sess.seq_meta.block_size\n\n        # test alloc\n        token_ids = torch.tensor([1])\n        msg = sess.add_sequence(token_ids)\n        assert block_mgr.can_allocate(msg)\n        block_mgr.allocate(msg)\n        block_table = block_mgr.get_block_table(msg)\n        assert block_mgr.get_num_free_gpu_blocks() == num_gpu_blocks - 1\n        assert block_table is not None\n        assert len(block_table) == 1\n\n        # test free\n        block_mgr.free(msg)\n        block_table = block_mgr.get_block_table(msg)\n        assert block_table is None or len(block_table) == 0\n        assert block_mgr.get_num_free_gpu_blocks() == num_gpu_blocks\n\n        # alloc over limit\n        token_ids = torch.zeros((num_gpu_blocks * block_size + 1, ), dtype=torch.int64)\n        msg = sess.add_sequence(token_ids)\n        assert not block_mgr.can_allocate(msg)\n\n    def test_num_required_blocks(self, scheduler, block_mgr):\n        from lmdeploy.pytorch.messages import InputEmbeddings\n        sess = scheduler.add_session(0)\n        block_size = sess.seq_meta.block_size\n\n        token_ids = torch.tensor([1])\n        msg = sess.add_sequence(token_ids)\n        num_required = block_mgr.num_required_blocks(msg)\n        assert num_required == 1\n\n        embedding = InputEmbeddings(None, 0, block_size * 2)\n        msg = sess.add_sequence(token_ids, input_embeddings=[embedding])\n        num_required = block_mgr.num_required_blocks(msg)\n        assert num_required == 1\n\n        token_ids = torch.tensor([1] * block_size * 3)\n        embedding = InputEmbeddings(None, 0, block_size * 2)\n        msg = sess.add_sequence(token_ids, input_embeddings=[embedding])\n        num_required = block_mgr.num_required_blocks(msg)\n        assert num_required == 3\n\n    def test_append_slot(self, scheduler, block_mgr, num_gpu_blocks):\n        sess = scheduler.add_session(0)\n        block_size = sess.seq_meta.block_size\n\n        # test append\n        token_ids = torch.tensor([1])\n        msg = sess.add_sequence(token_ids)\n        block_mgr.allocate(msg)\n        block_table = block_mgr.get_block_table(msg)\n        assert len(block_table) == 1\n\n        # no new logical block\n        msg.update_token_ids(torch.tensor([1] * (block_size - 1)))\n        assert block_mgr.can_allocate(msg)\n        block_mgr.allocate(msg)\n        block_table = block_mgr.get_block_table(msg)\n        assert len(block_table) == 1\n        assert block_mgr.get_num_free_gpu_blocks() == num_gpu_blocks - 1\n\n        # with new logical block\n        msg.update_token_ids(torch.tensor([1]))\n        block_mgr.allocate(msg)\n        block_table = block_mgr.get_block_table(msg)\n        assert len(block_table) == 2\n        assert block_mgr.get_num_free_gpu_blocks() == num_gpu_blocks - 2\n\n    def test_swap(self, scheduler, block_mgr, num_gpu_blocks):\n        sess = scheduler.add_session(0)\n        block_size = sess.seq_meta.block_size\n\n        token_ids = torch.tensor([1] * (block_size + 1))\n        msg = sess.add_sequence(token_ids)\n        block_mgr.allocate(msg)\n\n        old_phy_blocks = block_mgr.get_block_table(msg)\n        success, swap_map = block_mgr.try_swap_out(msg)\n        new_phy_blocks = block_mgr.get_block_table(msg)\n        assert success\n        assert block_mgr.get_num_free_gpu_blocks() == num_gpu_blocks\n        assert block_mgr.get_num_free_cpu_blocks() == num_gpu_blocks - 2\n        assert len(swap_map) == 2\n        for block_id in old_phy_blocks:\n            assert block_id in swap_map\n        for block_id in new_phy_blocks:\n            assert block_id - num_gpu_blocks in swap_map.values()\n\n        old_phy_blocks = block_mgr.get_block_table(msg)\n        success, swap_map = block_mgr.try_swap_in(msg)\n        new_phy_blocks = block_mgr.get_block_table(msg)\n        assert block_mgr.get_num_free_gpu_blocks() == num_gpu_blocks - 2\n        assert block_mgr.get_num_free_cpu_blocks() == num_gpu_blocks\n        assert len(swap_map) == 2\n        for block_id in old_phy_blocks:\n            assert block_id - num_gpu_blocks in swap_map\n        for block_id in new_phy_blocks:\n            assert block_id in swap_map.values()\n\n        success, swap_map = block_mgr.try_swap_out(msg)\n        assert success\n        token_ids = torch.tensor([1] * (block_size * 4))\n        msg_full = sess.add_sequence(token_ids)\n        block_mgr.allocate(msg_full)\n        success, swap_map = block_mgr.try_swap_out(msg)\n        assert not success\n\n\nclass TestWindowBlockManager:\n\n    @pytest.fixture\n    def window_size(self):\n        yield 32\n\n    @pytest.fixture\n    def block_size(self):\n        yield 16\n\n    @pytest.fixture\n    def num_cpu_blocks(self):\n        yield 4\n\n    @pytest.fixture\n    def num_gpu_blocks(self):\n        yield 4\n\n    @pytest.fixture\n    def max_batch_size(self):\n        yield 4\n\n    @pytest.fixture\n    def cache_config(self, block_size, num_cpu_blocks, num_gpu_blocks, max_batch_size, window_size):\n        yield CacheConfig(max_batches=max_batch_size,\n                          block_size=block_size,\n                          num_cpu_blocks=num_cpu_blocks,\n                          num_gpu_blocks=num_gpu_blocks,\n                          window_size=window_size)\n\n    @pytest.fixture\n    def scheduler_config(self, max_batch_size):\n        yield SchedulerConfig(max_batches=max_batch_size,\n                              max_session_len=128,\n                              max_request_output_len=64,\n                              eviction_type='recompute')\n\n    @pytest.fixture\n    def seq_meta(self, block_size):\n        from lmdeploy.pytorch.strategies.ar.sequence import ARSequenceStrategy\n        strategy = ARSequenceStrategy()\n        yield SequenceMeta(block_size, strategy=strategy)\n\n    @pytest.fixture\n    def scheduler(self, cache_config, scheduler_config, seq_meta):\n        yield Scheduler(scheduler_config=scheduler_config, cache_config=cache_config, seq_meta=seq_meta)\n\n    @pytest.fixture\n    def block_mgr(self, scheduler):\n        yield scheduler.block_manager\n\n    def test_alloc(self, scheduler, block_mgr, num_gpu_blocks):\n        sess = scheduler.add_session(0)\n        block_size = sess.seq_meta.block_size\n\n        # test alloc\n        token_ids = torch.tensor([1])\n        msg = sess.add_sequence(token_ids)\n        assert block_mgr.can_allocate(msg)\n        block_mgr.allocate(msg)\n        block_table = block_mgr.get_block_table(msg)\n        assert block_mgr.get_num_free_gpu_blocks() == num_gpu_blocks - 1\n        assert block_table is not None\n        assert len(block_table) == 1\n\n        # test free\n        block_mgr.free(msg)\n        block_table = block_mgr.get_block_table(msg)\n        assert block_table is None or len(block_table) == 0\n        assert block_mgr.get_num_free_gpu_blocks() == num_gpu_blocks\n\n        # alloc over limit\n        token_ids = torch.zeros((num_gpu_blocks * block_size + 1, ), dtype=torch.int64)\n        msg = sess.add_sequence(token_ids)\n        assert not block_mgr.can_allocate(msg)\n\n    def test_win_alloc(self, scheduler, block_mgr, num_gpu_blocks, window_size):\n        sess = scheduler.add_session(0)\n\n        # 2 win block\n        token_ids = torch.tensor([1] * window_size)\n        msg = sess.add_sequence(token_ids)\n        block_mgr.allocate(msg)\n        msg.update_token_ids(torch.tensor([1]))\n        block_mgr.allocate(msg)\n        assert block_mgr.get_num_free_gpu_blocks() == num_gpu_blocks - 3\n        block_table = block_mgr.get_block_table(msg)\n        assert block_table is None or len(block_table) == 3\n        block_mgr.free(msg)\n\n        # 3 win block\n        token_ids = torch.tensor([1] * (window_size + 2))\n        msg = sess.add_sequence(token_ids)\n        block_mgr.allocate(msg)\n        assert block_mgr.get_num_free_gpu_blocks() == num_gpu_blocks - 3\n        msg.update_token_ids(torch.tensor([1]))\n        block_mgr.allocate(msg)\n        assert block_mgr.get_num_free_gpu_blocks() == num_gpu_blocks - 3\n        block_table = block_mgr.get_block_table(msg)\n        assert block_table is None or len(block_table) == 3\n        block_mgr.free(msg)\n\n        # not full win\n        token_ids = torch.tensor([1] * (window_size - 2))\n        msg = sess.add_sequence(token_ids)\n        block_mgr.allocate(msg)\n        assert block_mgr.get_num_free_gpu_blocks() == num_gpu_blocks - 2\n        msg.update_token_ids(torch.tensor([1]))\n        block_mgr.allocate(msg)\n        assert block_mgr.get_num_free_gpu_blocks() == num_gpu_blocks - 2\n        block_table = block_mgr.get_block_table(msg)\n        assert block_table is None or len(block_table) == 2\n        block_mgr.free(msg)\n"
  },
  {
    "path": "tests/pytorch/paging/test_block_trie.py",
    "content": "import numpy as np\nimport pytest\n\nfrom lmdeploy.pytorch.config import CacheConfig, SchedulerConfig\nfrom lmdeploy.pytorch.messages import SequenceMeta\nfrom lmdeploy.pytorch.paging import Scheduler\n\n\nclass TestBlockTire:\n\n    @pytest.fixture\n    def block_size(self):\n        yield 16\n\n    @pytest.fixture\n    def num_cpu_blocks(self):\n        yield 4\n\n    @pytest.fixture\n    def num_gpu_blocks(self):\n        yield 16\n\n    @pytest.fixture\n    def max_batch_size(self):\n        yield 4\n\n    @pytest.fixture\n    def cache_config(self, block_size, num_cpu_blocks, num_gpu_blocks, max_batch_size):\n        yield CacheConfig(max_batches=max_batch_size,\n                          block_size=block_size,\n                          num_cpu_blocks=num_cpu_blocks,\n                          num_gpu_blocks=num_gpu_blocks,\n                          enable_prefix_caching=True)\n\n    @pytest.fixture\n    def scheduler_config(self, max_batch_size):\n        yield SchedulerConfig(max_batches=max_batch_size,\n                              max_session_len=128,\n                              max_request_output_len=64,\n                              eviction_type='recompute')\n\n    @pytest.fixture\n    def seq_meta(self, block_size):\n        from lmdeploy.pytorch.strategies.ar.sequence import ARSequenceStrategy\n        strategy = ARSequenceStrategy()\n        yield SequenceMeta(block_size, strategy=strategy)\n\n    @pytest.fixture\n    def scheduler(self, cache_config, scheduler_config, seq_meta):\n        yield Scheduler(scheduler_config=scheduler_config, cache_config=cache_config, seq_meta=seq_meta)\n\n    @pytest.fixture\n    def block_mgr(self, scheduler):\n        yield scheduler.block_manager\n\n    @pytest.fixture\n    def block_trie(self, scheduler):\n        yield scheduler.block_trie\n\n    def test_allocate(self, block_trie, block_mgr, scheduler):\n        allocator = block_trie.allocator\n        sess = scheduler.add_session(0)\n        block_size = sess.seq_meta.block_size\n        token_ids = ([1] * block_size + [2] * block_size)\n        token_ids += [3] * (block_size // 2)\n        seq = sess.add_sequence(token_ids)\n\n        # first allocate\n        block_mgr.allocate(seq)\n        block_trie.allocate(seq)\n        logical_blocks = seq.logical_blocks\n        assert len(logical_blocks) == 3\n        ref_cnt = allocator.get_ref_count(logical_blocks.get_real_blocks())\n        assert np.array_equal(ref_cnt, [2, 2, 1])\n        node = getattr(seq.logical_blocks, 'last_shared_node', None)\n        assert node is not None\n        assert node.num_matched == block_size * 2\n        assert np.array_equal(node.tokens, [2] * block_size)\n        assert np.array_equal(node.parent.tokens, [1] * block_size)\n        assert node in block_trie.leaves\n        assert node.parent not in block_trie.leaves\n\n        # append\n        seq.update_token_ids([4] * block_size)\n        block_mgr.allocate(seq)\n        block_trie.allocate(seq)\n        logical_blocks = seq.logical_blocks\n        assert len(logical_blocks) == 4\n        ref_cnt = allocator.get_ref_count(logical_blocks.get_real_blocks())\n        assert np.array_equal(ref_cnt, [2, 2, 2, 1])\n        node = getattr(seq.logical_blocks, 'last_shared_node', None)\n        assert node is not None\n        assert node.num_matched == block_size * 3\n        expect_tokens = [3] * (block_size // 2) + [4] * (block_size // 2)\n        assert np.array_equal(node.tokens, expect_tokens)\n        assert node in block_trie.leaves\n        assert len(block_trie.leaves) == 1\n\n    def test_match(self, block_trie, block_mgr, scheduler):\n        allocator = block_trie.allocator\n        sess = scheduler.add_session(0)\n        block_size = sess.seq_meta.block_size\n\n        # initialize cache\n        token_ids = ([1] * block_size + [2] * block_size)\n        token_ids += [3] * (block_size // 2)\n        seq = sess.add_sequence(token_ids)\n        block_mgr.allocate(seq)\n        block_trie.allocate(seq)\n\n        # test1\n        token_ids = ([1] * block_size + [3] * block_size)\n        seq = sess.add_sequence(token_ids)\n        block_trie.match(seq)\n        logical_blocks = seq.logical_blocks\n        assert len(logical_blocks) == 1\n        ref_cnt = allocator.get_ref_count(logical_blocks.get_real_blocks())\n        assert np.array_equal(ref_cnt, [3])\n        node = getattr(seq.logical_blocks, 'last_shared_node', None)\n        assert node is not None\n        assert node.num_matched == block_size\n        assert np.array_equal(node.tokens, [1] * block_size)\n        block_mgr.allocate(seq)\n        block_trie.allocate(seq)\n        assert len(block_trie.leaves) == 2\n\n        # test2\n        token_ids = ([1] * block_size + [2] * block_size)\n        token_ids += [4] * (block_size // 2)\n        seq = sess.add_sequence(token_ids)\n        block_trie.match(seq)\n        logical_blocks = seq.logical_blocks\n        assert len(logical_blocks) == 2\n        ref_cnt = allocator.get_ref_count(logical_blocks.get_real_blocks())\n        assert np.array_equal(ref_cnt, [4, 3])\n\n    def test_evict(self, block_trie, scheduler, num_gpu_blocks):\n        block_mgr = block_trie.block_manager\n        sess = scheduler.add_session(0)\n        block_size = sess.seq_meta.block_size\n        token_ids = ([1] * block_size * (num_gpu_blocks - 1))\n        token_ids += [2] * (block_size // 2)\n        seq = sess.add_sequence(token_ids)\n        block_mgr.allocate(seq)\n        block_trie.allocate(seq)\n        assert block_mgr.get_num_free_gpu_blocks() == 0\n\n        # test free\n        block_mgr.free(seq)\n        seq.set_step(0)\n        assert block_mgr.get_num_free_gpu_blocks() == 1\n\n        # test evict\n        leaf = next(iter(block_trie.leaves))\n        block_trie.evict(4)\n        new_leaf = next(iter(block_trie.leaves))\n        assert leaf != new_leaf\n        assert block_mgr.get_num_free_gpu_blocks() == 5\n"
  },
  {
    "path": "tests/pytorch/paging/test_scheduler.py",
    "content": "import pytest\nimport torch\n\nfrom lmdeploy.pytorch.config import CacheConfig, SchedulerConfig\nfrom lmdeploy.pytorch.messages import MessageStatus, SequenceMeta\nfrom lmdeploy.pytorch.paging.scheduler import Scheduler\n\n\nclass TestScheduler:\n\n    @pytest.fixture\n    def block_size(self):\n        yield 16\n\n    @pytest.fixture\n    def num_cpu_blocks(self):\n        yield 4\n\n    @pytest.fixture\n    def num_gpu_blocks(self):\n        yield 4\n\n    @pytest.fixture\n    def max_batch_size(self):\n        yield 4\n\n    @pytest.fixture\n    def cache_config(self, block_size, num_cpu_blocks, num_gpu_blocks, max_batch_size):\n        yield CacheConfig(max_batches=max_batch_size,\n                          block_size=block_size,\n                          num_cpu_blocks=num_cpu_blocks,\n                          num_gpu_blocks=num_gpu_blocks)\n\n    @pytest.fixture\n    def scheduler_config(self, max_batch_size):\n        yield SchedulerConfig(max_batches=max_batch_size,\n                              max_session_len=128,\n                              max_request_output_len=64,\n                              eviction_type='recompute')\n\n    @pytest.fixture\n    def seq_meta(self, block_size):\n        from lmdeploy.pytorch.strategies.ar.sequence import ARSequenceStrategy\n        strategy = ARSequenceStrategy()\n        yield SequenceMeta(block_size, strategy=strategy)\n\n    @pytest.fixture\n    def scheduler(self, cache_config, scheduler_config, seq_meta):\n        yield Scheduler(scheduler_config=scheduler_config, cache_config=cache_config, seq_meta=seq_meta)\n\n    def test_schedule_base(self, scheduler, block_size, num_gpu_blocks):\n        block_manager = scheduler.block_manager\n        session_id = 0\n        session = scheduler.add_session(session_id)\n        assert session_id in scheduler.sessions\n        assert scheduler.sessions[session_id] == session\n\n        num_blocks = 2\n        token_ids = torch.tensor([0] * block_size * num_blocks)\n        seq = session.add_sequence(token_ids)\n\n        assert seq.status == MessageStatus.WAITING\n        assert seq in scheduler.waiting\n\n        output = scheduler.schedule(is_prefill=True)\n        block_tables = scheduler.get_block_tables(output.running)\n\n        assert seq.status == MessageStatus.READY\n        assert seq in output.running\n        assert len(block_tables) == 1\n        assert len(block_tables[0]) == num_blocks\n        assert block_manager.get_num_free_gpu_blocks() == num_gpu_blocks - num_blocks\n\n        assert scheduler.has_unfinished()\n\n    def test_update(self, scheduler, block_size, num_gpu_blocks):\n        block_manager = scheduler.block_manager\n        session_id1 = 0\n        session1 = scheduler.add_session(session_id1)\n        token_ids1 = torch.tensor([0] * block_size * 1)\n        seq1 = session1.add_sequence(token_ids1)\n\n        session_id2 = 1\n        session2 = scheduler.add_session(session_id2)\n        token_ids2 = torch.tensor([0] * block_size * 2)\n        seq2 = session2.add_sequence(token_ids2)\n        token_ids3 = torch.tensor([0] * block_size * 3)\n        seq3 = session2.add_sequence(token_ids3)\n\n        scheduler.schedule(is_prefill=True)\n        assert seq1.status == MessageStatus.READY\n        assert seq2.status == MessageStatus.READY\n        assert seq3.status == MessageStatus.WAITING\n\n        # stop seq\n        seq1.state.stop()\n        assert len(scheduler.ready) == 1\n        assert seq1 in scheduler.hanging\n\n        # end seq\n        seq1.session.remove_sequence(seq1)\n        assert session_id1 in scheduler.sessions\n        assert seq1 not in scheduler.ready\n        assert seq1 not in scheduler.hanging\n        assert block_manager.get_num_free_gpu_blocks() == num_gpu_blocks - 2\n\n        # stop session\n        scheduler.stop_session(session_id2)\n        assert len(scheduler.ready) == 0\n        assert len(scheduler.waiting) == 0\n        assert len(scheduler.hanging) == 2\n\n        # end session\n        scheduler.end_session(session_id2)\n        assert session_id2 not in scheduler.sessions\n        assert len(scheduler.hanging) == 0\n        assert block_manager.get_num_free_gpu_blocks() == num_gpu_blocks\n\n    def test_evict(self, scheduler, block_size, num_gpu_blocks, num_cpu_blocks):\n        block_manager = scheduler.block_manager\n        session_id = 0\n        session = scheduler.add_session(session_id)\n\n        # test: add 3 seq\n        token_ids1 = torch.tensor([0] * block_size * 1)\n        seq1 = session.add_sequence(token_ids1)\n        token_ids2 = torch.tensor([0] * block_size * 2)\n        seq2 = session.add_sequence(token_ids2)\n        token_ids3 = torch.tensor([0] * block_size * 3)\n        seq3 = session.add_sequence(token_ids3)\n        scheduler.schedule(is_prefill=True)\n        # seq1: 1 running gpu\n        # seq2: 2 running gpu\n        # seq3: 3 waiting empty\n        assert seq1.status == MessageStatus.READY\n        assert seq2.status == MessageStatus.READY\n        assert seq3.status == MessageStatus.WAITING\n        assert block_manager.get_num_free_gpu_blocks() == num_gpu_blocks - 3\n\n        # test: waiting alloc\n        seq2.state.stop()\n        assert len(scheduler.ready) == 1\n        assert len(scheduler.waiting) == 1\n        assert len(scheduler.hanging) == 1\n\n        scheduler.schedule(is_prefill=True)\n        # seq1: 1 running gpu\n        # seq2: 2 hanging cpu\n        # seq3: 3 running gpu\n        assert seq1.status == MessageStatus.READY\n        assert seq2.status == MessageStatus.STOPPED\n        assert seq3.status == MessageStatus.READY\n        assert block_manager.get_num_free_gpu_blocks() == 0\n\n        # test: waiting append token\n        seq2.state.activate()\n        seq3.session.remove_sequence(seq3)\n        seq2.update_token_ids(torch.tensor([1] * block_size))\n        assert len(scheduler.ready) == 1\n        assert len(scheduler.waiting) == 1\n        assert len(scheduler.hanging) == 0\n\n        scheduler.schedule(is_prefill=True)\n        # seq1: 1 running gpu\n        # seq2: 3 running gpu\n        # seq3: 3 nan\n        assert seq1.status == MessageStatus.READY\n        assert seq2.status == MessageStatus.READY\n        assert block_manager.get_num_free_gpu_blocks() == 0\n\n        # test running append\n        seq1.update_token_ids(torch.tensor([1] * block_size))\n        seq2.update_token_ids(torch.tensor([1] * block_size))\n        assert len(scheduler.ready) == 2\n        scheduler.schedule(is_prefill=False)\n        # seq1: 2 running gpu\n        # seq2: 4 waiting cpu\n        # seq3: 3 nan\n        assert seq1.status == MessageStatus.READY\n        assert seq2.status == MessageStatus.WAITING\n        assert block_manager.get_num_free_gpu_blocks() == 2\n"
  },
  {
    "path": "tests/test_lmdeploy/test_auto_backend.py",
    "content": "import os\nimport tempfile\n\nimport numpy as np\nimport pytest\n\n\nclass TestAutoBackend:\n\n    @pytest.fixture\n    def turbomind_workspace(self):\n        workspace = tempfile.TemporaryDirectory('internlm-chat-7b-turbomind').name\n        os.makedirs(os.path.join(workspace, 'triton_models'), exist_ok=True)\n        return workspace\n\n    @pytest.fixture\n    def models(self):\n        # example models to test\n        # format (model_path, is_turbomind_supported)\n        models = [\n            ('baichuan-inc/Baichuan-7B', True),\n            ('baichuan-inc/Baichuan2-7B-Chat', True),\n            ('baichuan-inc/Baichuan-13B-Chat', False),\n            ('baichuan-inc/Baichuan2-13B-Chat', False),\n            ('internlm/internlm-chat-7b', True),\n            ('internlm/internlm2-chat-7b', True),\n            ('internlm/internlm-xcomposer2-7b', True),\n            ('internlm/internlm-xcomposer-7b', False),\n            ('THUDM/chatglm2-6b', False),\n            ('THUDM/chatglm3-6b', False),\n            ('deepseek-ai/deepseek-moe-16b-chat', False),\n            ('01-ai/Yi-34B-Chat', True),\n            ('codellama/CodeLlama-7b-Instruct-hf', True),\n            ('Qwen/Qwen-7B-Chat', True),\n            ('Qwen/Qwen-VL-Chat', True),\n            ('Qwen/Qwen1.5-4B-Chat', True),\n            ('Qwen/Qwen1.5-0.5B-Chat', True),\n        ]\n        return models\n\n    def test_turbomind_is_supported(self, turbomind_workspace, models):\n        from lmdeploy.turbomind.supported_models import is_supported\n        assert is_supported(turbomind_workspace) is True\n        for m, flag in models:\n            assert is_supported(m) is flag\n\n    def test_autoget_backend(self, turbomind_workspace, models):\n        from lmdeploy.archs import autoget_backend\n        assert autoget_backend(turbomind_workspace) == 'turbomind'\n        n = len(models)\n        choices = np.random.choice(n, n // 2, replace=False)\n        for i in choices:\n            model, is_support_turbomind = models[i]\n            target = 'turbomind' if is_support_turbomind else 'pytorch'\n            backend = autoget_backend(model)\n            assert backend == target\n"
  },
  {
    "path": "tests/test_lmdeploy/test_content_merge.py",
    "content": "import pytest\n\nfrom lmdeploy.serve.processors import MultimodalProcessor\n\n\nclass TestMergeMessageContent:\n    \"\"\"Test suite for merge_message_content function.\"\"\"\n\n    def test_missing_content_field(self):\n        \"\"\"Test that missing content field is added with empty string.\n\n        This case occurs with assistant messages that only have tool_calls.\n        \"\"\"\n        msg = {\n            'role':\n            'assistant',\n            'tool_calls': [{\n                'id': 'chatcmpl-tool-123',\n                'type': 'function',\n                'function': {\n                    'name': 'get_weather',\n                    'arguments': '{\"city\": \"Paris\"}'\n                }\n            }]\n        }\n        result = MultimodalProcessor.merge_message_content(msg)\n\n        assert 'content' in result\n        assert result['content'] == ''\n        assert 'tool_calls' in result\n        assert result['tool_calls'] == msg['tool_calls']\n\n    def test_explicit_none_content(self):\n        \"\"\"Test that explicit None content is converted to empty string.\n\n        This matches vLLM's behavior: None → [] → ''.join([]) → ''.\n        \"\"\"\n        msg = {\n            'role':\n            'assistant',\n            'content':\n            None,\n            'tool_calls': [{\n                'id': 'chatcmpl-tool-456',\n                'type': 'function',\n                'function': {\n                    'name': 'Bash',\n                    'arguments': '{\"command\": \"ls\"}'\n                }\n            }]\n        }\n        result = MultimodalProcessor.merge_message_content(msg)\n\n        assert result['content'] == ''\n        assert 'tool_calls' in result\n\n    def test_string_content_unchanged(self):\n        \"\"\"Test that string content remains unchanged.\"\"\"\n        msg = {'role': 'user', 'content': 'Hello, world!'}\n        result = MultimodalProcessor.merge_message_content(msg)\n\n        assert result['content'] == 'Hello, world!'\n        assert result is msg  # Should return the same object\n\n    def test_single_text_block(self):\n        \"\"\"Test extraction of single text block from list content.\"\"\"\n        msg = {'role': 'user', 'content': [{'type': 'text', 'text': 'Single block'}]}\n        result = MultimodalProcessor.merge_message_content(msg)\n\n        assert result['content'] == 'Single block'\n\n    def test_multiple_text_blocks_newline_join(self):\n        \"\"\"Test that multiple text blocks are merged with newline separator.\n\n        This matches vLLM's behavior: text_prompt = \"\\\\n\".join(texts)\n        \"\"\"\n        msg = {\n            'role':\n            'user',\n            'content': [{\n                'type': 'text',\n                'text': 'First block'\n            }, {\n                'type': 'text',\n                'text': 'Second block'\n            }, {\n                'type': 'text',\n                'text': 'Third block'\n            }]\n        }\n        result = MultimodalProcessor.merge_message_content(msg)\n\n        assert result['content'] == 'First block\\nSecond block\\nThird block'\n\n    def test_mixed_content_types(self):\n        \"\"\"Test that only text blocks are extracted from mixed content.\n\n        Non-text blocks (like image_url) should be filtered out.\n        \"\"\"\n        msg = {\n            'role':\n            'user',\n            'content': [{\n                'type': 'text',\n                'text': 'Analyze this image:'\n            }, {\n                'type': 'image_url',\n                'image_url': {\n                    'url': 'http://example.com/img.jpg'\n                }\n            }, {\n                'type': 'text',\n                'text': 'What do you see?'\n            }]\n        }\n        result = MultimodalProcessor.merge_message_content(msg)\n\n        assert result['content'] == 'Analyze this image:\\nWhat do you see?'\n\n    def test_empty_list_content(self):\n        \"\"\"Test that empty list content produces empty string.\"\"\"\n        msg = {'role': 'user', 'content': []}\n        result = MultimodalProcessor.merge_message_content(msg)\n\n        assert result['content'] == ''\n\n    def test_list_with_non_text_blocks_only(self):\n        \"\"\"Test content with only non-text blocks (e.g., only images).\"\"\"\n        msg = {\n            'role':\n            'user',\n            'content': [{\n                'type': 'image_url',\n                'image_url': {\n                    'url': 'http://example.com/img1.jpg'\n                }\n            }, {\n                'type': 'image_url',\n                'image_url': {\n                    'url': 'http://example.com/img2.jpg'\n                }\n            }]\n        }\n        result = MultimodalProcessor.merge_message_content(msg)\n\n        assert result['content'] == ''\n\n    def test_preserve_all_message_fields(self):\n        \"\"\"Test that all message fields are preserved during content merge.\"\"\"\n        msg = {\n            'role': 'assistant',\n            'content': [{\n                'type': 'text',\n                'text': 'Response'\n            }],\n            'tool_calls': [{\n                'id': '123',\n                'type': 'function'\n            }],\n            'name': 'assistant',\n            'custom_field': 'custom_value'\n        }\n        result = MultimodalProcessor.merge_message_content(msg)\n\n        assert result['content'] == 'Response'\n        assert result['tool_calls'] == msg['tool_calls']\n        assert result['name'] == 'assistant'\n        assert result['custom_field'] == 'custom_value'\n        assert set(result.keys()) == set(msg.keys())\n\n    def test_text_block_with_missing_text_field(self):\n        \"\"\"Test handling of text block without 'text' field.\"\"\"\n        msg = {\n            'role':\n            'user',\n            'content': [\n                {\n                    'type': 'text',\n                    'text': 'First'\n                },\n                {\n                    'type': 'text'\n                },  # Missing 'text' field\n                {\n                    'type': 'text',\n                    'text': 'Third'\n                }\n            ]\n        }\n        result = MultimodalProcessor.merge_message_content(msg)\n\n        # Missing text field should be treated as empty string\n        assert result['content'] == 'First\\n\\nThird'\n\n    def test_gpt_oss_tool_call_scenario(self):\n        \"\"\"Test the specific GPT-OSS tool call scenario from the bug report.\n\n        When GPT-OSS assistant returns tool calls, content is empty/missing.\n        \"\"\"\n        msg = {\n            'role':\n            'assistant',\n            'tool_calls': [{\n                'id': 'chatcmpl-tool-UK9rkwzMAyxt9DxBezk7E2',\n                'type': 'function',\n                'function': {\n                    'name': 'Bash',\n                    'arguments': '{\"command\": \"ls\", \"description\": \"List files in current directory\"}'\n                }\n            }]\n        }\n        result = MultimodalProcessor.merge_message_content(msg)\n\n        # Should add content field with empty string\n        assert 'content' in result\n        assert result['content'] == ''\n        # Should preserve tool_calls\n        assert len(result['tool_calls']) == 1\n        assert result['tool_calls'][0]['function']['name'] == 'Bash'\n\n\n@pytest.mark.parametrize(\n    'msg,expected_content',\n    [\n        # Basic cases\n        ({\n            'role': 'user',\n            'content': 'test'\n        }, 'test'),\n        ({\n            'role': 'user',\n            'content': None\n        }, ''),\n        ({\n            'role': 'assistant'\n        }, ''),\n\n        # List content cases\n        ({\n            'role': 'user',\n            'content': [{\n                'type': 'text',\n                'text': 'a'\n            }]\n        }, 'a'),\n        ({\n            'role': 'user',\n            'content': [{\n                'type': 'text',\n                'text': 'a'\n            }, {\n                'type': 'text',\n                'text': 'b'\n            }]\n        }, 'a\\nb'),\n\n        # Empty cases\n        ({\n            'role': 'user',\n            'content': []\n        }, ''),\n        ({\n            'role': 'user',\n            'content': [{\n                'type': 'image_url'\n            }]\n        }, ''),\n    ])\ndef test_merge_message_content_parametrized(msg, expected_content):\n    \"\"\"Parametrized test for various message content scenarios.\"\"\"\n    result = MultimodalProcessor.merge_message_content(msg)\n    assert result['content'] == expected_content\n\n\ndef test_batch_message_processing():\n    \"\"\"Test processing multiple messages in a batch (typical usage pattern).\"\"\"\n    messages = [{\n        'role': 'user',\n        'content': 'Hello'\n    }, {\n        'role': 'assistant',\n        'tool_calls': [{\n            'id': '123',\n            'type': 'function'\n        }]\n    }, {\n        'role': 'user',\n        'content': [{\n            'type': 'text',\n            'text': 'Block 1'\n        }, {\n            'type': 'text',\n            'text': 'Block 2'\n        }]\n    }]\n\n    processed = [MultimodalProcessor.merge_message_content(msg) for msg in messages]\n\n    # Verify all messages have content field\n    assert all('content' in msg for msg in processed)\n\n    # Verify content values\n    assert processed[0]['content'] == 'Hello'\n    assert processed[1]['content'] == ''\n    assert processed[2]['content'] == 'Block 1\\nBlock 2'\n\n    # Should pass model.py assertion\n    assert all(isinstance(m, dict) and 'role' in m and 'content' in m for m in processed)\n"
  },
  {
    "path": "tests/test_lmdeploy/test_grammar.py",
    "content": "import json\nimport re\n\nimport pytest\nfrom jsonschema import validate\n\nfrom lmdeploy import pipeline\nfrom lmdeploy.messages import GenerationConfig, PytorchEngineConfig, TurbomindEngineConfig\n\nMODEL_IDS = [\n    'Qwen/Qwen3-0.6B',\n    'OpenGVLab/InternVL3_5-1B',\n]\n\nBACKEND_FACTORIES = [\n    ('tm', lambda: TurbomindEngineConfig(max_batch_size=2, session_len=1024)),\n    ('pt', lambda: PytorchEngineConfig(max_batch_size=1, session_len=1024)),\n]\n\nSCHEMA_MAP = {\n    'json_schema': {\n        'type': 'object',\n        'properties': {\n            'name': {\n                'type': 'string'\n            },\n            'skills': {\n                'type': 'array',\n                'items': {\n                    'type': 'string',\n                    'maxLength': 10\n                },\n                'minItems': 3,\n                'maxItems': 10,\n            },\n            'work history': {\n                'type': 'array',\n                'items': {\n                    'type': 'object',\n                    'properties': {\n                        'company': {\n                            'type': 'string'\n                        },\n                        'duration': {\n                            'type': 'string'\n                        },\n                    },\n                    'required': ['company'],\n                },\n            },\n        },\n        'required': ['name', 'skills', 'work history'],\n    },\n    'regex_schema': 'call me [A-Za-z]{1,10}',\n    'json_object': None,\n}\n\n\n@pytest.mark.parametrize('model_id', MODEL_IDS)\n@pytest.mark.parametrize('backend_name,backend_factory', BACKEND_FACTORIES)\n@pytest.mark.parametrize('schema_type', list(SCHEMA_MAP.keys()) + [None])\ndef test_guided_matrix(model_id, backend_name, backend_factory, schema_type):\n    pipe = pipeline(\n        model_id,\n        backend_config=backend_factory(),\n        log_level='INFO',\n    )\n\n    if schema_type is None:\n        enable_guide = False\n    else:\n        enable_guide = True\n        response_format = {'type': schema_type}\n        schema = SCHEMA_MAP[schema_type]\n        if schema_type == 'json_schema':\n            response_format[schema_type] = dict(name='test', schema=schema)\n        elif schema_type == 'regex_schema':\n            response_format[schema_type] = schema\n\n    try:\n        if enable_guide:\n            gen_config = GenerationConfig(response_format=response_format)\n        else:\n            gen_config = GenerationConfig()\n\n        response = pipe(['Make a self introduction please.'] * 3, gen_config=gen_config)\n        assert response and response[0].text\n\n        if enable_guide:\n            if schema_type == 'json_schema':\n                validate(instance=json.loads(response[0].text), schema=schema)\n            elif schema_type == 'json_object':\n                validate(instance=json.loads(response[0].text), schema={'type': 'object', 'additionalProperties': True})\n            elif schema_type == 'regex_schema':\n                assert re.fullmatch(schema, response[0].text)\n    finally:\n        pipe.close()\n\n\n@pytest.mark.parametrize('model_id', MODEL_IDS)\n@pytest.mark.parametrize('backend_name,backend_factory', BACKEND_FACTORIES)\ndef test_mix_guided_matrix(model_id, backend_name, backend_factory):\n    pipe = pipeline(\n        model_id,\n        backend_config=backend_factory(),\n        log_level='INFO',\n    )\n\n    schema_type = 'json_schema'\n    response_format = {'type': schema_type}\n    schema = SCHEMA_MAP[schema_type]\n    response_format[schema_type] = dict(name='test', schema=schema)\n\n    prompts = ['Make a self introduction please.'] * 4\n    try:\n        config = GenerationConfig(response_format=response_format)\n\n        gen_config = [None if idx % 3 else config for idx in range(4)]\n\n        responses = pipe.batch_infer(prompts, gen_config=gen_config)\n\n        for resp, c in zip(responses, gen_config):\n            if c is None:\n                # Unguided generation: ensure we get some text, and that it does not\n                # accidentally produce JSON that conforms to the guided schema.\n                assert resp and resp.text\n                try:\n                    data = json.loads(resp.text)\n                except json.JSONDecodeError:\n                    # Not valid JSON, so it cannot conform to the schema.\n                    continue\n                else:\n                    try:\n                        validate(instance=data, schema=schema)\n                    except Exception:\n                        # JSON is present but does not satisfy the schema.\n                        continue\n                    else:\n                        pytest.fail('Unguided generation unexpectedly produced schema-conformant JSON')\n            else:\n                validate(instance=json.loads(resp.text), schema=schema)\n    finally:\n        pipe.close()\n"
  },
  {
    "path": "tests/test_lmdeploy/test_harmony_gpt_oss_parser.py",
    "content": "import collections\nimport json\nimport os\nimport sys\nimport time\nimport types\nfrom typing import Generator, List\n\nimport pytest\nimport shortuuid\n\n# Ensure local package is imported (not any site-packages installation)\nREPO_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))\nif REPO_ROOT not in sys.path:\n    sys.path.insert(0, REPO_ROOT)\n\n\ndef _install_openai_harmony_stub():\n    \"\"\"Install a minimal stub for `openai_harmony` so the module imports\n    without the real dependency.\n\n    The GptOssChatParser test injects its own dummy parser, so the stub is sufficient.\n    \"\"\"\n    if 'openai_harmony' in sys.modules:\n        return\n    m = types.ModuleType('openai_harmony')\n\n    class HarmonyEncodingName:\n        HARMONY_GPT_OSS = 'HARMONY_GPT_OSS'\n\n    class Role:\n        ASSISTANT = 'assistant'\n\n    class StreamableParser:  # pragma: no cover - constructor only used\n\n        def __init__(self, encoding, role=None):\n            self.encoding = encoding\n            self.role = role\n\n    def load_harmony_encoding(name):  # pragma: no cover - not used in test\n        return object()\n\n    m.HarmonyEncodingName = HarmonyEncodingName\n    m.Role = Role\n    m.StreamableParser = StreamableParser\n    m.load_harmony_encoding = load_harmony_encoding\n    sys.modules['openai_harmony'] = m\n\n\nTestExpects = collections.namedtuple('TestExpects', 'func_name location')\n\n\nclass DummyParser:\n    \"\"\"A minimal stand-in for Harmony's StreamableParser with channels.\n\n    Control tokens:\n      -1: start functions.get_weather (commentary)\n      -4: start functions.get_time (commentary)\n      -6: start functions.get_weather (again)\n      -9: end current tool call, append to `messages`\n      -2: switch to final (visible) content\n      -3: switch to analysis (reasoning)\n    Other tokens are interpreted as chr(token).\n    \"\"\"\n\n    class _Msg:\n\n        def __init__(self, channel, recipient):\n            self.channel = channel\n            self.recipient = recipient\n\n    def __init__(self):\n        self.current_channel = None\n        self.current_recipient = None\n        self.last_content_delta = ''\n        self.messages = []\n\n    def process(self, token):\n        if token == -1:\n            self.current_channel = 'commentary'\n            self.current_recipient = 'functions.get_weather'\n            self.last_content_delta = ''\n            return\n        if token == -4:\n            self.current_channel = 'commentary'\n            self.current_recipient = 'functions.get_time'\n            self.last_content_delta = ''\n            return\n        if token == -6:\n            self.current_channel = 'commentary'\n            self.current_recipient = 'functions.get_weather'\n            self.last_content_delta = ''\n            return\n        if token == -9:\n            if self.current_channel == 'commentary' and self.current_recipient and self.current_recipient.startswith(\n                    'functions.'):\n                self.messages.append(self._Msg(self.current_channel, self.current_recipient))\n            # reset recipient to signal end of current tool call\n            self.current_recipient = None\n            self.current_channel = None\n            self.last_content_delta = ''\n            return\n        if token == -2:\n            self.current_channel = 'final'\n            self.current_recipient = None\n            self.last_content_delta = ''\n            return\n        if token == -3:\n            self.current_channel = 'analysis'\n            self.current_recipient = None\n            self.last_content_delta = ''\n            return\n        # regular character token\n        self.last_content_delta = chr(token)\n\n\ndef _chat_completion_v1(request, token_chunks: List[List[int]]):\n    from lmdeploy.serve.openai.harmony_utils import GptOssChatParser\n    from lmdeploy.serve.openai.protocol import (ChatCompletionResponse, ChatCompletionResponseChoice,\n                                                ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse,\n                                                UsageInfo)\n\n    request_id = f'chat-{shortuuid.random()}'\n    created_time = int(time.time())\n    model_name = request.model\n\n    parser = GptOssChatParser()\n    parser.parser = DummyParser()\n\n    if request.stream:\n\n        def completion_stream_generator() -> Generator['ChatCompletionStreamResponse', None, None]:\n            finish_reason = 'stop'\n            for chunk in token_chunks:\n                delta_message = parser.parse_streaming(chunk)\n                choice_data = ChatCompletionResponseStreamChoice(index=0,\n                                                                 delta=delta_message,\n                                                                 finish_reason=finish_reason,\n                                                                 logprobs=None)\n                response = ChatCompletionStreamResponse(id=request_id,\n                                                        created=created_time,\n                                                        model=model_name,\n                                                        choices=[choice_data],\n                                                        usage=None)\n                yield response\n\n        return completion_stream_generator()\n\n    # Non-stream path: parse all tokens at once using parse_full\n    tokens: List[int] = []\n    for c in token_chunks:\n        tokens.extend(c)\n    message = parser.parse_full(tokens)\n    finish_reason = 'tool_calls' if message.tool_calls else 'stop'\n    choice_data = ChatCompletionResponseChoice(index=0, message=message, finish_reason=finish_reason)\n    return ChatCompletionResponse(id=request_id,\n                                  created=created_time,\n                                  model=model_name,\n                                  choices=[choice_data],\n                                  usage=UsageInfo())\n\n\ndef _stream_parse(request, token_chunks: List[List[int]]):\n    from lmdeploy.serve.openai.protocol import DeltaMessage\n\n    content = ''\n    reasoning_content = ''\n    tool_calls_by_index = {}\n\n    for i, stream_resp in enumerate(_chat_completion_v1(request, token_chunks)):\n        delta_message: DeltaMessage = stream_resp.choices[0].delta\n        if delta_message.content:\n            content += delta_message.content\n        if delta_message.reasoning_content:\n            reasoning_content += delta_message.reasoning_content\n        if delta_message.tool_calls:\n            for c in delta_message.tool_calls:\n                idx = c.index\n                existing_call = tool_calls_by_index.get(idx, None)\n                if not existing_call:\n                    tool_calls_by_index[idx] = c\n                    continue\n                if c.function.name:\n                    existing_call.function.name = c.function.name\n                if c.function.arguments:\n                    existing_call.function.arguments = existing_call.function.arguments or ''\n                    existing_call.function.arguments += c.function.arguments\n    # sorted list for stable order\n    tool_calls = [tool_calls_by_index[i] for i in sorted(tool_calls_by_index.keys())]\n    return content, reasoning_content, tool_calls\n\n\ndef _t(s: str) -> List[int]:\n    return [ord(c) for c in s]\n\n\n# Basic: single function call split across two chunks (bug repro scenario)\nTOKENS_SINGLE_CALL_TWO_CHUNKS = [\n    [-1] + _t('{\"location\": \"Paris'),\n    _t(', France\"}'),\n]\n\n# Multiple calls with indices and different function names\nTOKENS_TWO_CALLS_DIFFERENT_FUNCS = [\n    [-1] + _t('{\"location\": \"Berlin\"}') + [-9] + [-4] + _t('{\"city\": \"New'),\n    _t(' York\"}') + [-9],\n]\n\n# Interleaved channels: analysis, tool call, final content\nTOKENS_INTERLEAVED = [\n    [-3] + _t('Thinking about the weather. ') + [-1] + _t('{\"location\": \"Par'),\n    _t('is, France\"}') + [-9] + [-2] + _t('Fetching the weather now.'),\n]\n\n# Two calls, same function name, indices increment\nTOKENS_TWO_CALLS_SAME_FUNC = [\n    [-1] + _t('{\"location\": \"Tokyo\"}') + [-9],\n    [-6] + _t('{\"location\": \"Ky'),\n    _t('oto\"}') + [-9],\n]\n\n\n@pytest.mark.parametrize(('token_chunks', 'expects'), [\n    (TOKENS_SINGLE_CALL_TWO_CHUNKS, [TestExpects('get_weather', 'Paris, France')]),\n])\ndef test_parser_stream_basic(token_chunks: List[List[int]], expects: List[TestExpects]):\n    from lmdeploy.serve.openai.protocol import ChatCompletionRequest\n\n    _install_openai_harmony_stub()\n    request = ChatCompletionRequest(model='gpt-oss', messages=[], stream=True)\n    content, reasoning_content, tool_calls = _stream_parse(request, token_chunks)\n\n    assert len(tool_calls) == len(expects)\n    for parsed_call, expected_call in zip(tool_calls, expects):\n        assert parsed_call.function.name == expected_call.func_name\n        args = json.loads(parsed_call.function.arguments)\n        assert args['location'] == expected_call.location\n    assert content.strip() == ''\n    assert (reasoning_content or '').strip() == ''\n\n\ndef test_parser_stream_multiple_calls_indices():\n    from lmdeploy.serve.openai.protocol import ChatCompletionRequest\n\n    _install_openai_harmony_stub()\n    request = ChatCompletionRequest(model='gpt-oss', messages=[], stream=True)\n    content, reasoning_content, tool_calls = _stream_parse(request, TOKENS_TWO_CALLS_DIFFERENT_FUNCS)\n\n    assert len(tool_calls) == 2\n    # tool_calls sorted by index ensures stable order\n    tc0, tc1 = tool_calls\n    assert tc0.index == 0 and tc1.index == 1\n    assert tc0.function.name == 'get_weather'\n    assert json.loads(tc0.function.arguments)['location'] == 'Berlin'\n    assert tc1.function.name == 'get_time'\n    assert json.loads(tc1.function.arguments)['city'] == 'New York'\n    assert (content or '').strip() == ''\n    assert (reasoning_content or '').strip() == ''\n\n\ndef test_parser_stream_interleaved_channels():\n    from lmdeploy.serve.openai.protocol import ChatCompletionRequest\n\n    _install_openai_harmony_stub()\n    request = ChatCompletionRequest(model='gpt-oss', messages=[], stream=True)\n    content, reasoning_content, tool_calls = _stream_parse(request, TOKENS_INTERLEAVED)\n\n    assert json.loads(tool_calls[0].function.arguments)['location'] == 'Paris, France'\n    assert reasoning_content == 'Thinking about the weather. '\n    assert content == 'Fetching the weather now.'\n\n\n@pytest.mark.parametrize(('token_chunks', 'expects'), [\n    (TOKENS_TWO_CALLS_SAME_FUNC, [TestExpects('get_weather', 'Tokyo'),\n                                  TestExpects('get_weather', 'Kyoto')]),\n])\ndef test_parser_stream_two_calls_same_func(token_chunks: List[List[int]], expects: List[TestExpects]):\n    from lmdeploy.serve.openai.protocol import ChatCompletionRequest\n\n    _install_openai_harmony_stub()\n    request = ChatCompletionRequest(model='gpt-oss', messages=[], stream=True)\n    _, _, tool_calls = _stream_parse(request, token_chunks)\n\n    assert len(tool_calls) == len(expects)\n    for parsed_call, expected_call in zip(tool_calls, expects):\n        assert parsed_call.function.name == expected_call.func_name\n        args = json.loads(parsed_call.function.arguments)\n        assert args['location'] == expected_call.location\n\n\ndef test_open_tool_call_no_args():\n    from lmdeploy.serve.openai.protocol import ChatCompletionRequest\n\n    _install_openai_harmony_stub()\n    request = ChatCompletionRequest(model='gpt-oss', messages=[], stream=True)\n    content, reasoning_content, tool_calls = _stream_parse(request, [[-1]])\n\n    assert len(tool_calls) == 1\n    assert tool_calls[0].function.name == 'get_weather'\n    assert (tool_calls[0].function.arguments or '') == ''\n    assert (content or '') == ''\n    assert (reasoning_content or '') == ''\n\n\n@pytest.mark.parametrize(('token_chunks', 'expects'), [\n    (TOKENS_SINGLE_CALL_TWO_CHUNKS, [TestExpects('get_weather', 'Paris, France')]),\n    (TOKENS_TWO_CALLS_SAME_FUNC, [TestExpects('get_weather', 'Tokyo'),\n                                  TestExpects('get_weather', 'Kyoto')]),\n])\ndef test_parser_nonstream(token_chunks: List[List[int]], expects: List[TestExpects]):\n    from lmdeploy.serve.openai.protocol import ChatCompletionRequest\n\n    _install_openai_harmony_stub()\n    resp = _chat_completion_v1(ChatCompletionRequest(model='gpt-oss', messages=[], stream=False), token_chunks)\n\n    assert len(resp.choices) == 1\n    first_message = resp.choices[0].message\n    assert first_message.content is None\n    assert (first_message.reasoning_content or '') == ''\n    assert len(first_message.tool_calls) == len(expects)\n    for parsed_call, expected_call in zip(first_message.tool_calls, expects):\n        assert parsed_call.function.name == expected_call.func_name\n        args = json.loads(parsed_call.function.arguments)\n        assert args['location'] == expected_call.location\n"
  },
  {
    "path": "tests/test_lmdeploy/test_lite/test_quantization/test_utils/test_cal_qparams.py",
    "content": "# yapf: disable\nimport torch\n\nfrom lmdeploy.lite.utils import (cal_qparams_per_channel_absmax, cal_qparams_per_channel_minmax,\n                                 cal_qparams_per_group_absmax, cal_qparams_per_group_minmax,\n                                 cal_qparams_per_tensor_absmax, cal_qparams_per_tensor_minmax)\n\n\n# yapf: enable\ndef test_cal_qparams():\n    \"\"\"Test function for quantization parameter calculation.\"\"\"\n\n    # Create a dummy tensor\n    w = torch.randn(64, 64)\n\n    # Test per-channel absmax method\n    qparams = cal_qparams_per_channel_absmax(w, 8)\n    assert qparams.scales.shape == (64, 1)\n    assert qparams.zero_points is None\n\n    # Test per-channel minmax method\n    qparams = cal_qparams_per_channel_minmax(w, 8)\n    assert qparams.scales.shape == (64, 1)\n    assert qparams.zero_points.shape == (64, 1)\n\n    # Test per-group absmax method\n    qparams = cal_qparams_per_group_absmax(w, 8, 16)\n    assert qparams.scales.shape == (64, 4, 1)\n    assert qparams.zero_points is None\n\n    # Test per-group minmax method\n    qparams = cal_qparams_per_group_minmax(w, 8, 16)\n    assert qparams.scales.shape == (64, 4, 1)\n    assert qparams.zero_points.shape == (64, 4, 1)\n\n    # Test per-tensor absmax method\n    qparams = cal_qparams_per_tensor_absmax(w, 8)\n    assert qparams.scales.shape == ()\n    assert qparams.zero_points is None\n\n    # Test per-tensor minmax method\n    qparams = cal_qparams_per_tensor_minmax(w, 8)\n    assert qparams.scales.shape == ()\n    assert qparams.zero_points.shape == ()\n"
  },
  {
    "path": "tests/test_lmdeploy/test_messages.py",
    "content": "from typing import List\n\nimport pytest\n\nfrom lmdeploy import GenerationConfig, Tokenizer\nfrom lmdeploy.utils import get_hf_gen_cfg\n\n\ndef test_engine_generation_config():\n    tokenizer = Tokenizer('internlm/internlm-chat-7b')\n    config = GenerationConfig(n=3, stop_words=['<eoa>'])\n    stop_token_ids = tokenizer.encode('<eoa>', add_bos=False)\n    config.convert_stop_bad_words_to_ids(tokenizer)\n    assert stop_token_ids == config.stop_token_ids\n    assert isinstance(config.stop_token_ids, List) and \\\n        isinstance(config.stop_token_ids[0], int)\n\n\n@pytest.mark.parametrize('model_path', [\n    'deepseek-ai/DeepSeek-V3',\n    'Qwen/Qwen2.5-32B-Instruct',\n    'internlm/internlm3-8b-instruct',\n])\ndef test_update_from_hf_gen_cfg(model_path):\n    tokenizer = Tokenizer(model_path)\n    model_cfg = get_hf_gen_cfg(model_path)\n\n    generation_config = GenerationConfig()\n    generation_config.update_from_hf_gen_cfg(model_cfg, tokenizer.eos_token_id)\n    assert generation_config.stop_token_ids is not None\n"
  },
  {
    "path": "tests/test_lmdeploy/test_model.py",
    "content": "import pytest\n\nfrom lmdeploy.model import MODELS\n\nHF_MODELS_WITH_CHAT_TEMPLATES = [\n    'Qwen/Qwen1.5-7B-Chat',\n    'Qwen/Qwen2.5-7B-Instruct',\n    'Qwen/Qwen3-8B',\n    'Qwen/QwQ-32B',\n    'Qwen/QwQ-32B-Preview',\n    'Qwen/QwQ-32B-AWQ',\n    'Qwen/Qwen2.5-VL-7B-Instruct',\n    'Qwen/Qwen2-VL-7B-Instruct',\n    'internlm/internlm2-chat-7b',\n    'internlm/internlm2_5-7b-chat',\n    'internlm/internlm3-8b-instruct',\n    # 'internlm/Intern-S1',\n    # 'internlm/Intern-S1-mini',\n    'OpenGVLab/InternVL-Chat-V1-2',\n    'OpenGVLab/InternVL-Chat-V1-5',\n    'OpenGVLab/Mini-InternVL-Chat-2B-V1-5',\n    'OpenGVLab/InternVL2-2B',\n    'OpenGVLab/InternVL2-4B',\n    'OpenGVLab/InternVL2-8B',\n    'OpenGVLab/InternVL2_5-2B',\n    'OpenGVLab/InternVL2_5-4B',\n    'OpenGVLab/InternVL2_5-8B',\n    'OpenGVLab/InternVL3-2B',\n    'OpenGVLab/InternVL3-8B',\n    'OpenGVLab/InternVL3-9B',\n    'OpenGVLab/InternVL3_5-1B',\n    'OpenGVLab/InternVL3_5-4B',\n    'OpenGVLab/InternVL3_5-8B',\n    'OpenGVLab/InternVL3_5-GPT-OSS-20B-A4B-Preview',\n    'deepseek-ai/DeepSeek-V2-Lite',\n    'deepseek-ai/DeepSeek-V3',\n    'deepseek-ai/DeepSeek-R1',\n    'deepseek-ai/DeepSeek-R1-Zero',\n    'deepseek-ai/DeepSeek-V3.1',\n    'deepseek-ai/deepseek-coder-1.3b-instruct',\n    'deepseek-ai/DeepSeek-R1-Distill-Llama-8B',\n    'deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B',\n    'deepseek-ai/DeepSeek-R1-Distill-Qwen-7B',\n    'zai-org/chatglm3-6b',\n    'zai-org/glm-4-9b-chat',\n    'zai-org/codegeex4-all-9b',\n    'zai-org/cogvlm2-llama3-chat-19B',\n    'microsoft/Phi-3-mini-128k-instruct',\n    'microsoft/Phi-3-vision-128k-instruct',\n    'microsoft/Phi-3.5-mini-instruct',\n    'microsoft/Phi-3.5-vision-instruct',\n    'microsoft/Phi-3.5-MoE-instruct',\n    '01-ai/Yi-1.5-34B-Chat',\n    # Accessing the following models is supposed to be authenticated\n    # 'openbmb/MiniCPM-V-2_6',\n    # 'google/gemma-3-4b-it',\n]\n\n\n@pytest.mark.parametrize('model_path', HF_MODELS_WITH_CHAT_TEMPLATES)\ndef test_HFChatTemplate_get_prompt_sequence_start_True(model_path):\n    model = MODELS.get('hf')(model_path=model_path)\n    prompt = 'How to apply chat template using transformers?'\n    messages = [{'role': 'user', 'content': prompt}]\n\n    from transformers import AutoTokenizer\n    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)\n    expected = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n    assert model.get_prompt(prompt, sequence_start=True) == expected\n\n\n@pytest.mark.parametrize('model_path', HF_MODELS_WITH_CHAT_TEMPLATES)\ndef test_HFChatTemplate_message2prompt_sequence_start_True(model_path):\n    model = MODELS.get('hf')(model_path=model_path)\n    prompt = 'How to apply chat template using transformers?'\n    messages = [{'role': 'user', 'content': prompt}]\n\n    from transformers import AutoTokenizer\n    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)\n    expected = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n    assert model.messages2prompt(prompt, sequence_start=True) == expected\n    assert model.messages2prompt(messages, sequence_start=True) == expected\n\n\ndef test_base_model():\n    model = MODELS.get('internlm')(capability='completion')\n    assert model.capability == 'completion'\n    assert model.get_prompt('hi') == 'hi'\n    assert model.messages2prompt('test') == 'test'\n\n\ndef test_vicuna():\n    prompt = 'hello, can u introduce yourself'\n    model = MODELS.get('vicuna')(capability='completion')\n    assert model.get_prompt(prompt, sequence_start=True) == prompt\n    assert model.get_prompt(prompt, sequence_start=False) == prompt\n\n    model = MODELS.get('vicuna')(capability='chat', system='Provide answers in Python')\n    assert model.get_prompt(prompt, sequence_start=True) != prompt\n    assert model.get_prompt(prompt, sequence_start=False) != prompt\n    assert model.system == 'Provide answers in Python'\n\n    model = MODELS.get('vicuna')(capability='voice')\n    _prompt = None\n    with pytest.raises(AssertionError):\n        _prompt = model.get_prompt(prompt, sequence_start=True)\n        assert _prompt is None\n\n\ndef test_prefix_response():\n    model = MODELS.get('hf')(model_path='Qwen/Qwen3-8B')\n    messages = [dict(role='assistant', content='prefix test')]\n    prompt = model.messages2prompt(messages)\n    assert prompt[-len('prefix test'):] == 'prefix test'\n\n\ndef test_internlm_chat():\n    prompt = 'hello, can u introduce yourself'\n    model = MODELS.get('internlm')(capability='completion')\n    assert model.get_prompt(prompt, sequence_start=True) == prompt\n    assert model.get_prompt(prompt, sequence_start=False) == prompt\n    assert model.stop_words is not None\n    assert model.system == '<|System|>:'\n\n    model = MODELS.get('internlm')(capability='chat', system='Provide answers in Python')\n    assert model.get_prompt(prompt, sequence_start=True) != prompt\n    assert model.get_prompt(prompt, sequence_start=False) != prompt\n    assert model.system == 'Provide answers in Python'\n\n    model = MODELS.get('internlm')(capability='voice')\n    _prompt = None\n    with pytest.raises(AssertionError):\n        _prompt = model.get_prompt(prompt, sequence_start=True)\n        assert _prompt is None\n\n\ndef test_baichuan():\n    prompt = 'hello, can u introduce yourself'\n    model = MODELS.get('baichuan2')(capability='completion')\n    assert model.get_prompt(prompt, sequence_start=True) == prompt\n    assert model.get_prompt(prompt, sequence_start=False) == prompt\n    assert model.stop_words is None\n\n    model = MODELS.get('baichuan2')(capability='chat')\n    _prompt = model.get_prompt(prompt, sequence_start=True)\n    assert _prompt == '<reserved_106>' + prompt + '<reserved_107>'\n\n\ndef test_llama2():\n    prompt = 'hello, can u introduce yourself'\n    model = MODELS.get('llama2')(capability='completion')\n    assert model.get_prompt(prompt, sequence_start=True) == prompt\n    assert model.get_prompt(prompt, sequence_start=False) == prompt\n    assert model.stop_words is None\n    assert model.meta_instruction is not None\n\n    model = MODELS.get('llama2')(capability='chat', meta_instruction='Provide answers in Python')\n    assert model.get_prompt(prompt, sequence_start=True) != prompt\n    assert model.get_prompt(prompt, sequence_start=False) != prompt\n    assert model.meta_instruction == 'Provide answers in Python'\n\n    model = MODELS.get('llama2')(capability='voice')\n    _prompt = None\n    with pytest.raises(AssertionError):\n        _prompt = model.get_prompt(prompt, sequence_start=True)\n        assert _prompt is None\n\n\ndef test_codellama_completion():\n    model = MODELS.get('codellama')(capability='completion')\n    prompt = \"\"\"\\\nimport socket\n\ndef ping_exponential_backoff(host: str):\"\"\"\n    assert model.get_prompt(prompt) == prompt\n    assert model.get_prompt(prompt, sequence_start=False) == prompt\n    assert model.stop_words is None\n\n\ndef test_codellama_infilling():\n    model = MODELS.get('codellama')(capability='infilling')\n    prompt = '''def remove_non_ascii(s: str) -> str:\n    \"\"\" <FILL>\n    return result\n'''\n    _prompt = model.get_prompt(prompt)\n    assert _prompt.find('<FILL>') == -1\n    assert model.stop_words == ['<EOT>']\n\n    model = MODELS.get('codellama')(capability='infilling', suffix_first=True)\n    _prompt = model.get_prompt(prompt)\n    assert _prompt.find('<FILL>') == -1\n\n\ndef test_codellama_chat():\n    model = MODELS.get('codellama')(capability='chat', system='Provide answers in Python')\n    prompt = 'Write a function that computes the set of sums of all contiguous sublists of a given list.'  # noqa: E501\n    _prompt = model.get_prompt(prompt, sequence_start=True)\n    assert _prompt.find('Provide answers in Python') != -1\n\n    _prompt = model.get_prompt(prompt, sequence_start=False)\n    assert _prompt.find('Provide answers in Python') == -1\n    assert model.stop_words is None\n\n\ndef test_codellama_python_specialist():\n    model = MODELS.get('codellama')(capability='python')\n    prompt = \"\"\"\n    def remove_non_ascii(s: str) -> str:\n\"\"\"\n    assert model.get_prompt(prompt, sequence_start=True) == prompt\n    assert model.get_prompt(prompt, sequence_start=False) == prompt\n    assert model.stop_words is None\n\n\ndef test_codellama_others():\n    model = None\n    with pytest.raises(AssertionError):\n        model = MODELS.get('codellama')(capability='java')\n    assert model is None\n\n\n@pytest.mark.parametrize(\n    'model_path_or_name',\n    ['deepseek-ai/deepseek-vl2-tiny', 'deepseek-ai/deepseek-vl2-small', 'deepseek-ai/deepseek-vl2'])\ndef test_deepseek_vl2(model_path_or_name):\n    chat_template = MODELS.get('deepseek-vl2')()\n    messages = [{\n        'role': 'user',\n        'content': 'This is image_1: <image>\\n'\n        'This is image_2: <image>\\n'\n        'This is image_3: <image>\\n Can you tell me what are in the images?',\n        'images': [\n            'images/multi_image_1.jpeg',\n            'images/multi_image_2.jpeg',\n            'images/multi_image_3.jpeg',\n        ],\n    }, {\n        'role': 'assistant',\n        'content': ''\n    }]\n\n    ref = '<|User|>: This is image_1: <image>\\nThis is image_2: <image>\\nThis is image_3: <image>' + \\\n          '\\n Can you tell me what are in the images?\\n\\n<|Assistant|>:'\n    lm_res = chat_template.messages2prompt(messages)\n    assert ref == lm_res\n\n\n@pytest.mark.parametrize('model_path', ['Qwen/Qwen3-30B-A3B', 'Qwen/Qwen2.5-7B-Instruct', 'Qwen/Qwen3.5-35B-A3B'])\n@pytest.mark.parametrize('enable_thinking', [True, False, None])\ndef test_qwen3(model_path, enable_thinking):\n    from transformers import AutoTokenizer\n\n    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)\n    chat_template = MODELS.get('hf')(model_path)\n\n    messages = [{\n        'role': 'system',\n        'content': 'you are a helpful assistant'\n    }, {\n        'role': 'user',\n        'content': 'who are you'\n    }, {\n        'role': 'assistant',\n        'content': 'I am an AI'\n    }, {\n        'role': 'user',\n        'content': 'AGI is?'\n    }]\n    if enable_thinking is None:\n        ref = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n    else:\n        ref = tokenizer.apply_chat_template(messages,\n                                            tokenize=False,\n                                            add_generation_prompt=True,\n                                            enable_thinking=enable_thinking)\n    lm_res = chat_template.messages2prompt(messages, enable_thinking=enable_thinking)\n    assert ref == lm_res\n\n\n# TODO(lvhan): bring this case back when internlm/Intern-S1 fix tokenizer\n# @pytest.mark.parametrize('model_path', ['internlm/Intern-S1'])\n# @pytest.mark.parametrize('enable_thinking', [None, True, False])\n# @pytest.mark.parametrize('has_user_sys', [True, False])\n# def test_interns1(model_path, enable_thinking, has_user_sys):\n#     from transformers import AutoTokenizer\n#     try:\n#         tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)\n#     except OSError:\n#         pytest.skip(reason=f'{model_path} not exists')\n\n#     chat_template = MODELS.get('hf')(model_path)\n\n#     messages = [{\n#         'role': 'system',\n#         'content': 'you are a helpful assistant'\n#     }, {\n#         'role': 'user',\n#         'content': 'who are you'\n#     }, {\n#         'role': 'assistant',\n#         'content': 'I am an AI'\n#     }, {\n#         'role': 'user',\n#         'content': 'AGI is?'\n#     }]\n#     if not has_user_sys:\n#         messages = messages[1:]\n\n#     if enable_thinking is None:\n#         ref = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n#     else:\n#         ref = tokenizer.apply_chat_template(messages,\n#                                             tokenize=False,\n#                                             add_generation_prompt=True,\n#                                             enable_thinking=enable_thinking)\n#     lm_res = chat_template.messages2prompt(messages, enable_thinking=enable_thinking)\n#     assert ref == lm_res\n\n\n@pytest.mark.parametrize('model_path', ['Qwen/Qwen1.5-7B-Chat', 'Qwen/Qwen2.5-7B-Instruct', 'Qwen/Qwen3-8B'])\ndef test_HFChatTemplate_get_prompt_sequence_start_False_Qwen(model_path):\n    model = MODELS.get('hf')(model_path=model_path)\n    assert model.stop_words == ['<|im_end|>']\n\n    prompt = 'How to apply chat template using transformers?'\n    assert model.get_prompt(prompt,\n                            sequence_start=False) == f'<|im_start|>user\\n{prompt}<|im_end|>\\n<|im_start|>assistant\\n'\n\n\n@pytest.mark.parametrize('model_path', ['Qwen/Qwen3.5-35B-A3B'])\ndef test_HFChatTemplate_get_prompt_sequence_start_False_Qwen3_5(model_path):\n    model = MODELS.get('hf')(model_path=model_path)\n    assert model.stop_words == ['<|im_end|>']\n\n    prompt = 'How to apply chat template using transformers?'\n    assert model.get_prompt(\n        prompt, sequence_start=False) == f'<|im_start|>user\\n{prompt}<|im_end|>\\n<|im_start|>assistant\\n<think>\\n'\n\n\n@pytest.mark.parametrize('model_path', ['deepseek-ai/DeepSeek-V3'])\ndef test_HFChatTemplate_DeepSeek_V3(model_path):\n    model = MODELS.get('hf')(model_path=model_path)\n    assert model.stop_words == ['<｜end▁of▁sentence｜>']\n\n    prompt = 'How to apply chat template using transformers?'\n    assert model.get_prompt(prompt, sequence_start=False) == f'<｜User｜>{prompt}<｜Assistant｜>'\n\n\n@pytest.mark.parametrize('model_path', ['deepseek-ai/DeepSeek-R1'])\ndef test_HFChatTemplate_DeepSeek_thinking(model_path):\n    model = MODELS.get('hf')(model_path=model_path)\n    assert model.stop_words == ['<｜end▁of▁sentence｜>']\n\n    prompt = 'How to apply chat template using transformers?'\n    assert model.get_prompt(prompt, sequence_start=False) == f'<｜User｜>{prompt}<｜Assistant｜><think>\\n'\n\n\n@pytest.mark.parametrize('model_path', ['Qwen/Qwen3-VL-8B-Instruct', 'Qwen/Qwen3.5-35B-A3B'])\ndef test_HFChatTemplate_Qwen3_VL_with_vision_id(model_path):\n    model = MODELS.get('hf')(model_path=model_path)\n\n    # testcase from https://github.com/QwenLM/Qwen3-VL\n    messages = [\n        {\n            'role': 'user',\n            'content': [{\n                'type': 'image'\n            }, {\n                'type': 'text',\n                'text': 'Hello, how are you?'\n            }],\n        },\n        {\n            'role': 'assistant',\n            'content': \"I'm doing well, thank you for asking. How can I assist you today?\",\n        },\n        {\n            'role':\n            'user',\n            'content': [\n                {\n                    'type': 'text',\n                    'text': 'Can you describe these images and video?'\n                },\n                {\n                    'type': 'image'\n                },\n                {\n                    'type': 'image'\n                },\n                {\n                    'type': 'video'\n                },\n                {\n                    'type': 'text',\n                    'text': 'These are from my vacation.'\n                },\n            ],\n        },\n        {\n            'role':\n            'assistant',\n            'content':\n            \"\"\"I'd be happy to describe the images and video for you.\n                Could you please provide more context about your vacation?\"\"\",\n        },\n        {\n            'role': 'user',\n            'content': 'It was a trip to the mountains. Can you see the details in the images and video?',\n        },\n    ]\n\n    from transformers import AutoTokenizer\n    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)\n    expected = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, add_vision_id=True)\n    chat_template_kwargs = dict(add_vision_id=True)\n    lm_res = model.messages2prompt(messages, **chat_template_kwargs)\n    assert expected == lm_res\n\n\n@pytest.mark.parametrize('model_path', ['google/gemma-2-9b-it', 'google/gemma-3-12b-it'])\ndef test_gemma_chat_template(model_path):\n    messages = [{'role': 'user', 'content': 'who are you'}]\n\n    from transformers import AutoTokenizer\n    tokenizer = AutoTokenizer.from_pretrained(model_path)\n    expected = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n\n    model = MODELS.get('hf')(model_path=model_path)\n    lm_res = model.messages2prompt(messages)\n    assert expected == lm_res\n\n    messages += [{'role': 'assistant', 'content': 'I am an AI'}, {'role': 'user', 'content': 'AGI is?'}]\n    lm_res = model.messages2prompt(messages, sequence_start=False)\n    assert lm_res == \"\"\"<start_of_turn>user\nwho are you<end_of_turn>\n<start_of_turn>model\nI am an AI<end_of_turn>\n<start_of_turn>user\nAGI is?<end_of_turn>\n<start_of_turn>model\n\"\"\"\n"
  },
  {
    "path": "tests/test_lmdeploy/test_pipeline.py",
    "content": "import gc\n\nimport pytest\nimport torch\n\nfrom lmdeploy import GenerationConfig, PytorchEngineConfig, TurbomindEngineConfig, pipeline\nfrom lmdeploy.messages import Response\n\nMODEL_ID = 'Qwen/Qwen3-8B'\n\n\n@pytest.mark.parametrize('backend', ['pytorch', 'turbomind'], scope='class')\nclass TestBackendInference:\n    \"\"\"Test class grouping all tests for each backend.\"\"\"\n\n    @pytest.fixture(scope='class', autouse=True)\n    def backend_config(self, backend):\n        \"\"\"Parametrized backend configuration for all tests.\"\"\"\n\n        if backend == 'pytorch':\n            return PytorchEngineConfig(session_len=4096, max_batch_size=4, tp=1)\n        elif backend == 'turbomind':\n            return TurbomindEngineConfig(session_len=4096, max_batch_size=4, tp=1)\n        else:\n            raise ValueError(f'Unknown backend type: {backend}')\n\n    @pytest.fixture(scope='class', autouse=True)\n    def pipe(self, backend_config):\n        \"\"\"Shared pipeline instance across all tests in class.\"\"\"\n        pipe = pipeline(MODEL_ID, backend_config=backend_config)\n        yield pipe\n        pipe.close()\n        del pipe\n        gc.collect()\n        if torch.cuda.is_available() and torch.cuda.device_count() > 0:\n            torch.cuda.reset_peak_memory_stats()\n            torch.cuda.synchronize()\n\n    def test_infer_single_string(self, pipe):\n        \"\"\"Test infer with single string prompt.\"\"\"\n        prompt = 'Hello, how are you?'\n        response = pipe.infer(prompt)\n\n        assert isinstance(response, Response)\n        assert hasattr(response, 'text')\n        assert hasattr(response, 'generate_token_len')\n        assert hasattr(response, 'input_token_len')\n        assert len(response.text) > 0\n\n    def test_infer_batch_strings(self, pipe):\n        \"\"\"Test infer with batch of string prompts.\"\"\"\n        prompts = ['What is AI?', 'Explain quantum computing', 'Tell me a joke']\n        responses = pipe.infer(prompts)\n\n        assert isinstance(responses, list)\n        assert len(responses) == len(prompts)\n        for resp in responses:\n            assert isinstance(resp, Response)\n            assert len(resp.text) > 0\n\n    def test_infer_openai_format(self, pipe):\n        \"\"\"Test infer with OpenAI-style message format.\"\"\"\n        prompts = [[{\n            'role': 'user',\n            'content': 'What is machine learning?'\n        }], [{\n            'role': 'user',\n            'content': 'Define deep learning'\n        }]]\n        responses = pipe.infer(prompts)\n\n        assert len(responses) == 2\n        for resp in responses:\n            assert isinstance(resp, Response)\n\n    def test_infer_with_generation_config(self, pipe):\n        \"\"\"Test infer with custom GenerationConfig.\"\"\"\n        gen_config = GenerationConfig(max_new_tokens=50, temperature=0.5, top_p=0.9, top_k=40, do_sample=True)\n        prompt = 'Write a haiku about nature'\n        response = pipe.infer(prompt, gen_config=gen_config)\n\n        assert isinstance(response, Response)\n        assert response.generate_token_len <= 50\n\n    def test_call_method(self, pipe):\n        \"\"\"Test __call__ method as shortcut for infer.\"\"\"\n        prompt = 'What is Python?'\n        response = pipe(prompt)\n\n        assert isinstance(response, Response)\n        assert len(response.text) > 0\n\n    def test_stream_infer_single(self, pipe):\n        \"\"\"Test stream_infer with single prompt.\"\"\"\n        prompt = 'Count from 1 to 5'\n        generator = pipe.stream_infer(prompt)\n\n        chunks = []\n        for chunk in generator:\n            chunks.append(chunk)\n            assert isinstance(chunk, Response)\n\n        assert len(chunks) > 0\n        full_text = ''.join([c.text for c in chunks])\n        assert len(full_text) > 0\n\n    def test_stream_infer_batch(self, pipe):\n        \"\"\"Test stream_infer with batch prompts.\"\"\"\n        prompts = ['First prompt', 'Second prompt']\n        generator = pipe.stream_infer(prompts)\n\n        responses = {}\n        for chunk in generator:\n            chunks = responses.setdefault(chunk.index, [])\n            chunks.append(chunk)\n            assert isinstance(chunk, Response)\n\n        assert len(responses) == len(prompts)\n        for chunks in responses.values():\n            full_text = ''.join([c.text for c in chunks])\n            assert len(full_text) > 0\n\n    def test_stream_infer_with_session(self, pipe):\n        \"\"\"Test stream_infer with session for multi-turn context.\"\"\"\n        session = pipe.session()\n        prompt1 = 'Hello! My name is Alice.'\n        step = 0\n\n        # First turn\n        generator = pipe.stream_infer(prompts=prompt1,\n                                      sessions=session,\n                                      gen_config=GenerationConfig(max_new_tokens=30),\n                                      sequence_start=True,\n                                      sequence_end=False,\n                                      enable_thinking=False)\n        resp = None\n        for out in generator:\n            resp = resp.extend(out) if resp else out\n\n        step += resp.generate_token_len + resp.input_token_len\n\n        response1 = resp.text\n\n        assert response1\n\n        # Second turn should remember context\n        prompt2 = 'What is my name?'\n        session.step = step\n        generator = pipe.stream_infer(prompts=prompt2,\n                                      sessions=session,\n                                      gen_config=GenerationConfig(max_new_tokens=30),\n                                      sequence_start=False,\n                                      sequence_end=False,\n                                      enable_thinking=False)\n\n        resp = None\n        for out in generator:\n            resp = resp.extend(out) if resp else out\n\n        step += out.generate_token_len + out.input_token_len\n\n        response2 = resp.text\n\n        assert 'alice' in response2.lower()\n\n    def test_chat_streaming(self, pipe):\n        \"\"\"Test chat method with streaming output.\"\"\"\n        prompt = 'Tell me a short story'\n        session = pipe.session()\n\n        generator = pipe.chat(prompt=prompt,\n                              session=session,\n                              stream_response=True,\n                              gen_config=GenerationConfig(max_new_tokens=50))\n\n        chunks = []\n        for chunk in generator:\n            chunks.append(chunk)\n            assert isinstance(chunk, Response)\n\n        assert len(chunks) > 0\n        assert session.response is not None\n        assert session.step > 0\n\n    def test_chat_non_streaming(self, pipe):\n        \"\"\"Test chat method with non-streaming output.\"\"\"\n        prompt = 'What is 2+2?'\n        session = pipe.chat(prompt=prompt,\n                            stream_response=False,\n                            gen_config=GenerationConfig(max_new_tokens=20),\n                            enable_thinking=False)\n\n        assert session is not None\n        assert hasattr(session, 'response')\n        assert hasattr(session, 'history')\n        assert len(session.history) == 1\n        assert '4' in session.response.text or 'four' in session.response.text.lower()\n\n    def test_chat_multi_turn(self, pipe):\n        \"\"\"Test chat method with multi-turn conversation.\"\"\"\n        # First turn\n        session = pipe.chat(prompt='My favorite color is blue.',\n                            stream_response=False,\n                            gen_config=GenerationConfig(max_new_tokens=30),\n                            enable_thinking=False)\n\n        # Second turn should remember context\n        session = pipe.chat(prompt='What is my favorite color?',\n                            session=session,\n                            stream_response=False,\n                            gen_config=GenerationConfig(max_new_tokens=30),\n                            enable_thinking=False)\n\n        assert 'blue' in session.response.text.lower()\n        assert len(session.history) == 2\n\n    def test_session_creation(self, pipe):\n        \"\"\"Test session method to create new sessions.\"\"\"\n        session1 = pipe.session()\n        session2 = pipe.session()\n\n        assert session1 is not None\n        assert session2 is not None\n        assert session1 != session2\n\n    def test_get_ppl_single(self, pipe):\n        \"\"\"Test get_ppl with single input.\"\"\"\n        from transformers import AutoTokenizer\n        tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)\n\n        text = 'This is a test sentence.'\n        input_ids = tokenizer.encode(text, return_tensors='pt')[0].tolist()\n\n        ppl = pipe.get_ppl(input_ids)\n\n        assert isinstance(ppl, list)\n        assert len(ppl) == 1\n        assert isinstance(ppl[0], float)\n        assert ppl[0] > 0\n\n    def test_get_ppl_batch(self, pipe):\n        \"\"\"Test get_ppl with batch inputs.\"\"\"\n        from transformers import AutoTokenizer\n        tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)\n\n        texts = ['First text.', 'Second text.']\n        input_ids_list = [tokenizer.encode(text, return_tensors='pt')[0].tolist() for text in texts]\n\n        ppl = pipe.get_ppl(input_ids_list)\n\n        assert isinstance(ppl, list)\n        assert len(ppl) == len(texts)\n        for score in ppl:\n            assert isinstance(score, float)\n            assert score > 0\n\n    def test_stream_infer_stream_response_parameter(self, pipe):\n        \"\"\"Test stream_infer stream_response parameter.\"\"\"\n        prompt = 'Test'\n        gen = pipe.stream_infer(prompt, stream_response=True)\n        assert hasattr(gen, '__iter__')\n\n        results = list(gen)\n        assert len(results) > 0\n\n    @pytest.mark.parametrize('max_new_tokens', [10, 50, 100])\n    def test_infer_different_max_tokens(self, pipe, max_new_tokens):\n        \"\"\"Parametrized test for different max_new_tokens values.\"\"\"\n        gen_config = GenerationConfig(max_new_tokens=max_new_tokens)\n        prompt = 'Continue: Once upon a time'\n        response = pipe.infer(prompt, gen_config=gen_config)\n\n        assert response.generate_token_len <= max_new_tokens + 5\n\n    def test_batch_infer_different_gen_configs(self, pipe):\n        \"\"\"Test batch infer with different GenerationConfig per prompt.\"\"\"\n        prompts = ['Short answer: What is AI?', 'Long answer: Explain ML']\n        gen_configs = [GenerationConfig(max_new_tokens=20), GenerationConfig(max_new_tokens=50)]\n\n        responses = pipe.infer(prompts, gen_config=gen_configs)\n\n        assert len(responses) == 2\n        assert responses[0].generate_token_len <= responses[1].generate_token_len + 10\n\n    def test_infer_zero_tokens(self, pipe):\n        \"\"\"Test infer with max_new_tokens=0 to end generation immediately\n        without producing tokens.\"\"\"\n        gen_config = GenerationConfig(max_new_tokens=0)\n        prompt = 'This prompt should not generate any response'\n        response = pipe.infer(prompt, gen_config=gen_config, enable_thinking=False)\n        assert isinstance(response, Response)\n        assert response.generate_token_len == 0\n"
  },
  {
    "path": "tests/test_lmdeploy/test_qwen3_parser.py",
    "content": "import collections\nimport json\nimport time\nfrom typing import Generator, List, Tuple, Union\n\nimport pytest\nimport shortuuid\n\nfrom lmdeploy.serve.openai.api_server import VariableInterface\nfrom lmdeploy.serve.openai.protocol import (ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseChoice,\n                                            ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse,\n                                            ChatMessage, DeltaMessage, DeltaToolCall, UsageInfo)\nfrom lmdeploy.serve.openai.reasoning_parser.qwen_qwq_reasoning_parser import QwenQwQReasoningParser\nfrom lmdeploy.serve.openai.tool_parser.qwen3_parser import Qwen3ToolParser\n\nTestExpects = collections.namedtuple('TestExpects', 'func_name location')\n\n\nclass DummyTokenizer:\n\n    def decode(self, token_ids: List[int]) -> str:\n        return ' '.join(map(str, token_ids))\n\n    def encode(self, text: str) -> List[int]:\n        return [ord(c) for c in text]\n\n\nDELTA_TEXT_SEQUENCE = [\n    '<think>',\n    '\\n',\n    '好的',\n    '，',\n    '用户',\n    '问',\n    '的是',\n    '北京',\n    '的',\n    '天气',\n    '怎么样',\n    '。',\n    '我',\n    '需要',\n    '调',\n    '用',\n    'get',\n    '_weather',\n    '这个',\n    '工具',\n    '来',\n    '获取',\n    '信息',\n    '。',\n    '首先',\n    '，',\n    '确认',\n    '用户',\n    '提供的',\n    '地点',\n    '是',\n    '北京',\n    '，',\n    '参数',\n    '正确',\n    '。',\n    '然后',\n    '检查',\n    '工具',\n    '的',\n    '参数',\n    '要求',\n    '，',\n    '只需要',\n    'location',\n    '，',\n    '类型',\n    '是',\n    '字符串',\n    '。',\n    '于是',\n    '构造',\n    '参数',\n    '对象',\n    '，',\n    '调',\n    '用',\n    '函数',\n    '，',\n    '返回',\n    '结果',\n    '。',\n    '确保',\n    '没有',\n    '遗漏',\n    '必要',\n    '参数',\n    '，',\n    '比如',\n    'location',\n    '是',\n    '必须',\n    '的',\n    '，',\n    '这里',\n    '已经',\n    '提供',\n    '，',\n    '所以',\n    '没问题',\n    '。',\n    '最后',\n    '将',\n    '结果',\n    '以',\n    '自然',\n    '语言',\n    '回复',\n    '用户',\n    '。\\n',\n    '</think>',\n    '\\n\\n',\n    '<tool_call>',\n    '\\n',\n    '{\"',\n    'name',\n    '\":',\n    ' \"',\n    'get',\n    '_weather',\n    '\",',\n    ' \"',\n    'arguments',\n    '\":',\n    ' {\"',\n    'location',\n    '\":',\n    ' \"',\n    '北京',\n    '\"}}\\n',\n    '</tool_call>',\n]\n\nDELTA_TEXT_SEQUENCE_MULTIPLE_CALLS = DELTA_TEXT_SEQUENCE + [\n    '\\n\\n',\n    '<tool_call>',\n    '\\n',\n    '{\"',\n    'name',\n    '\":',\n    ' \"',\n    'get',\n    '_weather',\n    '\",',\n    ' \"',\n    'arguments',\n    '\":',\n    ' {\"',\n    'location',\n    '\":',\n    ' \"',\n    '上海',\n    '\"}}\\n',\n    '</tool_call>',\n]\n\nEXPECTED_CONTENT = ''\nEXPECTED_REASONING_CONTENT = ''.join((\n    '好的，用户问的是北京的天气怎么样。我需要调用get_weather这个工具来获取信息。',\n    '首先，确认用户提供的地点是北京，参数正确。然后检查工具的参数要求，',\n    '只需要location，类型是字符串。于是构造参数对象，调用函数，返回结果。',\n    '确保没有遗漏必要参数，比如location是必须的，这里已经提供，所以没问题。',\n    '最后将结果以自然语言回复用户。',\n))\n\n\ndef _chat_completion_v1(\n        request: ChatCompletionRequest,\n        text_sequence: List[str]) -> Union[ChatCompletionResponse, Generator[ChatCompletionStreamResponse, None, None]]:\n    request_id = f'chat-{shortuuid.random()}'\n    created_time = int(time.time())\n    model_name = request.model\n    if request.stream:\n\n        def completion_stream_generator() -> Generator[ChatCompletionStreamResponse, None, None]:\n            previous_text = ''\n            current_text = ''\n            finish_reason = 'stop'\n            has_parser = VariableInterface.tool_parser is not None or VariableInterface.reasoning_parser is not None\n            for text in text_sequence:\n                logprobs, usage = None, None\n                delta_message = DeltaMessage(role='assistant', content=text)\n                if has_parser:\n                    current_text = current_text + text\n                if request.tool_choice != 'none' and VariableInterface.tool_parser is not None:\n                    tool_delta = VariableInterface.tool_parser.extract_tool_calls_streaming(\n                        previous_text=previous_text,\n                        current_text=current_text,\n                        delta_text=delta_message.content,\n                        previous_token_ids=[],\n                        current_token_ids=[],\n                        delta_token_ids=[],\n                        request=request)\n                    if tool_delta is not None:\n                        delta_message.tool_calls = tool_delta.tool_calls\n                        delta_message.content = tool_delta.content or ''\n                if VariableInterface.reasoning_parser is not None:\n                    reasoning_delta = VariableInterface.reasoning_parser.extract_reasoning_content_streaming(\n                        previous_text=previous_text,\n                        current_text=current_text,\n                        delta_text=delta_message.content,\n                        previous_token_ids=[],\n                        current_token_ids=[],\n                        delta_token_ids=[])\n                    if reasoning_delta is not None:\n                        delta_message.reasoning_content = reasoning_delta.reasoning_content\n                        delta_message.content = reasoning_delta.content or ''\n                if has_parser:\n                    previous_text = current_text\n\n                choice_data = ChatCompletionResponseStreamChoice(index=0,\n                                                                 delta=delta_message,\n                                                                 finish_reason=finish_reason,\n                                                                 logprobs=logprobs)\n                response = ChatCompletionStreamResponse(\n                    id=request_id,\n                    created=created_time,\n                    model=model_name,\n                    choices=[choice_data],\n                    usage=usage,\n                )\n                yield response\n\n        return completion_stream_generator()\n\n    # copied and simplified from api_server.py:chat_completions_v1\n    text = ''.join(text_sequence)\n    tool_calls = None\n    reasoning_content = None\n    finish_reason = 'stop'\n    if request.tool_choice != 'none' and VariableInterface.tool_parser is not None:\n        tool_call_info = VariableInterface.tool_parser.extract_tool_calls(text, request=request)\n        text, tool_calls = tool_call_info.content, tool_call_info.tool_calls\n        if isinstance(tool_calls, List) and len(tool_calls):\n            if finish_reason == 'stop':\n                finish_reason = 'tool_calls'\n\n    if VariableInterface.reasoning_parser is not None:\n        reasoning_content, text = VariableInterface.reasoning_parser.extract_reasoning_content(text, request)\n\n    choices = []\n    choice_data = ChatCompletionResponseChoice(\n        index=0,\n        message=ChatMessage(role='assistant', content=text, tool_calls=tool_calls, reasoning_content=reasoning_content),\n        finish_reason=finish_reason,\n    )\n    choices.append(choice_data)\n\n    return ChatCompletionResponse(\n        id=request_id,\n        created=created_time,\n        model=model_name,\n        choices=choices,\n        usage=UsageInfo(),\n    )\n\n\ndef _stream_parse(request: ChatCompletionRequest, text_sequence: List[str]) -> Tuple[str, str, List[DeltaToolCall]]:\n    # Call parser.extract_tool_calls_streaming with delta_text specified in `DELTA_TEXT_SEQUENCE`.\n    # `current_text` and `previous_text` init values and update logic\n    # can be found in lmdeploy/serve/openai/api_server.py:455-523.\n    content = ''\n    reasoning_content = ''\n    tool_calls = {}\n\n    for stream_resp in _chat_completion_v1(request, text_sequence):\n        delta_message: DeltaMessage = stream_resp.choices[0].delta\n        if delta_message.content:\n            content += delta_message.content\n        if delta_message.reasoning_content:\n            reasoning_content += delta_message.reasoning_content\n        if delta_message.tool_calls:\n            for c in delta_message.tool_calls:\n                existing_call = tool_calls.get(c.id, None)\n                if not existing_call:\n                    tool_calls[c.id] = c\n                    continue\n                # merge with existing\n                if c.function.name:\n                    existing_call.function.name = c.function.name\n                if c.function.arguments:\n                    existing_call.function.arguments = existing_call.function.arguments or ''\n                    existing_call.function.arguments += c.function.arguments\n    return content, reasoning_content, list(sorted(tool_calls.values(), key=lambda x: x.index))\n\n\n@pytest.mark.parametrize(('text_sequence', 'expects'), [\n    (DELTA_TEXT_SEQUENCE, [TestExpects('get_weather', '北京')]),\n    (DELTA_TEXT_SEQUENCE_MULTIPLE_CALLS, [TestExpects('get_weather', '北京'),\n                                          TestExpects('get_weather', '上海')]),\n])\ndef test_parser_stream(text_sequence: List[str], expects: List[TestExpects]):\n    tokenizer = DummyTokenizer()\n    VariableInterface.tool_parser = Qwen3ToolParser(tokenizer=tokenizer)\n    VariableInterface.reasoning_parser = QwenQwQReasoningParser(tokenizer=tokenizer)\n    request = ChatCompletionRequest(model='qwen', messages=[], stream=True)\n    content, reasoning_content, tool_calls = _stream_parse(request, text_sequence)\n    assert len(tool_calls) == len(expects)\n    for parsed_call, expected_call in zip(tool_calls, expects):\n        assert parsed_call.function.name == expected_call.func_name\n        args = json.loads(parsed_call.function.arguments)\n        assert args['location'] == expected_call.location\n        assert content.strip() == EXPECTED_CONTENT\n        assert reasoning_content.strip() == EXPECTED_REASONING_CONTENT\n\n\n@pytest.mark.parametrize(('text_sequence', 'expects'), [\n    (DELTA_TEXT_SEQUENCE, [TestExpects('get_weather', '北京')]),\n    (DELTA_TEXT_SEQUENCE_MULTIPLE_CALLS, [TestExpects('get_weather', '北京'),\n                                          TestExpects('get_weather', '上海')]),\n])\ndef test_parser_nonstream(text_sequence: List[str], expects: List[TestExpects]):\n    tokenizer = DummyTokenizer()\n    VariableInterface.tool_parser = Qwen3ToolParser(tokenizer=tokenizer)\n    VariableInterface.reasoning_parser = QwenQwQReasoningParser(tokenizer=tokenizer)\n    resp: ChatCompletionResponse = _chat_completion_v1(ChatCompletionRequest(model='qwen', messages=[], stream=False),\n                                                       text_sequence)\n\n    assert len(resp.choices) == 1\n    first_message = resp.choices[0].message\n    assert first_message.content is None\n    assert first_message.reasoning_content == EXPECTED_REASONING_CONTENT\n    assert len(first_message.tool_calls) == len(expects)\n    for parsed_call, expected_call in zip(first_message.tool_calls, expects):\n        assert parsed_call.function.name == expected_call.func_name\n        args = json.loads(parsed_call.function.arguments)\n        assert args['location'] == expected_call.location\n\n\ndef test_no_think_nonstream():\n    text_sequence = [\n        '你好',\n        '呀',\n        '！',\n        '✨',\n        '',\n        ' 很',\n        '高兴',\n        '见到',\n        '你',\n        '！',\n    ]\n    tokenizer = DummyTokenizer()\n    VariableInterface.tool_parser = Qwen3ToolParser(tokenizer=tokenizer)\n    VariableInterface.reasoning_parser = QwenQwQReasoningParser(tokenizer=tokenizer)\n    resp: ChatCompletionResponse = _chat_completion_v1(ChatCompletionRequest(model='qwen', messages=[], stream=False),\n                                                       text_sequence)\n\n    assert len(resp.choices) == 1\n    first_message = resp.choices[0].message\n    assert first_message.content == '你好呀！✨ 很高兴见到你！'\n    assert first_message.reasoning_content is None\n"
  },
  {
    "path": "tests/test_lmdeploy/test_qwen3coder_parser.py",
    "content": "import collections\nimport json\nimport time\nfrom typing import Generator, List, Tuple, Union\n\nimport pytest\nimport shortuuid\n\nfrom lmdeploy.serve.openai.api_server import VariableInterface\nfrom lmdeploy.serve.openai.protocol import (ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseChoice,\n                                            ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse,\n                                            ChatMessage, DeltaMessage, DeltaToolCall, UsageInfo)\nfrom lmdeploy.serve.openai.tool_parser.qwen3coder_parser import Qwen3CoderToolParser\n\nTestExpects = collections.namedtuple('TestExpects', 'func_name kwargs')\n\n\nclass DummyTokenizer:\n\n    def decode(self, token_ids: List[int]) -> str:\n        return ' '.join(map(str, token_ids))\n\n    def encode(self, text: str) -> List[int]:\n        return [ord(c) for c in text]\n\n\nDELTA_TEXT_SEQUENCE = [\n    '好的，我现在帮你调用工具。\\n',\n    '<tool_call>',\n    '\\n',\n    '<function=get_wea',\n    'ther>\\n',\n    '<parameter=loca',\n    'tion>',\n    '北京</par',\n    'ameter>\\n',\n    '<parameter=uni',\n    't>celsius</parameter>\\n',\n    '</function>\\n',\n    '</tool_call>',\n]\n\nDELTA_TEXT_SEQUENCE_MULTIPLE_CALLS = DELTA_TEXT_SEQUENCE + [\n    '\\n\\n',\n    '<tool_call>',\n    '\\n<function=get_weather',\n    '>\\n',\n    '<parameter=location>上海</parameter>\\n',\n    '</function>\\n',\n    '</tool_call>',\n]\n\nEXPECTED_CONTENT = '好的，我现在帮你调用工具。'\n\n\ndef _chat_completion_v1(\n        request: ChatCompletionRequest,\n        text_sequence: List[str]) -> Union[ChatCompletionResponse, Generator[ChatCompletionStreamResponse, None, None]]:\n    request_id = f'chat-{shortuuid.random()}'\n    created_time = int(time.time())\n    model_name = request.model\n    if request.stream:\n\n        def completion_stream_generator() -> Generator[ChatCompletionStreamResponse, None, None]:\n            previous_text = ''\n            current_text = ''\n            finish_reason = 'stop'\n            has_parser = (VariableInterface.tool_parser is not None or VariableInterface.reasoning_parser is not None)\n            for text in text_sequence:\n                logprobs, usage = None, None\n                delta_message = DeltaMessage(role='assistant', content=text)\n                if has_parser:\n                    current_text = current_text + text\n                has_tool = VariableInterface.tool_parser is not None\n                if request.tool_choice != 'none' and has_tool:\n                    tool_delta = VariableInterface.tool_parser.extract_tool_calls_streaming(\n                        previous_text=previous_text,\n                        current_text=current_text,\n                        delta_text=delta_message.content,\n                        previous_token_ids=[],\n                        current_token_ids=[],\n                        delta_token_ids=[],\n                        request=request)\n                    if tool_delta is not None:\n                        delta_message.tool_calls = tool_delta.tool_calls\n                        delta_message.content = tool_delta.content or ''\n                if VariableInterface.reasoning_parser is not None:\n                    parser = VariableInterface.reasoning_parser\n                    reasoning_delta = parser.extract_reasoning_content_streaming(previous_text=previous_text,\n                                                                                 current_text=current_text,\n                                                                                 delta_text=delta_message.content,\n                                                                                 previous_token_ids=[],\n                                                                                 current_token_ids=[],\n                                                                                 delta_token_ids=[])\n                    if reasoning_delta is not None:\n                        delta_message.reasoning_content = (reasoning_delta.reasoning_content)\n                        delta_message.content = reasoning_delta.content or ''\n                if has_parser:\n                    previous_text = current_text\n\n                choice_data = ChatCompletionResponseStreamChoice(index=0,\n                                                                 delta=delta_message,\n                                                                 finish_reason=finish_reason,\n                                                                 logprobs=logprobs)\n                response = ChatCompletionStreamResponse(\n                    id=request_id,\n                    created=created_time,\n                    model=model_name,\n                    choices=[choice_data],\n                    usage=usage,\n                )\n                yield response\n\n        return completion_stream_generator()\n\n    text = ''.join(text_sequence)\n    tool_calls = None\n    reasoning_content = None\n    finish_reason = 'stop'\n    has_tool = VariableInterface.tool_parser is not None\n    if request.tool_choice != 'none' and has_tool:\n        tool_call_info = VariableInterface.tool_parser.extract_tool_calls(text, request=request)\n        text, tool_calls = tool_call_info.content, tool_call_info.tool_calls\n        if isinstance(tool_calls, List) and len(tool_calls):\n            if finish_reason == 'stop':\n                finish_reason = 'tool_calls'\n\n    if VariableInterface.reasoning_parser is not None:\n        parser = VariableInterface.reasoning_parser\n        reasoning_content, text = parser.extract_reasoning_content(text, request)\n\n    choices = []\n    choice_data = ChatCompletionResponseChoice(\n        index=0,\n        message=ChatMessage(role='assistant', content=text, tool_calls=tool_calls, reasoning_content=reasoning_content),\n        finish_reason=finish_reason,\n    )\n    choices.append(choice_data)\n\n    return ChatCompletionResponse(\n        id=request_id,\n        created=created_time,\n        model=model_name,\n        choices=choices,\n        usage=UsageInfo(),\n    )\n\n\ndef _stream_parse(request: ChatCompletionRequest, text_sequence: List[str]) -> Tuple[str, str, List[DeltaToolCall]]:\n    content = ''\n    reasoning_content = ''\n    tool_calls = {}\n\n    for stream_resp in _chat_completion_v1(request, text_sequence):\n        delta_message: DeltaMessage = stream_resp.choices[0].delta\n        if delta_message.content:\n            content += delta_message.content\n        if delta_message.reasoning_content:\n            reasoning_content += delta_message.reasoning_content\n        if delta_message.tool_calls:\n            for c in delta_message.tool_calls:\n                existing_call = tool_calls.get(c.id, None)\n                if not existing_call:\n                    tool_calls[c.id] = c\n                    continue\n                # merge with existing\n                if c.function.name:\n                    existing_call.function.name = c.function.name\n                if c.function.arguments:\n                    existing_call.function.arguments = (existing_call.function.arguments or '')\n                    existing_call.function.arguments += c.function.arguments\n    return content, reasoning_content, list(sorted(tool_calls.values(), key=lambda x: x.index))\n\n\n@pytest.mark.parametrize(('text_sequence', 'expects'), [\n    (DELTA_TEXT_SEQUENCE, [TestExpects('get_weather', {\n        'location': '北京',\n        'unit': 'celsius'\n    })]),\n    (DELTA_TEXT_SEQUENCE_MULTIPLE_CALLS, [\n        TestExpects('get_weather', {\n            'location': '北京',\n            'unit': 'celsius'\n        }),\n        TestExpects('get_weather', {'location': '上海'})\n    ]),\n])\ndef test_parser_stream(text_sequence: List[str], expects: List[TestExpects]):\n    tokenizer = DummyTokenizer()\n    VariableInterface.tool_parser = Qwen3CoderToolParser(tokenizer=tokenizer)\n    VariableInterface.reasoning_parser = None\n    request = ChatCompletionRequest(model='qwen3coder', messages=[], stream=True)\n    content, reasoning_content, tool_calls = _stream_parse(request, text_sequence)\n    assert len(tool_calls) == len(expects)\n    for parsed_call, expected_call in zip(tool_calls, expects):\n        assert parsed_call.function.name == expected_call.func_name\n        args = json.loads(parsed_call.function.arguments)\n        assert args == expected_call.kwargs\n        assert content.strip() == EXPECTED_CONTENT\n\n\n@pytest.mark.parametrize(('text_sequence', 'expects'), [\n    (DELTA_TEXT_SEQUENCE, [TestExpects('get_weather', {\n        'location': '北京',\n        'unit': 'celsius'\n    })]),\n    (DELTA_TEXT_SEQUENCE_MULTIPLE_CALLS, [\n        TestExpects('get_weather', {\n            'location': '北京',\n            'unit': 'celsius'\n        }),\n        TestExpects('get_weather', {'location': '上海'})\n    ]),\n])\ndef test_parser_nonstream(text_sequence: List[str], expects: List[TestExpects]):\n    tokenizer = DummyTokenizer()\n    VariableInterface.tool_parser = Qwen3CoderToolParser(tokenizer=tokenizer)\n    VariableInterface.reasoning_parser = None\n    resp: ChatCompletionResponse = _chat_completion_v1(\n        ChatCompletionRequest(model='qwen3coder', messages=[], stream=False), text_sequence)\n\n    assert len(resp.choices) == 1\n    first_message = resp.choices[0].message\n    assert first_message.content.strip() == EXPECTED_CONTENT\n    assert first_message.reasoning_content is None\n    assert len(first_message.tool_calls) == len(expects)\n    for parsed_call, expected_call in zip(first_message.tool_calls, expects):\n        assert parsed_call.function.name == expected_call.func_name\n        args = json.loads(parsed_call.function.arguments)\n        assert args == expected_call.kwargs\n\n\ndef test_no_think_nonstream():\n    text_sequence = [\n        '你好',\n        '呀',\n        '！',\n        '✨',\n        '',\n        ' 很',\n        '高兴',\n        '见到',\n        '你',\n        '！',\n    ]\n    tokenizer = DummyTokenizer()\n    VariableInterface.tool_parser = Qwen3CoderToolParser(tokenizer=tokenizer)\n    VariableInterface.reasoning_parser = None\n    resp: ChatCompletionResponse = _chat_completion_v1(\n        ChatCompletionRequest(model='qwen3coder', messages=[], stream=False), text_sequence)\n\n    assert len(resp.choices) == 1\n    first_message = resp.choices[0].message\n    assert first_message.content == '你好呀！✨ 很高兴见到你！'\n    assert first_message.reasoning_content is None\n"
  },
  {
    "path": "tests/test_lmdeploy/test_tokenizer.py",
    "content": "import random\n\nimport pytest\n\nfrom lmdeploy.tokenizer import DetokenizeState, HuggingFaceTokenizer, Tokenizer\n\n\n@pytest.mark.parametrize('model_path', [\n    'internlm/internlm-chat-7b', 'Qwen/Qwen-7B-Chat', 'baichuan-inc/Baichuan2-7B-Chat', 'upstage/SOLAR-0-70b-16bit',\n    'baichuan-inc/Baichuan-7B', 'codellama/CodeLlama-7b-hf', 'THUDM/chatglm2-6b', '01-ai/Yi-6B-200k',\n    '01-ai/Yi-34B-Chat', '01-ai/Yi-6B-Chat', 'WizardLM/WizardLM-70B-V1.0', 'codellama/CodeLlama-34b-Instruct-hf'\n])\n@pytest.mark.parametrize('input', [' hi, this is a test 😆😆! 為什麼我還在用繁體字 😆😆       ' * 5])\n@pytest.mark.parametrize('interval', [1, 3])\n@pytest.mark.parametrize('add_special_tokens', [True, False])\n@pytest.mark.parametrize('skip_special_tokens', [True, False])\ndef test_tokenizer(model_path, input, interval, add_special_tokens, skip_special_tokens):\n    tokenizer = Tokenizer(model_path).model\n    encoded = tokenizer.encode(input, False, add_special_tokens=add_special_tokens)\n    output = ''\n    input = tokenizer.decode(encoded, skip_special_tokens=skip_special_tokens)\n    state = DetokenizeState()\n    for i in range(0, len(encoded), interval):\n        offset = i + interval\n        if offset < len(encoded):\n            # lmdeploy may decode nothing when concurrency is high\n            if random.randint(1, 10) < 4:\n                offset -= interval\n        decoded, state = tokenizer.detokenize_incrementally(encoded[:offset], state, skip_special_tokens)\n        output += decoded\n    assert input == output, 'input string should equal to output after enc-dec'\n\n\n@pytest.mark.parametrize('model_path', [\n    'internlm/internlm-chat-7b', 'Qwen/Qwen-7B-Chat', 'baichuan-inc/Baichuan2-7B-Chat', 'codellama/CodeLlama-7b-hf',\n    'upstage/SOLAR-0-70b-16bit'\n])\n@pytest.mark.parametrize('stop_words', ['.', ' ', '?', ''])\ndef test_tokenizer_with_stop_words(model_path, stop_words):\n    tokenizer = HuggingFaceTokenizer(model_path)\n    indexes = tokenizer.indexes_containing_token(stop_words)\n    assert indexes is not None\n\n\ndef test_qwen_vl_decode_special():\n    from lmdeploy.tokenizer import Tokenizer\n    tok = Tokenizer('Qwen/Qwen-VL-Chat')\n    try:\n        tok.decode([151857])\n        assert (0)\n    except Exception as e:\n        assert str(e) == 'Unclosed image token'\n\n\ndef test_glm4_special_token():\n    from lmdeploy.tokenizer import ChatGLM4Tokenizer, Tokenizer\n    model_path = 'THUDM/glm-4-9b-chat'\n    tokenizer = Tokenizer(model_path)\n    assert isinstance(tokenizer.model, ChatGLM4Tokenizer)\n    special_tokens = [\n        '<|endoftext|>', '[MASK]', '[gMASK]', '[sMASK]', '<sop>', '<eop>', '<|system|>', '<|user|>', '<|assistant|>',\n        '<|observation|>', '<|begin_of_image|>', '<|end_of_image|>', '<|begin_of_video|>', '<|end_of_video|>'\n    ]\n    speicial_token_ids = [i for i in range(151329, 151343)]\n\n    for token, token_id in zip(special_tokens, speicial_token_ids):\n        _token_id = tokenizer.encode(token, add_bos=False)\n        assert len(_token_id) == 1 and _token_id[0] == token_id\n\n\n@pytest.mark.parametrize('model_path',\n                         ['Qwen/Qwen2-7B-Instruct', 'deepseek-ai/deepseek-vl-1.3b-chat', 'OpenGVLab/InternVL2-1B'])\ndef test_check_transformers_version(model_path):\n    tokenizer = HuggingFaceTokenizer(model_path)\n    assert tokenizer is not None\n"
  },
  {
    "path": "tests/test_lmdeploy/test_turbomind/test_converter.py",
    "content": "# yapf: disable\nfrom lmdeploy import TurbomindEngineConfig\nfrom lmdeploy.turbomind import update_parallel_config\nfrom lmdeploy.turbomind.deploy.converter import (get_input_model_registered_name,\n                                                 get_output_model_registered_name_and_config)\nfrom lmdeploy.turbomind.deploy.source_model.base import INPUT_MODELS\n\n# yapf: enable\n\n\ndef test_torch_dtype_fallback():\n    \"\"\"torch_dtype is deprecated in transformers v5+; dtype should be\n    preferred.\n\n    This test ensures get_output_model_registered_name_and_config still works\n    for models whose config exposes either `dtype` or `torch_dtype`.\n    \"\"\"\n    _, config = get_output_model_registered_name_and_config(\n        'internlm/internlm2-chat-7b',\n        model_format='hf',\n        dtype='auto',\n        group_size=0,\n    )\n    assert config.weight_type in ('float16', 'bfloat16')\n\n\ndef test_ffn_reader_kind_none():\n    \"\"\"FFN readers must handle kind=None (returns filter list, not tensors).\n\n    This is the probe call from Ffn.apply() to discover parameter keys before loading actual tensor data. A missing\n    guard causes KeyError with 'None' in the key string (regression test for InternLM2Reader._ffn bug).\n    \"\"\"\n    import re\n\n    from lmdeploy.turbomind.deploy.source_model.internlm2 import InternLM2Reader\n    from lmdeploy.turbomind.deploy.source_model.llama import LlamaReader\n\n    # Create minimal readers with fake params that match ffn patterns\n    fake_params = {\n        'model.layers.0.mlp.gate_proj.weight': None,\n        'model.layers.0.mlp.down_proj.weight': None,\n        'model.layers.0.mlp.up_proj.weight': None,\n        'model.layers.0.feed_forward.w1.weight': None,\n        'model.layers.0.feed_forward.w2.weight': None,\n        'model.layers.0.feed_forward.w3.weight': None,\n    }\n\n    # LlamaReader with kind=None should return filtered key list\n    reader = LlamaReader.__new__(LlamaReader)\n    reader.params = dict(fake_params)\n    reader.ffn_pattern = r'mlp'\n    result = reader._ffn(0, None)\n    assert isinstance(result, list)\n    assert len(result) > 0\n    assert all(isinstance(k, str) for k in result)\n    assert all(re.search(r'mlp', k) for k in result)\n\n    # InternLM2Reader with kind=None should also return filtered key list\n    reader2 = InternLM2Reader.__new__(InternLM2Reader)\n    reader2.params = dict(fake_params)\n    reader2.fp8_quant = None\n    reader2.ffn_pattern = r'feed_forward'\n    result2 = reader2._ffn(0, None)\n    assert isinstance(result2, list)\n    assert len(result2) > 0\n    assert all(isinstance(k, str) for k in result2)\n    assert all(re.search(r'feed_forward', k) for k in result2)\n\n\ndef test_registered_models():\n    for model, model_format, group_size, weight_type, register_name in [\n        ('internlm/internlm2-7b', 'hf', 0, 'bfloat16', 'tm'), ('baichuan-inc/Baichuan-7B', 'hf', 0, 'float16', 'tm'),\n        ('baichuan-inc/Baichuan2-7B-Chat', 'hf', 0, 'bfloat16', 'tm'),\n        ('baichuan-inc/Baichuan-13B-Chat', 'hf', 0, 'bfloat16', 'tm'),\n        ('baichuan-inc/Baichuan2-13B-Chat', 'hf', 0, 'bfloat16', 'tm'),\n        ('internlm/internlm-chat-7b', 'hf', 0, 'float16', 'tm'),\n        ('internlm/internlm2-chat-7b', 'hf', 0, 'bfloat16', 'tm'),\n        ('internlm/internlm-xcomposer2-4khd-7b', 'hf', 0, 'bfloat16', 'tm'),\n        ('internlm/internlm-xcomposer2-vl-7b', 'hf', 0, 'bfloat16', 'tm'),\n        ('internlm/internlm-xcomposer2-7b', 'hf', 0, 'bfloat16', 'tm'),\n        ('lmsys/vicuna-7b-v1.5', 'hf', 0, 'float16', 'tm'), ('01-ai/Yi-1.5-9B', 'hf', 0, 'bfloat16', 'tm'),\n        ('deepseek-ai/deepseek-coder-6.7b-instruct', 'hf', 0, 'bfloat16', 'tm'),\n        ('deepseek-ai/deepseek-llm-7b-chat', 'hf', 0, 'bfloat16', 'tm'),\n        ('Qwen/Qwen-7B-Chat', 'hf', 0, 'bfloat16', 'tm'), ('Qwen/Qwen1.5-7B-Chat', 'hf', 0, 'bfloat16', 'tm'),\n        ('Qwen/Qwen2-7B-Instruct', 'hf', 0, 'bfloat16', 'tm'), ('Qwen/Qwen-VL-Chat', 'hf', 0, 'bfloat16', 'tm'),\n        ('liuhaotian/llava-v1.6-34b', 'hf', 0, 'bfloat16', 'tm'),\n        ('liuhaotian/llava-v1.6-mistral-7b', 'hf', 0, 'bfloat16', 'tm'),\n        ('liuhaotian/llava-v1.6-vicuna-13b', 'hf', 0, 'bfloat16', 'tm'),\n        ('OpenGVLab/InternVL-Chat-V1-5', 'hf', 0, 'bfloat16', 'tm'),\n        ('deepseek-ai/deepseek-vl-7b-chat', 'hf', 0, 'float16', 'tm'),\n        ('Qwen/Qwen1.5-4B-Chat-AWQ', 'awq', 128, 'int4', 'tm'),\n        ('solidrust/Meta-Llama-3-8B-Instruct-hf-AWQ', 'awq', 128, 'int4', 'tm'),\n        ('internlm/internlm2-chat-20b-4bits', 'awq', 128, 'int4', 'tm'),\n        ('internlm/internlm-xcomposer2-vl-7b-4bit', 'awq', 128, 'int4', 'tm')\n    ]:\n        input_name = get_input_model_registered_name(model, model_format=model_format)\n        assert input_name in list(INPUT_MODELS.module_dict.keys())\n\n        output_name, config = get_output_model_registered_name_and_config(model,\n                                                                          model_format=model_format,\n                                                                          dtype='auto',\n                                                                          group_size=0)\n        assert output_name == register_name\n        assert config.model_config.group_size == group_size\n        assert config.session_len > 0\n        assert config.model_config.model_arch is not None\n\n\ndef test_update_from_engine_config():\n    import copy\n    _, _config = get_output_model_registered_name_and_config('internlm/internlm2-chat-7b',\n                                                             model_format='hf',\n                                                             dtype='auto',\n                                                             group_size=0)\n    config = copy.deepcopy(_config)\n    config.update_from_engine_config(None)\n    assert (config == _config)\n\n    config = copy.deepcopy(_config)\n    engine_config = TurbomindEngineConfig()\n    update_parallel_config(engine_config)\n    config.update_from_engine_config(engine_config)\n    assert config.model_config.attn_tp_size == 1\n    assert config.session_len == 32768\n\n    config = copy.deepcopy(_config)\n    engine_config = TurbomindEngineConfig(model_format='hf',\n                                          tp=2,\n                                          device_num=2,\n                                          session_len=4000,\n                                          max_batch_size=100,\n                                          cache_max_entry_count=0.5,\n                                          quant_policy=8,\n                                          rope_scaling_factor=3.0,\n                                          use_logn_attn=True,\n                                          max_prefill_iters=64,\n                                          num_tokens_per_iter=256)\n    update_parallel_config(engine_config)\n    config.update_from_engine_config(engine_config)\n\n    assert (config.model_config.attn_tp_size == engine_config.attn_tp_size)\n    assert (config.session_len == engine_config.session_len)\n    assert (config.attention_config.rope_param.type == 'dynamic')\n    assert (config.attention_config.rope_param.factor == engine_config.rope_scaling_factor)\n    assert (config.attention_config.use_logn_attn == engine_config.use_logn_attn)\n\n\ndef test_dtype():\n    testsets = [('auto', 'bfloat16'), ('float16', 'float16'), ('bfloat16', 'bfloat16')]\n    for specified_dtype, expected_dtype in testsets:\n        _, _config = get_output_model_registered_name_and_config('internlm/internlm2-chat-7b',\n                                                                 model_format='hf',\n                                                                 dtype=specified_dtype,\n                                                                 group_size=0)\n        assert _config.weight_type == expected_dtype\n    for specified_dtype in ['auto', 'float16', 'bfloat16']:\n        _, _config = get_output_model_registered_name_and_config('internlm/internlm2_5-20b-chat-4bit-awq',\n                                                                 model_format='awq',\n                                                                 dtype=specified_dtype,\n                                                                 group_size=128)\n        assert _config.weight_type == 'int4'\n"
  },
  {
    "path": "tests/test_lmdeploy/test_utils.py",
    "content": "from transformers import AutoConfig\n\nfrom lmdeploy.utils import _get_and_verify_max_len\n\n\ndef test_get_and_verify_max_len():\n    # with PretrainedConfig\n    config = AutoConfig.from_pretrained('OpenGVLab/InternVL-Chat-V1-5-AWQ', trust_remote_code=True)\n    assert (_get_and_verify_max_len(config, None) == 32768)\n    assert (_get_and_verify_max_len(config, 1024) == 1024)\n    assert (_get_and_verify_max_len(config, 102400) == 102400)\n\n    # with PretrainedConfig\n    config = AutoConfig.from_pretrained('internlm/internlm2-chat-7b', trust_remote_code=True)\n    assert (_get_and_verify_max_len(config, None) == 32768)\n    assert (_get_and_verify_max_len(config, 1024) == 1024)\n    assert (_get_and_verify_max_len(config, 102400) == 102400)\n"
  },
  {
    "path": "tests/test_lmdeploy/test_vl/test_hf_chat_template.py",
    "content": "import os\n\nimport pytest\n\nfrom lmdeploy.model import MODELS\nfrom lmdeploy.vl.model.builder import load_vl_model\n\n\ndef get_model_and_chat_template(model_path):\n    if os.getenv('LMDEPLOY_USE_MODELSCOPE', 'False').lower() == 'true':\n        from modelscope import snapshot_download\n    elif os.getenv('LMDEPLOY_USE_OPENMIND_HUB', 'False').lower() == 'true':\n        from openmind_hub import snapshot_download\n    else:\n        from huggingface_hub import snapshot_download\n    model_path = snapshot_download(model_path, allow_patterns=['*.json', '*.py', '*.txt', '*.model', '*.jinja'])\n    model = load_vl_model(model_path=model_path, with_llm=False, backend='pytorch')\n    chat_template = MODELS.module_dict['hf'](model_path=model_path)\n    return model, chat_template\n\n\n@pytest.fixture(scope='module')\ndef mock_messages():\n    return [\n        dict(role='user',\n             content=[\n                 dict(type='text', text='Describe the following images in detail'),\n                 dict(type='image', url=dict(url='http://images.cocodataset.org/val2017/000000039769.jpg')),\n                 dict(type='image', url=dict(url='http://images.cocodataset.org/val2017/000000039769.jpg')),\n                 dict(type='text', text='How many cats are there in total?')\n             ]),\n    ]\n\n\n@pytest.fixture(scope='module')\ndef mock_pure_img_messages():\n    return [\n        dict(role='user',\n             content=[\n                 dict(type='image', url=dict(url='http://images.cocodataset.org/val2017/000000039769.jpg')),\n             ]),\n    ]\n\n\n@pytest.fixture(scope='module')\ndef mock_pure_text_messages():\n    return [\n        dict(role='user',\n             content=[\n                 dict(type='text', text='Describe the following images in detail'),\n                 dict(type='text', text='How many cats are there in total?'),\n             ]),\n    ]\n\n\nclass TestInternVLHFChatTemplate:\n\n    @pytest.fixture(scope='module')\n    def models(self):\n        model_list = [\n            'OpenGVLab/InternVL3_5-1B-HF',\n            'OpenGVLab/InternVL3_5-2B-HF',\n            'OpenGVLab/InternVL3_5-4B-HF',\n            'OpenGVLab/InternVL3_5-8B-HF',\n            'OpenGVLab/InternVL3_5-14B-HF',\n            'OpenGVLab/InternVL3_5-38B-HF',\n            'OpenGVLab/InternVL3_5-30B-A3B-HF',\n            'OpenGVLab/InternVL3_5-241B-A28B-HF',\n        ]\n        models = [get_model_and_chat_template(model_path) for model_path in model_list]\n        return models\n\n    def test_proc_messages(self, models, mock_messages):\n        for model, chat_template in models:\n            model.build_preprocessor()\n            reference = model.processor.apply_chat_template(mock_messages,\n                                                            add_generation_prompt=True,\n                                                            tokenize=False,\n                                                            return_dict=True)\n            # InternVL-HF and InternS1 models pad <img> and </img> internally\n            reference = reference.replace('<IMG_CONTEXT>', '<img><IMG_CONTEXT></img>')\n            prompt, _ = model.proc_messages(mock_messages, chat_template, sequence_start=True)\n            assert prompt == reference\n\n    def test_proc_pure_img_messages(self, models, mock_pure_img_messages):\n        for model, chat_template in models:\n            model.build_preprocessor()\n            reference = model.processor.apply_chat_template(mock_pure_img_messages,\n                                                            add_generation_prompt=True,\n                                                            tokenize=False,\n                                                            return_dict=True)\n            # InternVL-HF and InternS1 models pad <img> and </img> internally\n            reference = reference.replace('<IMG_CONTEXT>', '<img><IMG_CONTEXT></img>')\n            prompt, _ = model.proc_messages(mock_pure_img_messages, chat_template, sequence_start=True)\n            assert prompt == reference\n\n    def test_proc_pure_text_messages(self, models, mock_pure_text_messages):\n        for model, chat_template in models:\n            model.build_preprocessor()\n            reference = model.processor.apply_chat_template(mock_pure_text_messages,\n                                                            add_generation_prompt=True,\n                                                            tokenize=False,\n                                                            return_dict=True)\n            prompt, _ = model.proc_messages(mock_pure_text_messages, chat_template, sequence_start=True)\n            assert prompt == reference\n\n\nclass TestQwenVLChatTemplate:\n\n    @pytest.fixture(scope='module')\n    def models(self):\n        model_list = [\n            'Qwen/Qwen2-VL-2B-Instruct',\n            'Qwen/Qwen2-VL-7B-Instruct',\n            'Qwen/Qwen2-VL-72B-Instruct',\n            'Qwen/Qwen2.5-VL-3B-Instruct',\n            'Qwen/Qwen2.5-VL-7B-Instruct',\n            'Qwen/Qwen2.5-VL-32B-Instruct',\n            'Qwen/Qwen2.5-VL-72B-Instruct',\n            'Qwen/Qwen3-VL-2B-Instruct',\n            'Qwen/Qwen3-VL-2B-Thinking',\n            'Qwen/Qwen3-VL-4B-Instruct',\n            'Qwen/Qwen3-VL-4B-Thinking',\n            'Qwen/Qwen3-VL-8B-Instruct',\n            'Qwen/Qwen3-VL-8B-Thinking',\n            'Qwen/Qwen3-VL-32B-Instruct',\n            'Qwen/Qwen3-VL-32B-Thinking',\n            'Qwen/Qwen3-VL-30B-A3B-Instruct',\n            'Qwen/Qwen3-VL-30B-A3B-Thinking',\n            'Qwen/Qwen3-VL-235B-A22B-Instruct',\n            'Qwen/Qwen3-VL-235B-A22B-Thinking',\n        ]\n        models = [get_model_and_chat_template(model_path) for model_path in model_list]\n        return models\n\n    def test_proc_messages(self, models, mock_messages):\n        for model, chat_template in models:\n            model.build_preprocessor()\n            reference = model.processor.apply_chat_template(mock_messages,\n                                                            add_generation_prompt=True,\n                                                            tokenize=False,\n                                                            return_dict=True)\n            prompt, _ = model.proc_messages(mock_messages, chat_template, sequence_start=True)\n            assert prompt == reference\n\n    def test_pure_img_messages(self, models, mock_pure_img_messages):\n        for model, chat_template in models:\n            model.build_preprocessor()\n            reference = model.processor.apply_chat_template(mock_pure_img_messages,\n                                                            add_generation_prompt=True,\n                                                            tokenize=False,\n                                                            return_dict=True)\n            prompt, _ = model.proc_messages(mock_pure_img_messages, chat_template, sequence_start=True)\n            assert prompt == reference\n\n    def test_pure_text_messages(self, models, mock_pure_text_messages):\n        for model, chat_template in models:\n            model.build_preprocessor()\n            reference = model.processor.apply_chat_template(mock_pure_text_messages,\n                                                            add_generation_prompt=True,\n                                                            tokenize=False,\n                                                            return_dict=True)\n            prompt, _ = model.proc_messages(mock_pure_text_messages, chat_template, sequence_start=True)\n            assert prompt == reference\n"
  },
  {
    "path": "tests/test_lmdeploy/test_vl/test_nonhf_chat_template.py",
    "content": "import os\n\nimport pytest\n\nfrom lmdeploy.model import MODELS\nfrom lmdeploy.vl.model.builder import load_vl_model\n\n\ndef get_model_and_chat_template(model_path):\n    if os.getenv('LMDEPLOY_USE_MODELSCOPE', 'False').lower() == 'true':\n        from modelscope import snapshot_download\n    elif os.getenv('LMDEPLOY_USE_OPENMIND_HUB', 'False').lower() == 'true':\n        from openmind_hub import snapshot_download\n    else:\n        from huggingface_hub import snapshot_download\n    model_path = snapshot_download(model_path, allow_patterns=['*.json', '*.py', '*.txt', '*.model', '*.jinja'])\n    model = load_vl_model(model_path=model_path, with_llm=False, backend='pytorch')\n    chat_template = MODELS.module_dict['hf'](model_path=model_path)\n    return model, chat_template\n\n\nclass TestInternVLChatTemplate:\n\n    @pytest.fixture(scope='module')\n    def internvl3_5(self):\n        model_list = [\n            'OpenGVLab/InternVL3_5-241B-A28B',\n            'OpenGVLab/InternVL3_5-30B-A3B',\n            'OpenGVLab/InternVL3_5-38B',\n            'OpenGVLab/InternVL3_5-14B',\n            'OpenGVLab/InternVL3_5-8B',\n            'OpenGVLab/InternVL3_5-4B',\n            'OpenGVLab/InternVL3_5-2B',\n            'OpenGVLab/InternVL3_5-1B',\n        ]\n        models = [get_model_and_chat_template(model_path) for model_path in model_list]\n        return models\n\n    @pytest.fixture(scope='module')\n    def internvl3(self):\n        model_list = [\n            'OpenGVLab/InternVL3-78B',\n            'OpenGVLab/InternVL3-38B',\n            'OpenGVLab/InternVL3-14B',\n            'OpenGVLab/InternVL3-8B',\n            # \"OpenGVLab/InternVL3-9B\",  # <s>\n            'OpenGVLab/InternVL3-2B',\n            'OpenGVLab/InternVL3-1B',\n        ]\n        models = [get_model_and_chat_template(model_path) for model_path in model_list]\n        return models\n\n    @pytest.fixture(scope='module')\n    def internvl2_5(self):\n        model_list = [\n            'OpenGVLab/InternVL2_5-78B',\n            'OpenGVLab/InternVL2_5-38B',\n            # \"OpenGVLab/InternVL2_5-26B\",  # <s>\n            # \"OpenGVLab/InternVL2_5-8B\",  # <s>\n            'OpenGVLab/InternVL2_5-4B',\n            # \"OpenGVLab/InternVL2_5-2B\",  # <s>\n            'OpenGVLab/InternVL2_5-1B',\n        ]\n        models = [get_model_and_chat_template(model_path) for model_path in model_list]\n        return models\n\n    @pytest.fixture(scope='module')\n    def internvl2(self):\n        model_list = [\n            'OpenGVLab/InternVL2-Llama3-76B',\n            'OpenGVLab/InternVL2-40B',\n            'OpenGVLab/InternVL2-26B',\n            'OpenGVLab/InternVL2-8B',\n            # \"OpenGVLab/InternVL2-4B\",  # <|user|> not <|im_start|>\n            'OpenGVLab/InternVL2-2B',\n            'OpenGVLab/InternVL2-1B',\n        ]\n        models = [get_model_and_chat_template(model_path) for model_path in model_list]\n        return models\n\n    @pytest.fixture(scope='module')\n    def mock_messages(self):\n        return [\n            dict(role='user',\n                 content=[\n                     dict(type='text', text='Describe the following images in detail'),\n                     dict(type='image', url=dict(url='http://images.cocodataset.org/val2017/000000039769.jpg')),\n                     dict(type='image', url=dict(url='http://images.cocodataset.org/val2017/000000039769.jpg')),\n                     dict(type='text', text='How many cats are there in total?')\n                 ]),\n        ]\n\n    @pytest.fixture(scope='module')\n    def mock_IMAGE_TOKEN_messages(self):\n        return [\n            dict(role='system', content='你是书生·万象，英文名是InternVL，是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。'),\n            dict(role='user',\n                 content=[\n                     dict(type='text', text='<IMAGE_TOKEN>\\nDescribe the following images in detail'),\n                     dict(type='image', url=dict(url='http://images.cocodataset.org/val2017/000000039769.jpg'))\n                 ]),\n        ]\n\n    def test_internvl3_5(self, internvl3_5, mock_messages):\n        reference = \"\"\"<|im_start|>user\nDescribe the following images in detail<img><IMG_CONTEXT></img>\n<img><IMG_CONTEXT></img>\nHow many cats are there in total?<|im_end|>\n<|im_start|>assistant\n\"\"\"\n        for model, chat_template in internvl3_5:\n            prompt, _ = model.proc_messages(mock_messages, chat_template, sequence_start=True)\n\n            assert prompt == reference\n\n    def test_internvl3_5_backward_compatibility(self, internvl3_5, mock_IMAGE_TOKEN_messages):\n        reference = \"\"\"<|im_start|>system\n你是书生·万象，英文名是InternVL，是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。<|im_end|>\n<|im_start|>user\n<img><IMG_CONTEXT></img>\nDescribe the following images in detail<|im_end|>\n<|im_start|>assistant\n\"\"\"\n        for model, chat_template in internvl3_5:\n            prompt, _ = model.proc_messages(mock_IMAGE_TOKEN_messages, chat_template, sequence_start=True)\n            assert prompt == reference\n\n    def test_internvl3(self, internvl3, mock_messages):\n        reference = \"\"\"<|im_start|>system\n你是书生·万象，英文名是InternVL，是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。<|im_end|>\n<|im_start|>user\nDescribe the following images in detail<img><IMG_CONTEXT></img>\n<img><IMG_CONTEXT></img>\nHow many cats are there in total?<|im_end|>\n<|im_start|>assistant\n\"\"\"\n        for model, chat_template in internvl3:\n            prompt, _ = model.proc_messages(mock_messages, chat_template, sequence_start=True)\n            assert prompt == reference\n\n    def test_internvl3_backward_compatibility(self, internvl3, mock_IMAGE_TOKEN_messages):\n        reference = \"\"\"<|im_start|>system\n你是书生·万象，英文名是InternVL，是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。<|im_end|>\n<|im_start|>user\n<img><IMG_CONTEXT></img>\nDescribe the following images in detail<|im_end|>\n<|im_start|>assistant\n\"\"\"\n        for model, chat_template in internvl3:\n            prompt, _ = model.proc_messages(mock_IMAGE_TOKEN_messages, chat_template, sequence_start=True)\n            assert prompt == reference\n\n    def test_internvl2_5(self, internvl2_5, mock_messages):\n        reference = \"\"\"<|im_start|>system\n你是书生·万象，英文名是InternVL，是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。<|im_end|>\n<|im_start|>user\nDescribe the following images in detail<img><IMG_CONTEXT></img>\n<img><IMG_CONTEXT></img>\nHow many cats are there in total?<|im_end|>\n<|im_start|>assistant\n\"\"\"\n        for model, chat_template in internvl2_5:\n            prompt, _ = model.proc_messages(mock_messages, chat_template, sequence_start=True)\n            assert prompt == reference\n\n    def test_internvl2_5_backward_compatibility(self, internvl2_5, mock_IMAGE_TOKEN_messages):\n        reference = \"\"\"<|im_start|>system\n你是书生·万象，英文名是InternVL，是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。<|im_end|>\n<|im_start|>user\n<img><IMG_CONTEXT></img>\nDescribe the following images in detail<|im_end|>\n<|im_start|>assistant\n\"\"\"\n        for model, chat_template in internvl2_5:\n            prompt, _ = model.proc_messages(mock_IMAGE_TOKEN_messages, chat_template, sequence_start=True)\n            assert prompt == reference\n\n    def test_internvl2(self, internvl2, mock_messages):\n        reference = \"\"\"<|im_start|>user\nDescribe the following images in detail<img><IMG_CONTEXT></img>\n<img><IMG_CONTEXT></img>\nHow many cats are there in total?<|im_end|>\n<|im_start|>assistant\n\"\"\"\n        for model, chat_template in internvl2:\n            # Let sequence_start=False to avoid the begin-of-prompt token, such as <|begin_of_text|>, <s>\n            prompt, _ = model.proc_messages(mock_messages, chat_template, sequence_start=False)\n            assert prompt == reference\n\n    def test_internvl2_backward_compatibility(self, internvl2, mock_IMAGE_TOKEN_messages):\n        reference = \"\"\"<|im_start|>system\n你是书生·万象，英文名是InternVL，是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。<|im_end|>\n<|im_start|>user\n<img><IMG_CONTEXT></img>\nDescribe the following images in detail<|im_end|>\n<|im_start|>assistant\n\"\"\"\n        for model, chat_template in internvl2:\n            # Let sequence_start=False to avoid the begin-of-prompt token, such as <|begin_of_text|>, <s>\n            prompt, _ = model.proc_messages(mock_IMAGE_TOKEN_messages, chat_template, sequence_start=False)\n            assert prompt == reference\n"
  },
  {
    "path": "tests/test_lmdeploy/test_vl/test_qwen3vl_processor.py",
    "content": "import copy\n\nimport pytest\n\nfrom lmdeploy.vl import load_image\nfrom lmdeploy.vl.model.qwen3 import Qwen3VLModel\n\nQWEN3VL_MODELS = [\n    'Qwen/Qwen3-VL-4B-Instruct',\n]\n\nIMAGE_URL = ('https://raw.githubusercontent.com/open-mmlab/'\n             'mmdeploy/main/tests/data/tiger.jpeg')\n\n\n@pytest.fixture(scope='module', params=QWEN3VL_MODELS)\ndef qwen3vl_model(request):\n    \"\"\"Initialize Qwen3VLModel with a real model path.\"\"\"\n    model = Qwen3VLModel(model_path=request.param)\n    model.build_preprocessor()\n    return model\n\n\n@pytest.fixture\ndef sample_messages():\n    \"\"\"Create sample messages for preprocessing using image_url.\"\"\"\n    pil_image = load_image(IMAGE_URL)\n    return [{\n        'role':\n        'user',\n        'content': [\n            {\n                'type': 'text',\n                'text': 'Can you describe this image?'\n            },\n            {\n                'type': 'image',\n                'data': pil_image\n            },\n        ]\n    }]\n\n\ndef test_qwen3vl_preprocess_with_custom_pixels(qwen3vl_model, sample_messages):\n    \"\"\"Test that mm_processor_kwargs with min/max pixels takes effect.\"\"\"\n\n    # compression ratio for qwen3vl is 32 = patch_size * spatial_merge_size = 16 * 2\n    # qwen3vl_model.processor.image_processor.size['shortest_edge'] = 66536\n    # 65536 = 64 * 32 * 32, indicates 64 image token budget\n    # qwen3vl_model.processor.image_processor.size['longest_edge'] = 16777216\n    # 16777216 = 16384 * 32 * 32, indicates 16384 image token budget\n\n    # Default processing without custom arguments\n    default_processed_messages = qwen3vl_model.preprocess(messages=copy.deepcopy(sample_messages))\n    default_content = default_processed_messages[-1]['content']\n    default_shape = default_content[0]['pixel_values'].shape  # [280, 1536]\n\n    # Processing with smaller pixel range\n    mm_processor_kwargs = {'min_pixels': 10 * 32 * 32, 'max_pixels': 20 * 32 * 32}\n    custom_processed_messages = qwen3vl_model.preprocess(messages=copy.deepcopy(sample_messages),\n                                                         mm_processor_kwargs=mm_processor_kwargs)\n    custom_content = custom_processed_messages[-1]['content']\n    custom_shape = custom_content[0]['pixel_values'].shape  # [60, 1536]\n\n    assert default_shape != custom_shape, \\\n        'Default and custom processing should result in different shapes.'\n    assert default_shape[0] > custom_shape[0], \\\n        'Custom processing with smaller pixel range should result in smaller image size.'\n\n    # Processing with larger pixel range\n    mm_processor_kwargs = {'min_pixels': 100 * 32 * 32, 'max_pixels': 20000 * 32 * 32}\n    custom_processed_messages = qwen3vl_model.preprocess(messages=copy.deepcopy(sample_messages),\n                                                         mm_processor_kwargs=mm_processor_kwargs)\n    custom_content = custom_processed_messages[-1]['content']\n    custom_shape = custom_content[0]['pixel_values'].shape  # [468, 1536]\n\n    assert default_shape != custom_shape, \\\n        'Default and custom processing should result in different shapes.'\n    assert default_shape[0] < custom_shape[0], \\\n        'Custom processing with larger pixel range should result in larger image size.'\n"
  },
  {
    "path": "tests/test_lmdeploy/test_vl/test_vl_encode.py",
    "content": "import math\n\nimport numpy as np\n\nfrom lmdeploy.vl import (encode_image_base64, encode_time_series_base64, encode_video_base64, load_image,\n                         load_time_series, load_video)\n\n\ndef test_image_encode_decode():\n    url = 'https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg'\n\n    img1 = load_image(url)\n    # use PNG for lossless pixel-perfect comparison\n    b64 = encode_image_base64(url, format='PNG')\n    img2 = load_image(f'data:image/png;base64,{b64}')\n\n    assert img1.size == img2.size\n    assert img1.mode == img2.mode\n    assert img1.tobytes() == img2.tobytes()\n\n\ndef test_video_encode_decode():\n    # url = 'https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-VL/space_woaudio.mp4'\n    url = 'https://raw.githubusercontent.com/CUHKSZzxy/Online-Data/main/clip_3_removed.mp4'\n\n    # num_frames=4 to keep test fast\n    vid1, meta1 = load_video(url, num_frames=4)\n    b64 = encode_video_base64(url, num_frames=4, format='JPEG')\n    vid2, meta2 = load_video(f'data:video/jpeg;base64,{b64}')\n\n    gt_meta = {\n        'total_num_frames': 498,\n        'fps': 29.97002997002997,\n        'duration': 16.616600000000002,\n        'video_backend': 'opencv',\n        'frames_indices': [0, 165, 331, 497]\n    }\n\n    assert vid1.shape == vid2.shape\n    assert np.mean(np.abs(vid1.astype(float) - vid2.astype(float))) < 2.0  # JPEG is lossy\n    assert meta1['total_num_frames'] == gt_meta['total_num_frames']\n    assert meta1['frames_indices'] == gt_meta['frames_indices']\n\n\ndef test_time_series_encode_decode():\n    # url = \"https://huggingface.co/internlm/Intern-S1-Pro/raw/main/0092638_seism.npy\"\n    url = 'https://raw.githubusercontent.com/CUHKSZzxy/Online-Data/main/0092638_seism.npy'\n\n    ts1 = load_time_series(url)\n    b64 = encode_time_series_base64(url)\n    ts2 = load_time_series(f'data:time_series/npy;base64,{b64}')\n\n    assert ts1.shape == ts2.shape\n    assert np.allclose(ts1, ts2)\n\n\ndef test_image_modes():\n    import numpy as np\n    from PIL import Image\n\n    grayscale_img = Image.fromarray(np.zeros((100, 100), dtype=np.uint8)).convert('L')\n    b64 = encode_image_base64(grayscale_img)  # should convert L -> RGB internally\n\n    img_out = load_image(f'data:image/png;base64,{b64}')\n    assert img_out.mode == 'RGB'\n\n\ndef test_truncated_image():\n    url = 'https://github.com/irexyc/lmdeploy/releases/download/v0.0.1/tr.jpeg'\n    im = load_image(url)\n    assert im.width == 1638\n    assert im.height == 2048\n\n\ndef test_single_frame_video():\n    url = 'https://raw.githubusercontent.com/CUHKSZzxy/Online-Data/main/clip_3_removed.mp4'\n    vid, meta = load_video(url, num_frames=1)\n    assert vid.shape[0] == 1\n\n    b64 = encode_video_base64(vid)\n    assert isinstance(b64, str)\n    assert ',' not in b64  # should only be one JPEG block, no commas\n\n\ndef test_video_sampling_params():\n    url = 'https://raw.githubusercontent.com/CUHKSZzxy/Online-Data/main/clip_3_removed.mp4'\n\n    # 1. test num_frames constraint\n    num_frames = 5\n    vid, meta = load_video(url, num_frames=num_frames)\n    assert vid.shape[0] == num_frames\n    assert len(meta['frames_indices']) == num_frames\n\n    # 2. test fps constraint (original fps is ~29.97, duration ~16.6s)\n    fps = 1\n    vid, meta = load_video(url, fps=fps)\n    expected_frames = max(1, int(math.floor(meta['duration'] * fps)))\n    assert vid.shape[0] == expected_frames\n\n    # 3. test both constraints (should take the minimum)\n    # 10 fps x 16.6s ~= 166 frames > 10 frames, so will be limited by num_frames\n    num_frames = 10\n    fps = 10\n    vid, meta = load_video(url, num_frames=num_frames, fps=fps)\n    assert vid.shape[0] == num_frames\n\n    # 1 fps x 16.6s ~= 16 frames < 100 frames, so will be limited by fps\n    num_frames = 100\n    fps = 1\n    vid, meta = load_video(url, num_frames=num_frames, fps=fps)\n    expected_frames = max(1, int(math.floor(meta['duration'] * fps)))\n    assert vid.shape[0] == expected_frames\n\n\ndef test_invalid_inputs():\n    # non-existent local path\n    import pytest\n    with pytest.raises(Exception):\n        load_image('/non_existent/path/image.jpg')\n    with pytest.raises(Exception):\n        load_video('/non_existent/path/video.mp4')\n    with pytest.raises(Exception):\n        load_time_series('/non_existent/path/data.npy')\n"
  }
]